From c4d945aa8e466f23822bf17b5a1be4bad72b5a30 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=90=B4=E9=9B=A8=E7=BE=B2?= Date: Mon, 24 Jan 2022 18:41:33 +0800 Subject: [PATCH] adapt train operator --- .../fairing/deployers/pytorchjob/pytorchjob.py | 14 +++++++------- kubeflow/fairing/deployers/tfjob/tfjob.py | 14 ++++++++------ kubeflow/fairing/kubernetes/manager.py | 4 ++-- 3 files changed, 17 insertions(+), 15 deletions(-) diff --git a/kubeflow/fairing/deployers/pytorchjob/pytorchjob.py b/kubeflow/fairing/deployers/pytorchjob/pytorchjob.py index 8d70bda9..f6d6dfb9 100644 --- a/kubeflow/fairing/deployers/pytorchjob/pytorchjob.py +++ b/kubeflow/fairing/deployers/pytorchjob/pytorchjob.py @@ -1,12 +1,11 @@ import logging -from kubernetes import client as k8s_client - -from kubeflow.pytorchjob import V1ReplicaSpec -from kubeflow.pytorchjob import V1PyTorchJob -from kubeflow.pytorchjob import V1PyTorchJobSpec from kubeflow.fairing.constants import constants from kubeflow.fairing.deployers.job.job import Job +from kubeflow.training import V1PyTorchJob +from kubeflow.training import V1PyTorchJobSpec +from kubeflow.training import V1ReplicaSpec, V1RunPolicy +from kubernetes import client as k8s_client logger = logging.getLogger(__name__) @@ -14,6 +13,7 @@ class PyTorchJob(Job): """ Handle all the k8s' template building to create pytorch training job using Kubeflow PyTorch Operator""" + def __init__(self, namespace=None, master_count=1, worker_count=1, runs=1, job_name=None, stream_log=True, labels=None, pod_spec_mutators=None, cleanup=False, annotations=None, @@ -81,12 +81,12 @@ def generate_deployment_spec(self, pod_template_spec): pytorchjob = V1PyTorchJob( api_version=constants.PYTORCH_JOB_GROUP + "/" + \ - constants.PYTORCH_JOB_VERSION, + constants.PYTORCH_JOB_VERSION, kind=constants.PYTORCH_JOB_KIND, metadata=k8s_client.V1ObjectMeta(name=self.job_name, generate_name=constants.PYTORCH_JOB_DEFAULT_NAME, labels=self.labels), - spec=V1PyTorchJobSpec(pytorch_replica_specs=pytorch_replica_specs) + spec=V1PyTorchJobSpec(pytorch_replica_specs=pytorch_replica_specs, run_policy=V1RunPolicy(clean_pod_policy="None")) ) return pytorchjob diff --git a/kubeflow/fairing/deployers/tfjob/tfjob.py b/kubeflow/fairing/deployers/tfjob/tfjob.py index 6a2f0b0d..4de92924 100644 --- a/kubeflow/fairing/deployers/tfjob/tfjob.py +++ b/kubeflow/fairing/deployers/tfjob/tfjob.py @@ -1,13 +1,13 @@ -import logging import copy -from kubernetes import client as k8s_client +import logging -from kubeflow.tfjob import V1ReplicaSpec -from kubeflow.tfjob import V1TFJob -from kubeflow.tfjob import V1TFJobSpec +from kubernetes import client as k8s_client from kubeflow.fairing.constants import constants from kubeflow.fairing.deployers.job.job import Job +from kubeflow.training import V1ReplicaSpec, V1RunPolicy +from kubeflow.training import V1TFJob +from kubeflow.training import V1TFJobSpec logger = logging.getLogger(__name__) @@ -15,6 +15,7 @@ class TfJob(Job): """ Handle all the k8s' template building to create tensorflow training job using Kubeflow TFOperator""" + def __init__(self, namespace=None, worker_count=1, ps_count=0, chief_count=0, runs=1, job_name=None, stream_log=True, labels=None, pod_spec_mutators=None, cleanup=False, annotations=None, @@ -90,7 +91,8 @@ def generate_deployment_spec(self, pod_template_spec): metadata=k8s_client.V1ObjectMeta(name=self.job_name, generate_name=constants.TF_JOB_DEFAULT_NAME, labels=self.labels), - spec=V1TFJobSpec(tf_replica_specs=tf_replica_specs) + spec=V1TFJobSpec(tf_replica_specs=tf_replica_specs, run_policy=V1RunPolicy(clean_pod_policy="None")) + ) return tfjob diff --git a/kubeflow/fairing/kubernetes/manager.py b/kubeflow/fairing/kubernetes/manager.py index 71bb5e7c..d18bcb39 100644 --- a/kubeflow/fairing/kubernetes/manager.py +++ b/kubeflow/fairing/kubernetes/manager.py @@ -5,8 +5,8 @@ from kubernetes import client, config, watch from kfserving import KFServingClient -from kubeflow.tfjob import TFJobClient -from kubeflow.pytorchjob import PyTorchJobClient +from kubeflow.training import TFJobClient +from kubeflow.training import PyTorchJobClient from kubeflow.fairing.utils import is_running_in_k8s, camel_to_snake from kubeflow.fairing.constants import constants