Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] Clean up type annotations for mistral tokenizer #8314

Merged
merged 1 commit into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions tests/async_engine/test_chat_template.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest

from vllm.entrypoints.chat_utils import apply_chat_template, load_chat_template
from vllm.entrypoints.chat_utils import (apply_hf_chat_template,
load_chat_template)
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
from vllm.transformers_utils.tokenizer import get_tokenizer

Expand Down Expand Up @@ -87,7 +88,7 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
add_generation_prompt=add_generation_prompt)

# Call the function and get the result
result = apply_chat_template(
result = apply_hf_chat_template(
tokenizer,
conversation=mock_request.messages,
chat_template=mock_request.chat_template or template_content,
Expand Down
61 changes: 41 additions & 20 deletions vllm/entrypoints/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
# yapf: enable
# pydantic needs the TypedDict from typing_extensions
from pydantic import ConfigDict
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from typing_extensions import Required, TypeAlias, TypedDict

from vllm.config import ModelConfig
Expand All @@ -31,7 +32,7 @@
from vllm.multimodal.utils import (async_get_and_parse_audio,
async_get_and_parse_image,
get_and_parse_audio, get_and_parse_image)
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer

logger = init_logger(__name__)

Expand Down Expand Up @@ -379,6 +380,9 @@ def _parse_chat_message_content_parts(
audio_url = _AudioParser(part)["audio_url"]

mm_parser.parse_audio(audio_url["url"])
elif part_type == "refusal":
text = _RefusalParser(part)["refusal"]
texts.append(text)
Comment on lines +383 to +385
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed that _RefusalParser got left out previously, so I'm adding it here.

else:
raise NotImplementedError(f"Unknown part type: {part_type}")

Expand Down Expand Up @@ -433,6 +437,21 @@ def _parse_chat_message_content(
return result


def _postprocess_messages(messages: List[ConversationMessage]) -> None:
# per the Transformers docs & maintainers, tool call arguments in
# assistant-role messages with tool_calls need to be dicts not JSON str -
# this is how tool-use chat templates will expect them moving forwards
# so, for messages that have tool_calls, parse the string (which we get
# from openAI format) to dict
for message in messages:
if (message["role"] == "assistant" and "tool_calls" in message
and isinstance(message["tool_calls"], list)):

for item in message["tool_calls"]:
item["function"]["arguments"] = json.loads(
item["function"]["arguments"])
Comment on lines +440 to +452
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume that mistral tokenizers will be able to handle tool calls internally since the output conversation will not be used by mistral tokenizer.



def parse_chat_messages(
messages: List[ChatCompletionMessageParam],
model_config: ModelConfig,
Expand All @@ -446,6 +465,8 @@ def parse_chat_messages(

conversation.extend(sub_messages)

_postprocess_messages(conversation)

return conversation, mm_tracker.all_mm_data()


Expand All @@ -462,41 +483,41 @@ def parse_chat_messages_futures(

conversation.extend(sub_messages)

_postprocess_messages(conversation)

return conversation, mm_tracker.all_mm_data()


def apply_chat_template(
tokenizer: AnyTokenizer,
def apply_hf_chat_template(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
conversation: List[ConversationMessage],
chat_template: Optional[str],
*,
tokenize: bool = False, # Different from HF's default
**kwargs: Any,
) -> Union[str, List[int]]:
) -> str:
if chat_template is None and tokenizer.chat_template is None:
raise ValueError(
"As of transformers v4.44, default chat template is no longer "
"allowed, so you must provide a chat template if the tokenizer "
"does not define one.")

# per the Transformers docs & maintainers, tool call arguments in
# assistant-role messages with tool_calls need to be dicts not JSON str -
# this is how tool-use chat templates will expect them moving forwards
# so, for messages that have tool_calls, parse the string (which we get
# from openAI format) to dict
for message in conversation:
if (message["role"] == "assistant" and "tool_calls" in message
and isinstance(message["tool_calls"], list)):
return tokenizer.apply_chat_template(
conversation=conversation, # type: ignore[arg-type]
chat_template=chat_template,
tokenize=tokenize,
**kwargs,
)

for i in range(len(message["tool_calls"])):
args: str = message["tool_calls"][i]["function"]["arguments"]
parsed_args: Dict = json.loads(args)
message["tool_calls"][i]["function"]["arguments"] = parsed_args

prompt = tokenizer.apply_chat_template(
conversation=conversation,
def apply_mistral_chat_template(
tokenizer: MistralTokenizer,
messages: List[ChatCompletionMessageParam],
chat_template: Optional[str],
**kwargs: Any,
) -> List[int]:
return tokenizer.apply_chat_template(
messages=messages,
chat_template=chat_template,
Comment on lines +516 to 521
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
chat_template: Optional[str],
**kwargs: Any,
) -> List[int]:
return tokenizer.apply_chat_template(
messages=messages,
chat_template=chat_template,
**kwargs: Any,
) -> List[int]:
return tokenizer.apply_chat_template(
messages=messages,

maybe out of scope for this PR, but mistral tokenizers will actually never need a chat_template

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, let's do this in another PR.

tokenize=tokenize,
**kwargs,
)
return prompt
26 changes: 18 additions & 8 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
apply_chat_template,
apply_hf_chat_template,
apply_mistral_chat_template,
parse_chat_messages)
from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt
from vllm.inputs.parse import parse_and_batch_prompt
Expand All @@ -19,7 +20,7 @@
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer import (AnyTokenizer,
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
get_cached_tokenizer)
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.usage.usage_lib import UsageContext
Expand Down Expand Up @@ -393,12 +394,21 @@ def chat(
conversation, mm_data = parse_chat_messages(messages, model_config,
tokenizer)

prompt = apply_chat_template(
tokenizer,
conversation,
chat_template=chat_template,
add_generation_prompt=add_generation_prompt,
)
prompt: Union[str, List[int]]
if isinstance(tokenizer, MistralTokenizer):
prompt = apply_mistral_chat_template(
tokenizer,
messages=messages,
chat_template=chat_template,
add_generation_prompt=add_generation_prompt,
)
else:
prompt = apply_hf_chat_template(
tokenizer,
conversation=conversation,
chat_template=chat_template,
add_generation_prompt=add_generation_prompt,
)
Comment on lines +397 to +411
Copy link
Member Author

@DarkLight1337 DarkLight1337 Sep 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main part of this PR. Notice that mistral tokenizer uses messages while HF tokenizer uses conversation. This is cleaner than having different parsing logic inside parse_chat_messages as it avoids the need to handle different types of conversation when generating the output request.


inputs: PromptInputs
if is_list_of(prompt, int):
Expand Down
48 changes: 30 additions & 18 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from vllm.config import ModelConfig
from vllm.engine.protocol import AsyncEngineClient
from vllm.entrypoints.chat_utils import (ConversationMessage,
apply_chat_template,
apply_hf_chat_template,
apply_mistral_chat_template,
load_chat_template,
parse_chat_messages_futures)
from vllm.entrypoints.logger import RequestLogger
Expand All @@ -35,7 +36,7 @@
from vllm.sequence import Logprob
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
log_tracing_disabled_warning)
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import iterate_with_cancellation, random_uuid

logger = init_logger(__name__)
Expand Down Expand Up @@ -121,15 +122,27 @@ async def create_chat_completion(
tool.model_dump() for tool in request.tools
]

prompt = apply_chat_template(
tokenizer,
conversation=conversation,
chat_template=request.chat_template or self.chat_template,
add_generation_prompt=request.add_generation_prompt,
tools=tool_dicts,
documents=request.documents,
**(request.chat_template_kwargs or {}),
)
prompt: Union[str, List[int]]
if isinstance(tokenizer, MistralTokenizer):
prompt = apply_mistral_chat_template(
tokenizer,
messages=request.messages,
chat_template=request.chat_template or self.chat_template,
add_generation_prompt=request.add_generation_prompt,
tools=tool_dicts,
documents=request.documents,
**(request.chat_template_kwargs or {}),
)
else:
prompt = apply_hf_chat_template(
tokenizer,
conversation=conversation,
chat_template=request.chat_template or self.chat_template,
add_generation_prompt=request.add_generation_prompt,
tools=tool_dicts,
documents=request.documents,
**(request.chat_template_kwargs or {}),
)
except Exception as e:
logger.error("Error in applying chat template from request: %s", e)
return self.create_error_response(str(e))
Expand Down Expand Up @@ -307,11 +320,10 @@ async def chat_completion_stream_generator(
# Send response to echo the input portion of the
# last message
if request.echo:
last_msg_content: Optional[str] = ""
if conversation and conversation[-1].get(
"content") and conversation[-1].get(
"role") == role:
last_msg_content = conversation[-1]["content"]
last_msg_content: str = ""
if conversation and "content" in conversation[
-1] and conversation[-1].get("role") == role:
last_msg_content = conversation[-1]["content"] or ""

if last_msg_content:
for i in range(num_choices):
Expand Down Expand Up @@ -659,8 +671,8 @@ async def chat_completion_full_generator(

if request.echo:
last_msg_content = ""
if conversation and conversation[-1].get(
"content") and conversation[-1].get("role") == role:
if conversation and "content" in conversation[-1] and conversation[
-1].get("role") == role:
last_msg_content = conversation[-1]["content"] or ""

for choice in choices:
Expand Down
25 changes: 18 additions & 7 deletions vllm/entrypoints/openai/serving_tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

from vllm.config import ModelConfig
from vllm.engine.protocol import AsyncEngineClient
from vllm.entrypoints.chat_utils import (apply_chat_template,
from vllm.entrypoints.chat_utils import (apply_hf_chat_template,
apply_mistral_chat_template,
load_chat_template,
parse_chat_messages_futures)
from vllm.entrypoints.logger import RequestLogger
Expand All @@ -18,6 +19,7 @@
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
OpenAIServing)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import MistralTokenizer
from vllm.utils import random_uuid

logger = init_logger(__name__)
Expand Down Expand Up @@ -66,6 +68,7 @@ async def create_tokenize(

tokenizer = await self.async_engine_client.get_tokenizer(lora_request)

prompt: Union[str, List[int]]
if isinstance(request, TokenizeChatRequest):
model_config = self.model_config

Expand All @@ -77,12 +80,20 @@ async def create_tokenize(
logger.warning(
"Multi-modal inputs are ignored during tokenization")

prompt = apply_chat_template(
tokenizer,
conversation=conversation,
chat_template=self.chat_template,
add_generation_prompt=request.add_generation_prompt,
)
if isinstance(tokenizer, MistralTokenizer):
prompt = apply_mistral_chat_template(
tokenizer,
messages=request.messages,
chat_template=self.chat_template,
add_generation_prompt=request.add_generation_prompt,
)
else:
prompt = apply_hf_chat_template(
tokenizer,
conversation=conversation,
chat_template=self.chat_template,
add_generation_prompt=request.add_generation_prompt,
)
else:
prompt = request.prompt

Expand Down
8 changes: 4 additions & 4 deletions vllm/transformers_utils/tokenizers/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
Tekkenizer)

if TYPE_CHECKING:
from vllm.entrypoints.chat_utils import ConversationMessage
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam


@dataclass
Expand Down Expand Up @@ -122,19 +122,19 @@ def get_added_vocab(self) -> List[str]:
return []

def encode(self, prompt: str) -> List[int]:
# `encode ` should only be used for prompt completion
# `encode` should only be used for prompt completion
# it should never be used for chat_completion.
# For chat completion use `apply_chat_template`
return self.tokenizer.encode(prompt, bos=True, eos=False)

def apply_chat_template(self,
conversation: List["ConversationMessage"],
messages: List["ChatCompletionMessageParam"],
tools: Optional[Dict[str, Any]] = None,
**kwargs) -> List[int]:
assert tools is None, "`tools` are not yet supported."

request = ChatCompletionRequest(
messages=conversation) # type: ignore[type-var]
messages=messages) # type: ignore[type-var]
encoded = self.mistral.encode_chat_completion(request)

# encode-decode to get clean prompt
Expand Down
Loading