Skip to content

Commit

Permalink
Add GeneratedNameMaxLength property (flyteorg#167)
Browse files Browse the repository at this point in the history
  • Loading branch information
regadas committed Mar 24, 2021
1 parent 47a3c02 commit 98bfa7a
Show file tree
Hide file tree
Showing 16 changed files with 223 additions and 4 deletions.
21 changes: 21 additions & 0 deletions flyteplugins/go/tasks/pluginmachinery/core/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@ package core

import (
"context"
"fmt"
)

//go:generate mockery -all -case=underscore

// https://github.com/flyteorg/flytepropeller/blob/979fabe1d1b22b01645259a03b8096f227681d08/pkg/utils/encoder.go#L25-L26
const minGeneratedNameLength = 8

type TaskType = string

// A Lazy loading function, that will load the plugin. Plugins should be initialized in this method. It is guaranteed
Expand Down Expand Up @@ -34,6 +38,8 @@ type PluginEntry struct {
type PluginProperties struct {
// Instructs the execution engine to not attempt to cache lookup or write for the node.
DisableNodeLevelCaching bool
// Specifies the length of TaskExecutionID generated name. default: 50
GeneratedNameMaxLength *int
}

// Interface for the core Flyte plugin
Expand All @@ -52,3 +58,18 @@ type Plugin interface {
// Finalize is always called, after Handle or Abort. Finalize should be an idempotent operation
Finalize(ctx context.Context, tCtx TaskExecutionContext) error
}

// Loads and validates a plugin.
func LoadPlugin(ctx context.Context, iCtx SetupContext, entry PluginEntry) (Plugin, error) {
plugin, err := entry.LoadPlugin(ctx, iCtx)
if err != nil {
return nil, err
}

length := plugin.GetProperties().GeneratedNameMaxLength
if length != nil && *length < minGeneratedNameLength {
return nil, fmt.Errorf("GeneratedNameMaxLength needs to be greater then %d", minGeneratedNameLength)
}

return plugin, err
}
94 changes: 94 additions & 0 deletions flyteplugins/go/tasks/pluginmachinery/core/plugin_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package core_test

import (
"context"
"testing"

"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks"
"gotest.tools/assert"
)

func TestLoadPlugin(t *testing.T) {
corePluginType := "core"

t.Run("valid", func(t *testing.T) {
corePlugin := &mocks.Plugin{}
corePlugin.On("GetID").Return(corePluginType)
corePlugin.OnGetProperties().Return(core.PluginProperties{})

corePluginEntry := core.PluginEntry{
ID: corePluginType,
RegisteredTaskTypes: []core.TaskType{corePluginType},
LoadPlugin: func(ctx context.Context, iCtx core.SetupContext) (core.Plugin, error) {
return corePlugin, nil
},
}
setupCtx := mocks.SetupContext{}
p, err := core.LoadPlugin(context.TODO(), &setupCtx, corePluginEntry)
assert.NilError(t, err)
assert.Equal(t, corePluginType, p.GetID())
})

t.Run("valid GeneratedNameMaxLength", func(t *testing.T) {
corePlugin := &mocks.Plugin{}
corePlugin.On("GetID").Return(corePluginType)
length := 10
corePlugin.OnGetProperties().Return(core.PluginProperties{
GeneratedNameMaxLength: &length,
})

corePluginEntry := core.PluginEntry{
ID: corePluginType,
RegisteredTaskTypes: []core.TaskType{corePluginType},
LoadPlugin: func(ctx context.Context, iCtx core.SetupContext) (core.Plugin, error) {
return corePlugin, nil
},
}
setupCtx := mocks.SetupContext{}
p, err := core.LoadPlugin(context.TODO(), &setupCtx, corePluginEntry)
assert.NilError(t, err)
assert.Equal(t, corePluginType, p.GetID())
})

t.Run("valid GeneratedNameMaxLength", func(t *testing.T) {
corePlugin := &mocks.Plugin{}
corePlugin.On("GetID").Return(corePluginType)
length := 10
corePlugin.OnGetProperties().Return(core.PluginProperties{
GeneratedNameMaxLength: &length,
})

corePluginEntry := core.PluginEntry{
ID: corePluginType,
RegisteredTaskTypes: []core.TaskType{corePluginType},
LoadPlugin: func(ctx context.Context, iCtx core.SetupContext) (core.Plugin, error) {
return corePlugin, nil
},
}
setupCtx := mocks.SetupContext{}
_, err := core.LoadPlugin(context.TODO(), &setupCtx, corePluginEntry)
assert.NilError(t, err)
})

t.Run("invalid GeneratedNameMaxLength", func(t *testing.T) {
corePlugin := &mocks.Plugin{}
corePlugin.On("GetID").Return(corePluginType)
length := 5
corePlugin.OnGetProperties().Return(core.PluginProperties{
GeneratedNameMaxLength: &length,
})

corePluginEntry := core.PluginEntry{
ID: corePluginType,
RegisteredTaskTypes: []core.TaskType{corePluginType},
LoadPlugin: func(ctx context.Context, iCtx core.SetupContext) (core.Plugin, error) {
return corePlugin, nil
},
}
setupCtx := mocks.SetupContext{}
_, err := core.LoadPlugin(context.TODO(), &setupCtx, corePluginEntry)
assert.Error(t, err, "GeneratedNameMaxLength needs to be greater then 8")
})

}
32 changes: 32 additions & 0 deletions flyteplugins/go/tasks/pluginmachinery/k8s/mocks/plugin.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 9 additions & 0 deletions flyteplugins/go/tasks/pluginmachinery/k8s/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,18 @@ type PluginEntry struct {
DefaultForTaskTypes []pluginsCore.TaskType
// Returns a new KubeClient to be used instead of the internal controller-runtime client.
CustomKubeClient func(ctx context.Context) (pluginsCore.KubeClient, error)
}

// System level properties that this Plugin supports
type PluginProperties struct {
// Disables the inclusion of OwnerReferences in kubernetes resources that this plugin is responsible for.
// Disabling is only useful if resources will be created in a remote cluster.
DisableInjectOwnerReferences bool
// Boolean that indicates if finalizer injection should be disabled for resources that this plugin is
// responsible for.
DisableInjectFinalizer bool
// Specifies the length of TaskExecutionID generated name. default: 50
GeneratedNameMaxLength *int
}

// Special context passed in to plugins when checking task phase
Expand Down Expand Up @@ -76,4 +82,7 @@ type Plugin interface {
// any operations that might take a long time (limits are configured system-wide) should be offloaded to the
// background.
GetTaskPhase(ctx context.Context, pluginContext PluginContext, resource client.Object) (pluginsCore.PhaseInfo, error)

// Properties desired by the plugin
GetProperties() PluginProperties
}
4 changes: 4 additions & 0 deletions flyteplugins/go/tasks/plugins/k8s/container/container.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ const (
type Plugin struct {
}

func (Plugin) GetProperties() k8s.PluginProperties {
return k8s.PluginProperties{}
}

func (Plugin) GetTaskPhase(ctx context.Context, pluginContext k8s.PluginContext, r client.Object) (pluginsCore.PhaseInfo, error) {

pod := r.(*v1.Pod)
Expand Down
7 changes: 7 additions & 0 deletions flyteplugins/go/tasks/plugins/k8s/container/container_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
pluginsCoreMock "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s"
pluginsIOMock "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io/mocks"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/k8s"
)

var resourceRequirements = &v1.ResourceRequirements{
Expand Down Expand Up @@ -190,3 +191,9 @@ func TestContainerTaskExecutor_GetTaskStatus(t *testing.T) {
assert.Equal(t, pluginsCore.PhaseSuccess, phaseInfo.Phase())
})
}

func TestContainerTaskExecutor_GetProperties(t *testing.T) {
plugin := Plugin{}
expected := k8s.PluginProperties{}
assert.Equal(t, expected, plugin.GetProperties())
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ type pytorchOperatorResourceHandler struct {
// Sanity test that the plugin implements method of k8s.Plugin
var _ k8s.Plugin = pytorchOperatorResourceHandler{}

func (pytorchOperatorResourceHandler) GetProperties() k8s.PluginProperties {
return k8s.PluginProperties{}
}

// Defines a func to create a query object (typically just object and type meta portions) that's used to query k8s
// resources.
func (pytorchOperatorResourceHandler) BuildIdentityResource(ctx context.Context, taskCtx pluginsCore.TaskExecutionMetadata) (client.Object, error) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (

"github.com/flyteorg/flyteplugins/go/tasks/logs"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/k8s"
commonOp "github.com/kubeflow/tf-operator/pkg/apis/common/v1"
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/resource"
Expand Down Expand Up @@ -355,3 +356,9 @@ func TestGetLogs(t *testing.T) {
assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-worker-0/pod?namespace=pytorch-namespace", jobNamespace, jobName), jobLogs[1].Uri)
assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-worker-1/pod?namespace=pytorch-namespace", jobNamespace, jobName), jobLogs[2].Uri)
}

func TestGetProperties(t *testing.T) {
pytorchResourceHandler := pytorchOperatorResourceHandler{}
expected := k8s.PluginProperties{}
assert.Equal(t, expected, pytorchResourceHandler.GetProperties())
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ type tensorflowOperatorResourceHandler struct {
// Sanity test that the plugin implements method of k8s.Plugin
var _ k8s.Plugin = tensorflowOperatorResourceHandler{}

func (tensorflowOperatorResourceHandler) GetProperties() k8s.PluginProperties {
return k8s.PluginProperties{}
}

// Defines a func to create a query object (typically just object and type meta portions) that's used to query k8s
// resources.
func (tensorflowOperatorResourceHandler) BuildIdentityResource(ctx context.Context, taskCtx pluginsCore.TaskExecutionMetadata) (client.Object, error) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (

"github.com/flyteorg/flyteplugins/go/tasks/logs"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/k8s"
commonOp "github.com/kubeflow/tf-operator/pkg/apis/common/v1"
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/resource"
Expand Down Expand Up @@ -364,3 +365,9 @@ func TestGetLogs(t *testing.T) {
assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-psReplica-0/pod?namespace=tensorflow-namespace", jobNamespace, jobName), jobLogs[2].Uri)
assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-chiefReplica-0/pod?namespace=tensorflow-namespace", jobNamespace, jobName), jobLogs[3].Uri)
}

func TestGetProperties(t *testing.T) {
tensorflowResourceHandler := tensorflowOperatorResourceHandler{}
expected := k8s.PluginProperties{}
assert.Equal(t, expected, tensorflowResourceHandler.GetProperties())
}
4 changes: 4 additions & 0 deletions flyteplugins/go/tasks/plugins/k8s/sagemaker/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ type awsSagemakerPlugin struct {
TaskType pluginsCore.TaskType
}

func (awsSagemakerPlugin) GetProperties() k8s.PluginProperties {
return k8s.PluginProperties{}
}

func (m awsSagemakerPlugin) BuildIdentityResource(_ context.Context, _ pluginsCore.TaskExecutionMetadata) (client.Object, error) {
if m.TaskType == trainingJobTaskType || m.TaskType == customTrainingJobTaskType {
return &trainingjobv1.TrainingJob{}, nil
Expand Down
8 changes: 8 additions & 0 deletions flyteplugins/go/tasks/plugins/k8s/sagemaker/plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@ import (

"github.com/flyteorg/flytestdlib/contextutils"
"github.com/flyteorg/flytestdlib/promutils/labeled"
"github.com/stretchr/testify/assert"

hpojobv1 "github.com/aws/amazon-sagemaker-operator-for-k8s/api/v1/hyperparametertuningjob"
trainingjobv1 "github.com/aws/amazon-sagemaker-operator-for-k8s/api/v1/trainingjob"
pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/k8s"
)

func Test_awsSagemakerPlugin_BuildIdentityResource(t *testing.T) {
Expand Down Expand Up @@ -55,6 +57,12 @@ func Test_awsSagemakerPlugin_BuildIdentityResource(t *testing.T) {
}
}

func TestGetProperties(t *testing.T) {
plugin := awsSagemakerPlugin{}
expected := k8s.PluginProperties{}
assert.Equal(t, expected, plugin.GetProperties())
}

func init() {
labeled.SetMetricKeys(contextutils.NamespaceKey)
}
4 changes: 4 additions & 0 deletions flyteplugins/go/tasks/plugins/k8s/sidecar/sidecar.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ type sidecarJob struct {
PrimaryContainerName string
}

func (sidecarResourceHandler) GetProperties() k8s.PluginProperties {
return k8s.PluginProperties{}
}

func (sidecarResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext) (client.Object, error) {
var podSpec k8sv1.PodSpec
var primaryContainerName string
Expand Down
7 changes: 7 additions & 0 deletions flyteplugins/go/tasks/plugins/k8s/sidecar/sidecar_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
pluginsCoreMock "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s/config"
pluginsIOMock "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io/mocks"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/k8s"
)

const ResourceNvidiaGPU = "nvidia.com/gpu"
Expand Down Expand Up @@ -494,3 +495,9 @@ func TestDemystifiedSidecarStatus_PrimaryMissing(t *testing.T) {
assert.Nil(t, err)
assert.Equal(t, pluginsCore.PhasePermanentFailure, phaseInfo.Phase())
}

func TestGetProperties(t *testing.T) {
handler := &sidecarResourceHandler{}
expected := k8s.PluginProperties{}
assert.Equal(t, expected, handler.GetProperties())
}
8 changes: 4 additions & 4 deletions flyteplugins/go/tasks/plugins/k8s/spark/spark.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,6 @@ var sparkTaskType = "spark"
type sparkResourceHandler struct {
}

func (sparkResourceHandler) GetProperties() pluginsCore.PluginProperties {
return pluginsCore.PluginProperties{}
}

func validateSparkJob(sparkJob *plugins.SparkJob) error {
if sparkJob == nil {
return fmt.Errorf("empty sparkJob")
Expand All @@ -61,6 +57,10 @@ func validateSparkJob(sparkJob *plugins.SparkJob) error {
return nil
}

func (sparkResourceHandler) GetProperties() k8s.PluginProperties {
return k8s.PluginProperties{}
}

// Creates a new Job that will execute the main container as well as any generated types the result from the execution.
func (sparkResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext) (client.Object, error) {
taskTemplate, err := taskCtx.TaskReader().Read(ctx)
Expand Down
7 changes: 7 additions & 0 deletions flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"testing"

"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s/config"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/k8s"

"github.com/stretchr/testify/mock"

Expand Down Expand Up @@ -470,3 +471,9 @@ func TestBuildResourceSpark(t *testing.T) {
assert.NotNil(t, err)
assert.Nil(t, resource)
}

func TestGetPropertiesSpark(t *testing.T) {
sparkResourceHandler := sparkResourceHandler{}
expected := k8s.PluginProperties{}
assert.Equal(t, expected, sparkResourceHandler.GetProperties())
}

0 comments on commit 98bfa7a

Please sign in to comment.