Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Write events to CloudWatch in batches #1712

Merged
merged 2 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 71 additions & 9 deletions src/dstack/_internal/server/services/logs.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import atexit
import base64
import itertools
from abc import ABC, abstractmethod
from contextlib import contextmanager
from datetime import datetime, timezone
from pathlib import Path
from typing import Iterator, List, Optional, Set, TypedDict, Union
from typing import Iterator, List, Optional, Set, Tuple, TypedDict, Union
from uuid import UUID

from dstack._internal.core.errors import DstackError
Expand Down Expand Up @@ -61,6 +62,16 @@ class _CloudWatchLogEvent(TypedDict):


class CloudWatchLogStorage(LogStorage):
# "The maximum number of log events in a batch is 10,000".
EVENT_MAX_COUNT_IN_BATCH = 10000
# "The maximum batch size is 1,048,576 bytes" — exactly 1 MiB. "This size is calculated
# as the sum of all event messages in UTF-8, plus 26 bytes for each log event".
BATCH_MAX_SIZE = 1048576
# "Each log event can be no larger than 256 KB" — KB means KiB; includes MESSAGE_OVERHEAD_SIZE.
MESSAGE_MAX_SIZE = 262144
# Message size in bytes = len(message.encode("utf-8")) + MESSAGE_OVERHEAD_SIZE.
MESSAGE_OVERHEAD_SIZE = 26

def __init__(self, *, group: str, region: Optional[str] = None) -> None:
with self._wrap_boto_errors():
session = boto3.Session(region_name=region)
Expand Down Expand Up @@ -169,24 +180,75 @@ def write_logs(
)

def _write_logs(self, stream: str, log_events: List[RunnerLogEvent]) -> None:
events = [self._runner_log_event_to_cloudwatch_event(event) for event in log_events]
params = {
"logGroupName": self._group,
"logStreamName": stream,
"logEvents": events,
}
with self._wrap_boto_errors():
self._ensure_stream_exists(stream)
try:
self._client.put_log_events(**params)
self._put_log_events(stream, log_events)
return
except botocore.exceptions.ClientError as e:
if not self._is_resource_not_found_exception(e):
raise
logger.debug("Stream %s not found, recreating", stream)
# The stream is probably deleted due to retention policy, our cache is stale.
self._ensure_stream_exists(stream, force=True)
self._client.put_log_events(**params)
self._put_log_events(stream, log_events)

def _put_log_events(self, stream: str, log_events: List[RunnerLogEvent]) -> None:
for batch in self._get_batch_iter(stream, log_events):
self._client.put_log_events(
logGroupName=self._group,
logStreamName=stream,
logEvents=batch,
)

def _get_batch_iter(
self, stream: str, log_events: List[RunnerLogEvent]
) -> Iterator[List[_CloudWatchLogEvent]]:
shared_event_iter = iter(log_events)
event_iter = shared_event_iter
while True:
batch, excessive_event = self._get_next_batch(stream, event_iter)
if not batch:
return
yield batch
if excessive_event is not None:
event_iter = itertools.chain([excessive_event], shared_event_iter)
else:
event_iter = shared_event_iter

def _get_next_batch(
self, stream: str, event_iter: Iterator[RunnerLogEvent]
) -> Tuple[List[_CloudWatchLogEvent], Optional[RunnerLogEvent]]:
batch: List[_CloudWatchLogEvent] = []
total_size = 0
event_count = 0
for event in event_iter:
# Normally there should not be empty messages.
if not event.message:
continue
cw_event = self._runner_log_event_to_cloudwatch_event(event)
# as message is base64-encoded, length in bytes = length in code points.
message_size = len(cw_event["message"]) + self.MESSAGE_OVERHEAD_SIZE
if message_size > self.MESSAGE_MAX_SIZE:
# we should never hit this limit, as we use `io.Copy` to copy from pty to logs,
# which under the hood uses 32KiB buffer, see runner/internal/executor/executor.go,
# `execJob` -> `io.Copy(logger, ptmx)`
logger.error(
"Stream %s: skipping event %d, message exceeds max size: %d > %d",
stream,
event.timestamp,
message_size,
self.MESSAGE_MAX_SIZE,
)
continue
if total_size + message_size > self.BATCH_MAX_SIZE:
return batch, event
batch.append(cw_event)
total_size += message_size
event_count += 1
if event_count >= self.EVENT_MAX_COUNT_IN_BATCH:
break
return batch, None

