Skip to content

Commit

Permalink
fix: cli is not a library!
Browse files Browse the repository at this point in the history
Direct calls into the CLI are problematic, because the CLI uses a
variety of authentication techniques, some of which are only configured
by the cli's __main__() codepath.  That means that cli auth and cert
singletons have been configured all over e2e tests in order to support
direct CLI calls.

Avoiding direct calls to cli functions is a step in the direction of
cleaning up our authentication story.
  • Loading branch information
rb-determined-ai committed Sep 18, 2023
1 parent ff2e16d commit 92b36be
Show file tree
Hide file tree
Showing 13 changed files with 216 additions and 203 deletions.
31 changes: 17 additions & 14 deletions e2e_tests/tests/cluster/test_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import _socket
import pytest

from determined.cli import command
from determined.common import api
from determined.common.api import authentication, bindings, certs
from tests import config as conf
Expand Down Expand Up @@ -63,10 +62,10 @@ def test_trial_logs() -> None:
@pytest.mark.parametrize(
"task_type,task_config,log_regex",
[
(command.TaskTypeCommand, {"entrypoint": ["echo", "hello"]}, re.compile("^.*hello.*$")),
(command.TaskTypeNotebook, {}, re.compile("^.*Jupyter Server .* is running.*$")),
(command.TaskTypeShell, {}, re.compile("^.*Server listening on.*$")),
(command.TaskTypeTensorBoard, {}, re.compile("^.*TensorBoard .* at .*$")),
("command", {"entrypoint": ["echo", "hello"]}, re.compile("^.*hello.*$")),
("notebook", {}, re.compile("^.*Jupyter Server .* is running.*$")),
("shell", {}, re.compile("^.*Server listening on.*$")),
("tensorboard", {}, re.compile("^.*TensorBoard .* at .*$")),
],
)
def test_task_logs(task_type: str, task_config: Dict[str, Any], log_regex: Any) -> None:
Expand All @@ -78,28 +77,25 @@ def test_task_logs(task_type: str, task_config: Dict[str, Any], log_regex: Any)
rps = bindings.get_GetResourcePools(session)
assert rps.resourcePools and len(rps.resourcePools) > 0, "missing resource pool"

if (
rps.resourcePools[0].type == bindings.v1ResourcePoolType.K8S
and task_type == command.TaskTypeCommand
):
if rps.resourcePools[0].type == bindings.v1ResourcePoolType.K8S and task_type == "command":
# TODO(DET-6712): Investigate intermittent slowness with K8s command logs.
pytest.skip("DET-6712: Investigate intermittent slowness with K8s command logs")

if task_type == command.TaskTypeTensorBoard:
if task_type == "tensorboard":
exp_id = exp.run_basic_test(
conf.fixtures_path("no_op/single.yaml"),
conf.fixtures_path("no_op"),
1,
)
treq = bindings.v1LaunchTensorboardRequest(config=task_config, experimentIds=[exp_id])
task_id = bindings.post_LaunchTensorboard(session, body=treq).tensorboard.id
elif task_type == command.TaskTypeNotebook:
elif task_type == "notebook":
nreq = bindings.v1LaunchNotebookRequest(config=task_config)
task_id = bindings.post_LaunchNotebook(session, body=nreq).notebook.id
elif task_type == command.TaskTypeCommand:
elif task_type == "command":
creq = bindings.v1LaunchCommandRequest(config=task_config)
task_id = bindings.post_LaunchCommand(session, body=creq).command.id
elif task_type == command.TaskTypeShell:
elif task_type == "shell":
sreq = bindings.v1LaunchShellRequest(config=task_config)
task_id = bindings.post_LaunchShell(session, body=sreq).shell.id
else:
Expand All @@ -121,7 +117,14 @@ def task_log_fields(follow: Optional[bool] = None) -> Iterable[LogFields]:
raise TimeoutError(f"timed out waiting for {task_type} with id {task_id}")

finally:
command._kill(master_url, task_type, task_id)
if task_type == "tensorboard":
bindings.post_KillTensorboard(session, tensorboardId=task_id)
elif task_type == "notebook":
bindings.post_KillNotebook(session, notebookId=task_id)
elif task_type == "command":
bindings.post_KillCommand(session, commandId=task_id)
elif task_type == "shell":
bindings.post_KillShell(session, shellId=task_id)


