From cf6324e43b2f7c183c3872704733b69d1498cda1 Mon Sep 17 00:00:00 2001 From: Jens Larsson Date: Tue, 4 May 2021 18:21:08 +0200 Subject: [PATCH] Implement BigQuery Table Schema Update Operator (#15367) Co-authored-by: Jens Larsson --- .../example_bigquery_operations.py | 14 ++ .../providers/google/cloud/hooks/bigquery.py | 95 ++++++++++ .../google/cloud/operators/bigquery.py | 112 ++++++++++++ .../operators/cloud/bigquery.rst | 17 ++ .../google/cloud/hooks/test_bigquery.py | 172 ++++++++++++++++++ .../google/cloud/operators/test_bigquery.py | 31 ++++ 6 files changed, 441 insertions(+) diff --git a/airflow/providers/google/cloud/example_dags/example_bigquery_operations.py b/airflow/providers/google/cloud/example_dags/example_bigquery_operations.py index b487efb610222..a72ff36f58c5f 100644 --- a/airflow/providers/google/cloud/example_dags/example_bigquery_operations.py +++ b/airflow/providers/google/cloud/example_dags/example_bigquery_operations.py @@ -36,6 +36,7 @@ BigQueryPatchDatasetOperator, BigQueryUpdateDatasetOperator, BigQueryUpdateTableOperator, + BigQueryUpdateTableSchemaOperator, BigQueryUpsertTableOperator, ) from airflow.utils.dates import days_ago @@ -73,6 +74,18 @@ ) # [END howto_operator_bigquery_create_table] + # [START howto_operator_bigquery_update_table_schema] + update_table_schema = BigQueryUpdateTableSchemaOperator( + task_id="update_table_schema", + dataset_id=DATASET_NAME, + table_id="test_table", + schema_fields_updates=[ + {"name": "emp_name", "description": "Name of employee"}, + {"name": "salary", "description": "Monthly salary in USD"}, + ], + ) + # [END howto_operator_bigquery_update_table_schema] + # [START howto_operator_bigquery_delete_table] delete_table = BigQueryDeleteTableOperator( task_id="delete_table", @@ -216,6 +229,7 @@ delete_view, ] >> upsert_table + >> update_table_schema >> delete_materialized_view >> delete_table >> delete_dataset diff --git a/airflow/providers/google/cloud/hooks/bigquery.py b/airflow/providers/google/cloud/hooks/bigquery.py index 044c014bf6377..e36baf548c9fd 100644 --- a/airflow/providers/google/cloud/hooks/bigquery.py +++ b/airflow/providers/google/cloud/hooks/bigquery.py @@ -1373,6 +1373,101 @@ def get_schema(self, dataset_id: str, table_id: str, project_id: Optional[str] = table = self.get_client(project_id=project_id).get_table(table_ref) return {"fields": [s.to_api_repr() for s in table.schema]} + @GoogleBaseHook.fallback_to_default_project_id + def update_table_schema( + self, + schema_fields_updates: List[Dict[str, Any]], + include_policy_tags: bool, + dataset_id: str, + table_id: str, + project_id: Optional[str] = None, + ) -> None: + """ + Update fields within a schema for a given dataset and table. Note that + some fields in schemas are immutable and trying to change them will cause + an exception. + If a new field is included it will be inserted which requires all required fields to be set. + See https://cloud.google.com/bigquery/docs/reference/rest/v2/tables#TableSchema + + :param include_policy_tags: If set to True policy tags will be included in + the update request which requires special permissions even if unchanged + see https://cloud.google.com/bigquery/docs/column-level-security#roles + :type include_policy_tags: bool + :param dataset_id: the dataset ID of the requested table to be updated + :type dataset_id: str + :param table_id: the table ID of the table to be updated + :type table_id: str + :param schema_fields_updates: a partial schema resource. see + https://cloud.google.com/bigquery/docs/reference/rest/v2/tables#TableSchema + + **Example**: :: + + schema_fields_updates=[ + {"name": "emp_name", "description": "Some New Description"}, + {"name": "salary", "description": "Some New Description"}, + {"name": "departments", "fields": [ + {"name": "name", "description": "Some New Description"}, + {"name": "type", "description": "Some New Description"} + ]}, + ] + + :type schema_fields_updates: List[dict] + :param project_id: The name of the project where we want to update the table. + :type project_id: str + """ + + def _build_new_schema( + current_schema: List[Dict[str, Any]], schema_fields_updates: List[Dict[str, Any]] + ) -> List[Dict[str, Any]]: + + # Turn schema_field_updates into a dict keyed on field names + schema_fields_updates = {field["name"]: field for field in deepcopy(schema_fields_updates)} + + # Create a new dict for storing the new schema, initated based on the current_schema + # as of Python 3.6, dicts retain order. + new_schema = {field["name"]: field for field in deepcopy(current_schema)} + + # Each item in schema_fields_updates contains a potential patch + # to a schema field, iterate over them + for field_name, patched_value in schema_fields_updates.items(): + # If this field already exists, update it + if field_name in new_schema: + # If this field is of type RECORD and has a fields key we need to patch it recursively + if "fields" in patched_value: + patched_value["fields"] = _build_new_schema( + new_schema[field_name]["fields"], patched_value["fields"] + ) + # Update the new_schema with the patched value + new_schema[field_name].update(patched_value) + # This is a new field, just include the whole configuration for it + else: + new_schema[field_name] = patched_value + + return list(new_schema.values()) + + def _remove_policy_tags(schema: List[Dict[str, Any]]): + for field in schema: + if "policyTags" in field: + del field["policyTags"] + if "fields" in field: + _remove_policy_tags(field["fields"]) + + current_table_schema = self.get_schema( + dataset_id=dataset_id, table_id=table_id, project_id=project_id + )["fields"] + new_schema = _build_new_schema(current_table_schema, schema_fields_updates) + + if not include_policy_tags: + _remove_policy_tags(new_schema) + + self.update_table( + table_resource={"schema": {"fields": new_schema}}, + fields=["schema"], + project_id=project_id, + dataset_id=dataset_id, + table_id=table_id, + ) + @GoogleBaseHook.fallback_to_default_project_id def poll_job_complete( self, diff --git a/airflow/providers/google/cloud/operators/bigquery.py b/airflow/providers/google/cloud/operators/bigquery.py index 67e38f4f089fb..1c2cbedcb423f 100644 --- a/airflow/providers/google/cloud/operators/bigquery.py +++ b/airflow/providers/google/cloud/operators/bigquery.py @@ -2039,6 +2039,118 @@ def execute(self, context) -> None: ) +class BigQueryUpdateTableSchemaOperator(BaseOperator): + """ + Update BigQuery Table Schema + Updates fields on a table schema based on contents of the supplied schema_fields_updates + parameter. The supplied schema does not need to be complete, if the field + already exists in the schema you only need to supply keys & values for the + items you want to patch, just ensure the "name" key is set. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:BigQueryUpdateTableSchemaOperator` + + :param schema_fields_updates: a partial schema resource. see + https://cloud.google.com/bigquery/docs/reference/rest/v2/tables#TableSchema + + **Example**: :: + + schema_fields_updates=[ + {"name": "emp_name", "description": "Some New Description"}, + {"name": "salary", "policyTags": {'names': ['some_new_policy_tag']},}, + {"name": "departments", "fields": [ + {"name": "name", "description": "Some New Description"}, + {"name": "type", "description": "Some New Description"} + ]}, + ] + + :type schema_fields_updates: List[dict] + :param include_policy_tags: (Optional) If set to True policy tags will be included in + the update request which requires special permissions even if unchanged (default False) + see https://cloud.google.com/bigquery/docs/column-level-security#roles + :type include_policy_tags: bool + :param dataset_id: A dotted + ``(.|:)`` that indicates which dataset + will be updated. (templated) + :type dataset_id: str + :param table_id: The table ID of the requested table. (templated) + :type table_id: str + :param project_id: The name of the project where we want to update the dataset. + Don't need to provide, if projectId in dataset_reference. + :type project_id: str + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :type gcp_conn_id: str + :param bigquery_conn_id: (Deprecated) The connection ID used to connect to Google Cloud. + This parameter has been deprecated. You should pass the gcp_conn_id parameter instead. + :type bigquery_conn_id: str + :param delegate_to: The account to impersonate, if any. + For this to work, the service account making the request must have domain-wide + delegation enabled. + :type delegate_to: str + :param location: The location used for the operation. + :type location: str + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :type impersonation_chain: Union[str, Sequence[str]] + """ + + template_fields = ( + 'schema_fields_updates', + 'dataset_id', + 'table_id', + 'project_id', + 'impersonation_chain', + ) + template_fields_renderers = {"schema_fields_updates": "json"} + ui_color = BigQueryUIColors.TABLE.value + + @apply_defaults + def __init__( + self, + *, + schema_fields_updates: List[Dict[str, Any]], + include_policy_tags: Optional[bool] = False, + dataset_id: Optional[str] = None, + table_id: Optional[str] = None, + project_id: Optional[str] = None, + gcp_conn_id: str = 'google_cloud_default', + delegate_to: Optional[str] = None, + impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + **kwargs, + ) -> None: + self.schema_fields_updates = schema_fields_updates + self.include_policy_tags = include_policy_tags + self.table_id = table_id + self.dataset_id = dataset_id + self.project_id = project_id + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.impersonation_chain = impersonation_chain + super().__init__(**kwargs) + + def execute(self, context): + bq_hook = BigQueryHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + impersonation_chain=self.impersonation_chain, + ) + + return bq_hook.update_table_schema( + schema_fields_updates=self.schema_fields_updates, + include_policy_tags=self.include_policy_tags, + dataset_id=self.dataset_id, + table_id=self.table_id, + project_id=self.project_id, + ) + + # pylint: disable=too-many-arguments class BigQueryInsertJobOperator(BaseOperator): """ diff --git a/docs/apache-airflow-providers-google/operators/cloud/bigquery.rst b/docs/apache-airflow-providers-google/operators/cloud/bigquery.rst index 7fc27804db38d..b99971a193065 100644 --- a/docs/apache-airflow-providers-google/operators/cloud/bigquery.rst +++ b/docs/apache-airflow-providers-google/operators/cloud/bigquery.rst @@ -245,6 +245,23 @@ in the given dataset. :start-after: [START howto_operator_bigquery_upsert_table] :end-before: [END howto_operator_bigquery_upsert_table] +.. _howto/operator:BigQueryUpdateTableSchemaOperator: + +Update table schema +""""""""""""""""""" + +To update the schema of a table you can use +:class:`~airflow.providers.google.cloud.operators.bigquery.BigQueryUpdateTableSchemaOperator`. + +This operator updates the schema field values supplied, while leaving the rest unchanged. This is useful +for instance to set new field descriptions on an existing table schema. + +.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_bigquery_operations.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_bigquery_update_table_schema] + :end-before: [END howto_operator_bigquery_update_table_schema] + .. _howto/operator:BigQueryDeleteTableOperator: Delete table diff --git a/tests/providers/google/cloud/hooks/test_bigquery.py b/tests/providers/google/cloud/hooks/test_bigquery.py index f7b09be9cfeb1..79057419b0cf6 100644 --- a/tests/providers/google/cloud/hooks/test_bigquery.py +++ b/tests/providers/google/cloud/hooks/test_bigquery.py @@ -676,6 +676,178 @@ def test_get_schema(self, mock_client): assert "fields" in result assert len(result["fields"]) == 2 + @mock.patch('airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_schema') + @mock.patch('airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.update_table') + def test_update_table_schema_with_policy_tags(self, mock_update, mock_get_schema): + mock_get_schema.return_value = { + "fields": [ + {'name': 'emp_name', 'type': 'STRING', 'mode': 'REQUIRED'}, + { + 'name': 'salary', + 'type': 'INTEGER', + 'mode': 'REQUIRED', + 'policyTags': {'names': ['sensitive']}, + }, + {'name': 'not_changed', 'type': 'INTEGER', 'mode': 'REQUIRED'}, + { + 'name': 'subrecord', + 'type': 'RECORD', + 'mode': 'REQUIRED', + 'fields': [ + { + 'name': 'field_1', + 'type': 'STRING', + 'mode': 'REQUIRED', + 'policyTags': {'names': ['sensitive']}, + }, + ], + }, + ] + } + + schema_fields_updates = [ + {'name': 'emp_name', 'description': 'Name of employee', 'policyTags': {'names': ['sensitive']}}, + { + 'name': 'salary', + 'description': 'Monthly salary in USD', + 'policyTags': {}, + }, + { + 'name': 'subrecord', + 'description': 'Some Desc', + 'fields': [ + {'name': 'field_1', 'description': 'Some nested desc'}, + ], + }, + ] + + expected_result_schema = { + 'fields': [ + { + 'name': 'emp_name', + 'type': 'STRING', + 'mode': 'REQUIRED', + 'description': 'Name of employee', + 'policyTags': {'names': ['sensitive']}, + }, + { + 'name': 'salary', + 'type': 'INTEGER', + 'mode': 'REQUIRED', + 'description': 'Monthly salary in USD', + 'policyTags': {}, + }, + {'name': 'not_changed', 'type': 'INTEGER', 'mode': 'REQUIRED'}, + { + 'name': 'subrecord', + 'type': 'RECORD', + 'mode': 'REQUIRED', + 'description': 'Some Desc', + 'fields': [ + { + 'name': 'field_1', + 'type': 'STRING', + 'mode': 'REQUIRED', + 'description': 'Some nested desc', + 'policyTags': {'names': ['sensitive']}, + } + ], + }, + ] + } + + self.hook.update_table_schema( + schema_fields_updates=schema_fields_updates, + include_policy_tags=True, + dataset_id=DATASET_ID, + table_id=TABLE_ID, + ) + + mock_update.assert_called_once_with( + dataset_id=DATASET_ID, + table_id=TABLE_ID, + project_id=PROJECT_ID, + table_resource={'schema': expected_result_schema}, + fields=['schema'], + ) + + @mock.patch('airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_schema') + @mock.patch('airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.update_table') + def test_update_table_schema_without_policy_tags(self, mock_update, mock_get_schema): + mock_get_schema.return_value = { + "fields": [ + {'name': 'emp_name', 'type': 'STRING', 'mode': 'REQUIRED'}, + {'name': 'salary', 'type': 'INTEGER', 'mode': 'REQUIRED'}, + {'name': 'not_changed', 'type': 'INTEGER', 'mode': 'REQUIRED'}, + { + 'name': 'subrecord', + 'type': 'RECORD', + 'mode': 'REQUIRED', + 'fields': [ + {'name': 'field_1', 'type': 'STRING', 'mode': 'REQUIRED'}, + ], + }, + ] + } + + schema_fields_updates = [ + {'name': 'emp_name', 'description': 'Name of employee'}, + { + 'name': 'salary', + 'description': 'Monthly salary in USD', + 'policyTags': {'names': ['sensitive']}, + }, + { + 'name': 'subrecord', + 'description': 'Some Desc', + 'fields': [ + {'name': 'field_1', 'description': 'Some nested desc'}, + ], + }, + ] + + expected_result_schema = { + 'fields': [ + {'name': 'emp_name', 'type': 'STRING', 'mode': 'REQUIRED', 'description': 'Name of employee'}, + { + 'name': 'salary', + 'type': 'INTEGER', + 'mode': 'REQUIRED', + 'description': 'Monthly salary in USD', + }, + {'name': 'not_changed', 'type': 'INTEGER', 'mode': 'REQUIRED'}, + { + 'name': 'subrecord', + 'type': 'RECORD', + 'mode': 'REQUIRED', + 'description': 'Some Desc', + 'fields': [ + { + 'name': 'field_1', + 'type': 'STRING', + 'mode': 'REQUIRED', + 'description': 'Some nested desc', + } + ], + }, + ] + } + + self.hook.update_table_schema( + schema_fields_updates=schema_fields_updates, + include_policy_tags=False, + dataset_id=DATASET_ID, + table_id=TABLE_ID, + ) + + mock_update.assert_called_once_with( + dataset_id=DATASET_ID, + table_id=TABLE_ID, + project_id=PROJECT_ID, + table_resource={'schema': expected_result_schema}, + fields=['schema'], + ) + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_service") def test_invalid_source_format(self, mock_get_service): with pytest.raises( diff --git a/tests/providers/google/cloud/operators/test_bigquery.py b/tests/providers/google/cloud/operators/test_bigquery.py index 61030f659ed11..801034c06ce66 100644 --- a/tests/providers/google/cloud/operators/test_bigquery.py +++ b/tests/providers/google/cloud/operators/test_bigquery.py @@ -46,6 +46,7 @@ BigQueryPatchDatasetOperator, BigQueryUpdateDatasetOperator, BigQueryUpdateTableOperator, + BigQueryUpdateTableSchemaOperator, BigQueryUpsertTableOperator, BigQueryValueCheckOperator, ) @@ -290,6 +291,36 @@ def test_execute(self, mock_hook): ) +class TestBigQueryUpdateTableSchemaOperator(unittest.TestCase): + @mock.patch('airflow.providers.google.cloud.operators.bigquery.BigQueryHook') + def test_execute(self, mock_hook): + + schema_field_updates = [ + { + 'name': 'emp_name', + 'description': 'Name of employee', + } + ] + + operator = BigQueryUpdateTableSchemaOperator( + schema_fields_updates=schema_field_updates, + include_policy_tags=False, + task_id=TASK_ID, + dataset_id=TEST_DATASET, + table_id=TEST_TABLE_ID, + project_id=TEST_GCP_PROJECT_ID, + ) + operator.execute(None) + + mock_hook.return_value.update_table_schema.assert_called_once_with( + schema_fields_updates=schema_field_updates, + include_policy_tags=False, + dataset_id=TEST_DATASET, + table_id=TEST_TABLE_ID, + project_id=TEST_GCP_PROJECT_ID, + ) + + class TestBigQueryPatchDatasetOperator(unittest.TestCase): @mock.patch('airflow.providers.google.cloud.operators.bigquery.BigQueryHook') def test_execute(self, mock_hook):