diff --git a/harness/determined/common/api/__init__.py b/harness/determined/common/api/__init__.py index b70004ba126..90f8abbd147 100644 --- a/harness/determined/common/api/__init__.py +++ b/harness/determined/common/api/__init__.py @@ -1,5 +1,5 @@ from determined.common.api import authentication, errors, metric, bindings -from determined.common.api._session import BaseSession, UnauthSession, Session +from determined.common.api._session import BaseSession, UnauthSession, Session, TaskSession from determined.common.api._util import ( PageOpts, get_ntsc_details, diff --git a/harness/determined/common/api/_session.py b/harness/determined/common/api/_session.py index e7c5f306e6e..5baeccbe365 100644 --- a/harness/determined/common/api/_session.py +++ b/harness/determined/common/api/_session.py @@ -281,6 +281,8 @@ class Session(BaseSession): By far, most BaseSessions in the codebase will be this Session subclass. """ + AUTH_HEADER = "Authorization" + def __init__( self, master: str, @@ -308,10 +310,27 @@ def _make_http_session(self) -> requests.Session: server_hostname=self.cert.name if self.cert else None, verify=self.cert.bundle if self.cert else None, max_retries=self._max_retries, - headers={"Authorization": f"Bearer {self.token}"}, + headers={self.AUTH_HEADER: f"Bearer {self.token}"}, ) +class TaskSession(Session): + """ + ``TaskSession`` is a subclass of ``Session`` designed to be used for authenticating requests + using a task session token. It simply overrides the authentication header name used for + requests. + + Most sessions that are created from user input should use ``Session`` instead, which + authenticates requests using a user token (i.e. the CLI, SDK). + + Task session tokens really only have a longer expiration, and should be used for internal + sessions that may persist throughout a long training job (i.e. Core API). + """ + + # Authentication header name for task session tokens + AUTH_HEADER = "Grpc-Metadata-x-allocation-token" + + class _HTTPSAdapter(adapters.HTTPAdapter): """Overrides the hostname checked against for TLS verification. diff --git a/harness/determined/common/api/authentication.py b/harness/determined/common/api/authentication.py index 7eea446c6b0..50109125b80 100644 --- a/harness/determined/common/api/authentication.py +++ b/harness/determined/common/api/authentication.py @@ -101,6 +101,10 @@ def get_det_password_from_env() -> Optional[str]: return os.environ.get("DET_PASS") +def get_det_session_token_from_env() -> Optional[str]: + return os.environ.get("DET_SESSION_TOKEN") + + def login( master_address: str, username: str, @@ -307,6 +311,27 @@ def logout_all(master_address: str, cert: Optional[certs.Cert]) -> None: logout(master_address, user, cert) +def login_from_task( + master_address: str, + cert: Optional[certs.Cert], +) -> "api.TaskSession": + """ + Creates a ``TaskSession`` from environment variables to be used for authenticating subsequent + requests. + + This method should only be called on-cluster, from inside a task container. + """ + session_token = get_det_session_token_from_env() + if not session_token: + raise ValueError("DET_SESSION_TOKEN environment variable not set.") + + username = get_det_username_from_env() + if not username: + raise ValueError("DET_USER environment variable not set.") + + return api.TaskSession(master=master_address, username=username, token=session_token, cert=cert) + + def _is_token_valid(master_address: str, token: str, cert: Optional[certs.Cert]) -> bool: """ Find out whether the given token is valid by attempting to use it diff --git a/harness/determined/core/_context.py b/harness/determined/core/_context.py index 827a355b952..1e5f200f6db 100644 --- a/harness/determined/core/_context.py +++ b/harness/determined/core/_context.py @@ -242,9 +242,10 @@ def init( # We are on the cluster. cert = certs.default_load(info.master_url) - session = authentication.login_with_cache(info.master_url, cert=cert).with_retry( - util.get_max_retries_config() - ) + session = authentication.login_from_task( + master_address=info.master_url, + cert=cert, + ).with_retry(util.get_max_retries_config()) if distributed is None: if len(info.container_addrs) > 1 or len(info.slot_ids) > 1: diff --git a/harness/determined/exec/gc_checkpoints.py b/harness/determined/exec/gc_checkpoints.py index 8baa28e0370..0693962f39f 100644 --- a/harness/determined/exec/gc_checkpoints.py +++ b/harness/determined/exec/gc_checkpoints.py @@ -27,7 +27,7 @@ def patch_checkpoints(storage_ids_to_resources: Dict[str, Dict[str, int]]) -> No cert = certs.default_load(info.master_url) # With backoff retries for 64 seconds - sess = authentication.login_with_cache(info.master_url, cert=cert).with_retry( + sess = authentication.login_from_task(info.master_url, cert=cert).with_retry( urllib3.util.retry.Retry(total=6, backoff_factor=0.5) ) diff --git a/harness/determined/exec/launch.py b/harness/determined/exec/launch.py index 9ef077a119d..96502ed2c78 100644 --- a/harness/determined/exec/launch.py +++ b/harness/determined/exec/launch.py @@ -22,7 +22,7 @@ def trigger_preemption(signum: int, frame: types.FrameType) -> None: logger.info("SIGTERM: Preemption imminent.") # Notify the master that we need to be preempted cert = certs.default_load(info.master_url) - sess = authentication.login_with_cache(info.master_url, cert=cert) + sess = authentication.login_from_task(info.master_url, cert=cert) sess.post(f"/api/v1/allocations/{info.allocation_id}/signals/pending_preemption") diff --git a/harness/determined/exec/prep_container.py b/harness/determined/exec/prep_container.py index 18e27e76584..392bcdda1e7 100644 --- a/harness/determined/exec/prep_container.py +++ b/harness/determined/exec/prep_container.py @@ -388,7 +388,7 @@ def do_proxy(sess: api.Session, allocation_id: str) -> None: cert = certs.default_load(info.master_url) # With backoff retries for 64 seconds - sess = authentication.login_with_cache(info.master_url, cert=cert).with_retry( + sess = authentication.login_from_task(info.master_url, cert=cert).with_retry( urllib3.util.retry.Retry(total=6, backoff_factor=0.5) ) diff --git a/harness/determined/experimental/core_v2/_core_context_v2.py b/harness/determined/experimental/core_v2/_core_context_v2.py index f3cb3623813..3a0c4c624a2 100644 --- a/harness/determined/experimental/core_v2/_core_context_v2.py +++ b/harness/determined/experimental/core_v2/_core_context_v2.py @@ -6,7 +6,7 @@ import determined as det from determined import core, experimental, tensorboard -from determined.common import constants, storage, util +from determined.common import api, constants, storage, util from determined.common.api import authentication, certs logger = logging.getLogger("determined.core") @@ -43,9 +43,9 @@ def _make_v2_context( # We are on the cluster. cert = certs.default_load(info.master_url) - session = authentication.login_with_cache(info.master_url, cert=cert).with_retry( - util.get_max_retries_config() - ) + session: api.Session = authentication.login_from_task( + info.master_url, cert=cert + ).with_retry(util.get_max_retries_config()) else: unmanaged = True diff --git a/harness/determined/launch/deepspeed.py b/harness/determined/launch/deepspeed.py index 347306a1ff1..6cff0c91d91 100644 --- a/harness/determined/launch/deepspeed.py +++ b/harness/determined/launch/deepspeed.py @@ -286,7 +286,7 @@ def main(script: List[str]) -> int: # Mark sshd containers as daemon containers that the master should kill when all non-daemon # containers (deepspeed launcher, in this case) have exited. cert = certs.default_load(info.master_url) - sess = authentication.login_with_cache(info.master_url, cert=cert) + sess = authentication.login_from_task(info.master_url, cert=cert) sess.post(f"/api/v1/allocations/{info.allocation_id}/resources/{resources_id}/daemon") # Wrap it in a pid_server to ensure that we can't hang if a worker fails. diff --git a/harness/determined/launch/horovod.py b/harness/determined/launch/horovod.py index a30f5bce83a..899a290c364 100644 --- a/harness/determined/launch/horovod.py +++ b/harness/determined/launch/horovod.py @@ -128,7 +128,7 @@ def main(hvd_args: List[str], script: List[str], autohorovod: bool) -> int: # Mark sshd containers as daemon resources that the master should kill when all non-daemon # containers (horovodrun, in this case) have exited. cert = certs.default_load(info.master_url) - sess = authentication.login_with_cache(info.master_url, cert=cert) + sess = authentication.login_from_task(info.master_url, cert=cert) sess.post(f"/api/v1/allocations/{info.allocation_id}/resources/{resources_id}/daemon") pid_server_cmd, run_sshd_command = create_sshd_worker_cmd( diff --git a/harness/tests/cli/test_auth.py b/harness/tests/cli/test_auth.py index 7a7cd6f66e0..7209a7716db 100644 --- a/harness/tests/cli/test_auth.py +++ b/harness/tests/cli/test_auth.py @@ -10,7 +10,7 @@ from determined.cli import cli from determined.common import api -from determined.common.api import authentication +from determined.common.api import authentication, certs from tests.cli import util MOCK_MASTER_URL = "http://localhost:8080" @@ -439,3 +439,28 @@ def test_logout_all() -> None: mts.clear_active() cli.main(["user", "logout", "--all"]) + + +def test_login_from_task() -> None: + mock_session_token = "abcde12345" + mock_user = "abababa" + mock_cert = certs.Cert() + with contextlib.ExitStack() as es: + # Configure environment variables. + es.enter_context(util.setenv_optional("DET_SESSION_TOKEN", mock_session_token)) + es.enter_context(util.setenv_optional("DET_USER", mock_user)) + + with responses.RequestsMock( + registry=registries.OrderedRegistry, assert_all_requests_are_fired=True + ) as rsps: + sess = authentication.login_from_task(master_address=MOCK_MASTER_URL, cert=mock_cert) + assert sess.token == mock_session_token + assert sess.username == mock_user + assert sess.cert == mock_cert + + rsps.get( + f"{MOCK_MASTER_URL}/api/v1/me", + status=200, + match=[matchers.header_matcher({sess.AUTH_HEADER: f"Bearer {mock_session_token}"})], + ) + sess.get("/api/v1/me") diff --git a/harness/tests/launch/test_deepspeed.py b/harness/tests/launch/test_deepspeed.py index c690a7dc6f8..c5ee63015db 100644 --- a/harness/tests/launch/test_deepspeed.py +++ b/harness/tests/launch/test_deepspeed.py @@ -75,7 +75,7 @@ def mock_process(cmd: List[str], *args: Any, **kwargs: Any) -> Any: mock_subprocess.side_effect = mock_process - with test_util.set_resources_id_env_var(): + with test_util.set_env_vars({"DET_RESOURCES_ID": "resourcesId"}): deeplaunch.main(script) mock_cluster_info.assert_called_once() @@ -149,7 +149,7 @@ def mock_process(cmd: List[str], *args: Any, **kwargs: Any) -> Any: mock_subprocess.side_effect = mock_process - with test_util.set_resources_id_env_var(): + with test_util.set_env_vars({"DET_RESOURCES_ID": "resourcesId"}): with pytest.raises(ValueError, match="no sshd greeting"): deeplaunch.main(script) @@ -195,7 +195,7 @@ def test_launch_one_slot( log_redirect_cmd = deeplaunch.create_log_redirect_cmd() launch_cmd = pid_server_cmd + deepspeed_cmd + pid_client_cmd + log_redirect_cmd + script - with test_util.set_resources_id_env_var(): + with test_util.set_env_vars({"DET_RESOURCES_ID": "resourcesId"}): deeplaunch.main(script) mock_cluster_info.assert_called_once() @@ -223,7 +223,7 @@ def test_launch_fail(mock_cluster_info: mock.MagicMock, mock_subprocess: mock.Ma log_redirect_cmd = deeplaunch.create_log_redirect_cmd() launch_cmd = pid_server_cmd + deepspeed_cmd + pid_client_cmd + log_redirect_cmd + script - with test_util.set_resources_id_env_var(): + with test_util.set_env_vars({"DET_RESOURCES_ID": "resourcesId"}): assert deeplaunch.main(script) == 1 mock_cluster_info.assert_called_once() @@ -235,7 +235,7 @@ def test_launch_fail(mock_cluster_info: mock.MagicMock, mock_subprocess: mock.Ma @mock.patch("subprocess.Popen") @mock.patch("determined.get_cluster_info") -@mock.patch("determined.common.api.authentication.login_with_cache") +@mock.patch("determined.common.api.authentication.login_from_task") def test_launch_worker( mock_login: mock.MagicMock, mock_cluster_info: mock.MagicMock, @@ -245,7 +245,13 @@ def test_launch_worker( mock_cluster_info.return_value = cluster_info mock_session = mock.MagicMock() mock_login.return_value = mock_session - with test_util.set_resources_id_env_var(): + with test_util.set_env_vars( + { + "DET_RESOURCES_ID": "resourcesId", + "DET_SESSION_TOKEN": cluster_info.session_token, + "DET_USER": "abcd", + } + ): deeplaunch.main(["script"]) mock_cluster_info.assert_called_once() diff --git a/harness/tests/launch/test_horovod.py b/harness/tests/launch/test_horovod.py index 18735322cb0..1da62f3b99c 100644 --- a/harness/tests/launch/test_horovod.py +++ b/harness/tests/launch/test_horovod.py @@ -86,7 +86,7 @@ def test_horovod_chief( os.environ.pop("DET_CHIEF_IP", None) os.environ.pop("USE_HOROVOD", None) - with test_util.set_resources_id_env_var(): + with test_util.set_env_vars({"DET_RESOURCES_ID": "resourcesId"}): assert launch.horovod.main(hvd_args, script, autohorovod) == 99 if autohorovod and nnodes == 1 and nslots == 1: @@ -116,7 +116,7 @@ def test_horovod_chief( @mock.patch("subprocess.Popen") @mock.patch("determined.get_cluster_info") -@mock.patch("determined.common.api.authentication.login_with_cache") +@mock.patch("determined.common.api.authentication.login_from_task") def test_sshd_worker( mock_login: mock.MagicMock, mock_cluster_info: mock.MagicMock, @@ -143,7 +143,13 @@ def test_sshd_worker( os.environ.pop("DET_CHIEF_IP", None) - with test_util.set_resources_id_env_var(): + with test_util.set_env_vars( + { + "DET_RESOURCES_ID": "resourcesId", + "DET_SESSION_TOKEN": info.session_token, + "DET_USER": "abcd", + } + ): assert launch.horovod.main(hvd_args, script, True) == 99 mock_cluster_info.assert_called_once() diff --git a/harness/tests/launch/test_launch.py b/harness/tests/launch/test_launch.py index 0a3cde56e08..9bd29b0c7aa 100644 --- a/harness/tests/launch/test_launch.py +++ b/harness/tests/launch/test_launch.py @@ -42,7 +42,7 @@ def test_launch_list() -> None: @mock.patch("determined.common.storage.validate_config") def test_launch_script(mock_validate_config: mock.MagicMock) -> None: # Use runpy to actually run the whole launch script. - with test_util.set_resources_id_env_var(): + with test_util.set_env_vars({"DET_RESOURCES_ID": "resourcesId"}): with test_util.set_mock_cluster_info(["0.0.0.1"], 0, 1) as info: # Successful entrypoints exit 0. info.trial._config["entrypoint"] = ["true"] diff --git a/harness/tests/launch/test_torch_distributed.py b/harness/tests/launch/test_torch_distributed.py index 8e4cbabcf0c..9a09c0f73a6 100644 --- a/harness/tests/launch/test_torch_distributed.py +++ b/harness/tests/launch/test_torch_distributed.py @@ -38,7 +38,7 @@ def test_launch_single_slot( script = ["python3", "-m", "determined.exec.harness", "my_module:MyTrial"] override_args = ["--max_restarts", "1"] - with test_util.set_resources_id_env_var(): + with test_util.set_env_vars({"DET_RESOURCES_ID": "resourcesId"}): launch.torch_distributed.main(override_args, script) launch_cmd = launch.torch_distributed.create_pid_server_cmd( @@ -78,7 +78,7 @@ def test_launch_distributed( mock_subprocess.return_value = mock_proc - with test_util.set_resources_id_env_var(): + with test_util.set_env_vars({"DET_RESOURCES_ID": "resourcesId"}): assert launch.torch_distributed.main(override_args, script) == mock_success_code launch_cmd = launch.torch_distributed.create_pid_server_cmd( diff --git a/harness/tests/launch/test_util.py b/harness/tests/launch/test_util.py index 3186c1ae8d6..319a20dde04 100644 --- a/harness/tests/launch/test_util.py +++ b/harness/tests/launch/test_util.py @@ -64,12 +64,14 @@ def set_mock_cluster_info( @contextlib.contextmanager -def set_resources_id_env_var() -> Iterator[None]: +def set_env_vars(env_vars: Dict[str, str]) -> Iterator[None]: try: - os.environ["DET_RESOURCES_ID"] = "resourcesId" + for k, v in env_vars.items(): + os.environ[k] = v yield finally: - del os.environ["DET_RESOURCES_ID"] + for k, _ in env_vars.items(): + del os.environ[k] def parse_args_check(positive_cases: Dict, negative_cases: Dict, parse_func: Callable) -> None: diff --git a/master/static/srv/check_idle.py b/master/static/srv/check_idle.py index 9949065b1ae..f80c4986655 100755 --- a/master/static/srv/check_idle.py +++ b/master/static/srv/check_idle.py @@ -83,7 +83,7 @@ def main(): notebook_server = f"https://127.0.0.1:{port}/proxy/{notebook_id}" master_url = api.canonicalize_master_url(os.environ["DET_MASTER"]) cert = certs.default_load(master_url) - sess = authentication.login_with_cache(master_url, cert=cert) + sess = authentication.login_from_task(master_url, cert=cert) try: idle_type = IdleType[os.environ["NOTEBOOK_IDLE_TYPE"].upper()] except KeyError: diff --git a/master/static/srv/check_ready_logs.py b/master/static/srv/check_ready_logs.py index ee409b10b19..51e1a4ba186 100644 --- a/master/static/srv/check_ready_logs.py +++ b/master/static/srv/check_ready_logs.py @@ -37,7 +37,7 @@ def main(ready: Pattern, waiting: Optional[Pattern] = None): cert = certs.default_load(master_url) # This only runs on-cluster, so it is expected the username and session token are present in the # environment. - sess = authentication.login_with_cache(master_url, cert=cert) + sess = authentication.login_from_task(master_url, cert=cert) allocation_id = str(os.environ["DET_ALLOCATION_ID"]) for line in sys.stdin: if ready.match(line): @@ -49,11 +49,15 @@ def main(ready: Pattern, waiting: Optional[Pattern] = None): if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Read STDIN for a match and mark a task as ready") + parser = argparse.ArgumentParser( + description="Read STDIN for a match and mark a task as ready" + ) parser.add_argument( "--ready-regex", type=str, help="the pattern to match task ready", required=True ) - parser.add_argument("--waiting-regex", type=str, help="the pattern to match task waiting") + parser.add_argument( + "--waiting-regex", type=str, help="the pattern to match task waiting" + ) args = parser.parse_args() ready_regex = re.compile(args.ready_regex)