Skip to content

Commit

Permalink
[AIRFLOW-5313] Add params support for awsbatch_operator (#5900)
Browse files Browse the repository at this point in the history
  • Loading branch information
jrowen authored and kaxil committed Dec 17, 2019
1 parent 255e89f commit 1818d41
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 9 deletions.
18 changes: 12 additions & 6 deletions airflow/contrib/operators/awsbatch_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,17 @@ class AWSBatchOperator(BaseOperator):
:param job_queue: the queue name on AWS Batch
:type job_queue: str
:param overrides: the same parameter that boto3 will receive on
containerOverrides (templated):
http://boto3.readthedocs.io/en/latest/reference/services/batch.html#submit_job
containerOverrides (templated)
http://boto3.readthedocs.io/en/latest/reference/services/batch.html#Batch.Client.submit_job
:type overrides: dict
:param array_properties: the same parameter that boto3 will receive on
arrayProperties:
http://boto3.readthedocs.io/en/latest/reference/services/batch.html#submit_job
arrayProperties
http://boto3.readthedocs.io/en/latest/reference/services/batch.html#Batch.Client.submit_job
:type array_properties: dict
:param parameters: the same parameter that boto3 will receive on
parameters (templated)
http://boto3.readthedocs.io/en/latest/reference/services/batch.html#Batch.Client.submit_job
:type parameters: dict
:param max_retries: exponential backoff retries while waiter is not
merged, 4200 = 48 hours
:type max_retries: int
Expand All @@ -66,11 +70,11 @@ class AWSBatchOperator(BaseOperator):
ui_color = '#c3dae0'
client = None
arn = None
template_fields = ('job_name', 'overrides',)
template_fields = ('job_name', 'overrides', 'parameters',)

@apply_defaults
def __init__(self, job_name, job_definition, job_queue, overrides, array_properties=None,
max_retries=4200, aws_conn_id=None, region_name=None, **kwargs):
parameters=None, max_retries=4200, aws_conn_id=None, region_name=None, **kwargs):
super(AWSBatchOperator, self).__init__(**kwargs)

self.job_name = job_name
Expand All @@ -80,6 +84,7 @@ def __init__(self, job_name, job_definition, job_queue, overrides, array_propert
self.job_queue = job_queue
self.overrides = overrides
self.array_properties = array_properties or {}
self.parameters = parameters
self.max_retries = max_retries

self.jobId = None # pylint: disable=invalid-name
Expand All @@ -105,6 +110,7 @@ def execute(self, context):
jobQueue=self.job_queue,
jobDefinition=self.job_definition,
arrayProperties=self.array_properties,
parameters=self.parameters,
containerOverrides=self.overrides)

self.log.info('AWS Batch Job started: %s', response)
Expand Down
10 changes: 7 additions & 3 deletions tests/contrib/operators/test_awsbatch_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def setUp(self, aws_hook_mock):
job_queue='queue',
job_definition='hello-world',
max_retries=5,
parameters=None,
overrides={},
array_properties=None,
aws_conn_id=None,
Expand All @@ -52,6 +53,7 @@ def test_init(self):
self.assertEqual(self.batch.job_queue, 'queue')
self.assertEqual(self.batch.job_definition, 'hello-world')
self.assertEqual(self.batch.max_retries, 5)
self.assertEqual(self.batch.parameters, None)
self.assertEqual(self.batch.overrides, {})
self.assertEqual(self.batch.array_properties, {})
self.assertEqual(self.batch.region_name, 'eu-west-1')
Expand All @@ -61,7 +63,7 @@ def test_init(self):
self.aws_hook_mock.assert_called_once_with(aws_conn_id=None)

def test_template_fields_overrides(self):
self.assertEqual(self.batch.template_fields, ('job_name', 'overrides',))
self.assertEqual(self.batch.template_fields, ('job_name', 'overrides', 'parameters',))

@mock.patch.object(AWSBatchOperator, '_wait_for_task_ended')
@mock.patch.object(AWSBatchOperator, '_check_success_task')
Expand All @@ -78,7 +80,8 @@ def test_execute_without_failures(self, check_mock, wait_mock):
jobName='51455483-c62c-48ac-9b88-53a6a725baa3',
containerOverrides={},
jobDefinition='hello-world',
arrayProperties={}
arrayProperties={},
parameters=None
)

wait_mock.assert_called_once_with()
Expand All @@ -99,7 +102,8 @@ def test_execute_with_failures(self):
jobName='51455483-c62c-48ac-9b88-53a6a725baa3',
containerOverrides={},
jobDefinition='hello-world',
arrayProperties={}
arrayProperties={},
parameters=None
)

def test_wait_end_tasks(self):
Expand Down

0 comments on commit 1818d41

Please sign in to comment.