Skip to content

Commit

Permalink
fix(json schema): unravel $refs alongside additional keys
Browse files Browse the repository at this point in the history
  • Loading branch information
RobertCraigie committed Aug 12, 2024
1 parent 53d964d commit c7a3d29
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 17 deletions.
73 changes: 59 additions & 14 deletions src/openai/lib/_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,32 @@


def to_strict_json_schema(model: type[pydantic.BaseModel]) -> dict[str, Any]:
return _ensure_strict_json_schema(model_json_schema(model), path=())
schema = model_json_schema(model)
return _ensure_strict_json_schema(schema, path=(), root=schema)


def _ensure_strict_json_schema(
json_schema: object,
*,
path: tuple[str, ...],
root: dict[str, object],
) -> dict[str, Any]:
"""Mutates the given JSON schema to ensure it conforms to the `strict` standard
that the API expects.
"""
if not is_dict(json_schema):
raise TypeError(f"Expected {json_schema} to be a dictionary; path={path}")

defs = json_schema.get("$defs")
if is_dict(defs):
for def_name, def_schema in defs.items():
_ensure_strict_json_schema(def_schema, path=(*path, "$defs", def_name), root=root)

definitions = json_schema.get("definitions")
if is_dict(definitions):
for definition_name, definition_schema in definitions.items():
_ensure_strict_json_schema(definition_schema, path=(*path, "definitions", definition_name), root=root)

typ = json_schema.get("type")
if typ == "object" and "additionalProperties" not in json_schema:
json_schema["additionalProperties"] = False
Expand All @@ -33,48 +46,80 @@ def _ensure_strict_json_schema(
if is_dict(properties):
json_schema["required"] = [prop for prop in properties.keys()]
json_schema["properties"] = {
key: _ensure_strict_json_schema(prop_schema, path=(*path, "properties", key))
key: _ensure_strict_json_schema(prop_schema, path=(*path, "properties", key), root=root)
for key, prop_schema in properties.items()
}

# arrays
# { 'type': 'array', 'items': {...} }
items = json_schema.get("items")
if is_dict(items):
json_schema["items"] = _ensure_strict_json_schema(items, path=(*path, "items"))
json_schema["items"] = _ensure_strict_json_schema(items, path=(*path, "items"), root=root)

# unions
any_of = json_schema.get("anyOf")
if is_list(any_of):
json_schema["anyOf"] = [
_ensure_strict_json_schema(variant, path=(*path, "anyOf", str(i))) for i, variant in enumerate(any_of)
_ensure_strict_json_schema(variant, path=(*path, "anyOf", str(i)), root=root)
for i, variant in enumerate(any_of)
]

# intersections
all_of = json_schema.get("allOf")
if is_list(all_of):
if len(all_of) == 1:
json_schema.update(_ensure_strict_json_schema(all_of[0], path=(*path, "allOf", "0")))
json_schema.update(_ensure_strict_json_schema(all_of[0], path=(*path, "allOf", "0"), root=root))
json_schema.pop("allOf")
else:
json_schema["allOf"] = [
_ensure_strict_json_schema(entry, path=(*path, "allOf", str(i))) for i, entry in enumerate(all_of)
_ensure_strict_json_schema(entry, path=(*path, "allOf", str(i)), root=root)
for i, entry in enumerate(all_of)
]

defs = json_schema.get("$defs")
if is_dict(defs):
for def_name, def_schema in defs.items():
_ensure_strict_json_schema(def_schema, path=(*path, "$defs", def_name))
# we can't use `$ref`s if there are also other properties defined, e.g.
# `{"$ref": "...", "description": "my description"}`
#
# so we unravel the ref
# `{"type": "string", "description": "my description"}`
ref = json_schema.get("$ref")
if ref and has_more_than_n_keys(json_schema, 1):
assert isinstance(ref, str), f"Received non-string $ref - {ref}"

definitions = json_schema.get("definitions")
if is_dict(definitions):
for definition_name, definition_schema in definitions.items():
_ensure_strict_json_schema(definition_schema, path=(*path, "definitions", definition_name))
resolved = resolve_ref(root=root, ref=ref)
if not is_dict(resolved):
raise ValueError(f"Expected `$ref: {ref}` to resolved to a dictionary but got {resolved}")

# properties from the json schema take priority over the ones on the `$ref`
json_schema.update({**resolved, **json_schema})
json_schema.pop("$ref")

return json_schema


def resolve_ref(*, root: dict[str, object], ref: str) -> object:
if not ref.startswith("#/"):
raise ValueError(f"Unexpected $ref format {ref!r}; Does not start with #/")

path = ref[2:].split("/")
resolved = root
for key in path:
value = resolved[key]
assert is_dict(value), f"encountered non-dictionary entry while resolving {ref} - {resolved}"
resolved = value

return resolved


def is_dict(obj: object) -> TypeGuard[dict[str, object]]:
# just pretend that we know there are only `str` keys
# as that check is not worth the performance cost
return _is_dict(obj)


def has_more_than_n_keys(obj: dict[str, object], n: int) -> bool:
i = 0
for _ in obj.keys():
i += 1
if i > n:
return True
return False
50 changes: 49 additions & 1 deletion tests/lib/chat/test_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@

import os
import json
from enum import Enum
from typing import Any, Callable
from typing_extensions import Literal, TypeVar

import httpx
import pytest
from respx import MockRouter
from pydantic import BaseModel
from pydantic import Field, BaseModel
from inline_snapshot import snapshot

import openai
Expand Down Expand Up @@ -133,6 +134,53 @@ class Location(BaseModel):
)


