Skip to content

Commit

Permalink
fix: don't drop active_user on expired tokens [MLG-1494] (#8653)
Browse files Browse the repository at this point in the history
The authentication logic would drop expired tokens, and then ask for a
password and, if that worked, add the fresh token back to the token
store.

However, the initial token drop had a side-effect of dropping the active
user, which was not being restored.

After this fix, the TokenStore doesn't try to be so smart.  Instead, the
logout-related CLI commands will explicitly clear the active user, but
tokens dropped due to expiration are unaffected.
  • Loading branch information
rb-determined-ai authored Jan 10, 2024
1 parent 9dc4fc8 commit fd4c1a0
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 68 deletions.
17 changes: 14 additions & 3 deletions harness/determined/common/api/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,9 @@ def logout(
) -> None:
"""
Logout if there is an active session for this master/username pair, otherwise do nothing.
Additionally, if the user happens to be the active user, drop the active user from the
TokenStore.
"""

master_address = master_address or util.get_default_master_address()
Expand All @@ -215,6 +218,9 @@ def logout(
if session_user is None:
return

if session_user == token_store.get_active_user():
token_store.clear_active()

session_token = token_store.get_token(session_user)

if session_token is None:
Expand All @@ -240,6 +246,8 @@ def logout_all(master_address: Optional[str], cert: Optional[certs.Cert]) -> Non
for user in users:
logout(master_address, user, cert)

token_store.clear_active()


def _is_token_valid(master_address: str, token: str, cert: Optional[certs.Cert]) -> bool:
"""
Expand All @@ -258,7 +266,8 @@ def _is_token_valid(master_address: str, token: str, cert: Optional[certs.Cert])
class TokenStore:
"""
TokenStore is a class for reading/updating a persistent store of user authentication tokens.
TokenStore can remember tokens for many users for each of many masters.
TokenStore can remember tokens for many users for each of many masters. It can also remembers
one "active user" for each master, which is set via `det user login`.
All updates to the file follow a read-modify-write pattern, and use file locks to protect the
integrity of the underlying file cache.
Expand Down Expand Up @@ -304,8 +313,6 @@ def drop_user(self, username: str) -> None:
tokens = substore.setdefault("tokens", {})
if username in tokens:
del tokens[username]
if substore.get("active_user") == username:
del substore["active_user"]

def set_token(self, username: str, token: str) -> None:
with self._persistent_store() as substore:
Expand All @@ -319,6 +326,10 @@ def set_active(self, username: str) -> None:
raise api.errors.UnauthenticatedException(username=username)
substore["active_user"] = username

def clear_active(self) -> None:
with self._persistent_store() as substore:
substore.pop("active_user", None)

@contextlib.contextmanager
def _persistent_store(self) -> Iterator[Dict[str, Any]]:
"""
Expand Down
66 changes: 2 additions & 64 deletions harness/tests/cli/test_auth.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import collections
import contextlib
import itertools
import json
import shutil
from pathlib import Path
from typing import Any, Iterator, List, Tuple, no_type_check
from unittest import mock

Expand All @@ -12,25 +9,10 @@
from responses import matchers, registries

from determined.cli import cli
from determined.common.api import authentication, certs
from determined.common.api import authentication
from tests.cli import util
from tests.confdir import use_test_config_dir

MOCK_MASTER_URL = "http://localhost:8080"
AUTH_V0_PATH = Path(__file__).parent / "auth_v0.json"
UNTRUSTED_CERT_PATH = Path(__file__).parents[1] / "common" / "untrusted-root" / "127.0.0.1-ca.crt"
AUTH_JSON = {
"version": 1,
"masters": {
"http://localhost:8080": {
"active_user": "bob",
"tokens": {
"determined": "det.token",
"bob": "bob.token",
},
}
},
}


class ScenarioSetMeta(type):
Expand Down Expand Up @@ -342,6 +324,7 @@ def __init__(self, user: str) -> None:
self.user = user

def expect(self, rsps: responses.RequestsMock, mts: util.MockTokenStore, scenario: Any) -> None:
mts.get_active_user(retval=None)
mts.get_token(self.user, retval="cache_token" if scenario.user_in_cache else None)


Expand Down Expand Up @@ -420,48 +403,3 @@ def test_logout(scenario_set: Logout) -> None:
else:
cmd = ["user", "logout"]
cli.main(cmd)


def test_auth_json_v0_upgrade() -> None:
with use_test_config_dir() as config_dir:
auth_json_path = config_dir / "auth.json"
shutil.copy2(AUTH_V0_PATH, auth_json_path)
ts = authentication.TokenStore(MOCK_MASTER_URL, auth_json_path)

assert ts.get_active_user() == "determined"
assert ts.get_token("determined") == "v2.public.this.is.a.test"

ts.set_token("determined", "ai")

ts2 = authentication.TokenStore(MOCK_MASTER_URL, auth_json_path)
assert ts2.get_token("determined") == "ai"

with auth_json_path.open() as fin:
data = json.load(fin)
assert data.get("version") == 1
assert "masters" in data and list(data["masters"].keys()) == [MOCK_MASTER_URL]


def test_cert_v0_upgrade() -> None:
with use_test_config_dir() as config_dir:
cert_path = config_dir / "master.crt"
shutil.copy2(UNTRUSTED_CERT_PATH, cert_path)
with cert_path.open() as fin:
cert_data = fin.read()

cert = certs.default_load(MOCK_MASTER_URL)
assert isinstance(cert.bundle, str)
with open(cert.bundle) as fin:
loaded_cert_data = fin.read()
assert loaded_cert_data.endswith(cert_data)
assert not cert_path.exists()

v1_certs_path = config_dir / "certs.json"
assert v1_certs_path.exists()

# Load once again from v1.
cert2 = certs.default_load(MOCK_MASTER_URL)
assert isinstance(cert2.bundle, str)
with open(cert2.bundle) as fin:
loaded_cert_data = fin.read()
assert loaded_cert_data.endswith(cert_data)
13 changes: 12 additions & 1 deletion harness/tests/cli/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,11 @@ def _match_call(self, call: Any) -> Any:
if self._strict:
if self._ncalls == len(self._exp_calls):
raise ValueError(f"unexpected call to TokenStore: {call}")
assert self._exp_calls[self._ncalls] == call
if self._exp_calls[self._ncalls] != call:
raise ValueError(
f"mismstached call to TokenStore: expected {self._exp_calls[self._ncalls]} "
f"but got {call}"
)
retval = self._retvals[self._ncalls]
self._ncalls += 1
return retval
Expand Down Expand Up @@ -91,6 +95,10 @@ def set_active(self, username: str) -> None:
self._exp_calls.append(("set_active", username))
self._retvals.append(None)

def clear_active(self) -> None:
self._exp_calls.append("clear_active")
self._retvals.append(None)


class MockTokenStoreInstance:
def __init__(self, mts: MockTokenStore) -> None:
Expand All @@ -114,6 +122,9 @@ def set_token(self, username: str, token: str) -> None:
def set_active(self, username: str) -> None:
self._mts._match_call(("set_active", username))

def clear_active(self) -> None:
self._mts._match_call("clear_active")


@contextlib.contextmanager
def standard_cli_rsps() -> Iterator[responses.RequestsMock]:
Expand Down
File renamed without changes.
58 changes: 58 additions & 0 deletions harness/tests/common/api/test_authentication.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import contextlib
import json
import pathlib
import shutil
from typing import Optional

import pytest
import responses
from responses import registries

from determined.common.api import authentication
from tests import confdir
from tests.cli import util

MOCK_MASTER_URL = "http://localhost:8080"
AUTH_V0_PATH = pathlib.Path(__file__).parent / "auth_v0.json"


@pytest.mark.parametrize("active_user", ["alice", "bob", None])
def test_logout_clears_active_user(active_user: Optional[str]) -> None:
with contextlib.ExitStack() as es:
es.enter_context(util.setenv_optional("DET_MASTER", MOCK_MASTER_URL))
rsps = es.enter_context(
responses.RequestsMock(
registry=registries.OrderedRegistry,
assert_all_requests_are_fired=True,
)
)
mts = es.enter_context(util.MockTokenStore(strict=True))

mts.get_active_user(retval=active_user)
if active_user == "alice":
mts.clear_active()
mts.get_token("alice", retval="token")
mts.drop_user("alice")
rsps.post(f"{MOCK_MASTER_URL}/api/v1/auth/logout", status=200)

authentication.logout(MOCK_MASTER_URL, "alice", None)


def test_auth_json_v0_upgrade() -> None:
with confdir.use_test_config_dir() as config_dir:
auth_json_path = config_dir / "auth.json"
shutil.copy2(AUTH_V0_PATH, auth_json_path)
ts = authentication.TokenStore(MOCK_MASTER_URL, auth_json_path)

assert ts.get_active_user() == "determined"
assert ts.get_token("determined") == "v2.public.this.is.a.test"

ts.set_token("determined", "ai")

ts2 = authentication.TokenStore(MOCK_MASTER_URL, auth_json_path)
assert ts2.get_token("determined") == "ai"

with auth_json_path.open() as fin:
data = json.load(fin)
assert data.get("version") == 1
assert "masters" in data and list(data["masters"].keys()) == [MOCK_MASTER_URL]
33 changes: 33 additions & 0 deletions harness/tests/common/api/test_certs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import pathlib
import shutil

from determined.common.api import certs
from tests import confdir

MOCK_MASTER_URL = "http://localhost:8080"
UNTRUSTED_CERT_PATH = pathlib.Path(__file__).parents[1] / "untrusted-root" / "127.0.0.1-ca.crt"


def test_cert_v0_upgrade() -> None:
with confdir.use_test_config_dir() as config_dir:
cert_path = config_dir / "master.crt"
shutil.copy2(UNTRUSTED_CERT_PATH, cert_path)
with cert_path.open() as fin:
cert_data = fin.read()

cert = certs.default_load(MOCK_MASTER_URL)
assert isinstance(cert.bundle, str)
with open(cert.bundle) as fin:
loaded_cert_data = fin.read()
assert loaded_cert_data.endswith(cert_data)
assert not cert_path.exists()

v1_certs_path = config_dir / "certs.json"
assert v1_certs_path.exists()

# Load once again from v1.
cert2 = certs.default_load(MOCK_MASTER_URL)
assert isinstance(cert2.bundle, str)
with open(cert2.bundle) as fin:
loaded_cert_data = fin.read()
assert loaded_cert_data.endswith(cert_data)

0 comments on commit fd4c1a0

Please sign in to comment.