Skip to content

Commit

Permalink
Neurips client (#1693)
Browse files Browse the repository at this point in the history
  • Loading branch information
drisspg committed Aug 16, 2023
1 parent a2f2b68 commit 6d18584
Show file tree
Hide file tree
Showing 6 changed files with 192 additions and 0 deletions.
10 changes: 10 additions & 0 deletions src/helm/benchmark/static/schema.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,16 @@ models:
access: limited
num_parameters: 70000000000
release_date: 2022-01-01

# TODO: Remove Once we have configurable model names
- name: neurips/local
display_name: Local service
description: Local competition service
creator_organization: neurips
access: open
num_parameters: 1
release_date: 2021-12-01


# Anthropic
- name: anthropic/stanford-online-all-v4-s3
Expand Down
28 changes: 28 additions & 0 deletions src/helm/benchmark/window_services/http_model_window_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from .local_window_service import LocalWindowService
from .tokenizer_service import TokenizerService


# TODO: Remove Once we have configurable model names since this hardcodes the tokenizer name
class HTTPModelWindowServce(LocalWindowService):
def __init__(self, service: TokenizerService):
super().__init__(service)

@property
def max_sequence_length(self) -> int:
return 2048

@property
def max_request_length(self) -> int:
return self.max_sequence_length

@property
def end_of_text_token(self) -> str:
return "<|endoftext|>"

@property
def tokenizer_name(self) -> str:
return "neurips/local"

@property
def prefix_token(self) -> str:
return self.end_of_text_token
3 changes: 3 additions & 0 deletions src/helm/benchmark/window_services/window_service_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from .llama_window_service import LlamaWindowService, Llama2WindowService
from .window_service import WindowService
from .tokenizer_service import TokenizerService
from .http_model_window_service import HTTPModelWindowServce
from helm.proxy.clients.huggingface_client import get_huggingface_model_config
from helm.proxy.clients.remote_model_registry import get_remote_model

Expand Down Expand Up @@ -86,6 +87,8 @@ def get_window_service(model_name: str, service: TokenizerService) -> WindowServ
)
elif get_remote_model(model_name):
window_service = get_remote_window_service(service, model_name)
elif organization == "neurips":
window_service = HTTPModelWindowServce(service)
elif huggingface_model_config:
window_service = HuggingFaceWindowService(service=service, model_config=huggingface_model_config)
elif organization == "openai":
Expand Down
5 changes: 5 additions & 0 deletions src/helm/proxy/clients/auto_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from helm.proxy.retry import retry_request, NonRetriableException
from helm.proxy.clients.critique_client import CritiqueClient
from helm.proxy.clients.client import Client
from .http_model_client import HTTPModelClient
from helm.proxy.clients.huggingface_model_registry import get_huggingface_model_config
from helm.proxy.clients.toxicity_classifier_client import ToxicityClassifierClient

Expand Down Expand Up @@ -86,6 +87,8 @@ def _get_client(self, model: str) -> Client:
from helm.proxy.clients.huggingface_client import HuggingFaceClient

client = HuggingFaceClient(cache_config=cache_config)
elif organization == "neurips":
client = HTTPModelClient(cache_config=cache_config)
elif organization == "openai":
from helm.proxy.clients.chat_gpt_client import ChatGPTClient
from helm.proxy.clients.openai_client import OpenAIClient
Expand Down Expand Up @@ -216,6 +219,8 @@ def _get_tokenizer_client(self, tokenizer: str) -> Client:
from helm.proxy.clients.huggingface_client import HuggingFaceClient

