Skip to content

Commit

Permalink
Persist DAG and task doc values in TaskFlow API if explicitly set (#2…
Browse files Browse the repository at this point in the history
…9399)

If users set `doc_md` arg on `@dag`- or `@task`-decorated TaskFlow functions and those functions have a docstring, the `doc_md` value is not respected. Instead this PR will enforce only using the TaskFlow function docstring as the documentation if `doc_md` (or any `doc*` attrs for tasks) is not set.

(cherry picked from commit e02bfc8)
  • Loading branch information
josh-fell authored and pierrejeambrun committed Mar 7, 2023
1 parent 1dc9d73 commit 0fcfef8
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 53 deletions.
4 changes: 3 additions & 1 deletion airflow/decorators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,9 @@ def __call__(self, *args: FParams.args, **kwargs: FParams.kwargs) -> XComArg:
multiple_outputs=self.multiple_outputs,
**self.kwargs,
)
if self.function.__doc__:
op_doc_attrs = [op.doc, op.doc_json, op.doc_md, op.doc_rst, op.doc_yaml]
# Set the task's doc_md to the function's docstring if it exists and no other doc* args are set.
if self.function.__doc__ and not any(op_doc_attrs):
op.doc_md = self.function.__doc__
return XComArg(op)

Expand Down
4 changes: 2 additions & 2 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -3544,8 +3544,8 @@ def factory(*args, **kwargs):
owner_links=owner_links,
auto_register=auto_register,
) as dag_obj:
# Set DAG documentation from function documentation.
if f.__doc__:
# Set DAG documentation from function documentation if it exists and doc_md is not set.
if f.__doc__ and not dag_obj.doc_md:
dag_obj.doc_md = f.__doc__

# Generate DAGParam for each function arg/kwarg and replace it for calling the function.
Expand Down
25 changes: 18 additions & 7 deletions tests/decorators/test_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,21 +422,32 @@ def add_2(number: int):

assert "add_2" in self.dag.task_ids

def test_task_documentation(self):
"""Tests that task_decorator loads doc_md from function doc"""
@pytest.mark.parametrize(
argnames=["op_doc_attr", "op_doc_value", "expected_doc_md"],
argvalues=[
pytest.param("doc", "task docs.", None, id="set_doc"),
pytest.param("doc_json", '{"task": "docs."}', None, id="set_doc_json"),
pytest.param("doc_md", "task docs.", "task docs.", id="set_doc_md"),
pytest.param("doc_rst", "task docs.", None, id="set_doc_rst"),
pytest.param("doc_yaml", "task:\n\tdocs", None, id="set_doc_yaml"),
pytest.param("doc_md", None, "Adds 2 to number.", id="no_doc_md_use_docstring"),
],
)
def test_task_documentation(self, op_doc_attr, op_doc_value, expected_doc_md):
"""Tests that task_decorator loads doc_md from function doc if doc_md is not explicitly provided."""
kwargs = {}
kwargs[op_doc_attr] = op_doc_value

@task_decorator
@task_decorator(**kwargs)
def add_2(number: int):
"""
Adds 2 to number.
"""
"""Adds 2 to number."""
return number + 2

test_number = 10
with self.dag:
ret = add_2(test_number)

assert ret.operator.doc_md.strip(), "Adds 2 to number."
assert ret.operator.doc_md == expected_doc_md

def test_user_provided_task_id_in_a_loop_is_used(self):
"""Tests that when looping that user provided task_id is used"""
Expand Down
74 changes: 31 additions & 43 deletions tests/models/test_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -2682,58 +2682,49 @@ def noop_pipeline():

dag = noop_pipeline()
assert isinstance(dag, DAG)
assert dag.dag_id, "noop_pipeline"
assert dag.dag_id == "noop_pipeline"
assert dag.fileloc == __file__

def test_set_dag_id(self):
"""Test that checks you can set dag_id from decorator."""

