Skip to content

Commit

Permalink
add chat_completion (#31)
Browse files Browse the repository at this point in the history
* add chat_completion in BaseLLMAPIClient + AnthropicClient

* add get_chat_tokens_count in BaseLLMAPIClient + AnthropicClient + OpenAIClient
  • Loading branch information
uripeled2 authored Jul 22, 2023
1 parent 980f72e commit 66af541
Show file tree
Hide file tree
Showing 8 changed files with 238 additions and 47 deletions.
33 changes: 29 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@ any flexibility (API params, endpoints etc.). *We also provide sync version, see
more details below in Usage section.

## Base Interface
The package exposes two simple interfaces for communicating with LLMs (In the future, we
The package exposes two simple interfaces for seamless integration with LLMs (In the future, we
will expand the interface to support more tasks like list models, edits, etc.):
```python
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, Optional
from enum import Enum
from dataclasses_json import dataclass_json, config
from aiohttp import ClientSession


Expand All @@ -30,6 +32,20 @@ class BaseLLMClient(ABC):
raise NotImplementedError()


class Role(Enum):
SYSTEM = "system"
USER = "user"
ASSISTANT = "assistant"


@dataclass_json
@dataclass
class ChatMessage:
role: Role = field(metadata=config(encoder=lambda role: role.value, decoder=Role))
content: str
name: Optional[str] = field(default=None, metadata=config(exclude=lambda name: name is None))
example: bool = field(default=False, metadata=config(exclude=lambda _: True))


@dataclass
class LLMAPIClientConfig:
Expand All @@ -49,8 +65,15 @@ class BaseLLMAPIClient(BaseLLMClient, ABC):
temperature: Optional[float] = None, top_p: Optional[float] = None, **kwargs) -> list[str]:
raise NotImplementedError()

async def chat_completion(self, messages: list[ChatMessage], temperature: float = 0,
max_tokens: int = 16, model: Optional[str] = None, **kwargs) -> list[str]:
raise NotImplementedError()

async def embedding(self, text: str, model: Optional[str] = None, **kwargs) -> list[float]:
raise NotImplementedError()

async def get_chat_tokens_count(self, messages: list[ChatMessage], **kwargs) -> int:
raise NotImplementedError()
```

## Requirements
Expand Down Expand Up @@ -109,10 +132,12 @@ async def main():
llm_client = OpenAIClient(LLMAPIClientConfig(OPENAI_API_KEY, session, default_model="text-davinci-003",
headers={"OpenAI-Organization": OPENAI_ORG_ID})) # The headers are optional
text = "This is indeed a test"
messages = [ChatMessage(role=Role.USER, content="Hello!"),
ChatMessage(role=Role.SYSTEM, content="Hi there! How can I assist you today?")]

print("number of tokens:", await llm_client.get_tokens_count(text)) # 5
print("generated chat:", await llm_client.chat_completion(
messages=[ChatMessage(role=Role.USER, content="Hello!")], model="gpt-3.5-turbo")) # ['Hi there! How can I assist you today?']
print("number of tokens for chat completion:", await llm_client.get_chat_tokens_count(messages, model="gpt-3.5-turbo")) # 23
print("generated chat:", await llm_client.chat_completion(messages, model="gpt-3.5-turbo")) # ['Hi there! How can I assist you today?']
print("generated text:", await llm_client.text_completion(text)) # [' string\n\nYes, this is a test string. Test strings are used to']
print("generated embedding:", await llm_client.embedding(text)) # [0.0023064255, -0.009327292, ...]
```
Expand Down Expand Up @@ -190,7 +215,7 @@ Contributions are welcome! Please check out the todos below, and feel free to op
- [ ] Cohere
- [x] Add support for more functions via LLMs
- [x] embeddings
- [ ] chat
- [x] chat
- [ ] list models
- [ ] edits
- [ ] more
Expand Down
6 changes: 3 additions & 3 deletions llm_client/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
__version__ = "0.7.0"
__version__ = "0.8.0"

from llm_client.base_llm_client import BaseLLMClient

# load api clients
try:
from llm_client.llm_api_client.base_llm_api_client import BaseLLMAPIClient, LLMAPIClientConfig
from llm_client.llm_api_client.base_llm_api_client import BaseLLMAPIClient, LLMAPIClientConfig, ChatMessage, Role
from llm_client.llm_api_client.llm_api_client_factory import LLMAPIClientFactory, LLMAPIClientType
# load base-api clients
try:
Expand All @@ -15,7 +15,7 @@
pass
# load apis with different dependencies
try:
from llm_client.llm_api_client.openai_client import OpenAIClient, ChatMessage, Role
from llm_client.llm_api_client.openai_client import OpenAIClient
except ImportError:
pass
try:
Expand Down
32 changes: 31 additions & 1 deletion llm_client/llm_api_client/anthropic_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from anthropic import AsyncAnthropic

from llm_client.llm_api_client.base_llm_api_client import BaseLLMAPIClient, LLMAPIClientConfig
from llm_client.llm_api_client.base_llm_api_client import BaseLLMAPIClient, LLMAPIClientConfig, ChatMessage, Role
from llm_client.consts import PROMPT_KEY

COMPLETE_PATH = "complete"
Expand All @@ -13,6 +13,11 @@
VERSION_HEADER = "anthropic-version"
ACCEPT_VALUE = "application/json"
MAX_TOKENS_KEY = "max_tokens_to_sample"
USER_PREFIX = "Human:"
ASSISTANT_PREFIX = "Assistant:"
START_PREFIX = "\n\n"
SYSTEM_START_PREFIX = "<admin>"
SYSTEM_END_PREFIX = "</admin>"


class AnthropicClient(BaseLLMAPIClient):
Expand All @@ -26,6 +31,10 @@ def __init__(self, config: LLMAPIClientConfig):
self._headers[ACCEPT_HEADER] = ACCEPT_VALUE
self._headers[AUTH_HEADER] = self._api_key

async def chat_completion(self, messages: list[ChatMessage], model: Optional[str] = None,
max_tokens: Optional[int] = None, temperature: float = 1, **kwargs) -> list[str]:
return await self.text_completion(self.messages_to_text(messages), model, max_tokens, temperature, **kwargs)

async def text_completion(self, prompt: str, model: Optional[str] = None, max_tokens: Optional[int] = None,
temperature: float = 1, top_p: Optional[float] = None,
**kwargs) -> \
Expand All @@ -45,5 +54,26 @@ async def text_completion(self, prompt: str, model: Optional[str] = None, max_to
response_json = await response.json()
return [response_json[COMPLETIONS_KEY]]

async def get_chat_tokens_count(self, messages: list[ChatMessage], **kwargs) -> int:
return await self.get_tokens_count(self.messages_to_text(messages), **kwargs)

async def get_tokens_count(self, text: str, **kwargs) -> int:
return await self._anthropic.count_tokens(text)

def messages_to_text(self, messages: list[ChatMessage]) -> str:
prompt = START_PREFIX
prompt += START_PREFIX.join(map(self._message_to_prompt, messages))
if messages[-1].role != Role.ASSISTANT:
prompt += START_PREFIX
prompt += self._message_to_prompt(ChatMessage(role=Role.ASSISTANT, content=""))
return prompt.rstrip()

@staticmethod
def _message_to_prompt(message: ChatMessage) -> str:
if message.role == Role.USER:
return f"{USER_PREFIX} {message.content}"
if message.role == Role.ASSISTANT:
return f"{ASSISTANT_PREFIX} {message.content}"
if message.role == Role.SYSTEM:
return f"{USER_PREFIX} {SYSTEM_START_PREFIX}{message.content}{SYSTEM_END_PREFIX}"
raise ValueError(f"Unknown role: {message.role}")
25 changes: 25 additions & 0 deletions llm_client/llm_api_client/base_llm_api_client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Optional

from dataclasses_json import dataclass_json, config

try:
from aiohttp import ClientSession
except ImportError:
Expand All @@ -11,6 +14,21 @@
from llm_client.consts import MODEL_KEY


class Role(Enum):
SYSTEM = "system"
USER = "user"
ASSISTANT = "assistant"


@dataclass_json
@dataclass
class ChatMessage:
role: Role = field(metadata=config(encoder=lambda role: role.value, decoder=Role))
content: str
name: Optional[str] = field(default=None, metadata=config(exclude=lambda name: name is None))
example: bool = field(default=False, metadata=config(exclude=lambda _: True))


@dataclass
class LLMAPIClientConfig:
api_key: str
Expand All @@ -33,9 +51,16 @@ async def text_completion(self, prompt: str, model: Optional[str] = None, max_to
temperature: Optional[float] = None, top_p: Optional[float] = None, **kwargs) -> list[str]:
raise NotImplementedError()

async def chat_completion(self, messages: list[ChatMessage], temperature: float = 0,
max_tokens: int = 16, model: Optional[str] = None, **kwargs) -> list[str]:
raise NotImplementedError()

async def embedding(self, text: str, model: Optional[str] = None, **kwargs) -> list[float]:
raise NotImplementedError()

async def get_chat_tokens_count(self, messages: list[ChatMessage], **kwargs) -> int:
raise NotImplementedError()

def _set_model_in_kwargs(self, kwargs, model: Optional[str]) -> None:
if model is not None:
kwargs[MODEL_KEY] = model
Expand Down
77 changes: 58 additions & 19 deletions llm_client/llm_api_client/openai_client.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,24 @@
from dataclasses import dataclass, field
from enum import Enum
from functools import lru_cache
from typing import Optional

import openai
import tiktoken
from dataclasses_json import dataclass_json, config
from tiktoken import Encoding

from llm_client.llm_api_client.base_llm_api_client import BaseLLMAPIClient, LLMAPIClientConfig
from llm_client.llm_api_client.base_llm_api_client import BaseLLMAPIClient, LLMAPIClientConfig, ChatMessage
from llm_client.consts import PROMPT_KEY

INPUT_KEY = "input"


class Role(Enum):
SYSTEM = "system"
USER = "user"
ASSISTANT = "assistant"


@dataclass_json
@dataclass
class ChatMessage:
role: Role = field(metadata=config(encoder=lambda role: role.value, decoder=Role))
content: str
name: Optional[str] = field(default=None, metadata=config(exclude=lambda name: name is None))
MODEL_NAME_TO_TOKENS_PER_MESSAGE_AND_TOKENS_PER_NAME = {
"gpt-3.5-turbo-0613": (3, 1),
"gpt-3.5-turbo-16k-0613": (3, 1),
"gpt-4-0314": (3, 1),
"gpt-4-32k-0314": (3, 1),
"gpt-4-0613": (3, 1),
"gpt-4-32k-0613": (3, 1),
# every message follows <|start|>{role/name}\n{content}<|end|>\n, if there's a name, the role is omitted
"gpt-3.5-turbo-0301": (4, -1),
}


class OpenAIClient(BaseLLMAPIClient):
Expand All @@ -46,7 +39,8 @@ async def text_completion(self, prompt: str, model: Optional[str] = None, temper
return [choice.text for choice in completions.choices]

async def chat_completion(self, messages: list[ChatMessage], temperature: float = 0,
max_tokens: int = 16, top_p: float = 1, model: Optional[str] = None, **kwargs) -> list[str]:
max_tokens: int = 16, top_p: float = 1, model: Optional[str] = None, **kwargs) \
-> list[str]:
self._set_model_in_kwargs(kwargs, model)
kwargs["messages"] = [message.to_dict() for message in messages]
kwargs["temperature"] = temperature
Expand All @@ -66,6 +60,51 @@ async def get_tokens_count(self, text: str, model: Optional[str] = None, **kwarg
model = self._default_model
return len(self._get_relevant_tokeniser(model).encode(text))

async def get_chat_tokens_count(self, messages: list[ChatMessage], model: Optional[str] = None) -> int:
"""
This is based on:
https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
"""
model = self._get_model_name_for_tokeniser(model)
encoding = self._get_relevant_tokeniser(model)
tokens_per_message, tokens_per_name = MODEL_NAME_TO_TOKENS_PER_MESSAGE_AND_TOKENS_PER_NAME[model]
num_tokens = 0
for message in messages:
num_tokens += tokens_per_message
num_tokens += len(encoding.encode(message.content))
num_tokens += len(encoding.encode(message.role.value))
if message.name:
num_tokens += len(encoding.encode(message.name))
num_tokens += tokens_per_name
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
return num_tokens

def _get_model_name_for_tokeniser(self, model: Optional[str] = None) -> str:
if model is None:
model = self._default_model
if model in {
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k-0613",
"gpt-4-0314",
"gpt-4-32k-0314",
"gpt-4-0613",
"gpt-4-32k-0613",
}:
return model
elif model == "gpt-3.5-turbo-0301":
return model
elif "gpt-3.5-turbo" in model:
print("Warning: gpt-3.5-turbo may update over time. Returning tokeniser assuming gpt-3.5-turbo-0613.")
return "gpt-3.5-turbo-0613"
elif "gpt-4" in model:
print("Warning: gpt-4 may update over time. Returning tokeniser assuming gpt-4-0613.")
return "gpt-4-0613"
else:
raise NotImplementedError(
f"""not implemented for model {model}.
See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens."""
)

@staticmethod
@lru_cache(maxsize=40)
def _get_relevant_tokeniser(model: str) -> Encoding:
Expand Down
6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@ classifiers = [
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
]
dependencies = ["aiohttp >=3.0.0,<4.0.0"]
dependencies = [
"aiohttp >=3.0.0,<4.0.0",
"dataclasses_json >= 0.5.0"
]
dynamic = ["version"]

[project.urls]
Expand All @@ -37,7 +40,6 @@ test = [
openai = [
"openai >=0.27.4",
"tiktoken >=0.3.3",
"dataclasses_json >= 0.5.0"
]
huggingface = [
"transformers >= 4.0.0"
Expand Down
50 changes: 48 additions & 2 deletions tests/llm_api_client/anthropic_client/test_anthropic_client.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from unittest.mock import AsyncMock

import pytest

from llm_client import LLMAPIClientFactory, LLMAPIClientType
from llm_client import LLMAPIClientFactory, LLMAPIClientType, ChatMessage
from llm_client.consts import PROMPT_KEY, MODEL_KEY
from llm_client.llm_api_client.anthropic_client import AUTH_HEADER, COMPLETIONS_KEY, MAX_TOKENS_KEY, ACCEPT_HEADER, \
ACCEPT_VALUE, VERSION_HEADER, AnthropicClient
ACCEPT_VALUE, VERSION_HEADER, AnthropicClient, USER_PREFIX, ASSISTANT_PREFIX, START_PREFIX, SYSTEM_START_PREFIX, \
SYSTEM_END_PREFIX
from llm_client.llm_api_client.base_llm_api_client import Role


@pytest.mark.asyncio
Expand All @@ -14,6 +18,48 @@ async def test_get_llm_api_client__with_anthropic(config):

assert isinstance(actual, AnthropicClient)

@pytest.mark.asyncio
async def test_chat_completion_sanity(llm_client):
text_completion_mock = AsyncMock(return_value=["completion text"])
llm_client.text_completion = text_completion_mock

actual = await llm_client.chat_completion(messages=[ChatMessage(Role.USER, "Why is the sky blue?")], max_tokens=10)

assert actual == ["completion text"]
text_completion_mock.assert_awaited_once_with(f"{START_PREFIX}{USER_PREFIX} Why is the sky blue?"
f"{START_PREFIX}{ASSISTANT_PREFIX}", None, 10, 1)


@pytest.mark.asyncio
async def test_chat_completion_with_assistant_in_the_end(llm_client):
text_completion_mock = AsyncMock(return_value=["completion text"])
llm_client.text_completion = text_completion_mock

actual = await llm_client.chat_completion(messages=[ChatMessage(Role.USER, "Why is the sky blue?"),
ChatMessage(Role.ASSISTANT, "Answer - ")], temperature=10)

assert actual == ["completion text"]
text_completion_mock.assert_awaited_once_with(f"{START_PREFIX}{USER_PREFIX} Why is the sky blue?"
f"{START_PREFIX}{ASSISTANT_PREFIX} Answer -", None, None,
10)


@pytest.mark.asyncio
async def test_chat_completion_with_system(llm_client):
text_completion_mock = AsyncMock(return_value=["completion text"])
llm_client.text_completion = text_completion_mock

actual = await llm_client.chat_completion(messages=[ChatMessage(Role.SYSTEM, "Be nice!"),
ChatMessage(Role.USER, "Why is the sky blue?")], max_tokens=10,
temperature=2)

assert actual == ["completion text"]
text_completion_mock.assert_awaited_once_with(f"{START_PREFIX}{USER_PREFIX} "
f"{SYSTEM_START_PREFIX}Be nice!{SYSTEM_END_PREFIX}{START_PREFIX}"
f"{USER_PREFIX} Why is the sky blue?"
f"{START_PREFIX}{ASSISTANT_PREFIX}", None, 10, 2)


@pytest.mark.asyncio
async def test_text_completion__sanity(mock_aioresponse, llm_client, complete_url, anthropic_version):
mock_aioresponse.post(
Expand Down
Loading

0 comments on commit 66af541

Please sign in to comment.