From ab2f0092799b499bc49c70395ad0846ec3a10818 Mon Sep 17 00:00:00 2001 From: John Horan Date: Sat, 21 Sep 2024 11:01:06 +0100 Subject: [PATCH] Fix invalid argument (`full_refresh`) passed to DbtTestAwsEksOperator (and others) (#1175) https://github.com/astronomer/astronomer-cosmos/pull/590/ added a fix to consume the kwargs `full_refresh_ignore` if it wasn't consumed by a higher class as it was preventing the use of test in the DbtTaskGroup if `full_refresh_ignore` was set. The previous patch fixed this by consuming the variable for the `DbtLocalBaseOperator`, leaving a bug in kubernetes and docker operator. Since `AbstractDbtBaseOperator` has been added as a base of `DbtDockerBaseOperator`, `DbtKubernetesBaseOperator` and `DbtLocalBaseOperator`, moving the code there will fix all three. Fixes https://github.com/astronomer/astronomer-cosmos/issues/1062 --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tatiana Al-Chueyr --- cosmos/operators/base.py | 1 + cosmos/operators/local.py | 1 - tests/operators/test_kubernetes.py | 65 ++++++++++++++++++++++++++++++ 3 files changed, 66 insertions(+), 1 deletion(-) diff --git a/cosmos/operators/base.py b/cosmos/operators/base.py index 9a723383f..d82083a23 100644 --- a/cosmos/operators/base.py +++ b/cosmos/operators/base.py @@ -139,6 +139,7 @@ def __init__( self.dbt_cmd_global_flags = dbt_cmd_global_flags or [] self.cache_dir = cache_dir self.extra_context = extra_context or {} + kwargs.pop("full_refresh", None) # usage of this param should be implemented in child classes super().__init__(**kwargs) def get_env(self, context: Context) -> dict[str, str | bytes | os.PathLike[Any]]: diff --git a/cosmos/operators/local.py b/cosmos/operators/local.py index 49bf45293..2ba2b18ff 100644 --- a/cosmos/operators/local.py +++ b/cosmos/operators/local.py @@ -145,7 +145,6 @@ def __init__( self._dbt_runner: dbtRunner | None = None if self.invocation_mode: self._set_invocation_methods() - kwargs.pop("full_refresh", None) # usage of this param should be implemented in child classes super().__init__(**kwargs) # For local execution mode, we're consistent with the LoadMode.DBT_LS command in forwarding the environment diff --git a/tests/operators/test_kubernetes.py b/tests/operators/test_kubernetes.py index d0be2acad..8e0dda9c0 100644 --- a/tests/operators/test_kubernetes.py +++ b/tests/operators/test_kubernetes.py @@ -285,3 +285,68 @@ def test_created_pod(): ] assert container.args == expected_container_args assert container.command == [] + + +@pytest.mark.parametrize( + "operator_class,kwargs,expected_cmd", + [ + ( + DbtSeedKubernetesOperator, + {"full_refresh": True}, + ["dbt", "seed", "--full-refresh", "--project-dir", "my/dir"], + ), + ( + DbtBuildKubernetesOperator, + {"full_refresh": True}, + ["dbt", "build", "--full-refresh", "--project-dir", "my/dir"], + ), + ( + DbtRunKubernetesOperator, + {"full_refresh": True}, + ["dbt", "run", "--full-refresh", "--project-dir", "my/dir"], + ), + ( + DbtTestKubernetesOperator, + {}, + ["dbt", "test", "--project-dir", "my/dir"], + ), + ( + DbtTestKubernetesOperator, + {"select": []}, + ["dbt", "test", "--project-dir", "my/dir"], + ), + ( + DbtTestKubernetesOperator, + {"full_refresh": True, "select": ["tag:daily"], "exclude": ["tag:disabled"]}, + ["dbt", "test", "--select", "tag:daily", "--exclude", "tag:disabled", "--project-dir", "my/dir"], + ), + ( + DbtTestKubernetesOperator, + {"full_refresh": True, "selector": "nightly_snowplow"}, + ["dbt", "test", "--selector", "nightly_snowplow", "--project-dir", "my/dir"], + ), + ], +) +def test_operator_execute_with_flags(operator_class, kwargs, expected_cmd): + task = operator_class( + task_id="my-task", + project_dir="my/dir", + **kwargs, + ) + + with patch( + "airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator.hook", + is_in_cluster=False, + ), patch("airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator.cleanup"), patch( + "airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator.get_or_create_pod", + side_effect=ValueError("Mock"), + ) as get_or_create_pod: + try: + task.execute(context={}) + except ValueError as e: + if e != get_or_create_pod.side_effect: + raise + + pod_args = get_or_create_pod.call_args.kwargs["pod_request_obj"].to_dict()["spec"]["containers"][0]["args"] + + assert expected_cmd == pod_args