Skip to content

Commit

Permalink
fix(client): raise helpful error message for response_format misuse
Browse files Browse the repository at this point in the history
  • Loading branch information
RobertCraigie committed Aug 8, 2024
1 parent 631a2a7 commit 18191da
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 0 deletions.
11 changes: 11 additions & 0 deletions src/openai/resources/chat/completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

from __future__ import annotations

import inspect
from typing import Dict, List, Union, Iterable, Optional, overload
from typing_extensions import Literal

import httpx
import pydantic

from ... import _legacy_response
from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven
Expand Down Expand Up @@ -647,6 +649,7 @@ def create(
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> ChatCompletion | Stream[ChatCompletionChunk]:
validate_response_format(response_format)
return self._post(
"/chat/completions",
body=maybe_transform(
Expand Down Expand Up @@ -1302,6 +1305,7 @@ async def create(
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> ChatCompletion | AsyncStream[ChatCompletionChunk]:
validate_response_format(response_format)
return await self._post(
"/chat/completions",
body=await async_maybe_transform(
Expand Down Expand Up @@ -1375,3 +1379,10 @@ def __init__(self, completions: AsyncCompletions) -> None:
self.create = async_to_streamed_response_wrapper(
completions.create,
)


def validate_response_format(response_format: object) -> None:
if inspect.isclass(response_format) and issubclass(response_format, pydantic.BaseModel):
raise TypeError(
"You tried to pass a `BaseModel` class to `chat.completions.create()`; You must use `beta.chat.completions.parse()` instead"
)
35 changes: 35 additions & 0 deletions tests/api_resources/chat/test_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Any, cast

import pytest
import pydantic

from openai import OpenAI, AsyncOpenAI
from tests.utils import assert_matches_type
Expand Down Expand Up @@ -257,6 +258,23 @@ def test_streaming_response_create_overload_2(self, client: OpenAI) -> None:

assert cast(Any, response.is_closed) is True

@parametrize
def test_method_create_disallows_pydantic(self, client: OpenAI) -> None:
class MyModel(pydantic.BaseModel):
a: str

with pytest.raises(TypeError, match=r"You tried to pass a `BaseModel` class"):
client.chat.completions.create(
messages=[
{
"content": "string",
"role": "system",
}
],
model="gpt-4o",
response_format=cast(Any, MyModel),
)


class TestAsyncCompletions:
parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"])
Expand Down Expand Up @@ -498,3 +516,20 @@ async def test_streaming_response_create_overload_2(self, async_client: AsyncOpe
await stream.close()

assert cast(Any, response.is_closed) is True

@parametrize
async def test_method_create_disallows_pydantic(self, async_client: AsyncOpenAI) -> None:
class MyModel(pydantic.BaseModel):
a: str

with pytest.raises(TypeError, match=r"You tried to pass a `BaseModel` class"):
await async_client.chat.completions.create(
messages=[
{
"content": "string",
"role": "system",
}
],
model="gpt-4o",
response_format=cast(Any, MyModel),
)

0 comments on commit 18191da

Please sign in to comment.