Skip to content

Commit

Permalink
Add concept of AssetWatcher
Browse files Browse the repository at this point in the history
  • Loading branch information
vincbeck committed Jan 17, 2025
1 parent eba35e0 commit f64e72d
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 39 deletions.
10 changes: 6 additions & 4 deletions airflow/dag_processing/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
from airflow.models.dagwarning import DagWarningType
from airflow.models.errors import ParseImportError
from airflow.models.trigger import Trigger
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetNameRef, AssetUriRef
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetNameRef, AssetUriRef, AssetWatcher
from airflow.serialization.serialized_objects import BaseSerialization
from airflow.utils.retries import MAX_DB_RETRIES, run_with_db_retries
from airflow.utils.sqlalchemy import with_row_locks
Expand Down Expand Up @@ -742,10 +742,12 @@ def add_asset_trigger_references(

for name_uri, asset in self.assets.items():
# If the asset belong to a DAG not active or paused, consider there is no watcher associated to it
asset_watchers: list[dict] = cast(list[dict], asset.watchers) if name_uri in active_assets else []
asset_watchers: list[AssetWatcher] = asset.watchers if name_uri in active_assets else []
trigger_hash_to_trigger_dict: dict[int, dict] = {
self._get_trigger_hash(trigger["classpath"], trigger["kwargs"]): trigger
for trigger in asset_watchers
self._get_trigger_hash(
cast(dict, watcher.trigger)["classpath"], cast(dict, watcher.trigger)["kwargs"]
): cast(dict, watcher.trigger)
for watcher in asset_watchers
}
triggers.update(trigger_hash_to_trigger_dict)
trigger_hash_from_asset: set[int] = set(trigger_hash_to_trigger_dict.keys())
Expand Down
23 changes: 20 additions & 3 deletions airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
AssetAny,
AssetRef,
AssetUniqueKey,
AssetWatcher,
BaseAsset,
)
from airflow.sdk.definitions.baseoperator import BaseOperator as TaskSDKBaseOperator
Expand Down Expand Up @@ -254,6 +255,12 @@ def encode_asset_condition(var: BaseAsset) -> dict[str, Any]:
"""
if isinstance(var, Asset):

def _encode_watcher(watcher: AssetWatcher):
return {
"name": watcher.name,
"trigger": _encode_trigger(watcher.trigger),
}

def _encode_trigger(trigger: BaseTrigger | dict):
if isinstance(trigger, dict):
return trigger
Expand All @@ -272,7 +279,7 @@ def _encode_trigger(trigger: BaseTrigger | dict):
}

if len(var.watchers) > 0:
asset["watchers"] = [_encode_trigger(trigger) for trigger in var.watchers]
asset["watchers"] = [_encode_watcher(watcher) for watcher in var.watchers]

return asset
if isinstance(var, AssetAlias):
Expand Down Expand Up @@ -300,12 +307,16 @@ def decode_asset_condition(var: dict[str, Any]) -> BaseAsset:
"""
dat = var["__type"]
if dat == DAT.ASSET:
serialized_watchers = var["watchers"] if "watchers" in var else []
return Asset(
name=var["name"],
uri=var["uri"],
group=var["group"],
extra=var["extra"],
watchers=var["watchers"] if "watchers" in var else [],
watchers=[
AssetWatcher(name=watcher["name"], trigger=watcher["trigger"])
for watcher in serialized_watchers
],
)
if dat == DAT.ASSET_ALL:
return AssetAll(*(decode_asset_condition(x) for x in var["objects"]))
Expand Down Expand Up @@ -897,7 +908,13 @@ def deserialize(cls, encoded_var: Any) -> Any:
elif type_ == DAT.XCOM_REF:
return _XComRef(var) # Delay deserializing XComArg objects until we have the entire DAG.
elif type_ == DAT.ASSET:
return Asset(**var)
watchers = var.pop("watchers", [])
return Asset(
**var,
watchers=[
AssetWatcher(name=watcher["name"], trigger=watcher["trigger"]) for watcher in watchers
],
)
elif type_ == DAT.ASSET_ALIAS:
return AssetAlias(**var)
elif type_ == DAT.ASSET_ANY:
Expand Down
26 changes: 17 additions & 9 deletions task_sdk/src/airflow/sdk/definitions/asset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
"AssetNameRef",
"AssetRef",
"AssetUriRef",
"AssetWatcher",
]


