Skip to content

Commit

Permalink
fix: don't drop active_user on expired tokens [MLG-1494]
Browse files Browse the repository at this point in the history
The authentication logic would drop expired tokens, and then ask for a
password and, if that worked, add the fresh token back to the token
store.

However, the initial token drop had a side-effect of dropping the active
user, which was not being restored.

After this fix, the TokenStore doesn't try to be so smart.  Instead, the
logout-related CLI commands will explicitly clear the active user, but
tokens dropped due to expiration are unaffected.
  • Loading branch information
rb-determined-ai committed Jan 5, 2024
1 parent ab08345 commit d7357d7
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 4 deletions.
11 changes: 9 additions & 2 deletions harness/determined/common/api/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,9 @@ def logout(
if session_user is None:
return

if session_user == token_store.get_active_user():
token_store.clear_active()

session_token = token_store.get_token(session_user)

if session_token is None:
Expand All @@ -240,6 +243,8 @@ def logout_all(master_address: Optional[str], cert: Optional[certs.Cert]) -> Non
for user in users:
logout(master_address, user, cert)

token_store.clear_active()


def _is_token_valid(master_address: str, token: str, cert: Optional[certs.Cert]) -> bool:
"""
Expand Down Expand Up @@ -304,8 +309,6 @@ def drop_user(self, username: str) -> None:
tokens = substore.setdefault("tokens", {})
if username in tokens:
del tokens[username]
if substore.get("active_user") == username:
del substore["active_user"]

def set_token(self, username: str, token: str) -> None:
with self._persistent_store() as substore:
Expand All @@ -319,6 +322,10 @@ def set_active(self, username: str) -> None:
raise api.errors.UnauthenticatedException(username=username)
substore["active_user"] = username

def clear_active(self) -> None:
with self._persistent_store() as substore:
substore.pop("active_user", None)

@contextlib.contextmanager
def _persistent_store(self) -> Iterator[Dict[str, Any]]:
"""
Expand Down
27 changes: 26 additions & 1 deletion harness/tests/cli/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import json
import shutil
from pathlib import Path
from typing import Any, Iterator, List, Tuple, no_type_check
from typing import Any, Iterator, List, Optional, Tuple, no_type_check
from unittest import mock

import pytest
Expand Down Expand Up @@ -342,6 +342,7 @@ def __init__(self, user: str) -> None:
self.user = user

def expect(self, rsps: responses.RequestsMock, mts: util.MockTokenStore, scenario: Any) -> None:
mts.get_active_user(retval=None)
mts.get_token(self.user, retval="cache_token" if scenario.user_in_cache else None)


Expand Down Expand Up @@ -422,6 +423,30 @@ def test_logout(scenario_set: Logout) -> None:
cli.main(cmd)


@pytest.mark.parametrize("active_user", ["alice", "bob", None])
def test_logout_clears_active_user(active_user: Optional[str]) -> None:
with contextlib.ExitStack() as es:
es.enter_context(util.setenv_optional("DET_MASTER", MOCK_MASTER_URL))
rsps = es.enter_context(
responses.RequestsMock(
registry=registries.OrderedRegistry,
assert_all_requests_are_fired=True,
)
)
mts = es.enter_context(util.MockTokenStore(strict=True))

util.expect_get_info(rsps)

mts.get_active_user(retval=active_user)
if active_user == "alice":
mts.clear_active()
mts.get_token("alice", retval="token")
mts.drop_user("alice")
rsps.post(f"{MOCK_MASTER_URL}/api/v1/auth/logout", status=200)

cli.main(["-u", "alice", "user", "logout"])


def test_auth_json_v0_upgrade() -> None:
with use_test_config_dir() as config_dir:
auth_json_path = config_dir / "auth.json"
Expand Down
12 changes: 11 additions & 1 deletion harness/tests/cli/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,11 @@ def _match_call(self, call: Any) -> Any:
if self._strict:
if self._ncalls == len(self._exp_calls):
raise ValueError(f"unexpected call to TokenStore: {call}")
assert self._exp_calls[self._ncalls] == call
if self._exp_calls[self._ncalls] != call:
raise ValueError(
f"mismstached call to TokenStore: expected {self._exp_calls[self._ncalls]} "
f"but got {call}"
)
retval = self._retvals[self._ncalls]
self._ncalls += 1
return retval
Expand Down Expand Up @@ -91,6 +95,9 @@ def set_active(self, username: str) -> None:
self._exp_calls.append(("set_active", username))
self._retvals.append(None)

def clear_active(self) -> None:
self._exp_calls.append("clear_active")
self._retvals.append(None)

class MockTokenStoreInstance:
def __init__(self, mts: MockTokenStore) -> None:
Expand All @@ -114,6 +121,9 @@ def set_token(self, username: str, token: str) -> None:
def set_active(self, username: str) -> None:
self._mts._match_call(("set_active", username))

def clear_active(self) -> None:
self._mts._match_call("clear_active")


@contextlib.contextmanager
def standard_cli_rsps() -> Iterator[responses.RequestsMock]:
Expand Down

0 comments on commit d7357d7

Please sign in to comment.