Skip to content

Commit

Permalink
Merge pull request #100 from astronomer/starship_dag_migration_operator
Browse files Browse the repository at this point in the history
Starship Operator
  • Loading branch information
fritz-astronomer authored May 31, 2024
2 parents f5ac3e9 + 559da6d commit ad571dc
Show file tree
Hide file tree
Showing 15 changed files with 698 additions and 295 deletions.
Empty file.
Empty file.
Empty file.
459 changes: 247 additions & 212 deletions astronomer_starship/providers/starship/hooks/starship.py

Large diffs are not rendered by default.

Empty file.
350 changes: 339 additions & 11 deletions astronomer_starship/providers/starship/operators/starship.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,349 @@
"""Operators, TaskGroups, and DAGs for interacting with the Starship migrations."""
from datetime import datetime
from typing import Any, Union, List

import airflow
from airflow import DAG
from airflow.decorators import task
from airflow.exceptions import AirflowSkipException
from airflow.models.baseoperator import BaseOperator
from airflow.utils.context import Context
from airflow.utils.task_group import TaskGroup

from astronomer_starship.providers.starship.hooks.starship import (
StarshipLocalHook,
StarshipHttpHook,
)

# Compatability Notes:
# - @task() is >=AF2.0
# - @task_group is >=AF2.1
# - Dynamic Task Mapping is >=AF2.3
# - Dynamic Task Mapping labelling is >=AF2.9


class StarshipMigrationOperator(BaseOperator):
def __init__(self, http_conn_id=None, **kwargs):
super().__init__(**kwargs)
self.source_hook = StarshipLocalHook()
self.target_hook = StarshipHttpHook(http_conn_id=http_conn_id)


class StarshipVariableMigrationOperator(StarshipMigrationOperator):
"""Operator to migrate a single Variable from one Airflow instance to another."""

def __init__(self, variable_key: Union[str, None] = None, **kwargs):
super().__init__(**kwargs)
self.variable_key = variable_key

def execute(self, context: Context) -> Any:
print("Getting Variable", self.variable_key)
variables = self.source_hook.get_variables()
variable: Union[dict, None] = (
[v for v in variables if v["key"] == self.variable_key] or [None]
)[0]
if variable is not None:
print("Migrating Variable", self.variable_key)
self.target_hook.set_variable(**variable)
else:
raise RuntimeError("Variable not found! " + self.variable_key)


def starship_variables_migration(variables: List[str] = None, **kwargs):
"""TaskGroup to fetch and migrate Variables from one Airflow instance to another."""
with TaskGroup("variables") as tg:

@task()
def get_variables():
_variables = StarshipLocalHook().get_variables()

_variables = (
[k["key"] for k in _variables if k["key"] in variables]
if variables is not None
else [k["key"] for k in _variables]
)

if not len(_variables):
raise AirflowSkipException("Nothing to migrate")
return _variables

variables_results = get_variables()
if airflow.__version__ >= "2.3.0":
StarshipVariableMigrationOperator.partial(
task_id="migrate_variables", **kwargs
).expand(variable_key=variables_results)
else:
for variable in variables_results.output:
variables_results >> StarshipVariableMigrationOperator(
task_id="migrate_variable_" + variable,
variable_key=variable,
**kwargs,
)
return tg


class StarshipPoolMigrationOperator(StarshipMigrationOperator):
"""Operator to migrate a single Pool from one Airflow instance to another."""

def __init__(self, pool_name: Union[str, None] = None, **kwargs):
super().__init__(**kwargs)
self.pool_name = pool_name

def execute(self, context: Context) -> Any:
print("Getting Pool", self.pool_name)
pool: Union[dict, None] = (
[v for v in self.source_hook.get_pools() if v["name"] == self.pool_name]
or [None]
)[0]
if pool is not None:
print("Migrating Pool", self.pool_name)
self.target_hook.set_pool(**pool)
else:
raise RuntimeError("Pool not found!")


def starship_pools_migration(pools: List[str] = None, **kwargs):
"""TaskGroup to fetch and migrate Pools from one Airflow instance to another."""
with TaskGroup("pools") as tg:

