Skip to content

Commit

Permalink
Refactor socket communication handling (#79)
Browse files Browse the repository at this point in the history
* Refactor socket communication handling

Introduce a `contextmanager` to streamline socket attachment and closing. Extract socket processing into a new helper function to improve readability and maintainability of the `docker_communicate` function.

* Update ruff linter command in auto-format workflow

Changed the ruff linter command from 'ruff' to 'ruff check' in the GitHub Actions auto-format workflow. This aligns the command usage with best practices and ensures consistency in the linting process.

---------

Co-authored-by: meanmail <[email protected]>
  • Loading branch information
meanmail and meanmail authored Sep 13, 2024
1 parent 324cd02 commit 5a8b74e
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 49 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/auto-format.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
run: poetry run ruff format

- name: Check files using the ruff linter
run: poetry run ruff --fix --unsafe-fixes --preview --exit-zero .
run: poetry run ruff check --fix --unsafe-fixes --preview --exit-zero .

- name: Commit changes
uses: EndBug/add-and-commit@v9
Expand Down
2 changes: 1 addition & 1 deletion epicbox/sandboxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def create(
raise ValueError(msg)

if not isinstance(workdir, WorkingDirectory | None):
msg = (
msg = ( # type: ignore[unreachable]
"Invalid 'workdir', "
"it should be created using 'working_directory' context manager"
)
Expand Down
130 changes: 83 additions & 47 deletions epicbox/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import socket
import struct
import time
from contextlib import contextmanager
from typing import Any, TYPE_CHECKING

import dateutil.parser
Expand All @@ -22,6 +23,8 @@
from epicbox import config, exceptions

if TYPE_CHECKING:
from collections.abc import Iterator

from docker.models.containers import Container

logger = structlog.get_logger()
Expand Down Expand Up @@ -157,6 +160,54 @@ def _socket_write(sock: socket.SocketIO, data: bytes) -> int:
raise


def process_sock(
sock: socket.SocketIO, stdin: bytes | None, log: structlog.BoundLogger
) -> tuple[bytes, int]:
"""Process the socket IO.
Read data from the socket if it is ready for reading.
Write data to the socket if it is ready for writing.
:returns: A tuple containing the received data and the number of bytes written.
"""
ready_to_read, ready_to_write, _ = select.select([sock], [sock], [], 1)
received_data: bytes = b""
bytes_written = 0
if ready_to_read:
data = _socket_read(sock)
if data is None:
msg = "Received EOF from the container"
raise EOFError(msg)
received_data = data

if ready_to_write and stdin:
bytes_written = _socket_write(sock, stdin)
if bytes_written >= len(stdin):
log.debug(
"All input data has been sent. "
"Shut down the write half of the socket.",
)
sock._sock.shutdown(socket.SHUT_WR) # type: ignore[attr-defined]

if not ready_to_read and (not ready_to_write or not stdin):
# Save CPU time by sleeping when there is no IO activity.
time.sleep(0.05)

return received_data, bytes_written


@contextmanager
def attach_socket(
docker_client: DockerClient,
container: Container,
params: dict[str, Any],
) -> Iterator[socket.SocketIO]:
sock = docker_client.api.attach_socket(container.id, params=params)

yield sock

sock.close()


def docker_communicate(
container: Container,
stdin: bytes | None = None,
Expand Down Expand Up @@ -196,44 +247,33 @@ def docker_communicate(
"stream": 1,
"logs": 0,
}
sock = docker_client.api.attach_socket(container.id, params=params)
sock._sock.setblocking(False) # Make socket non-blocking
log.info(
"Attached to the container",
params=params,
fd=sock.fileno(),
timeout=timeout,
)
if not stdin:
log.debug("There is no input data. Shut down the write half of the socket.")
sock._sock.shutdown(socket.SHUT_WR)
if start_container:
container.start()
log.info("Container started")

stream_data = b""
start_time = time.monotonic()
while timeout is None or time.monotonic() - start_time < timeout:
read_ready, write_ready, _ = select.select([sock], [sock], [], 1)
is_io_active = bool(read_ready or (write_ready and stdin))

if read_ready:

with attach_socket(docker_client, container, params) as sock:
sock._sock.setblocking(False) # type: ignore[attr-defined] # Make socket non-blocking
log.info(
"Attached to the container",
params=params,
fd=sock.fileno(),
timeout=timeout,
)
if not stdin:
log.debug("There is no input data. Shut down the write half of the socket.")
sock._sock.shutdown(socket.SHUT_WR) # type: ignore[attr-defined]
if start_container:
container.start()
log.info("Container started")

stream_data = b""
start_time = time.monotonic()
while timeout is None or time.monotonic() - start_time < timeout:
try:
data = _socket_read(sock)
received_data, bytes_written = process_sock(sock, stdin, log)
except ConnectionResetError:
log.warning(
"Connection reset caught on reading the container "
"output stream. Break communication",
)
break
if data is None:
log.debug("Container output reached EOF. Closing the socket")
break
stream_data += data

if write_ready and stdin:
try:
written = _socket_write(sock, stdin)
except BrokenPipeError:
# Broken pipe may happen when a container terminates quickly
# (e.g. OOM Killer) and docker manages to close the socket
Expand All @@ -242,22 +282,18 @@ def docker_communicate(
"Broken pipe caught on writing to stdin. Break communication",
)
break
stdin = stdin[written:]
if not stdin:
log.debug(
"All input data has been sent. Shut down the write "
"half of the socket.",
)
sock._sock.shutdown(socket.SHUT_WR)

if not is_io_active:
# Save CPU time
time.sleep(0.05)
else:
sock.close()
msg = "Container didn't terminate after timeout seconds"
raise TimeoutError(msg)
sock.close()
except EOFError:
log.debug("Container output reached EOF. Closing the socket")
break

if received_data:
stream_data += received_data
if stdin and bytes_written > 0:
stdin = stdin[bytes_written:]
else:
msg = "Container didn't terminate after timeout seconds"
raise TimeoutError(msg)

return demultiplex_docker_stream(stream_data)


Expand Down

0 comments on commit 5a8b74e

Please sign in to comment.