Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

openai: embeddings: supported chunk_size when check_embedding_ctx_length is disabled #23767

Merged
merged 5 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions libs/partners/openai/langchain_openai/embeddings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,14 +254,14 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
retry_max_seconds: int = 20
"""Max number of seconds to wait between retries"""
http_client: Union[Any, None] = None
"""Optional httpx.Client. Only used for sync invocations. Must specify
"""Optional httpx.Client. Only used for sync invocations. Must specify
http_async_client as well if you'd like a custom client for async invocations.
"""
http_async_client: Union[Any, None] = None
"""Optional httpx.AsyncClient. Only used for async invocations. Must specify
"""Optional httpx.AsyncClient. Only used for async invocations. Must specify
http_client as well if you'd like a custom client for sync invocations."""
check_embedding_ctx_length: bool = True
"""Whether to check the token length of inputs and automatically split inputs
"""Whether to check the token length of inputs and automatically split inputs
longer than embedding_ctx_length."""

model_config = ConfigDict(
Expand Down Expand Up @@ -558,7 +558,7 @@ async def empty_embedding() -> List[float]:
return [e if e is not None else await empty_embedding() for e in embeddings]

def embed_documents(
self, texts: List[str], chunk_size: Optional[int] = 0
self, texts: List[str], chunk_size: int | None = None
) -> List[List[float]]:
"""Call out to OpenAI's embedding endpoint for embedding search docs.

Expand All @@ -570,10 +570,13 @@ def embed_documents(
Returns:
List of embeddings, one for each text.
"""
chunk_size_ = chunk_size or self.chunk_size
if not self.check_embedding_ctx_length:
embeddings: List[List[float]] = []
for text in texts:
response = self.client.create(input=text, **self._invocation_params)
for i in range(0, len(texts), self.chunk_size):
response = self.client.create(
input=texts[i : i + chunk_size_], **self._invocation_params
)
if not isinstance(response, dict):
response = response.dict()
embeddings.extend(r["embedding"] for r in response["data"])
Expand All @@ -585,7 +588,7 @@ def embed_documents(
return self._get_len_safe_embeddings(texts, engine=engine)

async def aembed_documents(
self, texts: List[str], chunk_size: Optional[int] = 0
self, texts: List[str], chunk_size: int | None = None
) -> List[List[float]]:
"""Call out to OpenAI's embedding endpoint async for embedding search docs.

Expand All @@ -597,11 +600,12 @@ async def aembed_documents(
Returns:
List of embeddings, one for each text.
"""
chunk_size_ = chunk_size or self.chunk_size
if not self.check_embedding_ctx_length:
embeddings: List[List[float]] = []
for text in texts:
for i in range(0, len(texts), chunk_size_):
response = await self.async_client.create(
input=text, **self._invocation_params
input=texts[i : i + chunk_size_], **self._invocation_params
)
if not isinstance(response, dict):
response = response.dict()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,12 @@ async def test_langchain_openai_embeddings_equivalent_to_raw_async() -> None:
.embedding
)
assert np.isclose(lc_output, direct_output).all()


def test_langchain_openai_embeddings_dimensions_large_num() -> None:
"""Test openai embeddings."""
documents = [f"foo bar {i}" for i in range(2000)]
embedding = OpenAIEmbeddings(model="text-embedding-3-small", dimensions=128)
output = embedding.embed_documents(documents)
assert len(output) == 2000
assert len(output[0]) == 128
Loading