Skip to content

Commit

Permalink
- Fix the leaky abstraction of task_group to Aggregator
Browse files Browse the repository at this point in the history
- Added is_task_group_evaluation function in Assigner class

Signed-off-by: Shailesh Pant <[email protected]>
  • Loading branch information
ishaileshpant committed Jan 23, 2025
1 parent 57f5094 commit f40af12
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 8 deletions.
11 changes: 4 additions & 7 deletions openfl/component/aggregator/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ def __init__(
callbacks: Optional[List] = None,
persist_checkpoint=True,
persistent_db_path=None,
task_group: str = "learning",
):
"""Initializes the Aggregator.
Expand All @@ -111,9 +110,7 @@ def __init__(
Defaults to 1.
initial_tensor_dict (dict, optional): Initial tensor dictionary.
callbacks: List of callbacks to be used during the experiment.
task_group (str, optional): Selected task_group for assignment.
"""
self.task_group = task_group
self.round_number = 0
self.next_model_round_number = 0

Expand All @@ -132,10 +129,10 @@ def __init__(
)

self.rounds_to_train = rounds_to_train
if self.task_group == "evaluation":
if self.assigner.is_task_group_evaluation():
self.rounds_to_train = 1
logger.info(
f"task_group is {self.task_group}, setting rounds_to_train = {self.rounds_to_train}"
f"For evaluation tasks setting rounds_to_train = {self.rounds_to_train}"
)

self._end_of_round_check_done = [False] * rounds_to_train
Expand Down Expand Up @@ -311,8 +308,8 @@ def _load_initial_tensors(self):
)

# Check selected task_group before updating round number
if self.task_group == "evaluation":
logger.info(f"Skipping round_number check for {self.task_group} task_group")
if self.assigner.is_task_group_evaluation():
logger.info("Skipping round_number check for evaluation run")
elif round_number > self.round_number:
logger.info(f"Starting training from round {round_number} of previously saved model")
self.round_number = round_number
Expand Down
10 changes: 10 additions & 0 deletions openfl/component/assigner/assigner.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,16 @@ def get_collaborators_for_task(self, task_name, round_number):
"""Abstract method."""
raise NotImplementedError

def is_task_group_evaluation(self):
"""Check if the selected task group is for 'evaluation' run.
Returns:
bool: True if the selected task group is 'evaluation', False otherwise.
"""
if hasattr(self, "selected_task_group"):
return self.selected_task_group == 'evaluation'
return False

def get_all_tasks_for_round(self, round_number):
"""Return tasks for the current round.
Expand Down
1 change: 0 additions & 1 deletion openfl/interface/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ def start_(plan, authorized_cols, task_group):
# Set task_group in aggregator settings
if "settings" not in parsed_plan.config["aggregator"]:
parsed_plan.config["aggregator"]["settings"] = {}
parsed_plan.config["aggregator"]["settings"]["task_group"] = task_group
parsed_plan.config["assigner"]["settings"]["selected_task_group"] = task_group
logger.info(f"Setting aggregator to assign: {task_group} task_group")

Expand Down

0 comments on commit f40af12

Please sign in to comment.