diff --git a/config/samples/kubeflow.org_v1_pytorchjob.yaml b/config/samples/kubeflow.org_v1_pytorchjob.yaml deleted file mode 100644 index d08d097256..0000000000 --- a/config/samples/kubeflow.org_v1_pytorchjob.yaml +++ /dev/null @@ -1,7 +0,0 @@ -apiVersion: kubeflow.org/v1 -kind: PyTorchJob -metadata: - name: pytorchjob-sample -spec: - # Add fields here - foo: bar diff --git a/config/samples/kubeflow.org_v1_xgboostjob.yaml b/config/samples/kubeflow.org_v1_xgboostjob.yaml deleted file mode 100644 index d61bfcb27f..0000000000 --- a/config/samples/kubeflow.org_v1_xgboostjob.yaml +++ /dev/null @@ -1,7 +0,0 @@ -apiVersion: kubeflow.org/v1 -kind: XGBoostJob -metadata: - name: xgboostjob-sample -spec: - # Add fields here - foo: bar diff --git a/examples/xgboost/xgboostjob.yaml b/examples/xgboost/xgboostjob.yaml new file mode 100644 index 0000000000..b7a39a997a --- /dev/null +++ b/examples/xgboost/xgboostjob.yaml @@ -0,0 +1,42 @@ +apiVersion: kubeflow.org/v1 +kind: XGBoostJob +metadata: + name: xgboost-dist-iris-test-train +spec: + xgbReplicaSpecs: + Master: + replicas: 1 + restartPolicy: Never + template: + spec: + containers: + - name: xgboostjob + image: docker.io/merlintang/xgboost-dist-iris:1.1 + ports: + - containerPort: 9991 + name: xgboostjob-port + imagePullPolicy: Always + args: + - --job_type=Train + - --xgboost_parameter=objective:multi:softprob,num_class:3 + - --n_estimators=10 + - --learning_rate=0.1 + - --model_path=/tmp/xgboost-model + - --model_storage_type=local + Worker: + replicas: 2 + restartPolicy: ExitCode + template: + spec: + containers: + - name: xgboostjob + image: docker.io/merlintang/xgboost-dist-iris:1.1 + ports: + - containerPort: 9991 + name: xgboostjob-port + imagePullPolicy: Always + args: + - --job_type=Train + - --xgboost_parameter="objective:multi:softprob,num_class:3" + - --n_estimators=10 + - --learning_rate=0.1 diff --git a/main.go b/main.go index d8e135abc7..bb74bcf70a 100644 --- a/main.go +++ b/main.go @@ -34,6 +34,7 @@ import ( pytorchv1 "github.com/kubeflow/tf-operator/pkg/apis/pytorch/v1" xgboostv1 "github.com/kubeflow/tf-operator/pkg/apis/xgboost/v1" pytorchcontroller "github.com/kubeflow/tf-operator/pkg/controller.v1/pytorch" + xgboostcontroller "github.com/kubeflow/tf-operator/pkg/controller.v1/xgboost" //+kubebuilder:scaffold:imports ) @@ -80,10 +81,16 @@ func main() { os.Exit(1) } + // TODO: We need a general manager. all rest reconciler addsToManager + // Based on the user configuration, we start different controllers if err = pytorchcontroller.NewReconciler(mgr).SetupWithManager(mgr); err != nil { setupLog.Error(err, "unable to create controller", "controller", "PyTorchJob") os.Exit(1) } + if err = xgboostcontroller.NewReconciler(mgr).SetupWithManager(mgr); err != nil { + setupLog.Error(err, "unable to create controller", "controller", "XGBoostJob") + os.Exit(1) + } //+kubebuilder:scaffold:builder if err := mgr.AddHealthzCheck("healthz", healthz.Ping); err != nil { diff --git a/pkg/controller.v1/xgboost/expectation.go b/pkg/controller.v1/xgboost/expectation.go new file mode 100644 index 0000000000..27ec7eff4b --- /dev/null +++ b/pkg/controller.v1/xgboost/expectation.go @@ -0,0 +1,110 @@ +/* +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package xgboost + +import ( + "fmt" + commonv1 "github.com/kubeflow/common/pkg/apis/common/v1" + "github.com/kubeflow/common/pkg/controller.v1/common" + "github.com/kubeflow/common/pkg/controller.v1/expectation" + v1 "github.com/kubeflow/tf-operator/pkg/apis/xgboost/v1" + "github.com/sirupsen/logrus" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + utilruntime "k8s.io/apimachinery/pkg/util/runtime" + "sigs.k8s.io/controller-runtime/pkg/event" + "sigs.k8s.io/controller-runtime/pkg/reconcile" +) + +// satisfiedExpectations returns true if the required adds/dels for the given job have been observed. +// Add/del counts are established by the controller at sync time, and updated as controllees are observed by the controller +// manager. +func (r *XGBoostJobReconciler) satisfiedExpectations(xgbJob *v1.XGBoostJob) bool { + satisfied := false + key, err := common.KeyFunc(xgbJob) + if err != nil { + utilruntime.HandleError(fmt.Errorf("couldn't get key for job object %#v: %v", xgbJob, err)) + return false + } + for rtype := range xgbJob.Spec.XGBReplicaSpecs { + // Check the expectations of the pods. + expectationPodsKey := expectation.GenExpectationPodsKey(key, string(rtype)) + satisfied = satisfied || r.Expectations.SatisfiedExpectations(expectationPodsKey) + // Check the expectations of the services. + expectationServicesKey := expectation.GenExpectationServicesKey(key, string(rtype)) + satisfied = satisfied || r.Expectations.SatisfiedExpectations(expectationServicesKey) + } + return satisfied +} + +// onDependentCreateFunc modify expectations when dependent (pod/service) creation observed. +func onDependentCreateFunc(r reconcile.Reconciler) func(event.CreateEvent) bool { + return func(e event.CreateEvent) bool { + xgbr, ok := r.(*XGBoostJobReconciler) + if !ok { + return true + } + rtype := e.Object.GetLabels()[commonv1.ReplicaTypeLabel] + if len(rtype) == 0 { + return false + } + + logrus.Info("Update on create function ", xgbr.ControllerName(), " create object ", e.Object.GetName()) + if controllerRef := metav1.GetControllerOf(e.Object); controllerRef != nil { + var expectKey string + if _, ok := e.Object.(*corev1.Pod); ok { + expectKey = expectation.GenExpectationPodsKey(e.Object.GetNamespace()+"/"+controllerRef.Name, rtype) + } + + if _, ok := e.Object.(*corev1.Service); ok { + expectKey = expectation.GenExpectationServicesKey(e.Object.GetNamespace()+"/"+controllerRef.Name, rtype) + } + xgbr.Expectations.CreationObserved(expectKey) + return true + } + + return true + } +} + +// onDependentDeleteFunc modify expectations when dependent (pod/service) deletion observed. +func onDependentDeleteFunc(r reconcile.Reconciler) func(event.DeleteEvent) bool { + return func(e event.DeleteEvent) bool { + xgbr, ok := r.(*XGBoostJobReconciler) + if !ok { + return true + } + + rtype := e.Object.GetLabels()[commonv1.ReplicaTypeLabel] + if len(rtype) == 0 { + return false + } + + logrus.Info("Update on deleting function ", xgbr.ControllerName(), " delete object ", e.Object.GetName()) + if controllerRef := metav1.GetControllerOf(e.Object); controllerRef != nil { + var expectKey string + if _, ok := e.Object.(*corev1.Pod); ok { + expectKey = expectation.GenExpectationPodsKey(e.Object.GetNamespace()+"/"+controllerRef.Name, rtype) + } + + if _, ok := e.Object.(*corev1.Service); ok { + expectKey = expectation.GenExpectationServicesKey(e.Object.GetNamespace()+"/"+controllerRef.Name, rtype) + } + + xgbr.Expectations.DeletionObserved(expectKey) + return true + } + + return true + } +} diff --git a/pkg/controller.v1/xgboost/job.go b/pkg/controller.v1/xgboost/job.go new file mode 100644 index 0000000000..18044f57bf --- /dev/null +++ b/pkg/controller.v1/xgboost/job.go @@ -0,0 +1,225 @@ +/* +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package xgboost + +import ( + "context" + "fmt" + "reflect" + "sigs.k8s.io/controller-runtime/pkg/log" + + commonv1 "github.com/kubeflow/common/pkg/apis/common/v1" + commonutil "github.com/kubeflow/common/pkg/util" + logger "github.com/kubeflow/common/pkg/util" + xgboostv1 "github.com/kubeflow/tf-operator/pkg/apis/xgboost/v1" + "github.com/sirupsen/logrus" + corev1 "k8s.io/api/core/v1" + k8sv1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" + "k8s.io/client-go/kubernetes/scheme" + "sigs.k8s.io/controller-runtime/pkg/event" + "sigs.k8s.io/controller-runtime/pkg/reconcile" +) + +// Reasons for job events. +const ( + FailedDeleteJobReason = "FailedDeleteJob" + SuccessfulDeleteJobReason = "SuccessfulDeleteJob" + // xgboostJobCreatedReason is added in a job when it is created. + xgboostJobCreatedReason = "XGBoostJobCreated" + + xgboostJobSucceededReason = "XGBoostJobSucceeded" + xgboostJobRunningReason = "XGBoostJobRunning" + xgboostJobFailedReason = "XGBoostJobFailed" + xgboostJobRestartingReason = "XGBoostJobRestarting" +) + +// DeleteJob deletes the job +func (r *XGBoostJobReconciler) DeleteJob(job interface{}) error { + xgboostjob, ok := job.(*xgboostv1.XGBoostJob) + if !ok { + return fmt.Errorf("%+v is not a type of XGBoostJob", xgboostjob) + } + if err := r.Delete(context.Background(), xgboostjob); err != nil { + r.recorder.Eventf(xgboostjob, corev1.EventTypeWarning, FailedDeleteJobReason, "Error deleting: %v", err) + r.Log.Error(err, "failed to delete job", "namespace", xgboostjob.Namespace, "name", xgboostjob.Name) + return err + } + r.recorder.Eventf(xgboostjob, corev1.EventTypeNormal, SuccessfulDeleteJobReason, "Deleted job: %v", xgboostjob.Name) + r.Log.Info("job deleted", "namespace", xgboostjob.Namespace, "name", xgboostjob.Name) + return nil +} + +// GetJobFromInformerCache returns the Job from Informer Cache +func (r *XGBoostJobReconciler) GetJobFromInformerCache(namespace, name string) (metav1.Object, error) { + job := &xgboostv1.XGBoostJob{} + // Default reader for XGBoostJob is cache reader. + err := r.Get(context.Background(), types.NamespacedName{Namespace: namespace, Name: name}, job) + if err != nil { + if errors.IsNotFound(err) { + r.Log.Error(err, "xgboost job not found", "namespace", namespace, "name", name) + } else { + r.Log.Error(err, "failed to get job from api-server", "namespace", namespace, "name", name) + } + return nil, err + } + return job, nil +} + +// GetJobFromAPIClient returns the Job from API server +func (r *XGBoostJobReconciler) GetJobFromAPIClient(namespace, name string) (metav1.Object, error) { + job := &xgboostv1.XGBoostJob{} + + // TODO (Jeffwan@): consider to read from apiserver directly. + err := r.Client.Get(context.Background(), types.NamespacedName{Namespace: namespace, Name: name}, job) + if err != nil { + if errors.IsNotFound(err) { + r.Log.Error(err, "xgboost job not found", "namespace", namespace, "name", name) + } else { + r.Log.Error(err, "failed to get job from api-server", "namespace", namespace, "name", name) + } + return nil, err + } + return job, nil +} + +// UpdateJobStatus updates the job status and job conditions +func (r *XGBoostJobReconciler) UpdateJobStatus(job interface{}, replicas map[commonv1.ReplicaType]*commonv1.ReplicaSpec, jobStatus *commonv1.JobStatus) error { + xgboostJob, ok := job.(*xgboostv1.XGBoostJob) + if !ok { + return fmt.Errorf("%+v is not a type of xgboostJob", xgboostJob) + } + + for rtype, spec := range replicas { + status := jobStatus.ReplicaStatuses[rtype] + + succeeded := status.Succeeded + expected := *(spec.Replicas) - succeeded + running := status.Active + failed := status.Failed + + logrus.Infof("XGBoostJob=%s, ReplicaType=%s expected=%d, running=%d, succeeded=%d , failed=%d", + xgboostJob.Name, rtype, expected, running, succeeded, failed) + + if rtype == commonv1.ReplicaType(xgboostv1.XGBoostReplicaTypeMaster) { + if running > 0 { + msg := fmt.Sprintf("XGBoostJob %s is running.", xgboostJob.Name) + err := commonutil.UpdateJobConditions(jobStatus, commonv1.JobRunning, xgboostJobRunningReason, msg) + if err != nil { + logger.LoggerForJob(xgboostJob).Infof("Append job condition error: %v", err) + return err + } + } + // when master is succeed, the job is finished. + if expected == 0 { + msg := fmt.Sprintf("XGBoostJob %s is successfully completed.", xgboostJob.Name) + logrus.Info(msg) + r.Recorder.Event(xgboostJob, k8sv1.EventTypeNormal, xgboostJobSucceededReason, msg) + if jobStatus.CompletionTime == nil { + now := metav1.Now() + jobStatus.CompletionTime = &now + } + err := commonutil.UpdateJobConditions(jobStatus, commonv1.JobSucceeded, xgboostJobSucceededReason, msg) + if err != nil { + logger.LoggerForJob(xgboostJob).Infof("Append job condition error: %v", err) + return err + } + return nil + } + } + if failed > 0 { + if spec.RestartPolicy == commonv1.RestartPolicyExitCode { + msg := fmt.Sprintf("XGBoostJob %s is restarting because %d %s replica(s) failed.", xgboostJob.Name, failed, rtype) + r.Recorder.Event(xgboostJob, k8sv1.EventTypeWarning, xgboostJobRestartingReason, msg) + err := commonutil.UpdateJobConditions(jobStatus, commonv1.JobRestarting, xgboostJobRestartingReason, msg) + if err != nil { + logger.LoggerForJob(xgboostJob).Infof("Append job condition error: %v", err) + return err + } + } else { + msg := fmt.Sprintf("XGBoostJob %s is failed because %d %s replica(s) failed.", xgboostJob.Name, failed, rtype) + r.Recorder.Event(xgboostJob, k8sv1.EventTypeNormal, xgboostJobFailedReason, msg) + if xgboostJob.Status.CompletionTime == nil { + now := metav1.Now() + xgboostJob.Status.CompletionTime = &now + } + err := commonutil.UpdateJobConditions(jobStatus, commonv1.JobFailed, xgboostJobFailedReason, msg) + if err != nil { + logger.LoggerForJob(xgboostJob).Infof("Append job condition error: %v", err) + return err + } + } + } + } + + // Some workers are still running, leave a running condition. + msg := fmt.Sprintf("XGBoostJob %s is running.", xgboostJob.Name) + logger.LoggerForJob(xgboostJob).Infof(msg) + + if err := commonutil.UpdateJobConditions(jobStatus, commonv1.JobRunning, xgboostJobRunningReason, msg); err != nil { + logger.LoggerForJob(xgboostJob).Error(err, "failed to update XGBoost Job conditions") + return err + } + + return nil +} + +// UpdateJobStatusInApiServer updates the job status in to cluster. +func (r *XGBoostJobReconciler) UpdateJobStatusInApiServer(job interface{}, jobStatus *commonv1.JobStatus) error { + xgboostjob, ok := job.(*xgboostv1.XGBoostJob) + if !ok { + return fmt.Errorf("%+v is not a type of XGBoostJob", xgboostjob) + } + + // Job status passed in differs with status in job, update in basis of the passed in one. + if !reflect.DeepEqual(&xgboostjob.Status.JobStatus, jobStatus) { + xgboostjob = xgboostjob.DeepCopy() + xgboostjob.Status.JobStatus = *jobStatus.DeepCopy() + } + + result := r.Update(context.Background(), xgboostjob) + + if result != nil { + logger.LoggerForJob(xgboostjob).Error(result, "failed to update XGBoost Job conditions in the API server") + return result + } + + return nil +} + +// onOwnerCreateFunc modify creation condition. +func onOwnerCreateFunc(r reconcile.Reconciler) func(event.CreateEvent) bool { + return func(e event.CreateEvent) bool { + xgboostJob, ok := e.Object.(*xgboostv1.XGBoostJob) + if !ok { + return true + } + scheme.Scheme.Default(xgboostJob) + msg := fmt.Sprintf("xgboostJob %s is created.", e.Object.GetName()) + logrus.Info(msg) + //specific the run policy + + if xgboostJob.Spec.RunPolicy.CleanPodPolicy == nil { + xgboostJob.Spec.RunPolicy.CleanPodPolicy = new(commonv1.CleanPodPolicy) + xgboostJob.Spec.RunPolicy.CleanPodPolicy = &defaultCleanPodPolicy + } + + if err := commonutil.UpdateJobConditions(&xgboostJob.Status.JobStatus, commonv1.JobCreated, xgboostJobCreatedReason, msg); err != nil { + log.Log.Error(err, "append job condition error") + return false + } + return true + } +} diff --git a/pkg/controller.v1/xgboost/pod.go b/pkg/controller.v1/xgboost/pod.go new file mode 100644 index 0000000000..254cddca36 --- /dev/null +++ b/pkg/controller.v1/xgboost/pod.go @@ -0,0 +1,142 @@ +/* +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package xgboost + +import ( + "context" + "fmt" + "strconv" + "strings" + + "k8s.io/apimachinery/pkg/api/meta" + "sigs.k8s.io/controller-runtime/pkg/client" + + commonv1 "github.com/kubeflow/common/pkg/apis/common/v1" + xgboostv1 "github.com/kubeflow/tf-operator/pkg/apis/xgboost/v1" + corev1 "k8s.io/api/core/v1" +) + +// GetPodsForJob returns the pods managed by the job. This can be achieved by selecting pods using label key "job-name" +// i.e. all pods created by the job will come with label "job-name" = +func (r *XGBoostJobReconciler) GetPodsForJob(obj interface{}) ([]*corev1.Pod, error) { + job, err := meta.Accessor(obj) + if err != nil { + return nil, err + } + // List all pods to include those that don't match the selector anymore + // but have a ControllerRef pointing to this controller. + podlist := &corev1.PodList{} + err = r.List(context.Background(), podlist, client.MatchingLabels(r.GenLabels(job.GetName()))) + if err != nil { + return nil, err + } + + return convertPodList(podlist.Items), nil +} + +// convertPodList convert pod list to pod point list +func convertPodList(list []corev1.Pod) []*corev1.Pod { + if list == nil { + return nil + } + ret := make([]*corev1.Pod, 0, len(list)) + for i := range list { + ret = append(ret, &list[i]) + } + return ret +} + +// SetPodEnv sets the pod env set for: +// - XGBoost Rabit Tracker and worker +// - LightGBM master and workers +func SetPodEnv(job interface{}, podTemplate *corev1.PodTemplateSpec, rtype, index string) error { + xgboostjob, ok := job.(*xgboostv1.XGBoostJob) + if !ok { + return fmt.Errorf("%+v is not a type of XGBoostJob", xgboostjob) + } + + rank, err := strconv.Atoi(index) + if err != nil { + return err + } + + // Add master offset for worker pods + if strings.ToLower(rtype) == strings.ToLower(string(xgboostv1.XGBoostReplicaTypeWorker)) { + masterSpec := xgboostjob.Spec.XGBReplicaSpecs[commonv1.ReplicaType(xgboostv1.XGBoostReplicaTypeMaster)] + masterReplicas := int(*masterSpec.Replicas) + rank += masterReplicas + } + + masterAddr := computeMasterAddr(xgboostjob.Name, strings.ToLower(string(xgboostv1.XGBoostReplicaTypeMaster)), strconv.Itoa(0)) + + masterPort, err := GetPortFromXGBoostJob(xgboostjob, xgboostv1.XGBoostReplicaTypeMaster) + if err != nil { + return err + } + + totalReplicas := computeTotalReplicas(xgboostjob) + + var workerPort int32 + var workerAddrs []string + + if totalReplicas > 1 { + workerPortTemp, err := GetPortFromXGBoostJob(xgboostjob, xgboostv1.XGBoostReplicaTypeWorker) + if err != nil { + return err + } + workerPort = workerPortTemp + workerAddrs = make([]string, totalReplicas-1) + for i := range workerAddrs { + workerAddrs[i] = computeMasterAddr(xgboostjob.Name, strings.ToLower(string(xgboostv1.XGBoostReplicaTypeWorker)), strconv.Itoa(i)) + } + } + + for i := range podTemplate.Spec.Containers { + if len(podTemplate.Spec.Containers[i].Env) == 0 { + podTemplate.Spec.Containers[i].Env = make([]corev1.EnvVar, 0) + } + podTemplate.Spec.Containers[i].Env = append(podTemplate.Spec.Containers[i].Env, corev1.EnvVar{ + Name: "MASTER_PORT", + Value: strconv.Itoa(int(masterPort)), + }) + podTemplate.Spec.Containers[i].Env = append(podTemplate.Spec.Containers[i].Env, corev1.EnvVar{ + Name: "MASTER_ADDR", + Value: masterAddr, + }) + podTemplate.Spec.Containers[i].Env = append(podTemplate.Spec.Containers[i].Env, corev1.EnvVar{ + Name: "WORLD_SIZE", + Value: strconv.Itoa(int(totalReplicas)), + }) + podTemplate.Spec.Containers[i].Env = append(podTemplate.Spec.Containers[i].Env, corev1.EnvVar{ + Name: "RANK", + Value: strconv.Itoa(rank), + }) + podTemplate.Spec.Containers[i].Env = append(podTemplate.Spec.Containers[i].Env, corev1.EnvVar{ + Name: "PYTHONUNBUFFERED", + Value: "0", + }) + // This variables are used if it is a LightGBM job + if totalReplicas > 1 { + podTemplate.Spec.Containers[i].Env = append(podTemplate.Spec.Containers[i].Env, corev1.EnvVar{ + Name: "WORKER_PORT", + Value: strconv.Itoa(int(workerPort)), + }) + podTemplate.Spec.Containers[i].Env = append(podTemplate.Spec.Containers[i].Env, corev1.EnvVar{ + Name: "WORKER_ADDRS", + Value: strings.Join(workerAddrs, ","), + }) + } + } + + return nil +} diff --git a/pkg/controller.v1/xgboost/service.go b/pkg/controller.v1/xgboost/service.go new file mode 100644 index 0000000000..0edd964510 --- /dev/null +++ b/pkg/controller.v1/xgboost/service.go @@ -0,0 +1,54 @@ +/* +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package xgboost + +import ( + "context" + "fmt" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/meta" + "sigs.k8s.io/controller-runtime/pkg/client" +) + +// GetServicesForJob returns the services managed by the job. This can be achieved by selecting services using label key "job-name" +// i.e. all services created by the job will come with label "job-name" = +func (r *XGBoostJobReconciler) GetServicesForJob(obj interface{}) ([]*corev1.Service, error) { + job, err := meta.Accessor(obj) + if err != nil { + return nil, fmt.Errorf("%+v is not a type of XGBoostJob", job) + } + // List all pods to include those that don't match the selector anymore + // but have a ControllerRef pointing to this controller. + serviceList := &corev1.ServiceList{} + err = r.List(context.Background(), serviceList, client.MatchingLabels(r.GenLabels(job.GetName()))) + if err != nil { + return nil, err + } + + //TODO support adopting/orphaning + ret := convertServiceList(serviceList.Items) + + return ret, nil +} + +// convertServiceList convert service list to service point list +func convertServiceList(list []corev1.Service) []*corev1.Service { + if list == nil { + return nil + } + ret := make([]*corev1.Service, 0, len(list)) + for i := range list { + ret = append(ret, &list[i]) + } + return ret +} diff --git a/pkg/controller.v1/xgboost/suite_test.go b/pkg/controller.v1/xgboost/suite_test.go index 4a7734a170..a8545e66e1 100644 --- a/pkg/controller.v1/xgboost/suite_test.go +++ b/pkg/controller.v1/xgboost/suite_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package controllers +package xgboost import ( "path/filepath" diff --git a/pkg/controller.v1/xgboost/util.go b/pkg/controller.v1/xgboost/util.go new file mode 100644 index 0000000000..f7fd3c3661 --- /dev/null +++ b/pkg/controller.v1/xgboost/util.go @@ -0,0 +1,168 @@ +/* +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package xgboost + +import ( + "fmt" + "os" + "strings" + "time" + + commonv1 "github.com/kubeflow/common/pkg/apis/common/v1" + xgboostv1 "github.com/kubeflow/tf-operator/pkg/apis/xgboost/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + kubeclientset "k8s.io/client-go/kubernetes" + restclientset "k8s.io/client-go/rest" + "sigs.k8s.io/controller-runtime/pkg/client" + volcanoclient "volcano.sh/apis/pkg/client/clientset/versioned" +) + +// TODO (Jeffwan@): Find an elegant way to either use delegatingReader or directly use clientss + +// getClientReaderFromClient try to extract client reader from client, client +// reader reads cluster info from api client. +func getClientReaderFromClient(c client.Client) (client.Reader, error) { + //if dr, err := getDelegatingReader(c); err != nil { + // return nil, err + //} else { + // return dr.ClientReader, nil + //} + + //return dr, nil + + return nil, nil +} + +// getDelegatingReader try to extract DelegatingReader from client. +//func getDelegatingReader(c client.Client) (*client.DelegatingReader, error) { +// dc, ok := c.(*client.DelegatingClient) +// if !ok { +// return nil, errors.New("cannot convert from Client to DelegatingClient") +// } +// dr, ok := dc.Reader.(*client.DelegatingReader) +// if !ok { +// return nil, errors.New("cannot convert from DelegatingClient.Reader to Delegating Reader") +// } +// return dr, nil +//} + +func computeMasterAddr(jobName, rtype, index string) string { + n := jobName + "-" + rtype + "-" + index + return strings.Replace(n, "/", "-", -1) +} + +// GetPortFromXGBoostJob gets the port of xgboost container. +func GetPortFromXGBoostJob(job *xgboostv1.XGBoostJob, rtype xgboostv1.XGBoostJobReplicaType) (int32, error) { + containers := job.Spec.XGBReplicaSpecs[commonv1.ReplicaType(rtype)].Template.Spec.Containers + for _, container := range containers { + if container.Name == xgboostv1.DefaultContainerName { + ports := container.Ports + for _, port := range ports { + if port.Name == xgboostv1.DefaultContainerPortName { + return port.ContainerPort, nil + } + } + } + } + return -1, fmt.Errorf("failed to found the port") +} + +func computeTotalReplicas(obj metav1.Object) int32 { + job := obj.(*xgboostv1.XGBoostJob) + jobReplicas := int32(0) + + if job.Spec.XGBReplicaSpecs == nil || len(job.Spec.XGBReplicaSpecs) == 0 { + return jobReplicas + } + for _, r := range job.Spec.XGBReplicaSpecs { + if r.Replicas == nil { + continue + } else { + jobReplicas += *r.Replicas + } + } + return jobReplicas +} + +func createClientSets(config *restclientset.Config) (kubeclientset.Interface, kubeclientset.Interface, volcanoclient.Interface, error) { + if config == nil { + println("there is an error for the input config") + return nil, nil, nil, nil + } + + kubeClientSet, err := kubeclientset.NewForConfig(restclientset.AddUserAgent(config, "xgboostjob-operator")) + if err != nil { + return nil, nil, nil, err + } + + leaderElectionClientSet, err := kubeclientset.NewForConfig(restclientset.AddUserAgent(config, "leader-election")) + if err != nil { + return nil, nil, nil, err + } + + volcanoClientSet, err := volcanoclient.NewForConfig(restclientset.AddUserAgent(config, "volcano")) + if err != nil { + return nil, nil, nil, err + } + + return kubeClientSet, leaderElectionClientSet, volcanoClientSet, nil +} + +func homeDir() string { + if h := os.Getenv("HOME"); h != "" { + return h + } + return os.Getenv("USERPROFILE") // windows +} + +func isGangSchedulerSet(replicas map[commonv1.ReplicaType]*commonv1.ReplicaSpec) bool { + for _, spec := range replicas { + if spec.Template.Spec.SchedulerName != "" && spec.Template.Spec.SchedulerName == gangSchedulerName { + return true + } + } + return false +} + +// FakeWorkQueue implements RateLimitingInterface but actually does nothing. +type FakeWorkQueue struct{} + +// Add WorkQueue Add method +func (f *FakeWorkQueue) Add(item interface{}) {} + +// Len WorkQueue Len method +func (f *FakeWorkQueue) Len() int { return 0 } + +// Get WorkQueue Get method +func (f *FakeWorkQueue) Get() (item interface{}, shutdown bool) { return nil, false } + +// Done WorkQueue Done method +func (f *FakeWorkQueue) Done(item interface{}) {} + +// ShutDown WorkQueue ShutDown method +func (f *FakeWorkQueue) ShutDown() {} + +// ShuttingDown WorkQueue ShuttingDown method +func (f *FakeWorkQueue) ShuttingDown() bool { return true } + +// AddAfter WorkQueue AddAfter method +func (f *FakeWorkQueue) AddAfter(item interface{}, duration time.Duration) {} + +// AddRateLimited WorkQueue AddRateLimited method +func (f *FakeWorkQueue) AddRateLimited(item interface{}) {} + +// Forget WorkQueue Forget method +func (f *FakeWorkQueue) Forget(item interface{}) {} + +// NumRequeues WorkQueue NumRequeues method +func (f *FakeWorkQueue) NumRequeues(item interface{}) int { return 0 } diff --git a/pkg/controller.v1/xgboost/xgboostjob_controller.go b/pkg/controller.v1/xgboost/xgboostjob_controller.go index 59e6154a90..0392fc5817 100644 --- a/pkg/controller.v1/xgboost/xgboostjob_controller.go +++ b/pkg/controller.v1/xgboost/xgboostjob_controller.go @@ -12,10 +12,22 @@ // See the License for the specific language governing permissions and // limitations under the License. -package controllers +package xgboost import ( "context" + commonv1 "github.com/kubeflow/common/pkg/apis/common/v1" + "github.com/kubeflow/common/pkg/controller.v1/common" + "github.com/kubeflow/common/pkg/controller.v1/control" + "github.com/kubeflow/common/pkg/controller.v1/expectation" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/client-go/kubernetes/scheme" + "k8s.io/client-go/tools/record" + "sigs.k8s.io/controller-runtime/pkg/manager" + "sigs.k8s.io/controller-runtime/pkg/reconcile" "github.com/go-logr/logr" "k8s.io/apimachinery/pkg/runtime" @@ -25,37 +37,182 @@ import ( xgboostv1 "github.com/kubeflow/tf-operator/pkg/apis/xgboost/v1" ) +const ( + controllerName = "xgboostjob-operator" + labelXGBoostJobRole = "xgboostjob-job-role" + gangSchedulerName = "volcano" +) + +var ( + jobOwnerKey = ".metadata.controller" + defaultTTLSeconds = int32(100) + defaultCleanPodPolicy = commonv1.CleanPodPolicyNone +) + +func NewReconciler(mgr manager.Manager) *XGBoostJobReconciler { + r := &XGBoostJobReconciler{ + Client: mgr.GetClient(), + Log: ctrl.Log.WithName("controllers").WithName("XGBoostJob"), + Scheme: mgr.GetScheme(), + } + r.recorder = mgr.GetEventRecorderFor(r.ControllerName()) + + // Create clients. + kubeClientSet, _, volcanoClientSet, err := createClientSets(ctrl.GetConfigOrDie()) + if err != nil { + r.Log.Info("Error building kubeclientset: %s", err.Error()) + } + + // Initialize common job controller + r.JobController = common.JobController{ + Controller: r, + Expectations: expectation.NewControllerExpectations(), + // TODO: add batch scheduler check later. + Config: common.JobControllerConfiguration{EnableGangScheduling: false}, + WorkQueue: &FakeWorkQueue{}, + Recorder: r.recorder, + KubeClientSet: kubeClientSet, + VolcanoClientSet: volcanoClientSet, + PodControl: control.RealPodControl{KubeClient: kubeClientSet, Recorder: r.recorder}, + ServiceControl: control.RealServiceControl{KubeClient: kubeClientSet, Recorder: r.recorder}, + } + + return r +} + // XGBoostJobReconciler reconciles a XGBoostJob object type XGBoostJobReconciler struct { + common.JobController client.Client - Log logr.Logger - Scheme *runtime.Scheme + Log logr.Logger + Scheme *runtime.Scheme + recorder record.EventRecorder } +// Reconcile reads that state of the cluster for a XGBoostJob object and makes changes based on the state read +// and what is in the XGBoostJob.Spec +// Automatically generate RBAC rules to allow the Controller to read and write Deployments +// +kubebuilder:rbac:groups=apps,resources=deployments,verbs=get;list;watch;create;update;patch;delete +// +kubebuilder:rbac:groups=apps,resources=deployments/status,verbs=get;update;patch //+kubebuilder:rbac:groups=kubeflow.org,resources=xgboostjobs,verbs=get;list;watch;create;update;patch;delete //+kubebuilder:rbac:groups=kubeflow.org,resources=xgboostjobs/status,verbs=get;update;patch //+kubebuilder:rbac:groups=kubeflow.org,resources=xgboostjobs/finalizers,verbs=update -// Reconcile is part of the main kubernetes reconciliation loop which aims to -// move the current state of the cluster closer to the desired state. -// TODO(user): Modify the Reconcile function to compare the state specified by -// the XGBoostJob object against the actual cluster state, and then -// perform operations to make the cluster state reflect the state specified by -// the user. -// -// For more details, check Reconcile and its Result here: -// - https://pkg.go.dev/sigs.k8s.io/controller-runtime@v0.7.2/pkg/reconcile func (r *XGBoostJobReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) { - _ = r.Log.WithValues("xgboostjob", req.NamespacedName) + logger := r.Log.WithValues("xgboostjob", req.NamespacedName) + + xgboostjob := &xgboostv1.XGBoostJob{} + err := r.Get(context.Background(), req.NamespacedName, xgboostjob) + if err != nil { + if errors.IsNotFound(err) { + // Object not found, return. Created objects are automatically garbage collected. + // For additional cleanup logic use finalizers. + return ctrl.Result{}, nil + } + // Error reading the object - requeue the request. + return ctrl.Result{}, err + } + + // Check reconcile is required. + needSync := r.satisfiedExpectations(xgboostjob) + + if !needSync || xgboostjob.DeletionTimestamp != nil { + logger.Info("reconcile cancelled, job does not need to do reconcile or has been deleted", + "sync", needSync, "deleted", xgboostjob.DeletionTimestamp != nil) + return reconcile.Result{}, nil + } + + // Set default priorities for xgboost job + scheme.Scheme.Default(xgboostjob) + + // Use common to reconcile the job related pod and service + err = r.ReconcileJobs(xgboostjob, xgboostjob.Spec.XGBReplicaSpecs, xgboostjob.Status.JobStatus, &xgboostjob.Spec.RunPolicy) + if err != nil { + logger.V(2).Error(err, "Reconcile XGBoost Job error") + return ctrl.Result{}, err + } + + return reconcile.Result{}, nil +} + +func (r *XGBoostJobReconciler) ControllerName() string { + return controllerName +} - // your logic here +func (r *XGBoostJobReconciler) GetAPIGroupVersionKind() schema.GroupVersionKind { + return xgboostv1.SchemeBuilder.GroupVersion.WithKind(xgboostv1.Kind) +} + +func (r *XGBoostJobReconciler) GetAPIGroupVersion() schema.GroupVersion { + return xgboostv1.GroupVersion +} + +func (r *XGBoostJobReconciler) GetGroupNameLabelValue() string { + return xgboostv1.GroupName +} - return ctrl.Result{}, nil +func (r *XGBoostJobReconciler) GetDefaultContainerName() string { + return xgboostv1.DefaultContainerName +} + +func (r *XGBoostJobReconciler) GetDefaultContainerPortName() string { + return xgboostv1.DefaultContainerPortName +} + +func (r *XGBoostJobReconciler) GetJobRoleKey() string { + return labelXGBoostJobRole +} + +func (r *XGBoostJobReconciler) IsMasterRole(replicas map[commonv1.ReplicaType]*commonv1.ReplicaSpec, + rtype commonv1.ReplicaType, index int) bool { + return string(rtype) == string(xgboostv1.XGBoostReplicaTypeMaster) +} + +// SetClusterSpec sets the cluster spec for the pod +func (r *XGBoostJobReconciler) SetClusterSpec(job interface{}, podTemplate *corev1.PodTemplateSpec, rtype, index string) error { + return SetPodEnv(job, podTemplate, rtype, index) } // SetupWithManager sets up the controller with the Manager. func (r *XGBoostJobReconciler) SetupWithManager(mgr ctrl.Manager) error { + // setup FieldIndexer to inform the manager that this controller owns pods and services, + // so that it will automatically call Reconcile on the underlying XGBoostJob when a Pod or Service changes, is deleted, etc. + if err := mgr.GetFieldIndexer().IndexField(context.Background(), &corev1.Pod{}, jobOwnerKey, func(rawObj client.Object) []string { + pod := rawObj.(*corev1.Pod) + owner := metav1.GetControllerOf(pod) + if owner == nil { + return nil + } + + // Make sure owner is XGBoostJob Controller. + if owner.APIVersion != r.GetAPIGroupVersion().Version || owner.Kind != r.GetAPIGroupVersionKind().Kind { + return nil + } + + return []string{owner.Name} + }); err != nil { + return err + } + + if err := mgr.GetFieldIndexer().IndexField(context.Background(), &corev1.Service{}, jobOwnerKey, func(rawObj client.Object) []string { + svc := rawObj.(*corev1.Service) + owner := metav1.GetControllerOf(svc) + if owner == nil { + return nil + } + + if owner.APIVersion != r.GetAPIGroupVersion().Version || owner.Kind != r.GetAPIGroupVersionKind().Kind { + return nil + } + + return []string{owner.Name} + }); err != nil { + return err + } + return ctrl.NewControllerManagedBy(mgr). For(&xgboostv1.XGBoostJob{}). + Owns(&corev1.Pod{}). + Owns(&corev1.Service{}). Complete(r) }