Skip to content

Commit

Permalink
add default async (#11141)
Browse files Browse the repository at this point in the history
  • Loading branch information
baskaryan authored Oct 4, 2023
1 parent 88c5349 commit 106608b
Show file tree
Hide file tree
Showing 9 changed files with 38 additions and 48 deletions.
11 changes: 3 additions & 8 deletions libs/langchain/langchain/chains/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import logging
import warnings
from abc import ABC, abstractmethod
from functools import partial
from pathlib import Path
from typing import Any, Dict, List, Optional, Type, Union

Expand Down Expand Up @@ -97,12 +96,6 @@ async def ainvoke(
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Dict[str, Any]:
if type(self)._acall == Chain._acall:
# If the chain does not implement async, fall back to default implementation
return await asyncio.get_running_loop().run_in_executor(
None, partial(self.invoke, input, config, **kwargs)
)

config = config or {}
return await self.acall(
input,
Expand Down Expand Up @@ -246,7 +239,9 @@ async def _acall(
A dict of named outputs. Should contain all outputs specified in
`Chain.output_keys`.
"""
raise NotImplementedError("Async call not supported for this chain type.")
return await asyncio.get_running_loop().run_in_executor(
None, self._call, inputs, run_manager
)

def __call__(
self,
Expand Down
5 changes: 1 addition & 4 deletions libs/langchain/langchain/chat_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,10 +577,7 @@ async def _agenerate(
) -> ChatResult:
"""Top Level call"""
return await asyncio.get_running_loop().run_in_executor(
None,
partial(
self._generate, messages, stop=stop, run_manager=run_manager, **kwargs
),
None, partial(self._generate, **kwargs), messages, stop, run_manager
)

def _stream(
Expand Down
27 changes: 6 additions & 21 deletions libs/langchain/langchain/llms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,12 +248,6 @@ async def ainvoke(
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> str:
if type(self)._agenerate == BaseLLM._agenerate:
# model doesn't implement async invoke, so use default implementation
return await asyncio.get_running_loop().run_in_executor(
None, partial(self.invoke, input, config, stop=stop, **kwargs)
)

config = config or {}
llm_result = await self.agenerate_prompt(
[self._convert_input(input)],
Expand Down Expand Up @@ -319,13 +313,6 @@ async def abatch(
) -> List[str]:
if not inputs:
return []

if type(self)._agenerate == BaseLLM._agenerate:
# model doesn't implement async batch, so use default implementation
return await asyncio.get_running_loop().run_in_executor(
None, partial(self.batch, **kwargs), inputs, config
)

config = get_config_list(config, len(inputs))
max_concurrency = config[0].get("max_concurrency")

Expand Down Expand Up @@ -478,7 +465,9 @@ async def _agenerate(
**kwargs: Any,
) -> LLMResult:
"""Run the LLM on the given prompts."""
raise NotImplementedError()
return await asyncio.get_running_loop().run_in_executor(
None, partial(self._generate, **kwargs), prompts, stop, run_manager
)

def _stream(
self,
Expand Down Expand Up @@ -1035,7 +1024,9 @@ async def _acall(
**kwargs: Any,
) -> str:
"""Run the LLM on the given prompt and input."""
raise NotImplementedError()
return await asyncio.get_running_loop().run_in_executor(
None, partial(self._call, **kwargs), prompt, stop, run_manager
)

def _generate(
self,
Expand Down Expand Up @@ -1064,12 +1055,6 @@ async def _agenerate(
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
if type(self)._acall == LLM._acall:
# model doesn't implement async call, so use default implementation
return await asyncio.get_running_loop().run_in_executor(
None, partial(self._generate, prompts, stop, run_manager, **kwargs)
)

"""Run the LLM on the given prompt and input."""
generations = []
new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager")
Expand Down
6 changes: 5 additions & 1 deletion libs/langchain/langchain/schema/document.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import asyncio
from abc import ABC, abstractmethod
from functools import partial
from typing import Any, Sequence

from langchain.load.serializable import Serializable
Expand Down Expand Up @@ -72,7 +74,6 @@ def transform_documents(
A list of transformed Documents.
"""

@abstractmethod
async def atransform_documents(
self, documents: Sequence[Document], **kwargs: Any
) -> Sequence[Document]:
Expand All @@ -84,3 +85,6 @@ async def atransform_documents(
Returns:
A list of transformed Documents.
"""
return await asyncio.get_running_loop().run_in_executor(
None, partial(self.transform_documents, **kwargs), documents
)
9 changes: 7 additions & 2 deletions libs/langchain/langchain/schema/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
from abc import ABC, abstractmethod
from typing import List

Expand All @@ -15,8 +16,12 @@ def embed_query(self, text: str) -> List[float]:

async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
"""Asynchronous Embed search docs."""
raise NotImplementedError
return await asyncio.get_running_loop().run_in_executor(
None, self.embed_documents, texts
)

async def aembed_query(self, text: str) -> List[float]:
"""Asynchronous Embed query text."""
raise NotImplementedError
return await asyncio.get_running_loop().run_in_executor(
None, self.embed_query, text
)
10 changes: 5 additions & 5 deletions libs/langchain/langchain/schema/retriever.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from __future__ import annotations

import asyncio
import warnings
from abc import ABC, abstractmethod
from functools import partial
from inspect import signature
from typing import TYPE_CHECKING, Any, Dict, List, Optional

Expand Down Expand Up @@ -121,10 +123,6 @@ async def ainvoke(
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> List[Document]:
if type(self).aget_relevant_documents == BaseRetriever.aget_relevant_documents:
# If the retriever doesn't implement async, use default implementation
return await super().ainvoke(input, config)

config = config or {}
return await self.aget_relevant_documents(
input,
Expand Down Expand Up @@ -156,7 +154,9 @@ async def _aget_relevant_documents(
Returns:
List of relevant documents
"""
raise NotImplementedError()
return await asyncio.get_running_loop().run_in_executor(
None, partial(self._get_relevant_documents, run_manager=run_manager), query
)

def get_relevant_documents(
self,
Expand Down
8 changes: 6 additions & 2 deletions libs/langchain/langchain/schema/vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,9 @@ async def aadd_texts(
**kwargs: Any,
) -> List[str]:
"""Run more texts through the embeddings and add to the vectorstore."""
raise NotImplementedError
return await asyncio.get_running_loop().run_in_executor(
None, partial(self.add_texts, **kwargs), texts, metadatas
)

def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:
"""Run more documents through the embeddings and add to the vectorstore.
Expand Down Expand Up @@ -451,7 +453,9 @@ async def afrom_texts(
**kwargs: Any,
) -> VST:
"""Return VectorStore initialized from texts and embeddings."""
raise NotImplementedError
return await asyncio.get_running_loop().run_in_executor(
None, partial(cls.from_texts, **kwargs), texts, embedding, metadatas
)

def _get_retriever_tags(self) -> List[str]:
"""Get tags for retriever."""
Expand Down
6 changes: 5 additions & 1 deletion libs/langchain/langchain/text_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@

from __future__ import annotations

import asyncio
import copy
import logging
import pathlib
import re
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from functools import partial
from io import BytesIO, StringIO
from typing import (
AbstractSet,
Expand Down Expand Up @@ -284,7 +286,9 @@ async def atransform_documents(
self, documents: Sequence[Document], **kwargs: Any
) -> Sequence[Document]:
"""Asynchronously transform a sequence of documents by splitting them."""
raise NotImplementedError
return await asyncio.get_running_loop().run_in_executor(
None, partial(self.transform_documents, **kwargs), documents
)


class CharacterTextSplitter(TextSplitter):
Expand Down
4 changes: 0 additions & 4 deletions libs/langchain/langchain/tools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,10 +217,6 @@ async def ainvoke(
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Any:
if type(self)._arun == BaseTool._arun:
# If the tool does not implement async, fall back to default implementation
return await super().ainvoke(input, config, **kwargs)

config = config or {}
return await self.arun(
input,
Expand Down

0 comments on commit 106608b

Please sign in to comment.