From 641977d66a4757b60e11b01f05b611af39328123 Mon Sep 17 00:00:00 2001 From: Ryan Date: Fri, 5 Jan 2024 16:50:53 -0700 Subject: [PATCH] fix: don't drop active_user on expired tokens [MLG-1494] The authentication logic would drop expired tokens, and then ask for a password and, if that worked, add the fresh token back to the token store. However, the initial token drop had a side-effect of dropping the active user, which was not being restored. After this fix, the TokenStore doesn't try to be so smart. Instead, the logout-related CLI commands will explicitly clear the active user, but tokens dropped due to expiration are unaffected. --- .../determined/common/api/authentication.py | 17 ++++- harness/tests/cli/test_auth.py | 66 +------------------ harness/tests/cli/util.py | 13 +++- .../tests/{cli => common/api}/auth_v0.json | 0 .../tests/common/api/test_authentication.py | 58 ++++++++++++++++ harness/tests/common/api/test_certs.py | 33 ++++++++++ 6 files changed, 119 insertions(+), 68 deletions(-) rename harness/tests/{cli => common/api}/auth_v0.json (100%) create mode 100644 harness/tests/common/api/test_authentication.py create mode 100644 harness/tests/common/api/test_certs.py diff --git a/harness/determined/common/api/authentication.py b/harness/determined/common/api/authentication.py index 3f55e44768b..eff03425e15 100644 --- a/harness/determined/common/api/authentication.py +++ b/harness/determined/common/api/authentication.py @@ -204,6 +204,9 @@ def logout( ) -> None: """ Logout if there is an active session for this master/username pair, otherwise do nothing. + + Additionally, if the user happens to be the active user, drop the active user from the + TokenStore. """ master_address = master_address or util.get_default_master_address() @@ -215,6 +218,9 @@ def logout( if session_user is None: return + if session_user == token_store.get_active_user(): + token_store.clear_active() + session_token = token_store.get_token(session_user) if session_token is None: @@ -240,6 +246,8 @@ def logout_all(master_address: Optional[str], cert: Optional[certs.Cert]) -> Non for user in users: logout(master_address, user, cert) + token_store.clear_active() + def _is_token_valid(master_address: str, token: str, cert: Optional[certs.Cert]) -> bool: """ @@ -258,7 +266,8 @@ def _is_token_valid(master_address: str, token: str, cert: Optional[certs.Cert]) class TokenStore: """ TokenStore is a class for reading/updating a persistent store of user authentication tokens. - TokenStore can remember tokens for many users for each of many masters. + TokenStore can remember tokens for many users for each of many masters. It can also remembers + one "active user" for each master, which is set via `det user login`. All updates to the file follow a read-modify-write pattern, and use file locks to protect the integrity of the underlying file cache. @@ -304,8 +313,6 @@ def drop_user(self, username: str) -> None: tokens = substore.setdefault("tokens", {}) if username in tokens: del tokens[username] - if substore.get("active_user") == username: - del substore["active_user"] def set_token(self, username: str, token: str) -> None: with self._persistent_store() as substore: @@ -319,6 +326,10 @@ def set_active(self, username: str) -> None: raise api.errors.UnauthenticatedException(username=username) substore["active_user"] = username + def clear_active(self) -> None: + with self._persistent_store() as substore: + substore.pop("active_user", None) + @contextlib.contextmanager def _persistent_store(self) -> Iterator[Dict[str, Any]]: """ diff --git a/harness/tests/cli/test_auth.py b/harness/tests/cli/test_auth.py index 8a32bef881b..fd68411d7c3 100644 --- a/harness/tests/cli/test_auth.py +++ b/harness/tests/cli/test_auth.py @@ -1,9 +1,6 @@ import collections import contextlib import itertools -import json -import shutil -from pathlib import Path from typing import Any, Iterator, List, Tuple, no_type_check from unittest import mock @@ -12,25 +9,10 @@ from responses import matchers, registries from determined.cli import cli -from determined.common.api import authentication, certs +from determined.common.api import authentication from tests.cli import util -from tests.confdir import use_test_config_dir MOCK_MASTER_URL = "http://localhost:8080" -AUTH_V0_PATH = Path(__file__).parent / "auth_v0.json" -UNTRUSTED_CERT_PATH = Path(__file__).parents[1] / "common" / "untrusted-root" / "127.0.0.1-ca.crt" -AUTH_JSON = { - "version": 1, - "masters": { - "http://localhost:8080": { - "active_user": "bob", - "tokens": { - "determined": "det.token", - "bob": "bob.token", - }, - } - }, -} class ScenarioSetMeta(type): @@ -342,6 +324,7 @@ def __init__(self, user: str) -> None: self.user = user def expect(self, rsps: responses.RequestsMock, mts: util.MockTokenStore, scenario: Any) -> None: + mts.get_active_user(retval=None) mts.get_token(self.user, retval="cache_token" if scenario.user_in_cache else None) @@ -420,48 +403,3 @@ def test_logout(scenario_set: Logout) -> None: else: cmd = ["user", "logout"] cli.main(cmd) - - -def test_auth_json_v0_upgrade() -> None: - with use_test_config_dir() as config_dir: - auth_json_path = config_dir / "auth.json" - shutil.copy2(AUTH_V0_PATH, auth_json_path) - ts = authentication.TokenStore(MOCK_MASTER_URL, auth_json_path) - - assert ts.get_active_user() == "determined" - assert ts.get_token("determined") == "v2.public.this.is.a.test" - - ts.set_token("determined", "ai") - - ts2 = authentication.TokenStore(MOCK_MASTER_URL, auth_json_path) - assert ts2.get_token("determined") == "ai" - - with auth_json_path.open() as fin: - data = json.load(fin) - assert data.get("version") == 1 - assert "masters" in data and list(data["masters"].keys()) == [MOCK_MASTER_URL] - - -def test_cert_v0_upgrade() -> None: - with use_test_config_dir() as config_dir: - cert_path = config_dir / "master.crt" - shutil.copy2(UNTRUSTED_CERT_PATH, cert_path) - with cert_path.open() as fin: - cert_data = fin.read() - - cert = certs.default_load(MOCK_MASTER_URL) - assert isinstance(cert.bundle, str) - with open(cert.bundle) as fin: - loaded_cert_data = fin.read() - assert loaded_cert_data.endswith(cert_data) - assert not cert_path.exists() - - v1_certs_path = config_dir / "certs.json" - assert v1_certs_path.exists() - - # Load once again from v1. - cert2 = certs.default_load(MOCK_MASTER_URL) - assert isinstance(cert2.bundle, str) - with open(cert2.bundle) as fin: - loaded_cert_data = fin.read() - assert loaded_cert_data.endswith(cert_data) diff --git a/harness/tests/cli/util.py b/harness/tests/cli/util.py index 0b35c346e47..cbe674c68fe 100644 --- a/harness/tests/cli/util.py +++ b/harness/tests/cli/util.py @@ -56,7 +56,11 @@ def _match_call(self, call: Any) -> Any: if self._strict: if self._ncalls == len(self._exp_calls): raise ValueError(f"unexpected call to TokenStore: {call}") - assert self._exp_calls[self._ncalls] == call + if self._exp_calls[self._ncalls] != call: + raise ValueError( + f"mismstached call to TokenStore: expected {self._exp_calls[self._ncalls]} " + f"but got {call}" + ) retval = self._retvals[self._ncalls] self._ncalls += 1 return retval @@ -91,6 +95,10 @@ def set_active(self, username: str) -> None: self._exp_calls.append(("set_active", username)) self._retvals.append(None) + def clear_active(self) -> None: + self._exp_calls.append("clear_active") + self._retvals.append(None) + class MockTokenStoreInstance: def __init__(self, mts: MockTokenStore) -> None: @@ -114,6 +122,9 @@ def set_token(self, username: str, token: str) -> None: def set_active(self, username: str) -> None: self._mts._match_call(("set_active", username)) + def clear_active(self) -> None: + self._mts._match_call("clear_active") + @contextlib.contextmanager def standard_cli_rsps() -> Iterator[responses.RequestsMock]: diff --git a/harness/tests/cli/auth_v0.json b/harness/tests/common/api/auth_v0.json similarity index 100% rename from harness/tests/cli/auth_v0.json rename to harness/tests/common/api/auth_v0.json diff --git a/harness/tests/common/api/test_authentication.py b/harness/tests/common/api/test_authentication.py new file mode 100644 index 00000000000..b44615a0d51 --- /dev/null +++ b/harness/tests/common/api/test_authentication.py @@ -0,0 +1,58 @@ +import contextlib +import json +import pathlib +import shutil +from typing import Optional + +import pytest +import responses +from responses import registries + +from determined.common.api import authentication +from tests import confdir +from tests.cli import util + +MOCK_MASTER_URL = "http://localhost:8080" +AUTH_V0_PATH = pathlib.Path(__file__).parent / "auth_v0.json" + + +@pytest.mark.parametrize("active_user", ["alice", "bob", None]) +def test_logout_clears_active_user(active_user: Optional[str]) -> None: + with contextlib.ExitStack() as es: + es.enter_context(util.setenv_optional("DET_MASTER", MOCK_MASTER_URL)) + rsps = es.enter_context( + responses.RequestsMock( + registry=registries.OrderedRegistry, + assert_all_requests_are_fired=True, + ) + ) + mts = es.enter_context(util.MockTokenStore(strict=True)) + + mts.get_active_user(retval=active_user) + if active_user == "alice": + mts.clear_active() + mts.get_token("alice", retval="token") + mts.drop_user("alice") + rsps.post(f"{MOCK_MASTER_URL}/api/v1/auth/logout", status=200) + + authentication.logout(MOCK_MASTER_URL, "alice", None) + + +def test_auth_json_v0_upgrade() -> None: + with confdir.use_test_config_dir() as config_dir: + auth_json_path = config_dir / "auth.json" + shutil.copy2(AUTH_V0_PATH, auth_json_path) + ts = authentication.TokenStore(MOCK_MASTER_URL, auth_json_path) + + assert ts.get_active_user() == "determined" + assert ts.get_token("determined") == "v2.public.this.is.a.test" + + ts.set_token("determined", "ai") + + ts2 = authentication.TokenStore(MOCK_MASTER_URL, auth_json_path) + assert ts2.get_token("determined") == "ai" + + with auth_json_path.open() as fin: + data = json.load(fin) + assert data.get("version") == 1 + assert "masters" in data and list(data["masters"].keys()) == [MOCK_MASTER_URL] diff --git a/harness/tests/common/api/test_certs.py b/harness/tests/common/api/test_certs.py new file mode 100644 index 00000000000..36c50c568f3 --- /dev/null +++ b/harness/tests/common/api/test_certs.py @@ -0,0 +1,33 @@ +import pathlib +import shutil + +from determined.common.api import certs +from tests import confdir + +MOCK_MASTER_URL = "http://localhost:8080" +UNTRUSTED_CERT_PATH = pathlib.Path(__file__).parents[1] / "untrusted-root" / "127.0.0.1-ca.crt" + + +def test_cert_v0_upgrade() -> None: + with confdir.use_test_config_dir() as config_dir: + cert_path = config_dir / "master.crt" + shutil.copy2(UNTRUSTED_CERT_PATH, cert_path) + with cert_path.open() as fin: + cert_data = fin.read() + + cert = certs.default_load(MOCK_MASTER_URL) + assert isinstance(cert.bundle, str) + with open(cert.bundle) as fin: + loaded_cert_data = fin.read() + assert loaded_cert_data.endswith(cert_data) + assert not cert_path.exists() + + v1_certs_path = config_dir / "certs.json" + assert v1_certs_path.exists() + + # Load once again from v1. + cert2 = certs.default_load(MOCK_MASTER_URL) + assert isinstance(cert2.bundle, str) + with open(cert2.bundle) as fin: + loaded_cert_data = fin.read() + assert loaded_cert_data.endswith(cert_data)