From fa2bb47b93ea446a7bb27d547e9a24c1ec55e860 Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Tue, 19 Apr 2022 14:35:28 +0800 Subject: [PATCH] Parse error for task added to multiple groups This raises an exception if a task already belonging to a task group (including added to a DAG, since such task is automatically added to the DAG's root task group). Also, according to the issue response, manually calling TaskGroup.add() is not considered a supported way to add a task to group. So a meta-marker is added to the function docstring to prevent it from showing up in documentation and users from trying to use it. --- airflow/exceptions.py | 17 +++++++++++++++++ airflow/utils/task_group.py | 18 ++++++++++++++++-- tests/utils/test_task_group.py | 22 ++++++++++++++++++++++ 3 files changed, 55 insertions(+), 2 deletions(-) diff --git a/airflow/exceptions.py b/airflow/exceptions.py index 95fa9e3276545b..fa7acf61da1cec 100644 --- a/airflow/exceptions.py +++ b/airflow/exceptions.py @@ -185,6 +185,23 @@ class DuplicateTaskIdFound(AirflowException): """Raise when a Task with duplicate task_id is defined in the same DAG.""" +class TaskAlreadyInTaskGroup(AirflowException): + """Raise when a Task cannot be added to a TaskGroup since it already belongs to another TaskGroup.""" + + def __init__(self, task_id: str, existing_group_id: Optional[str], new_group_id: str) -> None: + super().__init__(task_id, new_group_id) + self.task_id = task_id + self.existing_group_id = existing_group_id + self.new_group_id = new_group_id + + def __str__(self) -> str: + if self.existing_group_id is None: + existing_group = "the DAG's root group" + else: + existing_group = f"group {self.existing_group_id!r}" + return f"cannot add {self.task_id!r} to {self.new_group_id!r} (already in {existing_group})" + + class SerializationError(AirflowException): """A problem occurred when trying to serialize a DAG.""" diff --git a/airflow/utils/task_group.py b/airflow/utils/task_group.py index 11ee806beb9b37..513c31fb781d97 100644 --- a/airflow/utils/task_group.py +++ b/airflow/utils/task_group.py @@ -24,7 +24,12 @@ import weakref from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Sequence, Set, Tuple, Union -from airflow.exceptions import AirflowDagCycleException, AirflowException, DuplicateTaskIdFound +from airflow.exceptions import ( + AirflowDagCycleException, + AirflowException, + DuplicateTaskIdFound, + TaskAlreadyInTaskGroup, +) from airflow.models.taskmixin import DAGNode, DependencyMixin from airflow.serialization.enums import DagAttributeTypes from airflow.utils.helpers import validate_group_key @@ -186,7 +191,16 @@ def __iter__(self): yield child def add(self, task: DAGNode) -> None: - """Add a task to this TaskGroup.""" + """Add a task to this TaskGroup. + + :meta private: + """ + from airflow.models.abstractoperator import AbstractOperator + + existing_tg = task.task_group + if isinstance(task, AbstractOperator) and existing_tg is not None and existing_tg != self: + raise TaskAlreadyInTaskGroup(task.node_id, existing_tg.node_id, self.node_id) + # Set the TG first, as setting it might change the return value of node_id! task.task_group = weakref.proxy(self) key = task.node_id diff --git a/tests/utils/test_task_group.py b/tests/utils/test_task_group.py index 9a65c8d621e90f..708e07bfa97129 100644 --- a/tests/utils/test_task_group.py +++ b/tests/utils/test_task_group.py @@ -20,6 +20,7 @@ import pytest from airflow.decorators import dag, task_group as task_group_decorator +from airflow.exceptions import TaskAlreadyInTaskGroup from airflow.models import DAG from airflow.models.xcom_arg import XComArg from airflow.operators.bash import BashOperator @@ -1201,3 +1202,24 @@ def nested_topo(group): ], task6, ] + + +def test_add_to_sub_group(): + with DAG("test_dag", start_date=pendulum.parse("20200101")): + tg = TaskGroup("section") + task = EmptyOperator(task_id="task") + with pytest.raises(TaskAlreadyInTaskGroup) as ctx: + tg.add(task) + + assert str(ctx.value) == "cannot add 'task' to 'section' (already in the DAG's root group)" + + +def test_add_to_another_group(): + with DAG("test_dag", start_date=pendulum.parse("20200101")): + tg = TaskGroup("section_1") + with TaskGroup("section_2"): + task = EmptyOperator(task_id="task") + with pytest.raises(TaskAlreadyInTaskGroup) as ctx: + tg.add(task) + + assert str(ctx.value) == "cannot add 'section_2.task' to 'section_1' (already in group 'section_2')"