Skip to content
This repository has been archived by the owner on Aug 17, 2023. It is now read-only.

Commit

Permalink
clean kfsering class
Browse files Browse the repository at this point in the history
  • Loading branch information
吴雨羲 committed Jan 25, 2022
1 parent 935a949 commit 42cbd71
Showing 1 changed file with 4 additions and 28 deletions.
32 changes: 4 additions & 28 deletions kubeflow/fairing/deployers/kfserving/kfserving.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import logging
import uuid

from kfserving import V1alpha2ONNXSpec
from kfserving import V1beta1InferenceService
from kfserving import V1beta1InferenceServiceSpec
from kfserving import V1beta1PredictorSpec
Expand All @@ -37,7 +36,7 @@
class KFServing(DeployerInterface):
"""Serves a prediction endpoint using Kubeflow KFServing."""

def __init__(self, framework, default_storage_uri=None, canary_storage_uri=None,
def __init__(self, framework, default_storage_uri=None,
canary_traffic_percent=0, namespace=None, labels=None, annotations=None,
custom_default_container=None, custom_canary_container=None,
isvc_name=None, stream_log=False, cleanup=False, config_file=None,
Expand All @@ -46,7 +45,6 @@ def __init__(self, framework, default_storage_uri=None, canary_storage_uri=None,
:param framework: The framework for the InferenceService, such as Tensorflow,
XGBoost and ScikitLearn etc.
:param default_storage_uri: URI pointing to Saved Model assets for default service.
:param canary_storage_uri: URI pointing to Saved Model assets for canary service.
:param canary_traffic_percent: The amount of traffic to sent to the canary, defaults to 0.
:param namespace: The k8s namespace where the InferenceService will be deployed.
:param labels: Labels for the InferenceService, separate with commas if have more than one.
Expand All @@ -69,7 +67,6 @@ def __init__(self, framework, default_storage_uri=None, canary_storage_uri=None,
self.framework = framework
self.isvc_name = isvc_name
self.default_storage_uri = default_storage_uri
self.canary_storage_uri = canary_storage_uri
self.canary_traffic_percent = canary_traffic_percent
self.annotations = annotations
self.set_labels(labels)
Expand Down Expand Up @@ -131,24 +128,9 @@ def deploy(self, isvc): # pylint:disable=arguments-differ,unused-argument

def generate_isvc(self):
""" generate InferenceService """

api_version = constants.KFSERVING_GROUP + '/' + constants.KFSERVING_VERSION
default_predictor, canary_predictor = None, None

if self.framework == 'custom':
default_predictor = self.generate_predictor_spec(
self.framework, container=self.custom_default_container)
else:
default_predictor = self.generate_predictor_spec(
self.framework, storage_uri=self.default_storage_uri)

if self.framework != 'custom' and self.canary_storage_uri is not None:
canary_predictor = self.generate_predictor_spec(
self.framework, storage_uri=self.canary_storage_uri)
if self.framework == 'custom' and self.custom_canary_container is not None:
canary_predictor = self.generate_predictor_spec(
self.framework, container=self.custom_canary_container)

default_predictor = self.generate_predictor_spec(
self.framework, storage_uri=self.default_storage_uri)
return V1beta1InferenceService(api_version=api_version,
kind=constants.KFSERVING_KIND,
metadata=k8s_client.V1ObjectMeta(
Expand All @@ -159,16 +141,13 @@ def generate_isvc(self):
predictor=default_predictor
))

def generate_predictor_spec(self, framework, storage_uri=None, container=None):
def generate_predictor_spec(self, framework, storage_uri=None):
'''Generate predictor spec according to framework and
default_storage_uri or custom container.
'''
if self.framework == 'tensorflow':
predictor = V1beta1PredictorSpec(
tensorflow=V1beta1TFServingSpec(storage_uri=storage_uri))
elif self.framework == 'onnx':
predictor = V1beta1PredictorSpec(
onnx=V1alpha2ONNXSpec(storage_uri=storage_uri))
elif self.framework == 'pytorch':
predictor = V1beta1PredictorSpec(
pytorch=V1beta1TorchServeSpec(storage_uri=storage_uri))
Expand All @@ -181,9 +160,6 @@ def generate_predictor_spec(self, framework, storage_uri=None, container=None):
elif self.framework == 'xgboost':
predictor = V1beta1PredictorSpec(
xgboost=V1beta1XGBoostSpec(storage_uri=storage_uri))
# elif self.framework == 'custom':
# predictor = V1beta1PredictorSpec(
# custom=V1alpha2CustomSpec(container=container))
else:
raise RuntimeError("Unsupported framework {}".format(framework))
return predictor
Expand Down

0 comments on commit 42cbd71

Please sign in to comment.