Skip to content

Commit

Permalink
fix: hive metadata extractor not work on postgresql (#394)
Browse files Browse the repository at this point in the history
* fix: hive metadata extractor not work on postgresql

Signed-off-by: zhmin <[email protected]>

* fix: add unit test hive metadata extractor patch

Signed-off-by: zhmin <[email protected]>

Co-authored-by: root <[email protected]>
  • Loading branch information
zhmin and root authored Oct 28, 2020
1 parent a230c56 commit 2992618
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 10 deletions.
47 changes: 42 additions & 5 deletions databuilder/extractor/hive_table_metadata_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from pyhocon import ConfigFactory, ConfigTree
from typing import Iterator, Union, Dict, Any
from sqlalchemy.engine.url import make_url

from databuilder import Scoped
from databuilder.extractor.table_metadata_constants import PARTITION_BADGE
Expand Down Expand Up @@ -56,6 +57,34 @@ class HiveTableMetadataExtractor(Extractor):
ORDER by tbl_id, is_partition_col desc;
"""

DEFAULT_POSTGRES_SQL_STATEMENT = """
SELECT source.* FROM
(SELECT t."TBL_ID" as tbl_id, d."NAME" as "schema", t."TBL_NAME" as name, t."TBL_TYPE",
tp."PARAM_VALUE" as description, p."PKEY_NAME" as col_name, p."INTEGER_IDX" as col_sort_order,
p."PKEY_TYPE" as col_type, p."PKEY_COMMENT" as col_description, 1 as "is_partition_col",
CASE WHEN t."TBL_TYPE" = 'VIRTUAL_VIEW' THEN 1
ELSE 0 END as "is_view"
FROM "TBLS" t
JOIN "DBS" d ON t."DB_ID" = d."DB_ID"
JOIN "PARTITION_KEYS" p ON t."TBL_ID" = p."TBL_ID"
LEFT JOIN "TABLE_PARAMS" tp ON (t."TBL_ID" = tp."TBL_ID" AND tp."PARAM_KEY"='comment')
{where_clause_suffix}
UNION
SELECT t."TBL_ID" as tbl_id, d."NAME" as "schema", t."TBL_NAME" as name, t."TBL_TYPE",
tp."PARAM_VALUE" as description, c."COLUMN_NAME" as col_name, c."INTEGER_IDX" as col_sort_order,
c."TYPE_NAME" as col_type, c."COMMENT" as col_description, 0 as "is_partition_col",
CASE WHEN t."TBL_TYPE" = 'VIRTUAL_VIEW' THEN 1
ELSE 0 END as "is_view"
FROM "TBLS" t
JOIN "DBS" d ON t."DB_ID" = d."DB_ID"
JOIN "SDS" s ON t."SD_ID" = s."SD_ID"
JOIN "COLUMNS_V2" c ON s."CD_ID" = c."CD_ID"
LEFT JOIN "TABLE_PARAMS" tp ON (t."TBL_ID" = tp."TBL_ID" AND tp."PARAM_KEY"='comment')
{where_clause_suffix}
) source
ORDER by tbl_id, is_partition_col desc;
"""

# CONFIG KEYS
WHERE_CLAUSE_SUFFIX_KEY = 'where_clause_suffix'
CLUSTER_KEY = 'cluster'
Expand All @@ -67,20 +96,28 @@ def init(self, conf: ConfigTree) -> None:
conf = conf.with_fallback(HiveTableMetadataExtractor.DEFAULT_CONFIG)
self._cluster = '{}'.format(conf.get_string(HiveTableMetadataExtractor.CLUSTER_KEY))

default_sql = HiveTableMetadataExtractor.DEFAULT_SQL_STATEMENT.format(
self._alchemy_extractor = SQLAlchemyExtractor()

sql_alch_conf = Scoped.get_scoped_conf(conf, self._alchemy_extractor.get_scope())
default_sql = self._choose_default_sql_stm(sql_alch_conf).format(
where_clause_suffix=conf.get_string(HiveTableMetadataExtractor.WHERE_CLAUSE_SUFFIX_KEY))

self.sql_stmt = conf.get_string(HiveTableMetadataExtractor.EXTRACT_SQL, default=default_sql)

LOGGER.info('SQL for hive metastore: {}'.format(self.sql_stmt))

self._alchemy_extractor = SQLAlchemyExtractor()
sql_alch_conf = Scoped.get_scoped_conf(conf, self._alchemy_extractor.get_scope())\
.with_fallback(ConfigFactory.from_dict({SQLAlchemyExtractor.EXTRACT_SQL: self.sql_stmt}))

sql_alch_conf = sql_alch_conf.with_fallback(ConfigFactory.from_dict(
{SQLAlchemyExtractor.EXTRACT_SQL: self.sql_stmt}))
self._alchemy_extractor.init(sql_alch_conf)
self._extract_iter: Union[None, Iterator] = None

def _choose_default_sql_stm(self, conf: ConfigTree) -> str:
url = make_url(conf.get_string(SQLAlchemyExtractor.CONN_STRING))
if url.drivername.lower() in ['postgresql', 'postgres']:
return HiveTableMetadataExtractor.DEFAULT_POSTGRES_SQL_STATEMENT
else:
return HiveTableMetadataExtractor.DEFAULT_SQL_STATEMENT

def extract(self) -> Union[TableMetadata, None]:
if not self._extract_iter:
self._extract_iter = self._get_extract_iter()
Expand Down
20 changes: 15 additions & 5 deletions tests/unit/extractor/test_hive_table_metadata_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,19 @@ def test_extraction_with_empty_query_result(self) -> None:
"""
Test Extraction with empty result from query
"""
with patch.object(SQLAlchemyExtractor, '_get_connection'):
with patch.object(SQLAlchemyExtractor, '_get_connection'), \
patch.object(HiveTableMetadataExtractor, '_choose_default_sql_stm',
return_value=HiveTableMetadataExtractor.DEFAULT_SQL_STATEMENT):
extractor = HiveTableMetadataExtractor()
extractor.init(self.conf)

results = extractor.extract()
self.assertEqual(results, None)

def test_extraction_with_single_result(self) -> None:
with patch.object(SQLAlchemyExtractor, '_get_connection') as mock_connection:
with patch.object(SQLAlchemyExtractor, '_get_connection') as mock_connection, \
patch.object(HiveTableMetadataExtractor, '_choose_default_sql_stm',
return_value=HiveTableMetadataExtractor.DEFAULT_SQL_STATEMENT):
connection = MagicMock()
mock_connection.return_value = connection
sql_execute = MagicMock()
Expand Down Expand Up @@ -101,7 +105,9 @@ def test_extraction_with_single_result(self) -> None:
self.assertIsNone(extractor.extract())

def test_extraction_with_multiple_result(self) -> None:
with patch.object(SQLAlchemyExtractor, '_get_connection') as mock_connection:
with patch.object(SQLAlchemyExtractor, '_get_connection') as mock_connection, \
patch.object(HiveTableMetadataExtractor, '_choose_default_sql_stm',
return_value=HiveTableMetadataExtractor.DEFAULT_SQL_STATEMENT):
connection = MagicMock()
mock_connection.return_value = connection
sql_execute = MagicMock()
Expand Down Expand Up @@ -240,7 +246,9 @@ def test_sql_statement(self) -> None:
"""
Test Extraction with empty result from query
"""
with patch.object(SQLAlchemyExtractor, '_get_connection'):
with patch.object(SQLAlchemyExtractor, '_get_connection'), \
patch.object(HiveTableMetadataExtractor, '_choose_default_sql_stm',
return_value=HiveTableMetadataExtractor.DEFAULT_SQL_STATEMENT):
extractor = HiveTableMetadataExtractor()
extractor.init(self.conf)
self.assertTrue(self.where_clause_suffix in extractor.sql_stmt)
Expand All @@ -250,7 +258,9 @@ def test_hive_sql_statement_with_custom_sql(self) -> None:
Test Extraction by providing a custom sql
:return:
"""
with patch.object(SQLAlchemyExtractor, '_get_connection'):
with patch.object(SQLAlchemyExtractor, '_get_connection'), \
patch.object(HiveTableMetadataExtractor, '_choose_default_sql_stm',
return_value=HiveTableMetadataExtractor.DEFAULT_SQL_STATEMENT):
config_dict = {
HiveTableMetadataExtractor.WHERE_CLAUSE_SUFFIX_KEY: self.where_clause_suffix,
'extractor.sqlalchemy.{}'.format(SQLAlchemyExtractor.CONN_STRING):
Expand Down

0 comments on commit 2992618

Please sign in to comment.