From 5f2ebb312b08769b454a777280ddf5c43c38bb87 Mon Sep 17 00:00:00 2001 From: Josh Dimarsky <24758845+yehoshuadimarsky@users.noreply.github.com> Date: Wed, 29 May 2024 05:18:10 -0400 Subject: [PATCH] ECS Overrides for AWS Batch submit_job (#39903) --- .../amazon/aws/hooks/batch_client.py | 3 + .../providers/amazon/aws/operators/batch.py | 8 ++ .../amazon/aws/operators/test_batch.py | 73 +++++++++++++++++-- 3 files changed, 79 insertions(+), 5 deletions(-) diff --git a/airflow/providers/amazon/aws/hooks/batch_client.py b/airflow/providers/amazon/aws/hooks/batch_client.py index b419239a16b50d..f024134560d842 100644 --- a/airflow/providers/amazon/aws/hooks/batch_client.py +++ b/airflow/providers/amazon/aws/hooks/batch_client.py @@ -102,6 +102,7 @@ def submit_job( arrayProperties: dict, parameters: dict, containerOverrides: dict, + ecsPropertiesOverride: dict, tags: dict, ) -> dict: """ @@ -119,6 +120,8 @@ def submit_job( :param containerOverrides: the same parameter that boto3 will receive + :param ecsPropertiesOverride: the same parameter that boto3 will receive + :param tags: the same parameter that boto3 will receive :return: an API response diff --git a/airflow/providers/amazon/aws/operators/batch.py b/airflow/providers/amazon/aws/operators/batch.py index 00b6287145a817..849fc193461a55 100644 --- a/airflow/providers/amazon/aws/operators/batch.py +++ b/airflow/providers/amazon/aws/operators/batch.py @@ -65,6 +65,7 @@ class BatchOperator(BaseOperator): :param job_queue: the queue name on AWS Batch :param overrides: DEPRECATED, use container_overrides instead with the same value. :param container_overrides: the `containerOverrides` parameter for boto3 (templated) + :param ecs_properties_override: the `ecsPropertiesOverride` parameter for boto3 (templated) :param node_overrides: the `nodeOverrides` parameter for boto3 (templated) :param share_identifier: The share identifier for the job. Don't specify this parameter if the job queue doesn't have a scheduling policy. @@ -112,6 +113,7 @@ class BatchOperator(BaseOperator): "job_queue", "container_overrides", "array_properties", + "ecs_properties_override", "node_overrides", "parameters", "retry_strategy", @@ -124,6 +126,7 @@ class BatchOperator(BaseOperator): template_fields_renderers = { "container_overrides": "json", "parameters": "json", + "ecs_properties_override": "json", "node_overrides": "json", "retry_strategy": "json", } @@ -160,6 +163,7 @@ def __init__( overrides: dict | None = None, # deprecated container_overrides: dict | None = None, array_properties: dict | None = None, + ecs_properties_override: dict | None = None, node_overrides: dict | None = None, share_identifier: str | None = None, scheduling_priority_override: int | None = None, @@ -201,6 +205,7 @@ def __init__( stacklevel=2, ) + self.ecs_properties_override = ecs_properties_override self.node_overrides = node_overrides self.share_identifier = share_identifier self.scheduling_priority_override = scheduling_priority_override @@ -296,6 +301,8 @@ def submit_job(self, context: Context): self.log.info("AWS Batch job - container overrides: %s", self.container_overrides) if self.array_properties: self.log.info("AWS Batch job - array properties: %s", self.array_properties) + if self.ecs_properties_override: + self.log.info("AWS Batch job - ECS properties: %s", self.ecs_properties_override) if self.node_overrides: self.log.info("AWS Batch job - node properties: %s", self.node_overrides) @@ -307,6 +314,7 @@ def submit_job(self, context: Context): "parameters": self.parameters, "tags": self.tags, "containerOverrides": self.container_overrides, + "ecsPropertiesOverride": self.ecs_properties_override, "nodeOverrides": self.node_overrides, "retryStrategy": self.retry_strategy, "shareIdentifier": self.share_identifier, diff --git a/tests/providers/amazon/aws/operators/test_batch.py b/tests/providers/amazon/aws/operators/test_batch.py index f769c1baa81819..27f86e279c27ee 100644 --- a/tests/providers/amazon/aws/operators/test_batch.py +++ b/tests/providers/amazon/aws/operators/test_batch.py @@ -132,6 +132,7 @@ def test_init_defaults(self): assert batch_job.retry_strategy is None assert batch_job.container_overrides is None assert batch_job.array_properties is None + assert batch_job.ecs_properties_override is None assert batch_job.node_overrides is None assert batch_job.share_identifier is None assert batch_job.scheduling_priority_override is None @@ -149,6 +150,7 @@ def test_template_fields_overrides(self): "job_queue", "container_overrides", "array_properties", + "ecs_properties_override", "node_overrides", "parameters", "retry_strategy", @@ -204,6 +206,62 @@ def test_execute_with_failures(self): tags={}, ) + @mock.patch.object(BatchClientHook, "get_job_description") + @mock.patch.object(BatchClientHook, "wait_for_job") + @mock.patch.object(BatchClientHook, "check_job_success") + def test_execute_with_ecs_overrides(self, check_mock, wait_mock, job_description_mock): + self.batch.container_overrides = None + self.batch.ecs_properties_override = { + "taskProperties": [ + { + "containers": [ + { + "command": [ + "string", + ], + "environment": [ + {"name": "string", "value": "string"}, + ], + "name": "string", + "resourceRequirements": [ + {"value": "string", "type": "'GPU'|'VCPU'|'MEMORY'"}, + ], + }, + ] + }, + ] + } + self.batch.execute(self.mock_context) + + self.client_mock.submit_job.assert_called_once_with( + jobQueue="queue", + jobName=JOB_NAME, + jobDefinition="hello-world", + ecsPropertiesOverride={ + "taskProperties": [ + { + "containers": [ + { + "command": [ + "string", + ], + "environment": [ + {"name": "string", "value": "string"}, + ], + "name": "string", + "resourceRequirements": [ + {"value": "string", "type": "'GPU'|'VCPU'|'MEMORY'"}, + ], + }, + ] + }, + ] + }, + parameters={}, + retryStrategy={"attempts": 1}, + tags={}, + ) + @mock.patch.object(BatchClientHook, "check_job_success") def test_wait_job_complete_using_waiters(self, check_mock): mock_waiters = mock.Mock() @@ -238,7 +296,7 @@ def test_kill_job(self): self.batch.on_kill() self.client_mock.terminate_job.assert_called_once_with(jobId=JOB_ID, reason="Task killed by the user") - @pytest.mark.parametrize("override", ["overrides", "node_overrides"]) + @pytest.mark.parametrize("override", ["overrides", "node_overrides", "ecs_properties_override"]) @patch( "airflow.providers.amazon.aws.hooks.batch_client.BatchClientHook.client", new_callable=mock.PropertyMock, @@ -269,10 +327,15 @@ def test_override_not_sent_if_not_set(self, client_mock, override): "parameters": {}, "tags": {}, } - if override == "overrides": - expected_args["containerOverrides"] = {"a": "a"} - else: - expected_args["nodeOverrides"] = {"a": "a"} + + py2api = { + "overrides": "containerOverrides", + "node_overrides": "nodeOverrides", + "ecs_properties_override": "ecsPropertiesOverride", + } + + expected_args[py2api[override]] = {"a": "a"} + client_mock().submit_job.assert_called_once_with(**expected_args) def test_deprecated_override_param(self):