Skip to content

Commit

Permalink
Changing peristed path to be saved along with the proto model files, …
Browse files Browse the repository at this point in the history
…for Gramine as well

Signed-off-by: Lerer, Eran <[email protected]>
  • Loading branch information
cloudnoize committed Jan 14, 2025
1 parent 00392a0 commit 2eae11e
Show file tree
Hide file tree
Showing 6 changed files with 164 additions and 105 deletions.
2 changes: 1 addition & 1 deletion docs/about/features_index/taskrunner.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ Configurable Settings
- :code:`rounds_to_train`: (int) Specifies the number of rounds in a federation. A federated learning round is defined as one complete iteration when the collaborators train the model and send the updated model weights back to the aggregator to form a new global model. Within a round, collaborators can train the model for multiple iterations called epochs.
- :code:`write_logs`: (boolean) Metric logging callback feature. By default, logging is done through `tensorboard <https://www.tensorflow.org/tensorboard/get_started>`_ but users can also use custom metric logging function for each task.
- :code:`persist_checkpoint`: (boolean) Specifies whether to enable the storage of a persistent checkpoint in non-volatile storage for recovery purposes. When enabled, the aggregator will restore its state to what it was prior to the restart, ensuring continuity after a restart.

- :code:`persistent_db_path`: (str:path) Defines the persisted database path.

- :class:`Collaborator <openfl.component.Collaborator>`
`openfl.component.Collaborator <https://github.com/intel/openfl/blob/develop/openfl/component/collaborator/collaborator.py>`_
Expand Down
1 change: 1 addition & 0 deletions openfl-workspace/workspace/plan/defaults/aggregator.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ template : openfl.component.Aggregator
settings :
db_store_rounds : 2
persist_checkpoint: True
persistent_db_path: save/tensor.db
89 changes: 60 additions & 29 deletions openfl/component/aggregator/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@