Expand Down Expand Up @@ -257,6 +258,17 @@ def iter_dag_dependencies(self, *, source: str, target: str) -> Iterator[DagDepe
raise NotImplementedError


@attrs.define(frozen=True)
class AssetWatcher:
"""A representation of an asset watcher. The name uniquely identity the watch."""

name: str
# This attribute serves double purpose. For a "normal" asset instance
# loaded from DAG, this holds the trigger used to monitor an external resource.
# For an asset recreated from a serialized DAG, however, this holds the serialized data of the trigger.
trigger: BaseTrigger | dict


@attrs.define(init=False, unsafe_hash=False)
class Asset(os.PathLike, BaseAsset):
"""A representation of data asset dependencies between workflows."""
Expand All @@ -276,11 +288,7 @@ class Asset(os.PathLike, BaseAsset):
factory=dict,
converter=_set_extra_default,
)
# This attribute serves double purpose. For a "normal" asset instance
# loaded from DAG, this holds the list of triggers used to monitor an external resource.
# For an asset recreated from a serialized DAG, however, this holds the serialized data of the list of
# triggers.
watchers: list[BaseTrigger | dict] = attrs.field(
watchers: list[AssetWatcher] = attrs.field(
factory=list,
)

Expand All @@ -295,7 +303,7 @@ def __init__(
*,
group: str = ...,
extra: dict | None = None,
watchers: list[BaseTrigger | dict] = ...,
watchers: list[AssetWatcher] = ...,
) -> None:
"""Canonical; both name and uri are provided."""

Expand All @@ -306,7 +314,7 @@ def __init__(
*,
group: str = ...,
extra: dict | None = None,
watchers: list[BaseTrigger | dict] = ...,
watchers: list[AssetWatcher] = ...,
) -> None:
"""It's possible to only provide the name, either by keyword or as the only positional argument."""

Expand All @@ -317,7 +325,7 @@ def __init__(
uri: str,
group: str = ...,
extra: dict | None = None,
watchers: list[BaseTrigger | dict] = ...,
watchers: list[AssetWatcher] = ...,
) -> None:
"""It's possible to only provide the URI as a keyword argument."""

Expand All @@ -328,7 +336,7 @@ def __init__(
*,
group: str | None = None,
extra: dict | None = None,
watchers: list[BaseTrigger | dict] | None = None,
watchers: list[AssetWatcher] | None = None,
) -> None:
if name is None and uri is None:
raise TypeError("Asset() requires either 'name' or 'uri'")
Expand Down
5 changes: 3 additions & 2 deletions tests/dag_processing/test_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
from airflow.models.serialized_dag import SerializedDagModel
from airflow.operators.empty import EmptyOperator
from airflow.providers.standard.triggers.temporal import TimeDeltaTrigger
from airflow.sdk.definitions.asset import Asset
from airflow.sdk.definitions.asset import Asset, AssetWatcher
from airflow.serialization.serialized_objects import LazyDeserializedDAG, SerializedDAG
from airflow.utils import timezone as tz
from airflow.utils.session import create_session
Expand Down Expand Up @@ -133,7 +133,8 @@ def test_add_asset_trigger_references(self, is_active, is_paused, expected_num_t
trigger = TimeDeltaTrigger(timedelta(seconds=0))
classpath, kwargs = trigger.serialize()
asset = Asset(
"test_add_asset_trigger_references_asset", watchers=[{"classpath": classpath, "kwargs": kwargs}]
"test_add_asset_trigger_references_asset",
watchers=[AssetWatcher(name="test", trigger={"classpath": classpath, "kwargs": kwargs})],
)

with dag_maker(dag_id="test_add_asset_trigger_references_dag", schedule=[asset]) as dag:
Expand Down
28 changes: 7 additions & 21 deletions tests/serialization/test_serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from airflow.operators.empty import EmptyOperator
from airflow.providers.standard.operators.python import PythonOperator
from airflow.providers.standard.triggers.file import FileTrigger
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetUniqueKey
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetUniqueKey, AssetWatcher
from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding
from airflow.serialization.serialized_objects import BaseSerialization
from airflow.triggers.base import BaseTrigger
Expand Down Expand Up @@ -171,24 +171,6 @@ def equal_outlet_event_accessor(a: OutletEventAccessor, b: OutletEventAccessor)
return a.key == b.key and a.extra == b.extra and a.asset_alias_events == b.asset_alias_events


def equal_asset(a: Asset, b: Asset) -> bool:
def _serialize_trigger(trigger: BaseTrigger | dict):
if isinstance(trigger, dict):
return trigger

classpath, kwargs = trigger.serialize()
return {
"classpath": classpath,
"kwargs": kwargs,
}

a_watchers = [_serialize_trigger(watcher) for watcher in a.watchers]
b_watchers = b.watchers
a.watchers = []
b.watchers = []
return a == b and a_watchers == b_watchers


class MockLazySelectSequence(LazySelectSequence):
_data = ["a", "b", "c"]

Expand Down Expand Up @@ -274,9 +256,13 @@ def __len__(self) -> int:
),
(Asset(uri="test://asset1", name="test"), DAT.ASSET, equals),
(
Asset(uri="test://asset1", name="test", watchers=[FileTrigger(filepath="/tmp")]),
Asset(
uri="test://asset1",
name="test",
watchers=[AssetWatcher(name="test", trigger=FileTrigger(filepath="/tmp"))],
),
DAT.ASSET,
equal_asset,
equals,
),
(SimpleTaskInstance.from_ti(ti=TI), DAT.SIMPLE_TASK_INSTANCE, equals),
(
Expand Down

0 comments on commit f64e72d

Please sign in to comment.