Skip to content

Commit

Permalink
feat: improve tools to include name and add tests (huggingface#1693)
Browse files Browse the repository at this point in the history
This PR makes tool calling aware of the name of the function selected. 

Fixes:
huggingface#1657

Thank you @puppetm4st3r for the helpful snippets, large parts of this PR
are simply refactors of the code shared 🙏

**opening draft PR because small tweaks are needed before merging
  • Loading branch information
drbh authored and kdamaszk committed May 27, 2024
1 parent 55737e5 commit 111061c
Show file tree
Hide file tree
Showing 11 changed files with 428 additions and 174 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"usage": null
}
],
"created": 1710795556,
"created": 1712874856,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,12 @@
"tool_calls": [
{
"function": {
"description": null,
"name": "tools",
"parameters": {
"arguments": {
"format": "celsius",
"location": "New York, NY",
"num_days": 14
}
"location": "Brooklyn"
},
"description": null,
"name": "get_current_weather"
},
"id": 0,
"type": "function"
Expand All @@ -27,14 +26,14 @@
"usage": null
}
],
"created": 1710795556,
"created": 1712782670,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native",
"usage": {
"completion_tokens": 29,
"prompt_tokens": 316,
"total_tokens": 345
"completion_tokens": 37,
"prompt_tokens": 524,
"total_tokens": 561
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,12 @@
"tool_calls": [
{
"function": {
"description": null,
"name": "tools",
"parameters": {
"arguments": {
"format": "celsius",
"location": "New York, NY",
"num_days": 14
}
"location": "Brooklyn"
},
"description": null,
"name": "get_current_weather"
},
"id": 0,
"type": "function"
Expand All @@ -27,14 +26,14 @@
"usage": null
}
],
"created": 1710795557,
"created": 1712787937,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native",
"usage": {
"completion_tokens": 29,
"prompt_tokens": 316,
"total_tokens": 345
"completion_tokens": 37,
"prompt_tokens": 524,
"total_tokens": 561
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
"tool_calls": [
{
"function": {
"description": null,
"name": "tools",
"parameters": {
"arguments": {
"format": "celsius",
"location": "New York, NY"
}
},
"description": null,
"name": "get_current_weather"
},
"id": 0,
"type": "function"
Expand All @@ -26,14 +26,14 @@
"usage": null
}
],
"created": 1710795557,
"created": 1712852394,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native",
"usage": {
"completion_tokens": 21,
"prompt_tokens": 187,
"total_tokens": 208
"completion_tokens": 48,
"prompt_tokens": 320,
"total_tokens": 368
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
{
"choices": [
{
"finish_reason": "eos_token",
"index": 0,
"logprobs": null,
"message": {
"content": null,
"name": null,
"role": "assistant",
"tool_calls": [
{
"function": {
"arguments": {
"error": "Cannot get current weather forecast from specified location and temperature unit. Please try again with different options."
},
"description": null,
"name": "notify_error"
},
"id": 0,
"type": "function"
}
]
},
"usage": null
}
],
"created": 1712852597,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "1.4.5-native",
"usage": {
"completion_tokens": 39,
"prompt_tokens": 496,
"total_tokens": 535
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
"logprobs": null
}
],
"created": 1710795499,
"created": 1712788218,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
Expand Down
42 changes: 42 additions & 0 deletions integration-tests/models/test_chat_llama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import pytest
import json

from text_generation.types import GrammarType


@pytest.fixture(scope="module")
def flash_llama_chat_handle(launcher):
with launcher(
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", num_shard=2, disable_grammar_support=False
) as handle:
yield handle


@pytest.fixture(scope="module")
async def flash_llama_chat(flash_llama_chat_handle):
await flash_llama_chat_handle.health(300)
return flash_llama_chat_handle.client


@pytest.mark.private
async def test_flash_llama_simple(flash_llama_chat, response_snapshot):
response = await flash_llama_chat.chat(
max_tokens=100,
seed=1,
messages=[
{
"role": "system",
"content": "Youre a helpful assistant! Answer the users question best you can.",
},
{
"role": "user",
"content": "What is the weather like in Brooklyn, New York?",
},
],
)

assert (
response.choices[0].message.content
== "As of today, there is a Update available for the Brooklyn, New York, area. According to the latest forecast, it's warm with high temperatures throughout the day. It's forecasted at 75°F for today and 77°F for tomorrow. However, in autumn, the weather typically changes drastically, becoming cooler and wetter. You can find the current weather forecast for the area through your local weather service. Additionally"
)
assert response == response_snapshot
109 changes: 59 additions & 50 deletions integration-tests/models/test_tools_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,34 +71,7 @@ async def flash_llama_grammar_tools(flash_llama_grammar_tools_handle):
]


@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_no_tools(
flash_llama_grammar_tools, response_snapshot
):
response = await flash_llama_grammar_tools.chat(
max_tokens=100,
seed=1,
messages=[
{
"role": "system",
"content": "Youre a helpful assistant! Answer the users question best you can.",
},
{
"role": "user",
"content": "What is the weather like in Brooklyn, New York?",
},
],
)

assert (
response.choices[0].message.content
== "As of today, there is a Update available for the Brooklyn, New York, area. According to the latest forecast, it's warm with high temperatures throughout the day. It's forecasted at 75°F for today and 77°F for tomorrow. However, in autumn, the weather typically changes drastically, becoming cooler and wetter. You can find the current weather forecast for the area through your local weather service. Additionally"
)
assert response == response_snapshot


@pytest.mark.skip
@pytest.mark.skip(reason="Takes too long to run")
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_snapshot):
Expand All @@ -121,23 +94,19 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna
assert response.choices[0].message.content == None
assert response.choices[0].message.tool_calls == [
{
"id": 0,
"type": "function",
"function": {
"description": None,
"name": "tools",
"parameters": {
"format": "celsius",
"location": "New York, NY",
"num_days": 14,
},
"name": "get_current_weather",
"arguments": {"format": "celsius", "location": "New York, NY"},
},
"id": 0,
"type": "function",
}
]
assert response == response_snapshot


@pytest.mark.skip
@pytest.mark.skip(reason="Takes too long to run")
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_auto(
Expand All @@ -163,23 +132,20 @@ async def test_flash_llama_grammar_tools_auto(
assert response.choices[0].message.content == None
assert response.choices[0].message.tool_calls == [
{
"id": 0,
"type": "function",
"function": {
"description": None,
"name": "tools",
"parameters": {
"format": "celsius",
"location": "New York, NY",
"num_days": 14,
},
"name": "get_current_weather",
"arguments": {"format": "celsius", "location": "New York, NY"},
},
"id": 0,
"type": "function",
}
]

assert response == response_snapshot


@pytest.mark.skip
@pytest.mark.skip(reason="Takes too long to run")
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_choice(
Expand Down Expand Up @@ -209,15 +175,16 @@ async def test_flash_llama_grammar_tools_choice(
"type": "function",
"function": {
"description": None,
"name": "tools",
"parameters": {"format": "celsius", "location": "New York, NY"},
"name": "get_current_weather",
"arguments": {"format": "celsius", "location": "New York, NY"},
},
}
]

assert response == response_snapshot


@pytest.mark.skip
@pytest.mark.skip(reason="Takes too long to run")
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_stream(
Expand Down Expand Up @@ -246,5 +213,47 @@ async def test_flash_llama_grammar_tools_stream(
async for response in responses:
count += 1

assert count == 20
assert count == 38
assert response == response_snapshot


@pytest.mark.skip(reason="Takes too long to run")
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_insufficient_information(
flash_llama_grammar_tools, response_snapshot
):
responses = await flash_llama_grammar_tools.chat(
max_tokens=100,
seed=8,
tools=tools,
tool_choice="auto",
messages=[
{
"role": "system",
"content": "ONLY RESPOND IF THE USER ASKS A WEATHER RELATED QUESTION",
},
{
"role": "user",
"content": "Tell me a story about 3 sea creatures",
},
],
stream=False,
)

assert responses.choices[0].message.content == None
assert responses.choices[0].message.tool_calls == [
{
"function": {
"arguments": {
"error": "Cannot get current weather forecast from specified location and temperature unit. Please try again with different options."
},
"description": None,
"name": "notify_error",
},
"id": 0,
"type": "function",
}
]

assert responses == response_snapshot
Loading

0 comments on commit 111061c

Please sign in to comment.