def _runner_log_event_to_cloudwatch_event(
self, runner_log_event: RunnerLogEvent
Expand Down
106 changes: 106 additions & 0 deletions src/tests/_internal/server/services/test_logs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import base64
import itertools
from datetime import datetime, timezone
from pathlib import Path
from typing import List
from unittest.mock import Mock, call
from uuid import UUID

Expand Down Expand Up @@ -528,3 +530,107 @@ async def test_write_logs_other_exception(
],
job_logs=[],
)

@pytest.mark.parametrize(
["messages", "expected"],
[
# `messages` is a concatenated list for better readability — each list is a batch
# `expected` is a list of lists, each nested list is a batch.
[
["", "toolong"],
[],
],
[
["111", "toolong", "111"] + ["222222"] + ["333"],
[["111", "111"], ["222222"], ["333"]],
],
[
["111", "111"] + ["222", "222"],
[["111", "111"], ["222", "222"]],
],
[
["111", "111"] + ["222"],
[["111", "111"], ["222"]],
],
[
["111"] + ["222222"] + ["333", "333"],
[["111"], ["222222"], ["333", "333"]],
],
],
)
@pytest.mark.asyncio
async def test_write_logs_batching_by_size(
self,
monkeypatch: pytest.MonkeyPatch,
project: ProjectModel,
log_storage: CloudWatchLogStorage,
mock_client: Mock,
mock_ensure_stream_exists: Mock,
messages: List[str],
expected: List[List[str]],
):
# maximum 6 bytes: 12 (in base64) + 26 (overhead) = 34
monkeypatch.setattr(CloudWatchLogStorage, "MESSAGE_MAX_SIZE", 34)
monkeypatch.setattr(CloudWatchLogStorage, "BATCH_MAX_SIZE", 60)
log_storage.write_logs(
project=project,
run_name="test-run",
job_submission_id=UUID("1b0e1b45-2f8c-4ab6-8010-a0d1a3e44e0e"),
runner_logs=[
RunnerLogEvent(timestamp=1696586513234, message=message.encode())
for message in messages
],
job_logs=[],
)
assert mock_client.put_log_events.call_count == len(expected)
actual = [
[base64.b64decode(e["message"]).decode() for e in c.kwargs["logEvents"]]
for c in mock_client.put_log_events.call_args_list
]
assert actual == expected

@pytest.mark.parametrize(
["messages", "expected"],
[
# `messages` is a concatenated list for better readability — each list is a batch
# `expected` is a list of lists, each nested list is a batch.
[
["111", "111", "111"] + ["222"],
[["111", "111", "111"], ["222"]],
],
[
["111", "111", "111"] + ["222", "222", "toolong", "", "222222"],
[["111", "111", "111"], ["222", "222", "222222"]],
],
],
)
@pytest.mark.asyncio
async def test_write_logs_batching_by_count(
self,
monkeypatch: pytest.MonkeyPatch,
project: ProjectModel,
log_storage: CloudWatchLogStorage,
mock_client: Mock,
mock_ensure_stream_exists: Mock,
messages: List[str],
expected: List[List[str]],
):
# maximum 6 bytes: 12 (in base64) + 26 (overhead) = 34
monkeypatch.setattr(CloudWatchLogStorage, "MESSAGE_MAX_SIZE", 34)
monkeypatch.setattr(CloudWatchLogStorage, "EVENT_MAX_COUNT_IN_BATCH", 3)
log_storage.write_logs(
project=project,
run_name="test-run",
job_submission_id=UUID("1b0e1b45-2f8c-4ab6-8010-a0d1a3e44e0e"),
runner_logs=[
RunnerLogEvent(timestamp=1696586513234, message=message.encode())
for message in messages
],
job_logs=[],
)
assert mock_client.put_log_events.call_count == len(expected)
actual = [
[base64.b64decode(e["message"]).decode() for e in c.kwargs["logEvents"]]
for c in mock_client.put_log_events.call_args_list
]
assert actual == expected