Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
feat(bigquery): get_catalog_names (apache#23461)
Browse files Browse the repository at this point in the history
(cherry picked from commit 7a1aa63)
  • Loading branch information
betodealmeida authored and john-bodley committed Apr 12, 2023
1 parent 592e7b9 commit 9b0f55b
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 76 deletions.
99 changes: 63 additions & 36 deletions superset/db_engine_specs/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from marshmallow.exceptions import ValidationError
from sqlalchemy import column, types
from sqlalchemy.engine.base import Engine
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.sql import sqltypes
from typing_extensions import TypedDict

Expand All @@ -39,10 +40,26 @@
from superset.db_engine_specs.base import BaseEngineSpec, BasicPropertiesType
from superset.db_engine_specs.exceptions import SupersetDBAPIConnectionError
from superset.errors import SupersetError, SupersetErrorType
from superset.exceptions import SupersetException
from superset.sql_parse import Table
from superset.utils import core as utils
from superset.utils.hashing import md5_sha_from_str

try:
from google.cloud import bigquery
from google.oauth2 import service_account

dependencies_installed = True
except ModuleNotFoundError:
dependencies_installed = False

try:
import pandas_gbq

can_upload = True
except ModuleNotFoundError:
can_upload = False

if TYPE_CHECKING:
from superset.models.core import Database # pragma: no cover

Expand Down Expand Up @@ -86,7 +103,7 @@ class BigQueryParametersType(TypedDict):
query: Dict[str, Any]


class BigQueryEngineSpec(BaseEngineSpec):
class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-methods
"""Engine spec for Google's BigQuery
As contributed by @mxmzdlv on issue #945"""
Expand Down Expand Up @@ -349,20 +366,13 @@ def df_to_sql(
:param df: The dataframe with data to be uploaded
:param to_sql_kwargs: The kwargs to be passed to pandas.DataFrame.to_sql` method
"""

try:
# pylint: disable=import-outside-toplevel
import pandas_gbq
from google.oauth2 import service_account
except ImportError as ex:
raise Exception(
"Could not import libraries `pandas_gbq` or `google.oauth2`, which are "
"required to be installed in your environment in order "
"to upload data to BigQuery"
) from ex
if not can_upload:
raise SupersetException(
"Could not import libraries needed to upload data to BigQuery."
)

if not table.schema:
raise Exception("The table schema must be defined")
raise SupersetException("The table schema must be defined")

to_gbq_kwargs = {}
with cls.get_engine(database) as engine:
Expand All @@ -388,6 +398,21 @@ def df_to_sql(

pandas_gbq.to_gbq(df, **to_gbq_kwargs)

@classmethod
def _get_client(cls, engine: Engine) -> Any:
"""
Return the BigQuery client associated with an engine.
"""
if not dependencies_installed:
raise SupersetException(
"Could not import libraries needed to connect to BigQuery."
)

credentials = service_account.Credentials.from_service_account_info(
engine.dialect.credentials_info
)
return bigquery.Client(credentials=credentials)

@classmethod
def estimate_query_cost(
cls,
Expand All @@ -406,7 +431,7 @@ def estimate_query_cost(
"""
extra = database.get_extra() or {}
if not cls.get_allow_cost_estimate(extra):
raise Exception("Database does not support cost estimation")
raise SupersetException("Database does not support cost estimation")

parsed_query = sql_parse.ParsedQuery(sql)
statements = parsed_query.get_statements()
Expand All @@ -417,35 +442,37 @@ def estimate_query_cost(
costs.append(cls.estimate_statement_cost(processed_statement, database))
return costs

@classmethod
def get_catalog_names(
cls,
database: "Database",
inspector: Inspector,
) -> List[str]:
"""
Get all catalogs.
In BigQuery, a catalog is called a "project".
"""
engine: Engine
with database.get_sqla_engine_with_context() as engine:
client = cls._get_client(engine)
projects = client.list_projects()

return sorted(project.project_id for project in projects)

@classmethod
def get_allow_cost_estimate(cls, extra: Dict[str, Any]) -> bool:
return True

@classmethod
def estimate_statement_cost(cls, statement: str, cursor: Any) -> Dict[str, Any]:
try:
# pylint: disable=import-outside-toplevel
# It's the only way to perfom a dry-run estimate cost
from google.cloud import bigquery
from google.oauth2 import service_account
except ImportError as ex:
raise Exception(
"Could not import libraries `pygibquery` or `google.oauth2`, which are "
"required to be installed in your environment in order "
"to upload data to BigQuery"
) from ex

with cls.get_engine(cursor) as engine:
creds = engine.dialect.credentials_info

creds = service_account.Credentials.from_service_account_info(creds)
client = bigquery.Client(credentials=creds)
job_config = bigquery.QueryJobConfig(dry_run=True)

query_job = client.query(
statement,
job_config=job_config,
) # Make an API request.
client = cls._get_client(engine)
job_config = bigquery.QueryJobConfig(dry_run=True)
query_job = client.query(
statement,
job_config=job_config,
) # Make an API request.

# Format Bytes.
# TODO: Humanize in case more db engine specs need to be added,
Expand Down
47 changes: 7 additions & 40 deletions tests/integration_tests/db_engine_specs/bigquery_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,48 +216,13 @@ def test_get_indexes(self):
]

@mock.patch("superset.db_engine_specs.bigquery.BigQueryEngineSpec.get_engine")
def test_df_to_sql(self, mock_get_engine):
@mock.patch("superset.db_engine_specs.bigquery.pandas_gbq")
@mock.patch("superset.db_engine_specs.bigquery.service_account")
def test_df_to_sql(self, mock_service_account, mock_pandas_gbq, mock_get_engine):
"""
DB Eng Specs (bigquery): Test DataFrame to SQL contract
"""
# test missing google.oauth2 dependency
sys.modules["pandas_gbq"] = mock.MagicMock()
df = DataFrame()
database = mock.MagicMock()
with self.assertRaises(Exception):
BigQueryEngineSpec.df_to_sql(
database=database,
table=Table(table="name", schema="schema"),
df=df,
to_sql_kwargs={},
)

invalid_kwargs = [
{"name": "some_name"},
{"schema": "some_schema"},
{"con": "some_con"},
{"name": "some_name", "con": "some_con"},
{"name": "some_name", "schema": "some_schema"},
{"con": "some_con", "schema": "some_schema"},
]
# Test check for missing schema.
sys.modules["google.oauth2"] = mock.MagicMock()
for invalid_kwarg in invalid_kwargs:
self.assertRaisesRegex(
Exception,
"The table schema must be defined",
BigQueryEngineSpec.df_to_sql,
database=database,
table=Table(table="name"),
df=df,
to_sql_kwargs=invalid_kwarg,
)

import pandas_gbq
from google.oauth2 import service_account

pandas_gbq.to_gbq = mock.Mock()
service_account.Credentials.from_service_account_info = mock.MagicMock(
mock_service_account.Credentials.from_service_account_info = mock.MagicMock(
return_value="account_info"
)

Expand All @@ -266,14 +231,16 @@ def test_df_to_sql(self, mock_get_engine):
"secrets"
)

df = DataFrame()
database = mock.MagicMock()
BigQueryEngineSpec.df_to_sql(
database=database,
table=Table(table="name", schema="schema"),
df=df,
to_sql_kwargs={"if_exists": "extra_key"},
)

pandas_gbq.to_gbq.assert_called_with(
mock_pandas_gbq.to_gbq.assert_called_with(
df,
project_id="google-host",
destination_table="schema.name",
Expand Down

0 comments on commit 9b0f55b

Please sign in to comment.