@task()
def get_pools():
_pools = StarshipLocalHook().get_pools()
_pools = (
[k["name"] for k in _pools if k["name"] in pools]
if pools is not None
else [k["name"] for k in _pools]
)

if not len(_pools):
raise AirflowSkipException("Nothing to migrate")
return _pools

pools_result = get_pools()
if airflow.__version__ >= "2.3.0":
StarshipPoolMigrationOperator.partial(
task_id="migrate_pools", **kwargs
).expand(pool_name=pools_result)
else:
for pool in pools_result.output:
pools_result >> StarshipPoolMigrationOperator(
task_id="migrate_pool_" + pool, pool_name=pool, **kwargs
)
return tg


from astronomer_starship.compat.starship_hook import StarshipDagRunMigrationHook
class StarshipConnectionMigrationOperator(StarshipMigrationOperator):
"""Operator to migrate a single Connection from one Airflow instance to another."""

def __init__(self, connection_id: Union[str, None] = None, **kwargs):
super().__init__(**kwargs)
self.connection_id = connection_id

class StarshipOperator(BaseOperator):
def execute(self, context: Context) -> Any:
print("Getting Connection", self.connection_id)
connection: Union[dict, None] = (
[
v
for v in self.source_hook.get_connections()
if v["conn_id"] == self.connection_id
]
or [None]
)[0]
if connection is not None:
print("Migrating Connection", self.connection_id)
self.target_hook.set_connection(**connection)
else:
raise RuntimeError("Connection not found!")


def starship_connections_migration(connections: List[str] = None, **kwargs):
"""TaskGroup to fetch and migrate Connections from one Airflow instance to another."""
with TaskGroup("connections") as tg:

@task()
def get_connections():
_connections = StarshipLocalHook().get_connections()
_connections = (
[k["conn_id"] for k in _connections if k["conn_id"] in connections]
if connections is not None
else [k["conn_id"] for k in _connections]
)

if not len(_connections):
raise AirflowSkipException("Nothing to migrate")
return _connections

connections_result = get_connections()
if airflow.__version__ >= "2.3.0":
StarshipConnectionMigrationOperator.partial(
task_id="migrate_connections", **kwargs
).expand(connection_id=connections_result)
else:
for connection in connections_result.output:
connections_result >> StarshipConnectionMigrationOperator(
task_id="migrate_connection_" + connection.conn_id,
connection_id=connection,
**kwargs,
)
return tg


class StarshipDagHistoryMigrationOperator(StarshipMigrationOperator):
"""Operator to migrate a single DAG from one Airflow instance to another, with it's history."""

def __init__(
self,
target_dag_id: str,
unpause_dag_in_target: bool = False,
dag_run_limit: int = 10,
**kwargs,
):
super().__init__(**kwargs)
self.target_dag_id = target_dag_id
self.unpause_dag_in_target = unpause_dag_in_target
self.dag_run_limit = dag_run_limit

def execute(self, context):
print("Pausing local DAG for", self.target_dag_id)
self.source_hook.set_dag_is_paused(dag_id=self.target_dag_id, is_paused=True)
# TODO - Poll until all tasks are done

print("Getting local DAG Runs for", self.target_dag_id)
dag_runs = self.source_hook.get_dag_runs(
dag_id=self.target_dag_id, limit=self.dag_run_limit
)
if len(dag_runs["dag_runs"]) == 0:
raise AirflowSkipException("No DAG Runs found for " + self.target_dag_id)

print("Getting local Task Instances for", self.target_dag_id)
task_instances = self.source_hook.get_task_instances(
dag_id=self.target_dag_id, limit=self.dag_run_limit
)
if len(task_instances["task_instances"]) == 0:
raise AirflowSkipException(
"No Task Instances found for " + self.target_dag_id
)

print("Setting target DAG Runs for", self.target_dag_id)
self.target_hook.set_dag_runs(dag_runs=dag_runs["dag_runs"])

print("Setting target Task Instances for", self.target_dag_id)
self.target_hook.set_task_instances(
task_instances=task_instances["task_instances"]
)

if self.unpause_dag_in_target:
print("Unpausing target DAG for", self.target_dag_id)
self.target_hook.set_dag_is_paused(
dag_id=self.target_dag_id, is_paused=False
)


