Skip to content

Commit

Permalink
[AIRFLOW-2704] Add support for labels in the bigquery_operator
Browse files Browse the repository at this point in the history
[AIRFLOW-2704]Add support for labels in the
bigquery_operator

Adds support for bigquery labels in the bigquery
operator and hook.

Make labels template fields

Closes #3573 from mastoj/AIRFLOW-2704
  • Loading branch information
Tomas Jansson authored and kaxil committed Sep 15, 2018
1 parent 65bdc78 commit b7f5a3d
Show file tree
Hide file tree
Showing 6 changed files with 105 additions and 17 deletions.
41 changes: 37 additions & 4 deletions airflow/contrib/hooks/bigquery_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,8 @@ def create_empty_table(self,
dataset_id,
table_id,
schema_fields=None,
time_partitioning={}
time_partitioning={},
labels=None
):
"""
Creates a new, empty table in the dataset.
Expand All @@ -219,6 +220,8 @@ def create_empty_table(self,
:type table_id: str
:param schema_fields: If set, the schema field list as defined here:
https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs#configuration.load.schema
:param labels: a dictionary containing labels for the table, passed to BigQuery
:type labels: dict
**Example**: ::
Expand Down Expand Up @@ -249,6 +252,9 @@ def create_empty_table(self,
if time_partitioning:
table_resource['timePartitioning'] = time_partitioning

if labels:
table_resource['labels'] = labels

self.log.info('Creating Table %s:%s.%s',
project_id, dataset_id, table_id)

Expand Down Expand Up @@ -280,7 +286,8 @@ def create_external_table(self,
quote_character=None,
allow_quoted_newlines=False,
allow_jagged_rows=False,
src_fmt_configs={}
src_fmt_configs={},
labels=None
):
"""
Creates a new external table in the dataset with the data in Google
Expand Down Expand Up @@ -341,6 +348,8 @@ def create_external_table(self,
:type allow_jagged_rows: bool
:param src_fmt_configs: configure optional fields specific to the source format
:type src_fmt_configs: dict
:param labels: a dictionary containing labels for the table, passed to BigQuery
:type labels: dict
"""

project_id, dataset_id, external_table_id = \
Expand Down Expand Up @@ -439,6 +448,9 @@ def create_external_table(self,
table_resource['externalDataConfiguration'][src_fmt_to_param_mapping[
source_format]] = src_fmt_configs

if labels:
table_resource['labels'] = labels

try:
self.service.tables().insert(
projectId=project_id,
Expand Down Expand Up @@ -467,6 +479,7 @@ def run_query(self,
maximum_bytes_billed=None,
create_disposition='CREATE_IF_NEEDED',
query_params=None,
labels=None,
schema_update_options=(),
priority='INTERACTIVE',
time_partitioning={}):
Expand Down Expand Up @@ -516,6 +529,9 @@ def run_query(self,
:param query_params a dictionary containing query parameter types and
values, passed to BigQuery
:type query_params: dict
:param labels a dictionary containing labels for the job/query,
passed to BigQuery
:type labels: dict
:param schema_update_options: Allows the schema of the desitination
table to be updated as a side effect of the query job.
:type schema_update_options: tuple
Expand Down Expand Up @@ -606,6 +622,9 @@ def run_query(self,
else:
configuration['query']['queryParameters'] = query_params

if labels:
configuration['labels'] = labels

time_partitioning = _cleanse_time_partitioning(
destination_dataset_table,
time_partitioning
Expand Down Expand Up @@ -636,7 +655,8 @@ def run_extract( # noqa
compression='NONE',
export_format='CSV',
field_delimiter=',',
print_header=True):
print_header=True,
labels=None):
"""
Executes a BigQuery extract command to copy data from BigQuery to
Google Cloud Storage. See here:
Expand All @@ -661,6 +681,9 @@ def run_extract( # noqa
:type field_delimiter: string
:param print_header: Whether to print a header for a CSV file extract.
:type print_header: boolean
:param labels: a dictionary containing labels for the job/query,
passed to BigQuery
:type labels: dict
"""

source_project, source_dataset, source_table = \
Expand All @@ -681,6 +704,9 @@ def run_extract( # noqa
}
}

if labels:
configuration['labels'] = labels

if export_format == 'CSV':
# Only set fieldDelimiter and printHeader fields if using CSV.
# Google does not like it if you set these fields for other export
Expand All @@ -694,7 +720,8 @@ def run_copy(self,
source_project_dataset_tables,
destination_project_dataset_table,
write_disposition='WRITE_EMPTY',
create_disposition='CREATE_IF_NEEDED'):
create_disposition='CREATE_IF_NEEDED',
labels=None):
"""
Executes a BigQuery copy command to copy data from one BigQuery table
to another. See here:
Expand All @@ -717,6 +744,9 @@ def run_copy(self,
:type write_disposition: string
:param create_disposition: The create disposition if the table doesn't exist.
:type create_disposition: string
:param labels a dictionary containing labels for the job/query,
passed to BigQuery
:type labels: dict
"""
source_project_dataset_tables = ([
source_project_dataset_tables
Expand Down Expand Up @@ -754,6 +784,9 @@ def run_copy(self,
}
}

if labels:
configuration['labels'] = labels

return self.run_with_configuration(configuration)

def run_load(self,
Expand Down
27 changes: 22 additions & 5 deletions airflow/contrib/operators/bigquery_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ class BigQueryOperator(BaseOperator):
:param query_params: a dictionary containing query parameter types and
values, passed to BigQuery.
:type query_params: dict
:param labels: a dictionary containing labels for the job/query,
passed to BigQuery
:type labels: dict
:param priority: Specifies a priority for the query.
Possible values include INTERACTIVE and BATCH.
The default value is INTERACTIVE.
Expand All @@ -92,7 +95,7 @@ class BigQueryOperator(BaseOperator):
:type time_partitioning: dict
"""

template_fields = ('bql', 'sql', 'destination_dataset_table')
template_fields = ('bql', 'sql', 'destination_dataset_table', 'labels')
template_ext = ('.sql', )
ui_color = '#e4f0e8'

Expand All @@ -113,6 +116,7 @@ def __init__(self,
create_disposition='CREATE_IF_NEEDED',
schema_update_options=(),
query_params=None,
labels=None,
priority='INTERACTIVE',
time_partitioning={},
*args,
Expand All @@ -133,6 +137,7 @@ def __init__(self,
self.maximum_bytes_billed = maximum_bytes_billed
self.schema_update_options = schema_update_options
self.query_params = query_params
self.labels = labels
self.bq_cursor = None
self.priority = priority
self.time_partitioning = time_partitioning
Expand Down Expand Up @@ -171,6 +176,7 @@ def execute(self, context):
maximum_bytes_billed=self.maximum_bytes_billed,
create_disposition=self.create_disposition,
query_params=self.query_params,
labels=self.labels,
schema_update_options=self.schema_update_options,
priority=self.priority,
time_partitioning=self.time_partitioning
Expand Down Expand Up @@ -228,6 +234,8 @@ class BigQueryCreateEmptyTableOperator(BaseOperator):
work, the service account making the request must have domain-wide
delegation enabled.
:type delegate_to: string
:param labels a dictionary containing labels for the table, passed to BigQuery
:type labels: dict
**Example (with schema JSON in GCS)**: ::
Expand Down Expand Up @@ -270,7 +278,8 @@ class BigQueryCreateEmptyTableOperator(BaseOperator):
)
"""
template_fields = ('dataset_id', 'table_id', 'project_id', 'gcs_schema_object')
template_fields = ('dataset_id', 'table_id', 'project_id',
'gcs_schema_object', 'labels')
ui_color = '#f0eee4'

@apply_defaults
Expand All @@ -284,6 +293,7 @@ def __init__(self,
bigquery_conn_id='bigquery_default',
google_cloud_storage_conn_id='google_cloud_default',
delegate_to=None,
labels=None,
*args, **kwargs):

super(BigQueryCreateEmptyTableOperator, self).__init__(*args, **kwargs)
Expand All @@ -297,6 +307,7 @@ def __init__(self,
self.google_cloud_storage_conn_id = google_cloud_storage_conn_id
self.delegate_to = delegate_to
self.time_partitioning = time_partitioning
self.labels = labels

def execute(self, context):
bq_hook = BigQueryHook(bigquery_conn_id=self.bigquery_conn_id,
Expand All @@ -323,7 +334,8 @@ def execute(self, context):
dataset_id=self.dataset_id,
table_id=self.table_id,
schema_fields=schema_fields,
time_partitioning=self.time_partitioning
time_partitioning=self.time_partitioning,
labels=self.labels
)


Expand Down Expand Up @@ -396,9 +408,11 @@ class BigQueryCreateExternalTableOperator(BaseOperator):
:type delegate_to: string
:param src_fmt_configs: configure optional fields specific to the source format
:type src_fmt_configs: dict
:param labels a dictionary containing labels for the table, passed to BigQuery
:type labels: dict
"""
template_fields = ('bucket', 'source_objects',
'schema_object', 'destination_project_dataset_table')
'schema_object', 'destination_project_dataset_table', 'labels')
ui_color = '#f0eee4'

@apply_defaults
Expand All @@ -420,6 +434,7 @@ def __init__(self,
google_cloud_storage_conn_id='google_cloud_default',
delegate_to=None,
src_fmt_configs={},
labels=None,
*args, **kwargs):

super(BigQueryCreateExternalTableOperator, self).__init__(*args, **kwargs)
Expand All @@ -446,6 +461,7 @@ def __init__(self,
self.delegate_to = delegate_to

self.src_fmt_configs = src_fmt_configs
self.labels = labels

def execute(self, context):
bq_hook = BigQueryHook(bigquery_conn_id=self.bigquery_conn_id,
Expand Down Expand Up @@ -479,7 +495,8 @@ def execute(self, context):
quote_character=self.quote_character,
allow_quoted_newlines=self.allow_quoted_newlines,
allow_jagged_rows=self.allow_jagged_rows,
src_fmt_configs=self.src_fmt_configs
src_fmt_configs=self.src_fmt_configs,
labels=self.labels
)


Expand Down
10 changes: 8 additions & 2 deletions airflow/contrib/operators/bigquery_to_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,12 @@ class BigQueryToBigQueryOperator(BaseOperator):
For this to work, the service account making the request must have domain-wide
delegation enabled.
:type delegate_to: string
:param labels: a dictionary containing labels for the job/query,
passed to BigQuery
:type labels: dict
"""
template_fields = ('source_project_dataset_tables',
'destination_project_dataset_table')
'destination_project_dataset_table', 'labels')
template_ext = ('.sql',)
ui_color = '#e6f0e4'

Expand All @@ -63,6 +66,7 @@ def __init__(self,
create_disposition='CREATE_IF_NEEDED',
bigquery_conn_id='bigquery_default',
delegate_to=None,
labels=None,
*args,
**kwargs):
super(BigQueryToBigQueryOperator, self).__init__(*args, **kwargs)
Expand All @@ -72,6 +76,7 @@ def __init__(self,
self.create_disposition = create_disposition
self.bigquery_conn_id = bigquery_conn_id
self.delegate_to = delegate_to
self.labels = labels

def execute(self, context):
self.log.info(
Expand All @@ -86,4 +91,5 @@ def execute(self, context):
self.source_project_dataset_tables,
self.destination_project_dataset_table,
self.write_disposition,
self.create_disposition)
self.create_disposition,
self.labels)
11 changes: 9 additions & 2 deletions airflow/contrib/operators/bigquery_to_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,12 @@ class BigQueryToCloudStorageOperator(BaseOperator):
For this to work, the service account making the request must have domain-wide
delegation enabled.
:type delegate_to: string
:param labels: a dictionary containing labels for the job/query,
passed to BigQuery
:type labels: dict
"""
template_fields = ('source_project_dataset_table', 'destination_cloud_storage_uris')
template_fields = ('source_project_dataset_table',
'destination_cloud_storage_uris', 'labels')
template_ext = ('.sql',)
ui_color = '#e4e6f0'

Expand All @@ -69,6 +73,7 @@ def __init__(self,
print_header=True,
bigquery_conn_id='bigquery_default',
delegate_to=None,
labels=None,
*args,
**kwargs):
super(BigQueryToCloudStorageOperator, self).__init__(*args, **kwargs)
Expand All @@ -80,6 +85,7 @@ def __init__(self,
self.print_header = print_header
self.bigquery_conn_id = bigquery_conn_id
self.delegate_to = delegate_to
self.labels = labels

def execute(self, context):
self.log.info('Executing extract of %s into: %s',
Expand All @@ -95,4 +101,5 @@ def execute(self, context):
self.compression,
self.export_format,
self.field_delimiter,
self.print_header)
self.print_header,
self.labels)
27 changes: 25 additions & 2 deletions tests/contrib/hooks/test_bigquery_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ def test_suceeds_with_explicit_legacy_query(self):

@unittest.skipIf(not bq_available, 'BQ is not available to run tests')
def test_suceeds_with_explicit_std_query(self):
df = self.instance.get_pandas_df('select * except(b) from (select 1 a, 2 b)', dialect='standard')
df = self.instance.get_pandas_df(
'select * except(b) from (select 1 a, 2 b)', dialect='standard')
self.assertEqual(df.iloc(0)[0][0], 1)

@unittest.skipIf(not bq_available, 'BQ is not available to run tests')
Expand Down Expand Up @@ -281,6 +282,27 @@ def test_run_query_sql_dialect_override(self, run_with_config):
self.assertIs(args[0]['query']['useLegacySql'], bool_val)


class TestLabelsInRunJob(unittest.TestCase):
@mock.patch.object(hook.BigQueryBaseCursor, 'run_with_configuration')
def test_run_query_with_arg(self, mocked_rwc):
project_id = 12345

def run_with_config(config):
self.assertEqual(
config['labels'], {'label1': 'test1', 'label2': 'test2'}
)
mocked_rwc.side_effect = run_with_config

bq_hook = hook.BigQueryBaseCursor(mock.Mock(), project_id)
bq_hook.run_query(
sql='select 1',
destination_dataset_table='my_dataset.my_table',
labels={'label1': 'test1', 'label2': 'test2'}
)

mocked_rwc.assert_called_once()


class TestTimePartitioningInRunJob(unittest.TestCase):

@mock.patch("airflow.contrib.hooks.bigquery_hook.LoggingMixin")
Expand Down Expand Up @@ -365,7 +387,8 @@ def run_with_config(config):
bq_hook.run_query(
sql='select 1',
destination_dataset_table='my_dataset.my_table',
time_partitioning={'type': 'DAY', 'field': 'test_field', 'expirationMs': 1000}
time_partitioning={'type': 'DAY',
'field': 'test_field', 'expirationMs': 1000}
)

mocked_rwc.assert_called_once()
Expand Down
Loading

0 comments on commit b7f5a3d

Please sign in to comment.