def check_logs(
Expand Down
43 changes: 20 additions & 23 deletions e2e_tests/tests/cluster/test_rbac.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,9 @@

import pytest

from determined import cli
from determined.cli.user_groups import group_name_to_group_id, usernames_to_user_ids
from determined.common import api
from determined.common.api import authentication, bindings, errors
from tests import api_utils, utils
from tests import api_utils
from tests.api_utils import configure_token_store, create_test_user, determined_test_session
from tests.cluster.test_workspace_org import setup_workspaces

Expand Down Expand Up @@ -59,18 +57,17 @@ def create_users_with_gloabl_roles(user_roles: List[List[str]]) -> List[authenti
user_roles: list of roles to assign to each user, one entry per user.
"""
user_creds: List[authentication.Credentials] = []
with logged_in_user(ADMIN_CREDENTIALS):
for roles in user_roles:
user = bindings.v1User(username=api_utils.get_random_string(), admin=False, active=True)
creds = api_utils.create_test_user(True, user=user)
for role in roles:
cli.rbac.assign_role(
utils.CliArgsMock(
username_to_assign=creds.username,
role_name=role,
)
)
user_creds.append(creds)
sess = api_utils.determined_test_session(admin=True)
for roles in user_roles:
user = bindings.v1User(username=api_utils.get_random_string(), admin=False, active=True)
creds = api_utils.create_test_user(True, user=user)
for role in roles:
api.assign_role(
session=sess,
role_name=role,
username_to_assign=creds.username,
)
user_creds.append(creds)
return user_creds


Expand Down Expand Up @@ -102,6 +99,7 @@ def create_workspaces_with_users(
]
]
"""
sess = api_utils.determined_test_session(admin=True)
configure_token_store(ADMIN_CREDENTIALS)
rid_to_creds: Dict[int, authentication.Credentials] = {}
with setup_workspaces(count=len(assignments_list)) as workspaces:
Expand All @@ -110,12 +108,11 @@ def create_workspaces_with_users(
if rid not in rid_to_creds:
rid_to_creds[rid] = create_test_user()
for role in roles:
cli.rbac.assign_role(
utils.CliArgsMock(
username_to_assign=rid_to_creds[rid].username,
workspace_name=workspace.name,
role_name=role,
)
api.assign_role(
session=sess,
role_name=role,
workspace_name=workspace.name,
username_to_assign=rid_to_creds[rid].username,
)
yield workspaces, rid_to_creds

Expand Down Expand Up @@ -576,8 +573,8 @@ def test_rbac_describe_role() -> None:
)

sess = api_utils.determined_test_session(ADMIN_CREDENTIALS)
user_id = usernames_to_user_ids(sess, [test_user_creds.username])[0]
group_id = group_name_to_group_id(sess, group_name)
user_id = api.usernames_to_user_ids(sess, [test_user_creds.username])[0]
group_id = api.group_name_to_group_id(sess, group_name)

det_cmd(
["rbac", "assign-role", "Viewer", "--username-to-assign", test_user_creds.username],
Expand Down
19 changes: 0 additions & 19 deletions e2e_tests/tests/utils.py

This file was deleted.

1 change: 0 additions & 1 deletion harness/determined/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
login_sdk_client,
print_warnings,
wait_ntsc_ready,
not_found_errs,
)
from determined.cli import (
agent,
Expand Down
12 changes: 0 additions & 12 deletions harness/determined/cli/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,15 +131,3 @@ def wait_ntsc_ready(session: api.Session, ntsc_type: api.NTSC_Kind, eid: str) ->
loading_animator.clear(msg)
if err_msg:
raise errors.CliError(err_msg)


# not_found_errs mirrors NotFoundErrs from the golang api/errors.go. In the cases where
# Python errors override the golang errors, this ensures the error messages stay consistent.
def not_found_errs(
category: str, name: str, session: api.Session
) -> api.errors.BadRequestException:
resp = bindings.get_GetMaster(session)
msg = f"{category} '{name}' not found"
if not resp.to_json().get("rbacEnabled"):
return api.errors.NotFoundException(msg)
return api.errors.NotFoundException(msg + ", please check your permissions.")
2 changes: 1 addition & 1 deletion harness/determined/cli/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def list_tasks(args: Namespace) -> None:
params: Dict[str, Any] = {}

if "workspace_name" in args and args.workspace_name is not None:
workspace = cli.workspace.workspace_by_name(cli.setup_session(args), args.workspace_name)
workspace = api.workspace_by_name(cli.setup_session(args), args.workspace_name)

params["workspaceId"] = workspace.id

Expand Down
2 changes: 1 addition & 1 deletion harness/determined/cli/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,7 @@ def experiment_logs(args: Namespace) -> None:
sess = cli.setup_session(args)
trials = bindings.get_GetExperimentTrials(sess, experimentId=args.experiment_id).trials
if len(trials) == 0:
raise cli.not_found_errs("experiment", args.experiment_id, sess)
raise api.not_found_errs("experiment", args.experiment_id, sess)
first_trial_id = sorted(t_id.id for t_id in trials)[0]
try:
logs = api.trial_logs(
Expand Down
8 changes: 4 additions & 4 deletions harness/determined/cli/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from determined.common.api import authentication, bindings, errors
from determined.common.declarative_argparse import Arg, Cmd

from .workspace import list_workspace_projects, pagination_args, workspace_by_name
from .workspace import list_workspace_projects, pagination_args


def render_experiments(args: Namespace, experiments: Sequence[bindings.v1Experiment]) -> None:
Expand Down Expand Up @@ -61,10 +61,10 @@ def render_project(project: bindings.v1Project) -> None:
def project_by_name(
sess: api.Session, workspace_name: str, project_name: str
) -> Tuple[bindings.v1Workspace, bindings.v1Project]:
w = workspace_by_name(sess, workspace_name)
w = api.workspace_by_name(sess, workspace_name)
p = bindings.get_GetWorkspaceProjects(sess, id=w.id, name=project_name).projects
if len(p) == 0:
raise cli.not_found_errs("project", project_name, sess)
raise api.not_found_errs("project", project_name, sess)
return (w, p[0])


Expand Down Expand Up @@ -102,7 +102,7 @@ def list_project_experiments(args: Namespace) -> None:
@authentication.required
def create_project(args: Namespace) -> None:
sess = cli.setup_session(args)
w = workspace_by_name(sess, args.workspace_name)
w = api.workspace_by_name(sess, args.workspace_name)
content = bindings.v1PostProjectRequest(
name=args.name, description=args.description, workspaceId=w.id
)
Expand Down
88 changes: 20 additions & 68 deletions harness/determined/cli/rbac.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,12 @@
from argparse import Namespace
from collections import namedtuple
from typing import Any, Dict, List, Set, Tuple
from typing import Any, Dict, List, Set

import determined.cli.render
from determined.cli import (
default_pagination_args,
render,
require_feature_flag,
setup_session,
user_groups,
workspace,
)
from determined.cli import default_pagination_args, render, require_feature_flag, setup_session
from determined.common import api
from determined.common.api import authentication, bindings
from determined.common.declarative_argparse import Arg, Cmd
from determined.common.experimental import session

rbac_flag_disabled_message = (
"RBAC commands require the Determined Enterprise Edition "
Expand Down Expand Up @@ -134,7 +126,7 @@ def list_roles(args: Namespace) -> None:


def role_with_assignment_to_dict(
session: session.Session,
session: api.Session,
r: bindings.v1RoleWithAssignments,
assignment: bindings.v1RoleAssignment,
) -> Dict[str, Any]:
Expand Down Expand Up @@ -162,7 +154,7 @@ def role_with_assignment_to_dict(
@require_feature_flag("rbacEnabled", rbac_flag_disabled_message)
def list_users_roles(args: Namespace) -> None:
session = setup_session(args)
user_id = user_groups.usernames_to_user_ids(session, [args.username])[0]
user_id = api.usernames_to_user_ids(session, [args.username])[0]
resp = bindings.get_GetRolesAssignedToUser(session, userId=user_id)
if args.json:
determined.cli.render.print_json(resp.to_json())
Expand Down Expand Up @@ -198,7 +190,7 @@ def list_users_roles(args: Namespace) -> None:
@require_feature_flag("rbacEnabled", rbac_flag_disabled_message)
def list_groups_roles(args: Namespace) -> None:
session = setup_session(args)
group_id = user_groups.group_name_to_group_id(session, args.group_name)
group_id = api.group_name_to_group_id(session, args.group_name)
resp = bindings.get_GetRolesAssignedToGroup(session, groupId=group_id)
if args.json:
determined.cli.render.print_json(resp.to_json())
Expand Down Expand Up @@ -238,7 +230,7 @@ def list_groups_roles(args: Namespace) -> None:
@require_feature_flag("rbacEnabled", rbac_flag_disabled_message)
def describe_role(args: Namespace) -> None:
session = setup_session(args)
role_id = role_name_to_role_id(session, args.role_name)
role_id = api.role_name_to_role_id(session, args.role_name)
req = bindings.v1GetRolesByIDRequest(roleIds=[role_id])
resp = bindings.post_GetRolesByID(session, body=req)
if args.json:
Expand Down Expand Up @@ -318,61 +310,30 @@ def describe_role(args: Namespace) -> None:
print()


def create_assignment_request(
session: session.Session, args: Namespace
) -> Tuple[List[bindings.v1UserRoleAssignment], List[bindings.v1GroupRoleAssignment]]:
if (args.username_to_assign is None) == (args.group_name_to_assign is None):
raise api.errors.BadRequestException(
"must provide exactly one of --username-to-assign or --group-name-to-assign"
)

role = bindings.v1Role(roleId=role_name_to_role_id(session, args.role_name))

workspace_id = None
if args.workspace_name is not None:
workspace_id = workspace.workspace_by_name(session, args.workspace_name).id
role_assign = bindings.v1RoleAssignment(role=role, scopeWorkspaceId=workspace_id)

if args.username_to_assign is not None:
user_id = user_groups.usernames_to_user_ids(session, [args.username_to_assign])[0]
return [bindings.v1UserRoleAssignment(userId=user_id, roleAssignment=role_assign)], []

group_id = user_groups.group_name_to_group_id(session, args.group_name_to_assign)
return [], [bindings.v1GroupRoleAssignment(groupId=group_id, roleAssignment=role_assign)]


@authentication.required
@require_feature_flag("rbacEnabled", rbac_flag_disabled_message)
def assign_role(args: Namespace) -> None:
session = setup_session(args)
user_assign, group_assign = create_assignment_request(session, args)
req = bindings.v1AssignRolesRequest(
userRoleAssignments=user_assign, groupRoleAssignments=group_assign
api.assign_role(
session=setup_session(args),
role_name=args.role_name,
workspace_name=args.workspace_name,
username_to_assign=args.username_to_assign,
group_name_to_assign=args.group_name_to_assign,
)
bindings.post_AssignRoles(session, body=req)

scope = " globally"
if args.workspace_name:
scope = f" to workspace {args.workspace_name}"
if len(user_assign) > 0:
role_id = user_assign[0].roleAssignment.role.roleId
print(
f"assigned role '{args.role_name}' with ID {role_id} "
+ f"to user '{args.username_to_assign}' with ID {user_assign[0].userId}{scope}"
)
else:
role_id = group_assign[0].roleAssignment.role.roleId
print(
f"assigned role '{args.role_name}' with ID {role_id} "
+ f"to group '{args.group_name_to_assign}' with ID {group_assign[0].groupId}{scope}"
)


@authentication.required
@require_feature_flag("rbacEnabled", rbac_flag_disabled_message)
def unassign_role(args: Namespace) -> None:
session = setup_session(args)
user_assign, group_assign = create_assignment_request(session, args)
user_assign, group_assign = api.create_assignment_request(
session=session,
role_name=args.role_name,
workspace_name=args.workspace_name,
username_to_assign=args.username_to_assign,
group_name_to_assign=args.group_name_to_assign,
)

req = bindings.v1RemoveAssignmentsRequest(
userRoleAssignments=user_assign, groupRoleAssignments=group_assign
)
Expand All @@ -393,15 +354,6 @@ def unassign_role(args: Namespace) -> None:
)


def role_name_to_role_id(session: session.Session, role_name: str) -> int:
req = bindings.v1ListRolesRequest(limit=499, offset=0)
resp = bindings.post_ListRoles(session=session, body=req)
for r in resp.roles:
if r.name == role_name and r.roleId is not None:
return r.roleId
raise api.errors.BadRequestException(f"could not find role name {role_name}")


args_description = [
Cmd(
"rbac",
Expand Down
Loading

0 comments on commit 92b36be

Please sign in to comment.