import openfl.callbacks as callbacks_module
from openfl.component.straggler_handling_functions import CutoffTimeBasedStragglerHandling
from openfl.databases import TensorDB
from openfl.databases import PersistentTensorDB
from openfl.databases import PersistentTensorDB, TensorDB
from openfl.interface.aggregation_functions import WeightedAverage
from openfl.pipelines import NoCompressionPipeline, TensorCodec
from openfl.protocols import base_pb2, utils
Expand Down Expand Up @@ -85,7 +84,7 @@ def __init__(
write_logs=False,
callbacks: Optional[List] = None,
persist_checkpoint=True,
persistent_db_path=None
persistent_db_path=None,
):
"""Initializes the Aggregator.
Expand Down Expand Up @@ -143,7 +142,11 @@ def __init__(

self.tensor_db = TensorDB()
if persist_checkpoint:
logger.info("Persistent checkpoint is enabled")
persistent_db_path = persistent_db_path or "tensor.db"
logger.info(
"Persistent checkpoint is enabled, setting persistent db at path %s",
persistent_db_path,
)
self.persistent_db = PersistentTensorDB(persistent_db_path)
else:
logger.info("Persistent checkpoint is disabled")
Expand All @@ -168,7 +171,6 @@ def __init__(
# these enable getting all tensors for a task
self.collaborator_tasks_results = {} # {TaskResultKey: list of TensorKeys}
self.collaborator_task_weight = {} # {TaskResultKey: data_size}


# maintain a list of collaborators that have completed task and
# reported results in a given round
Expand All @@ -177,8 +179,13 @@ def __init__(
self.lock = Lock()
self.use_delta_updates = use_delta_updates

self.model = None # Initialize the model attribute to None
if self.persistent_db and self._recover():
logger.info("recovered state of aggregator")

# The model is built by recovery if at least one round has finished
if self.model:
logger.info("Model was loaded by recovery")
elif initial_tensor_dict:
self._load_initial_tensors_from_dict(initial_tensor_dict)
self.model = utils.construct_model_proto(
Expand All @@ -204,19 +211,21 @@ def __init__(
# https://github.com/securefederatedai/openfl/pull/1195#discussion_r1879479537
self.callbacks.on_experiment_begin()
self.callbacks.on_round_begin(self.round_number)

def _recover(self):
"""Populates the aggregator state to the state it was prior a restart
"""
"""Populates the aggregator state to the state it was prior a restart"""
recovered = False
# load tensors persistent DB
logger.info("Recovering previous state from persistent storage")
tensor_key_dict = self.persistent_db.load_tensors(self.persistent_db.get_tensors_table_name())
tensor_key_dict = self.persistent_db.load_tensors(
self.persistent_db.get_tensors_table_name()
)
if len(tensor_key_dict) > 0:
logger.info(f"Recovering {len(tensor_key_dict)} model tensors")
recovered = True
self.tensor_db.cache_tensor(tensor_key_dict)
committed_round_number, self.best_model_score = self.persistent_db.get_round_and_best_score()
committed_round_number, self.best_model_score = (
self.persistent_db.get_round_and_best_score()
)
logger.info("Recovery - Setting model proto")
to_proto_tensor_dict = {}
for tk in tensor_key_dict:
Expand All @@ -225,24 +234,29 @@ def _recover(self):
self.model = utils.construct_model_proto(
to_proto_tensor_dict, committed_round_number, self.compression_pipeline
)
# round number is the current round which is still in process i.e. committed_round_number + 1
# round number is the current round which is still in process
# i.e. committed_round_number + 1
self.round_number = committed_round_number + 1
logger.info("Recovery - loaded round number %s and best score %s", self.round_number,self.best_model_score)

next_round_tensor_key_dict = self.persistent_db.load_tensors(self.persistent_db.get_next_round_tensors_table_name())
logger.info(
"Recovery - loaded round number %s and best score %s",
self.round_number,
self.best_model_score,
)

next_round_tensor_key_dict = self.persistent_db.load_tensors(
self.persistent_db.get_next_round_tensors_table_name()
)
if len(next_round_tensor_key_dict) > 0:
logger.info(f"Recovering {len(next_round_tensor_key_dict)} next round model tensors")
recovered = True
self.tensor_db.cache_tensor(next_round_tensor_key_dict)


logger.info("Recovery - Finished populating tensor DB")

logger.debug("Recovery - this is the tensor_db after recovery: %s", self.tensor_db)

if self.persistent_db.is_task_table_empty():
logger.debug("task table is empty")
return recovered

logger.info("Recovery - Replaying saved task results")
task_id = 1
while True:
Expand All @@ -259,8 +273,15 @@ def _recover(self):
NamedTensor.FromString(serialized_tensor)
for serialized_tensor in serialized_tensors
]
logger.info("Recovery - Replaying task results %s %s %s",collaborator_name ,round_number, task_name )
self.process_task_results(collaborator_name, round_number, task_name, data_size, named_tensors)
logger.info(
"Recovery - Replaying task results %s %s %s",
collaborator_name,
round_number,
task_name,
)
self.process_task_results(
collaborator_name, round_number, task_name, data_size, named_tensors
)
task_id += 1
return recovered

Expand Down Expand Up @@ -342,12 +363,16 @@ def _save_model(self, round_number, file_path):
# Transaction to persist/delete all data needed to increment the round
if self.persistent_db:
if self.next_model_round_number > 0:
next_round_tensors = self.tensor_db.get_tensors_by_round_and_tags(self.next_model_round_number,("model",))
self.persistent_db.finalize_round(tensor_tuple_dict,next_round_tensors,self.round_number,self.best_model_score)
logger.info(
"Persist model and clean task result for round %s",
round_number,
next_round_tensors = self.tensor_db.get_tensors_by_round_and_tags(
self.next_model_round_number, ("model",)
)
self.persistent_db.finalize_round(
tensor_tuple_dict, next_round_tensors, self.round_number, self.best_model_score
)
logger.info(
"Persist model and clean task result for round %s",
round_number,
)
self.last_tensor_dict = tensor_dict
self.model = utils.construct_model_proto(
tensor_dict, round_number, self.compression_pipeline
Expand Down Expand Up @@ -690,13 +715,19 @@ def send_local_task_results(
# Save task and its metadata for recovery
serialized_tensors = [tensor.SerializeToString() for tensor in named_tensors]
if self.persistent_db:
self.persistent_db.save_task_results(collaborator_name,round_number,task_name,data_size,serialized_tensors)
logger.debug(f"Persisting task results {task_name} from {collaborator_name} round {round_number}")
self.persistent_db.save_task_results(
collaborator_name, round_number, task_name, data_size, serialized_tensors
)
logger.debug(
f"Persisting task results {task_name} from {collaborator_name} round {round_number}"
)
logger.info(
f"Collaborator {collaborator_name} is sending task results "
f"for {task_name}, round {round_number}"
)
self.process_task_results(collaborator_name,round_number,task_name,data_size,named_tensors)
self.process_task_results(
collaborator_name, round_number, task_name, data_size, named_tensors
)

def process_task_results(
self,
Expand Down
2 changes: 1 addition & 1 deletion openfl/databases/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
# SPDX-License-Identifier: Apache-2.0


from openfl.databases.tensor_db import TensorDB
from openfl.databases.persistent_db import PersistentTensorDB
from openfl.databases.tensor_db import TensorDB
Loading

0 comments on commit 2eae11e

Please sign in to comment.