Skip to content

Commit

Permalink
Changing peristed path to be saved along with the proto model files, …
Browse files Browse the repository at this point in the history
…for Gramine as well

Signed-off-by: Lerer, Eran <[email protected]>
  • Loading branch information
cloudnoize committed Jan 13, 2025
1 parent 00392a0 commit e727b92
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 8 deletions.
2 changes: 1 addition & 1 deletion docs/about/features_index/taskrunner.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ Configurable Settings
- :code:`rounds_to_train`: (int) Specifies the number of rounds in a federation. A federated learning round is defined as one complete iteration when the collaborators train the model and send the updated model weights back to the aggregator to form a new global model. Within a round, collaborators can train the model for multiple iterations called epochs.
- :code:`write_logs`: (boolean) Metric logging callback feature. By default, logging is done through `tensorboard <https://www.tensorflow.org/tensorboard/get_started>`_ but users can also use custom metric logging function for each task.
- :code:`persist_checkpoint`: (boolean) Specifies whether to enable the storage of a persistent checkpoint in non-volatile storage for recovery purposes. When enabled, the aggregator will restore its state to what it was prior to the restart, ensuring continuity after a restart.

- :code:`persistent_db_path`: (str:path) Defines the persisted database path.

- :class:`Collaborator <openfl.component.Collaborator>`
`openfl.component.Collaborator <https://github.com/intel/openfl/blob/develop/openfl/component/collaborator/collaborator.py>`_
Expand Down
1 change: 1 addition & 0 deletions openfl-workspace/workspace/plan/defaults/aggregator.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ template : openfl.component.Aggregator
settings :
db_store_rounds : 2
persist_checkpoint: True
persistent_db_path: save/tensor.db
2 changes: 1 addition & 1 deletion openfl/component/aggregator/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def __init__(

self.tensor_db = TensorDB()
if persist_checkpoint:
logger.info("Persistent checkpoint is enabled")
logger.info("Persistent checkpoint is enabled, setting persistent db at path %s",persistent_db_path)
self.persistent_db = PersistentTensorDB(persistent_db_path)
else:
logger.info("Persistent checkpoint is disabled")
Expand Down
10 changes: 4 additions & 6 deletions openfl/databases/persistent_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,11 @@ class PersistentTensorDB:
NEXT_ROUND_TENSORS_TABLE = "next_round_tensors"
TASK_RESULT_TABLE = "task_results"
KEY_VALUE_TABLE = "key_value_store"
def __init__(self, db_path: str = "") -> None:
def __init__(self, db_path) -> None:
"""Initializes a new instance of the PersistentTensorDB class."""
full_path = "tensordb.sqlite"
if db_path:
full_path = os.path.join(db_path, full_path)
logger.info("Initializing persistent db at %s",full_path)
self.conn = sqlite3.connect(full_path, check_same_thread=False)

logger.info("Initializing persistent db at %s",db_path)
self.conn = sqlite3.connect(db_path, check_same_thread=False)
self.lock = Lock()

cursor = self.conn.cursor()
Expand Down

0 comments on commit e727b92

Please sign in to comment.