diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go index 3a86ed9a5..acaf09680 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go @@ -31,23 +31,6 @@ type ReplicaEntry struct { RestartPolicy commonOp.RestartPolicy } -// ExtractMPICurrentCondition will return the first job condition for MPI -func ExtractMPICurrentCondition(jobConditions []commonOp.JobCondition) (commonOp.JobCondition, error) { - if jobConditions != nil { - sort.Slice(jobConditions, func(i, j int) bool { - return jobConditions[i].LastTransitionTime.Time.After(jobConditions[j].LastTransitionTime.Time) - }) - - for _, jc := range jobConditions { - if jc.Status == v1.ConditionTrue { - return jc, nil - } - } - } - - return commonOp.JobCondition{}, fmt.Errorf("found no current condition. Conditions: %+v", jobConditions) -} - // ExtractCurrentCondition will return the first job condition for tensorflow/pytorch func ExtractCurrentCondition(jobConditions []commonOp.JobCondition) (commonOp.JobCondition, error) { if jobConditions != nil { @@ -60,14 +43,17 @@ func ExtractCurrentCondition(jobConditions []commonOp.JobCondition) (commonOp.Jo return jc, nil } } + return commonOp.JobCondition{}, fmt.Errorf("found no current condition. Conditions: %+v", jobConditions) } - - return commonOp.JobCondition{}, fmt.Errorf("found no current condition. Conditions: %+v", jobConditions) + return commonOp.JobCondition{}, nil } // GetPhaseInfo will return the phase of kubeflow job func GetPhaseInfo(currentCondition commonOp.JobCondition, occurredAt time.Time, taskPhaseInfo pluginsCore.TaskInfo) (pluginsCore.PhaseInfo, error) { + if len(currentCondition.Type) == 0 { + return pluginsCore.PhaseInfoQueued(occurredAt, pluginsCore.DefaultPhaseVersion, "JobCreated"), nil + } switch currentCondition.Type { case commonOp.JobCreated: return pluginsCore.PhaseInfoQueued(occurredAt, pluginsCore.DefaultPhaseVersion, "JobCreated"), nil diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator_test.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator_test.go index 8914e976f..7f74e3f58 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator_test.go @@ -18,7 +18,7 @@ import ( meta_v1 "k8s.io/apimachinery/pkg/apis/meta/v1" ) -func TestExtractMPICurrentCondition(t *testing.T) { +func TestExtractCurrentCondition(t *testing.T) { jobCreated := commonOp.JobCondition{ Type: commonOp.JobCreated, Status: corev1.ConditionTrue, @@ -31,35 +31,21 @@ func TestExtractMPICurrentCondition(t *testing.T) { jobCreated, jobRunningActive, } - currentCondition, err := ExtractMPICurrentCondition(jobConditions) + currentCondition, err := ExtractCurrentCondition(jobConditions) assert.NoError(t, err) assert.Equal(t, currentCondition, jobCreated) jobConditions = nil - currentCondition, err = ExtractMPICurrentCondition(jobConditions) - assert.Error(t, err) + currentCondition, err = ExtractCurrentCondition(jobConditions) + assert.NoError(t, err) assert.Equal(t, currentCondition, commonOp.JobCondition{}) - assert.Equal(t, err, fmt.Errorf("found no current condition. Conditions: %+v", jobConditions)) -} -func TestExtractCurrentCondition(t *testing.T) { - jobCreated := commonOp.JobCondition{ - Type: commonOp.JobCreated, - Status: corev1.ConditionTrue, - } - jobRunningActive := commonOp.JobCondition{ - Type: commonOp.JobRunning, - Status: corev1.ConditionFalse, - } - jobConditions := []commonOp.JobCondition{ - jobCreated, - jobRunningActive, - } - currentCondition, err := ExtractCurrentCondition(jobConditions) + currentCondition, err = ExtractCurrentCondition(nil) assert.NoError(t, err) - assert.Equal(t, currentCondition, jobCreated) + assert.Equal(t, currentCondition, commonOp.JobCondition{}) - jobConditions = nil + jobUnknown := commonOp.JobCondition{Type: "unknown"} + jobConditions = []commonOp.JobCondition{jobUnknown} currentCondition, err = ExtractCurrentCondition(jobConditions) assert.Error(t, err) assert.Equal(t, currentCondition, commonOp.JobCondition{}) @@ -67,10 +53,17 @@ func TestExtractCurrentCondition(t *testing.T) { } func TestGetPhaseInfo(t *testing.T) { + jobCreating := commonOp.JobCondition{} + taskPhase, err := GetPhaseInfo(jobCreating, time.Now(), pluginsCore.TaskInfo{}) + assert.NoError(t, err) + assert.Equal(t, pluginsCore.PhaseQueued, taskPhase.Phase()) + assert.NotNil(t, taskPhase.Info()) + assert.Nil(t, err) + jobCreated := commonOp.JobCondition{ Type: commonOp.JobCreated, } - taskPhase, err := GetPhaseInfo(jobCreated, time.Now(), pluginsCore.TaskInfo{}) + taskPhase, err = GetPhaseInfo(jobCreated, time.Now(), pluginsCore.TaskInfo{}) assert.NoError(t, err) assert.Equal(t, pluginsCore.PhaseQueued, taskPhase.Phase()) assert.NotNil(t, taskPhase.Info()) diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/config.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/config.go new file mode 100644 index 000000000..7803e9097 --- /dev/null +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/config.go @@ -0,0 +1,32 @@ +package common + +import ( + "time" + + pluginsConfig "github.com/flyteorg/flyteplugins/go/tasks/config" + "github.com/flyteorg/flytestdlib/config" +) + +//go:generate pflags Config --default-var=defaultConfig + +var ( + defaultConfig = Config{ + Timeout: config.Duration{Duration: 1 * time.Minute}, + } + + configSection = pluginsConfig.MustRegisterSubSection("kf-operator", &defaultConfig) +) + +// Config is config for 'pytorch' plugin +type Config struct { + // If kubeflow operator doesn't update the status of the task after this timeout, the task will be considered failed. + Timeout config.Duration `json:"timeout,omitempty"` +} + +func GetConfig() *Config { + return configSection.GetConfig().(*Config) +} + +func SetConfig(cfg *Config) error { + return configSection.SetConfig(cfg) +} diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/config_flags.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/config_flags.go new file mode 100755 index 000000000..9fb5c0297 --- /dev/null +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/config_flags.go @@ -0,0 +1,55 @@ +// Code generated by go generate; DO NOT EDIT. +// This file was generated by robots. + +package common + +import ( + "encoding/json" + "reflect" + + "fmt" + + "github.com/spf13/pflag" +) + +// If v is a pointer, it will get its element value or the zero value of the element type. +// If v is not a pointer, it will return it as is. +func (Config) elemValueOrNil(v interface{}) interface{} { + if t := reflect.TypeOf(v); t.Kind() == reflect.Ptr { + if reflect.ValueOf(v).IsNil() { + return reflect.Zero(t.Elem()).Interface() + } else { + return reflect.ValueOf(v).Interface() + } + } else if v == nil { + return reflect.Zero(t).Interface() + } + + return v +} + +func (Config) mustJsonMarshal(v interface{}) string { + raw, err := json.Marshal(v) + if err != nil { + panic(err) + } + + return string(raw) +} + +func (Config) mustMarshalJSON(v json.Marshaler) string { + raw, err := v.MarshalJSON() + if err != nil { + panic(err) + } + + return string(raw) +} + +// GetPFlagSet will return strongly types pflags for all fields in Config and its nested types. The format of the +// flags is json-name.json-sub-name... etc. +func (cfg Config) GetPFlagSet(prefix string) *pflag.FlagSet { + cmdFlags := pflag.NewFlagSet("Config", pflag.ExitOnError) + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "timeout"), defaultConfig.Timeout.String(), "") + return cmdFlags +} diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/config_flags_test.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/config_flags_test.go new file mode 100755 index 000000000..0afdf456c --- /dev/null +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/config_flags_test.go @@ -0,0 +1,116 @@ +// Code generated by go generate; DO NOT EDIT. +// This file was generated by robots. + +package common + +import ( + "encoding/json" + "fmt" + "reflect" + "strings" + "testing" + + "github.com/mitchellh/mapstructure" + "github.com/stretchr/testify/assert" +) + +var dereferencableKindsConfig = map[reflect.Kind]struct{}{ + reflect.Array: {}, reflect.Chan: {}, reflect.Map: {}, reflect.Ptr: {}, reflect.Slice: {}, +} + +// Checks if t is a kind that can be dereferenced to get its underlying type. +func canGetElementConfig(t reflect.Kind) bool { + _, exists := dereferencableKindsConfig[t] + return exists +} + +// This decoder hook tests types for json unmarshaling capability. If implemented, it uses json unmarshal to build the +// object. Otherwise, it'll just pass on the original data. +func jsonUnmarshalerHookConfig(_, to reflect.Type, data interface{}) (interface{}, error) { + unmarshalerType := reflect.TypeOf((*json.Unmarshaler)(nil)).Elem() + if to.Implements(unmarshalerType) || reflect.PtrTo(to).Implements(unmarshalerType) || + (canGetElementConfig(to.Kind()) && to.Elem().Implements(unmarshalerType)) { + + raw, err := json.Marshal(data) + if err != nil { + fmt.Printf("Failed to marshal Data: %v. Error: %v. Skipping jsonUnmarshalHook", data, err) + return data, nil + } + + res := reflect.New(to).Interface() + err = json.Unmarshal(raw, &res) + if err != nil { + fmt.Printf("Failed to umarshal Data: %v. Error: %v. Skipping jsonUnmarshalHook", data, err) + return data, nil + } + + return res, nil + } + + return data, nil +} + +func decode_Config(input, result interface{}) error { + config := &mapstructure.DecoderConfig{ + TagName: "json", + WeaklyTypedInput: true, + Result: result, + DecodeHook: mapstructure.ComposeDecodeHookFunc( + mapstructure.StringToTimeDurationHookFunc(), + mapstructure.StringToSliceHookFunc(","), + jsonUnmarshalerHookConfig, + ), + } + + decoder, err := mapstructure.NewDecoder(config) + if err != nil { + return err + } + + return decoder.Decode(input) +} + +func join_Config(arr interface{}, sep string) string { + listValue := reflect.ValueOf(arr) + strs := make([]string, 0, listValue.Len()) + for i := 0; i < listValue.Len(); i++ { + strs = append(strs, fmt.Sprintf("%v", listValue.Index(i))) + } + + return strings.Join(strs, sep) +} + +func testDecodeJson_Config(t *testing.T, val, result interface{}) { + assert.NoError(t, decode_Config(val, result)) +} + +func testDecodeRaw_Config(t *testing.T, vStringSlice, result interface{}) { + assert.NoError(t, decode_Config(vStringSlice, result)) +} + +func TestConfig_GetPFlagSet(t *testing.T) { + val := Config{} + cmdFlags := val.GetPFlagSet("") + assert.True(t, cmdFlags.HasFlags()) +} + +func TestConfig_SetFlags(t *testing.T) { + actual := Config{} + cmdFlags := actual.GetPFlagSet("") + assert.True(t, cmdFlags.HasFlags()) + + t.Run("Test_timeout", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := defaultConfig.Timeout.String() + + cmdFlags.Set("timeout", testValue) + if vString, err := cmdFlags.GetString("timeout"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.Timeout) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) +} diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go index 92c864340..7b6974ef8 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go @@ -209,7 +209,10 @@ func (mpiOperatorResourceHandler) GetTaskPhase(_ context.Context, pluginContext if err != nil { return pluginsCore.PhaseInfoUndefined, err } - currentCondition, err := common.ExtractMPICurrentCondition(app.Status.Conditions) + if app.Status.StartTime == nil && app.CreationTimestamp.Add(common.GetConfig().Timeout.Duration).Before(time.Now()) { + return pluginsCore.PhaseInfoUndefined, fmt.Errorf("kubeflow operator hasn't updated the mpi custom resource since creation time %v", app.CreationTimestamp) + } + currentCondition, err := common.ExtractCurrentCondition(app.Status.Conditions) if err != nil { return pluginsCore.PhaseInfoUndefined, err } @@ -223,7 +226,6 @@ func (mpiOperatorResourceHandler) GetTaskPhase(_ context.Context, pluginContext } return common.GetMPIPhaseInfo(currentCondition, occurredAt, taskPhaseInfo) - } func init() { diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go index fbced8085..4a442b29a 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go @@ -281,7 +281,7 @@ func dummyMPIJobResource(mpiResourceHandler mpiOperatorResourceHandler, Status: mpiOp.JobStatus{ Conditions: jobConditions, ReplicaStatuses: nil, - StartTime: nil, + StartTime: &v1.Time{Time: time.Now()}, CompletionTime: nil, LastReconcileTime: nil, }, diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go index 8652dc7d3..d6f18b1a3 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go @@ -231,6 +231,9 @@ func (pytorchOperatorResourceHandler) GetTaskPhase(_ context.Context, pluginCont return pluginsCore.PhaseInfoUndefined, err } + if app.Status.StartTime == nil && app.CreationTimestamp.Add(common.GetConfig().Timeout.Duration).Before(time.Now()) { + return pluginsCore.PhaseInfoUndefined, fmt.Errorf("kubeflow operator hasn't updated the pytorch custom resource since creation time %v", app.CreationTimestamp) + } currentCondition, err := common.ExtractCurrentCondition(app.Status.Conditions) if err != nil { return pluginsCore.PhaseInfoUndefined, err diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go index 8122a85e3..69b2cb556 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go @@ -293,8 +293,9 @@ func dummyPytorchJobResource(pytorchResourceHandler pytorchOperatorResourceHandl return &kubeflowv1.PyTorchJob{ ObjectMeta: v1.ObjectMeta{ - Name: jobName, - Namespace: jobNamespace, + CreationTimestamp: v1.Time{Time: time.Now()}, + Name: jobName, + Namespace: jobNamespace, }, Spec: resource.(*kubeflowv1.PyTorchJob).Spec, Status: commonOp.JobStatus{ diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go index 5e2ed948c..fb006935b 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go @@ -209,6 +209,10 @@ func (tensorflowOperatorResourceHandler) GetTaskPhase(_ context.Context, pluginC return pluginsCore.PhaseInfoUndefined, err } + if app.Status.StartTime == nil && app.CreationTimestamp.Add(common.GetConfig().Timeout.Duration).Before(time.Now()) { + return pluginsCore.PhaseInfoUndefined, fmt.Errorf("kubeflow operator hasn't updated the tensorflow custom resource since creation time %v", app.CreationTimestamp) + } + currentCondition, err := common.ExtractCurrentCondition(app.Status.Conditions) if err != nil { return pluginsCore.PhaseInfoUndefined, err diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go index 254d3efc1..5ec5658d8 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go @@ -283,7 +283,7 @@ func dummyTensorFlowJobResource(tensorflowResourceHandler tensorflowOperatorReso Status: commonOp.JobStatus{ Conditions: jobConditions, ReplicaStatuses: nil, - StartTime: nil, + StartTime: &v1.Time{Time: time.Now()}, CompletionTime: nil, LastReconcileTime: nil, },