Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix credentials error for S3ToGCSOperator trigger #37518

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from googleapiclient.errors import HttpError

from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.providers.google.common.consts import CLIENT_INFO
from airflow.providers.google.common.hooks.base_google import GoogleBaseAsyncHook, GoogleBaseHook

if TYPE_CHECKING:
Expand Down Expand Up @@ -508,14 +509,18 @@ def __init__(self, project_id: str | None = None, **kwargs: Any) -> None:
self.project_id = project_id
self._client: StorageTransferServiceAsyncClient | None = None

def get_conn(self) -> StorageTransferServiceAsyncClient:
async def get_conn(self) -> StorageTransferServiceAsyncClient:
"""
Return async connection to the Storage Transfer Service.

:return: Google Storage Transfer asynchronous client.
"""
if not self._client:
self._client = StorageTransferServiceAsyncClient()
credentials = (await self.get_sync_hook()).get_credentials()
self._client = StorageTransferServiceAsyncClient(
credentials=credentials,
client_info=CLIENT_INFO,
)
return self._client

async def get_jobs(self, job_names: list[str]) -> ListTransferJobsAsyncPager:
Expand All @@ -525,7 +530,7 @@ async def get_jobs(self, job_names: list[str]) -> ListTransferJobsAsyncPager:
:param job_names: (Required) List of names of the jobs to be fetched.
:return: Object that yields Transfer jobs.
"""
client = self.get_conn()
client = await self.get_conn()
jobs_list_request = ListTransferJobsRequest(
filter=json.dumps({"project_id": self.project_id, "job_names": job_names})
)
Expand All @@ -540,7 +545,7 @@ async def get_latest_operation(self, job: TransferJob) -> Message | None:
"""
latest_operation_name = job.latest_operation_name
if latest_operation_name:
client = self.get_conn()
client = await self.get_conn()
response_operation = await client.transport.operations_client.get_operation(latest_operation_name)
operation = TransferOperation.deserialize(response_operation.metadata.value)
return operation
Expand Down
1 change: 1 addition & 0 deletions airflow/providers/google/cloud/transfers/s3_to_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ def transfer_files_async(self, files: list[str], gcs_hook: GCSHook, s3_hook: S3H
self.defer(
trigger=CloudStorageTransferServiceCreateJobsTrigger(
project_id=gcs_hook.project_id,
gcp_conn_id=self.gcp_conn_id,
eladkal marked this conversation as resolved.
Show resolved Hide resolved
job_names=job_names,
poll_interval=self.poll_interval,
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,19 @@ class CloudStorageTransferServiceCreateJobsTrigger(BaseTrigger):
:param job_names: List of transfer jobs names.
:param project_id: GCP project id.
:param poll_interval: Interval in seconds between polls.
:param gcp_conn_id: The connection ID used to connect to Google Cloud.
"""

def __init__(self, job_names: list[str], project_id: str | None = None, poll_interval: int = 10) -> None:
def __init__(
self,
job_names: list[str],
project_id: str | None = None,
poll_interval: int = 10,
gcp_conn_id: str = "google_cloud_default",
) -> None:
super().__init__()
self.project_id = project_id
self.gcp_conn_id = gcp_conn_id
self.job_names = job_names
self.poll_interval = poll_interval

Expand All @@ -53,6 +61,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"project_id": self.project_id,
"job_names": self.job_names,
"poll_interval": self.poll_interval,
"gcp_conn_id": self.gcp_conn_id,
},
)

Expand Down Expand Up @@ -117,4 +126,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
await asyncio.sleep(self.poll_interval)

def get_async_hook(self) -> CloudDataTransferServiceAsyncHook:
return CloudDataTransferServiceAsyncHook(project_id=self.project_id)
return CloudDataTransferServiceAsyncHook(
project_id=self.project_id,
gcp_conn_id=self.gcp_conn_id,
)
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,20 @@ def hook_async():


class TestCloudDataTransferServiceAsyncHook:
@pytest.mark.asyncio
@mock.patch(f"{TRANSFER_HOOK_PATH}.CloudDataTransferServiceAsyncHook.get_conn")
@mock.patch(f"{TRANSFER_HOOK_PATH}.StorageTransferServiceAsyncClient")
def test_get_conn(self, mock_async_client):
async def test_get_conn(self, mock_async_client, mock_get_conn):
expected_value = "Async Hook"
mock_async_client.return_value = expected_value
mock_get_conn.return_value = expected_value

hook = CloudDataTransferServiceAsyncHook(project_id=TEST_PROJECT_ID)
conn_0 = hook.get_conn()

conn_0 = await hook.get_conn()
assert conn_0 == expected_value

conn_1 = hook.get_conn()
conn_1 = await hook.get_conn()
assert conn_1 == expected_value
assert id(conn_0) == id(conn_1)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from airflow.triggers.base import TriggerEvent

PROJECT_ID = "test-project"
GCP_CONN_ID = "google-cloud-default-id"
JOB_0 = "test-job-0"
JOB_1 = "test-job-1"
JOB_NAMES = [JOB_0, JOB_1]
Expand All @@ -51,7 +52,10 @@
@pytest.fixture(scope="session")
def trigger():
return CloudStorageTransferServiceCreateJobsTrigger(
project_id=PROJECT_ID, job_names=JOB_NAMES, poll_interval=POLL_INTERVAL
project_id=PROJECT_ID,
job_names=JOB_NAMES,
poll_interval=POLL_INTERVAL,
gcp_conn_id=GCP_CONN_ID,
)


Expand Down Expand Up @@ -80,6 +84,7 @@ def test_serialize(self, trigger):
"project_id": PROJECT_ID,
"job_names": JOB_NAMES,
"poll_interval": POLL_INTERVAL,
"gcp_conn_id": GCP_CONN_ID,
}

def test_get_async_hook(self, trigger):
Expand Down