diff --git a/master/internal/command/command_job_service.go b/master/internal/command/command_job_service.go index 92ad1850075..ffbfd0c9d63 100644 --- a/master/internal/command/command_job_service.go +++ b/master/internal/command/command_job_service.go @@ -6,8 +6,10 @@ import ( "google.golang.org/protobuf/types/known/timestamppb" "github.com/determined-ai/determined/master/internal/config" + "github.com/determined-ai/determined/master/internal/configpolicy" "github.com/determined-ai/determined/master/internal/rm/rmerrors" "github.com/determined-ai/determined/master/internal/sproto" + "github.com/determined-ai/determined/master/pkg/model" "github.com/determined-ai/determined/proto/pkg/jobv1" ) @@ -47,6 +49,23 @@ func (c *Command) SetJobPriority(priority int) error { if priority < 1 || priority > 99 { return fmt.Errorf("priority must be between 1 and 99") } + + // Returns an error if RM does not implement priority. + if smallerHigher, err := c.rm.SmallerValueIsHigherPriority(); err == nil { + ok, err := configpolicy.PriorityAllowed( + int(c.GenericCommandSpec.Metadata.WorkspaceID), + model.NTSCType, + priority, + smallerHigher, + ) + if err != nil { + return err + } + if !ok { + return fmt.Errorf("priority exceeds task config policy's priority_limit") + } + } + err := c.setNTSCPriority(priority, true) if err != nil { c.syslog.WithError(err).Info("setting command job priority") diff --git a/master/internal/configpolicy/postgres_task_config_policy.go b/master/internal/configpolicy/postgres_task_config_policy.go index a1c74a2da72..f4574bce429 100644 --- a/master/internal/configpolicy/postgres_task_config_policy.go +++ b/master/internal/configpolicy/postgres_task_config_policy.go @@ -3,6 +3,7 @@ package configpolicy import ( "context" "database/sql" + "encoding/json" "fmt" "strings" @@ -93,6 +94,43 @@ func GetTaskConfigPolicies(ctx context.Context, return &tcp, nil } +// GetPriorityLimit reads the priority limit for the given scope and workload type. +// It returns found=false if no limit exists. +func GetPriorityLimit(ctx context.Context, scope *int, workloadType string) (limit int, found bool, err error) { + if !ValidWorkloadType(workloadType) { + return 0, false, fmt.Errorf("invalid workload type: %s", workloadType) + } + + wkspQuery := wkspIDQuery + if scope == nil { + wkspQuery = wkspIDGlobalQuery + } + + var constraints model.Constraints + var constraintsStr string + err = db.Bun().NewSelect(). + Table("task_config_policies"). + Column("constraints"). + Where(wkspQuery, scope). + Where("workload_type = ?", workloadType). + Scan(ctx, &constraintsStr) + + if err == sql.ErrNoRows { + return 0, false, nil + } else if err != nil { + return 0, false, fmt.Errorf("error retrieving priority limit: %w", err) + } + + if err = json.Unmarshal([]byte(constraintsStr), &constraints); err != nil { + return 0, false, err + } + if constraints.PriorityLimit != nil { + return *constraints.PriorityLimit, true, nil + } + + return 0, false, nil +} + // DeleteConfigPolicies deletes the invariant experiment config and constraints for the // given scope (global or workspace-level) and workload type. func DeleteConfigPolicies(ctx context.Context, diff --git a/master/internal/configpolicy/postgres_task_config_policy_intg_test.go b/master/internal/configpolicy/postgres_task_config_policy_intg_test.go index a7dcd69a0b2..119436e5e4e 100644 --- a/master/internal/configpolicy/postgres_task_config_policy_intg_test.go +++ b/master/internal/configpolicy/postgres_task_config_policy_intg_test.go @@ -6,6 +6,7 @@ package configpolicy import ( "context" "encoding/json" + "fmt" "testing" "time" @@ -206,6 +207,124 @@ func TestSetTaskConfigPolicies(t *testing.T) { require.ErrorContains(t, err, "violates foreign key constraint") } +func TestWorkspaceGetPriorityLimit(t *testing.T) { + ctx := context.Background() + require.NoError(t, etc.SetRootPath(db.RootFromDB)) + pgDB, cleanup := db.MustResolveNewPostgresDatabase(t) + defer cleanup() + db.MustMigrateTestPostgres(t, pgDB, db.MigrationsFromDB) + user := db.RequireMockUser(t, pgDB) + + // Add a workspace to use. + w := model.Workspace{Name: uuid.NewString(), UserID: user.ID} + _, err := db.Bun().NewInsert().Model(&w).Exec(ctx) + require.NoError(t, err) + defer func() { + err := db.CleanupMockWorkspace([]int32{int32(w.ID)}) + if err != nil { + log.Errorf("error when cleaning up mock workspaces") + } + }() + + // No limit set. + _, found, err := GetPriorityLimit(ctx, nil, model.NTSCType) + require.NoError(t, err) + require.False(t, found) + + // Add priority limit for workspace NTSC. + wkspLimit := 20 + constraints := fmt.Sprintf(`{"priority_limit": %d}`, wkspLimit) + wkspInput := model.TaskConfigPolicies{ + WorkloadType: model.NTSCType, + WorkspaceID: &w.ID, + Constraints: &constraints, + LastUpdatedBy: user.ID, + } + + err = SetTaskConfigPolicies(ctx, &wkspInput) + require.NoError(t, err) + + // Get priority limit; should match workspace limit. + res, found, err := GetPriorityLimit(ctx, &w.ID, model.NTSCType) + require.NoError(t, err) + require.True(t, found) + require.Equal(t, wkspLimit, res) + + // Get limit for a workspace that does not exist. + wkspIDDoesNotExist := 404 + _, found, err = GetPriorityLimit(ctx, &wkspIDDoesNotExist, model.NTSCType) + require.NoError(t, err) + require.False(t, found) + + // Get global limit. + _, found, err = GetPriorityLimit(ctx, nil, model.NTSCType) + require.NoError(t, err) + require.False(t, found) + + // Get limit for other workload type. + _, found, err = GetPriorityLimit(ctx, &w.ID, model.ExperimentType) + require.NoError(t, err) + require.False(t, found) + + // Try an invalid workload type. + _, found, err = GetPriorityLimit(ctx, &w.ID, "bogus") + require.Error(t, err) + require.False(t, found) +} + +func TestGlobalGetPriorityLimit(t *testing.T) { + ctx := context.Background() + require.NoError(t, etc.SetRootPath(db.RootFromDB)) + pgDB, cleanup := db.MustResolveNewPostgresDatabase(t) + defer cleanup() + db.MustMigrateTestPostgres(t, pgDB, db.MigrationsFromDB) + user := db.RequireMockUser(t, pgDB) + + // Add a workspace to use. + w := model.Workspace{Name: uuid.NewString(), UserID: user.ID} + _, err := db.Bun().NewInsert().Model(&w).Exec(ctx) + require.NoError(t, err) + defer func() { + err := db.CleanupMockWorkspace([]int32{int32(w.ID)}) + if err != nil { + log.Errorf("error when cleaning up mock workspaces") + } + }() + + // No limit set. + _, found, err := GetPriorityLimit(ctx, nil, model.NTSCType) + require.NoError(t, err) + require.False(t, found) + + // Add priority limit for global NTSC. + globalLimit := 5 + constraints := fmt.Sprintf(`{"priority_limit": %d}`, globalLimit) + globalInput := model.TaskConfigPolicies{ + WorkloadType: model.NTSCType, + WorkspaceID: nil, + Constraints: &constraints, + LastUpdatedBy: user.ID, + } + err = SetTaskConfigPolicies(ctx, &globalInput) + require.NoError(t, err) + + // Get priority limit, should be global limit. + res, found, err := GetPriorityLimit(ctx, nil, model.NTSCType) + require.NoError(t, err) + require.True(t, found) + require.Equal(t, globalLimit, res) + + // Get limit for a different workload type. + _, found, err = GetPriorityLimit(ctx, nil, model.ExperimentType) + require.NoError(t, err) + require.False(t, found) + + // Try an invalid workload type. + _, found, err = GetPriorityLimit(ctx, nil, "bogus") + require.Error(t, err) + require.False(t, found) +} + // Test the enforcement of the primary key on the task_config_polciies table. func TestTaskConfigPoliciesUnique(t *testing.T) { ctx := context.Background() diff --git a/master/internal/configpolicy/task_config_policy.go b/master/internal/configpolicy/task_config_policy.go index ca2f07926e5..a6d25794539 100644 --- a/master/internal/configpolicy/task_config_policy.go +++ b/master/internal/configpolicy/task_config_policy.go @@ -1,6 +1,9 @@ package configpolicy import ( + "context" + "fmt" + "github.com/determined-ai/determined/master/pkg/model" "github.com/determined-ai/determined/master/pkg/schemas/expconf" ) @@ -24,3 +27,38 @@ type NTSCConfigPolicies struct { InvariantConfig *model.CommandConfig `json:"invariant_config"` Constraints *model.Constraints `json:"constraints"` } + +// PriorityAllowed returns true if the desired priority is within the limit set by task config policies. +func PriorityAllowed(wkspID int, workloadType string, priority int, smallerHigher bool) (bool, error) { + // Check if a priority limit has been set with a constraint policy. + // Global policies have highest precedence. + limit, found, err := GetPriorityLimit(context.TODO(), nil, workloadType) + if err != nil { + return false, fmt.Errorf("unable to fetch task config policy priority limit") + } + if found { + return priorityWithinLimit(priority, limit, smallerHigher), nil + } + + // TODO use COALESCE instead once postgres updates are complete. + // Workspace policies have second precedence. + limit, found, err = GetPriorityLimit(context.TODO(), &wkspID, workloadType) + if err != nil { + // TODO do we really want to block on this? + return false, fmt.Errorf("unable to fetch task config policy priority limit") + } + if found { + return priorityWithinLimit(priority, limit, smallerHigher), nil + } + + // No priority limit has been set. + return true, nil +} + +func priorityWithinLimit(userPriority int, adminLimit int, smallerHigher bool) bool { + if smallerHigher { + return userPriority >= adminLimit + } + + return userPriority <= adminLimit +} diff --git a/master/internal/configpolicy/task_config_policy_intg_test.go b/master/internal/configpolicy/task_config_policy_intg_test.go new file mode 100644 index 00000000000..66af9ca612b --- /dev/null +++ b/master/internal/configpolicy/task_config_policy_intg_test.go @@ -0,0 +1,84 @@ +package configpolicy + +import ( + "context" + "fmt" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" + + "github.com/determined-ai/determined/master/internal/db" + "github.com/determined-ai/determined/master/pkg/etc" + "github.com/determined-ai/determined/master/pkg/model" +) + +func TestPriorityAllowed(t *testing.T) { + require.NoError(t, etc.SetRootPath(db.RootFromDB)) + pgDB, cleanup := db.MustResolveNewPostgresDatabase(t) + defer cleanup() + db.MustMigrateTestPostgres(t, pgDB, db.MigrationsFromDB) + + // When no constraints are present, any priority is allowed. + ok, err := PriorityAllowed(1, model.NTSCType, 0, true) + require.NoError(t, err) + require.True(t, ok) + + wkspLimit := 50 + user := db.RequireMockUser(t, pgDB) + w := addWorkspacePriorityLimit(t, pgDB, user, wkspLimit) + + // Priority is outside workspace limit. + smallerValueIsHigherPriority := true + ok, err = PriorityAllowed(w.ID, model.NTSCType, wkspLimit-1, smallerValueIsHigherPriority) + require.NoError(t, err) + require.False(t, ok) + + globalLimit := 42 + addGlobalPriorityLimit(t, pgDB, user, globalLimit) + + // Priority is within global limit. + ok, err = PriorityAllowed(w.ID, model.NTSCType, wkspLimit-1, true) + require.NoError(t, err) + require.True(t, ok) + + // Priority is outside global limit. + ok, err = PriorityAllowed(w.ID+1, model.NTSCType, globalLimit-1, true) + require.NoError(t, err) + require.False(t, ok) +} + +func addWorkspacePriorityLimit(t *testing.T, pgDB *db.PgDB, user model.User, limit int) model.Workspace { + ctx := context.Background() + + // add a workspace to use + w := model.Workspace{Name: uuid.NewString(), UserID: user.ID} + _, err := db.Bun().NewInsert().Model(&w).Exec(ctx) + require.NoError(t, err) + + constraints := fmt.Sprintf(`{"priority_limit": %d}`, limit) + input := model.TaskConfigPolicies{ + WorkloadType: model.NTSCType, + WorkspaceID: &w.ID, + Constraints: &constraints, + LastUpdatedBy: user.ID, + } + err = SetTaskConfigPolicies(ctx, &input) + require.NoError(t, err) + + return w +} + +func addGlobalPriorityLimit(t *testing.T, pgDB *db.PgDB, user model.User, limit int) { + ctx := context.Background() + + constraints := fmt.Sprintf(`{"priority_limit": %d}`, limit) + input := model.TaskConfigPolicies{ + WorkloadType: model.NTSCType, + WorkspaceID: nil, + Constraints: &constraints, + LastUpdatedBy: user.ID, + } + err := SetTaskConfigPolicies(ctx, &input) + require.NoError(t, err) +} diff --git a/master/internal/configpolicy/task_config_policy_test.go b/master/internal/configpolicy/task_config_policy_test.go new file mode 100644 index 00000000000..11ac52b1728 --- /dev/null +++ b/master/internal/configpolicy/task_config_policy_test.go @@ -0,0 +1,31 @@ +package configpolicy + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestPriorityWithinLimit(t *testing.T) { + testCases := []struct { + name string + userPriority int + adminLimit int + smallerIsHigher bool + ok bool + }{ + {"smaller is higher - ok", 10, 1, true, true}, + {"smaller is higher - not ok", 10, 20, true, false}, + {"smaller is higher - equal", 20, 20, true, true}, + {"smaller is lower - ok", 11, 13, false, true}, + {"smaller is lower - not ok", 13, 11, false, false}, + {"smaller is lower - equal", 11, 11, false, true}, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + ok := priorityWithinLimit(tt.userPriority, tt.adminLimit, tt.smallerIsHigher) + require.Equal(t, tt.ok, ok) + }) + } +} diff --git a/master/internal/experiment_job_service.go b/master/internal/experiment_job_service.go index 1749ffe1ed1..109590edb5e 100644 --- a/master/internal/experiment_job_service.go +++ b/master/internal/experiment_job_service.go @@ -2,6 +2,7 @@ package internal import ( "context" + "database/sql" "fmt" "strconv" @@ -9,7 +10,9 @@ import ( "google.golang.org/protobuf/types/known/timestamppb" "github.com/determined-ai/determined/master/internal/config" + "github.com/determined-ai/determined/master/internal/configpolicy" "github.com/determined-ai/determined/master/internal/workspace" + "github.com/determined-ai/determined/master/pkg/model" "github.com/determined-ai/determined/proto/pkg/jobv1" ) @@ -53,7 +56,30 @@ func (e *internalExperiment) SetJobPriority(priority int) error { if priority < 1 || priority > 99 { return fmt.Errorf("priority must be between 1 and 99") } - err := e.setPriority(&priority, true) + + workspaceModel, err := workspace.WorkspaceByProjectID(context.TODO(), e.ProjectID) + if err != nil && errors.Cause(err) != sql.ErrNoRows { + return err + } + wkspID := resolveWorkspaceID(workspaceModel) + + // Returns an error if RM does not implement priority. + if smallerHigher, err := e.rm.SmallerValueIsHigherPriority(); err == nil { + ok, err := configpolicy.PriorityAllowed( + wkspID, + model.ExperimentType, + priority, + smallerHigher, + ) + if err != nil { + return err + } + if !ok { + return fmt.Errorf("priority exceeds task config policy's priority_limit") + } + } + + err = e.setPriority(&priority, true) if err != nil { e.syslog.WithError(err).Info("setting experiment job priority") } diff --git a/master/internal/rm/agentrm/agent_resource_manager.go b/master/internal/rm/agentrm/agent_resource_manager.go index c2e3cd14dc7..2482e320297 100644 --- a/master/internal/rm/agentrm/agent_resource_manager.go +++ b/master/internal/rm/agentrm/agent_resource_manager.go @@ -872,3 +872,8 @@ func (a *ResourceManager) stop() { pool.stop() } } + +// SmallerValueIsHigherPriority returns true if smaller priority values indicate a higher priority level. +func (a *ResourceManager) SmallerValueIsHigherPriority() (bool, error) { + return true, nil +} diff --git a/master/internal/rm/dispatcherrm/dispatcher_resource_manager.go b/master/internal/rm/dispatcherrm/dispatcher_resource_manager.go index fe82a0aaeb2..0777552369f 100644 --- a/master/internal/rm/dispatcherrm/dispatcher_resource_manager.go +++ b/master/internal/rm/dispatcherrm/dispatcher_resource_manager.go @@ -2257,3 +2257,8 @@ func (m *DispatcherResourceManager) DisableSlot(*apiv1.DisableSlotRequest, ) (resp *apiv1.DisableSlotResponse, err error) { return nil, errNotSupportedOnHpcCluster } + +// SmallerValueIsHigherPriority returns true if smaller priority values indicate a higher priority level. +func (m *DispatcherResourceManager) SmallerValueIsHigherPriority() (bool, error) { + return false, fmt.Errorf("priority not implemented") +} diff --git a/master/internal/rm/kubernetesrm/kubernetes_resource_manager.go b/master/internal/rm/kubernetesrm/kubernetes_resource_manager.go index 9ec6065dd66..8eadb1291b2 100644 --- a/master/internal/rm/kubernetesrm/kubernetes_resource_manager.go +++ b/master/internal/rm/kubernetesrm/kubernetes_resource_manager.go @@ -726,3 +726,8 @@ func (k ResourceManager) DisableSlot( ) (resp *apiv1.DisableSlotResponse, err error) { return nil, rmerrors.ErrNotSupported } + +// SmallerValueIsHigherPriority returns true if smaller priority values indicate a higher priority level. +func (k *ResourceManager) SmallerValueIsHigherPriority() (bool, error) { + return false, nil +} diff --git a/master/internal/rm/multirm/multirm.go b/master/internal/rm/multirm/multirm.go index be782a09887..ae20726064b 100644 --- a/master/internal/rm/multirm/multirm.go +++ b/master/internal/rm/multirm/multirm.go @@ -547,3 +547,21 @@ func (m *MultiRMRouter) fanOutRMCommand(f func(rm.ResourceManager) error) error } return nil } + +// SmallerValueIsHigherPriority returns true if smaller priority values indicate a higher priority level. +func (m *MultiRMRouter) SmallerValueIsHigherPriority() (bool, error) { + set := false + var smallerIsHigher bool + for _, rm := range m.rms { + s, err := rm.SmallerValueIsHigherPriority() + if err != nil { + return false, err + } + if set && s != smallerIsHigher { + return false, fmt.Errorf("multiRM resource managers use different priority ordering") + } + smallerIsHigher = s + set = true + } + return smallerIsHigher, nil +} diff --git a/master/internal/rm/multirm/multirm_intg_test.go b/master/internal/rm/multirm/multirm_intg_test.go index 918b32a955e..c313402eae1 100644 --- a/master/internal/rm/multirm/multirm_intg_test.go +++ b/master/internal/rm/multirm/multirm_intg_test.go @@ -777,6 +777,50 @@ func TestGetRM(t *testing.T) { } } +func TestSmallerValueIsHigherPriority(t *testing.T) { + t.Run("both RMs same type", func(t *testing.T) { + defaultRM, otherRM, m := mockMultiRM() + defaultRM.On("SmallerValueIsHigherPriority", mock.Anything).Return(true, nil) + otherRM.On("SmallerValueIsHigherPriority", mock.Anything).Return(true, nil) + smallerIsHigher, err := m.SmallerValueIsHigherPriority() + require.NoError(t, err) + require.True(t, smallerIsHigher) + }) + + t.Run("RMs different type", func(t *testing.T) { + // This is not a supported mode in Determined but we still need to test it. + defaultRM, otherRM, m := mockMultiRM() + defaultRM.On("SmallerValueIsHigherPriority", mock.Anything).Return(false, nil) + otherRM.On("SmallerValueIsHigherPriority", mock.Anything).Return(true, nil) + _, err := m.SmallerValueIsHigherPriority() + require.Error(t, err) + }) + + t.Run("RMs with error", func(t *testing.T) { + // This is not a supported mode in Determined but we still need to test it. + defaultRM, otherRM, m := mockMultiRM() + defaultRM.On("SmallerValueIsHigherPriority", mock.Anything).Return(false, fmt.Errorf("error")) + otherRM.On("SmallerValueIsHigherPriority", mock.Anything).Return(false, nil) + _, err := m.SmallerValueIsHigherPriority() + require.Error(t, err) + }) +} + +func mockMultiRM() (*mocks.ResourceManager, *mocks.ResourceManager, *MultiRMRouter) { + defaultRM := &mocks.ResourceManager{} + otherRM := &mocks.ResourceManager{} + + m := &MultiRMRouter{ + defaultClusterName: defaultClusterName, + rms: map[string]rm.ResourceManager{ + defaultClusterName: defaultRM, + "rm1": otherRM, + }, + } + + return defaultRM, otherRM, m +} + func mockRM(poolName rm.ResourcePoolName) *mocks.ResourceManager { mockRM := mocks.ResourceManager{} mockRM.On("GetResourcePools").Return(&apiv1.GetResourcePoolsResponse{ @@ -810,5 +854,6 @@ func mockRM(poolName rm.ResourcePoolName) *mocks.ResourceManager { mockRM.On("DisableSlot", mock.Anything).Return(&apiv1.DisableSlotResponse{}, nil) mockRM.On("DefaultNamespace", mock.Anything).Return("default", nil) mockRM.On("VerifyNamespaceExists", mock.Anything).Return(nil) + return &mockRM } diff --git a/master/internal/rm/resource_manager_iface.go b/master/internal/rm/resource_manager_iface.go index 5bcb8bbd618..6b487e2c5bc 100644 --- a/master/internal/rm/resource_manager_iface.go +++ b/master/internal/rm/resource_manager_iface.go @@ -26,6 +26,7 @@ type ResourceManager interface { SetGroupPriority(sproto.SetGroupPriority) error ExternalPreemptionPending(sproto.PendingPreemption) error IsReattachableOnlyAfterStarted() bool + SmallerValueIsHigherPriority() (bool, error) // Resource pool stuff. GetResourcePools() (*apiv1.GetResourcePoolsResponse, error)