Skip to content

Commit

Permalink
AIP-82 Save references between assets and triggers (apache#43826)
Browse files Browse the repository at this point in the history
  • Loading branch information
vincbeck authored Nov 25, 2024
1 parent b4c4806 commit bee7f0c
Show file tree
Hide file tree
Showing 8 changed files with 145 additions and 16 deletions.
96 changes: 96 additions & 0 deletions airflow/dag_processing/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@
)
from airflow.models.dag import DAG, DagModel, DagOwnerAttributes, DagTag
from airflow.models.dagrun import DagRun
from airflow.models.trigger import Trigger
from airflow.sdk.definitions.asset import Asset, AssetAlias
from airflow.triggers.base import BaseTrigger
from airflow.utils.sqlalchemy import with_row_locks
from airflow.utils.timezone import utcnow
from airflow.utils.types import DagRunType
Expand Down Expand Up @@ -425,3 +427,97 @@ def add_task_asset_references(
for task_id, asset_id in referenced_outlets
if (task_id, asset_id) not in orm_refs
)

def add_asset_trigger_references(
self, assets: dict[tuple[str, str], AssetModel], *, session: Session
) -> None:
# Update references from assets being used
refs_to_add: dict[tuple[str, str], set[str]] = {}
refs_to_remove: dict[tuple[str, str], set[str]] = {}
triggers: dict[str, BaseTrigger] = {}
for name_uri, asset in self.assets.items():
asset_model = assets[name_uri]
trigger_repr_to_trigger_dict: dict[str, BaseTrigger] = {
repr(trigger): trigger for trigger in asset.watchers
}
triggers.update(trigger_repr_to_trigger_dict)
trigger_repr_from_asset: set[str] = set(trigger_repr_to_trigger_dict.keys())

trigger_repr_from_asset_model: set[str] = {
BaseTrigger.repr(trigger.classpath, trigger.kwargs) for trigger in asset_model.triggers
}

# Optimization: no diff between the DB and DAG definitions, no update needed
if trigger_repr_from_asset == trigger_repr_from_asset_model:
continue

diff_to_add = trigger_repr_from_asset - trigger_repr_from_asset_model
diff_to_remove = trigger_repr_from_asset_model - trigger_repr_from_asset
if diff_to_add:
refs_to_add[name_uri] = diff_to_add
if diff_to_remove:
refs_to_remove[name_uri] = diff_to_remove

if refs_to_add:
all_trigger_reprs: set[str] = {
trigger_repr for trigger_reprs in refs_to_add.values() for trigger_repr in trigger_reprs
}

all_trigger_keys: set[tuple[str, str]] = {
self._encrypt_trigger_kwargs(triggers[trigger_repr])
for trigger_reprs in refs_to_add.values()
for trigger_repr in trigger_reprs
}
orm_triggers: dict[str, Trigger] = {
BaseTrigger.repr(trigger.classpath, trigger.kwargs): trigger
for trigger in session.scalars(
select(Trigger).where(
tuple_(Trigger.classpath, Trigger.encrypted_kwargs).in_(all_trigger_keys)
)
)
}

# Create new triggers
new_trigger_models = [
trigger
for trigger in [
Trigger.from_object(triggers[trigger_repr])
for trigger_repr in all_trigger_reprs
if trigger_repr not in orm_triggers
]
]
session.add_all(new_trigger_models)
orm_triggers.update(
(BaseTrigger.repr(trigger.classpath, trigger.kwargs), trigger)
for trigger in new_trigger_models
)

# Add new references
for name_uri, trigger_reprs in refs_to_add.items():
asset_model = assets[name_uri]
asset_model.triggers.extend(
[orm_triggers.get(trigger_repr) for trigger_repr in trigger_reprs]
)

if refs_to_remove:
# Remove old references
for name_uri, trigger_reprs in refs_to_remove.items():
asset_model = assets[name_uri]
asset_model.triggers = [
trigger
for trigger in asset_model.triggers
if BaseTrigger.repr(trigger.classpath, trigger.kwargs) not in trigger_reprs
]

