diff --git a/airflow/providers/amazon/aws/hooks/batch_client.py b/airflow/providers/amazon/aws/hooks/batch_client.py index 526ab9a8a4f012..1f9917e86bf2e3 100644 --- a/airflow/providers/amazon/aws/hooks/batch_client.py +++ b/airflow/providers/amazon/aws/hooks/batch_client.py @@ -414,43 +414,76 @@ def parse_job_description(job_id: str, response: dict) -> dict: return matching_jobs[0] def get_job_awslogs_info(self, job_id: str) -> dict[str, str] | None: + all_info = self.get_job_all_awslogs_info(job_id) + if not all_info: + return None + if len(all_info) > 1: + self.log.warning( + f"AWS Batch job ({job_id}) has more than one log stream, " f"only returning the first one." + ) + return all_info[0] + + def get_job_all_awslogs_info(self, job_id: str) -> list[dict[str, str]]: """ Parse job description to extract AWS CloudWatch information. :param job_id: AWS Batch Job ID """ - job_container_desc = self.get_job_description(job_id=job_id).get("container", {}) - log_configuration = job_container_desc.get("logConfiguration", {}) - - # In case if user select other "logDriver" rather than "awslogs" - # than CloudWatch logging should be disabled. - # If user not specify anything than expected that "awslogs" will use - # with default settings: - # awslogs-group = /aws/batch/job - # awslogs-region = `same as AWS Batch Job region` - log_driver = log_configuration.get("logDriver", "awslogs") - if log_driver != "awslogs": + job_desc = self.get_job_description(job_id=job_id) + + job_node_properties = job_desc.get("nodeProperties", {}) + job_container_desc = job_desc.get("container", {}) + + if job_node_properties: + # one log config per node + log_configs = [ + p.get("container", {}).get("logConfiguration", {}) + for p in job_node_properties.get("nodeRangeProperties", {}) + ] + # one stream name per attempt + stream_names = [a.get("container", {}).get("logStreamName") for a in job_desc.get("attempts", [])] + elif job_container_desc: + log_configs = [job_container_desc.get("logConfiguration", {})] + stream_name = job_container_desc.get("logStreamName") + stream_names = [stream_name] if stream_name is not None else [] + else: + raise AirflowException( + f"AWS Batch job ({job_id}) is not a supported job type. " + "Supported job types: container, array, multinode." + ) + + # If the user selected another logDriver than "awslogs", then CloudWatch logging is disabled. + if any([c.get("logDriver", "awslogs") != "awslogs" for c in log_configs]): self.log.warning( - "AWS Batch job (%s) uses logDriver (%s). AWS CloudWatch logging disabled.", job_id, log_driver + f"AWS Batch job ({job_id}) uses non-aws log drivers. AWS CloudWatch logging disabled." ) - return None + return [] - awslogs_stream_name = job_container_desc.get("logStreamName") - if not awslogs_stream_name: - # In case of call this method on very early stage of running AWS Batch - # there is possibility than AWS CloudWatch Stream Name not exists yet. - # AWS CloudWatch Stream Name also not created in case of misconfiguration. - self.log.warning("AWS Batch job (%s) doesn't create AWS CloudWatch Stream.", job_id) - return None + if not stream_names: + # If this method is called very early after starting the AWS Batch job, + # there is a possibility that the AWS CloudWatch Stream Name would not exist yet. + # This can also happen in case of misconfiguration. + self.log.warning(f"AWS Batch job ({job_id}) doesn't have any AWS CloudWatch Stream.") + return [] # Try to get user-defined log configuration options - log_options = log_configuration.get("options", {}) - - return { - "awslogs_stream_name": awslogs_stream_name, - "awslogs_group": log_options.get("awslogs-group", "/aws/batch/job"), - "awslogs_region": log_options.get("awslogs-region", self.conn_region_name), - } + log_options = [c.get("options", {}) for c in log_configs] + + # cross stream names with options (i.e. attempts X nodes) to generate all log infos + result = [] + for stream in stream_names: + for option in log_options: + result.append( + { + "awslogs_stream_name": stream, + # If the user did not specify anything, the default settings are: + # awslogs-group = /aws/batch/job + # awslogs-region = `same as AWS Batch Job region` + "awslogs_group": option.get("awslogs-group", "/aws/batch/job"), + "awslogs_region": option.get("awslogs-region", self.conn_region_name), + } + ) + return result @staticmethod def add_jitter(delay: int | float, width: int | float = 1, minima: int | float = 0) -> float: diff --git a/airflow/providers/amazon/aws/operators/batch.py b/airflow/providers/amazon/aws/operators/batch.py index 6565bcecfbaf18..30264d9f4e1a7c 100644 --- a/airflow/providers/amazon/aws/operators/batch.py +++ b/airflow/providers/amazon/aws/operators/batch.py @@ -25,6 +25,7 @@ """ from __future__ import annotations +import warnings from typing import TYPE_CHECKING, Any, Sequence from airflow.compat.functools import cached_property @@ -54,7 +55,9 @@ class BatchOperator(BaseOperator): :param job_name: the name for the job that will run on AWS Batch (templated) :param job_definition: the job definition name on AWS Batch :param job_queue: the queue name on AWS Batch - :param overrides: the `containerOverrides` parameter for boto3 (templated) + :param overrides: DEPRECATED, use container_overrides instead with the same value. + :param container_overrides: the `containerOverrides` parameter for boto3 (templated) + :param node_overrides: the `nodeOverrides` parameter for boto3 (templated) :param array_properties: the `arrayProperties` parameter for boto3 :param parameters: the `parameters` for boto3 (templated) :param job_id: the job ID, usually unknown (None) until the @@ -88,14 +91,19 @@ class BatchOperator(BaseOperator): "job_name", "job_definition", "job_queue", - "overrides", + "container_overrides", "array_properties", + "node_overrides", "parameters", "waiters", "tags", "wait_for_completion", ) - template_fields_renderers = {"overrides": "json", "parameters": "json"} + template_fields_renderers = { + "container_overrides": "json", + "parameters": "json", + "node_overrides": "json", + } @property def operator_extra_links(self): @@ -114,8 +122,10 @@ def __init__( job_name: str, job_definition: str, job_queue: str, - overrides: dict, + overrides: dict | None = None, # deprecated + container_overrides: dict | None = None, array_properties: dict | None = None, + node_overrides: dict | None = None, parameters: dict | None = None, job_id: str | None = None, waiters: Any | None = None, @@ -133,17 +143,43 @@ def __init__( self.job_name = job_name self.job_definition = job_definition self.job_queue = job_queue - self.overrides = overrides or {} - self.array_properties = array_properties or {} + + self.container_overrides = container_overrides + # handle `overrides` deprecation in favor of `container_overrides` + if overrides: + if container_overrides: + # disallow setting both old and new params + raise AirflowException( + "'container_overrides' replaces the 'overrides' parameter. " + "You cannot specify both. Please remove assignation to the deprecated 'overrides'." + ) + self.container_overrides = overrides + warnings.warn( + "Parameter `overrides` is deprecated, Please use `container_overrides` instead.", + DeprecationWarning, + stacklevel=2, + ) + + self.node_overrides = node_overrides + self.array_properties = array_properties self.parameters = parameters or {} self.waiters = waiters self.tags = tags or {} self.wait_for_completion = wait_for_completion - self.hook = BatchClientHook( - max_retries=max_retries, - status_retries=status_retries, - aws_conn_id=aws_conn_id, - region_name=region_name, + + # params for hook + self.max_retries = max_retries + self.status_retries = status_retries + self.aws_conn_id = aws_conn_id + self.region_name = region_name + + @cached_property + def hook(self) -> BatchClientHook: + return BatchClientHook( + max_retries=self.max_retries, + status_retries=self.status_retries, + aws_conn_id=self.aws_conn_id, + region_name=self.region_name, ) def execute(self, context: Context): @@ -174,18 +210,27 @@ def submit_job(self, context: Context): self.job_definition, self.job_queue, ) - self.log.info("AWS Batch job - container overrides: %s", self.overrides) + + if self.container_overrides: + self.log.info("AWS Batch job - container overrides: %s", self.container_overrides) + if self.array_properties: + self.log.info("AWS Batch job - array properties: %s", self.array_properties) + if self.node_overrides: + self.log.info("AWS Batch job - node properties: %s", self.node_overrides) + + args = { + "jobName": self.job_name, + "jobQueue": self.job_queue, + "jobDefinition": self.job_definition, + "arrayProperties": self.array_properties, + "parameters": self.parameters, + "tags": self.tags, + "containerOverrides": self.container_overrides, + "nodeOverrides": self.node_overrides, + } try: - response = self.hook.client.submit_job( - jobName=self.job_name, - jobQueue=self.job_queue, - jobDefinition=self.job_definition, - arrayProperties=self.array_properties, - parameters=self.parameters, - containerOverrides=self.overrides, - tags=self.tags, - ) + response = self.hook.client.submit_job(**trim_none_values(args)) except Exception as e: self.log.error( "AWS Batch job failed submission - job definition: %s - on queue %s", @@ -249,15 +294,24 @@ def monitor_job(self, context: Context): else: self.hook.wait_for_job(self.job_id) - awslogs = self.hook.get_job_awslogs_info(self.job_id) + awslogs = self.hook.get_job_all_awslogs_info(self.job_id) if awslogs: - self.log.info("AWS Batch job (%s) CloudWatch Events details found: %s", self.job_id, awslogs) + self.log.info("AWS Batch job (%s) CloudWatch Events details found. Links to logs:", self.job_id) + link_builder = CloudWatchEventsLink() + for log in awslogs: + self.log.info(link_builder.format_link(**log)) + if len(awslogs) > 1: + # there can be several log streams on multi-node jobs + self.log.warning( + "out of all those logs, we can only link to one in the UI. " "Using the first one." + ) + CloudWatchEventsLink.persist( context=context, operator=self, region_name=self.hook.conn_region_name, aws_partition=self.hook.conn_partition, - **awslogs, + **awslogs[0], ) self.hook.check_job_success(self.job_id) diff --git a/tests/providers/amazon/aws/hooks/test_batch_client.py b/tests/providers/amazon/aws/hooks/test_batch_client.py index 13726e5518ff4e..aef8be1d26262c 100644 --- a/tests/providers/amazon/aws/hooks/test_batch_client.py +++ b/tests/providers/amazon/aws/hooks/test_batch_client.py @@ -280,7 +280,7 @@ def test_job_no_awslogs_stream(self, caplog): "jobs": [ { "jobId": JOB_ID, - "container": {}, + "container": {"logConfiguration": {}}, } ] } @@ -288,7 +288,16 @@ def test_job_no_awslogs_stream(self, caplog): with caplog.at_level(level=logging.WARNING): assert self.batch_client.get_job_awslogs_info(JOB_ID) is None assert len(caplog.records) == 1 - assert "doesn't create AWS CloudWatch Stream" in caplog.messages[0] + assert "doesn't have any AWS CloudWatch Stream" in caplog.messages[0] + + def test_job_not_recognized_job(self): + self.client_mock.describe_jobs.return_value = {"jobs": [{"jobId": JOB_ID}]} + with pytest.raises(AirflowException) as ctx: + self.batch_client.get_job_awslogs_info(JOB_ID) + # It should not retry when this client error occurs + self.client_mock.describe_jobs.assert_called_once_with(jobs=[JOB_ID]) + msg = "is not a supported job type" + assert msg in str(ctx.value) def test_job_splunk_logs(self, caplog): self.client_mock.describe_jobs.return_value = { @@ -307,7 +316,66 @@ def test_job_splunk_logs(self, caplog): with caplog.at_level(level=logging.WARNING): assert self.batch_client.get_job_awslogs_info(JOB_ID) is None assert len(caplog.records) == 1 - assert "uses logDriver (splunk). AWS CloudWatch logging disabled." in caplog.messages[0] + assert "uses non-aws log drivers. AWS CloudWatch logging disabled." in caplog.messages[0] + + def test_job_awslogs_multinode_job(self): + self.client_mock.describe_jobs.return_value = { + "jobs": [ + { + "jobId": JOB_ID, + "attempts": [ + {"container": {"exitCode": 0, "logStreamName": "test/stream/attempt0"}}, + {"container": {"exitCode": 0, "logStreamName": "test/stream/attempt1"}}, + ], + "nodeProperties": { + "mainNode": 0, + "nodeRangeProperties": [ + { + "targetNodes": "0:", + "container": { + "logConfiguration": { + "logDriver": "awslogs", + "options": { + "awslogs-group": "/test/batch/job-a", + "awslogs-region": AWS_REGION, + }, + } + }, + }, + { + "targetNodes": "1:", + "container": { + "logConfiguration": { + "logDriver": "awslogs", + "options": { + "awslogs-group": "/test/batch/job-b", + "awslogs-region": AWS_REGION, + }, + } + }, + }, + ], + }, + } + ] + } + awslogs = self.batch_client.get_job_all_awslogs_info(JOB_ID) + assert len(awslogs) == 4 + assert all([log["awslogs_region"] == AWS_REGION for log in awslogs]) + + combinations = { + ("test/stream/attempt0", "/test/batch/job-a"): False, + ("test/stream/attempt0", "/test/batch/job-b"): False, + ("test/stream/attempt1", "/test/batch/job-a"): False, + ("test/stream/attempt1", "/test/batch/job-b"): False, + } + for log_info in awslogs: + # mark combinations that we see + combinations[(log_info["awslogs_stream_name"], log_info["awslogs_group"])] = True + + assert len(combinations) == 4 + # all combinations listed above should have been seen + assert all(combinations.values()) class TestBatchClientDelays: diff --git a/tests/providers/amazon/aws/operators/test_batch.py b/tests/providers/amazon/aws/operators/test_batch.py index 0ddfcea591713a..c6a923b51ec2d9 100644 --- a/tests/providers/amazon/aws/operators/test_batch.py +++ b/tests/providers/amazon/aws/operators/test_batch.py @@ -18,6 +18,7 @@ from __future__ import annotations from unittest import mock +from unittest.mock import patch import pytest @@ -48,7 +49,7 @@ class TestBatchOperator: @mock.patch.dict("os.environ", AWS_ACCESS_KEY_ID=AWS_ACCESS_KEY_ID) @mock.patch.dict("os.environ", AWS_SECRET_ACCESS_KEY=AWS_SECRET_ACCESS_KEY) @mock.patch("airflow.providers.amazon.aws.hooks.batch_client.AwsBaseHook.get_client_type") - def setup_method(self, method, get_client_type_mock): + def setup_method(self, _, get_client_type_mock): self.get_client_type_mock = get_client_type_mock self.batch = BatchOperator( task_id="task", @@ -58,7 +59,7 @@ def setup_method(self, method, get_client_type_mock): max_retries=self.MAX_RETRIES, status_retries=self.STATUS_RETRIES, parameters=None, - overrides={}, + container_overrides={}, array_properties=None, aws_conn_id="airflow_test", region_name="eu-west-1", @@ -91,8 +92,9 @@ def test_init(self): assert self.batch.hook.max_retries == self.MAX_RETRIES assert self.batch.hook.status_retries == self.STATUS_RETRIES assert self.batch.parameters == {} - assert self.batch.overrides == {} - assert self.batch.array_properties == {} + assert self.batch.container_overrides == {} + assert self.batch.array_properties is None + assert self.batch.node_overrides is None assert self.batch.hook.region_name == "eu-west-1" assert self.batch.hook.aws_conn_id == "airflow_test" assert self.batch.hook.client == self.client_mock @@ -107,8 +109,9 @@ def test_template_fields_overrides(self): "job_name", "job_definition", "job_queue", - "overrides", + "container_overrides", "array_properties", + "node_overrides", "parameters", "waiters", "tags", @@ -131,7 +134,6 @@ def test_execute_without_failures(self, check_mock, wait_mock, job_description_m jobName=JOB_NAME, containerOverrides={}, jobDefinition="hello-world", - arrayProperties={}, parameters={}, tags={}, ) @@ -155,7 +157,6 @@ def test_execute_with_failures(self): jobName=JOB_NAME, containerOverrides={}, jobDefinition="hello-world", - arrayProperties={}, parameters={}, tags={}, ) @@ -166,9 +167,17 @@ def test_wait_job_complete_using_waiters(self, check_mock): self.batch.waiters = mock_waiters self.client_mock.submit_job.return_value = RESPONSE_WITHOUT_FAILURES - self.client_mock.describe_jobs.return_value = {"jobs": [{"jobId": JOB_ID, "status": "SUCCEEDED"}]} + self.client_mock.describe_jobs.return_value = { + "jobs": [ + { + "jobId": JOB_ID, + "status": "SUCCEEDED", + "logStreamName": "logStreamName", + "container": {"logConfiguration": {}}, + } + ] + } self.batch.execute(self.mock_context) - mock_waiters.wait_for_job.assert_called_once_with(JOB_ID) check_mock.assert_called_once_with(JOB_ID) @@ -186,6 +195,65 @@ def test_kill_job(self): self.batch.on_kill() self.client_mock.terminate_job.assert_called_once_with(jobId=JOB_ID, reason="Task killed by the user") + @pytest.mark.parametrize("override", ["overrides", "node_overrides"]) + @patch( + "airflow.providers.amazon.aws.hooks.batch_client.BatchClientHook.client", + new_callable=mock.PropertyMock, + ) + def test_override_not_sent_if_not_set(self, client_mock, override): + """ + check that when setting container override or node override, the other key is not sent + in the API call (which would create a validation error from boto) + """ + override_arg = {override: {"a": "a"}} + batch = BatchOperator( + task_id="task", + job_name=JOB_NAME, + job_queue="queue", + job_definition="hello-world", + **override_arg, + # setting those to bypass code that is not relevant here + do_xcom_push=False, + wait_for_completion=False, + ) + + batch.execute(None) + + expected_args = { + "jobQueue": "queue", + "jobName": JOB_NAME, + "jobDefinition": "hello-world", + "parameters": {}, + "tags": {}, + } + if override == "overrides": + expected_args["containerOverrides"] = {"a": "a"} + else: + expected_args["nodeOverrides"] = {"a": "a"} + client_mock().submit_job.assert_called_once_with(**expected_args) + + def test_deprecated_override_param(self): + with pytest.warns(DeprecationWarning): + _ = BatchOperator( + task_id="task", + job_name=JOB_NAME, + job_queue="queue", + job_definition="hello-world", + overrides={"a": "b"}, # <- the deprecated field + ) + + def test_cant_set_old_and_new_override_param(self): + with pytest.raises(AirflowException): + _ = BatchOperator( + task_id="task", + job_name=JOB_NAME, + job_queue="queue", + job_definition="hello-world", + # can't set both of those, as one is a replacement for the other + overrides={"a": "b"}, + container_overrides={"a": "b"}, + ) + class TestBatchCreateComputeEnvironmentOperator: @mock.patch.object(BatchClientHook, "client")