Skip to content

Commit

Permalink
update README.md
Browse files Browse the repository at this point in the history
  • Loading branch information
uripeled2 committed Jul 22, 2023
1 parent 97ee4e7 commit 7a23a0c
Showing 1 changed file with 29 additions and 4 deletions.
33 changes: 29 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@ any flexibility (API params, endpoints etc.). *We also provide sync version, see
more details below in Usage section.

## Base Interface
The package exposes two simple interfaces for communicating with LLMs (In the future, we
The package exposes two simple interfaces for seamless integration with LLMs (In the future, we
will expand the interface to support more tasks like list models, edits, etc.):
```python
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, Optional
from enum import Enum
from dataclasses_json import dataclass_json, config
from aiohttp import ClientSession


Expand All @@ -30,6 +32,20 @@ class BaseLLMClient(ABC):
raise NotImplementedError()


class Role(Enum):
SYSTEM = "system"
USER = "user"
ASSISTANT = "assistant"


@dataclass_json
@dataclass
class ChatMessage:
role: Role = field(metadata=config(encoder=lambda role: role.value, decoder=Role))
content: str
name: Optional[str] = field(default=None, metadata=config(exclude=lambda name: name is None))
example: bool = field(default=False, metadata=config(exclude=lambda _: True))


@dataclass
class LLMAPIClientConfig:
Expand All @@ -49,8 +65,15 @@ class BaseLLMAPIClient(BaseLLMClient, ABC):
temperature: Optional[float] = None, top_p: Optional[float] = None, **kwargs) -> list[str]:
raise NotImplementedError()

async def chat_completion(self, messages: list[ChatMessage], temperature: float = 0,
max_tokens: int = 16, model: Optional[str] = None, **kwargs) -> list[str]:
raise NotImplementedError()

async def embedding(self, text: str, model: Optional[str] = None, **kwargs) -> list[float]:
raise NotImplementedError()

async def get_chat_tokens_count(self, messages: list[ChatMessage], **kwargs) -> int:
raise NotImplementedError()
```

## Requirements
Expand Down Expand Up @@ -109,10 +132,12 @@ async def main():
llm_client = OpenAIClient(LLMAPIClientConfig(OPENAI_API_KEY, session, default_model="text-davinci-003",
headers={"OpenAI-Organization": OPENAI_ORG_ID})) # The headers are optional
text = "This is indeed a test"
messages = [ChatMessage(role=Role.USER, content="Hello!"),
ChatMessage(role=Role.SYSTEM, content="Hi there! How can I assist you today?")]

print("number of tokens:", await llm_client.get_tokens_count(text)) # 5
print("generated chat:", await llm_client.chat_completion(
messages=[ChatMessage(role=Role.USER, content="Hello!")], model="gpt-3.5-turbo")) # ['Hi there! How can I assist you today?']
print("number of tokens for chat completion:", await llm_client.get_chat_tokens_count(messages, model="gpt-3.5-turbo")) # 23
print("generated chat:", await llm_client.chat_completion(messages, model="gpt-3.5-turbo")) # ['Hi there! How can I assist you today?']
print("generated text:", await llm_client.text_completion(text)) # [' string\n\nYes, this is a test string. Test strings are used to']
print("generated embedding:", await llm_client.embedding(text)) # [0.0023064255, -0.009327292, ...]
```
Expand Down Expand Up @@ -190,7 +215,7 @@ Contributions are welcome! Please check out the todos below, and feel free to op
- [ ] Cohere
- [x] Add support for more functions via LLMs
- [x] embeddings
- [ ] chat
- [x] chat
- [ ] list models
- [ ] edits
- [ ] more
Expand Down

0 comments on commit 7a23a0c

Please sign in to comment.