diff --git a/docs/about/features_index/taskrunner.rst b/docs/about/features_index/taskrunner.rst index 2097c72f32..f8730f463b 100644 --- a/docs/about/features_index/taskrunner.rst +++ b/docs/about/features_index/taskrunner.rst @@ -44,8 +44,9 @@ Configurable Settings - :code:`best_state_path`: (str:path) Defines the weight protobuf file path that will be saved to for the highest accuracy model during the experiment. - :code:`last_state_path`: (str:path) Defines the weight protobuf file path that will be saved to during the last round completed in each experiment. - :code:`rounds_to_train`: (int) Specifies the number of rounds in a federation. A federated learning round is defined as one complete iteration when the collaborators train the model and send the updated model weights back to the aggregator to form a new global model. Within a round, collaborators can train the model for multiple iterations called epochs. - - :code:`write_logs`: (boolean) Metric logging callback feature. By default, logging is done through `tensorboard `_ but users can also use custom metric logging function for each task. - + - :code:`write_logs`: (boolean) Metric logging callback feature. By default, logging is done through `tensorboard `_ but users can also use custom metric logging function for each task. + - :code:`persist_checkpoint`: (boolean) Specifies whether to enable the storage of a persistent checkpoint in non-volatile storage for recovery purposes. When enabled, the aggregator will restore its state to what it was prior to the restart, ensuring continuity after a restart. + - :code:`persistent_db_path`: (str:path) Defines the persisted database path. - :class:`Collaborator ` `openfl.component.Collaborator `_ diff --git a/openfl-workspace/workspace/plan/defaults/aggregator.yaml b/openfl-workspace/workspace/plan/defaults/aggregator.yaml index 43d923b996..d4204e2b0c 100644 --- a/openfl-workspace/workspace/plan/defaults/aggregator.yaml +++ b/openfl-workspace/workspace/plan/defaults/aggregator.yaml @@ -1,3 +1,5 @@ template : openfl.component.Aggregator settings : db_store_rounds : 2 + persist_checkpoint: True + persistent_db_path: save/tensor.db diff --git a/openfl/component/aggregator/aggregator.py b/openfl/component/aggregator/aggregator.py index eaac9fa6a0..9964f331d4 100644 --- a/openfl/component/aggregator/aggregator.py +++ b/openfl/component/aggregator/aggregator.py @@ -12,10 +12,11 @@ import openfl.callbacks as callbacks_module from openfl.component.straggler_handling_functions import CutoffTimeBasedStragglerHandling -from openfl.databases import TensorDB +from openfl.databases import PersistentTensorDB, TensorDB from openfl.interface.aggregation_functions import WeightedAverage from openfl.pipelines import NoCompressionPipeline, TensorCodec from openfl.protocols import base_pb2, utils +from openfl.protocols.base_pb2 import NamedTensor from openfl.utilities import TaskResultKey, TensorKey, change_tags logger = logging.getLogger(__name__) @@ -82,6 +83,8 @@ def __init__( log_memory_usage=False, write_logs=False, callbacks: Optional[List] = None, + persist_checkpoint=True, + persistent_db_path=None, ): """Initializes the Aggregator. @@ -110,6 +113,7 @@ def __init__( callbacks: List of callbacks to be used during the experiment. """ self.round_number = 0 + self.next_model_round_number = 0 if single_col_cert_common_name: logger.warning( @@ -137,6 +141,16 @@ def __init__( self.quit_job_sent_to = [] self.tensor_db = TensorDB() + if persist_checkpoint: + persistent_db_path = persistent_db_path or "tensor.db" + logger.info( + "Persistent checkpoint is enabled, setting persistent db at path %s", + persistent_db_path, + ) + self.persistent_db = PersistentTensorDB(persistent_db_path) + else: + logger.info("Persistent checkpoint is disabled") + self.persistent_db = None # FIXME: I think next line generates an error on the second round # if it is set to 1 for the aggregator. self.db_store_rounds = db_store_rounds @@ -154,8 +168,25 @@ def __init__( # TODO: Remove. Used in deprecated interactive and native APIs self.best_tensor_dict: dict = {} self.last_tensor_dict: dict = {} + # these enable getting all tensors for a task + self.collaborator_tasks_results = {} # {TaskResultKey: list of TensorKeys} + self.collaborator_task_weight = {} # {TaskResultKey: data_size} - if initial_tensor_dict: + # maintain a list of collaborators that have completed task and + # reported results in a given round + self.collaborators_done = [] + # Initialize a lock for thread safety + self.lock = Lock() + self.use_delta_updates = use_delta_updates + + self.model = None # Initialize the model attribute to None + if self.persistent_db and self._recover(): + logger.info("recovered state of aggregator") + + # The model is built by recovery if at least one round has finished + if self.model: + logger.info("Model was loaded by recovery") + elif initial_tensor_dict: self._load_initial_tensors_from_dict(initial_tensor_dict) self.model = utils.construct_model_proto( tensor_dict=initial_tensor_dict, @@ -168,20 +199,6 @@ def __init__( self.collaborator_tensor_results = {} # {TensorKey: nparray}} - # these enable getting all tensors for a task - self.collaborator_tasks_results = {} # {TaskResultKey: list of TensorKeys} - - self.collaborator_task_weight = {} # {TaskResultKey: data_size} - - # maintain a list of collaborators that have completed task and - # reported results in a given round - self.collaborators_done = [] - - # Initialize a lock for thread safety - self.lock = Lock() - - self.use_delta_updates = use_delta_updates - # Callbacks self.callbacks = callbacks_module.CallbackList( callbacks, @@ -195,6 +212,79 @@ def __init__( self.callbacks.on_experiment_begin() self.callbacks.on_round_begin(self.round_number) + def _recover(self): + """Populates the aggregator state to the state it was prior a restart""" + recovered = False + # load tensors persistent DB + tensor_key_dict = self.persistent_db.load_tensors( + self.persistent_db.get_tensors_table_name() + ) + if len(tensor_key_dict) > 0: + logger.info(f"Recovering {len(tensor_key_dict)} model tensors") + recovered = True + self.tensor_db.cache_tensor(tensor_key_dict) + committed_round_number, self.best_model_score = ( + self.persistent_db.get_round_and_best_score() + ) + logger.info("Recovery - Setting model proto") + to_proto_tensor_dict = {} + for tk in tensor_key_dict: + tk_name, _, _, _, _ = tk + to_proto_tensor_dict[tk_name] = tensor_key_dict[tk] + self.model = utils.construct_model_proto( + to_proto_tensor_dict, committed_round_number, self.compression_pipeline + ) + # round number is the current round which is still in process + # i.e. committed_round_number + 1 + self.round_number = committed_round_number + 1 + logger.info( + "Recovery - loaded round number %s and best score %s", + self.round_number, + self.best_model_score, + ) + + next_round_tensor_key_dict = self.persistent_db.load_tensors( + self.persistent_db.get_next_round_tensors_table_name() + ) + if len(next_round_tensor_key_dict) > 0: + logger.info(f"Recovering {len(next_round_tensor_key_dict)} next round model tensors") + recovered = True + self.tensor_db.cache_tensor(next_round_tensor_key_dict) + + logger.debug("Recovery - this is the tensor_db after recovery: %s", self.tensor_db) + + if self.persistent_db.is_task_table_empty(): + logger.debug("task table is empty") + return recovered + + logger.info("Recovery - Replaying saved task results") + task_id = 1 + while True: + task_result = self.persistent_db.get_task_result_by_id(task_id) + if not task_result: + break + recovered = True + collaborator_name = task_result["collaborator_name"] + round_number = task_result["round_number"] + task_name = task_result["task_name"] + data_size = task_result["data_size"] + serialized_tensors = task_result["named_tensors"] + named_tensors = [ + NamedTensor.FromString(serialized_tensor) + for serialized_tensor in serialized_tensors + ] + logger.info( + "Recovery - Replaying task results %s %s %s", + collaborator_name, + round_number, + task_name, + ) + self.process_task_results( + collaborator_name, round_number, task_name, data_size, named_tensors + ) + task_id += 1 + return recovered + def _load_initial_tensors(self): """Load all of the tensors required to begin federated learning. @@ -255,9 +345,12 @@ def _save_model(self, round_number, file_path): for k, v in og_tensor_dict.items() ] tensor_dict = {} + tensor_tuple_dict = {} for tk in tensor_keys: tk_name, _, _, _, _ = tk - tensor_dict[tk_name] = self.tensor_db.get_tensor_from_cache(tk) + tensor_value = self.tensor_db.get_tensor_from_cache(tk) + tensor_dict[tk_name] = tensor_value + tensor_tuple_dict[tk] = tensor_value if tensor_dict[tk_name] is None: logger.info( "Cannot save model for round %s. Continuing...", @@ -267,6 +360,19 @@ def _save_model(self, round_number, file_path): if file_path == self.best_state_path: self.best_tensor_dict = tensor_dict if file_path == self.last_state_path: + # Transaction to persist/delete all data needed to increment the round + if self.persistent_db: + if self.next_model_round_number > 0: + next_round_tensors = self.tensor_db.get_tensors_by_round_and_tags( + self.next_model_round_number, ("model",) + ) + self.persistent_db.finalize_round( + tensor_tuple_dict, next_round_tensors, self.round_number, self.best_model_score + ) + logger.info( + "Persist model and clean task result for round %s", + round_number, + ) self.last_tensor_dict = tensor_dict self.model = utils.construct_model_proto( tensor_dict, round_number, self.compression_pipeline @@ -606,6 +712,31 @@ def send_local_task_results( Returns: None """ + # Save task and its metadata for recovery + serialized_tensors = [tensor.SerializeToString() for tensor in named_tensors] + if self.persistent_db: + self.persistent_db.save_task_results( + collaborator_name, round_number, task_name, data_size, serialized_tensors + ) + logger.debug( + f"Persisting task results {task_name} from {collaborator_name} round {round_number}" + ) + logger.info( + f"Collaborator {collaborator_name} is sending task results " + f"for {task_name}, round {round_number}" + ) + self.process_task_results( + collaborator_name, round_number, task_name, data_size, named_tensors + ) + + def process_task_results( + self, + collaborator_name, + round_number, + task_name, + data_size, + named_tensors, + ): if self._time_to_quit() or collaborator_name in self.stragglers: logger.warning( f"STRAGGLER: Collaborator {collaborator_name} is reporting results " @@ -620,11 +751,6 @@ def send_local_task_results( ) return - logger.info( - f"Collaborator {collaborator_name} is sending task results " - f"for {task_name}, round {round_number}" - ) - task_key = TaskResultKey(task_name, collaborator_name, round_number) # we mustn't have results already @@ -864,7 +990,7 @@ def _prepare_trained(self, tensor_name, origin, round_number, report, agg_result new_model_report, ("model",), ) - + self.next_model_round_number = new_model_round_number # Finally, cache the updated model tensor self.tensor_db.cache_tensor({final_model_tk: new_model_nparray}) diff --git a/openfl/databases/__init__.py b/openfl/databases/__init__.py index 849fcde7c9..0e64082d5f 100644 --- a/openfl/databases/__init__.py +++ b/openfl/databases/__init__.py @@ -2,4 +2,5 @@ # SPDX-License-Identifier: Apache-2.0 +from openfl.databases.persistent_db import PersistentTensorDB from openfl.databases.tensor_db import TensorDB diff --git a/openfl/databases/persistent_db.py b/openfl/databases/persistent_db.py new file mode 100644 index 0000000000..7fe0c6463f --- /dev/null +++ b/openfl/databases/persistent_db.py @@ -0,0 +1,365 @@ +import json +import logging +import pickle +import sqlite3 +from threading import Lock +from typing import Dict, Optional + +import numpy as np + +from openfl.utilities import TensorKey + +logger = logging.getLogger(__name__) + +__all__ = ["PersistentTensorDB"] + + +class PersistentTensorDB: + """ + The PersistentTensorDB class implements a database + for storing tensors and metadata using SQLite. + + Attributes: + conn: The SQLite connection object. + cursor: The SQLite cursor object. + lock: A threading Lock object used to ensure thread-safe operations. + """ + + TENSORS_TABLE = "tensors" + NEXT_ROUND_TENSORS_TABLE = "next_round_tensors" + TASK_RESULT_TABLE = "task_results" + KEY_VALUE_TABLE = "key_value_store" + + def __init__(self, db_path) -> None: + """Initializes a new instance of the PersistentTensorDB class.""" + + logger.info("Initializing persistent db at %s", db_path) + self.conn = sqlite3.connect(db_path, check_same_thread=False) + self.lock = Lock() + + cursor = self.conn.cursor() + self._create_model_tensors_table(cursor, PersistentTensorDB.TENSORS_TABLE) + self._create_model_tensors_table(cursor, PersistentTensorDB.NEXT_ROUND_TENSORS_TABLE) + self._create_task_results_table(cursor) + self._create_key_value_store(cursor) + self.conn.commit() + + def _create_model_tensors_table(self, cursor, table_name) -> None: + """Create the database table for storing tensors if it does not exist.""" + query = f""" + CREATE TABLE IF NOT EXISTS {table_name} ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + tensor_name TEXT NOT NULL, + origin TEXT NOT NULL, + round INTEGER NOT NULL, + report INTEGER NOT NULL, + tags TEXT, + nparray BLOB NOT NULL + ) + """ + cursor.execute(query) + + def _create_task_results_table(self, cursor) -> None: + """Creates a table for storing task results.""" + query = f""" + CREATE TABLE IF NOT EXISTS {PersistentTensorDB.TASK_RESULT_TABLE} ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + collaborator_name TEXT NOT NULL, + round_number INTEGER NOT NULL, + task_name TEXT NOT NULL, + data_size INTEGER NOT NULL, + named_tensors BLOB NOT NULL + ) + """ + cursor.execute(query) + + def _create_key_value_store(self, cursor) -> None: + """Create a key-value store table for storing additional metadata.""" + query = f""" + CREATE TABLE IF NOT EXISTS {PersistentTensorDB.KEY_VALUE_TABLE} ( + key TEXT PRIMARY KEY, + value REAL NOT NULL + ) + """ + cursor.execute(query) + + def save_task_results( + self, + collaborator_name: str, + round_number: int, + task_name: str, + data_size: int, + named_tensors, + ): + """ + Saves task results to the task_results table. + + Args: + collaborator_name (str): Collaborator name. + round_number (int): Round number. + task_name (str): Task name. + data_size (int): Data size. + named_tensors(List): list of binary representation of tensors. + """ + serialized_blob = pickle.dumps(named_tensors) + + # Insert into the database + insert_query = f""" + INSERT INTO {PersistentTensorDB.TASK_RESULT_TABLE} + (collaborator_name, round_number, task_name, data_size, named_tensors) + VALUES (?, ?, ?, ?, ?); + """ + with self.lock: + cursor = self.conn.cursor() + cursor.execute( + insert_query, + (collaborator_name, round_number, task_name, data_size, serialized_blob), + ) + self.conn.commit() + + def get_task_result_by_id(self, task_result_id: int): + """ + Retrieve a task result by its ID. + + Args: + task_result_id (int): The ID of the task result to retrieve. + + Returns: + A dictionary containing the task result details, or None if not found. + """ + with self.lock: + cursor = self.conn.cursor() + query = f""" + SELECT collaborator_name, round_number, task_name, data_size, named_tensors + FROM {PersistentTensorDB.TASK_RESULT_TABLE} + WHERE id = ? + """ + cursor.execute(query, (task_result_id,)) + result = cursor.fetchone() + if result: + collaborator_name, round_number, task_name, data_size, serialized_blob = result + serialized_tensors = pickle.loads(serialized_blob) + return { + "collaborator_name": collaborator_name, + "round_number": round_number, + "task_name": task_name, + "data_size": data_size, + "named_tensors": serialized_tensors, + } + return None + + def _serialize_array(self, array: np.ndarray) -> bytes: + """Serialize a NumPy array into bytes for storing in SQLite. + note: using pickle since in some cases the array is actually a scalar. + """ + return pickle.dumps(array) + + def _deserialize_array(self, blob: bytes, dtype: Optional[np.dtype] = None) -> np.ndarray: + """Deserialize bytes from SQLite into a NumPy array.""" + try: + return pickle.loads(blob) + except Exception as e: + raise ValueError(f"Failed to deserialize array: {e}") + + def __repr__(self) -> str: + """Returns a string representation of the PersistentTensorDB.""" + with self.lock: + cursor = self.conn.cursor() + cursor.execute("SELECT tensor_name, origin, round, report, tags FROM tensors") + rows = cursor.fetchall() + return f"PersistentTensorDB contents:\n{rows}" + + def finalize_round( + self, + tensor_key_dict: Dict[TensorKey, np.ndarray], + next_round_tensor_key_dict: Dict[TensorKey, np.ndarray], + round_number: int, + best_score: float, + ): + """Finalize a training round by saving tensors, preparing for the next round, + and updating metadata in the database. + + This function performs the following steps as a single transaction: + 1. Persist the tensors of the current round into the database. + 2. Persist the tensors for the next training round into the database. + 3. Reinitialize the task results table to prepare for new tasks. + 4. Update the round number and best score in the key-value store. + + If any step fails, the transaction is rolled back to ensure data integrity. + + Args: + tensor_key_dict (Dict[TensorKey, np.ndarray]): + A dictionary mapping tensor keys to their corresponding + NumPy arrays for the current round. + next_round_tensor_key_dict (Dict[TensorKey, np.ndarray]): + A dictionary mapping tensor keys to their corresponding + NumPy arrays for the next round. + round_number (int): + The current training round number. + best_score (float): + The best score achieved during the current round. + + Raises: + RuntimeError: If an error occurs during the transaction, the transaction is rolled back, + and a RuntimeError is raised with the details of the failure. + """ + with self.lock: + try: + # Begin transaction + cursor = self.conn.cursor() + cursor.execute("BEGIN TRANSACTION") + self._persist_tensors(cursor, PersistentTensorDB.TENSORS_TABLE, tensor_key_dict) + self._persist_next_round_tensors(cursor, next_round_tensor_key_dict) + self._init_task_results_table(cursor) + self._save_round_and_best_score(cursor, round_number, best_score) + # Commit transaction + self.conn.commit() + logger.info( + f"Committed model for round {round_number}, saved {len(tensor_key_dict)}" + f" model tensors and {len(next_round_tensor_key_dict)}" + f" next round model tensors with best_score {best_score}" + ) + except Exception as e: + # Rollback transaction in case of an error + self.conn.rollback() + raise RuntimeError(f"Failed to finalize round: {e}") + + def _persist_tensors( + self, cursor, table_name, tensor_key_dict: Dict[TensorKey, np.ndarray] + ) -> None: + """Insert a dictionary of tensors into the SQLite as part of transaction""" + for tensor_key, nparray in tensor_key_dict.items(): + tensor_name, origin, fl_round, report, tags = tensor_key + serialized_array = self._serialize_array(nparray) + serialized_tags = json.dumps(tags) + query = f""" + INSERT INTO {table_name} (tensor_name, origin, round, report, tags, nparray) + VALUES (?, ?, ?, ?, ?, ?) + """ + cursor.execute( + query, + (tensor_name, origin, fl_round, int(report), serialized_tags, serialized_array), + ) + + def _persist_next_round_tensors( + self, cursor, tensor_key_dict: Dict[TensorKey, np.ndarray] + ) -> None: + """Persisting the last round next_round tensors.""" + drop_table_query = f"DROP TABLE IF EXISTS {PersistentTensorDB.NEXT_ROUND_TENSORS_TABLE}" + cursor.execute(drop_table_query) + self._create_model_tensors_table(cursor, PersistentTensorDB.NEXT_ROUND_TENSORS_TABLE) + self._persist_tensors(cursor, PersistentTensorDB.NEXT_ROUND_TENSORS_TABLE, tensor_key_dict) + + def _init_task_results_table(self, cursor): + """ + Creates a table for storing task results. Drops the table first if it already exists. + """ + drop_table_query = "DROP TABLE IF EXISTS task_results" + cursor.execute(drop_table_query) + self._create_task_results_table(cursor) + + def _save_round_and_best_score(self, cursor, round_number: int, best_score: float) -> None: + """Save the round number and best score as key-value pairs in the database.""" + # Create a table with key-value structure where values can be integer or float + # Insert or update the round_number + cursor.execute( + """ + INSERT OR REPLACE INTO key_value_store (key, value) + VALUES (?, ?) + """, + ("round_number", float(round_number)), + ) + + # Insert or update the best_score + cursor.execute( + """ + INSERT OR REPLACE INTO key_value_store (key, value) + VALUES (?, ?) + """, + ("best_score", float(best_score)), + ) + + def get_tensors_table_name(self) -> str: + return PersistentTensorDB.TENSORS_TABLE + + def get_next_round_tensors_table_name(self) -> str: + return PersistentTensorDB.NEXT_ROUND_TENSORS_TABLE + + def load_tensors(self, tensor_table) -> Dict[TensorKey, np.ndarray]: + """Load all tensors from the SQLite database and return them as a dictionary.""" + tensor_dict = {} + with self.lock: + cursor = self.conn.cursor() + query = f"SELECT tensor_name, origin, round, report, tags, nparray FROM {tensor_table}" + cursor.execute(query) + rows = cursor.fetchall() + for row in rows: + tensor_name, origin, fl_round, report, tags, nparray = row + # Deserialize the JSON string back to a Python list + deserialized_tags = tuple(json.loads(tags)) + tensor_key = TensorKey(tensor_name, origin, fl_round, report, deserialized_tags) + tensor_dict[tensor_key] = self._deserialize_array(nparray) + return tensor_dict + + def get_round_and_best_score(self) -> tuple[int, float]: + """Retrieve the round number and best score from the database.""" + with self.lock: + cursor = self.conn.cursor() + # Fetch the round_number + cursor.execute( + """ + SELECT value FROM key_value_store WHERE key = ? + """, + ("round_number",), + ) + round_number = cursor.fetchone() + if round_number is None: + round_number = -1 + else: + round_number = int(round_number[0]) # Cast to int + + # Fetch the best_score + cursor.execute( + """ + SELECT value FROM key_value_store WHERE key = ? + """, + ("best_score",), + ) + best_score = cursor.fetchone() + if best_score is None: + best_score = 0 + else: + best_score = float(best_score[0]) # Cast to float + return round_number, best_score + + def clean_up(self, remove_older_than: int = 1) -> None: + """Remove old entries from the database.""" + if remove_older_than < 0: + return + with self.lock: + cursor = self.conn.cursor() + query = f"SELECT MAX(round) FROM {PersistentTensorDB.TENSORS_TABLE}" + cursor.execute(query) + current_round = cursor.fetchone()[0] + if current_round is None: + return + cursor.execute( + """ + DELETE FROM tensors + WHERE round <= ? AND report = 0 + """, + (current_round - remove_older_than,), + ) + self.conn.commit() + + def close(self) -> None: + """Close the SQLite database connection.""" + self.conn.close() + + def is_task_table_empty(self) -> bool: + """Check if the task table is empty.""" + with self.lock: + cursor = self.conn.cursor() + cursor.execute("SELECT COUNT(*) FROM task_results") + count = cursor.fetchone()[0] + return count == 0 diff --git a/openfl/databases/tensor_db.py b/openfl/databases/tensor_db.py index 1b9d5ea132..5f9ffe78c6 100644 --- a/openfl/databases/tensor_db.py +++ b/openfl/databases/tensor_db.py @@ -151,6 +151,39 @@ def get_tensor_from_cache(self, tensor_key: TensorKey) -> Optional[np.ndarray]: return None return np.array(df["nparray"].iloc[0]) + def get_tensors_by_round_and_tags(self, fl_round: int, tags: tuple) -> dict: + """Retrieve all tensors that match the specified round and tags. + + Args: + fl_round (int): The round number to filter tensors. + tags (tuple): The tags to filter tensors. + + Returns: + dict: A dictionary where the keys are TensorKey objects and the values are numpy arrays. + """ + # Filter the DataFrame based on the round and tags + df = self.tensor_db[ + (self.tensor_db["round"] == fl_round) & (self.tensor_db["tags"] == tags) + ] + + # Check if any tensors match the criteria + if len(df) == 0: + return {} + + # Construct a dictionary mapping TensorKey to np.ndarray + tensor_dict = {} + for _, row in df.iterrows(): + tensor_key = TensorKey( + tensor_name=row["tensor_name"], + origin=row["origin"], + round_number=row["round"], + report=row["report"], + tags=row["tags"], + ) + tensor_dict[tensor_key] = np.array(row["nparray"]) + + return tensor_dict + def get_aggregated_tensor( self, tensor_key: TensorKey,