Skip to content

Commit

Permalink
- implement a new task_group filtering decorator in Assigner class
Browse files Browse the repository at this point in the history
- update all the sub-classes that use task_groups to use the decorator
- update fedeval sample workspace to use default assigner, tasks and aggregator
- use of federated-evaluation/aggregator.yaml for FedEval specific workspace example to use round_number as 1
- removed assigner and tasks yaml from defaults/federated-evaluation, superseded by default assigner/tasks
- Rebase 21-Jan-2025.2
- added additional checks for assigner sub-classes that might not have task_groups
- Addressing review comments
- Updated existing test cases for Assigner sub-classes
- Remove hard-coded setting in assigner for torch_cnn_mnist ws, refer to default as in other Workspaces
- Use aggregator supplied --task_group to override the assinger selected_task_group
Signed-off-by: Shailesh Pant <[email protected]>
  • Loading branch information
ishaileshpant committed Jan 21, 2025
1 parent 8104144 commit 1a0ef16
Show file tree
Hide file tree
Showing 13 changed files with 81 additions and 32 deletions.
12 changes: 2 additions & 10 deletions openfl-workspace/torch_cnn_mnist/plan/plan.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,8 @@ aggregator:
rounds_to_train: 2
write_logs: false
template: openfl.component.aggregator.Aggregator
assigner:
settings:
task_groups:
- name: learning
percentage: 1.0
tasks:
- aggregated_model_validation
- train
- locally_tuned_model_validation
template: openfl.component.RandomGroupedAssigner
assigner :
defaults : plan/defaults/assigner.yaml
collaborator:
settings:
db_store_rounds: 1
Expand Down
8 changes: 5 additions & 3 deletions openfl-workspace/torch_cnn_mnist_fed_eval/plan/plan.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,12 @@ network :
defaults : plan/defaults/network.yaml

assigner :
defaults : plan/defaults/federated-evaluation/assigner.yaml

defaults : plan/defaults/assigner.yaml
settings :
selected_task_group : evaluation

tasks :
defaults : plan/defaults/federated-evaluation/tasks_torch.yaml
defaults : plan/defaults/tasks_torch.yaml

compression_pipeline :
defaults : plan/defaults/compression_pipeline.yaml
5 changes: 5 additions & 0 deletions openfl-workspace/workspace/plan/defaults/assigner.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,8 @@ settings :
- aggregated_model_validation
- train
- locally_tuned_model_validation
- name : evaluation
percentage : 1.0
tasks :
- aggregated_model_validation
selected_task_group: learning

This file was deleted.

This file was deleted.

1 change: 1 addition & 0 deletions openfl/component/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright 2020-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""OpenFL Component Module."""

from openfl.component.aggregator.aggregator import Aggregator
from openfl.component.assigner.assigner import Assigner
Expand Down
4 changes: 4 additions & 0 deletions openfl/component/aggregator/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,10 @@ def __init__(
self.uuid = aggregator_uuid
self.federation_uuid = federation_uuid
self.assigner = assigner
# override the assigner selected_task_group
# FIXME check the case of CustomAssigner as base class Assigner is redefined
# and doesn't have selected_task_group as attribute
self.assigner.selected_task_group = task_group
self.quit_job_sent_to = []

self.tensor_db = TensorDB()
Expand Down
1 change: 1 addition & 0 deletions openfl/component/assigner/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright 2020-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""OpenFL Assigner Module."""

from openfl.component.assigner.assigner import Assigner
from openfl.component.assigner.random_grouped_assigner import RandomGroupedAssigner
Expand Down
54 changes: 53 additions & 1 deletion openfl/component/assigner/assigner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@

"""Assigner module."""

import logging
from functools import wraps

logger = logging.getLogger(__name__)


class Assigner:
r"""
Expand Down Expand Up @@ -35,18 +40,27 @@ class Assigner:
\* - ``tasks`` argument is taken from ``tasks`` section of FL plan YAML file.
"""

