From 76424f010350929917d7ee1aac601a4a3a220c3b Mon Sep 17 00:00:00 2001 From: Kristine Kunapuli Date: Wed, 18 Sep 2024 16:13:13 -0400 Subject: [PATCH 1/5] chore: check task config policy priority limit for [CM-490] --- .../internal/command/command_job_service.go | 19 +++ .../postgres_task_config_policy.go | 39 +++++++ .../postgres_task_config_policy_intg_test.go | 109 ++++++++++++++++++ .../configpolicy/task_config_policy.go | 39 +++++++ .../task_config_policy_intg_test.go | 84 ++++++++++++++ .../configpolicy/task_config_policy_test.go | 31 +++++ master/internal/experiment_job_service.go | 28 ++++- .../rm/agentrm/agent_resource_manager.go | 5 + .../dispatcher_resource_manager.go | 5 + .../kubernetes_resource_manager.go | 5 + master/internal/rm/multirm/multirm.go | 5 + master/internal/rm/resource_manager_iface.go | 1 + 12 files changed, 369 insertions(+), 1 deletion(-) create mode 100644 master/internal/configpolicy/task_config_policy_intg_test.go create mode 100644 master/internal/configpolicy/task_config_policy_test.go diff --git a/master/internal/command/command_job_service.go b/master/internal/command/command_job_service.go index 92ad1850075..ffbfd0c9d63 100644 --- a/master/internal/command/command_job_service.go +++ b/master/internal/command/command_job_service.go @@ -6,8 +6,10 @@ import ( "google.golang.org/protobuf/types/known/timestamppb" "github.com/determined-ai/determined/master/internal/config" + "github.com/determined-ai/determined/master/internal/configpolicy" "github.com/determined-ai/determined/master/internal/rm/rmerrors" "github.com/determined-ai/determined/master/internal/sproto" + "github.com/determined-ai/determined/master/pkg/model" "github.com/determined-ai/determined/proto/pkg/jobv1" ) @@ -47,6 +49,23 @@ func (c *Command) SetJobPriority(priority int) error { if priority < 1 || priority > 99 { return fmt.Errorf("priority must be between 1 and 99") } + + // Returns an error if RM does not implement priority. + if smallerHigher, err := c.rm.SmallerValueIsHigherPriority(); err == nil { + ok, err := configpolicy.PriorityAllowed( + int(c.GenericCommandSpec.Metadata.WorkspaceID), + model.NTSCType, + priority, + smallerHigher, + ) + if err != nil { + return err + } + if !ok { + return fmt.Errorf("priority exceeds task config policy's priority_limit") + } + } + err := c.setNTSCPriority(priority, true) if err != nil { c.syslog.WithError(err).Info("setting command job priority") diff --git a/master/internal/configpolicy/postgres_task_config_policy.go b/master/internal/configpolicy/postgres_task_config_policy.go index b160ef5d477..aecb338ae83 100644 --- a/master/internal/configpolicy/postgres_task_config_policy.go +++ b/master/internal/configpolicy/postgres_task_config_policy.go @@ -2,6 +2,8 @@ package configpolicy import ( "context" + "database/sql" + "encoding/json" "fmt" "strings" @@ -88,6 +90,43 @@ func GetTaskConfigPolicies(ctx context.Context, return &tcp, nil } +// GetPriorityLimit reads the priority limit for the given scope and workload type. +// It returns found=false if no limit exists. +func GetPriorityLimit(ctx context.Context, scope *int, workloadType string) (limit int, found bool, err error) { + if !ValidWorkloadType(workloadType) { + return 0, false, fmt.Errorf("invalid workload type: %s", workloadType) + } + + wkspQuery := wkspIDQuery + if scope == nil { + wkspQuery = wkspIDGlobalQuery + } + + var constraints model.Constraints + var constraintsStr string + err = db.Bun().NewSelect(). + Table("task_config_policies"). + Column("constraints"). + Where(wkspQuery, scope). + Where("workload_type = ?", workloadType). + Scan(ctx, &constraintsStr) + + if err == sql.ErrNoRows { + return 0, false, nil + } else if err != nil { + return 0, false, fmt.Errorf("error retrieving priority limit: %w", err) + } + + if err = json.Unmarshal([]byte(constraintsStr), &constraints); err != nil { + return 0, false, err + } + if constraints.PriorityLimit != nil { + return *constraints.PriorityLimit, true, nil + } + + return 0, false, nil +} + // DeleteConfigPolicies deletes the invariant experiment config and constraints for the // given scope (global or workspace-level) and workload type. func DeleteConfigPolicies(ctx context.Context, diff --git a/master/internal/configpolicy/postgres_task_config_policy_intg_test.go b/master/internal/configpolicy/postgres_task_config_policy_intg_test.go index a7dcd69a0b2..2c4121908f3 100644 --- a/master/internal/configpolicy/postgres_task_config_policy_intg_test.go +++ b/master/internal/configpolicy/postgres_task_config_policy_intg_test.go @@ -6,6 +6,7 @@ package configpolicy import ( "context" "encoding/json" + "fmt" "testing" "time" @@ -206,6 +207,114 @@ func TestSetTaskConfigPolicies(t *testing.T) { require.ErrorContains(t, err, "violates foreign key constraint") } +func TestWorkspaceGetPriorityLimit(t *testing.T) { + ctx := context.Background() + require.NoError(t, etc.SetRootPath(db.RootFromDB)) + pgDB, cleanup := db.MustResolveNewPostgresDatabase(t) + defer cleanup() + db.MustMigrateTestPostgres(t, pgDB, db.MigrationsFromDB) + user := db.RequireMockUser(t, pgDB) + + // add a workspace to use + w := model.Workspace{Name: uuid.NewString(), UserID: user.ID} + _, err := db.Bun().NewInsert().Model(&w).Exec(ctx) + require.NoError(t, err) + defer func() { + err := db.CleanupMockWorkspace([]int32{int32(w.ID)}) + if err != nil { + log.Errorf("error when cleaning up mock workspaces") + } + }() + + // add priority limit for workspace NTSC + wkspLimit := 20 + constraints := fmt.Sprintf(`{"priority_limit": %d}`, wkspLimit) + wkspInput := model.TaskConfigPolicies{ + WorkloadType: model.NTSCType, + WorkspaceID: &w.ID, + Constraints: &constraints, + LastUpdatedBy: user.ID, + } + // err = SetNTSCConfigPolicies(ctx, &wkspInput) + err = SetTaskConfigPolicies(ctx, &wkspInput) + require.NoError(t, err) + + // get priority limit; should match workspace + res, found, err := GetPriorityLimit(ctx, &w.ID, model.NTSCType) + require.NoError(t, err) + require.True(t, found) + require.Equal(t, wkspLimit, res) + + // get limit for workspace that does not exist + wkspIDDoesNotExist := 404 + _, found, err = GetPriorityLimit(ctx, &wkspIDDoesNotExist, model.NTSCType) + require.NoError(t, err) + require.False(t, found) + + // read global + _, found, err = GetPriorityLimit(ctx, nil, model.NTSCType) + require.NoError(t, err) + require.False(t, found) + + // read experiment + _, found, err = GetPriorityLimit(ctx, &w.ID, model.ExperimentType) + require.NoError(t, err) + require.False(t, found) + + // read invalid workload type + _, found, err = GetPriorityLimit(ctx, &w.ID, "bogus") + require.Error(t, err) + require.False(t, found) +} + +func TestGlobalGetPriorityLimit(t *testing.T) { + ctx := context.Background() + require.NoError(t, etc.SetRootPath(db.RootFromDB)) + pgDB, cleanup := db.MustResolveNewPostgresDatabase(t) + defer cleanup() + db.MustMigrateTestPostgres(t, pgDB, db.MigrationsFromDB) + user := db.RequireMockUser(t, pgDB) + + // add a workspace to use + w := model.Workspace{Name: uuid.NewString(), UserID: user.ID} + _, err := db.Bun().NewInsert().Model(&w).Exec(ctx) + require.NoError(t, err) + defer func() { + err := db.CleanupMockWorkspace([]int32{int32(w.ID)}) + if err != nil { + log.Errorf("error when cleaning up mock workspaces") + } + }() + + // add priority limit for global NTSC + globalLimit := 5 + constraints := fmt.Sprintf(`{"priority_limit": %d}`, globalLimit) + globalInput := model.TaskConfigPolicies{ + WorkloadType: model.NTSCType, + WorkspaceID: nil, + Constraints: &constraints, + LastUpdatedBy: user.ID, + } + err = SetTaskConfigPolicies(ctx, &globalInput) + require.NoError(t, err) + + // get priority limit + res, found, err := GetPriorityLimit(ctx, nil, model.NTSCType) + require.NoError(t, err) + require.True(t, found) + require.Equal(t, globalLimit, res) + + // read experiment + _, found, err = GetPriorityLimit(ctx, nil, model.ExperimentType) + require.NoError(t, err) + require.False(t, found) + + // read invalid workload type + _, found, err = GetPriorityLimit(ctx, nil, "bogus") + require.Error(t, err) + require.False(t, found) +} + // Test the enforcement of the primary key on the task_config_polciies table. func TestTaskConfigPoliciesUnique(t *testing.T) { ctx := context.Background() diff --git a/master/internal/configpolicy/task_config_policy.go b/master/internal/configpolicy/task_config_policy.go index ca2f07926e5..2848b1dec11 100644 --- a/master/internal/configpolicy/task_config_policy.go +++ b/master/internal/configpolicy/task_config_policy.go @@ -1,6 +1,9 @@ package configpolicy import ( + "context" + "fmt" + "github.com/determined-ai/determined/master/pkg/model" "github.com/determined-ai/determined/master/pkg/schemas/expconf" ) @@ -24,3 +27,39 @@ type NTSCConfigPolicies struct { InvariantConfig *model.CommandConfig `json:"invariant_config"` Constraints *model.Constraints `json:"constraints"` } + +// PriorityAllowed returns true if the desired priority is within the task config policy limit. +func PriorityAllowed(wkspID int, workloadType string, priority int, smallerHigher bool) (bool, error) { + // Check if a priority limit has been set in task config policies. + // Global policies have highest precedence. + limit, found, err := GetPriorityLimit(context.TODO(), nil, workloadType) + if err != nil { + // TODO do we really want to block on this? + return false, fmt.Errorf("unable to fetch task config policy priority limit") + } + if found { + return priorityWithinLimit(priority, limit, smallerHigher), nil + } + + // TODO use COALESCE instead once postgres updates are complete. + // Workspace policies have second precedence. + limit, found, err = GetPriorityLimit(context.TODO(), &wkspID, workloadType) + if err != nil { + // TODO do we really want to block on this? + return false, fmt.Errorf("unable to fetch task config policy priority limit") + } + if found { + return priorityWithinLimit(priority, limit, smallerHigher), nil + } + + // No priority limit has been set. + return true, nil +} + +func priorityWithinLimit(userPriority int, adminLimit int, smallerHigher bool) bool { + if smallerHigher { + return userPriority >= adminLimit + } + + return userPriority <= adminLimit +} diff --git a/master/internal/configpolicy/task_config_policy_intg_test.go b/master/internal/configpolicy/task_config_policy_intg_test.go new file mode 100644 index 00000000000..a0a806d529f --- /dev/null +++ b/master/internal/configpolicy/task_config_policy_intg_test.go @@ -0,0 +1,84 @@ +package configpolicy + +import ( + "context" + "fmt" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" + + "github.com/determined-ai/determined/master/internal/db" + "github.com/determined-ai/determined/master/pkg/etc" + "github.com/determined-ai/determined/master/pkg/model" +) + +func TestPriorityAllowed(t *testing.T) { + require.NoError(t, etc.SetRootPath(db.RootFromDB)) + pgDB, cleanup := db.MustResolveNewPostgresDatabase(t) + defer cleanup() + db.MustMigrateTestPostgres(t, pgDB, db.MigrationsFromDB) + + // When no constraints are present, any priority is allowed. + ok, err := PriorityAllowed(1, model.NTSCType, 0, true) + require.NoError(t, err) + require.True(t, ok) + + wkspLimit := 50 + w := addWorkspacePriorityLimit(t, pgDB, wkspLimit) + + // Priority is outside workspace limit. + ok, err = PriorityAllowed(w.ID, model.NTSCType, wkspLimit-1, true) + require.NoError(t, err) + require.False(t, ok) + + globalLimit := 42 + addGlobalPriorityLimit(t, pgDB, globalLimit) + + // Priority is within global limit. + ok, err = PriorityAllowed(w.ID, model.NTSCType, wkspLimit-1, true) + require.NoError(t, err) + require.True(t, ok) + + // Priority is outside global limit. + ok, err = PriorityAllowed(w.ID+1, model.NTSCType, globalLimit-1, true) + require.NoError(t, err) + require.False(t, ok) +} + +func addWorkspacePriorityLimit(t *testing.T, pgDB *db.PgDB, limit int) model.Workspace { + ctx := context.Background() + user := db.RequireMockUser(t, pgDB) + + // add a workspace to use + w := model.Workspace{Name: uuid.NewString(), UserID: user.ID} + _, err := db.Bun().NewInsert().Model(&w).Exec(ctx) + require.NoError(t, err) + + constraints := fmt.Sprintf(`{"priority_limit": %d}`, limit) + input := model.TaskConfigPolicies{ + WorkloadType: model.NTSCType, + WorkspaceID: &w.ID, + Constraints: &constraints, + LastUpdatedBy: user.ID, + } + err = SetTaskConfigPolicies(ctx, &input) + require.NoError(t, err) + + return w +} + +func addGlobalPriorityLimit(t *testing.T, pgDB *db.PgDB, limit int) { + ctx := context.Background() + user := db.RequireMockUser(t, pgDB) + + constraints := fmt.Sprintf(`{"priority_limit": %d}`, limit) + input := model.TaskConfigPolicies{ + WorkloadType: model.NTSCType, + WorkspaceID: nil, + Constraints: &constraints, + LastUpdatedBy: user.ID, + } + err := SetTaskConfigPolicies(ctx, &input) + require.NoError(t, err) +} diff --git a/master/internal/configpolicy/task_config_policy_test.go b/master/internal/configpolicy/task_config_policy_test.go new file mode 100644 index 00000000000..11ac52b1728 --- /dev/null +++ b/master/internal/configpolicy/task_config_policy_test.go @@ -0,0 +1,31 @@ +package configpolicy + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestPriorityWithinLimit(t *testing.T) { + testCases := []struct { + name string + userPriority int + adminLimit int + smallerIsHigher bool + ok bool + }{ + {"smaller is higher - ok", 10, 1, true, true}, + {"smaller is higher - not ok", 10, 20, true, false}, + {"smaller is higher - equal", 20, 20, true, true}, + {"smaller is lower - ok", 11, 13, false, true}, + {"smaller is lower - not ok", 13, 11, false, false}, + {"smaller is lower - equal", 11, 11, false, true}, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + ok := priorityWithinLimit(tt.userPriority, tt.adminLimit, tt.smallerIsHigher) + require.Equal(t, tt.ok, ok) + }) + } +} diff --git a/master/internal/experiment_job_service.go b/master/internal/experiment_job_service.go index 1749ffe1ed1..109590edb5e 100644 --- a/master/internal/experiment_job_service.go +++ b/master/internal/experiment_job_service.go @@ -2,6 +2,7 @@ package internal import ( "context" + "database/sql" "fmt" "strconv" @@ -9,7 +10,9 @@ import ( "google.golang.org/protobuf/types/known/timestamppb" "github.com/determined-ai/determined/master/internal/config" + "github.com/determined-ai/determined/master/internal/configpolicy" "github.com/determined-ai/determined/master/internal/workspace" + "github.com/determined-ai/determined/master/pkg/model" "github.com/determined-ai/determined/proto/pkg/jobv1" ) @@ -53,7 +56,30 @@ func (e *internalExperiment) SetJobPriority(priority int) error { if priority < 1 || priority > 99 { return fmt.Errorf("priority must be between 1 and 99") } - err := e.setPriority(&priority, true) + + workspaceModel, err := workspace.WorkspaceByProjectID(context.TODO(), e.ProjectID) + if err != nil && errors.Cause(err) != sql.ErrNoRows { + return err + } + wkspID := resolveWorkspaceID(workspaceModel) + + // Returns an error if RM does not implement priority. + if smallerHigher, err := e.rm.SmallerValueIsHigherPriority(); err == nil { + ok, err := configpolicy.PriorityAllowed( + wkspID, + model.ExperimentType, + priority, + smallerHigher, + ) + if err != nil { + return err + } + if !ok { + return fmt.Errorf("priority exceeds task config policy's priority_limit") + } + } + + err = e.setPriority(&priority, true) if err != nil { e.syslog.WithError(err).Info("setting experiment job priority") } diff --git a/master/internal/rm/agentrm/agent_resource_manager.go b/master/internal/rm/agentrm/agent_resource_manager.go index c2e3cd14dc7..2482e320297 100644 --- a/master/internal/rm/agentrm/agent_resource_manager.go +++ b/master/internal/rm/agentrm/agent_resource_manager.go @@ -872,3 +872,8 @@ func (a *ResourceManager) stop() { pool.stop() } } + +// SmallerValueIsHigherPriority returns true if smaller priority values indicate a higher priority level. +func (a *ResourceManager) SmallerValueIsHigherPriority() (bool, error) { + return true, nil +} diff --git a/master/internal/rm/dispatcherrm/dispatcher_resource_manager.go b/master/internal/rm/dispatcherrm/dispatcher_resource_manager.go index fe82a0aaeb2..0777552369f 100644 --- a/master/internal/rm/dispatcherrm/dispatcher_resource_manager.go +++ b/master/internal/rm/dispatcherrm/dispatcher_resource_manager.go @@ -2257,3 +2257,8 @@ func (m *DispatcherResourceManager) DisableSlot(*apiv1.DisableSlotRequest, ) (resp *apiv1.DisableSlotResponse, err error) { return nil, errNotSupportedOnHpcCluster } + +// SmallerValueIsHigherPriority returns true if smaller priority values indicate a higher priority level. +func (m *DispatcherResourceManager) SmallerValueIsHigherPriority() (bool, error) { + return false, fmt.Errorf("priority not implemented") +} diff --git a/master/internal/rm/kubernetesrm/kubernetes_resource_manager.go b/master/internal/rm/kubernetesrm/kubernetes_resource_manager.go index 9ec6065dd66..8eadb1291b2 100644 --- a/master/internal/rm/kubernetesrm/kubernetes_resource_manager.go +++ b/master/internal/rm/kubernetesrm/kubernetes_resource_manager.go @@ -726,3 +726,8 @@ func (k ResourceManager) DisableSlot( ) (resp *apiv1.DisableSlotResponse, err error) { return nil, rmerrors.ErrNotSupported } + +// SmallerValueIsHigherPriority returns true if smaller priority values indicate a higher priority level. +func (k *ResourceManager) SmallerValueIsHigherPriority() (bool, error) { + return false, nil +} diff --git a/master/internal/rm/multirm/multirm.go b/master/internal/rm/multirm/multirm.go index be782a09887..b096898682a 100644 --- a/master/internal/rm/multirm/multirm.go +++ b/master/internal/rm/multirm/multirm.go @@ -547,3 +547,8 @@ func (m *MultiRMRouter) fanOutRMCommand(f func(rm.ResourceManager) error) error } return nil } + +// SmallerValueIsHigherPriority returns true if smaller priority values indicate a higher priority level. +func (m *MultiRMRouter) SmallerValueIsHigherPriority() (bool, error) { + return false, nil +} diff --git a/master/internal/rm/resource_manager_iface.go b/master/internal/rm/resource_manager_iface.go index 5bcb8bbd618..6b487e2c5bc 100644 --- a/master/internal/rm/resource_manager_iface.go +++ b/master/internal/rm/resource_manager_iface.go @@ -26,6 +26,7 @@ type ResourceManager interface { SetGroupPriority(sproto.SetGroupPriority) error ExternalPreemptionPending(sproto.PendingPreemption) error IsReattachableOnlyAfterStarted() bool + SmallerValueIsHigherPriority() (bool, error) // Resource pool stuff. GetResourcePools() (*apiv1.GetResourcePoolsResponse, error) From dd453037457437905b3caf48d9d63b5f88ad822a Mon Sep 17 00:00:00 2001 From: Kristine Kunapuli Date: Wed, 18 Sep 2024 16:51:59 -0400 Subject: [PATCH 2/5] refine tests --- .../postgres_task_config_policy_intg_test.go | 36 ++++++++++++------- 1 file changed, 23 insertions(+), 13 deletions(-) diff --git a/master/internal/configpolicy/postgres_task_config_policy_intg_test.go b/master/internal/configpolicy/postgres_task_config_policy_intg_test.go index 2c4121908f3..119436e5e4e 100644 --- a/master/internal/configpolicy/postgres_task_config_policy_intg_test.go +++ b/master/internal/configpolicy/postgres_task_config_policy_intg_test.go @@ -215,7 +215,7 @@ func TestWorkspaceGetPriorityLimit(t *testing.T) { db.MustMigrateTestPostgres(t, pgDB, db.MigrationsFromDB) user := db.RequireMockUser(t, pgDB) - // add a workspace to use + // Add a workspace to use. w := model.Workspace{Name: uuid.NewString(), UserID: user.ID} _, err := db.Bun().NewInsert().Model(&w).Exec(ctx) require.NoError(t, err) @@ -226,7 +226,12 @@ func TestWorkspaceGetPriorityLimit(t *testing.T) { } }() - // add priority limit for workspace NTSC + // No limit set. + _, found, err := GetPriorityLimit(ctx, nil, model.NTSCType) + require.NoError(t, err) + require.False(t, found) + + // Add priority limit for workspace NTSC. wkspLimit := 20 constraints := fmt.Sprintf(`{"priority_limit": %d}`, wkspLimit) wkspInput := model.TaskConfigPolicies{ @@ -235,33 +240,33 @@ func TestWorkspaceGetPriorityLimit(t *testing.T) { Constraints: &constraints, LastUpdatedBy: user.ID, } - // err = SetNTSCConfigPolicies(ctx, &wkspInput) + err = SetTaskConfigPolicies(ctx, &wkspInput) require.NoError(t, err) - // get priority limit; should match workspace + // Get priority limit; should match workspace limit. res, found, err := GetPriorityLimit(ctx, &w.ID, model.NTSCType) require.NoError(t, err) require.True(t, found) require.Equal(t, wkspLimit, res) - // get limit for workspace that does not exist + // Get limit for a workspace that does not exist. wkspIDDoesNotExist := 404 _, found, err = GetPriorityLimit(ctx, &wkspIDDoesNotExist, model.NTSCType) require.NoError(t, err) require.False(t, found) - // read global + // Get global limit. _, found, err = GetPriorityLimit(ctx, nil, model.NTSCType) require.NoError(t, err) require.False(t, found) - // read experiment + // Get limit for other workload type. _, found, err = GetPriorityLimit(ctx, &w.ID, model.ExperimentType) require.NoError(t, err) require.False(t, found) - // read invalid workload type + // Try an invalid workload type. _, found, err = GetPriorityLimit(ctx, &w.ID, "bogus") require.Error(t, err) require.False(t, found) @@ -275,7 +280,7 @@ func TestGlobalGetPriorityLimit(t *testing.T) { db.MustMigrateTestPostgres(t, pgDB, db.MigrationsFromDB) user := db.RequireMockUser(t, pgDB) - // add a workspace to use + // Add a workspace to use. w := model.Workspace{Name: uuid.NewString(), UserID: user.ID} _, err := db.Bun().NewInsert().Model(&w).Exec(ctx) require.NoError(t, err) @@ -286,7 +291,12 @@ func TestGlobalGetPriorityLimit(t *testing.T) { } }() - // add priority limit for global NTSC + // No limit set. + _, found, err := GetPriorityLimit(ctx, nil, model.NTSCType) + require.NoError(t, err) + require.False(t, found) + + // Add priority limit for global NTSC. globalLimit := 5 constraints := fmt.Sprintf(`{"priority_limit": %d}`, globalLimit) globalInput := model.TaskConfigPolicies{ @@ -298,18 +308,18 @@ func TestGlobalGetPriorityLimit(t *testing.T) { err = SetTaskConfigPolicies(ctx, &globalInput) require.NoError(t, err) - // get priority limit + // Get priority limit, should be global limit. res, found, err := GetPriorityLimit(ctx, nil, model.NTSCType) require.NoError(t, err) require.True(t, found) require.Equal(t, globalLimit, res) - // read experiment + // Get limit for a different workload type. _, found, err = GetPriorityLimit(ctx, nil, model.ExperimentType) require.NoError(t, err) require.False(t, found) - // read invalid workload type + // Try an invalid workload type. _, found, err = GetPriorityLimit(ctx, nil, "bogus") require.Error(t, err) require.False(t, found) From e733c268450e13f6f79cdb8ddd068e5cf8b044b7 Mon Sep 17 00:00:00 2001 From: Kristine Kunapuli Date: Tue, 24 Sep 2024 10:31:22 -0400 Subject: [PATCH 3/5] incorporate feedback; implement multi-RM call --- .../configpolicy/task_config_policy.go | 5 +-- .../task_config_policy_intg_test.go | 14 +++---- master/internal/rm/multirm/multirm.go | 15 ++++++- .../internal/rm/multirm/multirm_intg_test.go | 39 +++++++++++++++++++ 4 files changed, 62 insertions(+), 11 deletions(-) diff --git a/master/internal/configpolicy/task_config_policy.go b/master/internal/configpolicy/task_config_policy.go index 2848b1dec11..a6d25794539 100644 --- a/master/internal/configpolicy/task_config_policy.go +++ b/master/internal/configpolicy/task_config_policy.go @@ -28,13 +28,12 @@ type NTSCConfigPolicies struct { Constraints *model.Constraints `json:"constraints"` } -// PriorityAllowed returns true if the desired priority is within the task config policy limit. +// PriorityAllowed returns true if the desired priority is within the limit set by task config policies. func PriorityAllowed(wkspID int, workloadType string, priority int, smallerHigher bool) (bool, error) { - // Check if a priority limit has been set in task config policies. + // Check if a priority limit has been set with a constraint policy. // Global policies have highest precedence. limit, found, err := GetPriorityLimit(context.TODO(), nil, workloadType) if err != nil { - // TODO do we really want to block on this? return false, fmt.Errorf("unable to fetch task config policy priority limit") } if found { diff --git a/master/internal/configpolicy/task_config_policy_intg_test.go b/master/internal/configpolicy/task_config_policy_intg_test.go index a0a806d529f..66af9ca612b 100644 --- a/master/internal/configpolicy/task_config_policy_intg_test.go +++ b/master/internal/configpolicy/task_config_policy_intg_test.go @@ -25,15 +25,17 @@ func TestPriorityAllowed(t *testing.T) { require.True(t, ok) wkspLimit := 50 - w := addWorkspacePriorityLimit(t, pgDB, wkspLimit) + user := db.RequireMockUser(t, pgDB) + w := addWorkspacePriorityLimit(t, pgDB, user, wkspLimit) // Priority is outside workspace limit. - ok, err = PriorityAllowed(w.ID, model.NTSCType, wkspLimit-1, true) + smallerValueIsHigherPriority := true + ok, err = PriorityAllowed(w.ID, model.NTSCType, wkspLimit-1, smallerValueIsHigherPriority) require.NoError(t, err) require.False(t, ok) globalLimit := 42 - addGlobalPriorityLimit(t, pgDB, globalLimit) + addGlobalPriorityLimit(t, pgDB, user, globalLimit) // Priority is within global limit. ok, err = PriorityAllowed(w.ID, model.NTSCType, wkspLimit-1, true) @@ -46,9 +48,8 @@ func TestPriorityAllowed(t *testing.T) { require.False(t, ok) } -func addWorkspacePriorityLimit(t *testing.T, pgDB *db.PgDB, limit int) model.Workspace { +func addWorkspacePriorityLimit(t *testing.T, pgDB *db.PgDB, user model.User, limit int) model.Workspace { ctx := context.Background() - user := db.RequireMockUser(t, pgDB) // add a workspace to use w := model.Workspace{Name: uuid.NewString(), UserID: user.ID} @@ -68,9 +69,8 @@ func addWorkspacePriorityLimit(t *testing.T, pgDB *db.PgDB, limit int) model.Wor return w } -func addGlobalPriorityLimit(t *testing.T, pgDB *db.PgDB, limit int) { +func addGlobalPriorityLimit(t *testing.T, pgDB *db.PgDB, user model.User, limit int) { ctx := context.Background() - user := db.RequireMockUser(t, pgDB) constraints := fmt.Sprintf(`{"priority_limit": %d}`, limit) input := model.TaskConfigPolicies{ diff --git a/master/internal/rm/multirm/multirm.go b/master/internal/rm/multirm/multirm.go index b096898682a..ae20726064b 100644 --- a/master/internal/rm/multirm/multirm.go +++ b/master/internal/rm/multirm/multirm.go @@ -550,5 +550,18 @@ func (m *MultiRMRouter) fanOutRMCommand(f func(rm.ResourceManager) error) error // SmallerValueIsHigherPriority returns true if smaller priority values indicate a higher priority level. func (m *MultiRMRouter) SmallerValueIsHigherPriority() (bool, error) { - return false, nil + set := false + var smallerIsHigher bool + for _, rm := range m.rms { + s, err := rm.SmallerValueIsHigherPriority() + if err != nil { + return false, err + } + if set && s != smallerIsHigher { + return false, fmt.Errorf("multiRM resource managers use different priority ordering") + } + smallerIsHigher = s + set = true + } + return smallerIsHigher, nil } diff --git a/master/internal/rm/multirm/multirm_intg_test.go b/master/internal/rm/multirm/multirm_intg_test.go index 918b32a955e..722ccd87e11 100644 --- a/master/internal/rm/multirm/multirm_intg_test.go +++ b/master/internal/rm/multirm/multirm_intg_test.go @@ -777,6 +777,43 @@ func TestGetRM(t *testing.T) { } } +func TestSmallerValueIsHigherPriority(t *testing.T) { + defaultRM := &mocks.ResourceManager{} + otherRM := &mocks.ResourceManager{} + + m := &MultiRMRouter{ + defaultClusterName: defaultClusterName, + rms: map[string]rm.ResourceManager{ + defaultClusterName: defaultRM, + "rm1": otherRM, + }, + } + + t.Run("both RMs same type", func(t *testing.T) { + defaultRM.On("SmallerValueIsHigherPriority", mock.Anything).Return(false, nil) + otherRM.On("SmallerValueIsHigherPriority", mock.Anything).Return(false, nil) + smallerIsHigher, err := m.SmallerValueIsHigherPriority() + require.NoError(t, err) + require.False(t, smallerIsHigher) + }) + + t.Run("RMs different type", func(t *testing.T) { + // This is not a supported mode in Determined but we still need to test it. + defaultRM.On("SmallerValueIsHigherPriority", mock.Anything).Return(false, nil) + otherRM.On("SmallerValueIsHigherPriority", mock.Anything).Return(true, nil) + _, err := m.SmallerValueIsHigherPriority() + require.Error(t, err) + }) + + t.Run("RMs with error", func(t *testing.T) { + // This is not a supported mode in Determined but we still need to test it. + defaultRM.On("SmallerValueIsHigherPriority", mock.Anything).Return(false, fmt.Errorf("error")) + otherRM.On("SmallerValueIsHigherPriority", mock.Anything).Return(true, nil) + _, err := m.SmallerValueIsHigherPriority() + require.Error(t, err) + }) +} + func mockRM(poolName rm.ResourcePoolName) *mocks.ResourceManager { mockRM := mocks.ResourceManager{} mockRM.On("GetResourcePools").Return(&apiv1.GetResourcePoolsResponse{ @@ -810,5 +847,7 @@ func mockRM(poolName rm.ResourcePoolName) *mocks.ResourceManager { mockRM.On("DisableSlot", mock.Anything).Return(&apiv1.DisableSlotResponse{}, nil) mockRM.On("DefaultNamespace", mock.Anything).Return("default", nil) mockRM.On("VerifyNamespaceExists", mock.Anything).Return(nil) + + mockRM.On("SmallerValueIsHigherPriority", mock.Anything).Return(false, nil) return &mockRM } From 1596e797ef609bccdf3b689f23b03028e17338d7 Mon Sep 17 00:00:00 2001 From: Kristine Kunapuli Date: Tue, 24 Sep 2024 10:32:18 -0400 Subject: [PATCH 4/5] better test cases --- master/internal/rm/multirm/multirm_intg_test.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/master/internal/rm/multirm/multirm_intg_test.go b/master/internal/rm/multirm/multirm_intg_test.go index 722ccd87e11..e8572a4682d 100644 --- a/master/internal/rm/multirm/multirm_intg_test.go +++ b/master/internal/rm/multirm/multirm_intg_test.go @@ -790,11 +790,11 @@ func TestSmallerValueIsHigherPriority(t *testing.T) { } t.Run("both RMs same type", func(t *testing.T) { - defaultRM.On("SmallerValueIsHigherPriority", mock.Anything).Return(false, nil) - otherRM.On("SmallerValueIsHigherPriority", mock.Anything).Return(false, nil) + defaultRM.On("SmallerValueIsHigherPriority", mock.Anything).Return(true, nil) + otherRM.On("SmallerValueIsHigherPriority", mock.Anything).Return(true, nil) smallerIsHigher, err := m.SmallerValueIsHigherPriority() require.NoError(t, err) - require.False(t, smallerIsHigher) + require.True(t, smallerIsHigher) }) t.Run("RMs different type", func(t *testing.T) { @@ -808,7 +808,7 @@ func TestSmallerValueIsHigherPriority(t *testing.T) { t.Run("RMs with error", func(t *testing.T) { // This is not a supported mode in Determined but we still need to test it. defaultRM.On("SmallerValueIsHigherPriority", mock.Anything).Return(false, fmt.Errorf("error")) - otherRM.On("SmallerValueIsHigherPriority", mock.Anything).Return(true, nil) + otherRM.On("SmallerValueIsHigherPriority", mock.Anything).Return(false, nil) _, err := m.SmallerValueIsHigherPriority() require.Error(t, err) }) From 85f45ce714179d25a63a022b53c1e81104073044 Mon Sep 17 00:00:00 2001 From: Kristine Kunapuli Date: Tue, 24 Sep 2024 13:32:47 -0400 Subject: [PATCH 5/5] tests play nicely in parallel --- .../internal/rm/multirm/multirm_intg_test.go | 30 +++++++++++-------- 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/master/internal/rm/multirm/multirm_intg_test.go b/master/internal/rm/multirm/multirm_intg_test.go index e8572a4682d..c313402eae1 100644 --- a/master/internal/rm/multirm/multirm_intg_test.go +++ b/master/internal/rm/multirm/multirm_intg_test.go @@ -778,18 +778,8 @@ func TestGetRM(t *testing.T) { } func TestSmallerValueIsHigherPriority(t *testing.T) { - defaultRM := &mocks.ResourceManager{} - otherRM := &mocks.ResourceManager{} - - m := &MultiRMRouter{ - defaultClusterName: defaultClusterName, - rms: map[string]rm.ResourceManager{ - defaultClusterName: defaultRM, - "rm1": otherRM, - }, - } - t.Run("both RMs same type", func(t *testing.T) { + defaultRM, otherRM, m := mockMultiRM() defaultRM.On("SmallerValueIsHigherPriority", mock.Anything).Return(true, nil) otherRM.On("SmallerValueIsHigherPriority", mock.Anything).Return(true, nil) smallerIsHigher, err := m.SmallerValueIsHigherPriority() @@ -799,6 +789,7 @@ func TestSmallerValueIsHigherPriority(t *testing.T) { t.Run("RMs different type", func(t *testing.T) { // This is not a supported mode in Determined but we still need to test it. + defaultRM, otherRM, m := mockMultiRM() defaultRM.On("SmallerValueIsHigherPriority", mock.Anything).Return(false, nil) otherRM.On("SmallerValueIsHigherPriority", mock.Anything).Return(true, nil) _, err := m.SmallerValueIsHigherPriority() @@ -807,6 +798,7 @@ func TestSmallerValueIsHigherPriority(t *testing.T) { t.Run("RMs with error", func(t *testing.T) { // This is not a supported mode in Determined but we still need to test it. + defaultRM, otherRM, m := mockMultiRM() defaultRM.On("SmallerValueIsHigherPriority", mock.Anything).Return(false, fmt.Errorf("error")) otherRM.On("SmallerValueIsHigherPriority", mock.Anything).Return(false, nil) _, err := m.SmallerValueIsHigherPriority() @@ -814,6 +806,21 @@ func TestSmallerValueIsHigherPriority(t *testing.T) { }) } +func mockMultiRM() (*mocks.ResourceManager, *mocks.ResourceManager, *MultiRMRouter) { + defaultRM := &mocks.ResourceManager{} + otherRM := &mocks.ResourceManager{} + + m := &MultiRMRouter{ + defaultClusterName: defaultClusterName, + rms: map[string]rm.ResourceManager{ + defaultClusterName: defaultRM, + "rm1": otherRM, + }, + } + + return defaultRM, otherRM, m +} + func mockRM(poolName rm.ResourcePoolName) *mocks.ResourceManager { mockRM := mocks.ResourceManager{} mockRM.On("GetResourcePools").Return(&apiv1.GetResourcePoolsResponse{ @@ -848,6 +855,5 @@ func mockRM(poolName rm.ResourcePoolName) *mocks.ResourceManager { mockRM.On("DefaultNamespace", mock.Anything).Return("default", nil) mockRM.On("VerifyNamespaceExists", mock.Anything).Return(nil) - mockRM.On("SmallerValueIsHigherPriority", mock.Anything).Return(false, nil) return &mockRM }