diff --git a/libs/partners/openai/langchain_openai/embeddings/base.py b/libs/partners/openai/langchain_openai/embeddings/base.py index 96f4a9b209584..b55fec42afc2e 100644 --- a/libs/partners/openai/langchain_openai/embeddings/base.py +++ b/libs/partners/openai/langchain_openai/embeddings/base.py @@ -184,14 +184,14 @@ class OpenAIEmbeddings(BaseModel, Embeddings): retry_max_seconds: int = 20 """Max number of seconds to wait between retries""" http_client: Union[Any, None] = None - """Optional httpx.Client. Only used for sync invocations. Must specify + """Optional httpx.Client. Only used for sync invocations. Must specify http_async_client as well if you'd like a custom client for async invocations. """ http_async_client: Union[Any, None] = None - """Optional httpx.AsyncClient. Only used for async invocations. Must specify + """Optional httpx.AsyncClient. Only used for async invocations. Must specify http_client as well if you'd like a custom client for sync invocations.""" check_embedding_ctx_length: bool = True - """Whether to check the token length of inputs and automatically split inputs + """Whether to check the token length of inputs and automatically split inputs longer than embedding_ctx_length.""" class Config: @@ -525,9 +525,7 @@ def embed_documents( engine = cast(str, self.deployment) return self._get_len_safe_embeddings(texts, engine=engine) - async def aembed_documents( - self, texts: List[str], chunk_size: Optional[int] = 0 - ) -> List[List[float]]: + async def aembed_documents(self, texts: List[str]) -> List[List[float]]: """Call out to OpenAI's embedding endpoint async for embedding search docs. Args: @@ -540,9 +538,9 @@ async def aembed_documents( """ if not self.check_embedding_ctx_length: embeddings: List[List[float]] = [] - for text in texts: + for i in range(0, len(texts), self.chunk_size): response = await self.async_client.create( - input=text, **self._invocation_params + input=texts[i : i + self.chunk_size], **self._invocation_params ) if not isinstance(response, dict): response = response.dict()