Skip to content

Commit

Permalink
Adding v2 trainjob validation webhook
Browse files Browse the repository at this point in the history
fixing runtime

Signed-off-by: Akshay Chitneni <[email protected]>
  • Loading branch information
Akshay Chitneni committed Nov 4, 2024
1 parent 9e46f9d commit a3ea261
Show file tree
Hide file tree
Showing 19 changed files with 454 additions and 81 deletions.
6 changes: 6 additions & 0 deletions pkg/constants/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,15 @@ const (
// JobInitializer is the Job name for the initializer.
JobInitializer string = "initializer"

// JobExporter is the Job name for the exporter.
JobExporter string = "exporter"

// ContainerModelInitializer is the container name for the model initializer.
ContainerModelInitializer string = "model-initializer"

// ContainerModelExporter is the container name for the model exporter.
ContainerModelExporter string = "model-exporter"

// ContainerDatasetInitializer is the container name for the dataset initializer.
ContainerDatasetInitializer string = "dataset-initializer"

Expand Down
16 changes: 3 additions & 13 deletions pkg/controller.v2/trainjob_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,13 @@ package controllerv2

import (
"context"
"errors"
"fmt"
runtimeUtils "github.com/kubeflow/training-operator/pkg/util.v2/runtime"

"github.com/go-logr/logr"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/client-go/tools/record"
"k8s.io/klog/v2"
"k8s.io/utils/ptr"
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/client/apiutil"
Expand All @@ -34,8 +33,6 @@ import (
jobruntimes "github.com/kubeflow/training-operator/pkg/runtime.v2"
)

var errorUnsupportedRuntime = errors.New("the specified runtime is not supported")

type TrainJobReconciler struct {
log logr.Logger
client client.Client
Expand Down Expand Up @@ -73,10 +70,10 @@ func (r *TrainJobReconciler) Reconcile(ctx context.Context, req ctrl.Request) (c
func (r *TrainJobReconciler) createOrUpdateObjs(ctx context.Context, trainJob *kubeflowv2.TrainJob) error {
log := ctrl.LoggerFrom(ctx)

runtimeRefGK := runtimeRefToGroupKind(trainJob.Spec.RuntimeRef).String()
runtimeRefGK := runtimeUtils.RuntimeRefToGroupKind(trainJob.Spec.RuntimeRef).String()
runtime, ok := r.runtimes[runtimeRefGK]
if !ok {
return fmt.Errorf("%w: %s", errorUnsupportedRuntime, runtimeRefGK)
return fmt.Errorf("%w: %s", runtimeUtils.ErrorUnsupportedRuntime, runtimeRefGK)
}
objs, err := runtime.NewObjects(ctx, trainJob)
if err != nil {
Expand Down Expand Up @@ -117,13 +114,6 @@ func (r *TrainJobReconciler) createOrUpdateObjs(ctx context.Context, trainJob *k
return nil
}

func runtimeRefToGroupKind(runtimeRef kubeflowv2.RuntimeRef) schema.GroupKind {
return schema.GroupKind{
Group: ptr.Deref(runtimeRef.APIGroup, ""),
Kind: ptr.Deref(runtimeRef.Kind, ""),
}
}

func (r *TrainJobReconciler) SetupWithManager(mgr ctrl.Manager) error {
b := ctrl.NewControllerManagedBy(mgr).
For(&kubeflowv2.TrainJob{})
Expand Down
16 changes: 11 additions & 5 deletions pkg/runtime.v2/core/clustertrainingruntime.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"context"
"errors"
"fmt"
jobsetv1alpha2 "sigs.k8s.io/jobset/api/jobset/v1alpha2"

"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apimachinery/pkg/util/validation/field"
Expand Down Expand Up @@ -64,14 +65,19 @@ func (r *ClusterTrainingRuntime) EventHandlerRegistrars() []runtime.ReconcilerBu
}

func (r *ClusterTrainingRuntime) ValidateObjects(ctx context.Context, old, new *kubeflowv2.TrainJob) (admission.Warnings, field.ErrorList) {
clusterTrainingRuntime := &kubeflowv2.ClusterTrainingRuntime{}
if err := r.client.Get(ctx, client.ObjectKey{
Namespace: old.Namespace,
Name: old.Spec.RuntimeRef.Name,
}, &kubeflowv2.ClusterTrainingRuntime{}); err != nil {
Namespace: new.Namespace,
Name: new.Spec.RuntimeRef.Name,
}, clusterTrainingRuntime); err != nil {
return nil, field.ErrorList{
field.Invalid(field.NewPath("spec", "RuntimeRef"), old.Spec.RuntimeRef,
field.Invalid(field.NewPath("spec", "RuntimeRef"), new.Spec.RuntimeRef,
fmt.Sprintf("%v: specified clusterTrainingRuntime must be created before the TrainJob is created", err)),
}
}
return r.framework.RunCustomValidationPlugins(old, new)
info := r.getRuntimeInfo(ctx, new, clusterTrainingRuntime.Spec.Template, clusterTrainingRuntime.Spec.MLPolicy, clusterTrainingRuntime.Spec.PodGroupPolicy)
jobSetTemplate := jobsetv1alpha2.JobSet{
Spec: clusterTrainingRuntime.Spec.Template.Spec,
}
return r.framework.RunCustomValidationPlugins(jobSetTemplate.DeepCopy(), info, old, new)
}
49 changes: 31 additions & 18 deletions pkg/runtime.v2/core/trainingruntime.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,26 @@ func (r *TrainingRuntime) NewObjects(ctx context.Context, trainJob *kubeflowv2.T
func (r *TrainingRuntime) buildObjects(
ctx context.Context, trainJob *kubeflowv2.TrainJob, jobSetTemplateSpec kubeflowv2.JobSetTemplateSpec, mlPolicy *kubeflowv2.MLPolicy, podGroupPolicy *kubeflowv2.PodGroupPolicy,
) ([]client.Object, error) {

info := r.getRuntimeInfo(ctx, trainJob, jobSetTemplateSpec, mlPolicy, podGroupPolicy)
if err := r.framework.RunEnforceMLPolicyPlugins(info, trainJob); err != nil {
return nil, err
}

if err := r.framework.RunEnforcePodGroupPolicyPlugins(info, trainJob); err != nil {
return nil, err
}

jobSetTemplate := jobsetv1alpha2.JobSet{
Spec: jobSetTemplateSpec.Spec,
}

return r.framework.RunComponentBuilderPlugins(ctx, jobSetTemplate.DeepCopy(), info, trainJob)
}

func (r *TrainingRuntime) getRuntimeInfo(
ctx context.Context, trainJob *kubeflowv2.TrainJob, jobSetTemplateSpec kubeflowv2.JobSetTemplateSpec, mlPolicy *kubeflowv2.MLPolicy, podGroupPolicy *kubeflowv2.PodGroupPolicy) *runtime.Info {

propagationLabels := jobSetTemplateSpec.Labels
if propagationLabels == nil && trainJob.Spec.Labels != nil {
propagationLabels = make(map[string]string, len(trainJob.Spec.Labels))
Expand Down Expand Up @@ -112,19 +132,7 @@ func (r *TrainingRuntime) buildObjects(

info := runtime.NewInfo(opts...)

if err := r.framework.RunEnforceMLPolicyPlugins(info, trainJob); err != nil {
return nil, err
}

if err := r.framework.RunEnforcePodGroupPolicyPlugins(info, trainJob); err != nil {
return nil, err
}

jobSetTemplate := jobsetv1alpha2.JobSet{
Spec: jobSetTemplateSpec.Spec,
}

return r.framework.RunComponentBuilderPlugins(ctx, jobSetTemplate.DeepCopy(), info, trainJob)
return info
}

func (r *TrainingRuntime) EventHandlerRegistrars() []runtime.ReconcilerBuilder {
Expand All @@ -136,14 +144,19 @@ func (r *TrainingRuntime) EventHandlerRegistrars() []runtime.ReconcilerBuilder {
}

func (r *TrainingRuntime) ValidateObjects(ctx context.Context, old, new *kubeflowv2.TrainJob) (admission.Warnings, field.ErrorList) {
trainingRuntime := &kubeflowv2.TrainingRuntime{}
if err := r.client.Get(ctx, client.ObjectKey{
Namespace: old.Namespace,
Name: old.Spec.RuntimeRef.Name,
}, &kubeflowv2.TrainingRuntime{}); err != nil {
Namespace: new.Namespace,
Name: new.Spec.RuntimeRef.Name,
}, trainingRuntime); err != nil {
return nil, field.ErrorList{
field.Invalid(field.NewPath("spec", "runtimeRef"), old.Spec.RuntimeRef,
field.Invalid(field.NewPath("spec", "runtimeRef"), new.Spec.RuntimeRef,
fmt.Sprintf("%v: specified trainingRuntime must be created before the TrainJob is created", err)),
}
}
return r.framework.RunCustomValidationPlugins(old, new)
info := r.getRuntimeInfo(ctx, new, trainingRuntime.Spec.Template, trainingRuntime.Spec.MLPolicy, trainingRuntime.Spec.PodGroupPolicy)
jobSetTemplate := jobsetv1alpha2.JobSet{
Spec: trainingRuntime.Spec.Template.Spec,
}
return r.framework.RunCustomValidationPlugins(jobSetTemplate.DeepCopy(), info, old, new)
}
4 changes: 2 additions & 2 deletions pkg/runtime.v2/framework/core/framework.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,11 @@ func (f *Framework) RunEnforcePodGroupPolicyPlugins(info *runtime.Info, trainJob
return nil
}

func (f *Framework) RunCustomValidationPlugins(oldObj, newObj *kubeflowv2.TrainJob) (admission.Warnings, field.ErrorList) {
func (f *Framework) RunCustomValidationPlugins(runtimeJobTemplate client.Object, info *runtime.Info, oldObj, newObj *kubeflowv2.TrainJob) (admission.Warnings, field.ErrorList) {
var aggregatedWarnings admission.Warnings
var aggregatedErrors field.ErrorList
for _, plugin := range f.customValidationPlugins {
warnings, errs := plugin.Validate(oldObj, newObj)
warnings, errs := plugin.Validate(runtimeJobTemplate, info, oldObj, newObj)
if len(warnings) != 0 {
aggregatedWarnings = append(aggregatedWarnings, warnings...)
}
Expand Down
5 changes: 4 additions & 1 deletion pkg/runtime.v2/framework/core/framework_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ func TestNew(t *testing.T) {
customValidationPlugins: []framework.CustomValidationPlugin{
&mpi.MPI{},
&torch.Torch{},
&jobset.JobSet{},
},
watchExtensionPlugins: []framework.WatchExtensionPlugin{
&coscheduling.CoScheduling{},
Expand Down Expand Up @@ -364,7 +365,9 @@ func TestRunCustomValidationPlugins(t *testing.T) {
if err != nil {
t.Fatal(err)
}
warnings, errs := fwk.RunCustomValidationPlugins(tc.oldObj, tc.newObj)
runtimeInfo := runtime.NewInfo()
jobSetTemplate := testingutil.MakeJobSetWrapper(metav1.NamespaceDefault, "test")
warnings, errs := fwk.RunCustomValidationPlugins(jobSetTemplate, runtimeInfo, tc.oldObj, tc.newObj)
if diff := cmp.Diff(tc.wantWarnings, warnings, cmpopts.SortSlices(func(a, b string) bool { return a < b })); len(diff) != 0 {
t.Errorf("Unexpected warninigs (-want,+got):\n%s", diff)
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/runtime.v2/framework/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ type EnforceMLPolicyPlugin interface {

type CustomValidationPlugin interface {
Plugin
Validate(oldObj, newObj *kubeflowv2.TrainJob) (admission.Warnings, field.ErrorList)
Validate(runtimeJobTemplate client.Object, runtimeInfo *runtime.Info, oldObj, newObj *kubeflowv2.TrainJob) (admission.Warnings, field.ErrorList)
}

type ComponentBuilderPlugin interface {
Expand Down
105 changes: 105 additions & 0 deletions pkg/runtime.v2/framework/plugins/jobset/jobset.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@ package jobset
import (
"context"
"fmt"
util_v2 "github.com/kubeflow/training-operator/pkg/util.v2"
"k8s.io/apimachinery/pkg/util/validation/field"
"maps"
"sigs.k8s.io/controller-runtime/pkg/webhook/admission"

"github.com/go-logr/logr"
"k8s.io/apimachinery/pkg/api/equality"
Expand Down Expand Up @@ -50,6 +53,7 @@ type JobSet struct {

var _ framework.WatchExtensionPlugin = (*JobSet)(nil)
var _ framework.ComponentBuilderPlugin = (*JobSet)(nil)
var _ framework.CustomValidationPlugin = (*JobSet)(nil)

const Name = constants.JobSetKind

Expand Down Expand Up @@ -140,3 +144,104 @@ func (j *JobSet) ReconcilerBuilders() []runtime.ReconcilerBuilder {
},
}
}

func (j *JobSet) Validate(runtimeJobTemplate client.Object, runtimeInfo *runtime.Info, oldObj, newObj *kubeflowv2.TrainJob) (admission.Warnings, field.ErrorList) {

var allErrs field.ErrorList
specPath := field.NewPath("spec")
runtimeRefPath := specPath.Child("runtimeRef")

jobSet, ok := runtimeJobTemplate.(*jobsetv1alpha2.JobSet)
if !ok {
return nil, nil
}

if newObj.Spec.ModelConfig != nil && newObj.Spec.ModelConfig.Input != nil {
if !util_v2.JobExists(jobSet.Spec.ReplicatedJobs, constants.JobInitializer) {
allErrs = append(allErrs, field.Invalid(runtimeRefPath, newObj.Spec.RuntimeRef, "trainingRuntime should have initializer job when trainJob is configured with input modelConfig"))
} else {
for _, job := range jobSet.Spec.ReplicatedJobs {
if job.Name == constants.JobInitializer {
if !util_v2.ContainerExists(job.Template.Spec.Template.Spec.InitContainers, constants.ContainerModelInitializer) {
allErrs = append(allErrs, field.Invalid(runtimeRefPath, newObj.Spec.RuntimeRef, "trainingRuntime should have container with name - model-initializer in the initializer job"))
}
}
}
}
}

if newObj.Spec.DatasetConfig != nil {
if !util_v2.JobExists(jobSet.Spec.ReplicatedJobs, constants.JobInitializer) {
allErrs = append(allErrs, field.Invalid(runtimeRefPath, newObj.Spec.RuntimeRef, "trainingRuntime should have initializer job when trainJob is configured with datasetConfig"))
} else {
for _, job := range jobSet.Spec.ReplicatedJobs {
if job.Name == constants.JobInitializer {
if !util_v2.ContainerExists(job.Template.Spec.Template.Spec.InitContainers, constants.ContainerDatasetInitializer) {
allErrs = append(allErrs, field.Invalid(runtimeRefPath, newObj.Spec.RuntimeRef, "trainingRuntime should have container with name - dataset-initializer in the initializer job"))
}
}
}
}
}

if newObj.Spec.ModelConfig != nil && newObj.Spec.ModelConfig.Output != nil {
if !util_v2.JobExists(jobSet.Spec.ReplicatedJobs, constants.JobExporter) {
allErrs = append(allErrs, field.Invalid(runtimeRefPath, newObj.Spec.RuntimeRef, "trainingRuntime should have exporter job when trainJob is configured with output modelConfig"))
} else {
for _, job := range jobSet.Spec.ReplicatedJobs {
if job.Name == constants.JobExporter {
if !util_v2.ContainerExists(job.Template.Spec.Template.Spec.InitContainers, constants.ContainerModelExporter) {
allErrs = append(allErrs, field.Invalid(runtimeRefPath, newObj.Spec.RuntimeRef, "trainingRuntime should have container with name - model-exporter in the exporter job"))
}
}
}
}
}

if len(newObj.Spec.PodSpecOverrides) != 0 {
podSpecOverridesPath := specPath.Child("podSpecOverrides")
jobsMap := map[string]bool{}
for _, job := range jobSet.Spec.ReplicatedJobs {
jobsMap[job.Name] = true
}
// validate if jobOverrides are valid
for idx, override := range newObj.Spec.PodSpecOverrides {
for _, job := range override.TargetJobs {
if _, found := jobsMap[job.Name]; !found {
allErrs = append(allErrs, field.Invalid(podSpecOverridesPath, newObj.Spec.PodSpecOverrides, fmt.Sprintf("job: %s, configured in the podOverride should be present in the referenced training runtime", job)))
}
}
if len(override.Containers) != 0 {
// validate if containerOverrides are valid
containerMap := map[string]bool{}
for _, job := range jobSet.Spec.ReplicatedJobs {
for _, container := range job.Template.Spec.Template.Spec.Containers {
containerMap[container.Name] = true
}
}
containerOverridePath := podSpecOverridesPath.Index(idx)
for _, container := range override.Containers {
if _, found := containerMap[container.Name]; !found {
allErrs = append(allErrs, field.Invalid(containerOverridePath, override.Containers, fmt.Sprintf("container: %s, configured in the containerOverride should be present in the referenced training runtime", container.Name)))
}
}
}
if len(override.InitContainers) != 0 {
// validate if initContainerOverrides are valid
initContainerMap := map[string]bool{}
for _, job := range jobSet.Spec.ReplicatedJobs {
for _, initContainer := range job.Template.Spec.Template.Spec.InitContainers {
initContainerMap[initContainer.Name] = true
}
}
initContainerOverridePath := podSpecOverridesPath.Index(idx)
for _, container := range override.Containers {
if _, found := initContainerMap[container.Name]; !found {
allErrs = append(allErrs, field.Invalid(initContainerOverridePath, override.InitContainers, fmt.Sprintf("initContainer: %s, configured in the initContainerOverride should be present in the referenced training runtime", container.Name)))
}
}
}
}
}
return nil, allErrs
}
16 changes: 13 additions & 3 deletions pkg/runtime.v2/framework/plugins/mpi/mpi.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package mpi

import (
"context"
"strconv"

"k8s.io/apimachinery/pkg/util/validation/field"
"sigs.k8s.io/controller-runtime/pkg/client"
Expand Down Expand Up @@ -55,7 +56,16 @@ func (m *MPI) EnforceMLPolicy(info *runtime.Info, trainJob *kubeflowv2.TrainJob)
return nil
}

// TODO: Need to implement validations for MPIJob.
func (m *MPI) Validate(oldObj, newObj *kubeflowv2.TrainJob) (admission.Warnings, field.ErrorList) {
return nil, nil
func (m *MPI) Validate(runtimeJobTemplate client.Object, runtimeInfo *runtime.Info, oldJobObj, newJobObj *kubeflowv2.TrainJob) (admission.Warnings, field.ErrorList) {
var allErrs field.ErrorList
specPath := field.NewPath("spec")
if newJobObj.Spec.Trainer != nil {
numProcPerNodePath := specPath.Child("trainer").Child("numProcPerNode")
if runtimeInfo.RuntimePolicy.MLPolicy.MPI != nil {
if _, err := strconv.Atoi(*newJobObj.Spec.Trainer.NumProcPerNode); err != nil {
allErrs = append(allErrs, field.Invalid(numProcPerNodePath, newJobObj.Spec.Trainer.NumProcPerNode, "should have an int value"))
}
}
}
return nil, allErrs
}
Loading

0 comments on commit a3ea261

Please sign in to comment.