Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: improve tool type, bump pydantic and outlines #1650

Merged
merged 8 commits into from
Mar 21, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

225 changes: 113 additions & 112 deletions clients/python/poetry.lock

Large diffs are not rendered by default.

46 changes: 23 additions & 23 deletions clients/python/text_generation/types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import Enum
from pydantic import BaseModel, validator
from pydantic import BaseModel, field_validator
from typing import Optional, List, Union, Any

from text_generation.errors import ValidationError
Expand Down Expand Up @@ -32,7 +32,7 @@ class Message(BaseModel):
# Role of the message sender
role: str
# Content of the message
content: Optional[str]
content: Optional[str] = None
# Optional name of the message sender
name: Optional[str] = None
# Tool calls associated with the chat completion
Expand All @@ -56,7 +56,7 @@ class ChatCompletionComplete(BaseModel):
# Reason for completion
finish_reason: str
# Usage details of the chat completion
usage: Any
usage: Optional[Any] = None


class Function(BaseModel):
Expand All @@ -73,7 +73,7 @@ class ChoiceDeltaToolCall(BaseModel):

class ChoiceDelta(BaseModel):
role: str
content: Optional[str]
content: Optional[str] = None
tool_calls: Optional[ChoiceDeltaToolCall]


Expand Down Expand Up @@ -176,74 +176,74 @@ class Parameters(BaseModel):
# grammar to use for generation
grammar: Optional[Grammar] = None

@validator("best_of")
@field_validator("best_of")
def valid_best_of(cls, field_value, values):
if field_value is not None:
if field_value <= 0:
raise ValidationError("`best_of` must be strictly positive")
if field_value > 1 and values["seed"] is not None:
if field_value > 1 and values.data["seed"] is not None:
raise ValidationError("`seed` must not be set when `best_of` is > 1")
sampling = (
values["do_sample"]
| (values["temperature"] is not None)
| (values["top_k"] is not None)
| (values["top_p"] is not None)
| (values["typical_p"] is not None)
values.data["do_sample"]
| (values.data["temperature"] is not None)
| (values.data["top_k"] is not None)
| (values.data["top_p"] is not None)
| (values.data["typical_p"] is not None)
)
if field_value > 1 and not sampling:
raise ValidationError("you must use sampling when `best_of` is > 1")

return field_value

@validator("repetition_penalty")
@field_validator("repetition_penalty")
def valid_repetition_penalty(cls, v):
if v is not None and v <= 0:
raise ValidationError("`repetition_penalty` must be strictly positive")
return v

@validator("seed")
@field_validator("seed")
def valid_seed(cls, v):
if v is not None and v < 0:
raise ValidationError("`seed` must be positive")
return v

@validator("temperature")
@field_validator("temperature")
def valid_temp(cls, v):
if v is not None and v <= 0:
raise ValidationError("`temperature` must be strictly positive")
return v

@validator("top_k")
@field_validator("top_k")
def valid_top_k(cls, v):
if v is not None and v <= 0:
raise ValidationError("`top_k` must be strictly positive")
return v

@validator("top_p")
@field_validator("top_p")
def valid_top_p(cls, v):
if v is not None and (v <= 0 or v >= 1.0):
raise ValidationError("`top_p` must be > 0.0 and < 1.0")
return v

@validator("truncate")
@field_validator("truncate")
def valid_truncate(cls, v):
if v is not None and v <= 0:
raise ValidationError("`truncate` must be strictly positive")
return v

@validator("typical_p")
@field_validator("typical_p")
def valid_typical_p(cls, v):
if v is not None and (v <= 0 or v >= 1.0):
raise ValidationError("`typical_p` must be > 0.0 and < 1.0")
return v

@validator("top_n_tokens")
@field_validator("top_n_tokens")
def valid_top_n_tokens(cls, v):
if v is not None and v <= 0:
raise ValidationError("`top_n_tokens` must be strictly positive")
return v

@validator("grammar")
@field_validator("grammar")
def valid_grammar(cls, v):
if v is not None:
if v.type == GrammarType.Regex and not v.value:
Expand All @@ -261,15 +261,15 @@ class Request(BaseModel):
# Whether to stream output tokens
stream: bool = False

@validator("inputs")
@field_validator("inputs")
def valid_input(cls, v):
if not v:
raise ValidationError("`inputs` cannot be empty")
return v

@validator("stream")
@field_validator("stream")
def valid_best_of_stream(cls, field_value, values):
parameters = values["parameters"]
parameters = values.data["parameters"]
if (
parameters is not None
and parameters.best_of is not None
Expand Down
12 changes: 9 additions & 3 deletions integration-tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
Grammar,
ChatComplete,
ChatCompletionChunk,
ChatCompletionComplete,
)

DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None)
Expand All @@ -42,11 +43,16 @@ def serialize(
exclude=None,
matcher=None,
):
if isinstance(data, Response):
data = data.dict()
if (
isinstance(data, Response)
or isinstance(data, ChatComplete)
or isinstance(data, ChatCompletionChunk)
or isinstance(data, ChatCompletionComplete)
):
data = data.model_dump()

if isinstance(data, List):
data = [d.dict() for d in data]
data = [d.model_dump() for d in data]

data = self._filter(
data=data, depth=0, path=(), exclude=exclude, matcher=matcher
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"usage": null
}
],
"created": 1708957015,
"created": 1710795556,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,26 @@
"content": null,
"name": null,
"role": "assistant",
"tool_calls": {
"function": {
"description": null,
"name": "tools",
"parameters": {
"format": "celsius",
"location": "New York, NY",
"num_days": 14
}
},
"id": 0,
"type": "function"
}
"tool_calls": [
{
"function": {
"description": null,
"name": "tools",
"parameters": {
"format": "celsius",
"location": "New York, NY",
"num_days": 14
}
},
"id": 0,
"type": "function"
}
]
},
"usage": null
}
],
"created": 1709079417,
"created": 1710795556,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,26 @@
"content": null,
"name": null,
"role": "assistant",
"tool_calls": {
"function": {
"description": null,
"name": "tools",
"parameters": {
"format": "celsius",
"location": "New York, NY",
"num_days": 14
}
},
"id": 0,
"type": "function"
}
"tool_calls": [
{
"function": {
"description": null,
"name": "tools",
"parameters": {
"format": "celsius",
"location": "New York, NY",
"num_days": 14
}
},
"id": 0,
"type": "function"
}
]
},
"usage": null
}
],
"created": 1709079492,
"created": 1710795557,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,25 @@
"content": null,
"name": null,
"role": "assistant",
"tool_calls": {
"function": {
"description": null,
"name": "tools",
"parameters": {
"format": "celsius",
"location": "New York, NY"
}
},
"id": 0,
"type": "function"
}
"tool_calls": [
{
"function": {
"description": null,
"name": "tools",
"parameters": {
"format": "celsius",
"location": "New York, NY"
}
},
"id": 0,
"type": "function"
}
]
},
"usage": null
}
],
"created": 1709079493,
"created": 1710795557,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,16 @@
"name": null
},
"id": "",
"index": 20,
"index": 0,
"type": "function"
}
},
"finish_reason": "eos_token",
"index": 20,
"index": 0,
"logprobs": null
}
],
"created": 1709087088,
"created": 1710795499,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
Expand Down
Loading
Loading