From 779d0540922a0657e15bdfba8f0c186c817ea624 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Sat, 2 Nov 2024 17:40:07 +0100 Subject: [PATCH] fixes column_schemas --- dlt/common/destination/reference.py | 2 +- dlt/destinations/dataset.py | 25 ++++++++++++------- .../impl/sqlalchemy/db_api_client.py | 2 +- dlt/destinations/sql_client.py | 10 ++++---- 4 files changed, 23 insertions(+), 16 deletions(-) diff --git a/dlt/common/destination/reference.py b/dlt/common/destination/reference.py index 8847dee541..1f3a6c1120 100644 --- a/dlt/common/destination/reference.py +++ b/dlt/common/destination/reference.py @@ -475,7 +475,7 @@ def create_followup_jobs(self, final_state: TLoadJobState) -> List[FollowupJobRe class SupportsReadableRelation(Protocol): """A readable relation retrieved from a destination that supports it""" - schema_columns: TTableSchemaColumns + columns_schema: TTableSchemaColumns """Known dlt table columns for this relation""" def df(self, chunk_size: int = None) -> Optional[DataFrame]: diff --git a/dlt/destinations/dataset.py b/dlt/destinations/dataset.py index 2adc177401..cffdc0f059 100644 --- a/dlt/destinations/dataset.py +++ b/dlt/destinations/dataset.py @@ -115,24 +115,31 @@ def query(self) -> Any: return f"SELECT {maybe_limit_clause_1} {selector} FROM {table_name} {maybe_limit_clause_2}" @property - def computed_schema_columns(self) -> TTableSchemaColumns: + def columns_schema(self) -> TTableSchemaColumns: + return self.compute_columns_schema() + + @columns_schema.setter + def columns_schema(self, new_value: TTableSchemaColumns) -> None: + raise NotImplementedError("columns schema in ReadableDBAPIRelation can only be computed") + + def compute_columns_schema(self) -> TTableSchemaColumns: """provide schema columns for the cursor, may be filtered by selected columns""" - schema_columns = ( + columns_schema = ( self.schema.tables.get(self._table_name, {}).get("columns", {}) if self.schema else {} ) - if not schema_columns: + if not columns_schema: return None if not self._selected_columns: - return schema_columns + return columns_schema filtered_columns: TTableSchemaColumns = {} for sc in self._selected_columns: sc = self.schema.naming.normalize_path(sc) - if sc not in schema_columns.keys(): + if sc not in columns_schema.keys(): raise ReadableRelationUnknownColumnException(sc) - filtered_columns[sc] = schema_columns[sc] + filtered_columns[sc] = columns_schema[sc] return filtered_columns @@ -146,8 +153,8 @@ def cursor(self) -> Generator[SupportsReadableRelation, Any, Any]: if hasattr(self.sql_client, "_conn") and hasattr(self.sql_client._conn, "autocommit"): self.sql_client._conn.autocommit = False with client.execute_query(self.query) as cursor: - if schema_columns := self.computed_schema_columns: - cursor.schema_columns = schema_columns + if columns_schema := self.columns_schema: + cursor.columns_schema = columns_schema yield cursor def _wrap_iter(self, func_name: str) -> Any: @@ -191,7 +198,7 @@ def select(self, *columns: str) -> "ReadableDBAPIRelation": rel._selected_columns = columns # NOTE: the line below will ensure that no unknown columns are selected if # schema is known - rel.computed_schema_columns + rel.compute_columns_schema() return rel def __getitem__(self, columns: Union[str, Sequence[str]]) -> "SupportsReadableRelation": diff --git a/dlt/destinations/impl/sqlalchemy/db_api_client.py b/dlt/destinations/impl/sqlalchemy/db_api_client.py index a407e53d70..6f3ff065bf 100644 --- a/dlt/destinations/impl/sqlalchemy/db_api_client.py +++ b/dlt/destinations/impl/sqlalchemy/db_api_client.py @@ -80,7 +80,7 @@ def __init__(self, curr: sa.engine.CursorResult) -> None: self.fetchone = curr.fetchone # type: ignore[assignment] self.fetchmany = curr.fetchmany # type: ignore[assignment] - self.set_default_schema_columns() + self._set_default_schema_columns() def _get_columns(self) -> List[str]: try: diff --git a/dlt/destinations/sql_client.py b/dlt/destinations/sql_client.py index cbce930d98..7c2029dd7b 100644 --- a/dlt/destinations/sql_client.py +++ b/dlt/destinations/sql_client.py @@ -332,7 +332,7 @@ def __init__(self, curr: DBApiCursor) -> None: self.fetchmany = curr.fetchmany # type: ignore self.fetchone = curr.fetchone # type: ignore - self.set_default_schema_columns() + self._set_default_schema_columns() def __getattr__(self, name: str) -> Any: return getattr(self.native_cursor, name) @@ -342,8 +342,8 @@ def _get_columns(self) -> List[str]: return [c[0] for c in self.native_cursor.description] return [] - def set_default_schema_columns(self) -> None: - self.schema_columns = cast( + def _set_default_schema_columns(self) -> None: + self.columns_schema = cast( TTableSchemaColumns, {c: {"name": c, "nullable": True} for c in self._get_columns()} ) @@ -397,11 +397,11 @@ def iter_arrow(self, chunk_size: int) -> Generator[ArrowTable, None, None]: if not chunk_size: result = self.fetchall() - yield row_tuples_to_arrow(result, caps, self.schema_columns, tz="UTC") + yield row_tuples_to_arrow(result, caps, self.columns_schema, tz="UTC") return for result in self.iter_fetch(chunk_size=chunk_size): - yield row_tuples_to_arrow(result, caps, self.schema_columns, tz="UTC") + yield row_tuples_to_arrow(result, caps, self.columns_schema, tz="UTC") def raise_database_error(f: TFun) -> TFun: