-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Streaming working with nnsight streaming #60
base: dev
Are you sure you want to change the base?
Changes from all commits
c59a88f
afb6492
932d62a
2f66083
c4605ab
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,24 +1,35 @@ | ||
import asyncio | ||
import gc | ||
from concurrent.futures import ThreadPoolExecutor, Future, TimeoutError | ||
from functools import partial, wraps | ||
import traceback | ||
import weakref | ||
from concurrent.futures import Future, ThreadPoolExecutor, TimeoutError | ||
from functools import wraps | ||
from typing import Any, Dict | ||
|
||
import ray | ||
import socketio | ||
import torch | ||
from minio import Minio | ||
from pydantic import BaseModel, ConfigDict | ||
from torch.amp import autocast | ||
from torch.cuda import (max_memory_allocated, memory_allocated, | ||
reset_peak_memory_stats) | ||
from torch.cuda import ( | ||
max_memory_allocated, | ||
memory_allocated, | ||
reset_peak_memory_stats, | ||
) | ||
from transformers import PreTrainedModel | ||
|
||
from nnsight.contexts.backends.RemoteBackend import RemoteMixin | ||
from nnsight.models.mixins import RemoteableMixin | ||
from nnsight.schema.Request import StreamValueModel | ||
|
||
from ...logging import load_logger | ||
from ...metrics import NDIFGauge | ||
from ...schema import (BackendRequestModel, BackendResponseModel, | ||
BackendResultModel) | ||
from ...schema import ( | ||
BackendRequestModel, | ||
BackendResponseModel, | ||
BackendResultModel, | ||
) | ||
from . import protocols | ||
|
||
|
||
|
@@ -45,6 +56,8 @@ def __init__( | |
secure=False, | ||
) | ||
|
||
self.sio = socketio.SimpleClient(reconnection_attempts=10) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Environment variable? |
||
|
||
self.logger = load_logger( | ||
service_name=str(self.__class__), logger_name="ray.serve" | ||
) | ||
|
@@ -84,7 +97,7 @@ def __init__( | |
dtype: str | torch.dtype, | ||
*args, | ||
extra_kwargs: Dict[str, Any] = {}, | ||
**kwargs | ||
**kwargs, | ||
) -> None: | ||
|
||
super().__init__(*args, **kwargs) | ||
|
@@ -103,42 +116,65 @@ def __init__( | |
device_map=device_map, | ||
dispatch=dispatch, | ||
torch_dtype=dtype, | ||
**extra_kwargs | ||
**extra_kwargs, | ||
) | ||
|
||
if dispatch: | ||
self.model._model.requires_grad_(False) | ||
|
||
torch.cuda.empty_cache() | ||
|
||
async def __call__(self, request: BackendRequestModel) -> Any: | ||
self.request: BackendRequestModel | ||
|
||
protocols.LogProtocol.set(lambda *args: self.log(*args)) | ||
protocols.ServerStreamingDownloadProtocol.set( | ||
lambda *args: self.stream_send(*args) | ||
) | ||
|
||
protocols.ServerStreamingUploadProtocol.set( | ||
lambda *args: self.stream_receive(*args) | ||
) | ||
|
||
async def __call__(self, request: BackendRequestModel) -> None: | ||
"""Executes the model service pipeline: | ||
|
||
1.) Pre-processing | ||
2.) Execution | ||
3.) Post-processing | ||
4.) Cleanup | ||
|
||
Args: | ||
request (BackendRequestModel): Request. | ||
""" | ||
|
||
self.request = weakref.proxy(request) | ||
|
||
try: | ||
|
||
result = None | ||
|
||
protocols.LogProtocol.put(partial(self.log, request=request)) | ||
self.pre(request) | ||
|
||
self.pre(request) | ||
|
||
with autocast(device_type="cuda", dtype=torch.get_default_dtype()): | ||
|
||
result = self.execute(request) | ||
|
||
if isinstance(result, Future): | ||
result = result.result(timeout=self.execution_timeout) | ||
|
||
self.post(request, result) | ||
|
||
except TimeoutError as e: | ||
|
||
exception = Exception(f"Job took longer than timeout: {self.execution_timeout} seconds") | ||
|
||
self.exception(request, exception) | ||
|
||
|
||
exception = Exception( | ||
f"Job took longer than timeout: {self.execution_timeout} seconds" | ||
) | ||
|
||
self.exception(exception) | ||
|
||
except Exception as e: | ||
|
||
self.exception(request, e) | ||
self.exception(e) | ||
|
||
finally: | ||
|
||
|
@@ -163,17 +199,29 @@ def check_health(self): | |
### ABSTRACT METHODS ################################# | ||
|
||
def pre(self, request: BackendRequestModel): | ||
"""Logic to execute before execution. | ||
|
||
Args: | ||
request (BackendRequestModel): Request. | ||
result (Any): Result. | ||
""" | ||
|
||
request.create_response( | ||
self.respond( | ||
status=BackendResponseModel.JobStatus.RUNNING, | ||
description="Your job has started running.", | ||
logger=self.logger, | ||
gauge=self.gauge, | ||
).respond(self.api_url, self.object_store) | ||
|
||
) | ||
|
||
request.object = ray.get(request.object) | ||
|
||
def execute(self, request: BackendRequestModel): | ||
def execute(self, request: BackendRequestModel) -> Any: | ||
"""Execute request. | ||
|
||
Args: | ||
request (BackendRequestModel): Request. | ||
|
||
Returns: | ||
Any: Result. | ||
""" | ||
|
||
# For tracking peak GPU usage | ||
reset_peak_memory_stats() | ||
|
@@ -190,7 +238,13 @@ def execute(self, request: BackendRequestModel): | |
|
||
return result, obj, gpu_mem | ||
|
||
def post(self, request: BackendRequestModel, result: Any): | ||
def post(self, request: BackendRequestModel, result: Any) -> None: | ||
"""Logic to execute after execution with result from `.execute`. | ||
|
||
Args: | ||
request (BackendRequestModel): Request. | ||
result (Any): Result. | ||
""" | ||
|
||
obj: RemoteMixin = result[1] | ||
gpu_mem: int = result[2] | ||
|
@@ -201,40 +255,86 @@ def post(self, request: BackendRequestModel, result: Any): | |
value=obj.remote_backend_postprocess_result(result), | ||
).save(self.object_store) | ||
|
||
# Send COMPLETED response. | ||
request.create_response( | ||
self.respond( | ||
status=BackendResponseModel.JobStatus.COMPLETED, | ||
description="Your job has been completed.", | ||
logger=self.logger, | ||
gauge=self.gauge, | ||
gpu_mem=gpu_mem, | ||
).respond(self.api_url, self.object_store) | ||
) | ||
|
||
def exception(self, request: BackendRequestModel, exception: Exception): | ||
def exception(self, exception: Exception): | ||
"""Logic to execute of there was an exception. | ||
|
||
request.create_response( | ||
status=BackendResponseModel.JobStatus.ERROR, | ||
description=str(exception), | ||
logger=self.logger, | ||
gauge=self.gauge, | ||
).respond(self.api_url, self.object_store) | ||
Args: | ||
exception (Exception): Exception. | ||
""" | ||
|
||
description = traceback.format_exc() | ||
|
||
self.respond( | ||
status=BackendResponseModel.JobStatus.ERROR, description=description | ||
) | ||
|
||
def cleanup(self): | ||
"""Logic to execute to clean up memory after execution result is post-processed.""" | ||
|
||
if self.sio.connected: | ||
|
||
self.sio.disconnect() | ||
|
||
self.model._model.zero_grad() | ||
|
||
gc.collect() | ||
|
||
torch.cuda.empty_cache() | ||
|
||
def log(self, data: Any, request: BackendRequestModel): | ||
def log(self, *data): | ||
"""Logic to execute for logging data during execution. | ||
|
||
Args: | ||
data (Any): Data to log. | ||
""" | ||
|
||
description = "".join([str(_data) for _data in data]) | ||
|
||
self.respond( | ||
status=BackendResponseModel.JobStatus.LOG, description=description | ||
) | ||
|
||
def stream_send(self, data: Any): | ||
"""Logic to execute to stream data back during execution. | ||
|
||
Args: | ||
data (Any): Data to stream back. | ||
""" | ||
|
||
self.respond(status=BackendResponseModel.JobStatus.STREAM, data=data) | ||
|
||
def stream_receive(self, *args): | ||
|
||
return StreamValueModel(**self.sio.receive()[1]).deserialize(self.model) | ||
|
||
def stream_connect(self): | ||
|
||
if self.sio.client is None: | ||
|
||
self.sio.connected = False | ||
|
||
self.sio.connect( | ||
f"{self.api_url}?job_id={self.request.id}", | ||
socketio_path="/ws/socket.io", | ||
transports=["websocket"], | ||
wait_timeout=10, | ||
) | ||
|
||
def respond(self, **kwargs) -> None: | ||
|
||
if self.request.session_id is not None: | ||
|
||
self.stream_connect() | ||
|
||
request.create_response( | ||
status=BackendResponseModel.JobStatus.LOG, | ||
description=str(data), | ||
logger=self.logger, | ||
gauge=self.gauge, | ||
).respond(self.api_url, self.object_store) | ||
self.request.create_response( | ||
**kwargs, logger=self.logger, gauge=self.gauge | ||
).respond(self.sio, self.object_store) | ||
|
||
|
||
class BaseModelDeploymentArgs(BaseDeploymentArgs): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Cannot comment directly becuase it is from a previous PR, but this shouldn't be hardcoded:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Additionally, I noticed the following in
Is this 405b specific, in that this logic will not allow you to deploy a non llama distributed model? If so, is there a way to have this passed in only for 405b? Otherwise, should this be indicated more clearly? |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -197,17 +197,21 @@ def pre(self, request: BackendRequestModel): | |
|
||
torch.distributed.barrier() | ||
|
||
def post(self, request: BackendRequestModel, result: Any): | ||
def post(self, *args, **kwargs): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The way you space things here is inconsistent with how you space things in this file is inconsistent with how you space things in the other files (e..g |
||
if self.head: | ||
super().post(request, result) | ||
super().post(*args, **kwargs) | ||
|
||
def exception(self, request: BackendRequestModel, exception: Exception): | ||
def exception(self, *args, **kwargs): | ||
if self.head: | ||
super().exception(request, exception) | ||
super().exception(*args, **kwargs) | ||
|
||
def log(self, data: Any, request: BackendRequestModel): | ||
def log(self, *args, **kwargs): | ||
if self.head: | ||
super().log(data, request) | ||
super().log(*args, **kwargs) | ||
|
||
def stream_send(self, *args, **kwargs): | ||
if self.head: | ||
super().stream_send(*args, **kwargs) | ||
|
||
|
||
class DistributedModelDeploymentArgs(BaseModelDeploymentArgs): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems like you have specific preferences on how to format things (using parentheses, vertically listing arguments, using double quotes) which differs from how I do things normally. Maybe we should have a custom linter to standardize things and keep the code style consistent?