Skip to content

Commit

Permalink
fix: user can only list models with correct permissions + small fixes…
Browse files Browse the repository at this point in the history
… in workspace filtering in get models (determined-ai#681)
  • Loading branch information
nrajanee authored and determined-ci committed Jun 21, 2023
1 parent de5bd47 commit 8d20ab2
Show file tree
Hide file tree
Showing 8 changed files with 125 additions and 53 deletions.
33 changes: 30 additions & 3 deletions e2e_tests/tests/cluster/test_model_registry_rbac.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,15 @@ def all_operations(
assert db_version.name == "Test 2021"

model_obj.move_to_workspace(workspace_name="Uncategorized")
models = determined_obj.get_models(workspace_name="Uncategorized")
models = determined_obj.get_models(workspace_names=["Uncategorized"])
assert model_obj.name in [m.name for m in models]
return model_obj, "Uncategorized"


def view_operations(determined_obj: Determined, model: model.Model, workspace_name: str) -> None:
db_model = determined_obj.get_model(model.name)
assert db_model.name == model.name
models = determined_obj.get_models(workspace_name=workspace_name)
models = determined_obj.get_models(workspace_names=[workspace_name])
assert db_model.name in [m.name for m in models]


Expand All @@ -87,6 +87,7 @@ def test_model_registry_rbac() -> None:
test_user_editor_creds = api_utils.create_test_user()
test_user_workspace_admin_creds = api_utils.create_test_user()
test_user_viewer_creds = api_utils.create_test_user()
test_user_with_no_perms_creds = api_utils.create_test_user()
test_user_model_registry_viewer_creds = api_utils.create_test_user()
admin_session = api_utils.determined_test_session(admin=True)
with setup_workspaces(admin_session) as [test_workspace]:
Expand Down Expand Up @@ -212,7 +213,7 @@ def test_model_registry_rbac() -> None:
model_1, current_model_workspace = all_operations(
determined_obj=d, test_workspace=test_workspace, checkpoint=checkpoint
)
print(test_user_model_registry_viewer_creds.username)

with logged_in_user(test_user_model_registry_viewer_creds):
d = Determined(master_url)
user_with_view_perms_test(
Expand All @@ -225,6 +226,32 @@ def test_model_registry_rbac() -> None:
determined_obj=d, workspace_name=current_model_workspace, model=model_1
)

with logged_in_user(test_user_with_no_perms_creds):
d = Determined(master_url)
with pytest.raises(Exception) as e:
d.get_models()
assert "doesn't have view permissions" in str(e.value)

# Unassign view permissions to a certain workspace.
# List should return models only in workspaces with permissions.
with logged_in_user(ADMIN_CREDENTIALS):
det_cmd(
[
"rbac",
"unassign-role",
"ModelRegistryViewer",
"--username-to-assign",
test_user_model_registry_viewer_creds.username,
"--workspace-name",
test_workspace.name,
],
check=True,
)
with logged_in_user(test_user_model_registry_viewer_creds):
d = Determined(master_url)
models = d.get_models()
assert test_workspace.id not in [m.workspace_id for m in models]

with logged_in_user(test_user_editor_creds):
d = Determined(master_url)
model = d.get_model(model_1.name)
Expand Down
18 changes: 9 additions & 9 deletions master/internal/model/authz_permissive.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,31 +12,31 @@ import (
// ModelAuthZPermissive is the permission implementation.
type ModelAuthZPermissive struct{}

// CanGetModels always returns true and a nil error.
func (a *ModelAuthZPermissive) CanGetModels(ctx context.Context, curUser model.User,
workspaceID int32,
) (canGetModel bool, serverError error) {
_, _ = (&ModelAuthZRBAC{}).CanGetModels(ctx, curUser, workspaceID)
return (&ModelAuthZBasic{}).CanGetModels(ctx, curUser, workspaceID)
// CanGetModels calls RBAC authz but enforces basic authz..
func (a *ModelAuthZPermissive) CanGetModels(ctx context.Context,
curUser model.User, workspaceIDs []int32,
) (workspaceIDsWithPermsFilter []int32, canGetModels bool, serverError error) {
_, _, _ = (&ModelAuthZRBAC{}).CanGetModels(ctx, curUser, workspaceIDs) //nolint:dogsled
return (&ModelAuthZBasic{}).CanGetModels(ctx, curUser, workspaceIDs)
}

// CanGetModel always returns true and a nil error.
// CanGetModel calls RBAC authz but enforces basic authz..
func (a *ModelAuthZPermissive) CanGetModel(ctx context.Context, curUser model.User,
m *modelv1.Model, workspaceID int32,
) (canGetModel bool, serverError error) {
_, _ = (&ModelAuthZRBAC{}).CanGetModel(ctx, curUser, m, workspaceID)
return (&ModelAuthZBasic{}).CanGetModel(ctx, curUser, m, workspaceID)
}

// CanEditModel always returns true and a nil error.
// CanEditModel calls RBAC authz but enforces basic authz..
func (a *ModelAuthZPermissive) CanEditModel(ctx context.Context, curUser model.User,
m *modelv1.Model, workspaceID int32,
) error {
_ = (&ModelAuthZRBAC{}).CanEditModel(ctx, curUser, m, workspaceID)
return (&ModelAuthZBasic{}).CanEditModel(ctx, curUser, m, workspaceID)
}

// CanCreateModel always returns true and a nil error.
// CanCreateModel calls RBAC authz but enforces basic authz..
func (a *ModelAuthZPermissive) CanCreateModel(ctx context.Context,
curUser model.User, workspaceID int32,
) error {
Expand Down
62 changes: 48 additions & 14 deletions master/internal/model/authz_rbac.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,28 +37,62 @@ func addExpInfo(
}
}

// CanGetModels always returns true and a nil error.
func (a *ModelAuthZRBAC) CanGetModels(ctx context.Context, curUser model.User, workspaceID int32,
) (canGetModel bool, serverError error) {
// CanGetModels checks if a user has permissions to view models.
func (a *ModelAuthZRBAC) CanGetModels(ctx context.Context, curUser model.User, workspaceIDs []int32,
) (workspaceIDsWithPermsFilter []int32, canGetModels bool, serverError error) {
fields := audit.ExtractLogFields(ctx)
addExpInfo(curUser, fields, fmt.Sprintf("all models in %d", workspaceID),
addExpInfo(curUser, fields, fmt.Sprintf("all models in workspaces %v", workspaceIDs),
[]rbacv1.PermissionType{rbacv1.PermissionType_PERMISSION_TYPE_VIEW_MODEL_REGISTRY})
defer func() {
fields["permissionGranted"] = canGetModel
fields["permissionGranted"] = canGetModels
audit.Log(fields)
}()

if err := db.DoesPermissionMatch(ctx, curUser.ID, &workspaceID,
rbacv1.PermissionType_PERMISSION_TYPE_VIEW_MODEL_REGISTRY); err != nil {
if _, ok := err.(authz.PermissionDeniedError); ok {
return false, nil
assignmentsMap, err := rbac.GetPermissionSummary(ctx, curUser.ID)
if err != nil {
return workspaceIDs, false, err
}

workspacesIDsWithPermsSet := make(map[int32]bool)
var workspacesIDsWithPerms []int32

for role, roleAssignments := range assignmentsMap {
for _, permission := range role.Permissions {
if permission.ID == int(
rbacv1.PermissionType_PERMISSION_TYPE_VIEW_MODEL_REGISTRY) {
for _, assignment := range roleAssignments {
if assignment.Scope.WorkspaceID.Valid {
workspacesIDsWithPermsSet[assignment.Scope.WorkspaceID.Int32] = true
workspacesIDsWithPerms = append(workspacesIDsWithPerms, assignment.Scope.WorkspaceID.Int32)
} else {
// if permission is global, return true and the list provided by user.
return workspaceIDs, true, nil
}
}
}
}
return false, err
}
return true, nil

if workspacesIDsWithPerms == nil {
return nil, false, nil // user doesn't have permissions to see models in any workspace.
}

for _, givenWID := range workspaceIDs {
if _, ok := workspacesIDsWithPermsSet[givenWID]; !ok {
return nil, false, nil
// user doesn't have permissions to see models in the user given list of workspaces.
}
}

if workspaceIDs != nil {
return workspaceIDs, true, nil // at this point the user given workspaceIDs
// could be smaller than the workspaces with permissions.
}

return workspacesIDsWithPerms, true, nil
}

// CanGetModel always returns true and a nil error.
// CanGetModel checks if a user has permissions to view model.
func (a *ModelAuthZRBAC) CanGetModel(ctx context.Context, curUser model.User,
m *modelv1.Model, workspaceID int32,
) (canGetModel bool, serverError error) {
Expand All @@ -80,7 +114,7 @@ func (a *ModelAuthZRBAC) CanGetModel(ctx context.Context, curUser model.User,
return true, nil
}

// CanEditModel always returns true and a nil error.
// CanEditModel checks is user has permissions to edit models.
func (a *ModelAuthZRBAC) CanEditModel(ctx context.Context, curUser model.User,
m *modelv1.Model, workspaceID int32,
) (err error) {
Expand All @@ -95,7 +129,7 @@ func (a *ModelAuthZRBAC) CanEditModel(ctx context.Context, curUser model.User,
rbacv1.PermissionType_PERMISSION_TYPE_EDIT_MODEL_REGISTRY)
}

// CanCreateModel always returns true and a nil error.
// CanCreateModel checks is user has permissions to create models.
func (a *ModelAuthZRBAC) CanCreateModel(ctx context.Context,
curUser model.User, workspaceID int32,
) (err error) {
Expand Down
3 changes: 1 addition & 2 deletions master/internal/plugin/oauth/management.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (

"github.com/labstack/echo/v4"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
log "github.com/sirupsen/logrus"
"golang.org/x/crypto/scrypt"

Expand Down Expand Up @@ -113,7 +112,7 @@ func (s *Service) ValidateRequest(c echo.Context) (bool, error) {

token, err := s.server.Manager.LoadAccessToken(bearer)
if err != nil {
logrus.WithError(err).Error("failed to load access token")
log.WithError(err).Error("failed to load access token")
return false, nil
}
return token != nil, nil
Expand Down
2 changes: 1 addition & 1 deletion master/internal/project/authz_rbac.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func permCheck(
}

// CanGetProject returns true if user has "VIEW_PROJECT" globally
// or on a given workspace scope scope and false if not along
// or on a given workspace scope and false if not along
// with a serverError in case of a database failure.
func (a *ProjectAuthZRBAC) CanGetProject(
ctx context.Context, curUser model.User, project *projectv1.Project,
Expand Down
16 changes: 9 additions & 7 deletions master/internal/rm/dispatcherrm/dispatcher_monitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ var messagePatternsOfInterest = []*regexp.Regexp{
}

// containerInfo stores the data sent by the container in the
// "NotifyContainerRunning" message, so that we can keep keep track of which
// "NotifyContainerRunning" message, so that we can keep track of which
// containers are running.
type containerInfo struct {
nodeName string
Expand Down Expand Up @@ -374,7 +374,8 @@ func (m *launcherMonitor) isJobBeingMonitored(dispatchID string) bool {
// jobs are them removed from further consideration.
func (m *launcherMonitor) processWatchedJobs(
ctx *actor.Context,
processingWatchedJobs *bool) {
processingWatchedJobs *bool,
) {
defer setBoolean(processingWatchedJobs, false, &m.processingWatchedJobsMutex)

var job launcherJob
Expand Down Expand Up @@ -531,7 +532,8 @@ func (m *launcherMonitor) clearJobsToRemoveMap() {
// status check was made to the launcher.
func (m *launcherMonitor) getDispatchIDsSortedByLastJobStatusCheckTime(
monitoredJobs map[string]launcherJob,
ctx *actor.Context) []string {
ctx *actor.Context,
) []string {
// Obtain a read lock.
m.mu.RLock()
defer m.mu.RUnlock()
Expand Down Expand Up @@ -738,7 +740,7 @@ func (m *launcherMonitor) getDispatchStatus(
resp, r, err := m.apiClient.MonitoringApi.
GetEnvironmentStatus(m.authContext(ctx), owner, dispatchID).
Refresh(true).
Execute()
Execute() //nolint:bodyclose
if err != nil {
// This may happen if the job is canceled before the launcher creates
// the environment files containing status. Wouldn't expect this to
Expand Down Expand Up @@ -819,7 +821,7 @@ func calculateJobExitStatus(
func getJobExitMessages(resp launcher.DispatchInfo) []string {
var result []string
for _, event := range resp.GetEvents() {
if "com.cray.analytics.capsules.dispatcher.shasta.ShastaDispatcher" == *event.Reporter {
if *event.Reporter == "com.cray.analytics.capsules.dispatcher.shasta.ShastaDispatcher" {
// Ignore general dispatcher messages, only want carrier messages
continue
}
Expand Down Expand Up @@ -851,7 +853,7 @@ func (m *launcherMonitor) getTaskLogsFromDispatcher(
// the log file content, read the payload name from the launcher.
if len(job.payloadName) == 0 {
manifest, resp, err := m.apiClient.MonitoringApi.GetEnvironmentDetails(
m.authContext(ctx), job.user, dispatchID).Execute()
m.authContext(ctx), job.user, dispatchID).Execute() //nolint:bodyclose
if err != nil {
ctx.Log().WithError(err).Warnf(
"For dispatchID %s, unable to access environment details, response {%v}",
Expand All @@ -868,7 +870,7 @@ func (m *launcherMonitor) getTaskLogsFromDispatcher(

logFile, httpResponse, err := m.apiClient.MonitoringApi.LoadEnvironmentLog(
m.authContext(ctx), job.user, dispatchID, logFileName,
).Range_(logRange).Execute()
).Range_(logRange).Execute() //nolint:bodyclose
if err != nil {
ctx.Log().WithError(err).Warnf("For dispatchID %s, unable to access %s, response {%v}",
dispatchID, logFileName, httpResponse)
Expand Down
Loading

0 comments on commit 8d20ab2

Please sign in to comment.