Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Streaming working with nnsight streaming #60

Open
wants to merge 5 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions compose/dev/.env
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# RabbitMQ Ports
DEV_RABBITMQ_PORT=5673
RABBITMQ_INTERNAL_PORT=5672
# Broker Ports
DEV_BROKER_PORT=6379
BROKER_INTERNAL_PORT=6379
BROKER_PROTOCOL=redis://

# MinIO Ports
DEV_MINIO_PORT=27018
Expand Down Expand Up @@ -42,8 +43,6 @@ LOKI_INTERNAL_PORT=3100
N_DEVICES=$N_DEVICES

# Credentials and Other Configs
RABBITMQ_DEFAULT_USER=guest
RABBITMQ_DEFAULT_PASS=guest
GRAFANA_ADMIN_USER=admin
GRAFANA_ADMIN_PASSWORD=admin
HOST_IP=$HOST_IP
Expand Down
11 changes: 4 additions & 7 deletions compose/dev/docker-compose.yml
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
services:
rabbitmq:
image: rabbitmq:3.11.28
environment:
RABBITMQ_DEFAULT_USER: ${RABBITMQ_DEFAULT_USER}
RABBITMQ_DEFAULT_PASS: ${RABBITMQ_DEFAULT_PASS}
message_broker:
image: redis:latest
ports:
- ${DEV_RABBITMQ_PORT}:${RABBITMQ_INTERNAL_PORT}
- ${DEV_BROKER_PORT}:${BROKER_INTERNAL_PORT}

minio:
image: minio/minio:latest
Expand Down Expand Up @@ -48,7 +45,7 @@ services:
- ${DEV_API_PORT}:${API_INTERNAL_PORT}
environment:
OBJECT_STORE_URL: ${HOST_IP}:${DEV_MINIO_PORT}
RMQ_URL: amqp://${RABBITMQ_DEFAULT_USER}:${RABBITMQ_DEFAULT_PASS}@${HOST_IP}:${DEV_RABBITMQ_PORT}/
BROKER_URL: ${BROKER_PROTOCOL}@${HOST_IP}:${DEV_BROKER_PORT}/
WORKERS: 1
RAY_ADDRESS: ray://${HOST_IP}:${DEV_RAY_CLIENT_PORT}
LOKI_URL: http://${HOST_IP}:${DEV_LOKI_PORT}/loki/api/v1/push
Expand Down
194 changes: 147 additions & 47 deletions ray/deployments/base.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,35 @@
import asyncio
import gc
from concurrent.futures import ThreadPoolExecutor, Future, TimeoutError
from functools import partial, wraps
import traceback
import weakref
from concurrent.futures import Future, ThreadPoolExecutor, TimeoutError
from functools import wraps
from typing import Any, Dict

import ray
import socketio
import torch
from minio import Minio
from pydantic import BaseModel, ConfigDict
from torch.amp import autocast
from torch.cuda import (max_memory_allocated, memory_allocated,
reset_peak_memory_stats)
from torch.cuda import (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like you have specific preferences on how to format things (using parentheses, vertically listing arguments, using double quotes) which differs from how I do things normally. Maybe we should have a custom linter to standardize things and keep the code style consistent?

max_memory_allocated,
memory_allocated,
reset_peak_memory_stats,
)
from transformers import PreTrainedModel

from nnsight.contexts.backends.RemoteBackend import RemoteMixin
from nnsight.models.mixins import RemoteableMixin
from nnsight.schema.Request import StreamValueModel

from ...logging import load_logger
from ...metrics import NDIFGauge
from ...schema import (BackendRequestModel, BackendResponseModel,
BackendResultModel)
from ...schema import (
BackendRequestModel,
BackendResponseModel,
BackendResultModel,
)
from . import protocols


Expand All @@ -45,6 +56,8 @@ def __init__(
secure=False,
)

self.sio = socketio.SimpleClient(reconnection_attempts=10)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Environment variable?


