Skip to content

Commit

Permalink
fix: cli is not a library!
Browse files Browse the repository at this point in the history
Direct calls into the CLI are problematic, because the CLI uses a
variety of authentication techniques, some of which are only configured
by the cli's __main__() codepath.  That means that cli auth and cert
singletons have been configured all over e2e tests in order to support
direct CLI calls.

Avoiding direct calls to cli functions is a step in the direction of
cleaning up our authentication story.
  • Loading branch information
rb-determined-ai committed Sep 15, 2023
1 parent ff2e16d commit 591eedc
Show file tree
Hide file tree
Showing 6 changed files with 131 additions and 99 deletions.
31 changes: 17 additions & 14 deletions e2e_tests/tests/cluster/test_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import _socket
import pytest

from determined.cli import command
from determined.common import api
from determined.common.api import authentication, bindings, certs
from tests import config as conf
Expand Down Expand Up @@ -63,10 +62,10 @@ def test_trial_logs() -> None:
@pytest.mark.parametrize(
"task_type,task_config,log_regex",
[
(command.TaskTypeCommand, {"entrypoint": ["echo", "hello"]}, re.compile("^.*hello.*$")),
(command.TaskTypeNotebook, {}, re.compile("^.*Jupyter Server .* is running.*$")),
(command.TaskTypeShell, {}, re.compile("^.*Server listening on.*$")),
(command.TaskTypeTensorBoard, {}, re.compile("^.*TensorBoard .* at .*$")),
("command", {"entrypoint": ["echo", "hello"]}, re.compile("^.*hello.*$")),
("notebook", {}, re.compile("^.*Jupyter Server .* is running.*$")),
("shell", {}, re.compile("^.*Server listening on.*$")),
("tensorboard", {}, re.compile("^.*TensorBoard .* at .*$")),
],
)
def test_task_logs(task_type: str, task_config: Dict[str, Any], log_regex: Any) -> None:
Expand All @@ -78,28 +77,25 @@ def test_task_logs(task_type: str, task_config: Dict[str, Any], log_regex: Any)
rps = bindings.get_GetResourcePools(session)
assert rps.resourcePools and len(rps.resourcePools) > 0, "missing resource pool"

if (
rps.resourcePools[0].type == bindings.v1ResourcePoolType.K8S
and task_type == command.TaskTypeCommand
):
if rps.resourcePools[0].type == bindings.v1ResourcePoolType.K8S and task_type == "command":
# TODO(DET-6712): Investigate intermittent slowness with K8s command logs.
pytest.skip("DET-6712: Investigate intermittent slowness with K8s command logs")

if task_type == command.TaskTypeTensorBoard:
if task_type == "tensorboard":
exp_id = exp.run_basic_test(
conf.fixtures_path("no_op/single.yaml"),
conf.fixtures_path("no_op"),
1,
)
treq = bindings.v1LaunchTensorboardRequest(config=task_config, experimentIds=[exp_id])
task_id = bindings.post_LaunchTensorboard(session, body=treq).tensorboard.id
elif task_type == command.TaskTypeNotebook:
elif task_type == "notebook":
nreq = bindings.v1LaunchNotebookRequest(config=task_config)
task_id = bindings.post_LaunchNotebook(session, body=nreq).notebook.id
elif task_type == command.TaskTypeCommand:
elif task_type == "command":
creq = bindings.v1LaunchCommandRequest(config=task_config)
task_id = bindings.post_LaunchCommand(session, body=creq).command.id
elif task_type == command.TaskTypeShell:
elif task_type == "shell":
sreq = bindings.v1LaunchShellRequest(config=task_config)
task_id = bindings.post_LaunchShell(session, body=sreq).shell.id
else:
Expand All @@ -121,7 +117,14 @@ def task_log_fields(follow: Optional[bool] = None) -> Iterable[LogFields]:
raise TimeoutError(f"timed out waiting for {task_type} with id {task_id}")

finally:
command._kill(master_url, task_type, task_id)
if task_type == "tensorboard":
bindings.post_KillTensorboard(session, tensorboardId=task_id)
elif task_type == "notebook":
bindings.post_KillNotebook(session, notebookId=task_id)
elif task_type == "command":
bindings.post_KillCommand(session, commandId=task_id)
elif task_type == "shell":
bindings.post_KillShell(session, shellId=task_id)


