From 27da395f1f577b96da7d84ed86f89ba0af79a1b4 Mon Sep 17 00:00:00 2001 From: Yuki Iwai Date: Thu, 11 Apr 2024 21:25:48 +0900 Subject: [PATCH] Implement webhook validations for the PaddleJob (#2057) Signed-off-by: Yuki Iwai --- manifests/base/webhook/manifests.yaml | 20 ++ manifests/base/webhook/patch.yaml | 3 + .../v1/paddlepaddle_validation.go | 76 -------- .../v1/paddlepaddle_validation_test.go | 166 ---------------- .../paddlepaddle/paddlepaddle_controller.go | 7 - .../paddlepaddle_controller_suite_test.go | 30 ++- .../paddlepaddle/paddlepaddle_webhook.go | 120 ++++++++++++ .../paddlepaddle/paddlepaddle_webhook_test.go | 181 ++++++++++++++++++ pkg/webhooks/webhooks.go | 3 +- 9 files changed, 352 insertions(+), 254 deletions(-) delete mode 100644 pkg/apis/kubeflow.org/v1/paddlepaddle_validation.go delete mode 100644 pkg/apis/kubeflow.org/v1/paddlepaddle_validation_test.go create mode 100644 pkg/webhooks/paddlepaddle/paddlepaddle_webhook.go create mode 100644 pkg/webhooks/paddlepaddle/paddlepaddle_webhook_test.go diff --git a/manifests/base/webhook/manifests.yaml b/manifests/base/webhook/manifests.yaml index 984c6e63d4..c8a69845b7 100644 --- a/manifests/base/webhook/manifests.yaml +++ b/manifests/base/webhook/manifests.yaml @@ -4,6 +4,26 @@ kind: ValidatingWebhookConfiguration metadata: name: validating-webhook-configuration webhooks: +- admissionReviewVersions: + - v1 + clientConfig: + service: + name: webhook-service + namespace: system + path: /validate-kubeflow-org-v1-paddlejob + failurePolicy: Fail + name: validator.paddlejob.training-operator.kubeflow.org + rules: + - apiGroups: + - kubeflow.org + apiVersions: + - v1 + operations: + - CREATE + - UPDATE + resources: + - paddlejobs + sideEffects: None - admissionReviewVersions: - v1 clientConfig: diff --git a/manifests/base/webhook/patch.yaml b/manifests/base/webhook/patch.yaml index cba6fbf0d2..a02b11bf1c 100644 --- a/manifests/base/webhook/patch.yaml +++ b/manifests/base/webhook/patch.yaml @@ -7,6 +7,9 @@ - op: replace path: /webhooks/2/clientConfig/service/name value: training-operator +- op: replace + path: /webhooks/3/clientConfig/service/name + value: training-operator - op: replace path: /metadata/name value: validator.training-operator.kubeflow.org diff --git a/pkg/apis/kubeflow.org/v1/paddlepaddle_validation.go b/pkg/apis/kubeflow.org/v1/paddlepaddle_validation.go deleted file mode 100644 index 995763b30b..0000000000 --- a/pkg/apis/kubeflow.org/v1/paddlepaddle_validation.go +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright 2022 The Kubeflow Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package v1 - -import ( - "fmt" - - apimachineryvalidation "k8s.io/apimachinery/pkg/api/validation" -) - -func ValidateV1PaddleJob(paddleJob *PaddleJob) error { - if errors := apimachineryvalidation.NameIsDNS1035Label(paddleJob.ObjectMeta.Name, false); errors != nil { - return fmt.Errorf("PaddleJob name is invalid: %v", errors) - } - if err := validatePaddleReplicaSpecs(paddleJob.Spec.PaddleReplicaSpecs); err != nil { - return err - } - return nil -} - -func validatePaddleReplicaSpecs(specs map[ReplicaType]*ReplicaSpec) error { - if specs == nil { - return fmt.Errorf("PaddleJobSpec is not valid") - } - for rType, value := range specs { - if value == nil || len(value.Template.Spec.Containers) == 0 { - return fmt.Errorf("PaddleJobSpec is not valid: containers definition expected in %v", rType) - } - // Make sure the replica type is valid. - validReplicaTypes := []ReplicaType{PaddleJobReplicaTypeMaster, PaddleJobReplicaTypeWorker} - - isValidReplicaType := false - for _, t := range validReplicaTypes { - if t == rType { - isValidReplicaType = true - break - } - } - - if !isValidReplicaType { - return fmt.Errorf("PaddleReplicaType is %v but must be one of %v", rType, validReplicaTypes) - } - - //Make sure the image is defined in the container - defaultContainerPresent := false - for _, container := range value.Template.Spec.Containers { - if container.Image == "" { - msg := fmt.Sprintf("PaddleJobSpec is not valid: Image is undefined in the container of %v", rType) - return fmt.Errorf(msg) - } - if container.Name == PaddleJobDefaultContainerName { - defaultContainerPresent = true - } - } - //Make sure there has at least one container named "paddle" - if !defaultContainerPresent { - msg := fmt.Sprintf("PaddleJobSpec is not valid: There is no container named %s in %v", PaddleJobDefaultContainerName, rType) - return fmt.Errorf(msg) - } - - } - - return nil -} diff --git a/pkg/apis/kubeflow.org/v1/paddlepaddle_validation_test.go b/pkg/apis/kubeflow.org/v1/paddlepaddle_validation_test.go deleted file mode 100644 index cd52f78fa1..0000000000 --- a/pkg/apis/kubeflow.org/v1/paddlepaddle_validation_test.go +++ /dev/null @@ -1,166 +0,0 @@ -// Copyright 2022 The Kubeflow Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package v1 - -import ( - "testing" - - corev1 "k8s.io/api/core/v1" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/utils/ptr" -) - -func TestValidateV1PaddleJob(t *testing.T) { - validPaddleReplicaSpecs := map[ReplicaType]*ReplicaSpec{ - PaddleJobReplicaTypeWorker: { - Replicas: ptr.To[int32](2), - RestartPolicy: RestartPolicyNever, - Template: corev1.PodTemplateSpec{ - Spec: corev1.PodSpec{ - Containers: []corev1.Container{{ - Name: "paddle", - Image: "registry.baidubce.com/paddlepaddle/paddle:2.4.0rc0-cpu", - Command: []string{"python"}, - Args: []string{ - "-m", - "paddle.distributed.launch", - "run_check", - }, - Ports: []corev1.ContainerPort{{ - Name: "master", - ContainerPort: int32(37777), - }}, - ImagePullPolicy: corev1.PullAlways, - }}, - }, - }, - }, - } - - testCases := map[string]struct { - paddleJob *PaddleJob - wantErr bool - }{ - "valid paddleJob": { - paddleJob: &PaddleJob{ - ObjectMeta: metav1.ObjectMeta{ - Name: "test", - }, - Spec: PaddleJobSpec{ - PaddleReplicaSpecs: validPaddleReplicaSpecs, - }, - }, - wantErr: false, - }, - "paddleJob name does not meet DNS1035": { - paddleJob: &PaddleJob{ - ObjectMeta: metav1.ObjectMeta{ - Name: "__test", - }, - Spec: PaddleJobSpec{ - PaddleReplicaSpecs: validPaddleReplicaSpecs, - }, - }, - wantErr: true, - }, - "no containers": { - paddleJob: &PaddleJob{ - ObjectMeta: metav1.ObjectMeta{ - Name: "test", - }, - Spec: PaddleJobSpec{ - PaddleReplicaSpecs: map[ReplicaType]*ReplicaSpec{ - PaddleJobReplicaTypeWorker: { - Template: corev1.PodTemplateSpec{ - Spec: corev1.PodSpec{ - Containers: []corev1.Container{}, - }, - }, - }, - }, - }, - }, - wantErr: true, - }, - "image is empty": { - paddleJob: &PaddleJob{ - ObjectMeta: metav1.ObjectMeta{ - Name: "test", - }, - Spec: PaddleJobSpec{ - PaddleReplicaSpecs: map[ReplicaType]*ReplicaSpec{ - PaddleJobReplicaTypeWorker: { - Template: corev1.PodTemplateSpec{ - Spec: corev1.PodSpec{ - Containers: []corev1.Container{ - { - Name: "paddle", - Image: "", - }, - }, - }, - }, - }, - }, - }, - }, - wantErr: true, - }, - "paddle default container name doesn't find": { - paddleJob: &PaddleJob{ - ObjectMeta: metav1.ObjectMeta{ - Name: "test", - }, - Spec: PaddleJobSpec{ - PaddleReplicaSpecs: map[ReplicaType]*ReplicaSpec{ - PaddleJobReplicaTypeWorker: { - Template: corev1.PodTemplateSpec{ - Spec: corev1.PodSpec{ - Containers: []corev1.Container{ - { - Name: "", - Image: "gcr.io/kubeflow-ci/paddle-dist-mnist_test:1.0", - }, - }, - }, - }, - }, - }, - }, - }, - wantErr: true, - }, - "replicaSpec is nil": { - paddleJob: &PaddleJob{ - ObjectMeta: metav1.ObjectMeta{ - Name: "test", - }, - Spec: PaddleJobSpec{ - PaddleReplicaSpecs: nil, - }, - }, - wantErr: true, - }, - } - - for name, tc := range testCases { - t.Run(name, func(t *testing.T) { - got := ValidateV1PaddleJob(tc.paddleJob) - if (got != nil) != tc.wantErr { - t.Fatalf("ValidateV1PaddleJob() error = %v, wantErr %v", got, tc.wantErr) - } - }) - } -} diff --git a/pkg/controller.v1/paddlepaddle/paddlepaddle_controller.go b/pkg/controller.v1/paddlepaddle/paddlepaddle_controller.go index f48927db40..7cae58e9c5 100644 --- a/pkg/controller.v1/paddlepaddle/paddlepaddle_controller.go +++ b/pkg/controller.v1/paddlepaddle/paddlepaddle_controller.go @@ -131,13 +131,6 @@ func (r *PaddleJobReconciler) Reconcile(ctx context.Context, req ctrl.Request) ( return ctrl.Result{}, client.IgnoreNotFound(err) } - if err = kubeflowv1.ValidateV1PaddleJob(paddlejob); err != nil { - logger.Error(err, "PaddleJob failed validation") - r.Recorder.Eventf(paddlejob, corev1.EventTypeWarning, commonutil.NewReason(kubeflowv1.PaddleJobKind, commonutil.JobFailedValidationReason), - "PaddleJob failed validation because %s", err) - return ctrl.Result{}, err - } - // Check if reconciliation is needed jobKey, err := common.KeyFunc(paddlejob) if err != nil { diff --git a/pkg/controller.v1/paddlepaddle/paddlepaddle_controller_suite_test.go b/pkg/controller.v1/paddlepaddle/paddlepaddle_controller_suite_test.go index d1f555f909..5d3505cb71 100644 --- a/pkg/controller.v1/paddlepaddle/paddlepaddle_controller_suite_test.go +++ b/pkg/controller.v1/paddlepaddle/paddlepaddle_controller_suite_test.go @@ -16,14 +16,18 @@ package paddle import ( "context" + "crypto/tls" + "fmt" + "net" "path/filepath" "testing" + "time" kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" "github.com/kubeflow/training-operator/pkg/controller.v1/common" + paddlewebhook "github.com/kubeflow/training-operator/pkg/webhooks/paddlepaddle" . "github.com/onsi/ginkgo/v2" - "github.com/onsi/gomega" . "github.com/onsi/gomega" "k8s.io/client-go/kubernetes/scheme" ctrl "sigs.k8s.io/controller-runtime" @@ -32,6 +36,7 @@ import ( logf "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/controller-runtime/pkg/log/zap" metricsserver "sigs.k8s.io/controller-runtime/pkg/metrics/server" + "sigs.k8s.io/controller-runtime/pkg/webhook" "volcano.sh/apis/pkg/apis/scheduling/v1beta1" //+kubebuilder:scaffold:imports ) @@ -61,6 +66,9 @@ var _ = BeforeSuite(func() { testEnv = &envtest.Environment{ CRDDirectoryPaths: []string{filepath.Join("..", "..", "..", "manifests", "base", "crds")}, ErrorIfCRDPathMissing: true, + WebhookInstallOptions: envtest.WebhookInstallOptions{ + Paths: []string{filepath.Join("..", "..", "..", "manifests", "base", "webhook", "manifests.yaml")}, + }, } cfg, err := testEnv.Start() @@ -82,19 +90,33 @@ var _ = BeforeSuite(func() { Metrics: metricsserver.Options{ BindAddress: "0", }, + WebhookServer: webhook.NewServer( + webhook.Options{ + Host: testEnv.WebhookInstallOptions.LocalServingHost, + Port: testEnv.WebhookInstallOptions.LocalServingPort, + CertDir: testEnv.WebhookInstallOptions.LocalServingCertDir, + }), }) - Expect(err).NotTo(gomega.HaveOccurred()) + Expect(err).NotTo(HaveOccurred()) gangSchedulingSetupFunc := common.GenNonGangSchedulerSetupFunc() r := NewReconciler(mgr, gangSchedulingSetupFunc) - - Expect(r.SetupWithManager(mgr, 1)).NotTo(gomega.HaveOccurred()) + Expect(r.SetupWithManager(mgr, 1)).NotTo(HaveOccurred()) + Expect(paddlewebhook.SetupWebhook(mgr)).NotTo(HaveOccurred()) go func() { defer GinkgoRecover() err = mgr.Start(testCtx) Expect(err).ToNot(HaveOccurred(), "failed to run manager") }() + + dialer := &net.Dialer{Timeout: time.Second} + addrPort := fmt.Sprintf("%s:%d", testEnv.WebhookInstallOptions.LocalServingHost, testEnv.WebhookInstallOptions.LocalServingPort) + Eventually(func(g Gomega) { + conn, err := tls.DialWithDialer(dialer, "tcp", addrPort, &tls.Config{InsecureSkipVerify: true}) + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(conn.Close()).NotTo(HaveOccurred()) + }).Should(Succeed()) }) var _ = AfterSuite(func() { diff --git a/pkg/webhooks/paddlepaddle/paddlepaddle_webhook.go b/pkg/webhooks/paddlepaddle/paddlepaddle_webhook.go new file mode 100644 index 0000000000..dab25d419a --- /dev/null +++ b/pkg/webhooks/paddlepaddle/paddlepaddle_webhook.go @@ -0,0 +1,120 @@ +/* +Copyright 2024 The Kubeflow Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package paddlepaddle + +import ( + "context" + "fmt" + "slices" + "strings" + + apimachineryvalidation "k8s.io/apimachinery/pkg/api/validation" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/util/validation/field" + "k8s.io/klog/v2" + ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/webhook" + "sigs.k8s.io/controller-runtime/pkg/webhook/admission" + + trainingoperator "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" +) + +var ( + specPath = field.NewPath("spec") + paddleReplicaSpecPath = specPath.Child("paddleReplicaSpecs") +) + +type Webhook struct{} + +func SetupWebhook(mgr ctrl.Manager) error { + return ctrl.NewWebhookManagedBy(mgr). + For(&trainingoperator.PaddleJob{}). + WithValidator(&Webhook{}). + Complete() +} + +// +kubebuilder:webhook:path=/validate-kubeflow-org-v1-paddlejob,mutating=false,failurePolicy=fail,sideEffects=None,groups=kubeflow.org,resources=paddlejobs,verbs=create;update,versions=v1,name=validator.paddlejob.training-operator.kubeflow.org,admissionReviewVersions=v1 + +var _ webhook.CustomValidator = &Webhook{} + +func (w Webhook) ValidateCreate(ctx context.Context, obj runtime.Object) (admission.Warnings, error) { + job := obj.(*trainingoperator.PaddleJob) + log := ctrl.LoggerFrom(ctx).WithName("paddlejob-webhook") + log.V(5).Info("Validating create", "paddleJob", klog.KObj(job)) + return nil, validatePaddleJob(job).ToAggregate() +} + +func (w Webhook) ValidateUpdate(ctx context.Context, _, newObj runtime.Object) (admission.Warnings, error) { + job := newObj.(*trainingoperator.PaddleJob) + log := ctrl.LoggerFrom(ctx).WithName("paddlejob-webhook") + log.V(5).Info("Validating update", "paddleJob", klog.KObj(job)) + return nil, validatePaddleJob(job).ToAggregate() +} + +func (w Webhook) ValidateDelete(context.Context, runtime.Object) (admission.Warnings, error) { + return nil, nil +} + +func validatePaddleJob(job *trainingoperator.PaddleJob) field.ErrorList { + var allErrs field.ErrorList + if errors := apimachineryvalidation.NameIsDNS1035Label(job.Name, false); len(errors) != 0 { + allErrs = append(allErrs, field.Invalid(field.NewPath("metadata").Child("name"), job.Name, fmt.Sprintf("should match: %v", strings.Join(errors, ",")))) + } + allErrs = append(allErrs, validateSpec(job.Spec.PaddleReplicaSpecs)...) + return allErrs +} + +func validateSpec(rSpecs map[trainingoperator.ReplicaType]*trainingoperator.ReplicaSpec) field.ErrorList { + var allErrs field.ErrorList + + if rSpecs == nil { + allErrs = append(allErrs, field.Required(paddleReplicaSpecPath, "must be required")) + } + for rType, rSpec := range rSpecs { + rolePath := paddleReplicaSpecPath.Key(string(rType)) + containersPath := rolePath.Child("template").Child("spec").Child("containers") + + // Make sure the replica type is valid. + validReplicaTypes := []trainingoperator.ReplicaType{ + trainingoperator.PaddleJobReplicaTypeMaster, + trainingoperator.PaddleJobReplicaTypeWorker, + } + if !slices.Contains(validReplicaTypes, rType) { + allErrs = append(allErrs, field.NotSupported(rolePath, rType, validReplicaTypes)) + } + + if rSpec == nil || len(rSpec.Template.Spec.Containers) == 0 { + allErrs = append(allErrs, field.Required(containersPath, "must be specified")) + } + + // Make sure the image is defined in the container + defaultContainerPresent := false + for idx, container := range rSpec.Template.Spec.Containers { + if container.Image == "" { + allErrs = append(allErrs, field.Required(containersPath.Index(idx).Child("image"), "must be required")) + } + if container.Name == trainingoperator.PaddleJobDefaultContainerName { + defaultContainerPresent = true + } + } + // Make sure there has at least one container named "paddle" + if !defaultContainerPresent { + allErrs = append(allErrs, field.Required(containersPath, fmt.Sprintf("must have at least one container with name %q", trainingoperator.PaddleJobDefaultContainerName))) + } + } + return allErrs +} diff --git a/pkg/webhooks/paddlepaddle/paddlepaddle_webhook_test.go b/pkg/webhooks/paddlepaddle/paddlepaddle_webhook_test.go new file mode 100644 index 0000000000..087aa5d87e --- /dev/null +++ b/pkg/webhooks/paddlepaddle/paddlepaddle_webhook_test.go @@ -0,0 +1,181 @@ +/* +Copyright 2024 The Kubeflow Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package paddlepaddle + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/util/validation/field" + "k8s.io/utils/ptr" + + trainingoperator "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" +) + +func TestValidateV1PaddleJob(t *testing.T) { + validPaddleReplicaSpecs := map[trainingoperator.ReplicaType]*trainingoperator.ReplicaSpec{ + trainingoperator.PaddleJobReplicaTypeWorker: { + Replicas: ptr.To[int32](2), + RestartPolicy: trainingoperator.RestartPolicyNever, + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{{ + Name: "paddle", + Image: "registry.baidubce.com/paddlepaddle/paddle:2.4.0rc0-cpu", + Command: []string{"python"}, + Args: []string{ + "-m", + "paddle.distributed.launch", + "run_check", + }, + Ports: []corev1.ContainerPort{{ + Name: "master", + ContainerPort: int32(37777), + }}, + ImagePullPolicy: corev1.PullAlways, + }}, + }, + }, + }, + } + testCases := map[string]struct { + paddleJob *trainingoperator.PaddleJob + wantErr field.ErrorList + }{ + "valid paddleJob": { + paddleJob: &trainingoperator.PaddleJob{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + }, + Spec: trainingoperator.PaddleJobSpec{ + PaddleReplicaSpecs: validPaddleReplicaSpecs, + }, + }, + }, + "paddleJob name does not meet DNS1035": { + paddleJob: &trainingoperator.PaddleJob{ + ObjectMeta: metav1.ObjectMeta{ + Name: "__test", + }, + Spec: trainingoperator.PaddleJobSpec{ + PaddleReplicaSpecs: validPaddleReplicaSpecs, + }, + }, + wantErr: field.ErrorList{ + field.Invalid(field.NewPath("metadata").Child("name"), "", ""), + }, + }, + "no containers": { + paddleJob: &trainingoperator.PaddleJob{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + }, + Spec: trainingoperator.PaddleJobSpec{ + PaddleReplicaSpecs: map[trainingoperator.ReplicaType]*trainingoperator.ReplicaSpec{ + trainingoperator.PaddleJobReplicaTypeWorker: { + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{}, + }, + }, + }, + }, + }, + }, + wantErr: field.ErrorList{ + field.Required(paddleReplicaSpecPath.Key(string(trainingoperator.PaddleJobReplicaTypeWorker)).Child("template").Child("spec").Child("containers"), ""), + field.Required(paddleReplicaSpecPath.Key(string(trainingoperator.PaddleJobReplicaTypeWorker)).Child("template").Child("spec").Child("containers"), ""), + }, + }, + "image is empty": { + paddleJob: &trainingoperator.PaddleJob{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + }, + Spec: trainingoperator.PaddleJobSpec{ + PaddleReplicaSpecs: map[trainingoperator.ReplicaType]*trainingoperator.ReplicaSpec{ + trainingoperator.PaddleJobReplicaTypeWorker: { + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "paddle", + Image: "", + }, + }, + }, + }, + }, + }, + }, + }, + wantErr: field.ErrorList{ + field.Required(paddleReplicaSpecPath.Key(string(trainingoperator.PaddleJobReplicaTypeWorker)).Child("template").Child("spec").Child("containers").Index(0).Child("image"), ""), + }, + }, + "paddle default container name doesn't find": { + paddleJob: &trainingoperator.PaddleJob{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + }, + Spec: trainingoperator.PaddleJobSpec{ + PaddleReplicaSpecs: map[trainingoperator.ReplicaType]*trainingoperator.ReplicaSpec{ + trainingoperator.PaddleJobReplicaTypeWorker: { + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "", + Image: "gcr.io/kubeflow-ci/paddle-dist-mnist_test:1.0", + }, + }, + }, + }, + }, + }, + }, + }, + wantErr: field.ErrorList{ + field.Required(paddleReplicaSpecPath.Key(string(trainingoperator.PaddleJobReplicaTypeWorker)).Child("template").Child("spec").Child("containers"), ""), + }, + }, + "replicaSpec is nil": { + paddleJob: &trainingoperator.PaddleJob{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + }, + Spec: trainingoperator.PaddleJobSpec{ + PaddleReplicaSpecs: nil, + }, + }, + wantErr: field.ErrorList{ + field.Required(paddleReplicaSpecPath, ""), + }, + }, + } + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + got := validatePaddleJob(tc.paddleJob) + if diff := cmp.Diff(tc.wantErr, got, cmpopts.IgnoreFields(field.Error{}, "Detail", "BadValue")); len(diff) != 0 { + t.Errorf("Unexpected error (-want,+got):\n%s", diff) + } + }) + } +} diff --git a/pkg/webhooks/webhooks.go b/pkg/webhooks/webhooks.go index 040e7e56df..8bfb23c68f 100644 --- a/pkg/webhooks/webhooks.go +++ b/pkg/webhooks/webhooks.go @@ -20,6 +20,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/manager" trainingoperator "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" + "github.com/kubeflow/training-operator/pkg/webhooks/paddlepaddle" "github.com/kubeflow/training-operator/pkg/webhooks/pytorch" "github.com/kubeflow/training-operator/pkg/webhooks/tensorflow" "github.com/kubeflow/training-operator/pkg/webhooks/xgboost" @@ -34,7 +35,7 @@ var ( trainingoperator.MXJobKind: scaffold, trainingoperator.XGBoostJobKind: xgboost.SetupWebhook, trainingoperator.MPIJobKind: scaffold, - trainingoperator.PaddleJobKind: scaffold, + trainingoperator.PaddleJobKind: paddlepaddle.SetupWebhook, } )