Skip to content

Commit

Permalink
fix: use task sessions in Core API [MD-509] (#9860)
Browse files Browse the repository at this point in the history
use task session tokens in CoreContext instead of user tokens
  • Loading branch information
azhou-determined authored Aug 23, 2024
1 parent 3ee88bb commit a55af74
Show file tree
Hide file tree
Showing 18 changed files with 122 additions and 34 deletions.
2 changes: 1 addition & 1 deletion harness/determined/common/api/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from determined.common.api import authentication, errors, metric, bindings
from determined.common.api._session import BaseSession, UnauthSession, Session
from determined.common.api._session import BaseSession, UnauthSession, Session, TaskSession
from determined.common.api._util import (
PageOpts,
get_ntsc_details,
Expand Down
21 changes: 20 additions & 1 deletion harness/determined/common/api/_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,8 @@ class Session(BaseSession):
By far, most BaseSessions in the codebase will be this Session subclass.
"""

AUTH_HEADER = "Authorization"

def __init__(
self,
master: str,
Expand Down Expand Up @@ -308,10 +310,27 @@ def _make_http_session(self) -> requests.Session:
server_hostname=self.cert.name if self.cert else None,
verify=self.cert.bundle if self.cert else None,
max_retries=self._max_retries,
headers={"Authorization": f"Bearer {self.token}"},
headers={self.AUTH_HEADER: f"Bearer {self.token}"},
)


class TaskSession(Session):
"""
``TaskSession`` is a subclass of ``Session`` designed to be used for authenticating requests
using a task session token. It simply overrides the authentication header name used for
requests.
Most sessions that are created from user input should use ``Session`` instead, which
authenticates requests using a user token (i.e. the CLI, SDK).
Task session tokens really only have a longer expiration, and should be used for internal
sessions that may persist throughout a long training job (i.e. Core API).
"""

# Authentication header name for task session tokens
AUTH_HEADER = "Grpc-Metadata-x-allocation-token"


class _HTTPSAdapter(adapters.HTTPAdapter):
"""Overrides the hostname checked against for TLS verification.
Expand Down
25 changes: 25 additions & 0 deletions harness/determined/common/api/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ def get_det_password_from_env() -> Optional[str]:
return os.environ.get("DET_PASS")


def get_det_session_token_from_env() -> Optional[str]:
return os.environ.get("DET_SESSION_TOKEN")


def login(
master_address: str,
username: str,
Expand Down Expand Up @@ -307,6 +311,27 @@ def logout_all(master_address: str, cert: Optional[certs.Cert]) -> None:
logout(master_address, user, cert)


def login_from_task(
master_address: str,
cert: Optional[certs.Cert],
) -> "api.TaskSession":
"""
Creates a ``TaskSession`` from environment variables to be used for authenticating subsequent
requests.
This method should only be called on-cluster, from inside a task container.
"""
session_token = get_det_session_token_from_env()
if not session_token:
raise ValueError("DET_SESSION_TOKEN environment variable not set.")

username = get_det_username_from_env()
if not username:
raise ValueError("DET_USER environment variable not set.")

return api.TaskSession(master=master_address, username=username, token=session_token, cert=cert)


def _is_token_valid(master_address: str, token: str, cert: Optional[certs.Cert]) -> bool:
"""
Find out whether the given token is valid by attempting to use it
Expand Down
7 changes: 4 additions & 3 deletions harness/determined/core/_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,9 +242,10 @@ def init(

# We are on the cluster.
cert = certs.default_load(info.master_url)
session = authentication.login_with_cache(info.master_url, cert=cert).with_retry(
util.get_max_retries_config()
)
session = authentication.login_from_task(
master_address=info.master_url,
cert=cert,
).with_retry(util.get_max_retries_config())

if distributed is None:
if len(info.container_addrs) > 1 or len(info.slot_ids) > 1:
Expand Down
2 changes: 1 addition & 1 deletion harness/determined/exec/gc_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def patch_checkpoints(storage_ids_to_resources: Dict[str, Dict[str, int]]) -> No

cert = certs.default_load(info.master_url)
# With backoff retries for 64 seconds
sess = authentication.login_with_cache(info.master_url, cert=cert).with_retry(
sess = authentication.login_from_task(info.master_url, cert=cert).with_retry(
urllib3.util.retry.Retry(total=6, backoff_factor=0.5)
)

Expand Down
2 changes: 1 addition & 1 deletion harness/determined/exec/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def trigger_preemption(signum: int, frame: types.FrameType) -> None:
logger.info("SIGTERM: Preemption imminent.")
# Notify the master that we need to be preempted
cert = certs.default_load(info.master_url)
sess = authentication.login_with_cache(info.master_url, cert=cert)
sess = authentication.login_from_task(info.master_url, cert=cert)
sess.post(f"/api/v1/allocations/{info.allocation_id}/signals/pending_preemption")


Expand Down
2 changes: 1 addition & 1 deletion harness/determined/exec/prep_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ def do_proxy(sess: api.Session, allocation_id: str) -> None:

cert = certs.default_load(info.master_url)
# With backoff retries for 64 seconds
sess = authentication.login_with_cache(info.master_url, cert=cert).with_retry(
sess = authentication.login_from_task(info.master_url, cert=cert).with_retry(
urllib3.util.retry.Retry(total=6, backoff_factor=0.5)
)

Expand Down
8 changes: 4 additions & 4 deletions harness/determined/experimental/core_v2/_core_context_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import determined as det
from determined import core, experimental, tensorboard
from determined.common import constants, storage, util
from determined.common import api, constants, storage, util
from determined.common.api import authentication, certs

logger = logging.getLogger("determined.core")
Expand Down Expand Up @@ -43,9 +43,9 @@ def _make_v2_context(

# We are on the cluster.
cert = certs.default_load(info.master_url)
session = authentication.login_with_cache(info.master_url, cert=cert).with_retry(
util.get_max_retries_config()
)
session: api.Session = authentication.login_from_task(
info.master_url, cert=cert
).with_retry(util.get_max_retries_config())
else:
unmanaged = True

Expand Down
2 changes: 1 addition & 1 deletion harness/determined/launch/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def main(script: List[str]) -> int:
# Mark sshd containers as daemon containers that the master should kill when all non-daemon
# containers (deepspeed launcher, in this case) have exited.
cert = certs.default_load(info.master_url)
sess = authentication.login_with_cache(info.master_url, cert=cert)
sess = authentication.login_from_task(info.master_url, cert=cert)
sess.post(f"/api/v1/allocations/{info.allocation_id}/resources/{resources_id}/daemon")

# Wrap it in a pid_server to ensure that we can't hang if a worker fails.
Expand Down
2 changes: 1 addition & 1 deletion harness/determined/launch/horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def main(hvd_args: List[str], script: List[str], autohorovod: bool) -> int:
# Mark sshd containers as daemon resources that the master should kill when all non-daemon
# containers (horovodrun, in this case) have exited.
cert = certs.default_load(info.master_url)
sess = authentication.login_with_cache(info.master_url, cert=cert)
sess = authentication.login_from_task(info.master_url, cert=cert)
sess.post(f"/api/v1/allocations/{info.allocation_id}/resources/{resources_id}/daemon")

pid_server_cmd, run_sshd_command = create_sshd_worker_cmd(
Expand Down
27 changes: 26 additions & 1 deletion harness/tests/cli/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from determined.cli import cli
from determined.common import api
from determined.common.api import authentication
from determined.common.api import authentication, certs
from tests.cli import util

MOCK_MASTER_URL = "http://localhost:8080"
Expand Down Expand Up @@ -439,3 +439,28 @@ def test_logout_all() -> None:
mts.clear_active()

cli.main(["user", "logout", "--all"])


def test_login_from_task() -> None:
mock_session_token = "abcde12345"
mock_user = "abababa"
mock_cert = certs.Cert()
with contextlib.ExitStack() as es:
# Configure environment variables.
es.enter_context(util.setenv_optional("DET_SESSION_TOKEN", mock_session_token))
es.enter_context(util.setenv_optional("DET_USER", mock_user))

with responses.RequestsMock(
registry=registries.OrderedRegistry, assert_all_requests_are_fired=True
) as rsps:
sess = authentication.login_from_task(master_address=MOCK_MASTER_URL, cert=mock_cert)
assert sess.token == mock_session_token
assert sess.username == mock_user
assert sess.cert == mock_cert

rsps.get(
f"{MOCK_MASTER_URL}/api/v1/me",
status=200,
match=[matchers.header_matcher({sess.AUTH_HEADER: f"Bearer {mock_session_token}"})],
)
sess.get("/api/v1/me")
18 changes: 12 additions & 6 deletions harness/tests/launch/test_deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def mock_process(cmd: List[str], *args: Any, **kwargs: Any) -> Any:

mock_subprocess.side_effect = mock_process

with test_util.set_resources_id_env_var():
with test_util.set_env_vars({"DET_RESOURCES_ID": "resourcesId"}):
deeplaunch.main(script)

mock_cluster_info.assert_called_once()
Expand Down Expand Up @@ -149,7 +149,7 @@ def mock_process(cmd: List[str], *args: Any, **kwargs: Any) -> Any:

mock_subprocess.side_effect = mock_process

with test_util.set_resources_id_env_var():
with test_util.set_env_vars({"DET_RESOURCES_ID": "resourcesId"}):
with pytest.raises(ValueError, match="no sshd greeting"):
deeplaunch.main(script)

Expand Down Expand Up @@ -195,7 +195,7 @@ def test_launch_one_slot(
log_redirect_cmd = deeplaunch.create_log_redirect_cmd()
launch_cmd = pid_server_cmd + deepspeed_cmd + pid_client_cmd + log_redirect_cmd + script

with test_util.set_resources_id_env_var():
with test_util.set_env_vars({"DET_RESOURCES_ID": "resourcesId"}):
deeplaunch.main(script)

mock_cluster_info.assert_called_once()
Expand Down Expand Up @@ -223,7 +223,7 @@ def test_launch_fail(mock_cluster_info: mock.MagicMock, mock_subprocess: mock.Ma
log_redirect_cmd = deeplaunch.create_log_redirect_cmd()
launch_cmd = pid_server_cmd + deepspeed_cmd + pid_client_cmd + log_redirect_cmd + script

with test_util.set_resources_id_env_var():
with test_util.set_env_vars({"DET_RESOURCES_ID": "resourcesId"}):
assert deeplaunch.main(script) == 1

mock_cluster_info.assert_called_once()
Expand All @@ -235,7 +235,7 @@ def test_launch_fail(mock_cluster_info: mock.MagicMock, mock_subprocess: mock.Ma

@mock.patch("subprocess.Popen")
@mock.patch("determined.get_cluster_info")
@mock.patch("determined.common.api.authentication.login_with_cache")
@mock.patch("determined.common.api.authentication.login_from_task")
def test_launch_worker(
mock_login: mock.MagicMock,
mock_cluster_info: mock.MagicMock,
Expand All @@ -245,7 +245,13 @@ def test_launch_worker(
mock_cluster_info.return_value = cluster_info
mock_session = mock.MagicMock()
mock_login.return_value = mock_session
with test_util.set_resources_id_env_var():
with test_util.set_env_vars(
{
"DET_RESOURCES_ID": "resourcesId",
"DET_SESSION_TOKEN": cluster_info.session_token,
"DET_USER": "abcd",
}
):
deeplaunch.main(["script"])

mock_cluster_info.assert_called_once()
Expand Down
12 changes: 9 additions & 3 deletions harness/tests/launch/test_horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def test_horovod_chief(
os.environ.pop("DET_CHIEF_IP", None)
os.environ.pop("USE_HOROVOD", None)

with test_util.set_resources_id_env_var():
with test_util.set_env_vars({"DET_RESOURCES_ID": "resourcesId"}):
assert launch.horovod.main(hvd_args, script, autohorovod) == 99

if autohorovod and nnodes == 1 and nslots == 1:
Expand Down Expand Up @@ -116,7 +116,7 @@ def test_horovod_chief(

@mock.patch("subprocess.Popen")
@mock.patch("determined.get_cluster_info")
@mock.patch("determined.common.api.authentication.login_with_cache")
@mock.patch("determined.common.api.authentication.login_from_task")
def test_sshd_worker(
mock_login: mock.MagicMock,
mock_cluster_info: mock.MagicMock,
Expand All @@ -143,7 +143,13 @@ def test_sshd_worker(

os.environ.pop("DET_CHIEF_IP", None)

with test_util.set_resources_id_env_var():
with test_util.set_env_vars(
{
"DET_RESOURCES_ID": "resourcesId",
"DET_SESSION_TOKEN": info.session_token,
"DET_USER": "abcd",
}
):
assert launch.horovod.main(hvd_args, script, True) == 99

mock_cluster_info.assert_called_once()
Expand Down
2 changes: 1 addition & 1 deletion harness/tests/launch/test_launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def test_launch_list() -> None:
@mock.patch("determined.common.storage.validate_config")
def test_launch_script(mock_validate_config: mock.MagicMock) -> None:
# Use runpy to actually run the whole launch script.
with test_util.set_resources_id_env_var():
with test_util.set_env_vars({"DET_RESOURCES_ID": "resourcesId"}):
with test_util.set_mock_cluster_info(["0.0.0.1"], 0, 1) as info:
# Successful entrypoints exit 0.
info.trial._config["entrypoint"] = ["true"]
Expand Down
4 changes: 2 additions & 2 deletions harness/tests/launch/test_torch_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_launch_single_slot(
script = ["python3", "-m", "determined.exec.harness", "my_module:MyTrial"]
override_args = ["--max_restarts", "1"]

with test_util.set_resources_id_env_var():
with test_util.set_env_vars({"DET_RESOURCES_ID": "resourcesId"}):
launch.torch_distributed.main(override_args, script)

launch_cmd = launch.torch_distributed.create_pid_server_cmd(
Expand Down Expand Up @@ -78,7 +78,7 @@ def test_launch_distributed(

mock_subprocess.return_value = mock_proc

with test_util.set_resources_id_env_var():
with test_util.set_env_vars({"DET_RESOURCES_ID": "resourcesId"}):
assert launch.torch_distributed.main(override_args, script) == mock_success_code

launch_cmd = launch.torch_distributed.create_pid_server_cmd(
Expand Down
8 changes: 5 additions & 3 deletions harness/tests/launch/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,14 @@ def set_mock_cluster_info(


@contextlib.contextmanager
def set_resources_id_env_var() -> Iterator[None]:
def set_env_vars(env_vars: Dict[str, str]) -> Iterator[None]:
try:
os.environ["DET_RESOURCES_ID"] = "resourcesId"
for k, v in env_vars.items():
os.environ[k] = v
yield
finally:
del os.environ["DET_RESOURCES_ID"]
for k, _ in env_vars.items():
del os.environ[k]


def parse_args_check(positive_cases: Dict, negative_cases: Dict, parse_func: Callable) -> None:
Expand Down
2 changes: 1 addition & 1 deletion master/static/srv/check_idle.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def main():
notebook_server = f"https://127.0.0.1:{port}/proxy/{notebook_id}"
master_url = api.canonicalize_master_url(os.environ["DET_MASTER"])
cert = certs.default_load(master_url)
sess = authentication.login_with_cache(master_url, cert=cert)
sess = authentication.login_from_task(master_url, cert=cert)
try:
idle_type = IdleType[os.environ["NOTEBOOK_IDLE_TYPE"].upper()]
except KeyError:
Expand Down
10 changes: 7 additions & 3 deletions master/static/srv/check_ready_logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def main(ready: Pattern, waiting: Optional[Pattern] = None):
cert = certs.default_load(master_url)
# This only runs on-cluster, so it is expected the username and session token are present in the
# environment.
sess = authentication.login_with_cache(master_url, cert=cert)
sess = authentication.login_from_task(master_url, cert=cert)
allocation_id = str(os.environ["DET_ALLOCATION_ID"])
for line in sys.stdin:
if ready.match(line):
Expand All @@ -49,11 +49,15 @@ def main(ready: Pattern, waiting: Optional[Pattern] = None):


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Read STDIN for a match and mark a task as ready")
parser = argparse.ArgumentParser(
description="Read STDIN for a match and mark a task as ready"
)
parser.add_argument(
"--ready-regex", type=str, help="the pattern to match task ready", required=True
)
parser.add_argument("--waiting-regex", type=str, help="the pattern to match task waiting")
parser.add_argument(
"--waiting-regex", type=str, help="the pattern to match task waiting"
)
args = parser.parse_args()

ready_regex = re.compile(args.ready_regex)
Expand Down

0 comments on commit a55af74

Please sign in to comment.