client = HuggingFaceClient(cache_config=cache_config)
elif organization == "neurips":
client = HTTPModelClient(cache_config=cache_config)
elif organization in [
"bigscience",
"bigcode",
Expand Down
140 changes: 140 additions & 0 deletions src/helm/proxy/clients/http_model_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
from dataclasses import asdict
from typing import Optional

from helm.common.cache import Cache, CacheConfig
from helm.common.request import (
Request,
RequestResult,
Sequence,
Token,
EMBEDDING_UNAVAILABLE_REQUEST_RESULT,
)
from helm.common.tokenization_request import (
DecodeRequest,
DecodeRequestResult,
TokenizationRequest,
TokenizationRequestResult,
TokenizationToken,
)
from .client import Client, wrap_request_time

import requests


class HTTPModelClient(Client):
"""Implements a simple client for a model being served over HTTP."""

def __init__(
self,
cache_config: CacheConfig,
base_url: str = "http://localhost:8080",
timeout: int = 10,
do_cache: bool = False,
):
self.cache: Optional[Cache] = Cache(cache_config) if do_cache else None
self.base_url = base_url
self.timeout = timeout

def make_request(self, request: Request) -> RequestResult:
cache_key = asdict(request)
# This needs to match whatever we define in pedantic
if request.embedding:
return EMBEDDING_UNAVAILABLE_REQUEST_RESULT

raw_request = {
"prompt": request.prompt,
"temperature": 1e-7 if request.temperature == 0 else request.temperature,
"num_return_sequences": request.num_completions,
"max_new_tokens": request.max_tokens,
"top_p": request.top_p,
"echo_prompt": request.echo_prompt,
"top_k_per_token": request.top_k_per_token,
"stop_sequences": request.stop_sequences,
}

try:

def do_it():
url = f"{self.base_url}/process"
response = requests.post(url, json=raw_request, timeout=self.timeout)
response.raise_for_status()
response_data = response.json()
return response_data

if self.cache:
response, cached = self.cache.get(cache_key, wrap_request_time(do_it))
else:
response, cached = do_it(), False

tokens = [
Token(text=token["text"], logprob=token["logprob"], top_logprobs=token["top_logprob"])
for token in response["tokens"]
]
completions = [Sequence(text=response["text"], logprob=response["logprob"], tokens=tokens)]

return RequestResult(
success=True,
cached=cached,
error=None,
completions=completions,
embedding=[],
request_time=response["request_time"],
)
except requests.exceptions.RequestException as e:
error: str = f"Request error: {e}"
return RequestResult(success=False, cached=False, error=error, completions=[], embedding=[])

def tokenize(self, request: TokenizationRequest) -> TokenizationRequestResult:
cache_key = asdict(request)
raw_request = {
"text": request.text,
"truncation": request.truncation,
"max_length": request.max_length,
}

try:

def do_it():
url = f"{self.base_url}/tokenize"
response = requests.post(url, json=raw_request)
response.raise_for_status()
response_data = response.json()
return response_data

if self.cache:
result, cached = self.cache.get(cache_key, wrap_request_time(do_it))
else:
result, cached = do_it(), False
except Exception as e:
error: str = f"Local Model error: {e}"
return TokenizationRequestResult(success=False, cached=False, error=error, text="", tokens=[])

return TokenizationRequestResult(
success=True,
cached=cached,
text=request.text,
tokens=[TokenizationToken(value) for value in result["tokens"]],
request_time=result["request_time"],
)

def decode(self, request: DecodeRequest) -> DecodeRequestResult:
raise NotImplementedError("Not implemented yet.")
# cache_key = asdict(request)

# try:

# def do_it():
# url = f"{self.base_url}/decode"
# response = requests.post(url, json={"tokens": request.tokens})
# response.raise_for_status()
# response_data = response.json()
# return response_data

# result, cached = self.cache.get(cache_key, wrap_request_time(do_it))
# except Exception as e:
# error: str = f"Local Model error: {e}"
# return DecodeRequestResult(success=False, cached=False, error=error, text="")

# return DecodeRequestResult(
# success=True, cached=cached, text=result["text"], request_time=result["request_time"]
# )
6 changes: 6 additions & 0 deletions src/helm/proxy/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,12 @@ def engine(self) -> str:
# Over time, we should add more information there.

ALL_MODELS = [
# Local Model
Model(
group="neurips",
name="neurips/local",
tags=[TEXT_MODEL_TAG, FULL_FUNCTIONALITY_TEXT_MODEL_TAG, GPT2_TOKENIZER_TAG],
),
# AI21: https://studio.ai21.com/pricing
Model(
group="jurassic",
Expand Down

0 comments on commit 6d18584

Please sign in to comment.