Skip to content

Commit

Permalink
Format via ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
grieve54706 committed Jun 6, 2024
1 parent fd68736 commit a2d6f91
Show file tree
Hide file tree
Showing 12 changed files with 397 additions and 247 deletions.
6 changes: 3 additions & 3 deletions ibis-server/app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@ def __new__(cls):

def __init__(self):
load_dotenv(override=True)
self.wren_engine_endpoint = os.getenv('WREN_ENGINE_ENDPOINT')
self.wren_engine_endpoint = os.getenv("WREN_ENGINE_ENDPOINT")
self.validate_wren_engine_endpoint(self.wren_engine_endpoint)
self.log_level = os.getenv('LOG_LEVEL', 'INFO')
self.log_level = os.getenv("LOG_LEVEL", "INFO")

@staticmethod
def validate_wren_engine_endpoint(endpoint):
if endpoint is None:
raise ValueError('WREN_ENGINE_ENDPOINT is not set')
raise ValueError("WREN_ENGINE_ENDPOINT is not set")


def get_config() -> Config:
Expand Down
6 changes: 3 additions & 3 deletions ibis-server/app/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def get_logger(name):


def log_dto(f):
logger = get_logger('app.routers.ibis')
logger = get_logger("app.routers.ibis")

@wraps(f)
def wrapper(*args, **kwargs):
Expand All @@ -22,12 +22,12 @@ def wrapper(*args, **kwargs):


def log_rewritten(f):
logger = get_logger('app.mdl.rewriter')
logger = get_logger("app.mdl.rewriter")

@wraps(f)
def wrapper(*args, **kwargs):
rs = f(*args, **kwargs)
logger.debug(f'Rewritten SQL: {rs}')
logger.debug(f"Rewritten SQL: {rs}")
return rs

return wrapper
13 changes: 6 additions & 7 deletions ibis-server/app/mdl/rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,11 @@
def rewrite(manifest_str: str, sql: str) -> str:
try:
r = httpx.request(
method='GET',
url=f'{wren_engine_endpoint}/v2/mdl/dry-plan',
headers={'Content-Type': 'application/json', 'Accept': 'application/json'},
content=orjson.dumps({
'manifestStr': manifest_str,
'sql': sql}))
method="GET",
url=f"{wren_engine_endpoint}/v2/mdl/dry-plan",
headers={"Content-Type": "application/json", "Accept": "application/json"},
content=orjson.dumps({"manifestStr": manifest_str, "sql": sql}),
)
return r.text if r.status_code == httpx.codes.OK else r.raise_for_status()
except httpx.ConnectError as e:
raise ConnectionError(f'Can not connect to Wren Engine: {e}')
raise ConnectionError(f"Can not connect to Wren Engine: {e}")
34 changes: 24 additions & 10 deletions ibis-server/app/model/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,42 +7,56 @@


class Connector:
def __init__(self, data_source: DataSource, connection_info: ConnectionInfo, manifest_str: str, column_dtypes: dict[str, str]):
def __init__(
self,
data_source: DataSource,
connection_info: ConnectionInfo,
manifest_str: str,
column_dtypes: dict[str, str],
):
self.data_source = data_source
self.connection = self.data_source.get_connection(connection_info)
self.manifest_str = manifest_str
self.column_dtypes = column_dtypes

def query(self, sql) -> dict:
rewritten_sql = rewrite(self.manifest_str, sql)
return self._to_json(self.connection.sql(rewritten_sql, dialect='trino').to_pandas())
return self._to_json(
self.connection.sql(rewritten_sql, dialect="trino").to_pandas()
)

def dry_run(self, sql) -> None:
try:
rewritten_sql = rewrite(self.manifest_str, sql)
self.connection.sql(rewritten_sql, dialect='trino')
self.connection.sql(rewritten_sql, dialect="trino")
except Exception as e:
raise QueryDryRunError(f'Exception: {type(e)}, message: {str(e)}')
raise QueryDryRunError(f"Exception: {type(e)}, message: {str(e)}")

def _to_json(self, df):
if self.column_dtypes:
self._to_specific_types(df, self.column_dtypes)
json_obj = loads(df.to_json(orient='split'))
del json_obj['index']
json_obj['dtypes'] = df.dtypes.apply(lambda x: x.name).to_dict()
json_obj = loads(df.to_json(orient="split"))
del json_obj["index"]
json_obj["dtypes"] = df.dtypes.apply(lambda x: x.name).to_dict()
return json_obj

def _to_specific_types(self, df: pd.DataFrame, column_dtypes: dict[str, str]):
for column, dtype in column_dtypes.items():
if dtype == 'datetime64':
if dtype == "datetime64":
df[column] = self._to_datetime_and_format(df[column])
else:
df[column] = df[column].astype(dtype)

@staticmethod
def _to_datetime_and_format(series: pd.Series) -> pd.Series:
series = pd.to_datetime(series, errors='coerce')
return series.apply(lambda d: d.strftime('%Y-%m-%d %H:%M:%S.%f' + (' %Z' if series.dt.tz is not None else '')) if not pd.isnull(d) else d)
series = pd.to_datetime(series, errors="coerce")
return series.apply(
lambda d: d.strftime(
"%Y-%m-%d %H:%M:%S.%f" + (" %Z" if series.dt.tz is not None else "")
)
if not pd.isnull(d)
else d
)


class QueryDryRunError(Exception):
Expand Down
21 changes: 15 additions & 6 deletions ibis-server/app/model/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,23 @@ def get_connection(self, dto) -> BaseBackend:
case DataSource.snowflake:
return self.get_snowflake_connection(dto)
case _:
raise NotImplementedError(f'Unsupported data source: {self}')
raise NotImplementedError(f"Unsupported data source: {self}")