@dag_decorator("test", default_args=self.DEFAULT_ARGS)
def noop_pipeline():
@task_decorator
def return_num(num):
return num

return_num(4)
...

dag = noop_pipeline()
assert isinstance(dag, DAG)
assert dag.dag_id, "test"
assert dag.dag_id == "test"

def test_default_dag_id(self):
"""Test that @dag uses function name as default dag id."""

@dag_decorator(default_args=self.DEFAULT_ARGS)
def noop_pipeline():
@task_decorator
def return_num(num):
return num

return_num(4)
...

dag = noop_pipeline()
assert isinstance(dag, DAG)
assert dag.dag_id, "noop_pipeline"
assert dag.dag_id == "noop_pipeline"

def test_documentation_added(self):
"""Test that @dag uses function docs as doc_md for DAG object"""
@pytest.mark.parametrize(
argnames=["dag_doc_md", "expected_doc_md"],
argvalues=[
pytest.param("dag docs.", "dag docs.", id="use_dag_doc_md"),
pytest.param(None, "Regular DAG documentation", id="use_dag_docstring"),
],
)
def test_documentation_added(self, dag_doc_md, expected_doc_md):
"""Test that @dag uses function docs as doc_md for DAG object if doc_md is not explicitly set."""

@dag_decorator(default_args=self.DEFAULT_ARGS)
@dag_decorator(default_args=self.DEFAULT_ARGS, doc_md=dag_doc_md)
def noop_pipeline():
"""
Regular DAG documentation
"""

@task_decorator
def return_num(num):
return num

return_num(4)
"""Regular DAG documentation"""

dag = noop_pipeline()
assert isinstance(dag, DAG)
assert dag.dag_id, "test"
assert dag.doc_md.strip(), "Regular DAG documentation"
assert dag.dag_id == "noop_pipeline"
assert dag.doc_md == expected_doc_md

def test_documentation_template_rendered(self):
"""Test that @dag uses function docs as doc_md for DAG object"""
Expand All @@ -2746,16 +2737,10 @@ def noop_pipeline():
{% endif %}
"""

@task_decorator
def return_num(num):
return num

return_num(4)

dag = noop_pipeline()
assert isinstance(dag, DAG)
assert dag.dag_id, "test"
assert dag.doc_md.strip(), "Regular DAG documentation"
assert dag.dag_id == "noop_pipeline"
assert "Regular DAG documentation" in dag.doc_md

def test_resolve_documentation_template_file_rendered(self):
"""Test that @dag uses function docs as doc_md for DAG object"""
Expand All @@ -2769,16 +2754,19 @@ def test_resolve_documentation_template_file_rendered(self):
"""
)
f.flush()
template_dir = os.path.dirname(f.name)
template_file = os.path.basename(f.name)

with DAG("test-dag", start_date=DEFAULT_DATE, doc_md=template_file) as dag:
task = EmptyOperator(task_id="op1")

task
@dag_decorator(
"test-dag", start_date=DEFAULT_DATE, template_searchpath=template_dir, doc_md=template_file
)
def markdown_docs():
...

assert isinstance(dag, DAG)
assert dag.dag_id, "test"
assert dag.doc_md.strip(), "External Markdown DAG documentation"
dag = markdown_docs()
assert isinstance(dag, DAG)
assert dag.dag_id == "test-dag"
assert dag.doc_md.strip() == "External Markdown DAG documentation"

def test_fails_if_arg_not_set(self):
"""Test that @dag decorated function fails if positional argument is not set"""
Expand Down Expand Up @@ -2848,7 +2836,7 @@ def return_num(num):

self.operator.run(start_date=self.DEFAULT_DATE, end_date=self.DEFAULT_DATE)
ti = dr.get_task_instances()[0]
assert ti.xcom_pull(), new_value
assert ti.xcom_pull() == new_value

@pytest.mark.parametrize("value", [VALUE, 0])
def test_set_params_for_dag(self, value):
Expand Down

0 comments on commit 0fcfef8

Please sign in to comment.