diff --git a/airflow/callbacks/callback_requests.py b/airflow/callbacks/callback_requests.py index 8ec0187978db6..42fb5166124e9 100644 --- a/airflow/callbacks/callback_requests.py +++ b/airflow/callbacks/callback_requests.py @@ -21,6 +21,7 @@ if TYPE_CHECKING: from airflow.models.taskinstance import SimpleTaskInstance + from airflow.utils.state import TaskInstanceState class CallbackRequest: @@ -71,6 +72,7 @@ class TaskCallbackRequest(CallbackRequest): :param is_failure_callback: Flag to determine whether it is a Failure Callback or Success Callback :param msg: Additional Message that can be used for logging to determine failure/zombie :param processor_subdir: Directory used by Dag Processor when parsed the dag. + :param task_callback_type: e.g. whether on success, on failure, on retry. """ def __init__( @@ -80,10 +82,12 @@ def __init__( is_failure_callback: bool | None = True, processor_subdir: str | None = None, msg: str | None = None, + task_callback_type: TaskInstanceState | None = None, ): super().__init__(full_filepath=full_filepath, processor_subdir=processor_subdir, msg=msg) self.simple_task_instance = simple_task_instance self.is_failure_callback = is_failure_callback + self.task_callback_type = task_callback_type def to_json(self) -> str: from airflow.serialization.serialized_objects import BaseSerialization diff --git a/airflow/dag_processing/manager.py b/airflow/dag_processing/manager.py index 074d83585bd9a..bd629e2828392 100644 --- a/airflow/dag_processing/manager.py +++ b/airflow/dag_processing/manager.py @@ -72,6 +72,8 @@ from sqlalchemy.orm import Session +logger = logging.getLogger(__name__) + class DagParsingStat(NamedTuple): """Information on processing progress.""" @@ -574,6 +576,7 @@ def _run_parsing_loop(self): pass elif isinstance(agent_signal, CallbackRequest): self._add_callback_to_queue(agent_signal) + self.log.warning("_add_callback_to_queue; agent signal; %s", agent_signal) else: raise ValueError(f"Invalid message {type(agent_signal)}") @@ -676,7 +679,7 @@ def _fetch_callbacks(self, max_callbacks: int, session: Session = NEW_SESSION): @retry_db_transaction def _fetch_callbacks_with_retries(self, max_callbacks: int, session: Session): """Fetch callbacks from database and add them to the internal queue for execution.""" - self.log.debug("Fetching callbacks from the database.") + self.log.warning("Fetching callbacks from the database.") with prohibit_commit(session) as guard: query = select(DbCallbackRequest) if self.standalone_dag_processor: @@ -761,7 +764,7 @@ def _refresh_dag_dir(self) -> bool: self.set_file_paths(self._file_paths) try: - self.log.debug("Removing old import errors") + self.log.warning("Removing old import errors") DagFileProcessorManager.clear_nonexistent_import_errors( file_paths=self._file_paths, processor_subdir=self.get_dag_directory() ) diff --git a/airflow/dag_processing/processor.py b/airflow/dag_processing/processor.py index 8df64c9f1eb3e..eb8b97c7e0424 100644 --- a/airflow/dag_processing/processor.py +++ b/airflow/dag_processing/processor.py @@ -47,7 +47,7 @@ from airflow.models.dagwarning import DagWarning, DagWarningType from airflow.models.errors import ParseImportError from airflow.models.serialized_dag import SerializedDagModel -from airflow.models.taskinstance import TaskInstance, TaskInstance as TI +from airflow.models.taskinstance import TaskInstance, TaskInstance as TI, _run_finished_callback from airflow.stats import Stats from airflow.utils import timezone from airflow.utils.email import get_email_address_list, send_email @@ -762,8 +762,29 @@ def _execute_dag_callbacks(self, dagbag: DagBag, request: DagCallbackRequest, se if callbacks and context: DAG.execute_callback(callbacks, context, dag.dag_id) - def _execute_task_callbacks(self, dagbag: DagBag | None, request: TaskCallbackRequest, session: Session): - if not request.is_failure_callback: + def _execute_task_callbacks( + self, dagbag: DagBag | None, request: TaskCallbackRequest, session: Session + ) -> None: + """ + Execute the task callbacks. + + :param dagbag: the DagBag to use to get the task instance + :param request: the task callback request + :param session: the session to use + """ + try: + callback_type = TaskInstanceState(request.task_callback_type) + except Exception: + callback_type = None + is_remote = callback_type in (TaskInstanceState.SUCCESS, TaskInstanceState.FAILED) + + # previously we ignored any request besides failures. now if given callback type directly, + # then we respect it and execute it. additionally because in this scenario the callback + # is submitted remotely, we assume there is no need to mess with state; we simply run + # the callback + + if not is_remote and not request.is_failure_callback: + self.log.warning("not failure callback: %s", request) return simple_ti = request.simple_task_instance @@ -774,6 +795,7 @@ def _execute_task_callbacks(self, dagbag: DagBag | None, request: TaskCallbackRe map_index=simple_ti.map_index, session=session, ) + if not ti: return @@ -795,8 +817,17 @@ def _execute_task_callbacks(self, dagbag: DagBag | None, request: TaskCallbackRe if task: ti.refresh_from_task(task) - ti.handle_failure(error=request.msg, test_mode=self.UNIT_TEST_MODE, session=session) - self.log.info("Executed failure callback for %s in state %s", ti, ti.state) + if callback_type is TaskInstanceState.SUCCESS: + context = ti.get_template_context(session=session) + if callback_type is TaskInstanceState.SUCCESS: + if not ti.task: + return + callbacks = ti.task.on_success_callback + _run_finished_callback(callbacks=callbacks, context=context) + self.log.info("Executed callback for %s in state %s", ti, ti.state) + elif not is_remote or callback_type is TaskInstanceState.FAILED: + ti.handle_failure(error=request.msg, test_mode=self.UNIT_TEST_MODE, session=session) + self.log.info("Executed callback for %s in state %s", ti, ti.state) session.flush() @classmethod diff --git a/airflow/exceptions.py b/airflow/exceptions.py index 960c24b7e0f39..d0a30f7828c07 100644 --- a/airflow/exceptions.py +++ b/airflow/exceptions.py @@ -372,14 +372,20 @@ class TaskDeferred(BaseException): Signal an operator moving to deferred state. Special exception raised to signal that the operator it was raised from - wishes to defer until a trigger fires. + wishes to defer until a trigger fires. Triggers can send execution back to task or end the task instance + directly. If the trigger will end the task instance itself, ``method_name`` should be + None; otherwise, provide the name of the method that should be used when + resuming execution in the task. """ + TRIGGER_EXIT = "__trigger_exit__" + """Sentinel value to signal the expectation that the trigger will exit the task.""" + def __init__( self, *, trigger, - method_name: str, + method_name: str = TRIGGER_EXIT, kwargs: dict[str, Any] | None = None, timeout: datetime.timedelta | None = None, ): diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 98532d90b0256..84a90e276011f 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -1704,7 +1704,7 @@ def defer( self, *, trigger: BaseTrigger, - method_name: str, + method_name: str = TaskDeferred.TRIGGER_EXIT, kwargs: dict[str, Any] | None = None, timeout: timedelta | None = None, ) -> NoReturn: @@ -1712,7 +1712,10 @@ def defer( Mark this Operator "deferred", suspending its execution until the provided trigger fires an event. This is achieved by raising a special exception (TaskDeferred) - which is caught in the main _execute_task wrapper. + which is caught in the main _execute_task wrapper. Triggers can send execution back to task or end + the task instance directly. If the trigger will end the task instance itself, ``method_name`` should + be None; otherwise, provide the name of the method that should be used when resuming execution in + the task. """ raise TaskDeferred(trigger=trigger, method_name=method_name, kwargs=kwargs, timeout=timeout) diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 1637c06e58572..1695b0d41e572 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -709,6 +709,12 @@ def _execute_task(task_instance: TaskInstance | TaskInstancePydantic, context: C execute_callable = task_to_execute.resume_execution execute_callable_kwargs["next_method"] = task_instance.next_method execute_callable_kwargs["next_kwargs"] = task_instance.next_kwargs + elif task_instance.next_method == TaskDeferred.TRIGGER_EXIT: + raise AirflowException( + "Task is resuming from deferral without next_method specified. " + "You must either set `method_name` when deferring, or use a trigger " + "that is designed to exit the task." + ) else: execute_callable = task_to_execute.execute if execute_callable.__name__ == "execute": diff --git a/airflow/models/trigger.py b/airflow/models/trigger.py index 670d88d6142bb..a2ec24b6c2cd1 100644 --- a/airflow/models/trigger.py +++ b/airflow/models/trigger.py @@ -202,14 +202,7 @@ def submit_event(cls, trigger_id, event, session: Session = NEW_SESSION) -> None TaskInstance.trigger_id == trigger_id, TaskInstance.state == TaskInstanceState.DEFERRED ) ): - # Add the event's payload into the kwargs for the task - next_kwargs = task_instance.next_kwargs or {} - next_kwargs["event"] = event.payload - task_instance.next_kwargs = next_kwargs - # Remove ourselves as its trigger - task_instance.trigger_id = None - # Finally, mark it as scheduled so it gets re-queued - task_instance.state = TaskInstanceState.SCHEDULED + event.handle_submit(task_instance=task_instance) @classmethod @internal_api_call diff --git a/airflow/sensors/date_time.py b/airflow/sensors/date_time.py index b0763ebd40a87..d3d58356610d3 100644 --- a/airflow/sensors/date_time.py +++ b/airflow/sensors/date_time.py @@ -18,7 +18,7 @@ from __future__ import annotations import datetime -from typing import TYPE_CHECKING, NoReturn, Sequence +from typing import TYPE_CHECKING, Sequence from airflow.sensors.base import BaseSensorOperator from airflow.triggers.temporal import DateTimeTrigger @@ -90,13 +90,5 @@ class DateTimeSensorAsync(DateTimeSensor): def __init__(self, **kwargs) -> None: super().__init__(**kwargs) - def execute(self, context: Context) -> NoReturn: - trigger = DateTimeTrigger(moment=timezone.parse(self.target_time)) - self.defer( - trigger=trigger, - method_name="execute_complete", - ) - - def execute_complete(self, context, event=None) -> None: - """Execute when the trigger fires - returns immediately.""" - return None + def execute(self, context: Context): + self.defer(trigger=DateTimeTrigger(moment=timezone.parse(self.target_time), end_task=True)) diff --git a/airflow/sensors/time_delta.py b/airflow/sensors/time_delta.py index 82d16bbae6575..ee17b0bfa8ea0 100644 --- a/airflow/sensors/time_delta.py +++ b/airflow/sensors/time_delta.py @@ -17,7 +17,7 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, NoReturn +from typing import TYPE_CHECKING from airflow.exceptions import AirflowSkipException from airflow.sensors.base import BaseSensorOperator @@ -66,18 +66,14 @@ class TimeDeltaSensorAsync(TimeDeltaSensor): """ - def execute(self, context: Context) -> NoReturn: + def execute(self, context: Context): target_dttm = context["data_interval_end"] target_dttm += self.delta try: - trigger = DateTimeTrigger(moment=target_dttm) + trigger = DateTimeTrigger(moment=target_dttm, end_task=True) except (TypeError, ValueError) as e: if self.soft_fail: raise AirflowSkipException("Skipping due to soft_fail is set to True.") from e raise - self.defer(trigger=trigger, method_name="execute_complete") - - def execute_complete(self, context, event=None) -> None: - """Execute for when the trigger fires - return immediately.""" - return None + self.defer(trigger=trigger) diff --git a/airflow/sensors/time_sensor.py b/airflow/sensors/time_sensor.py index 91c1354782593..24a8655dca15e 100644 --- a/airflow/sensors/time_sensor.py +++ b/airflow/sensors/time_sensor.py @@ -18,7 +18,7 @@ from __future__ import annotations import datetime -from typing import TYPE_CHECKING, NoReturn +from typing import TYPE_CHECKING from airflow.sensors.base import BaseSensorOperator from airflow.triggers.temporal import DateTimeTrigger @@ -40,11 +40,11 @@ class TimeSensor(BaseSensorOperator): """ - def __init__(self, *, target_time: datetime.time, **kwargs) -> None: + def __init__(self, *, target_time, **kwargs): super().__init__(**kwargs) self.target_time = target_time - def poke(self, context: Context) -> bool: + def poke(self, context: Context): self.log.info("Checking if the time (%s) has come", self.target_time) return timezone.make_naive(timezone.utcnow(), self.dag.timezone).time() > self.target_time @@ -62,7 +62,7 @@ class TimeSensorAsync(BaseSensorOperator): :ref:`howto/operator:TimeSensorAsync` """ - def __init__(self, *, target_time: datetime.time, **kwargs) -> None: + def __init__(self, *, target_time, **kwargs): super().__init__(**kwargs) self.target_time = target_time @@ -72,13 +72,5 @@ def __init__(self, *, target_time: datetime.time, **kwargs) -> None: self.target_datetime = timezone.convert_to_utc(aware_time) - def execute(self, context: Context) -> NoReturn: - trigger = DateTimeTrigger(moment=self.target_datetime) - self.defer( - trigger=trigger, - method_name="execute_complete", - ) - - def execute_complete(self, context, event=None) -> None: - """Execute when the trigger fires - returns immediately.""" - return None + def execute(self, context: Context): + self.defer(trigger=DateTimeTrigger(moment=self.target_datetime, end_task=True)) diff --git a/airflow/triggers/base.py b/airflow/triggers/base.py index 0d239af0cafd4..6b07f4092805d 100644 --- a/airflow/triggers/base.py +++ b/airflow/triggers/base.py @@ -17,9 +17,22 @@ from __future__ import annotations import abc -from typing import Any, AsyncIterator +import logging +from typing import TYPE_CHECKING, Any, AsyncIterator +from airflow.callbacks.callback_requests import TaskCallbackRequest +from airflow.callbacks.database_callback_sink import DatabaseCallbackSink +from airflow.models.taskinstance import SimpleTaskInstance from airflow.utils.log.logging_mixin import LoggingMixin +from airflow.utils.session import NEW_SESSION, provide_session +from airflow.utils.state import TaskInstanceState + +if TYPE_CHECKING: + from sqlalchemy.orm import Session + + from airflow.models import TaskInstance + +log = logging.getLogger(__name__) class BaseTrigger(abc.ABC, LoggingMixin): @@ -115,3 +128,105 @@ def __eq__(self, other): if isinstance(other, TriggerEvent): return other.payload == self.payload return False + + @provide_session + def handle_submit(self, *, task_instance: TaskInstance, session: Session = NEW_SESSION): + """ + Handle the submit event for a given task instance. + + This function sets the next method and next kwargs of the task instance, + as well as its state to scheduled. It also adds the event's payload + into the kwargs for the task. + + :param task_instance: The task instance to handle the submit event for. + :param session: The session to be used for the database callback sink. + """ + # Get the next kwargs of the task instance, or an empty dictionary if it doesn't exist + next_kwargs = task_instance.next_kwargs or {} + + # Add the event's payload into the kwargs for the task + next_kwargs["event"] = self.payload + + # Update the next kwargs of the task instance + task_instance.next_kwargs = next_kwargs + + # Remove ourselves as its trigger + task_instance.trigger_id = None + + # Set the state of the task instance to scheduled + task_instance.state = TaskInstanceState.SCHEDULED + + +class BaseTaskEndEvent(TriggerEvent): + """Base event class to end the task without resuming on worker.""" + + task_instance_state: TaskInstanceState + + def __init__(self, *, xcoms: dict[str, Any] | None = None, **kwargs) -> None: + """ + Initialize the class with the specified parameters. + + :param xcoms: A dictionary of XComs or None. + :param kwargs: Additional keyword arguments. + """ + if "payload" in kwargs: + raise ValueError("Param 'payload' not supported for this class.") + super().__init__(payload=self.task_instance_state) + self.xcoms = xcoms + + @provide_session + def handle_submit(self, *, task_instance: TaskInstance, session: Session = NEW_SESSION): + """ + Submit event for the given task instance. + + Marks the task with the state `task_instance_state` and optionally pushes xcom if applicable. + + :param task_instance: The task instance to be submitted. + :param session: The session to be used for the database callback sink. + """ + # Mark the task with terminal state and prevent it from resuming on worker + task_instance.trigger_id = None + task_instance.state = self.task_instance_state + + self._submit_callback_if_necessary(task_instance=task_instance, session=session) + self._push_xcoms_if_necessary(task_instance=task_instance) + + def _submit_callback_if_necessary(self, *, task_instance: TaskInstance, session): + """Submit a callback request if the task state is SUCCESS or FAILED.""" + is_failure = self.task_instance_state == TaskInstanceState.FAILED + if self.task_instance_state in [TaskInstanceState.SUCCESS, TaskInstanceState.FAILED]: + request = TaskCallbackRequest( + full_filepath=task_instance.dag_model.fileloc, + simple_task_instance=SimpleTaskInstance.from_ti(task_instance), + is_failure_callback=is_failure, + task_callback_type=self.task_instance_state, + ) + log.warning("Sending callback: %s", request) + try: + DatabaseCallbackSink().send(callback=request, session=session) + except Exception as e: + log.error("Failed to send callback: %s", e) + + def _push_xcoms_if_necessary(self, *, task_instance: TaskInstance): + """Pushes XComs to the database if they are provided.""" + if self.xcoms: + for key, value in self.xcoms.items(): + task_instance.xcom_push(key=key, value=value) + + +class TaskSuccessEvent(BaseTaskEndEvent): + """Yield this event in order to end the task successfully.""" + + task_instance_state = TaskInstanceState.SUCCESS + + +class TaskFailedEvent(BaseTaskEndEvent): + """Yield this event in order to end the task with failure.""" + + task_instance_state = TaskInstanceState.FAILED + + +class TaskSkippedEvent(BaseTaskEndEvent): + """Yield this event in order to end the task with status 'skipped'.""" + + task_instance_state = TaskInstanceState.SKIPPED diff --git a/airflow/triggers/temporal.py b/airflow/triggers/temporal.py index 79e8f39dd76e7..86c2bf16b777e 100644 --- a/airflow/triggers/temporal.py +++ b/airflow/triggers/temporal.py @@ -22,7 +22,7 @@ import pendulum -from airflow.triggers.base import BaseTrigger, TriggerEvent +from airflow.triggers.base import BaseTrigger, TaskSuccessEvent, TriggerEvent from airflow.utils import timezone @@ -34,9 +34,13 @@ class DateTimeTrigger(BaseTrigger): a few seconds. The provided datetime MUST be in UTC. + + :param moment: when to yield event + :param end_task: whether the trigger should mark the task successful after time condition + reached or resume the task after time condition reached. """ - def __init__(self, moment: datetime.datetime): + def __init__(self, moment: datetime.datetime, *, end_task=False): super().__init__() if not isinstance(moment, datetime.datetime): raise TypeError(f"Expected datetime.datetime type for moment. Got {type(moment)}") @@ -45,9 +49,13 @@ def __init__(self, moment: datetime.datetime): raise ValueError("You cannot pass naive datetimes") else: self.moment: pendulum.DateTime = timezone.convert_to_utc(moment) + self.end_task = end_task def serialize(self) -> tuple[str, dict[str, Any]]: - return ("airflow.triggers.temporal.DateTimeTrigger", {"moment": self.moment}) + return ( + "airflow.triggers.temporal.DateTimeTrigger", + {"moment": self.moment, "end_task": self.end_task}, + ) async def run(self) -> AsyncIterator[TriggerEvent]: """ @@ -70,9 +78,12 @@ async def run(self) -> AsyncIterator[TriggerEvent]: while self.moment > pendulum.instance(timezone.utcnow()): self.log.info("sleeping 1 second...") await asyncio.sleep(1) - # Send our single event and then we're done - self.log.info("yielding event with payload %r", self.moment) - yield TriggerEvent(self.moment) + if self.end_task: + self.log.info("Sensor time condition reached; marking task successful and exiting") + yield TaskSuccessEvent() + else: + self.log.info("yielding event with payload %r", self.moment) + yield TriggerEvent(self.moment) class TimeDeltaTrigger(DateTimeTrigger): @@ -84,7 +95,11 @@ class TimeDeltaTrigger(DateTimeTrigger): While this is its own distinct class here, it will serialise to a DateTimeTrigger class, since they're operationally the same. + + :param delta: how long to wait + :param end_task: whether the trigger should mark the task successful after time condition + reached or resume the task after time condition reached. """ - def __init__(self, delta: datetime.timedelta): - super().__init__(moment=timezone.utcnow() + delta) + def __init__(self, delta: datetime.timedelta, *, end_task=False): + super().__init__(moment=timezone.utcnow() + delta, end_task=end_task) diff --git a/tests/models/test_trigger.py b/tests/models/test_trigger.py index 6be2086f34112..76d858afb4f6c 100644 --- a/tests/models/test_trigger.py +++ b/tests/models/test_trigger.py @@ -26,13 +26,20 @@ from airflow.jobs.job import Job from airflow.jobs.triggerer_job_runner import TriggererJobRunner -from airflow.models import TaskInstance, Trigger +from airflow.models import TaskInstance, Trigger, XCom from airflow.operators.empty import EmptyOperator from airflow.serialization.serialized_objects import BaseSerialization -from airflow.triggers.base import BaseTrigger, TriggerEvent +from airflow.triggers.base import ( + BaseTrigger, + TaskFailedEvent, + TaskSkippedEvent, + TaskSuccessEvent, + TriggerEvent, +) from airflow.utils import timezone from airflow.utils.session import create_session from airflow.utils.state import State +from airflow.utils.xcom import XCOM_RETURN_KEY from tests.test_utils.config import conf_vars pytestmark = pytest.mark.db_test @@ -123,7 +130,7 @@ def test_submit_event(session, create_task_instance): def test_submit_failure(session, create_task_instance): """ Tests that failures submitted to a trigger fail their dependent - task instances. + task instances if not using a TaskEndEvent. """ # Make a trigger trigger = Trigger(classpath="airflow.triggers.testing.SuccessTrigger", kwargs={}) @@ -144,6 +151,58 @@ def test_submit_failure(session, create_task_instance): assert updated_task_instance.next_method == "__fail__" +@pytest.mark.parametrize( + "event_cls, expected", + [ + (TaskSuccessEvent, "success"), + (TaskFailedEvent, "failed"), + (TaskSkippedEvent, "skipped"), + ], +) +def test_submit_event_task_end(session, create_task_instance, event_cls, expected): + """ + Tests that events inheriting BaseTaskEndEvent *don't* re-wake their dependent + but mark them in the appropriate terminal state and send xcom + """ + # Make a trigger + trigger = Trigger(classpath="does.not.matter", kwargs={}) + trigger.id = 1 + session.add(trigger) + session.commit() + # Make a TaskInstance that's deferred and waiting on it + task_instance = create_task_instance( + session=session, execution_date=timezone.utcnow(), state=State.DEFERRED + ) + task_instance.trigger_id = trigger.id + session.commit() + + def get_xcoms(ti): + return XCom.get_many(dag_ids=[ti.dag_id], task_ids=[ti.task_id], run_id=ti.run_id).all() + + # now for the real test + # first check initial state + ti: TaskInstance = session.query(TaskInstance).one() + assert ti.state == "deferred" + assert get_xcoms(ti) == [] + + session.flush() + session.expunge_all() + # now, for each type, submit event + # verify that (1) task ends in right state and (2) xcom is pushed + Trigger.submit_event( + trigger.id, event_cls(xcoms={XCOM_RETURN_KEY: "xcomret", "a": "b", "c": "d"}), session=session + ) + # commit changes made by submit event and expire all cache to read from db. + session.flush() + session.expunge_all() + # Check that the task instance is now correct + ti = session.query(TaskInstance).one() + assert ti.state == expected + assert ti.next_kwargs is None + actual_xcoms = {x.key: x.value for x in get_xcoms(ti)} + assert actual_xcoms == {"return_value": "xcomret", "a": "b", "c": "d"} + + def test_assign_unassigned(session, create_task_instance): """ Tests that unassigned triggers of all appropriate states are assigned. diff --git a/tests/sensors/test_time_sensor.py b/tests/sensors/test_time_sensor.py index 54a0212a247a9..bcf4700742a14 100644 --- a/tests/sensors/test_time_sensor.py +++ b/tests/sensors/test_time_sensor.py @@ -63,7 +63,6 @@ def test_task_is_deferred(self): assert isinstance(exc_info.value.trigger, DateTimeTrigger) assert exc_info.value.trigger.moment == timezone.datetime(2020, 7, 7, 10) - assert exc_info.value.method_name == "execute_complete" assert exc_info.value.kwargs is None def test_target_time_aware(self): diff --git a/tests/triggers/test_temporal.py b/tests/triggers/test_temporal.py index 6e8d32c467e63..d1e2f6ad0706d 100644 --- a/tests/triggers/test_temporal.py +++ b/tests/triggers/test_temporal.py @@ -26,6 +26,7 @@ from airflow.triggers.base import TriggerEvent from airflow.triggers.temporal import DateTimeTrigger, TimeDeltaTrigger from airflow.utils import timezone +from airflow.utils.state import TaskInstanceState from airflow.utils.timezone import utcnow @@ -56,7 +57,7 @@ def test_datetime_trigger_serialization(): trigger = DateTimeTrigger(moment) classpath, kwargs = trigger.serialize() assert classpath == "airflow.triggers.temporal.DateTimeTrigger" - assert kwargs == {"moment": moment} + assert kwargs == {"moment": moment, "end_task": False} def test_timedelta_trigger_serialization(): @@ -74,15 +75,16 @@ def test_timedelta_trigger_serialization(): @pytest.mark.parametrize( - "tz", + "tz, end_task", [ - timezone.parse_timezone("UTC"), - timezone.parse_timezone("Europe/Paris"), - timezone.parse_timezone("America/Toronto"), + (pendulum.timezone("UTC"), True), + (pendulum.timezone("UTC"), False), # only really need to test one + (pendulum.timezone("Europe/Paris"), True), + (pendulum.timezone("America/Toronto"), True), ], ) @pytest.mark.asyncio -async def test_datetime_trigger_timing(tz): +async def test_datetime_trigger_timing(tz, end_task): """ Tests that the DateTimeTrigger only goes off on or after the appropriate time. @@ -91,7 +93,7 @@ async def test_datetime_trigger_timing(tz): future_moment = pendulum.instance((timezone.utcnow() + datetime.timedelta(seconds=60)).astimezone(tz)) # Create a task that runs the trigger for a short time then cancels it - trigger = DateTimeTrigger(future_moment) + trigger = DateTimeTrigger(future_moment, end_task=end_task) trigger_task = asyncio.create_task(trigger.run().__anext__()) await asyncio.sleep(0.5) @@ -100,14 +102,15 @@ async def test_datetime_trigger_timing(tz): trigger_task.cancel() # Now, make one waiting for en event in the past and do it again - trigger = DateTimeTrigger(past_moment) + trigger = DateTimeTrigger(past_moment, end_task=end_task) trigger_task = asyncio.create_task(trigger.run().__anext__()) await asyncio.sleep(0.5) assert trigger_task.done() is True result = trigger_task.result() assert isinstance(result, TriggerEvent) - assert result.payload == past_moment + expected_payload = TaskInstanceState.SUCCESS if end_task else past_moment + assert result.payload == expected_payload @mock.patch("airflow.triggers.temporal.timezone.utcnow")