Skip to content

Commit

Permalink
openai: embeddings: supported chunk_size when check_embedding_ctx_len…
Browse files Browse the repository at this point in the history
…gth is disabled (#23767)

Chunking of the input array controlled by `self.chunk_size` is being
ignored when `self.check_embedding_ctx_length` is disabled. Effectively,
the chunk size is assumed to be equal 1 in such a case. This is
suprising.

The PR takes into account `self.chunk_size` passed by the user.

---------

Co-authored-by: Erick Friis <[email protected]>
  • Loading branch information
adubovik and efriis authored Sep 20, 2024
1 parent 864020e commit 3e2cb4e
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 9 deletions.
22 changes: 13 additions & 9 deletions libs/partners/openai/langchain_openai/embeddings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,14 +254,14 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
retry_max_seconds: int = 20
"""Max number of seconds to wait between retries"""
http_client: Union[Any, None] = None
"""Optional httpx.Client. Only used for sync invocations. Must specify
"""Optional httpx.Client. Only used for sync invocations. Must specify
http_async_client as well if you'd like a custom client for async invocations.
"""
http_async_client: Union[Any, None] = None
"""Optional httpx.AsyncClient. Only used for async invocations. Must specify
"""Optional httpx.AsyncClient. Only used for async invocations. Must specify
http_client as well if you'd like a custom client for sync invocations."""
check_embedding_ctx_length: bool = True
"""Whether to check the token length of inputs and automatically split inputs
"""Whether to check the token length of inputs and automatically split inputs
longer than embedding_ctx_length."""

model_config = ConfigDict(
Expand Down Expand Up @@ -558,7 +558,7 @@ async def empty_embedding() -> List[float]:
return [e if e is not None else await empty_embedding() for e in embeddings]

def embed_documents(
self, texts: List[str], chunk_size: Optional[int] = 0
self, texts: List[str], chunk_size: int | None = None
) -> List[List[float]]:
"""Call out to OpenAI's embedding endpoint for embedding search docs.
Expand All @@ -570,10 +570,13 @@ def embed_documents(
Returns:
List of embeddings, one for each text.
"""
chunk_size_ = chunk_size or self.chunk_size
if not self.check_embedding_ctx_length:
embeddings: List[List[float]] = []
for text in texts:
response = self.client.create(input=text, **self._invocation_params)
for i in range(0, len(texts), self.chunk_size):
response = self.client.create(
input=texts[i : i + chunk_size_], **self._invocation_params
)
if not isinstance(response, dict):
response = response.dict()
embeddings.extend(r["embedding"] for r in response["data"])
Expand All @@ -585,7 +588,7 @@ def embed_documents(
return self._get_len_safe_embeddings(texts, engine=engine)

async def aembed_documents(
self, texts: List[str], chunk_size: Optional[int] = 0
self, texts: List[str], chunk_size: int | None = None
) -> List[List[float]]:
"""Call out to OpenAI's embedding endpoint async for embedding search docs.
Expand All @@ -597,11 +600,12 @@ async def aembed_documents(
Returns:
List of embeddings, one for each text.
"""
chunk_size_ = chunk_size or self.chunk_size
if not self.check_embedding_ctx_length:
embeddings: List[List[float]] = []
for text in texts:
for i in range(0, len(texts), chunk_size_):
response = await self.async_client.create(
input=text, **self._invocation_params
input=texts[i : i + chunk_size_], **self._invocation_params
)
if not isinstance(response, dict):
response = response.dict()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,12 @@ async def test_langchain_openai_embeddings_equivalent_to_raw_async() -> None:
.embedding
)
assert np.isclose(lc_output, direct_output).all()


def test_langchain_openai_embeddings_dimensions_large_num() -> None:
"""Test openai embeddings."""
documents = [f"foo bar {i}" for i in range(2000)]
embedding = OpenAIEmbeddings(model="text-embedding-3-small", dimensions=128)
output = embedding.embed_documents(documents)
assert len(output) == 2000
assert len(output[0]) == 128

0 comments on commit 3e2cb4e

Please sign in to comment.