From 5c4f6e6f83a1ddd1982951ee2b85c3c3f215cf88 Mon Sep 17 00:00:00 2001 From: Shahar Epstein Date: Mon, 11 Mar 2024 21:24:00 +0200 Subject: [PATCH] Add AutoML templating tests --- .../google/cloud/operators/test_automl.py | 286 ++++++++++++++++++ 1 file changed, 286 insertions(+) diff --git a/tests/providers/google/cloud/operators/test_automl.py b/tests/providers/google/cloud/operators/test_automl.py index 2c9872f4450bb..4f00f76a2dbef 100644 --- a/tests/providers/google/cloud/operators/test_automl.py +++ b/tests/providers/google/cloud/operators/test_automl.py @@ -20,6 +20,7 @@ import copy from unittest import mock +import pytest from google.api_core.gapic_v1.method import DEFAULT from google.cloud.automl_v1beta1 import BatchPredictResult, Dataset, Model, PredictResponse @@ -39,6 +40,7 @@ AutoMLTablesUpdateDatasetOperator, AutoMLTrainModelOperator, ) +from airflow.utils import timezone CREDENTIALS = "test-creds" TASK_ID = "test-automl-hook" @@ -88,6 +90,25 @@ def test_execute(self, mock_hook): metadata=(), ) + @pytest.mark.db_test + def test_templating(self, create_task_instance_of_operator): + ti = create_task_instance_of_operator( + AutoMLTrainModelOperator, + # Templated fields + model="{{ 'model' }}", + location="{{ 'location' }}", + impersonation_chain="{{ 'impersonation_chain' }}", + # Other parameters + dag_id="test_template_body_templating_dag", + task_id="test_template_body_templating_task", + execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc), + ) + ti.render_templates() + task: AutoMLTrainModelOperator = ti.task + assert task.model == "model" + assert task.location == "location" + assert task.impersonation_chain == "impersonation_chain" + class TestAutoMLBatchPredictOperator: @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook") @@ -118,6 +139,31 @@ def test_execute(self, mock_hook): timeout=None, ) + @pytest.mark.db_test + def test_templating(self, create_task_instance_of_operator): + ti = create_task_instance_of_operator( + AutoMLBatchPredictOperator, + # Templated fields + model_id="{{ 'model' }}", + input_config="{{ 'input-config' }}", + output_config="{{ 'output-config' }}", + location="{{ 'location' }}", + project_id="{{ 'project-id' }}", + impersonation_chain="{{ 'impersonation-chain' }}", + # Other parameters + dag_id="test_template_body_templating_dag", + task_id="test_template_body_templating_task", + execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc), + ) + ti.render_templates() + task: AutoMLBatchPredictOperator = ti.task + assert task.model_id == "model" + assert task.input_config == "input-config" + assert task.output_config == "output-config" + assert task.location == "location" + assert task.project_id == "project-id" + assert task.impersonation_chain == "impersonation-chain" + class TestAutoMLPredictOperator: @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook") @@ -144,6 +190,28 @@ def test_execute(self, mock_hook): timeout=None, ) + @pytest.mark.db_test + def test_templating(self, create_task_instance_of_operator): + ti = create_task_instance_of_operator( + AutoMLPredictOperator, + # Templated fields + model_id="{{ 'model-id' }}", + location="{{ 'location' }}", + project_id="{{ 'project-id' }}", + impersonation_chain="{{ 'impersonation-chain' }}", + # Other parameters + dag_id="test_template_body_templating_dag", + task_id="test_template_body_templating_task", + execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc), + payload={}, + ) + ti.render_templates() + task: AutoMLPredictOperator = ti.task + assert task.model_id == "model-id" + assert task.project_id == "project-id" + assert task.location == "location" + assert task.impersonation_chain == "impersonation-chain" + class TestAutoMLCreateImportOperator: @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook") @@ -167,6 +235,27 @@ def test_execute(self, mock_hook): timeout=None, ) + @pytest.mark.db_test + def test_templating(self, create_task_instance_of_operator): + ti = create_task_instance_of_operator( + AutoMLCreateDatasetOperator, + # Templated fields + dataset="{{ 'dataset' }}", + location="{{ 'location' }}", + project_id="{{ 'project-id' }}", + impersonation_chain="{{ 'impersonation-chain' }}", + # Other parameters + dag_id="test_template_body_templating_dag", + task_id="test_template_body_templating_task", + execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc), + ) + ti.render_templates() + task: AutoMLCreateDatasetOperator = ti.task + assert task.dataset == "dataset" + assert task.project_id == "project-id" + assert task.location == "location" + assert task.impersonation_chain == "impersonation-chain" + class TestAutoMLListColumnsSpecsOperator: @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook") @@ -199,6 +288,33 @@ def test_execute(self, mock_hook): timeout=None, ) + @pytest.mark.db_test + def test_templating(self, create_task_instance_of_operator): + ti = create_task_instance_of_operator( + AutoMLTablesListColumnSpecsOperator, + # Templated fields + dataset_id="{{ 'dataset-id' }}", + table_spec_id="{{ 'table-spec-id' }}", + field_mask="{{ 'field-mask' }}", + filter_="{{ 'filter-' }}", + location="{{ 'location' }}", + project_id="{{ 'project-id' }}", + impersonation_chain="{{ 'impersonation-chain' }}", + # Other parameters + dag_id="test_template_body_templating_dag", + task_id="test_template_body_templating_task", + execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc), + ) + ti.render_templates() + task: AutoMLTablesListColumnSpecsOperator = ti.task + assert task.dataset_id == "dataset-id" + assert task.table_spec_id == "table-spec-id" + assert task.field_mask == "field-mask" + assert task.filter_ == "filter-" + assert task.location == "location" + assert task.project_id == "project-id" + assert task.impersonation_chain == "impersonation-chain" + class TestAutoMLUpdateDatasetOperator: @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook") @@ -223,6 +339,27 @@ def test_execute(self, mock_hook): update_mask=MASK, ) + @pytest.mark.db_test + def test_templating(self, create_task_instance_of_operator): + ti = create_task_instance_of_operator( + AutoMLTablesUpdateDatasetOperator, + # Templated fields + dataset="{{ 'dataset' }}", + update_mask="{{ 'update-mask' }}", + location="{{ 'location' }}", + impersonation_chain="{{ 'impersonation-chain' }}", + # Other parameters + dag_id="test_template_body_templating_dag", + task_id="test_template_body_templating_task", + execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc), + ) + ti.render_templates() + task: AutoMLTablesUpdateDatasetOperator = ti.task + assert task.dataset == "dataset" + assert task.update_mask == "update-mask" + assert task.location == "location" + assert task.impersonation_chain == "impersonation-chain" + class TestAutoMLGetModelOperator: @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook") @@ -246,6 +383,27 @@ def test_execute(self, mock_hook): timeout=None, ) + @pytest.mark.db_test + def test_templating(self, create_task_instance_of_operator): + ti = create_task_instance_of_operator( + AutoMLGetModelOperator, + # Templated fields + model_id="{{ 'model-id' }}", + location="{{ 'location' }}", + project_id="{{ 'project-id' }}", + impersonation_chain="{{ 'impersonation-chain' }}", + # Other parameters + dag_id="test_template_body_templating_dag", + task_id="test_template_body_templating_task", + execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc), + ) + ti.render_templates() + task: AutoMLGetModelOperator = ti.task + assert task.model_id == "model-id" + assert task.location == "location" + assert task.project_id == "project-id" + assert task.impersonation_chain == "impersonation-chain" + class TestAutoMLDeleteModelOperator: @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook") @@ -266,6 +424,27 @@ def test_execute(self, mock_hook): timeout=None, ) + @pytest.mark.db_test + def test_templating(self, create_task_instance_of_operator): + ti = create_task_instance_of_operator( + AutoMLDeleteModelOperator, + # Templated fields + model_id="{{ 'model-id' }}", + location="{{ 'location' }}", + project_id="{{ 'project-id' }}", + impersonation_chain="{{ 'impersonation-chain' }}", + # Other parameters + dag_id="test_template_body_templating_dag", + task_id="test_template_body_templating_task", + execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc), + ) + ti.render_templates() + task: AutoMLDeleteModelOperator = ti.task + assert task.model_id == "model-id" + assert task.location == "location" + assert task.project_id == "project-id" + assert task.impersonation_chain == "impersonation-chain" + class TestAutoMLDeployModelOperator: @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook") @@ -289,6 +468,27 @@ def test_execute(self, mock_hook): timeout=None, ) + @pytest.mark.db_test + def test_templating(self, create_task_instance_of_operator): + ti = create_task_instance_of_operator( + AutoMLDeployModelOperator, + # Templated fields + model_id="{{ 'model-id' }}", + location="{{ 'location' }}", + project_id="{{ 'project-id' }}", + impersonation_chain="{{ 'impersonation-chain' }}", + # Other parameters + dag_id="test_template_body_templating_dag", + task_id="test_template_body_templating_task", + execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc), + ) + ti.render_templates() + task: AutoMLDeployModelOperator = ti.task + assert task.model_id == "model-id" + assert task.location == "location" + assert task.project_id == "project-id" + assert task.impersonation_chain == "impersonation-chain" + class TestAutoMLDatasetImportOperator: @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook") @@ -311,6 +511,29 @@ def test_execute(self, mock_hook): timeout=None, ) + @pytest.mark.db_test + def test_templating(self, create_task_instance_of_operator): + ti = create_task_instance_of_operator( + AutoMLImportDataOperator, + # Templated fields + dataset_id="{{ 'dataset-id' }}", + input_config="{{ 'input-config' }}", + location="{{ 'location' }}", + project_id="{{ 'project-id' }}", + impersonation_chain="{{ 'impersonation-chain' }}", + # Other parameters + dag_id="test_template_body_templating_dag", + task_id="test_template_body_templating_task", + execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc), + ) + ti.render_templates() + task: AutoMLImportDataOperator = ti.task + assert task.dataset_id == "dataset-id" + assert task.input_config == "input-config" + assert task.location == "location" + assert task.project_id == "project-id" + assert task.impersonation_chain == "impersonation-chain" + class TestAutoMLTablesListTableSpecsOperator: @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook") @@ -338,6 +561,29 @@ def test_execute(self, mock_hook): timeout=None, ) + @pytest.mark.db_test + def test_templating(self, create_task_instance_of_operator): + ti = create_task_instance_of_operator( + AutoMLTablesListTableSpecsOperator, + # Templated fields + dataset_id="{{ 'dataset-id' }}", + filter_="{{ 'filter-' }}", + location="{{ 'location' }}", + project_id="{{ 'project-id' }}", + impersonation_chain="{{ 'impersonation-chain' }}", + # Other parameters + dag_id="test_template_body_templating_dag", + task_id="test_template_body_templating_task", + execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc), + ) + ti.render_templates() + task: AutoMLTablesListTableSpecsOperator = ti.task + assert task.dataset_id == "dataset-id" + assert task.filter_ == "filter-" + assert task.location == "location" + assert task.project_id == "project-id" + assert task.impersonation_chain == "impersonation-chain" + class TestAutoMLDatasetListOperator: @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook") @@ -352,6 +598,25 @@ def test_execute(self, mock_hook): timeout=None, ) + @pytest.mark.db_test + def test_templating(self, create_task_instance_of_operator): + ti = create_task_instance_of_operator( + AutoMLListDatasetOperator, + # Templated fields + location="{{ 'location' }}", + project_id="{{ 'project-id' }}", + impersonation_chain="{{ 'impersonation-chain' }}", + # Other parameters + dag_id="test_template_body_templating_dag", + task_id="test_template_body_templating_task", + execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc), + ) + ti.render_templates() + task: AutoMLListDatasetOperator = ti.task + assert task.location == "location" + assert task.project_id == "project-id" + assert task.impersonation_chain == "impersonation-chain" + class TestAutoMLDatasetDeleteOperator: @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook") @@ -371,3 +636,24 @@ def test_execute(self, mock_hook): retry=DEFAULT, timeout=None, ) + + @pytest.mark.db_test + def test_templating(self, create_task_instance_of_operator): + ti = create_task_instance_of_operator( + AutoMLDeleteDatasetOperator, + # Templated fields + dataset_id="{{ 'dataset-id' }}", + location="{{ 'location' }}", + project_id="{{ 'project-id' }}", + impersonation_chain="{{ 'impersonation-chain' }}", + # Other parameters + dag_id="test_template_body_templating_dag", + task_id="test_template_body_templating_task", + execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc), + ) + ti.render_templates() + task: AutoMLDeleteDatasetOperator = ti.task + assert task.dataset_id == "dataset-id" + assert task.location == "location" + assert task.project_id == "project-id" + assert task.impersonation_chain == "impersonation-chain"