Skip to content

Commit

Permalink
fixes column_schemas
Browse files Browse the repository at this point in the history
  • Loading branch information
rudolfix committed Nov 2, 2024
1 parent 8570c31 commit 779d054
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 16 deletions.
2 changes: 1 addition & 1 deletion dlt/common/destination/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,7 @@ def create_followup_jobs(self, final_state: TLoadJobState) -> List[FollowupJobRe
class SupportsReadableRelation(Protocol):
"""A readable relation retrieved from a destination that supports it"""

schema_columns: TTableSchemaColumns
columns_schema: TTableSchemaColumns
"""Known dlt table columns for this relation"""

def df(self, chunk_size: int = None) -> Optional[DataFrame]:
Expand Down
25 changes: 16 additions & 9 deletions dlt/destinations/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,24 +115,31 @@ def query(self) -> Any:
return f"SELECT {maybe_limit_clause_1} {selector} FROM {table_name} {maybe_limit_clause_2}"

@property
def computed_schema_columns(self) -> TTableSchemaColumns:
def columns_schema(self) -> TTableSchemaColumns:
return self.compute_columns_schema()

@columns_schema.setter
def columns_schema(self, new_value: TTableSchemaColumns) -> None:
raise NotImplementedError("columns schema in ReadableDBAPIRelation can only be computed")

def compute_columns_schema(self) -> TTableSchemaColumns:
"""provide schema columns for the cursor, may be filtered by selected columns"""

schema_columns = (
columns_schema = (
self.schema.tables.get(self._table_name, {}).get("columns", {}) if self.schema else {}
)

if not schema_columns:
if not columns_schema:
return None
if not self._selected_columns:
return schema_columns
return columns_schema

filtered_columns: TTableSchemaColumns = {}
for sc in self._selected_columns:
sc = self.schema.naming.normalize_path(sc)
if sc not in schema_columns.keys():
if sc not in columns_schema.keys():
raise ReadableRelationUnknownColumnException(sc)
filtered_columns[sc] = schema_columns[sc]
filtered_columns[sc] = columns_schema[sc]

return filtered_columns

Expand All @@ -146,8 +153,8 @@ def cursor(self) -> Generator[SupportsReadableRelation, Any, Any]:
if hasattr(self.sql_client, "_conn") and hasattr(self.sql_client._conn, "autocommit"):
self.sql_client._conn.autocommit = False
with client.execute_query(self.query) as cursor:
if schema_columns := self.computed_schema_columns:
cursor.schema_columns = schema_columns
if columns_schema := self.columns_schema:
cursor.columns_schema = columns_schema
yield cursor

def _wrap_iter(self, func_name: str) -> Any:
Expand Down Expand Up @@ -191,7 +198,7 @@ def select(self, *columns: str) -> "ReadableDBAPIRelation":
rel._selected_columns = columns
# NOTE: the line below will ensure that no unknown columns are selected if
# schema is known
rel.computed_schema_columns
rel.compute_columns_schema()
return rel

def __getitem__(self, columns: Union[str, Sequence[str]]) -> "SupportsReadableRelation":
Expand Down
2 changes: 1 addition & 1 deletion dlt/destinations/impl/sqlalchemy/db_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def __init__(self, curr: sa.engine.CursorResult) -> None:
self.fetchone = curr.fetchone # type: ignore[assignment]
self.fetchmany = curr.fetchmany # type: ignore[assignment]

self.set_default_schema_columns()
self._set_default_schema_columns()

def _get_columns(self) -> List[str]:
try:
Expand Down
10 changes: 5 additions & 5 deletions dlt/destinations/sql_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def __init__(self, curr: DBApiCursor) -> None:
self.fetchmany = curr.fetchmany # type: ignore
self.fetchone = curr.fetchone # type: ignore

self.set_default_schema_columns()
self._set_default_schema_columns()

def __getattr__(self, name: str) -> Any:
return getattr(self.native_cursor, name)
Expand All @@ -342,8 +342,8 @@ def _get_columns(self) -> List[str]:
return [c[0] for c in self.native_cursor.description]
return []

def set_default_schema_columns(self) -> None:
self.schema_columns = cast(
def _set_default_schema_columns(self) -> None:
self.columns_schema = cast(
TTableSchemaColumns, {c: {"name": c, "nullable": True} for c in self._get_columns()}
)

Expand Down Expand Up @@ -397,11 +397,11 @@ def iter_arrow(self, chunk_size: int) -> Generator[ArrowTable, None, None]:

if not chunk_size:
result = self.fetchall()
yield row_tuples_to_arrow(result, caps, self.schema_columns, tz="UTC")
yield row_tuples_to_arrow(result, caps, self.columns_schema, tz="UTC")
return

for result in self.iter_fetch(chunk_size=chunk_size):
yield row_tuples_to_arrow(result, caps, self.schema_columns, tz="UTC")
yield row_tuples_to_arrow(result, caps, self.columns_schema, tz="UTC")


def raise_database_error(f: TFun) -> TFun:
Expand Down

0 comments on commit 779d054

Please sign in to comment.