@pytest.mark.respx(base_url=base_url)
def test_parse_pydantic_model_enum(client: OpenAI, respx_mock: MockRouter, monkeypatch: pytest.MonkeyPatch) -> None:
class Color(Enum):
"""The detected color"""

RED = "red"
BLUE = "blue"
GREEN = "green"

class ColorDetection(BaseModel):
color: Color
hex_color_code: str = Field(description="The hex color code of the detected color")

completion = _make_snapshot_request(
lambda c: c.beta.chat.completions.parse(
model="gpt-4o-2024-08-06",
messages=[
{"role": "user", "content": "What color is a Coke can?"},
],
response_format=ColorDetection,
),
content_snapshot=snapshot(
'{"id": "chatcmpl-9vK4UZVr385F2UgZlP1ShwPn2nFxG", "object": "chat.completion", "created": 1723448878, "model": "gpt-4o-2024-08-06", "choices": [{"index": 0, "message": {"role": "assistant", "content": "{\\"color\\":\\"red\\",\\"hex_color_code\\":\\"#FF0000\\"}", "refusal": null}, "logprobs": null, "finish_reason": "stop"}], "usage": {"prompt_tokens": 18, "completion_tokens": 14, "total_tokens": 32}, "system_fingerprint": "fp_845eaabc1f"}'
),
mock_client=client,
respx_mock=respx_mock,
)

assert print_obj(completion.choices[0], monkeypatch) == snapshot(
"""\
ParsedChoice[ColorDetection](
finish_reason='stop',
index=0,
logprobs=None,
message=ParsedChatCompletionMessage[ColorDetection](
content='{"color":"red","hex_color_code":"#FF0000"}',
function_call=None,
parsed=ColorDetection(color=<Color.RED: 'red'>, hex_color_code='#FF0000'),
refusal=None,
role='assistant',
tool_calls=[]
)
)
"""
)


@pytest.mark.respx(base_url=base_url)
def test_parse_pydantic_model_multiple_choices(
client: OpenAI, respx_mock: MockRouter, monkeypatch: pytest.MonkeyPatch
Expand Down
13 changes: 11 additions & 2 deletions tests/lib/test_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,12 @@ def test_enums() -> None:
"parameters": {
"$defs": {"Color": {"enum": ["red", "blue", "green"], "title": "Color", "type": "string"}},
"properties": {
"color": {"description": "The detected color", "$ref": "#/$defs/Color"},
"color": {
"description": "The detected color",
"enum": ["red", "blue", "green"],
"title": "Color",
"type": "string",
},
"hex_color_code": {
"description": "The hex color code of the detected color",
"title": "Hex Color Code",
Expand All @@ -207,7 +212,11 @@ def test_enums() -> None:
"strict": True,
"parameters": {
"properties": {
"color": {"description": "The detected color", "$ref": "#/definitions/Color"},
"color": {
"description": "The detected color",
"title": "Color",
"enum": ["red", "blue", "green"],
},
"hex_color_code": {
"description": "The hex color code of the detected color",
"title": "Hex Color Code",
Expand Down

0 comments on commit c7a3d29

Please sign in to comment.