def check_logs(
Expand Down
10 changes: 4 additions & 6 deletions e2e_tests/tests/cluster/test_rbac.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@

import pytest

from determined import cli
from determined.cli.user_groups import group_name_to_group_id, usernames_to_user_ids
from determined.common import api
from determined.common.api import authentication, bindings, errors
from tests import api_utils, utils
Expand Down Expand Up @@ -64,7 +62,7 @@ def create_users_with_gloabl_roles(user_roles: List[List[str]]) -> List[authenti
user = bindings.v1User(username=api_utils.get_random_string(), admin=False, active=True)
creds = api_utils.create_test_user(True, user=user)
for role in roles:
cli.rbac.assign_role(
api.assign_role(
utils.CliArgsMock(
username_to_assign=creds.username,
role_name=role,
Expand Down Expand Up @@ -110,7 +108,7 @@ def create_workspaces_with_users(
if rid not in rid_to_creds:
rid_to_creds[rid] = create_test_user()
for role in roles:
cli.rbac.assign_role(
api.assign_role(
utils.CliArgsMock(
username_to_assign=rid_to_creds[rid].username,
workspace_name=workspace.name,
Expand Down Expand Up @@ -576,8 +574,8 @@ def test_rbac_describe_role() -> None:
)

sess = api_utils.determined_test_session(ADMIN_CREDENTIALS)
user_id = usernames_to_user_ids(sess, [test_user_creds.username])[0]
group_id = group_name_to_group_id(sess, group_name)
user_id = api.usernames_to_user_ids(sess, [test_user_creds.username])[0]
group_id = api.group_name_to_group_id(sess, group_name)

det_cmd(
["rbac", "assign-role", "Viewer", "--username-to-assign", test_user_creds.username],
Expand Down
55 changes: 8 additions & 47 deletions harness/determined/cli/rbac.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from determined.common import api
from determined.common.api import authentication, bindings
from determined.common.declarative_argparse import Arg, Cmd
from determined.common.experimental import session

rbac_flag_disabled_message = (
"RBAC commands require the Determined Enterprise Edition "
Expand Down Expand Up @@ -134,7 +133,7 @@ def list_roles(args: Namespace) -> None:


def role_with_assignment_to_dict(
session: session.Session,
session: api.Session,
r: bindings.v1RoleWithAssignments,
assignment: bindings.v1RoleAssignment,
) -> Dict[str, Any]:
Expand Down Expand Up @@ -318,54 +317,16 @@ def describe_role(args: Namespace) -> None:
print()


def create_assignment_request(
session: session.Session, args: Namespace
) -> Tuple[List[bindings.v1UserRoleAssignment], List[bindings.v1GroupRoleAssignment]]:
if (args.username_to_assign is None) == (args.group_name_to_assign is None):
raise api.errors.BadRequestException(
"must provide exactly one of --username-to-assign or --group-name-to-assign"
)

role = bindings.v1Role(roleId=role_name_to_role_id(session, args.role_name))

workspace_id = None
if args.workspace_name is not None:
workspace_id = workspace.workspace_by_name(session, args.workspace_name).id
role_assign = bindings.v1RoleAssignment(role=role, scopeWorkspaceId=workspace_id)

if args.username_to_assign is not None:
user_id = user_groups.usernames_to_user_ids(session, [args.username_to_assign])[0]
return [bindings.v1UserRoleAssignment(userId=user_id, roleAssignment=role_assign)], []

group_id = user_groups.group_name_to_group_id(session, args.group_name_to_assign)
return [], [bindings.v1GroupRoleAssignment(groupId=group_id, roleAssignment=role_assign)]


@authentication.required
@require_feature_flag("rbacEnabled", rbac_flag_disabled_message)
def assign_role(args: Namespace) -> None:
session = setup_session(args)
user_assign, group_assign = create_assignment_request(session, args)
req = bindings.v1AssignRolesRequest(
userRoleAssignments=user_assign, groupRoleAssignments=group_assign
api.assign_role(
session=setup_session(args),
role_name=args.role_name,
workspace_name=args.workspace_name,
userspace_to_assign=args.userspace_to_assign,
group_name_to_assign=args.group_name_to_assign,
)
bindings.post_AssignRoles(session, body=req)

scope = " globally"
if args.workspace_name:
scope = f" to workspace {args.workspace_name}"
if len(user_assign) > 0:
role_id = user_assign[0].roleAssignment.role.roleId
print(
f"assigned role '{args.role_name}' with ID {role_id} "
+ f"to user '{args.username_to_assign}' with ID {user_assign[0].userId}{scope}"
)
else:
role_id = group_assign[0].roleAssignment.role.roleId
print(
f"assigned role '{args.role_name}' with ID {role_id} "
+ f"to group '{args.group_name_to_assign}' with ID {group_assign[0].groupId}{scope}"
)


@authentication.required
Expand Down Expand Up @@ -393,7 +354,7 @@ def unassign_role(args: Namespace) -> None:
)


def role_name_to_role_id(session: session.Session, role_name: str) -> int:
def role_name_to_role_id(session: api.Session, role_name: str) -> int:
req = bindings.v1ListRolesRequest(limit=499, offset=0)
resp = bindings.post_ListRoles(session=session, body=req)
for r in resp.roles:
Expand Down
31 changes: 0 additions & 31 deletions harness/determined/cli/user_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,37 +145,6 @@ def delete_group(args: Namespace) -> None:
print("Skipping group deletion.")


def usernames_to_user_ids(session: Session, usernames: List[str]) -> List[int]:
usernames_to_ids: Dict[str, Optional[int]] = {u: None for u in usernames}
users = bindings.get_GetUsers(session).users or []
for user in users:
if user.username in usernames_to_ids:
usernames_to_ids[user.username] = user.id

missing_users = []
user_ids = []
for username, user_id in usernames_to_ids.items():
if user_id is None:
missing_users.append(username)
else:
user_ids.append(user_id)

if missing_users:
raise api.errors.BadRequestException(
f"could not find users for usernames {', '.join(missing_users)}"
)
return user_ids


def group_name_to_group_id(session: Session, group_name: str) -> int:
body = bindings.v1GetGroupsRequest(name=group_name, limit=1, offset=0)
resp = bindings.post_GetGroups(session, body=body)
groups = resp.groups
if groups is None or len(groups) != 1 or groups[0].group.groupId is None:
raise api.errors.BadRequestException(f"could not find user group name {group_name}")
return groups[0].group.groupId


args_description = [
Cmd(
"user-group",
Expand Down
3 changes: 3 additions & 0 deletions harness/determined/common/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
wait_for_ntsc_state,
task_is_ready,
NTSC_Kind,
assign_role,
usernames_to_user_ids,
group_name_to_group_id,
)
from determined.common.api.authentication import Authentication, salt_and_hash
from determined.common.api.logs import (
Expand Down
100 changes: 99 additions & 1 deletion harness/determined/common/api/_util.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import enum
from typing import Callable, Iterator, Optional, Set, Tuple, TypeVar, Union
from typing import Callable, Iterator, List, Optional, Set, Tuple, TypeVar, Union

import urllib3

Expand Down Expand Up @@ -135,3 +135,101 @@ def _task_is_done_loading() -> Tuple[bool, Optional[str]]:

err_msg = util.wait_for(_task_is_done_loading, timeout=300, interval=1)
return err_msg


def create_assignment_request(
session: api.Session,
role_name: str,
workspace_name: Optional[str] = None,
username_to_assign: Optional[str] = None,
group_name_to_assign: Optional[str] = None,
) -> Tuple[List[bindings.v1UserRoleAssignment], List[bindings.v1GroupRoleAssignment]]:
if (username_to_assign is None) == (group_name_to_assign is None):
# XXX: very weird that we use api.errors.BadRequestException here!
raise api.errors.BadRequestException(
"must provide exactly one of --username-to-assign or --group-name-to-assign"
)

role = bindings.v1Role(roleId=role_name_to_role_id(session, role_name))

workspace_id = None
if workspace_name is not None:
workspace_id = workspace.workspace_by_name(session, workspace_name).id
role_assign = bindings.v1RoleAssignment(role=role, scopeWorkspaceId=workspace_id)

if username_to_assign is not None:
user_id = user_groups.usernames_to_user_ids(session, [username_to_assign])[0]
return [bindings.v1UserRoleAssignment(userId=user_id, roleAssignment=role_assign)], []

group_id = user_groups.group_name_to_group_id(session, group_name_to_assign)
return [], [bindings.v1GroupRoleAssignment(groupId=group_id, roleAssignment=role_assign)]


def assign_role(
session: api.Session,
role_name: str,
workspace_name: Optional[str] = None,
username_to_assign: Optional[str] = None,
group_name_to_assign: Optional[str] = None,
) -> None:
"""
assign_role is a CLI endpoint that is also used in e2e tests.
"""
user_assign, group_assign = create_assignment_request(
session=session,
role_name=role_name,
workspace_name=workspace_name,
userspace_to_assign=userspace_to_assign,
group_name_to_assign=group_name_to_assign,
)
req = bindings.v1AssignRolesRequest(
userRoleAssignments=user_assign, groupRoleAssignments=group_assign
)
bindings.post_AssignRoles(session, body=req)

scope = " globally"
if workspace_name:
scope = f" to workspace {workspace_name}"
if len(user_assign) > 0:
role_id = user_assign[0].roleAssignment.role.roleId
print(
f"assigned role '{role_name}' with ID {role_id} "
+ f"to user '{username_to_assign}' with ID {user_assign[0].userId}{scope}"
)
else:
role_id = group_assign[0].roleAssignment.role.roleId
print(
f"assigned role '{role_name}' with ID {role_id} "
+ f"to group '{group_name_to_assign}' with ID {group_assign[0].groupId}{scope}"
)


def usernames_to_user_ids(session: api.Session, usernames: List[str]) -> List[int]:
usernames_to_ids: Dict[str, Optional[int]] = {u: None for u in usernames}
users = bindings.get_GetUsers(session).users or []
for user in users:
if user.username in usernames_to_ids:
usernames_to_ids[user.username] = user.id

missing_users = []
user_ids = []
for username, user_id in usernames_to_ids.items():
if user_id is None:
missing_users.append(username)
else:
user_ids.append(user_id)

if missing_users:
raise api.errors.BadRequestException(
f"could not find users for usernames {', '.join(missing_users)}"
)
return user_ids


def group_name_to_group_id(session: api.Session, group_name: str) -> int:
body = bindings.v1GetGroupsRequest(name=group_name, limit=1, offset=0)
resp = bindings.post_GetGroups(session, body=body)
groups = resp.groups
if groups is None or len(groups) != 1 or groups[0].group.groupId is None:
raise api.errors.BadRequestException(f"could not find user group name {group_name}")
return groups[0].group.groupId

0 comments on commit 591eedc

Please sign in to comment.