diff --git a/airflow/decorators/__init__.py b/airflow/decorators/__init__.py index 8295355787b662..7363eb4b8af126 100644 --- a/airflow/decorators/__init__.py +++ b/airflow/decorators/__init__.py @@ -18,6 +18,7 @@ from typing import Callable, Optional from airflow.decorators.python import python_task +from airflow.decorators.task_group import task_group # noqa # pylint: disable=unused-import from airflow.models.dag import dag # noqa # pylint: disable=unused-import diff --git a/airflow/decorators/task_group.py b/airflow/decorators/task_group.py new file mode 100644 index 00000000000000..8c169ddc28ac3c --- /dev/null +++ b/airflow/decorators/task_group.py @@ -0,0 +1,72 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +A TaskGroup is a collection of closely related tasks on the same DAG that should be grouped +together when the DAG is displayed graphically. +""" +import functools +from inspect import signature +from typing import Callable, Optional, TypeVar, cast + +from airflow.utils.task_group import TaskGroup + +T = TypeVar("T", bound=Callable) # pylint: disable=invalid-name + +task_group_sig = signature(TaskGroup.__init__) + + +def task_group(python_callable: Optional[Callable] = None, *tg_args, **tg_kwargs) -> Callable[[T], T]: + """ + Python TaskGroup decorator. Wraps a function into an Airflow TaskGroup. + Accepts kwargs for operator TaskGroup. Can be used to parametrize TaskGroup. + + :param python_callable: Function to decorate + :param tg_args: Arguments for TaskGroup object + :type tg_args: list + :param tg_kwargs: Kwargs for TaskGroup object. + :type tg_kwargs: dict + """ + + def wrapper(f: T): + # Setting group_id as function name if not given in kwarg group_id + if not tg_args and 'group_id' not in tg_kwargs: + tg_kwargs['group_id'] = f.__name__ + task_group_bound_args = task_group_sig.bind_partial(*tg_args, **tg_kwargs) + + @functools.wraps(f) + def factory(*args, **kwargs): + # Generate signature for decorated function and bind the arguments when called + # we do this to extract parameters so we can annotate them on the DAG object. + # In addition, this fails if we are missing any args/kwargs with TypeError as expected. + # Apply defaults to capture default values if set. + + # Initialize TaskGroup with bound arguments + with TaskGroup( + *task_group_bound_args.args, add_suffix_on_collision=True, **task_group_bound_args.kwargs + ) as tg_obj: + # Invoke function to run Tasks inside the TaskGroup + f(*args, **kwargs) + + # Return task_group object such that it's accessible in Globals. + return tg_obj + + return cast(T, factory) + + if callable(python_callable): + return wrapper(python_callable) + return wrapper diff --git a/airflow/example_dags/example_task_group_decorator.py b/airflow/example_dags/example_task_group_decorator.py new file mode 100644 index 00000000000000..39ee6620cda95f --- /dev/null +++ b/airflow/example_dags/example_task_group_decorator.py @@ -0,0 +1,73 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Example DAG demonstrating the usage of the @taskgroup decorator.""" + +from airflow.decorators import task, task_group +from airflow.models.dag import DAG +from airflow.utils.dates import days_ago + + +# [START howto_task_group_decorator] +# Creating Tasks +@task +def task_start(): + """Dummy Task which is First Task of Dag """ + return '[Task_start]' + + +@task +def task_1(value): + """ Dummy Task1""" + return f'[ Task1 {value} ]' + + +@task +def task_2(value): + """ Dummy Task2""" + return f'[ Task2 {value} ]' + + +@task +def task_3(value): + """ Dummy Task3""" + print(f'[ Task3 {value} ]') + + +@task +def task_end(): + """ Dummy Task which is Last Task of Dag """ + print('[ Task_End ]') + + +# Creating TaskGroups +@task_group +def task_group_function(value): + """ TaskGroup for grouping related Tasks""" + return task_3(task_2(task_1(value))) + + +# Executing Tasks and TaskGroups +with DAG(dag_id="example_task_group_decorator", start_date=days_ago(2), tags=["example"]) as dag: + start_task = task_start() + end_task = task_end() + for i in range(5): + current_task_group = task_group_function(i) + start_task >> current_task_group >> end_task + +# [END howto_task_group_decorator] diff --git a/airflow/utils/task_group.py b/airflow/utils/task_group.py index 521c53d94ac866..551eb4886ceb12 100644 --- a/airflow/utils/task_group.py +++ b/airflow/utils/task_group.py @@ -19,7 +19,7 @@ A TaskGroup is a collection of closely related tasks on the same DAG that should be grouped together when the DAG is displayed graphically. """ - +import re from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Sequence, Set, Union from airflow.exceptions import AirflowException, DuplicateTaskIdFound @@ -54,6 +54,9 @@ class TaskGroup(TaskMixin): :type ui_color: str :param ui_fgcolor: The label color of the TaskGroup node when displayed in the UI :type ui_fgcolor: str + :param add_suffix_on_collision: If this task group name already exists, + automatically add `__1` etc suffixes + :type from_decorator: add_suffix_on_collision """ def __init__( @@ -65,6 +68,7 @@ def __init__( tooltip: str = "", ui_color: str = "CornflowerBlue", ui_fgcolor: str = "#000", + add_suffix_on_collision: bool = False, ): from airflow.models.dag import DagContext @@ -95,8 +99,24 @@ def __init__( self.used_group_ids = self._parent_group.used_group_ids self._group_id = group_id - if self.group_id in self.used_group_ids: - raise DuplicateTaskIdFound(f"group_id '{self.group_id}' has already been added to the DAG") + # if given group_id already used assign suffix by incrementing largest used suffix integer + # Example : task_group ==> task_group__1 -> task_group__2 -> task_group__3 + if group_id in self.used_group_ids: + if not add_suffix_on_collision: + raise DuplicateTaskIdFound(f"group_id '{self.group_id}' has already been added to the DAG") + base = re.split(r'__\d+$', group_id)[0] + suffixes = sorted( + [ + int(re.split(r'^.+__', used_group_id)[1]) + for used_group_id in self.used_group_ids + if used_group_id is not None and re.match(rf'^{base}__\d+$', used_group_id) + ] + ) + if not suffixes: + self._group_id += '__1' + else: + self._group_id = f'{base}__{suffixes[-1] + 1}' + self.used_group_ids.add(self.group_id) self.used_group_ids.add(self.downstream_join_id) self.used_group_ids.add(self.upstream_join_id) @@ -316,7 +336,7 @@ class TaskGroupContext: _previous_context_managed_task_groups: List[TaskGroup] = [] @classmethod - def push_context_managed_task_group(cls, task_group: TaskGroup): + def push_context_managed_task_group(cls, task_group: TaskGroup): # pylint: disable=redefined-outer-name """Push a TaskGroup into the list of managed TaskGroups.""" if cls._context_managed_task_group: cls._previous_context_managed_task_groups.append(cls._context_managed_task_group) diff --git a/docs/apache-airflow/concepts.rst b/docs/apache-airflow/concepts.rst index 8ca63223ccc54f..c286d8a7a93051 100644 --- a/docs/apache-airflow/concepts.rst +++ b/docs/apache-airflow/concepts.rst @@ -1129,6 +1129,14 @@ This animated gif shows the UI interactions. TaskGroups are expanded or collapse .. image:: img/task_group.gif +TaskGroup can be created using ``@task_group`` decorator, it takes one argument ``group_id`` which is same as constructor of TaskGroup class, if not given it copies function name as ``group_id``. It works exactly same as creating TaskGroup using context manager ``with TaskGroup('groupid') as section:``. + +.. exampleinclude:: /../../airflow/example_dags/example_task_group_decorator.py + :language: python + :start-after: [START howto_task_group_decorator] + :end-before: [END howto_task_group_decorator] + + SLAs ==== diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index ab394d3cca5edb..f203300c76a50e 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -1317,6 +1317,7 @@ tablefmt tagKey tagValue tao +task_group taskflow taskinstance tblproperties diff --git a/tests/utils/test_task_group.py b/tests/utils/test_task_group.py index 0bc80e32a43c31..ceefdbf56231cf 100644 --- a/tests/utils/test_task_group.py +++ b/tests/utils/test_task_group.py @@ -19,8 +19,11 @@ import pendulum import pytest +from airflow.decorators import task_group as task_group_decorator from airflow.models import DAG +from airflow.operators.bash import BashOperator from airflow.operators.dummy import DummyOperator +from airflow.operators.python import PythonOperator from airflow.utils.task_group import TaskGroup from airflow.www.views import dag_edges, task_group_to_dict @@ -269,7 +272,7 @@ def test_build_task_group_with_task_decorator(): """ Test that TaskGroup can be used with the @task decorator. """ - from airflow.operators.python import task + from airflow.decorators import task @task def task_1(): @@ -576,3 +579,371 @@ def test_task_without_dag(): assert op1.dag == op2.dag == op3.dag assert dag.task_group.children.keys() == {"op1", "op2", "op3"} assert dag.task_group.children.keys() == dag.task_dict.keys() + + +# taskgroup decorator tests + + +def test_build_task_group_deco_context_manager(): + """ + Tests Following : + 1. Nested TaskGroup creation using taskgroup decorator should create same TaskGroup which can be + created using TaskGroup context manager. + 2. TaskGroup consisting Tasks created using task decorator. + 3. Node Ids of dags created with taskgroup decorator. + """ + + from airflow.decorators import task + + # Creating Tasks + @task + def task_start(): + """Dummy Task which is First Task of Dag """ + return '[Task_start]' + + @task + def task_end(): + """Dummy Task which is Last Task of Dag""" + print('[ Task_End ]') + + @task + def task_1(value): + """ Dummy Task1""" + return f'[ Task1 {value} ]' + + @task + def task_2(value): + """ Dummy Task2""" + print(f'[ Task2 {value} ]') + + @task + def task_3(value): + """ Dummy Task3""" + return f'[ Task3 {value} ]' + + @task + def task_4(value): + """ Dummy Task3""" + print(f'[ Task4 {value} ]') + + # Creating TaskGroups + @task_group_decorator + def section_1(value): + """ TaskGroup for grouping related Tasks""" + + @task_group_decorator() + def section_2(value2): + """ TaskGroup for grouping related Tasks""" + return task_4(task_3(value2)) + + op1 = task_2(task_1(value)) + return section_2(op1) + + execution_date = pendulum.parse("20201109") + with DAG( + dag_id="example_nested_task_group_decorator", start_date=execution_date, tags=["example"] + ) as dag: + t_start = task_start() + sec_1 = section_1(t_start) + sec_1.set_downstream(task_end()) + + # Testing TaskGroup created using taskgroup decorator + assert set(dag.task_group.children.keys()) == {"task_start", "task_end", "section_1"} + assert set(dag.task_group.children['section_1'].children.keys()) == { + 'section_1.task_1', + 'section_1.task_2', + 'section_1.section_2', + } + + # Testing TaskGroup consisting Tasks created using task decorator + assert dag.task_dict['task_start'].downstream_task_ids == {'section_1.task_1'} + assert dag.task_dict['section_1.task_2'].downstream_task_ids == {'section_1.section_2.task_3'} + assert dag.task_dict['section_1.section_2.task_4'].downstream_task_ids == {'task_end'} + + # Node IDs test + node_ids = { + 'id': None, + 'children': [ + { + 'id': 'section_1', + 'children': [ + { + 'id': 'section_1.section_2', + 'children': [ + {'id': 'section_1.section_2.task_3'}, + {'id': 'section_1.section_2.task_4'}, + ], + }, + {'id': 'section_1.task_1'}, + {'id': 'section_1.task_2'}, + {'id': 'section_1.downstream_join_id'}, + ], + }, + {'id': 'task_end'}, + {'id': 'task_start'}, + ], + } + + assert extract_node_id(task_group_to_dict(dag.task_group)) == node_ids + + +def test_build_task_group_with_operators(): + """ Tests DAG with Tasks created with *Operators and TaskGroup created with taskgroup decorator """ + + from airflow.decorators import task + + def task_start(): + """Dummy Task which is First Task of Dag """ + return '[Task_start]' + + def task_end(): + """Dummy Task which is Last Task of Dag""" + print('[ Task_End ]') + + # Creating Tasks + @task + def task_1(value): + """ Dummy Task1""" + return f'[ Task1 {value} ]' + + @task + def task_2(value): + """ Dummy Task2""" + return f'[ Task2 {value} ]' + + @task + def task_3(value): + """ Dummy Task3""" + print(f'[ Task3 {value} ]') + + # Creating TaskGroups + @task_group_decorator(group_id='section_1') + def section_a(value): + """ TaskGroup for grouping related Tasks""" + return task_3(task_2(task_1(value))) + + execution_date = pendulum.parse("20201109") + with DAG(dag_id="example_task_group_decorator_mix", start_date=execution_date, tags=["example"]) as dag: + t_start = PythonOperator(task_id='task_start', python_callable=task_start, dag=dag) + sec_1 = section_a(t_start.output) + t_end = PythonOperator(task_id='task_end', python_callable=task_end, dag=dag) + sec_1.set_downstream(t_end) + + # Testing Tasks ing DAG + assert set(dag.task_group.children.keys()) == {'section_1', 'task_start', 'task_end'} + assert set(dag.task_group.children['section_1'].children.keys()) == { + 'section_1.task_2', + 'section_1.task_3', + 'section_1.task_1', + } + + # Testing Tasks downstream + assert dag.task_dict['task_start'].downstream_task_ids == {'section_1.task_1'} + assert dag.task_dict['section_1.task_3'].downstream_task_ids == {'task_end'} + + +def test_task_group_context_mix(): + """ Test cases to check nested TaskGroup context manager with taskgroup decorator""" + + from airflow.decorators import task + + def task_start(): + """Dummy Task which is First Task of Dag """ + return '[Task_start]' + + def task_end(): + """Dummy Task which is Last Task of Dag""" + print('[ Task_End ]') + + # Creating Tasks + @task + def task_1(value): + """ Dummy Task1""" + return f'[ Task1 {value} ]' + + @task + def task_2(value): + """ Dummy Task2""" + return f'[ Task2 {value} ]' + + @task + def task_3(value): + """ Dummy Task3""" + print(f'[ Task3 {value} ]') + + # Creating TaskGroups + @task_group_decorator + def section_2(value): + """ TaskGroup for grouping related Tasks""" + return task_3(task_2(task_1(value))) + + execution_date = pendulum.parse("20201109") + with DAG(dag_id="example_task_group_decorator_mix", start_date=execution_date, tags=["example"]) as dag: + t_start = PythonOperator(task_id='task_start', python_callable=task_start, dag=dag) + + with TaskGroup("section_1", tooltip="section_1") as section_1: + sec_2 = section_2(t_start.output) + task_s1 = DummyOperator(task_id="task_1") + task_s2 = BashOperator(task_id="task_2", bash_command='echo 1') + task_s3 = DummyOperator(task_id="task_3") + + sec_2.set_downstream(task_s1) + task_s1 >> [task_s2, task_s3] + + t_end = PythonOperator(task_id='task_end', python_callable=task_end, dag=dag) + t_start >> section_1 >> t_end + + node_ids = { + 'id': None, + 'children': [ + { + 'id': 'section_1', + 'children': [ + { + 'id': 'section_1.section_2', + 'children': [ + {'id': 'section_1.section_2.task_1'}, + {'id': 'section_1.section_2.task_2'}, + {'id': 'section_1.section_2.task_3'}, + {'id': 'section_1.section_2.downstream_join_id'}, + ], + }, + {'id': 'section_1.task_1'}, + {'id': 'section_1.task_2'}, + {'id': 'section_1.task_3'}, + {'id': 'section_1.upstream_join_id'}, + {'id': 'section_1.downstream_join_id'}, + ], + }, + {'id': 'task_end'}, + {'id': 'task_start'}, + ], + } + + assert extract_node_id(task_group_to_dict(dag.task_group)) == node_ids + + +def test_duplicate_task_group_id(): + """ Testing automatic suffix assignment for duplicate group_id""" + + from airflow.decorators import task + + @task(task_id='start_task') + def task_start(): + """Dummy Task which is First Task of Dag """ + print('[Task_start]') + + @task(task_id='end_task') + def task_end(): + """Dummy Task which is Last Task of Dag""" + print('[Task_End]') + + # Creating Tasks + @task(task_id='task') + def task_1(): + """ Dummy Task1""" + print('[Task1]') + + @task(task_id='task') + def task_2(): + """ Dummy Task2""" + print('[Task2]') + + @task(task_id='task1') + def task_3(): + """ Dummy Task3""" + print('[Task3]') + + @task_group_decorator('task_group1') + def task_group1(): + task_start() + task_1() + task_2() + + @task_group_decorator(group_id='task_group1') + def task_group2(): + task_3() + + @task_group_decorator(group_id='task_group1') + def task_group3(): + task_end() + + execution_date = pendulum.parse("20201109") + with DAG(dag_id="example_duplicate_task_group_id", start_date=execution_date, tags=["example"]) as dag: + task_group1() + task_group2() + task_group3() + node_ids = { + 'id': None, + 'children': [ + { + 'id': 'task_group1', + 'children': [ + {'id': 'task_group1.start_task'}, + {'id': 'task_group1.task'}, + {'id': 'task_group1.task__1'}, + ], + }, + {'id': 'task_group1__1', 'children': [{'id': 'task_group1__1.task1'}]}, + {'id': 'task_group1__2', 'children': [{'id': 'task_group1__2.end_task'}]}, + ], + } + + assert extract_node_id(task_group_to_dict(dag.task_group)) == node_ids + + +def test_call_taskgroup_twice(): + """Test for using same taskgroup decorated function twice""" + from airflow.decorators import task + + @task(task_id='start_task') + def task_start(): + """Dummy Task which is First Task of Dag """ + print('[Task_start]') + + @task(task_id='end_task') + def task_end(): + """Dummy Task which is Last Task of Dag""" + print('[Task_End]') + + # Creating Tasks + @task(task_id='task') + def task_1(): + """ Dummy Task1""" + print('[Task1]') + + @task_group_decorator + def task_group1(name: str): + print(f'Starting taskgroup {name}') + task_start() + task_1() + task_end() + + execution_date = pendulum.parse("20201109") + with DAG(dag_id="example_multi_call_task_groups", start_date=execution_date, tags=["example"]) as dag: + task_group1('Call1') + task_group1('Call2') + + node_ids = { + 'id': None, + 'children': [ + { + 'id': 'task_group1', + 'children': [ + {'id': 'task_group1.end_task'}, + {'id': 'task_group1.start_task'}, + {'id': 'task_group1.task'}, + ], + }, + { + 'id': 'task_group1__1', + 'children': [ + {'id': 'task_group1__1.end_task'}, + {'id': 'task_group1__1.start_task'}, + {'id': 'task_group1__1.task'}, + ], + }, + ], + } + + assert extract_node_id(task_group_to_dict(dag.task_group)) == node_ids