self.logger = load_logger(
service_name=str(self.__class__), logger_name="ray.serve"
)
Expand Down Expand Up @@ -84,7 +97,7 @@ def __init__(
dtype: str | torch.dtype,
*args,
extra_kwargs: Dict[str, Any] = {},
**kwargs
**kwargs,
) -> None:

super().__init__(*args, **kwargs)
Expand All @@ -103,42 +116,65 @@ def __init__(
device_map=device_map,
dispatch=dispatch,
torch_dtype=dtype,
**extra_kwargs
**extra_kwargs,
)

if dispatch:
self.model._model.requires_grad_(False)

torch.cuda.empty_cache()

async def __call__(self, request: BackendRequestModel) -> Any:
self.request: BackendRequestModel

protocols.LogProtocol.set(lambda *args: self.log(*args))
protocols.ServerStreamingDownloadProtocol.set(
lambda *args: self.stream_send(*args)
)

protocols.ServerStreamingUploadProtocol.set(
lambda *args: self.stream_receive(*args)
)

async def __call__(self, request: BackendRequestModel) -> None:
"""Executes the model service pipeline:

1.) Pre-processing
2.) Execution
3.) Post-processing
4.) Cleanup

Args:
request (BackendRequestModel): Request.
"""

self.request = weakref.proxy(request)

try:

result = None

protocols.LogProtocol.put(partial(self.log, request=request))
self.pre(request)

self.pre(request)

with autocast(device_type="cuda", dtype=torch.get_default_dtype()):

result = self.execute(request)

if isinstance(result, Future):
result = result.result(timeout=self.execution_timeout)

self.post(request, result)

except TimeoutError as e:

exception = Exception(f"Job took longer than timeout: {self.execution_timeout} seconds")

self.exception(request, exception)


exception = Exception(
f"Job took longer than timeout: {self.execution_timeout} seconds"
)

self.exception(exception)

except Exception as e:

self.exception(request, e)
self.exception(e)

finally:

Expand All @@ -163,17 +199,29 @@ def check_health(self):
### ABSTRACT METHODS #################################

def pre(self, request: BackendRequestModel):
"""Logic to execute before execution.

Args:
request (BackendRequestModel): Request.
result (Any): Result.
"""

request.create_response(
self.respond(
status=BackendResponseModel.JobStatus.RUNNING,
description="Your job has started running.",
logger=self.logger,
gauge=self.gauge,
).respond(self.api_url, self.object_store)

)

request.object = ray.get(request.object)

def execute(self, request: BackendRequestModel):
def execute(self, request: BackendRequestModel) -> Any:
"""Execute request.

Args:
request (BackendRequestModel): Request.

Returns:
Any: Result.
"""

# For tracking peak GPU usage
reset_peak_memory_stats()
Expand All @@ -190,7 +238,13 @@ def execute(self, request: BackendRequestModel):

return result, obj, gpu_mem

def post(self, request: BackendRequestModel, result: Any):
def post(self, request: BackendRequestModel, result: Any) -> None:
"""Logic to execute after execution with result from `.execute`.

Args:
request (BackendRequestModel): Request.
result (Any): Result.
"""

obj: RemoteMixin = result[1]
gpu_mem: int = result[2]
Expand All @@ -201,40 +255,86 @@ def post(self, request: BackendRequestModel, result: Any):
value=obj.remote_backend_postprocess_result(result),
).save(self.object_store)

# Send COMPLETED response.
request.create_response(
self.respond(
status=BackendResponseModel.JobStatus.COMPLETED,
description="Your job has been completed.",
logger=self.logger,
gauge=self.gauge,
gpu_mem=gpu_mem,
).respond(self.api_url, self.object_store)
)

def exception(self, request: BackendRequestModel, exception: Exception):
def exception(self, exception: Exception):
"""Logic to execute of there was an exception.

request.create_response(
status=BackendResponseModel.JobStatus.ERROR,
description=str(exception),
logger=self.logger,
gauge=self.gauge,
).respond(self.api_url, self.object_store)
Args:
exception (Exception): Exception.
"""

description = traceback.format_exc()

self.respond(
status=BackendResponseModel.JobStatus.ERROR, description=description
)

