Skip to content

Commit

Permalink
Add Non-persistent deployment type (#197)
Browse files Browse the repository at this point in the history
Add Non-persistent deployment type

---------

Co-authored-by: Tosin Segun <insanechills.com>
Co-authored-by: Michael Wyatt <[email protected]>
Co-authored-by: Michael Wyatt <[email protected]>
  • Loading branch information
3 people committed Jun 22, 2023
1 parent dc5ab44 commit dac178e
Show file tree
Hide file tree
Showing 13 changed files with 359 additions and 81 deletions.
18 changes: 18 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,24 @@ mii.deploy(...
mii_config=mii_configs)
```

**Non-persistent Deployment**

You can enable a non-persistent deployment which allows you to make queries without standing up a server. The non-persistent deployment acts as a simplified interface to DeepSpeed-inference for use cases that do not require creating a persistent model server process. Changing the `deployment_type` to `NON_PERSISTENT` in `mii.deploy(...)` will activate this option.

```python
...
mii.deploy(deployment_name = DEPLOYMENT_NAME,
deployment_type=mii.constants.DeploymentType.NON_PERSISTENT
...
)

generator = mii.mii_query_handle(DEPLOYMENT_NAME)
result = generator.query({"query": ["DeepSpeed is", "Seattle is"]}, do_sample=True, max_new_tokens=30})

```

You can find a complete example [here]("https://github.com/microsoft/DeepSpeed-MII/tree/main/examples/non_persistent")

Any HTTP client can be used to call the APIs. An example of using curl is:
```bash
# Assume deployment_name and restful_api_port are set to bloom560m_deployment and 28080 respectively:
Expand Down
16 changes: 16 additions & 0 deletions examples/non_persistent/text-generation-bloom560-example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team
import mii

mii_configs = {"tensor_parallel": 1, "dtype": "fp16"}
name = "bloom560m"
mii.deploy(task='text-generation',
model="bigscience/bloom-560m",
deployment_name=name + "_deployment",
deployment_type=mii.constants.DeploymentType.NON_PERSISTENT,
mii_config=mii_configs)
generator = mii.mii_query_handle(name + "_deployment")
result = generator.query({'query': ["DeepSpeed is the", "Seattle is"]})
print(result)
1 change: 1 addition & 0 deletions mii/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .grpc_related.proto import modelresponse_pb2_grpc

__version__ = "0.0.0"
non_persistent_models = {}
try:
from .version import __version__
except ImportError:
Expand Down
38 changes: 38 additions & 0 deletions mii/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ def mii_query_handle(deployment_name):
Returns:
query_handle: A query handle with a single method `.query(request_dictionary)` using which queries can be sent to the model.
"""

if deployment_name in mii.non_persistent_models:
inference_pipeline, task = mii.non_persistent_models[deployment_name]
return MIINonPersistentClient(task, deployment_name)

task_name, mii_configs = _get_deployment_info(deployment_name)
if mii_configs.enable_load_balancing:
return MIIClient(task_name, "localhost", mii_configs.port_number)
Expand Down Expand Up @@ -156,6 +161,39 @@ def destroy_session(self, session_id):
client.destroy_session(session_id)


class MIINonPersistentClient():
def __init__(self, task, deployment_name):
self.task = task
self.deployment_name = deployment_name

def query(self, request_dict, **query_kwargs):
assert self.deployment_name in mii.non_persistent_models, f"deployment: {self.deployment_name} not found"
task_methods = GRPC_METHOD_TABLE[self.task]
inference_pipeline = mii.non_persistent_models[self.deployment_name][0]

if self.task == Tasks.QUESTION_ANSWERING:
if 'question' not in request_dict or 'context' not in request_dict:
raise Exception(
"Question Answering Task requires 'question' and 'context' keys")
args = (request_dict["question"], request_dict["context"])
kwargs = query_kwargs

elif self.task == Tasks.CONVERSATIONAL:
conv = task_methods.create_conversation(request_dict, **query_kwargs)
args = (conv, )
kwargs = {}

else:
args = (request_dict['query'], )
kwargs = query_kwargs

return task_methods.run_inference(inference_pipeline, args, query_kwargs)

def terminate(self):
print(f"Terminating {self.deployment_name}...")
del mii.non_persistent_models[self.deployment_name]


def terminate_restful_gateway(deployment_name):
_, mii_configs = _get_deployment_info(deployment_name)
if mii_configs.enable_restful_api:
Expand Down
1 change: 1 addition & 0 deletions mii/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
class DeploymentType(enum.Enum):
LOCAL = 1
AML = 2
NON_PERSISTENT = 3


MII_CONFIGS_KEY = 'mii_configs'
Expand Down
42 changes: 29 additions & 13 deletions mii/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
# DeepSpeed Team
import torch
import string

import os
import mii

from deepspeed.launcher.runner import fetch_hostfile

from .constants import DeploymentType, MII_MODEL_PATH_DEFAULT
from .utils import logger
from .constants import DeploymentType, MII_MODEL_PATH_DEFAULT, MODEL_PROVIDER_MAP
from .utils import logger, get_task_name, get_provider_name
from .models.score import create_score_file
from .models import load_models
from .config import ReplicaConfig, LoadBalancerConfig


Expand Down Expand Up @@ -65,6 +66,7 @@ def deploy(task,
If deployment_type is `LOCAL`, returns just the name of the deployment that can be used to create a query handle using `mii.mii_query_handle(deployment_name)`
"""

