diff --git a/airflow/providers/google/cloud/operators/vertex_ai/hyperparameter_tuning_job.py b/airflow/providers/google/cloud/operators/vertex_ai/hyperparameter_tuning_job.py index 5e5e20aa9f363..2da2bf33946b8 100644 --- a/airflow/providers/google/cloud/operators/vertex_ai/hyperparameter_tuning_job.py +++ b/airflow/providers/google/cloud/operators/vertex_ai/hyperparameter_tuning_job.py @@ -20,14 +20,15 @@ from __future__ import annotations +import warnings from typing import TYPE_CHECKING, Any, Sequence from google.api_core.exceptions import NotFound from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault -from google.cloud.aiplatform_v1.types import HyperparameterTuningJob +from google.cloud.aiplatform_v1 import types from airflow.configuration import conf -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.providers.google.cloud.hooks.vertex_ai.hyperparameter_tuning_job import ( HyperparameterTuningJobHook, ) @@ -40,7 +41,7 @@ if TYPE_CHECKING: from google.api_core.retry import Retry - from google.cloud.aiplatform import gapic, hyperparameter_tuning + from google.cloud.aiplatform import HyperparameterTuningJob, gapic, hyperparameter_tuning from airflow.utils.context import Context @@ -127,8 +128,8 @@ class CreateHyperparameterTuningJobOperator(GoogleCloudBaseOperator): `service_account` is required with provided `tensorboard`. For more information on configuring your service account please visit: https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training - :param sync: Whether to execute this method synchronously. If False, this method will unblock, and it - will be executed in a concurrent Future. + :param sync: (Deprecated) Whether to execute this method synchronously. If False, this method will + unblock, and it will be executed in a concurrent Future. :param gcp_conn_id: The connection ID to use connecting to Google Cloud. :param impersonation_chain: Optional service account to impersonate using short-term credentials, or chained list of accounts required to get the access_token @@ -138,8 +139,7 @@ class CreateHyperparameterTuningJobOperator(GoogleCloudBaseOperator): If set as a sequence, the identities from the list must grant Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). - :param deferrable: Run operator in the deferrable mode. Note that it requires calling the operator - with `sync=False` parameter. + :param deferrable: Run operator in the deferrable mode. :param poll_interval: Interval size which defines how often job status is checked in deferrable mode. """ @@ -221,19 +221,18 @@ def __init__( self.poll_interval = poll_interval def execute(self, context: Context): - if self.deferrable and self.sync: - raise AirflowException( - "Deferrable mode can be used only with sync=False option. " - "If you are willing to run the operator in deferrable mode, please, set sync=False. " - "Otherwise, disable deferrable mode `deferrable=False`." - ) + warnings.warn( + "The 'sync' parameter is deprecated and will be removed after 01.09.2024.", + AirflowProviderDeprecationWarning, + stacklevel=2, + ) self.log.info("Creating Hyperparameter Tuning job") self.hook = HyperparameterTuningJobHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, ) - result = self.hook.create_hyperparameter_tuning_job( + hyperparameter_tuning_job: HyperparameterTuningJob = self.hook.create_hyperparameter_tuning_job( project_id=self.project_id, region=self.region, display_name=self.display_name, @@ -259,14 +258,19 @@ def execute(self, context: Context): restart_job_on_worker_restart=self.restart_job_on_worker_restart, enable_web_access=self.enable_web_access, tensorboard=self.tensorboard, - sync=self.sync, - wait_job_completed=not self.deferrable, + sync=False, + wait_job_completed=False, ) - hyperparameter_tuning_job = result.to_dict() - hyperparameter_tuning_job_id = self.hook.extract_hyperparameter_tuning_job_id( - hyperparameter_tuning_job + hyperparameter_tuning_job.wait_for_resource_creation() + hyperparameter_tuning_job_id = hyperparameter_tuning_job.name + self.log.info("Hyperparameter Tuning job was created. Job id: %s", hyperparameter_tuning_job_id) + + self.xcom_push(context, key="hyperparameter_tuning_job_id", value=hyperparameter_tuning_job_id) + VertexAITrainingLink.persist( + context=context, task_instance=self, training_id=hyperparameter_tuning_job_id ) + if self.deferrable: self.defer( trigger=CreateHyperparameterTuningJobTrigger( @@ -279,14 +283,10 @@ def execute(self, context: Context): ), method_name="execute_complete", ) + return - self.log.info("Hyperparameter Tuning job was created. Job id: %s", hyperparameter_tuning_job_id) - - self.xcom_push(context, key="hyperparameter_tuning_job_id", value=hyperparameter_tuning_job_id) - VertexAITrainingLink.persist( - context=context, task_instance=self, training_id=hyperparameter_tuning_job_id - ) - return hyperparameter_tuning_job + hyperparameter_tuning_job.wait_for_completion() + return hyperparameter_tuning_job.to_dict() def on_kill(self) -> None: """Act as a callback called when the operator is killed; cancel any running job.""" @@ -298,26 +298,7 @@ def execute_complete(self, context: Context, event: dict[str, Any]) -> dict[str, raise AirflowException(event["message"]) job: dict[str, Any] = event["job"] self.log.info("Hyperparameter tuning job %s created and completed successfully.", job["name"]) - hook = HyperparameterTuningJobHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) - job_id = hook.extract_hyperparameter_tuning_job_id(job) - self.xcom_push( - context, - key="hyperparameter_tuning_job_id", - value=job_id, - ) - self.xcom_push( - context, - key="training_conf", - value={ - "training_conf_id": job_id, - "region": self.region, - "project_id": self.project_id, - }, - ) - return event["job"] + return job class GetHyperparameterTuningJobOperator(GoogleCloudBaseOperator): @@ -387,7 +368,7 @@ def execute(self, context: Context): context=context, task_instance=self, training_id=self.hyperparameter_tuning_job_id ) self.log.info("Hyperparameter tuning job was gotten.") - return HyperparameterTuningJob.to_dict(result) + return types.HyperparameterTuningJob.to_dict(result) except NotFound: self.log.info( "The Hyperparameter tuning job %s does not exist.", self.hyperparameter_tuning_job_id @@ -532,4 +513,4 @@ def execute(self, context: Context): metadata=self.metadata, ) VertexAIHyperparameterTuningJobListLink.persist(context=context, task_instance=self) - return [HyperparameterTuningJob.to_dict(result) for result in results] + return [types.HyperparameterTuningJob.to_dict(result) for result in results] diff --git a/tests/providers/google/cloud/operators/test_vertex_ai.py b/tests/providers/google/cloud/operators/test_vertex_ai.py index 0c52ee737996a..57864a0a0241f 100644 --- a/tests/providers/google/cloud/operators/test_vertex_ai.py +++ b/tests/providers/google/cloud/operators/test_vertex_ai.py @@ -1419,7 +1419,7 @@ def test_execute(self, mock_hook): class TestVertexAICreateHyperparameterTuningJobOperator: - @mock.patch(VERTEX_AI_PATH.format("hyperparameter_tuning_job.HyperparameterTuningJob.to_dict")) + @mock.patch(VERTEX_AI_PATH.format("hyperparameter_tuning_job.types.HyperparameterTuningJob.to_dict")) @mock.patch(VERTEX_AI_PATH.format("hyperparameter_tuning_job.HyperparameterTuningJobHook")) def test_execute(self, mock_hook, to_dict_mock): op = CreateHyperparameterTuningJobOperator( @@ -1464,7 +1464,7 @@ def test_execute(self, mock_hook, to_dict_mock): enable_web_access=False, tensorboard=None, sync=False, - wait_job_completed=True, + wait_job_completed=False, ) @mock.patch( @@ -1511,11 +1511,8 @@ def test_deferrable_sync_error(self): with pytest.raises(AirflowException): op.execute(context={"ti": mock.MagicMock()}) - @mock.patch( - VERTEX_AI_PATH.format("hyperparameter_tuning_job.CreateHyperparameterTuningJobOperator.xcom_push") - ) @mock.patch(VERTEX_AI_PATH.format("hyperparameter_tuning_job.HyperparameterTuningJobHook")) - def test_execute_complete(self, mock_hook, mock_xcom_push): + def test_execute_complete(self, mock_hook): test_job_id = "test_job_id" test_job = {"name": f"test/{test_job_id}"} event = { @@ -1544,20 +1541,6 @@ def test_execute_complete(self, mock_hook, mock_xcom_push): result = op.execute_complete(context=mock_context, event=event) - mock_xcom_push.assert_has_calls( - [ - call(mock_context, key="hyperparameter_tuning_job_id", value=test_job_id), - call( - mock_context, - key="training_conf", - value={ - "training_conf_id": test_job_id, - "region": GCP_LOCATION, - "project_id": GCP_PROJECT, - }, - ), - ] - ) assert result == test_job def test_execute_complete_error(self): @@ -1587,7 +1570,7 @@ def test_execute_complete_error(self): class TestVertexAIGetHyperparameterTuningJobOperator: - @mock.patch(VERTEX_AI_PATH.format("hyperparameter_tuning_job.HyperparameterTuningJob.to_dict")) + @mock.patch(VERTEX_AI_PATH.format("hyperparameter_tuning_job.types.HyperparameterTuningJob.to_dict")) @mock.patch(VERTEX_AI_PATH.format("hyperparameter_tuning_job.HyperparameterTuningJobHook")) def test_execute(self, mock_hook, to_dict_mock): op = GetHyperparameterTuningJobOperator(