diff --git a/flyteadmin_config.yaml b/flyteadmin_config.yaml index 964f83a81..46d79a449 100644 --- a/flyteadmin_config.yaml +++ b/flyteadmin_config.yaml @@ -114,7 +114,7 @@ externalEvents: eventTypes: all Logger: show-source: true - level: 6 + level: 5 storage: type: stow stow: @@ -129,7 +129,7 @@ storage: secret_key: miniostorage signedUrl: stowConfigOverride: - endpoint: http://localhost:30084 + endpoint: http://localhost:30002 cache: max_size_mbs: 10 target_gc_percent: 100 @@ -162,16 +162,23 @@ queues: - critical - tags: - default -task_resources: - defaults: - cpu: 100m - memory: 200Mi - storage: 100M - limits: - cpu: 500m - gpu: 1 - memory: 300Mi - storage: 10G +#task_resources: +# defaults: +# cpu: 100m +# memory: 200Mi +# ephemeralStorage: 100M +# limits: +# cpu: 500m +# memory: 300Mi +# ephemeralStorage: 10G +#task_resources: +# defaults: +# cpu: 100m +#task_resources: +# defaults: +# ephemeralStorage: 500M +# limits: +# ephemeralStorage: 10G task_type_whitelist: sparkonk8s: - project: my_queue_1 diff --git a/go.mod b/go.mod index 6710d33bb..6d78ffdd1 100644 --- a/go.mod +++ b/go.mod @@ -13,7 +13,7 @@ require ( github.com/cloudevents/sdk-go/v2 v2.8.0 github.com/coreos/go-oidc v2.2.1+incompatible github.com/evanphx/json-patch v4.12.0+incompatible - github.com/flyteorg/flyteidl v1.3.14 + github.com/flyteorg/flyteidl v1.5.0 github.com/flyteorg/flyteplugins v1.0.40 github.com/flyteorg/flytepropeller v1.1.70 github.com/flyteorg/flytestdlib v1.0.15 diff --git a/go.sum b/go.sum index f2fe7c70b..8e166ce64 100644 --- a/go.sum +++ b/go.sum @@ -312,8 +312,8 @@ github.com/fatih/structs v1.0.0/go.mod h1:9NiDSp5zOcgEDl+j00MP/WkGVPOlPRLejGD8Ga github.com/fatih/structs v1.1.0/go.mod h1:9NiDSp5zOcgEDl+j00MP/WkGVPOlPRLejGD8Ga6PJ7M= github.com/felixge/httpsnoop v1.0.1 h1:lvB5Jl89CsZtGIWuTcDM1E/vkVs49/Ml7JJe07l8SPQ= github.com/felixge/httpsnoop v1.0.1/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= -github.com/flyteorg/flyteidl v1.3.14 h1:o5M0g/r6pXTPu5PEurbYxbQmuOu3hqqsaI2M6uvK0N8= -github.com/flyteorg/flyteidl v1.3.14/go.mod h1:Pkt2skI1LiHs/2ZoekBnyPhuGOFMiuul6HHcKGZBsbM= +github.com/flyteorg/flyteidl v1.5.0 h1:vdaA5Cg9eqi5NMuASSod/AE7RXlHvzdWjSL9abDyd/M= +github.com/flyteorg/flyteidl v1.5.0/go.mod h1:ckLjB51moX4L0oQml+WTCrPK50zrJf6IZJ6LPC0RB4I= github.com/flyteorg/flyteplugins v1.0.40 h1:RTsYingqmqr13qBbi4CB2ArXDHNHUOkAF+HTLJQiQ/s= github.com/flyteorg/flyteplugins v1.0.40/go.mod h1:qyUPqVspLcLGJpKxVwHDWf+kBpOGuItOxCaF6zAmDio= github.com/flyteorg/flytepropeller v1.1.70 h1:/d1qqz13rdVADM85ST70eerAdBstJJz9UUB/mNSZi0w= diff --git a/pkg/clusterresource/controller.go b/pkg/clusterresource/controller.go index 348d61c5e..c86ea4550 100644 --- a/pkg/clusterresource/controller.go +++ b/pkg/clusterresource/controller.go @@ -687,7 +687,7 @@ func NewClusterResourceControllerFromConfig(ctx context.Context, scope promutils repo := repositories.NewGormRepo( db, errors2.NewPostgresErrorTransformer(dbScope.NewSubScope("errors")), dbScope) - adminDataProvider = impl2.NewDatabaseAdminDataProvider(repo, configuration, resources.NewResourceManager(repo, configuration.ApplicationConfiguration())) + adminDataProvider = impl2.NewDatabaseAdminDataProvider(repo, configuration, resources.NewResourceManager(repo, configuration)) } return NewClusterResourceController(adminDataProvider, listTargetsProvider, scope), nil diff --git a/pkg/common/testutils/common.go b/pkg/common/testutils/common.go index 7df3db16f..8ac70974d 100644 --- a/pkg/common/testutils/common.go +++ b/pkg/common/testutils/common.go @@ -1,6 +1,9 @@ package testutils -import "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" +import ( + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" + "k8s.io/apimachinery/pkg/api/resource" +) // Convenience method to wrap verbose boilerplate for initializing a PluginOverrides MatchingAttributes. func GetPluginOverridesAttributes(vals map[string][]string) *admin.MatchingAttributes { @@ -19,3 +22,7 @@ func GetPluginOverridesAttributes(vals map[string][]string) *admin.MatchingAttri }, } } + +func GetPtr(quantity resource.Quantity) *resource.Quantity { + return &quantity +} diff --git a/pkg/executioncluster/impl/random_cluster_selector.go b/pkg/executioncluster/impl/random_cluster_selector.go index 10bdd7714..799afd4a7 100644 --- a/pkg/executioncluster/impl/random_cluster_selector.go +++ b/pkg/executioncluster/impl/random_cluster_selector.go @@ -172,7 +172,7 @@ func NewRandomClusterSelector(listTargets interfaces.ListTargetsInterface, confi } return &RandomClusterSelector{ labelWeightedRandomMap: labelWeightedRandomMap, - resourceManager: resources.NewResourceManager(db, config.ApplicationConfiguration()), + resourceManager: resources.NewResourceManager(db, config), equalWeightedAllClusters: equalWeightedAllClusters, ListTargetsInterface: listTargets, defaultExecutionLabel: defaultExecutionLabel, diff --git a/pkg/manager/impl/execution_manager.go b/pkg/manager/impl/execution_manager.go index 4f4689715..4fe2c4614 100644 --- a/pkg/manager/impl/execution_manager.go +++ b/pkg/manager/impl/execution_manager.go @@ -10,8 +10,6 @@ import ( "github.com/flyteorg/flyteadmin/plugins" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s" - "github.com/flyteorg/flyteadmin/auth" "github.com/flyteorg/flyteadmin/pkg/manager/impl/resources" @@ -184,12 +182,49 @@ func (m *ExecutionManager) addPluginOverrides(ctx context.Context, executionID * return nil, nil } -// TODO: Delete this code usage after the flyte v0.17.0 release +// defaults should be a coalesce of task defaults, and platform defaults. +// task limits should be the limits from the task, coalesced with the defaults from step one +// then both should be limited by any platform limits. +// anything 0 or empty is not set. +// if both requests and limits end up empty, return nil. if one is empty, return nil for it +func (m *ExecutionManager) getResources(ctx context.Context, taskResources *core.Resources, platformResources workflowengineInterfaces.TaskResources) *core.Resources { + + // requests: coalesce(task request, platform default) + // limits: coalesce(task limits, task requests) + // check that defaults and limits are both below platform limit + var requestSet runtimeInterfaces.TaskResourceSet + var limitSet runtimeInterfaces.TaskResourceSet + if taskResources != nil && taskResources.GetRequests() != nil { + requestSet = util.GetTaskResourcesAndCoalesce(ctx, taskResources.GetRequests(), platformResources.Defaults) + } else { + requestSet = platformResources.Defaults + } + if taskResources != nil && taskResources.GetLimits() != nil { + limitSet = util.GetTaskResourcesAndCoalesce(ctx, taskResources.GetLimits(), requestSet) + } else { + limitSet = requestSet + } + adjustedRequestSet := util.ConstrainTaskResourceSet(ctx, requestSet, platformResources.Limits) + adjustedLimitSet := util.ConstrainTaskResourceSet(ctx, limitSet, platformResources.Limits) + + // convert the sets back to core.Resources + requestEntries := util.ConvertTaskResourceSetToCoreResources(adjustedRequestSet) + limitEntries := util.ConvertTaskResourceSetToCoreResources(adjustedLimitSet) + if len(requestEntries) == 0 && len(limitEntries) == 0 { + return nil + } + res := core.Resources{} + if len(requestEntries) > 0 { + res.Requests = requestEntries + } + if len(limitEntries) > 0 { + res.Limits = limitEntries + } + + return &res +} + // Assumes input contains a compiled task with a valid container resource execConfig. -// -// Note: The system will assign a system-default value for request but for limit it will deduce it from the request -// itself => Limit := Min([Some-Multiplier X Request], System-Max). For now we are using a multiplier of 1. In -// general we recommend the users to set limits close to requests for more predictability in the system. func (m *ExecutionManager) setCompiledTaskDefaults(ctx context.Context, task *core.CompiledTask, platformTaskResources workflowengineInterfaces.TaskResources) { @@ -204,99 +239,7 @@ func (m *ExecutionManager) setCompiledTaskDefaults(ctx context.Context, task *co return } - if task.Template.GetContainer().Resources == nil { - // In case of no resources on the container, create empty requests and limits - // so the container will still have resources configure properly - task.Template.GetContainer().Resources = &core.Resources{ - Requests: []*core.Resources_ResourceEntry{}, - Limits: []*core.Resources_ResourceEntry{}, - } - } - - var finalizedResourceRequests = make([]*core.Resources_ResourceEntry, 0) - var finalizedResourceLimits = make([]*core.Resources_ResourceEntry, 0) - - // The IDL representation for container-type tasks represents resources as a list with string quantities. - // In order to easily reason about them we convert them to a set where we can O(1) fetch specific resources (e.g. CPU) - // and represent them as comparable quantities rather than strings. - taskResourceRequirements := util.GetCompleteTaskResourceRequirements(ctx, task.Template.Id, task) - - cpu := flytek8s.AdjustOrDefaultResource(taskResourceRequirements.Defaults.CPU, taskResourceRequirements.Limits.CPU, - platformTaskResources.Defaults.CPU, platformTaskResources.Limits.CPU) - finalizedResourceRequests = append(finalizedResourceRequests, &core.Resources_ResourceEntry{ - Name: core.Resources_CPU, - Value: cpu.Request.String(), - }) - finalizedResourceLimits = append(finalizedResourceLimits, &core.Resources_ResourceEntry{ - Name: core.Resources_CPU, - Value: cpu.Limit.String(), - }) - - memory := flytek8s.AdjustOrDefaultResource(taskResourceRequirements.Defaults.Memory, taskResourceRequirements.Limits.Memory, - platformTaskResources.Defaults.Memory, platformTaskResources.Limits.Memory) - finalizedResourceRequests = append(finalizedResourceRequests, &core.Resources_ResourceEntry{ - Name: core.Resources_MEMORY, - Value: memory.Request.String(), - }) - finalizedResourceLimits = append(finalizedResourceLimits, &core.Resources_ResourceEntry{ - Name: core.Resources_MEMORY, - Value: memory.Limit.String(), - }) - - // Only assign ephemeral storage when it is either requested or limited in the task definition, or a platform - // default exists. - if !taskResourceRequirements.Defaults.EphemeralStorage.IsZero() || - !taskResourceRequirements.Limits.EphemeralStorage.IsZero() || - !platformTaskResources.Defaults.EphemeralStorage.IsZero() { - ephemeralStorage := flytek8s.AdjustOrDefaultResource(taskResourceRequirements.Defaults.EphemeralStorage, taskResourceRequirements.Limits.EphemeralStorage, - platformTaskResources.Defaults.EphemeralStorage, platformTaskResources.Limits.EphemeralStorage) - finalizedResourceRequests = append(finalizedResourceRequests, &core.Resources_ResourceEntry{ - Name: core.Resources_EPHEMERAL_STORAGE, - Value: ephemeralStorage.Request.String(), - }) - finalizedResourceLimits = append(finalizedResourceLimits, &core.Resources_ResourceEntry{ - Name: core.Resources_EPHEMERAL_STORAGE, - Value: ephemeralStorage.Limit.String(), - }) - } - - // Only assign storage when it is either requested or limited in the task definition, or a platform - // default exists. - if !taskResourceRequirements.Defaults.Storage.IsZero() || - !taskResourceRequirements.Limits.Storage.IsZero() || - !platformTaskResources.Defaults.Storage.IsZero() { - storageResource := flytek8s.AdjustOrDefaultResource(taskResourceRequirements.Defaults.Storage, taskResourceRequirements.Limits.Storage, - platformTaskResources.Defaults.Storage, platformTaskResources.Limits.Storage) - finalizedResourceRequests = append(finalizedResourceRequests, &core.Resources_ResourceEntry{ - Name: core.Resources_STORAGE, - Value: storageResource.Request.String(), - }) - finalizedResourceLimits = append(finalizedResourceLimits, &core.Resources_ResourceEntry{ - Name: core.Resources_STORAGE, - Value: storageResource.Limit.String(), - }) - } - - // Only assign gpu when it is either requested or limited in the task definition, or a platform default exists. - if !taskResourceRequirements.Defaults.GPU.IsZero() || - !taskResourceRequirements.Limits.GPU.IsZero() || - !platformTaskResources.Defaults.GPU.IsZero() { - gpu := flytek8s.AdjustOrDefaultResource(taskResourceRequirements.Defaults.GPU, taskResourceRequirements.Limits.GPU, - platformTaskResources.Defaults.GPU, platformTaskResources.Limits.GPU) - finalizedResourceRequests = append(finalizedResourceRequests, &core.Resources_ResourceEntry{ - Name: core.Resources_GPU, - Value: gpu.Request.String(), - }) - finalizedResourceLimits = append(finalizedResourceLimits, &core.Resources_ResourceEntry{ - Name: core.Resources_GPU, - Value: gpu.Limit.String(), - }) - } - - task.Template.GetContainer().Resources = &core.Resources{ - Requests: finalizedResourceRequests, - Limits: finalizedResourceLimits, - } + task.Template.GetContainer().Resources = m.getResources(ctx, task.Template.GetContainer().Resources, platformTaskResources) } // Fetches inherited execution metadata including the parent node execution db model id and the source execution model id @@ -1623,7 +1566,7 @@ func NewExecutionManager(db repositoryInterfaces.Repository, pluginRegistry *plu "size in bytes of serialized execution outputs"), } - resourceManager := resources.NewResourceManager(db, config.ApplicationConfiguration()) + resourceManager := resources.NewResourceManager(db, config) return &ExecutionManager{ db: db, config: config, diff --git a/pkg/manager/impl/execution_manager_test.go b/pkg/manager/impl/execution_manager_test.go index e8d16d348..7ff41936a 100644 --- a/pkg/manager/impl/execution_manager_test.go +++ b/pkg/manager/impl/execution_manager_test.go @@ -9,7 +9,6 @@ import ( "github.com/flyteorg/flyteadmin/plugins" "google.golang.org/grpc/status" - "google.golang.org/protobuf/types/known/timestamppb" "github.com/benbjohnson/clock" @@ -87,12 +86,12 @@ var testCluster = "C1" var outputURI = "output uri" var resourceDefaults = runtimeInterfaces.TaskResourceSet{ - CPU: resource.MustParse("200m"), - Memory: resource.MustParse("200Gi"), + CPU: testutils.GetPtr(resource.MustParse("200m")), + Memory: testutils.GetPtr(resource.MustParse("200Gi")), } var resourceLimits = runtimeInterfaces.TaskResourceSet{ - CPU: resource.MustParse("300m"), - Memory: resource.MustParse("500Gi"), + CPU: testutils.GetPtr(resource.MustParse("300m")), + Memory: testutils.GetPtr(resource.MustParse("500Gi")), } func getLegacySpec() *admin.ExecutionSpec { @@ -3754,16 +3753,16 @@ func TestSetDefaults(t *testing.T) { execManager := NewExecutionManager(repositoryMocks.NewMockRepository(), r, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockScope.NewTestScope(), &mockPublisher, mockExecutionRemoteURL, nil, nil, nil, nil, &eventWriterMocks.WorkflowExecutionEventWriter{}) execManager.(*ExecutionManager).setCompiledTaskDefaults(context.Background(), task, workflowengineInterfaces.TaskResources{ Defaults: runtimeInterfaces.TaskResourceSet{ - CPU: resource.MustParse("200m"), - GPU: resource.MustParse("4"), - Memory: resource.MustParse("200Gi"), - EphemeralStorage: resource.MustParse("500Mi"), + CPU: testutils.GetPtr(resource.MustParse("200m")), + GPU: testutils.GetPtr(resource.MustParse("4")), + Memory: testutils.GetPtr(resource.MustParse("200Gi")), + EphemeralStorage: testutils.GetPtr(resource.MustParse("500Mi")), }, Limits: runtimeInterfaces.TaskResourceSet{ - CPU: resource.MustParse("300m"), - GPU: resource.MustParse("8"), - Memory: resource.MustParse("500Gi"), - EphemeralStorage: resource.MustParse("501Mi"), + CPU: testutils.GetPtr(resource.MustParse("300m")), + GPU: testutils.GetPtr(resource.MustParse("8")), + Memory: testutils.GetPtr(resource.MustParse("500Gi")), + EphemeralStorage: testutils.GetPtr(resource.MustParse("501Mi")), }, }) assert.True(t, proto.Equal( @@ -3839,15 +3838,15 @@ func TestSetDefaults_MissingRequests_ExistingRequestsPreserved(t *testing.T) { execManager := NewExecutionManager(repositoryMocks.NewMockRepository(), r, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockScope.NewTestScope(), &mockPublisher, mockExecutionRemoteURL, nil, nil, nil, nil, &eventWriterMocks.WorkflowExecutionEventWriter{}) execManager.(*ExecutionManager).setCompiledTaskDefaults(context.Background(), task, workflowengineInterfaces.TaskResources{ Defaults: runtimeInterfaces.TaskResourceSet{ - CPU: resource.MustParse("200m"), - GPU: resource.MustParse("4"), - Memory: resource.MustParse("200Gi"), + CPU: testutils.GetPtr(resource.MustParse("200m")), + GPU: testutils.GetPtr(resource.MustParse("4")), + Memory: testutils.GetPtr(resource.MustParse("200Gi")), }, Limits: runtimeInterfaces.TaskResourceSet{ - CPU: resource.MustParse("300m"), - GPU: resource.MustParse("8"), + CPU: testutils.GetPtr(resource.MustParse("300m")), + GPU: testutils.GetPtr(resource.MustParse("8")), // Because only the limit is set, this resource should not be injected. - EphemeralStorage: resource.MustParse("100"), + EphemeralStorage: testutils.GetPtr(resource.MustParse("100")), }, }) assert.True(t, proto.Equal( @@ -3888,10 +3887,10 @@ func TestSetDefaults_MissingRequests_ExistingRequestsPreserved(t *testing.T) { func TestSetDefaults_OptionalRequiredResources(t *testing.T) { taskConfigLimits := runtimeInterfaces.TaskResourceSet{ - CPU: resource.MustParse("300m"), - GPU: resource.MustParse("1"), - Memory: resource.MustParse("500Gi"), - EphemeralStorage: resource.MustParse("501Mi"), + CPU: testutils.GetPtr(resource.MustParse("300m")), + GPU: testutils.GetPtr(resource.MustParse("1")), + Memory: testutils.GetPtr(resource.MustParse("500Gi")), + EphemeralStorage: testutils.GetPtr(resource.MustParse("501Mi")), } task := &core.CompiledTask{ @@ -3917,8 +3916,8 @@ func TestSetDefaults_OptionalRequiredResources(t *testing.T) { execManager := NewExecutionManager(repositoryMocks.NewMockRepository(), r, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockScope.NewTestScope(), &mockPublisher, mockExecutionRemoteURL, nil, nil, nil, nil, &eventWriterMocks.WorkflowExecutionEventWriter{}) execManager.(*ExecutionManager).setCompiledTaskDefaults(context.Background(), task, workflowengineInterfaces.TaskResources{ Defaults: runtimeInterfaces.TaskResourceSet{ - CPU: resource.MustParse("200m"), - Memory: resource.MustParse("200Gi"), + CPU: testutils.GetPtr(resource.MustParse("200m")), + Memory: testutils.GetPtr(resource.MustParse("200Gi")), }, Limits: taskConfigLimits, }) @@ -3957,9 +3956,9 @@ func TestSetDefaults_OptionalRequiredResources(t *testing.T) { execManager.(*ExecutionManager).setCompiledTaskDefaults(context.Background(), task, workflowengineInterfaces.TaskResources{ Limits: taskConfigLimits, Defaults: runtimeInterfaces.TaskResourceSet{ - CPU: resource.MustParse("200m"), - Memory: resource.MustParse("200Gi"), - EphemeralStorage: resource.MustParse("1"), + CPU: testutils.GetPtr(resource.MustParse("200m")), + Memory: testutils.GetPtr(resource.MustParse("200Gi")), + EphemeralStorage: testutils.GetPtr(resource.MustParse("1")), }, }) assert.True(t, proto.Equal( @@ -3998,7 +3997,202 @@ func TestSetDefaults_OptionalRequiredResources(t *testing.T) { task.Template.GetContainer()), fmt.Sprintf("%+v", task.Template.GetContainer())) }) + t.Run("respect non-required ddresources when defaults exist in config", func(t *testing.T) { + r := plugins.NewRegistry() + r.RegisterDefault(plugins.PluginIDWorkflowExecutor, &defaultTestExecutor) + execManager := NewExecutionManager(repositoryMocks.NewMockRepository(), r, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockScope.NewTestScope(), &mockPublisher, mockExecutionRemoteURL, nil, nil, nil, nil, &eventWriterMocks.WorkflowExecutionEventWriter{}) + + zeroedLimitTask := &core.CompiledTask{ + Template: &core.TaskTemplate{ + Target: &core.TaskTemplate_Container{ + Container: &core.Container{ + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + { + Name: core.Resources_CPU, + Value: "200m", + }, + }, + Limits: []*core.Resources_ResourceEntry{ + { + Name: core.Resources_CPU, + Value: "0", + }, + { + Name: core.Resources_MEMORY, + Value: "0", + }, + }, + }, + }, + }, + Id: &taskIdentifier, + }, + } + + execManager.(*ExecutionManager).setCompiledTaskDefaults(context.Background(), zeroedLimitTask, workflowengineInterfaces.TaskResources{ + Limits: taskConfigLimits, + Defaults: runtimeInterfaces.TaskResourceSet{ + CPU: testutils.GetPtr(resource.MustParse("200m")), // should be ignored + Memory: testutils.GetPtr(resource.MustParse("200Gi")), // should get merged in + EphemeralStorage: testutils.GetPtr(resource.MustParse("100Mi")), // should get merged in + }, + }) + assert.True(t, proto.Equal( + &core.Container{ + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + { + Name: core.Resources_CPU, + Value: "200m", + }, + { + Name: core.Resources_MEMORY, + Value: "200Gi", + }, + { + Name: core.Resources_EPHEMERAL_STORAGE, + Value: "100Mi", + }, + }, + Limits: []*core.Resources_ResourceEntry{ + // zeros requested by the task are preserved + { + Name: core.Resources_CPU, + Value: "0", + }, + { + Name: core.Resources_MEMORY, + Value: "0", + }, + { + // task limit should be set to task request, which was 100Mi, the 501Mi from the taskConfigLimits + // should have just been used to gate the request. + Name: core.Resources_EPHEMERAL_STORAGE, + Value: "100Mi", + }, + }, + }, + }, + zeroedLimitTask.Template.GetContainer()), fmt.Sprintf("Received value: %+v", zeroedLimitTask.Template.GetContainer())) + + // Test the same task again but with different system limits + loweredEStoragePlatformLimits := runtimeInterfaces.TaskResourceSet{ + CPU: testutils.GetPtr(resource.MustParse("300m")), + GPU: testutils.GetPtr(resource.MustParse("1")), + Memory: nil, + EphemeralStorage: testutils.GetPtr(resource.MustParse("99Mi")), + } + execManager.(*ExecutionManager).setCompiledTaskDefaults(context.Background(), zeroedLimitTask, workflowengineInterfaces.TaskResources{ + Limits: loweredEStoragePlatformLimits, + Defaults: runtimeInterfaces.TaskResourceSet{ + CPU: testutils.GetPtr(resource.MustParse("200m")), // should be ignored + Memory: testutils.GetPtr(resource.MustParse("200Gi")), // should get merged in + EphemeralStorage: testutils.GetPtr(resource.MustParse("100Mi")), // should get merged in + }, + }) + assert.True(t, proto.Equal( + &core.Container{ + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + { + Name: core.Resources_CPU, + Value: "200m", + }, + { + Name: core.Resources_MEMORY, + Value: "200Gi", + }, + { + Name: core.Resources_EPHEMERAL_STORAGE, + Value: "99Mi", + }, + }, + Limits: []*core.Resources_ResourceEntry{ + // zeros requested by the task are preserved + { + Name: core.Resources_CPU, + Value: "0", + }, + { + Name: core.Resources_MEMORY, + Value: "0", + }, + { + // task limit should be set to task request, which was 100Mi, the 501Mi from the taskConfigLimits + // should have just been used to gate the request. + Name: core.Resources_EPHEMERAL_STORAGE, + Value: "99Mi", + }, + }, + }, + }, + zeroedLimitTask.Template.GetContainer()), fmt.Sprintf("Received value: %+v", zeroedLimitTask.Template.GetContainer())) + }) +} + +func TestGetResourcesDirectly(t *testing.T) { + taskConfigLimits := runtimeInterfaces.TaskResourceSet{ + CPU: testutils.GetPtr(resource.MustParse("300m")), + GPU: testutils.GetPtr(resource.MustParse("1")), + Memory: testutils.GetPtr(resource.MustParse("500Gi")), + EphemeralStorage: testutils.GetPtr(resource.MustParse("501Mi")), + } + + t.Run("zero handling 1", func(t *testing.T) { + r := plugins.NewRegistry() + r.RegisterDefault(plugins.PluginIDWorkflowExecutor, &defaultTestExecutor) + execManager := NewExecutionManager(repositoryMocks.NewMockRepository(), r, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockScope.NewTestScope(), &mockPublisher, mockExecutionRemoteURL, nil, nil, nil, nil, &eventWriterMocks.WorkflowExecutionEventWriter{}) + + taskResources := &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + { + Name: core.Resources_CPU, + Value: "200m", + }, + { + Name: core.Resources_MEMORY, + Value: "0", + }, + }, + } + platformResources := workflowengineInterfaces.TaskResources{ + Defaults: runtimeInterfaces.TaskResourceSet{ + CPU: testutils.GetPtr(resource.MustParse("200m")), + Memory: testutils.GetPtr(resource.MustParse("200Gi")), + }, + Limits: taskConfigLimits, + } + + result := execManager.(*ExecutionManager).getResources(context.Background(), taskResources, platformResources) + + assert.True(t, proto.Equal( + &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + { + Name: core.Resources_CPU, + Value: "200m", + }, + { + Name: core.Resources_MEMORY, + Value: "0", + }, + }, + Limits: []*core.Resources_ResourceEntry{ + { + Name: core.Resources_CPU, + Value: "200m", + }, + { + Name: core.Resources_MEMORY, + Value: "0", + }, + }, + }, + result), fmt.Sprintf("Actual: %+v", result)) + }) } + func TestCreateSingleTaskExecution(t *testing.T) { repository := getMockRepositoryForExecTest() var getCalledCount = 0 diff --git a/pkg/manager/impl/executions/queues.go b/pkg/manager/impl/executions/queues.go index b97254970..999a9fc52 100644 --- a/pkg/manager/impl/executions/queues.go +++ b/pkg/manager/impl/executions/queues.go @@ -114,7 +114,7 @@ func NewQueueAllocator(config runtimeInterfaces.Configuration, db repoInterfaces queueAllocator := queueAllocatorImpl{ config: config, db: db, - resourceManager: resources.NewResourceManager(db, config.ApplicationConfiguration()), + resourceManager: resources.NewResourceManager(db, config), } return &queueAllocator } diff --git a/pkg/manager/impl/resources/resource_manager.go b/pkg/manager/impl/resources/resource_manager.go index d65658991..3bc1e2fb1 100644 --- a/pkg/manager/impl/resources/resource_manager.go +++ b/pkg/manager/impl/resources/resource_manager.go @@ -26,7 +26,7 @@ import ( type ResourceManager struct { db repo_interface.Repository - config runtimeInterfaces.ApplicationConfiguration + config runtimeInterfaces.Configuration } func (m *ResourceManager) GetResource(ctx context.Context, request interfaces.ResourceRequest) (*interfaces.ResourceResponse, error) { @@ -57,6 +57,40 @@ func (m *ResourceManager) GetResource(ctx context.Context, request interfaces.Re }, nil } +func (m *ResourceManager) GetResourcesList(ctx context.Context, request interfaces.ResourceRequest) (*interfaces.ResourceResponseList, error) { + resources, err := m.db.ResourceRepo().GetRows(ctx, repo_interface.ResourceID{ + ResourceType: request.ResourceType.String(), + Project: request.Project, + Domain: request.Domain, + Workflow: request.Workflow, + LaunchPlan: request.LaunchPlan, + }) + if err != nil { + return nil, err + } + logger.Debugf(ctx, "Retrieved %d rows listing resource type %s", len(resources), request.ResourceType.String()) + + var attributes = make([]*admin.MatchingAttributes, 0, len(resources)) + for _, resource := range resources { + var attr admin.MatchingAttributes + err = proto.Unmarshal(resource.Attributes, &attr) + if err != nil { + return nil, errors.NewFlyteAdminErrorf( + codes.Internal, "Failed to decode resource attribute with err: %v", err) + } + attributes = append(attributes, &attr) + } + + return &interfaces.ResourceResponseList{ + ResourceType: request.ResourceType.String(), + Project: request.Project, + Domain: request.Domain, + Workflow: request.Workflow, + LaunchPlan: request.LaunchPlan, + AttributeList: attributes, + }, nil +} + func (m *ResourceManager) createOrMergeUpdateWorkflowAttributes( ctx context.Context, request admin.WorkflowAttributesUpdateRequest, model models.Resource, resourceType admin.MatchableResource) (*admin.WorkflowAttributesUpdateResponse, error) { @@ -97,7 +131,7 @@ func (m *ResourceManager) UpdateWorkflowAttributes( *admin.WorkflowAttributesUpdateResponse, error) { var resource admin.MatchableResource var err error - if resource, err = validation.ValidateWorkflowAttributesUpdateRequest(ctx, m.db, m.config, request); err != nil { + if resource, err = validation.ValidateWorkflowAttributesUpdateRequest(ctx, m.db, m.config.ApplicationConfiguration(), request); err != nil { return nil, err } @@ -119,7 +153,26 @@ func (m *ResourceManager) UpdateWorkflowAttributes( func (m *ResourceManager) GetWorkflowAttributes( ctx context.Context, request admin.WorkflowAttributesGetRequest) ( *admin.WorkflowAttributesGetResponse, error) { - if err := validation.ValidateWorkflowAttributesGetRequest(ctx, m.db, m.config, request); err != nil { + + // if the request is a task resource request, then call that logic designed to merge task resources from + // different levels along with base config + if request.ResourceType == admin.MatchableResource_TASK_RESOURCE { + r := repo_interface.ResourceID{Project: request.Project, Domain: request.Domain, Workflow: request.Workflow, ResourceType: request.ResourceType.String()} + matchingAttributes, err := m.HandleGetTaskResourceRequest(ctx, r) + if err != nil { + return nil, err + } + return &admin.WorkflowAttributesGetResponse{ + Attributes: &admin.WorkflowAttributes{ + Project: request.Project, + Domain: request.Domain, + Workflow: request.Workflow, + MatchingAttributes: matchingAttributes, + }, + }, nil + } + + if err := validation.ValidateWorkflowAttributesGetRequest(ctx, m.db, m.config.ApplicationConfiguration(), request); err != nil { return nil, err } workflowAttributesModel, err := m.db.ResourceRepo().Get( @@ -138,7 +191,7 @@ func (m *ResourceManager) GetWorkflowAttributes( func (m *ResourceManager) DeleteWorkflowAttributes(ctx context.Context, request admin.WorkflowAttributesDeleteRequest) (*admin.WorkflowAttributesDeleteResponse, error) { - if err := validation.ValidateWorkflowAttributesDeleteRequest(ctx, m.db, m.config, request); err != nil { + if err := validation.ValidateWorkflowAttributesDeleteRequest(ctx, m.db, m.config.ApplicationConfiguration(), request); err != nil { return nil, err } if err := m.db.ResourceRepo().Delete( @@ -202,41 +255,98 @@ func (m *ResourceManager) GetProjectAttributesBase(ctx context.Context, request }, nil } +// HandleGetTaskResourceRequest needs to merge results from multiple layers in the db, along with configuration value +func (m *ResourceManager) HandleGetTaskResourceRequest(ctx context.Context, request repo_interface.ResourceID) (*admin.MatchingAttributes, error) { + if err := validation.ValidateProjectExists(ctx, m.db, request.Project); err != nil { + return nil, err + } + + var attrs []admin.TaskResourceAttributes + attrs = []admin.TaskResourceAttributes{} + + rrList, err := m.GetResourcesList(ctx, interfaces.ResourceRequest{ + Project: request.Project, + Domain: request.Domain, + Workflow: request.Workflow, + LaunchPlan: "", + ResourceType: admin.MatchableResource_TASK_RESOURCE, + }) + + if err != nil { + ec, ok := err.(errors.FlyteAdminError) + if ok && ec.Code() == codes.NotFound { + logger.Debug(ctx, "HandleGetTaskResourceRequest did not find any task resources, falling back") + } else { + return nil, err + } + } else { + logger.Debugf(ctx, "HandleGetTaskResourceRequest returned [%d] task resources, combining with config", len(rrList.AttributeList)) + for _, rr := range rrList.AttributeList { + if rr.GetTaskResourceAttributes() != nil { + attrs = append(attrs, *rr.GetTaskResourceAttributes()) + } + } + } + + attrs = append(attrs, m.config.TaskResourceConfiguration().GetAsAttribute()) + responseAttributes := util.MergeDownTaskResources(attrs...) + + return &admin.MatchingAttributes{ + Target: &admin.MatchingAttributes_TaskResourceAttributes{ + TaskResourceAttributes: responseAttributes, + }, + }, nil +} + // GetProjectAttributes combines the call to the database to get the Project level settings with // Admin server level configuration. // Note this merge is only done for WorkflowExecutionConfig +// This merge should be done for the following matchable-resource types: +// TASK_RESOURCE, WORKFLOW_EXECUTION_CONFIG // This code should be removed pending implementation of a complete settings implementation. func (m *ResourceManager) GetProjectAttributes(ctx context.Context, request admin.ProjectAttributesGetRequest) ( *admin.ProjectAttributesGetResponse, error) { + // if the request is a task resource request, then call that logic designed to merge task resources from + // different levels along with base config + if request.ResourceType == admin.MatchableResource_TASK_RESOURCE { + r := repo_interface.ResourceID{Project: request.Project, Domain: "", ResourceType: request.ResourceType.String()} + matchingAttributes, err := m.HandleGetTaskResourceRequest(ctx, r) + if err != nil { + return nil, err + } + return &admin.ProjectAttributesGetResponse{ + Attributes: &admin.ProjectAttributes{ + Project: request.Project, + MatchingAttributes: matchingAttributes, + }, + }, nil + } + getResponse, err := m.GetProjectAttributesBase(ctx, request) - configLevelDefaults := m.config.GetTopLevelConfig().GetAsWorkflowExecutionConfig() + + // Return as missing if missing and not one of the two matchable resources that are merged with system level config if err != nil { ec, ok := err.(errors.FlyteAdminError) - if ok && ec.Code() == codes.NotFound && request.ResourceType == admin.MatchableResource_WORKFLOW_EXECUTION_CONFIG { - // TODO: Will likely be removed after overarching settings project is done - return &admin.ProjectAttributesGetResponse{ - Attributes: &admin.ProjectAttributes{ - Project: request.Project, - MatchingAttributes: &admin.MatchingAttributes{ - Target: &admin.MatchingAttributes_WorkflowExecutionConfig{ - WorkflowExecutionConfig: &configLevelDefaults, - }, - }, - }, - }, nil + if ok && ec.Code() == codes.NotFound && (request.ResourceType == admin.MatchableResource_WORKFLOW_EXECUTION_CONFIG || request.ResourceType == admin.MatchableResource_TASK_RESOURCE) { + logger.Debugf(ctx, "Attributes not found, but look for system fallback %s", request.ResourceType.String()) + } else { + return nil, err } - return nil, err - } - // If found, then merge result with the default values for the platform - // TODO: Remove this logic once the overarching settings project is done. Those endpoints should take - // default configuration into account. - responseAttributes := getResponse.Attributes.GetMatchingAttributes().GetWorkflowExecutionConfig() - if responseAttributes != nil { - logger.Warningf(ctx, "Merging response %s with defaults %s", responseAttributes, configLevelDefaults) - tmp := util.MergeIntoExecConfig(*responseAttributes, &configLevelDefaults) - responseAttributes = &tmp + + // Merge with system level config if appropriate + if request.ResourceType == admin.MatchableResource_WORKFLOW_EXECUTION_CONFIG { + var responseAttributes *admin.WorkflowExecutionConfig + configLevelDefaults := m.config.ApplicationConfiguration().GetAsWorkflowExecutionAttribute() + if getResponse == nil || getResponse.Attributes == nil || getResponse.Attributes.GetMatchingAttributes() == nil || getResponse.Attributes.GetMatchingAttributes().GetWorkflowExecutionConfig() == nil { + responseAttributes = &configLevelDefaults + } else { + logger.Debugf(ctx, "Merging workflow config %s with defaults %s", responseAttributes, configLevelDefaults) + responseAttributes = getResponse.Attributes.GetMatchingAttributes().GetWorkflowExecutionConfig() + tmp := util.MergeIntoExecConfig(*responseAttributes, &configLevelDefaults) + responseAttributes = &tmp + } return &admin.ProjectAttributesGetResponse{ Attributes: &admin.ProjectAttributes{ Project: request.Project, @@ -247,6 +357,27 @@ func (m *ResourceManager) GetProjectAttributes(ctx context.Context, request admi }, }, }, nil + } else if request.ResourceType == admin.MatchableResource_TASK_RESOURCE { + // todo: delete this, handled above + var responseAttributes *admin.TaskResourceAttributes + configLevelDefaults := m.config.TaskResourceConfiguration().GetAsAttribute() + if getResponse == nil || getResponse.Attributes == nil || getResponse.Attributes.GetMatchingAttributes() == nil || getResponse.Attributes.GetMatchingAttributes().GetTaskResourceAttributes() == nil { + responseAttributes = &configLevelDefaults + } else { + logger.Debugf(ctx, "Merging taskresources %v with system config %v", responseAttributes, configLevelDefaults) + responseAttributes = getResponse.Attributes.GetMatchingAttributes().GetTaskResourceAttributes() + responseAttributes = util.MergeDownTaskResources(*responseAttributes, configLevelDefaults) + } + return &admin.ProjectAttributesGetResponse{ + Attributes: &admin.ProjectAttributes{ + Project: request.Project, + MatchingAttributes: &admin.MatchingAttributes{ + Target: &admin.MatchingAttributes_TaskResourceAttributes{ + TaskResourceAttributes: responseAttributes, + }, + }, + }, + }, nil } return getResponse, nil @@ -289,6 +420,7 @@ func (m *ResourceManager) createOrMergeUpdateProjectDomainAttributes( } return nil, err } + // TODO: does this belong here? feels like the error should be better handled and not returned updatedModel, err := transformers.MergeUpdatePluginAttributes( ctx, existing, resourceType, &resourceID, request.Attributes.MatchingAttributes) if err != nil { @@ -342,7 +474,7 @@ func (m *ResourceManager) UpdateProjectDomainAttributes( *admin.ProjectDomainAttributesUpdateResponse, error) { var resource admin.MatchableResource var err error - if resource, err = validation.ValidateProjectDomainAttributesUpdateRequest(ctx, m.db, m.config, request); err != nil { + if resource, err = validation.ValidateProjectDomainAttributesUpdateRequest(ctx, m.db, m.config.ApplicationConfiguration(), request); err != nil { return nil, err } ctx = contextutils.WithProjectDomain(ctx, request.Attributes.Project, request.Attributes.Domain) @@ -364,7 +496,25 @@ func (m *ResourceManager) UpdateProjectDomainAttributes( func (m *ResourceManager) GetProjectDomainAttributes( ctx context.Context, request admin.ProjectDomainAttributesGetRequest) ( *admin.ProjectDomainAttributesGetResponse, error) { - if err := validation.ValidateProjectDomainAttributesGetRequest(ctx, m.db, m.config, request); err != nil { + + // if the request is a task resource request, then call that logic designed to merge task resources from + // different levels along with base config + if request.ResourceType == admin.MatchableResource_TASK_RESOURCE { + r := repo_interface.ResourceID{Project: request.Project, Domain: request.Domain, ResourceType: request.ResourceType.String()} + matchingAttributes, err := m.HandleGetTaskResourceRequest(ctx, r) + if err != nil { + return nil, err + } + return &admin.ProjectDomainAttributesGetResponse{ + Attributes: &admin.ProjectDomainAttributes{ + Project: request.Project, + Domain: request.Domain, + MatchingAttributes: matchingAttributes, + }, + }, nil + } + + if err := validation.ValidateProjectDomainAttributesGetRequest(ctx, m.db, m.config.ApplicationConfiguration(), request); err != nil { return nil, err } projectAttributesModel, err := m.db.ResourceRepo().Get( @@ -383,7 +533,7 @@ func (m *ResourceManager) GetProjectDomainAttributes( func (m *ResourceManager) DeleteProjectDomainAttributes(ctx context.Context, request admin.ProjectDomainAttributesDeleteRequest) (*admin.ProjectDomainAttributesDeleteResponse, error) { - if err := validation.ValidateProjectDomainAttributesDeleteRequest(ctx, m.db, m.config, request); err != nil { + if err := validation.ValidateProjectDomainAttributesDeleteRequest(ctx, m.db, m.config.ApplicationConfiguration(), request); err != nil { return nil, err } if err := m.db.ResourceRepo().Delete( @@ -417,7 +567,7 @@ func (m *ResourceManager) ListAll(ctx context.Context, request admin.ListMatchab }, nil } -func NewResourceManager(db repo_interface.Repository, config runtimeInterfaces.ApplicationConfiguration) interfaces.ResourceInterface { +func NewResourceManager(db repo_interface.Repository, config runtimeInterfaces.Configuration) interfaces.ResourceInterface { return &ResourceManager{ db: db, config: config, diff --git a/pkg/manager/impl/resources/resource_manager_test.go b/pkg/manager/impl/resources/resource_manager_test.go index 8c587937c..ed6b617af 100644 --- a/pkg/manager/impl/resources/resource_manager_test.go +++ b/pkg/manager/impl/resources/resource_manager_test.go @@ -5,6 +5,7 @@ import ( runtimeInterfaces "github.com/flyteorg/flyteadmin/pkg/runtime/interfaces" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + "k8s.io/apimachinery/pkg/api/resource" // pkg/runtime/interfaces/application_configuration.go "testing" @@ -54,7 +55,7 @@ func TestUpdateWorkflowAttributes(t *testing.T) { createOrUpdateCalled = true return nil } - manager := NewResourceManager(db, testutils.GetApplicationConfigWithDefaultDomains()) + manager := NewResourceManager(db, testutils.GetMockConfiguration()) _, err := manager.UpdateWorkflowAttributes(context.Background(), request) assert.Nil(t, err) assert.True(t, createOrUpdateCalled) @@ -95,7 +96,7 @@ func TestUpdateWorkflowAttributes_CreateOrMerge(t *testing.T) { createOrUpdateCalled = true return nil } - manager := NewResourceManager(db, testutils.GetApplicationConfigWithDefaultDomains()) + manager := NewResourceManager(db, testutils.GetMockConfiguration()) _, err := manager.UpdateWorkflowAttributes(context.Background(), request) assert.NoError(t, err) assert.True(t, createOrUpdateCalled) @@ -144,7 +145,7 @@ func TestUpdateWorkflowAttributes_CreateOrMerge(t *testing.T) { createOrUpdateCalled = true return nil } - manager := NewResourceManager(db, testutils.GetApplicationConfigWithDefaultDomains()) + manager := NewResourceManager(db, testutils.GetMockConfiguration()) _, err := manager.UpdateWorkflowAttributes(context.Background(), request) assert.NoError(t, err) assert.True(t, createOrUpdateCalled) @@ -174,7 +175,7 @@ func TestGetWorkflowAttributes(t *testing.T) { Attributes: expectedSerializedAttrs, }, nil } - manager := NewResourceManager(db, testutils.GetApplicationConfigWithDefaultDomains()) + manager := NewResourceManager(db, testutils.GetMockConfiguration()) response, err := manager.GetWorkflowAttributes(context.Background(), request) assert.Nil(t, err) assert.True(t, proto.Equal(&admin.WorkflowAttributesGetResponse{ @@ -203,7 +204,7 @@ func TestDeleteWorkflowAttributes(t *testing.T) { assert.Equal(t, admin.MatchableResource_EXECUTION_QUEUE.String(), ID.ResourceType) return nil } - manager := NewResourceManager(db, testutils.GetApplicationConfigWithDefaultDomains()) + manager := NewResourceManager(db, testutils.GetMockConfiguration()) _, err := manager.DeleteWorkflowAttributes(context.Background(), request) assert.Nil(t, err) } @@ -229,7 +230,7 @@ func TestUpdateProjectDomainAttributes(t *testing.T) { createOrUpdateCalled = true return nil } - manager := NewResourceManager(db, testutils.GetApplicationConfigWithDefaultDomains()) + manager := NewResourceManager(db, testutils.GetMockConfiguration()) _, err := manager.UpdateProjectDomainAttributes(context.Background(), request) assert.Nil(t, err) assert.True(t, createOrUpdateCalled) @@ -268,7 +269,7 @@ func TestUpdateProjectDomainAttributes_CreateOrMerge(t *testing.T) { createOrUpdateCalled = true return nil } - manager := NewResourceManager(db, testutils.GetApplicationConfigWithDefaultDomains()) + manager := NewResourceManager(db, testutils.GetMockConfiguration()) _, err := manager.UpdateProjectDomainAttributes(context.Background(), request) assert.NoError(t, err) assert.True(t, createOrUpdateCalled) @@ -315,7 +316,7 @@ func TestUpdateProjectDomainAttributes_CreateOrMerge(t *testing.T) { createOrUpdateCalled = true return nil } - manager := NewResourceManager(db, testutils.GetApplicationConfigWithDefaultDomains()) + manager := NewResourceManager(db, testutils.GetMockConfiguration()) _, err := manager.UpdateProjectDomainAttributes(context.Background(), request) assert.NoError(t, err) assert.True(t, createOrUpdateCalled) @@ -343,7 +344,7 @@ func TestGetProjectDomainAttributes(t *testing.T) { Attributes: expectedSerializedAttrs, }, nil } - manager := NewResourceManager(db, testutils.GetApplicationConfigWithDefaultDomains()) + manager := NewResourceManager(db, testutils.GetMockConfiguration()) response, err := manager.GetProjectDomainAttributes(context.Background(), request) assert.Nil(t, err) assert.True(t, proto.Equal(&admin.ProjectDomainAttributesGetResponse{ @@ -369,7 +370,7 @@ func TestDeleteProjectDomainAttributes(t *testing.T) { assert.Equal(t, admin.MatchableResource_EXECUTION_QUEUE.String(), ID.ResourceType) return nil } - manager := NewResourceManager(db, testutils.GetApplicationConfigWithDefaultDomains()) + manager := NewResourceManager(db, testutils.GetMockConfiguration()) _, err := manager.DeleteProjectDomainAttributes(context.Background(), request) assert.Nil(t, err) } @@ -394,7 +395,7 @@ func TestUpdateProjectAttributes(t *testing.T) { createOrUpdateCalled = true return nil } - manager := NewResourceManager(db, testutils.GetApplicationConfigWithDefaultDomains()) + manager := NewResourceManager(db, testutils.GetMockConfiguration()) _, err := manager.UpdateProjectAttributes(context.Background(), request) assert.Nil(t, err) assert.True(t, createOrUpdateCalled) @@ -451,7 +452,7 @@ func TestUpdateProjectAttributes_CreateOrMerge(t *testing.T) { createOrUpdateCalled = true return nil } - manager := NewResourceManager(db, testutils.GetApplicationConfigWithDefaultDomains()) + manager := NewResourceManager(db, testutils.GetMockConfiguration()) _, err := manager.UpdateProjectAttributes(context.Background(), request) assert.NoError(t, err) assert.True(t, createOrUpdateCalled) @@ -497,21 +498,21 @@ func TestUpdateProjectAttributes_CreateOrMerge(t *testing.T) { createOrUpdateCalled = true return nil } - manager := NewResourceManager(db, testutils.GetApplicationConfigWithDefaultDomains()) + manager := NewResourceManager(db, testutils.GetMockConfiguration()) _, err := manager.UpdateProjectAttributes(context.Background(), request) assert.NoError(t, err) assert.True(t, createOrUpdateCalled) }) } -func TestGetProjectAttributes(t *testing.T) { +func TestGetProjectAttributesWec(t *testing.T) { request := admin.ProjectAttributesGetRequest{ Project: project, ResourceType: admin.MatchableResource_WORKFLOW_EXECUTION_CONFIG, } db := mocks.NewMockRepository() - manager := NewResourceManager(db, testutils.GetApplicationConfigWithDefaultDomains()) + manager := NewResourceManager(db, testutils.GetMockConfiguration()) db.ResourceRepo().(*mocks.MockResourceRepo).GetFunction = func( ctx context.Context, ID repoInterfaces.ResourceID) (models.Resource, error) { @@ -557,8 +558,10 @@ func TestGetProjectAttributes_ConfigLookup(t *testing.T) { // return not found to trigger loading from config return models.Resource{}, errors.NewFlyteAdminError(codes.NotFound, "not found message") } - config := runtimeMocks.MockApplicationProvider{} - manager := NewResourceManager(db, &config) + appConfiguration := runtimeMocks.MockApplicationProvider{} + config := runtimeMocks.NewMockConfigurationProvider(&appConfiguration, nil, nil, nil, nil, nil).(*runtimeMocks.MockConfigurationProvider) + + manager := NewResourceManager(db, config) t.Run("config 1", func(t *testing.T) { appConfig := runtimeInterfaces.ApplicationConfig{ @@ -567,7 +570,7 @@ func TestGetProjectAttributes_ConfigLookup(t *testing.T) { Labels: map[string]string{"lab1": "name"}, OutputLocationPrefix: "s3://test-bucket", } - config.SetTopLevelConfig(appConfig) + appConfiguration.SetTopLevelConfig(appConfig) response, err := manager.GetProjectAttributes(context.Background(), request) assert.Nil(t, err) @@ -599,7 +602,7 @@ func TestGetProjectAttributes_ConfigLookup(t *testing.T) { MaxParallelism: 3, AssumableIamRole: "myrole", } - config.SetTopLevelConfig(appConfig) + appConfiguration.SetTopLevelConfig(appConfig) response, err := manager.GetProjectAttributes(context.Background(), request) assert.Nil(t, err) @@ -625,7 +628,7 @@ func TestGetProjectAttributes_ConfigLookup(t *testing.T) { MaxParallelism: 3, Annotations: map[string]string{"ann1": "val1"}, } - config.SetTopLevelConfig(appConfig) + appConfiguration.SetTopLevelConfig(appConfig) response, err := manager.GetProjectAttributes(context.Background(), request) assert.Nil(t, err) @@ -653,7 +656,7 @@ func TestGetProjectAttributes_ConfigLookup(t *testing.T) { Labels: map[string]string{"lab1": "name"}, OutputLocationPrefix: "s3://test-bucket", } - config.SetTopLevelConfig(appConfig) + appConfiguration.SetTopLevelConfig(appConfig) request := admin.ProjectAttributesGetRequest{ Project: project, ResourceType: admin.MatchableResource_EXECUTION_QUEUE, @@ -667,6 +670,457 @@ func TestGetProjectAttributes_ConfigLookup(t *testing.T) { }) } +func TestGetProjectAttributesTaskResource(t *testing.T) { + request := admin.ProjectAttributesGetRequest{ + Project: project, + ResourceType: admin.MatchableResource_TASK_RESOURCE, + } + db := mocks.NewMockRepository() + + manager := NewResourceManager(db, testutils.GetMockConfiguration()) + db.ResourceRepo().(*mocks.MockResourceRepo).GetRowsFunction = func( + ctx context.Context, ID repoInterfaces.ResourceID) ([]models.Resource, error) { + + assert.Equal(t, project, ID.Project) + assert.Equal(t, "", ID.Domain) + assert.Equal(t, "", ID.Workflow) + assert.Equal(t, admin.MatchableResource_TASK_RESOURCE.String(), ID.ResourceType) + expectedSerializedAttrs, _ := proto.Marshal(testutils.TaskResourcesSample) + return []models.Resource{ + { + Project: project, + Domain: "", + ResourceType: "resource", + Attributes: expectedSerializedAttrs, + }, + }, nil + } + response, err := manager.GetProjectAttributes(context.Background(), request) + assert.Nil(t, err) + assert.True(t, proto.Equal(&admin.ProjectAttributesGetResponse{ + Attributes: &admin.ProjectAttributes{ + Project: project, + MatchingAttributes: testutils.TaskResourcesSample, + }, + }, response)) + + // unrecognized errors are thrown + db.ResourceRepo().(*mocks.MockResourceRepo).GetRowsFunction = func( + ctx context.Context, ID repoInterfaces.ResourceID) ([]models.Resource, error) { + + return []models.Resource{}, errors.NewFlyteAdminErrorf(5323, "random code") + } + _, err = manager.GetProjectAttributes(context.Background(), request) + assert.Error(t, err) +} + +func TestGetProjectAttributesMTaskResource(t *testing.T) { + request := admin.ProjectAttributesGetRequest{ + Project: project, + ResourceType: admin.MatchableResource_TASK_RESOURCE, + } + db := mocks.NewMockRepository() + + manager := NewResourceManager(db, testutils.GetMockConfiguration()) + db.ResourceRepo().(*mocks.MockResourceRepo).GetRowsFunction = func( + ctx context.Context, ID repoInterfaces.ResourceID) ([]models.Resource, error) { + + assert.Equal(t, project, ID.Project) + assert.Equal(t, "", ID.Domain) + assert.Equal(t, "", ID.Workflow) + assert.Equal(t, admin.MatchableResource_TASK_RESOURCE.String(), ID.ResourceType) + expectedSerializedAttrs, _ := proto.Marshal(testutils.TaskResourcesSample) + return []models.Resource{ + { + Project: project, + Domain: "", + ResourceType: "resource", + Attributes: expectedSerializedAttrs, + }, + }, nil + } + response, err := manager.GetProjectAttributes(context.Background(), request) + assert.Nil(t, err) + assert.True(t, proto.Equal(&admin.ProjectAttributesGetResponse{ + Attributes: &admin.ProjectAttributes{ + Project: project, + MatchingAttributes: testutils.TaskResourcesSample, + }, + }, response)) +} + +func TestGetProjectAttributes_MergeTaskResourceConfigWithMatchableResources(t *testing.T) { + request := admin.ProjectAttributesGetRequest{ + Project: project, + ResourceType: admin.MatchableResource_TASK_RESOURCE, + } + db := mocks.NewMockRepository() + config := runtimeMocks.NewMockConfigurationProvider(nil, nil, nil, nil, nil, nil).(*runtimeMocks.MockConfigurationProvider) + + manager := NewResourceManager(db, config) + + t.Run("config no replacement, all from db", func(t *testing.T) { + db.ResourceRepo().(*mocks.MockResourceRepo).GetRowsFunction = func( + ctx context.Context, ID repoInterfaces.ResourceID) ([]models.Resource, error) { + matchingAttributes := &admin.MatchingAttributes{ + Target: &admin.MatchingAttributes_TaskResourceAttributes{ + TaskResourceAttributes: &admin.TaskResourceAttributes{ + Defaults: &admin.TaskResourceSpec{ + Cpu: "1", + Gpu: "2", + Memory: "100Mi", + EphemeralStorage: "100Gi", + }, + Limits: &admin.TaskResourceSpec{ + Cpu: "2", + Gpu: "3", + Memory: "200Mi", + EphemeralStorage: "300Gi", + }, + }, + }, + } + expectedSerializedAttrs, _ := proto.Marshal(matchingAttributes) + return []models.Resource{ + { + Project: project, + Domain: "", + ResourceType: "resource", + Attributes: expectedSerializedAttrs, + }, + }, nil + } + + taskConfiguration := runtimeMocks.NewMockTaskResourceConfiguration( + runtimeInterfaces.TaskResourceSet{ + CPU: testutils.GetPtr(resource.MustParse("2")), + Memory: testutils.GetPtr(resource.MustParse("200Mi")), + GPU: testutils.GetPtr(resource.MustParse("0")), + EphemeralStorage: testutils.GetPtr(resource.MustParse("100Gi")), + }, + runtimeInterfaces.TaskResourceSet{ + CPU: testutils.GetPtr(resource.MustParse("2")), + Memory: testutils.GetPtr(resource.MustParse("1Gi")), + GPU: testutils.GetPtr(resource.MustParse("1")), + }, + ) + config.SetTaskResourceConfiguration(taskConfiguration) + + response, err := manager.GetProjectAttributes(context.Background(), request) + + assert.Nil(t, err) + tra := response.GetAttributes().GetMatchingAttributes().GetTaskResourceAttributes() + assert.True(t, proto.Equal(&admin.TaskResourceAttributes{ + Defaults: &admin.TaskResourceSpec{ + Cpu: "1", + Gpu: "2", + Memory: "100Mi", + EphemeralStorage: "100Gi", + }, + Limits: &admin.TaskResourceSpec{ + Cpu: "2", + Gpu: "3", + Memory: "200Mi", + EphemeralStorage: "300Gi", + }, + }, tra)) + }) + + t.Run("config merge some from config", func(t *testing.T) { + // returned by the database as matchable resource + db.ResourceRepo().(*mocks.MockResourceRepo).GetRowsFunction = func( + ctx context.Context, ID repoInterfaces.ResourceID) ([]models.Resource, error) { + matchingAttributes := &admin.MatchingAttributes{ + Target: &admin.MatchingAttributes_TaskResourceAttributes{ + TaskResourceAttributes: &admin.TaskResourceAttributes{ + Defaults: &admin.TaskResourceSpec{ + Cpu: "1", + Memory: "100Mi", + EphemeralStorage: "100Gi", + }, + Limits: &admin.TaskResourceSpec{ + Cpu: "2", + EphemeralStorage: "300Gi", + }, + }, + }, + } + expectedSerializedAttrs, _ := proto.Marshal(matchingAttributes) + return []models.Resource{ + { + Project: project, + Domain: "", + ResourceType: "resource", + Attributes: expectedSerializedAttrs, + }}, nil + } + + taskConfiguration := runtimeMocks.NewMockTaskResourceConfiguration( + runtimeInterfaces.TaskResourceSet{ + Memory: testutils.GetPtr(resource.MustParse("200Mi")), + EphemeralStorage: testutils.GetPtr(resource.MustParse("100Gi")), + }, + runtimeInterfaces.TaskResourceSet{ + Memory: testutils.GetPtr(resource.MustParse("1Gi")), + }, + ) + config.SetTaskResourceConfiguration(taskConfiguration) + + response, err := manager.GetProjectAttributes(context.Background(), request) + + assert.Nil(t, err) + tra := response.GetAttributes().GetMatchingAttributes().GetTaskResourceAttributes() + assert.True(t, proto.Equal(&admin.TaskResourceAttributes{ + Defaults: &admin.TaskResourceSpec{ + Cpu: "1", + Memory: "100Mi", + EphemeralStorage: "100Gi", + }, + Limits: &admin.TaskResourceSpec{ + Cpu: "2", + Memory: "1Gi", + EphemeralStorage: "300Gi", + }, + }, tra)) + }) + + t.Run("config merge all limits from config", func(t *testing.T) { + db.ResourceRepo().(*mocks.MockResourceRepo).GetRowsFunction = func( + ctx context.Context, ID repoInterfaces.ResourceID) ([]models.Resource, error) { + matchingAttributes := &admin.MatchingAttributes{ + Target: &admin.MatchingAttributes_TaskResourceAttributes{ + TaskResourceAttributes: &admin.TaskResourceAttributes{ + Defaults: &admin.TaskResourceSpec{ + Cpu: "1", + Memory: "100Mi", + EphemeralStorage: "100Gi", + }, + Limits: &admin.TaskResourceSpec{}, + }, + }, + } + expectedSerializedAttrs, _ := proto.Marshal(matchingAttributes) + return []models.Resource{ + { + Project: project, + Domain: "", + ResourceType: "resource", + Attributes: expectedSerializedAttrs, + }, + }, nil + } + + taskConfiguration := runtimeMocks.NewMockTaskResourceConfiguration( + runtimeInterfaces.TaskResourceSet{ + Memory: testutils.GetPtr(resource.MustParse("200Mi")), + EphemeralStorage: testutils.GetPtr(resource.MustParse("101Gi")), + }, + runtimeInterfaces.TaskResourceSet{ + CPU: testutils.GetPtr(resource.MustParse("2")), + Memory: testutils.GetPtr(resource.MustParse("1Gi")), + EphemeralStorage: testutils.GetPtr(resource.MustParse("153Gi")), + }, + ) + config.SetTaskResourceConfiguration(taskConfiguration) + + response, err := manager.GetProjectAttributes(context.Background(), request) + assert.Nil(t, err) + tra := response.GetAttributes().GetMatchingAttributes().GetTaskResourceAttributes() + assert.True(t, proto.Equal(&admin.TaskResourceAttributes{ + Defaults: &admin.TaskResourceSpec{ + Cpu: "1", + Memory: "100Mi", + EphemeralStorage: "100Gi", + }, + Limits: &admin.TaskResourceSpec{ + Cpu: "2", + Memory: "1Gi", + EphemeralStorage: "153Gi", + }, + }, tra)) + }) + + t.Run("base config limits limit matchable resources", func(t *testing.T) { + // returned by the database as matchable resource + db.ResourceRepo().(*mocks.MockResourceRepo).GetRowsFunction = func( + ctx context.Context, ID repoInterfaces.ResourceID) ([]models.Resource, error) { + matchingAttributes := &admin.MatchingAttributes{ + Target: &admin.MatchingAttributes_TaskResourceAttributes{ + TaskResourceAttributes: &admin.TaskResourceAttributes{ + Defaults: &admin.TaskResourceSpec{ + Cpu: "3", + Memory: "100Mi", + EphemeralStorage: "100Gi", + }, + Limits: &admin.TaskResourceSpec{ + Cpu: "4", + }, + }, + }, + } + expectedSerializedAttrs, _ := proto.Marshal(matchingAttributes) + return []models.Resource{ + { + Project: project, + Domain: "", + ResourceType: "resource", + Attributes: expectedSerializedAttrs, + }}, nil + } + + // This is the fake base system level configuration. + taskConfiguration := runtimeMocks.NewMockTaskResourceConfiguration( + runtimeInterfaces.TaskResourceSet{ + CPU: testutils.GetPtr(resource.MustParse("2")), + Memory: testutils.GetPtr(resource.MustParse("200Mi")), + EphemeralStorage: testutils.GetPtr(resource.MustParse("5Gi")), + }, + runtimeInterfaces.TaskResourceSet{ + CPU: testutils.GetPtr(resource.MustParse("2")), + Memory: testutils.GetPtr(resource.MustParse("1Gi")), + EphemeralStorage: testutils.GetPtr(resource.MustParse("10Gi")), + }, + ) + config.SetTaskResourceConfiguration(taskConfiguration) + + response, err := manager.GetProjectAttributes(context.Background(), request) + + assert.Nil(t, err) + tra := response.GetAttributes().GetMatchingAttributes().GetTaskResourceAttributes() + assert.True(t, proto.Equal(&admin.TaskResourceAttributes{ + Defaults: &admin.TaskResourceSpec{ + Cpu: "3", // unch, because the matchable resource overrides the base config. + Memory: "100Mi", // unch, lower than system limit of 1Gi + EphemeralStorage: "10Gi", // cut from 100 to 10, because we inherit the system limit of 10Gi + }, + Limits: &admin.TaskResourceSpec{ + Cpu: "4", // unch, base config overridden by the matchable resource + Memory: "1Gi", // inherited from the system limit + EphemeralStorage: "10Gi", // inherited from the system limit + }, + }, tra)) + }) + + t.Run("config override even works if 0", func(t *testing.T) { + // returned by the database as matchable resource + db.ResourceRepo().(*mocks.MockResourceRepo).GetRowsFunction = func( + ctx context.Context, ID repoInterfaces.ResourceID) ([]models.Resource, error) { + matchingAttributes := &admin.MatchingAttributes{ + Target: &admin.MatchingAttributes_TaskResourceAttributes{ + TaskResourceAttributes: &admin.TaskResourceAttributes{ + Defaults: &admin.TaskResourceSpec{ + Cpu: "2", + Memory: "2Gi", + Gpu: "0", + }, + Limits: &admin.TaskResourceSpec{ + Memory: "0", + }, + }, + }, + } + expectedSerializedAttrs, _ := proto.Marshal(matchingAttributes) + return []models.Resource{ + { + Project: project, + Domain: "", + ResourceType: "resource", + Attributes: expectedSerializedAttrs, + }}, nil + } + + // This is the fake base system level configuration. + taskConfiguration := runtimeMocks.NewMockTaskResourceConfiguration( + runtimeInterfaces.TaskResourceSet{ + Memory: testutils.GetPtr(resource.MustParse("200Mi")), + GPU: testutils.GetPtr(resource.MustParse("1")), + }, + runtimeInterfaces.TaskResourceSet{ + Memory: testutils.GetPtr(resource.MustParse("1Gi")), + }, + ) + config.SetTaskResourceConfiguration(taskConfiguration) + + response, err := manager.GetProjectAttributes(context.Background(), request) + + assert.Nil(t, err) + tra := response.GetAttributes().GetMatchingAttributes().GetTaskResourceAttributes() + assert.True(t, proto.Equal(&admin.TaskResourceAttributes{ + Defaults: &admin.TaskResourceSpec{ + Cpu: "2", + Gpu: "0", + Memory: "2Gi", + }, + Limits: &admin.TaskResourceSpec{ + Memory: "0", + }, + }, tra)) + }) +} + +func TestGetProjectAttributes_TaskResourceConfigLookup(t *testing.T) { + request := admin.ProjectAttributesGetRequest{ + Project: project, + ResourceType: admin.MatchableResource_TASK_RESOURCE, + } + db := mocks.NewMockRepository() + db.ResourceRepo().(*mocks.MockResourceRepo).GetFunction = func( + ctx context.Context, ID repoInterfaces.ResourceID) (models.Resource, error) { + // return not found to trigger loading from config + return models.Resource{}, errors.NewFlyteAdminError(codes.NotFound, "not found message") + } + config := runtimeMocks.NewMockConfigurationProvider(nil, nil, nil, nil, nil, nil).(*runtimeMocks.MockConfigurationProvider) + + manager := NewResourceManager(db, config) + + // what is the treatment of 0s that are explicitly set + // what about missing things + t.Run("config 1", func(t *testing.T) { + taskConfiguration := runtimeMocks.NewMockTaskResourceConfiguration( + runtimeInterfaces.TaskResourceSet{ + CPU: testutils.GetPtr(resource.MustParse("2")), + Memory: testutils.GetPtr(resource.MustParse("200Mi")), + GPU: testutils.GetPtr(resource.MustParse("0")), + EphemeralStorage: testutils.GetPtr(resource.MustParse("100Gi")), + }, + runtimeInterfaces.TaskResourceSet{ + CPU: testutils.GetPtr(resource.MustParse("2")), + Memory: testutils.GetPtr(resource.MustParse("1Gi")), + GPU: testutils.GetPtr(resource.MustParse("1")), + }, + ) + config.SetTaskResourceConfiguration(taskConfiguration) + + response, err := manager.GetProjectAttributes(context.Background(), request) + assert.Nil(t, err) + assert.True(t, proto.Equal(&admin.ProjectAttributesGetResponse{ + Attributes: &admin.ProjectAttributes{ + Project: project, + MatchingAttributes: &admin.MatchingAttributes{ + Target: &admin.MatchingAttributes_TaskResourceAttributes{ + TaskResourceAttributes: &admin.TaskResourceAttributes{ + Defaults: &admin.TaskResourceSpec{ + Cpu: "2", + Gpu: "0", + Memory: "200Mi", + Storage: "", + EphemeralStorage: "100Gi", + }, + Limits: &admin.TaskResourceSpec{ + Cpu: "2", + Gpu: "1", + Memory: "1Gi", + Storage: "", + EphemeralStorage: "", + }, + }, + }, + }, + }, + }, response)) + }) +} + func TestDeleteProjectAttributes(t *testing.T) { request := admin.ProjectAttributesDeleteRequest{ Project: project, @@ -680,7 +1134,7 @@ func TestDeleteProjectAttributes(t *testing.T) { assert.Equal(t, admin.MatchableResource_WORKFLOW_EXECUTION_CONFIG.String(), ID.ResourceType) return nil } - manager := NewResourceManager(db, testutils.GetApplicationConfigWithDefaultDomains()) + manager := NewResourceManager(db, testutils.GetMockConfiguration()) _, err := manager.DeleteProjectAttributes(context.Background(), request) assert.Nil(t, err) } @@ -711,7 +1165,7 @@ func TestGetResource(t *testing.T) { Attributes: expectedSerializedAttrs, }, nil } - manager := NewResourceManager(db, testutils.GetApplicationConfigWithDefaultDomains()) + manager := NewResourceManager(db, testutils.GetMockConfiguration()) response, err := manager.GetResource(context.Background(), request) assert.Nil(t, err) assert.Equal(t, request.Project, response.Project) @@ -762,7 +1216,7 @@ func TestListAllResources(t *testing.T) { }, }, nil } - manager := NewResourceManager(db, testutils.GetApplicationConfigWithDefaultDomains()) + manager := NewResourceManager(db, testutils.GetMockConfiguration()) response, err := manager.ListAll(context.Background(), admin.ListMatchableAttributesRequest{ ResourceType: admin.MatchableResource_CLUSTER_RESOURCE, }) diff --git a/pkg/manager/impl/task_manager.go b/pkg/manager/impl/task_manager.go index b4346fcd9..eae111bca 100644 --- a/pkg/manager/impl/task_manager.go +++ b/pkg/manager/impl/task_manager.go @@ -272,7 +272,7 @@ func NewTaskManager( ClosureSizeBytes: scope.MustNewSummary("closure_size_bytes", "size in bytes of serialized task closure"), Registered: labeled.NewCounter("num_registered", "count of registered tasks", scope), } - resourceManager := resources.NewResourceManager(db, config.ApplicationConfiguration()) + resourceManager := resources.NewResourceManager(db, config) return &TaskManager{ db: db, config: config, diff --git a/pkg/manager/impl/testutils/attributes.go b/pkg/manager/impl/testutils/attributes.go index 92d276d3e..50d407d95 100644 --- a/pkg/manager/impl/testutils/attributes.go +++ b/pkg/manager/impl/testutils/attributes.go @@ -26,3 +26,22 @@ var WorkflowExecutionConfigSample = &admin.MatchingAttributes{ }, }, } + +var TaskResourcesSample = &admin.MatchingAttributes{ + Target: &admin.MatchingAttributes_TaskResourceAttributes{ + TaskResourceAttributes: &admin.TaskResourceAttributes{ + Defaults: &admin.TaskResourceSpec{ + Cpu: "1", + Gpu: "2", + Memory: "100Mi", + EphemeralStorage: "100Gi", + }, + Limits: &admin.TaskResourceSpec{ + Cpu: "2", + Gpu: "3", + Memory: "200Mi", + EphemeralStorage: "300Gi", + }, + }, + }, +} diff --git a/pkg/manager/impl/testutils/config.go b/pkg/manager/impl/testutils/config.go index c2fe20139..799f4f1c8 100644 --- a/pkg/manager/impl/testutils/config.go +++ b/pkg/manager/impl/testutils/config.go @@ -4,6 +4,7 @@ import ( "github.com/flyteorg/flyteadmin/pkg/common" runtimeInterfaces "github.com/flyteorg/flyteadmin/pkg/runtime/interfaces" runtimeMocks "github.com/flyteorg/flyteadmin/pkg/runtime/mocks" + "k8s.io/apimachinery/pkg/api/resource" ) func GetApplicationConfigWithDefaultDomains() runtimeInterfaces.ApplicationConfiguration { @@ -32,3 +33,28 @@ func GetApplicationConfigWithDefaultDomains() runtimeInterfaces.ApplicationConfi }}) return &config } + +func GetPtr(quantity resource.Quantity) *resource.Quantity { + return &quantity +} + +func GetSampleTaskResourceConfiguration() runtimeInterfaces.TaskResourceConfiguration { + resourceDefaults := runtimeInterfaces.TaskResourceSet{ + CPU: GetPtr(resource.MustParse("200m")), + Memory: GetPtr(resource.MustParse("200Gi")), + GPU: GetPtr(resource.MustParse("0")), + } + resourceLimits := runtimeInterfaces.TaskResourceSet{ + CPU: GetPtr(resource.MustParse("300m")), + Memory: GetPtr(resource.MustParse("500Gi")), + EphemeralStorage: GetPtr(resource.MustParse("10Gi")), + } + + return runtimeMocks.NewMockTaskResourceConfiguration(resourceDefaults, resourceLimits) +} + +func GetMockConfiguration() runtimeInterfaces.Configuration { + appConfig := GetApplicationConfigWithDefaultDomains() + taskResourceConfig := GetSampleTaskResourceConfiguration() + return runtimeMocks.NewMockConfigurationProvider(appConfig, nil, nil, taskResourceConfig, nil, nil) +} diff --git a/pkg/manager/impl/util/resources.go b/pkg/manager/impl/util/resources.go index f09695723..faaa7455e 100644 --- a/pkg/manager/impl/util/resources.go +++ b/pkg/manager/impl/util/resources.go @@ -2,7 +2,9 @@ package util import ( "context" - "fmt" + + "github.com/flyteorg/flyteadmin/pkg/errors" + "google.golang.org/grpc/codes" "github.com/flyteorg/flyteadmin/pkg/manager/interfaces" runtimeInterfaces "github.com/flyteorg/flyteadmin/pkg/runtime/interfaces" @@ -10,78 +12,146 @@ import ( "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flytestdlib/logger" + "github.com/golang/protobuf/proto" "k8s.io/apimachinery/pkg/api/resource" ) // parseQuantityNoError parses the k8s defined resource quantity gracefully masking errors. -func parseQuantityNoError(ctx context.Context, ownerID, name, value string) resource.Quantity { +func parseQuantityNoError(ctx context.Context, value string) *resource.Quantity { + if value == "" { + return nil + } q, err := resource.ParseQuantity(value) if err != nil { - logger.Infof(ctx, "Failed to parse owner's [%s] resource [%s]'s value [%s] with err: %v", ownerID, name, value, err) + logger.Infof(ctx, "Failed to value [%s] with err: %v", value, err) + return nil } - return q + return &q } -// getTaskResourcesAsSet converts a list of flyteidl `ResourceEntry` messages into a singular `TaskResourceSet`. -func getTaskResourcesAsSet(ctx context.Context, identifier *core.Identifier, - resourceEntries []*core.Resources_ResourceEntry, resourceName string) runtimeInterfaces.TaskResourceSet { +func ConvertTaskResourceSetToCoreResources(resources runtimeInterfaces.TaskResourceSet) []*core.Resources_ResourceEntry { + var resourceEntries = make([]*core.Resources_ResourceEntry, 0) + + if resources.CPU != nil { + resourceEntries = append(resourceEntries, &core.Resources_ResourceEntry{ + Name: core.Resources_CPU, + Value: resources.CPU.String(), + }) + } + + if resources.Memory != nil { + resourceEntries = append(resourceEntries, &core.Resources_ResourceEntry{ + Name: core.Resources_MEMORY, + Value: resources.Memory.String(), + }) + } + + if resources.EphemeralStorage != nil { + resourceEntries = append(resourceEntries, &core.Resources_ResourceEntry{ + Name: core.Resources_EPHEMERAL_STORAGE, + Value: resources.EphemeralStorage.String(), + }) + } + + if resources.GPU != nil { + resourceEntries = append(resourceEntries, &core.Resources_ResourceEntry{ + Name: core.Resources_GPU, + Value: resources.GPU.String(), + }) + } + return resourceEntries +} + +// GetTaskResourcesAndCoalesce takes a set of Flyte IDL ResourceEntry's and fills in missing fields from the coalesce +func GetTaskResourcesAndCoalesce(ctx context.Context, + resourceEntries []*core.Resources_ResourceEntry, coalesce runtimeInterfaces.TaskResourceSet) runtimeInterfaces.TaskResourceSet { result := runtimeInterfaces.TaskResourceSet{} for _, entry := range resourceEntries { + q := parseQuantityNoError(ctx, entry.Value) switch entry.Name { case core.Resources_CPU: - result.CPU = parseQuantityNoError(ctx, identifier.String(), fmt.Sprintf("%v.cpu", resourceName), entry.Value) + result.CPU = q case core.Resources_MEMORY: - result.Memory = parseQuantityNoError(ctx, identifier.String(), fmt.Sprintf("%v.memory", resourceName), entry.Value) + result.Memory = q case core.Resources_EPHEMERAL_STORAGE: - result.EphemeralStorage = parseQuantityNoError(ctx, identifier.String(), - fmt.Sprintf("%v.ephemeral storage", resourceName), entry.Value) + result.EphemeralStorage = q + case core.Resources_STORAGE: + result.Storage = q case core.Resources_GPU: - result.GPU = parseQuantityNoError(ctx, identifier.String(), "gpu", entry.Value) + result.GPU = q + default: + logger.Warnf(ctx, "Unknown resource type [%s]", entry.Name) } } + if result.CPU == nil && coalesce.CPU != nil { + result.CPU = coalesce.CPU + } + if result.Memory == nil && coalesce.Memory != nil { + result.Memory = coalesce.Memory + } + if result.EphemeralStorage == nil && coalesce.EphemeralStorage != nil { + result.EphemeralStorage = coalesce.EphemeralStorage + } + if result.GPU == nil && coalesce.GPU != nil { + result.GPU = coalesce.GPU + } return result } -// GetCompleteTaskResourceRequirements parses the resource requests and limits from the `TaskTemplate` Container. -func GetCompleteTaskResourceRequirements(ctx context.Context, identifier *core.Identifier, task *core.CompiledTask) workflowengineInterfaces.TaskResources { - return workflowengineInterfaces.TaskResources{ - Defaults: getTaskResourcesAsSet(ctx, identifier, task.GetTemplate().GetContainer().Resources.Requests, "requests"), - Limits: getTaskResourcesAsSet(ctx, identifier, task.GetTemplate().GetContainer().Resources.Limits, "limits"), +// getTaskResourcesAsSet converts a list of flyteidl `ResourceEntry` messages into a singular `TaskResourceSet`. +func getTaskResourcesAsSet(ctx context.Context, + resourceEntries []*core.Resources_ResourceEntry) runtimeInterfaces.TaskResourceSet { + + result := runtimeInterfaces.TaskResourceSet{} + for _, entry := range resourceEntries { + switch entry.Name { + case core.Resources_CPU: + result.CPU = parseQuantityNoError(ctx, entry.Value) + case core.Resources_MEMORY: + result.Memory = parseQuantityNoError(ctx, entry.Value) + case core.Resources_EPHEMERAL_STORAGE: + result.EphemeralStorage = parseQuantityNoError(ctx, entry.Value) + case core.Resources_GPU: + result.GPU = parseQuantityNoError(ctx, entry.Value) + } } + + return result } // fromAdminProtoTaskResourceSpec parses the flyteidl `TaskResourceSpec` message into a `TaskResourceSet`. func fromAdminProtoTaskResourceSpec(ctx context.Context, spec *admin.TaskResourceSpec) runtimeInterfaces.TaskResourceSet { result := runtimeInterfaces.TaskResourceSet{} if len(spec.Cpu) > 0 { - result.CPU = parseQuantityNoError(ctx, "project", "cpu", spec.Cpu) + result.CPU = parseQuantityNoError(ctx, spec.Cpu) } if len(spec.Memory) > 0 { - result.Memory = parseQuantityNoError(ctx, "project", "memory", spec.Memory) + result.Memory = parseQuantityNoError(ctx, spec.Memory) } if len(spec.Storage) > 0 { - result.Storage = parseQuantityNoError(ctx, "project", "storage", spec.Storage) + result.Storage = parseQuantityNoError(ctx, spec.Storage) } if len(spec.EphemeralStorage) > 0 { - result.EphemeralStorage = parseQuantityNoError(ctx, "project", "ephemeral storage", spec.EphemeralStorage) + result.EphemeralStorage = parseQuantityNoError(ctx, spec.EphemeralStorage) } if len(spec.Gpu) > 0 { - result.GPU = parseQuantityNoError(ctx, "project", "gpu", spec.Gpu) + result.GPU = parseQuantityNoError(ctx, spec.Gpu) } return result } -// GetTaskResources returns the most specific default and limit task resources for the specified id. This first checks -// if there is a matchable resource(s) defined, and uses the highest priority one, otherwise it falls back to using the -// flyteadmin default configured values. +// GetTaskResources returns a merged set of all the requests, and limits for the given request. +// This will combine all layers matched and merge missing resource types. That is, if CPU is set at the project level +// and memory is set at the project/domain/workflow level, this will return both. +// Admin default system wide configuration is also merged in. func GetTaskResources(ctx context.Context, id *core.Identifier, resourceManager interfaces.ResourceInterface, taskResourceConfig runtimeInterfaces.TaskResourceConfiguration) workflowengineInterfaces.TaskResources { @@ -98,23 +168,172 @@ func GetTaskResources(ctx context.Context, id *core.Identifier, resourceManager request.Workflow = id.Name } - resource, err := resourceManager.GetResource(ctx, request) + var attrs = make([]admin.TaskResourceAttributes, 0) + + // Get list of all task resources. + resrc, err := resourceManager.GetResourcesList(ctx, request) if err != nil { - logger.Warningf(ctx, "Failed to fetch override values when assigning task resource default values for [%+v]: %v", - id, err) + ec, ok := err.(errors.FlyteAdminError) + if ok && ec.Code() == codes.NotFound { + logger.Debug(ctx, "HandleGetTaskResourceRequest did not find any task resources, falling back") + } else { + logger.Warningf(ctx, "Failed to fetch override values when assigning task resource default values for [%+v]: %v", + id, err) + } + } else if resrc != nil { + logger.Debugf(ctx, "GetTaskResources returned [%d] task resources, combining with config", len(resrc.AttributeList)) + for _, rr := range resrc.AttributeList { + if rr.GetTaskResourceAttributes() != nil { + attrs = append(attrs, *rr.GetTaskResourceAttributes()) + } + } } - logger.Debugf(ctx, "Assigning task requested resources for [%+v]", id) + attrs = append(attrs, taskResourceConfig.GetAsAttribute()) + responseAttributes := MergeDownTaskResources(attrs...) + var taskResourceAttributes = workflowengineInterfaces.TaskResources{} - if resource != nil && resource.Attributes != nil && resource.Attributes.GetTaskResourceAttributes() != nil { - taskResourceAttributes.Defaults = fromAdminProtoTaskResourceSpec(ctx, resource.Attributes.GetTaskResourceAttributes().Defaults) - taskResourceAttributes.Limits = fromAdminProtoTaskResourceSpec(ctx, resource.Attributes.GetTaskResourceAttributes().Limits) - } else { - taskResourceAttributes = workflowengineInterfaces.TaskResources{ - Defaults: taskResourceConfig.GetDefaults(), - Limits: taskResourceConfig.GetLimits(), - } + + if responseAttributes.GetLimits() != nil { + taskResourceAttributes.Limits = fromAdminProtoTaskResourceSpec(ctx, responseAttributes.GetLimits()) + } + if responseAttributes.GetDefaults() != nil { + taskResourceAttributes.Defaults = fromAdminProtoTaskResourceSpec(ctx, responseAttributes.GetDefaults()) } return taskResourceAttributes } + +// MergeTaskResourceSpec merges two TaskResourceSpecs. No notion of quantity comparison, just whether or not fields +// are empty strings. +func MergeTaskResourceSpec(high, low *admin.TaskResourceSpec) *admin.TaskResourceSpec { + if high == nil && low == nil { + return nil + } else if high == nil && low != nil { + // Return nil if all fields are empty strings. This is just done for this case, preserving the behavior that if + // the higher priority thing is nil, an empty lower priority object shouldn't make it non-nil. + if low.Cpu == "" && low.Gpu == "" && low.Memory == "" && low.EphemeralStorage == "" { + return nil + } + res := proto.Clone(low).(*admin.TaskResourceSpec) + return res + } else if high != nil && low == nil { + res := proto.Clone(high).(*admin.TaskResourceSpec) + return res + } + + res := proto.Clone(high).(*admin.TaskResourceSpec) + if res.GetCpu() == "" && low.GetCpu() != "" { + res.Cpu = low.Cpu + } + if res.GetGpu() == "" && low.GetGpu() != "" { + res.Gpu = low.Gpu + } + if res.GetMemory() == "" && low.GetMemory() != "" { + res.Memory = low.Memory + } + if res.GetEphemeralStorage() == "" && low.GetEphemeralStorage() != "" { + res.EphemeralStorage = low.EphemeralStorage + } + return res +} + +// MergeTaskResourceAttributes will merge without error, taking non-empty strings from lower priority +// and filling in missing higher priority fields. +func MergeTaskResourceAttributes(high, low admin.TaskResourceAttributes) admin.TaskResourceAttributes { + res := proto.Clone(&high).(*admin.TaskResourceAttributes) + res.Defaults = MergeTaskResourceSpec(high.GetDefaults(), low.GetDefaults()) + res.Limits = MergeTaskResourceSpec(high.GetLimits(), low.GetLimits()) + return *res +} + +// ConstrainTaskResourceSpec takes two TaskResourceSpecs and returns a new one, limiting the first argument, to the +// values of the second arg (maxes), for each resource type, if it exists and is non-zero. A zero is taken to mean +// no limit. This function parses the strings into resource.Quantity objects, and compares them using k8s Cmp. +func ConstrainTaskResourceSpec(spec admin.TaskResourceSpec, maxes admin.TaskResourceSpec) admin.TaskResourceSpec { + res := proto.Clone(&spec).(*admin.TaskResourceSpec) + if maxes.GetCpu() != "" && spec.GetCpu() != "" { + maxCPU := resource.MustParse(maxes.GetCpu()) + specCPU := resource.MustParse(spec.GetCpu()) + if specCPU.Cmp(maxCPU) == 1 && !maxCPU.IsZero() { + res.Cpu = maxes.GetCpu() + } + } + if maxes.GetGpu() != "" && spec.GetGpu() != "" { + maxGpu := resource.MustParse(maxes.GetGpu()) + specGpu := resource.MustParse(spec.GetGpu()) + if specGpu.Cmp(maxGpu) == 1 && !maxGpu.IsZero() { + res.Gpu = maxes.GetGpu() + } + } + if maxes.GetMemory() != "" && spec.GetMemory() != "" { + maxMemory := resource.MustParse(maxes.GetMemory()) + specMemory := resource.MustParse(spec.GetMemory()) + if specMemory.Cmp(maxMemory) == 1 && !maxMemory.IsZero() { + res.Memory = maxes.GetMemory() + } + } + if maxes.GetEphemeralStorage() != "" && spec.GetEphemeralStorage() != "" { + maxEphemeralStorage := resource.MustParse(maxes.GetEphemeralStorage()) + specEphemeralStorage := resource.MustParse(spec.GetEphemeralStorage()) + if specEphemeralStorage.Cmp(maxEphemeralStorage) == 1 && !maxEphemeralStorage.IsZero() { + res.EphemeralStorage = maxes.GetEphemeralStorage() + } + } + return *res +} + +func quantityToString(q *resource.Quantity) string { + if q == nil { + return "" + } + return q.String() +} + +// ConstrainTaskResourceSet is the same as ConstrainTaskResourceSpec, but for TaskResourceSet. +// Converts to TaskResourceSpec, and then calls the limiting function for that, and convert back. +func ConstrainTaskResourceSet(ctx context.Context, spec runtimeInterfaces.TaskResourceSet, maxes runtimeInterfaces.TaskResourceSet) runtimeInterfaces.TaskResourceSet { + + specAsResourceSpec := admin.TaskResourceSpec{ + Cpu: quantityToString(spec.CPU), + Gpu: quantityToString(spec.GPU), + Memory: quantityToString(spec.Memory), + EphemeralStorage: quantityToString(spec.EphemeralStorage), + } + maxesAsResourceSpec := admin.TaskResourceSpec{ + Cpu: quantityToString(maxes.CPU), + Gpu: quantityToString(maxes.GPU), + Memory: quantityToString(maxes.Memory), + EphemeralStorage: quantityToString(maxes.EphemeralStorage), + } + + r := ConstrainTaskResourceSpec(specAsResourceSpec, maxesAsResourceSpec) + resourceSet := fromAdminProtoTaskResourceSpec(ctx, &r) + return resourceSet +} + +func ConformLimits(attr admin.TaskResourceAttributes) admin.TaskResourceAttributes { + maxes := admin.TaskResourceSpec{} + if attr.GetLimits() != nil { + maxes = *attr.GetLimits() + } + if attr.GetDefaults() != nil { + x := ConstrainTaskResourceSpec(*attr.GetDefaults(), maxes) + attr.Defaults = &x + } + return attr +} + +// MergeDownTaskResources does not today check that the defaults are below the limits when setting, therefore +// go through the list from high to low priority, first merge the various types, and then resolve inconsistencies +// around quantities. +// - If set, must be limit >= default +func MergeDownTaskResources(highToLowPriorityTaskResourceAttributes ...admin.TaskResourceAttributes) *admin.TaskResourceAttributes { + // Merge each one down, checking each condition + merged := admin.TaskResourceAttributes{} + for _, attr := range highToLowPriorityTaskResourceAttributes { + merged = MergeTaskResourceAttributes(merged, attr) + } + merged = ConformLimits(merged) + return &merged +} diff --git a/pkg/manager/impl/util/resources_test.go b/pkg/manager/impl/util/resources_test.go index f4180c4b5..0d0465d4e 100644 --- a/pkg/manager/impl/util/resources_test.go +++ b/pkg/manager/impl/util/resources_test.go @@ -26,21 +26,24 @@ var workflowIdentifier = core.Identifier{ Version: "version", } +// GetPtr because golang +func GetPtr(quantity resource.Quantity) *resource.Quantity { + return &quantity +} + func TestGetTaskResources(t *testing.T) { taskConfig := runtimeMocks.MockTaskResourceConfiguration{} taskConfig.Defaults = runtimeInterfaces.TaskResourceSet{ - CPU: resource.MustParse("200m"), - GPU: resource.MustParse("8"), - Memory: resource.MustParse("200Gi"), - EphemeralStorage: resource.MustParse("500Mi"), - Storage: resource.MustParse("400Mi"), + CPU: GetPtr(resource.MustParse("200m")), + GPU: GetPtr(resource.MustParse("8")), + Memory: GetPtr(resource.MustParse("200Gi")), + EphemeralStorage: GetPtr(resource.MustParse("500Mi")), } taskConfig.Limits = runtimeInterfaces.TaskResourceSet{ - CPU: resource.MustParse("300m"), - GPU: resource.MustParse("8"), - Memory: resource.MustParse("500Gi"), - EphemeralStorage: resource.MustParse("501Mi"), - Storage: resource.MustParse("450Mi"), + CPU: GetPtr(resource.MustParse("300m")), + GPU: GetPtr(resource.MustParse("8")), + Memory: GetPtr(resource.MustParse("500Gi")), + EphemeralStorage: GetPtr(resource.MustParse("501Mi")), } t.Run("use runtime application values", func(t *testing.T) { @@ -57,50 +60,50 @@ func TestGetTaskResources(t *testing.T) { } taskResourceAttrs := GetTaskResources(context.TODO(), &workflowIdentifier, &resourceManager, &taskConfig) + assert.EqualValues(t, taskResourceAttrs, workflowengineInterfaces.TaskResources{ Defaults: runtimeInterfaces.TaskResourceSet{ - CPU: resource.MustParse("200m"), - GPU: resource.MustParse("8"), - Memory: resource.MustParse("200Gi"), - EphemeralStorage: resource.MustParse("500Mi"), - Storage: resource.MustParse("400Mi"), + CPU: GetPtr(resource.MustParse("200m")), + GPU: GetPtr(resource.MustParse("8")), + Memory: GetPtr(resource.MustParse("200Gi")), + EphemeralStorage: GetPtr(resource.MustParse("500Mi")), }, Limits: runtimeInterfaces.TaskResourceSet{ - CPU: resource.MustParse("300m"), - GPU: resource.MustParse("8"), - Memory: resource.MustParse("500Gi"), - EphemeralStorage: resource.MustParse("501Mi"), - Storage: resource.MustParse("450Mi"), + CPU: GetPtr(resource.MustParse("300m")), + GPU: GetPtr(resource.MustParse("8")), + Memory: GetPtr(resource.MustParse("500Gi")), + EphemeralStorage: GetPtr(resource.MustParse("501Mi")), }, }) }) + t.Run("use specific overrides", func(t *testing.T) { resourceManager := managerMocks.MockResourceManager{} - resourceManager.GetResourceFunc = func(ctx context.Context, - request managerInterfaces.ResourceRequest) (*managerInterfaces.ResourceResponse, error) { + resourceManager.GetResourcesListFunc = func(ctx context.Context, + request managerInterfaces.ResourceRequest) (*managerInterfaces.ResourceResponseList, error) { assert.EqualValues(t, request, managerInterfaces.ResourceRequest{ Project: workflowIdentifier.Project, Domain: workflowIdentifier.Domain, Workflow: workflowIdentifier.Name, ResourceType: admin.MatchableResource_TASK_RESOURCE, }) - return &managerInterfaces.ResourceResponse{ - Attributes: &admin.MatchingAttributes{ - Target: &admin.MatchingAttributes_TaskResourceAttributes{ - TaskResourceAttributes: &admin.TaskResourceAttributes{ - Defaults: &admin.TaskResourceSpec{ - Cpu: "1200m", - Gpu: "18", - Memory: "1200Gi", - EphemeralStorage: "1500Mi", - Storage: "1400Mi", - }, - Limits: &admin.TaskResourceSpec{ - Cpu: "300m", - Gpu: "8", - Memory: "500Gi", - EphemeralStorage: "501Mi", - Storage: "450Mi", + return &managerInterfaces.ResourceResponseList{ + AttributeList: []*admin.MatchingAttributes{ + { + Target: &admin.MatchingAttributes_TaskResourceAttributes{ + TaskResourceAttributes: &admin.TaskResourceAttributes{ + Defaults: &admin.TaskResourceSpec{ + Cpu: "1200m", + Gpu: "18", + Memory: "1200Gi", + EphemeralStorage: "1500Mi", + }, + Limits: &admin.TaskResourceSpec{ + Cpu: "300m", + Gpu: "8", + Memory: "500Gi", + EphemeralStorage: "501Mi", + }, }, }, }, @@ -110,18 +113,16 @@ func TestGetTaskResources(t *testing.T) { taskResourceAttrs := GetTaskResources(context.TODO(), &workflowIdentifier, &resourceManager, &taskConfig) assert.EqualValues(t, taskResourceAttrs, workflowengineInterfaces.TaskResources{ Defaults: runtimeInterfaces.TaskResourceSet{ - CPU: resource.MustParse("1200m"), - GPU: resource.MustParse("18"), - Memory: resource.MustParse("1200Gi"), - EphemeralStorage: resource.MustParse("1500Mi"), - Storage: resource.MustParse("1400Mi"), + CPU: GetPtr(resource.MustParse("300m")), + GPU: GetPtr(resource.MustParse("8")), + Memory: GetPtr(resource.MustParse("500Gi")), + EphemeralStorage: GetPtr(resource.MustParse("501Mi")), }, Limits: runtimeInterfaces.TaskResourceSet{ - CPU: resource.MustParse("300m"), - GPU: resource.MustParse("8"), - Memory: resource.MustParse("500Gi"), - EphemeralStorage: resource.MustParse("501Mi"), - Storage: resource.MustParse("450Mi"), + CPU: GetPtr(resource.MustParse("300m")), + GPU: GetPtr(resource.MustParse("8")), + Memory: GetPtr(resource.MustParse("500Gi")), + EphemeralStorage: GetPtr(resource.MustParse("501Mi")), }, }) }) @@ -131,21 +132,19 @@ func TestFromAdminProtoTaskResourceSpec(t *testing.T) { taskResourceSet := fromAdminProtoTaskResourceSpec(context.TODO(), &admin.TaskResourceSpec{ Cpu: "1", Memory: "100", - Storage: "200", EphemeralStorage: "300", Gpu: "2", }) assert.EqualValues(t, runtimeInterfaces.TaskResourceSet{ - CPU: resource.MustParse("1"), - Memory: resource.MustParse("100"), - Storage: resource.MustParse("200"), - EphemeralStorage: resource.MustParse("300"), - GPU: resource.MustParse("2"), + CPU: GetPtr(resource.MustParse("1")), + Memory: GetPtr(resource.MustParse("100")), + EphemeralStorage: GetPtr(resource.MustParse("300")), + GPU: GetPtr(resource.MustParse("2")), }, taskResourceSet) } func TestGetTaskResourcesAsSet(t *testing.T) { - taskResources := getTaskResourcesAsSet(context.TODO(), &core.Identifier{}, []*core.Resources_ResourceEntry{ + taskResources := getTaskResourcesAsSet(context.TODO(), []*core.Resources_ResourceEntry{ { Name: core.Resources_CPU, Value: "100", @@ -162,68 +161,11 @@ func TestGetTaskResourcesAsSet(t *testing.T) { Name: core.Resources_GPU, Value: "400", }, - }, "request") + }) assert.True(t, taskResources.CPU.Equal(resource.MustParse("100"))) assert.True(t, taskResources.Memory.Equal(resource.MustParse("200"))) assert.True(t, taskResources.EphemeralStorage.Equal(resource.MustParse("300"))) assert.True(t, taskResources.GPU.Equal(resource.MustParse("400"))) } -func TestGetCompleteTaskResourceRequirements(t *testing.T) { - taskResources := GetCompleteTaskResourceRequirements(context.TODO(), &core.Identifier{}, &core.CompiledTask{ - Template: &core.TaskTemplate{ - Target: &core.TaskTemplate_Container{ - Container: &core.Container{ - Resources: &core.Resources{ - Requests: []*core.Resources_ResourceEntry{ - { - Name: core.Resources_CPU, - Value: "100", - }, - { - Name: core.Resources_MEMORY, - Value: "200", - }, - { - Name: core.Resources_EPHEMERAL_STORAGE, - Value: "300", - }, - { - Name: core.Resources_GPU, - Value: "400", - }, - }, - Limits: []*core.Resources_ResourceEntry{ - { - Name: core.Resources_CPU, - Value: "200", - }, - { - Name: core.Resources_MEMORY, - Value: "400", - }, - { - Name: core.Resources_EPHEMERAL_STORAGE, - Value: "600", - }, - { - Name: core.Resources_GPU, - Value: "800", - }, - }, - }, - }, - }, - }, - }) - - assert.True(t, taskResources.Defaults.CPU.Equal(resource.MustParse("100"))) - assert.True(t, taskResources.Defaults.Memory.Equal(resource.MustParse("200"))) - assert.True(t, taskResources.Defaults.EphemeralStorage.Equal(resource.MustParse("300"))) - assert.True(t, taskResources.Defaults.GPU.Equal(resource.MustParse("400"))) - - assert.True(t, taskResources.Limits.CPU.Equal(resource.MustParse("200"))) - assert.True(t, taskResources.Limits.Memory.Equal(resource.MustParse("400"))) - assert.True(t, taskResources.Limits.EphemeralStorage.Equal(resource.MustParse("600"))) - assert.True(t, taskResources.Limits.GPU.Equal(resource.MustParse("800"))) -} +// TODO add test with default limits diff --git a/pkg/manager/impl/util/shared.go b/pkg/manager/impl/util/shared.go index bf9490473..feb6e0931 100644 --- a/pkg/manager/impl/util/shared.go +++ b/pkg/manager/impl/util/shared.go @@ -288,6 +288,7 @@ func GetMatchableResource(ctx context.Context, resourceManager interfaces.Resour // new object with the merged changes. // After settings project is done, can move this function back to execution manager. Currently shared with resource. func MergeIntoExecConfig(workflowExecConfig admin.WorkflowExecutionConfig, spec shared.WorkflowExecutionConfigInterface) admin.WorkflowExecutionConfig { + if workflowExecConfig.GetMaxParallelism() == 0 && spec.GetMaxParallelism() > 0 { workflowExecConfig.MaxParallelism = spec.GetMaxParallelism() } diff --git a/pkg/manager/impl/validation/task_validator.go b/pkg/manager/impl/validation/task_validator.go index c8625ec4b..3c2573abe 100644 --- a/pkg/manager/impl/validation/task_validator.go +++ b/pkg/manager/impl/validation/task_validator.go @@ -125,18 +125,19 @@ func ValidateTask( func taskResourceSetToMap( resourceSet runtimeInterfaces.TaskResourceSet) map[core.Resources_ResourceName]*resource.Quantity { + resourceMap := make(map[core.Resources_ResourceName]*resource.Quantity) - if !resourceSet.CPU.IsZero() { - resourceMap[core.Resources_CPU] = &resourceSet.CPU + if resourceSet.CPU != nil && !resourceSet.CPU.IsZero() { + resourceMap[core.Resources_CPU] = resourceSet.CPU } - if !resourceSet.Memory.IsZero() { - resourceMap[core.Resources_MEMORY] = &resourceSet.Memory + if resourceSet.Memory != nil && !resourceSet.Memory.IsZero() { + resourceMap[core.Resources_MEMORY] = resourceSet.Memory } - if !resourceSet.GPU.IsZero() { - resourceMap[core.Resources_GPU] = &resourceSet.GPU + if resourceSet.GPU != nil && !resourceSet.GPU.IsZero() { + resourceMap[core.Resources_GPU] = resourceSet.GPU } - if !resourceSet.EphemeralStorage.IsZero() { - resourceMap[core.Resources_EPHEMERAL_STORAGE] = &resourceSet.EphemeralStorage + if resourceSet.EphemeralStorage != nil && !resourceSet.EphemeralStorage.IsZero() { + resourceMap[core.Resources_EPHEMERAL_STORAGE] = resourceSet.EphemeralStorage } return resourceMap } diff --git a/pkg/manager/impl/validation/task_validator_test.go b/pkg/manager/impl/validation/task_validator_test.go index 78ec0309c..df60a80cd 100644 --- a/pkg/manager/impl/validation/task_validator_test.go +++ b/pkg/manager/impl/validation/task_validator_test.go @@ -23,9 +23,9 @@ import ( func getMockTaskResources() workflowengineInterfaces.TaskResources { return workflowengineInterfaces.TaskResources{ Limits: runtimeInterfaces.TaskResourceSet{ - Memory: resource.MustParse("500Mi"), - CPU: resource.MustParse("200m"), - GPU: resource.MustParse("8"), + Memory: testutils.GetPtr(resource.MustParse("500Mi")), + CPU: testutils.GetPtr(resource.MustParse("200m")), + GPU: testutils.GetPtr(resource.MustParse("8")), }, } } @@ -227,10 +227,10 @@ func TestValidateTaskTypeWhitelist(t *testing.T) { func TestTaskResourceSetToMap(t *testing.T) { resourceSet := runtimeInterfaces.TaskResourceSet{ - CPU: resource.MustParse("100Mi"), - GPU: resource.MustParse("2"), - Memory: resource.MustParse("1.5Gi"), - EphemeralStorage: resource.MustParse("500Mi"), + CPU: testutils.GetPtr(resource.MustParse("100Mi")), + GPU: testutils.GetPtr(resource.MustParse("2")), + Memory: testutils.GetPtr(resource.MustParse("1.5Gi")), + EphemeralStorage: testutils.GetPtr(resource.MustParse("500Mi")), } resourceSetMap := taskResourceSetToMap(resourceSet) assert.Len(t, resourceSetMap, 4) @@ -394,7 +394,7 @@ func TestValidateTaskResources_LimitGreaterThanConfig(t *testing.T) { err := validateTaskResources(&core.Identifier{ Name: "name", }, runtimeInterfaces.TaskResourceSet{ - CPU: resource.MustParse("1Gi"), + CPU: testutils.GetPtr(resource.MustParse("1Gi")), }, []*core.Resources_ResourceEntry{ { @@ -414,7 +414,7 @@ func TestValidateTaskResources_DefaultGreaterThanConfig(t *testing.T) { err := validateTaskResources(&core.Identifier{ Name: "name", }, runtimeInterfaces.TaskResourceSet{ - CPU: resource.MustParse("1Gi"), + CPU: testutils.GetPtr(resource.MustParse("1Gi")), }, []*core.Resources_ResourceEntry{ { @@ -448,7 +448,7 @@ func TestValidateTaskResources_GPULimitGreaterThanConfig(t *testing.T) { err := validateTaskResources(&core.Identifier{ Name: "name", }, runtimeInterfaces.TaskResourceSet{ - GPU: resource.MustParse("1"), + GPU: testutils.GetPtr(resource.MustParse("1")), }, []*core.Resources_ResourceEntry{ { @@ -468,7 +468,7 @@ func TestValidateTaskResources_GPUDefaultGreaterThanConfig(t *testing.T) { err := validateTaskResources(&core.Identifier{ Name: "name", }, runtimeInterfaces.TaskResourceSet{ - GPU: resource.MustParse("1"), + GPU: testutils.GetPtr(resource.MustParse("1")), }, []*core.Resources_ResourceEntry{ { diff --git a/pkg/manager/interfaces/resource.go b/pkg/manager/interfaces/resource.go index 42d4a3c9b..8638762a4 100644 --- a/pkg/manager/interfaces/resource.go +++ b/pkg/manager/interfaces/resource.go @@ -11,6 +11,7 @@ type ResourceInterface interface { ListAll(ctx context.Context, request admin.ListMatchableAttributesRequest) ( *admin.ListMatchableAttributesResponse, error) GetResource(ctx context.Context, request ResourceRequest) (*ResourceResponse, error) + GetResourcesList(ctx context.Context, request ResourceRequest) (*ResourceResponseList, error) UpdateProjectAttributes(ctx context.Context, request admin.ProjectAttributesUpdateRequest) ( *admin.ProjectAttributesUpdateResponse, error) @@ -51,3 +52,12 @@ type ResourceResponse struct { ResourceType string Attributes *admin.MatchingAttributes } + +type ResourceResponseList struct { + Project string + Domain string + Workflow string + LaunchPlan string + ResourceType string + AttributeList []*admin.MatchingAttributes +} diff --git a/pkg/manager/mocks/resource.go b/pkg/manager/mocks/resource.go index 1339e10d7..5e67e3430 100644 --- a/pkg/manager/mocks/resource.go +++ b/pkg/manager/mocks/resource.go @@ -24,6 +24,7 @@ type DeleteProjectDomainFunc func(ctx context.Context, request admin.ProjectDoma type ListResourceFunc func(ctx context.Context, request admin.ListMatchableAttributesRequest) ( *admin.ListMatchableAttributesResponse, error) type GetResourceFunc func(ctx context.Context, request interfaces.ResourceRequest) (*interfaces.ResourceResponse, error) +type GetResourcesListFunc func(ctx context.Context, request interfaces.ResourceRequest) (*interfaces.ResourceResponseList, error) type MockResourceManager struct { updateProjectDomainFunc UpdateProjectDomainFunc @@ -31,6 +32,7 @@ type MockResourceManager struct { DeleteFunc DeleteProjectDomainFunc ListFunc ListResourceFunc GetResourceFunc GetResourceFunc + GetResourcesListFunc GetResourcesListFunc updateProjectAttrsFunc UpdateProjectAttrsFunc getProjectAttrFunc GetProjectAttrFunc deleteProjectAttrFunc DeleteProjectAttrFunc @@ -43,6 +45,13 @@ func (m *MockResourceManager) GetResource(ctx context.Context, request interface return nil, nil } +func (m *MockResourceManager) GetResourcesList(ctx context.Context, request interfaces.ResourceRequest) (*interfaces.ResourceResponseList, error) { + if m.GetResourcesListFunc != nil { + return m.GetResourcesListFunc(ctx, request) + } + return nil, nil +} + func (m *MockResourceManager) UpdateWorkflowAttributes(ctx context.Context, request admin.WorkflowAttributesUpdateRequest) ( *admin.WorkflowAttributesUpdateResponse, error) { panic("implement me") diff --git a/pkg/repositories/database.go b/pkg/repositories/database.go index fdc663ac6..b86dfae03 100644 --- a/pkg/repositories/database.go +++ b/pkg/repositories/database.go @@ -61,6 +61,7 @@ func GetDB(ctx context.Context, dbConfig *database.DbConfig, logConfig *logger.C if dbConfig == nil { panic("Cannot initialize database repository from empty db config") } + gormConfig := &gorm.Config{ Logger: database.GetGormLogger(ctx, logConfig), DisableForeignKeyConstraintWhenMigrating: !dbConfig.EnableForeignKeyConstraintWhenMigrating, diff --git a/pkg/repositories/gormimpl/resource_repo.go b/pkg/repositories/gormimpl/resource_repo.go index 1f94abce5..9d22ed6e0 100644 --- a/pkg/repositories/gormimpl/resource_repo.go +++ b/pkg/repositories/gormimpl/resource_repo.go @@ -153,6 +153,51 @@ func (r *ResourceRepo) GetProjectLevel(ctx context.Context, ID interfaces.Resour return resources[0], nil } +// GetRows returns rows at the given specificity, and lower, in descending order of specificity. +// For example, if the resource ID has project, domain, and workflow, this will return the rows for an exact match, +// as well as, just project and just project and domain, in that order. +// Get returns the most-specific attribute setting for the given ResourceType. +func (r *ResourceRepo) GetRows(ctx context.Context, ID interfaces.ResourceID) ([]models.Resource, error) { + if ID.ResourceType == "" { + return []models.Resource{}, r.errorTransformer.ToFlyteAdminError(flyteAdminDbErrors.GetInvalidInputError(fmt.Sprintf("%v", ID))) + } + var resources []models.Resource + timer := r.metrics.GetDuration.Start() + + txWhereClause := "resource_type = ? AND domain IN (?) AND project IN (?) AND workflow IN (?) AND launch_plan IN (?)" + project := []string{""} + if ID.Project != "" { + project = append(project, ID.Project) + } + + domain := []string{""} + if ID.Domain != "" { + domain = append(domain, ID.Domain) + } + + workflow := []string{""} + if ID.Workflow != "" { + workflow = append(workflow, ID.Workflow) + } + + launchPlan := []string{""} + if ID.LaunchPlan != "" { + launchPlan = append(launchPlan, ID.LaunchPlan) + } + + tx := r.db.Where(txWhereClause, ID.ResourceType, domain, project, workflow, launchPlan) + tx.Order(priorityDescending).Find(&resources) + timer.Stop() + + if (tx.Error != nil && errors.Is(tx.Error, gorm.ErrRecordNotFound)) || len(resources) == 0 { + return []models.Resource{}, flyteAdminErrors.NewFlyteAdminErrorf(codes.NotFound, + "Resource [%+v] not found", ID) + } else if tx.Error != nil { + return []models.Resource{}, r.errorTransformer.ToFlyteAdminError(tx.Error) + } + return resources, nil +} + func (r *ResourceRepo) GetRaw(ctx context.Context, ID interfaces.ResourceID) (models.Resource, error) { if ID.Domain == "" || ID.ResourceType == "" { return models.Resource{}, r.errorTransformer.ToFlyteAdminError(flyteAdminDbErrors.GetInvalidInputError(fmt.Sprintf("%v", ID))) diff --git a/pkg/repositories/interfaces/resource_repo.go b/pkg/repositories/interfaces/resource_repo.go index 5ffbe0127..8e18dedff 100644 --- a/pkg/repositories/interfaces/resource_repo.go +++ b/pkg/repositories/interfaces/resource_repo.go @@ -7,18 +7,20 @@ import ( ) type ResourceRepoInterface interface { - // Inserts or updates an existing Type model into the database store. + // CreateOrUpdate inserts or updates an existing Type model into the database store. CreateOrUpdate(ctx context.Context, input models.Resource) error - // Returns a matching Type model based on hierarchical resolution. + // Get returns a matching Type model based on hierarchical resolution. Get(ctx context.Context, ID ResourceID) (models.Resource, error) - // Returns a matching Type model. + // GetRows returns a matching Type model based on hierarchical resolution. + GetRows(ctx context.Context, ID ResourceID) ([]models.Resource, error) + // GetRaw returns a matching Type model. GetRaw(ctx context.Context, ID ResourceID) (models.Resource, error) // GetProjectLevel returns the Project level resource entry, if any, even if there is a higher // specificity resource. GetProjectLevel(ctx context.Context, ID ResourceID) (models.Resource, error) - // Lists all resources + // ListAll resources ListAll(ctx context.Context, resourceType string) ([]models.Resource, error) - // Deletes a matching Type model when it exists. + // Delete a matching Type model when it exists. Delete(ctx context.Context, ID ResourceID) error } diff --git a/pkg/repositories/mocks/resource.go b/pkg/repositories/mocks/resource.go index 0def6351e..d31d84abe 100644 --- a/pkg/repositories/mocks/resource.go +++ b/pkg/repositories/mocks/resource.go @@ -10,12 +10,15 @@ import ( type CreateOrUpdateResourceFunction func(ctx context.Context, input models.Resource) error type GetResourceFunction func(ctx context.Context, ID interfaces.ResourceID) ( models.Resource, error) +type GetRowsFunction func(ctx context.Context, ID interfaces.ResourceID) ( + []models.Resource, error) type ListAllResourcesFunction func(ctx context.Context, resourceType string) ([]models.Resource, error) type DeleteResourceFunction func(ctx context.Context, ID interfaces.ResourceID) error type MockResourceRepo struct { CreateOrUpdateFunction CreateOrUpdateResourceFunction GetFunction GetResourceFunction + GetRowsFunction GetRowsFunction DeleteFunction DeleteResourceFunction ListAllFunction ListAllResourcesFunction } @@ -35,6 +38,15 @@ func (r *MockResourceRepo) Get(ctx context.Context, ID interfaces.ResourceID) ( return models.Resource{}, nil } +func (r *MockResourceRepo) GetRows(ctx context.Context, ID interfaces.ResourceID) ( + []models.Resource, error) { + + if r.GetRowsFunction != nil { + return r.GetRowsFunction(ctx, ID) + } + return []models.Resource{}, nil +} + func (r *MockResourceRepo) GetRaw(ctx context.Context, ID interfaces.ResourceID) ( models.Resource, error) { if r.GetFunction != nil { diff --git a/pkg/rpc/adminservice/base.go b/pkg/rpc/adminservice/base.go index 77c78f480..edaeb4b2e 100644 --- a/pkg/rpc/adminservice/base.go +++ b/pkg/rpc/adminservice/base.go @@ -176,7 +176,7 @@ func NewAdminServer(ctx context.Context, pluginRegistry *plugins.Registry, confi NodeExecutionManager: nodeExecutionManager, TaskExecutionManager: taskExecutionManager, ProjectManager: manager.NewProjectManager(repo, configuration), - ResourceManager: resources.NewResourceManager(repo, configuration.ApplicationConfiguration()), + ResourceManager: resources.NewResourceManager(repo, configuration), MetricsManager: manager.NewMetricsManager(workflowManager, executionManager, nodeExecutionManager, taskExecutionManager, adminScope.NewSubScope("metrics_manager")), Metrics: InitMetrics(adminScope), diff --git a/pkg/runtime/application_config_provider.go b/pkg/runtime/application_config_provider.go index 3b8b0a270..bbc7d0ebf 100644 --- a/pkg/runtime/application_config_provider.go +++ b/pkg/runtime/application_config_provider.go @@ -3,6 +3,7 @@ package runtime import ( "github.com/flyteorg/flyteadmin/pkg/common" "github.com/flyteorg/flyteadmin/pkg/runtime/interfaces" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" "github.com/flyteorg/flytestdlib/config" "github.com/flyteorg/flytestdlib/contextutils" "github.com/flyteorg/flytestdlib/database" @@ -119,6 +120,36 @@ func (p *ApplicationConfigurationProvider) GetCloudEventsConfig() *interfaces.Cl return cloudEventsConfig.GetConfig().(*interfaces.CloudEventsConfig) } +// GetAsWorkflowExecutionAttribute returns the WorkflowExecutionConfig as extracted from the base system configuration +// admin has been loaded with. +func (p *ApplicationConfigurationProvider) GetAsWorkflowExecutionAttribute() admin.WorkflowExecutionConfig { + // These values should always be set as their fallback values equals to their zero value or nil, + // providing a sensible default even if the actual value was not set. + a := p.GetTopLevelConfig() + + wec := admin.WorkflowExecutionConfig{ + MaxParallelism: a.GetMaxParallelism(), + OverwriteCache: a.GetOverwriteCache(), + Interruptible: a.GetInterruptible(), + } + + // For the others, we only add the field when the field is set in the config. + if a.GetSecurityContext().RunAs.GetK8SServiceAccount() != "" || a.GetSecurityContext().RunAs.GetIamRole() != "" { + wec.SecurityContext = a.GetSecurityContext() + } + if a.GetRawOutputDataConfig().OutputLocationPrefix != "" { + wec.RawOutputDataConfig = a.GetRawOutputDataConfig() + } + if len(a.GetLabels().Values) > 0 { + wec.Labels = a.GetLabels() + } + if len(a.GetAnnotations().Values) > 0 { + wec.Annotations = a.GetAnnotations() + } + + return wec +} + func NewApplicationConfigurationProvider() interfaces.ApplicationConfiguration { return &ApplicationConfigurationProvider{} } diff --git a/pkg/runtime/interfaces/application_configuration.go b/pkg/runtime/interfaces/application_configuration.go index 16b1f921d..2e715b1cb 100644 --- a/pkg/runtime/interfaces/application_configuration.go +++ b/pkg/runtime/interfaces/application_configuration.go @@ -166,33 +166,6 @@ func (a *ApplicationConfig) GetOverwriteCache() bool { return a.OverwriteCache } -// GetAsWorkflowExecutionConfig returns the WorkflowExecutionConfig as extracted from this object -func (a *ApplicationConfig) GetAsWorkflowExecutionConfig() admin.WorkflowExecutionConfig { - // These values should always be set as their fallback values equals to their zero value or nil, - // providing a sensible default even if the actual value was not set. - wec := admin.WorkflowExecutionConfig{ - MaxParallelism: a.GetMaxParallelism(), - OverwriteCache: a.GetOverwriteCache(), - Interruptible: a.GetInterruptible(), - } - - // For the others, we only add the field when the field is set in the config. - if a.GetSecurityContext().RunAs.GetK8SServiceAccount() != "" || a.GetSecurityContext().RunAs.GetIamRole() != "" { - wec.SecurityContext = a.GetSecurityContext() - } - if a.GetRawOutputDataConfig().OutputLocationPrefix != "" { - wec.RawOutputDataConfig = a.GetRawOutputDataConfig() - } - if len(a.GetLabels().Values) > 0 { - wec.Labels = a.GetLabels() - } - if len(a.GetAnnotations().Values) > 0 { - wec.Annotations = a.GetAnnotations() - } - - return wec -} - // This section holds common config for AWS type AWSConfig struct { Region string `json:"region"` @@ -556,4 +529,5 @@ type ApplicationConfiguration interface { GetDomainsConfig() *DomainsConfig GetExternalEventsConfig() *ExternalEventsConfig GetCloudEventsConfig() *CloudEventsConfig + GetAsWorkflowExecutionAttribute() admin.WorkflowExecutionConfig } diff --git a/pkg/runtime/interfaces/task_resource_configuration.go b/pkg/runtime/interfaces/task_resource_configuration.go index 8c5565169..9c01b4bf7 100644 --- a/pkg/runtime/interfaces/task_resource_configuration.go +++ b/pkg/runtime/interfaces/task_resource_configuration.go @@ -1,17 +1,22 @@ package interfaces -import "k8s.io/apimachinery/pkg/api/resource" +import ( + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" + "k8s.io/apimachinery/pkg/api/resource" +) +// TaskResourceSet changed all this to pointers type TaskResourceSet struct { - CPU resource.Quantity `json:"cpu"` - GPU resource.Quantity `json:"gpu"` - Memory resource.Quantity `json:"memory"` - Storage resource.Quantity `json:"storage"` - EphemeralStorage resource.Quantity `json:"ephemeralStorage"` + CPU *resource.Quantity `json:"cpu"` + GPU *resource.Quantity `json:"gpu"` + Memory *resource.Quantity `json:"memory"` + Storage *resource.Quantity `json:"storage"` + EphemeralStorage *resource.Quantity `json:"ephemeralStorage"` } -// Provides default values for task resource limits and defaults. +// TaskResourceConfiguration provides default values for task resource limits and defaults. type TaskResourceConfiguration interface { GetDefaults() TaskResourceSet GetLimits() TaskResourceSet + GetAsAttribute() admin.TaskResourceAttributes } diff --git a/pkg/runtime/mocks/mock_application_provider.go b/pkg/runtime/mocks/mock_application_provider.go index 13079619c..0cab4a939 100644 --- a/pkg/runtime/mocks/mock_application_provider.go +++ b/pkg/runtime/mocks/mock_application_provider.go @@ -2,6 +2,7 @@ package mocks import ( "github.com/flyteorg/flyteadmin/pkg/runtime/interfaces" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" "github.com/flyteorg/flytestdlib/database" ) @@ -79,3 +80,29 @@ func (p *MockApplicationProvider) SetCloudEventsConfig(cloudEventConfig interfac func (p *MockApplicationProvider) GetCloudEventsConfig() *interfaces.CloudEventsConfig { return &p.cloudEventConfig } + +func (p *MockApplicationProvider) GetAsWorkflowExecutionAttribute() admin.WorkflowExecutionConfig { + a := p.GetTopLevelConfig() + + wec := admin.WorkflowExecutionConfig{ + MaxParallelism: a.GetMaxParallelism(), + OverwriteCache: a.GetOverwriteCache(), + Interruptible: a.GetInterruptible(), + } + + // For the others, we only add the field when the field is set in the config. + if a.GetSecurityContext().RunAs.GetK8SServiceAccount() != "" || a.GetSecurityContext().RunAs.GetIamRole() != "" { + wec.SecurityContext = a.GetSecurityContext() + } + if a.GetRawOutputDataConfig().OutputLocationPrefix != "" { + wec.RawOutputDataConfig = a.GetRawOutputDataConfig() + } + if len(a.GetLabels().Values) > 0 { + wec.Labels = a.GetLabels() + } + if len(a.GetAnnotations().Values) > 0 { + wec.Annotations = a.GetAnnotations() + } + + return wec +} diff --git a/pkg/runtime/mocks/mock_configuration_provider.go b/pkg/runtime/mocks/mock_configuration_provider.go index 7af3e2e35..8f9b28f00 100644 --- a/pkg/runtime/mocks/mock_configuration_provider.go +++ b/pkg/runtime/mocks/mock_configuration_provider.go @@ -34,6 +34,10 @@ func (p *MockConfigurationProvider) TaskResourceConfiguration() interfaces.TaskR return p.taskResourceConfiguration } +func (p *MockConfigurationProvider) SetTaskResourceConfiguration(t interfaces.TaskResourceConfiguration) { + p.taskResourceConfiguration = t +} + func (p *MockConfigurationProvider) WhitelistConfiguration() interfaces.WhitelistConfiguration { return p.whitelistConfiguration } diff --git a/pkg/runtime/mocks/mock_task_resource_provider.go b/pkg/runtime/mocks/mock_task_resource_provider.go index e1e99a45b..4d0ebcca1 100644 --- a/pkg/runtime/mocks/mock_task_resource_provider.go +++ b/pkg/runtime/mocks/mock_task_resource_provider.go @@ -1,12 +1,42 @@ package mocks -import "github.com/flyteorg/flyteadmin/pkg/runtime/interfaces" +import ( + "github.com/flyteorg/flyteadmin/pkg/runtime/interfaces" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" +) type MockTaskResourceConfiguration struct { Defaults interfaces.TaskResourceSet Limits interfaces.TaskResourceSet } +func (c *MockTaskResourceConfiguration) ConstructTaskResourceSpec(a interfaces.TaskResourceSet) admin.TaskResourceSpec { + res := admin.TaskResourceSpec{} + if a.CPU != nil { + res.Cpu = a.CPU.String() + } + if a.GPU != nil { + res.Gpu = a.GPU.String() + } + if a.Memory != nil { + res.Memory = a.Memory.String() + } + if a.EphemeralStorage != nil { + res.EphemeralStorage = a.EphemeralStorage.String() + } + return res +} + +func (c *MockTaskResourceConfiguration) GetAsAttribute() admin.TaskResourceAttributes { + defaults := c.ConstructTaskResourceSpec(c.GetDefaults()) + limits := c.ConstructTaskResourceSpec(c.GetLimits()) + + return admin.TaskResourceAttributes{ + Defaults: &defaults, + Limits: &limits, + } +} + func (c *MockTaskResourceConfiguration) GetDefaults() interfaces.TaskResourceSet { return c.Defaults } diff --git a/pkg/runtime/task_resource_provider.go b/pkg/runtime/task_resource_provider.go index d8baa2fe1..0ed9d6d76 100644 --- a/pkg/runtime/task_resource_provider.go +++ b/pkg/runtime/task_resource_provider.go @@ -2,22 +2,15 @@ package runtime import ( "github.com/flyteorg/flyteadmin/pkg/runtime/interfaces" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" "github.com/flyteorg/flytestdlib/config" - "k8s.io/apimachinery/pkg/api/resource" ) const taskResourceKey = "task_resources" var taskResourceConfig = config.MustRegisterSection(taskResourceKey, &TaskResourceSpec{ - Defaults: interfaces.TaskResourceSet{ - CPU: resource.MustParse("2"), - Memory: resource.MustParse("200Mi"), - }, - Limits: interfaces.TaskResourceSet{ - CPU: resource.MustParse("2"), - Memory: resource.MustParse("1Gi"), - GPU: resource.MustParse("1"), - }, + Defaults: interfaces.TaskResourceSet{}, + Limits: interfaces.TaskResourceSet{}, }) type TaskResourceSpec struct { @@ -25,7 +18,7 @@ type TaskResourceSpec struct { Limits interfaces.TaskResourceSet `json:"limits"` } -// Implementation of an interfaces.TaskResourceConfiguration +// TaskResourceProvider Implementation of an interfaces.TaskResourceConfiguration type TaskResourceProvider struct{} func (p *TaskResourceProvider) GetDefaults() interfaces.TaskResourceSet { @@ -36,6 +29,34 @@ func (p *TaskResourceProvider) GetLimits() interfaces.TaskResourceSet { return taskResourceConfig.GetConfig().(*TaskResourceSpec).Limits } +// ConstructTaskResourceSpec takes the configuration struct and turns it into the protobuf struct +func (p *TaskResourceProvider) ConstructTaskResourceSpec(a interfaces.TaskResourceSet) admin.TaskResourceSpec { + res := admin.TaskResourceSpec{} + if a.CPU != nil { + res.Cpu = a.CPU.String() + } + if a.GPU != nil { + res.Gpu = a.GPU.String() + } + if a.Memory != nil { + res.Memory = a.Memory.String() + } + if a.EphemeralStorage != nil { + res.EphemeralStorage = a.EphemeralStorage.String() + } + return res +} + +func (p *TaskResourceProvider) GetAsAttribute() admin.TaskResourceAttributes { + defaults := p.ConstructTaskResourceSpec(p.GetDefaults()) + limits := p.ConstructTaskResourceSpec(p.GetLimits()) + + return admin.TaskResourceAttributes{ + Defaults: &defaults, + Limits: &limits, + } +} + func NewTaskResourceProvider() interfaces.TaskResourceConfiguration { return &TaskResourceProvider{} } diff --git a/pkg/runtime/task_resource_provider_test.go b/pkg/runtime/task_resource_provider_test.go new file mode 100644 index 000000000..8ad903fef --- /dev/null +++ b/pkg/runtime/task_resource_provider_test.go @@ -0,0 +1,21 @@ +package runtime + +import ( + "testing" + + "github.com/flyteorg/flyteadmin/pkg/common/testutils" + "github.com/stretchr/testify/assert" + + "github.com/flyteorg/flyteadmin/pkg/runtime/interfaces" + "k8s.io/apimachinery/pkg/api/resource" +) + +func TestNewTaskResourceProvider(t *testing.T) { + tt := &TaskResourceSpec{ + Defaults: interfaces.TaskResourceSet{ + GPU: testutils.GetPtr(resource.MustParse("0")), + }, + } + assert.True(t, tt.Defaults.GPU.IsZero()) + assert.Nil(t, tt.Defaults.Storage) +} diff --git a/pkg/workflowengine/impl/prepare_execution.go b/pkg/workflowengine/impl/prepare_execution.go index f2a778e27..dff9ccbae 100644 --- a/pkg/workflowengine/impl/prepare_execution.go +++ b/pkg/workflowengine/impl/prepare_execution.go @@ -42,6 +42,7 @@ func addPermissions(securityCtx *core.SecurityContext, roleNameKey string, flyte func addExecutionOverrides(taskPluginOverrides []*admin.PluginOverride, workflowExecutionConfig *admin.WorkflowExecutionConfig, recoveryExecution *core.WorkflowExecutionIdentifier, taskResources *interfaces.TaskResources, flyteWf *v1alpha1.FlyteWorkflow) { + executionConfig := v1alpha1.ExecutionConfig{ TaskPluginImpls: make(map[string]v1alpha1.TaskPluginOverride), RecoveryExecution: v1alpha1.WorkflowExecutionIdentifier{ @@ -67,37 +68,37 @@ func addExecutionOverrides(taskPluginOverrides []*admin.PluginOverride, } if taskResources != nil { var requests = v1alpha1.TaskResourceSpec{} - if !taskResources.Defaults.CPU.IsZero() { - requests.CPU = taskResources.Defaults.CPU + if taskResources.Defaults.CPU != nil && !taskResources.Defaults.CPU.IsZero() { + requests.CPU = *taskResources.Defaults.CPU } - if !taskResources.Defaults.Memory.IsZero() { - requests.Memory = taskResources.Defaults.Memory + if taskResources.Defaults.Memory != nil && !taskResources.Defaults.Memory.IsZero() { + requests.Memory = *taskResources.Defaults.Memory } - if !taskResources.Defaults.EphemeralStorage.IsZero() { - requests.EphemeralStorage = taskResources.Defaults.EphemeralStorage + if taskResources.Defaults.EphemeralStorage != nil && !taskResources.Defaults.EphemeralStorage.IsZero() { + requests.EphemeralStorage = *taskResources.Defaults.EphemeralStorage } - if !taskResources.Defaults.Storage.IsZero() { - requests.Storage = taskResources.Defaults.Storage + if taskResources.Defaults.Storage != nil && !taskResources.Defaults.Storage.IsZero() { + requests.Storage = *taskResources.Defaults.Storage } - if !taskResources.Defaults.GPU.IsZero() { - requests.GPU = taskResources.Defaults.GPU + if taskResources.Defaults.GPU != nil && !taskResources.Defaults.GPU.IsZero() { + requests.GPU = *taskResources.Defaults.GPU } var limits = v1alpha1.TaskResourceSpec{} - if !taskResources.Limits.CPU.IsZero() { - limits.CPU = taskResources.Limits.CPU + if taskResources.Limits.CPU != nil && !taskResources.Limits.CPU.IsZero() { + limits.CPU = *taskResources.Limits.CPU } - if !taskResources.Limits.Memory.IsZero() { - limits.Memory = taskResources.Limits.Memory + if taskResources.Limits.Memory != nil && !taskResources.Limits.Memory.IsZero() { + limits.Memory = *taskResources.Limits.Memory } - if !taskResources.Limits.EphemeralStorage.IsZero() { - limits.EphemeralStorage = taskResources.Limits.EphemeralStorage + if taskResources.Limits.EphemeralStorage != nil && !taskResources.Limits.EphemeralStorage.IsZero() { + limits.EphemeralStorage = *taskResources.Limits.EphemeralStorage } - if !taskResources.Limits.Storage.IsZero() { - limits.Storage = taskResources.Limits.Storage + if taskResources.Limits.Storage != nil && !taskResources.Limits.Storage.IsZero() { + limits.Storage = *taskResources.Limits.Storage } - if !taskResources.Limits.GPU.IsZero() { - limits.GPU = taskResources.Limits.GPU + if taskResources.Limits.GPU != nil && !taskResources.Limits.GPU.IsZero() { + limits.GPU = *taskResources.Limits.GPU } executionConfig.TaskResources = v1alpha1.TaskResources{ Requests: requests, diff --git a/pkg/workflowengine/impl/prepare_execution_test.go b/pkg/workflowengine/impl/prepare_execution_test.go index 38e155636..a55d5d888 100644 --- a/pkg/workflowengine/impl/prepare_execution_test.go +++ b/pkg/workflowengine/impl/prepare_execution_test.go @@ -4,6 +4,8 @@ import ( "testing" "time" + "github.com/flyteorg/flyteadmin/pkg/common/testutils" + "github.com/flyteorg/flyteadmin/pkg/workflowengine/interfaces" runtimeInterfaces "github.com/flyteorg/flyteadmin/pkg/runtime/interfaces" @@ -125,15 +127,15 @@ func TestAddExecutionOverrides(t *testing.T) { workflow := &v1alpha1.FlyteWorkflow{} addExecutionOverrides(nil, nil, nil, &interfaces.TaskResources{ Defaults: runtimeInterfaces.TaskResourceSet{ - CPU: resource.MustParse("1"), - Memory: resource.MustParse("100Gi"), + CPU: testutils.GetPtr(resource.MustParse("1")), + Memory: testutils.GetPtr(resource.MustParse("100Gi")), }, Limits: runtimeInterfaces.TaskResourceSet{ - CPU: resource.MustParse("2"), - Memory: resource.MustParse("200Gi"), - Storage: resource.MustParse("5Gi"), - EphemeralStorage: resource.MustParse("1Gi"), - GPU: resource.MustParse("1"), + CPU: testutils.GetPtr(resource.MustParse("2")), + Memory: testutils.GetPtr(resource.MustParse("200Gi")), + Storage: testutils.GetPtr(resource.MustParse("5Gi")), + EphemeralStorage: testutils.GetPtr(resource.MustParse("1Gi")), + GPU: testutils.GetPtr(resource.MustParse("1")), }, }, workflow) assert.EqualValues(t, v1alpha1.TaskResourceSpec{