Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make task to go to terminal state from triggerer without needing a worker #1509

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions airflow/callbacks/callback_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

if TYPE_CHECKING:
from airflow.models.taskinstance import SimpleTaskInstance
from airflow.utils.state import TaskInstanceState


class CallbackRequest:
Expand Down Expand Up @@ -71,6 +72,7 @@ class TaskCallbackRequest(CallbackRequest):
:param is_failure_callback: Flag to determine whether it is a Failure Callback or Success Callback
:param msg: Additional Message that can be used for logging to determine failure/zombie
:param processor_subdir: Directory used by Dag Processor when parsed the dag.
:param task_callback_type: e.g. whether on success, on failure, on retry.
"""

def __init__(
Expand All @@ -80,10 +82,12 @@ def __init__(
is_failure_callback: bool | None = True,
processor_subdir: str | None = None,
msg: str | None = None,
task_callback_type: TaskInstanceState | None = None,
):
super().__init__(full_filepath=full_filepath, processor_subdir=processor_subdir, msg=msg)
self.simple_task_instance = simple_task_instance
self.is_failure_callback = is_failure_callback
self.task_callback_type = task_callback_type

def to_json(self) -> str:
from airflow.serialization.serialized_objects import BaseSerialization
Expand Down
7 changes: 5 additions & 2 deletions airflow/dag_processing/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@

from sqlalchemy.orm import Session

logger = logging.getLogger(__name__)


class DagParsingStat(NamedTuple):
"""Information on processing progress."""
Expand Down Expand Up @@ -574,6 +576,7 @@ def _run_parsing_loop(self):
pass
elif isinstance(agent_signal, CallbackRequest):
self._add_callback_to_queue(agent_signal)
self.log.warning("_add_callback_to_queue; agent signal; %s", agent_signal)
else:
raise ValueError(f"Invalid message {type(agent_signal)}")

Expand Down Expand Up @@ -676,7 +679,7 @@ def _fetch_callbacks(self, max_callbacks: int, session: Session = NEW_SESSION):
@retry_db_transaction
def _fetch_callbacks_with_retries(self, max_callbacks: int, session: Session):
"""Fetch callbacks from database and add them to the internal queue for execution."""
self.log.debug("Fetching callbacks from the database.")
self.log.warning("Fetching callbacks from the database.")
with prohibit_commit(session) as guard:
query = select(DbCallbackRequest)
if self.standalone_dag_processor:
Expand Down Expand Up @@ -761,7 +764,7 @@ def _refresh_dag_dir(self) -> bool:
self.set_file_paths(self._file_paths)

try:
self.log.debug("Removing old import errors")
self.log.warning("Removing old import errors")
DagFileProcessorManager.clear_nonexistent_import_errors(
file_paths=self._file_paths, processor_subdir=self.get_dag_directory()
)
Expand Down
41 changes: 36 additions & 5 deletions airflow/dag_processing/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
from airflow.models.dagwarning import DagWarning, DagWarningType
from airflow.models.errors import ParseImportError
from airflow.models.serialized_dag import SerializedDagModel
from airflow.models.taskinstance import TaskInstance, TaskInstance as TI
from airflow.models.taskinstance import TaskInstance, TaskInstance as TI, _run_finished_callback
from airflow.stats import Stats
from airflow.utils import timezone
from airflow.utils.email import get_email_address_list, send_email
Expand Down Expand Up @@ -762,8 +762,29 @@ def _execute_dag_callbacks(self, dagbag: DagBag, request: DagCallbackRequest, se
if callbacks and context:
DAG.execute_callback(callbacks, context, dag.dag_id)

def _execute_task_callbacks(self, dagbag: DagBag | None, request: TaskCallbackRequest, session: Session):
if not request.is_failure_callback:
def _execute_task_callbacks(
self, dagbag: DagBag | None, request: TaskCallbackRequest, session: Session
) -> None:
"""
Execute the task callbacks.

