Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Taskgroup decorator #15034

Merged
merged 18 commits into from
Apr 1, 2021
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions airflow/decorators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import Callable, Optional

from airflow.decorators.python import python_task
from airflow.decorators.task_group import task_group # noqa # pylint: disable=unused-import
from airflow.models.dag import dag # noqa # pylint: disable=unused-import


Expand Down
72 changes: 72 additions & 0 deletions airflow/decorators/task_group.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
A TaskGroup is a collection of closely related tasks on the same DAG that should be grouped
together when the DAG is displayed graphically.
"""
import functools
from inspect import signature
from typing import Callable, Optional, TypeVar, cast

from airflow.utils.task_group import TaskGroup

T = TypeVar("T", bound=Callable) # pylint: disable=invalid-name

task_group_sig = signature(TaskGroup.__init__)


def task_group(python_callable: Optional[Callable] = None, *tg_args, **tg_kwargs) -> Callable[[T], T]:
"""
Python TaskGroup decorator. Wraps a function into an Airflow TaskGroup.
Accepts kwargs for operator TaskGroup. Can be used to parametrize TaskGroup.

:param python_callable: Function to decorate
:param tg_args: Arguments for TaskGroup object
:type tg_args: list
:param tg_kwargs: Kwargs for TaskGroup object.
:type tg_kwargs: dict
"""

def wrapper(f: T):
# Setting group_id as function name if not given in kwarg group_id
if not tg_args and 'group_id' not in tg_kwargs:
tg_kwargs['group_id'] = f.__name__
task_group_bound_args = task_group_sig.bind_partial(*tg_args, **tg_kwargs)

@functools.wraps(f)
def factory(*args, **kwargs):
# Generate signature for decorated function and bind the arguments when called
# we do this to extract parameters so we can annotate them on the DAG object.
# In addition, this fails if we are missing any args/kwargs with TypeError as expected.
# Apply defaults to capture default values if set.

# Initialize TaskGroup with bound arguments
with TaskGroup(
*task_group_bound_args.args, add_suffix_on_collision=True, **task_group_bound_args.kwargs
) as tg_obj:
# Invoke function to run Tasks inside the TaskGroup
f(*args, **kwargs)

# Return task_group object such that it's accessible in Globals.
return tg_obj

return cast(T, factory)

if callable(python_callable):
return wrapper(python_callable)
return wrapper
73 changes: 73 additions & 0 deletions airflow/example_dags/example_task_group_decorator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""Example DAG demonstrating the usage of the @taskgroup decorator."""

from airflow.decorators import task, task_group
from airflow.models.dag import DAG
from airflow.utils.dates import days_ago


# [START howto_task_group_decorator]
# Creating Tasks
@task
def task_start():
"""Dummy Task which is First Task of Dag """
return '[Task_start]'


@task
def task_1(value):
""" Dummy Task1"""
return f'[ Task1 {value} ]'


@task
def task_2(value):
""" Dummy Task2"""
return f'[ Task2 {value} ]'


@task
def task_3(value):
""" Dummy Task3"""
print(f'[ Task3 {value} ]')


@task
def task_end():
""" Dummy Task which is Last Task of Dag """
print('[ Task_End ]')


# Creating TaskGroups
@task_group
def task_group_function(value):
""" TaskGroup for grouping related Tasks"""
return task_3(task_2(task_1(value)))


# Executing Tasks and TaskGroups
with DAG(dag_id="example_task_group_decorator", start_date=days_ago(2), tags=["example"]) as dag:
start_task = task_start()
end_task = task_end()
for i in range(5):
current_task_group = task_group_function(i)
start_task >> current_task_group >> end_task

# [END howto_task_group_decorator]
28 changes: 24 additions & 4 deletions airflow/utils/task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
A TaskGroup is a collection of closely related tasks on the same DAG that should be grouped
together when the DAG is displayed graphically.
"""

import re
from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Sequence, Set, Union

from airflow.exceptions import AirflowException, DuplicateTaskIdFound
Expand Down Expand Up @@ -54,6 +54,9 @@ class TaskGroup(TaskMixin):
:type ui_color: str
:param ui_fgcolor: The label color of the TaskGroup node when displayed in the UI
:type ui_fgcolor: str
:param add_suffix_on_collision: If this task group name already exists,
automatically add `__1` etc suffixes
:type from_decorator: add_suffix_on_collision
"""

def __init__(
Expand All @@ -65,6 +68,7 @@ def __init__(
tooltip: str = "",
ui_color: str = "CornflowerBlue",
ui_fgcolor: str = "#000",
add_suffix_on_collision: bool = False,
):
from airflow.models.dag import DagContext

Expand Down Expand Up @@ -95,8 +99,24 @@ def __init__(
self.used_group_ids = self._parent_group.used_group_ids

self._group_id = group_id
if self.group_id in self.used_group_ids:
raise DuplicateTaskIdFound(f"group_id '{self.group_id}' has already been added to the DAG")
# if given group_id already used assign suffix by incrementing largest used suffix integer
# Example : task_group ==> task_group__1 -> task_group__2 -> task_group__3
if group_id in self.used_group_ids:
if not add_suffix_on_collision:
raise DuplicateTaskIdFound(f"group_id '{self.group_id}' has already been added to the DAG")
base = re.split(r'__\d+$', group_id)[0]
suffixes = sorted(
[
int(re.split(r'^.+__', used_group_id)[1])
for used_group_id in self.used_group_ids
if used_group_id is not None and re.match(rf'^{base}__\d+$', used_group_id)
]
)
if not suffixes:
self._group_id += '__1'
else:
self._group_id = f'{base}__{suffixes[-1] + 1}'

self.used_group_ids.add(self.group_id)
self.used_group_ids.add(self.downstream_join_id)
self.used_group_ids.add(self.upstream_join_id)
Expand Down Expand Up @@ -316,7 +336,7 @@ class TaskGroupContext:
_previous_context_managed_task_groups: List[TaskGroup] = []

@classmethod
def push_context_managed_task_group(cls, task_group: TaskGroup):
def push_context_managed_task_group(cls, task_group: TaskGroup): # pylint: disable=redefined-outer-name
"""Push a TaskGroup into the list of managed TaskGroups."""
if cls._context_managed_task_group:
cls._previous_context_managed_task_groups.append(cls._context_managed_task_group)
Expand Down
8 changes: 8 additions & 0 deletions docs/apache-airflow/concepts.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1129,6 +1129,14 @@ This animated gif shows the UI interactions. TaskGroups are expanded or collapse
.. image:: img/task_group.gif


TaskGroup can be created using ``@task_group`` decorator, it takes one argument ``group_id`` which is same as constructor of TaskGroup class, if not given it copies function name as ``group_id``. It works exactly same as creating TaskGroup using context manager ``with TaskGroup('groupid') as section:``.

.. exampleinclude:: /../../airflow/example_dags/example_task_group_decorator.py
:language: python
:start-after: [START howto_task_group_decorator]
:end-before: [END howto_task_group_decorator]


SLAs
====

Expand Down
1 change: 1 addition & 0 deletions docs/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1317,6 +1317,7 @@ tablefmt
tagKey
tagValue
tao
task_group
taskflow
taskinstance
tblproperties
Expand Down
Loading