Skip to content

Commit

Permalink
Move triggers to standard provider (apache#43608)
Browse files Browse the repository at this point in the history
* move triggers to standard provider

* move triggers to standard provider

* remove tests/triggers

* fix static checks

* ignore TaskSuccessEvent for 2.8 and 2.9 versions

* fix compat tests related to logical dates and TaskSuccessEvent

* fix test_external_task test

* move sensor helper inside standard provider

* fix external task test
  • Loading branch information
gopidesupavan authored Nov 25, 2024
1 parent 222dbdc commit 4404e64
Show file tree
Hide file tree
Showing 31 changed files with 217 additions and 86 deletions.
4 changes: 2 additions & 2 deletions .github/boring-cyborg.yml
Original file line number Diff line number Diff line change
Expand Up @@ -652,12 +652,12 @@ labelPRBasedOnFilePath:
- airflow/cli/commands/triggerer_command.py
- airflow/jobs/triggerer_job_runner.py
- airflow/models/trigger.py
- airflow/triggers/**/*
- providers/src/airflow/providers/standard/triggers/**/*
- tests/cli/commands/test_triggerer_command.py
- tests/jobs/test_triggerer_job.py
- tests/models/test_trigger.py
- tests/jobs/test_triggerer_job_logging.py
- tests/triggers/**/*
- providers/tests/standard/triggers/**/*

area:Serialization:
- airflow/serialization/**/*
Expand Down
4 changes: 2 additions & 2 deletions airflow/sensors/external_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@
from airflow.models.dagbag import DagBag
from airflow.models.taskinstance import TaskInstance
from airflow.operators.empty import EmptyOperator
from airflow.providers.standard.triggers.external_task import WorkflowTrigger
from airflow.providers.standard.utils.sensor_helper import _get_count, _get_external_task_group_task_ids
from airflow.sensors.base import BaseSensorOperator
from airflow.triggers.external_task import WorkflowTrigger
from airflow.utils.file import correct_maybe_zipped
from airflow.utils.helpers import build_airflow_url_with_query
from airflow.utils.sensor_helper import _get_count, _get_external_task_group_task_ids
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.state import State, TaskInstanceState

Expand Down
1 change: 0 additions & 1 deletion dev/breeze/tests/test_pytest_args_for_test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,6 @@
"tests/template",
"tests/testconfig",
"tests/timetables",
"tests/triggers",
],
),
(
Expand Down
12 changes: 6 additions & 6 deletions docs/apache-airflow/authoring-and-scheduling/deferring.rst
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ When writing a deferrable operators these are the main points to consider:
from airflow.configuration import conf
from airflow.sensors.base import BaseSensorOperator
from airflow.triggers.temporal import TimeDeltaTrigger
from airflow.providers.standard.triggers.temporal import TimeDeltaTrigger
from airflow.utils.context import Context
Expand Down Expand Up @@ -122,7 +122,7 @@ This example shows the structure of a basic trigger, a very simplified version o
self.moment = moment
def serialize(self):
return ("airflow.triggers.temporal.DateTimeTrigger", {"moment": self.moment})
return ("airflow.providers.standard.triggers.temporal.DateTimeTrigger", {"moment": self.moment})
async def run(self):
while self.moment > timezone.utcnow():
Expand Down Expand Up @@ -177,7 +177,7 @@ Here's a basic example of how a sensor might trigger deferral:
from typing import TYPE_CHECKING, Any
from airflow.sensors.base import BaseSensorOperator
from airflow.triggers.temporal import TimeDeltaTrigger
from airflow.providers.standard.triggers.temporal import TimeDeltaTrigger
if TYPE_CHECKING:
from airflow.utils.context import Context
Expand Down Expand Up @@ -237,7 +237,7 @@ In the sensor part, we'll need to provide the path to ``TimeDeltaTrigger`` as ``
class WaitOneHourSensor(BaseSensorOperator):
start_trigger_args = StartTriggerArgs(
trigger_cls="airflow.triggers.temporal.TimeDeltaTrigger",
trigger_cls="airflow.providers.standard.triggers.temporal.TimeDeltaTrigger",
trigger_kwargs={"moment": timedelta(hours=1)},
next_method="execute_complete",
next_kwargs=None,
Expand Down Expand Up @@ -268,7 +268,7 @@ In the sensor part, we'll need to provide the path to ``TimeDeltaTrigger`` as ``
class WaitHoursSensor(BaseSensorOperator):
start_trigger_args = StartTriggerArgs(
trigger_cls="airflow.triggers.temporal.TimeDeltaTrigger",
trigger_cls="airflow.providers.standard.triggers.temporal.TimeDeltaTrigger",
trigger_kwargs={"moment": timedelta(hours=1)},
next_method="execute_complete",
next_kwargs=None,
Expand Down Expand Up @@ -307,7 +307,7 @@ After the trigger has finished executing, the task may be sent back to the worke
class WaitHoursSensor(BaseSensorOperator):
start_trigger_args = StartTriggerArgs(
trigger_cls="airflow.triggers.temporal.TimeDeltaTrigger",
trigger_cls="airflow.providers.standard.triggers.temporal.TimeDeltaTrigger",
trigger_kwargs={"moment": timedelta(hours=1)},
next_method="execute_complete",
next_kwargs=None,
Expand Down
1 change: 1 addition & 0 deletions generated/provider_dependencies.json
Original file line number Diff line number Diff line change
Expand Up @@ -846,6 +846,7 @@
"plugins": [],
"cross-providers-deps": [
"amazon",
"common.compat",
"google",
"oracle",
"sftp"
Expand Down
31 changes: 31 additions & 0 deletions providers/src/airflow/providers/common/compat/standard/triggers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING:
from airflow.providers.standard.triggers.temporal import TimeDeltaTrigger
else:
try:
from airflow.providers.standard.triggers.temporal import TimeDeltaTrigger
except ModuleNotFoundError:
from airflow.triggers.temporal import TimeDeltaTrigger


__all__ = ["TimeDeltaTrigger"]
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
from typing import TYPE_CHECKING, Any, Callable

from airflow.exceptions import AirflowException
from airflow.providers.common.compat.standard.triggers import TimeDeltaTrigger
from airflow.providers.microsoft.azure.hooks.msgraph import KiotaRequestAdapterHook
from airflow.providers.microsoft.azure.triggers.msgraph import MSGraphTrigger, ResponseSerializer
from airflow.sensors.base import BaseSensorOperator
from airflow.triggers.temporal import TimeDeltaTrigger

if TYPE_CHECKING:
from datetime import timedelta
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from airflow.models.dagbag import DagBag
from airflow.models.dagrun import DagRun
from airflow.models.xcom import XCom
from airflow.triggers.external_task import DagStateTrigger
from airflow.providers.standard.triggers.external_task import DagStateTrigger
from airflow.utils import timezone
from airflow.utils.helpers import build_airflow_url_with_query
from airflow.utils.session import provide_session
Expand Down
7 changes: 7 additions & 0 deletions providers/src/airflow/providers/standard/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,13 @@ hooks:
- airflow.providers.standard.hooks.package_index
- airflow.providers.standard.hooks.subprocess

triggers:
- integration-name: Standard
python-modules:
- airflow.providers.standard.triggers.external_task
- airflow.providers.standard.triggers.file
- airflow.providers.standard.triggers.temporal

config:
standard:
description: Options for the standard provider operators.
Expand Down
4 changes: 2 additions & 2 deletions providers/src/airflow/providers/standard/sensors/date_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, NoReturn

from airflow.providers.standard.triggers.temporal import DateTimeTrigger
from airflow.providers.standard.utils.version_references import AIRFLOW_V_3_0_PLUS
from airflow.sensors.base import BaseSensorOperator

Expand All @@ -40,7 +41,6 @@ class StartTriggerArgs: # type: ignore[no-redef]
timeout: datetime.timedelta | None = None


from airflow.triggers.temporal import DateTimeTrigger
from airflow.utils import timezone

if TYPE_CHECKING:
Expand Down Expand Up @@ -111,7 +111,7 @@ class DateTimeSensorAsync(DateTimeSensor):
"""

start_trigger_args = StartTriggerArgs(
trigger_cls="airflow.triggers.temporal.DateTimeTrigger",
trigger_cls="airflow.providers.standard.triggers.temporal.DateTimeTrigger",
trigger_kwargs={"moment": "", "end_from_trigger": False},
next_method="execute_complete",
next_kwargs=None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@
from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.providers.standard.hooks.filesystem import FSHook
from airflow.providers.standard.triggers.file import FileTrigger
from airflow.sensors.base import BaseSensorOperator
from airflow.triggers.base import StartTriggerArgs
from airflow.triggers.file import FileTrigger

if TYPE_CHECKING:
from airflow.utils.context import Context
Expand Down Expand Up @@ -64,7 +64,7 @@ class FileSensor(BaseSensorOperator):
template_fields: Sequence[str] = ("filepath",)
ui_color = "#91818a"
start_trigger_args = StartTriggerArgs(
trigger_cls="airflow.triggers.file.FileTrigger",
trigger_cls="airflow.providers.standard.triggers.file.FileTrigger",
trigger_kwargs={},
next_method="execute_complete",
next_kwargs=None,
Expand Down
4 changes: 2 additions & 2 deletions providers/src/airflow/providers/standard/sensors/time.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, NoReturn

from airflow.providers.standard.triggers.temporal import DateTimeTrigger
from airflow.providers.standard.utils.version_references import AIRFLOW_V_2_10_PLUS
from airflow.sensors.base import BaseSensorOperator

Expand All @@ -39,7 +40,6 @@ class StartTriggerArgs: # type: ignore[no-redef]
timeout: datetime.timedelta | None = None


from airflow.triggers.temporal import DateTimeTrigger
from airflow.utils import timezone

if TYPE_CHECKING:
Expand Down Expand Up @@ -85,7 +85,7 @@ class TimeSensorAsync(BaseSensorOperator):
"""

start_trigger_args = StartTriggerArgs(
trigger_cls="airflow.triggers.temporal.DateTimeTrigger",
trigger_cls="airflow.providers.standard.triggers.temporal.DateTimeTrigger",
trigger_kwargs={"moment": "", "end_from_trigger": False},
next_method="execute_complete",
next_kwargs=None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@

from airflow.configuration import conf
from airflow.exceptions import AirflowSkipException
from airflow.providers.standard.triggers.temporal import DateTimeTrigger, TimeDeltaTrigger
from airflow.providers.standard.utils.version_references import AIRFLOW_V_3_0_PLUS
from airflow.sensors.base import BaseSensorOperator
from airflow.triggers.temporal import DateTimeTrigger, TimeDeltaTrigger
from airflow.utils import timezone

if TYPE_CHECKING:
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@
from sqlalchemy import func

from airflow.models import DagRun
from airflow.providers.standard.utils.sensor_helper import _get_count
from airflow.providers.standard.utils.version_references import AIRFLOW_V_3_0_PLUS
from airflow.triggers.base import BaseTrigger, TriggerEvent
from airflow.utils.sensor_helper import _get_count
from airflow.utils.session import NEW_SESSION, provide_session

if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -54,7 +55,8 @@ class WorkflowTrigger(BaseTrigger):
def __init__(
self,
external_dag_id: str,
logical_dates: list,
logical_dates: list[datetime] | None = None,
execution_dates: list[datetime] | None = None,
external_task_ids: typing.Collection[str] | None = None,
external_task_group_id: str | None = None,
failed_states: typing.Iterable[str] | None = None,
Expand All @@ -73,20 +75,26 @@ def __init__(
self.logical_dates = logical_dates
self.poke_interval = poke_interval
self.soft_fail = soft_fail
self.execution_dates = execution_dates
super().__init__(**kwargs)

def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serialize the trigger param and module path."""
_dates = (
{"logical_dates": self.logical_dates}
if AIRFLOW_V_3_0_PLUS
else {"execution_dates": self.execution_dates}
)
return (
"airflow.triggers.external_task.WorkflowTrigger",
"airflow.providers.standard.triggers.external_task.WorkflowTrigger",
{
"external_dag_id": self.external_dag_id,
"external_task_ids": self.external_task_ids,
"external_task_group_id": self.external_task_group_id,
"failed_states": self.failed_states,
"skipped_states": self.skipped_states,
"allowed_states": self.allowed_states,
"logical_dates": self.logical_dates,
**_dates,
"poke_interval": self.poke_interval,
"soft_fail": self.soft_fail,
},
Expand All @@ -109,7 +117,8 @@ async def run(self) -> typing.AsyncIterator[TriggerEvent]:
yield TriggerEvent({"status": "skipped"})
return
allowed_count = await self._get_count(self.allowed_states)
if allowed_count == len(self.logical_dates):
_dates = self.logical_dates if AIRFLOW_V_3_0_PLUS else self.execution_dates
if allowed_count == len(_dates): # type: ignore[arg-type]
yield TriggerEvent({"status": "success"})
return
self.log.info("Sleeping for %s seconds", self.poke_interval)
Expand All @@ -124,7 +133,7 @@ def _get_count(self, states: typing.Iterable[str] | None) -> int:
:return The count of records.
"""
return _get_count(
dttm_filter=self.logical_dates,
dttm_filter=self.logical_dates if AIRFLOW_V_3_0_PLUS else self.execution_dates,
external_task_ids=self.external_task_ids,
external_task_group_id=self.external_task_group_id,
external_dag_id=self.external_dag_id,
Expand All @@ -147,23 +156,30 @@ def __init__(
self,
dag_id: str,
states: list[DagRunState],
logical_dates: list[datetime],
logical_dates: list[datetime] | None = None,
execution_dates: list[datetime] | None = None,
poll_interval: float = 5.0,
):
super().__init__()
self.dag_id = dag_id
self.states = states
self.logical_dates = logical_dates
self.execution_dates = execution_dates
self.poll_interval = poll_interval

def serialize(self) -> tuple[str, dict[str, typing.Any]]:
"""Serialize DagStateTrigger arguments and classpath."""
_dates = (
{"logical_dates": self.logical_dates}
if AIRFLOW_V_3_0_PLUS
else {"execution_dates": self.execution_dates}
)
return (
"airflow.triggers.external_task.DagStateTrigger",
"airflow.providers.standard.triggers.external_task.DagStateTrigger",
{
"dag_id": self.dag_id,
"states": self.states,
"logical_dates": self.logical_dates,
**_dates,
"poll_interval": self.poll_interval,
},
)
Expand All @@ -173,7 +189,8 @@ async def run(self) -> typing.AsyncIterator[TriggerEvent]:
while True:
# mypy confuses typing here
num_dags = await self.count_dags() # type: ignore[call-arg]
if num_dags == len(self.logical_dates):
_dates = self.logical_dates if AIRFLOW_V_3_0_PLUS else self.execution_dates
if num_dags == len(_dates): # type: ignore[arg-type]
yield TriggerEvent(self.serialize())
return
await asyncio.sleep(self.poll_interval)
Expand All @@ -182,12 +199,17 @@ async def run(self) -> typing.AsyncIterator[TriggerEvent]:
@provide_session
def count_dags(self, *, session: Session = NEW_SESSION) -> int | None:
"""Count how many dag runs in the database match our criteria."""
_dag_run_date_condition = (
DagRun.logical_date.in_(self.logical_dates)
if AIRFLOW_V_3_0_PLUS
else DagRun.execution_date.in_(self.execution_dates)
)
count = (
session.query(func.count("*")) # .count() is inefficient
.filter(
DagRun.dag_id == self.dag_id,
DagRun.state.in_(self.states),
DagRun.logical_date.in_(self.logical_dates),
_dag_run_date_condition,
)
.scalar()
)
Expand Down
Loading

0 comments on commit 4404e64

Please sign in to comment.