From c7a3d2986acaf3b31844b39608d03265ad87bb04 Mon Sep 17 00:00:00 2001 From: Robert Craigie Date: Mon, 12 Aug 2024 08:48:17 +0100 Subject: [PATCH] fix(json schema): unravel `$ref`s alongside additional keys --- src/openai/lib/_pydantic.py | 73 ++++++++++++++++++++++++------ tests/lib/chat/test_completions.py | 50 +++++++++++++++++++- tests/lib/test_pydantic.py | 13 +++++- 3 files changed, 119 insertions(+), 17 deletions(-) diff --git a/src/openai/lib/_pydantic.py b/src/openai/lib/_pydantic.py index 85f147c236..ad3b6eb29f 100644 --- a/src/openai/lib/_pydantic.py +++ b/src/openai/lib/_pydantic.py @@ -10,12 +10,15 @@ def to_strict_json_schema(model: type[pydantic.BaseModel]) -> dict[str, Any]: - return _ensure_strict_json_schema(model_json_schema(model), path=()) + schema = model_json_schema(model) + return _ensure_strict_json_schema(schema, path=(), root=schema) def _ensure_strict_json_schema( json_schema: object, + *, path: tuple[str, ...], + root: dict[str, object], ) -> dict[str, Any]: """Mutates the given JSON schema to ensure it conforms to the `strict` standard that the API expects. @@ -23,6 +26,16 @@ def _ensure_strict_json_schema( if not is_dict(json_schema): raise TypeError(f"Expected {json_schema} to be a dictionary; path={path}") + defs = json_schema.get("$defs") + if is_dict(defs): + for def_name, def_schema in defs.items(): + _ensure_strict_json_schema(def_schema, path=(*path, "$defs", def_name), root=root) + + definitions = json_schema.get("definitions") + if is_dict(definitions): + for definition_name, definition_schema in definitions.items(): + _ensure_strict_json_schema(definition_schema, path=(*path, "definitions", definition_name), root=root) + typ = json_schema.get("type") if typ == "object" and "additionalProperties" not in json_schema: json_schema["additionalProperties"] = False @@ -33,7 +46,7 @@ def _ensure_strict_json_schema( if is_dict(properties): json_schema["required"] = [prop for prop in properties.keys()] json_schema["properties"] = { - key: _ensure_strict_json_schema(prop_schema, path=(*path, "properties", key)) + key: _ensure_strict_json_schema(prop_schema, path=(*path, "properties", key), root=root) for key, prop_schema in properties.items() } @@ -41,40 +54,72 @@ def _ensure_strict_json_schema( # { 'type': 'array', 'items': {...} } items = json_schema.get("items") if is_dict(items): - json_schema["items"] = _ensure_strict_json_schema(items, path=(*path, "items")) + json_schema["items"] = _ensure_strict_json_schema(items, path=(*path, "items"), root=root) # unions any_of = json_schema.get("anyOf") if is_list(any_of): json_schema["anyOf"] = [ - _ensure_strict_json_schema(variant, path=(*path, "anyOf", str(i))) for i, variant in enumerate(any_of) + _ensure_strict_json_schema(variant, path=(*path, "anyOf", str(i)), root=root) + for i, variant in enumerate(any_of) ] # intersections all_of = json_schema.get("allOf") if is_list(all_of): if len(all_of) == 1: - json_schema.update(_ensure_strict_json_schema(all_of[0], path=(*path, "allOf", "0"))) + json_schema.update(_ensure_strict_json_schema(all_of[0], path=(*path, "allOf", "0"), root=root)) json_schema.pop("allOf") else: json_schema["allOf"] = [ - _ensure_strict_json_schema(entry, path=(*path, "allOf", str(i))) for i, entry in enumerate(all_of) + _ensure_strict_json_schema(entry, path=(*path, "allOf", str(i)), root=root) + for i, entry in enumerate(all_of) ] - defs = json_schema.get("$defs") - if is_dict(defs): - for def_name, def_schema in defs.items(): - _ensure_strict_json_schema(def_schema, path=(*path, "$defs", def_name)) + # we can't use `$ref`s if there are also other properties defined, e.g. + # `{"$ref": "...", "description": "my description"}` + # + # so we unravel the ref + # `{"type": "string", "description": "my description"}` + ref = json_schema.get("$ref") + if ref and has_more_than_n_keys(json_schema, 1): + assert isinstance(ref, str), f"Received non-string $ref - {ref}" - definitions = json_schema.get("definitions") - if is_dict(definitions): - for definition_name, definition_schema in definitions.items(): - _ensure_strict_json_schema(definition_schema, path=(*path, "definitions", definition_name)) + resolved = resolve_ref(root=root, ref=ref) + if not is_dict(resolved): + raise ValueError(f"Expected `$ref: {ref}` to resolved to a dictionary but got {resolved}") + + # properties from the json schema take priority over the ones on the `$ref` + json_schema.update({**resolved, **json_schema}) + json_schema.pop("$ref") return json_schema +def resolve_ref(*, root: dict[str, object], ref: str) -> object: + if not ref.startswith("#/"): + raise ValueError(f"Unexpected $ref format {ref!r}; Does not start with #/") + + path = ref[2:].split("/") + resolved = root + for key in path: + value = resolved[key] + assert is_dict(value), f"encountered non-dictionary entry while resolving {ref} - {resolved}" + resolved = value + + return resolved + + def is_dict(obj: object) -> TypeGuard[dict[str, object]]: # just pretend that we know there are only `str` keys # as that check is not worth the performance cost return _is_dict(obj) + + +def has_more_than_n_keys(obj: dict[str, object], n: int) -> bool: + i = 0 + for _ in obj.keys(): + i += 1 + if i > n: + return True + return False diff --git a/tests/lib/chat/test_completions.py b/tests/lib/chat/test_completions.py index e406a5a3bc..d2189e7cb6 100644 --- a/tests/lib/chat/test_completions.py +++ b/tests/lib/chat/test_completions.py @@ -2,13 +2,14 @@ import os import json +from enum import Enum from typing import Any, Callable from typing_extensions import Literal, TypeVar import httpx import pytest from respx import MockRouter -from pydantic import BaseModel +from pydantic import Field, BaseModel from inline_snapshot import snapshot import openai @@ -133,6 +134,53 @@ class Location(BaseModel): ) +@pytest.mark.respx(base_url=base_url) +def test_parse_pydantic_model_enum(client: OpenAI, respx_mock: MockRouter, monkeypatch: pytest.MonkeyPatch) -> None: + class Color(Enum): + """The detected color""" + + RED = "red" + BLUE = "blue" + GREEN = "green" + + class ColorDetection(BaseModel): + color: Color + hex_color_code: str = Field(description="The hex color code of the detected color") + + completion = _make_snapshot_request( + lambda c: c.beta.chat.completions.parse( + model="gpt-4o-2024-08-06", + messages=[ + {"role": "user", "content": "What color is a Coke can?"}, + ], + response_format=ColorDetection, + ), + content_snapshot=snapshot( + '{"id": "chatcmpl-9vK4UZVr385F2UgZlP1ShwPn2nFxG", "object": "chat.completion", "created": 1723448878, "model": "gpt-4o-2024-08-06", "choices": [{"index": 0, "message": {"role": "assistant", "content": "{\\"color\\":\\"red\\",\\"hex_color_code\\":\\"#FF0000\\"}", "refusal": null}, "logprobs": null, "finish_reason": "stop"}], "usage": {"prompt_tokens": 18, "completion_tokens": 14, "total_tokens": 32}, "system_fingerprint": "fp_845eaabc1f"}' + ), + mock_client=client, + respx_mock=respx_mock, + ) + + assert print_obj(completion.choices[0], monkeypatch) == snapshot( + """\ +ParsedChoice[ColorDetection]( + finish_reason='stop', + index=0, + logprobs=None, + message=ParsedChatCompletionMessage[ColorDetection]( + content='{"color":"red","hex_color_code":"#FF0000"}', + function_call=None, + parsed=ColorDetection(color=, hex_color_code='#FF0000'), + refusal=None, + role='assistant', + tool_calls=[] + ) +) +""" + ) + + @pytest.mark.respx(base_url=base_url) def test_parse_pydantic_model_multiple_choices( client: OpenAI, respx_mock: MockRouter, monkeypatch: pytest.MonkeyPatch diff --git a/tests/lib/test_pydantic.py b/tests/lib/test_pydantic.py index 568844eada..531a89df58 100644 --- a/tests/lib/test_pydantic.py +++ b/tests/lib/test_pydantic.py @@ -186,7 +186,12 @@ def test_enums() -> None: "parameters": { "$defs": {"Color": {"enum": ["red", "blue", "green"], "title": "Color", "type": "string"}}, "properties": { - "color": {"description": "The detected color", "$ref": "#/$defs/Color"}, + "color": { + "description": "The detected color", + "enum": ["red", "blue", "green"], + "title": "Color", + "type": "string", + }, "hex_color_code": { "description": "The hex color code of the detected color", "title": "Hex Color Code", @@ -207,7 +212,11 @@ def test_enums() -> None: "strict": True, "parameters": { "properties": { - "color": {"description": "The detected color", "$ref": "#/definitions/Color"}, + "color": { + "description": "The detected color", + "title": "Color", + "enum": ["red", "blue", "green"], + }, "hex_color_code": { "description": "The hex color code of the detected color", "title": "Hex Color Code",