Skip to content

Commit

Permalink
More fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
yifanmai committed Mar 8, 2023
1 parent 1c21fb8 commit 48d91a9
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 7 deletions.
7 changes: 1 addition & 6 deletions src/helm/benchmark/window_services/window_service_factory.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,4 @@
from helm.proxy.models import (
get_model,
get_model_names_with_tag,
Model,
WIDER_CONTEXT_WINDOW_TAG,
)
from helm.proxy.models import get_model, get_model_names_with_tag, Model, WIDER_CONTEXT_WINDOW_TAG
from .ai21_window_service import AI21WindowService
from .anthropic_window_service import AnthropicWindowService
from .cohere_window_service import CohereWindowService
Expand Down
16 changes: 15 additions & 1 deletion src/helm/proxy/clients/auto_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .ai21_client import AI21Client
from .aleph_alpha_client import AlephAlphaClient
from .anthropic_client import AnthropicClient
from .chat_gpt_client import ChatGPTClient
from .cohere_client import CohereClient
from .together_client import TogetherClient
from .google_client import GoogleClient
Expand Down Expand Up @@ -62,12 +63,25 @@ def get_client(self, request: Request) -> Client:
cache_config: CacheConfig = self._build_cache_config(organization)

if organization == "openai":
# TODO: add ChatGPT to the OpenAIClient when it's supported.
# We're using a separate client for now since we're using an unofficial Python library.
# See https://github.com/acheong08/ChatGPT/wiki/Setup on how to get a valid session token.
chat_gpt_client: ChatGPTClient = ChatGPTClient(
session_token=self.credentials.get("chatGPTSessionToken", ""),
lock_file_path=os.path.join(self.cache_path, "ChatGPT.lock"),
# TODO: use `cache_config` above. Since this feature is still experimental,
# save queries and responses in a separate collection.
cache_config=self._build_cache_config("ChatGPT"),
tokenizer_client=self.get_tokenizer_client("huggingface"),
)

org_id = self.credentials.get("openaiOrgId", None)
client = OpenAIClient(
api_key=self.credentials["openaiApiKey"],
cache_config=cache_config,
org_id=org_id,
tokenizer_client=self.get_tokenizer_client("huggingface"),
chat_gpt_client=chat_gpt_client,
org_id=org_id,
)
elif organization == "AlephAlpha":
client = AlephAlphaClient(api_key=self.credentials["alephAlphaKey"], cache_config=cache_config)
Expand Down
7 changes: 7 additions & 0 deletions src/helm/proxy/clients/openai_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
DecodeRequestResult,
)
from .client import Client, truncate_sequence, wrap_request_time
from .chat_gpt_client import ChatGPTClient

ORIGINAL_COMPLETION_ATTRIBUTES = openai.api_resources.completion.Completion.__bases__

Expand All @@ -24,18 +25,24 @@ def __init__(
api_key: str,
cache_config: CacheConfig,
tokenizer_client: Client,
chat_gpt_client: Optional[ChatGPTClient] = None,
org_id: Optional[str] = None,
):
self.org_id: Optional[str] = org_id
self.api_key: str = api_key
self.api_base: str = "https://api.openai.com/v1"
self.cache = Cache(cache_config)
self.tokenizer_client: Client = tokenizer_client
self.chat_gpt_client: Optional[ChatGPTClient] = chat_gpt_client

def _is_chat_model_engine(self, model_engine: str):
return model_engine.startswith("gpt-3.5")

def make_request(self, request: Request) -> RequestResult:
if request.model_engine == "chat-gpt":
assert self.chat_gpt_client is not None
return self.chat_gpt_client.make_request(request)

raw_request: Dict[str, Any]
if request.embedding:
raw_request = {
Expand Down

0 comments on commit 48d91a9

Please sign in to comment.