Skip to content

Commit

Permalink
Support OAuth authentication for Big Query (#431)
Browse files Browse the repository at this point in the history
## Description

Support OAuth authentication for Big Query

## Related Issue(s)
closes #420 

## Breaking Change?

None

## Checklist

- [ ] I have made corresponding changes to the documentation (if
required)
- [X] I have added tests that prove my fix is effective or that my
feature works

---------

Co-authored-by: Monideep De <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored and tatiana committed Aug 9, 2023
1 parent 35d1b69 commit 5a398b0
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 2 deletions.
3 changes: 3 additions & 0 deletions cosmos/profiles/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .base import BaseProfileMapping
from .bigquery.service_account_file import GoogleCloudServiceAccountFileProfileMapping
from .bigquery.service_account_keyfile_dict import GoogleCloudServiceAccountDictProfileMapping
from .bigquery.oauth import GoogleCloudOauthProfileMapping
from .databricks.token import DatabricksTokenProfileMapping
from .exasol.user_pass import ExasolUserPasswordProfileMapping
from .postgres.user_pass import PostgresUserPasswordProfileMapping
Expand All @@ -22,6 +23,7 @@
profile_mappings: list[Type[BaseProfileMapping]] = [
GoogleCloudServiceAccountFileProfileMapping,
GoogleCloudServiceAccountDictProfileMapping,
GoogleCloudOauthProfileMapping,
DatabricksTokenProfileMapping,
PostgresUserPasswordProfileMapping,
RedshiftUserPasswordProfileMapping,
Expand Down Expand Up @@ -57,6 +59,7 @@ def get_automatic_profile_mapping(
"BaseProfileMapping",
"GoogleCloudServiceAccountFileProfileMapping",
"GoogleCloudServiceAccountDictProfileMapping",
"GoogleCloudOauthProfileMapping",
"DatabricksTokenProfileMapping",
"PostgresUserPasswordProfileMapping",
"RedshiftUserPasswordProfileMapping",
Expand Down
4 changes: 2 additions & 2 deletions cosmos/profiles/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def get_dbt_value(self, name: str) -> Any:
if self.profile_args.get(name):
return self.profile_args[name]

# if it's has an entry in airflow_param_mapping, we can get it from conn
# if it has an entry in airflow_param_mapping, we can get it from conn
if name in self.airflow_param_mapping:
airflow_fields = self.airflow_param_mapping[name]

Expand All @@ -147,7 +147,7 @@ def get_dbt_value(self, name: str) -> Any:
airflow_field = airflow_field.replace("extra.", "", 1)
value = self.conn.extra_dejson.get(airflow_field)
else:
value = getattr(self.conn, airflow_field)
value = getattr(self.conn, airflow_field, None)

if not value:
continue
Expand Down
2 changes: 2 additions & 0 deletions cosmos/profiles/bigquery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

from .service_account_file import GoogleCloudServiceAccountFileProfileMapping
from .service_account_keyfile_dict import GoogleCloudServiceAccountDictProfileMapping
from .oauth import GoogleCloudOauthProfileMapping

__all__ = [
"GoogleCloudServiceAccountFileProfileMapping",
"GoogleCloudServiceAccountDictProfileMapping",
"GoogleCloudOauthProfileMapping",
]
39 changes: 39 additions & 0 deletions cosmos/profiles/bigquery/oauth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"Maps Airflow GCP connections to dbt BigQuery profiles that uses oauth via gcloud, if they don't use key file or JSON."
from __future__ import annotations

from typing import Any

from cosmos.profiles.base import BaseProfileMapping


class GoogleCloudOauthProfileMapping(BaseProfileMapping):
"""
Maps Airflow GCP connections to dbt BigQuery profiles that uses oauth via gcloud,
if they don't use key file or JSON.
https://docs.getdbt.com/docs/core/connect-data-platform/bigquery-setup#oauth-via-gcloud
https://airflow.apache.org/docs/apache-airflow-providers-google/stable/connections/gcp.html
"""

airflow_connection_type: str = "google_cloud_platform"

required_fields = [
"project",
"dataset",
]

airflow_param_mapping = {
"project": "extra.project",
"dataset": "extra.dataset",
}

@property
def profile(self) -> dict[str, Any | None]:
"Generates profile. Defaults `threads` to 1."
return {
**self.mapped_params,
"type": "bigquery",
"method": "oauth",
"threads": 1,
**self.profile_args,
}
61 changes: 61 additions & 0 deletions tests/profiles/bigquery/test_bq_oauth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"Tests for the BigQuery profile."

import json
from unittest.mock import patch

import pytest
from airflow.models.connection import Connection

from cosmos.profiles import get_automatic_profile_mapping
from cosmos.profiles.bigquery.oauth import (
GoogleCloudOauthProfileMapping,
)


@pytest.fixture()
def mock_bigquery_conn(request):
"""
Mocks and returns an Airflow BigQuery connection.
"""
extra = {"project": "my_project", "dataset": "my_dataset"} if not hasattr(request, "param") else request.param
conn = Connection(
conn_id="my_bigquery_connection",
conn_type="google_cloud_platform",
extra=json.dumps(extra),
)

with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn):
yield conn


def test_bigquery_mapping_selected(mock_bigquery_conn: Connection):
profile_mapping = get_automatic_profile_mapping(mock_bigquery_conn.conn_id, {})
assert isinstance(profile_mapping, GoogleCloudOauthProfileMapping)


@pytest.mark.parametrize(
"mock_bigquery_conn", [{"project": "my_project"}, {"dataset": "my_dataset"}, {}], indirect=True
)
def test_connection_claiming_fails(mock_bigquery_conn: Connection) -> None:
"""
Tests that the BigQuery profile mapping claims the correct connection type.
"""
profile_mapping = GoogleCloudOauthProfileMapping(mock_bigquery_conn)
assert not profile_mapping.can_claim_connection()


def test_connection_claiming_succeeds(mock_bigquery_conn: Connection):
profile_mapping = GoogleCloudOauthProfileMapping(mock_bigquery_conn, {})
assert profile_mapping.can_claim_connection()


def test_profile(mock_bigquery_conn: Connection):
profile_mapping = GoogleCloudOauthProfileMapping(mock_bigquery_conn, {})
expected = {
"type": "bigquery",
"method": "oauth",
"project": "my_project",
"dataset": "my_dataset",
"threads": 1,
}
assert profile_mapping.profile == expected

0 comments on commit 5a398b0

Please sign in to comment.