diff --git a/airflow/providers/google/cloud/hooks/bigquery.py b/airflow/providers/google/cloud/hooks/bigquery.py index 20b1ac10c8021d..c3f774443ced8b 100644 --- a/airflow/providers/google/cloud/hooks/bigquery.py +++ b/airflow/providers/google/cloud/hooks/bigquery.py @@ -56,6 +56,7 @@ from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.providers.common.sql.hooks.sql import DbApiHook from airflow.providers.google.cloud.utils.bigquery import bq_cast +from airflow.providers.google.cloud.utils.credentials_provider import _get_scopes from airflow.providers.google.common.consts import CLIENT_INFO from airflow.providers.google.common.hooks.base_google import GoogleBaseAsyncHook, GoogleBaseHook, get_field @@ -92,6 +93,8 @@ class BigQueryHook(GoogleBaseHook, DbApiHook): Google BigQuery jobs. :param impersonation_chain: This is the optional service account to impersonate using short term credentials. + :param impersonation_scopes: Optional list of scopes for impersonated account. + Will override scopes from connection. :param labels: The BigQuery resource label. """ @@ -108,6 +111,7 @@ def __init__( priority: str = "INTERACTIVE", api_resource_configs: dict | None = None, impersonation_chain: str | Sequence[str] | None = None, + impersonation_scopes: str | Sequence[str] | None = None, labels: dict | None = None, **kwargs, ) -> None: @@ -127,6 +131,7 @@ def __init__( self.api_resource_configs: dict = api_resource_configs or {} self.labels = labels self.credentials_path = "bigquery_hook_credentials.json" + self.impersonation_scopes = impersonation_scopes def get_conn(self) -> BigQueryConnection: """Get a BigQuery PEP 249 connection object.""" @@ -2335,6 +2340,20 @@ def var_print(var_name): return project_id, dataset_id, table_id + @property + def scopes(self) -> Sequence[str]: + """ + Return OAuth 2.0 scopes. + + :return: Returns the scope defined in impersonation_scopes, the connection configuration, or the default scope + """ + scope_value: str | None + if self.impersonation_chain and self.impersonation_scopes: + scope_value = ",".join(self.impersonation_scopes) + else: + scope_value = self._get_field("scope", None) + return _get_scopes(scope_value) + class BigQueryConnection: """BigQuery connection. diff --git a/airflow/providers/google/cloud/operators/bigquery.py b/airflow/providers/google/cloud/operators/bigquery.py index 77862491b47a8e..84e03928dae0dc 100644 --- a/airflow/providers/google/cloud/operators/bigquery.py +++ b/airflow/providers/google/cloud/operators/bigquery.py @@ -1213,6 +1213,7 @@ def __init__( location: str | None = None, encryption_configuration: dict | None = None, impersonation_chain: str | Sequence[str] | None = None, + impersonation_scopes: str | Sequence[str] | None = None, job_id: str | list[str] | None = None, **kwargs, ) -> None: @@ -1239,6 +1240,7 @@ def __init__( self.encryption_configuration = encryption_configuration self.hook: BigQueryHook | None = None self.impersonation_chain = impersonation_chain + self.impersonation_scopes = impersonation_scopes self.job_id = job_id def execute(self, context: Context): @@ -1249,6 +1251,7 @@ def execute(self, context: Context): use_legacy_sql=self.use_legacy_sql, location=self.location, impersonation_chain=self.impersonation_chain, + impersonation_scopes=self.impersonation_scopes, ) if isinstance(self.sql, str): self.job_id = self.hook.run_query( diff --git a/tests/providers/google/cloud/operators/test_bigquery.py b/tests/providers/google/cloud/operators/test_bigquery.py index 64e31913fb7781..4d63ba57f1003b 100644 --- a/tests/providers/google/cloud/operators/test_bigquery.py +++ b/tests/providers/google/cloud/operators/test_bigquery.py @@ -543,9 +543,24 @@ def test_execute(self, mock_hook): api_resource_configs=None, cluster_fields=None, encryption_configuration=encryption_configuration, + impersonation_chain=["service-account@myproject.iam.gserviceaccount.com"], + impersonation_scopes=[ + "https://www.googleapis.com/auth/cloud-platform", + "https://www.googleapis.com/auth/drive", + ], ) operator.execute(MagicMock()) + mock_hook.assert_called_with( + gcp_conn_id="google_cloud_default", + use_legacy_sql=True, + location=None, + impersonation_chain=["service-account@myproject.iam.gserviceaccount.com"], + impersonation_scopes=[ + "https://www.googleapis.com/auth/cloud-platform", + "https://www.googleapis.com/auth/drive", + ], + ) mock_hook.return_value.run_query.assert_called_once_with( sql="Select * from test_table", destination_dataset_table=None,