-
-
Notifications
You must be signed in to change notification settings - Fork 4.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Frontend] Clean up type annotations for mistral tokenizer #8314
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -23,6 +23,7 @@ | |||||||||||||||||||||
# yapf: enable | ||||||||||||||||||||||
# pydantic needs the TypedDict from typing_extensions | ||||||||||||||||||||||
from pydantic import ConfigDict | ||||||||||||||||||||||
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast | ||||||||||||||||||||||
from typing_extensions import Required, TypeAlias, TypedDict | ||||||||||||||||||||||
|
||||||||||||||||||||||
from vllm.config import ModelConfig | ||||||||||||||||||||||
|
@@ -31,7 +32,7 @@ | |||||||||||||||||||||
from vllm.multimodal.utils import (async_get_and_parse_audio, | ||||||||||||||||||||||
async_get_and_parse_image, | ||||||||||||||||||||||
get_and_parse_audio, get_and_parse_image) | ||||||||||||||||||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer | ||||||||||||||||||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer | ||||||||||||||||||||||
|
||||||||||||||||||||||
logger = init_logger(__name__) | ||||||||||||||||||||||
|
||||||||||||||||||||||
|
@@ -379,6 +380,9 @@ def _parse_chat_message_content_parts( | |||||||||||||||||||||
audio_url = _AudioParser(part)["audio_url"] | ||||||||||||||||||||||
|
||||||||||||||||||||||
mm_parser.parse_audio(audio_url["url"]) | ||||||||||||||||||||||
elif part_type == "refusal": | ||||||||||||||||||||||
text = _RefusalParser(part)["refusal"] | ||||||||||||||||||||||
texts.append(text) | ||||||||||||||||||||||
else: | ||||||||||||||||||||||
raise NotImplementedError(f"Unknown part type: {part_type}") | ||||||||||||||||||||||
|
||||||||||||||||||||||
|
@@ -433,6 +437,21 @@ def _parse_chat_message_content( | |||||||||||||||||||||
return result | ||||||||||||||||||||||
|
||||||||||||||||||||||
|
||||||||||||||||||||||
def _postprocess_messages(messages: List[ConversationMessage]) -> None: | ||||||||||||||||||||||
# per the Transformers docs & maintainers, tool call arguments in | ||||||||||||||||||||||
# assistant-role messages with tool_calls need to be dicts not JSON str - | ||||||||||||||||||||||
# this is how tool-use chat templates will expect them moving forwards | ||||||||||||||||||||||
# so, for messages that have tool_calls, parse the string (which we get | ||||||||||||||||||||||
# from openAI format) to dict | ||||||||||||||||||||||
for message in messages: | ||||||||||||||||||||||
if (message["role"] == "assistant" and "tool_calls" in message | ||||||||||||||||||||||
and isinstance(message["tool_calls"], list)): | ||||||||||||||||||||||
|
||||||||||||||||||||||
for item in message["tool_calls"]: | ||||||||||||||||||||||
item["function"]["arguments"] = json.loads( | ||||||||||||||||||||||
item["function"]["arguments"]) | ||||||||||||||||||||||
Comment on lines
+440
to
+452
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I assume that mistral tokenizers will be able to handle tool calls internally since the output |
||||||||||||||||||||||
|
||||||||||||||||||||||
|
||||||||||||||||||||||
def parse_chat_messages( | ||||||||||||||||||||||
messages: List[ChatCompletionMessageParam], | ||||||||||||||||||||||
model_config: ModelConfig, | ||||||||||||||||||||||
|
@@ -446,6 +465,8 @@ def parse_chat_messages( | |||||||||||||||||||||
|
||||||||||||||||||||||
conversation.extend(sub_messages) | ||||||||||||||||||||||
|
||||||||||||||||||||||
_postprocess_messages(conversation) | ||||||||||||||||||||||
|
||||||||||||||||||||||
return conversation, mm_tracker.all_mm_data() | ||||||||||||||||||||||
|
||||||||||||||||||||||
|
||||||||||||||||||||||
|
@@ -462,41 +483,41 @@ def parse_chat_messages_futures( | |||||||||||||||||||||
|
||||||||||||||||||||||
conversation.extend(sub_messages) | ||||||||||||||||||||||
|
||||||||||||||||||||||
_postprocess_messages(conversation) | ||||||||||||||||||||||
|
||||||||||||||||||||||
return conversation, mm_tracker.all_mm_data() | ||||||||||||||||||||||
|
||||||||||||||||||||||
|
||||||||||||||||||||||
def apply_chat_template( | ||||||||||||||||||||||
tokenizer: AnyTokenizer, | ||||||||||||||||||||||
def apply_hf_chat_template( | ||||||||||||||||||||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], | ||||||||||||||||||||||
conversation: List[ConversationMessage], | ||||||||||||||||||||||
chat_template: Optional[str], | ||||||||||||||||||||||
*, | ||||||||||||||||||||||
tokenize: bool = False, # Different from HF's default | ||||||||||||||||||||||
**kwargs: Any, | ||||||||||||||||||||||
) -> Union[str, List[int]]: | ||||||||||||||||||||||
) -> str: | ||||||||||||||||||||||
if chat_template is None and tokenizer.chat_template is None: | ||||||||||||||||||||||
raise ValueError( | ||||||||||||||||||||||
"As of transformers v4.44, default chat template is no longer " | ||||||||||||||||||||||
"allowed, so you must provide a chat template if the tokenizer " | ||||||||||||||||||||||
"does not define one.") | ||||||||||||||||||||||
|
||||||||||||||||||||||
# per the Transformers docs & maintainers, tool call arguments in | ||||||||||||||||||||||
# assistant-role messages with tool_calls need to be dicts not JSON str - | ||||||||||||||||||||||
# this is how tool-use chat templates will expect them moving forwards | ||||||||||||||||||||||
# so, for messages that have tool_calls, parse the string (which we get | ||||||||||||||||||||||
# from openAI format) to dict | ||||||||||||||||||||||
for message in conversation: | ||||||||||||||||||||||
if (message["role"] == "assistant" and "tool_calls" in message | ||||||||||||||||||||||
and isinstance(message["tool_calls"], list)): | ||||||||||||||||||||||
return tokenizer.apply_chat_template( | ||||||||||||||||||||||
conversation=conversation, # type: ignore[arg-type] | ||||||||||||||||||||||
chat_template=chat_template, | ||||||||||||||||||||||
tokenize=tokenize, | ||||||||||||||||||||||
**kwargs, | ||||||||||||||||||||||
) | ||||||||||||||||||||||
|
||||||||||||||||||||||
for i in range(len(message["tool_calls"])): | ||||||||||||||||||||||
args: str = message["tool_calls"][i]["function"]["arguments"] | ||||||||||||||||||||||
parsed_args: Dict = json.loads(args) | ||||||||||||||||||||||
message["tool_calls"][i]["function"]["arguments"] = parsed_args | ||||||||||||||||||||||
|
||||||||||||||||||||||
prompt = tokenizer.apply_chat_template( | ||||||||||||||||||||||
conversation=conversation, | ||||||||||||||||||||||
def apply_mistral_chat_template( | ||||||||||||||||||||||
tokenizer: MistralTokenizer, | ||||||||||||||||||||||
messages: List[ChatCompletionMessageParam], | ||||||||||||||||||||||
chat_template: Optional[str], | ||||||||||||||||||||||
**kwargs: Any, | ||||||||||||||||||||||
) -> List[int]: | ||||||||||||||||||||||
return tokenizer.apply_chat_template( | ||||||||||||||||||||||
messages=messages, | ||||||||||||||||||||||
chat_template=chat_template, | ||||||||||||||||||||||
Comment on lines
+516
to
521
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
maybe out of scope for this PR, but mistral tokenizers will actually never need a chat_template There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, let's do this in another PR. |
||||||||||||||||||||||
tokenize=tokenize, | ||||||||||||||||||||||
**kwargs, | ||||||||||||||||||||||
) | ||||||||||||||||||||||
return prompt |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,7 +6,8 @@ | |
from vllm.engine.arg_utils import EngineArgs | ||
from vllm.engine.llm_engine import LLMEngine | ||
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam, | ||
apply_chat_template, | ||
apply_hf_chat_template, | ||
apply_mistral_chat_template, | ||
parse_chat_messages) | ||
from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt | ||
from vllm.inputs.parse import parse_and_batch_prompt | ||
|
@@ -19,7 +20,7 @@ | |
from vllm.pooling_params import PoolingParams | ||
from vllm.prompt_adapter.request import PromptAdapterRequest | ||
from vllm.sampling_params import SamplingParams | ||
from vllm.transformers_utils.tokenizer import (AnyTokenizer, | ||
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer, | ||
get_cached_tokenizer) | ||
from vllm.transformers_utils.tokenizer_group import TokenizerGroup | ||
from vllm.usage.usage_lib import UsageContext | ||
|
@@ -393,12 +394,21 @@ def chat( | |
conversation, mm_data = parse_chat_messages(messages, model_config, | ||
tokenizer) | ||
|
||
prompt = apply_chat_template( | ||
tokenizer, | ||
conversation, | ||
chat_template=chat_template, | ||
add_generation_prompt=add_generation_prompt, | ||
) | ||
prompt: Union[str, List[int]] | ||
if isinstance(tokenizer, MistralTokenizer): | ||
prompt = apply_mistral_chat_template( | ||
tokenizer, | ||
messages=messages, | ||
chat_template=chat_template, | ||
add_generation_prompt=add_generation_prompt, | ||
) | ||
else: | ||
prompt = apply_hf_chat_template( | ||
tokenizer, | ||
conversation=conversation, | ||
chat_template=chat_template, | ||
add_generation_prompt=add_generation_prompt, | ||
) | ||
Comment on lines
+397
to
+411
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The main part of this PR. Notice that mistral tokenizer uses |
||
|
||
inputs: PromptInputs | ||
if is_list_of(prompt, int): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I noticed that
_RefusalParser
got left out previously, so I'm adding it here.