Skip to content

Commit

Permalink
* rebase
Browse files Browse the repository at this point in the history
Signed-off-by: Tyler Farris <[email protected]>
  • Loading branch information
mobuchowski authored and Tylerpfarris committed Jun 24, 2022
1 parent 75ae3f5 commit 6789430
Show file tree
Hide file tree
Showing 5 changed files with 174 additions and 149 deletions.
26 changes: 13 additions & 13 deletions integration/airflow/openlineage/airflow/extractors/dbapi_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,28 +49,28 @@ def get_table_schemas(
query_schemas.append(parse_query_result(cursor))
else:
query_schemas.append([])
return tuple([
return tuple(
[
Dataset.from_table_schema(
source=source,
table_schema=schema,
database_name=database
) for schema in schemas
] for schemas in query_schemas
])
[
Dataset.from_table_schema(
source=source, table_schema=schema, database_name=database
)
for schema in schemas
]
for schemas in query_schemas
]
)


def parse_query_result(cursor) -> List[DbTableSchema]:
schemas: Dict = {}
for row in cursor.fetchall():
table_schema_name: str = row[_TABLE_SCHEMA]
table_name: DbTableMeta = DbTableMeta(
row[_TABLE_NAME]
)
table_name: DbTableMeta = DbTableMeta(row[_TABLE_NAME])
table_column: DbColumn = DbColumn(
name=row[_COLUMN_NAME],
type=row[_UDT_NAME],
ordinal_position=row[_ORDINAL_POSITION]
ordinal_position=row[_ORDINAL_POSITION],
)

# Attempt to get table schema
Expand All @@ -85,6 +85,6 @@ def parse_query_result(cursor) -> List[DbTableSchema]:
schemas[table_key] = DbTableSchema(
schema_name=table_schema_name,
table_name=table_name,
columns=[table_column]
columns=[table_column],
)
return list(schemas.values())
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,10 @@
from openlineage.airflow.extractors.dbapi_utils import get_table_schemas
from openlineage.airflow.utils import (
get_connection_uri,
get_connection, safe_import_airflow
)
from openlineage.airflow.extractors.base import (
BaseExtractor,
TaskMetadata
get_connection,
safe_import_airflow,
)
from openlineage.airflow.extractors.base import BaseExtractor, TaskMetadata
from openlineage.client.facet import SqlJobFacet
from openlineage.common.sql import SqlMeta, parse, DbTableMeta
from openlineage.common.dataset import Source
Expand All @@ -29,25 +27,25 @@ def __init__(self, operator):

@classmethod
def get_operator_classnames(cls) -> List[str]:
return ['MySqlOperator']
return ["MySqlOperator"]

def extract(self) -> TaskMetadata:
task_name = f"{self.operator.dag_id}.{self.operator.task_id}"
run_facets: Dict = {}
job_facets = {
'sql': SqlJobFacet(self.operator.sql)
}
job_facets = {"sql": SqlJobFacet(self.operator.sql)}

# (1) Parse sql statement to obtain input / output tables.
sql_meta: Optional[SqlMeta] = parse(self.operator.sql, self.default_schema)
sql_meta: Optional[SqlMeta] = parse(
self.operator.sql, self.default_schema
)

if not sql_meta:
return TaskMetadata(
name=task_name,
inputs=[],
outputs=[],
run_facets=run_facets,
job_facets=job_facets
job_facets=job_facets,
)

# (2) Get database connection
Expand All @@ -59,7 +57,7 @@ def extract(self) -> TaskMetadata:
source = Source(
scheme=self._get_scheme(),
authority=self._get_authority(),
connection_url=self._get_connection_uri()
connection_url=self._get_connection_uri(),
)

database = self.operator.database
Expand All @@ -74,46 +72,48 @@ def extract(self) -> TaskMetadata:
self._get_hook(),
source,
database,
self._information_schema_query(sql_meta.in_tables) if sql_meta.in_tables else None,
self._information_schema_query(sql_meta.out_tables) if sql_meta.out_tables else None
self._information_schema_query(sql_meta.in_tables)
if sql_meta.in_tables
else None,
self._information_schema_query(sql_meta.out_tables)
if sql_meta.out_tables
else None,
)

return TaskMetadata(
name=task_name,
inputs=[ds.to_openlineage_dataset() for ds in inputs],
outputs=[ds.to_openlineage_dataset() for ds in outputs],
run_facets=run_facets,
job_facets=job_facets
job_facets=job_facets,
)

def _get_connection_uri(self):
return get_connection_uri(self.conn)

def _get_scheme(self):
return 'mysql'
return "mysql"

def _get_database(self) -> str:
if self.conn.schema:
return self.conn.schema
else:
parsed = urlparse(self.conn.get_uri())
return f'{parsed.path}'
return f"{parsed.path}"

def _get_authority(self) -> str:
if self.conn.host and self.conn.port:
return f'{self.conn.host}:{self.conn.port}'
return f"{self.conn.host}:{self.conn.port}"
else:
parsed = urlparse(self.conn.get_uri())
return f'{parsed.hostname}:{parsed.port}'
return f"{parsed.hostname}:{parsed.port}"

def _conn_id(self):
return self.operator.mysql_conn_id