def cleanup(self):
"""Logic to execute to clean up memory after execution result is post-processed."""

if self.sio.connected:

self.sio.disconnect()

self.model._model.zero_grad()

gc.collect()

torch.cuda.empty_cache()

def log(self, data: Any, request: BackendRequestModel):
def log(self, *data):
"""Logic to execute for logging data during execution.

Args:
data (Any): Data to log.
"""

description = "".join([str(_data) for _data in data])

self.respond(
status=BackendResponseModel.JobStatus.LOG, description=description
)

def stream_send(self, data: Any):
"""Logic to execute to stream data back during execution.

Args:
data (Any): Data to stream back.
"""

self.respond(status=BackendResponseModel.JobStatus.STREAM, data=data)

def stream_receive(self, *args):

return StreamValueModel(**self.sio.receive()[1]).deserialize(self.model)

def stream_connect(self):

if self.sio.client is None:

self.sio.connected = False

self.sio.connect(
f"{self.api_url}?job_id={self.request.id}",
socketio_path="/ws/socket.io",
transports=["websocket"],
wait_timeout=10,
)

def respond(self, **kwargs) -> None:

if self.request.session_id is not None:

self.stream_connect()

request.create_response(
status=BackendResponseModel.JobStatus.LOG,
description=str(data),
logger=self.logger,
gauge=self.gauge,
).respond(self.api_url, self.object_store)
self.request.create_response(
**kwargs, logger=self.logger, gauge=self.gauge
).respond(self.sio, self.object_store)


class BaseModelDeploymentArgs(BaseDeploymentArgs):
Expand Down
16 changes: 10 additions & 6 deletions ray/deployments/distributed_model.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cannot comment directly becuase it is from a previous PR, but this shouldn't be hardcoded:

@serve.deployment(
    ray_actor_options={"num_gpus": 1, "num_cpus": 2},
    health_check_timeout_s=1200,
)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additionally, I noticed the following in ModelDeployment.__init__():

            extra_kwargs={"meta_buffers": False, "patch_llama_scan": False},

Is this 405b specific, in that this logic will not allow you to deploy a non llama distributed model? If so, is there a way to have this passed in only for 405b? Otherwise, should this be indicated more clearly?

Original file line number Diff line number Diff line change
Expand Up @@ -197,17 +197,21 @@ def pre(self, request: BackendRequestModel):

torch.distributed.barrier()

def post(self, request: BackendRequestModel, result: Any):
def post(self, *args, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The way you space things here is inconsistent with how you space things in this file is inconsistent with how you space things in the other files (e..g base.py)

if self.head:
super().post(request, result)
super().post(*args, **kwargs)

def exception(self, request: BackendRequestModel, exception: Exception):
def exception(self, *args, **kwargs):
if self.head:
super().exception(request, exception)
super().exception(*args, **kwargs)

def log(self, data: Any, request: BackendRequestModel):
def log(self, *args, **kwargs):
if self.head:
super().log(data, request)
super().log(*args, **kwargs)

def stream_send(self, *args, **kwargs):
if self.head:
super().stream_send(*args, **kwargs)


class DistributedModelDeploymentArgs(BaseModelDeploymentArgs):
Expand Down
11 changes: 9 additions & 2 deletions ray/deployments/model.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import os
from ray import serve

from ...schema.Request import BackendRequestModel

from .base import BaseModelDeployment, BaseModelDeploymentArgs, threaded

from ..util import set_cuda_env_var
class ThreadedModelDeployment(BaseModelDeployment):

@threaded
Expand All @@ -18,7 +19,13 @@ def execute(self, request: BackendRequestModel):
health_check_timeout_s=1200,
)
class ModelDeployment(ThreadedModelDeployment):
pass

def __init__(self, *args, **kwargs):

if os.environ.get("CUDA_VISIBLE_DEVICES","") == "":
set_cuda_env_var()

super().__init__(*args, **kwargs)

def app(args: BaseModelDeploymentArgs) -> serve.Application:

Expand Down
Loading