From 3a65517244bbfe61b938e78e9b8f98dea589c2a7 Mon Sep 17 00:00:00 2001 From: Ryan Date: Fri, 22 Sep 2023 16:01:07 -0600 Subject: [PATCH] chore: write log shipper in python [MLG-993] Instead of orchestrating multiple processes with bash and confusing/complex process substitution, with all the trapping of signals and `printf x > ...` nonsense, just write a wrapper process in a real programming language. This reduces the complexity of our logging solution, and prevents the out-of-order bugs inherent in having two separate log shippers, one for stdout and one for stderr. The new ship_logs.py has the following features: - it launches and monitors a child command - it shares a log buffer for shipping stdout and stderr - it has the same log parsing regexes as enrich_logging.py - it converts carriage returns to newlines - it forwards signals to its child process - it exits after a maximum of DET_LOG_WAIT_TIME or 30 seconds - it depends on the python interpreter, but only the interpreter (all imports are from the standard library) - in the special case that the child process can't be started, it ships an explanation of what happened to the master and exits with standard bash exit codes --- e2e_tests/tests/cluster/test_logging.py | 11 +- e2e_tests/tests/cluster/test_ship_logs.py | 512 +++++++++++++++ harness/determined/__init__.py | 3 +- master/internal/rm/kubernetesrm/pod_test.go | 42 +- master/internal/rm/kubernetesrm/spec.go | 2 - master/pkg/etc/etc.go | 14 +- master/pkg/tasks/const.go | 15 +- master/pkg/tasks/task.go | 32 +- master/pkg/tasks/task_trial.go | 4 - master/static/srv/command-entrypoint.sh | 9 +- master/static/srv/enrich_task_logs.py | 206 ------ master/static/srv/entrypoint.sh | 7 +- .../static/srv/gc-checkpoints-entrypoint.sh | 7 +- master/static/srv/notebook-entrypoint.sh | 5 +- master/static/srv/shell-entrypoint.sh | 7 +- master/static/srv/ship-logs.sh | 21 + master/static/srv/ship_logs.py | 609 ++++++++++++++++++ master/static/srv/task-logging-setup.sh | 123 ---- master/static/srv/task-logging-teardown.sh | 39 -- master/static/srv/task-setup.sh | 41 ++ master/static/srv/task-signal-handling.sh | 45 -- master/static/srv/tensorboard-entrypoint.sh | 7 +- 22 files changed, 1227 insertions(+), 534 deletions(-) create mode 100644 e2e_tests/tests/cluster/test_ship_logs.py delete mode 100644 master/static/srv/enrich_task_logs.py create mode 100755 master/static/srv/ship-logs.sh create mode 100644 master/static/srv/ship_logs.py delete mode 100644 master/static/srv/task-logging-setup.sh delete mode 100644 master/static/srv/task-logging-teardown.sh create mode 100644 master/static/srv/task-setup.sh delete mode 100644 master/static/srv/task-signal-handling.sh diff --git a/e2e_tests/tests/cluster/test_logging.py b/e2e_tests/tests/cluster/test_logging.py index 9ef4d7bc5a67..e10dfb55779c 100644 --- a/e2e_tests/tests/cluster/test_logging.py +++ b/e2e_tests/tests/cluster/test_logging.py @@ -1,8 +1,8 @@ import functools import re +import socket from typing import Any, Callable, Dict, Iterable, Optional, Union -import _socket import pytest from determined.cli import command @@ -78,13 +78,6 @@ def test_task_logs(task_type: str, task_config: Dict[str, Any], log_regex: Any) rps = bindings.get_GetResourcePools(session) assert rps.resourcePools and len(rps.resourcePools) > 0, "missing resource pool" - if ( - rps.resourcePools[0].type == bindings.v1ResourcePoolType.K8S - and task_type == command.TaskTypeCommand - ): - # TODO(DET-6712): Investigate intermittent slowness with K8s command logs. - pytest.skip("DET-6712: Investigate intermittent slowness with K8s command logs") - if task_type == command.TaskTypeTensorBoard: exp_id = exp.run_basic_test( conf.fixtures_path("no_op/single.yaml"), @@ -117,7 +110,7 @@ def task_log_fields(follow: Optional[bool] = None) -> Iterable[LogFields]: functools.partial(api.task_logs, session, task_id), functools.partial(bindings.get_TaskLogsFields, session, taskId=task_id), ) - except _socket.timeout: + except socket.timeout: raise TimeoutError(f"timed out waiting for {task_type} with id {task_id}") finally: diff --git a/e2e_tests/tests/cluster/test_ship_logs.py b/e2e_tests/tests/cluster/test_ship_logs.py new file mode 100644 index 000000000000..e56e9df14c79 --- /dev/null +++ b/e2e_tests/tests/cluster/test_ship_logs.py @@ -0,0 +1,512 @@ +import io +import json +import logging +import os +import shutil +import signal +import socket +import ssl +import subprocess +import sys +import tempfile +import textwrap +import threading +import time +from typing import Any, List, Optional + +import pytest + +here = os.path.dirname(__file__) +static_srv = os.path.join(here, "../../../master/static/srv") +old = sys.path +try: + sys.path = [static_srv] + sys.path + import ship_logs +finally: + sys.path = old + + +class ShipLogServer: + """ + A tiny, hacky http(s) server for testing ship logs. + + It's about the same amount of code as subclassing the stdlib's SimpleHTTPRequestHandler, but it + shuts down much faster. + """ + + def __init__(self, ctx: Optional[ssl.SSLContext] = None) -> None: + self.ctx = ctx + self.quit = False + self.logs = [] # type: List[str] + + self.listener = socket.socket() + self.listener.bind(("127.0.0.1", 0)) + self.listener.listen(5) + + _, self.port = self.listener.getsockname() + + self.thread = threading.Thread(target=self.serve_requests) + self.thread.start() + + def __enter__(self) -> "ShipLogServer": + return self + + def __exit__(self, *args: Any) -> None: + self.quit = True + # Wake up the accept() call. + try: + with socket.socket() as s: + s.connect(("127.0.0.1", self.port)) + s.send(b"quit") + except Exception: + logging.error("failed to wake up accept loop", exc_info=True) + self.thread.join() + self.listener.close() + + def serve_requests(self) -> None: + try: + while not self.quit: + # Accept a conneciton. + s, _ = self.listener.accept() + try: + if self.quit: + return + if self.ctx: + s = self.ctx.wrap_socket(s, server_side=True) + try: + self.serve_one_request(s) + except Exception: + logging.error("error reading request", exc_info=True) + finally: + s.close() + except Exception: + logging.error("server crashed", exc_info=True) + + def serve_one_request(self, s: socket.socket) -> None: + # Receive headers. + hdrs = b"" + while b"\r\n\r\n" not in hdrs: + buf = s.recv(4096) + if not buf: + # EOF + return + hdrs += buf + # Receive body until we have valid json. + hdrs, body = hdrs.split(b"\r\n\r\n", maxsplit=1) + while True: + try: + jbody = json.loads(body) + break + except json.decoder.JSONDecodeError: + # Must not have the full body yet. + pass + buf = s.recv(4096) + if not buf: + # EOF + return + body += buf + + # Remember the logs we saw. + self.logs.extend(j["log"] for j in jbody) + + # Send a response. + s.sendall(b"HTTP/1.1 200 OK\r\n\r\n") + + def master_url(self) -> str: + return f"http://127.0.0.1:{self.port}" + + +def mkcmd(script: str) -> List[str]: + # -u: don't buffer stdout/stderr. + return [sys.executable, "-u", "-c", textwrap.dedent(script)] + + +class TestShipLogs: + """ + A suite of unit tests for ship_logs.py + + Yeah, it's a hack that these tests live in e2e tests. But since they test python code it's + just easier this way. + """ + + def run_ship_logs( + self, + master_url: str, + cmd: List[str], + log_wait_time: float = 30, + cert_name: str = "", + cert_file: str = "", + ) -> int: + exit_code = ship_logs.main( + master_url=master_url, + cert_name=cert_name, + cert_file=cert_file, + metadata={}, + token="token", + emit_stdout_logs=False, + cmd=cmd, + log_wait_time=log_wait_time, + ) + assert isinstance(exit_code, int), exit_code + return exit_code + + @pytest.mark.e2e_cpu + def test_preserve_exit(self) -> None: + cmd = mkcmd( + """ + import sys + print("hi", file=sys.stdout, flush=True) + print("bye", file=sys.stderr, flush=True) + sys.exit(9) + """ + ) + with ShipLogServer() as srv: + exit_code = self.run_ship_logs(srv.master_url(), cmd) + assert exit_code == 9, exit_code + # Ordering of stdout vs stderr is non-deterministic. + assert set(srv.logs) == {"hi\n", "bye\n"}, srv.logs + + @pytest.mark.e2e_cpu + def test_cr_to_lf(self) -> None: + cmd = mkcmd( + r""" + print("1\n", end="") + print("2\r", end="") + print("3\r\n", end="") + """ + ) + with ShipLogServer() as srv: + exit_code = self.run_ship_logs(srv.master_url(), cmd) + assert exit_code == 0, exit_code + assert "".join(srv.logs) == "1\n2\n3\n\n", srv.logs + + @pytest.mark.e2e_cpu + def test_stdout_stderr_ordering(self) -> None: + # Stdout and stderr are collected on different threads, and therefore _can't_ be perfectly + # synced. But they should be "approximately" synced; i.e. each 1-second batch should + # contain both log types. + # + # Most dev machines probably will be fine with small timeouts, but CI machines might be + # slower and we allow up to 0.2 seconds of slop. + timeouts = [0.001, 0.2] + for timeout in timeouts: + cmd = mkcmd( + f""" + import sys + import time + print("1", file=sys.stdout, flush=True) + time.sleep({timeout}) + print("2", file=sys.stderr, flush=True) + time.sleep({timeout}) + print("3", file=sys.stdout, flush=True) + time.sleep({timeout}) + print("4", file=sys.stderr, flush=True) + time.sleep({timeout}) + print("5", file=sys.stdout, flush=True) + time.sleep({timeout}) + print("6", file=sys.stderr, flush=True) + """ + ) + with ShipLogServer() as srv: + exit_code = self.run_ship_logs(srv.master_url(), cmd) + assert exit_code == 0, exit_code + if "".join(srv.logs) == "1\n2\n3\n4\n5\n6\n": + # Success + break + elif timeout == timeouts[-1]: + # Failed, even on the highest timeout + raise ValueError("".join(srv.logs)) + + @pytest.mark.e2e_cpu + def test_signal_forwarding(self) -> None: + cmd = mkcmd( + """ + import signal + import time + + def handle_sigint(*arg): + print("caught sigint!") + + signal.signal(signal.SIGINT, handle_sigint) + + print("ready!", flush=True) + + time.sleep(5) + """ + ) + with ShipLogServer() as srv: + # Start a subprocess so we can signal it. + env = { + "DET_MASTER": srv.master_url(), + "DET_SESSION_TOKEN": "token", + "DET_TASK_LOGGING_METADATA": "{}", + "DET_SHIPPER_EMIT_STDOUT_LOGS": "1", + } + fullcmd = [sys.executable, "-u", ship_logs.__file__] + cmd + p = subprocess.Popen(fullcmd, env=env, stdout=subprocess.PIPE) + assert p.stdout + try: + # Wait for the granchild log to indicate signals are set up. + for line in p.stdout: + if b"ready!" in line: + break + # Send a signal that is caught and logged, to test signal forwarding. + p.send_signal(signal.SIGINT) + for line in p.stdout: + if b"caught sigint!" in line: + break + # Send a signal that is not caught, to test for signal exit codes. + p.send_signal(signal.SIGTERM) + exit_code = p.wait() + finally: + p.kill() + p.wait() + assert exit_code == 128 + signal.SIGTERM, exit_code + assert "".join(srv.logs) == "ready!\ncaught sigint!\n", srv.logs + + @pytest.mark.e2e_cpu + def test_exit_wait_time(self) -> None: + cmd = mkcmd("print('hello world')") + # We need a misbehaving server to guarantee the shipper times out. + # This misbehaving server will listen without ever accepting. + with socket.socket() as listener: + listener.bind(("127.0.0.1", 0)) + listener.listen(10) + _, port = listener.getsockname() + master_url = f"http://127.0.0.1:{port}" + start = time.time() + exit_code = self.run_ship_logs(master_url, cmd, log_wait_time=0.1) + end = time.time() + assert exit_code == 0, exit_code + assert end - start < 1, end - start + + @pytest.mark.e2e_cpu + def test_entrypoint_not_found(self) -> None: + cmd = ["/does-not-exist"] + with ShipLogServer() as srv: + exit_code = self.run_ship_logs(srv.master_url(), cmd) + # 127 is the standard bash exit code for file-not-found. + assert exit_code == 127, exit_code + assert "FileNotFoundError" in "".join(srv.logs), srv.logs + + @pytest.mark.e2e_cpu + def test_entrypoint_not_executable(self) -> None: + cmd = ["/bin/"] + with ShipLogServer() as srv: + exit_code = self.run_ship_logs(srv.master_url(), cmd) + # 126 is the standard bash exit code for permission failure. + assert exit_code == 126, exit_code + assert "PermissionError" in "".join(srv.logs), srv.logs + + @pytest.mark.e2e_cpu + def test_only_standard_library_dependences(self) -> None: + cmd = mkcmd( + """ + # ONLY STANDARD LIBRARY IMPORTS ARE ALLOWED + import datetime + import io + import json + import logging + import os + import queue + import re + import signal + import ssl + import subprocess + import sys + import threading + import time + import traceback + import typing + import urllib.request + # END OF STANDARD LIBRARY IMPORTS + + # Now the only new module that `import ship_logs` can add is ship_logs itself. + allowed_modules = set((*sys.modules, "ship_logs")) + + sys.path = ["%s"] + sys.path + import ship_logs + + new_modules = set(sys.modules).difference(allowed_modules) + + for name in new_modules: + print("possible non-standard-library dependency detected:", name) + + exit(1 if new_modules else 0) + """ + % (static_srv) + ) + p = subprocess.Popen(cmd, stdout=subprocess.PIPE) + assert p.stdout + errmsgs = p.stdout.read().decode("utf8") + assert p.wait() == 0, "\n" + errmsgs + + @pytest.mark.e2e_cpu + def test_custom_certs(self) -> None: + # Use the untrusted key and cert from the harness unit tests. + untrusted = os.path.join(here, "../../../harness/tests/common/untrusted-root") + keyfile = os.path.join(untrusted, "127.0.0.1-key.pem") + certfile = os.path.join(untrusted, "127.0.0.1-cert.pem") + + # Create the server ssl context. + ctx = ssl.create_default_context(purpose=ssl.Purpose.CLIENT_AUTH) + ctx.load_cert_chain(certfile=certfile, keyfile=keyfile) + + cmd = mkcmd("print('hello world')") + + with ShipLogServer(ctx) as srv: + # Use the wrong name to talk to the server, to test name verification override. + master_url = f"https://localhost:{srv.port}" + exit_code = self.run_ship_logs( + master_url, + cmd, + cert_file=certfile, + cert_name="127.0.0.1", + ) + assert exit_code == 0, exit_code + assert srv.logs == ["hello world\n"], srv.logs + + @pytest.mark.e2e_cpu + def test_honor_http_proxy(self) -> None: + # Use a subprocess to control the environment. + with ShipLogServer() as srv: + env = { + "DET_MASTER": "http://notreal.faketld", + "DET_SESSION_TOKEN": "token", + "DET_TASK_LOGGING_METADATA": "{}", + "http_proxy": srv.master_url(), + } + cmd = mkcmd("print('hello world')") + fullcmd = [sys.executable, "-u", ship_logs.__file__] + cmd + subprocess.run(fullcmd, env=env, check=True) + assert srv.logs == ["hello world\n"], srv.logs + + @pytest.mark.e2e_cpu + def test_honor_no_proxy(self) -> None: + # Use a subprocess to control the environment. + with ShipLogServer() as srv: + env = { + "DET_MASTER": srv.master_url(), + "DET_SESSION_TOKEN": "token", + "DET_TASK_LOGGING_METADATA": "{}", + "http_proxy": "http://notreal.faketld", + "NO_PROXY": "127.0.0.1", + } + cmd = mkcmd("print('hello world')") + fullcmd = [sys.executable, "-u", ship_logs.__file__] + cmd + subprocess.run(fullcmd, env=env, check=True) + assert srv.logs == ["hello world\n"], srv.logs + + @pytest.mark.e2e_cpu + def test_escape_hatch(self) -> None: + # Create a temporary directory to catch our escape-hatch logs + tmp = tempfile.mkdtemp(suffix="ship_logs") + try: + with ShipLogServer() as srv: + # Use a subprocess to control the environment. + # Leave out DET_MASTER to force a crash. + env = {"DET_SHIP_LOGS_PATH": tmp} + cmd = mkcmd("pass") + fullcmd = [sys.executable, "-u", ship_logs.__file__] + cmd + p = subprocess.run(fullcmd, env=env) + assert p.returncode == 80, p.returncode + assert srv.logs == [], srv.logs + files = os.listdir(tmp) + assert len(files) == 2, files + assert "ship-logs-ran" in files, files + files.remove("ship-logs-ran") + with open(os.path.join(tmp, files[0]), "r") as f: + text = f.read() + assert "KeyError: 'DET_MASTER'" in text, text + finally: + shutil.rmtree(tmp) + + +class TestReadNewlinesOrCarriageReturns: + # read_newlines_or_carriage_returns is designed to read from filedescriptors resulting from + # subprocess.Popen(bufsize=0).stdout, and different kinds of filehandles can result in slightly + # different read/write behaviors, and stdout and stderr can additionally be different on + # different operating systems, so this test will use Popen to create file descriptors instead of + # something more convenient like os.pipe(), in order to test against the most realistic + # conditions. + + @pytest.mark.e2e_cpu + @pytest.mark.parametrize("lastline", [repr("hi"), repr("hi\r"), repr("hi\n")]) + def test_eof_handling_after_process_killed(self, lastline: str) -> None: + cmd = mkcmd( + r""" + import sys + import time + print(%s, end="", flush=True) + # message test to kill us, then wait + print(".", file=sys.stderr, flush=True) + time.sleep(10) + """ + % lastline + ) + start = time.time() + p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, bufsize=0) + assert p.stdout and p.stderr + reader = ship_logs.read_newlines_or_carriage_returns(p.stdout) + # wait for message on stderr + _ = p.stderr.read(1) + p.kill() + p.wait() + line = next(reader) + end = time.time() + assert line == "hi\n", line + # Make sure the test didn't wait for that sleep(10) to finish. + assert end - start < 1, end - start + + @pytest.mark.e2e_cpu + @pytest.mark.parametrize("lastline", ["hi", "hi\r", "hi\n"]) + def test_eof_handling_after_process_closes_stdout(self, lastline: str) -> None: + cmd = mkcmd( + r""" + import sys + import time + import os + print(%s, end="", flush=True) + # close stdout + os.close(1) + # message test that it's done, then wait + print(".", file=sys.stderr, flush=True) + time.sleep(10) + """ + % repr(lastline) + ) + start = time.time() + p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, bufsize=0) + assert p.stdout and p.stderr + reader = ship_logs.read_newlines_or_carriage_returns(p.stdout) + # wait for message on stderr + _ = p.stderr.read(1) + line = next(reader) + assert line == "hi\n", line + # stdout is now empty + with pytest.raises(StopIteration): + next(reader) + p.kill() + p.wait() + end = time.time() + # Make sure the test didn't wait for that sleep(10) to finish. + assert end - start < 1, end - start + + @pytest.mark.e2e_cpu + def test_long_lines(self) -> None: + s = "abcdefghijklmopqrstuvwxyz" + # Reader will buffer up to io.DEAULT_BUFFER_SIZE-1 before forcing a line break. + max_chars = io.DEFAULT_BUFFER_SIZE - 1 + n = (max_chars + len(s) - 1) // len(s) + msg = s * n + exp_1 = msg[:max_chars] + "\n" + exp_2 = msg[max_chars:] + "\n" + cmd = mkcmd("print(%s, flush=True)" % repr(msg)) + p = subprocess.Popen(cmd, stdout=subprocess.PIPE, bufsize=0) + assert p.stdout + reader = ship_logs.read_newlines_or_carriage_returns(p.stdout) + line = next(reader) + assert line == exp_1, line + line = next(reader) + assert line == exp_2, line + p.wait() diff --git a/harness/determined/__init__.py b/harness/determined/__init__.py index 55a722be9281..3dc6a2b2f053 100644 --- a/harness/determined/__init__.py +++ b/harness/determined/__init__.py @@ -26,6 +26,5 @@ # LOG_FORMAT is the standard format for use with the logging module, which is required for the # WebUI's log viewer to filter logs by log level. # -# Dev note: if this format is changed, -# the enrich-task-logs.py log parsing must be updated as well. +# Dev note: if this format is changed, the ship_logs.py log parsing must be updated as well. LOG_FORMAT = "%(levelname)s: [%(process)s] %(name)s: %(message)s" diff --git a/master/internal/rm/kubernetesrm/pod_test.go b/master/internal/rm/kubernetesrm/pod_test.go index 35eb620b7e1f..ea4cd7a7524a 100644 --- a/master/internal/rm/kubernetesrm/pod_test.go +++ b/master/internal/rm/kubernetesrm/pod_test.go @@ -3,7 +3,6 @@ package kubernetesrm import ( "context" "fmt" - "os" "reflect" "testing" "time" @@ -129,38 +128,11 @@ func createPodWithMockQueue(t *testing.T, k8sRequestQueue *requestQueue) ( return newPod, aID, sub } -var taskContainerFiles = []string{ - "k8_init_container_entrypoint.sh", - "task-logging-setup.sh", - "task-logging-teardown.sh", - "task-signal-handling.sh", - "enrich_task_logs.py", - "singularity-entrypoint-wrapper.sh", -} - func setupEntrypoint(t *testing.T) { - err := etc.SetRootPath(".") + err := etc.SetRootPath("../../../static/srv") if err != nil { t.Logf("Failed to set root directory") } - - for _, file := range taskContainerFiles { - //nolint:gosec - f, _ := os.OpenFile(file, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o600) - err = f.Close() - if err != nil { - t.Logf("failed to close %s", file) - } - } -} - -func cleanup(t *testing.T) { - for _, file := range taskContainerFiles { - err := os.Remove(file) - if err != nil { - t.Logf("failed to remove %s", file) - } - } } func checkReceiveTermination( @@ -194,7 +166,6 @@ func checkReceiveTermination( func TestResourceCreationFailed(t *testing.T) { setupEntrypoint(t) - defer cleanup(t) const correctMsg = "already exists" @@ -219,7 +190,6 @@ func TestResourceCreationFailed(t *testing.T) { func TestReceivePodStatusUpdateTerminated(t *testing.T) { setupEntrypoint(t) - defer cleanup(t) typeMeta := metaV1.TypeMeta{Kind: "rest test"} objectMeta := metaV1.ObjectMeta{ @@ -279,7 +249,6 @@ func TestReceivePodStatusUpdateTerminated(t *testing.T) { func TestMultipleContainerTerminate(t *testing.T) { // Status update test involving two containers. setupEntrypoint(t) - defer cleanup(t) containerStatuses := []k8sV1.ContainerStatus{ { @@ -342,7 +311,6 @@ func TestMultipleContainerTerminate(t *testing.T) { func TestReceivePodStatusUpdateAssigned(t *testing.T) { setupEntrypoint(t) - defer cleanup(t) ref, aID, sub := createPodWithMockQueue(t, nil) purge(aID, sub) @@ -378,7 +346,6 @@ func TestReceivePodStatusUpdateAssigned(t *testing.T) { func TestReceivePodStatusUpdateStarting(t *testing.T) { setupEntrypoint(t) - defer cleanup(t) typeMeta := metaV1.TypeMeta{Kind: "rest test"} objectMeta := metaV1.ObjectMeta{ @@ -489,7 +456,6 @@ func TestReceivePodStatusUpdateStarting(t *testing.T) { func TestMultipleContainersRunning(t *testing.T) { // Status update test involving two containers. setupEntrypoint(t) - defer cleanup(t) typeMeta := metaV1.TypeMeta{Kind: "rest test"} objectMeta := metaV1.ObjectMeta{ @@ -579,7 +545,6 @@ func TestMultipleContainersRunning(t *testing.T) { func TestReceivePodEventUpdate(t *testing.T) { setupEntrypoint(t) - defer cleanup(t) ref, aID, sub := createPodWithMockQueue(t, nil) purge(aID, sub) @@ -625,7 +590,6 @@ func TestReceivePodEventUpdate(t *testing.T) { func TestReceiveContainerLog(t *testing.T) { setupEntrypoint(t) - defer cleanup(t) mockLogMessage := "mock log message" ref, aID, sub := createPodWithMockQueue(t, nil) @@ -700,7 +664,6 @@ func TestReceiveContainerLog(t *testing.T) { func TestKillTaskPod(t *testing.T) { setupEntrypoint(t) - defer cleanup(t) podInterface := &mockPodInterface{pods: make(map[string]*k8sV1.Pod)} configMapInterface := &mockConfigMapInterface{configMaps: make(map[string]*k8sV1.ConfigMap)} @@ -723,7 +686,6 @@ func TestKillTaskPod(t *testing.T) { func TestResourceCreationCancelled(t *testing.T) { setupEntrypoint(t) - defer cleanup(t) podInterface := &mockPodInterface{ pods: make(map[string]*k8sV1.Pod), @@ -776,7 +738,6 @@ func TestResourceCreationCancelled(t *testing.T) { func TestResourceDeletionFailed(t *testing.T) { setupEntrypoint(t) - defer cleanup(t) podInterface := &mockPodInterface{pods: make(map[string]*k8sV1.Pod)} configMapInterface := &mockConfigMapInterface{configMaps: make(map[string]*k8sV1.ConfigMap)} @@ -821,7 +782,6 @@ func TestResourceDeletionFailed(t *testing.T) { func TestGetPodNodeInfo(t *testing.T) { setupEntrypoint(t) - defer cleanup(t) ref, aID, sub := createPodWithMockQueue(t, nil) ref.slots = 99 diff --git a/master/internal/rm/kubernetesrm/spec.go b/master/internal/rm/kubernetesrm/spec.go index 6c12eb2f970d..558b1c33d9d6 100644 --- a/master/internal/rm/kubernetesrm/spec.go +++ b/master/internal/rm/kubernetesrm/spec.go @@ -414,8 +414,6 @@ func (p *pod) createPodSpec(scheduler string) error { var sidecars []k8sV1.Container - envVars = append(envVars, k8sV1.EnvVar{Name: "DET_K8S_LOG_TO_FILE", Value: "true"}) - container := k8sV1.Container{ Name: model.DeterminedK8ContainerName, Command: spec.Entrypoint, diff --git a/master/pkg/etc/etc.go b/master/pkg/etc/etc.go index db6ccc0b7501..28561e5e66e5 100644 --- a/master/pkg/etc/etc.go +++ b/master/pkg/etc/etc.go @@ -30,8 +30,10 @@ const ( NotebookIdleCheckResource = "check_idle.py" // TaskCheckReadyLogsResource is the script to parse logs to check if a task is ready. TaskCheckReadyLogsResource = "check_ready_logs.py" - // TaskEnrichLogsResource is the script to enrich logs for slurm (which doesn't run fluent). - TaskEnrichLogsResource = "enrich_task_logs.py" + // TaskShipLogsShellResource is the shell script to call the python script to ship logs. + TaskShipLogsShellResource = "ship-logs.sh" + // TaskShipLogsPythonResource is the python script to ship logs. + TaskShipLogsPythonResource = "ship_logs.py" // TensorboardEntryScriptResource is the script to set up TensorBoard. TensorboardEntryScriptResource = "tensorboard-entrypoint.sh" // TrialEntrypointScriptResource is the script to set up a trial. @@ -40,12 +42,8 @@ const ( AgentSetupScriptTemplateResource = "agent_setup_script.sh.template" // K8InitContainerEntryScriptResource is the script to run the init container on k8s. K8InitContainerEntryScriptResource = "k8_init_container_entrypoint.sh" - // TaskLoggingSetupScriptResource is the script to setup prerequistites for logging. - TaskLoggingSetupScriptResource = "task-logging-setup.sh" - // TaskLoggingTeardownScriptResource is the script to teardown stuff for logging. - TaskLoggingTeardownScriptResource = "task-logging-teardown.sh" - // TaskSignalHandlingScriptResource is the script to teardown stuff for logging. - TaskSignalHandlingScriptResource = "task-signal-handling.sh" + // TaskSetupScriptResource is the script to setup various things for all tasks. + TaskSetupScriptResource = "task-setup.sh" // SingularityEntrypointWrapperScriptResource is the entrypoint for singularity containers. SingularityEntrypointWrapperScriptResource = "singularity-entrypoint-wrapper.sh" ) diff --git a/master/pkg/tasks/const.go b/master/pkg/tasks/const.go index 9c327109649a..2e8af58a70b0 100644 --- a/master/pkg/tasks/const.go +++ b/master/pkg/tasks/const.go @@ -8,17 +8,14 @@ const ( SingularityEntrypointWrapperScript = "singularity-entrypoint-wrapper.sh" singularityEntrypointWrapperMode = 0o744 - taskLoggingSetupScript = "task-logging-setup.sh" - taskLoggingSetupMode = 0o744 + taskSetupScript = "task-setup.sh" + taskSetupMode = 0o744 - taskLoggingTeardownScript = "task-logging-teardown.sh" - taskLoggingTeardownMode = 0o744 + taskShipLogsShell = "ship-logs.sh" + taskShipLogsShellMode = 0o755 - taskSignalHandlingScript = "task-signal-handling.sh" - taskSignalHandlingMode = 0o744 - - taskEnrichLogsScript = "enrich_task_logs.py" - taskEnrichLogsScriptMode = 0o744 + taskShipLogsPython = "ship_logs.py" + taskShipLogsPythonMode = 0o755 // Put as many ssh-related files in /run/determined as possible. In particular, it is very // important that we don't overwrite the user's host $HOME/.ssh/id_rsa, if the user happens to diff --git a/master/pkg/tasks/task.go b/master/pkg/tasks/task.go index 417ebff6617b..aced2ba73266 100644 --- a/master/pkg/tasks/task.go +++ b/master/pkg/tasks/task.go @@ -4,6 +4,7 @@ import ( "archive/tar" "encoding/json" "fmt" + "path/filepath" "strings" docker "github.com/docker/docker/api/types/container" @@ -257,6 +258,11 @@ func (t *TaskSpec) ToDockerSpec() cproto.Spec { }) } + // Prepend the entrypoint like: `ship-logs.sh "$@"`. + shipLogsShell := filepath.Join(RunDir, taskShipLogsShell) + shipLogsPython := filepath.Join(RunDir, taskShipLogsPython) + entrypoint := append([]string{shipLogsShell, shipLogsPython}, t.Entrypoint...) + runArchives, rootArchives := t.Archives() spec := cproto.Spec{ TaskType: string(t.TaskType), @@ -269,7 +275,7 @@ func (t *TaskSpec) ToDockerSpec() cproto.Spec { User: getUser(t.AgentUserGroup), ExposedPorts: toPortSet(env.Ports()), Env: envVars, - Cmd: t.Entrypoint, + Cmd: entrypoint, Image: env.Image().For(deviceType), WorkingDir: t.WorkDir, }, @@ -313,27 +319,21 @@ func workDirArchive( func runDirHelpersArchive(aug *model.AgentUserGroup) cproto.RunArchive { return wrapArchive(archive.Archive{ aug.OwnedArchiveItem( - taskLoggingSetupScript, - etc.MustStaticFile(etc.TaskLoggingSetupScriptResource), - taskLoggingSetupMode, - tar.TypeReg, - ), - aug.OwnedArchiveItem( - taskEnrichLogsScript, - etc.MustStaticFile(etc.TaskEnrichLogsResource), - taskEnrichLogsScriptMode, + taskSetupScript, + etc.MustStaticFile(etc.TaskSetupScriptResource), + taskSetupMode, tar.TypeReg, ), aug.OwnedArchiveItem( - taskLoggingTeardownScript, - etc.MustStaticFile(etc.TaskLoggingTeardownScriptResource), - taskLoggingTeardownMode, + taskShipLogsShell, + etc.MustStaticFile(etc.TaskShipLogsShellResource), + taskShipLogsShellMode, tar.TypeReg, ), aug.OwnedArchiveItem( - taskSignalHandlingScript, - etc.MustStaticFile(etc.TaskSignalHandlingScriptResource), - taskSignalHandlingMode, + taskShipLogsPython, + etc.MustStaticFile(etc.TaskShipLogsPythonResource), + taskShipLogsPythonMode, tar.TypeReg, ), aug.OwnedArchiveItem( diff --git a/master/pkg/tasks/task_trial.go b/master/pkg/tasks/task_trial.go index 2d334a69af1e..b3ec6bb6e81f 100644 --- a/master/pkg/tasks/task_trial.go +++ b/master/pkg/tasks/task_trial.go @@ -126,10 +126,6 @@ func (s TrialSpec) ToTaskSpec() TaskSpec { res.ExtraEnvVars = envVars - res.LoggingFields = map[string]string{ - "trial_id": strconv.Itoa(s.TrialID), - } - if shm := s.ExperimentConfig.Resources().ShmSize(); shm != nil { res.ShmSize = int64(*shm) } diff --git a/master/static/srv/command-entrypoint.sh b/master/static/srv/command-entrypoint.sh index c8038ca2fe73..d5d8af0d0557 100644 --- a/master/static/srv/command-entrypoint.sh +++ b/master/static/srv/command-entrypoint.sh @@ -1,7 +1,6 @@ #!/usr/bin/env bash -source /run/determined/task-signal-handling.sh -source /run/determined/task-logging-setup.sh +source /run/determined/task-setup.sh set -e @@ -15,10 +14,8 @@ fi # to register the proxy with the Determined master. "$DET_PYTHON_EXECUTABLE" -m determined.exec.prep_container --proxy --download_context_directory -trap_and_forward_signals if [ "$#" -eq 1 ]; then - /bin/sh -c "$@" & + exec /bin/sh -c "$@" else - "$@" & + exec "$@" fi -wait_and_handle_signals $! diff --git a/master/static/srv/enrich_task_logs.py b/master/static/srv/enrich_task_logs.py deleted file mode 100644 index ec7468c207cd..000000000000 --- a/master/static/srv/enrich_task_logs.py +++ /dev/null @@ -1,206 +0,0 @@ -import argparse -import datetime -import distutils.util -import json -import os -import queue -import re -import socket -import sys -import threading -import time -from typing import Any, Dict, Iterator - -from determined.common import api -from determined.common.api import certs, errors - -# Example log message given below. -# 2022-05-12 16:32:48,757:gc_checkpoints: [rank=0] INFO: Determined checkpoint GC, version 0.17.16-dev0 -# Below regex is used to extract the rank field from the log message. -# Excluding empty spaces this regex matches rank in the above example as [rank=0] -rank = re.compile("(?P ?)\[rank=(?P([0-9]+))\](?P ?)(?P.*)") -# Below regex is used to extract the message severity from the log message. -# Excluding empty spaces and delimiter(:) this regex matches message severity level in the above example as INFO -level = re.compile( - "(?P ?)(?P(DEBUG|INFO|WARNING|ERROR|CRITICAL)):(?P ?)(?P.*)" -) - - -# Interval at which to force a flush. -SHIPPER_FLUSH_INTERVAL = 1 # How often to make API calls - -# Full jitter time on encountering an API exception. -SHIPPER_FAILURE_BACKOFF_SECONDS = 1 - -# Max size of the log buffer before forcing a flush. -LOG_BATCH_MAX_SIZE = 1000 - -# Max size of the shipping queue before we start to apply backpressure by blocking sends. We would -# only hit this if we got underwater by three full batches while trying to ship a batch. -SHIP_QUEUE_MAX_SIZE = 3 * LOG_BATCH_MAX_SIZE - - -class ShutdownMessage: - pass - - -class LogCollector(threading.Thread): - def __init__( - self, - ship_queue: queue.Queue, - task_logging_metadata: Dict[str, Any], - emit_stdout_logs: bool, - ): - self.ship_queue = ship_queue - self.task_logging_metadata = task_logging_metadata - self.emit_stdout_logs = emit_stdout_logs - super().__init__() - - def run(self) -> None: - try: - for line in sys.stdin: - if self.emit_stdout_logs: - print(line, flush=True, end="") - try: - parsed_metadata = {} - - m = rank.match(line) - if m: - try: - parsed_metadata["rank_id"] = int(m.group("rank_id")) - line = m.group("log") - except ValueError: - pass - - m = level.match(line) - if m: - parsed_metadata["level"] = m.group("level") - line = m.group("log") - - self.ship_queue.put( - { - "timestamp": datetime.datetime.now(datetime.timezone.utc).isoformat(), - "log": line if line.endswith("\n") else line + "\n", - **self.task_logging_metadata, - **parsed_metadata, - } - ) - except Exception as e: - print(f"fatal error collecting log {e}", file=sys.stderr) - finally: - self.ship_queue.put(ShutdownMessage()) - - -class LogShipper(threading.Thread): - """ - This is a thread that exists solely so that we can batch logs and ship them to the - SenderThread every FLUSH_INTERVAL seconds. - """ - - def __init__( - self, - ship_queue: queue.Queue, - master_url: str, - cert: certs.Cert, - ) -> None: - self.ship_queue = ship_queue - self.logs = [] - self.master_url = master_url - self.cert = cert - super().__init__() - - def run(self) -> None: - while True: - deadline = time.time() + SHIPPER_FLUSH_INTERVAL - for m in pop_until_deadline(self.ship_queue, deadline): - if isinstance(m, ShutdownMessage): - self.ship() - return - - self.logs.append(m) - if len(self.logs) >= LOG_BATCH_MAX_SIZE: - self.ship() - - # Timeout met. - self.ship() - - def ship(self) -> None: - if len(self.logs) <= 0: - return - - max_tries = 3 - tries = 0 - while tries < max_tries: - try: - api.post(self.master_url, "task-logs", self.logs, cert=self.cert) - self.logs = [] - return - except Exception as e: - tries += 1 - if tries == max_tries: - print( - f"failed to ship logs: {e}\nLogs to Ship: {len(self.logs)}", - file=sys.stderr, - ) - time.sleep(SHIPPER_FAILURE_BACKOFF_SECONDS) - - -def pop_until_deadline(q: queue.Queue, deadline: float) -> Iterator[Any]: - while True: - timeout = deadline - time.time() - if timeout <= 0: - break - - try: - yield q.get(timeout=timeout) - except queue.Empty: - break - - -def main( - master_url: str, - cert: certs.Cert, - task_logging_metadata: Dict[str, Any], - emit_stdout_logs: bool, -) -> None: - ship_queue = queue.Queue(maxsize=SHIP_QUEUE_MAX_SIZE) - collector = LogCollector(ship_queue, task_logging_metadata, emit_stdout_logs) - shipper = LogShipper(ship_queue, master_url, cert) - - collector.start() - shipper.start() - - # Collector will exit when it sees the end of stdin. - collector.join() - shipper.join() - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="read a stream and enrich it with the standard logging metadata" - ) - parser.add_argument("--stdtype", type=str, help="the stdtype of this stream", required=True) - args = parser.parse_args() - - master_url = os.environ.get("DET_MASTER", os.environ.get("DET_MASTER_ADDR")) - assert master_url is not None, "DET_MASTER and DET_MASTER_ADDR unset" - - task_logging_metadata_json = os.environ.get("DET_TASK_LOGGING_METADATA") - assert task_logging_metadata_json is not None, "DET_TASK_LOGGING_METADATA unset" - - cert = certs.default_load(master_url) - - task_logging_metadata = json.loads(task_logging_metadata_json) - task_logging_metadata["stdtype"] = args.stdtype - task_logging_metadata["agent_id"] = socket.gethostname() - task_logging_metadata["source"] = "task" - container_id = os.environ.get("DET_CONTAINER_ID") - if container_id is not None: - task_logging_metadata["container_id"] = container_id - # If trial exists, just drop it since it could mess with de-ser on the API end. - task_logging_metadata.pop("trial_id", None) - emit_stdout_logs = distutils.util.strtobool( - os.environ.get("DET_SHIPPER_EMIT_STDOUT_LOGS", "True"), - ) - - main(master_url, cert, task_logging_metadata, emit_stdout_logs) diff --git a/master/static/srv/entrypoint.sh b/master/static/srv/entrypoint.sh index 20a3f4361fa2..3cd91b15b14d 100755 --- a/master/static/srv/entrypoint.sh +++ b/master/static/srv/entrypoint.sh @@ -1,7 +1,6 @@ #!/bin/bash -source /run/determined/task-signal-handling.sh -source /run/determined/task-logging-setup.sh +source /run/determined/task-setup.sh set -e @@ -34,6 +33,4 @@ set +x # Do rendezvous last, to ensure all launch layers start around the same time. "$DET_PYTHON_EXECUTABLE" -m determined.exec.prep_container --rendezvous -trap_and_forward_signals -"$DET_PYTHON_EXECUTABLE" -m determined.exec.launch "$@" & -wait_and_handle_signals $! +exec "$DET_PYTHON_EXECUTABLE" -m determined.exec.launch "$@" diff --git a/master/static/srv/gc-checkpoints-entrypoint.sh b/master/static/srv/gc-checkpoints-entrypoint.sh index 5a6fef68968d..398d9637db6d 100644 --- a/master/static/srv/gc-checkpoints-entrypoint.sh +++ b/master/static/srv/gc-checkpoints-entrypoint.sh @@ -1,7 +1,6 @@ #!/usr/bin/env bash -source /run/determined/task-signal-handling.sh -source /run/determined/task-logging-setup.sh +source /run/determined/task-setup.sh set -e @@ -12,6 +11,4 @@ fi "$DET_PYTHON_EXECUTABLE" -m determined.exec.prep_container -trap_and_forward_signals -"$DET_PYTHON_EXECUTABLE" -m determined.exec.gc_checkpoints "$@" & -wait_and_handle_signals $! +exec "$DET_PYTHON_EXECUTABLE" -m determined.exec.gc_checkpoints "$@" diff --git a/master/static/srv/notebook-entrypoint.sh b/master/static/srv/notebook-entrypoint.sh index fe1d42460c26..11396d940b25 100755 --- a/master/static/srv/notebook-entrypoint.sh +++ b/master/static/srv/notebook-entrypoint.sh @@ -1,7 +1,6 @@ #!/usr/bin/env bash -source /run/determined/task-signal-handling.sh -source /run/determined/task-logging-setup.sh +source /run/determined/task-setup.sh set -e @@ -43,7 +42,6 @@ set +x JUPYTER_LAB_LOG_FORMAT="%(levelname)s: [%(name)s] %(message)s" READINESS_REGEX='^.*Jupyter Server .* is running.*$' -trap_and_forward_signals jupyter lab --ServerApp.port=${NOTEBOOK_PORT} \ --ServerApp.allow_origin="*" \ --ServerApp.base_url="/proxy/${DET_TASK_ID}/" \ @@ -61,4 +59,3 @@ jupyter lab --ServerApp.port=${NOTEBOOK_PORT} \ --LabApp.log_format="$JUPYTER_LAB_LOG_FORMAT" \ --ServerApp.log_format="$JUPYTER_LAB_LOG_FORMAT" \ 2> >(tee -p >("$DET_PYTHON_EXECUTABLE" /run/determined/check_ready_logs.py --ready-regex "${READINESS_REGEX}") >&2) -wait_and_handle_signals $! diff --git a/master/static/srv/shell-entrypoint.sh b/master/static/srv/shell-entrypoint.sh index 48d8439dbb4d..37180d06c291 100755 --- a/master/static/srv/shell-entrypoint.sh +++ b/master/static/srv/shell-entrypoint.sh @@ -1,7 +1,6 @@ #!/usr/bin/env bash -source /run/determined/task-signal-handling.sh -source /run/determined/task-logging-setup.sh +source /run/determined/task-setup.sh set -e @@ -93,7 +92,5 @@ chmod 600 "$modified" READINESS_REGEX="Server listening on" -trap_and_forward_signals /usr/sbin/sshd "$@" \ - 2> >(tee -p >("$DET_PYTHON_EXECUTABLE" /run/determined/check_ready_logs.py --ready-regex "$READINESS_REGEX") >&2) & -wait_and_handle_signals $! + 2> >(tee -p >("$DET_PYTHON_EXECUTABLE" /run/determined/check_ready_logs.py --ready-regex "$READINESS_REGEX") >&2) diff --git a/master/static/srv/ship-logs.sh b/master/static/srv/ship-logs.sh new file mode 100755 index 000000000000..a001f6fada0b --- /dev/null +++ b/master/static/srv/ship-logs.sh @@ -0,0 +1,21 @@ +#!/usr/bin/env bash + +# Ideally, we'd like to make ship_logs.py the entrypoint in task containers, so +# it could capture all logs from any process in the process tree. +# +# But we can't actually set it as the entrypoint because we don't know how to +# call python in the container until we're inside the container. +# +# So ship-logs.sh runs inside the container, figures out how to call python, and +# then calls ship_logs.py. + +set -e + +if [ -z "$DET_PYTHON_EXECUTABLE" ]; then + export DET_PYTHON_EXECUTABLE="python3" +fi + +ship_logs="$1" +shift + +exec "$DET_PYTHON_EXECUTABLE" "$ship_logs" "$@" diff --git a/master/static/srv/ship_logs.py b/master/static/srv/ship_logs.py new file mode 100644 index 000000000000..ed66a2b697c9 --- /dev/null +++ b/master/static/srv/ship_logs.py @@ -0,0 +1,609 @@ +""" +THIS LOG SHIPPER MUST ONLY SHIP LOGS. + +It is not allowed to depend on any libraries other than python standard libraries, including our own +determined library. The reason is that python virtual env errors in user containers are common +scenarios for users to encounter, and we must be able to still ship logs in those scenarios. + +The only thing that is allowed to break the log shipper is a misconfigured cluster (if the log +shipper isn't able to connect to the master) or if python isn't installed. + +But if you are thinking of making this log shipper do anything other than ship logs, or if you are +thinking of adding any dependencies not inside the standard library, stop. Let the log shipper +just ship logs; it's too important of a job to mix it with anything else. + +--- + +ship_logs.py: a suitable container entrypoint that ships logs from a child process to the master. + +usage: ship_logs.py CMD ARGS... + +ship_logs.py will read environment variables set by the master to obtain its configuration. It +isn't intended to be useful in any non-managed environments. +""" + +import datetime +import io +import json +import logging +import os +import queue +import re +import signal +import ssl +import subprocess +import sys +import threading +import time +import traceback +import urllib.request +from typing import Any, Dict, Iterator, List, Optional, cast + +# Duplicated from determined/__init__.py. It's nice to keep them in sync. +LOG_FORMAT = "%(levelname)s: [%(process)s] %(name)s: %(message)s" + + +# Example log message given below. +# 2022-05-12 16:32:48,757:gc_checkpoints: [rank=0] INFO: Determined checkpoint GC, ... +# Below regex is used to extract the rank field from the log message. +# Excluding empty spaces this regex matches rank in the above example as [rank=0] +rank = re.compile(r"(?P ?)\[rank=(?P([0-9]+))\](?P ?)(?P.*)") +# Below regex is used to extract the message severity from the log message. +# Excluding empty spaces and delimiter(:) this regex matches message severity level in the above +# example as INFO +level = re.compile( + r"(?P ?)(?P(DEBUG|INFO|WARNING|ERROR|CRITICAL)):(?P ?)(?P.*)" +) +lineend = re.compile(rb"[\r\n]") + + +# Interval at which to force a flush. +SHIPPER_FLUSH_INTERVAL = 1 # How often to make API calls + +# Full jitter time on encountering an API exception. +SHIPPER_FAILURE_BACKOFF_SECONDS = 1 + +# Max size of the log buffer before forcing a flush. +LOG_BATCH_MAX_SIZE = 1000 + +# Max size of the shipping queue before we start to apply backpressure by blocking sends. We would +# only hit this if we got underwater by three full batches while trying to ship a batch. +SHIP_QUEUE_MAX_SIZE = 3 * LOG_BATCH_MAX_SIZE + + +def read_newlines_or_carriage_returns(fd: io.RawIOBase) -> Iterator[str]: + r""" + Read lines, delineated by either '\n' or '\r. + + Unlike the default io.BufferedReader used in subprocess.Popen(bufsize=-1), we read until we + encounter either '\n' or \r', and treat that as one line. + + Specifically, io.BufferedReader doesn't handle tqdm progress bar outputs very well; it treats + all of the '\r' outputs as one enormous line. + + Args: + fd: an unbuffered stdout or stderr from a subprocess.Popen. + + Yields: + A series of str, one per line. Each line always ends with a '\n'. Each line will be + broken to length io.DEFAULT_BUFFER_SIZE, even if the underlying io didn't have a linebreak. + """ + # Ship lines of length of DEFAULT_BUFFER_SIZE, including the terminating newline. + limit = io.DEFAULT_BUFFER_SIZE - 1 + nread = 0 + chunks = [] # type: List[bytes] + + def oneline(): + nonlocal nread + nonlocal chunks + out = b"".join(chunks).decode("utf8") + chunks = [] + nread = 0 + return out + + while True: + buf = fd.read(limit - nread) + if not buf: + # EOF. + break + + # Extract all the lines from this buffer. + while buf: + m = lineend.search(buf) + if m is None: + # No line break here; just append to chunks. + chunks.append(buf) + nread += len(buf) + break + + # Line break found! + start, end = m.span() + chunks.append(buf[:start]) + # Even if we matched a '\r', emit a '\n'. + chunks.append(b"\n") + yield oneline() + # keep checking the rest of buf + buf = buf[end:] + + # Detect if we reached our buffer limit. + if nread >= limit: + # Pretend we got a line anyway. + chunks.append(b"\n") + yield oneline() + + # One last line, maybe. + if chunks: + chunks.append(b"\n") + yield oneline() + + +class Collector(threading.Thread): + """ + Collector is the thread that reads and parses lines from stdout or stderr. + + It will pass structured data to the logq, and will send a message on doneq when it finishes. + """ + + def __init__( + self, + fd: io.RawIOBase, + stdtype: str, + emit_stdout_logs: bool, + metadata: Dict[str, Any], + logq: queue.Queue, + doneq: queue.Queue, + ) -> None: + super().__init__() + self.fd = fd + self.stdtype = stdtype + self.metadata = {"stdtype": self.stdtype, **metadata} + self.logq = logq + self.doneq = doneq + + if not emit_stdout_logs: + self.dup_io = None + else: + self.dup_io = sys.stdout if stdtype == "stdout" else sys.stderr + + self.shipper_died = False + + def run(self) -> None: + try: + self._run() + self.doneq.put((self.stdtype, None, None)) + except Exception as e: + self.logq.put(None) + self.doneq.put((self.stdtype, None, e)) + else: + self.logq.put(None) + + def _run(self) -> None: + for line in read_newlines_or_carriage_returns(self.fd): + # Capture the timestamp as soon as the line is collected. + now = datetime.datetime.now(datetime.timezone.utc).isoformat() + + if self.dup_io: + print(line, file=self.dup_io, flush=True, end="") + + if self.shipper_died: + # Keep draining logs so process doesn't block on stdout or stderr, but don't bother + # queuing the logs we capture. + continue + + log = {"timestamp": now, **self.metadata} # type: Dict[str, Any] + + m = rank.match(line) + if m: + try: + log["rank_id"] = int(m.group("rank_id")) + line = m.group("log") + except ValueError: + pass + + m = level.match(line) + if m: + log["level"] = m.group("level") + line = m.group("log") + + log["log"] = line + + self.logq.put(log) + + +def override_verify_name(ctx: ssl.SSLContext, verify_name: str) -> ssl.SSLContext: + class VerifyNameOverride: + def __getattr__(self, name: str, default: Any = None) -> Any: + return getattr(ctx, name, default) + + def wrap_socket(self, *args, server_hostname=None, **kwargs) -> Any: + kwargs["server_hostname"] = verify_name + return ctx.wrap_socket(*args, **kwargs) + + return cast(ssl.SSLContext, VerifyNameOverride()) + + +class Shipper(threading.Thread): + """ + Shipper reads structured logs from logq and ships them to the determined-master. + + It will send a message on doneq when it finishes. + """ + + def __init__( + self, + master_url: str, + token: str, + cert_name: str, + cert_file: str, + logq: queue.Queue, + doneq: queue.Queue, + daemon: bool, + ) -> None: + super().__init__(daemon=daemon) + self.logq = logq + self.doneq = doneq + + self.headers = {"Authentication": f"Bearer {token}"} + + baseurl = master_url.rstrip("/") + self.url = f"{baseurl}/task-logs" + + self.context = None + if master_url.startswith("https://"): + # Create an SSLContext that trusts our DET_MASTER_CERT_FILE, and checks the hostname + # against the DET_MASTER_CERT_NAME (which may differ from the hostname in the url). + self.context = ssl.create_default_context() + if cert_file.lower() == "noverify": + # Don't check the master's certificate. + # Presently the master never sets this value for DET_MASTER_CERT_FILE, but we keep + # this check to be consistent with the CLI behavior. + self.context.verify_mode = ssl.CERT_NONE + elif cert_file: + # Explicitly trust the cert in cert_file. + self.context.load_verify_locations(cafile=cert_file) + if cert_name: + # Override hostname verification + self.context = override_verify_name(self.context, cert_name) + + def run(self) -> None: + try: + self._run() + except Exception as e: + self.doneq.put(("shipper", None, e)) + else: + self.doneq.put(("shipper", None, None)) + + def _run(self) -> None: + eofs = 0 + while eofs < 2: + logs = [] # type: List[Dict[str, Any]] + deadline = time.time() + SHIPPER_FLUSH_INTERVAL + # Pop logs until both collectors close, or we fill up a batch, or we hit the deadline. + while eofs < 2 and len(logs) < LOG_BATCH_MAX_SIZE: + now = time.time() + timeout = deadline - now + if timeout <= 0: + # We are already passed the deadline. + break + + try: + log = self.logq.get(timeout=timeout) + except queue.Empty: + # We hit the timeout. + break + + if log is None: + eofs += 1 + continue + + logs.append(log) + + data = json.dumps(logs).encode("utf8") + + # Try to ship for about ten minutes. + backoffs = [0, 1, 5, 10, 15, 15, 15, 15, 15, 15, 15, 60, 60, 60, 60, 60, 60, 60, 60, 60] + + self.ship(data, backoffs) + + def ship(self, data: bytes, backoffs: List[int]) -> None: + for delay in backoffs: + time.sleep(delay) + try: + req = urllib.request.Request(self.url, data, self.headers, method="POST") + with urllib.request.urlopen(req, context=self.context) as resp: + respbody = resp.read() + + # XXX: what am I supposed to be checking here? + + assert resp.getcode() == 200, (resp.get_code(), respbody.decode("utf8")) + # Shipped successfully + return + + except Exception: + logging.error("failed to ship logs to master", exc_info=True) + pass + + raise RuntimeError("failed to connect to master for too long, giving up") + + def ship_special(self, msg: str, metadata: Dict[str, str], emit_stdout_logs: bool) -> None: + """ + Ship a special message, probably from failing to start the child process. + """ + now = datetime.datetime.now(datetime.timezone.utc).isoformat() + logs = [] + # Build a json log line out of each message line. + if not msg.endswith("\n"): + msg += "\n" + for line in msg.splitlines(keepends=True): + if emit_stdout_logs: + print(line, end="", file=sys.stderr) + logs.append( + { + "timestamp": now, + "log": line, + "level": "ERROR", + "stdtype": "stderr", + **metadata, + } + ) + + data = json.dumps(logs).encode("utf8") + + # Try to ship for about 30 seconds. + backoffs = [0, 1, 5, 10, 15] + self.ship(data, backoffs) + + +def main( + master_url: str, + cert_name: str, + cert_file: str, + metadata: Dict[str, str], + token: str, + emit_stdout_logs: bool, + cmd: List[str], + log_wait_time: int, +) -> int: + logq = queue.Queue() # type: queue.Queue + doneq = queue.Queue() # type: queue.Queue + + waiter_started = False + stdout_started = False + stderr_started = False + shipper_started = False + + # Normally we like structured concurrency; i.e. a function that owns a thread must not exit + # until that thead has been properly cleaned up. However, it is important that the log shipper + # is not allowed to keep a task container alive too long after the child process has exited. We + # want to guarantee that we exit about DET_LOG_WAIT_TIME seconds after the child process exits. + # + # However, interruping a synchronous HTTP call form urllib is nearly impossible; even if you + # were to select() until the underlying file descriptor had something to read before calling + # Request.read(), there are many buffered readers in there and most likely multiple os.read() + # calls would occur and you'd be blocking anyway. + # + # So as an easy workaround, we set daemon=True and just exit the process if it's not done on + # time. + shipper = Shipper(master_url, token, cert_name, cert_file, logq, doneq, daemon=True) + shipper_timed_out = False + + # Start the process or ship a special log message to the master why we couldn't. + try: + # Don't rely on Popen's standard line buffering; we want to do our own line buffering. + p = subprocess.Popen( + cmd, stdin=subprocess.DEVNULL, stdout=subprocess.PIPE, stderr=subprocess.PIPE, bufsize=0 + ) + except FileNotFoundError: + shipper.ship_special(f"FileNotFoundError executing {cmd}", metadata, emit_stdout_logs) + # 127 is the standard bash exit code for file-not-found. + return 127 + except PermissionError: + # Unable to read or to execute the command. + shipper.ship_special(f"PermissionError executing {cmd}", metadata, emit_stdout_logs) + # 126 is the standard bash exit code for permission failure. + return 126 + except Exception: + msg = f"unexpected failure executing {cmd}:\n" + traceback.format_exc() + shipper.ship_special(msg, metadata, emit_stdout_logs) + # 80 is the exit code we use to signal "ship_logs.py failed" + return 80 + + # Just for mypy. + assert isinstance(p.stdout, io.RawIOBase) and isinstance(p.stderr, io.RawIOBase) + + try: + stdout = Collector(p.stdout, "stdout", emit_stdout_logs, metadata, logq, doneq) + stderr = Collector(p.stderr, "stderr", emit_stdout_logs, metadata, logq, doneq) + waiter = threading.Thread(target=lambda: doneq.put(("waiter", p.wait(), None))) + + waiter.start() + waiter_started = True + + # Set up signal forwarding. + def signal_passthru(signum: Any, frame: Any): + p.send_signal(signum) + + for sig in [ + signal.SIGINT, + signal.SIGTERM, + signal.SIGHUP, + signal.SIGUSR1, + signal.SIGUSR2, + signal.SIGWINCH, + ]: + signal.signal(sig, signal_passthru) + + stdout.start() + stdout_started = True + + stderr.start() + stderr_started = True + + shipper.start() + shipper_started = True + + # expect 4 messages on the doneq + first_who = None + first_error = None + exit_code = None + deadline = None # type: Optional[float] + for _ in range(4): + # Wait for an event, possibly with a deadline (if the child process already exited). + try: + timeout = None if deadline is None else deadline - time.time() + if timeout is not None and timeout <= 0: + raise queue.Empty() + who, what, error = doneq.get(timeout=timeout) + except queue.Empty: + # Deadline is done, just abandon the shipper. + shipper_timed_out = True + logging.error( + f"waited {log_wait_time} seconds for shipper to finish after child exit; " + "giving up now" + ) + break + + if who == "shipper": + # There's no point in collecting logs after the shipper is gone. + stdout.shipper_died = True + stderr.shipper_died = True + + if who == "waiter": + # After the log shipper exits, the shipping code is on a deadline. + deadline = time.time() + log_wait_time + exit_code = what + + if first_error is None and error is not None: + first_who = who + first_error = error + # If our logging infrastructure ever crashes, just give up on the child process. + p.kill() + + if first_error is not None: + raise RuntimeError(f"failure in log shipper; {first_who} thread died") from first_error + + # Mypy doesn't know that we're guaranteed to have an exit_code by now. + assert exit_code is not None + + # Convert signal exits to standard bash-style exits. + if exit_code < 0: + return 128 - exit_code + return exit_code + + finally: + # p.kill() is a SIGKILL, so p.wait() should be quick. + p.kill() + if waiter_started: + waiter.join() + # After p is dead, the Collectors should run out of input and exit quickly. + if stdout_started: + stdout.join() + if stderr_started: + stderr.join() + # After everyone else is dead, the shipper could still be in a retry loop for a long time, + # so we wait up to DET_LOG_WAIT_TIME seconds for it to exit, then we give up on it. + if shipper_started and not shipper_timed_out: + shipper.join(timeout=log_wait_time) + if shipper.is_alive(): + # The timeout was reached. + logging.error( + f"waited {log_wait_time} seconds for shipper to finish after crash; " + "giving up now" + ) + + +def configure_escape_hatch(dirpath: str) -> None: + """ + Even if the log shipper goes belly-up, dump logs to a bind-mounted path. + + If the log shipper is failing in production, you obviously can't expect to find those logs in + task logs, so this allows a user or support person to mount a directory into a container in + order to find out why the log shipper is broken. + """ + + try: + hostname = os.environ["DET_AGENT_ID"] + except Exception: + try: + import socket + + hostname = socket.gethostname() + except Exception: + hostname = "unknown" + # You can't run logging.basicConfig() twice so we manually add a second handler at the root + # logging level. + fh = logging.FileHandler( + filename=os.path.join(ship_logs_path, f"{hostname}-{time.time()}.log"), + # Only create the file if we actually log to it (aka if there's an error). That way if + # there's lots of processes not failing, we don't create tons of empty files. + delay=False, + ) + fh.setFormatter(logging.Formatter(LOG_FORMAT)) + logging.getLogger().addHandler(fh) + + try: + # Touch a single file to indicate that the escape hatch is working, so that in debugging + # someone can distinguish "the escape hatch isn't working" from "ship_logs just isn't + # hitting any errors". + with open(os.path.join(ship_logs_path, "ship-logs-ran"), "w"): + pass + except Exception: + pass + + +if __name__ == "__main__": + try: + logging.basicConfig( + format=LOG_FORMAT, + stream=sys.stderr, + ) + + ship_logs_path = os.environ.get("DET_SHIP_LOGS_PATH", "/ship_logs") + if os.path.exists(ship_logs_path): + configure_escape_hatch(ship_logs_path) + + master_url = os.environ["DET_MASTER"] + cert_name = os.environ.get("DET_MASTER_CERT_NAME", "") + cert_file = os.environ.get("DET_MASTER_CERT_FILE", "") + # TODO(rb): fix DET_USER_TOKEN to support tokens with lifetimes tied to an allocation, and + # use DET_USER_TOKEN here instead. + token = os.environ["DET_SESSION_TOKEN"] + raw_metadata = os.environ["DET_TASK_LOGGING_METADATA"] + try: + metadata = json.loads(raw_metadata) + assert isinstance(metadata, dict) + except Exception: + raise ValueError(f"invalid DET_TASK_LOGGING_METADATA: '{raw_metadata}'") from None + + metadata["container_id"] = os.environ.get("DET_CONTAINER_ID", "") + metadata["agent_id"] = os.environ.get("DET_AGENT_ID", "") + + raw_log_wait_time = os.environ.get("DET_LOG_WAIT_TIME", "30") + try: + log_wait_time = int(raw_log_wait_time) + except Exception: + raise ValueError(f"invalid DET_LOG_WAIT_TIME: '{raw_log_wait_time}'") from None + + emit_stdout_logs = bool(os.environ.get("DET_SHIPPER_EMIT_STDOUT_LOGS")) + + metadata["source"] = "task" + + # XXX: remove this after ensuring it has been cleaned up master-side + if "trial_id" in metadata: + os._exit(81) + + exit_code = main( + master_url, + cert_name, + cert_file, + metadata, + token, + emit_stdout_logs, + cmd=sys.argv[1:], + log_wait_time=log_wait_time, + ) + except Exception: + logging.error("ship_logs.py crashed!", exc_info=True) + sys.exit(80) + + sys.exit(exit_code) diff --git a/master/static/srv/task-logging-setup.sh b/master/static/srv/task-logging-setup.sh deleted file mode 100644 index 9a05acfaf297..000000000000 --- a/master/static/srv/task-logging-setup.sh +++ /dev/null @@ -1,123 +0,0 @@ -#!/usr/bin/env bash - -STDOUT_FILE=/run/determined/train/logs/stdout.log -STDERR_FILE=/run/determined/train/logs/stderr.log - -mkdir -p "$(dirname "$STDOUT_FILE")" "$(dirname "$STDERR_FILE")" - -# Create symbolic links from well-known files to this process's STDOUT and -# STDERR. Anything written to those files will be inserted into the output -# streams of this process, allowing distributed training logs to route through -# individual containers rather than all going through SSH back to agent 0. -ln -sf /proc/$$/fd/1 "$STDOUT_FILE" -ln -sf /proc/$$/fd/2 "$STDERR_FILE" - -# Create a FIFO to monitor process substitution exits, and a count to know how -# many to wait on. -DET_LOG_WAIT_FIFO=/run/determined/train/logs/wait.fifo -DET_LOG_WAIT_COUNT=0 -mkfifo $DET_LOG_WAIT_FIFO - -# Save the original stdout and stderr. Process substitutions we'll be doing -# below block until their stdin is closed and, when we clean up, by saving these -# we can close them safely and replace stdout and stderr for the shell with the -# original. -exec {ORIGINAL_STDOUT}>&1 {ORIGINAL_STDERR}>&2 - -if [ -n "$DET_K8S_LOG_TO_FILE" ]; then - # To do logging with a sidecar in Kubernetes, we need to log to files that - # can then be read from the sidecar. To avoid a disk explosion, we need to - # layer on some rotation. multilog is a tool that automatically writes its - # stdin to rotated log files; the following line pipes stdout and stderr of - # this process to separate multilog invocations. "n2" means to only store - # one old log file -- the logs are being streamed out, so we - # don't need to keep any more old ones around. Create the dirs ahead of time - # so they are 0755 (when they don't exist, multilog makes them 0700 and - # they can't accessed with the non-root user). - STDOUT_ROTATE_DIR="$STDOUT_FILE-rotate" - STDERR_ROTATE_DIR="$STDERR_FILE-rotate" - mkdir -p -m 755 $STDOUT_ROTATE_DIR - mkdir -p -m 755 $STDERR_ROTATE_DIR - - exec 1> >( - multilog n2 "$STDOUT_ROTATE_DIR" - printf x >$DET_LOG_WAIT_FIFO - ) \ - 2> >( - multilog n2 "$STDERR_ROTATE_DIR" - printf x >$DET_LOG_WAIT_FIFO - ) - - ((DET_LOG_WAIT_COUNT += 2)) -fi - -export PATH="/run/determined/pythonuserbase/bin:$PATH" -if [ -z "$DET_PYTHON_EXECUTABLE" ]; then - export DET_PYTHON_EXECUTABLE="python3" -fi - -if ! "$DET_PYTHON_EXECUTABLE" --version >/dev/null 2>&1; then - echo "{\"log\": \"error: unable to find python3 as '$DET_PYTHON_EXECUTABLE'\n\", \"timestamp\": \"$(date --rfc-3339=seconds)\"}" >&2 - echo "{\"log\": \"please install python3 or set the environment variable DET_PYTHON_EXECUTABLE=/path/to/python3\n\", \"timestamp\": "$(date --rfc-3339=seconds)"}" >&2 - exit 1 -fi - -if [ -z "$DET_SKIP_PIP_INSTALL" ]; then - "$DET_PYTHON_EXECUTABLE" -m pip install -q --user /opt/determined/wheels/determined*.whl -else - if ! "$DET_PYTHON_EXECUTABLE" -c "import determined" >/dev/null 2>&1; then - echo "{\"log\": \"error: unable run without determined package\n\", \"timestamp\": \"$(date --rfc-3339=seconds)\"}" >&2 - exit 1 - fi -fi - -# Intercept stdout/stderr and send content to DET_MASTER via the log API. -# When completed, write a single character to the DET_LOG_WAIT_FIFO to signal -# completion of one procesor. -exec 1> >( - "$DET_PYTHON_EXECUTABLE" /run/determined/enrich_task_logs.py --stdtype stdout >&1 - printf x >$DET_LOG_WAIT_FIFO -) \ -2> >( - "$DET_PYTHON_EXECUTABLE" /run/determined/enrich_task_logs.py --stdtype stderr >&2 - printf x >$DET_LOG_WAIT_FIFO -) - -((DET_LOG_WAIT_COUNT += 2)) - -if [ "$DET_RESOURCES_TYPE" == "slurm-job" ]; then - # Each container sends the Determined Master a notification that it's - # running, so that the Determined Master knows whether to set the state - # of the experiment to "Pulling", meaning some nodes are pulling down - # the image, or "Running", meaning that all containers are running. - # - # Note: This is not related to logging, but since task-logging-setup.sh - # gets called by all the entrypoint scripts, it seemed like the logical - # place to add it, without having to modify each entrypoint script. - "$DET_PYTHON_EXECUTABLE" -m determined.exec.prep_container --notify_container_running -fi - -# A task may output carriage return characters (\r) to do something mildly fancy -# with the terminal like update a progress bar in place on one line. Python's -# tqdm library is a common way to do this. That works poorly with our logging, -# since Fluent Bit interprets everything as one line, causing it to mash -# everything together and buffer the output for way too long. Since we're not -# going to do anything like interpreting the carriage returns in our log -# displays, here we simply replace them all with newlines to get a reasonable -# effect in those cases. This must be after the multilog exec, since exec -# redirections are applied in reverse order. -# -# When completed, write a single character to the DET_LOG_WAIT_FIFO to signal -# completion of one procesor. -exec > >( - stdbuf -o0 tr '\r' '\n' - printf x >$DET_LOG_WAIT_FIFO -) 2> >( - stdbuf -o0 tr '\r' '\n' >&2 - printf x >$DET_LOG_WAIT_FIFO -) - -((DET_LOG_WAIT_COUNT += 2)) - -# As shell exits, wait for stdout/stderr processors to complete -trap 'source /run/determined/task-logging-teardown.sh' EXIT diff --git a/master/static/srv/task-logging-teardown.sh b/master/static/srv/task-logging-teardown.sh deleted file mode 100644 index 97af181a25aa..000000000000 --- a/master/static/srv/task-logging-teardown.sh +++ /dev/null @@ -1,39 +0,0 @@ -#!/usr/bin/env bash - -# Replace overridden stdout and stderr with original and close them, since the -# command is finished. -exec >&1- >&2- 1>&$ORIGINAL_STDOUT 2>&$ORIGINAL_STDERR - -# We use the bash builtin printf for getting the epoch time in seconds. -# This requires bash 4.2 (from 2011) and it depends on strftime(3) supporting -# the %s directive, which is not in posix. -epoch_seconds() { - printf '%(%s)T\n' -1 -} - -# Wait for up to DET_LOG_WAIT_TIME seconds for the logging to finish -# At this point the launching entry point script has exited and we are -# waiting for the child logging processes to complete. The child -# processes will exit when reaching EOF of the log stream they are -# processing, so in the normal case they terminate quickly and -# each stream processor writes a single character to the DET_LOG_WAIT_FIFO -# to indicate they are no longer waiting. -# -# This wait time it to handle the case when log stream procesors have -# not exited yet -- either becuase someone is still writing to the stream, -# or the DET_MASTER is not reachable and therefor we are slow in flushing -# the logs to the master. -# -# After this wait, the container entrypoint immediately exits and all -# processing within the container is SIGKILLed without any opportunity -# for any furhter processing, so avoiding a premature exit is important. -waitfor="${DET_LOG_WAIT_TIME:-30}" - -deadline="$(($(epoch_seconds) + waitfor))" -timeout="$((deadline - $(epoch_seconds)))" - -# read returns 1 on EOF or >128 with timeout, but it's a fifo so that is OK. -# For read's -t timeout feature to work, we need to open the fifo for -# reading and writing for some reason, which is what the `<>` is for. -# See https://stackoverflow.com/a/6448737. -read -N $DET_LOG_WAIT_COUNT -t "$timeout" <>"$DET_LOG_WAIT_FIFO" || true diff --git a/master/static/srv/task-setup.sh b/master/static/srv/task-setup.sh new file mode 100644 index 000000000000..fc334099be15 --- /dev/null +++ b/master/static/srv/task-setup.sh @@ -0,0 +1,41 @@ +#!/usr/bin/env bash + +STDOUT_FILE=/run/determined/train/logs/stdout.log +STDERR_FILE=/run/determined/train/logs/stderr.log + +mkdir -p "$(dirname "$STDOUT_FILE")" "$(dirname "$STDERR_FILE")" + +# Create symbolic links from well-known files to this process's STDOUT and +# STDERR. Anything written to those files will be inserted into the output +# streams of this process, allowing distributed training logs to route through +# individual containers rather than all going through SSH back to agent 0. +ln -sf /proc/$$/fd/1 "$STDOUT_FILE" +ln -sf /proc/$$/fd/2 "$STDERR_FILE" + +export PATH="/run/determined/pythonuserbase/bin:$PATH" +if [ -z "$DET_PYTHON_EXECUTABLE" ]; then + export DET_PYTHON_EXECUTABLE="python3" +fi + +if ! "$DET_PYTHON_EXECUTABLE" --version >/dev/null 2>&1; then + echo "{\"log\": \"error: unable to find python3 as '$DET_PYTHON_EXECUTABLE'\n\", \"timestamp\": \"$(date --rfc-3339=seconds)\"}" >&2 + echo "{\"log\": \"please install python3 or set the environment variable DET_PYTHON_EXECUTABLE=/path/to/python3\n\", \"timestamp\": "$(date --rfc-3339=seconds)"}" >&2 + exit 1 +fi + +if [ -z "$DET_SKIP_PIP_INSTALL" ]; then + "$DET_PYTHON_EXECUTABLE" -m pip install -q --user /opt/determined/wheels/determined*.whl +else + if ! "$DET_PYTHON_EXECUTABLE" -c "import determined" >/dev/null 2>&1; then + echo "{\"log\": \"error: unable run without determined package\n\", \"timestamp\": \"$(date --rfc-3339=seconds)\"}" >&2 + exit 1 + fi +fi + +if [ "$DET_RESOURCES_TYPE" == "slurm-job" ]; then + # Each container sends the Determined Master a notification that it's + # running, so that the Determined Master knows whether to set the state + # of the experiment to "Pulling", meaning some nodes are pulling down + # the image, or "Running", meaning that all containers are running. + "$DET_PYTHON_EXECUTABLE" -m determined.exec.prep_container --notify_container_running +fi diff --git a/master/static/srv/task-signal-handling.sh b/master/static/srv/task-signal-handling.sh deleted file mode 100644 index 638ff31b6d28..000000000000 --- a/master/static/srv/task-signal-handling.sh +++ /dev/null @@ -1,45 +0,0 @@ -#!/usr/bin/env bash - -trap_and_forward_signals() { - handle_signals() { - sig="$1" - shift - trapped_signal="yes" - if [ "${wait_child_pid+x}" ]; then - # If the child process isn't alive yet, then this is OK, whoever can just resend the signal. - kill -s "$sig" "${wait_child_pid}" 2>/dev/null - fi - } - - trap_and_capture_signal() { - func="$1" - shift - for sig in "$@"; do - trap "$func $sig" "$sig" - done - } - - unset wait_child_pid - unset trapped_signal - trap_and_capture_signal 'handle_signals' TERM INT SIGUSR1 SIGUSR2 -} - -wait_and_handle_signals() { - wait_child_pid=$1 - - while true; do - set +e - wait $wait_child_pid - wait_child_exit=$? - set -e - - # When a signal is sent to the shell, it will interrupt waits, after all traps have run. To - # discern if the wait unblocked because of a signal or process exit, we set "trapped_signal" - # in traps and check it here. - if [ -z "${trapped_signal+x}" ]; then - exit $wait_child_exit - else - unset trapped_signal - fi - done -} diff --git a/master/static/srv/tensorboard-entrypoint.sh b/master/static/srv/tensorboard-entrypoint.sh index f5853355f203..b8b2947615a8 100755 --- a/master/static/srv/tensorboard-entrypoint.sh +++ b/master/static/srv/tensorboard-entrypoint.sh @@ -1,7 +1,6 @@ #!/bin/bash -source /run/determined/task-signal-handling.sh -source /run/determined/task-logging-setup.sh +source /run/determined/task-setup.sh set -e @@ -26,7 +25,5 @@ READINESS_REGEX="TensorBoard contains metrics" WAITING_REGEX="TensorBoard waits on metrics" TENSORBOARD_VERSION=$("$DET_PYTHON_EXECUTABLE" -c "import tensorboard; print(tensorboard.__version__)") -trap_and_forward_signals "$DET_PYTHON_EXECUTABLE" -m determined.exec.tensorboard "$TENSORBOARD_VERSION" "$@" \ - > >(tee -p >("$DET_PYTHON_EXECUTABLE" /run/determined/check_ready_logs.py --ready-regex "$READINESS_REGEX" --waiting-regex "$WAITING_REGEX")) & -wait_and_handle_signals $! + > >(tee -p >("$DET_PYTHON_EXECUTABLE" /run/determined/check_ready_logs.py --ready-regex "$READINESS_REGEX" --waiting-regex "$WAITING_REGEX"))