def starship_dag_history_migration(dag_ids: List[str] = None, **kwargs):
"""TaskGroup to fetch and migrate DAGs with their history from one Airflow instance to another."""
with TaskGroup("dag_history") as tg:

@task()
def get_dags():
_dags = StarshipLocalHook().get_dags()
_dags = (
[
k["dag_id"]
for k in _dags
if k["dag_id"] in dag_ids
and k["dag_id"] != "StarshipAirflowMigrationDAG"
]
if dag_ids is not None
else [
k["dag_id"]
for k in _dags
if k["dag_id"] != "StarshipAirflowMigrationDAG"
]
)

if not len(_dags):
raise AirflowSkipException("Nothing to migrate")
return _dags

dags_result = get_dags()
if airflow.__version__ >= "2.3.0":
StarshipDagHistoryMigrationOperator.partial(
task_id="migrate_dag_ids",
**(
{"map_index_template": "{{ task.target_dag_id }}"}
if airflow.__version__ >= "2.9.0"
else {}
),
**kwargs,
).expand(target_dag_id=dags_result)
else:
for dag_id in dags_result.output:
dags_result >> StarshipDagHistoryMigrationOperator(
task_id="migrate_dag_" + dag_id, target_dag_id=dag_id, **kwargs
)
return tg


# noinspection PyPep8Naming
def StarshipAirflowMigrationDAG(
http_conn_id: str,
variables: List[str] = None,
pools: List[str] = None,
connections: List[str] = None,
dag_ids: List[str] = None,
**kwargs,
):
"""
Migrate dag run and task run history by using this operator as a task in a DAG
DAG to fetch and migrate Variables, Pools, Connections, and DAGs with history from one Airflow instance to another.
"""
dag = DAG(
dag_id="starship_airflow_migration_dag",
schedule="@once",
start_date=datetime(1970, 1, 1),
tags=["migration", "starship"],
default_args={"owner": "Astronomer"},
doc_md="""
# Starship Migration DAG
A DAG to migrate Airflow Variables, Pools, Connections, and DAG History from one Airflow instance to another.
def __init__(self, hook: StarshipDagRunMigrationHook = None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.hook = hook
You can use this DAG to migrate all items, or specific items by providing a list of names.
def execute(self, context):
conf = context["conf"]
You can skip migration by providing an empty list.
## Setup:
Make a connection in Airflow with the following details:
- **Conn ID**: `starship_default`
- **Conn Type**: `HTTP`
- **Host**: the URL of the homepage of Airflow (excluding `/home` on the end of the URL)
- For example, if your deployment URL is `https://astronomer.astronomer.run/abcdt4ry/home`, you'll use `https://astronomer.astronomer.run/abcdt4ry`
- **Schema**: `https`
- **Extras**: `{"Authorization": "Bearer <token>"}`
if not self.hook:
self.hook = StarshipDagRunMigrationHook(**conf)
## Usage:
```python
from astronomer_starship.providers.starship.operators.starship import (
StarshipAirflowMigrationDAG,
)
return self.hook.load_dagruns_to_target(dag_ids=conf.get("dag_ids"))
globals()["starship_airflow_migration_dag"] = StarshipAirflowMigrationDAG(
http_conn_id="starship_default",
variables=None, # None to migrate all, or ["var1", "var2"] to migrate specific items, or empty list to skip all
pools=None, # None to migrate all, or ["pool1", "pool2"] to migrate specific items, or empty list to skip all
connections=None, # None to migrate all, or ["conn1", "conn2"] to migrate specific items, or empty list to skip all
dag_ids=None, # None to migrate all, or ["dag1", "dag2"] to migrate specific items, or empty list to skip all
)
```
""", # noqa: E501
)
with dag:
starship_variables_migration(
variables=variables, http_conn_id=http_conn_id, **kwargs
)
starship_pools_migration(pools=pools, http_conn_id=http_conn_id, **kwargs)
starship_connections_migration(
connections=connections, http_conn_id=http_conn_id, **kwargs
)
starship_dag_history_migration(
dag_ids=dag_ids, http_conn_id=http_conn_id, **kwargs
)
return dag
Loading

0 comments on commit ad571dc

Please sign in to comment.