diff --git a/flyteplugins/go/tasks/pluginmachinery/core/plugin.go b/flyteplugins/go/tasks/pluginmachinery/core/plugin.go index a4314816d3..add87cc1d5 100644 --- a/flyteplugins/go/tasks/pluginmachinery/core/plugin.go +++ b/flyteplugins/go/tasks/pluginmachinery/core/plugin.go @@ -2,10 +2,14 @@ package core import ( "context" + "fmt" ) //go:generate mockery -all -case=underscore +// https://github.com/flyteorg/flytepropeller/blob/979fabe1d1b22b01645259a03b8096f227681d08/pkg/utils/encoder.go#L25-L26 +const minGeneratedNameLength = 8 + type TaskType = string // A Lazy loading function, that will load the plugin. Plugins should be initialized in this method. It is guaranteed @@ -34,6 +38,8 @@ type PluginEntry struct { type PluginProperties struct { // Instructs the execution engine to not attempt to cache lookup or write for the node. DisableNodeLevelCaching bool + // Specifies the length of TaskExecutionID generated name. default: 50 + GeneratedNameMaxLength *int } // Interface for the core Flyte plugin @@ -52,3 +58,18 @@ type Plugin interface { // Finalize is always called, after Handle or Abort. Finalize should be an idempotent operation Finalize(ctx context.Context, tCtx TaskExecutionContext) error } + +// Loads and validates a plugin. +func LoadPlugin(ctx context.Context, iCtx SetupContext, entry PluginEntry) (Plugin, error) { + plugin, err := entry.LoadPlugin(ctx, iCtx) + if err != nil { + return nil, err + } + + length := plugin.GetProperties().GeneratedNameMaxLength + if length != nil && *length < minGeneratedNameLength { + return nil, fmt.Errorf("GeneratedNameMaxLength needs to be greater then %d", minGeneratedNameLength) + } + + return plugin, err +} diff --git a/flyteplugins/go/tasks/pluginmachinery/core/plugin_test.go b/flyteplugins/go/tasks/pluginmachinery/core/plugin_test.go new file mode 100644 index 0000000000..ea299a6aac --- /dev/null +++ b/flyteplugins/go/tasks/pluginmachinery/core/plugin_test.go @@ -0,0 +1,94 @@ +package core_test + +import ( + "context" + "testing" + + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks" + "gotest.tools/assert" +) + +func TestLoadPlugin(t *testing.T) { + corePluginType := "core" + + t.Run("valid", func(t *testing.T) { + corePlugin := &mocks.Plugin{} + corePlugin.On("GetID").Return(corePluginType) + corePlugin.OnGetProperties().Return(core.PluginProperties{}) + + corePluginEntry := core.PluginEntry{ + ID: corePluginType, + RegisteredTaskTypes: []core.TaskType{corePluginType}, + LoadPlugin: func(ctx context.Context, iCtx core.SetupContext) (core.Plugin, error) { + return corePlugin, nil + }, + } + setupCtx := mocks.SetupContext{} + p, err := core.LoadPlugin(context.TODO(), &setupCtx, corePluginEntry) + assert.NilError(t, err) + assert.Equal(t, corePluginType, p.GetID()) + }) + + t.Run("valid GeneratedNameMaxLength", func(t *testing.T) { + corePlugin := &mocks.Plugin{} + corePlugin.On("GetID").Return(corePluginType) + length := 10 + corePlugin.OnGetProperties().Return(core.PluginProperties{ + GeneratedNameMaxLength: &length, + }) + + corePluginEntry := core.PluginEntry{ + ID: corePluginType, + RegisteredTaskTypes: []core.TaskType{corePluginType}, + LoadPlugin: func(ctx context.Context, iCtx core.SetupContext) (core.Plugin, error) { + return corePlugin, nil + }, + } + setupCtx := mocks.SetupContext{} + p, err := core.LoadPlugin(context.TODO(), &setupCtx, corePluginEntry) + assert.NilError(t, err) + assert.Equal(t, corePluginType, p.GetID()) + }) + + t.Run("valid GeneratedNameMaxLength", func(t *testing.T) { + corePlugin := &mocks.Plugin{} + corePlugin.On("GetID").Return(corePluginType) + length := 10 + corePlugin.OnGetProperties().Return(core.PluginProperties{ + GeneratedNameMaxLength: &length, + }) + + corePluginEntry := core.PluginEntry{ + ID: corePluginType, + RegisteredTaskTypes: []core.TaskType{corePluginType}, + LoadPlugin: func(ctx context.Context, iCtx core.SetupContext) (core.Plugin, error) { + return corePlugin, nil + }, + } + setupCtx := mocks.SetupContext{} + _, err := core.LoadPlugin(context.TODO(), &setupCtx, corePluginEntry) + assert.NilError(t, err) + }) + + t.Run("invalid GeneratedNameMaxLength", func(t *testing.T) { + corePlugin := &mocks.Plugin{} + corePlugin.On("GetID").Return(corePluginType) + length := 5 + corePlugin.OnGetProperties().Return(core.PluginProperties{ + GeneratedNameMaxLength: &length, + }) + + corePluginEntry := core.PluginEntry{ + ID: corePluginType, + RegisteredTaskTypes: []core.TaskType{corePluginType}, + LoadPlugin: func(ctx context.Context, iCtx core.SetupContext) (core.Plugin, error) { + return corePlugin, nil + }, + } + setupCtx := mocks.SetupContext{} + _, err := core.LoadPlugin(context.TODO(), &setupCtx, corePluginEntry) + assert.Error(t, err, "GeneratedNameMaxLength needs to be greater then 8") + }) + +} diff --git a/flyteplugins/go/tasks/pluginmachinery/k8s/mocks/plugin.go b/flyteplugins/go/tasks/pluginmachinery/k8s/mocks/plugin.go index 6418108c30..0859458276 100644 --- a/flyteplugins/go/tasks/pluginmachinery/k8s/mocks/plugin.go +++ b/flyteplugins/go/tasks/pluginmachinery/k8s/mocks/plugin.go @@ -101,6 +101,38 @@ func (_m *Plugin) BuildResource(ctx context.Context, taskCtx core.TaskExecutionC return r0, r1 } +type Plugin_GetProperties struct { + *mock.Call +} + +func (_m Plugin_GetProperties) Return(_a0 k8s.PluginProperties) *Plugin_GetProperties { + return &Plugin_GetProperties{Call: _m.Call.Return(_a0)} +} + +func (_m *Plugin) OnGetProperties() *Plugin_GetProperties { + c := _m.On("GetProperties") + return &Plugin_GetProperties{Call: c} +} + +func (_m *Plugin) OnGetPropertiesMatch(matchers ...interface{}) *Plugin_GetProperties { + c := _m.On("GetProperties", matchers...) + return &Plugin_GetProperties{Call: c} +} + +// GetProperties provides a mock function with given fields: +func (_m *Plugin) GetProperties() k8s.PluginProperties { + ret := _m.Called() + + var r0 k8s.PluginProperties + if rf, ok := ret.Get(0).(func() k8s.PluginProperties); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(k8s.PluginProperties) + } + + return r0 +} + type Plugin_GetTaskPhase struct { *mock.Call } diff --git a/flyteplugins/go/tasks/pluginmachinery/k8s/plugin.go b/flyteplugins/go/tasks/pluginmachinery/k8s/plugin.go index 529314e2ff..6db3d4b3cd 100644 --- a/flyteplugins/go/tasks/pluginmachinery/k8s/plugin.go +++ b/flyteplugins/go/tasks/pluginmachinery/k8s/plugin.go @@ -34,12 +34,18 @@ type PluginEntry struct { DefaultForTaskTypes []pluginsCore.TaskType // Returns a new KubeClient to be used instead of the internal controller-runtime client. CustomKubeClient func(ctx context.Context) (pluginsCore.KubeClient, error) +} + +// System level properties that this Plugin supports +type PluginProperties struct { // Disables the inclusion of OwnerReferences in kubernetes resources that this plugin is responsible for. // Disabling is only useful if resources will be created in a remote cluster. DisableInjectOwnerReferences bool // Boolean that indicates if finalizer injection should be disabled for resources that this plugin is // responsible for. DisableInjectFinalizer bool + // Specifies the length of TaskExecutionID generated name. default: 50 + GeneratedNameMaxLength *int } // Special context passed in to plugins when checking task phase @@ -76,4 +82,7 @@ type Plugin interface { // any operations that might take a long time (limits are configured system-wide) should be offloaded to the // background. GetTaskPhase(ctx context.Context, pluginContext PluginContext, resource client.Object) (pluginsCore.PhaseInfo, error) + + // Properties desired by the plugin + GetProperties() PluginProperties } diff --git a/flyteplugins/go/tasks/plugins/k8s/container/container.go b/flyteplugins/go/tasks/plugins/k8s/container/container.go index 10667d5c23..84a35cd680 100755 --- a/flyteplugins/go/tasks/plugins/k8s/container/container.go +++ b/flyteplugins/go/tasks/plugins/k8s/container/container.go @@ -22,6 +22,10 @@ const ( type Plugin struct { } +func (Plugin) GetProperties() k8s.PluginProperties { + return k8s.PluginProperties{} +} + func (Plugin) GetTaskPhase(ctx context.Context, pluginContext k8s.PluginContext, r client.Object) (pluginsCore.PhaseInfo, error) { pod := r.(*v1.Pod) diff --git a/flyteplugins/go/tasks/plugins/k8s/container/container_test.go b/flyteplugins/go/tasks/plugins/k8s/container/container_test.go index c3f905ff97..df32ccdc3c 100755 --- a/flyteplugins/go/tasks/plugins/k8s/container/container_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/container/container_test.go @@ -20,6 +20,7 @@ import ( pluginsCoreMock "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s" pluginsIOMock "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io/mocks" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/k8s" ) var resourceRequirements = &v1.ResourceRequirements{ @@ -190,3 +191,9 @@ func TestContainerTaskExecutor_GetTaskStatus(t *testing.T) { assert.Equal(t, pluginsCore.PhaseSuccess, phaseInfo.Phase()) }) } + +func TestContainerTaskExecutor_GetProperties(t *testing.T) { + plugin := Plugin{} + expected := k8s.PluginProperties{} + assert.Equal(t, expected, plugin.GetProperties()) +} diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go index 12be3d2b84..005e97771c 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go @@ -31,6 +31,10 @@ type pytorchOperatorResourceHandler struct { // Sanity test that the plugin implements method of k8s.Plugin var _ k8s.Plugin = pytorchOperatorResourceHandler{} +func (pytorchOperatorResourceHandler) GetProperties() k8s.PluginProperties { + return k8s.PluginProperties{} +} + // Defines a func to create a query object (typically just object and type meta portions) that's used to query k8s // resources. func (pytorchOperatorResourceHandler) BuildIdentityResource(ctx context.Context, taskCtx pluginsCore.TaskExecutionMetadata) (client.Object, error) { diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go index 8e3968c954..7f3d458bba 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go @@ -10,6 +10,7 @@ import ( "github.com/flyteorg/flyteplugins/go/tasks/logs" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/k8s" commonOp "github.com/kubeflow/tf-operator/pkg/apis/common/v1" corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/resource" @@ -355,3 +356,9 @@ func TestGetLogs(t *testing.T) { assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-worker-0/pod?namespace=pytorch-namespace", jobNamespace, jobName), jobLogs[1].Uri) assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-worker-1/pod?namespace=pytorch-namespace", jobNamespace, jobName), jobLogs[2].Uri) } + +func TestGetProperties(t *testing.T) { + pytorchResourceHandler := pytorchOperatorResourceHandler{} + expected := k8s.PluginProperties{} + assert.Equal(t, expected, pytorchResourceHandler.GetProperties()) +} diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go index 56fb95338c..382da309e9 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go @@ -31,6 +31,10 @@ type tensorflowOperatorResourceHandler struct { // Sanity test that the plugin implements method of k8s.Plugin var _ k8s.Plugin = tensorflowOperatorResourceHandler{} +func (tensorflowOperatorResourceHandler) GetProperties() k8s.PluginProperties { + return k8s.PluginProperties{} +} + // Defines a func to create a query object (typically just object and type meta portions) that's used to query k8s // resources. func (tensorflowOperatorResourceHandler) BuildIdentityResource(ctx context.Context, taskCtx pluginsCore.TaskExecutionMetadata) (client.Object, error) { diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go index 05d8791f15..1c60386a8e 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go @@ -10,6 +10,7 @@ import ( "github.com/flyteorg/flyteplugins/go/tasks/logs" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/k8s" commonOp "github.com/kubeflow/tf-operator/pkg/apis/common/v1" corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/resource" @@ -364,3 +365,9 @@ func TestGetLogs(t *testing.T) { assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-psReplica-0/pod?namespace=tensorflow-namespace", jobNamespace, jobName), jobLogs[2].Uri) assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-chiefReplica-0/pod?namespace=tensorflow-namespace", jobNamespace, jobName), jobLogs[3].Uri) } + +func TestGetProperties(t *testing.T) { + tensorflowResourceHandler := tensorflowOperatorResourceHandler{} + expected := k8s.PluginProperties{} + assert.Equal(t, expected, tensorflowResourceHandler.GetProperties()) +} diff --git a/flyteplugins/go/tasks/plugins/k8s/sagemaker/plugin.go b/flyteplugins/go/tasks/plugins/k8s/sagemaker/plugin.go index 1f545adb64..ec04050516 100644 --- a/flyteplugins/go/tasks/plugins/k8s/sagemaker/plugin.go +++ b/flyteplugins/go/tasks/plugins/k8s/sagemaker/plugin.go @@ -25,6 +25,10 @@ type awsSagemakerPlugin struct { TaskType pluginsCore.TaskType } +func (awsSagemakerPlugin) GetProperties() k8s.PluginProperties { + return k8s.PluginProperties{} +} + func (m awsSagemakerPlugin) BuildIdentityResource(_ context.Context, _ pluginsCore.TaskExecutionMetadata) (client.Object, error) { if m.TaskType == trainingJobTaskType || m.TaskType == customTrainingJobTaskType { return &trainingjobv1.TrainingJob{}, nil diff --git a/flyteplugins/go/tasks/plugins/k8s/sagemaker/plugin_test.go b/flyteplugins/go/tasks/plugins/k8s/sagemaker/plugin_test.go index 613a78fb91..94eb0f9a77 100644 --- a/flyteplugins/go/tasks/plugins/k8s/sagemaker/plugin_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/sagemaker/plugin_test.go @@ -9,10 +9,12 @@ import ( "github.com/flyteorg/flytestdlib/contextutils" "github.com/flyteorg/flytestdlib/promutils/labeled" + "github.com/stretchr/testify/assert" hpojobv1 "github.com/aws/amazon-sagemaker-operator-for-k8s/api/v1/hyperparametertuningjob" trainingjobv1 "github.com/aws/amazon-sagemaker-operator-for-k8s/api/v1/trainingjob" pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/k8s" ) func Test_awsSagemakerPlugin_BuildIdentityResource(t *testing.T) { @@ -55,6 +57,12 @@ func Test_awsSagemakerPlugin_BuildIdentityResource(t *testing.T) { } } +func TestGetProperties(t *testing.T) { + plugin := awsSagemakerPlugin{} + expected := k8s.PluginProperties{} + assert.Equal(t, expected, plugin.GetProperties()) +} + func init() { labeled.SetMetricKeys(contextutils.NamespaceKey) } diff --git a/flyteplugins/go/tasks/plugins/k8s/sidecar/sidecar.go b/flyteplugins/go/tasks/plugins/k8s/sidecar/sidecar.go index 905dbd12b5..ca92e70bcd 100755 --- a/flyteplugins/go/tasks/plugins/k8s/sidecar/sidecar.go +++ b/flyteplugins/go/tasks/plugins/k8s/sidecar/sidecar.go @@ -73,6 +73,10 @@ type sidecarJob struct { PrimaryContainerName string } +func (sidecarResourceHandler) GetProperties() k8s.PluginProperties { + return k8s.PluginProperties{} +} + func (sidecarResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext) (client.Object, error) { var podSpec k8sv1.PodSpec var primaryContainerName string diff --git a/flyteplugins/go/tasks/plugins/k8s/sidecar/sidecar_test.go b/flyteplugins/go/tasks/plugins/k8s/sidecar/sidecar_test.go index da67a3613b..d1006156d1 100755 --- a/flyteplugins/go/tasks/plugins/k8s/sidecar/sidecar_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/sidecar/sidecar_test.go @@ -26,6 +26,7 @@ import ( pluginsCoreMock "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s/config" pluginsIOMock "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io/mocks" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/k8s" ) const ResourceNvidiaGPU = "nvidia.com/gpu" @@ -494,3 +495,9 @@ func TestDemystifiedSidecarStatus_PrimaryMissing(t *testing.T) { assert.Nil(t, err) assert.Equal(t, pluginsCore.PhasePermanentFailure, phaseInfo.Phase()) } + +func TestGetProperties(t *testing.T) { + handler := &sidecarResourceHandler{} + expected := k8s.PluginProperties{} + assert.Equal(t, expected, handler.GetProperties()) +} diff --git a/flyteplugins/go/tasks/plugins/k8s/spark/spark.go b/flyteplugins/go/tasks/plugins/k8s/spark/spark.go index 6f9b62683b..4a04699736 100755 --- a/flyteplugins/go/tasks/plugins/k8s/spark/spark.go +++ b/flyteplugins/go/tasks/plugins/k8s/spark/spark.go @@ -45,10 +45,6 @@ var sparkTaskType = "spark" type sparkResourceHandler struct { } -func (sparkResourceHandler) GetProperties() pluginsCore.PluginProperties { - return pluginsCore.PluginProperties{} -} - func validateSparkJob(sparkJob *plugins.SparkJob) error { if sparkJob == nil { return fmt.Errorf("empty sparkJob") @@ -61,6 +57,10 @@ func validateSparkJob(sparkJob *plugins.SparkJob) error { return nil } +func (sparkResourceHandler) GetProperties() k8s.PluginProperties { + return k8s.PluginProperties{} +} + // Creates a new Job that will execute the main container as well as any generated types the result from the execution. func (sparkResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext) (client.Object, error) { taskTemplate, err := taskCtx.TaskReader().Read(ctx) diff --git a/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go b/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go index 8070f5b2f0..394d06d5d2 100755 --- a/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go @@ -7,6 +7,7 @@ import ( "testing" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s/config" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/k8s" "github.com/stretchr/testify/mock" @@ -470,3 +471,9 @@ func TestBuildResourceSpark(t *testing.T) { assert.NotNil(t, err) assert.Nil(t, resource) } + +func TestGetPropertiesSpark(t *testing.T) { + sparkResourceHandler := sparkResourceHandler{} + expected := k8s.PluginProperties{} + assert.Equal(t, expected, sparkResourceHandler.GetProperties()) +}