diff --git a/airflow/providers/google/cloud/links/automl.py b/airflow/providers/google/cloud/links/automl.py index 79561d5b481320..3e62a1db30b73a 100644 --- a/airflow/providers/google/cloud/links/automl.py +++ b/airflow/providers/google/cloud/links/automl.py @@ -21,6 +21,9 @@ from typing import TYPE_CHECKING +from deprecated import deprecated + +from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.providers.google.cloud.links.base import BaseGoogleLink if TYPE_CHECKING: @@ -44,6 +47,13 @@ ) +@deprecated( + reason=( + "Class `AutoMLDatasetLink` has been deprecated and will be removed after 31.12.2024. " + "Please use `TranslationLegacyDatasetLink` from `airflow/providers/google/cloud/links/translate.py` instead." + ), + category=AirflowProviderDeprecationWarning, +) class AutoMLDatasetLink(BaseGoogleLink): """Helper class for constructing AutoML Dataset link.""" @@ -65,6 +75,13 @@ def persist( ) +@deprecated( + reason=( + "Class `AutoMLDatasetListLink` has been deprecated and will be removed after 31.12.2024. " + "Please use `TranslationDatasetListLink` from `airflow/providers/google/cloud/links/translate.py` instead." + ), + category=AirflowProviderDeprecationWarning, +) class AutoMLDatasetListLink(BaseGoogleLink): """Helper class for constructing AutoML Dataset List link.""" @@ -87,6 +104,13 @@ def persist( ) +@deprecated( + reason=( + "Class `AutoMLModelLink` has been deprecated and will be removed after 31.12.2024. " + "Please use `TranslationLegacyModelLink` from `airflow/providers/google/cloud/links/translate.py` instead." + ), + category=AirflowProviderDeprecationWarning, +) class AutoMLModelLink(BaseGoogleLink): """Helper class for constructing AutoML Model link.""" @@ -114,6 +138,13 @@ def persist( ) +@deprecated( + reason=( + "Class `AutoMLModelTrainLink` has been deprecated and will be removed after 31.12.2024. " + "Please use `TranslationLegacyModelTrainLink` from `airflow/providers/google/cloud/links/translate.py` instead." + ), + category=AirflowProviderDeprecationWarning, +) class AutoMLModelTrainLink(BaseGoogleLink): """Helper class for constructing AutoML Model Train link.""" @@ -138,6 +169,13 @@ def persist( ) +@deprecated( + reason=( + "Class `AutoMLModelPredictLink` has been deprecated and will be removed after 31.12.2024. " + "Please use `TranslationLegacyModelPredictLink` from `airflow/providers/google/cloud/links/translate.py` instead." + ), + category=AirflowProviderDeprecationWarning, +) class AutoMLModelPredictLink(BaseGoogleLink): """Helper class for constructing AutoML Model Predict link.""" diff --git a/airflow/providers/google/cloud/links/translate.py b/airflow/providers/google/cloud/links/translate.py new file mode 100644 index 00000000000000..074ce637671323 --- /dev/null +++ b/airflow/providers/google/cloud/links/translate.py @@ -0,0 +1,180 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains Google Translate links.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from airflow.providers.google.cloud.links.base import BASE_LINK, BaseGoogleLink + +if TYPE_CHECKING: + from airflow.utils.context import Context + + +TRANSLATION_BASE_LINK = BASE_LINK + "/translation" +TRANSLATION_LEGACY_DATASET_LINK = ( + TRANSLATION_BASE_LINK + "/locations/{location}/datasets/{dataset_id}/sentences?project={project_id}" +) +TRANSLATION_DATASET_LIST_LINK = TRANSLATION_BASE_LINK + "/datasets?project={project_id}" +TRANSLATION_LEGACY_MODEL_LINK = ( + TRANSLATION_BASE_LINK + + "/locations/{location}/datasets/{dataset_id}/evaluate;modelId={model_id}?project={project_id}" +) +TRANSLATION_LEGACY_MODEL_TRAIN_LINK = ( + TRANSLATION_BASE_LINK + "/locations/{location}/datasets/{dataset_id}/train?project={project_id}" +) +TRANSLATION_LEGACY_MODEL_PREDICT_LINK = ( + TRANSLATION_BASE_LINK + + "/locations/{location}/datasets/{dataset_id}/predict;modelId={model_id}?project={project_id}" +) + + +class TranslationLegacyDatasetLink(BaseGoogleLink): + """ + Helper class for constructing Legacy Translation Dataset link. + + Legacy Datasets are created and managed by AutoML API. + """ + + name = "Translation Legacy Dataset" + key = "translation_legacy_dataset" + format_str = TRANSLATION_LEGACY_DATASET_LINK + + @staticmethod + def persist( + context: Context, + task_instance, + dataset_id: str, + project_id: str, + ): + task_instance.xcom_push( + context, + key=TranslationLegacyDatasetLink.key, + value={"location": task_instance.location, "dataset_id": dataset_id, "project_id": project_id}, + ) + + +class TranslationDatasetListLink(BaseGoogleLink): + """Helper class for constructing Translation Dataset List link.""" + + name = "Translation Dataset List" + key = "translation_dataset_list" + format_str = TRANSLATION_DATASET_LIST_LINK + + @staticmethod + def persist( + context: Context, + task_instance, + project_id: str, + ): + task_instance.xcom_push( + context, + key=TranslationDatasetListLink.key, + value={ + "project_id": project_id, + }, + ) + + +class TranslationLegacyModelLink(BaseGoogleLink): + """ + Helper class for constructing Translation Legacy Model link. + + Legacy Models are created and managed by AutoML API. + """ + + name = "Translation Legacy Model" + key = "translation_legacy_model" + format_str = TRANSLATION_LEGACY_MODEL_LINK + + @staticmethod + def persist( + context: Context, + task_instance, + dataset_id: str, + model_id: str, + project_id: str, + ): + task_instance.xcom_push( + context, + key=TranslationLegacyModelLink.key, + value={ + "location": task_instance.location, + "dataset_id": dataset_id, + "model_id": model_id, + "project_id": project_id, + }, + ) + + +class TranslationLegacyModelTrainLink(BaseGoogleLink): + """ + Helper class for constructing Translation Legacy Model Train link. + + Legacy Models are created and managed by AutoML API. + """ + + name = "Translation Legacy Model Train" + key = "translation_legacy_model_train" + format_str = TRANSLATION_LEGACY_MODEL_TRAIN_LINK + + @staticmethod + def persist( + context: Context, + task_instance, + project_id: str, + ): + task_instance.xcom_push( + context, + key=TranslationLegacyModelTrainLink.key, + value={ + "location": task_instance.location, + "dataset_id": task_instance.model["dataset_id"], + "project_id": project_id, + }, + ) + + +class TranslationLegacyModelPredictLink(BaseGoogleLink): + """ + Helper class for constructing Translation Legacy Model Predict link. + + Legacy Models are created and managed by AutoML API. + """ + + name = "Translation Legacy Model Predict" + key = "translation_legacy_model_predict" + format_str = TRANSLATION_LEGACY_MODEL_PREDICT_LINK + + @staticmethod + def persist( + context: Context, + task_instance, + model_id: str, + project_id: str, + ): + task_instance.xcom_push( + context, + key=TranslationLegacyModelPredictLink.key, + value={ + "location": task_instance.location, + "dataset_id": task_instance.model.dataset_id, + "model_id": model_id, + "project_id": project_id, + }, + ) diff --git a/airflow/providers/google/cloud/operators/automl.py b/airflow/providers/google/cloud/operators/automl.py index d5dbb3b9209967..64c1c381519c46 100644 --- a/airflow/providers/google/cloud/operators/automl.py +++ b/airflow/providers/google/cloud/operators/automl.py @@ -38,12 +38,12 @@ from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook from airflow.providers.google.cloud.hooks.vertex_ai.prediction_service import PredictionServiceHook -from airflow.providers.google.cloud.links.automl import ( - AutoMLDatasetLink, - AutoMLDatasetListLink, - AutoMLModelLink, - AutoMLModelPredictLink, - AutoMLModelTrainLink, +from airflow.providers.google.cloud.links.translate import ( + TranslationDatasetListLink, + TranslationLegacyDatasetLink, + TranslationLegacyModelLink, + TranslationLegacyModelPredictLink, + TranslationLegacyModelTrainLink, ) from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID @@ -119,8 +119,8 @@ class AutoMLTrainModelOperator(GoogleCloudBaseOperator): "impersonation_chain", ) operator_extra_links = ( - AutoMLModelTrainLink(), - AutoMLModelLink(), + TranslationLegacyModelTrainLink(), + TranslationLegacyModelLink(), ) def __init__( @@ -173,7 +173,9 @@ def execute(self, context: Context): ) project_id = self.project_id or hook.project_id if project_id: - AutoMLModelTrainLink.persist(context=context, task_instance=self, project_id=project_id) + TranslationLegacyModelTrainLink.persist( + context=context, task_instance=self, project_id=project_id + ) operation_result = hook.wait_for_operation(timeout=self.timeout, operation=operation) result = Model.to_dict(operation_result) model_id = hook.extract_object_id(result) @@ -181,7 +183,7 @@ def execute(self, context: Context): self.xcom_push(context, key="model_id", value=model_id) if project_id: - AutoMLModelLink.persist( + TranslationLegacyModelLink.persist( context=context, task_instance=self, dataset_id=self.model["dataset_id"] or "-", @@ -195,6 +197,9 @@ class AutoMLPredictOperator(GoogleCloudBaseOperator): """ Runs prediction operation on Google Cloud AutoML. + AutoMLPredictOperator for text, image, and video prediction has been deprecated. + Please use endpoint_id param instead of model_id param. + .. seealso:: For more information on how to use this operator, take a look at the guide: :ref:`howto/operator:AutoMLPredictOperator` @@ -228,7 +233,7 @@ class AutoMLPredictOperator(GoogleCloudBaseOperator): "project_id", "impersonation_chain", ) - operator_extra_links = (AutoMLModelPredictLink(),) + operator_extra_links = (TranslationLegacyModelPredictLink(),) def __init__( self, @@ -325,7 +330,7 @@ def execute(self, context: Context): project_id = self.project_id or hook.project_id if project_id and self.model_id: - AutoMLModelPredictLink.persist( + TranslationLegacyModelPredictLink.persist( context=context, task_instance=self, model_id=self.model_id, @@ -389,7 +394,7 @@ class AutoMLBatchPredictOperator(GoogleCloudBaseOperator): "project_id", "impersonation_chain", ) - operator_extra_links = (AutoMLModelPredictLink(),) + operator_extra_links = (TranslationLegacyModelPredictLink(),) def __init__( self, @@ -426,7 +431,7 @@ def execute(self, context: Context): gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, ) - model: Model = hook.get_model( + self.model: Model = hook.get_model( model_id=self.model_id, location=self.location, project_id=self.project_id, @@ -435,7 +440,7 @@ def execute(self, context: Context): metadata=self.metadata, ) - if not hasattr(model, "translation_model_metadata"): + if not hasattr(self.model, "translation_model_metadata"): _raise_exception_for_deprecated_operator( self.__class__.__name__, [ @@ -462,7 +467,7 @@ def execute(self, context: Context): self.log.info("Batch prediction is ready.") project_id = self.project_id or hook.project_id if project_id: - AutoMLModelPredictLink.persist( + TranslationLegacyModelPredictLink.persist( context=context, task_instance=self, model_id=self.model_id, @@ -511,7 +516,7 @@ class AutoMLCreateDatasetOperator(GoogleCloudBaseOperator): "project_id", "impersonation_chain", ) - operator_extra_links = (AutoMLDatasetLink(),) + operator_extra_links = (TranslationLegacyDatasetLink(),) def __init__( self, @@ -560,7 +565,7 @@ def execute(self, context: Context): self.xcom_push(context, key="dataset_id", value=dataset_id) project_id = self.project_id or hook.project_id if project_id: - AutoMLDatasetLink.persist( + TranslationLegacyDatasetLink.persist( context=context, task_instance=self, dataset_id=dataset_id, @@ -611,7 +616,7 @@ class AutoMLImportDataOperator(GoogleCloudBaseOperator): "project_id", "impersonation_chain", ) - operator_extra_links = (AutoMLDatasetLink(),) + operator_extra_links = (TranslationLegacyDatasetLink(),) def __init__( self, @@ -668,7 +673,7 @@ def execute(self, context: Context): self.log.info("Import is completed") project_id = self.project_id or hook.project_id if project_id: - AutoMLDatasetLink.persist( + TranslationLegacyDatasetLink.persist( context=context, task_instance=self, dataset_id=self.dataset_id, @@ -722,7 +727,7 @@ class AutoMLTablesListColumnSpecsOperator(GoogleCloudBaseOperator): "project_id", "impersonation_chain", ) - operator_extra_links = (AutoMLDatasetLink(),) + operator_extra_links = (TranslationLegacyDatasetLink(),) def __init__( self, @@ -777,7 +782,7 @@ def execute(self, context: Context): self.log.info("Columns specs obtained.") project_id = self.project_id or hook.project_id if project_id: - AutoMLDatasetLink.persist( + TranslationLegacyDatasetLink.persist( context=context, task_instance=self, dataset_id=self.dataset_id, @@ -834,7 +839,7 @@ class AutoMLTablesUpdateDatasetOperator(GoogleCloudBaseOperator): "location", "impersonation_chain", ) - operator_extra_links = (AutoMLDatasetLink(),) + operator_extra_links = (TranslationLegacyDatasetLink(),) def __init__( self, @@ -876,7 +881,7 @@ def execute(self, context: Context): self.log.info("Dataset updated.") project_id = hook.project_id if project_id: - AutoMLDatasetLink.persist( + TranslationLegacyDatasetLink.persist( context=context, task_instance=self, dataset_id=hook.extract_object_id(self.dataset), @@ -924,7 +929,7 @@ class AutoMLGetModelOperator(GoogleCloudBaseOperator): "project_id", "impersonation_chain", ) - operator_extra_links = (AutoMLModelLink(),) + operator_extra_links = (TranslationLegacyModelLink(),) def __init__( self, @@ -968,7 +973,7 @@ def execute(self, context: Context): model = Model.to_dict(result) project_id = self.project_id or hook.project_id if project_id: - AutoMLModelLink.persist( + TranslationLegacyModelLink.persist( context=context, task_instance=self, dataset_id=model["dataset_id"], @@ -1223,7 +1228,7 @@ class AutoMLTablesListTableSpecsOperator(GoogleCloudBaseOperator): "project_id", "impersonation_chain", ) - operator_extra_links = (AutoMLDatasetLink(),) + operator_extra_links = (TranslationLegacyDatasetLink(),) def __init__( self, @@ -1273,7 +1278,7 @@ def execute(self, context: Context): self.log.info("Table specs obtained.") project_id = self.project_id or hook.project_id if project_id: - AutoMLDatasetLink.persist( + TranslationLegacyDatasetLink.persist( context=context, task_instance=self, dataset_id=self.dataset_id, @@ -1318,7 +1323,7 @@ class AutoMLListDatasetOperator(GoogleCloudBaseOperator): "project_id", "impersonation_chain", ) - operator_extra_links = (AutoMLDatasetListLink(),) + operator_extra_links = (TranslationDatasetListLink(),) def __init__( self, @@ -1373,7 +1378,7 @@ def execute(self, context: Context): ) project_id = self.project_id or hook.project_id if project_id: - AutoMLDatasetListLink.persist(context=context, task_instance=self, project_id=project_id) + TranslationDatasetListLink.persist(context=context, task_instance=self, project_id=project_id) return result diff --git a/airflow/providers/google/provider.yaml b/airflow/providers/google/provider.yaml index d5a656a29b401e..549f59852036a6 100644 --- a/airflow/providers/google/provider.yaml +++ b/airflow/providers/google/provider.yaml @@ -1274,6 +1274,12 @@ extra-links: - airflow.providers.google.common.links.storage.StorageLink - airflow.providers.google.common.links.storage.FileDetailsLink - airflow.providers.google.marketing_platform.links.analytics_admin.GoogleAnalyticsPropertyLink + - airflow.providers.google.cloud.links.translate.TranslationLegacyDatasetLink + - airflow.providers.google.cloud.links.translate.TranslationDatasetListLink + - airflow.providers.google.cloud.links.translate.TranslationLegacyModelLink + - airflow.providers.google.cloud.links.translate.TranslationLegacyModelTrainLink + - airflow.providers.google.cloud.links.translate.TranslationLegacyModelPredictLink + secrets-backends: - airflow.providers.google.cloud.secrets.secret_manager.CloudSecretManagerBackend diff --git a/tests/providers/google/cloud/links/__init__.py b/tests/providers/google/cloud/links/__init__.py new file mode 100644 index 00000000000000..13a83393a9124b --- /dev/null +++ b/tests/providers/google/cloud/links/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/providers/google/cloud/links/test_translate.py b/tests/providers/google/cloud/links/test_translate.py new file mode 100644 index 00000000000000..b6ecd16ce4c8d8 --- /dev/null +++ b/tests/providers/google/cloud/links/test_translate.py @@ -0,0 +1,150 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pytest +from google.cloud.automl_v1beta1 import Model + +from airflow.providers.google.cloud.links.translate import ( + TRANSLATION_BASE_LINK, + TranslationDatasetListLink, + TranslationLegacyDatasetLink, + TranslationLegacyModelLink, + TranslationLegacyModelPredictLink, + TranslationLegacyModelTrainLink, +) +from airflow.providers.google.cloud.operators.automl import ( + AutoMLBatchPredictOperator, + AutoMLCreateDatasetOperator, + AutoMLListDatasetOperator, + AutoMLTrainModelOperator, +) + +GCP_LOCATION = "test-location" +GCP_PROJECT_ID = "test-project" +DATASET = "test-dataset" +MODEL = "test-model" + + +class TestTranslationLegacyDatasetLink: + @pytest.mark.db_test + def test_get_link(self, create_task_instance_of_operator): + expected_url = f"{TRANSLATION_BASE_LINK}/locations/{GCP_LOCATION}/datasets/{DATASET}/sentences?project={GCP_PROJECT_ID}" + link = TranslationLegacyDatasetLink() + ti = create_task_instance_of_operator( + AutoMLCreateDatasetOperator, + dag_id="test_legacy_dataset_link_dag", + task_id="test_legacy_dataset_link_task", + dataset=DATASET, + location=GCP_LOCATION, + ) + link.persist(context={"ti": ti}, task_instance=ti.task, dataset_id=DATASET, project_id=GCP_PROJECT_ID) + actual_url = link.get_link(operator=ti.task, ti_key=ti.key) + assert actual_url == expected_url + + +class TestTranslationDatasetListLink: + @pytest.mark.db_test + def test_get_link(self, create_task_instance_of_operator): + expected_url = f"{TRANSLATION_BASE_LINK}/datasets?project={GCP_PROJECT_ID}" + link = TranslationDatasetListLink() + ti = create_task_instance_of_operator( + AutoMLListDatasetOperator, + dag_id="test_dataset_list_link_dag", + task_id="test_dataset_list_link_task", + location=GCP_LOCATION, + ) + link.persist(context={"ti": ti}, task_instance=ti.task, project_id=GCP_PROJECT_ID) + actual_url = link.get_link(operator=ti.task, ti_key=ti.key) + assert actual_url == expected_url + + +class TestTranslationLegacyModelLink: + @pytest.mark.db_test + def test_get_link(self, create_task_instance_of_operator): + expected_url = ( + f"{TRANSLATION_BASE_LINK}/locations/{GCP_LOCATION}/datasets/{DATASET}/" + f"evaluate;modelId={MODEL}?project={GCP_PROJECT_ID}" + ) + link = TranslationLegacyModelLink() + ti = create_task_instance_of_operator( + AutoMLTrainModelOperator, + dag_id="test_legacy_model_link_dag", + task_id="test_legacy_model_link_task", + model=MODEL, + project_id=GCP_PROJECT_ID, + location=GCP_LOCATION, + ) + link.persist( + context={"ti": ti}, + task_instance=ti.task, + dataset_id=DATASET, + model_id=MODEL, + project_id=GCP_PROJECT_ID, + ) + actual_url = link.get_link(operator=ti.task, ti_key=ti.key) + assert actual_url == expected_url + + +class TestTranslationLegacyModelTrainLink: + @pytest.mark.db_test + def test_get_link(self, create_task_instance_of_operator): + expected_url = ( + f"{TRANSLATION_BASE_LINK}/locations/{GCP_LOCATION}/datasets/{DATASET}/" + f"train?project={GCP_PROJECT_ID}" + ) + link = TranslationLegacyModelTrainLink() + ti = create_task_instance_of_operator( + AutoMLTrainModelOperator, + dag_id="test_legacy_model_train_link_dag", + task_id="test_legacy_model_train_link_task", + model={"dataset_id": DATASET}, + project_id=GCP_PROJECT_ID, + location=GCP_LOCATION, + ) + link.persist( + context={"ti": ti}, + task_instance=ti.task, + project_id=GCP_PROJECT_ID, + ) + actual_url = link.get_link(operator=ti.task, ti_key=ti.key) + assert actual_url == expected_url + + +class TestTranslationLegacyModelPredictLink: + @pytest.mark.db_test + def test_get_link(self, create_task_instance_of_operator): + expected_url = ( + f"{TRANSLATION_BASE_LINK}/locations/{GCP_LOCATION}/datasets/{DATASET}/" + f"predict;modelId={MODEL}?project={GCP_PROJECT_ID}" + ) + link = TranslationLegacyModelPredictLink() + ti = create_task_instance_of_operator( + AutoMLBatchPredictOperator, + dag_id="test_legacy_model_predict_link_dag", + task_id="test_legacy_model_predict_link_task", + model_id=MODEL, + project_id=GCP_PROJECT_ID, + location=GCP_LOCATION, + input_config="input_config", + output_config="input_config", + ) + ti.task.model = Model(dataset_id=DATASET, display_name=MODEL) + link.persist(context={"ti": ti}, task_instance=ti.task, model_id=MODEL, project_id=GCP_PROJECT_ID) + actual_url = link.get_link(operator=ti.task, ti_key=ti.key) + assert actual_url == expected_url diff --git a/tests/providers/google/cloud/operators/test_automl.py b/tests/providers/google/cloud/operators/test_automl.py index cecda2bf23a21f..f7bef5452193fb 100644 --- a/tests/providers/google/cloud/operators/test_automl.py +++ b/tests/providers/google/cloud/operators/test_automl.py @@ -139,12 +139,13 @@ def test_templating(self, create_task_instance_of_operator): class TestAutoMLBatchPredictOperator: + @mock.patch("airflow.providers.google.cloud.links.translate.TranslationLegacyModelPredictLink.persist") @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook") - def test_execute(self, mock_hook): + def test_execute(self, mock_hook, mock_link_persist): mock_hook.return_value.batch_predict.return_value.result.return_value = BatchPredictResult() mock_hook.return_value.extract_object_id = extract_object_id mock_hook.return_value.wait_for_operation.return_value = BatchPredictResult() - + mock_context = {"ti": mock.MagicMock()} op = AutoMLBatchPredictOperator( model_id=MODEL_ID, location=GCP_LOCATION, @@ -154,7 +155,7 @@ def test_execute(self, mock_hook): task_id=TASK_ID, prediction_params={}, ) - op.execute(context=mock.MagicMock()) + op.execute(context=mock_context) mock_hook.return_value.batch_predict.assert_called_once_with( input_config=INPUT_CONFIG, location=GCP_LOCATION, @@ -166,6 +167,12 @@ def test_execute(self, mock_hook): retry=DEFAULT, timeout=None, ) + mock_link_persist.assert_called_once_with( + context=mock_context, + task_instance=op, + model_id=MODEL_ID, + project_id=GCP_PROJECT_ID, + ) @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook") def test_execute_deprecated(self, mock_hook): @@ -226,10 +233,11 @@ def test_templating(self, create_task_instance_of_operator): class TestAutoMLPredictOperator: + @mock.patch("airflow.providers.google.cloud.links.translate.TranslationLegacyModelPredictLink.persist") @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook") - def test_execute(self, mock_hook): + def test_execute(self, mock_hook, mock_link_persist): mock_hook.return_value.predict.return_value = PredictResponse() - + mock_context = {"ti": mock.MagicMock()} op = AutoMLPredictOperator( model_id=MODEL_ID, location=GCP_LOCATION, @@ -238,7 +246,7 @@ def test_execute(self, mock_hook): task_id=TASK_ID, operation_params={"TEST_KEY": "TEST_VALUE"}, ) - op.execute(context=mock.MagicMock()) + op.execute(context=mock_context) mock_hook.return_value.predict.assert_called_once_with( location=GCP_LOCATION, metadata=(), @@ -249,6 +257,12 @@ def test_execute(self, mock_hook): retry=DEFAULT, timeout=None, ) + mock_link_persist.assert_called_once_with( + context=mock_context, + task_instance=op, + model_id=MODEL_ID, + project_id=GCP_PROJECT_ID, + ) @pytest.mark.db_test def test_templating(self, create_task_instance_of_operator):