Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactored all grpc methods in method_table #202

Merged
merged 10 commits into from
Jun 6, 2023
2 changes: 1 addition & 1 deletion examples/local/conversational-query-example.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import mii

# gpt2
name = "microsoft/DialoGPT-medium"
name = "microsoft/DialoGPT-large"

print(f"Querying {name}...")

Expand Down
15 changes: 7 additions & 8 deletions mii/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import mii
from mii.utils import get_task
from mii.grpc_related.proto import modelresponse_pb2, modelresponse_pb2_grpc
from mii.constants import GRPC_MAX_MSG_SIZE
from mii.constants import GRPC_MAX_MSG_SIZE, Tasks
from mii.method_table import GRPC_METHOD_TABLE


Expand Down Expand Up @@ -66,13 +66,10 @@ async def _request_async_response(self, request_dict, **query_kwargs):
if self.task not in GRPC_METHOD_TABLE:
raise ValueError(f"unknown task: {self.task}")

conversions = GRPC_METHOD_TABLE[self.task]
proto_request = conversions["pack_request_to_proto"](request_dict,
**query_kwargs)
proto_response = await getattr(self.stub, conversions["method"])(proto_request)
return conversions["unpack_response_from_proto"](
proto_response
) if "unpack_response_from_proto" in conversions else proto_response
task_methods = GRPC_METHOD_TABLE[self.task]
proto_request = task_methods.pack_request_to_proto(request_dict, **query_kwargs)
proto_response = await getattr(self.stub, task_methods.method)(proto_request)
return task_methods.unpack_response_from_proto(proto_response)

def query(self, request_dict, **query_kwargs):
return self.asyncio_loop.run_until_complete(
Expand All @@ -91,6 +88,7 @@ async def create_session_async(self, session_id):
modelresponse_pb2.SessionID(session_id=session_id))

def create_session(self, session_id):
assert self.task == Tasks.TEXT_GENERATION, f"Session creation only available for task '{Tasks.TEXT_GENERATION}'."
return self.asyncio_loop.run_until_complete(
self.create_session_async(session_id))

Expand All @@ -99,6 +97,7 @@ async def destroy_session_async(self, session_id):
)

def destroy_session(self, session_id):
assert self.task == Tasks.TEXT_GENERATION, f"Session deletion only available for task '{Tasks.TEXT_GENERATION}'."
self.asyncio_loop.run_until_complete(self.destroy_session_async(session_id))


Expand Down
40 changes: 15 additions & 25 deletions mii/grpc_related/modelresponse_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import threading
import time

from mii.constants import GRPC_MAX_MSG_SIZE, CREATE_SESSION_METHOD, DESTROY_SESSION_METHOD, TERMINATE_METHOD, LB_MAX_WORKER_THREADS, SERVER_SHUTDOWN_TIMEOUT
from mii.constants import GRPC_MAX_MSG_SIZE, CREATE_SESSION_METHOD, DESTROY_SESSION_METHOD, TERMINATE_METHOD, LB_MAX_WORKER_THREADS, SERVER_SHUTDOWN_TIMEOUT, Tasks
from mii.method_table import GRPC_METHOD_TABLE
from mii.client import create_channel
from mii.utils import get_task, unpack_proto_query_kwargs
Expand Down Expand Up @@ -41,8 +41,7 @@ class ModelResponse(ServiceBase):
def __init__(self, inference_pipeline):
super().__init__()
self.inference_pipeline = inference_pipeline
self.method_name_to_task = {m["method"]: t for t, m in GRPC_METHOD_TABLE.items()}
self.session_context = {}
self.method_name_to_task = {m.method: t for t, m in GRPC_METHOD_TABLE.items()}
self.lock = threading.Lock()

def _get_model_time(self, model, sum_times=False):
Expand All @@ -63,15 +62,18 @@ def _get_model_time(self, model, sum_times=False):
return model_time

def CreateSession(self, request, context):
if request.session_id in self.session_context:
raise ValueError(f"session {request.session_id} already exists")
self.session_context[request.session_id] = None
task_methods = GRPC_METHOD_TABLE[Tasks.TEXT_GENERATION]
task_methods.create_session(request.session_id)
return google_dot_protobuf_dot_empty__pb2.Empty()

def DestroySession(self, request, context):
if request.session_id not in self.session_context:
raise ValueError(f"session {request.session_id} does not exist")
del self.session_context[request.session_id]
# TODO improve this so the task is not hard-coded
task = self.inference_pipeline.task
if task != "text-generation":
raise Exception("Incorrect task: Cannot destroy session")
else:
task_methods = GRPC_METHOD_TABLE[Tasks.TEXT_GENERATION]
task_methods.destroy_session(request.session_id)
mrwyattii marked this conversation as resolved.
Show resolved Hide resolved
return google_dot_protobuf_dot_empty__pb2.Empty()

def _run_inference(self, method_name, request_proto):
Expand All @@ -82,32 +84,20 @@ def _run_inference(self, method_name, request_proto):
if task not in GRPC_METHOD_TABLE:
raise ValueError(f"unknown task: {task}")

conversions = GRPC_METHOD_TABLE[task]
args, kwargs = conversions["unpack_request_from_proto"](request_proto)

session_id = kwargs.pop("session_id", None)
if session_id and "preprocess_session" in GRPC_METHOD_TABLE[task]:
args, kwargs = GRPC_METHOD_TABLE[task]["preprocess_session"](session_id, self.session_context, args, kwargs)
task_methods = GRPC_METHOD_TABLE[task]
args, kwargs = task_methods.unpack_request_from_proto(request_proto)

start = time.time()
with self.lock:
response = self.inference_pipeline(*args, **kwargs)
response = task_methods.run_inference(self.inference_pipeline, args, kwargs)
end = time.time()

if session_id and "postprocess_session" in GRPC_METHOD_TABLE[task]:
response = GRPC_METHOD_TABLE[task]["postprocess_session"](
session_id,
self.session_context,
args,
kwargs,
response)

model_time = self._get_model_time(self.inference_pipeline.model,
sum_times=True) if hasattr(
self.inference_pipeline,
"model") else -1

return conversions["pack_response_to_proto"](response, end - start, model_time)
return task_methods.pack_response_to_proto(response, end - start, model_time)

def GeneratorReply(self, request, context):
return self._run_inference("GeneratorReply", request)
Expand Down
Loading