diff --git a/.changes/unreleased/Under the Hood-20230508-222313.yaml b/.changes/unreleased/Under the Hood-20230508-222313.yaml new file mode 100644 index 000000000..29a628119 --- /dev/null +++ b/.changes/unreleased/Under the Hood-20230508-222313.yaml @@ -0,0 +1,6 @@ +kind: Under the Hood +body: Convert information into dict +time: 2023-05-08T22:23:13.704302+02:00 +custom: + Author: Fokko + Issue: "751" diff --git a/dbt/adapters/spark/impl.py b/dbt/adapters/spark/impl.py index 2864c4f30..6a571ec4b 100644 --- a/dbt/adapters/spark/impl.py +++ b/dbt/adapters/spark/impl.py @@ -60,6 +60,14 @@ class SparkConfig(AdapterConfig): merge_update_columns: Optional[str] = None +@dataclass(frozen=True) +class RelationInfo: + table_schema: str + table_name: str + columns: List[Tuple[str, str]] + properties: Dict[str, str] + + class SparkAdapter(SQLAdapter): COLUMN_NAMES = ( "table_database", @@ -81,9 +89,7 @@ class SparkAdapter(SQLAdapter): "stats:rows:description", "stats:rows:include", ) - INFORMATION_COLUMNS_REGEX = re.compile(r"^ \|-- (.*): (.*) \(nullable = (.*)\b", re.MULTILINE) - INFORMATION_OWNER_REGEX = re.compile(r"^Owner: (.*)$", re.MULTILINE) - INFORMATION_STATISTICS_REGEX = re.compile(r"^Statistics: (.*)$", re.MULTILINE) + INFORMATION_COLUMN_REGEX = re.compile(r"[ | ]* \|-- (.*)\: (.*) \(nullable = (.*)\)") HUDI_METADATA_COLUMNS = [ "_hoodie_commit_time", @@ -102,7 +108,6 @@ class SparkAdapter(SQLAdapter): } Relation: TypeAlias = SparkRelation - RelationInfo = Tuple[str, str, str] Column: TypeAlias = SparkColumn ConnectionManager: TypeAlias = SparkConnectionManager AdapterSpecificConfigs: TypeAlias = SparkConfig @@ -138,13 +143,54 @@ def quote(self, identifier: str) -> str: # type: ignore def _get_relation_information(self, row: agate.Row) -> RelationInfo: """relation info was fetched with SHOW TABLES EXTENDED""" try: - _schema, name, _, information = row + table_properties = {} + columns = [] + _schema, name, _, information_blob = row + for line in information_blob.split("\n"): + if line: + if " |--" in line: + # A column + match = self.INFORMATION_COLUMN_REGEX.match(line) + if match: + columns.append((match[1], match[2])) + else: + logger.warning(f"Could not parse column: {line}") + else: + # A property + parts = line.split(": ", maxsplit=2) + if len(parts) == 2: + table_properties[parts[0]] = parts[1] + else: + logger.warning(f"Found invalid property: {line}") + except ValueError: raise dbt.exceptions.DbtRuntimeError( f'Invalid value from "show tables extended ...", got {len(row)} values, expected 4' ) - return _schema, name, information + return RelationInfo(_schema, name, columns, table_properties) + + def _parse_describe_table_extended( + self, table_results: agate.Table + ) -> Tuple[List[Tuple[str, str]], Dict[str, str]]: + # Wrap it in an iter, so we continue reading the properties from where we stopped reading columns + table_results_itr = iter(table_results) + + # First the columns + columns = [] + for info_row in table_results_itr: + if info_row[0] is None or info_row[0] == "" or info_row[0].startswith("#"): + break + columns.append((info_row[0], str(info_row[1]))) + + # Next all the properties + table_properties = {} + for info_row in table_results_itr: + info_type, info_value = info_row[:2] + if info_type is not None and not info_type.startswith("#") and info_type != "": + table_properties[info_type] = str(info_value) + + return columns, table_properties def _get_relation_information_using_describe(self, row: agate.Row) -> RelationInfo: """Relation info fetched using SHOW TABLES and an auxiliary DESCRIBE statement""" @@ -164,13 +210,8 @@ def _get_relation_information_using_describe(self, row: agate.Row) -> RelationIn logger.debug(f"Error while retrieving information about {table_name}: {e.msg}") table_results = AttrDict() - information = "" - for info_row in table_results: - info_type, info_value, _ = info_row - if not info_type.startswith("#"): - information += f"{info_type}: {info_value}\n" - - return _schema, name, information + columns, table_properties = self._parse_describe_table_extended(table_results) + return RelationInfo(_schema, name, columns, table_properties) def _build_spark_relation_list( self, @@ -178,27 +219,28 @@ def _build_spark_relation_list( relation_info_func: Callable[[agate.Row], RelationInfo], ) -> List[BaseRelation]: """Aggregate relations with format metadata included.""" - relations = [] + relations: List[BaseRelation] = [] for row in row_list: - _schema, name, information = relation_info_func(row) + relation = relation_info_func(row) rel_type: RelationType = ( - RelationType.View if "Type: VIEW" in information else RelationType.Table + RelationType.View + if relation.properties.get("Type") == "VIEW" + else RelationType.Table ) - is_delta: bool = "Provider: delta" in information - is_hudi: bool = "Provider: hudi" in information - is_iceberg: bool = "Provider: iceberg" in information - - relation: BaseRelation = self.Relation.create( - schema=_schema, - identifier=name, - type=rel_type, - information=information, - is_delta=is_delta, - is_iceberg=is_iceberg, - is_hudi=is_hudi, + + relations.append( + self.Relation.create( + schema=relation.table_schema, + identifier=relation.table_name, + type=rel_type, + is_delta=relation.properties.get("Provider") == "delta", + is_iceberg=relation.properties.get("Provider") == "iceberg", + is_hudi=relation.properties.get("Provider") == "hudi", + columns=relation.columns, + properties=relation.properties, + ) ) - relations.append(relation) return relations @@ -248,20 +290,30 @@ def get_relation(self, database: str, schema: str, identifier: str) -> Optional[ return super().get_relation(database, schema, identifier) - def parse_describe_extended( - self, relation: BaseRelation, raw_rows: AttrDict - ) -> List[SparkColumn]: - # Convert the Row to a dict - dict_rows = [dict(zip(row._keys, row._values)) for row in raw_rows] - # Find the separator between the rows and the metadata provided - # by the DESCRIBE TABLE EXTENDED statement - pos = self.find_table_information_separator(dict_rows) - - # Remove rows that start with a hash, they are comments - rows = [row for row in raw_rows[0:pos] if not row["col_name"].startswith("#")] - metadata = {col["col_name"]: col["data_type"] for col in raw_rows[pos + 1 :]} + def get_columns_in_relation(self, relation: BaseRelation) -> List[SparkColumn]: + assert isinstance(relation, SparkRelation) + if relation.columns is not None and len(relation.columns) > 0: + columns = relation.columns + properties = relation.properties + else: + try: + describe_extended_result = self.execute_macro( + GET_COLUMNS_IN_RELATION_RAW_MACRO_NAME, kwargs={"relation": relation} + ) + columns, properties = self._parse_describe_table_extended(describe_extended_result) + except dbt.exceptions.DbtRuntimeError as e: + # spark would throw error when table doesn't exist, where other + # CDW would just return and empty list, normalizing the behavior here + errmsg = getattr(e, "msg", "") + found_msgs = (msg in errmsg for msg in TABLE_OR_VIEW_NOT_FOUND_MESSAGES) + if any(found_msgs): + columns = [] + properties = {} + else: + raise e - raw_table_stats = metadata.get(KEY_TABLE_STATISTICS) + # Convert the Row to a dict + raw_table_stats = properties.get(KEY_TABLE_STATISTICS) table_stats = SparkColumn.convert_table_stats(raw_table_stats) return [ SparkColumn( @@ -269,59 +321,23 @@ def parse_describe_extended( table_schema=relation.schema, table_name=relation.name, table_type=relation.type, - table_owner=str(metadata.get(KEY_TABLE_OWNER)), + table_owner=properties.get(KEY_TABLE_OWNER, ""), table_stats=table_stats, - column=column["col_name"], + column=column_name, column_index=idx, - dtype=column["data_type"], + dtype=column_type, ) - for idx, column in enumerate(rows) + for idx, (column_name, column_type) in enumerate(columns) + if column_name not in self.HUDI_METADATA_COLUMNS ] - @staticmethod - def find_table_information_separator(rows: List[dict]) -> int: - pos = 0 - for row in rows: - if not row["col_name"] or row["col_name"].startswith("#"): - break - pos += 1 - return pos - - def get_columns_in_relation(self, relation: BaseRelation) -> List[SparkColumn]: - columns = [] - try: - rows: AttrDict = self.execute_macro( - GET_COLUMNS_IN_RELATION_RAW_MACRO_NAME, kwargs={"relation": relation} - ) - columns = self.parse_describe_extended(relation, rows) - except dbt.exceptions.DbtRuntimeError as e: - # spark would throw error when table doesn't exist, where other - # CDW would just return and empty list, normalizing the behavior here - errmsg = getattr(e, "msg", "") - found_msgs = (msg in errmsg for msg in TABLE_OR_VIEW_NOT_FOUND_MESSAGES) - if any(found_msgs): - pass - else: - raise e - - # strip hudi metadata columns. - columns = [x for x in columns if x.name not in self.HUDI_METADATA_COLUMNS] - return columns - - def parse_columns_from_information(self, relation: BaseRelation) -> List[SparkColumn]: - if hasattr(relation, "information"): - information = relation.information or "" - else: - information = "" - owner_match = re.findall(self.INFORMATION_OWNER_REGEX, information) - owner = owner_match[0] if owner_match else None - matches = re.finditer(self.INFORMATION_COLUMNS_REGEX, information) + def parse_columns_from_information(self, relation: SparkRelation) -> List[SparkColumn]: + owner = relation.properties.get(KEY_TABLE_OWNER, "") columns = [] - stats_match = re.findall(self.INFORMATION_STATISTICS_REGEX, information) - raw_table_stats = stats_match[0] if stats_match else None - table_stats = SparkColumn.convert_table_stats(raw_table_stats) - for match_num, match in enumerate(matches): - column_name, column_type, nullable = match.groups() + table_stats = SparkColumn.convert_table_stats( + relation.properties.get(KEY_TABLE_STATISTICS) + ) + for match_num, (column_name, column_type) in enumerate(relation.columns): column = SparkColumn( table_database=None, table_schema=relation.schema, @@ -337,7 +353,7 @@ def parse_columns_from_information(self, relation: BaseRelation) -> List[SparkCo return columns def _get_columns_for_catalog(self, relation: BaseRelation) -> Iterable[Dict[str, Any]]: - columns = self.parse_columns_from_information(relation) + columns = self.parse_columns_from_information(relation) # type: ignore for column in columns: # convert SparkColumns into catalog dicts @@ -410,13 +426,15 @@ def get_rows_different_sql( """ # This method only really exists for test reasons. names: List[str] - if column_names is None: + if not column_names: columns = self.get_columns_in_relation(relation_a) names = sorted((self.quote(c.name) for c in columns)) else: names = sorted((self.quote(n) for n in column_names)) columns_csv = ", ".join(names) + assert columns_csv, f"Could not find columns for: {relation_a}" + sql = COLUMNS_EQUAL_SQL.format( columns=columns_csv, relation_a=str(relation_a), diff --git a/dbt/adapters/spark/relation.py b/dbt/adapters/spark/relation.py index e80f2623f..453af51af 100644 --- a/dbt/adapters/spark/relation.py +++ b/dbt/adapters/spark/relation.py @@ -1,4 +1,4 @@ -from typing import Optional, TypeVar +from typing import Optional, TypeVar, List, Tuple, Dict from dataclasses import dataclass, field from dbt.adapters.base.relation import BaseRelation, Policy @@ -33,8 +33,8 @@ class SparkRelation(BaseRelation): is_delta: Optional[bool] = None is_hudi: Optional[bool] = None is_iceberg: Optional[bool] = None - # TODO: make this a dict everywhere - information: Optional[str] = None + columns: List[Tuple[str, str]] = field(default_factory=list) + properties: Dict[str, str] = field(default_factory=dict) def __post_init__(self) -> None: if self.database != self.schema and self.database: diff --git a/tests/unit/test_adapter.py b/tests/unit/test_adapter.py index 1eb818241..92f7ac709 100644 --- a/tests/unit/test_adapter.py +++ b/tests/unit/test_adapter.py @@ -1,11 +1,13 @@ import unittest from unittest import mock +import agate import dbt.flags as flags from dbt.exceptions import DbtRuntimeError from agate import Row from pyhive import hive from dbt.adapters.spark import SparkAdapter, SparkRelation +from dbt.adapters.spark.impl import RelationInfo, KEY_TABLE_OWNER from .utils import config_from_parts_or_dicts @@ -322,10 +324,15 @@ def test_parse_relation(self): input_cols = [Row(keys=["col_name", "data_type"], values=r) for r in plain_rows] config = self._get_target_http(self.project_cfg) - rows = SparkAdapter(config).parse_describe_extended(relation, input_cols) - self.assertEqual(len(rows), 4) + adapter = SparkAdapter(config) + columns, properties = adapter._parse_describe_table_extended(input_cols) + relation_info = adapter._build_spark_relation_list( + columns, lambda a: RelationInfo(relation.schema, relation.name, columns, properties) + ) + columns = adapter.get_columns_in_relation(relation_info[0]) + self.assertEqual(len(columns), 4) self.assertEqual( - rows[0].to_column_dict(omit_none=False), + columns[0].to_column_dict(omit_none=False), { "table_database": None, "table_schema": relation.schema, @@ -342,7 +349,7 @@ def test_parse_relation(self): ) self.assertEqual( - rows[1].to_column_dict(omit_none=False), + columns[1].to_column_dict(omit_none=False), { "table_database": None, "table_schema": relation.schema, @@ -359,7 +366,7 @@ def test_parse_relation(self): ) self.assertEqual( - rows[2].to_column_dict(omit_none=False), + columns[2].to_column_dict(omit_none=False), { "table_database": None, "table_schema": relation.schema, @@ -376,7 +383,7 @@ def test_parse_relation(self): ) self.assertEqual( - rows[3].to_column_dict(omit_none=False), + columns[3].to_column_dict(omit_none=False), { "table_database": None, "table_schema": relation.schema, @@ -408,12 +415,10 @@ def test_parse_relation_with_integer_owner(self): ("Owner", 1234), ] - input_cols = [Row(keys=["col_name", "data_type"], values=r) for r in plain_rows] - config = self._get_target_http(self.project_cfg) - rows = SparkAdapter(config).parse_describe_extended(relation, input_cols) + _, properties = SparkAdapter(config)._parse_describe_table_extended(plain_rows) - self.assertEqual(rows[0].to_column_dict().get("table_owner"), "1234") + self.assertEqual(properties.get(KEY_TABLE_OWNER), "1234") def test_parse_relation_with_statistics(self): self.maxDiff = None @@ -444,10 +449,16 @@ def test_parse_relation_with_statistics(self): ("Partition Provider", "Catalog"), ] - input_cols = [Row(keys=["col_name", "data_type"], values=r) for r in plain_rows] - config = self._get_target_http(self.project_cfg) - rows = SparkAdapter(config).parse_describe_extended(relation, input_cols) + columns, properties = SparkAdapter(config)._parse_describe_table_extended(plain_rows) + spark_relation = SparkRelation.create( + schema=relation.schema, + identifier=relation.name, + type=rel_type, + columns=columns, + properties=properties, + ) + rows = SparkAdapter(config).parse_columns_from_information(spark_relation) self.assertEqual(len(rows), 1) self.assertEqual( rows[0].to_column_dict(omit_none=False), @@ -552,19 +563,37 @@ def test_parse_columns_from_information_with_table_type_and_delta_provider(self) " |-- struct_col: struct (nullable = true)\n" " | |-- struct_inner_col: string (nullable = true)\n" ) - relation = SparkRelation.create( - schema="default_schema", identifier="mytable", type=rel_type, information=information + row = agate.MappedSequence(("default_schema", "mytable", False, information)) + config = self._get_target_http(self.project_cfg) + adapter = SparkAdapter(config) + + tables = adapter._build_spark_relation_list( + row_list=[row], + relation_info_func=adapter._get_relation_information, ) + self.assertEqual(len(tables), 1) - config = self._get_target_http(self.project_cfg) - columns = SparkAdapter(config).parse_columns_from_information(relation) - self.assertEqual(len(columns), 4) + table = tables[0] + + assert isinstance(table, SparkRelation) + + columns = adapter.get_columns_in_relation( + SparkRelation.create( + type=rel_type, + schema="default_schema", + identifier="mytable", + columns=table.columns, + properties=table.properties, + ) + ) + + self.assertEqual(len(columns), 5) self.assertEqual( columns[0].to_column_dict(omit_none=False), { "table_database": None, - "table_schema": relation.schema, - "table_name": relation.name, + "table_schema": "default_schema", + "table_name": "mytable", "table_type": rel_type, "table_owner": "root", "column": "col1", @@ -584,8 +613,8 @@ def test_parse_columns_from_information_with_table_type_and_delta_provider(self) columns[3].to_column_dict(omit_none=False), { "table_database": None, - "table_schema": relation.schema, - "table_name": relation.name, + "table_schema": "default_schema", + "table_name": "mytable", "table_type": rel_type, "table_owner": "root", "column": "struct_col", @@ -637,19 +666,38 @@ def test_parse_columns_from_information_with_view_type(self): " |-- struct_col: struct (nullable = true)\n" " | |-- struct_inner_col: string (nullable = true)\n" ) - relation = SparkRelation.create( - schema="default_schema", identifier="myview", type=rel_type, information=information - ) + row = agate.MappedSequence(("default_schema", "myview", False, information)) config = self._get_target_http(self.project_cfg) - columns = SparkAdapter(config).parse_columns_from_information(relation) - self.assertEqual(len(columns), 4) + adapter = SparkAdapter(config) + + tables = adapter._build_spark_relation_list( + row_list=[row], + relation_info_func=adapter._get_relation_information, + ) + self.assertEqual(len(tables), 1) + + table = tables[0] + + assert isinstance(table, SparkRelation) + + columns = adapter.get_columns_in_relation( + SparkRelation.create( + type=rel_type, + schema="default_schema", + identifier="myview", + columns=table.columns, + properties=table.properties, + ) + ) + + self.assertEqual(len(columns), 5) self.assertEqual( columns[1].to_column_dict(omit_none=False), { "table_database": None, - "table_schema": relation.schema, - "table_name": relation.name, + "table_schema": "default_schema", + "table_name": "myview", "table_type": rel_type, "table_owner": "root", "column": "col2", @@ -665,8 +713,8 @@ def test_parse_columns_from_information_with_view_type(self): columns[3].to_column_dict(omit_none=False), { "table_database": None, - "table_schema": relation.schema, - "table_name": relation.name, + "table_schema": "default_schema", + "table_name": "myview", "table_type": rel_type, "table_owner": "root", "column": "struct_col", @@ -703,19 +751,38 @@ def test_parse_columns_from_information_with_table_type_and_parquet_provider(sel " |-- struct_col: struct (nullable = true)\n" " | |-- struct_inner_col: string (nullable = true)\n" ) - relation = SparkRelation.create( - schema="default_schema", identifier="mytable", type=rel_type, information=information - ) + row = agate.MappedSequence(("default_schema", "mytable", False, information)) config = self._get_target_http(self.project_cfg) - columns = SparkAdapter(config).parse_columns_from_information(relation) - self.assertEqual(len(columns), 4) + adapter = SparkAdapter(config) + + tables = adapter._build_spark_relation_list( + row_list=[row], + relation_info_func=adapter._get_relation_information, + ) + self.assertEqual(len(tables), 1) + + table = tables[0] + + assert isinstance(table, SparkRelation) + + columns = adapter.get_columns_in_relation( + SparkRelation.create( + type=rel_type, + schema="default_schema", + identifier="mytable", + columns=table.columns, + properties=table.properties, + ) + ) + + self.assertEqual(len(columns), 5) self.assertEqual( columns[2].to_column_dict(omit_none=False), { "table_database": None, - "table_schema": relation.schema, - "table_name": relation.name, + "table_schema": "default_schema", + "table_name": "mytable", "table_type": rel_type, "table_owner": "root", "column": "dt", @@ -739,8 +806,8 @@ def test_parse_columns_from_information_with_table_type_and_parquet_provider(sel columns[3].to_column_dict(omit_none=False), { "table_database": None, - "table_schema": relation.schema, - "table_name": relation.name, + "table_schema": "default_schema", + "table_name": "mytable", "table_type": rel_type, "table_owner": "root", "column": "struct_col",