Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core: fix batch race condition in FakeListChatModel #26924

Merged
merged 5 commits into from
Oct 3, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions libs/core/langchain_core/language_models/fake_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from langchain_core.language_models.chat_models import BaseChatModel, SimpleChatModel
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.runnables import RunnableConfig


class FakeMessagesListChatModel(BaseChatModel):
Expand Down Expand Up @@ -128,6 +129,33 @@ async def _astream(
def _identifying_params(self) -> dict[str, Any]:
return {"responses": self.responses}

# manually override batch to preserve batch ordering with no concurrency
def batch(
self,
inputs: list[Any],
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Any,
) -> list[BaseMessage]:
if isinstance(config, list):
return [self.invoke(m, c, **kwargs) for m, c in zip(inputs, config)]
return [self.invoke(m, config, **kwargs) for m in inputs]

async def abatch(
self,
inputs: list[Any],
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Any,
) -> list[BaseMessage]:
if isinstance(config, list):
# do Not use an async iterator here because need explicit ordering
return [await self.ainvoke(m, c, **kwargs) for m, c in zip(inputs, config)]
# do Not use an async iterator here because need explicit ordering
return [await self.ainvoke(m, config, **kwargs) for m in inputs]


class FakeChatModel(SimpleChatModel):
"""Fake Chat Model wrapper for testing purposes."""
Expand Down
19 changes: 18 additions & 1 deletion libs/core/tests/unit_tests/fake/test_fake_chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
from uuid import UUID

from langchain_core.callbacks.base import AsyncCallbackHandler
from langchain_core.language_models import GenericFakeChatModel, ParrotFakeChatModel
from langchain_core.language_models import (
FakeListChatModel,
GenericFakeChatModel,
ParrotFakeChatModel,
)
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
from tests.unit_tests.stubs import (
Expand Down Expand Up @@ -205,3 +209,16 @@ def test_chat_model_inputs() -> None:
assert fake.invoke([AIMessage(content="blah")]) == _any_id_ai_message(
content="blah"
)


def test_fake_list_chat_model_batch() -> None:
expected = [
_any_id_ai_message(content="a"),
_any_id_ai_message(content="b"),
_any_id_ai_message(content="c"),
]
for _ in range(20):
# run this 20 times to test race condition in batch
fake = FakeListChatModel(responses=["a", "b", "c"])
resp = fake.batch(["1", "2", "3"])
assert resp == expected
Loading