Skip to content

Commit

Permalink
fix: improve tool type, bump pydantic and outlines (huggingface#1650)
Browse files Browse the repository at this point in the history
This PR resolves a couple

- [X] adjusts the tool response to align with openai's tools response
type
- [X] bumps pydantic to `2.6.4` in all apps (resolves dependency issue
when running tests)
- [X] bump `outlines` version and fix import for new name
  • Loading branch information
drbh authored and kdamaszk committed Apr 25, 2024
1 parent b36c0f8 commit ab074c8
Show file tree
Hide file tree
Showing 17 changed files with 367 additions and 275 deletions.
225 changes: 113 additions & 112 deletions clients/python/poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion clients/python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ repository = "https://github.com/huggingface/text-generation-inference"

[tool.poetry.dependencies]
python = "^3.7"
pydantic = "> 1.10, < 3"
pydantic = "> 2, < 3"
aiohttp = "^3.8"
huggingface-hub = ">= 0.12, < 1.0"

Expand Down
46 changes: 23 additions & 23 deletions clients/python/text_generation/types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import Enum
from pydantic import BaseModel, validator
from pydantic import BaseModel, field_validator
from typing import Optional, List, Union, Any

from text_generation.errors import ValidationError
Expand Down Expand Up @@ -32,7 +32,7 @@ class Message(BaseModel):
# Role of the message sender
role: str
# Content of the message
content: Optional[str]
content: Optional[str] = None
# Optional name of the message sender
name: Optional[str] = None
# Tool calls associated with the chat completion
Expand All @@ -56,7 +56,7 @@ class ChatCompletionComplete(BaseModel):
# Reason for completion
finish_reason: str
# Usage details of the chat completion
usage: Any
usage: Optional[Any] = None


class Function(BaseModel):
Expand All @@ -73,7 +73,7 @@ class ChoiceDeltaToolCall(BaseModel):

class ChoiceDelta(BaseModel):
role: str
content: Optional[str]
content: Optional[str] = None
tool_calls: Optional[ChoiceDeltaToolCall]


Expand Down Expand Up @@ -176,74 +176,74 @@ class Parameters(BaseModel):
# grammar to use for generation
grammar: Optional[Grammar] = None

@validator("best_of")
@field_validator("best_of")
def valid_best_of(cls, field_value, values):
if field_value is not None:
if field_value <= 0:
raise ValidationError("`best_of` must be strictly positive")
if field_value > 1 and values["seed"] is not None:
if field_value > 1 and values.data["seed"] is not None:
raise ValidationError("`seed` must not be set when `best_of` is > 1")
sampling = (
values["do_sample"]
| (values["temperature"] is not None)
| (values["top_k"] is not None)
| (values["top_p"] is not None)
| (values["typical_p"] is not None)
values.data["do_sample"]
| (values.data["temperature"] is not None)
| (values.data["top_k"] is not None)
| (values.data["top_p"] is not None)
| (values.data["typical_p"] is not None)
)
if field_value > 1 and not sampling:
raise ValidationError("you must use sampling when `best_of` is > 1")

return field_value

@validator("repetition_penalty")
@field_validator("repetition_penalty")
def valid_repetition_penalty(cls, v):
if v is not None and v <= 0:
raise ValidationError("`repetition_penalty` must be strictly positive")
return v

@validator("seed")
@field_validator("seed")
def valid_seed(cls, v):
if v is not None and v < 0:
raise ValidationError("`seed` must be positive")
return v

@validator("temperature")
@field_validator("temperature")
def valid_temp(cls, v):
if v is not None and v <= 0:
raise ValidationError("`temperature` must be strictly positive")
return v

@validator("top_k")
@field_validator("top_k")
def valid_top_k(cls, v):
if v is not None and v <= 0:
raise ValidationError("`top_k` must be strictly positive")
return v

@validator("top_p")
@field_validator("top_p")
def valid_top_p(cls, v):
if v is not None and (v <= 0 or v >= 1.0):
raise ValidationError("`top_p` must be > 0.0 and < 1.0")
return v

@validator("truncate")
@field_validator("truncate")
def valid_truncate(cls, v):
if v is not None and v <= 0:
raise ValidationError("`truncate` must be strictly positive")
return v

@validator("typical_p")
@field_validator("typical_p")
def valid_typical_p(cls, v):
if v is not None and (v <= 0 or v >= 1.0):
raise ValidationError("`typical_p` must be > 0.0 and < 1.0")
return v

@validator("top_n_tokens")
@field_validator("top_n_tokens")
def valid_top_n_tokens(cls, v):
if v is not None and v <= 0:
raise ValidationError("`top_n_tokens` must be strictly positive")
return v

@validator("grammar")
@field_validator("grammar")
def valid_grammar(cls, v):
if v is not None:
if v.type == GrammarType.Regex and not v.value:
Expand All @@ -261,15 +261,15 @@ class Request(BaseModel):
# Whether to stream output tokens
stream: bool = False

@validator("inputs")
@field_validator("inputs")
def valid_input(cls, v):
if not v:
raise ValidationError("`inputs` cannot be empty")
return v

@validator("stream")
@field_validator("stream")
def valid_best_of_stream(cls, field_value, values):
parameters = values["parameters"]
parameters = values.data["parameters"]
if (
parameters is not None
and parameters.best_of is not None
Expand Down
12 changes: 9 additions & 3 deletions integration-tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
Grammar,
ChatComplete,
ChatCompletionChunk,
ChatCompletionComplete,
)

DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None)
Expand All @@ -42,11 +43,16 @@ def serialize(
exclude=None,
matcher=None,
):
if isinstance(data, Response):
data = data.dict()
if (
isinstance(data, Response)
or isinstance(data, ChatComplete)
or isinstance(data, ChatCompletionChunk)
or isinstance(data, ChatCompletionComplete)
):
data = data.model_dump()

