diff --git a/cosmos/operators/azure_container_instance.py b/cosmos/operators/azure_container_instance.py index 397a47551..d8427b2fb 100644 --- a/cosmos/operators/azure_container_instance.py +++ b/cosmos/operators/azure_container_instance.py @@ -28,7 +28,7 @@ ) -class DbtAzureContainerInstanceBaseOperator(AzureContainerInstancesOperator, AbstractDbtBaseOperator): # type: ignore +class DbtAzureContainerInstanceBaseOperator(AbstractDbtBaseOperator, AzureContainerInstancesOperator): # type: ignore """ Executes a dbt core cli command in an Azure Container Instance """ @@ -66,7 +66,7 @@ def __init__( def build_and_run_cmd(self, context: Context, cmd_flags: list[str] | None = None) -> None: self.build_command(context, cmd_flags) self.log.info(f"Running command: {self.command}") - result = super().execute(context) + result = AzureContainerInstancesOperator.execute(self, context) logger.info(result) def build_command(self, context: Context, cmd_flags: list[str] | None = None) -> None: @@ -79,13 +79,13 @@ def build_command(self, context: Context, cmd_flags: list[str] | None = None) -> self.command: list[str] = dbt_cmd -class DbtLSAzureContainerInstanceOperator(DbtLSMixin, DbtAzureContainerInstanceBaseOperator): +class DbtLSAzureContainerInstanceOperator(DbtLSMixin, DbtAzureContainerInstanceBaseOperator): # type: ignore """ Executes a dbt core ls command. """ -class DbtSeedAzureContainerInstanceOperator(DbtSeedMixin, DbtAzureContainerInstanceBaseOperator): +class DbtSeedAzureContainerInstanceOperator(DbtSeedMixin, DbtAzureContainerInstanceBaseOperator): # type: ignore """ Executes a dbt core seed command. @@ -95,14 +95,14 @@ class DbtSeedAzureContainerInstanceOperator(DbtSeedMixin, DbtAzureContainerInsta template_fields: Sequence[str] = DbtAzureContainerInstanceBaseOperator.template_fields + DbtRunMixin.template_fields # type: ignore[operator] -class DbtSnapshotAzureContainerInstanceOperator(DbtSnapshotMixin, DbtAzureContainerInstanceBaseOperator): +class DbtSnapshotAzureContainerInstanceOperator(DbtSnapshotMixin, DbtAzureContainerInstanceBaseOperator): # type: ignore """ Executes a dbt core snapshot command. """ -class DbtRunAzureContainerInstanceOperator(DbtRunMixin, DbtAzureContainerInstanceBaseOperator): +class DbtRunAzureContainerInstanceOperator(DbtRunMixin, DbtAzureContainerInstanceBaseOperator): # type: ignore """ Executes a dbt core run command. """ @@ -110,7 +110,7 @@ class DbtRunAzureContainerInstanceOperator(DbtRunMixin, DbtAzureContainerInstanc template_fields: Sequence[str] = DbtAzureContainerInstanceBaseOperator.template_fields + DbtRunMixin.template_fields # type: ignore[operator] -class DbtTestAzureContainerInstanceOperator(DbtTestMixin, DbtAzureContainerInstanceBaseOperator): +class DbtTestAzureContainerInstanceOperator(DbtTestMixin, DbtAzureContainerInstanceBaseOperator): # type: ignore """ Executes a dbt core test command. """ @@ -121,7 +121,7 @@ def __init__(self, on_warning_callback: Callable[..., Any] | None = None, **kwar self.on_warning_callback = on_warning_callback -class DbtRunOperationAzureContainerInstanceOperator(DbtRunOperationMixin, DbtAzureContainerInstanceBaseOperator): +class DbtRunOperationAzureContainerInstanceOperator(DbtRunOperationMixin, DbtAzureContainerInstanceBaseOperator): # type: ignore """ Executes a dbt core run-operation command. diff --git a/tests/operators/test_azure_container_instance.py b/tests/operators/test_azure_container_instance.py index a99525ac3..84d733ce3 100644 --- a/tests/operators/test_azure_container_instance.py +++ b/tests/operators/test_azure_container_instance.py @@ -145,3 +145,25 @@ def test_dbt_azure_container_instance_build_command(): "start_time: '{{ data_interval_start.strftime(''%Y%m%d%H%M%S'') }}'\n", "--no-version-check", ] + + +@patch("cosmos.operators.azure_container_instance.AzureContainerInstancesOperator.execute") +def test_dbt_azure_container_instance_build_and_run_cmd(mock_execute): + dbt_base_operator = ConcreteDbtAzureContainerInstanceOperator( + ci_conn_id="my_airflow_connection", + task_id="my-task", + image="my_image", + region="Mordor", + name="my-aci", + resource_group="my-rg", + project_dir="my/dir", + environment_variables={"FOO": "BAR"}, + ) + mock_build_command = MagicMock() + dbt_base_operator.build_command = mock_build_command + + mock_context = MagicMock() + dbt_base_operator.build_and_run_cmd(context=mock_context) + + mock_build_command.assert_called_with(mock_context, None) + mock_execute.assert_called_once_with(dbt_base_operator, mock_context)