diff --git a/task_sdk/src/airflow/sdk/definitions/asset/decorators.py b/task_sdk/src/airflow/sdk/definitions/asset/decorators.py index 95876b76e665e..7e19d383ebbc5 100644 --- a/task_sdk/src/airflow/sdk/definitions/asset/decorators.py +++ b/task_sdk/src/airflow/sdk/definitions/asset/decorators.py @@ -27,10 +27,11 @@ from airflow.utils.session import create_session if TYPE_CHECKING: - from collections.abc import Iterator, Mapping + from collections.abc import Collection, Iterator, Mapping from airflow.io.path import ObjectStoragePath - from airflow.models.dag import ScheduleArg + from airflow.models.dag import DagStateChangeCallback, ScheduleArg + from airflow.models.param import ParamsDict from airflow.triggers.base import BaseTrigger @@ -76,22 +77,32 @@ class AssetDefinition(Asset): :meta private: """ - function: Callable - schedule: ScheduleArg + _function: Callable + _source: asset def __attrs_post_init__(self) -> None: from airflow.models.dag import DAG - with DAG(dag_id=self.name, schedule=self.schedule, auto_register=True): + with DAG( + dag_id=self.name, + schedule=self._source.schedule, + is_paused_upon_creation=self._source.is_paused_upon_creation, + dag_display_name=self._source.display_name or self.name, + description=self._source.description, + params=self._source.params, + on_success_callback=self._source.on_success_callback, + on_failure_callback=self._source.on_failure_callback, + auto_register=True, + ): _AssetMainOperator( task_id="__main__", inlets=[ AssetRef(name=inlet_asset_name) - for inlet_asset_name in inspect.signature(self.function).parameters + for inlet_asset_name in inspect.signature(self._function).parameters if inlet_asset_name not in ("self", "context") ], outlets=[self], - python_callable=self.function, + python_callable=self._function, definition_name=self.name, uri=self.uri, ) @@ -101,12 +112,24 @@ def __attrs_post_init__(self) -> None: class asset: """Create an asset by decorating a materialization function.""" - schedule: ScheduleArg uri: str | ObjectStoragePath | None = None group: str = Asset.asset_type extra: dict[str, Any] = attrs.field(factory=dict) watchers: list[BaseTrigger] = attrs.field(factory=list) + schedule: ScheduleArg + is_paused_upon_creation: bool | None = None + + display_name: str | None = None + description: str | None = None + + params: ParamsDict | None = None + on_success_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None + on_failure_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None + + access_control: dict[str, dict[str, Collection[str]]] | None = None + owner_links: dict[str, str] | None = None + def __call__(self, f: Callable) -> AssetDefinition: if self.schedule is not None: raise NotImplementedError("asset scheduling not implemented yet") @@ -117,9 +140,6 @@ def __call__(self, f: Callable) -> AssetDefinition: return AssetDefinition( name=name, uri=name if self.uri is None else str(self.uri), - group=self.group, - extra=self.extra, - watchers=self.watchers, function=f, - schedule=self.schedule, + source=self, )