diff --git a/airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py b/airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py index 1d38284f4f14a0..966735e9c21f19 100644 --- a/airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py +++ b/airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py @@ -35,8 +35,6 @@ from datetime import timedelta from typing import TYPE_CHECKING, Any, Sequence -from packaging.version import Version - from google.cloud.storage_transfer_v1 import ( ListTransferJobsRequest, StorageTransferServiceAsyncClient, @@ -518,31 +516,11 @@ async def get_conn(self) -> StorageTransferServiceAsyncClient: :return: Google Storage Transfer asynchronous client. """ if not self._client: - try: - from airflow.providers.google import __version__ - - if Version(__version__) >= Version("10.15.0"): - credentials = (await self.get_sync_hook()).get_credentials() - self._client = StorageTransferServiceAsyncClient( - credentials=credentials, - client_info=CLIENT_INFO, - ) - else: - self._client = StorageTransferServiceAsyncClient() - warnings.warn( - "Getting credentials from the environment has been deprecated. " - "You should pass gcp_conn_id as parameter.", - AirflowProviderDeprecationWarning, - stacklevel=2, - ) - except ImportError: # __version__ was added in 10.1.0, so this means it's < 10.15.0 - self._client = StorageTransferServiceAsyncClient() - warnings.warn( - "Getting credentials from the environment has been deprecated. " - "You should pass gcp_conn_id as parameter.", - AirflowProviderDeprecationWarning, - stacklevel=2, - ) + credentials = (await self.get_sync_hook()).get_credentials() + self._client = StorageTransferServiceAsyncClient( + credentials=credentials, + client_info=CLIENT_INFO, + ) return self._client async def get_jobs(self, job_names: list[str]) -> ListTransferJobsAsyncPager: @@ -567,7 +545,7 @@ async def get_latest_operation(self, job: TransferJob) -> Message | None: """ latest_operation_name = job.latest_operation_name if latest_operation_name: - client = self.get_conn() + client = await self.get_conn() response_operation = await client.transport.operations_client.get_operation(latest_operation_name) operation = TransferOperation.deserialize(response_operation.metadata.value) return operation diff --git a/tests/providers/google/cloud/hooks/test_cloud_storage_transfer_service_async.py b/tests/providers/google/cloud/hooks/test_cloud_storage_transfer_service_async.py index a27d233fd19f6d..e05bacbbd2c306 100644 --- a/tests/providers/google/cloud/hooks/test_cloud_storage_transfer_service_async.py +++ b/tests/providers/google/cloud/hooks/test_cloud_storage_transfer_service_async.py @@ -42,16 +42,20 @@ def hook_async(): class TestCloudDataTransferServiceAsyncHook: + @pytest.mark.asyncio + @mock.patch(f"{TRANSFER_HOOK_PATH}.CloudDataTransferServiceAsyncHook.get_conn") @mock.patch(f"{TRANSFER_HOOK_PATH}.StorageTransferServiceAsyncClient") - def test_get_conn(self, mock_async_client): + async def test_get_conn(self, mock_async_client, mock_get_conn): expected_value = "Async Hook" mock_async_client.return_value = expected_value + mock_get_conn.return_value = expected_value hook = CloudDataTransferServiceAsyncHook(project_id=TEST_PROJECT_ID) - conn_0 = hook.get_conn() + + conn_0 = await hook.get_conn() assert conn_0 == expected_value - conn_1 = hook.get_conn() + conn_1 = await hook.get_conn() assert conn_1 == expected_value assert id(conn_0) == id(conn_1)