if isinstance(data, List):
data = [d.dict() for d in data]
data = [d.model_dump() for d in data]

data = self._filter(
data=data, depth=0, path=(), exclude=exclude, matcher=matcher
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"usage": null
}
],
"created": 1708957015,
"created": 1710795556,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,26 @@
"content": null,
"name": null,
"role": "assistant",
"tool_calls": {
"function": {
"description": null,
"name": "tools",
"parameters": {
"format": "celsius",
"location": "New York, NY",
"num_days": 14
}
},
"id": 0,
"type": "function"
}
"tool_calls": [
{
"function": {
"description": null,
"name": "tools",
"parameters": {
"format": "celsius",
"location": "New York, NY",
"num_days": 14
}
},
"id": 0,
"type": "function"
}
]
},
"usage": null
}
],
"created": 1709079417,
"created": 1710795556,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,26 @@
"content": null,
"name": null,
"role": "assistant",
"tool_calls": {
"function": {
"description": null,
"name": "tools",
"parameters": {
"format": "celsius",
"location": "New York, NY",
"num_days": 14
}
},
"id": 0,
"type": "function"
}
"tool_calls": [
{
"function": {
"description": null,
"name": "tools",
"parameters": {
"format": "celsius",
"location": "New York, NY",
"num_days": 14
}
},
"id": 0,
"type": "function"
}
]
},
"usage": null
}
],
"created": 1709079492,
"created": 1710795557,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,25 @@
"content": null,
"name": null,
"role": "assistant",
"tool_calls": {
"function": {
"description": null,
"name": "tools",
"parameters": {
"format": "celsius",
"location": "New York, NY"
}
},
"id": 0,
"type": "function"
}
"tool_calls": [
{
"function": {
"description": null,
"name": "tools",
"parameters": {
"format": "celsius",
"location": "New York, NY"
}
},
"id": 0,
"type": "function"
}
]
},
"usage": null
}
],
"created": 1709079493,
"created": 1710795557,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,16 @@
"name": null
},
"id": "",
"index": 20,
"index": 0,
"type": "function"
}
},
"finish_reason": "eos_token",
"index": 20,
"index": 0,
"logprobs": null
}
],
"created": 1709087088,
"created": 1710795499,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
Expand Down
Loading

0 comments on commit ab074c8

Please sign in to comment.