Skip to content

Commit

Permalink
Add support of capacity provider strategy for ECSOperator (#15848)
Browse files Browse the repository at this point in the history
  • Loading branch information
Pavel Hlushchanka authored Jun 12, 2021
1 parent d99afc3 commit 30708b5
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 15 deletions.
12 changes: 11 additions & 1 deletion airflow/providers/amazon/aws/operators/ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ class ECSOperator(BaseOperator): # pylint: disable=too-many-instance-attributes
:type region_name: str
:param launch_type: the launch type on which to run your task ('EC2' or 'FARGATE')
:type launch_type: str
:param capacity_provider_strategy: the capacity provider strategy to use for the task.
When capacity_provider_strategy is specified, the launch_type parameter is omitted.
If no capacity_provider_strategy or launch_type is specified,
the default capacity provider strategy for the cluster is used.
:type capacity_provider_strategy: list
:param group: the name of the task group associated with the task
:type group: str
:param placement_constraints: an array of placement constraint objects to use for
Expand Down Expand Up @@ -153,6 +158,7 @@ def __init__(
aws_conn_id: Optional[str] = None,
region_name: Optional[str] = None,
launch_type: str = 'EC2',
capacity_provider_strategy: Optional[list] = None,
group: Optional[str] = None,
placement_constraints: Optional[list] = None,
placement_strategy: Optional[list] = None,
Expand All @@ -175,6 +181,7 @@ def __init__(
self.cluster = cluster
self.overrides = overrides
self.launch_type = launch_type
self.capacity_provider_strategy = capacity_provider_strategy
self.group = group
self.placement_constraints = placement_constraints
self.placement_strategy = placement_strategy
Expand Down Expand Up @@ -229,7 +236,10 @@ def _start_task(self):
'startedBy': self.owner,
}

if self.launch_type:
if self.capacity_provider_strategy:
run_opts['capacityProviderStrategy'] = self.capacity_provider_strategy
run_opts['platformVersion'] = self.platform_version
elif self.launch_type:
run_opts['launchType'] = self.launch_type
if self.launch_type == 'FARGATE':
run_opts['platformVersion'] = self.platform_version
Expand Down
66 changes: 52 additions & 14 deletions tests/providers/amazon/aws/operators/test_ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,30 +96,68 @@ def test_template_fields_overrides(self):

@parameterized.expand(
[
['EC2', None],
['FARGATE', None],
['EC2', {'testTagKey': 'testTagValue'}],
['', {'testTagKey': 'testTagValue'}],
['EC2', None, None, {'launchType': 'EC2'}],
['FARGATE', None, None, {'launchType': 'FARGATE', 'platformVersion': 'LATEST'}],
[
'EC2',
None,
{'testTagKey': 'testTagValue'},
{'launchType': 'EC2', 'tags': [{'key': 'testTagKey', 'value': 'testTagValue'}]},
],
[
'',
None,
{'testTagKey': 'testTagValue'},
{'tags': [{'key': 'testTagKey', 'value': 'testTagValue'}]},
],
[
None,
{'capacityProvider': 'FARGATE_SPOT'},
None,
{
'capacityProviderStrategy': {'capacityProvider': 'FARGATE_SPOT'},
'platformVersion': 'LATEST',
},
],
[
'FARGATE',
{'capacityProvider': 'FARGATE_SPOT', 'weight': 123, 'base': 123},
None,
{
'capacityProviderStrategy': {
'capacityProvider': 'FARGATE_SPOT',
'weight': 123,
'base': 123,
},
'platformVersion': 'LATEST',
},
],
[
'EC2',
{'capacityProvider': 'FARGATE_SPOT'},
None,
{
'capacityProviderStrategy': {'capacityProvider': 'FARGATE_SPOT'},
'platformVersion': 'LATEST',
},
],
]
)
@mock.patch.object(ECSOperator, '_wait_for_task_ended')
@mock.patch.object(ECSOperator, '_check_success_task')
def test_execute_without_failures(self, launch_type, tags, check_mock, wait_mock):
def test_execute_without_failures(
self, launch_type, capacity_provider_strategy, tags, expected_args, check_mock, wait_mock
):

self.set_up_operator(launch_type=launch_type, tags=tags) # pylint: disable=no-value-for-parameter
self.set_up_operator( # pylint: disable=no-value-for-parameter
launch_type=launch_type, capacity_provider_strategy=capacity_provider_strategy, tags=tags
)
client_mock = self.aws_hook_mock.return_value.get_conn.return_value
client_mock.run_task.return_value = RESPONSE_WITHOUT_FAILURES

self.ecs.execute(None)

self.aws_hook_mock.return_value.get_conn.assert_called_once()
extend_args = {}
if launch_type:
extend_args['launchType'] = launch_type
if launch_type == 'FARGATE':
extend_args['platformVersion'] = 'LATEST'
if tags:
extend_args['tags'] = [{'key': k, 'value': v} for (k, v) in tags.items()]

client_mock.run_task.assert_called_once_with(
cluster='c',
Expand All @@ -133,7 +171,7 @@ def test_execute_without_failures(self, launch_type, tags, check_mock, wait_mock
'awsvpcConfiguration': {'securityGroups': ['sg-123abc'], 'subnets': ['subnet-123456ab']}
},
propagateTags='TASK_DEFINITION',
**extend_args,
**expected_args,
)

wait_mock.assert_called_once_with()
Expand Down

0 comments on commit 30708b5

Please sign in to comment.