From a2d6f91f32dfeadaa059ccf1a660fa5247308db0 Mon Sep 17 00:00:00 2001 From: Grieve Date: Thu, 6 Jun 2024 19:02:35 +0800 Subject: [PATCH] Format via ruff --- ibis-server/app/config.py | 6 +- ibis-server/app/logger.py | 6 +- ibis-server/app/mdl/rewriter.py | 13 +- ibis-server/app/model/connector.py | 34 +++- ibis-server/app/model/data_source.py | 21 +- ibis-server/app/model/dto.py | 12 +- ibis-server/app/routers/ibis/bigquery.py | 10 +- ibis-server/app/routers/ibis/postgres.py | 10 +- ibis-server/app/routers/ibis/snowflake.py | 10 +- .../tests/routers/ibis/test_bigquery.py | 166 +++++++++------ .../tests/routers/ibis/test_postgres.py | 190 +++++++++++------- .../tests/routers/ibis/test_snowflake.py | 166 +++++++++------ 12 files changed, 397 insertions(+), 247 deletions(-) diff --git a/ibis-server/app/config.py b/ibis-server/app/config.py index 4ca9d8c20..83e440af5 100644 --- a/ibis-server/app/config.py +++ b/ibis-server/app/config.py @@ -13,14 +13,14 @@ def __new__(cls): def __init__(self): load_dotenv(override=True) - self.wren_engine_endpoint = os.getenv('WREN_ENGINE_ENDPOINT') + self.wren_engine_endpoint = os.getenv("WREN_ENGINE_ENDPOINT") self.validate_wren_engine_endpoint(self.wren_engine_endpoint) - self.log_level = os.getenv('LOG_LEVEL', 'INFO') + self.log_level = os.getenv("LOG_LEVEL", "INFO") @staticmethod def validate_wren_engine_endpoint(endpoint): if endpoint is None: - raise ValueError('WREN_ENGINE_ENDPOINT is not set') + raise ValueError("WREN_ENGINE_ENDPOINT is not set") def get_config() -> Config: diff --git a/ibis-server/app/logger.py b/ibis-server/app/logger.py index c78010940..c7cc4aa95 100644 --- a/ibis-server/app/logger.py +++ b/ibis-server/app/logger.py @@ -11,7 +11,7 @@ def get_logger(name): def log_dto(f): - logger = get_logger('app.routers.ibis') + logger = get_logger("app.routers.ibis") @wraps(f) def wrapper(*args, **kwargs): @@ -22,12 +22,12 @@ def wrapper(*args, **kwargs): def log_rewritten(f): - logger = get_logger('app.mdl.rewriter') + logger = get_logger("app.mdl.rewriter") @wraps(f) def wrapper(*args, **kwargs): rs = f(*args, **kwargs) - logger.debug(f'Rewritten SQL: {rs}') + logger.debug(f"Rewritten SQL: {rs}") return rs return wrapper diff --git a/ibis-server/app/mdl/rewriter.py b/ibis-server/app/mdl/rewriter.py index a3435006e..ce39cfc74 100644 --- a/ibis-server/app/mdl/rewriter.py +++ b/ibis-server/app/mdl/rewriter.py @@ -11,12 +11,11 @@ def rewrite(manifest_str: str, sql: str) -> str: try: r = httpx.request( - method='GET', - url=f'{wren_engine_endpoint}/v2/mdl/dry-plan', - headers={'Content-Type': 'application/json', 'Accept': 'application/json'}, - content=orjson.dumps({ - 'manifestStr': manifest_str, - 'sql': sql})) + method="GET", + url=f"{wren_engine_endpoint}/v2/mdl/dry-plan", + headers={"Content-Type": "application/json", "Accept": "application/json"}, + content=orjson.dumps({"manifestStr": manifest_str, "sql": sql}), + ) return r.text if r.status_code == httpx.codes.OK else r.raise_for_status() except httpx.ConnectError as e: - raise ConnectionError(f'Can not connect to Wren Engine: {e}') + raise ConnectionError(f"Can not connect to Wren Engine: {e}") diff --git a/ibis-server/app/model/connector.py b/ibis-server/app/model/connector.py index 5b3e54852..014695a46 100644 --- a/ibis-server/app/model/connector.py +++ b/ibis-server/app/model/connector.py @@ -7,7 +7,13 @@ class Connector: - def __init__(self, data_source: DataSource, connection_info: ConnectionInfo, manifest_str: str, column_dtypes: dict[str, str]): + def __init__( + self, + data_source: DataSource, + connection_info: ConnectionInfo, + manifest_str: str, + column_dtypes: dict[str, str], + ): self.data_source = data_source self.connection = self.data_source.get_connection(connection_info) self.manifest_str = manifest_str @@ -15,34 +21,42 @@ def __init__(self, data_source: DataSource, connection_info: ConnectionInfo, man def query(self, sql) -> dict: rewritten_sql = rewrite(self.manifest_str, sql) - return self._to_json(self.connection.sql(rewritten_sql, dialect='trino').to_pandas()) + return self._to_json( + self.connection.sql(rewritten_sql, dialect="trino").to_pandas() + ) def dry_run(self, sql) -> None: try: rewritten_sql = rewrite(self.manifest_str, sql) - self.connection.sql(rewritten_sql, dialect='trino') + self.connection.sql(rewritten_sql, dialect="trino") except Exception as e: - raise QueryDryRunError(f'Exception: {type(e)}, message: {str(e)}') + raise QueryDryRunError(f"Exception: {type(e)}, message: {str(e)}") def _to_json(self, df): if self.column_dtypes: self._to_specific_types(df, self.column_dtypes) - json_obj = loads(df.to_json(orient='split')) - del json_obj['index'] - json_obj['dtypes'] = df.dtypes.apply(lambda x: x.name).to_dict() + json_obj = loads(df.to_json(orient="split")) + del json_obj["index"] + json_obj["dtypes"] = df.dtypes.apply(lambda x: x.name).to_dict() return json_obj def _to_specific_types(self, df: pd.DataFrame, column_dtypes: dict[str, str]): for column, dtype in column_dtypes.items(): - if dtype == 'datetime64': + if dtype == "datetime64": df[column] = self._to_datetime_and_format(df[column]) else: df[column] = df[column].astype(dtype) @staticmethod def _to_datetime_and_format(series: pd.Series) -> pd.Series: - series = pd.to_datetime(series, errors='coerce') - return series.apply(lambda d: d.strftime('%Y-%m-%d %H:%M:%S.%f' + (' %Z' if series.dt.tz is not None else '')) if not pd.isnull(d) else d) + series = pd.to_datetime(series, errors="coerce") + return series.apply( + lambda d: d.strftime( + "%Y-%m-%d %H:%M:%S.%f" + (" %Z" if series.dt.tz is not None else "") + ) + if not pd.isnull(d) + else d + ) class QueryDryRunError(Exception): diff --git a/ibis-server/app/model/data_source.py b/ibis-server/app/model/data_source.py index bc9b4e729..a970e5157 100644 --- a/ibis-server/app/model/data_source.py +++ b/ibis-server/app/model/data_source.py @@ -25,16 +25,23 @@ def get_connection(self, dto) -> BaseBackend: case DataSource.snowflake: return self.get_snowflake_connection(dto) case _: - raise NotImplementedError(f'Unsupported data source: {self}') + raise NotImplementedError(f"Unsupported data source: {self}") @staticmethod - def get_postgres_connection(info: PostgresConnectionUrl | PostgresConnectionInfo) -> BaseBackend: - return ibis.connect(getattr(info, 'connection_url', None) or f"postgres://{info.user}:{info.password}@{info.host}:{info.port}/{info.database}") + def get_postgres_connection( + info: PostgresConnectionUrl | PostgresConnectionInfo, + ) -> BaseBackend: + return ibis.connect( + getattr(info, "connection_url", None) + or f"postgres://{info.user}:{info.password}@{info.host}:{info.port}/{info.database}" + ) @staticmethod def get_bigquery_connection(info: BigQueryConnectionInfo) -> BaseBackend: - credits_json = loads(base64.b64decode(info.credentials).decode('utf-8')) - credentials = service_account.Credentials.from_service_account_info(credits_json) + credits_json = loads(base64.b64decode(info.credentials).decode("utf-8")) + credentials = service_account.Credentials.from_service_account_info( + credits_json + ) return ibis.bigquery.connect( project_id=info.project_id, dataset_id=info.dataset_id, @@ -75,7 +82,9 @@ class SnowflakeConnectionInfo(BaseModel): password: str account: str database: str - sf_schema: str = Field(alias="schema") # Use `sf_schema` to avoid `schema` shadowing in BaseModel + sf_schema: str = Field( + alias="schema" + ) # Use `sf_schema` to avoid `schema` shadowing in BaseModel ConnectionInfo = Union[ diff --git a/ibis-server/app/model/dto.py b/ibis-server/app/model/dto.py index e7469928f..74059256a 100644 --- a/ibis-server/app/model/dto.py +++ b/ibis-server/app/model/dto.py @@ -6,18 +6,24 @@ PostgresConnectionUrl, PostgresConnectionInfo, BigQueryConnectionInfo, - SnowflakeConnectionInfo + SnowflakeConnectionInfo, ) class IbisDTO(BaseModel): sql: str manifest_str: str = Field(alias="manifestStr", description="Base64 manifest") - column_dtypes: dict[str, str] | None = Field(alias="columnDtypes", description="If this field is set, it will forcibly convert the type.", default=None) + column_dtypes: dict[str, str] | None = Field( + alias="columnDtypes", + description="If this field is set, it will forcibly convert the type.", + default=None, + ) class PostgresDTO(IbisDTO): - connection_info: PostgresConnectionUrl | PostgresConnectionInfo = Field(alias="connectionInfo") + connection_info: PostgresConnectionUrl | PostgresConnectionInfo = Field( + alias="connectionInfo" + ) class BigQueryDTO(IbisDTO): diff --git a/ibis-server/app/routers/ibis/bigquery.py b/ibis-server/app/routers/ibis/bigquery.py index 0c97bc756..cc9cf6b61 100644 --- a/ibis-server/app/routers/ibis/bigquery.py +++ b/ibis-server/app/routers/ibis/bigquery.py @@ -8,15 +8,19 @@ from app.model.data_source import DataSource from app.model.dto import BigQueryDTO -router = APIRouter(prefix='/bigquery', tags=['bigquery']) +router = APIRouter(prefix="/bigquery", tags=["bigquery"]) data_source = DataSource.bigquery @router.post("/query") @log_dto -def query(dto: BigQueryDTO, dry_run: Annotated[bool, Query(alias="dryRun")] = False) -> Response: - connector = Connector(data_source, dto.connection_info, dto.manifest_str, dto.column_dtypes) +def query( + dto: BigQueryDTO, dry_run: Annotated[bool, Query(alias="dryRun")] = False +) -> Response: + connector = Connector( + data_source, dto.connection_info, dto.manifest_str, dto.column_dtypes + ) if dry_run: connector.dry_run(dto.sql) return Response(status_code=204) diff --git a/ibis-server/app/routers/ibis/postgres.py b/ibis-server/app/routers/ibis/postgres.py index 3c73ebe97..d1c2a33a6 100644 --- a/ibis-server/app/routers/ibis/postgres.py +++ b/ibis-server/app/routers/ibis/postgres.py @@ -8,15 +8,19 @@ from app.model.data_source import DataSource from app.model.dto import PostgresDTO -router = APIRouter(prefix='/postgres', tags=['postgres']) +router = APIRouter(prefix="/postgres", tags=["postgres"]) data_source = DataSource.postgres @router.post("/query") @log_dto -def query(dto: PostgresDTO, dry_run: Annotated[bool, Query(alias="dryRun")] = False) -> Response: - connector = Connector(data_source, dto.connection_info, dto.manifest_str, dto.column_dtypes) +def query( + dto: PostgresDTO, dry_run: Annotated[bool, Query(alias="dryRun")] = False +) -> Response: + connector = Connector( + data_source, dto.connection_info, dto.manifest_str, dto.column_dtypes + ) if dry_run: connector.dry_run(dto.sql) return Response(status_code=204) diff --git a/ibis-server/app/routers/ibis/snowflake.py b/ibis-server/app/routers/ibis/snowflake.py index 9cc4d793a..7df007e02 100644 --- a/ibis-server/app/routers/ibis/snowflake.py +++ b/ibis-server/app/routers/ibis/snowflake.py @@ -8,15 +8,19 @@ from app.model.data_source import DataSource from app.model.dto import SnowflakeDTO -router = APIRouter(prefix='/snowflake', tags=['snowflake']) +router = APIRouter(prefix="/snowflake", tags=["snowflake"]) data_source = DataSource.snowflake @router.post("/query") @log_dto -def query(dto: SnowflakeDTO, dry_run: Annotated[bool, Query(alias="dryRun")] = False) -> Response: - connector = Connector(data_source, dto.connection_info, dto.manifest_str, dto.column_dtypes) +def query( + dto: SnowflakeDTO, dry_run: Annotated[bool, Query(alias="dryRun")] = False +) -> Response: + connector = Connector( + data_source, dto.connection_info, dto.manifest_str, dto.column_dtypes + ) if dry_run: connector.dry_run(dto.sql) return Response(status_code=204) diff --git a/ibis-server/tests/routers/ibis/test_bigquery.py b/ibis-server/tests/routers/ibis/test_bigquery.py index ab5d1fe01..b7d6f1318 100644 --- a/ibis-server/tests/routers/ibis/test_bigquery.py +++ b/ibis-server/tests/routers/ibis/test_bigquery.py @@ -21,36 +21,57 @@ class TestBigquery: "columns": [ {"name": "orderkey", "expression": "o_orderkey", "type": "integer"}, {"name": "custkey", "expression": "o_custkey", "type": "integer"}, - {"name": "orderstatus", "expression": "o_orderstatus", "type": "varchar"}, - {"name": "totalprice", "expression": "o_totalprice", "type": "float"}, + { + "name": "orderstatus", + "expression": "o_orderstatus", + "type": "varchar", + }, + { + "name": "totalprice", + "expression": "o_totalprice", + "type": "float", + }, {"name": "orderdate", "expression": "o_orderdate", "type": "date"}, - {"name": "order_cust_key", "expression": "concat(o_orderkey, '_', o_custkey)", "type": "varchar"}, - {"name": "timestamp", "expression": "cast('2024-01-01T23:59:59' as timestamp)", "type": "timestamp"}, - {"name": "timestamptz", "expression": "cast('2024-01-01T23:59:59' as timestamp with time zone)", "type": "timestamp"} + { + "name": "order_cust_key", + "expression": "concat(o_orderkey, '_', o_custkey)", + "type": "varchar", + }, + { + "name": "timestamp", + "expression": "cast('2024-01-01T23:59:59' as timestamp)", + "type": "timestamp", + }, + { + "name": "timestamptz", + "expression": "cast('2024-01-01T23:59:59' as timestamp with time zone)", + "type": "timestamp", + }, ], - "primaryKey": "orderkey" + "primaryKey": "orderkey", }, { "name": "Customer", "refSql": "select * from tpch_tiny.customer", "columns": [ {"name": "custkey", "expression": "c_custkey", "type": "integer"}, - {"name": "name", "expression": "c_name", "type": "varchar"} + {"name": "name", "expression": "c_name", "type": "varchar"}, ], - "primaryKey": "custkey" - } - ] + "primaryKey": "custkey", + }, + ], } - manifest_str = base64.b64encode(orjson.dumps(manifest)).decode('utf-8') + manifest_str = base64.b64encode(orjson.dumps(manifest)).decode("utf-8") @staticmethod def get_connection_info(): import os + return { "project_id": os.getenv("TEST_BIG_QUERY_PROJECT_ID"), "dataset_id": "tpch_tiny", - "credentials": os.getenv("TEST_BIG_QUERY_CREDENTIALS_BASE64_JSON") + "credentials": os.getenv("TEST_BIG_QUERY_CREDENTIALS_BASE64_JSON"), } def test_query(self): @@ -60,23 +81,32 @@ def test_query(self): json={ "connectionInfo": connection_info, "manifestStr": self.manifest_str, - "sql": 'SELECT * FROM "Orders" ORDER BY orderkey LIMIT 1' - } + "sql": 'SELECT * FROM "Orders" ORDER BY orderkey LIMIT 1', + }, ) assert response.status_code == 200 result = response.json() - assert len(result['columns']) == len(self.manifest['models'][0]['columns']) - assert len(result['data']) == 1 - assert result['data'][0] == [1, 370, 'O', 172799.49, 820540800000, '1_370', 1704153599000, 1704153599000] - assert result['dtypes'] == { - 'orderkey': 'int64', - 'custkey': 'int64', - 'orderstatus': 'object', - 'totalprice': 'float64', - 'orderdate': 'object', - 'order_cust_key': 'object', - 'timestamp': 'datetime64[ns]', - 'timestamptz': 'datetime64[ns, UTC]' + assert len(result["columns"]) == len(self.manifest["models"][0]["columns"]) + assert len(result["data"]) == 1 + assert result["data"][0] == [ + 1, + 370, + "O", + 172799.49, + 820540800000, + "1_370", + 1704153599000, + 1704153599000, + ] + assert result["dtypes"] == { + "orderkey": "int64", + "custkey": "int64", + "orderstatus": "object", + "totalprice": "float64", + "orderdate": "object", + "order_cust_key": "object", + "timestamp": "datetime64[ns]", + "timestamptz": "datetime64[ns, UTC]", } def test_query_with_column_dtypes(self): @@ -91,24 +121,33 @@ def test_query_with_column_dtypes(self): "totalprice": "float", "orderdate": "datetime64", "timestamp": "datetime64", - "timestamptz": "datetime64" - } - } + "timestamptz": "datetime64", + }, + }, ) assert response.status_code == 200 result = response.json() - assert len(result['columns']) == len(self.manifest['models'][0]['columns']) - assert len(result['data']) == 1 - assert result['data'][0] == [1, 370, 'O', 172799.49, '1996-01-02 00:00:00.000000', '1_370', '2024-01-01 23:59:59.000000', '2024-01-01 23:59:59.000000 UTC'] - assert result['dtypes'] == { - 'orderkey': 'int64', - 'custkey': 'int64', - 'orderstatus': 'object', - 'totalprice': 'float64', - 'orderdate': 'object', - 'order_cust_key': 'object', - 'timestamp': 'object', - 'timestamptz': 'object' + assert len(result["columns"]) == len(self.manifest["models"][0]["columns"]) + assert len(result["data"]) == 1 + assert result["data"][0] == [ + 1, + 370, + "O", + 172799.49, + "1996-01-02 00:00:00.000000", + "1_370", + "2024-01-01 23:59:59.000000", + "2024-01-01 23:59:59.000000 UTC", + ] + assert result["dtypes"] == { + "orderkey": "int64", + "custkey": "int64", + "orderstatus": "object", + "totalprice": "float64", + "orderdate": "object", + "order_cust_key": "object", + "timestamp": "object", + "timestamptz": "object", } def test_query_without_manifest(self): @@ -117,46 +156,43 @@ def test_query_without_manifest(self): url="/v2/ibis/bigquery/query", json={ "connectionInfo": connection_info, - "sql": 'SELECT * FROM "Orders" LIMIT 1' - } + "sql": 'SELECT * FROM "Orders" LIMIT 1', + }, ) assert response.status_code == 422 result = response.json() - assert result['detail'][0] is not None - assert result['detail'][0]['type'] == 'missing' - assert result['detail'][0]['loc'] == ['body', 'manifestStr'] - assert result['detail'][0]['msg'] == 'Field required' + assert result["detail"][0] is not None + assert result["detail"][0]["type"] == "missing" + assert result["detail"][0]["loc"] == ["body", "manifestStr"] + assert result["detail"][0]["msg"] == "Field required" def test_query_without_sql(self): connection_info = self.get_connection_info() response = client.post( url="/v2/ibis/bigquery/query", - json={ - "connectionInfo": connection_info, - "manifestStr": self.manifest_str - } + json={"connectionInfo": connection_info, "manifestStr": self.manifest_str}, ) assert response.status_code == 422 result = response.json() - assert result['detail'][0] is not None - assert result['detail'][0]['type'] == 'missing' - assert result['detail'][0]['loc'] == ['body', 'sql'] - assert result['detail'][0]['msg'] == 'Field required' + assert result["detail"][0] is not None + assert result["detail"][0]["type"] == "missing" + assert result["detail"][0]["loc"] == ["body", "sql"] + assert result["detail"][0]["msg"] == "Field required" def test_query_without_connection_info(self): response = client.post( url="/v2/ibis/bigquery/query", json={ "manifestStr": self.manifest_str, - "sql": 'SELECT * FROM "Orders" LIMIT 1' - } + "sql": 'SELECT * FROM "Orders" LIMIT 1', + }, ) assert response.status_code == 422 result = response.json() - assert result['detail'][0] is not None - assert result['detail'][0]['type'] == 'missing' - assert result['detail'][0]['loc'] == ['body', 'connectionInfo'] - assert result['detail'][0]['msg'] == 'Field required' + assert result["detail"][0] is not None + assert result["detail"][0]["type"] == "missing" + assert result["detail"][0]["loc"] == ["body", "connectionInfo"] + assert result["detail"][0]["msg"] == "Field required" def test_query_with_dry_run(self): connection_info = self.get_connection_info() @@ -166,8 +202,8 @@ def test_query_with_dry_run(self): json={ "connectionInfo": connection_info, "manifestStr": self.manifest_str, - "sql": 'SELECT * FROM "Orders" LIMIT 1' - } + "sql": 'SELECT * FROM "Orders" LIMIT 1', + }, ) assert response.status_code == 204 @@ -179,8 +215,8 @@ def test_query_with_dry_run_and_invalid_sql(self): json={ "connectionInfo": connection_info, "manifestStr": self.manifest_str, - "sql": 'SELECT * FROM X' - } + "sql": "SELECT * FROM X", + }, ) assert response.status_code == 422 assert response.text is not None diff --git a/ibis-server/tests/routers/ibis/test_postgres.py b/ibis-server/tests/routers/ibis/test_postgres.py index 885efd808..004cbd8df 100644 --- a/ibis-server/tests/routers/ibis/test_postgres.py +++ b/ibis-server/tests/routers/ibis/test_postgres.py @@ -14,6 +14,7 @@ def postgres(request) -> PostgresContainer: def file_path(path: str) -> str: import os + return os.path.join(os.path.dirname(__file__), path) import sqlalchemy @@ -23,8 +24,12 @@ def file_path(path: str) -> str: pg.start() psql_url = pg.get_connection_url() engine = sqlalchemy.create_engine(psql_url) - pd.read_parquet(file_path("../../resource/tpch/data/orders.parquet")).to_sql("orders", engine, index=False) - pd.read_parquet(file_path("../../resource/tpch/data/customer.parquet")).to_sql("customer", engine, index=False) + pd.read_parquet(file_path("../../resource/tpch/data/orders.parquet")).to_sql( + "orders", engine, index=False + ) + pd.read_parquet(file_path("../../resource/tpch/data/customer.parquet")).to_sql( + "customer", engine, index=False + ) def stop_pg(): pg.stop() @@ -46,28 +51,48 @@ class TestPostgres: "columns": [ {"name": "orderkey", "expression": "o_orderkey", "type": "integer"}, {"name": "custkey", "expression": "o_custkey", "type": "integer"}, - {"name": "orderstatus", "expression": "o_orderstatus", "type": "varchar"}, - {"name": "totalprice", "expression": "o_totalprice", "type": "float"}, + { + "name": "orderstatus", + "expression": "o_orderstatus", + "type": "varchar", + }, + { + "name": "totalprice", + "expression": "o_totalprice", + "type": "float", + }, {"name": "orderdate", "expression": "o_orderdate", "type": "date"}, - {"name": "order_cust_key", "expression": "concat(o_orderkey, '_', o_custkey)", "type": "varchar"}, - {"name": "timestamp", "expression": "cast('2024-01-01T23:59:59' as timestamp)", "type": "timestamp"}, - {"name": "timestamptz", "expression": "cast('2024-01-01T23:59:59' as timestamp with time zone)", "type": "timestamp"} + { + "name": "order_cust_key", + "expression": "concat(o_orderkey, '_', o_custkey)", + "type": "varchar", + }, + { + "name": "timestamp", + "expression": "cast('2024-01-01T23:59:59' as timestamp)", + "type": "timestamp", + }, + { + "name": "timestamptz", + "expression": "cast('2024-01-01T23:59:59' as timestamp with time zone)", + "type": "timestamp", + }, ], - "primaryKey": "orderkey" + "primaryKey": "orderkey", }, { "name": "Customer", "refSql": "select * from public.customer", "columns": [ {"name": "custkey", "expression": "c_custkey", "type": "integer"}, - {"name": "name", "expression": "c_name", "type": "varchar"} + {"name": "name", "expression": "c_name", "type": "varchar"}, ], - "primaryKey": "custkey" - } - ] + "primaryKey": "custkey", + }, + ], } - manifest_str = base64.b64encode(orjson.dumps(manifest)).decode('utf-8') + manifest_str = base64.b64encode(orjson.dumps(manifest)).decode("utf-8") @staticmethod def to_connection_info(pg: PostgresContainer): @@ -76,7 +101,7 @@ def to_connection_info(pg: PostgresContainer): "port": pg.get_exposed_port(pg.port), "user": pg.username, "password": pg.password, - "database": pg.dbname + "database": pg.dbname, } @staticmethod @@ -91,23 +116,32 @@ def test_query(self, postgres: PostgresContainer): json={ "connectionInfo": connection_info, "manifestStr": self.manifest_str, - "sql": 'SELECT * FROM "Orders" LIMIT 1' - } + "sql": 'SELECT * FROM "Orders" LIMIT 1', + }, ) assert response.status_code == 200 result = response.json() - assert len(result['columns']) == len(self.manifest['models'][0]['columns']) - assert len(result['data']) == 1 - assert result['data'][0] == [1, 370, 'O', '172799.49', 820540800000, '1_370', 1704153599000, 1704153599000] - assert result['dtypes'] == { - 'orderkey': 'int32', - 'custkey': 'int32', - 'orderstatus': 'object', - 'totalprice': 'object', - 'orderdate': 'object', - 'order_cust_key': 'object', - 'timestamp': 'datetime64[ns]', - 'timestamptz': 'datetime64[ns, UTC]' + assert len(result["columns"]) == len(self.manifest["models"][0]["columns"]) + assert len(result["data"]) == 1 + assert result["data"][0] == [ + 1, + 370, + "O", + "172799.49", + 820540800000, + "1_370", + 1704153599000, + 1704153599000, + ] + assert result["dtypes"] == { + "orderkey": "int32", + "custkey": "int32", + "orderstatus": "object", + "totalprice": "object", + "orderdate": "object", + "order_cust_key": "object", + "timestamp": "datetime64[ns]", + "timestamptz": "datetime64[ns, UTC]", } def test_query_with_connection_url(self, postgres: PostgresContainer): @@ -115,19 +149,17 @@ def test_query_with_connection_url(self, postgres: PostgresContainer): response = client.post( url="/v2/ibis/postgres/query", json={ - "connectionInfo": { - "connectionUrl": connection_url - }, + "connectionInfo": {"connectionUrl": connection_url}, "manifestStr": self.manifest_str, - "sql": 'SELECT * FROM "Orders" LIMIT 1' - } + "sql": 'SELECT * FROM "Orders" LIMIT 1', + }, ) assert response.status_code == 200 result = response.json() - assert len(result['columns']) == len(self.manifest['models'][0]['columns']) - assert len(result['data']) == 1 - assert result['data'][0][0] == 1 - assert result['dtypes'] is not None + assert len(result["columns"]) == len(self.manifest["models"][0]["columns"]) + assert len(result["data"]) == 1 + assert result["data"][0][0] == 1 + assert result["dtypes"] is not None def test_query_with_column_dtypes(self, postgres: PostgresContainer): connection_info = self.to_connection_info(postgres) @@ -141,24 +173,33 @@ def test_query_with_column_dtypes(self, postgres: PostgresContainer): "totalprice": "float", "orderdate": "datetime64", "timestamp": "datetime64", - "timestamptz": "datetime64" - } - } + "timestamptz": "datetime64", + }, + }, ) assert response.status_code == 200 result = response.json() - assert len(result['columns']) == len(self.manifest['models'][0]['columns']) - assert len(result['data']) == 1 - assert result['data'][0] == [1, 370, 'O', 172799.49, '1996-01-02 00:00:00.000000', '1_370', '2024-01-01 23:59:59.000000', '2024-01-01 23:59:59.000000 UTC'] - assert result['dtypes'] == { - 'orderkey': 'int32', - 'custkey': 'int32', - 'orderstatus': 'object', - 'totalprice': 'float64', - 'orderdate': 'object', - 'order_cust_key': 'object', - 'timestamp': 'object', - 'timestamptz': 'object' + assert len(result["columns"]) == len(self.manifest["models"][0]["columns"]) + assert len(result["data"]) == 1 + assert result["data"][0] == [ + 1, + 370, + "O", + 172799.49, + "1996-01-02 00:00:00.000000", + "1_370", + "2024-01-01 23:59:59.000000", + "2024-01-01 23:59:59.000000 UTC", + ] + assert result["dtypes"] == { + "orderkey": "int32", + "custkey": "int32", + "orderstatus": "object", + "totalprice": "float64", + "orderdate": "object", + "order_cust_key": "object", + "timestamp": "object", + "timestamptz": "object", } def test_query_without_manifest(self, postgres: PostgresContainer): @@ -167,46 +208,43 @@ def test_query_without_manifest(self, postgres: PostgresContainer): url="/v2/ibis/postgres/query", json={ "connectionInfo": connection_info, - "sql": 'SELECT * FROM "Orders" LIMIT 1' - } + "sql": 'SELECT * FROM "Orders" LIMIT 1', + }, ) assert response.status_code == 422 result = response.json() - assert result['detail'][0] is not None - assert result['detail'][0]['type'] == 'missing' - assert result['detail'][0]['loc'] == ['body', 'manifestStr'] - assert result['detail'][0]['msg'] == 'Field required' + assert result["detail"][0] is not None + assert result["detail"][0]["type"] == "missing" + assert result["detail"][0]["loc"] == ["body", "manifestStr"] + assert result["detail"][0]["msg"] == "Field required" def test_query_without_sql(self, postgres: PostgresContainer): connection_info = self.to_connection_info(postgres) response = client.post( url="/v2/ibis/postgres/query", - json={ - "connectionInfo": connection_info, - "manifestStr": self.manifest_str - } + json={"connectionInfo": connection_info, "manifestStr": self.manifest_str}, ) assert response.status_code == 422 result = response.json() - assert result['detail'][0] is not None - assert result['detail'][0]['type'] == 'missing' - assert result['detail'][0]['loc'] == ['body', 'sql'] - assert result['detail'][0]['msg'] == 'Field required' + assert result["detail"][0] is not None + assert result["detail"][0]["type"] == "missing" + assert result["detail"][0]["loc"] == ["body", "sql"] + assert result["detail"][0]["msg"] == "Field required" def test_query_without_connection_info(self): response = client.post( url="/v2/ibis/postgres/query", json={ "manifestStr": self.manifest_str, - "sql": 'SELECT * FROM "Orders" LIMIT 1' - } + "sql": 'SELECT * FROM "Orders" LIMIT 1', + }, ) assert response.status_code == 422 result = response.json() - assert result['detail'][0] is not None - assert result['detail'][0]['type'] == 'missing' - assert result['detail'][0]['loc'] == ['body', 'connectionInfo'] - assert result['detail'][0]['msg'] == 'Field required' + assert result["detail"][0] is not None + assert result["detail"][0]["type"] == "missing" + assert result["detail"][0]["loc"] == ["body", "connectionInfo"] + assert result["detail"][0]["msg"] == "Field required" def test_query_with_dry_run(self, postgres: PostgresContainer): connection_info = self.to_connection_info(postgres) @@ -216,8 +254,8 @@ def test_query_with_dry_run(self, postgres: PostgresContainer): json={ "connectionInfo": connection_info, "manifestStr": self.manifest_str, - "sql": 'SELECT * FROM "Orders" LIMIT 1' - } + "sql": 'SELECT * FROM "Orders" LIMIT 1', + }, ) assert response.status_code == 204 @@ -229,8 +267,8 @@ def test_query_with_dry_run_and_invalid_sql(self, postgres: PostgresContainer): json={ "connectionInfo": connection_info, "manifestStr": self.manifest_str, - "sql": 'SELECT * FROM X' - } + "sql": "SELECT * FROM X", + }, ) assert response.status_code == 422 assert response.text is not None diff --git a/ibis-server/tests/routers/ibis/test_snowflake.py b/ibis-server/tests/routers/ibis/test_snowflake.py index 060a27404..482546df4 100644 --- a/ibis-server/tests/routers/ibis/test_snowflake.py +++ b/ibis-server/tests/routers/ibis/test_snowflake.py @@ -22,38 +22,59 @@ class TestSnowflake: "columns": [ {"name": "orderkey", "expression": "O_ORDERKEY", "type": "integer"}, {"name": "custkey", "expression": "O_CUSTKEY", "type": "integer"}, - {"name": "orderstatus", "expression": "O_ORDERSTATUS", "type": "varchar"}, - {"name": "totalprice", "expression": "O_TOTALPRICE", "type": "float"}, + { + "name": "orderstatus", + "expression": "O_ORDERSTATUS", + "type": "varchar", + }, + { + "name": "totalprice", + "expression": "O_TOTALPRICE", + "type": "float", + }, {"name": "orderdate", "expression": "O_ORDERDATE", "type": "date"}, - {"name": "order_cust_key", "expression": "concat(O_ORDERKEY, '_', O_CUSTKEY)", "type": "varchar"}, - {"name": "timestamp", "expression": "cast('2024-01-01T23:59:59' as timestamp)", "type": "timestamp"}, - {"name": "timestamptz", "expression": "cast('2024-01-01T23:59:59' as timestamp with time zone)", "type": "timestamp"} + { + "name": "order_cust_key", + "expression": "concat(O_ORDERKEY, '_', O_CUSTKEY)", + "type": "varchar", + }, + { + "name": "timestamp", + "expression": "cast('2024-01-01T23:59:59' as timestamp)", + "type": "timestamp", + }, + { + "name": "timestamptz", + "expression": "cast('2024-01-01T23:59:59' as timestamp with time zone)", + "type": "timestamp", + }, ], - "primaryKey": "orderkey" + "primaryKey": "orderkey", }, { "name": "Customer", "refSql": "select * from TPCH_SF1.CUSTOMER", "columns": [ {"name": "custkey", "expression": "C_CUSTKEY", "type": "integer"}, - {"name": "name", "expression": "C_NAME", "type": "varchar"} + {"name": "name", "expression": "C_NAME", "type": "varchar"}, ], - "primaryKey": "custkey" - } - ] + "primaryKey": "custkey", + }, + ], } - manifest_str = base64.b64encode(orjson.dumps(manifest)).decode('utf-8') + manifest_str = base64.b64encode(orjson.dumps(manifest)).decode("utf-8") @staticmethod def get_connection_info(): import os + return { "user": os.getenv("SNOWFLAKE_USER"), "password": os.getenv("SNOWFLAKE_PASSWORD"), "account": os.getenv("SNOWFLAKE_ACCOUNT"), "database": "SNOWFLAKE_SAMPLE_DATA", - "schema": "TPCH_SF1" + "schema": "TPCH_SF1", } def test_query(self): @@ -63,23 +84,32 @@ def test_query(self): json={ "connectionInfo": connection_info, "manifestStr": self.manifest_str, - "sql": 'SELECT * FROM "Orders" ORDER BY "orderkey" LIMIT 1' - } + "sql": 'SELECT * FROM "Orders" ORDER BY "orderkey" LIMIT 1', + }, ) assert response.status_code == 200 result = response.json() - assert len(result['columns']) == len(self.manifest['models'][0]['columns']) - assert len(result['data']) == 1 - assert result['data'][0] == [1, 36901, 'O', 173665.47, 820540800000, '1_36901', 1704153599000, 1704153599000] - assert result['dtypes'] == { - 'orderkey': 'int64', - 'custkey': 'int64', - 'orderstatus': 'object', - 'totalprice': 'object', - 'orderdate': 'object', - 'order_cust_key': 'object', - 'timestamp': 'datetime64[ns]', - 'timestamptz': 'datetime64[ns, UTC]' + assert len(result["columns"]) == len(self.manifest["models"][0]["columns"]) + assert len(result["data"]) == 1 + assert result["data"][0] == [ + 1, + 36901, + "O", + 173665.47, + 820540800000, + "1_36901", + 1704153599000, + 1704153599000, + ] + assert result["dtypes"] == { + "orderkey": "int64", + "custkey": "int64", + "orderstatus": "object", + "totalprice": "object", + "orderdate": "object", + "order_cust_key": "object", + "timestamp": "datetime64[ns]", + "timestamptz": "datetime64[ns, UTC]", } def test_query_with_column_dtypes(self): @@ -94,24 +124,33 @@ def test_query_with_column_dtypes(self): "totalprice": "float", "orderdate": "datetime64", "timestamp": "datetime64", - "timestamptz": "datetime64" - } - } + "timestamptz": "datetime64", + }, + }, ) assert response.status_code == 200 result = response.json() - assert len(result['columns']) == len(self.manifest['models'][0]['columns']) - assert len(result['data']) == 1 - assert result['data'][0] == [1, 36901, 'O', 173665.47, '1996-01-02 00:00:00.000000', '1_36901', '2024-01-01 23:59:59.000000', '2024-01-01 23:59:59.000000 UTC'] - assert result['dtypes'] == { - 'orderkey': 'int64', - 'custkey': 'int64', - 'orderstatus': 'object', - 'totalprice': 'float64', - 'orderdate': 'object', - 'order_cust_key': 'object', - 'timestamp': 'object', - 'timestamptz': 'object' + assert len(result["columns"]) == len(self.manifest["models"][0]["columns"]) + assert len(result["data"]) == 1 + assert result["data"][0] == [ + 1, + 36901, + "O", + 173665.47, + "1996-01-02 00:00:00.000000", + "1_36901", + "2024-01-01 23:59:59.000000", + "2024-01-01 23:59:59.000000 UTC", + ] + assert result["dtypes"] == { + "orderkey": "int64", + "custkey": "int64", + "orderstatus": "object", + "totalprice": "float64", + "orderdate": "object", + "order_cust_key": "object", + "timestamp": "object", + "timestamptz": "object", } def test_query_without_manifest(self): @@ -120,46 +159,43 @@ def test_query_without_manifest(self): url="/v2/ibis/snowflake/query", json={ "connectionInfo": connection_info, - "sql": 'SELECT * FROM "Orders" LIMIT 1' - } + "sql": 'SELECT * FROM "Orders" LIMIT 1', + }, ) assert response.status_code == 422 result = response.json() - assert result['detail'][0] is not None - assert result['detail'][0]['type'] == 'missing' - assert result['detail'][0]['loc'] == ['body', 'manifestStr'] - assert result['detail'][0]['msg'] == 'Field required' + assert result["detail"][0] is not None + assert result["detail"][0]["type"] == "missing" + assert result["detail"][0]["loc"] == ["body", "manifestStr"] + assert result["detail"][0]["msg"] == "Field required" def test_query_without_sql(self): connection_info = self.get_connection_info() response = client.post( url="/v2/ibis/snowflake/query", - json={ - "connectionInfo": connection_info, - "manifestStr": self.manifest_str - } + json={"connectionInfo": connection_info, "manifestStr": self.manifest_str}, ) assert response.status_code == 422 result = response.json() - assert result['detail'][0] is not None - assert result['detail'][0]['type'] == 'missing' - assert result['detail'][0]['loc'] == ['body', 'sql'] - assert result['detail'][0]['msg'] == 'Field required' + assert result["detail"][0] is not None + assert result["detail"][0]["type"] == "missing" + assert result["detail"][0]["loc"] == ["body", "sql"] + assert result["detail"][0]["msg"] == "Field required" def test_query_without_connection_info(self): response = client.post( url="/v2/ibis/snowflake/query", json={ "manifestStr": self.manifest_str, - "sql": 'SELECT * FROM "Orders" LIMIT 1' - } + "sql": 'SELECT * FROM "Orders" LIMIT 1', + }, ) assert response.status_code == 422 result = response.json() - assert result['detail'][0] is not None - assert result['detail'][0]['type'] == 'missing' - assert result['detail'][0]['loc'] == ['body', 'connectionInfo'] - assert result['detail'][0]['msg'] == 'Field required' + assert result["detail"][0] is not None + assert result["detail"][0]["type"] == "missing" + assert result["detail"][0]["loc"] == ["body", "connectionInfo"] + assert result["detail"][0]["msg"] == "Field required" def test_query_with_dry_run(self): connection_info = self.get_connection_info() @@ -169,8 +205,8 @@ def test_query_with_dry_run(self): json={ "connectionInfo": connection_info, "manifestStr": self.manifest_str, - "sql": 'SELECT * FROM "Orders" LIMIT 1' - } + "sql": 'SELECT * FROM "Orders" LIMIT 1', + }, ) assert response.status_code == 204 @@ -182,8 +218,8 @@ def test_query_with_dry_run_and_invalid_sql(self): json={ "connectionInfo": connection_info, "manifestStr": self.manifest_str, - "sql": 'SELECT * FROM X' - } + "sql": "SELECT * FROM X", + }, ) assert response.status_code == 422 assert response.text is not None