# parse and validate mii config
mii_config = mii.config.MIIConfig(**mii_config)
if enable_zero:
Expand Down Expand Up @@ -125,21 +127,35 @@ def deploy(task,
lb_config = LoadBalancerConfig(port=mii_config.port_number,
replica_configs=replica_configs)

create_score_file(deployment_name=deployment_name,
deployment_type=deployment_type,
task=task,
model_name=model,
ds_optimize=enable_deepspeed,
ds_zero=enable_zero,
ds_config=ds_config,
mii_config=mii_config,
model_path=model_path,
lb_config=lb_config)
if deployment_type != DeploymentType.NON_PERSISTENT:
create_score_file(deployment_name=deployment_name,
deployment_type=deployment_type,
task=task,
model_name=model,
ds_optimize=enable_deepspeed,
ds_zero=enable_zero,
ds_config=ds_config,
mii_config=mii_config,
model_path=model_path,
lb_config=lb_config)

if deployment_type == DeploymentType.AML:
_deploy_aml(deployment_name=deployment_name, model_name=model, version=version)
elif deployment_type == DeploymentType.LOCAL:
return _deploy_local(deployment_name, model_path=model_path)
elif deployment_type == DeploymentType.NON_PERSISTENT:
assert not mii_config.enable_load_balancing, "Cannot use Load Balancing with Non persistent deployment"
assert int(os.getenv('WORLD_SIZE', '1')) == mii_config.tensor_parallel, "World Size does not equal number of tensors. When using non-persistent deployment type, please launch with `deepspeed --num_gpus <tensor_parallel>`"
provider = MODEL_PROVIDER_MAP[get_provider_name(model, task)]
mii.non_persistent_models[deployment_name] = (load_models(
get_task_name(task),
model,
model_path,
enable_deepspeed,
enable_zero,
provider,
mii_config),
task)
else:
raise Exception(f"Unknown deployment type: {deployment_type}")

Expand Down
31 changes: 24 additions & 7 deletions mii/method_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team
from transformers import Conversation
from abc import ABC, abstractmethod

from transformers import Conversation
from mii.constants import Tasks
from mii.grpc_related.proto import modelresponse_pb2
from mii.utils import kwarg_dict_to_proto, unpack_proto_query_kwargs
Expand Down Expand Up @@ -179,6 +178,28 @@ class ConversationalMethods(TaskMethods):
def method(self):
return "ConversationalReply"

def create_conversation(self, request, **kwargs):
if isinstance(request, dict):
assert 'text' in request and 'past_user_inputs' in request and 'generated_responses' in request, "Conversation requires 'text', 'past_user_inputs', and 'generated_responses' keys"
text = request['text']
conversation_id = request[
'conversation_id'] if 'conversation_id' in request else None
past_user_inputs = request['past_user_inputs']
generated_responses = request['generated_responses']

else:
text = getattr(request, 'text')
conversation_id = getattr(request, 'conversation_id')
past_user_inputs = getattr(request, 'past_user_inputs')
generated_responses = getattr(request, 'generated_responses')

conv = Conversation(text=text,
conversation_id=conversation_id,
past_user_inputs=past_user_inputs,
generated_responses=generated_responses,
**kwargs)
return conv

def pack_response_to_proto(self, conv, time_taken, model_time_taken):
return modelresponse_pb2.ConversationReply(
conversation_id=conv.uuid,
Expand All @@ -189,11 +210,7 @@ def pack_response_to_proto(self, conv, time_taken, model_time_taken):

def unpack_request_from_proto(self, request):
kwargs = unpack_proto_query_kwargs(request.query_kwargs)
conv = Conversation(text=request.text,
conversation_id=request.conversation_id,
past_user_inputs=request.past_user_inputs,
generated_responses=request.generated_responses,
**kwargs)
conv = self.create_conversation(request, **kwargs)
args = (conv, )
kwargs = {}
return args, kwargs
Expand Down
11 changes: 2 additions & 9 deletions mii/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from collections import defaultdict

import mii
from mii.utils import get_num_gpus, logger
from mii.utils import get_num_gpus, logger, get_provider_name
from mii.config import ReplicaConfig


Expand Down Expand Up @@ -120,14 +120,7 @@ def _build_server_args(self,
server_args_str += " --ds-optimize" if ds_optimize else ""

# XXX: fetch model provider based on model name in a more general way
if model_name == "gpt-neox":
provider = mii.constants.MODEL_PROVIDER_NAME_EA
elif ("bigscience/bloom" == model_name) or ("microsoft/bloom" in model_name):
provider = mii.constants.MODEL_PROVIDER_NAME_HF_LLM
elif self.task == mii.Tasks.TEXT2IMG:
provider = mii.constants.MODEL_PROVIDER_NAME_DIFFUSERS
else:
provider = mii.constants.MODEL_PROVIDER_NAME_HF
provider = get_provider_name(model_name, self.task)
server_args_str += f" --provider {provider}"

server_args_str += f" --config {b64_config_str}"
Expand Down
3 changes: 3 additions & 0 deletions mii/terminate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
def terminate(deployment_name):
mii.utils.logger.info(f"Terminating server for {deployment_name}")
generator = mii.mii_query_handle(deployment_name)
if (deployment_name in mii.non_persistent_models):
generator.terminate()
return
try:
generator.query({'query': ''})
except grpc.aio._call.AioRpcError as error:
Expand Down
13 changes: 12 additions & 1 deletion mii/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import importlib
import torch
import mii

from huggingface_hub import HfApi

from mii.constants import (CONVERSATIONAL_NAME,
Expand Down Expand Up @@ -209,6 +208,18 @@ def get_num_gpus(mii_configs):
return num_gpus


def get_provider_name(model_name, task):
if model_name == "gpt-neox":
provider = mii.constants.MODEL_PROVIDER_NAME_EA
elif ("bigscience/bloom" == model_name) or ("microsoft/bloom" in model_name):
provider = mii.constants.MODEL_PROVIDER_NAME_HF_LLM
elif task == mii.Tasks.TEXT2IMG:
provider = mii.constants.MODEL_PROVIDER_NAME_DIFFUSERS
else:
provider = mii.constants.MODEL_PROVIDER_NAME_HF
return provider


log_levels = {
"debug": logging.DEBUG,
"info": logging.INFO,
Expand Down
52 changes: 1 addition & 51 deletions tests/test_local_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,46 +8,16 @@
from types import SimpleNamespace
import json
import requests

from .utils import * # noqa: F401
import mii


def validate_config(config):
if (config.model in ['bert-base-uncased']) and (config.mii_config['dtype']
== 'fp16'):
pytest.skip(f"Model f{config.model} not supported for FP16")
elif config.mii_config['dtype'] == "fp32" and "bloom" in config.model:
pytest.skip('bloom does not support fp32')


''' These fixtures provide default values for the deployment config '''


@pytest.fixture(scope="function", params=['fp16'])
def dtype(request):
return request.param


@pytest.fixture(scope="function", params=[1])
def tensor_parallel(request):
return request.param


@pytest.fixture(scope="function", params=[50050])
def port_number(request):
return request.param


@pytest.fixture(scope="function", params=[False])
def load_with_sys_mem(request):
return request.param


@pytest.fixture(scope="function", params=[False])
def enable_load_balancing(request):
return request.param


@pytest.fixture(scope="function", params=[False])
def enable_restful_api(request):
return request.param
Expand All @@ -58,21 +28,6 @@ def restful_api_port(request):
return request.param


@pytest.fixture(scope="function", params=[True])
def enable_deepspeed(request):
return request.param


@pytest.fixture(scope="function", params=[False])
def enable_zero(request):
return request.param


@pytest.fixture(scope="function", params=[{}])
def ds_config(request):
return request.param


''' These fixtures provide a local deployment and ensure teardown '''


Expand Down Expand Up @@ -130,11 +85,6 @@ def deployment_config(task_name: str,
return config


@pytest.fixture(scope="function", params=[None])
def expected_failure(request):
return request.param


@pytest.fixture(scope="function")
def local_deployment(deployment_config, expected_failure):
if expected_failure is not None:
Expand Down
Loading

0 comments on commit dac178e

Please sign in to comment.