Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Render YML DAG config as DAG Docs #305

Merged
merged 7 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion dagfactory/dagbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,15 @@ class DagBuilder:
in the YAML file
"""

def __init__(self, dag_name: str, dag_config: Dict[str, Any], default_config: Dict[str, Any]) -> None:
def __init__(
self, dag_name: str, dag_config: Dict[str, Any], default_config: Dict[str, Any], yml_dag: str = ""
) -> None:
self.dag_name: str = dag_name
self.dag_config: Dict[str, Any] = deepcopy(dag_config)
self.default_config: Dict[str, Any] = deepcopy(default_config)
self.tasks_count: int = 0
self.taskgroups_count: int = 0
self._yml_dag = yml_dag

# pylint: disable=too-many-branches,too-many-statements
def get_dag_params(self) -> Dict[str, Any]:
Expand Down Expand Up @@ -795,6 +798,15 @@ def build(self) -> Dict[str, Union[str, DAG]]:
)
dag.doc_md = doc_md_callable(**dag_params.get("doc_md_python_arguments", {}))

# Render YML DAG in DAG Docs
if self._yml_dag:
subtitle = "## YML DAG"

if dag.doc_md is None:
dag.doc_md = f"{subtitle}\n```yaml\n{self._yml_dag}\n```"
else:
dag.doc_md += f"\n{subtitle}\n```yaml\n{self._yml_dag}\n```"

# tags parameter introduced in Airflow 1.10.8
if version.parse(AIRFLOW_VERSION) >= version.parse("1.10.8"):
dag.tags = dag_params.get("tags", None)
Expand Down
22 changes: 22 additions & 0 deletions dagfactory/dagfactory.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,27 @@ def __init__(self, config_filepath: Optional[str] = None, config: Optional[dict]
if config:
self.config: Dict[str, Any] = config

@staticmethod
def _serialise_config_md(dag_name, dag_config, default_config):
# Remove empty task_groups if it exists
# We inject it if not supply by user
# https://github.com/astronomer/dag-factory/blob/e53b456d25917b746d28eecd1e896595ae0ee62b/dagfactory/dagfactory.py#L102
if dag_config.get("task_groups") == {}:
del dag_config["task_groups"]

tatiana marked this conversation as resolved.
Show resolved Hide resolved
# Convert default_config to YAML format
default_config = {"default": default_config}
default_config_yaml = yaml.dump(default_config, default_flow_style=False, allow_unicode=True, sort_keys=False)

# Convert dag_config to YAML format
dag_config = {dag_name: dag_config}
dag_config_yaml = yaml.dump(dag_config, default_flow_style=False, allow_unicode=True, sort_keys=False)

# Combine the two YAML outputs with appropriate formatting
dag_yml = default_config_yaml + "\n" + dag_config_yaml

return dag_yml

@staticmethod
def _validate_config_filepath(config_filepath: str) -> None:
"""
Expand Down Expand Up @@ -104,6 +125,7 @@ def build_dags(self) -> Dict[str, DAG]:
dag_name=dag_name,
dag_config=dag_config,
default_config=default_config,
yml_dag=self._serialise_config_md(dag_name, dag_config, default_config),
)
try:
dag: Dict[str, Union[str, DAG]] = dag_builder.build()
Expand Down
30 changes: 30 additions & 0 deletions tests/fixtures/dag_md_docs.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
default:
concurrency: 1
dagrun_timeout_sec: 600
default_args:
end_date: 2018-03-05
owner: default_owner
retries: 1
retry_delay_sec: 300
start_date: 2018-03-01
default_view: tree
max_active_runs: 1
orientation: LR
schedule_interval: 0 1 * * *

example_dag2:
schedule_interval: None
tasks:
task_1:
bash_command: echo 1
operator: airflow.operators.bash_operator.BashOperator
task_2:
bash_command: echo 2
dependencies:
- task_1
operator: airflow.operators.bash_operator.BashOperator
task_3:
bash_command: echo 3
dependencies:
- task_1
operator: airflow.operators.bash_operator.BashOperator
51 changes: 49 additions & 2 deletions tests/test_dagfactory.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,19 +334,56 @@ def test_variables_as_arguments_dag():


def test_doc_md_file_path():
dag_config = f"""
## YML DAG
```yaml
default:
concurrency: 1
dagrun_timeout_sec: 600
default_args:
end_date: 2018-03-05
owner: default_owner
retries: 1
retry_delay_sec: 300
start_date: 2018-03-01
default_view: tree
max_active_runs: 1
orientation: LR
schedule_interval: 0 1 * * *

example_dag2:
doc_md_file_path: {DOC_MD_FIXTURE_FILE}
schedule_interval: None
tasks:
task_1:
bash_command: echo 1
operator: airflow.operators.bash_operator.BashOperator
task_2:
bash_command: echo 2
dependencies:
- task_1
operator: airflow.operators.bash_operator.BashOperator
task_3:
bash_command: echo 3
dependencies:
- task_1
operator: airflow.operators.bash_operator.BashOperator

```"""

td = dagfactory.DagFactory(TEST_DAG_FACTORY)
td.generate_dags(globals())
generated_doc_md = globals()["example_dag2"].doc_md
with open(DOC_MD_FIXTURE_FILE, "r") as file:
expected_doc_md = file.read()
expected_doc_md = file.read() + dag_config
assert generated_doc_md == expected_doc_md


def test_doc_md_callable():
td = dagfactory.DagFactory(TEST_DAG_FACTORY)
td.generate_dags(globals())
expected_doc_md = globals()["example_dag3"].doc_md
assert str(td.get_dag_configs()["example_dag3"]["doc_md_python_arguments"]) == expected_doc_md
assert str(td.get_dag_configs()["example_dag3"]["doc_md_python_arguments"]) in expected_doc_md


def test_schedule_interval():
Expand Down Expand Up @@ -443,3 +480,13 @@ def test_load_yaml_dags_default_suffix_succeed(caplog):
dags_folder="tests/fixtures",
)
assert "Loading DAGs from tests/fixtures" in caplog.messages


def test_yml_dag_rendering_in_docs():
dag_path = os.path.join(here, "fixtures/dag_md_docs.yml")
td = dagfactory.DagFactory(dag_path)
td.generate_dags(globals())
generated_doc_md = globals()["example_dag2"].doc_md
with open(dag_path, "r") as file:
expected_doc_md = "## YML DAG\n```yaml\n" + file.read() + "\n```"
assert generated_doc_md == expected_doc_md
Loading