Skip to content
This repository has been archived by the owner on Aug 17, 2023. It is now read-only.

Commit

Permalink
adapt train operator
Browse files Browse the repository at this point in the history
  • Loading branch information
吴雨羲 committed Jan 24, 2022
1 parent fce1c05 commit c4d945a
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 15 deletions.
14 changes: 7 additions & 7 deletions kubeflow/fairing/deployers/pytorchjob/pytorchjob.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
import logging
from kubernetes import client as k8s_client

from kubeflow.pytorchjob import V1ReplicaSpec
from kubeflow.pytorchjob import V1PyTorchJob
from kubeflow.pytorchjob import V1PyTorchJobSpec

from kubeflow.fairing.constants import constants
from kubeflow.fairing.deployers.job.job import Job
from kubeflow.training import V1PyTorchJob
from kubeflow.training import V1PyTorchJobSpec
from kubeflow.training import V1ReplicaSpec, V1RunPolicy
from kubernetes import client as k8s_client

logger = logging.getLogger(__name__)


class PyTorchJob(Job):
""" Handle all the k8s' template building to create pytorch
training job using Kubeflow PyTorch Operator"""

def __init__(self, namespace=None, master_count=1, worker_count=1,
runs=1, job_name=None, stream_log=True, labels=None,
pod_spec_mutators=None, cleanup=False, annotations=None,
Expand Down Expand Up @@ -81,12 +81,12 @@ def generate_deployment_spec(self, pod_template_spec):

pytorchjob = V1PyTorchJob(
api_version=constants.PYTORCH_JOB_GROUP + "/" + \
constants.PYTORCH_JOB_VERSION,
constants.PYTORCH_JOB_VERSION,
kind=constants.PYTORCH_JOB_KIND,
metadata=k8s_client.V1ObjectMeta(name=self.job_name,
generate_name=constants.PYTORCH_JOB_DEFAULT_NAME,
labels=self.labels),
spec=V1PyTorchJobSpec(pytorch_replica_specs=pytorch_replica_specs)
spec=V1PyTorchJobSpec(pytorch_replica_specs=pytorch_replica_specs, run_policy=V1RunPolicy(clean_pod_policy="None"))
)

return pytorchjob
Expand Down
14 changes: 8 additions & 6 deletions kubeflow/fairing/deployers/tfjob/tfjob.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
import logging
import copy
from kubernetes import client as k8s_client
import logging

from kubeflow.tfjob import V1ReplicaSpec
from kubeflow.tfjob import V1TFJob
from kubeflow.tfjob import V1TFJobSpec
from kubernetes import client as k8s_client

from kubeflow.fairing.constants import constants
from kubeflow.fairing.deployers.job.job import Job
from kubeflow.training import V1ReplicaSpec, V1RunPolicy
from kubeflow.training import V1TFJob
from kubeflow.training import V1TFJobSpec

logger = logging.getLogger(__name__)


class TfJob(Job):
""" Handle all the k8s' template building to create tensorflow
training job using Kubeflow TFOperator"""

def __init__(self, namespace=None, worker_count=1, ps_count=0,
chief_count=0, runs=1, job_name=None, stream_log=True,
labels=None, pod_spec_mutators=None, cleanup=False, annotations=None,
Expand Down Expand Up @@ -90,7 +91,8 @@ def generate_deployment_spec(self, pod_template_spec):
metadata=k8s_client.V1ObjectMeta(name=self.job_name,
generate_name=constants.TF_JOB_DEFAULT_NAME,
labels=self.labels),
spec=V1TFJobSpec(tf_replica_specs=tf_replica_specs)
spec=V1TFJobSpec(tf_replica_specs=tf_replica_specs, run_policy=V1RunPolicy(clean_pod_policy="None"))

)

return tfjob
Expand Down
4 changes: 2 additions & 2 deletions kubeflow/fairing/kubernetes/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from kubernetes import client, config, watch
from kfserving import KFServingClient

from kubeflow.tfjob import TFJobClient
from kubeflow.pytorchjob import PyTorchJobClient
from kubeflow.training import TFJobClient
from kubeflow.training import PyTorchJobClient

from kubeflow.fairing.utils import is_running_in_k8s, camel_to_snake
from kubeflow.fairing.constants import constants
Expand Down

0 comments on commit c4d945a

Please sign in to comment.