Skip to content

Commit

Permalink
fix: test, style, some bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
korolkevich authored and eladkal committed Mar 25, 2024
1 parent e11a8d1 commit 4f2583a
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@
from datetime import timedelta
from typing import TYPE_CHECKING, Any, Sequence

from packaging.version import Version

from google.cloud.storage_transfer_v1 import (
ListTransferJobsRequest,
StorageTransferServiceAsyncClient,
Expand Down Expand Up @@ -518,31 +516,11 @@ async def get_conn(self) -> StorageTransferServiceAsyncClient:
:return: Google Storage Transfer asynchronous client.
"""
if not self._client:
try:
from airflow.providers.google import __version__

if Version(__version__) >= Version("10.15.0"):
credentials = (await self.get_sync_hook()).get_credentials()
self._client = StorageTransferServiceAsyncClient(
credentials=credentials,
client_info=CLIENT_INFO,
)
else:
self._client = StorageTransferServiceAsyncClient()
warnings.warn(
"Getting credentials from the environment has been deprecated. "
"You should pass gcp_conn_id as parameter.",
AirflowProviderDeprecationWarning,
stacklevel=2,
)
except ImportError: # __version__ was added in 10.1.0, so this means it's < 10.15.0
self._client = StorageTransferServiceAsyncClient()
warnings.warn(
"Getting credentials from the environment has been deprecated. "
"You should pass gcp_conn_id as parameter.",
AirflowProviderDeprecationWarning,
stacklevel=2,
)
credentials = (await self.get_sync_hook()).get_credentials()
self._client = StorageTransferServiceAsyncClient(
credentials=credentials,
client_info=CLIENT_INFO,
)
return self._client

async def get_jobs(self, job_names: list[str]) -> ListTransferJobsAsyncPager:
Expand All @@ -567,7 +545,7 @@ async def get_latest_operation(self, job: TransferJob) -> Message | None:
"""
latest_operation_name = job.latest_operation_name
if latest_operation_name:
client = self.get_conn()
client = await self.get_conn()
response_operation = await client.transport.operations_client.get_operation(latest_operation_name)
operation = TransferOperation.deserialize(response_operation.metadata.value)
return operation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,20 @@ def hook_async():


class TestCloudDataTransferServiceAsyncHook:
@pytest.mark.asyncio
@mock.patch(f"{TRANSFER_HOOK_PATH}.CloudDataTransferServiceAsyncHook.get_conn")
@mock.patch(f"{TRANSFER_HOOK_PATH}.StorageTransferServiceAsyncClient")
def test_get_conn(self, mock_async_client):
async def test_get_conn(self, mock_async_client, mock_get_conn):
expected_value = "Async Hook"
mock_async_client.return_value = expected_value
mock_get_conn.return_value = expected_value

hook = CloudDataTransferServiceAsyncHook(project_id=TEST_PROJECT_ID)
conn_0 = hook.get_conn()

conn_0 = await hook.get_conn()
assert conn_0 == expected_value

conn_1 = hook.get_conn()
conn_1 = await hook.get_conn()
assert conn_1 == expected_value
assert id(conn_0) == id(conn_1)

Expand Down

0 comments on commit 4f2583a

Please sign in to comment.