Skip to content

Commit

Permalink
feat: option to group in-memory nodes
Browse files Browse the repository at this point in the history
Signed-off-by: Simon Brugman <[email protected]>
  • Loading branch information
sbrugman committed Dec 13, 2023
1 parent f8f4a7d commit f0a6e4e
Show file tree
Hide file tree
Showing 7 changed files with 266 additions and 14 deletions.
13 changes: 13 additions & 0 deletions kedro-airflow/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,19 @@ See ["What if I want to use a different Jinja2 template?"](#what-if-i-want-to-us
The [rich offering](https://airflow.apache.org/docs/apache-airflow-providers/operators-and-hooks-ref/index.html) of operators means that the `kedro-airflow` plugin is providing templates for specific operators.
The default template provided by `kedro-airflow` uses the `BaseOperator`.

### Can I group nodes together?

When running Kedro nodes using Airflow, MemoryDataSets are often not shared across operators.
This will cause the DAG run to fail.

MemoryDataSets may be used to provide logical separation between nodes in Kedro, without the overhead of needing to write to disk (and in the case of distributed running needing multiple executors).

Nodes that are connected through MemoryDataSets are grouped together via the `--group-in-memory` flag.
This preserves the option to have logical separation in Kedro, with little computational overhead.

It is possible to use [task groups](https://docs.astronomer.io/learn/task-groups) by changing the template.
See ["What if I want to use a different Jinja2 template?"](#what-if-i-want-to-use-a-different-jinja2-template) for instructions on using custom templates.

## Can I contribute?

Yes! Want to help build Kedro-Airflow? Check out our guide to [contributing](https://github.com/kedro-org/kedro-plugins/blob/main/kedro-airflow/CONTRIBUTING.md).
Expand Down
1 change: 1 addition & 0 deletions kedro-airflow/RELEASE.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Upcoming Release
## Community contributions
* Option to group MemoryDataSets in the same Airflow task (breaking change for custom template via `--jinja-file`).

# Release 0.7.0
* Added support for Python 3.11
Expand Down
20 changes: 10 additions & 10 deletions kedro-airflow/kedro_airflow/airflow_dag_template.j2
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import annotations

from datetime import datetime, timedelta
from pathlib import Path

Expand All @@ -16,7 +17,7 @@ class KedroOperator(BaseOperator):
self,
package_name: str,
pipeline_name: str,
node_name: str,
node_name: str | list[str],
project_path: str | Path,
env: str,
*args, **kwargs
Expand All @@ -30,11 +31,10 @@ class KedroOperator(BaseOperator):

def execute(self, context):
configure_project(self.package_name)
with KedroSession.create(self.package_name,
self.project_path,
env=self.env) as session:
session.run(self.pipeline_name, node_names=[self.node_name])

with KedroSession.create(self.package_name, self.project_path, env=self.env) as session:
if isinstance(self.node_name, str):
self.node_name = [self.node_name]
session.run(self.pipeline_name, node_names=self.node_name)

# Kedro settings required to run your pipeline
env = "{{ env }}"
Expand All @@ -61,17 +61,17 @@ with DAG(
)
) as dag:
tasks = {
{% for node in pipeline.nodes %} "{{ node.name | safe | slugify }}": KedroOperator(
task_id="{{ node.name | safe | slugify }}",
{% for node_name, node_list in nodes.items() %} "{{ node_name | safe | slugify }}": KedroOperator(
task_id="{{ node_name | safe | slugify }}",
package_name=package_name,
pipeline_name=pipeline_name,
node_name="{{ node.name | safe }}",
node_name={% if node_list | length > 1 %}[{% endif %}{% for node in node_list %}"{{ node.name | safe | slugify }}"{% if not loop.last %}, {% endif %}{% endfor %}{% if node_list | length > 1 %}]{% endif %},
project_path=project_path,
env=env,
),
{% endfor %} }

{% for parent_node, child_nodes in dependencies.items() -%}
{% for child in child_nodes %} tasks["{{ parent_node.name | safe | slugify }}"] >> tasks["{{ child.name | safe | slugify }}"]
{% for child in child_nodes %} tasks["{{ parent_node | safe | slugify }}"] >> tasks["{{ child | safe | slugify }}"]
{% endfor %}
{%- endfor %}
90 changes: 90 additions & 0 deletions kedro-airflow/kedro_airflow/grouping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from __future__ import annotations

from collections import defaultdict

from kedro.io import DataCatalog, MemoryDataSet
from kedro.pipeline.node import Node
from kedro.pipeline.pipeline import Pipeline


def _is_memory_dataset(catalog, dataset_name: str) -> bool:
if dataset_name == "parameters" or dataset_name.startswith("params:"):
return False

dataset = catalog._data_sets.get(dataset_name, None)
return dataset is not None and isinstance(dataset, MemoryDataSet)


def get_memory_datasets(catalog: DataCatalog, pipeline: Pipeline) -> set[str]:
"""Gather all datasets in the pipeline that are of type MemoryDataSet, excluding 'parameters'."""
return {
dataset_name
for dataset_name in pipeline.data_sets()
if _is_memory_dataset(catalog, dataset_name)
}


def node_sequence_name(node_sequence: list[Node]) -> str:
return "_".join([node.name for node in node_sequence])


def group_memory_nodes(catalog: DataCatalog, pipeline: Pipeline):
# get all memory datasets in the pipeline
ds = get_memory_datasets(catalog, pipeline)

# Node sequences
node_sequences = []

# Mapping from dataset name -> node sequence index
sequence_map = {}
for node in pipeline.nodes:
if all(o not in ds for o in node.inputs + node.outputs):
# standalone node
node_sequences.append([node])
else:
if all(i not in ds for i in node.inputs):
# start of a sequence; create a new sequence and store the id
node_sequences.append([node])
sequence_id = len(node_sequences) - 1
else:
# continuation of a sequence; retrieve sequence_id
sequence_id = None
for i in node.inputs:
if i in ds:
assert sequence_id is None or sequence_id == sequence_map[i]
sequence_id = sequence_map[i]

# Append to map
node_sequences[sequence_id].append(node)

# map outputs to sequence_id
for o in node.outputs:
if o in ds:
sequence_map[o] = sequence_id

# Named node sequences
nodes = {
node_sequence_name(node_sequence): node_sequence
for node_sequence in node_sequences
}

# Inverted mapping
node_mapping = {
node.name: sequence_name
for sequence_name, node_sequence in nodes.items()
for node in node_sequence
}

# Grouped dependencies
dependencies = defaultdict(list)
for node, parent_nodes in pipeline.node_dependencies.items():
for parent in parent_nodes:
parent_name = node_mapping[parent.name]
node_name = node_mapping[node.name]
if parent_name != node_name and (
parent_name not in dependencies
or node_name not in dependencies[parent_name]
):
dependencies[parent_name].append(node_name)

return nodes, dependencies
26 changes: 22 additions & 4 deletions kedro-airflow/kedro_airflow/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from kedro.framework.startup import ProjectMetadata, bootstrap_project
from slugify import slugify

from kedro_airflow.grouping import group_memory_nodes

PIPELINE_ARG_HELP = """Name of the registered pipeline to convert.
If not set, the '__default__' pipeline is used. This argument supports
passing multiple values using `--pipeline [p1] --pipeline [p2]`.
Expand Down Expand Up @@ -100,6 +102,14 @@ def _get_pipeline_config(config_airflow: dict, params: dict, pipeline_name: str)
default=Path(__file__).parent / "airflow_dag_template.j2",
help="The template file for the generated Airflow dags",
)
@click.option(
"-g",
"--group-in-memory",
is_flag=True,
default=False,
help="Group nodes with at least one MemoryDataSet as input/output together, "
"as they do not persist between Airflow operators.",
)
@click.option(
"--params",
type=click.UNPROCESSED,
Expand All @@ -114,6 +124,7 @@ def create( # noqa: PLR0913
env,
target_path,
jinja_file,
group_in_memory,
params,
convert_all: bool,
):
Expand Down Expand Up @@ -165,13 +176,20 @@ def create( # noqa: PLR0913
else f"{package_name}_{name}_dag.py"
)

dependencies = defaultdict(list)
for node, parent_nodes in pipeline.node_dependencies.items():
for parent in parent_nodes:
dependencies[parent].append(node)
# group memory nodes
if group_in_memory:
nodes, dependencies = group_memory_nodes(context.catalog, pipeline)
else:
nodes = {node.name: [node] for node in pipeline.nodes}

dependencies = defaultdict(list)
for node, parent_nodes in pipeline.node_dependencies.items():
for parent in parent_nodes:
dependencies[parent.name].append(node.name)

template.stream(
dag_name=package_name,
nodes=nodes,
dependencies=dependencies,
env=env,
pipeline_name=name,
Expand Down
128 changes: 128 additions & 0 deletions kedro-airflow/tests/test_node_grouping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
from __future__ import annotations

from typing import Any

import pytest
from kedro.io import AbstractDataSet, DataCatalog, MemoryDataSet
from kedro.pipeline import node
from kedro.pipeline.modular_pipeline import pipeline as modular_pipeline

from kedro_airflow.grouping import _is_memory_dataset, group_memory_nodes


class TestDataSet(AbstractDataSet):
def _save(self, data) -> None:
pass

def _describe(self) -> dict[str, Any]:
return {}

def _load(self):
return []


@pytest.mark.parametrize(
"memory_nodes,expected_nodes,expected_dependencies",
[
(
["ds3", "ds6"],
[["f1"], ["f2", "f3", "f4", "f6", "f7"], ["f5"]],
{"f1": ["f2_f3_f4_f6_f7"], "f2_f3_f4_f6_f7": ["f5"]},
),
(
["ds3"],
[["f1"], ["f2", "f3", "f4", "f7"], ["f5"], ["f6"]],
{"f1": ["f2_f3_f4_f7"], "f2_f3_f4_f7": ["f5", "f6"]},
),
(
[],
[["f1"], ["f2"], ["f3"], ["f4"], ["f5"], ["f6"], ["f7"]],
{"f1": ["f2"], "f2": ["f3", "f4", "f5", "f7"], "f4": ["f6", "f7"]},
),
],
)
def test_group_memory_nodes(
memory_nodes: list[str],
expected_nodes: list[list[str]],
expected_dependencies: dict[str, list[str]],
):
"""Check the grouping of memory nodes."""
nodes = [f"ds{i}" for i in range(1, 10)]
assert all(node_name in nodes for node_name in memory_nodes)

mock_catalog = DataCatalog()
for dataset_name in nodes:
if dataset_name in memory_nodes:
dataset = MemoryDataSet()
else:
dataset = TestDataSet()
mock_catalog.add(dataset_name, dataset)

def identity_one_to_one(x):
return x

mock_pipeline = modular_pipeline(
[
node(
func=identity_one_to_one,
inputs="ds1",
outputs="ds2",
name="f1",
),
node(
func=lambda x: (x, x),
inputs="ds2",
outputs=["ds3", "ds4"],
name="f2",
),
node(
func=identity_one_to_one,
inputs="ds3",
outputs="ds5",
name="f3",
),
node(
func=identity_one_to_one,
inputs="ds3",
outputs="ds6",
name="f4",
),
node(
func=identity_one_to_one,
inputs="ds4",
outputs="ds8",
name="f5",
),
node(
func=identity_one_to_one,
inputs="ds6",
outputs="ds7",
name="f6",
),
node(
func=lambda x, y: x,
inputs=["ds3", "ds6"],
outputs="ds9",
name="f7",
),
],
)

nodes, dependencies = group_memory_nodes(mock_catalog, mock_pipeline)
sequence = [
[node_.name for node_ in node_sequence] for node_sequence in nodes.values()
]
assert sequence == expected_nodes
assert dict(dependencies) == expected_dependencies


def test_is_memory_dataset():
catalog = DataCatalog()
catalog.add("parameters", {"hello": "world"})
catalog.add("params:hello", "world")
catalog.add("my_dataset", MemoryDataSet(True))
catalog.add("test_dataset", TestDataSet())
assert not _is_memory_dataset(catalog, "parameters")
assert not _is_memory_dataset(catalog, "params:hello")
assert _is_memory_dataset(catalog, "my_dataset")
assert not _is_memory_dataset(catalog, "test_dataset")
2 changes: 2 additions & 0 deletions kedro-airflow/tests/test_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
("hello_world", "__default__", ["airflow", "create"]),
# Test execution with alternate pipeline name
("hello_world", "ds", ["airflow", "create", "--pipeline", "ds"]),
# Test with grouping
("hello_world", "__default__", ["airflow", "create", "--group-in-memory"]),
],
)
def test_create_airflow_dag(dag_name, pipeline_name, command, cli_runner, metadata):
Expand Down

0 comments on commit f0a6e4e

Please sign in to comment.