Skip to content

Commit

Permalink
uu
Browse files Browse the repository at this point in the history
  • Loading branch information
dbittenbender committed Jun 10, 2024
1 parent d504d06 commit 3eb118c
Show file tree
Hide file tree
Showing 4 changed files with 234 additions and 72 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,11 @@ def get_key_sql_statement(self, schema_name, table_name) -> Any:
return None

@abc.abstractmethod
def get_view_def_sql_statement(self, schema_name, table_name) -> Any:
def get_old_view_def_sql_statement(self, schema_name, table_name) -> Any:
return None

@abc.abstractmethod
def get_new_view_def_sql_statement(self, schema_name, view_name) -> Any:
return None

def init(self, conf: ConfigTree) -> None:
Expand All @@ -66,6 +70,10 @@ def init(self, conf: ConfigTree) -> None:

self._database = conf.get_string(BasePostgresMetadataExtractor.DATABASE_KEY, default='postgres')

# where_clause_suffix = conf.get_string(BasePostgresMetadataExtractor.WHERE_CLAUSE_SUFFIX_KEY)
# if where_clause_suffix and where_clause_suffix != '':
# where_clause_suffix = f'WHERE {where_clause_suffix}'

self.sql_stmt = self.get_sql_statement(
use_catalog_as_cluster_name=conf.get_bool(BasePostgresMetadataExtractor.USE_CATALOG_AS_CLUSTER_NAME),
where_clause_suffix=conf.get_string(BasePostgresMetadataExtractor.WHERE_CLAUSE_SUFFIX_KEY),
Expand Down Expand Up @@ -150,37 +158,43 @@ def _get_extract_iter(self) -> Iterator[TableMetadata]:
table_metadata = TableMetadata(self._database, last_row['cluster'],
last_row['schema'],
last_row['name'],
last_row['description'],
last_row['description'] if 'description' in last_row else None,
columns,
is_view=last_row['is_view'])
yield table_metadata

if bool(last_row['is_view']) == True:
results = self.connection.execute(self.get_view_def_sql_statement(schema_name=last_row['schema'], view_name=last_row['name']))
view_row = results.fetchone()
LOGGER.info(f"view_row={view_row}")
if view_row:
view_def = view_row[0]
if view_def:
qp = PostgreSQLQueryProcessor()
try:
qp.set_query(view_def)

qp.process_query()

LOGGER.info(f"View table: {qp.tables}")

if qp.tables is not None and len(qp.tables) > 0:
for table in qp.tables:
table_key = TableMetadata.TABLE_KEY_FORMAT.format(db=self._database, cluster=last_row['cluster'], schema=table[0].lower(), tbl=table[1].lower())
LOGGER.info(f"Table Lineage: table={table_key} downstream={table_metadata._get_table_key()}")
yield TableLineage(
table_key=table_key,
downstream_deps=[table_metadata._get_table_key()]
)

except QuerySyntaxError as e:
LOGGER.exception(f"Error parsing the query for {last_row['schema']}.{last_row['name']}:")
results = None
try:
results = self.connection.execute(self.get_new_view_def_sql_statement(schema_name=last_row['schema'], view_name=last_row['name']))
except Exception as e:
results = self.connection.execute(self.get_old_view_def_sql_statement(schema_name=last_row['schema'], view_name=last_row['name']))
finally:
if results is not None:
view_row = results.fetchone()
LOGGER.info(f"view_row={view_row}")
if view_row:
view_def = view_row[0]
if view_def:
qp = PostgreSQLQueryProcessor()
try:
qp.set_query(view_def)

qp.process_query()

LOGGER.info(f"View table: {qp.tables}")

if qp.tables is not None and len(qp.tables) > 0:
for table in qp.tables:
table_key = TableMetadata.TABLE_KEY_FORMAT.format(db=self._database, cluster=last_row['cluster'], schema=table[0].lower(), tbl=table[1].lower())
LOGGER.info(f"Table Lineage: table={table_key} downstream={table_metadata._get_table_key()}")
yield TableLineage(
table_key=table_key,
downstream_deps=[table_metadata._get_table_key()]
)

except QuerySyntaxError as e:
LOGGER.exception(f"Error parsing the query for {last_row['schema']}.{last_row['name']}:")

def _get_raw_extract_iter(self) -> Iterator[Dict[str, Any]]:
"""
Expand Down
148 changes: 135 additions & 13 deletions databuilder/databuilder/extractor/mssql_metadata_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,15 @@

from pyhocon import ConfigFactory, ConfigTree

from sqlalchemy import create_engine
from sqlglot import parse_one, exp

from databuilder import Scoped
from databuilder.extractor import sql_alchemy_extractor
from databuilder.extractor.base_extractor import Extractor
from databuilder.extractor.sql_alchemy_extractor import SQLAlchemyExtractor
from databuilder.models.table_metadata import ColumnMetadata, TableMetadata
from databuilder.models.table_lineage import TableLineage

TableKey = namedtuple('TableKey', ['schema_name', 'table_name'])

Expand All @@ -35,7 +41,8 @@ class MSSQLMetadataExtractor(Extractor):
COL.COLUMN_NAME AS [col_name],
COL.DATA_TYPE AS [col_type],
CAST(PROP_COL.VALUE AS NVARCHAR(MAX)) AS [col_description],
COL.ORDINAL_POSITION AS col_sort_order
COL.ORDINAL_POSITION AS col_sort_order,
CASE WHEN TBL.TABLE_TYPE = 'VIEW' THEN 'True' ELSE 'False' END AS is_view
FROM INFORMATION_SCHEMA.TABLES TBL
INNER JOIN INFORMATION_SCHEMA.COLUMNS COL
ON (COL.TABLE_NAME = TBL.TABLE_NAME
Expand All @@ -48,13 +55,12 @@ class MSSQLMetadataExtractor(Extractor):
ON (PROP_COL.MAJOR_ID = OBJECT_ID(TBL.TABLE_SCHEMA + '.' + TBL.TABLE_NAME)
AND PROP_COL.MINOR_ID = COL.ORDINAL_POSITION
AND PROP_COL.NAME = 'MS_Description')
WHERE TBL.TABLE_TYPE = 'base table' {where_clause_suffix}
WHERE (TBL.TABLE_TYPE = 'base table' OR TBL.TABLE_TYPE = 'VIEW') {where_clause_suffix}
ORDER BY
CLUSTER,
SCHEMA_NAME,
NAME,
COL_SORT_ORDER
;
COL_SORT_ORDER;
"""

# CONFIG KEYS
Expand Down Expand Up @@ -108,8 +114,69 @@ def init(self, conf: ConfigTree) -> None:
LOGGER.info('SQL for MS SQL Metadata: %s', self.sql_stmt)

self._alchemy_extractor = sql_alchemy_extractor.from_surrounding_config(conf, self.sql_stmt)
sql_alch_conf = Scoped.get_scoped_conf(conf, self._alchemy_extractor.get_scope())
self.connection = self._get_connection(sql_alch_conf)

self._extract_iter: Union[None, Iterator] = None

def _get_connection(self, sql_alch_conf) -> Any:
"""
Create a SQLAlchemy connection to Database
"""
conn_string = sql_alch_conf.get_string(SQLAlchemyExtractor.CONN_STRING)

connect_args = {
k: v
for k, v in sql_alch_conf.get_config(
'connect_args', default=ConfigTree()
).items()
}
engine = create_engine(conn_string, connect_args=connect_args)
conn = engine.connect()
return conn

def get_key_sql_statement(self, schema_name, table_name) -> Any:
return """
SELECT
CASE
WHEN CONSTRAINT_TYPE = 'PRIMARY KEY' THEN 'Primary Key'
WHEN CONSTRAINT_TYPE = 'FOREIGN KEY' THEN 'Foreign Key'
WHEN CONSTRAINT_TYPE = 'UNIQUE' THEN 'Unique Constraint'
ELSE 'Unknown'
END AS constraint_type,
SCHEMA_NAME(t.schema_id) AS table_schema,
t.name AS table_name,
c.name AS column_name
FROM sys.tables t
INNER JOIN sys.columns c ON t.object_id = c.object_id
LEFT JOIN (
SELECT
KCU.TABLE_SCHEMA,
KCU.TABLE_NAME,
KCU.COLUMN_NAME,
TC.CONSTRAINT_TYPE
FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS TC
JOIN INFORMATION_SCHEMA.KEY_COLUMN_USAGE KCU
ON TC.CONSTRAINT_NAME = KCU.CONSTRAINT_NAME
AND TC.TABLE_NAME = KCU.TABLE_NAME
WHERE TC.CONSTRAINT_TYPE IN ('PRIMARY KEY', 'FOREIGN KEY', 'UNIQUE')
) AS cons ON t.name = cons.TABLE_NAME AND c.name = cons.COLUMN_NAME
WHERE SCHEMA_NAME(t.schema_id) = '{schema_name}'
AND t.name = '{table_name}'
ORDER BY table_schema, table_name, column_name;
""".format(schema_name=schema_name, table_name=table_name)

def get_view_def_sql_statement(self, schema_name, table_name) -> Any:
return """
SELECT OBJECT_SCHEMA_NAME(object_id) AS schema_name,
OBJECT_NAME(object_id) AS view_name,
definition AS view_definition
FROM sys.sql_modules
WHERE OBJECTPROPERTY(object_id, 'IsView') = 1
AND OBJECT_SCHEMA_NAME(object_id) = '{schema_name}'
AND OBJECT_NAME(object_id) = '{table_name}';
""".format(schema_name=schema_name, table_name=table_name)

def close(self) -> None:
if getattr(self, '_alchemy_extractor', None) is not None:
self._alchemy_extractor.close()
Expand All @@ -133,24 +200,79 @@ def _get_extract_iter(self) -> Iterator[TableMetadata]:
"""
for key, group in groupby(self._get_raw_extract_iter(), self._get_table_key):
columns = []
key_cols = None

for row in group:
last_row = row
columns.append(
ColumnMetadata(
row['col_name'],
row['col_description'],
row['col_type'],
row['col_sort_order']))

yield TableMetadata(

if key_cols is None:
results = self.connection.execute(self.get_key_sql_statement(schema_name=last_row['schema_name'], table_name=last_row['name']))
LOGGER.info(f"results={results}")
if results:
key_cols = {}
for key_row in results:
# Access columns by name or index
constraint_type = key_row['constraint_type'].lower().replace(" ", "")
key_column = key_row['column_name']

if key_column in key_cols:
key_cols[key_column].append(constraint_type)
else:
key_cols[key_column] = [constraint_type]

col_badges = []
if key_cols is not None and row['col_name'] in key_cols:
LOGGER.info(f"Found KEY={row['col_name']}")
LOGGER.info(f"Badges={key_cols[row['col_name']]}")
col_badges = key_cols[row['col_name']]

col_metadata = ColumnMetadata(row['col_name'], row['col_description'],
row['col_type'], row['col_sort_order'],
badges=col_badges)

columns.append(col_metadata)

table_metadata = TableMetadata(
self._database,
last_row['cluster'],
last_row['schema_name'],
last_row['name'],
last_row['description'],
columns,
tags=last_row['schema_name'])
tags=last_row['schema_name'],
is_view=last_row['is_view'])
yield table_metadata

if bool(last_row['is_view']) == True:
results = None
try:
results = self.connection.execute(self.get_view_def_sql_statement(schema_name=last_row['schema_name'], view_name=last_row['name']))
except Exception as e:
LOGGER.exception('Failed to get view def:')
finally:
if results is not None:
view_row = results.fetchone()
LOGGER.info(f"view_row={view_row}")
if view_row:
view_def = view_row[0]
if view_def:
try:
view_def = view_def.replace(']', '').replace('[', '')
view_tables = parse_one(view_def).find_all(exp.Table)

LOGGER.info(f"View table: {view_tables}")

if view_tables is not None and len(view_tables) > 0:
for table in view_tables:
table_key = TableMetadata.TABLE_KEY_FORMAT.format(db=self._database, cluster=last_row['cluster'], schema=table.db.lower(), tbl=table.this.lower())
LOGGER.info(f"Table Lineage: table={table_key} downstream={table_metadata._get_table_key()}")
yield TableLineage(
table_key=table_key,
downstream_deps=[table_metadata._get_table_key()]
)

except Exception as e:
LOGGER.exception(f"Error parsing the view def for {last_row['schema_name']}.{last_row['name']}:")

def _get_raw_extract_iter(self) -> Iterator[Dict[str, Any]]:
"""
Expand Down
82 changes: 51 additions & 31 deletions databuilder/databuilder/extractor/postgres_metadata_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,37 +22,46 @@ def get_sql_statement(self, use_catalog_as_cluster_name: bool, where_clause_suff
cluster_source = f"'{self._cluster}'"

return """
WITH Objects AS (
SELECT
current_database() AS cluster,
n.nspname AS schema,
c.relname AS name,
CASE WHEN c.relkind = 'r' THEN NULL ELSE v.definition END AS object_description,
c.relkind = 'v' AS is_view
FROM pg_catalog.pg_class c
LEFT JOIN pg_catalog.pg_views v ON c.relname = v.viewname
INNER JOIN pg_catalog.pg_namespace n ON c.relnamespace = n.oid
WHERE c.relkind IN ('r', 'v') {where_clause_suffix}
),
Columns AS (
SELECT
current_database() AS cluster,
n.nspname AS schema,
c.relname AS name,
a.attname AS col_name,
format_type(a.atttypid, a.atttypmod) AS col_type,
pd.description AS col_description,
a.attnum AS col_sort_order
FROM pg_catalog.pg_attribute a
INNER JOIN pg_catalog.pg_class c ON a.attrelid = c.oid
INNER JOIN pg_catalog.pg_namespace n ON c.relnamespace = n.oid
LEFT JOIN pg_catalog.pg_description pd ON c.oid = pd.objoid AND a.attnum = pd.objsubid
WHERE c.relkind IN ('r', 'v') {where_clause_suffix} AND a.attnum > 0
)
SELECT
{cluster_source} as cluster,
st.schemaname as schema,
st.relname as name,
pgtd.description as description,
att.attname as col_name,
pgtyp.typname as col_type,
pgcd.description as col_description,
att.attnum as col_sort_order,
CASE
WHEN pg_class.relkind = 'v' THEN true
ELSE false
END AS is_view
FROM pg_catalog.pg_attribute att
INNER JOIN
pg_catalog.pg_statio_all_tables as st
on att.attrelid=st.relid
INNER JOIN
pg_catalog.pg_class pg_class
ON st.relid = pg_class.oid
LEFT JOIN
pg_catalog.pg_type pgtyp
on pgtyp.oid=att.atttypid
LEFT JOIN
pg_catalog.pg_description pgtd
on pgtd.objoid=st.relid and pgtd.objsubid=0
LEFT JOIN
pg_catalog.pg_description pgcd
on pgcd.objoid=st.relid and pgcd.objsubid=att.attnum
WHERE att.attnum >=0 and {where_clause_suffix}
ORDER by cluster, schema, name, col_sort_order;
o.cluster,
o.schema,
o.name,
o.object_description,
c.col_name,
c.col_type,
c.col_description,
c.col_sort_order,
o.is_view
FROM Objects o
LEFT JOIN Columns c ON o.schema = c.schema AND o.name = c.name
ORDER BY o.cluster, o.schema, o.name, c.col_sort_order;
""".format(
cluster_source=cluster_source,
where_clause_suffix=where_clause_suffix,
Expand Down Expand Up @@ -85,7 +94,7 @@ def get_key_sql_statement(self, schema_name, table_name) -> Any:
AND tbl.relname = '{table_name}';
""".format(schema_name=schema_name, table_name=table_name)

def get_view_def_sql_statement(self, schema_name, view_name) -> Any:
def get_old_view_def_sql_statement(self, schema_name, view_name) -> Any:
return """
SELECT
view_definition
Expand All @@ -96,5 +105,16 @@ def get_view_def_sql_statement(self, schema_name, view_name) -> Any:
AND view_name = '{view_name}';
""".format(schema_name=schema_name, view_name=view_name)

def get_new_view_def_sql_statement(self, schema_name, view_name) -> Any:
return """
SELECT
definition
FROM
pg_views
WHERE
schemaname = '{schema_name}'
AND viewname = '{view_name}';
""".format(schema_name=schema_name, view_name=view_name)

def get_scope(self) -> str:
return 'extractor.postgres_metadata'
Loading

0 comments on commit 3eb118c

Please sign in to comment.