Skip to content

Commit

Permalink
Fix only create task group and test task only if model has test (#543)
Browse files Browse the repository at this point in the history
This PR will only create task group if `dbt` test exist by adding
propery `has_test` on `DbtNode`. The test dependency will be updated
after loading the dbt project into `DbtGraph`

Closes: #531
  • Loading branch information
raffifu authored Sep 25, 2023
1 parent 502051c commit d70e6d7
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 32 deletions.
22 changes: 13 additions & 9 deletions cosmos/airflow/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def create_test_task_metadata(


def create_task_metadata(
node: DbtNode, execution_mode: ExecutionMode, args: dict[str, Any], use_name_as_task_id_prefix: bool = True
node: DbtNode, execution_mode: ExecutionMode, args: dict[str, Any], use_task_group: bool = False
) -> TaskMetadata | None:
"""
Create the metadata that will be used to instantiate the Airflow Task used to run the Dbt node.
Expand All @@ -106,9 +106,9 @@ def create_task_metadata(

if hasattr(node.resource_type, "value") and node.resource_type in dbt_resource_to_class:
if node.resource_type == DbtResourceType.MODEL:
if use_name_as_task_id_prefix:
task_id = f"{node.name}_run"
else:
task_id = f"{node.name}_run"

if use_task_group is True:
task_id = "run"
else:
task_id = f"{node.name}_{node.resource_type.value}"
Expand Down Expand Up @@ -167,14 +167,18 @@ def build_airflow_graph(
# The exception are the test nodes, since it would be too slow to run test tasks individually.
# If test_behaviour=="after_each", each model task will be bundled with a test task, using TaskGroup
for node_id, node in nodes.items():
use_task_group = (
node.resource_type == DbtResourceType.MODEL
and test_behavior == TestBehavior.AFTER_EACH
and node.has_test is True
)

task_meta = create_task_metadata(
node=node,
execution_mode=execution_mode,
args=task_args,
use_name_as_task_id_prefix=test_behavior != TestBehavior.AFTER_EACH,
node=node, execution_mode=execution_mode, args=task_args, use_task_group=use_task_group
)

if task_meta and node.resource_type != DbtResourceType.TEST:
if node.resource_type == DbtResourceType.MODEL and test_behavior == TestBehavior.AFTER_EACH:
if use_task_group is True:
with TaskGroup(dag=dag, group_id=node.name, parent_group=task_group) as model_task_group:
task = create_airflow_task(task_meta, dag, task_group=model_task_group)
test_meta = create_test_task_metadata(
Expand Down
22 changes: 22 additions & 0 deletions cosmos/dbt/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class DbtNode:
file_path: Path
tags: list[str] = field(default_factory=lambda: [])
config: dict[str, Any] = field(default_factory=lambda: {})
has_test: bool = False


class DbtGraph:
Expand Down Expand Up @@ -262,6 +263,8 @@ def load_via_dbt_ls(self) -> None:
self.nodes = nodes
self.filtered_nodes = nodes

self.update_node_dependency()

logger.info("Total nodes: %i", len(self.nodes))
logger.info("Total filtered nodes: %i", len(self.nodes))

Expand Down Expand Up @@ -306,6 +309,8 @@ def load_via_custom_parser(self) -> None:
project_dir=self.project.dir, nodes=nodes, select=self.select, exclude=self.exclude
)

self.update_node_dependency()

logger.info("Total nodes: %i", len(self.nodes))
logger.info("Total filtered nodes: %i", len(self.nodes))

Expand Down Expand Up @@ -335,11 +340,28 @@ def load_from_dbt_manifest(self) -> None:
tags=node_dict["tags"],
config=node_dict["config"],
)

nodes[node.unique_id] = node

self.nodes = nodes
self.filtered_nodes = select_nodes(
project_dir=self.project.dir, nodes=nodes, select=self.select, exclude=self.exclude
)

self.update_node_dependency()

logger.info("Total nodes: %i", len(self.nodes))
logger.info("Total filtered nodes: %i", len(self.nodes))

def update_node_dependency(self) -> None:
"""
This will update the property `has_text` if node has `dbt` test
Updates in-place:
* self.filtered_nodes
"""
for _, node in self.filtered_nodes.items():
if node.resource_type == DbtResourceType.TEST:
for node_id in node.depends_on:
if node_id in self.filtered_nodes:
self.filtered_nodes[node_id].has_test = True
35 changes: 12 additions & 23 deletions tests/airflow/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
file_path=SAMPLE_PROJ_PATH / "gen2/models/parent.sql",
tags=["has_child"],
config={"materialized": "view"},
has_test=True,
)
test_parent_node = DbtNode(
name="test_parent", unique_id="test_parent", resource_type=DbtResourceType.TEST, depends_on=["parent"], file_path=""
Expand All @@ -49,15 +50,8 @@
tags=["nightly"],
config={"materialized": "table"},
)
test_child_node = DbtNode(
name="test_child",
unique_id="test_child",
resource_type=DbtResourceType.TEST,
depends_on=["child"],
file_path="",
)

sample_nodes_list = [parent_seed, parent_node, test_parent_node, child_node, test_child_node]
sample_nodes_list = [parent_seed, parent_node, test_parent_node, child_node]
sample_nodes = {node.unique_id: node for node in sample_nodes_list}


Expand Down Expand Up @@ -93,21 +87,18 @@ def test_build_airflow_graph_with_after_each():
"seed_parent_seed",
"parent.run",
"parent.test",
"child.run",
"child.test",
"child_run",
]

assert topological_sort == expected_sort
task_groups = dag.task_group_dict
assert len(task_groups) == 2
assert len(task_groups) == 1

assert task_groups["parent"].upstream_task_ids == {"seed_parent_seed"}
assert list(task_groups["parent"].children.keys()) == ["parent.run", "parent.test"]

assert task_groups["child"].upstream_task_ids == {"parent.test"}
assert list(task_groups["child"].children.keys()) == ["child.run", "child.test"]

assert len(dag.leaves) == 1
assert dag.leaves[0].task_id == "child.test"
assert dag.leaves[0].task_id == "child_run"


@pytest.mark.skipif(
Expand Down Expand Up @@ -231,7 +222,7 @@ def test_create_task_metadata_model(caplog):
assert metadata.arguments == {"models": "my_model"}


def test_create_task_metadata_model_use_name_as_task_id_prefix(caplog):
def test_create_task_metadata_model_use_task_group(caplog):
child_node = DbtNode(
name="my_model",
unique_id="my_folder.my_model",
Expand All @@ -241,14 +232,12 @@ def test_create_task_metadata_model_use_name_as_task_id_prefix(caplog):
tags=[],
config={},
)
metadata = create_task_metadata(
child_node, execution_mode=ExecutionMode.LOCAL, args={}, use_name_as_task_id_prefix=False
)
metadata = create_task_metadata(child_node, execution_mode=ExecutionMode.LOCAL, args={}, use_task_group=True)
assert metadata.id == "run"


@pytest.mark.parametrize("use_name_as_task_id_prefix", (None, True, False))
def test_create_task_metadata_seed(caplog, use_name_as_task_id_prefix):
@pytest.mark.parametrize("use_task_group", (None, True, False))
def test_create_task_metadata_seed(caplog, use_task_group):
sample_node = DbtNode(
name="my_seed",
unique_id="my_folder.my_seed",
Expand All @@ -258,14 +247,14 @@ def test_create_task_metadata_seed(caplog, use_name_as_task_id_prefix):
tags=[],
config={},
)
if use_name_as_task_id_prefix is None:
if use_task_group is None:
metadata = create_task_metadata(sample_node, execution_mode=ExecutionMode.DOCKER, args={})
else:
metadata = create_task_metadata(
sample_node,
execution_mode=ExecutionMode.DOCKER,
args={},
use_name_as_task_id_prefix=use_name_as_task_id_prefix,
use_task_group=use_task_group,
)
assert metadata.id == "my_seed_seed"
assert metadata.operator_class == "cosmos.operators.docker.DbtSeedDockerOperator"
Expand Down
29 changes: 29 additions & 0 deletions tests/dbt/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,3 +312,32 @@ def test_load_via_load_via_custom_parser(pipeline_name):
assert dbt_graph.nodes == dbt_graph.filtered_nodes
# the custom parser does not add dbt test nodes
assert len(dbt_graph.nodes) == 8


@patch("cosmos.dbt.graph.DbtGraph.update_node_dependency", return_value=None)
def test_update_node_dependency_called(mock_update_node_dependency):
dbt_project = DbtProject(name="jaffle_shop", root_dir=DBT_PROJECTS_ROOT_DIR, manifest_path=SAMPLE_MANIFEST)
dbt_graph = DbtGraph(project=dbt_project)
dbt_graph.load()

assert mock_update_node_dependency.called


def test_update_node_dependency_target_exist():
dbt_project = DbtProject(name="jaffle_shop", root_dir=DBT_PROJECTS_ROOT_DIR, manifest_path=SAMPLE_MANIFEST)
dbt_graph = DbtGraph(project=dbt_project)
dbt_graph.load()

for _, nodes in dbt_graph.nodes.items():
if nodes.resource_type == DbtResourceType.TEST:
for node_id in nodes.depends_on:
assert dbt_graph.nodes[node_id].has_test is True


def test_update_node_dependency_test_not_exist():
dbt_project = DbtProject(name="jaffle_shop", root_dir=DBT_PROJECTS_ROOT_DIR, manifest_path=SAMPLE_MANIFEST)
dbt_graph = DbtGraph(project=dbt_project, exclude=["config.materialized:test"])
dbt_graph.load_from_dbt_manifest()

for _, nodes in dbt_graph.filtered_nodes.items():
assert nodes.has_test is False

0 comments on commit d70e6d7

Please sign in to comment.