From 0eb2c5ecd5f96373a2615072441be5b7f24da46a Mon Sep 17 00:00:00 2001 From: Josh Fell Date: Tue, 7 Feb 2023 00:05:13 -0500 Subject: [PATCH] Persist DAG and task doc values in TaskFlow API if explicitly set If users set `doc_md` arg on `@dag`- or `@task`-decorated TaskFlow functions and those functions have a docstring, the `doc_md` value is not respected. Instead this PR will enforce only using the TaskFlow function docstring as the documentation if `doc_md` (or any `doc*` attrs for tasks) is not set. --- airflow/decorators/base.py | 4 +- airflow/models/dag.py | 4 +- tests/decorators/test_python.py | 25 +++++++---- tests/models/test_dag.py | 74 ++++++++++++++------------------- 4 files changed, 54 insertions(+), 53 deletions(-) diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py index 02db011801c07f..ac12adca47771e 100644 --- a/airflow/decorators/base.py +++ b/airflow/decorators/base.py @@ -333,7 +333,9 @@ def __call__(self, *args: FParams.args, **kwargs: FParams.kwargs) -> XComArg: multiple_outputs=self.multiple_outputs, **self.kwargs, ) - if self.function.__doc__: + op_doc_attrs = [op.doc, op.doc_json, op.doc_md, op.doc_rst, op.doc_yaml] + # Set the task's doc_md to the function's docstring if it exists and no other doc* args are set. + if self.function.__doc__ and not any(op_doc_attrs): op.doc_md = self.function.__doc__ return XComArg(op) diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 55c6c81eec57bd..268637fe256150 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -3546,8 +3546,8 @@ def factory(*args, **kwargs): owner_links=owner_links, auto_register=auto_register, ) as dag_obj: - # Set DAG documentation from function documentation. - if f.__doc__: + # Set DAG documentation from function documentation if it exists and doc_md is not set. + if f.__doc__ and not dag_obj.doc_md: dag_obj.doc_md = f.__doc__ # Generate DAGParam for each function arg/kwarg and replace it for calling the function. diff --git a/tests/decorators/test_python.py b/tests/decorators/test_python.py index 4db7d2f4479166..20dc883106ed0b 100644 --- a/tests/decorators/test_python.py +++ b/tests/decorators/test_python.py @@ -461,21 +461,32 @@ def add_2(number: int): assert "add_2" in self.dag.task_ids - def test_task_documentation(self): - """Tests that task_decorator loads doc_md from function doc""" + @pytest.mark.parametrize( + argnames=["op_doc_attr", "op_doc_value", "expected_doc_md"], + argvalues=[ + pytest.param("doc", "task docs.", None, id="set_doc"), + pytest.param("doc_json", '{"task": "docs."}', None, id="set_doc_json"), + pytest.param("doc_md", "task docs.", "task docs.", id="set_doc_md"), + pytest.param("doc_rst", "task docs.", None, id="set_doc_rst"), + pytest.param("doc_yaml", "task:\n\tdocs", None, id="set_doc_yaml"), + pytest.param("doc_md", None, "Adds 2 to number.", id="no_doc_md_use_docstring"), + ], + ) + def test_task_documentation(self, op_doc_attr, op_doc_value, expected_doc_md): + """Tests that task_decorator loads doc_md from function doc if doc_md is not explicitly provided.""" + kwargs = {} + kwargs[op_doc_attr] = op_doc_value - @task_decorator + @task_decorator(**kwargs) def add_2(number: int): - """ - Adds 2 to number. - """ + """Adds 2 to number.""" return number + 2 test_number = 10 with self.dag: ret = add_2(test_number) - assert ret.operator.doc_md.strip(), "Adds 2 to number." + assert ret.operator.doc_md == expected_doc_md def test_user_provided_task_id_in_a_loop_is_used(self): """Tests that when looping that user provided task_id is used""" diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index f1a043e1ea021d..a99e1b585ff864 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -2565,7 +2565,7 @@ def noop_pipeline(): dag = noop_pipeline() assert isinstance(dag, DAG) - assert dag.dag_id, "noop_pipeline" + assert dag.dag_id == "noop_pipeline" assert dag.fileloc == __file__ def test_set_dag_id(self): @@ -2573,50 +2573,41 @@ def test_set_dag_id(self): @dag_decorator("test", default_args=self.DEFAULT_ARGS) def noop_pipeline(): - @task_decorator - def return_num(num): - return num - - return_num(4) + ... dag = noop_pipeline() assert isinstance(dag, DAG) - assert dag.dag_id, "test" + assert dag.dag_id == "test" def test_default_dag_id(self): """Test that @dag uses function name as default dag id.""" @dag_decorator(default_args=self.DEFAULT_ARGS) def noop_pipeline(): - @task_decorator - def return_num(num): - return num - - return_num(4) + ... dag = noop_pipeline() assert isinstance(dag, DAG) - assert dag.dag_id, "noop_pipeline" + assert dag.dag_id == "noop_pipeline" - def test_documentation_added(self): - """Test that @dag uses function docs as doc_md for DAG object""" + @pytest.mark.parametrize( + argnames=["dag_doc_md", "expected_doc_md"], + argvalues=[ + pytest.param("dag docs.", "dag docs.", id="use_dag_doc_md"), + pytest.param(None, "Regular DAG documentation", id="use_dag_docstring"), + ], + ) + def test_documentation_added(self, dag_doc_md, expected_doc_md): + """Test that @dag uses function docs as doc_md for DAG object if doc_md is not explicitly set.""" - @dag_decorator(default_args=self.DEFAULT_ARGS) + @dag_decorator(default_args=self.DEFAULT_ARGS, doc_md=dag_doc_md) def noop_pipeline(): - """ - Regular DAG documentation - """ - - @task_decorator - def return_num(num): - return num - - return_num(4) + """Regular DAG documentation""" dag = noop_pipeline() assert isinstance(dag, DAG) - assert dag.dag_id, "test" - assert dag.doc_md.strip(), "Regular DAG documentation" + assert dag.dag_id == "noop_pipeline" + assert dag.doc_md == expected_doc_md def test_documentation_template_rendered(self): """Test that @dag uses function docs as doc_md for DAG object""" @@ -2629,16 +2620,10 @@ def noop_pipeline(): {% endif %} """ - @task_decorator - def return_num(num): - return num - - return_num(4) - dag = noop_pipeline() assert isinstance(dag, DAG) - assert dag.dag_id, "test" - assert dag.doc_md.strip(), "Regular DAG documentation" + assert dag.dag_id == "noop_pipeline" + assert "Regular DAG documentation" in dag.doc_md def test_resolve_documentation_template_file_rendered(self): """Test that @dag uses function docs as doc_md for DAG object""" @@ -2652,16 +2637,19 @@ def test_resolve_documentation_template_file_rendered(self): """ ) f.flush() + template_dir = os.path.dirname(f.name) template_file = os.path.basename(f.name) - with DAG("test-dag", start_date=DEFAULT_DATE, doc_md=template_file) as dag: - task = EmptyOperator(task_id="op1") - - task + @dag_decorator( + "test-dag", start_date=DEFAULT_DATE, template_searchpath=template_dir, doc_md=template_file + ) + def markdown_docs(): + ... - assert isinstance(dag, DAG) - assert dag.dag_id, "test" - assert dag.doc_md.strip(), "External Markdown DAG documentation" + dag = markdown_docs() + assert isinstance(dag, DAG) + assert dag.dag_id == "test-dag" + assert dag.doc_md.strip() == "External Markdown DAG documentation" def test_fails_if_arg_not_set(self): """Test that @dag decorated function fails if positional argument is not set""" @@ -2731,7 +2719,7 @@ def return_num(num): self.operator.run(start_date=self.DEFAULT_DATE, end_date=self.DEFAULT_DATE) ti = dr.get_task_instances()[0] - assert ti.xcom_pull(), new_value + assert ti.xcom_pull() == new_value @pytest.mark.parametrize("value", [VALUE, 0]) def test_set_params_for_dag(self, value):