diff --git a/airflow/providers/amazon/aws/hooks/batch_client.py b/airflow/providers/amazon/aws/hooks/batch_client.py index f024134560d842..53c1b5b7e6e3eb 100644 --- a/airflow/providers/amazon/aws/hooks/batch_client.py +++ b/airflow/providers/amazon/aws/hooks/batch_client.py @@ -103,6 +103,7 @@ def submit_job( parameters: dict, containerOverrides: dict, ecsPropertiesOverride: dict, + eksPropertiesOverride: dict, tags: dict, ) -> dict: """ @@ -122,6 +123,8 @@ def submit_job( :param ecsPropertiesOverride: the same parameter that boto3 will receive + :param eksPropertiesOverride: the same parameter that boto3 will receive + :param tags: the same parameter that boto3 will receive :return: an API response diff --git a/airflow/providers/amazon/aws/operators/batch.py b/airflow/providers/amazon/aws/operators/batch.py index 20c2e9dced6485..ca4ba8bfad8c87 100644 --- a/airflow/providers/amazon/aws/operators/batch.py +++ b/airflow/providers/amazon/aws/operators/batch.py @@ -68,6 +68,7 @@ class BatchOperator(BaseOperator): :param overrides: DEPRECATED, use container_overrides instead with the same value. :param container_overrides: the `containerOverrides` parameter for boto3 (templated) :param ecs_properties_override: the `ecsPropertiesOverride` parameter for boto3 (templated) + :param eks_properties_override: the `eksPropertiesOverride` parameter for boto3 (templated) :param node_overrides: the `nodeOverrides` parameter for boto3 (templated) :param share_identifier: The share identifier for the job. Don't specify this parameter if the job queue doesn't have a scheduling policy. @@ -116,6 +117,7 @@ class BatchOperator(BaseOperator): "container_overrides", "array_properties", "ecs_properties_override", + "eks_properties_override", "node_overrides", "parameters", "retry_strategy", @@ -129,6 +131,7 @@ class BatchOperator(BaseOperator): "container_overrides": "json", "parameters": "json", "ecs_properties_override": "json", + "eks_properties_override": "json", "node_overrides": "json", "retry_strategy": "json", } @@ -166,6 +169,7 @@ def __init__( container_overrides: dict | None = None, array_properties: dict | None = None, ecs_properties_override: dict | None = None, + eks_properties_override: dict | None = None, node_overrides: dict | None = None, share_identifier: str | None = None, scheduling_priority_override: int | None = None, @@ -208,6 +212,7 @@ def __init__( ) self.ecs_properties_override = ecs_properties_override + self.eks_properties_override = eks_properties_override self.node_overrides = node_overrides self.share_identifier = share_identifier self.scheduling_priority_override = scheduling_priority_override @@ -307,6 +312,8 @@ def submit_job(self, context: Context): self.log.info("AWS Batch job - array properties: %s", self.array_properties) if self.ecs_properties_override: self.log.info("AWS Batch job - ECS properties: %s", self.ecs_properties_override) + if self.eks_properties_override: + self.log.info("AWS Batch job - EKS properties: %s", self.eks_properties_override) if self.node_overrides: self.log.info("AWS Batch job - node properties: %s", self.node_overrides) @@ -319,6 +326,7 @@ def submit_job(self, context: Context): "tags": self.tags, "containerOverrides": self.container_overrides, "ecsPropertiesOverride": self.ecs_properties_override, + "eksPropertiesOverride": self.eks_properties_override, "nodeOverrides": self.node_overrides, "retryStrategy": self.retry_strategy, "shareIdentifier": self.share_identifier, diff --git a/tests/providers/amazon/aws/operators/test_batch.py b/tests/providers/amazon/aws/operators/test_batch.py index 137044a212539b..7d9f27a6f4c5ae 100644 --- a/tests/providers/amazon/aws/operators/test_batch.py +++ b/tests/providers/amazon/aws/operators/test_batch.py @@ -133,6 +133,7 @@ def test_init_defaults(self): assert batch_job.container_overrides is None assert batch_job.array_properties is None assert batch_job.ecs_properties_override is None + assert batch_job.eks_properties_override is None assert batch_job.node_overrides is None assert batch_job.share_identifier is None assert batch_job.scheduling_priority_override is None @@ -151,6 +152,7 @@ def test_template_fields_overrides(self): "container_overrides", "array_properties", "ecs_properties_override", + "eks_properties_override", "node_overrides", "parameters", "retry_strategy", @@ -262,6 +264,104 @@ def test_execute_with_ecs_overrides(self, check_mock, wait_mock, job_description tags={}, ) + @mock.patch.object(BatchClientHook, "get_job_description") + @mock.patch.object(BatchClientHook, "wait_for_job") + @mock.patch.object(BatchClientHook, "check_job_success") + def test_execute_with_eks_overrides(self, check_mock, wait_mock, job_description_mock): + self.batch.container_overrides = None + self.batch.eks_properties_override = { + "podProperties": [ + { + "containers": [ + { + "image": "string", + "command": [ + "string", + ], + "args": [ + "string", + ], + "env": [ + {"name": "string", "value": "string"}, + ], + "resources": [{"limits": {"string": "string"}, "requests": {"string": "string"}}], + }, + ], + "initContainers": [ + { + "image": "string", + "command": [ + "string", + ], + "args": [ + "string", + ], + "env": [ + {"name": "string", "value": "string"}, + ], + "resources": [{"limits": {"string": "string"}, "requests": {"string": "string"}}], + }, + ], + "metadata": { + "labels": {"string": "string"}, + }, + }, + ] + } + self.batch.execute(self.mock_context) + + self.client_mock.submit_job.assert_called_once_with( + jobQueue="queue", + jobName=JOB_NAME, + jobDefinition="hello-world", + eksPropertiesOverride={ + "podProperties": [ + { + "containers": [ + { + "image": "string", + "command": [ + "string", + ], + "args": [ + "string", + ], + "env": [ + {"name": "string", "value": "string"}, + ], + "resources": [ + {"limits": {"string": "string"}, "requests": {"string": "string"}} + ], + }, + ], + "initContainers": [ + { + "image": "string", + "command": [ + "string", + ], + "args": [ + "string", + ], + "env": [ + {"name": "string", "value": "string"}, + ], + "resources": [ + {"limits": {"string": "string"}, "requests": {"string": "string"}} + ], + }, + ], + "metadata": { + "labels": {"string": "string"}, + }, + }, + ] + }, + parameters={}, + retryStrategy={"attempts": 1}, + tags={}, + ) + @mock.patch.object(BatchClientHook, "check_job_success") def test_wait_job_complete_using_waiters(self, check_mock): mock_waiters = mock.Mock() @@ -296,7 +396,9 @@ def test_kill_job(self): self.batch.on_kill() self.client_mock.terminate_job.assert_called_once_with(jobId=JOB_ID, reason="Task killed by the user") - @pytest.mark.parametrize("override", ["overrides", "node_overrides", "ecs_properties_override"]) + @pytest.mark.parametrize( + "override", ["overrides", "node_overrides", "ecs_properties_override", "eks_properties_override"] + ) @patch( "airflow.providers.amazon.aws.hooks.batch_client.BatchClientHook.client", new_callable=mock.PropertyMock, @@ -348,6 +450,7 @@ def test_override_not_sent_if_not_set(self, client_mock, override): "overrides": "containerOverrides", "node_overrides": "nodeOverrides", "ecs_properties_override": "ecsPropertiesOverride", + "eks_properties_override": "eksPropertiesOverride", } expected_args[py2api[override]] = {"a": "a"}