Skip to content

Commit

Permalink
[AIRFLOW-XXX] Docs rendering improvement (apache#4684)
Browse files Browse the repository at this point in the history
  • Loading branch information
mik-laj authored and wayne.morris committed Jul 29, 2019
1 parent 5badd80 commit e65b050
Show file tree
Hide file tree
Showing 8 changed files with 69 additions and 56 deletions.
44 changes: 21 additions & 23 deletions airflow/contrib/hooks/bigquery_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def create_empty_table(self,
partition by field, type and expiration as per API specifications.
.. seealso::
https://cloud.google.com/bigquery/docs/reference/rest/v2/tables#timePartitioning
https://cloud.google.com/bigquery/docs/reference/rest/v2/tables#timePartitioning
:type time_partitioning: dict
:param view: [Optional] A dictionary containing definition for the view.
If set, it will create a view instead of a table:
Expand All @@ -269,7 +269,7 @@ def create_empty_table(self,
"useLegacySql": False
}
:return:
:return: None
"""

project_id = project_id if project_id is not None else self.project_id
Expand Down Expand Up @@ -356,7 +356,7 @@ def create_external_table(self,
Possible values include GZIP and NONE.
The default value is NONE.
This setting is ignored for Google Cloud Bigtable,
Google Cloud Datastore backups and Avro formats.
Google Cloud Datastore backups and Avro formats.
:type compression: str
:param ignore_unknown_values: [Optional] Indicates if BigQuery should allow
extra values that are not represented in the table schema.
Expand Down Expand Up @@ -546,28 +546,26 @@ def patch_table(self,
https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs#configuration.load.schema
The supported schema modifications and unsupported schema modification are listed here:
https://cloud.google.com/bigquery/docs/managing-table-schemas
:type schema: list
**Example**: ::
**Example**: ::
schema=[{"name": "emp_name", "type": "STRING", "mode": "REQUIRED"},
{"name": "salary", "type": "INTEGER", "mode": "NULLABLE"}]
schema=[{"name": "emp_name", "type": "STRING", "mode": "REQUIRED"},
{"name": "salary", "type": "INTEGER", "mode": "NULLABLE"}]
:type schema: list
:param time_partitioning: [Optional] A dictionary containing time-based partitioning
definition for the table.
:type time_partitioning: dict
:param view: [Optional] A dictionary containing definition for the view.
If set, it will patch a view instead of a table:
https://cloud.google.com/bigquery/docs/reference/rest/v2/tables#view
:type view: dict
**Example**: ::
**Example**: ::
view = {
"query": "SELECT * FROM `test-project-id.test_dataset_id.test_table_prefix*` LIMIT 500",
"useLegacySql": False
}
view = {
"query": "SELECT * FROM `test-project-id.test_dataset_id.test_table_prefix*` LIMIT 500",
"useLegacySql": False
}
:type view: dict
:param require_partition_filter: [Optional] If true, queries over the this table require a
partition filter. If false, queries over the table
:type require_partition_filter: bool
Expand Down Expand Up @@ -919,14 +917,14 @@ def run_copy(self,
For more details about these parameters.
:param source_project_dataset_tables: One or more dotted
(project:|project.)<dataset>.<table>
``(project:|project.)<dataset>.<table>``
BigQuery tables to use as the source data. Use a list if there are
multiple source tables.
If <project> is not included, project will be the project defined
in the connection json.
:type source_project_dataset_tables: list|string
:param destination_project_dataset_table: The destination BigQuery
table. Format is: (project:|project.)<dataset>.<table>
table. Format is: ``(project:|project.)<dataset>.<table>``
:type destination_project_dataset_table: str
:param write_disposition: The write disposition if the table already exists.
:type write_disposition: str
Expand Down Expand Up @@ -1371,11 +1369,11 @@ def run_table_delete(self, deletion_dataset_table,
is set to True.
:param deletion_dataset_table: A dotted
(<project>.|<project>:)<dataset>.<table> that indicates which table
will be deleted.
``(<project>.|<project>:)<dataset>.<table>`` that indicates which table
will be deleted.
:type deletion_dataset_table: str
:param ignore_if_missing: if True, then return success even if the
requested table does not exist.
requested table does not exist.
:type ignore_if_missing: bool
:return:
"""
Expand Down Expand Up @@ -1410,7 +1408,7 @@ def run_table_upsert(self, dataset_id, table_resource, project_id=None):
https://cloud.google.com/bigquery/docs/reference/v2/tables#resource
:type table_resource: dict
:param project_id: the project to upsert the table into. If None,
project will be self.project_id.
project will be self.project_id.
:return:
"""
# check to see if the table exists
Expand Down Expand Up @@ -1464,10 +1462,10 @@ def run_grant_dataset_view_access(self,
:param view_table: the table of the view
:type view_table: str
:param source_project: the project of the source dataset. If None,
self.project_id will be used.
self.project_id will be used.
:type source_project: str
:param view_project: the project that the view is in. If None,
self.project_id will be used.
self.project_id will be used.
:type view_project: str
:return: the datasets resource of the source dataset.
"""
Expand Down
19 changes: 11 additions & 8 deletions airflow/contrib/hooks/databricks_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,16 +79,18 @@ def _parse_host(host):
The purpose of this function is to be robust to improper connections
settings provided by users, specifically in the host field.
For example -- when users supply ``https://xx.cloud.databricks.com`` as the
host, we must strip out the protocol to get the host.
>>> h = DatabricksHook()
>>> assert h._parse_host('https://xx.cloud.databricks.com') == \
'xx.cloud.databricks.com'
host, we must strip out the protocol to get the host.::
h = DatabricksHook()
assert h._parse_host('https://xx.cloud.databricks.com') == \
'xx.cloud.databricks.com'
In the case where users supply the correct ``xx.cloud.databricks.com`` as the
host, this function is a no-op.
>>> assert h._parse_host('xx.cloud.databricks.com') == 'xx.cloud.databricks.com'
host, this function is a no-op.::
assert h._parse_host('xx.cloud.databricks.com') == 'xx.cloud.databricks.com'
"""
urlparse_host = urlparse.urlparse(host).hostname
if urlparse_host:
Expand All @@ -101,8 +103,9 @@ def _parse_host(host):
def _do_api_call(self, endpoint_info, json):
"""
Utility function to perform an API call with retries
:param endpoint_info: Tuple of method and endpoint
:type endpoint_info: (string, string)
:type endpoint_info: tuple[string, string]
:param json: Parameters for this API call.
:type json: dict
:return: If the api call returns a OK status code,
Expand Down
27 changes: 18 additions & 9 deletions airflow/contrib/hooks/spark_submit_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,16 +452,25 @@ def _start_driver_status_tracking(self):
Finish failed when the status is ERROR/UNKNOWN/KILLED/FAILED.
Possible status:
SUBMITTED: Submitted but not yet scheduled on a worker
RUNNING: Has been allocated to a worker to run
FINISHED: Previously ran and exited cleanly
RELAUNCHING: Exited non-zero or due to worker failure, but has not yet
SUBMITTED
Submitted but not yet scheduled on a worker
RUNNING
Has been allocated to a worker to run
FINISHED
Previously ran and exited cleanly
RELAUNCHING
Exited non-zero or due to worker failure, but has not yet
started running again
UNKNOWN: The status of the driver is temporarily not known due to
master failure recovery
KILLED: A user manually killed this driver
FAILED: The driver exited non-zero and was not supervised
ERROR: Unable to run or restart due to an unrecoverable error
UNKNOWN
The status of the driver is temporarily not known due to
master failure recovery
KILLED
A user manually killed this driver
FAILED
The driver exited non-zero and was not supervised
ERROR
Unable to run or restart due to an unrecoverable error
(e.g. missing jar file)
"""

Expand Down
2 changes: 1 addition & 1 deletion airflow/contrib/operators/dataflow_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ def google_cloud_to_local(self, file_name):
:param file_name: The full path of input file.
:type file_name: str
:return: The full path of local file.
:type str
:rtype: str
"""
if not file_name.startswith('gs://'):
return file_name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def validate_err_and_count(summary):
evaluate_summary = DataFlowPythonOperator(
task_id=(task_prefix + "-summary"),
py_options=["-m"],
py_file="airflow.contrib.operators.mlengine_prediction_summary",
py_file="airflow.contrib.utils.mlengine_prediction_summary",
dataflow_default_options=dataflow_options,
options={
"prediction_path": prediction_path,
Expand Down
10 changes: 6 additions & 4 deletions airflow/executors/celery_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,11 @@ def execute_command(command_to_exec):
class ExceptionWithTraceback(object):
"""
Wrapper class used to propagate exceptions to parent processes from subprocesses.
:param exception: The exception to wrap
:type exception: Exception
:param traceback: The stacktrace to wrap
:type traceback: str
:param exception_traceback: The stacktrace to wrap
:type exception_traceback: str
"""

def __init__(self, exception, exception_traceback):
Expand All @@ -90,11 +91,12 @@ def fetch_celery_task_state(celery_task):
"""
Fetch and return the state of the given celery task. The scope of this function is
global so that it can be called by subprocesses in the pool.
:param celery_task: a tuple of the Celery task key and the async Celery object used
to fetch the task's state
to fetch the task's state
:type celery_task: tuple(str, celery.result.AsyncResult)
:return: a tuple of the Celery task key and the Celery state of the task
:rtype: luple[str, str]
:rtype: tuple[str, str]
"""

try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@
import unittest

from airflow import configuration, DAG
from airflow.contrib.operators import mlengine_operator_utils
from airflow.contrib.operators.mlengine_operator_utils import create_evaluate_ops
from airflow.contrib.utils import mlengine_operator_utils
from airflow.exceptions import AirflowException
from airflow.version import version

Expand Down Expand Up @@ -76,7 +75,7 @@ def setUp(self):
def testSuccessfulRun(self):
input_with_model = self.INPUT_MISSING_ORIGIN.copy()

pred, summary, validate = create_evaluate_ops(
pred, summary, validate = mlengine_operator_utils.create_evaluate_ops(
task_prefix='eval-test',
batch_prediction_job_id='eval-test-prediction',
data_format=input_with_model['dataFormat'],
Expand Down Expand Up @@ -118,10 +117,10 @@ def testSuccessfulRun(self):
'metric_keys': 'err',
'metric_fn_encoded': self.metric_fn_encoded,
},
'airflow.contrib.operators.mlengine_prediction_summary',
'airflow.contrib.utils.mlengine_prediction_summary',
['-m'])

with patch('airflow.contrib.operators.mlengine_operator_utils.'
with patch('airflow.contrib.utils.mlengine_operator_utils.'
'GoogleCloudStorageHook') as mock_gcs_hook:
hook_instance = mock_gcs_hook.return_value
hook_instance.download.return_value = '{"err": 0.9, "count": 9}'
Expand Down Expand Up @@ -155,25 +154,27 @@ def testFailures(self):
}

with self.assertRaisesRegexp(AirflowException, 'Missing model origin'):
create_evaluate_ops(**other_params_but_models)
mlengine_operator_utils.create_evaluate_ops(**other_params_but_models)

with self.assertRaisesRegexp(AirflowException, 'Ambiguous model origin'):
create_evaluate_ops(model_uri='abc', model_name='cde', **other_params_but_models)
mlengine_operator_utils.create_evaluate_ops(model_uri='abc', model_name='cde',
**other_params_but_models)

with self.assertRaisesRegexp(AirflowException, 'Ambiguous model origin'):
create_evaluate_ops(model_uri='abc', version_name='vvv', **other_params_but_models)
mlengine_operator_utils.create_evaluate_ops(model_uri='abc', version_name='vvv',
**other_params_but_models)

with self.assertRaisesRegexp(AirflowException,
'`metric_fn` param must be callable'):
params = other_params_but_models.copy()
params['metric_fn_and_keys'] = (None, ['abc'])
create_evaluate_ops(model_uri='gs://blah', **params)
mlengine_operator_utils.create_evaluate_ops(model_uri='gs://blah', **params)

with self.assertRaisesRegexp(AirflowException,
'`validate_fn` param must be callable'):
params = other_params_but_models.copy()
params['validate_fn'] = None
create_evaluate_ops(model_uri='gs://blah', **params)
mlengine_operator_utils.create_evaluate_ops(model_uri='gs://blah', **params)


if __name__ == '__main__':
Expand Down

0 comments on commit e65b050

Please sign in to comment.