Skip to content

Commit

Permalink
EKS Overrides for AWS Batch submit_job (apache#40718)
Browse files Browse the repository at this point in the history
* add eks properties overrride
  • Loading branch information
ssilb4 authored and Artuz37 committed Aug 13, 2024
1 parent eba50bf commit d073d20
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 1 deletion.
3 changes: 3 additions & 0 deletions airflow/providers/amazon/aws/hooks/batch_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def submit_job(
parameters: dict,
containerOverrides: dict,
ecsPropertiesOverride: dict,
eksPropertiesOverride: dict,
tags: dict,
) -> dict:
"""
Expand All @@ -122,6 +123,8 @@ def submit_job(
:param ecsPropertiesOverride: the same parameter that boto3 will receive
:param eksPropertiesOverride: the same parameter that boto3 will receive
:param tags: the same parameter that boto3 will receive
:return: an API response
Expand Down
8 changes: 8 additions & 0 deletions airflow/providers/amazon/aws/operators/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class BatchOperator(BaseOperator):
:param overrides: DEPRECATED, use container_overrides instead with the same value.
:param container_overrides: the `containerOverrides` parameter for boto3 (templated)
:param ecs_properties_override: the `ecsPropertiesOverride` parameter for boto3 (templated)
:param eks_properties_override: the `eksPropertiesOverride` parameter for boto3 (templated)
:param node_overrides: the `nodeOverrides` parameter for boto3 (templated)
:param share_identifier: The share identifier for the job. Don't specify this parameter if the job queue
doesn't have a scheduling policy.
Expand Down Expand Up @@ -116,6 +117,7 @@ class BatchOperator(BaseOperator):
"container_overrides",
"array_properties",
"ecs_properties_override",
"eks_properties_override",
"node_overrides",
"parameters",
"retry_strategy",
Expand All @@ -129,6 +131,7 @@ class BatchOperator(BaseOperator):
"container_overrides": "json",
"parameters": "json",
"ecs_properties_override": "json",
"eks_properties_override": "json",
"node_overrides": "json",
"retry_strategy": "json",
}
Expand Down Expand Up @@ -166,6 +169,7 @@ def __init__(
container_overrides: dict | None = None,
array_properties: dict | None = None,
ecs_properties_override: dict | None = None,
eks_properties_override: dict | None = None,
node_overrides: dict | None = None,
share_identifier: str | None = None,
scheduling_priority_override: int | None = None,
Expand Down Expand Up @@ -208,6 +212,7 @@ def __init__(
)

self.ecs_properties_override = ecs_properties_override
self.eks_properties_override = eks_properties_override
self.node_overrides = node_overrides
self.share_identifier = share_identifier
self.scheduling_priority_override = scheduling_priority_override
Expand Down Expand Up @@ -307,6 +312,8 @@ def submit_job(self, context: Context):
self.log.info("AWS Batch job - array properties: %s", self.array_properties)
if self.ecs_properties_override:
self.log.info("AWS Batch job - ECS properties: %s", self.ecs_properties_override)
if self.eks_properties_override:
self.log.info("AWS Batch job - EKS properties: %s", self.eks_properties_override)
if self.node_overrides:
self.log.info("AWS Batch job - node properties: %s", self.node_overrides)

Expand All @@ -319,6 +326,7 @@ def submit_job(self, context: Context):
"tags": self.tags,
"containerOverrides": self.container_overrides,
"ecsPropertiesOverride": self.ecs_properties_override,
"eksPropertiesOverride": self.eks_properties_override,
"nodeOverrides": self.node_overrides,
"retryStrategy": self.retry_strategy,
"shareIdentifier": self.share_identifier,
Expand Down
105 changes: 104 additions & 1 deletion tests/providers/amazon/aws/operators/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def test_init_defaults(self):
assert batch_job.container_overrides is None
assert batch_job.array_properties is None
assert batch_job.ecs_properties_override is None
assert batch_job.eks_properties_override is None
assert batch_job.node_overrides is None
assert batch_job.share_identifier is None
assert batch_job.scheduling_priority_override is None
Expand All @@ -151,6 +152,7 @@ def test_template_fields_overrides(self):
"container_overrides",
"array_properties",
"ecs_properties_override",
"eks_properties_override",
"node_overrides",
"parameters",
"retry_strategy",
Expand Down Expand Up @@ -262,6 +264,104 @@ def test_execute_with_ecs_overrides(self, check_mock, wait_mock, job_description
tags={},
)

@mock.patch.object(BatchClientHook, "get_job_description")
@mock.patch.object(BatchClientHook, "wait_for_job")
@mock.patch.object(BatchClientHook, "check_job_success")
def test_execute_with_eks_overrides(self, check_mock, wait_mock, job_description_mock):
self.batch.container_overrides = None
self.batch.eks_properties_override = {
"podProperties": [
{
"containers": [
{
"image": "string",
"command": [
"string",
],
"args": [
"string",
],
"env": [
{"name": "string", "value": "string"},
],
"resources": [{"limits": {"string": "string"}, "requests": {"string": "string"}}],
},
],
"initContainers": [
{
"image": "string",
"command": [
"string",
],
"args": [
"string",
],
"env": [
{"name": "string", "value": "string"},
],
"resources": [{"limits": {"string": "string"}, "requests": {"string": "string"}}],
},
],
"metadata": {
"labels": {"string": "string"},
},
},
]
}
self.batch.execute(self.mock_context)

self.client_mock.submit_job.assert_called_once_with(
jobQueue="queue",
jobName=JOB_NAME,
jobDefinition="hello-world",
eksPropertiesOverride={
"podProperties": [
{
"containers": [
{
"image": "string",
"command": [
"string",
],
"args": [
"string",
],
"env": [
{"name": "string", "value": "string"},
],
"resources": [
{"limits": {"string": "string"}, "requests": {"string": "string"}}
],
},
],
"initContainers": [
{
"image": "string",
"command": [
"string",
],
"args": [
"string",
],
"env": [
{"name": "string", "value": "string"},
],
"resources": [
{"limits": {"string": "string"}, "requests": {"string": "string"}}
],
},
],
"metadata": {
"labels": {"string": "string"},
},
},
]
},
parameters={},
retryStrategy={"attempts": 1},
tags={},
)

@mock.patch.object(BatchClientHook, "check_job_success")
def test_wait_job_complete_using_waiters(self, check_mock):
mock_waiters = mock.Mock()
Expand Down Expand Up @@ -296,7 +396,9 @@ def test_kill_job(self):
self.batch.on_kill()
self.client_mock.terminate_job.assert_called_once_with(jobId=JOB_ID, reason="Task killed by the user")

@pytest.mark.parametrize("override", ["overrides", "node_overrides", "ecs_properties_override"])
@pytest.mark.parametrize(
"override", ["overrides", "node_overrides", "ecs_properties_override", "eks_properties_override"]
)
@patch(
"airflow.providers.amazon.aws.hooks.batch_client.BatchClientHook.client",
new_callable=mock.PropertyMock,
Expand Down Expand Up @@ -348,6 +450,7 @@ def test_override_not_sent_if_not_set(self, client_mock, override):
"overrides": "containerOverrides",
"node_overrides": "nodeOverrides",
"ecs_properties_override": "ecsPropertiesOverride",
"eks_properties_override": "eksPropertiesOverride",
}

expected_args[py2api[override]] = {"a": "a"}
Expand Down

0 comments on commit d073d20

Please sign in to comment.