@staticmethod
def _information_schema_query(tables: List[DbTableMeta]) -> str:
table_names = ",".join(map(
lambda name: f"'{name.name}'", tables
))
table_names = ",".join(map(lambda name: f"'{name.name}'", tables))
return f"""
SELECT table_schema,
table_name,
Expand All @@ -127,9 +127,9 @@ def _information_schema_query(tables: List[DbTableMeta]) -> str:
def _get_hook(self):
MySqlHook = safe_import_airflow(
airflow_1_path="airflow.hooks.mysql_hook.MySqlHook",
airflow_2_path="airflow.providers.mysql.hooks.mysql.MySqlHook"
airflow_2_path="airflow.providers.mysql.hooks.mysql.MySqlHook",
)
return MySqlHook(
mysql_conn_id=self.operator.mysql_conn_id,
schema=self.operator.database
schema=self.operator.database,
)
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@


class SnowflakeExtractor(BaseExtractor):
source_type = 'SNOWFLAKE'
default_schema = 'PUBLIC'
source_type = "SNOWFLAKE"
default_schema = "PUBLIC"

def __init__(self, operator):
super().__init__(operator)
Expand All @@ -29,18 +29,18 @@ def __init__(self, operator):

@classmethod
def get_operator_classnames(cls) -> List[str]:
return ['SnowflakeOperator']
return ["SnowflakeOperator"]

def extract(self) -> TaskMetadata:
task_name = f"{self.operator.dag_id}.{self.operator.task_id}"
run_facets: Dict = {}
job_facets = {
'sql': SqlJobFacet(self.operator.sql)
}
job_facets = {"sql": SqlJobFacet(self.operator.sql)}

# (1) Parse sql statement to obtain input / output tables.
logger.debug(f"Sending SQL to parser: {self.operator.sql}")
sql_meta: Optional[SqlMeta] = parse(self.operator.sql, self.default_schema)
sql_meta: Optional[SqlMeta] = parse(
self.operator.sql, self.default_schema
)
logger.debug(f"Got meta {sql_meta}")

if not sql_meta:
Expand All @@ -49,7 +49,7 @@ def extract(self) -> TaskMetadata:
inputs=[],
outputs=[],
run_facets=run_facets,
job_facets=job_facets
job_facets=job_facets,
)

# (2) Get Airflow connection
Expand All @@ -59,9 +59,9 @@ def extract(self) -> TaskMetadata:
# NOTE: We'll want to look into adding support for the `database`
# property that is used to override the one defined in the connection.
source = Source(
scheme='snowflake',
scheme="snowflake",
authority=self._get_authority(),
connection_url=self._get_connection_uri()
connection_url=self._get_connection_uri(),
)

database = self.operator.database
Expand All @@ -76,15 +76,18 @@ def extract(self) -> TaskMetadata:
self._get_hook(),
source,
database,
self._information_schema_query(sql_meta.in_tables) if sql_meta.in_tables else None,
self._information_schema_query(sql_meta.out_tables) if sql_meta.out_tables else None
self._information_schema_query(sql_meta.in_tables)
if sql_meta.in_tables
else None,
self._information_schema_query(sql_meta.out_tables)
if sql_meta.out_tables
else None,
)

query_ids = self._get_query_ids()
if len(query_ids) == 1:
run_facets['externalQuery'] = ExternalQueryRunFacet(
externalQueryId=query_ids[0],
source=source.name
run_facets["externalQuery"] = ExternalQueryRunFacet(
externalQueryId=query_ids[0], source=source.name
)
elif len(query_ids) > 1:
logger.warning(
Expand All @@ -97,13 +100,16 @@ def extract(self) -> TaskMetadata:
inputs=[ds.to_openlineage_dataset() for ds in inputs],
outputs=[ds.to_openlineage_dataset() for ds in outputs],
run_facets=run_facets,
job_facets=job_facets
job_facets=job_facets,
)

def _information_schema_query(self, tables: List[DbTableMeta]) -> str:
table_names = ",".join(map(
lambda name: f"'{self._normalize_identifiers(name.name)}'", tables
))
table_names = ",".join(
map(
lambda name: f"'{self._normalize_identifiers(name.name)}'",
tables,
)
)
database = self.operator.database
if not database:
database = self._get_database()
Expand All @@ -119,19 +125,27 @@ def _information_schema_query(self, tables: List[DbTableMeta]) -> str:
return sql

def _get_database(self) -> str:
if hasattr(self.operator, 'database') and self.operator.database is not None:
if (
hasattr(self.operator, "database")
and self.operator.database is not None
):
return self.operator.database
return self.conn.extra_dejson.get('extra__snowflake__database', '') \
or self.conn.extra_dejson.get('database', '')
return self.conn.extra_dejson.get(
"extra__snowflake__database", ""
) or self.conn.extra_dejson.get("database", "")

def _get_authority(self) -> str:
if hasattr(self.operator, 'account') and self.operator.account is not None:
if (
hasattr(self.operator, "account")
and self.operator.account is not None
):
return self.operator.account
return self.conn.extra_dejson.get('extra__snowflake__account', '') \
or self.conn.extra_dejson.get('account', '')
return self.conn.extra_dejson.get(
"extra__snowflake__account", ""
) or self.conn.extra_dejson.get("account", "")

def _get_hook(self):
if hasattr(self.operator, 'get_db_hook'):
if hasattr(self.operator, "get_db_hook"):
return self.operator.get_db_hook()
else:
return self.operator.get_hook()
Expand All @@ -151,6 +165,6 @@ def _get_connection_uri(self):
return get_connection_uri(self.conn)

def _get_query_ids(self) -> List[str]:
if hasattr(self.operator, 'query_ids'):
if hasattr(self.operator, "query_ids"):
return self.operator.query_ids
return []
Loading

0 comments on commit 6789430

Please sign in to comment.