@staticmethod
def get_postgres_connection(info: PostgresConnectionUrl | PostgresConnectionInfo) -> BaseBackend:
return ibis.connect(getattr(info, 'connection_url', None) or f"postgres://{info.user}:{info.password}@{info.host}:{info.port}/{info.database}")
def get_postgres_connection(
info: PostgresConnectionUrl | PostgresConnectionInfo,
) -> BaseBackend:
return ibis.connect(
getattr(info, "connection_url", None)
or f"postgres://{info.user}:{info.password}@{info.host}:{info.port}/{info.database}"
)

@staticmethod
def get_bigquery_connection(info: BigQueryConnectionInfo) -> BaseBackend:
credits_json = loads(base64.b64decode(info.credentials).decode('utf-8'))
credentials = service_account.Credentials.from_service_account_info(credits_json)
credits_json = loads(base64.b64decode(info.credentials).decode("utf-8"))
credentials = service_account.Credentials.from_service_account_info(
credits_json
)
return ibis.bigquery.connect(
project_id=info.project_id,
dataset_id=info.dataset_id,
Expand Down Expand Up @@ -75,7 +82,9 @@ class SnowflakeConnectionInfo(BaseModel):
password: str
account: str
database: str
sf_schema: str = Field(alias="schema") # Use `sf_schema` to avoid `schema` shadowing in BaseModel
sf_schema: str = Field(
alias="schema"
) # Use `sf_schema` to avoid `schema` shadowing in BaseModel


ConnectionInfo = Union[
Expand Down
12 changes: 9 additions & 3 deletions ibis-server/app/model/dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,24 @@
PostgresConnectionUrl,
PostgresConnectionInfo,
BigQueryConnectionInfo,
SnowflakeConnectionInfo
SnowflakeConnectionInfo,
)


class IbisDTO(BaseModel):
sql: str
manifest_str: str = Field(alias="manifestStr", description="Base64 manifest")
column_dtypes: dict[str, str] | None = Field(alias="columnDtypes", description="If this field is set, it will forcibly convert the type.", default=None)
column_dtypes: dict[str, str] | None = Field(
alias="columnDtypes",
description="If this field is set, it will forcibly convert the type.",
default=None,
)


class PostgresDTO(IbisDTO):
connection_info: PostgresConnectionUrl | PostgresConnectionInfo = Field(alias="connectionInfo")
connection_info: PostgresConnectionUrl | PostgresConnectionInfo = Field(
alias="connectionInfo"
)


class BigQueryDTO(IbisDTO):
Expand Down
10 changes: 7 additions & 3 deletions ibis-server/app/routers/ibis/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,19 @@
from app.model.data_source import DataSource
from app.model.dto import BigQueryDTO

router = APIRouter(prefix='/bigquery', tags=['bigquery'])
router = APIRouter(prefix="/bigquery", tags=["bigquery"])

data_source = DataSource.bigquery


@router.post("/query")
@log_dto
def query(dto: BigQueryDTO, dry_run: Annotated[bool, Query(alias="dryRun")] = False) -> Response:
connector = Connector(data_source, dto.connection_info, dto.manifest_str, dto.column_dtypes)
def query(
dto: BigQueryDTO, dry_run: Annotated[bool, Query(alias="dryRun")] = False
) -> Response:
connector = Connector(
data_source, dto.connection_info, dto.manifest_str, dto.column_dtypes
)
if dry_run:
connector.dry_run(dto.sql)
return Response(status_code=204)
Expand Down
10 changes: 7 additions & 3 deletions ibis-server/app/routers/ibis/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,19 @@
from app.model.data_source import DataSource
from app.model.dto import PostgresDTO

router = APIRouter(prefix='/postgres', tags=['postgres'])
router = APIRouter(prefix="/postgres", tags=["postgres"])

data_source = DataSource.postgres


@router.post("/query")
@log_dto
def query(dto: PostgresDTO, dry_run: Annotated[bool, Query(alias="dryRun")] = False) -> Response:
connector = Connector(data_source, dto.connection_info, dto.manifest_str, dto.column_dtypes)
def query(
dto: PostgresDTO, dry_run: Annotated[bool, Query(alias="dryRun")] = False
) -> Response:
connector = Connector(
data_source, dto.connection_info, dto.manifest_str, dto.column_dtypes
)
if dry_run:
connector.dry_run(dto.sql)
return Response(status_code=204)
Expand Down
10 changes: 7 additions & 3 deletions ibis-server/app/routers/ibis/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,19 @@
from app.model.data_source import DataSource
from app.model.dto import SnowflakeDTO

router = APIRouter(prefix='/snowflake', tags=['snowflake'])
router = APIRouter(prefix="/snowflake", tags=["snowflake"])

data_source = DataSource.snowflake


@router.post("/query")
@log_dto
def query(dto: SnowflakeDTO, dry_run: Annotated[bool, Query(alias="dryRun")] = False) -> Response:
connector = Connector(data_source, dto.connection_info, dto.manifest_str, dto.column_dtypes)
def query(
dto: SnowflakeDTO, dry_run: Annotated[bool, Query(alias="dryRun")] = False
) -> Response:
connector = Connector(
data_source, dto.connection_info, dto.manifest_str, dto.column_dtypes
)
if dry_run:
connector.dry_run(dto.sql)
return Response(status_code=204)
Expand Down
Loading

0 comments on commit a2d6f91

Please sign in to comment.