# Remove references from assets no longer used
orphan_assets = session.scalars(
select(AssetModel).filter(~AssetModel.consuming_dags.any()).filter(AssetModel.triggers.any())
)
for asset_model in orphan_assets:
if (asset_model.name, asset_model.uri) not in self.assets:
asset_model.triggers = []

@staticmethod
def _encrypt_trigger_kwargs(trigger: BaseTrigger) -> tuple[str, str]:
classpath, kwargs = trigger.serialize()
return classpath, Trigger.encrypt_kwargs(kwargs)
1 change: 1 addition & 0 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -1863,6 +1863,7 @@ def bulk_write_to_db(
asset_op.add_dag_asset_references(orm_dags, orm_assets, session=session)
asset_op.add_dag_asset_alias_references(orm_dags, orm_asset_aliases, session=session)
asset_op.add_task_asset_references(orm_dags, orm_assets, session=session)
asset_op.add_asset_trigger_references(orm_assets, session=session)
session.flush()

@provide_session
Expand Down
10 changes: 4 additions & 6 deletions airflow/models/trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
from sqlalchemy.orm import Session
from sqlalchemy.sql import Select

from airflow.serialization.pydantic.trigger import TriggerPydantic
from airflow.triggers.base import BaseTrigger


Expand Down Expand Up @@ -89,7 +88,7 @@ def __init__(
) -> None:
super().__init__()
self.classpath = classpath
self.encrypted_kwargs = self._encrypt_kwargs(kwargs)
self.encrypted_kwargs = self.encrypt_kwargs(kwargs)
self.created_date = created_date or timezone.utcnow()

@property
Expand All @@ -100,10 +99,10 @@ def kwargs(self) -> dict[str, Any]:
@kwargs.setter
def kwargs(self, kwargs: dict[str, Any]) -> None:
"""Set the encrypted kwargs of the trigger."""
self.encrypted_kwargs = self._encrypt_kwargs(kwargs)
self.encrypted_kwargs = self.encrypt_kwargs(kwargs)

@staticmethod
def _encrypt_kwargs(kwargs: dict[str, Any]) -> str:
def encrypt_kwargs(kwargs: dict[str, Any]) -> str:
"""Encrypt the kwargs of the trigger."""
import json

Expand Down Expand Up @@ -141,8 +140,7 @@ def rotate_fernet_key(self):

@classmethod
@internal_api_call
@provide_session
def from_object(cls, trigger: BaseTrigger, session=NEW_SESSION) -> Trigger | TriggerPydantic:
def from_object(cls, trigger: BaseTrigger) -> Trigger:
"""Alternative constructor that creates a trigger row based directly off of a Trigger object."""
classpath, kwargs = trigger.serialize()
return cls(classpath=classpath, kwargs=kwargs)
Expand Down
4 changes: 2 additions & 2 deletions airflow/serialization/pydantic/trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(self, **kwargs) -> None:
# created_date
if "kwargs" in kwargs:
self.classpath = kwargs.pop("classpath")
self.encrypted_kwargs = Trigger._encrypt_kwargs(kwargs.pop("kwargs"))
self.encrypted_kwargs = Trigger.encrypt_kwargs(kwargs.pop("kwargs"))
self.created_date = kwargs.pop("created_date", timezone.utcnow())
super().__init__(**kwargs)

Expand All @@ -60,4 +60,4 @@ def kwargs(self, kwargs: dict[str, Any]) -> None:
"""Set the encrypted kwargs of the trigger."""
from airflow.models import Trigger

self.encrypted_kwargs = Trigger._encrypt_kwargs(kwargs)
self.encrypted_kwargs = Trigger.encrypt_kwargs(kwargs)
8 changes: 6 additions & 2 deletions airflow/triggers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,15 @@ async def cleanup(self) -> None:
and handle it appropriately (in async-compatible way).
"""

def __repr__(self) -> str:
classpath, kwargs = self.serialize()
@staticmethod
def repr(classpath: str, kwargs: dict[str, Any]):
kwargs_str = ", ".join(f"{k}={v}" for k, v in kwargs.items())
return f"<{classpath} {kwargs_str}>"

def __repr__(self) -> str:
classpath, kwargs = self.serialize()
return self.repr(classpath, kwargs)


class TriggerEvent:
"""
Expand Down
33 changes: 30 additions & 3 deletions task_sdk/src/airflow/sdk/definitions/asset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@

from sqlalchemy.orm.session import Session

from airflow.triggers.base import BaseTrigger


__all__ = [
"Asset",
Expand Down Expand Up @@ -223,20 +225,43 @@ class Asset(os.PathLike, BaseAsset):
uri: str
group: str
extra: dict[str, Any]
watchers: list[BaseTrigger]

asset_type: ClassVar[str] = "asset"
__version__: ClassVar[int] = 1

@overload
def __init__(self, name: str, uri: str, *, group: str = "", extra: dict | None = None) -> None:
def __init__(
self,
name: str,
uri: str,
*,
group: str = "",
extra: dict | None = None,
watchers: list[BaseTrigger] | None = None,
) -> None:
"""Canonical; both name and uri are provided."""

@overload
def __init__(self, name: str, *, group: str = "", extra: dict | None = None) -> None:
def __init__(
self,
name: str,
*,
group: str = "",
extra: dict | None = None,
watchers: list[BaseTrigger] | None = None,
) -> None:
"""It's possible to only provide the name, either by keyword or as the only positional argument."""

@overload
def __init__(self, *, uri: str, group: str = "", extra: dict | None = None) -> None:
def __init__(
self,
*,
uri: str,
group: str = "",
extra: dict | None = None,
watchers: list[BaseTrigger] | None = None,
) -> None:
"""It's possible to only provide the URI as a keyword argument."""

def __init__(
Expand All @@ -246,6 +271,7 @@ def __init__(
*,
group: str = "",
extra: dict | None = None,
watchers: list[BaseTrigger] | None = None,
) -> None:
if name is None and uri is None:
raise TypeError("Asset() requires either 'name' or 'uri'")
Expand All @@ -258,6 +284,7 @@ def __init__(
self.uri = _sanitize_uri(_validate_non_empty_identifier(self, fields["uri"], uri))
self.group = _validate_identifier(self, fields["group"], group) if group else self.asset_type
self.extra = _set_extra_default(extra)
self.watchers = watchers or []

def __fspath__(self) -> str:
return self.uri
Expand Down
3 changes: 3 additions & 0 deletions task_sdk/src/airflow/sdk/definitions/asset/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@

if TYPE_CHECKING:
from airflow.io.path import ObjectStoragePath
from airflow.triggers.base import BaseTrigger


class _AssetMainOperator(PythonOperator):
Expand Down Expand Up @@ -121,6 +122,7 @@ class asset:
uri: str | ObjectStoragePath | None = None
group: str = ""
extra: dict[str, Any] = attrs.field(factory=dict)
watchers: list[BaseTrigger] = attrs.field(factory=list)

def __call__(self, f: Callable) -> AssetDefinition:
if (name := f.__name__) != f.__qualname__:
Expand All @@ -131,6 +133,7 @@ def __call__(self, f: Callable) -> AssetDefinition:
uri=name if self.uri is None else str(self.uri),
group=self.group,
extra=self.extra,
watchers=self.watchers,
function=f,
schedule=self.schedule,
)
6 changes: 3 additions & 3 deletions task_sdk/src/airflow/sdk/definitions/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,9 +552,9 @@ def __lt__(self, other):
def __hash__(self):
hash_components: list[Any] = [type(self)]
for c in _DAG_HASH_ATTRS:
# task_ids returns a list and lists can't be hashed
if c == "task_ids":
val = tuple(self.task_dict)
# If it is a list, convert to tuple because lists can't be hashed
if isinstance(getattr(self, c, None), list):
val = tuple(getattr(self, c))
else:
val = getattr(self, c, None)
try:
Expand Down

0 comments on commit bee7f0c

Please sign in to comment.