:param dagbag: the DagBag to use to get the task instance
:param request: the task callback request
:param session: the session to use
"""
try:
callback_type = TaskInstanceState(request.task_callback_type)
except Exception:
callback_type = None
is_remote = callback_type in (TaskInstanceState.SUCCESS, TaskInstanceState.FAILED)

# previously we ignored any request besides failures. now if given callback type directly,
# then we respect it and execute it. additionally because in this scenario the callback
# is submitted remotely, we assume there is no need to mess with state; we simply run
# the callback

if not is_remote and not request.is_failure_callback:
self.log.warning("not failure callback: %s", request)
return

simple_ti = request.simple_task_instance
Expand All @@ -774,6 +795,7 @@ def _execute_task_callbacks(self, dagbag: DagBag | None, request: TaskCallbackRe
map_index=simple_ti.map_index,
session=session,
)

if not ti:
return

Expand All @@ -795,8 +817,17 @@ def _execute_task_callbacks(self, dagbag: DagBag | None, request: TaskCallbackRe
if task:
ti.refresh_from_task(task)

ti.handle_failure(error=request.msg, test_mode=self.UNIT_TEST_MODE, session=session)
self.log.info("Executed failure callback for %s in state %s", ti, ti.state)
if callback_type is TaskInstanceState.SUCCESS:
context = ti.get_template_context(session=session)
if callback_type is TaskInstanceState.SUCCESS:
if not ti.task:
return
callbacks = ti.task.on_success_callback
_run_finished_callback(callbacks=callbacks, context=context)
self.log.info("Executed callback for %s in state %s", ti, ti.state)
elif not is_remote or callback_type is TaskInstanceState.FAILED:
ti.handle_failure(error=request.msg, test_mode=self.UNIT_TEST_MODE, session=session)
self.log.info("Executed callback for %s in state %s", ti, ti.state)
session.flush()

@classmethod
Expand Down
10 changes: 8 additions & 2 deletions airflow/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,14 +372,20 @@ class TaskDeferred(BaseException):
Signal an operator moving to deferred state.

Special exception raised to signal that the operator it was raised from
wishes to defer until a trigger fires.
wishes to defer until a trigger fires. Triggers can send execution back to task or end the task instance
directly. If the trigger will end the task instance itself, ``method_name`` should be
None; otherwise, provide the name of the method that should be used when
resuming execution in the task.
"""

TRIGGER_EXIT = "__trigger_exit__"
"""Sentinel value to signal the expectation that the trigger will exit the task."""

def __init__(
self,
*,
trigger,
method_name: str,
method_name: str = TRIGGER_EXIT,
kwargs: dict[str, Any] | None = None,
timeout: datetime.timedelta | None = None,
):
Expand Down
7 changes: 5 additions & 2 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1704,15 +1704,18 @@ def defer(
self,
*,
trigger: BaseTrigger,
method_name: str,
method_name: str = TaskDeferred.TRIGGER_EXIT,
kwargs: dict[str, Any] | None = None,
timeout: timedelta | None = None,
) -> NoReturn:
"""
Mark this Operator "deferred", suspending its execution until the provided trigger fires an event.

This is achieved by raising a special exception (TaskDeferred)
which is caught in the main _execute_task wrapper.
which is caught in the main _execute_task wrapper. Triggers can send execution back to task or end
the task instance directly. If the trigger will end the task instance itself, ``method_name`` should
be None; otherwise, provide the name of the method that should be used when resuming execution in
the task.
"""
raise TaskDeferred(trigger=trigger, method_name=method_name, kwargs=kwargs, timeout=timeout)

Expand Down
6 changes: 6 additions & 0 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,12 @@ def _execute_task(task_instance: TaskInstance | TaskInstancePydantic, context: C
execute_callable = task_to_execute.resume_execution
execute_callable_kwargs["next_method"] = task_instance.next_method
execute_callable_kwargs["next_kwargs"] = task_instance.next_kwargs
elif task_instance.next_method == TaskDeferred.TRIGGER_EXIT:
raise AirflowException(
"Task is resuming from deferral without next_method specified. "
"You must either set `method_name` when deferring, or use a trigger "
"that is designed to exit the task."
)
else:
execute_callable = task_to_execute.execute
if execute_callable.__name__ == "execute":
Expand Down
9 changes: 1 addition & 8 deletions airflow/models/trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,14 +202,7 @@ def submit_event(cls, trigger_id, event, session: Session = NEW_SESSION) -> None
TaskInstance.trigger_id == trigger_id, TaskInstance.state == TaskInstanceState.DEFERRED
)
):
# Add the event's payload into the kwargs for the task
next_kwargs = task_instance.next_kwargs or {}
next_kwargs["event"] = event.payload
task_instance.next_kwargs = next_kwargs
# Remove ourselves as its trigger
task_instance.trigger_id = None
# Finally, mark it as scheduled so it gets re-queued
task_instance.state = TaskInstanceState.SCHEDULED
event.handle_submit(task_instance=task_instance)

