diff --git a/airflow/providers/google/cloud/operators/bigquery.py b/airflow/providers/google/cloud/operators/bigquery.py index 9be3a0c2eef18b..57bf0268dfb884 100644 --- a/airflow/providers/google/cloud/operators/bigquery.py +++ b/airflow/providers/google/cloud/operators/bigquery.py @@ -56,6 +56,7 @@ BigQueryValueCheckTrigger, ) from airflow.providers.google.cloud.utils.bigquery import convert_job_id +from airflow.providers.google.cloud.utils.openlineage import _BigQueryOpenLineageMixin from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID from airflow.utils.helpers import exactly_one @@ -141,68 +142,6 @@ def get_db_hook(self: BigQueryCheckOperator) -> BigQueryHook: # type:ignore[mis ) -class _BigQueryOpenLineageMixin: - def get_openlineage_facets_on_complete(self, task_instance): - """ - Retrieve OpenLineage data for a COMPLETE BigQuery job. - - This method retrieves statistics for the specified job_ids using the BigQueryDatasetsProvider. - It calls BigQuery API, retrieving input and output dataset info from it, as well as run-level - usage statistics. - - Run facets should contain: - - ExternalQueryRunFacet - - BigQueryJobRunFacet - - Job facets should contain: - - SqlJobFacet if operator has self.sql - - Input datasets should contain facets: - - DataSourceDatasetFacet - - SchemaDatasetFacet - - Output datasets should contain facets: - - DataSourceDatasetFacet - - SchemaDatasetFacet - - OutputStatisticsOutputDatasetFacet - """ - from openlineage.client.facet import SqlJobFacet - from openlineage.common.provider.bigquery import BigQueryDatasetsProvider - - from airflow.providers.openlineage.extractors import OperatorLineage - from airflow.providers.openlineage.utils.utils import normalize_sql - - if not self.job_id: - return OperatorLineage() - - client = self.hook.get_client(project_id=self.hook.project_id) - job_ids = self.job_id - if isinstance(self.job_id, str): - job_ids = [self.job_id] - inputs, outputs, run_facets = {}, {}, {} - for job_id in job_ids: - stats = BigQueryDatasetsProvider(client=client).get_facets(job_id=job_id) - for input in stats.inputs: - input = input.to_openlineage_dataset() - inputs[input.name] = input - if stats.output: - output = stats.output.to_openlineage_dataset() - outputs[output.name] = output - for key, value in stats.run_facets.items(): - run_facets[key] = value - - job_facets = {} - if hasattr(self, "sql"): - job_facets["sql"] = SqlJobFacet(query=normalize_sql(self.sql)) - - return OperatorLineage( - inputs=list(inputs.values()), - outputs=list(outputs.values()), - run_facets=run_facets, - job_facets=job_facets, - ) - - class _BigQueryOperatorsEncryptionConfigurationMixin: """A class to handle the configuration for BigQueryHook.insert_job method.""" diff --git a/airflow/providers/google/cloud/utils/openlineage.py b/airflow/providers/google/cloud/utils/openlineage.py index 2121ba4ad84657..a89559b6479ffe 100644 --- a/airflow/providers/google/cloud/utils/openlineage.py +++ b/airflow/providers/google/cloud/utils/openlineage.py @@ -15,24 +15,33 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""This module contains code related to OpenLineage and lineage extraction.""" - from __future__ import annotations +import copy +import json +import traceback from typing import TYPE_CHECKING, Any +from attr import define, field from openlineage.client.facet import ( + BaseFacet, ColumnLineageDatasetFacet, ColumnLineageDatasetFacetFieldsAdditional, ColumnLineageDatasetFacetFieldsAdditionalInputFields, DocumentationDatasetFacet, + ErrorMessageRunFacet, + OutputStatisticsOutputDatasetFacet, SchemaDatasetFacet, SchemaField, ) +from openlineage.client.run import Dataset if TYPE_CHECKING: from google.cloud.bigquery.table import Table - from openlineage.client.run import Dataset + + +BIGQUERY_NAMESPACE = "bigquery" +BIGQUERY_URI = "bigquery" def get_facets_from_bq_table(table: Table) -> dict[Any, Any]: @@ -79,3 +88,287 @@ def get_identity_column_lineage_facet( } ) return column_lineage_facet + + +@define +class BigQueryJobRunFacet(BaseFacet): + """Facet that represents relevant statistics of bigquery run. + + This facet is used to provide statistics about bigquery run. + + :param cached: BigQuery caches query results. Rest of the statistics will not be provided for cached queries. + :param billedBytes: How many bytes BigQuery bills for. + :param properties: Full property tree of BigQUery run. + """ + + cached: bool + billedBytes: int | None = field(default=None) + properties: str | None = field(default=None) + + +# TODO: remove ErrorMessageRunFacet in next release +@define +class BigQueryErrorRunFacet(BaseFacet): + """ + Represents errors that can happen during execution of BigqueryExtractor. + + :param clientError: represents errors originating in bigquery client + :param parserError: represents errors that happened during parsing SQL provided to bigquery + """ + + clientError: str = field(default=None) + parserError: str = field(default=None) + + +def get_from_nullable_chain(source: Any, chain: list[str]) -> Any | None: + """Get object from nested structure of objects, where it's not guaranteed that all keys in the nested structure exist. + + Intended to replace chain of `dict.get()` statements. + + Example usage: + + .. code-block:: python + + if ( + not job._properties.get("statistics") + or not job._properties.get("statistics").get("query") + or not job._properties.get("statistics").get("query").get("referencedTables") + ): + return None + result = job._properties.get("statistics").get("query").get("referencedTables") + + becomes: + + .. code-block:: python + + result = get_from_nullable_chain(properties, ["statistics", "query", "queryPlan"]) + if not result: + return None + """ + chain.reverse() + try: + while chain: + next_key = chain.pop() + if isinstance(source, dict): + source = source.get(next_key) + else: + source = getattr(source, next_key) + return source + except AttributeError: + return None + + +class _BigQueryOpenLineageMixin: + def get_openlineage_facets_on_complete(self, _): + """ + Retrieve OpenLineage data for a COMPLETE BigQuery job. + + This method retrieves statistics for the specified job_ids using the BigQueryDatasetsProvider. + It calls BigQuery API, retrieving input and output dataset info from it, as well as run-level + usage statistics. + + Run facets should contain: + - ExternalQueryRunFacet + - BigQueryJobRunFacet + + Run facets may contain: + - ErrorMessageRunFacet + + Job facets should contain: + - SqlJobFacet if operator has self.sql + + Input datasets should contain facets: + - DataSourceDatasetFacet + - SchemaDatasetFacet + + Output datasets should contain facets: + - DataSourceDatasetFacet + - SchemaDatasetFacet + - OutputStatisticsOutputDatasetFacet + """ + from openlineage.client.facet import ExternalQueryRunFacet, SqlJobFacet + + from airflow.providers.openlineage.extractors import OperatorLineage + from airflow.providers.openlineage.sqlparser import SQLParser + + if not self.job_id: + return OperatorLineage() + + run_facets: dict[str, BaseFacet] = { + "externalQuery": ExternalQueryRunFacet(externalQueryId=self.job_id, source="bigquery") + } + + job_facets = {"sql": SqlJobFacet(query=SQLParser.normalize_sql(self.sql))} + + self.client = self.hook.get_client(project_id=self.hook.project_id) + job_ids = self.job_id + if isinstance(self.job_id, str): + job_ids = [self.job_id] + inputs, outputs = [], [] + for job_id in job_ids: + inner_inputs, inner_outputs, inner_run_facets = self.get_facets(job_id=job_id) + inputs.extend(inner_inputs) + outputs.extend(inner_outputs) + run_facets.update(inner_run_facets) + + return OperatorLineage( + inputs=inputs, + outputs=outputs, + run_facets=run_facets, + job_facets=job_facets, + ) + + def get_facets(self, job_id: str): + inputs = [] + outputs = [] + run_facets: dict[str, BaseFacet] = {} + if hasattr(self, "log"): + self.log.debug("Extracting data from bigquery job: `%s`", job_id) + try: + job = self.client.get_job(job_id=job_id) # type: ignore + props = job._properties + + if get_from_nullable_chain(props, ["status", "state"]) != "DONE": + raise ValueError(f"Trying to extract data from running bigquery job: `{job_id}`") + + # TODO: remove bigQuery_job in next release + run_facets["bigQuery_job"] = run_facets["bigQueryJob"] = self._get_bigquery_job_run_facet(props) + + if get_from_nullable_chain(props, ["statistics", "numChildJobs"]): + if hasattr(self, "log"): + self.log.debug("Found SCRIPT job. Extracting lineage from child jobs instead.") + # SCRIPT job type has no input / output information but spawns child jobs that have one + # https://cloud.google.com/bigquery/docs/information-schema-jobs#multi-statement_query_job + for child_job_id in self.client.list_jobs(parent_job=job_id): + child_job = self.client.get_job(job_id=child_job_id) # type: ignore + child_inputs, child_output = self._get_inputs_outputs_from_job(child_job._properties) + inputs.extend(child_inputs) + outputs.append(child_output) + else: + inputs, _output = self._get_inputs_outputs_from_job(props) + outputs.append(_output) + except Exception as e: + if hasattr(self, "log"): + self.log.warning("Cannot retrieve job details from BigQuery.Client. %s", e, exc_info=True) + exception_msg = traceback.format_exc() + # TODO: remove ErrorMessageRunFacet in next release + run_facets.update( + { + "errorMessage": ErrorMessageRunFacet( + message=f"{e}: {exception_msg}", + programmingLanguage="python", + ), + "bigQuery_error": BigQueryErrorRunFacet( + clientError=f"{e}: {exception_msg}", + ), + } + ) + deduplicated_outputs = self._deduplicate_outputs(outputs) + # For complex scripts there can be multiple outputs - in that case keep them all in `outputs` and + # leave the `output` empty to avoid providing misleading information. When the script has a single + # output (f.e. a single statement with some variable declarations), treat it as a regular non-script + # job and put the output in `output` as an addition to new `outputs`. `output` is deprecated. + return inputs, deduplicated_outputs, run_facets + + def _deduplicate_outputs(self, outputs: list[Dataset | None]) -> list[Dataset]: + # Sources are the same so we can compare only names + final_outputs = {} + for single_output in outputs: + if not single_output: + continue + key = single_output.name + if key not in final_outputs: + final_outputs[key] = single_output + continue + + # No OutputStatisticsOutputDatasetFacet is added to duplicated outputs as we can not determine + # if the rowCount or size can be summed together. + single_output.facets.pop("outputStatistics", None) + final_outputs[key] = single_output + + return list(final_outputs.values()) + + def _get_inputs_outputs_from_job(self, properties: dict) -> tuple[list[Dataset], Dataset | None]: + input_tables = get_from_nullable_chain(properties, ["statistics", "query", "referencedTables"]) or [] + output_table = get_from_nullable_chain(properties, ["configuration", "query", "destinationTable"]) + inputs = [self._get_dataset(input_table) for input_table in input_tables] + if output_table: + output = self._get_dataset(output_table) + dataset_stat_facet = self._get_statistics_dataset_facet(properties) + if dataset_stat_facet: + output.facets.update({"outputStatistics": dataset_stat_facet}) + + return inputs, output + + @staticmethod + def _get_bigquery_job_run_facet(properties: dict) -> BigQueryJobRunFacet: + if get_from_nullable_chain(properties, ["configuration", "query", "query"]): + # Exclude the query to avoid event size issues and duplicating SqlJobFacet information. + properties = copy.deepcopy(properties) + properties["configuration"]["query"].pop("query") + cache_hit = get_from_nullable_chain(properties, ["statistics", "query", "cacheHit"]) + billed_bytes = get_from_nullable_chain(properties, ["statistics", "query", "totalBytesBilled"]) + return BigQueryJobRunFacet( + cached=str(cache_hit).lower() == "true", + billedBytes=int(billed_bytes) if billed_bytes else None, + properties=json.dumps(properties), + ) + + @staticmethod + def _get_statistics_dataset_facet(properties) -> OutputStatisticsOutputDatasetFacet | None: + query_plan = get_from_nullable_chain(properties, chain=["statistics", "query", "queryPlan"]) + if not query_plan: + return None + + out_stage = query_plan[-1] + out_rows = out_stage.get("recordsWritten", None) + out_bytes = out_stage.get("shuffleOutputBytes", None) + if out_bytes and out_rows: + return OutputStatisticsOutputDatasetFacet(rowCount=int(out_rows), size=int(out_bytes)) + return None + + def _get_dataset(self, table: dict) -> Dataset: + project = table.get("projectId") + dataset = table.get("datasetId") + table_name = table.get("tableId") + dataset_name = f"{project}.{dataset}.{table_name}" + + dataset_schema = self._get_table_schema_safely(dataset_name) + return Dataset( + namespace=BIGQUERY_NAMESPACE, + name=dataset_name, + facets={ + "schema": dataset_schema, + } + if dataset_schema + else {}, + ) + + def _get_table_schema_safely(self, table_name: str) -> SchemaDatasetFacet | None: + try: + return self._get_table_schema(table_name) + except Exception as e: + if hasattr(self, "log"): + self.log.warning("Could not extract output schema from bigquery. %s", e) + return None + + def _get_table_schema(self, table: str) -> SchemaDatasetFacet | None: + bq_table = self.client.get_table(table) + + if not bq_table._properties: + return None + + fields = get_from_nullable_chain(bq_table._properties, ["schema", "fields"]) + if not fields: + return None + + return SchemaDatasetFacet( + fields=[ + SchemaField( + name=field.get("name"), + type=field.get("type"), + description=field.get("description"), + ) + for field in fields + ] + ) diff --git a/airflow/providers/openlineage/utils/utils.py b/airflow/providers/openlineage/utils/utils.py index b9cd385cc5ad73..5fdaf2401fa319 100644 --- a/airflow/providers/openlineage/utils/utils.py +++ b/airflow/providers/openlineage/utils/utils.py @@ -22,7 +22,7 @@ import logging from contextlib import suppress from functools import wraps -from typing import TYPE_CHECKING, Any, Iterable +from typing import TYPE_CHECKING, Any import attrs from openlineage.client.utils import RedactMixin # TODO: move this maybe to Airflow's logic? @@ -382,13 +382,6 @@ def get_filtered_unknown_operator_keys(operator: BaseOperator) -> dict: return {attr: value for attr, value in operator.__dict__.items() if attr not in not_required_keys} -def normalize_sql(sql: str | Iterable[str]): - if isinstance(sql, str): - sql = [stmt for stmt in sql.split(";") if stmt != ""] - sql = [obj for stmt in sql for obj in stmt.split(";") if obj != ""] - return ";\n".join(sql) - - def should_use_external_connection(hook) -> bool: # TODO: Add checking overrides return hook.__class__.__name__ not in ["SnowflakeHook", "SnowflakeSqlApiHook"] diff --git a/airflow/providers/snowflake/hooks/snowflake.py b/airflow/providers/snowflake/hooks/snowflake.py index a9b3ee5209d377..978bcf75e1c566 100644 --- a/airflow/providers/snowflake/hooks/snowflake.py +++ b/airflow/providers/snowflake/hooks/snowflake.py @@ -34,6 +34,7 @@ from airflow.exceptions import AirflowException from airflow.providers.common.sql.hooks.sql import DbApiHook, return_single_query_results +from airflow.providers.snowflake.utils.openlineage import fix_snowflake_sqlalchemy_uri from airflow.utils.strings import to_boolean T = TypeVar("T") @@ -462,9 +463,7 @@ def get_openlineage_database_dialect(self, _) -> str: def get_openlineage_default_schema(self) -> str | None: return self._get_conn_params["schema"] - def _get_openlineage_authority(self, _) -> str: - from openlineage.common.provider.snowflake import fix_snowflake_sqlalchemy_uri - + def _get_openlineage_authority(self, _) -> str | None: uri = fix_snowflake_sqlalchemy_uri(self.get_uri()) return urlparse(uri).hostname diff --git a/airflow/providers/snowflake/utils/openlineage.py b/airflow/providers/snowflake/utils/openlineage.py new file mode 100644 index 00000000000000..fa784bbb37c553 --- /dev/null +++ b/airflow/providers/snowflake/utils/openlineage.py @@ -0,0 +1,63 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from urllib.parse import quote, urlparse, urlunparse + + +def fix_account_name(name: str) -> str: + """Fix account name to have the following format: ...""" + spl = name.split(".") + if len(spl) == 1: + account = spl[0] + region, cloud = "us-west-1", "aws" + elif len(spl) == 2: + account, region = spl + cloud = "aws" + else: + account, region, cloud = spl + return f"{account}.{region}.{cloud}" + + +def fix_snowflake_sqlalchemy_uri(uri: str) -> str: + """Fix snowflake sqlalchemy connection URI to OpenLineage structure. + + Snowflake sqlalchemy connection URI has following structure: + 'snowflake://:@//?warehouse=&role=' + We want account identifier normalized. It can have two forms: + - newer, in form of -. In this case we want to do nothing. + - older, composed of -- where region and cloud can be + optional in some cases. If is omitted, it's AWS. + If region and cloud are omitted, it's AWS us-west-1 + """ + try: + parts = urlparse(uri) + except ValueError: + # snowflake.sqlalchemy.URL does not quote `[` and `]` + # that's a rare case so we can run more debugging code here + # to make sure we replace only password + parts = urlparse(uri.replace("[", quote("[")).replace("]", quote("]"))) + + hostname = parts.hostname + if not hostname: + return uri + + # old account identifier like xy123456 + if "." in hostname or not any(word in hostname for word in ["-", "_"]): + hostname = fix_account_name(hostname) + # else - its new hostname, just return it + return urlunparse((parts.scheme, hostname, parts.path, parts.params, parts.query, parts.fragment)) diff --git a/tests/providers/google/cloud/operators/test_bigquery.py b/tests/providers/google/cloud/operators/test_bigquery.py index 4cfcb7fe87c631..16c9cbdb820bae 100644 --- a/tests/providers/google/cloud/operators/test_bigquery.py +++ b/tests/providers/google/cloud/operators/test_bigquery.py @@ -26,9 +26,8 @@ import pytest from google.cloud.bigquery import DEFAULT_RETRY from google.cloud.exceptions import Conflict -from openlineage.client.facet import DataSourceDatasetFacet, ExternalQueryRunFacet, SqlJobFacet +from openlineage.client.facet import ErrorMessageRunFacet, ExternalQueryRunFacet, SqlJobFacet from openlineage.client.run import Dataset -from openlineage.common.provider.bigquery import BigQueryErrorRunFacet from airflow.exceptions import AirflowException, AirflowSkipException, AirflowTaskTimeout, TaskDeferred from airflow.providers.google.cloud.operators.bigquery import ( @@ -1712,21 +1711,19 @@ def test_execute_openlineage_events(self, mock_hook): assert result == real_job_id - with open(file="tests/providers/google/cloud/operators/job_details.json") as f: + with open(file="tests/providers/google/cloud/utils/job_details.json") as f: job_details = json.loads(f.read()) mock_hook.return_value.get_client.return_value.get_job.return_value._properties = job_details + mock_hook.return_value.get_client.return_value.get_table.side_effect = Exception() lineage = op.get_openlineage_facets_on_complete(None) assert lineage.inputs == [ - Dataset( - namespace="bigquery", - name="airflow-openlineage.new_dataset.test_table", - facets={"dataSource": DataSourceDatasetFacet(name="bigquery", uri="bigquery")}, - ) + Dataset(namespace="bigquery", name="airflow-openlineage.new_dataset.test_table") ] assert lineage.run_facets == { "bigQuery_job": mock.ANY, + "bigQueryJob": mock.ANY, "externalQuery": ExternalQueryRunFacet(externalQueryId=mock.ANY, source="bigquery"), } assert lineage.job_facets == {"sql": SqlJobFacet(query="SELECT * FROM test_table")} @@ -1756,7 +1753,7 @@ def test_execute_fails_openlineage_events(self, mock_hook): operator.execute(MagicMock()) lineage = operator.get_openlineage_facets_on_complete(None) - assert lineage.run_facets["bigQuery_error"] == BigQueryErrorRunFacet(clientError=mock.ANY) + assert isinstance(lineage.run_facets["errorMessage"], ErrorMessageRunFacet) @pytest.mark.db_test @mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook") diff --git a/tests/providers/google/cloud/operators/job_details.json b/tests/providers/google/cloud/utils/job_details.json similarity index 100% rename from tests/providers/google/cloud/operators/job_details.json rename to tests/providers/google/cloud/utils/job_details.json diff --git a/tests/providers/google/cloud/utils/out_table_details.json b/tests/providers/google/cloud/utils/out_table_details.json new file mode 100644 index 00000000000000..c162d6f84d3e5c --- /dev/null +++ b/tests/providers/google/cloud/utils/out_table_details.json @@ -0,0 +1,30 @@ +{ + "kind": "bigquery#table", + "etag": "/xPW+XqH9yYTFhR/MQSm+Q==", + "id": "bq-airflow-openlineage:new_dataset.output_table_3", + "selfLink": "https://bigquery.googleapis.com/bigquery/v2/projects/bq-airflow-openlineage/datasets/new_dataset/tables/output_table_3", + "tableReference": { + "projectId": "bq-airflow-openlineage", + "datasetId": "new_dataset", + "tableId": "output_table" + }, + "schema": { + "fields": [{ + "name": "name", + "type": "STRING", + "mode": "NULLABLE" + }, { + "name": "total_people", + "type": "INTEGER", + "mode": "NULLABLE" + }] + }, + "numBytes": "321", + "numLongTermBytes": "0", + "numRows": "20", + "creationTime": "1604622520179", + "expirationTime": "1609806520179", + "lastModifiedTime": "1607406385965", + "type": "TABLE", + "location": "US" +} diff --git a/tests/providers/google/cloud/utils/script_job_details.json b/tests/providers/google/cloud/utils/script_job_details.json new file mode 100644 index 00000000000000..5f9ed2cf9b1c15 --- /dev/null +++ b/tests/providers/google/cloud/utils/script_job_details.json @@ -0,0 +1,36 @@ +{ + "kind": "bigquery#job", + "etag": "123", + "id": "project-id:EU.bquxjob_2894c210_18e85d7a86e", + "selfLink": "https://bigquery.googleapis.com/bigquery/v2/projects/project-id/jobs/bquxjob_2894c210_18e85d7a86e?location=EU", + "configuration": { + "query": { + "query": "DECLARE\n start_ts TIMESTAMP;\nSET\n start_ts = TIMESTAMP(\"2020-01-04 12:00:00 UTC\"); \nCREATE OR REPLACE TABLE ...", + "priority": "INTERACTIVE", + "allowLargeResults": false + }, + "jobType": "QUERY" + }, + "jobReference": { + "projectId": "project-id", + "jobId": "bquxjob_2894c210_18e85d7a86e", + "location": "EU" + }, + "statistics": { + "creationTime": 1711642487190.0, + "startTime": 1711642487224.0, + "endTime": 1711642490618.0, + "totalBytesProcessed": "119672332", + "query": { + "totalBytesProcessed": "119672332", + "totalBytesBilled": "120586240", + "totalSlotMs": "31441", + "statementType": "SCRIPT" + }, + "totalSlotMs": "31441", + "numChildJobs": "1" + }, + "status": { + "state": "DONE" + } +} diff --git a/tests/providers/google/cloud/utils/table_details.json b/tests/providers/google/cloud/utils/table_details.json new file mode 100644 index 00000000000000..7904c913b48b46 --- /dev/null +++ b/tests/providers/google/cloud/utils/table_details.json @@ -0,0 +1,53 @@ +{ + "kind": "bigquery#table", + "etag": "L7YwMYGtkofoiqqF8tXSDA==", + "id": "bigquery-public-data:usa_names.usa_1910_2013", + "selfLink": "https://bigquery.googleapis.com/bigquery/v2/projects/bigquery-public-data/datasets/usa_names/tables/usa_1910_2013", + "tableReference": { + "projectId": "bigquery-public-data", + "datasetId": "usa_names", + "tableId": "usa_1910_2013" + }, + "description": "The table contains the number of applicants for a Social Security card by year of birth and sex. The number of such applicants is restricted to U.S. births where the year of birth, sex, State of birth (50 States and District of Columbia) are known, and where the given name is at least 2 characters long.\n\nsource: http://www.ssa.gov/OACT/babynames/limits.html", + "schema": { + "fields": [ + { + "name": "state", + "type": "STRING", + "mode": "NULLABLE", + "description": "2-digit state code" + }, + { + "name": "gender", + "type": "STRING", + "mode": "NULLABLE", + "description": "Sex (M=male or F=female)" + }, + { + "name": "year", + "type": "INTEGER", + "mode": "NULLABLE", + "description": "4-digit year of birth" + }, + { + "name": "name", + "type": "STRING", + "mode": "NULLABLE", + "description": "Given name of a person at birth" + }, + { + "name": "number", + "type": "INTEGER", + "mode": "NULLABLE", + "description": "Number of occurrences of the name" + } + ] + }, + "numBytes": "171432506", + "numLongTermBytes": "171432506", + "numRows": "5552452", + "creationTime": "1457744542425", + "lastModifiedTime": "1457746213452", + "type": "TABLE", + "location": "US" +} diff --git a/tests/providers/google/cloud/utils/test_openlineage.py b/tests/providers/google/cloud/utils/test_openlineage.py index 608007fa4ba0d3..d2fad84b22cf22 100644 --- a/tests/providers/google/cloud/utils/test_openlineage.py +++ b/tests/providers/google/cloud/utils/test_openlineage.py @@ -16,6 +16,9 @@ # under the License. from __future__ import annotations +import json +from unittest.mock import MagicMock + import pytest from google.cloud.bigquery.table import Table from openlineage.client.facet import ( @@ -23,12 +26,19 @@ ColumnLineageDatasetFacetFieldsAdditional, ColumnLineageDatasetFacetFieldsAdditionalInputFields, DocumentationDatasetFacet, + ExternalQueryRunFacet, + OutputStatisticsOutputDatasetFacet, SchemaDatasetFacet, SchemaField, ) from openlineage.client.run import Dataset -from airflow.providers.google.cloud.utils import openlineage +from airflow.providers.google.cloud.utils.openlineage import ( + BigQueryJobRunFacet, + _BigQueryOpenLineageMixin, + get_facets_from_bq_table, + get_identity_column_lineage_facet, +) TEST_DATASET = "test-dataset" TEST_TABLE_ID = "test-table-id" @@ -50,6 +60,227 @@ TEST_EMPTY_TABLE: Table = Table.from_api_repr(TEST_EMPTY_TABLE_API_REPR) +def read_file_json(file): + with open(file=file) as f: + return json.loads(f.read()) + + +class TableMock(MagicMock): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.inputs = [ + read_file_json("tests/providers/google/cloud/utils/table_details.json"), + read_file_json("tests/providers/google/cloud/utils/out_table_details.json"), + ] + + @property + def _properties(self): + return self.inputs.pop() + + +class TestBigQueryOpenLineageMixin: + def setup_method(self): + self.job_details = read_file_json("tests/providers/google/cloud/utils/job_details.json") + self.script_job_details = read_file_json("tests/providers/google/cloud/utils/script_job_details.json") + hook = MagicMock() + self.client = MagicMock() + + class BQOperator(_BigQueryOpenLineageMixin): + sql = "" + job_id = "job_id" + + @property + def hook(self): + return hook + + hook.get_client.return_value = self.client + + self.client.get_table.return_value = TableMock() + + self.operator = BQOperator() + + def test_bq_job_information(self): + self.client.get_job.return_value._properties = self.job_details + + lineage = self.operator.get_openlineage_facets_on_complete(None) + + self.job_details["configuration"]["query"].pop("query") + assert lineage.run_facets == { + "bigQuery_job": BigQueryJobRunFacet( + cached=False, billedBytes=111149056, properties=json.dumps(self.job_details) + ), + "bigQueryJob": BigQueryJobRunFacet( + cached=False, billedBytes=111149056, properties=json.dumps(self.job_details) + ), + "externalQuery": ExternalQueryRunFacet(externalQueryId="job_id", source="bigquery"), + } + assert lineage.inputs == [ + Dataset( + namespace="bigquery", + name="airflow-openlineage.new_dataset.test_table", + facets={ + "schema": SchemaDatasetFacet( + fields=[ + SchemaField("state", "STRING", "2-digit state code"), + SchemaField("gender", "STRING", "Sex (M=male or F=female)"), + SchemaField("year", "INTEGER", "4-digit year of birth"), + SchemaField("name", "STRING", "Given name of a person at birth"), + SchemaField("number", "INTEGER", "Number of occurrences of the name"), + ] + ) + }, + ) + ] + assert lineage.outputs == [ + Dataset( + namespace="bigquery", + name="airflow-openlineage.new_dataset.output_table", + facets={ + "outputStatistics": OutputStatisticsOutputDatasetFacet( + rowCount=20, size=321, fileCount=None + ) + }, + ), + ] + + def test_bq_script_job_information(self): + self.client.get_job.side_effect = [ + MagicMock(_properties=self.script_job_details), + MagicMock(_properties=self.job_details), + ] + self.client.list_jobs.return_value = ["child_job_id"] + + lineage = self.operator.get_openlineage_facets_on_complete(None) + + self.script_job_details["configuration"]["query"].pop("query") + assert lineage.run_facets == { + "bigQueryJob": BigQueryJobRunFacet( + cached=False, billedBytes=120586240, properties=json.dumps(self.script_job_details) + ), + "bigQuery_job": BigQueryJobRunFacet( + cached=False, billedBytes=120586240, properties=json.dumps(self.script_job_details) + ), + "externalQuery": ExternalQueryRunFacet(externalQueryId="job_id", source="bigquery"), + } + assert lineage.inputs == [ + Dataset( + namespace="bigquery", + name="airflow-openlineage.new_dataset.test_table", + facets={ + "schema": SchemaDatasetFacet( + fields=[ + SchemaField("state", "STRING", "2-digit state code"), + SchemaField("gender", "STRING", "Sex (M=male or F=female)"), + SchemaField("year", "INTEGER", "4-digit year of birth"), + SchemaField("name", "STRING", "Given name of a person at birth"), + SchemaField("number", "INTEGER", "Number of occurrences of the name"), + ] + ) + }, + ) + ] + assert lineage.outputs == [ + Dataset( + namespace="bigquery", + name="airflow-openlineage.new_dataset.output_table", + facets={ + "outputStatistics": OutputStatisticsOutputDatasetFacet( + rowCount=20, size=321, fileCount=None + ) + }, + ), + ] + + def test_deduplicate_outputs(self): + outputs = [ + None, + Dataset( + name="d1", namespace="", facets={"outputStatistics": OutputStatisticsOutputDatasetFacet(3, 4)} + ), + Dataset( + name="d1", + namespace="", + facets={"outputStatistics": OutputStatisticsOutputDatasetFacet(3, 4), "t1": "t1"}, + ), + Dataset( + name="d2", + namespace="", + facets={"outputStatistics": OutputStatisticsOutputDatasetFacet(6, 7), "t2": "t2"}, + ), + Dataset( + name="d2", + namespace="", + facets={"outputStatistics": OutputStatisticsOutputDatasetFacet(60, 70), "t20": "t20"}, + ), + ] + result = self.operator._deduplicate_outputs(outputs) + assert len(result) == 2 + first_result = result[0] + assert first_result.name == "d1" + assert first_result.facets == {"t1": "t1"} + second_result = result[1] + assert second_result.name == "d2" + assert second_result.facets == {"t20": "t20"} + + @pytest.mark.parametrize("cache", (None, "false", False, 0)) + def test_get_job_run_facet_no_cache_and_with_bytes(self, cache): + properties = { + "statistics": {"query": {"cacheHit": cache, "totalBytesBilled": 10}}, + "configuration": {"query": {"query": "SELECT ..."}}, + } + result = self.operator._get_bigquery_job_run_facet(properties) + assert result.cached is False + assert result.billedBytes == 10 + properties["configuration"]["query"].pop("query") + assert result.properties == json.dumps(properties) + + @pytest.mark.parametrize("cache", ("true", True)) + def test_get_job_run_facet_with_cache_and_no_bytes(self, cache): + properties = { + "statistics": { + "query": { + "cacheHit": cache, + } + }, + "configuration": {"query": {"query": "SELECT ..."}}, + } + result = self.operator._get_bigquery_job_run_facet(properties) + assert result.cached is True + assert result.billedBytes is None + properties["configuration"]["query"].pop("query") + assert result.properties == json.dumps(properties) + + def test_get_statistics_dataset_facet_no_query_plan(self): + properties = { + "statistics": {"query": {"totalBytesBilled": 10}}, + "configuration": {"query": {"query": "SELECT ..."}}, + } + result = self.operator._get_statistics_dataset_facet(properties) + assert result is None + + def test_get_statistics_dataset_facet_no_stats(self): + properties = { + "statistics": {"query": {"totalBytesBilled": 10, "queryPlan": [{"test": "test"}]}}, + "configuration": {"query": {"query": "SELECT ..."}}, + } + result = self.operator._get_statistics_dataset_facet(properties) + assert result is None + + def test_get_statistics_dataset_facet_with_stats(self): + properties = { + "statistics": { + "query": { + "totalBytesBilled": 10, + "queryPlan": [{"recordsWritten": 123, "shuffleOutputBytes": "321"}], + } + }, + "configuration": {"query": {"query": "SELECT ..."}}, + } + result = self.operator._get_statistics_dataset_facet(properties) + assert result.rowCount == 123 + assert result.size == 321 + + def test_get_facets_from_bq_table(): expected_facets = { "schema": SchemaDatasetFacet( @@ -60,7 +291,7 @@ def test_get_facets_from_bq_table(): ), "documentation": DocumentationDatasetFacet(description="Table description."), } - result = openlineage.get_facets_from_bq_table(TEST_TABLE) + result = get_facets_from_bq_table(TEST_TABLE) assert result == expected_facets @@ -69,7 +300,7 @@ def test_get_facets_from_empty_bq_table(): "schema": SchemaDatasetFacet(fields=[]), "documentation": DocumentationDatasetFacet(description=""), } - result = openlineage.get_facets_from_bq_table(TEST_EMPTY_TABLE) + result = get_facets_from_bq_table(TEST_EMPTY_TABLE) assert result == expected_facets @@ -115,9 +346,7 @@ def test_get_identity_column_lineage_facet_multiple_input_datasets(): ), } ) - result = openlineage.get_identity_column_lineage_facet( - field_names=field_names, input_datasets=input_datasets - ) + result = get_identity_column_lineage_facet(field_names=field_names, input_datasets=input_datasets) assert result == expected_facet @@ -128,9 +357,7 @@ def test_get_identity_column_lineage_facet_no_field_names(): Dataset(namespace="gs://second_bucket", name="dir2"), ] expected_facet = ColumnLineageDatasetFacet(fields={}) - result = openlineage.get_identity_column_lineage_facet( - field_names=field_names, input_datasets=input_datasets - ) + result = get_identity_column_lineage_facet(field_names=field_names, input_datasets=input_datasets) assert result == expected_facet @@ -139,4 +366,4 @@ def test_get_identity_column_lineage_facet_no_input_datasets(): input_datasets = [] with pytest.raises(ValueError): - openlineage.get_identity_column_lineage_facet(field_names=field_names, input_datasets=input_datasets) + get_identity_column_lineage_facet(field_names=field_names, input_datasets=input_datasets) diff --git a/tests/providers/openlineage/utils/test_sql.py b/tests/providers/openlineage/utils/test_sql.py index 180defbeec4385..f094fdaf1f7c5e 100644 --- a/tests/providers/openlineage/utils/test_sql.py +++ b/tests/providers/openlineage/utils/test_sql.py @@ -21,7 +21,6 @@ import pytest from openlineage.client.facet import SchemaDatasetFacet, SchemaField, set_producer from openlineage.client.run import Dataset -from openlineage.common.models import DbColumn, DbTableSchema from openlineage.common.sql import DbTableMeta from sqlalchemy import Column, MetaData, Table @@ -38,16 +37,6 @@ DB_NAME = "FOOD_DELIVERY" DB_SCHEMA_NAME = "PUBLIC" DB_TABLE_NAME = DbTableMeta("DISCOUNTS") -DB_TABLE_COLUMNS = [ - DbColumn(name="ID", type="int4", ordinal_position=1), - DbColumn(name="AMOUNT_OFF", type="int4", ordinal_position=2), - DbColumn(name="CUSTOMER_EMAIL", type="varchar", ordinal_position=3), - DbColumn(name="STARTS_ON", type="timestamp", ordinal_position=4), - DbColumn(name="ENDS_ON", type="timestamp", ordinal_position=5), -] -DB_TABLE_SCHEMA = DbTableSchema( - schema_name=DB_SCHEMA_NAME, table_name=DB_TABLE_NAME, columns=DB_TABLE_COLUMNS -) SCHEMA_FACET = SchemaDatasetFacet( fields=[ diff --git a/tests/providers/snowflake/utils/test_openlineage.py b/tests/providers/snowflake/utils/test_openlineage.py new file mode 100644 index 00000000000000..a85ed9c2afaf0e --- /dev/null +++ b/tests/providers/snowflake/utils/test_openlineage.py @@ -0,0 +1,62 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pytest + +from airflow.providers.snowflake.utils.openlineage import fix_snowflake_sqlalchemy_uri + + +@pytest.mark.parametrize( + "source,target", + [ + ( + "snowflake://user:pass@xy123456.us-east-1.aws/database/schema", + "snowflake://xy123456.us-east-1.aws/database/schema", + ), + ( + "snowflake://xy123456/database/schema", + "snowflake://xy123456.us-west-1.aws/database/schema", + ), + ( + "snowflake://xy12345.ap-southeast-1/database/schema", + "snowflake://xy12345.ap-southeast-1.aws/database/schema", + ), + ( + "snowflake://user:pass@xy12345.south-central-us.azure/database/schema", + "snowflake://xy12345.south-central-us.azure/database/schema", + ), + ( + "snowflake://user:pass@xy12345.us-east4.gcp/database/schema", + "snowflake://xy12345.us-east4.gcp/database/schema", + ), + ( + "snowflake://user:pass@organization-account/database/schema", + "snowflake://organization-account/database/schema", + ), + ( + "snowflake://user:p[ass@organization-account/database/schema", + "snowflake://organization-account/database/schema", + ), + ( + "snowflake://user:pass@organization]-account/database/schema", + "snowflake://organization%5D-account/database/schema", + ), + ], +) +def test_snowflake_sqlite_account_urls(source, target): + assert fix_snowflake_sqlalchemy_uri(source) == target