Skip to content

Commit

Permalink
chore: check task config policy priority limit for [CM-490] (#9958)
Browse files Browse the repository at this point in the history
  • Loading branch information
kkunapuli committed Sep 26, 2024
1 parent 8bc08e5 commit ac8fbf6
Show file tree
Hide file tree
Showing 13 changed files with 435 additions and 1 deletion.
19 changes: 19 additions & 0 deletions master/internal/command/command_job_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@ import (
"google.golang.org/protobuf/types/known/timestamppb"

"github.com/determined-ai/determined/master/internal/config"
"github.com/determined-ai/determined/master/internal/configpolicy"
"github.com/determined-ai/determined/master/internal/rm/rmerrors"
"github.com/determined-ai/determined/master/internal/sproto"
"github.com/determined-ai/determined/master/pkg/model"
"github.com/determined-ai/determined/proto/pkg/jobv1"
)

Expand Down Expand Up @@ -47,6 +49,23 @@ func (c *Command) SetJobPriority(priority int) error {
if priority < 1 || priority > 99 {
return fmt.Errorf("priority must be between 1 and 99")
}

// Returns an error if RM does not implement priority.
if smallerHigher, err := c.rm.SmallerValueIsHigherPriority(); err == nil {
ok, err := configpolicy.PriorityAllowed(
int(c.GenericCommandSpec.Metadata.WorkspaceID),
model.NTSCType,
priority,
smallerHigher,
)
if err != nil {
return err
}
if !ok {
return fmt.Errorf("priority exceeds task config policy's priority_limit")
}
}

err := c.setNTSCPriority(priority, true)
if err != nil {
c.syslog.WithError(err).Info("setting command job priority")
Expand Down
38 changes: 38 additions & 0 deletions master/internal/configpolicy/postgres_task_config_policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package configpolicy
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"strings"

Expand Down Expand Up @@ -93,6 +94,43 @@ func GetTaskConfigPolicies(ctx context.Context,
return &tcp, nil
}

// GetPriorityLimit reads the priority limit for the given scope and workload type.
// It returns found=false if no limit exists.
func GetPriorityLimit(ctx context.Context, scope *int, workloadType string) (limit int, found bool, err error) {
if !ValidWorkloadType(workloadType) {
return 0, false, fmt.Errorf("invalid workload type: %s", workloadType)
}

wkspQuery := wkspIDQuery
if scope == nil {
wkspQuery = wkspIDGlobalQuery
}

var constraints model.Constraints
var constraintsStr string
err = db.Bun().NewSelect().
Table("task_config_policies").
Column("constraints").
Where(wkspQuery, scope).
Where("workload_type = ?", workloadType).
Scan(ctx, &constraintsStr)

if err == sql.ErrNoRows {
return 0, false, nil
} else if err != nil {
return 0, false, fmt.Errorf("error retrieving priority limit: %w", err)
}

if err = json.Unmarshal([]byte(constraintsStr), &constraints); err != nil {
return 0, false, err
}
if constraints.PriorityLimit != nil {
return *constraints.PriorityLimit, true, nil
}

return 0, false, nil
}

// DeleteConfigPolicies deletes the invariant experiment config and constraints for the
// given scope (global or workspace-level) and workload type.
func DeleteConfigPolicies(ctx context.Context,
Expand Down
119 changes: 119 additions & 0 deletions master/internal/configpolicy/postgres_task_config_policy_intg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package configpolicy
import (
"context"
"encoding/json"
"fmt"
"testing"
"time"

Expand Down Expand Up @@ -206,6 +207,124 @@ func TestSetTaskConfigPolicies(t *testing.T) {
require.ErrorContains(t, err, "violates foreign key constraint")
}

func TestWorkspaceGetPriorityLimit(t *testing.T) {
ctx := context.Background()
require.NoError(t, etc.SetRootPath(db.RootFromDB))
pgDB, cleanup := db.MustResolveNewPostgresDatabase(t)
defer cleanup()
db.MustMigrateTestPostgres(t, pgDB, db.MigrationsFromDB)
user := db.RequireMockUser(t, pgDB)

// Add a workspace to use.
w := model.Workspace{Name: uuid.NewString(), UserID: user.ID}
_, err := db.Bun().NewInsert().Model(&w).Exec(ctx)
require.NoError(t, err)
defer func() {
err := db.CleanupMockWorkspace([]int32{int32(w.ID)})
if err != nil {
log.Errorf("error when cleaning up mock workspaces")
}
}()

// No limit set.
_, found, err := GetPriorityLimit(ctx, nil, model.NTSCType)
require.NoError(t, err)
require.False(t, found)

// Add priority limit for workspace NTSC.
wkspLimit := 20
constraints := fmt.Sprintf(`{"priority_limit": %d}`, wkspLimit)
wkspInput := model.TaskConfigPolicies{
WorkloadType: model.NTSCType,
WorkspaceID: &w.ID,
Constraints: &constraints,
LastUpdatedBy: user.ID,
}

err = SetTaskConfigPolicies(ctx, &wkspInput)
require.NoError(t, err)

// Get priority limit; should match workspace limit.
res, found, err := GetPriorityLimit(ctx, &w.ID, model.NTSCType)
require.NoError(t, err)
require.True(t, found)
require.Equal(t, wkspLimit, res)

// Get limit for a workspace that does not exist.
wkspIDDoesNotExist := 404
_, found, err = GetPriorityLimit(ctx, &wkspIDDoesNotExist, model.NTSCType)
require.NoError(t, err)
require.False(t, found)

// Get global limit.
_, found, err = GetPriorityLimit(ctx, nil, model.NTSCType)
require.NoError(t, err)
require.False(t, found)

// Get limit for other workload type.
_, found, err = GetPriorityLimit(ctx, &w.ID, model.ExperimentType)
require.NoError(t, err)
require.False(t, found)

// Try an invalid workload type.
_, found, err = GetPriorityLimit(ctx, &w.ID, "bogus")
require.Error(t, err)
require.False(t, found)
}

func TestGlobalGetPriorityLimit(t *testing.T) {
ctx := context.Background()
require.NoError(t, etc.SetRootPath(db.RootFromDB))
pgDB, cleanup := db.MustResolveNewPostgresDatabase(t)
defer cleanup()
db.MustMigrateTestPostgres(t, pgDB, db.MigrationsFromDB)
user := db.RequireMockUser(t, pgDB)

// Add a workspace to use.
w := model.Workspace{Name: uuid.NewString(), UserID: user.ID}
_, err := db.Bun().NewInsert().Model(&w).Exec(ctx)
require.NoError(t, err)
defer func() {
err := db.CleanupMockWorkspace([]int32{int32(w.ID)})
if err != nil {
log.Errorf("error when cleaning up mock workspaces")
}
}()

// No limit set.
_, found, err := GetPriorityLimit(ctx, nil, model.NTSCType)
require.NoError(t, err)
require.False(t, found)

// Add priority limit for global NTSC.
globalLimit := 5
constraints := fmt.Sprintf(`{"priority_limit": %d}`, globalLimit)
globalInput := model.TaskConfigPolicies{
WorkloadType: model.NTSCType,
WorkspaceID: nil,
Constraints: &constraints,
LastUpdatedBy: user.ID,
}
err = SetTaskConfigPolicies(ctx, &globalInput)
require.NoError(t, err)

// Get priority limit, should be global limit.
res, found, err := GetPriorityLimit(ctx, nil, model.NTSCType)
require.NoError(t, err)
require.True(t, found)
require.Equal(t, globalLimit, res)

// Get limit for a different workload type.
_, found, err = GetPriorityLimit(ctx, nil, model.ExperimentType)
require.NoError(t, err)
require.False(t, found)

// Try an invalid workload type.
_, found, err = GetPriorityLimit(ctx, nil, "bogus")
require.Error(t, err)
require.False(t, found)
}

// Test the enforcement of the primary key on the task_config_polciies table.
func TestTaskConfigPoliciesUnique(t *testing.T) {
ctx := context.Background()
Expand Down
38 changes: 38 additions & 0 deletions master/internal/configpolicy/task_config_policy.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package configpolicy

import (
"context"
"fmt"

"github.com/determined-ai/determined/master/pkg/model"
"github.com/determined-ai/determined/master/pkg/schemas/expconf"
)
Expand All @@ -24,3 +27,38 @@ type NTSCConfigPolicies struct {
InvariantConfig *model.CommandConfig `json:"invariant_config"`
Constraints *model.Constraints `json:"constraints"`
}

// PriorityAllowed returns true if the desired priority is within the limit set by task config policies.
func PriorityAllowed(wkspID int, workloadType string, priority int, smallerHigher bool) (bool, error) {
// Check if a priority limit has been set with a constraint policy.
// Global policies have highest precedence.
limit, found, err := GetPriorityLimit(context.TODO(), nil, workloadType)
if err != nil {
return false, fmt.Errorf("unable to fetch task config policy priority limit")
}
if found {
return priorityWithinLimit(priority, limit, smallerHigher), nil
}

// TODO use COALESCE instead once postgres updates are complete.
// Workspace policies have second precedence.
limit, found, err = GetPriorityLimit(context.TODO(), &wkspID, workloadType)
if err != nil {
// TODO do we really want to block on this?
return false, fmt.Errorf("unable to fetch task config policy priority limit")
}
if found {
return priorityWithinLimit(priority, limit, smallerHigher), nil
}

// No priority limit has been set.
return true, nil
}

func priorityWithinLimit(userPriority int, adminLimit int, smallerHigher bool) bool {
if smallerHigher {
return userPriority >= adminLimit
}

return userPriority <= adminLimit
}
84 changes: 84 additions & 0 deletions master/internal/configpolicy/task_config_policy_intg_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
package configpolicy

import (
"context"
"fmt"
"testing"

"github.com/google/uuid"
"github.com/stretchr/testify/require"

"github.com/determined-ai/determined/master/internal/db"
"github.com/determined-ai/determined/master/pkg/etc"
"github.com/determined-ai/determined/master/pkg/model"
)

func TestPriorityAllowed(t *testing.T) {
require.NoError(t, etc.SetRootPath(db.RootFromDB))
pgDB, cleanup := db.MustResolveNewPostgresDatabase(t)
defer cleanup()
db.MustMigrateTestPostgres(t, pgDB, db.MigrationsFromDB)

// When no constraints are present, any priority is allowed.
ok, err := PriorityAllowed(1, model.NTSCType, 0, true)
require.NoError(t, err)
require.True(t, ok)

wkspLimit := 50
user := db.RequireMockUser(t, pgDB)
w := addWorkspacePriorityLimit(t, pgDB, user, wkspLimit)

// Priority is outside workspace limit.
smallerValueIsHigherPriority := true
ok, err = PriorityAllowed(w.ID, model.NTSCType, wkspLimit-1, smallerValueIsHigherPriority)
require.NoError(t, err)
require.False(t, ok)

globalLimit := 42
addGlobalPriorityLimit(t, pgDB, user, globalLimit)

// Priority is within global limit.
ok, err = PriorityAllowed(w.ID, model.NTSCType, wkspLimit-1, true)
require.NoError(t, err)
require.True(t, ok)

// Priority is outside global limit.
ok, err = PriorityAllowed(w.ID+1, model.NTSCType, globalLimit-1, true)
require.NoError(t, err)
require.False(t, ok)
}

func addWorkspacePriorityLimit(t *testing.T, pgDB *db.PgDB, user model.User, limit int) model.Workspace {
ctx := context.Background()

// add a workspace to use
w := model.Workspace{Name: uuid.NewString(), UserID: user.ID}
_, err := db.Bun().NewInsert().Model(&w).Exec(ctx)
require.NoError(t, err)

constraints := fmt.Sprintf(`{"priority_limit": %d}`, limit)
input := model.TaskConfigPolicies{
WorkloadType: model.NTSCType,
WorkspaceID: &w.ID,
Constraints: &constraints,
LastUpdatedBy: user.ID,
}
err = SetTaskConfigPolicies(ctx, &input)
require.NoError(t, err)

return w
}

func addGlobalPriorityLimit(t *testing.T, pgDB *db.PgDB, user model.User, limit int) {
ctx := context.Background()

constraints := fmt.Sprintf(`{"priority_limit": %d}`, limit)
input := model.TaskConfigPolicies{
WorkloadType: model.NTSCType,
WorkspaceID: nil,
Constraints: &constraints,
LastUpdatedBy: user.ID,
}
err := SetTaskConfigPolicies(ctx, &input)
require.NoError(t, err)
}
31 changes: 31 additions & 0 deletions master/internal/configpolicy/task_config_policy_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package configpolicy

import (
"testing"

"github.com/stretchr/testify/require"
)

func TestPriorityWithinLimit(t *testing.T) {
testCases := []struct {
name string
userPriority int
adminLimit int
smallerIsHigher bool
ok bool
}{
{"smaller is higher - ok", 10, 1, true, true},
{"smaller is higher - not ok", 10, 20, true, false},
{"smaller is higher - equal", 20, 20, true, true},
{"smaller is lower - ok", 11, 13, false, true},
{"smaller is lower - not ok", 13, 11, false, false},
{"smaller is lower - equal", 11, 11, false, true},
}

for _, tt := range testCases {
t.Run(tt.name, func(t *testing.T) {
ok := priorityWithinLimit(tt.userPriority, tt.adminLimit, tt.smallerIsHigher)
require.Equal(t, tt.ok, ok)
})
}
}
Loading

0 comments on commit ac8fbf6

Please sign in to comment.