Skip to content

Commit

Permalink
Add Task Feature (experimental-design#360)
Browse files Browse the repository at this point in the history
  • Loading branch information
jduerholt authored Feb 28, 2024
1 parent 4790d63 commit ccb0bc2
Show file tree
Hide file tree
Showing 8 changed files with 136 additions and 6 deletions.
22 changes: 20 additions & 2 deletions bofire/data_models/domain/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
Input,
MolecularInput,
Output,
TaskInput,
)
from bofire.data_models.filters import filter_by_attribute, filter_by_class
from bofire.data_models.molfeatures.api import MolFeatures
Expand Down Expand Up @@ -181,6 +182,19 @@ class Inputs(Features):
type: Literal["Inputs"] = "Inputs"
features: Sequence[AnyInput] = Field(default_factory=lambda: [])

@field_validator("features")
@classmethod
def validate_only_one_task_input(cls, features: Sequence[AnyInput]):
filtered = filter_by_class(
features,
includes=TaskInput,
excludes=None,
exact=False,
)
if len(filtered) > 1:
raise ValueError(f"Only one `TaskInput` is allowed, got {len(filtered)}.")
return features

def get_fixed(self) -> "Inputs":
"""Gets all features in `self` that are fixed and returns them as new `Inputs` object.
Expand Down Expand Up @@ -702,7 +716,10 @@ def __call__(
]
+ [
(
pd.Series(data=feat(experiments.filter(regex=f"{feat.key}(.*)_prob")), name=f"{feat.key}_pred") # type: ignore
pd.Series(
data=feat(experiments.filter(regex=f"{feat.key}(.*)_prob")),
name=f"{feat.key}_pred",
) # type: ignore
if predictions
else experiments[feat.key]
)
Expand Down Expand Up @@ -766,7 +783,8 @@ def validate_candidates(self, candidates: pd.DataFrame) -> pd.DataFrame:
+ [
[f"{key}_pred", f"{key}_sd"]
for key in self.get_keys_by_objective(
excludes=Objective, includes=None # type: ignore
excludes=Objective,
includes=None, # type: ignore
)
]
)
Expand Down
3 changes: 3 additions & 0 deletions bofire/data_models/features/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
MolecularInput,
)
from bofire.data_models.features.numerical import NumericalInput
from bofire.data_models.features.task import TaskInput

AbstractFeature = Union[
Feature,
Expand All @@ -32,6 +33,7 @@
CategoricalDescriptorInput,
MolecularInput,
CategoricalMolecularInput,
TaskInput,
]

AnyInput = Union[
Expand All @@ -42,6 +44,7 @@
CategoricalDescriptorInput,
MolecularInput,
CategoricalMolecularInput,
TaskInput,
]

AnyOutput = Union[ContinuousOutput, CategoricalOutput]
3 changes: 1 addition & 2 deletions bofire/data_models/features/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,8 +332,7 @@ def __str__(self) -> str:

class CategoricalOutput(Output):
type: Literal["CategoricalOutput"] = "CategoricalOutput"
# order_id: ClassVar[int] = 8
order_id: ClassVar[int] = 9
order_id: ClassVar[int] = 10

categories: TCategoryVals
objective: AnyCategoricalObjective
Expand Down
3 changes: 1 addition & 2 deletions bofire/data_models/features/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,7 @@ class ContinuousOutput(Output):
"""

type: Literal["ContinuousOutput"] = "ContinuousOutput"
# order_id: ClassVar[int] = 7
order_id: ClassVar[int] = 8
order_id: ClassVar[int] = 9
unit: Optional[str] = None

objective: Optional[AnyObjective] = Field(
Expand Down
28 changes: 28 additions & 0 deletions bofire/data_models/features/task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from typing import ClassVar, List, Literal

import numpy as np
from pydantic import model_validator

from bofire.data_models.features.categorical import CategoricalInput


class TaskInput(CategoricalInput):
order_id: ClassVar[int] = 8
type: Literal["TaskInput"] = "TaskInput"
fidelities: List[int] = []

@model_validator(mode="after")
def validate_fidelities(self):
n_tasks = len(self.categories)
if self.fidelities == []:
for _ in range(n_tasks):
self.fidelities.append(0)
if len(self.fidelities) != n_tasks:
raise ValueError(
"Length of fidelity lists must be equal to the number of tasks"
)
if list(set(self.fidelities)) != list(range(np.max(self.fidelities) + 1)):
raise ValueError(
"Fidelities must be a list containing integers, starting from 0 and increasing by 1"
)
return self
9 changes: 9 additions & 0 deletions tests/bofire/data_models/features/test_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from bofire.data_models.features.api import TaskInput


def test_validate_fidelities_default_generation():
feat = TaskInput(
key="task",
categories=["p1", "p2"],
)
assert feat.fidelities == [0, 0]
47 changes: 47 additions & 0 deletions tests/bofire/data_models/specs/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,3 +153,50 @@
"allowed": [True, True, True, True],
},
)


specs.add_valid(
features.TaskInput,
lambda: {
"key": str(uuid.uuid4()),
"categories": [
"process_1",
"process_2",
"process_3",
],
"allowed": [True, True, True],
"fidelities": [0, 1, 2],
},
)

specs.add_invalid(
features.TaskInput,
lambda: {
"key": str(uuid.uuid4()),
"categories": [
"process_1",
"process_2",
"process_3",
],
"allowed": [True, True, True],
"fidelities": [0, 1],
},
error=ValueError,
message="Length of fidelity lists must be equal to the number of tasks",
)

specs.add_invalid(
features.TaskInput,
lambda: {
"key": str(uuid.uuid4()),
"categories": [
"process_1",
"process_2",
"process_3",
],
"allowed": [True, True, True],
"fidelities": [0, 1, 3],
},
error=ValueError,
message="Fidelities must be a list containing integers, starting from 0 and increasing by 1",
)
27 changes: 27 additions & 0 deletions tests/bofire/data_models/specs/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
CategoricalInput,
ContinuousInput,
ContinuousOutput,
TaskInput,
)
from tests.bofire.data_models.specs.specs import Specs

Expand All @@ -20,6 +21,19 @@
},
)

specs.add_valid(
Inputs,
lambda: {
"features": [
CategoricalInput(
key="a", categories=["1", "2"], allowed=[True, True]
).model_dump(),
ContinuousInput(key="b", bounds=(0, 1)).model_dump(),
TaskInput(key="c", categories=["a", "b", "c"]).model_dump(),
],
},
)


specs.add_invalid(
Inputs,
Expand All @@ -43,3 +57,16 @@
error=ValueError,
message="Feature keys are not unique.",
)

specs.add_invalid(
Inputs,
lambda: {
"features": [
CategoricalInput(key="a", categories=["1", "2"]),
TaskInput(key="b", categories=["a", "b", "c"]),
TaskInput(key="c", categories=["a", "b", "c"]),
],
},
error=ValueError,
message="Only one `TaskInput` is allowed, got 2.",
)

0 comments on commit ccb0bc2

Please sign in to comment.