Skip to content

Commit

Permalink
add include flag (#465)
Browse files Browse the repository at this point in the history
  • Loading branch information
jduerholt authored Nov 21, 2024
1 parent 1b99e8a commit d2bd87d
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 3 deletions.
10 changes: 8 additions & 2 deletions bofire/data_models/domain/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,17 +125,23 @@ def get_by_key(self, key: str) -> F:
"""
return {f.key: f for f in self.features}[key]

def get_by_keys(self, keys: Sequence[str]) -> Self:
def get_by_keys(self, keys: Sequence[str], include: bool = True) -> Self:
"""Get features of the domain specified by its keys.
Args:
keys: List of the keys of the features that should be returned.
include: Boolean to distinguish if the features with the keys in the
list should be included or excluded.
Returns:
Features: Features object with the requested features.
"""
return self.__class__(features=sorted([self.get_by_key(key) for key in keys]))
if include:
features = [self.get_by_key(key) for key in keys]
else:
features = [f for f in self.features if f.key not in keys]
return self.__class__(features=sorted(features))

def get(
self,
Expand Down
9 changes: 8 additions & 1 deletion tests/bofire/data_models/domain/test_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,13 +144,20 @@ def test_features_get_by_key(features, key, expected):
assert id(returned) == id(expected)


def test_features_get_by_keys():
def test_features_get_by_keys_include():
keys = ["of2", "if1"]
feats = features.get_by_keys(keys)
assert feats[0].key == "if1"
assert feats[1].key == "of2"


def test_features_get_by_keys_exclude():
keys = ["of2", "if1"]
feats = features.get_by_keys(keys, include=False)
assert feats[0].key == "if2"
assert feats[1].key == "of1"


@pytest.mark.parametrize(
"features, key",
[
Expand Down

0 comments on commit d2bd87d

Please sign in to comment.