diff --git a/airflow/dag_processing/collection.py b/airflow/dag_processing/collection.py index f559d68f945fd..a9f07573d1046 100644 --- a/airflow/dag_processing/collection.py +++ b/airflow/dag_processing/collection.py @@ -52,7 +52,7 @@ from airflow.models.dagwarning import DagWarningType from airflow.models.errors import ParseImportError from airflow.models.trigger import Trigger -from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetNameRef, AssetUriRef +from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetNameRef, AssetUriRef, AssetWatcher from airflow.serialization.serialized_objects import BaseSerialization from airflow.utils.retries import MAX_DB_RETRIES, run_with_db_retries from airflow.utils.sqlalchemy import with_row_locks @@ -742,10 +742,12 @@ def add_asset_trigger_references( for name_uri, asset in self.assets.items(): # If the asset belong to a DAG not active or paused, consider there is no watcher associated to it - asset_watchers: list[dict] = cast(list[dict], asset.watchers) if name_uri in active_assets else [] + asset_watchers: list[AssetWatcher] = asset.watchers if name_uri in active_assets else [] trigger_hash_to_trigger_dict: dict[int, dict] = { - self._get_trigger_hash(trigger["classpath"], trigger["kwargs"]): trigger - for trigger in asset_watchers + self._get_trigger_hash( + cast(dict, watcher.trigger)["classpath"], cast(dict, watcher.trigger)["kwargs"] + ): cast(dict, watcher.trigger) + for watcher in asset_watchers } triggers.update(trigger_hash_to_trigger_dict) trigger_hash_from_asset: set[int] = set(trigger_hash_to_trigger_dict.keys()) diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index d24185387e964..b3a38c5470d7f 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -61,6 +61,7 @@ AssetAny, AssetRef, AssetUniqueKey, + AssetWatcher, BaseAsset, ) from airflow.sdk.definitions.baseoperator import BaseOperator as TaskSDKBaseOperator @@ -254,6 +255,12 @@ def encode_asset_condition(var: BaseAsset) -> dict[str, Any]: """ if isinstance(var, Asset): + def _encode_watcher(watcher: AssetWatcher): + return { + "name": watcher.name, + "trigger": _encode_trigger(watcher.trigger), + } + def _encode_trigger(trigger: BaseTrigger | dict): if isinstance(trigger, dict): return trigger @@ -272,7 +279,7 @@ def _encode_trigger(trigger: BaseTrigger | dict): } if len(var.watchers) > 0: - asset["watchers"] = [_encode_trigger(trigger) for trigger in var.watchers] + asset["watchers"] = [_encode_watcher(watcher) for watcher in var.watchers] return asset if isinstance(var, AssetAlias): @@ -300,12 +307,16 @@ def decode_asset_condition(var: dict[str, Any]) -> BaseAsset: """ dat = var["__type"] if dat == DAT.ASSET: + serialized_watchers = var["watchers"] if "watchers" in var else [] return Asset( name=var["name"], uri=var["uri"], group=var["group"], extra=var["extra"], - watchers=var["watchers"] if "watchers" in var else [], + watchers=[ + AssetWatcher(name=watcher["name"], trigger=watcher["trigger"]) + for watcher in serialized_watchers + ], ) if dat == DAT.ASSET_ALL: return AssetAll(*(decode_asset_condition(x) for x in var["objects"])) @@ -897,7 +908,13 @@ def deserialize(cls, encoded_var: Any) -> Any: elif type_ == DAT.XCOM_REF: return _XComRef(var) # Delay deserializing XComArg objects until we have the entire DAG. elif type_ == DAT.ASSET: - return Asset(**var) + watchers = var.pop("watchers", []) + return Asset( + **var, + watchers=[ + AssetWatcher(name=watcher["name"], trigger=watcher["trigger"]) for watcher in watchers + ], + ) elif type_ == DAT.ASSET_ALIAS: return AssetAlias(**var) elif type_ == DAT.ASSET_ANY: diff --git a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py index b9d7f03db0a98..3cfbbcbd24430 100644 --- a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py +++ b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py @@ -53,6 +53,7 @@ "AssetNameRef", "AssetRef", "AssetUriRef", + "AssetWatcher", ] @@ -257,6 +258,17 @@ def iter_dag_dependencies(self, *, source: str, target: str) -> Iterator[DagDepe raise NotImplementedError +@attrs.define(frozen=True) +class AssetWatcher: + """A representation of an asset watcher. The name uniquely identity the watch.""" + + name: str + # This attribute serves double purpose. For a "normal" asset instance + # loaded from DAG, this holds the trigger used to monitor an external resource. + # For an asset recreated from a serialized DAG, however, this holds the serialized data of the trigger. + trigger: BaseTrigger | dict + + @attrs.define(init=False, unsafe_hash=False) class Asset(os.PathLike, BaseAsset): """A representation of data asset dependencies between workflows.""" @@ -276,11 +288,7 @@ class Asset(os.PathLike, BaseAsset): factory=dict, converter=_set_extra_default, ) - # This attribute serves double purpose. For a "normal" asset instance - # loaded from DAG, this holds the list of triggers used to monitor an external resource. - # For an asset recreated from a serialized DAG, however, this holds the serialized data of the list of - # triggers. - watchers: list[BaseTrigger | dict] = attrs.field( + watchers: list[AssetWatcher] = attrs.field( factory=list, ) @@ -295,7 +303,7 @@ def __init__( *, group: str = ..., extra: dict | None = None, - watchers: list[BaseTrigger | dict] = ..., + watchers: list[AssetWatcher] = ..., ) -> None: """Canonical; both name and uri are provided.""" @@ -306,7 +314,7 @@ def __init__( *, group: str = ..., extra: dict | None = None, - watchers: list[BaseTrigger | dict] = ..., + watchers: list[AssetWatcher] = ..., ) -> None: """It's possible to only provide the name, either by keyword or as the only positional argument.""" @@ -317,7 +325,7 @@ def __init__( uri: str, group: str = ..., extra: dict | None = None, - watchers: list[BaseTrigger | dict] = ..., + watchers: list[AssetWatcher] = ..., ) -> None: """It's possible to only provide the URI as a keyword argument.""" @@ -328,7 +336,7 @@ def __init__( *, group: str | None = None, extra: dict | None = None, - watchers: list[BaseTrigger | dict] | None = None, + watchers: list[AssetWatcher] | None = None, ) -> None: if name is None and uri is None: raise TypeError("Asset() requires either 'name' or 'uri'") diff --git a/tests/dag_processing/test_collection.py b/tests/dag_processing/test_collection.py index fb4ee5216abc5..a92c0023c8811 100644 --- a/tests/dag_processing/test_collection.py +++ b/tests/dag_processing/test_collection.py @@ -50,7 +50,7 @@ from airflow.models.serialized_dag import SerializedDagModel from airflow.operators.empty import EmptyOperator from airflow.providers.standard.triggers.temporal import TimeDeltaTrigger -from airflow.sdk.definitions.asset import Asset +from airflow.sdk.definitions.asset import Asset, AssetWatcher from airflow.serialization.serialized_objects import LazyDeserializedDAG, SerializedDAG from airflow.utils import timezone as tz from airflow.utils.session import create_session @@ -133,7 +133,8 @@ def test_add_asset_trigger_references(self, is_active, is_paused, expected_num_t trigger = TimeDeltaTrigger(timedelta(seconds=0)) classpath, kwargs = trigger.serialize() asset = Asset( - "test_add_asset_trigger_references_asset", watchers=[{"classpath": classpath, "kwargs": kwargs}] + "test_add_asset_trigger_references_asset", + watchers=[AssetWatcher(name="test", trigger={"classpath": classpath, "kwargs": kwargs})], ) with dag_maker(dag_id="test_add_asset_trigger_references_dag", schedule=[asset]) as dag: diff --git a/tests/serialization/test_serialized_objects.py b/tests/serialization/test_serialized_objects.py index f48442238f883..8a15564dce17a 100644 --- a/tests/serialization/test_serialized_objects.py +++ b/tests/serialization/test_serialized_objects.py @@ -43,7 +43,7 @@ from airflow.operators.empty import EmptyOperator from airflow.providers.standard.operators.python import PythonOperator from airflow.providers.standard.triggers.file import FileTrigger -from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetUniqueKey +from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetUniqueKey, AssetWatcher from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding from airflow.serialization.serialized_objects import BaseSerialization from airflow.triggers.base import BaseTrigger @@ -171,24 +171,6 @@ def equal_outlet_event_accessor(a: OutletEventAccessor, b: OutletEventAccessor) return a.key == b.key and a.extra == b.extra and a.asset_alias_events == b.asset_alias_events -def equal_asset(a: Asset, b: Asset) -> bool: - def _serialize_trigger(trigger: BaseTrigger | dict): - if isinstance(trigger, dict): - return trigger - - classpath, kwargs = trigger.serialize() - return { - "classpath": classpath, - "kwargs": kwargs, - } - - a_watchers = [_serialize_trigger(watcher) for watcher in a.watchers] - b_watchers = b.watchers - a.watchers = [] - b.watchers = [] - return a == b and a_watchers == b_watchers - - class MockLazySelectSequence(LazySelectSequence): _data = ["a", "b", "c"] @@ -274,9 +256,13 @@ def __len__(self) -> int: ), (Asset(uri="test://asset1", name="test"), DAT.ASSET, equals), ( - Asset(uri="test://asset1", name="test", watchers=[FileTrigger(filepath="/tmp")]), + Asset( + uri="test://asset1", + name="test", + watchers=[AssetWatcher(name="test", trigger=FileTrigger(filepath="/tmp"))], + ), DAT.ASSET, - equal_asset, + equals, ), (SimpleTaskInstance.from_ti(ti=TI), DAT.SIMPLE_TASK_INSTANCE, equals), (