Skip to content

Commit

Permalink
fix: user can only list models with correct permissions + small fixes…
Browse files Browse the repository at this point in the history
… in workspace filtering in get models (determined-ai#681)

[excluding e2e_tests changes]
  • Loading branch information
nrajanee authored and eecsliu committed Apr 16, 2024
1 parent e30636f commit f3adc8e
Show file tree
Hide file tree
Showing 7 changed files with 95 additions and 50 deletions.
18 changes: 9 additions & 9 deletions master/internal/model/authz_permissive.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,31 +12,31 @@ import (
// ModelAuthZPermissive is the permission implementation.
type ModelAuthZPermissive struct{}

// CanGetModels always returns true and a nil error.
func (a *ModelAuthZPermissive) CanGetModels(ctx context.Context, curUser model.User,
workspaceID int32,
) (canGetModel bool, serverError error) {
_, _ = (&ModelAuthZRBAC{}).CanGetModels(ctx, curUser, workspaceID)
return (&ModelAuthZBasic{}).CanGetModels(ctx, curUser, workspaceID)
// CanGetModels calls RBAC authz but enforces basic authz..
func (a *ModelAuthZPermissive) CanGetModels(ctx context.Context,
curUser model.User, workspaceIDs []int32,
) (workspaceIDsWithPermsFilter []int32, canGetModels bool, serverError error) {
_, _, _ = (&ModelAuthZRBAC{}).CanGetModels(ctx, curUser, workspaceIDs) //nolint:dogsled
return (&ModelAuthZBasic{}).CanGetModels(ctx, curUser, workspaceIDs)
}

// CanGetModel always returns true and a nil error.
// CanGetModel calls RBAC authz but enforces basic authz..
func (a *ModelAuthZPermissive) CanGetModel(ctx context.Context, curUser model.User,
m *modelv1.Model, workspaceID int32,
) (canGetModel bool, serverError error) {
_, _ = (&ModelAuthZRBAC{}).CanGetModel(ctx, curUser, m, workspaceID)
return (&ModelAuthZBasic{}).CanGetModel(ctx, curUser, m, workspaceID)
}

// CanEditModel always returns true and a nil error.
// CanEditModel calls RBAC authz but enforces basic authz..
func (a *ModelAuthZPermissive) CanEditModel(ctx context.Context, curUser model.User,
m *modelv1.Model, workspaceID int32,
) error {
_ = (&ModelAuthZRBAC{}).CanEditModel(ctx, curUser, m, workspaceID)
return (&ModelAuthZBasic{}).CanEditModel(ctx, curUser, m, workspaceID)
}

// CanCreateModel always returns true and a nil error.
// CanCreateModel calls RBAC authz but enforces basic authz..
func (a *ModelAuthZPermissive) CanCreateModel(ctx context.Context,
curUser model.User, workspaceID int32,
) error {
Expand Down
62 changes: 48 additions & 14 deletions master/internal/model/authz_rbac.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,28 +37,62 @@ func addExpInfo(
}
}

// CanGetModels always returns true and a nil error.
func (a *ModelAuthZRBAC) CanGetModels(ctx context.Context, curUser model.User, workspaceID int32,
) (canGetModel bool, serverError error) {
// CanGetModels checks if a user has permissions to view models.
func (a *ModelAuthZRBAC) CanGetModels(ctx context.Context, curUser model.User, workspaceIDs []int32,
) (workspaceIDsWithPermsFilter []int32, canGetModels bool, serverError error) {
fields := audit.ExtractLogFields(ctx)
addExpInfo(curUser, fields, fmt.Sprintf("all models in %d", workspaceID),
addExpInfo(curUser, fields, fmt.Sprintf("all models in workspaces %v", workspaceIDs),
[]rbacv1.PermissionType{rbacv1.PermissionType_PERMISSION_TYPE_VIEW_MODEL_REGISTRY})
defer func() {
fields["permissionGranted"] = canGetModel
fields["permissionGranted"] = canGetModels
audit.Log(fields)
}()

if err := db.DoesPermissionMatch(ctx, curUser.ID, &workspaceID,
rbacv1.PermissionType_PERMISSION_TYPE_VIEW_MODEL_REGISTRY); err != nil {
if _, ok := err.(authz.PermissionDeniedError); ok {
return false, nil
assignmentsMap, err := rbac.GetPermissionSummary(ctx, curUser.ID)
if err != nil {
return workspaceIDs, false, err
}

workspacesIDsWithPermsSet := make(map[int32]bool)
var workspacesIDsWithPerms []int32

for role, roleAssignments := range assignmentsMap {
for _, permission := range role.Permissions {
if permission.ID == int(
rbacv1.PermissionType_PERMISSION_TYPE_VIEW_MODEL_REGISTRY) {
for _, assignment := range roleAssignments {
if assignment.Scope.WorkspaceID.Valid {
workspacesIDsWithPermsSet[assignment.Scope.WorkspaceID.Int32] = true
workspacesIDsWithPerms = append(workspacesIDsWithPerms, assignment.Scope.WorkspaceID.Int32)
} else {
// if permission is global, return true and the list provided by user.
return workspaceIDs, true, nil
}
}
}
}
return false, err
}
return true, nil

if workspacesIDsWithPerms == nil {
return nil, false, nil // user doesn't have permissions to see models in any workspace.
}

for _, givenWID := range workspaceIDs {
if _, ok := workspacesIDsWithPermsSet[givenWID]; !ok {
return nil, false, nil
// user doesn't have permissions to see models in the user given list of workspaces.
}
}

if workspaceIDs != nil {
return workspaceIDs, true, nil // at this point the user given workspaceIDs
// could be smaller than the workspaces with permissions.
}

return workspacesIDsWithPerms, true, nil
}

// CanGetModel always returns true and a nil error.
// CanGetModel checks if a user has permissions to view model.
func (a *ModelAuthZRBAC) CanGetModel(ctx context.Context, curUser model.User,
m *modelv1.Model, workspaceID int32,
) (canGetModel bool, serverError error) {
Expand All @@ -80,7 +114,7 @@ func (a *ModelAuthZRBAC) CanGetModel(ctx context.Context, curUser model.User,
return true, nil
}

// CanEditModel always returns true and a nil error.
// CanEditModel checks is user has permissions to edit models.
func (a *ModelAuthZRBAC) CanEditModel(ctx context.Context, curUser model.User,
m *modelv1.Model, workspaceID int32,
) (err error) {
Expand All @@ -95,7 +129,7 @@ func (a *ModelAuthZRBAC) CanEditModel(ctx context.Context, curUser model.User,
rbacv1.PermissionType_PERMISSION_TYPE_EDIT_MODEL_REGISTRY)
}

// CanCreateModel always returns true and a nil error.
// CanCreateModel checks is user has permissions to create models.
func (a *ModelAuthZRBAC) CanCreateModel(ctx context.Context,
curUser model.User, workspaceID int32,
) (err error) {
Expand Down
3 changes: 1 addition & 2 deletions master/internal/plugin/oauth/management.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (

"github.com/labstack/echo/v4"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
log "github.com/sirupsen/logrus"
"golang.org/x/crypto/scrypt"

Expand Down Expand Up @@ -113,7 +112,7 @@ func (s *Service) ValidateRequest(c echo.Context) (bool, error) {

token, err := s.server.Manager.LoadAccessToken(bearer)
if err != nil {
logrus.WithError(err).Error("failed to load access token")
log.WithError(err).Error("failed to load access token")
return false, nil
}
return token != nil, nil
Expand Down
2 changes: 1 addition & 1 deletion master/internal/project/authz_rbac.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func permCheck(
}

// CanGetProject returns true if user has "VIEW_PROJECT" globally
// or on a given workspace scope scope and false if not along
// or on a given workspace scope and false if not along
// with a serverError in case of a database failure.
func (a *ProjectAuthZRBAC) CanGetProject(
ctx context.Context, curUser model.User, project *projectv1.Project,
Expand Down
16 changes: 9 additions & 7 deletions master/internal/rm/dispatcherrm/dispatcher_monitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ var messagePatternsOfInterest = []*regexp.Regexp{
}

// containerInfo stores the data sent by the container in the
// "NotifyContainerRunning" message, so that we can keep keep track of which
// "NotifyContainerRunning" message, so that we can keep track of which
// containers are running.
type containerInfo struct {
nodeName string
Expand Down Expand Up @@ -374,7 +374,8 @@ func (m *launcherMonitor) isJobBeingMonitored(dispatchID string) bool {
// jobs are them removed from further consideration.
func (m *launcherMonitor) processWatchedJobs(
ctx *actor.Context,
processingWatchedJobs *bool) {
processingWatchedJobs *bool,
) {
defer setBoolean(processingWatchedJobs, false, &m.processingWatchedJobsMutex)

var job launcherJob
Expand Down Expand Up @@ -531,7 +532,8 @@ func (m *launcherMonitor) clearJobsToRemoveMap() {
// status check was made to the launcher.
func (m *launcherMonitor) getDispatchIDsSortedByLastJobStatusCheckTime(
monitoredJobs map[string]launcherJob,
ctx *actor.Context) []string {
ctx *actor.Context,
) []string {
// Obtain a read lock.
m.mu.RLock()
defer m.mu.RUnlock()
Expand Down Expand Up @@ -738,7 +740,7 @@ func (m *launcherMonitor) getDispatchStatus(
resp, r, err := m.apiClient.MonitoringApi.
GetEnvironmentStatus(m.authContext(ctx), owner, dispatchID).
Refresh(true).
Execute()
Execute() //nolint:bodyclose
if err != nil {
// This may happen if the job is canceled before the launcher creates
// the environment files containing status. Wouldn't expect this to
Expand Down Expand Up @@ -819,7 +821,7 @@ func calculateJobExitStatus(
func getJobExitMessages(resp launcher.DispatchInfo) []string {
var result []string
for _, event := range resp.GetEvents() {
if "com.cray.analytics.capsules.dispatcher.shasta.ShastaDispatcher" == *event.Reporter {
if *event.Reporter == "com.cray.analytics.capsules.dispatcher.shasta.ShastaDispatcher" {
// Ignore general dispatcher messages, only want carrier messages
continue
}
Expand Down Expand Up @@ -851,7 +853,7 @@ func (m *launcherMonitor) getTaskLogsFromDispatcher(
// the log file content, read the payload name from the launcher.
if len(job.payloadName) == 0 {
manifest, resp, err := m.apiClient.MonitoringApi.GetEnvironmentDetails(
m.authContext(ctx), job.user, dispatchID).Execute()
m.authContext(ctx), job.user, dispatchID).Execute() //nolint:bodyclose
if err != nil {
ctx.Log().WithError(err).Warnf(
"For dispatchID %s, unable to access environment details, response {%v}",
Expand All @@ -868,7 +870,7 @@ func (m *launcherMonitor) getTaskLogsFromDispatcher(

logFile, httpResponse, err := m.apiClient.MonitoringApi.LoadEnvironmentLog(
m.authContext(ctx), job.user, dispatchID, logFileName,
).Range_(logRange).Execute()
).Range_(logRange).Execute() //nolint:bodyclose
if err != nil {
ctx.Log().WithError(err).Warnf("For dispatchID %s, unable to access %s, response {%v}",
dispatchID, logFileName, httpResponse)
Expand Down
41 changes: 25 additions & 16 deletions master/internal/rm/dispatcherrm/dispatcher_resource_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ type schedulerTick struct{}
const actionCoolDown = 500 * time.Millisecond

// hpcResources is a data type describing the HPC resources available
// to Slurm on on the Launcher node.
// to Slurm on the Launcher node.
// Example output of the HPC resource details from the Launcher.
// ---
// partitions:
Expand All @@ -72,8 +72,8 @@ const actionCoolDown = 500 * time.Millisecond
// - totalAvailableNodes: 293
// ...more partitions.
type hpcResources struct {
Partitions []hpcPartitionDetails `json:"partitions,flow"`
Nodes []hpcNodeDetails `json:"nodes,flow"`
Partitions []hpcPartitionDetails `json:"partitions,flow"` //nolint:staticcheck
Nodes []hpcNodeDetails `json:"nodes,flow"` //nolint:staticcheck
DefaultComputePoolPartition string `json:"defaultComputePoolPartition"`
DefaultAuxPoolPartition string `json:"defaultAuxPoolPartition"`
}
Expand Down Expand Up @@ -879,7 +879,8 @@ func (m *dispatcherResourceManager) waitForDispatchTerminalState(ctx *actor.Cont

func (m *dispatcherResourceManager) startLauncherJob(
ctx *actor.Context,
msg StartDispatcherResources) {
msg StartDispatcherResources,
) {
req, ok := m.reqList.TaskByHandler(msg.TaskActor)
if !ok {
sendResourceStateChangedErrorResponse(ctx, errors.New("no such task"), msg,
Expand Down Expand Up @@ -922,6 +923,7 @@ func (m *dispatcherResourceManager) startLauncherJob(

if impersonatedUser == root && m.rmConfig.UserName != root {
sendResourceStateChangedErrorResponse(ctx,
//nolint:stylecheck
fmt.Errorf(
"You are logged in as Determined user '%s', however the user ID on the "+
"target HPC cluster for this user has either not been configured, or has "+
Expand Down Expand Up @@ -961,7 +963,8 @@ func (m *dispatcherResourceManager) startLauncherJob(
// Adds the mapping of dispatch ID to allocation ID.
func (m *dispatcherResourceManager) addDispatchIDToAllocationMap(
dispatchID string,
allocationID model.AllocationID) {
allocationID model.AllocationID,
) {
// Read/Write lock blocks other readers and writers.
m.dispatchIDToAllocationIDMutex.Lock()
defer m.dispatchIDToAllocationIDMutex.Unlock()
Expand All @@ -971,7 +974,8 @@ func (m *dispatcherResourceManager) addDispatchIDToAllocationMap(

// Removes the mapping from dispatch ID to allocation ID.
func (m *dispatcherResourceManager) removeDispatchIDFromAllocationIDMap(
dispatchID string) {
dispatchID string,
) {
// Read/Write lock blocks other readers and writers.
m.dispatchIDToAllocationIDMutex.Lock()
defer m.dispatchIDToAllocationIDMutex.Unlock()
Expand All @@ -981,7 +985,8 @@ func (m *dispatcherResourceManager) removeDispatchIDFromAllocationIDMap(

// Gets the allocation ID for the specified dispatch ID.
func (m *dispatcherResourceManager) getAllocationIDFromDispatchID(
dispatchID string) (model.AllocationID, bool) {
dispatchID string,
) (model.AllocationID, bool) {
// Read lock allows multiple readers, but block writers.
m.dispatchIDToAllocationIDMutex.RLock()
defer m.dispatchIDToAllocationIDMutex.RUnlock()
Expand All @@ -994,7 +999,8 @@ func (m *dispatcherResourceManager) getAllocationIDFromDispatchID(
// Adds the mapping of dispatch ID to HPC job ID.
func (m *dispatcherResourceManager) addDispatchIDToHpcJobIDMap(
dispatchID string,
hpcJobID string) {
hpcJobID string,
) {
// Read/Write lock blocks other readers and writers.
m.dispatchIDToHPCJobIDMutex.Lock()
defer m.dispatchIDToHPCJobIDMutex.Unlock()
Expand All @@ -1004,7 +1010,8 @@ func (m *dispatcherResourceManager) addDispatchIDToHpcJobIDMap(

// Removes the mapping from dispatch ID to allocaiton ID.
func (m *dispatcherResourceManager) removeDispatchIDFromHpcJobIDMap(
dispatchID string) {
dispatchID string,
) {
// Read/Write lock blocks other readers and writers.
m.dispatchIDToHPCJobIDMutex.Lock()
defer m.dispatchIDToHPCJobIDMutex.Unlock()
Expand All @@ -1014,7 +1021,8 @@ func (m *dispatcherResourceManager) removeDispatchIDFromHpcJobIDMap(

// Gets the HPC job ID for the specified dispatch ID.
func (m *dispatcherResourceManager) getHpcJobIDFromDispatchID(
dispatchID string) (string, bool) {
dispatchID string,
) (string, bool) {
// Read lock allows multiple readers, but block writers.
m.dispatchIDToHPCJobIDMutex.RLock()
defer m.dispatchIDToHPCJobIDMutex.RUnlock()
Expand Down Expand Up @@ -1343,7 +1351,8 @@ func (m *dispatcherResourceManager) resolveSlotType(

// retrieves the launcher version and log error if not meeting minimum required version.
func (m *dispatcherResourceManager) getAndCheckLauncherVersion(ctx *actor.Context) {
resp, _, err := m.apiClient.InfoApi.GetServerVersion(m.authContext(ctx)).Execute()
resp, _, err := m.apiClient.InfoApi.GetServerVersion(m.authContext(ctx)).
Execute() //nolint:bodyclose
if err == nil {
if checkMinimumLauncherVersion(resp) {
if !m.launcherVersionIsOK {
Expand Down Expand Up @@ -1398,7 +1407,7 @@ func (m *dispatcherResourceManager) fetchHpcResourceDetails(ctx *actor.Context)
Launch(m.authContext(ctx)).
Manifest(*m.hpcResourcesManifest).
Impersonate(impersonatedUser).
Execute()
Execute() //nolint:bodyclose
if err != nil {
if r != nil && (r.StatusCode == http.StatusUnauthorized ||
r.StatusCode == http.StatusForbidden) {
Expand Down Expand Up @@ -1445,7 +1454,7 @@ func (m *dispatcherResourceManager) fetchHpcResourceDetails(ctx *actor.Context)
// long of a delay for us to deal with.
resp, _, err := m.apiClient.MonitoringApi.
LoadEnvironmentLog(m.authContext(ctx), owner, dispatchID, logFileName).
Execute()
Execute() //nolint:bodyclose
if err != nil {
ctx.Log().WithError(err).Errorf("failed to retrieve HPC Resource details. response: {%v}", resp)
return
Expand Down Expand Up @@ -1574,7 +1583,7 @@ func (m *dispatcherResourceManager) terminateDispatcherJob(ctx *actor.Context,
var err error
var response *http.Response
if _, response, err = m.apiClient.RunningApi.TerminateRunning(m.authContext(ctx),
owner, dispatchID).Force(true).Execute(); err != nil {
owner, dispatchID).Force(true).Execute(); err != nil { //nolint:bodyclose
if response == nil || response.StatusCode != 404 {
ctx.Log().WithError(err).Errorf("Failed to terminate job with Dispatch ID %s, response: {%v}",
dispatchID, response)
Expand All @@ -1601,7 +1610,7 @@ func (m *dispatcherResourceManager) removeDispatchEnvironment(
ctx.Log().Debugf("Deleting environment with DispatchID %s", dispatchID)

if response, err := m.apiClient.MonitoringApi.DeleteEnvironment(m.authContext(ctx),
owner, dispatchID).Execute(); err != nil {
owner, dispatchID).Execute(); err != nil { //nolint:bodyclose
if response == nil || response.StatusCode != 404 {
ctx.Log().WithError(err).Errorf("Failed to remove environment for Dispatch ID %s, response:{%v}",
dispatchID, response)
Expand Down Expand Up @@ -1642,7 +1651,7 @@ func (m *dispatcherResourceManager) sendManifestToDispatcher(
Launch(m.authContext(ctx)).
Manifest(*manifest).
Impersonate(impersonatedUser).
Execute()
Execute() //nolint:bodyclose
if err != nil {
httpStatus := ""
if response != nil {
Expand Down
3 changes: 2 additions & 1 deletion master/pkg/tasks/dispatcher_task.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ func (t *TaskSpec) ToDispatcherManifest(
// So use /var/tmp here to eliminate spurious error logs. We avoid using /tmp
// here because dispatcher-wrapper.sh by default relinks /tmp to
// a container-private directory and if it is in use we faile with EBUSY.
// nolint:dupword
workDir := t.WorkDir
if workDir == DefaultWorkDir {
workDir = varTmp
Expand Down Expand Up @@ -233,7 +234,7 @@ func (t *TaskSpec) ToDispatcherManifest(
customParams["pbsArgs"] = pbsArgs

if containerRunType == podman {
var portMappings []string = *getPortMappings(t)
portMappings := *getPortMappings(t)
if len(portMappings) != 0 {
customParams["ports"] = portMappings
}
Expand Down

0 comments on commit f3adc8e

Please sign in to comment.