diff --git a/airflow/providers/amazon/aws/operators/ecs.py b/airflow/providers/amazon/aws/operators/ecs.py index 1580e35c6a42b..0c893176aa53d 100644 --- a/airflow/providers/amazon/aws/operators/ecs.py +++ b/airflow/providers/amazon/aws/operators/ecs.py @@ -106,6 +106,11 @@ class ECSOperator(BaseOperator): # pylint: disable=too-many-instance-attributes :type region_name: str :param launch_type: the launch type on which to run your task ('EC2' or 'FARGATE') :type launch_type: str + :param capacity_provider_strategy: the capacity provider strategy to use for the task. + When capacity_provider_strategy is specified, the launch_type parameter is omitted. + If no capacity_provider_strategy or launch_type is specified, + the default capacity provider strategy for the cluster is used. + :type capacity_provider_strategy: list :param group: the name of the task group associated with the task :type group: str :param placement_constraints: an array of placement constraint objects to use for @@ -153,6 +158,7 @@ def __init__( aws_conn_id: Optional[str] = None, region_name: Optional[str] = None, launch_type: str = 'EC2', + capacity_provider_strategy: Optional[list] = None, group: Optional[str] = None, placement_constraints: Optional[list] = None, placement_strategy: Optional[list] = None, @@ -175,6 +181,7 @@ def __init__( self.cluster = cluster self.overrides = overrides self.launch_type = launch_type + self.capacity_provider_strategy = capacity_provider_strategy self.group = group self.placement_constraints = placement_constraints self.placement_strategy = placement_strategy @@ -229,7 +236,10 @@ def _start_task(self): 'startedBy': self.owner, } - if self.launch_type: + if self.capacity_provider_strategy: + run_opts['capacityProviderStrategy'] = self.capacity_provider_strategy + run_opts['platformVersion'] = self.platform_version + elif self.launch_type: run_opts['launchType'] = self.launch_type if self.launch_type == 'FARGATE': run_opts['platformVersion'] = self.platform_version diff --git a/tests/providers/amazon/aws/operators/test_ecs.py b/tests/providers/amazon/aws/operators/test_ecs.py index 96717c37fe174..cf4049e49face 100644 --- a/tests/providers/amazon/aws/operators/test_ecs.py +++ b/tests/providers/amazon/aws/operators/test_ecs.py @@ -96,30 +96,68 @@ def test_template_fields_overrides(self): @parameterized.expand( [ - ['EC2', None], - ['FARGATE', None], - ['EC2', {'testTagKey': 'testTagValue'}], - ['', {'testTagKey': 'testTagValue'}], + ['EC2', None, None, {'launchType': 'EC2'}], + ['FARGATE', None, None, {'launchType': 'FARGATE', 'platformVersion': 'LATEST'}], + [ + 'EC2', + None, + {'testTagKey': 'testTagValue'}, + {'launchType': 'EC2', 'tags': [{'key': 'testTagKey', 'value': 'testTagValue'}]}, + ], + [ + '', + None, + {'testTagKey': 'testTagValue'}, + {'tags': [{'key': 'testTagKey', 'value': 'testTagValue'}]}, + ], + [ + None, + {'capacityProvider': 'FARGATE_SPOT'}, + None, + { + 'capacityProviderStrategy': {'capacityProvider': 'FARGATE_SPOT'}, + 'platformVersion': 'LATEST', + }, + ], + [ + 'FARGATE', + {'capacityProvider': 'FARGATE_SPOT', 'weight': 123, 'base': 123}, + None, + { + 'capacityProviderStrategy': { + 'capacityProvider': 'FARGATE_SPOT', + 'weight': 123, + 'base': 123, + }, + 'platformVersion': 'LATEST', + }, + ], + [ + 'EC2', + {'capacityProvider': 'FARGATE_SPOT'}, + None, + { + 'capacityProviderStrategy': {'capacityProvider': 'FARGATE_SPOT'}, + 'platformVersion': 'LATEST', + }, + ], ] ) @mock.patch.object(ECSOperator, '_wait_for_task_ended') @mock.patch.object(ECSOperator, '_check_success_task') - def test_execute_without_failures(self, launch_type, tags, check_mock, wait_mock): + def test_execute_without_failures( + self, launch_type, capacity_provider_strategy, tags, expected_args, check_mock, wait_mock + ): - self.set_up_operator(launch_type=launch_type, tags=tags) # pylint: disable=no-value-for-parameter + self.set_up_operator( # pylint: disable=no-value-for-parameter + launch_type=launch_type, capacity_provider_strategy=capacity_provider_strategy, tags=tags + ) client_mock = self.aws_hook_mock.return_value.get_conn.return_value client_mock.run_task.return_value = RESPONSE_WITHOUT_FAILURES self.ecs.execute(None) self.aws_hook_mock.return_value.get_conn.assert_called_once() - extend_args = {} - if launch_type: - extend_args['launchType'] = launch_type - if launch_type == 'FARGATE': - extend_args['platformVersion'] = 'LATEST' - if tags: - extend_args['tags'] = [{'key': k, 'value': v} for (k, v) in tags.items()] client_mock.run_task.assert_called_once_with( cluster='c', @@ -133,7 +171,7 @@ def test_execute_without_failures(self, launch_type, tags, check_mock, wait_mock 'awsvpcConfiguration': {'securityGroups': ['sg-123abc'], 'subnets': ['subnet-123456ab']} }, propagateTags='TASK_DEFINITION', - **extend_args, + **expected_args, ) wait_mock.assert_called_once_with()