@classmethod
@internal_api_call
Expand Down
14 changes: 3 additions & 11 deletions airflow/sensors/date_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from __future__ import annotations

import datetime
from typing import TYPE_CHECKING, NoReturn, Sequence
from typing import TYPE_CHECKING, Sequence

from airflow.sensors.base import BaseSensorOperator
from airflow.triggers.temporal import DateTimeTrigger
Expand Down Expand Up @@ -90,13 +90,5 @@ class DateTimeSensorAsync(DateTimeSensor):
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)

def execute(self, context: Context) -> NoReturn:
trigger = DateTimeTrigger(moment=timezone.parse(self.target_time))
self.defer(
trigger=trigger,
method_name="execute_complete",
)

def execute_complete(self, context, event=None) -> None:
"""Execute when the trigger fires - returns immediately."""
return None
def execute(self, context: Context):
self.defer(trigger=DateTimeTrigger(moment=timezone.parse(self.target_time), end_task=True))
12 changes: 4 additions & 8 deletions airflow/sensors/time_delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# under the License.
from __future__ import annotations

from typing import TYPE_CHECKING, NoReturn
from typing import TYPE_CHECKING

from airflow.exceptions import AirflowSkipException
from airflow.sensors.base import BaseSensorOperator
Expand Down Expand Up @@ -66,18 +66,14 @@ class TimeDeltaSensorAsync(TimeDeltaSensor):

"""

def execute(self, context: Context) -> NoReturn:
def execute(self, context: Context):
target_dttm = context["data_interval_end"]
target_dttm += self.delta
try:
trigger = DateTimeTrigger(moment=target_dttm)
trigger = DateTimeTrigger(moment=target_dttm, end_task=True)
except (TypeError, ValueError) as e:
if self.soft_fail:
raise AirflowSkipException("Skipping due to soft_fail is set to True.") from e
raise

self.defer(trigger=trigger, method_name="execute_complete")

def execute_complete(self, context, event=None) -> None:
"""Execute for when the trigger fires - return immediately."""
return None
self.defer(trigger=trigger)
20 changes: 6 additions & 14 deletions airflow/sensors/time_sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from __future__ import annotations

import datetime
from typing import TYPE_CHECKING, NoReturn
from typing import TYPE_CHECKING

from airflow.sensors.base import BaseSensorOperator
from airflow.triggers.temporal import DateTimeTrigger
Expand All @@ -40,11 +40,11 @@ class TimeSensor(BaseSensorOperator):

"""

def __init__(self, *, target_time: datetime.time, **kwargs) -> None:
def __init__(self, *, target_time, **kwargs):
super().__init__(**kwargs)
self.target_time = target_time

def poke(self, context: Context) -> bool:
def poke(self, context: Context):
self.log.info("Checking if the time (%s) has come", self.target_time)
return timezone.make_naive(timezone.utcnow(), self.dag.timezone).time() > self.target_time

Expand All @@ -62,7 +62,7 @@ class TimeSensorAsync(BaseSensorOperator):
:ref:`howto/operator:TimeSensorAsync`
"""

def __init__(self, *, target_time: datetime.time, **kwargs) -> None:
def __init__(self, *, target_time, **kwargs):
super().__init__(**kwargs)
self.target_time = target_time

Expand All @@ -72,13 +72,5 @@ def __init__(self, *, target_time: datetime.time, **kwargs) -> None:

self.target_datetime = timezone.convert_to_utc(aware_time)

def execute(self, context: Context) -> NoReturn:
trigger = DateTimeTrigger(moment=self.target_datetime)
self.defer(
trigger=trigger,
method_name="execute_complete",
)

def execute_complete(self, context, event=None) -> None:
"""Execute when the trigger fires - returns immediately."""
return None
def execute(self, context: Context):
self.defer(trigger=DateTimeTrigger(moment=self.target_datetime, end_task=True))
Loading
Loading