Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Recover missing state after collaborator restart #1268

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 19 additions & 5 deletions openfl/component/collaborator/collaborator.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,15 +382,17 @@ def get_data_for_tensorkey(self, tensor_key):
return nparray
prior_round -= 1
logger.info(f"Cannot find any prior version of tensor {tensor_name} locally...")
logger.debug(
"Unable to get tensor from local store..." "attempting to retrieve from client"
)
# Determine whether there are additional compression related
# dependencies.
# Typically, dependencies are only relevant to model layers
tensor_dependencies = self.tensor_codec.find_dependencies(
tensor_key, self.delta_updates
)
logger.debug(
"Unable to get tensor from local store..."
"attempting to retrieve from client len tensor_dependencies"
f" tensor_key {tensor_key}"
)
if len(tensor_dependencies) > 0:
# Resolve dependencies
# tensor_dependencies[0] corresponds to the prior version
Expand All @@ -411,10 +413,10 @@ def get_data_for_tensorkey(self, tensor_key):
self.tensor_db.cache_tensor({new_model_tk: nparray})
else:
logger.info(
"Count not find previous model layer."
"Could not find previous model layer."
"Fetching latest layer from aggregator"
)
# The original model tensor should be fetched from client
# The original model tensor should be fetched from aggregator
nparray = self.get_aggregated_tensor_from_aggregator(
tensor_key, require_lossless=True
)
Expand All @@ -423,6 +425,18 @@ def get_data_for_tensorkey(self, tensor_key):
nparray = self.get_aggregated_tensor_from_aggregator(
tensor_key, require_lossless=True
)
else:
# we should try fetching the tensor from aggregator
tensor_name, origin, round_number, report, tags = tensor_key
tags = (self.collaborator_name,) + tags
tensor_key = (tensor_name, origin, round_number, report, tags)
logger.info(
"Could not find previous model layer."
f"Fetching latest layer from aggregator {tensor_key}"
)
nparray = self.get_aggregated_tensor_from_aggregator(
tensor_key, require_lossless=True
)
else:
logger.debug("Found tensor %s in local TensorDB", tensor_key)

Expand Down
11 changes: 10 additions & 1 deletion openfl/federated/task/runner_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,16 @@ def train_(self, batch_generator, metrics: list = None, **kwargs):
# initialization (build_model).
# If metrics are added (i.e. not a subset of what was originally
# defined) then the model must be recompiled.
results = self.model.get_metrics_result()
try:
results = self.model.get_metrics_result()
except ValueError:
if "batch_size" in kwargs:
batch_size = kwargs["batch_size"]
else:
batch_size = 1
# evaluation needed before metrics can be resolved
self.model.evaluate(self.data_loader.get_valid_loader(batch_size), verbose=1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For other reviewers reference, this is needed for the scenario when the collaborator dies between the aggregated_model_validation and train task. If the collaborator restarts it will immediately be assigned train, and will not be able to resolve the necessary metrics. The names of needed metrics can be resolved by running the evaluate function on the model, but these metric values will not be sent to the aggregator.

results = self.model.get_metrics_result()

# TODO if there are new metrics in the flplan that were not included
# in the originally
Expand Down
Loading