diff --git a/src/dstack/_internal/server/services/logs.py b/src/dstack/_internal/server/services/logs.py index 76ce14cfc..d771cfbab 100644 --- a/src/dstack/_internal/server/services/logs.py +++ b/src/dstack/_internal/server/services/logs.py @@ -1,10 +1,11 @@ import atexit import base64 +import itertools from abc import ABC, abstractmethod from contextlib import contextmanager from datetime import datetime, timezone from pathlib import Path -from typing import Iterator, List, Optional, Set, TypedDict, Union +from typing import Iterator, List, Optional, Set, Tuple, TypedDict, Union from uuid import UUID from dstack._internal.core.errors import DstackError @@ -61,6 +62,16 @@ class _CloudWatchLogEvent(TypedDict): class CloudWatchLogStorage(LogStorage): + # "The maximum number of log events in a batch is 10,000". + EVENT_MAX_COUNT_IN_BATCH = 10000 + # "The maximum batch size is 1,048,576 bytes" — exactly 1 MiB. "This size is calculated + # as the sum of all event messages in UTF-8, plus 26 bytes for each log event". + BATCH_MAX_SIZE = 1048576 + # "Each log event can be no larger than 256 KB" — KB means KiB; includes MESSAGE_OVERHEAD_SIZE. + MESSAGE_MAX_SIZE = 262144 + # Message size in bytes = len(message.encode("utf-8")) + MESSAGE_OVERHEAD_SIZE. + MESSAGE_OVERHEAD_SIZE = 26 + def __init__(self, *, group: str, region: Optional[str] = None) -> None: with self._wrap_boto_errors(): session = boto3.Session(region_name=region) @@ -169,16 +180,10 @@ def write_logs( ) def _write_logs(self, stream: str, log_events: List[RunnerLogEvent]) -> None: - events = [self._runner_log_event_to_cloudwatch_event(event) for event in log_events] - params = { - "logGroupName": self._group, - "logStreamName": stream, - "logEvents": events, - } with self._wrap_boto_errors(): self._ensure_stream_exists(stream) try: - self._client.put_log_events(**params) + self._put_log_events(stream, log_events) return except botocore.exceptions.ClientError as e: if not self._is_resource_not_found_exception(e): @@ -186,7 +191,64 @@ def _write_logs(self, stream: str, log_events: List[RunnerLogEvent]) -> None: logger.debug("Stream %s not found, recreating", stream) # The stream is probably deleted due to retention policy, our cache is stale. self._ensure_stream_exists(stream, force=True) - self._client.put_log_events(**params) + self._put_log_events(stream, log_events) + + def _put_log_events(self, stream: str, log_events: List[RunnerLogEvent]) -> None: + for batch in self._get_batch_iter(stream, log_events): + self._client.put_log_events( + logGroupName=self._group, + logStreamName=stream, + logEvents=batch, + ) + + def _get_batch_iter( + self, stream: str, log_events: List[RunnerLogEvent] + ) -> Iterator[List[_CloudWatchLogEvent]]: + shared_event_iter = iter(log_events) + event_iter = shared_event_iter + while True: + batch, excessive_event = self._get_next_batch(stream, event_iter) + if not batch: + return + yield batch + if excessive_event is not None: + event_iter = itertools.chain([excessive_event], shared_event_iter) + else: + event_iter = shared_event_iter + + def _get_next_batch( + self, stream: str, event_iter: Iterator[RunnerLogEvent] + ) -> Tuple[List[_CloudWatchLogEvent], Optional[RunnerLogEvent]]: + batch: List[_CloudWatchLogEvent] = [] + total_size = 0 + event_count = 0 + for event in event_iter: + # Normally there should not be empty messages. + if not event.message: + continue + cw_event = self._runner_log_event_to_cloudwatch_event(event) + # as message is base64-encoded, length in bytes = length in code points. + message_size = len(cw_event["message"]) + self.MESSAGE_OVERHEAD_SIZE + if message_size > self.MESSAGE_MAX_SIZE: + # we should never hit this limit, as we use `io.Copy` to copy from pty to logs, + # which under the hood uses 32KiB buffer, see runner/internal/executor/executor.go, + # `execJob` -> `io.Copy(logger, ptmx)` + logger.error( + "Stream %s: skipping event %d, message exceeds max size: %d > %d", + stream, + event.timestamp, + message_size, + self.MESSAGE_MAX_SIZE, + ) + continue + if total_size + message_size > self.BATCH_MAX_SIZE: + return batch, event + batch.append(cw_event) + total_size += message_size + event_count += 1 + if event_count >= self.EVENT_MAX_COUNT_IN_BATCH: + break + return batch, None def _runner_log_event_to_cloudwatch_event( self, runner_log_event: RunnerLogEvent diff --git a/src/tests/_internal/server/services/test_logs.py b/src/tests/_internal/server/services/test_logs.py index 7c2423e94..005c27c40 100644 --- a/src/tests/_internal/server/services/test_logs.py +++ b/src/tests/_internal/server/services/test_logs.py @@ -1,6 +1,8 @@ +import base64 import itertools from datetime import datetime, timezone from pathlib import Path +from typing import List from unittest.mock import Mock, call from uuid import UUID @@ -528,3 +530,107 @@ async def test_write_logs_other_exception( ], job_logs=[], ) + + @pytest.mark.parametrize( + ["messages", "expected"], + [ + # `messages` is a concatenated list for better readability — each list is a batch + # `expected` is a list of lists, each nested list is a batch. + [ + ["", "toolong"], + [], + ], + [ + ["111", "toolong", "111"] + ["222222"] + ["333"], + [["111", "111"], ["222222"], ["333"]], + ], + [ + ["111", "111"] + ["222", "222"], + [["111", "111"], ["222", "222"]], + ], + [ + ["111", "111"] + ["222"], + [["111", "111"], ["222"]], + ], + [ + ["111"] + ["222222"] + ["333", "333"], + [["111"], ["222222"], ["333", "333"]], + ], + ], + ) + @pytest.mark.asyncio + async def test_write_logs_batching_by_size( + self, + monkeypatch: pytest.MonkeyPatch, + project: ProjectModel, + log_storage: CloudWatchLogStorage, + mock_client: Mock, + mock_ensure_stream_exists: Mock, + messages: List[str], + expected: List[List[str]], + ): + # maximum 6 bytes: 12 (in base64) + 26 (overhead) = 34 + monkeypatch.setattr(CloudWatchLogStorage, "MESSAGE_MAX_SIZE", 34) + monkeypatch.setattr(CloudWatchLogStorage, "BATCH_MAX_SIZE", 60) + log_storage.write_logs( + project=project, + run_name="test-run", + job_submission_id=UUID("1b0e1b45-2f8c-4ab6-8010-a0d1a3e44e0e"), + runner_logs=[ + RunnerLogEvent(timestamp=1696586513234, message=message.encode()) + for message in messages + ], + job_logs=[], + ) + assert mock_client.put_log_events.call_count == len(expected) + actual = [ + [base64.b64decode(e["message"]).decode() for e in c.kwargs["logEvents"]] + for c in mock_client.put_log_events.call_args_list + ] + assert actual == expected + + @pytest.mark.parametrize( + ["messages", "expected"], + [ + # `messages` is a concatenated list for better readability — each list is a batch + # `expected` is a list of lists, each nested list is a batch. + [ + ["111", "111", "111"] + ["222"], + [["111", "111", "111"], ["222"]], + ], + [ + ["111", "111", "111"] + ["222", "222", "toolong", "", "222222"], + [["111", "111", "111"], ["222", "222", "222222"]], + ], + ], + ) + @pytest.mark.asyncio + async def test_write_logs_batching_by_count( + self, + monkeypatch: pytest.MonkeyPatch, + project: ProjectModel, + log_storage: CloudWatchLogStorage, + mock_client: Mock, + mock_ensure_stream_exists: Mock, + messages: List[str], + expected: List[List[str]], + ): + # maximum 6 bytes: 12 (in base64) + 26 (overhead) = 34 + monkeypatch.setattr(CloudWatchLogStorage, "MESSAGE_MAX_SIZE", 34) + monkeypatch.setattr(CloudWatchLogStorage, "EVENT_MAX_COUNT_IN_BATCH", 3) + log_storage.write_logs( + project=project, + run_name="test-run", + job_submission_id=UUID("1b0e1b45-2f8c-4ab6-8010-a0d1a3e44e0e"), + runner_logs=[ + RunnerLogEvent(timestamp=1696586513234, message=message.encode()) + for message in messages + ], + job_logs=[], + ) + assert mock_client.put_log_events.call_count == len(expected) + actual = [ + [base64.b64decode(e["message"]).decode() for e in c.kwargs["logEvents"]] + for c in mock_client.put_log_events.call_args_list + ] + assert actual == expected