Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Streaming working with nnsight streaming #60

Open
wants to merge 5 commits into
base: dev
Choose a base branch
from

Conversation

JadenFiotto-Kaufman
Copy link
Member

No description provided.


def __new__(cls, service: str):
"""Singleton pattern to ensure only one instance of the gauge per service."""
if service not in cls._instances:
instance = super(NDIFGauge, cls).__new__(cls)
instance.service = service
instance._gauge = instance._initialize_gauge()
if service != 'ray': # Only initialize the network gauge if the service is not 'ray'
if (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this notation?

import torch
from minio import Minio
from pydantic import BaseModel, ConfigDict
from torch.amp import autocast
from torch.cuda import (max_memory_allocated, memory_allocated,
reset_peak_memory_stats)
from torch.cuda import (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like you have specific preferences on how to format things (using parentheses, vertically listing arguments, using double quotes) which differs from how I do things normally. Maybe we should have a custom linter to standardize things and keep the code style consistent?

@@ -45,6 +56,8 @@ def __init__(
secure=False,
)

self.sio = socketio.SimpleClient(reconnection_attempts=10)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Environment variable?


self.sio.connect(
f"{self.api_url}?job_id={request.id}",
socketio_path="/ws/socket.io",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

environment variables? (socketio path and wait timeout)


await _blocking_response(response)
@sm.on("stream_upload")
async def stream_upload(session_id: str, value: Dict):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't seem like session_id is used here (along with many of the sm.on decorators), is this just because the session manager always passes session_id as the first arg?

@@ -31,31 +42,50 @@ class NumericJobStatus(Enum):
COMPLETED = 4
LOG = 5
ERROR = 6
STREAM = 7
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might want to add a status for "AUTHENTICATED" between "RECEIVED" and "APPROVED"

socketio_path="/ws/socket.io",
transports=["websocket"],
wait_timeout=10,
)

async def __call__(self, request: BackendRequestModel):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

async def __call__(self, request: BackendRequestModel) -> None:

@@ -205,17 +205,21 @@ def pre(self, request: BackendRequestModel):

torch.distributed.barrier()

def post(self, request: BackendRequestModel, result: Any):
def post(self, *args, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The way you space things here is inconsistent with how you space things in this file is inconsistent with how you space things in the other files (e..g base.py)

params = environ.get("QUERY_STRING")
params = dict(x.split("=") for x in params.split("&"))

if "job_id" in params:
Copy link
Member

@MichaelRipa MichaelRipa Sep 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think params.keys() makes more sense from a readability standpoint. Also, what happens if job_id is not in params? Should it log an error or debug?

@sm.on("stream_upload")
async def stream_upload(session_id: str, value: Dict):

value_model = StreamValueModel(**value)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You create a pydantic BaseModel instance from the inputed value json just to create a model key? Isn't there a more clear and efficient way of doing this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This whole file needs docstrings IMO, it is very abstract. Also, I feel like this should be in schema.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cannot comment directly becuase it is from a previous PR, but this shouldn't be hardcoded:

@serve.deployment(
    ray_actor_options={"num_gpus": 1, "num_cpus": 2},
    health_check_timeout_s=1200,
)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additionally, I noticed the following in ModelDeployment.__init__():

            extra_kwargs={"meta_buffers": False, "patch_llama_scan": False},

Is this 405b specific, in that this logic will not allow you to deploy a non llama distributed model? If so, is there a way to have this passed in only for 405b? Otherwise, should this be indicated more clearly?

@AdamBelfki3 AdamBelfki3 self-requested a review September 30, 2024 14:30
@MichaelRipa MichaelRipa mentioned this pull request Sep 30, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants