Skip to content

Commit

Permalink
Add persistent db module and recovery logic (#1229)
Browse files Browse the repository at this point in the history
* Add persistent db module and recovery logic

Signed-off-by: Lerer, Eran <[email protected]>

* Address code review comments

Signed-off-by: Lerer, Eran <[email protected]>

* Adding persist_checkpoint flag to the plan

Signed-off-by: Lerer, Eran <[email protected]>

* Handling next round model tensors

Signed-off-by: Lerer, Eran <[email protected]>

* Changing peristed path to be saved along with the proto model files, for Gramine as well

Signed-off-by: Lerer, Eran <[email protected]>

---------

Signed-off-by: Lerer, Eran <[email protected]>
  • Loading branch information
cloudnoize authored Jan 14, 2025
1 parent 4867947 commit 33004f6
Show file tree
Hide file tree
Showing 6 changed files with 553 additions and 25 deletions.
5 changes: 3 additions & 2 deletions docs/about/features_index/taskrunner.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,9 @@ Configurable Settings
- :code:`best_state_path`: (str:path) Defines the weight protobuf file path that will be saved to for the highest accuracy model during the experiment.
- :code:`last_state_path`: (str:path) Defines the weight protobuf file path that will be saved to during the last round completed in each experiment.
- :code:`rounds_to_train`: (int) Specifies the number of rounds in a federation. A federated learning round is defined as one complete iteration when the collaborators train the model and send the updated model weights back to the aggregator to form a new global model. Within a round, collaborators can train the model for multiple iterations called epochs.
- :code:`write_logs`: (boolean) Metric logging callback feature. By default, logging is done through `tensorboard <https://www.tensorflow.org/tensorboard/get_started>`_ but users can also use custom metric logging function for each task.

- :code:`write_logs`: (boolean) Metric logging callback feature. By default, logging is done through `tensorboard <https://www.tensorflow.org/tensorboard/get_started>`_ but users can also use custom metric logging function for each task.
- :code:`persist_checkpoint`: (boolean) Specifies whether to enable the storage of a persistent checkpoint in non-volatile storage for recovery purposes. When enabled, the aggregator will restore its state to what it was prior to the restart, ensuring continuity after a restart.
- :code:`persistent_db_path`: (str:path) Defines the persisted database path.

- :class:`Collaborator <openfl.component.Collaborator>`
`openfl.component.Collaborator <https://github.com/intel/openfl/blob/develop/openfl/component/collaborator/collaborator.py>`_
Expand Down
2 changes: 2 additions & 0 deletions openfl-workspace/workspace/plan/defaults/aggregator.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
template : openfl.component.Aggregator
settings :
db_store_rounds : 2
persist_checkpoint: True
persistent_db_path: save/tensor.db
172 changes: 149 additions & 23 deletions openfl/component/aggregator/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@

import openfl.callbacks as callbacks_module
from openfl.component.straggler_handling_functions import CutoffTimeBasedStragglerHandling
from openfl.databases import TensorDB
from openfl.databases import PersistentTensorDB, TensorDB
from openfl.interface.aggregation_functions import WeightedAverage
from openfl.pipelines import NoCompressionPipeline, TensorCodec
from openfl.protocols import base_pb2, utils
from openfl.protocols.base_pb2 import NamedTensor
from openfl.utilities import TaskResultKey, TensorKey, change_tags

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -82,6 +83,8 @@ def __init__(
log_memory_usage=False,
write_logs=False,
callbacks: Optional[List] = None,
persist_checkpoint=True,
persistent_db_path=None,
):
"""Initializes the Aggregator.
Expand Down Expand Up @@ -110,6 +113,7 @@ def __init__(
callbacks: List of callbacks to be used during the experiment.
"""
self.round_number = 0
self.next_model_round_number = 0

if single_col_cert_common_name:
logger.warning(
Expand Down Expand Up @@ -137,6 +141,16 @@ def __init__(
self.quit_job_sent_to = []

self.tensor_db = TensorDB()
if persist_checkpoint:
persistent_db_path = persistent_db_path or "tensor.db"
logger.info(
"Persistent checkpoint is enabled, setting persistent db at path %s",
persistent_db_path,
)
self.persistent_db = PersistentTensorDB(persistent_db_path)
else:
logger.info("Persistent checkpoint is disabled")
self.persistent_db = None
# FIXME: I think next line generates an error on the second round
# if it is set to 1 for the aggregator.
self.db_store_rounds = db_store_rounds
Expand All @@ -154,8 +168,25 @@ def __init__(
# TODO: Remove. Used in deprecated interactive and native APIs
self.best_tensor_dict: dict = {}
self.last_tensor_dict: dict = {}
# these enable getting all tensors for a task
self.collaborator_tasks_results = {} # {TaskResultKey: list of TensorKeys}
self.collaborator_task_weight = {} # {TaskResultKey: data_size}

if initial_tensor_dict:
# maintain a list of collaborators that have completed task and
# reported results in a given round
self.collaborators_done = []
# Initialize a lock for thread safety
self.lock = Lock()
self.use_delta_updates = use_delta_updates

self.model = None # Initialize the model attribute to None
if self.persistent_db and self._recover():
logger.info("recovered state of aggregator")

# The model is built by recovery if at least one round has finished
if self.model:
logger.info("Model was loaded by recovery")
elif initial_tensor_dict:
self._load_initial_tensors_from_dict(initial_tensor_dict)
self.model = utils.construct_model_proto(
tensor_dict=initial_tensor_dict,
Expand All @@ -168,20 +199,6 @@ def __init__(

self.collaborator_tensor_results = {} # {TensorKey: nparray}}

# these enable getting all tensors for a task
self.collaborator_tasks_results = {} # {TaskResultKey: list of TensorKeys}

self.collaborator_task_weight = {} # {TaskResultKey: data_size}

# maintain a list of collaborators that have completed task and
# reported results in a given round
self.collaborators_done = []

# Initialize a lock for thread safety
self.lock = Lock()

self.use_delta_updates = use_delta_updates

# Callbacks
self.callbacks = callbacks_module.CallbackList(
callbacks,
Expand All @@ -195,6 +212,79 @@ def __init__(
self.callbacks.on_experiment_begin()
self.callbacks.on_round_begin(self.round_number)

def _recover(self):
"""Populates the aggregator state to the state it was prior a restart"""
recovered = False
# load tensors persistent DB
tensor_key_dict = self.persistent_db.load_tensors(
self.persistent_db.get_tensors_table_name()
)
if len(tensor_key_dict) > 0:
logger.info(f"Recovering {len(tensor_key_dict)} model tensors")
recovered = True
self.tensor_db.cache_tensor(tensor_key_dict)
committed_round_number, self.best_model_score = (
self.persistent_db.get_round_and_best_score()
)
logger.info("Recovery - Setting model proto")
to_proto_tensor_dict = {}
for tk in tensor_key_dict:
tk_name, _, _, _, _ = tk
to_proto_tensor_dict[tk_name] = tensor_key_dict[tk]
self.model = utils.construct_model_proto(
to_proto_tensor_dict, committed_round_number, self.compression_pipeline
)
# round number is the current round which is still in process
# i.e. committed_round_number + 1
self.round_number = committed_round_number + 1
logger.info(
"Recovery - loaded round number %s and best score %s",
self.round_number,
self.best_model_score,
)

next_round_tensor_key_dict = self.persistent_db.load_tensors(
self.persistent_db.get_next_round_tensors_table_name()
)
if len(next_round_tensor_key_dict) > 0:
logger.info(f"Recovering {len(next_round_tensor_key_dict)} next round model tensors")
recovered = True
self.tensor_db.cache_tensor(next_round_tensor_key_dict)

logger.debug("Recovery - this is the tensor_db after recovery: %s", self.tensor_db)

if self.persistent_db.is_task_table_empty():
logger.debug("task table is empty")
return recovered

logger.info("Recovery - Replaying saved task results")
task_id = 1
while True:
task_result = self.persistent_db.get_task_result_by_id(task_id)
if not task_result:
break
recovered = True
collaborator_name = task_result["collaborator_name"]
round_number = task_result["round_number"]
task_name = task_result["task_name"]
data_size = task_result["data_size"]
serialized_tensors = task_result["named_tensors"]
named_tensors = [
NamedTensor.FromString(serialized_tensor)
for serialized_tensor in serialized_tensors
]
logger.info(
"Recovery - Replaying task results %s %s %s",
collaborator_name,
round_number,
task_name,
)
self.process_task_results(
collaborator_name, round_number, task_name, data_size, named_tensors
)
task_id += 1
return recovered

def _load_initial_tensors(self):
"""Load all of the tensors required to begin federated learning.
Expand Down Expand Up @@ -255,9 +345,12 @@ def _save_model(self, round_number, file_path):
for k, v in og_tensor_dict.items()
]
tensor_dict = {}
tensor_tuple_dict = {}
for tk in tensor_keys:
tk_name, _, _, _, _ = tk
tensor_dict[tk_name] = self.tensor_db.get_tensor_from_cache(tk)
tensor_value = self.tensor_db.get_tensor_from_cache(tk)
tensor_dict[tk_name] = tensor_value
tensor_tuple_dict[tk] = tensor_value
if tensor_dict[tk_name] is None:
logger.info(
"Cannot save model for round %s. Continuing...",
Expand All @@ -267,6 +360,19 @@ def _save_model(self, round_number, file_path):
if file_path == self.best_state_path:
self.best_tensor_dict = tensor_dict
if file_path == self.last_state_path:
# Transaction to persist/delete all data needed to increment the round
if self.persistent_db:
if self.next_model_round_number > 0:
next_round_tensors = self.tensor_db.get_tensors_by_round_and_tags(
self.next_model_round_number, ("model",)
)
self.persistent_db.finalize_round(
tensor_tuple_dict, next_round_tensors, self.round_number, self.best_model_score
)
logger.info(
"Persist model and clean task result for round %s",
round_number,
)
self.last_tensor_dict = tensor_dict
self.model = utils.construct_model_proto(
tensor_dict, round_number, self.compression_pipeline
Expand Down Expand Up @@ -606,6 +712,31 @@ def send_local_task_results(
Returns:
None
"""
# Save task and its metadata for recovery
serialized_tensors = [tensor.SerializeToString() for tensor in named_tensors]
if self.persistent_db:
self.persistent_db.save_task_results(
collaborator_name, round_number, task_name, data_size, serialized_tensors
)
logger.debug(
f"Persisting task results {task_name} from {collaborator_name} round {round_number}"
)
logger.info(
f"Collaborator {collaborator_name} is sending task results "
f"for {task_name}, round {round_number}"
)
self.process_task_results(
collaborator_name, round_number, task_name, data_size, named_tensors
)

def process_task_results(
self,
collaborator_name,
round_number,
task_name,
data_size,
named_tensors,
):
if self._time_to_quit() or collaborator_name in self.stragglers:
logger.warning(
f"STRAGGLER: Collaborator {collaborator_name} is reporting results "
Expand All @@ -620,11 +751,6 @@ def send_local_task_results(
)
return

logger.info(
f"Collaborator {collaborator_name} is sending task results "
f"for {task_name}, round {round_number}"
)

task_key = TaskResultKey(task_name, collaborator_name, round_number)

# we mustn't have results already
Expand Down Expand Up @@ -864,7 +990,7 @@ def _prepare_trained(self, tensor_name, origin, round_number, report, agg_result
new_model_report,
("model",),
)

self.next_model_round_number = new_model_round_number
# Finally, cache the updated model tensor
self.tensor_db.cache_tensor({final_model_tk: new_model_nparray})

Expand Down
1 change: 1 addition & 0 deletions openfl/databases/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@
# SPDX-License-Identifier: Apache-2.0


from openfl.databases.persistent_db import PersistentTensorDB
from openfl.databases.tensor_db import TensorDB
Loading

0 comments on commit 33004f6

Please sign in to comment.