Skip to content

Commit

Permalink
partners[openai]: embeddings: supported chunk_size when check_embeddi…
Browse files Browse the repository at this point in the history
…ng_ctx_length is disabled
  • Loading branch information
adubovik committed Jul 4, 2024
1 parent 0e916d0 commit b470364
Showing 1 changed file with 6 additions and 8 deletions.
14 changes: 6 additions & 8 deletions libs/partners/openai/langchain_openai/embeddings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,14 +184,14 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
retry_max_seconds: int = 20
"""Max number of seconds to wait between retries"""
http_client: Union[Any, None] = None
"""Optional httpx.Client. Only used for sync invocations. Must specify
"""Optional httpx.Client. Only used for sync invocations. Must specify
http_async_client as well if you'd like a custom client for async invocations.
"""
http_async_client: Union[Any, None] = None
"""Optional httpx.AsyncClient. Only used for async invocations. Must specify
"""Optional httpx.AsyncClient. Only used for async invocations. Must specify
http_client as well if you'd like a custom client for sync invocations."""
check_embedding_ctx_length: bool = True
"""Whether to check the token length of inputs and automatically split inputs
"""Whether to check the token length of inputs and automatically split inputs
longer than embedding_ctx_length."""

class Config:
Expand Down Expand Up @@ -525,9 +525,7 @@ def embed_documents(
engine = cast(str, self.deployment)
return self._get_len_safe_embeddings(texts, engine=engine)

async def aembed_documents(
self, texts: List[str], chunk_size: Optional[int] = 0
) -> List[List[float]]:
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
"""Call out to OpenAI's embedding endpoint async for embedding search docs.
Args:
Expand All @@ -540,9 +538,9 @@ async def aembed_documents(
"""
if not self.check_embedding_ctx_length:
embeddings: List[List[float]] = []
for text in texts:
for i in range(0, len(texts), self.chunk_size):
response = await self.async_client.create(
input=text, **self._invocation_params
input=texts[i : i + self.chunk_size], **self._invocation_params
)
if not isinstance(response, dict):
response = response.dict()
Expand Down

0 comments on commit b470364

Please sign in to comment.