Skip to content

Commit

Permalink
[AIRFLOW-2845] Asserts in contrib package code are changed on raise V…
Browse files Browse the repository at this point in the history
…alueError and TypeError (#3690)
  • Loading branch information
xnuinside authored and kaxil committed Sep 15, 2018
1 parent b7f5a3d commit a99d5c2
Show file tree
Hide file tree
Showing 10 changed files with 74 additions and 58 deletions.
55 changes: 29 additions & 26 deletions airflow/contrib/hooks/bigquery_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,9 +592,11 @@ def run_query(self,
}

if destination_dataset_table:
assert '.' in destination_dataset_table, (
'Expected destination_dataset_table in the format of '
'<dataset>.<table>. Got: {}').format(destination_dataset_table)
if '.' not in destination_dataset_table:
raise ValueError(
'Expected destination_dataset_table name in the format of '
'<dataset>.<table>. Got: {}'.format(
destination_dataset_table))
destination_project, destination_dataset, destination_table = \
_split_tablename(table_input=destination_dataset_table,
default_project_id=self.project_id)
Expand All @@ -610,7 +612,9 @@ def run_query(self,
}
})
if udf_config:
assert isinstance(udf_config, list)
if not isinstance(udf_config, list):
raise TypeError("udf_config argument must have a type 'list'"
" not {}".format(type(udf_config)))
configuration['query'].update({
'userDefinedFunctionResources': udf_config
})
Expand Down Expand Up @@ -1153,10 +1157,10 @@ def run_table_delete(self, deletion_dataset_table,
:type ignore_if_missing: boolean
:return:
"""

assert '.' in deletion_dataset_table, (
'Expected deletion_dataset_table in the format of '
'<dataset>.<table>. Got: {}').format(deletion_dataset_table)
if '.' not in deletion_dataset_table:
raise ValueError(
'Expected deletion_dataset_table name in the format of '
'<dataset>.<table>. Got: {}'.format(deletion_dataset_table))
deletion_project, deletion_dataset, deletion_table = \
_split_tablename(table_input=deletion_dataset_table,
default_project_id=self.project_id)
Expand Down Expand Up @@ -1284,14 +1288,10 @@ def run_grant_dataset_view_access(self,
# if view is already in access, do nothing.
self.log.info(
'Table %s:%s.%s already has authorized view access to %s:%s dataset.',
view_project, view_dataset, view_table, source_project,
source_dataset)
view_project, view_dataset, view_table, source_project, source_dataset)
return source_dataset_resource

def delete_dataset(self,
project_id,
dataset_id
):
def delete_dataset(self, project_id, dataset_id):
"""
Delete a dataset of Big query in your project.
:param project_id: The name of the project where we have the dataset .
Expand All @@ -1308,9 +1308,8 @@ def delete_dataset(self,
self.service.datasets().delete(
projectId=project_id,
datasetId=dataset_id).execute()

self.log.info('Dataset deleted successfully: In project %s Dataset %s',
project_id, dataset_id)
self.log.info('Dataset deleted successfully: In project %s '
'Dataset %s', project_id, dataset_id)

except HttpError as err:
raise AirflowException(
Expand Down Expand Up @@ -1518,14 +1517,17 @@ def _bq_cast(string_field, bq_type):
elif bq_type == 'FLOAT' or bq_type == 'TIMESTAMP':
return float(string_field)
elif bq_type == 'BOOLEAN':
assert string_field in set(['true', 'false'])
if string_field not in ['true', 'false']:
raise ValueError("{} must have value 'true' or 'false'".format(
string_field))
return string_field == 'true'
else:
return string_field


def _split_tablename(table_input, default_project_id, var_name=None):
assert default_project_id is not None, "INTERNAL: No default project is specified"
if not default_project_id:
raise ValueError("INTERNAL: No default project is specified")

def var_print(var_name):
if var_name is None:
Expand All @@ -1537,7 +1539,6 @@ def var_print(var_name):
raise Exception(('{var}Use either : or . to specify project '
'got {input}').format(
var=var_print(var_name), input=table_input))

cmpt = table_input.rsplit(':', 1)
project_id = None
rest = table_input
Expand All @@ -1555,8 +1556,10 @@ def var_print(var_name):

cmpt = rest.split('.')
if len(cmpt) == 3:
assert project_id is None, ("{var}Use either : or . to specify project"
).format(var=var_print(var_name))
if project_id:
raise ValueError(
"{var}Use either : or . to specify project".format(
var=var_print(var_name)))
project_id = cmpt[0]
dataset_id = cmpt[1]
table_id = cmpt[2]
Expand Down Expand Up @@ -1586,10 +1589,10 @@ def _cleanse_time_partitioning(destination_dataset_table, time_partitioning_in):
# if it is a partitioned table ($ is in the table name) add partition load option
time_partitioning_out = {}
if destination_dataset_table and '$' in destination_dataset_table:
assert not time_partitioning_in.get('field'), (
"Cannot specify field partition and partition name "
"(dataset.table$partition) at the same time"
)
if time_partitioning_in.get('field'):
raise ValueError(
"Cannot specify field partition and partition name"
"(dataset.table$partition) at the same time")
time_partitioning_out['type'] = 'DAY'

time_partitioning_out.update(time_partitioning_in)
Expand Down
3 changes: 2 additions & 1 deletion airflow/contrib/hooks/databricks_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ def __init__(
self.databricks_conn_id = databricks_conn_id
self.databricks_conn = self.get_connection(databricks_conn_id)
self.timeout_seconds = timeout_seconds
assert retry_limit >= 1, 'Retry limit must be greater than equal to 1'
if retry_limit < 1:
raise ValueError('Retry limit must be greater than equal to 1')
self.retry_limit = retry_limit

def _parse_host(self, host):
Expand Down
24 changes: 14 additions & 10 deletions airflow/contrib/hooks/gcp_dataflow_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,11 +225,11 @@ def label_formatter(labels_dict):
def _build_dataflow_job_name(task_id, append_job_name=True):
task_id = str(task_id).replace('_', '-')

assert re.match(r"^[a-z]([-a-z0-9]*[a-z0-9])?$", task_id), \
'Invalid job_name ({}); the name must consist of ' \
'only the characters [-a-z0-9], starting with a ' \
'letter and ending with a letter or number '.format(
task_id)
if not re.match(r"^[a-z]([-a-z0-9]*[a-z0-9])?$", task_id):
raise ValueError(
'Invalid job_name ({}); the name must consist of'
'only the characters [-a-z0-9], starting with a '
'letter and ending with a letter or number '.format(task_id))

if append_job_name:
job_name = task_id + "-" + str(uuid.uuid1())[:8]
Expand All @@ -238,7 +238,8 @@ def _build_dataflow_job_name(task_id, append_job_name=True):

return job_name

def _build_cmd(self, task_id, variables, label_formatter):
@staticmethod
def _build_cmd(task_id, variables, label_formatter):
command = ["--runner=DataflowRunner"]
if variables is not None:
for attr, value in variables.items():
Expand All @@ -250,7 +251,8 @@ def _build_cmd(self, task_id, variables, label_formatter):
command.append("--" + attr + "=" + value)
return command

def _start_template_dataflow(self, name, variables, parameters, dataflow_template):
def _start_template_dataflow(self, name, variables, parameters,
dataflow_template):
# Builds RuntimeEnvironment from variables dictionary
# https://cloud.google.com/dataflow/docs/reference/rest/v1b3/RuntimeEnvironment
environment = {}
Expand All @@ -262,9 +264,11 @@ def _start_template_dataflow(self, name, variables, parameters, dataflow_templat
"parameters": parameters,
"environment": environment}
service = self.get_conn()
request = service.projects().templates().launch(projectId=variables['project'],
gcsPath=dataflow_template,
body=body)
request = service.projects().templates().launch(
projectId=variables['project'],
gcsPath=dataflow_template,
body=body
)
response = request.execute()
variables = self._set_variables(variables)
_DataflowJob(self.get_conn(), variables['project'], name, variables['region'],
Expand Down
11 changes: 8 additions & 3 deletions airflow/contrib/hooks/gcp_mlengine_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,8 @@ def _wait_for_job_done(self, project_id, job_id, interval=30):
apiclient.errors.HttpError: if HTTP error is returned when getting
the job
"""
assert interval > 0
if interval <= 0:
raise ValueError("Interval must be > 0")
while True:
job = self._get_job(project_id, job_id)
if job['state'] in ['SUCCEEDED', 'FAILED', 'CANCELLED']:
Expand Down Expand Up @@ -242,7 +243,9 @@ def create_model(self, project_id, model):
"""
Create a Model. Blocks until finished.
"""
assert model['name'] is not None and model['name'] is not ''
if not model['name']:
raise ValueError("Model name must be provided and "
"could not be an empty string")
project = 'projects/{}'.format(project_id)

request = self._mlengine.projects().models().create(
Expand All @@ -253,7 +256,9 @@ def get_model(self, project_id, model_name):
"""
Gets a Model. Blocks until finished.
"""
assert model_name is not None and model_name is not ''
if not model_name:
raise ValueError("Model name must be provided and "
"it could not be an empty string")
full_model_name = 'projects/{}/models/{}'.format(
project_id, model_name)
request = self._mlengine.projects().models().get(name=full_model_name)
Expand Down
15 changes: 8 additions & 7 deletions airflow/contrib/hooks/gcs_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,15 +477,16 @@ def create_bucket(self,

self.log.info('Creating Bucket: %s; Location: %s; Storage Class: %s',
bucket_name, location, storage_class)
assert storage_class in storage_classes, \
'Invalid value ({}) passed to storage_class. Value should be ' \
'one of {}'.format(storage_class, storage_classes)
if storage_class not in storage_classes:
raise ValueError(
'Invalid value ({}) passed to storage_class. Value should be '
'one of {}'.format(storage_class, storage_classes))

assert re.match('[a-zA-Z0-9]+', bucket_name[0]), \
'Bucket names must start with a number or letter.'
if not re.match('[a-zA-Z0-9]+', bucket_name[0]):
raise ValueError('Bucket names must start with a number or letter.')

assert re.match('[a-zA-Z0-9]+', bucket_name[-1]), \
'Bucket names must end with a number or letter.'
if not re.match('[a-zA-Z0-9]+', bucket_name[-1]):
raise ValueError('Bucket names must end with a number or letter.')

service = self.get_conn()
bucket_resource = {
Expand Down
4 changes: 3 additions & 1 deletion airflow/contrib/operators/mlengine_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,9 @@ def execute(self, context):
gcp_conn_id=self._gcp_conn_id, delegate_to=self._delegate_to)

if self._operation == 'create':
assert self._version is not None
if not self._version:
raise ValueError("version attribute of {} could not "
"be empty".format(self.__class__.__name__))
return hook.create_version(self._project_id, self._model_name,
self._version)
elif self._operation == 'set_default':
Expand Down
2 changes: 1 addition & 1 deletion tests/contrib/hooks/test_bigquery_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ def test_extra_time_partitioning_options(self):
self.assertEqual(tp_out, expect)

def test_cant_add_dollar_and_field_name(self):
with self.assertRaises(AssertionError):
with self.assertRaises(ValueError):
_cleanse_time_partitioning(
'test.teast$20170101',
{'type': 'DAY', 'field': 'test_field', 'expirationMs': 1000}
Expand Down
6 changes: 3 additions & 3 deletions tests/contrib/hooks/test_databricks_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
Expand Down Expand Up @@ -110,7 +110,7 @@ def test_parse_host_with_scheme(self):
self.assertEquals(host, HOST)

def test_init_bad_retry_limit(self):
with self.assertRaises(AssertionError):
with self.assertRaises(ValueError):
DatabricksHook(retry_limit = 0)

@mock.patch('airflow.contrib.hooks.databricks_hook.requests')
Expand Down
8 changes: 4 additions & 4 deletions tests/contrib/hooks/test_gcp_dataflow_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def test_invalid_dataflow_job_name(self):
fixed_name = invalid_job_name.replace(
'_', '-')

with self.assertRaises(AssertionError) as e:
with self.assertRaises(ValueError) as e:
self.dataflow_hook._build_dataflow_job_name(
task_id=invalid_job_name, append_job_name=False
)
Expand All @@ -222,19 +222,19 @@ def test_dataflow_job_regex_check(self):
), 'dfjob1')

self.assertRaises(
AssertionError,
ValueError,
self.dataflow_hook._build_dataflow_job_name,
task_id='1dfjob', append_job_name=False
)

self.assertRaises(
AssertionError,
ValueError,
self.dataflow_hook._build_dataflow_job_name,
task_id='dfjob@', append_job_name=False
)

self.assertRaises(
AssertionError,
ValueError,
self.dataflow_hook._build_dataflow_job_name,
task_id='df^jo', append_job_name=False
)
Expand Down
4 changes: 2 additions & 2 deletions tests/contrib/hooks/test_gcs_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,14 @@ class TestGCSBucket(unittest.TestCase):
def test_bucket_name_value(self):

bad_start_bucket_name = '/testing123'
with self.assertRaises(AssertionError):
with self.assertRaises(ValueError):

gcs_hook.GoogleCloudStorageHook().create_bucket(
bucket_name=bad_start_bucket_name
)

bad_end_bucket_name = 'testing123/'
with self.assertRaises(AssertionError):
with self.assertRaises(ValueError):
gcs_hook.GoogleCloudStorageHook().create_bucket(
bucket_name=bad_end_bucket_name
)
Expand Down

0 comments on commit a99d5c2

Please sign in to comment.