Skip to content

Commit

Permalink
release v0.2.2
Browse files Browse the repository at this point in the history
  • Loading branch information
hiyouga committed Feb 24, 2024
1 parent ead91ee commit 2176b0b
Show file tree
Hide file tree
Showing 15 changed files with 278 additions and 154 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install black ruff
python -m pip install ruff
- name: Check quality
run: |
Expand Down
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
check_dirs := src tests

quality:
black --check $(check_dirs)
ruff $(check_dirs)
ruff format --check $(check_dirs)

style:
black $(check_dirs)
ruff $(check_dirs) --fix
ruff format $(check_dirs)
34 changes: 34 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,40 @@ pip install -U imitater
python -m imitater.service.app -c config/example.yaml
```

<details><summary>Show configuration instruction.</summary>

Add an openai model.

```yaml
- name: Display name
- token: OpenAI token
```
Add a chat model.
```yaml
- name: Display name
- path: Model name on hub or model path
- device: Device IDs
- port: Port ID
- maxlen: Maximum model length (optional)
- agent_type: Agent type (optional) {react, aligned}
- template: Template jinja file (optional)
- gen_config: Generation config folder (optional)
```
Add an embedding model:
```yaml
- name: Display name
- path: Model name on hub or model path
- device: Device IDs (does not support multi-gpus)
- port: Port ID
- batch_size: Batch size (optional)
```
</details>
> [!NOTE]
> [Chat template](https://huggingface.co/docs/transformers/chat_templating) is required for the chat models.
>
Expand Down
31 changes: 11 additions & 20 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,21 @@
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"

[tool.black]
[tool.ruff]
target-version = "py310"
line-length = 119
target-version = ["py310"]
indent-width = 4

[tool.ruff]
[tool.ruff.lint]
ignore = ["C901", "E501", "E741", "W605"]
select = ["C", "E", "F", "I", "W"]
line-length = 119

[tool.ruff.isort]
[tool.ruff.lint.isort]
lines-after-imports = 2
known-third-party = ["infinity_emb", "openai", "torch", "transformers", "vllm"]

[isort]
default_section = "FIRSTPARTY"
known_third_party = [
"infinity_emb",
"torch",
"transformers",
"vllm"
]
line_length = 119
lines_after_imports = 2
multi_line_output = 3
include_trailing_comma = true
force_grid_wrap = 0
use_parentheses = true
ensure_newline_before_comments = true
[tool.ruff.format]
quote-style = "double"
indent-style = "space"
skip-magic-trailing-comma = false
line-ending = "auto"
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
numpy
sse-starlette
infinity-emb[torch]
infinity-emb[torch]==0.0.17
openai>=1.5.0
transformers>=4.37.2
vllm>=0.3.0
2 changes: 1 addition & 1 deletion src/imitater/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.2.1"
__version__ = "0.2.2"
3 changes: 1 addition & 2 deletions src/imitater/agent/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ def extract_tool(self, answer: str, tools: List[Dict[str, Any]]) -> Union[str, T
tools: the tool specification in the OpenAI format.
Returns:
name, arguments (if tool call exists): the tool name with JSON formatted arguments.
response (if tool call does not exist): the assistant response.
response | (name, arguments): response text or tool name with JSON arguments if tool exists.
"""
...

Expand Down
83 changes: 71 additions & 12 deletions src/imitater/model/chat_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass, fields
from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, Generator, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Tuple, Union

from transformers import AutoTokenizer, GenerationConfig
from typing_extensions import Self
Expand All @@ -17,6 +17,14 @@

@dataclass
class ChatConfig:
r"""
Creates configuration for a chat model.
Methods:
add_cli_args: adds arguments to a argument parser.
from_cli_args: builds configuration based on the command line arguments.
"""

name: str
path: str
device: List[int]
Expand Down Expand Up @@ -44,6 +52,15 @@ def from_cli_args(cls, args: "Namespace") -> Self:


class ChatModel:
r"""
Creates a chat model for chat completions.
Methods:
chat: generates chat completions.
stream_chat: streams chat completions.
function_call: generates tool calls.
"""

def __init__(self, config: "ChatConfig") -> None:
config.path = try_download_model_from_ms(config.path)
self.config = config
Expand Down Expand Up @@ -89,7 +106,7 @@ def _load_generation_config(self) -> None:
self._generation_config.top_p = 1.0

if not self._generation_config.max_new_tokens:
self._generation_config.max_new_tokens = 1024
self._generation_config.max_new_tokens = 2048

if isinstance(self._generation_config.eos_token_id, int):
self._generation_config.eos_token_id = [self._generation_config.eos_token_id]
Expand Down Expand Up @@ -121,6 +138,22 @@ async def _generate(
return result_generator

async def chat(self, messages: List[Dict[str, str]], request_id: str, **gen_kwargs) -> Tuple[str, int, int]:
r"""
Generates chat completions.
Args:
messages: input messages.
request_id: request ID.
temperature: generation parameter.
top_p: generation parameter.
max_tokens: generation parameter.
stop_token_ids: generation parameter.
Returns:
generated_text: the generated text.
prompt_tokens: the number of prompt tokens.
completion_tokens: the number of completion tokens.
"""
generated_text, prompt_tokens, completion_tokens = "", 0, 0
generator = await self._generate(messages, request_id, **gen_kwargs)
async for result in generator:
Expand All @@ -133,7 +166,21 @@ async def chat(self, messages: List[Dict[str, str]], request_id: str, **gen_kwar

async def stream_chat(
self, messages: List[Dict[str, str]], request_id: str, **gen_kwargs
) -> Generator[str, None, None]:
) -> AsyncGenerator[str, None]:
r"""
Streams chat completions.
Args:
messages: input messages.
request_id: request ID.
temperature: generation parameter.
top_p: generation parameter.
max_tokens: generation parameter.
stop_token_ids: generation parameter.
Returns:
generated_token: the generated token.
"""
generated_text = ""
generator = await self._generate(messages, request_id, **gen_kwargs)
async for result in generator:
Expand All @@ -144,21 +191,33 @@ async def stream_chat(
async def function_call(
self, messages: List[Dict[str, str]], tools: List[Dict[str, Any]], request_id: str, **gen_kwargs
) -> Tuple[Union[str, Tuple[str, str]], int, int]:
generated_text, prompt_tokens, completion_tokens = "", 0, 0
r"""
Generates chat completions.
Args:
messages: input messages.
tools: tools available.
request_id: request ID.
temperature: generation parameter.
top_p: generation parameter.
max_tokens: generation parameter.
stop_token_ids: generation parameter.
Returns:
response | (name, arguments): response text or tool name with JSON arguments if tool exists.
prompt_tokens: the number of prompt tokens.
completion_tokens: the number of completion tokens.
"""
agent_messages = self._agent.build_prompt(messages, tools)
stop_word = self._agent.get_stop_word()
if stop_word is not None:
gen_kwargs["stop_token_ids"] = [self._tokenizer.encode(stop_word)[0]]
stop_word_id = self._tokenizer.convert_tokens_to_ids(self._tokenizer.tokenize(stop_word)[0])
gen_kwargs["stop_token_ids"] = gen_kwargs.pop("stop_token_ids", []) + [stop_word_id]

generator = await self._generate(agent_messages, request_id, **gen_kwargs)
async for result in generator:
if result.finished:
generated_text = result.outputs[0].text
prompt_tokens = len(result.prompt_token_ids)
completion_tokens = len(result.outputs[0].token_ids)
generated_text, prompt_tokens, completion_tokens = await self.chat(agent_messages, request_id, **gen_kwargs)

if stop_word is not None:
stop_token = self._tokenizer.decode(gen_kwargs["stop_token_ids"])
stop_token = self._tokenizer.convert_ids_to_tokens(stop_word_id)
if generated_text.endswith(stop_token):
generated_text = generated_text[: -len(stop_token)]

Expand Down
34 changes: 33 additions & 1 deletion src/imitater/model/embed_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@

@dataclass
class EmbedConfig:
r"""
Creates configuration for an embedding model.
Methods:
add_cli_args: adds arguments to a argument parser.
from_cli_args: builds configuration based on the command line arguments.
"""

name: str
path: str
device: List[int]
Expand All @@ -37,6 +45,15 @@ def from_cli_args(cls, args: "Namespace") -> Self:


class EmbedModel:
r"""
Creates an embedding model for text embeddings.
Methods:
startup: starts the embedding engine.
shutdown: stops the embedding engine.
embed: calculates text embeddings.
"""

def __init__(self, config: "EmbedConfig") -> None:
config.path = try_download_model_from_ms(config.path)
self.config = config
Expand All @@ -50,15 +67,30 @@ def _init_infinity_engine(self) -> None:
self._engine = AsyncEmbeddingEngine(
model_name_or_path=self.config.path,
batch_size=self.config.batch_size,
engine="torch",
device="cuda",
)

async def startup(self) -> None:
r"""
Starts the embedding engine.
"""
await self._engine.astart()

async def shutdown(self) -> None:
r"""
Stops the embedding engine.
"""
await self._engine.astop()

async def embed(self, texts: List[str]) -> Tuple[List["NDArray[float32]"], int]:
r"""
Calculates the text embeddings.
Args:
texts: the batched text input.
Returns:
embeddings: the batched embeddings.
usage: the number of input tokens.
"""
return await self._engine.embed(texts)
Loading

0 comments on commit 2176b0b

Please sign in to comment.