def __init__(self, tasks, authorized_cols, rounds_to_train, **kwargs):
def __init__(
self,
tasks,
authorized_cols,
rounds_to_train,
selected_task_group: str = "learning",
**kwargs,
):
"""Initializes the Assigner.
Args:
tasks (list of object): List of tasks to assign.
authorized_cols (list of str): Collaborators.
rounds_to_train (int): Number of training rounds.
selected_task_group (str, optional): Selected task_group. Defaults to "learning".
**kwargs: Additional keyword arguments.
"""
self.tasks = tasks
self.authorized_cols = authorized_cols
self.rounds = rounds_to_train
self.selected_task_group = selected_task_group
self.all_tasks_in_groups = []

self.task_group_collaborators = {}
Expand Down Expand Up @@ -93,3 +107,41 @@ def get_aggregation_type_for_task(self, task_name):
if "aggregation_type" not in self.tasks[task_name]:
return None
return self.tasks[task_name]["aggregation_type"]

@classmethod
def task_group_filtering(cls, func):
"""Decorator to filter task groups based on selected_task_group.
This decorator should be applied to define_task_assignments() method
in Assigner subclasses to handle task_group filtering.
"""

@wraps(func)
def wrapper(self, *args, **kwargs):
# First check if selection of task_group is applicable
if hasattr(self, "selected_task_group"):
# Verify task_groups exists before attempting filtering
if not hasattr(self, "task_groups"):
logger.warning(
"Task group specified for selection but no task_groups found. "
"Skipping filtering. This might be intentional for custom assigners."
)
return func(self, *args, **kwargs)

assert self.task_groups, "No task_groups defined in assigner."

# Perform the filtering
self.task_groups = [
group for group in self.task_groups if group["name"] == self.selected_task_group
]

assert self.task_groups, f"No task groups found for : {self.selected_task_group}"

# Mode-specific validations
if self.selected_task_group == "evaluation":
assert self.rounds == 1, "Number of rounds should be 1 for evaluation"

# Call the original method
return func(self, *args, **kwargs)

return wrapper
7 changes: 5 additions & 2 deletions openfl/component/assigner/random_grouped_assigner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import numpy as np

from openfl.component.assigner.assigner import Assigner
from openfl.component.assigner import Assigner


class RandomGroupedAssigner(Assigner):
Expand All @@ -33,16 +33,19 @@ class RandomGroupedAssigner(Assigner):
\* - Plan setting.
"""

task_group_filtering = Assigner.task_group_filtering

def __init__(self, task_groups, **kwargs):
"""Initializes the RandomGroupedAssigner.
Args:
task_groups (list of object): Task groups to assign.
**kwargs: Additional keyword arguments.
**kwargs: Additional keyword arguments, including mode.
"""
self.task_groups = task_groups
super().__init__(**kwargs)

@task_group_filtering
def define_task_assignments(self):
"""Define task assignments for each round and collaborator.
Expand Down
3 changes: 3 additions & 0 deletions openfl/component/assigner/static_grouped_assigner.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ class StaticGroupedAssigner(Assigner):
\* - Plan setting.
"""

task_group_filtering = Assigner.task_group_filtering

def __init__(self, task_groups, **kwargs):
"""Initializes the StaticGroupedAssigner.
Expand All @@ -42,6 +44,7 @@ def __init__(self, task_groups, **kwargs):
self.task_groups = task_groups
super().__init__(**kwargs)

@task_group_filtering
def define_task_assignments(self):
"""Define task assignments for each round and collaborator.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def task_groups():
"""Initialize task groups."""
task_groups = [
{
'name': 'train_and_validate',
'name': 'learning',
'percentage': 1.0,
'tasks': [
'aggregated_model_validation',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def task_groups(authorized_cols):
"""Initialize task groups."""
task_groups = [
{
'name': 'train_and_validate',
'name': 'learning',
'percentage': 1.0,
'collaborators': authorized_cols,
'tasks': [
Expand Down

0 comments on commit 1a0ef16

Please sign in to comment.