Skip to content

Commit

Permalink
Add dag arg_key as specific_args_keys (astronomer#916)
Browse files Browse the repository at this point in the history
This code fixes the problem that dag parameter cannot be passed as
argument at `DbtTaskGroup`.

Previously,  this would work:
```
with DAG(...):
    DbtTaskGroup(...) 
```

But this would fail:
```
dag = DAG(...)
DbtTaskGroup(dag=dag...) 
```

Both work now.

This change has been made to not affecting
`DbtToAirflowConverter` class at all.

Closes: astronomer#915
  • Loading branch information
tboutaour authored and arojasb3 committed Jul 14, 2024
1 parent 581fea7 commit c654123
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 1 deletion.
2 changes: 1 addition & 1 deletion cosmos/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def airflow_kwargs(**kwargs: dict[str, Any]) -> dict[str, Any]:
new_kwargs = {}
non_airflow_kwargs = specific_kwargs(**kwargs)
for arg_key, arg_value in kwargs.items():
if arg_key not in non_airflow_kwargs:
if arg_key not in non_airflow_kwargs or arg_key == "dag":
new_kwargs[arg_key] = arg_value
return new_kwargs

Expand Down
26 changes: 26 additions & 0 deletions tests/airflow/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
TestBehavior,
TestIndirectSelection,
)
from cosmos.converter import airflow_kwargs
from cosmos.dbt.graph import DbtNode
from cosmos.profiles import PostgresUserPasswordProfileMapping

Expand Down Expand Up @@ -431,3 +432,28 @@ def test_create_test_task_metadata(node_type, node_unique_id, test_indirect_sele
)
def test_snake_case_to_camelcase(input, expected):
assert _snake_case_to_camelcase(input) == expected


def test_airflow_kwargs_generation():
"""
airflow_kwargs_generation should always contain dag.
"""
task_args = {
"group_id": "fake_group_id",
"project_dir": SAMPLE_PROJ_PATH,
"conn_id": "fake_conn",
"render_config": RenderConfig(select=["fake-render"]),
"default_args": {"retries": 2},
"profile_config": ProfileConfig(
profile_name="default",
target_name="default",
profile_mapping=PostgresUserPasswordProfileMapping(
conn_id="fake_conn",
profile_args={"schema": "public"},
),
),
"dag": DAG(dag_id="fake_dag_name"),
}
result = airflow_kwargs(**task_args)

assert "dag" in result

0 comments on commit c654123

Please sign in to comment.