From ab95159bffed6a563a432096ea0d37c87b069542 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Sun, 25 Feb 2024 01:36:10 +0800 Subject: [PATCH] release v0.2.3 --- requirements.txt | 6 ++---- src/imitater/__init__.py | 2 +- src/imitater/model/chat_model.py | 17 +++++++---------- 3 files changed, 10 insertions(+), 15 deletions(-) diff --git a/requirements.txt b/requirements.txt index 78a603d..ef43b15 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,4 @@ -numpy -sse-starlette infinity-emb[torch]==0.0.17 openai>=1.5.0 -transformers>=4.37.2 -vllm>=0.3.0 +sse-starlette +vllm==0.3.2 diff --git a/src/imitater/__init__.py b/src/imitater/__init__.py index b5fdc75..d31c31e 100644 --- a/src/imitater/__init__.py +++ b/src/imitater/__init__.py @@ -1 +1 @@ -__version__ = "0.2.2" +__version__ = "0.2.3" diff --git a/src/imitater/model/chat_model.py b/src/imitater/model/chat_model.py index b8969c6..4b057e3 100644 --- a/src/imitater/model/chat_model.py +++ b/src/imitater/model/chat_model.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, fields -from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Optional, Tuple, Union from transformers import AutoTokenizer, GenerationConfig from typing_extensions import Self @@ -12,8 +12,6 @@ if TYPE_CHECKING: from argparse import ArgumentParser, Namespace - from vllm import RequestOutput - @dataclass class ChatConfig: @@ -120,9 +118,7 @@ def _load_generation_config(self) -> None: {"additional_special_tokens": extra_special_tokens}, replace_additional_special_tokens=False ) - async def _generate( - self, messages: List[Dict[str, str]], request_id: str, **gen_kwargs - ) -> AsyncIterator["RequestOutput"]: + async def _generate(self, messages: List[Dict[str, str]], request_id: str, **gen_kwargs): input_ids = self._tokenizer.apply_chat_template( conversation=messages, tokenize=True, add_generation_prompt=True ) @@ -157,7 +153,7 @@ async def chat(self, messages: List[Dict[str, str]], request_id: str, **gen_kwar generated_text, prompt_tokens, completion_tokens = "", 0, 0 generator = await self._generate(messages, request_id, **gen_kwargs) async for result in generator: - if result.finished: + if not result.finished: generated_text = result.outputs[0].text prompt_tokens = len(result.prompt_token_ids) completion_tokens = len(result.outputs[0].token_ids) @@ -184,9 +180,10 @@ async def stream_chat( generated_text = "" generator = await self._generate(messages, request_id, **gen_kwargs) async for result in generator: - delta_text = result.outputs[0].text[len(generated_text) :] - generated_text = result.outputs[0].text - yield delta_text + if not result.finished: + delta_text = result.outputs[0].text[len(generated_text) :] + generated_text = result.outputs[0].text + yield delta_text async def function_call( self, messages: List[Dict[str, str]], tools: List[Dict[str, Any]], request_id: str, **gen_kwargs