Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EKS Overrides for AWS Batch submit_job #40718

Merged
merged 11 commits into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions airflow/providers/amazon/aws/hooks/batch_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def submit_job(
parameters: dict,
containerOverrides: dict,
ecsPropertiesOverride: dict,
eksPropertiesOverride: dict,
vincbeck marked this conversation as resolved.
Show resolved Hide resolved
tags: dict,
) -> dict:
"""
Expand All @@ -122,6 +123,8 @@ def submit_job(
:param ecsPropertiesOverride: the same parameter that boto3 will receive
:param eksPropertiesOverride: the same parameter that boto3 will receive
:param tags: the same parameter that boto3 will receive
:return: an API response
Expand Down
8 changes: 8 additions & 0 deletions airflow/providers/amazon/aws/operators/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class BatchOperator(BaseOperator):
:param overrides: DEPRECATED, use container_overrides instead with the same value.
:param container_overrides: the `containerOverrides` parameter for boto3 (templated)
:param ecs_properties_override: the `ecsPropertiesOverride` parameter for boto3 (templated)
:param eks_properties_override: the `eksPropertiesOverride` parameter for boto3 (templated)
:param node_overrides: the `nodeOverrides` parameter for boto3 (templated)
:param share_identifier: The share identifier for the job. Don't specify this parameter if the job queue
doesn't have a scheduling policy.
Expand Down Expand Up @@ -116,6 +117,7 @@ class BatchOperator(BaseOperator):
"container_overrides",
"array_properties",
"ecs_properties_override",
"eks_properties_override",
"node_overrides",
"parameters",
"retry_strategy",
Expand All @@ -129,6 +131,7 @@ class BatchOperator(BaseOperator):
"container_overrides": "json",
"parameters": "json",
"ecs_properties_override": "json",
"eks_properties_override": "json",
"node_overrides": "json",
"retry_strategy": "json",
}
Expand Down Expand Up @@ -166,6 +169,7 @@ def __init__(
container_overrides: dict | None = None,
array_properties: dict | None = None,
ecs_properties_override: dict | None = None,
eks_properties_override: dict | None = None,
node_overrides: dict | None = None,
share_identifier: str | None = None,
scheduling_priority_override: int | None = None,
Expand Down Expand Up @@ -208,6 +212,7 @@ def __init__(
)

self.ecs_properties_override = ecs_properties_override
self.eks_properties_override = eks_properties_override
self.node_overrides = node_overrides
self.share_identifier = share_identifier
self.scheduling_priority_override = scheduling_priority_override
Expand Down Expand Up @@ -307,6 +312,8 @@ def submit_job(self, context: Context):
self.log.info("AWS Batch job - array properties: %s", self.array_properties)
if self.ecs_properties_override:
self.log.info("AWS Batch job - ECS properties: %s", self.ecs_properties_override)
if self.eks_properties_override:
self.log.info("AWS Batch job - EKS properties: %s", self.eks_properties_override)
if self.node_overrides:
self.log.info("AWS Batch job - node properties: %s", self.node_overrides)

Expand All @@ -319,6 +326,7 @@ def submit_job(self, context: Context):
"tags": self.tags,
"containerOverrides": self.container_overrides,
"ecsPropertiesOverride": self.ecs_properties_override,
"eksPropertiesOverride": self.eks_properties_override,
"nodeOverrides": self.node_overrides,
"retryStrategy": self.retry_strategy,
"shareIdentifier": self.share_identifier,
Expand Down
105 changes: 104 additions & 1 deletion tests/providers/amazon/aws/operators/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def test_init_defaults(self):
assert batch_job.container_overrides is None
assert batch_job.array_properties is None
assert batch_job.ecs_properties_override is None
assert batch_job.eks_properties_override is None
assert batch_job.node_overrides is None
assert batch_job.share_identifier is None
assert batch_job.scheduling_priority_override is None
Expand All @@ -151,6 +152,7 @@ def test_template_fields_overrides(self):
"container_overrides",
"array_properties",
"ecs_properties_override",
"eks_properties_override",
"node_overrides",
"parameters",
"retry_strategy",
Expand Down Expand Up @@ -262,6 +264,104 @@ def test_execute_with_ecs_overrides(self, check_mock, wait_mock, job_description
tags={},
)

@mock.patch.object(BatchClientHook, "get_job_description")
@mock.patch.object(BatchClientHook, "wait_for_job")
@mock.patch.object(BatchClientHook, "check_job_success")
def test_execute_with_eks_overrides(self, check_mock, wait_mock, job_description_mock):
self.batch.container_overrides = None
self.batch.eks_properties_override = {
"podProperties": [
{
"containers": [
{
"image": "string",
"command": [
"string",
],
"args": [
"string",
],
"env": [
{"name": "string", "value": "string"},
],
"resources": [{"limits": {"string": "string"}, "requests": {"string": "string"}}],
},
],
"initContainers": [
{
"image": "string",
"command": [
"string",
],
"args": [
"string",
],
"env": [
{"name": "string", "value": "string"},
],
"resources": [{"limits": {"string": "string"}, "requests": {"string": "string"}}],
},
],
"metadata": {
"labels": {"string": "string"},
},
},
]
}
self.batch.execute(self.mock_context)

self.client_mock.submit_job.assert_called_once_with(
jobQueue="queue",
jobName=JOB_NAME,
jobDefinition="hello-world",
eksPropertiesOverride={
"podProperties": [
{
"containers": [
{
"image": "string",
"command": [
"string",
],
"args": [
"string",
],
"env": [
{"name": "string", "value": "string"},
],
"resources": [
{"limits": {"string": "string"}, "requests": {"string": "string"}}
],
},
],
"initContainers": [
{
"image": "string",
"command": [
"string",
],
"args": [
"string",
],
"env": [
{"name": "string", "value": "string"},
],
"resources": [
{"limits": {"string": "string"}, "requests": {"string": "string"}}
],
},
],
"metadata": {
"labels": {"string": "string"},
},
},
]
},
parameters={},
retryStrategy={"attempts": 1},
tags={},
)

@mock.patch.object(BatchClientHook, "check_job_success")
def test_wait_job_complete_using_waiters(self, check_mock):
mock_waiters = mock.Mock()
Expand Down Expand Up @@ -296,7 +396,9 @@ def test_kill_job(self):
self.batch.on_kill()
self.client_mock.terminate_job.assert_called_once_with(jobId=JOB_ID, reason="Task killed by the user")

@pytest.mark.parametrize("override", ["overrides", "node_overrides", "ecs_properties_override"])
@pytest.mark.parametrize(
"override", ["overrides", "node_overrides", "ecs_properties_override", "eks_properties_override"]
)
@patch(
"airflow.providers.amazon.aws.hooks.batch_client.BatchClientHook.client",
new_callable=mock.PropertyMock,
Expand Down Expand Up @@ -348,6 +450,7 @@ def test_override_not_sent_if_not_set(self, client_mock, override):
"overrides": "containerOverrides",
"node_overrides": "nodeOverrides",
"ecs_properties_override": "ecsPropertiesOverride",
"eks_properties_override": "eksPropertiesOverride",
}

expected_args[py2api[override]] = {"a": "a"}
Expand Down