Skip to content

Commit

Permalink
rebase + comments
Browse files Browse the repository at this point in the history
  • Loading branch information
drisspg committed Aug 10, 2023
1 parent 826facb commit 02e5583
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from .tokenizer_service import TokenizerService


class NeuripsWindowService(LocalWindowService):
class HTTPModelWindowServce(LocalWindowService):
def __init__(self, service: TokenizerService):
super().__init__(service)

Expand Down
4 changes: 2 additions & 2 deletions src/helm/benchmark/window_services/window_service_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
from .llama_window_service import LlamaWindowService, Llama2WindowService
from .window_service import WindowService
from .tokenizer_service import TokenizerService
from .neruips_local_window_service import NeuripsWindowService
from .httpmodel_window_service import HTTPModelWindowServce
from helm.proxy.clients.huggingface_client import get_huggingface_model_config
from helm.proxy.clients.remote_model_registry import get_remote_model

Expand Down Expand Up @@ -88,7 +88,7 @@ def get_window_service(model_name: str, service: TokenizerService) -> WindowServ
elif get_remote_model(model_name):
window_service = get_remote_window_service(service, model_name)
elif organization == "neurips":
window_service = NeuripsWindowService(service)
window_service = HTTPModelWindowServce(service)
elif huggingface_model_config:
window_service = HuggingFaceWindowService(service=service, model_config=huggingface_model_config)
elif organization == "openai":
Expand Down
6 changes: 3 additions & 3 deletions src/helm/proxy/clients/auto_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from helm.proxy.retry import retry_request, NonRetriableException
from helm.proxy.clients.critique_client import CritiqueClient
from helm.proxy.clients.client import Client
from .http_client import HTTPClient
from .http_client import HTTPModelClient
from helm.proxy.clients.huggingface_model_registry import get_huggingface_model_config
from helm.proxy.clients.toxicity_classifier_client import ToxicityClassifierClient

Expand Down Expand Up @@ -88,7 +88,7 @@ def _get_client(self, model: str) -> Client:

client = HuggingFaceClient(cache_config=cache_config)
elif organization == "neurips":
client = HTTPClient(cache_config=cache_config)
client = HTTPModelClient(cache_config=cache_config)
elif organization == "openai":
from helm.proxy.clients.chat_gpt_client import ChatGPTClient
from helm.proxy.clients.openai_client import OpenAIClient
Expand Down Expand Up @@ -220,7 +220,7 @@ def _get_tokenizer_client(self, tokenizer: str) -> Client:

client = HuggingFaceClient(cache_config=cache_config)
elif organization == "neurips":
client = HTTPClient(cache_config=cache_config)
client = HTTPModelClient(cache_config=cache_config)
elif organization in [
"bigscience",
"bigcode",
Expand Down
21 changes: 10 additions & 11 deletions src/helm/proxy/clients/http_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataclasses import asdict
from typing import Optional

from helm.common.cache import Cache, CacheConfig
from helm.common.request import (
Expand All @@ -20,8 +21,8 @@
import requests


class HTTPClient(Client):
"""Implements a simple HTTP client."""
class HTTPModelClient(Client):
"""Implements a simple client for a model being served over HTTP."""

def __init__(
self,
Expand All @@ -31,25 +32,21 @@ def __init__(
timeout: int = 10,
do_cache: bool = False,
):
self.cache = Cache(cache_config)
self.do_cache = do_cache
self.cache: Optional[Cache] = Cache(cache_config) if do_cache else None
self.base_url = base_url
self.port = port
self.timeout = timeout
self.do_cache = do_cache

def make_request(self, request: Request) -> RequestResult:
cache_key = asdict(request)
# This needs to match whatever we define in pedantic
if request.embedding:
return EMBEDDING_UNAVAILABLE_REQUEST_RESULT

# Only a single stop sequence is supported as we can only pass in a single value for `eos_token_id`
if len(request.stop_sequences) > 1:
raise ValueError("More than one stop sequence is not supported.")

raw_request = {
"prompt": request.prompt,
"temperature": 1e-7 if request.temperature == 0 else request.temperature,
"temperature": request.temperature,
"num_return_sequences": request.num_completions,
"max_new_tokens": request.max_tokens,
"top_p": request.top_p,
Expand Down Expand Up @@ -112,8 +109,10 @@ def do_it():
response.raise_for_status()
response_data = response.json()
return response_data

result, cached = self.cache.get(cache_key, wrap_request_time(do_it))
if self.do_cache:
result, cached = self.cache.get(cache_key, wrap_request_time(do_it))
else:
result, cached = do_it(), False
except Exception as e:
error: str = f"Local Model error: {e}"
return TokenizationRequestResult(
Expand Down

0 comments on commit 02e5583

Please sign in to comment.