From 294d98c8f89b412287d6d021b438627fee153382 Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Wed, 14 Aug 2024 02:03:02 +0200 Subject: [PATCH] core: Add N(naming) ruff rules --- .../langchain_core/_api/beta_decorator.py | 4 +- libs/core/langchain_core/_api/deprecation.py | 4 +- libs/core/langchain_core/exceptions.py | 4 +- .../langchain_core/language_models/base.py | 3 +- .../language_models/chat_models.py | 3 +- .../langchain_core/language_models/llms.py | 2 + .../langchain_core/output_parsers/base.py | 6 +- .../langchain_core/output_parsers/pydantic.py | 2 + .../core/langchain_core/output_parsers/xml.py | 12 ++-- libs/core/langchain_core/prompts/base.py | 2 + libs/core/langchain_core/runnables/base.py | 17 +++++- .../langchain_core/runnables/configurable.py | 4 ++ .../langchain_core/runnables/fallbacks.py | 4 ++ .../langchain_core/runnables/graph_ascii.py | 24 ++++---- .../langchain_core/runnables/passthrough.py | 4 ++ libs/core/langchain_core/runnables/utils.py | 13 ++++- libs/core/langchain_core/tools.py | 2 +- .../langchain_core/tracers/langchain_v1.py | 2 +- libs/core/langchain_core/tracers/schemas.py | 2 +- libs/core/langchain_core/utils/aiter.py | 2 +- .../core/langchain_core/vectorstores/utils.py | 32 +++++----- libs/core/pyproject.toml | 10 +++- .../unit_tests/fake/test_fake_chat_model.py | 50 ++++++++-------- .../language_models/chat_models/test_base.py | 18 +++--- .../tests/unit_tests/prompts/test_loading.py | 8 +-- .../tests/unit_tests/runnables/test_graph.py | 5 ++ .../unit_tests/runnables/test_runnable.py | 36 ++++++------ .../runnables/test_runnable_events_v1.py | 38 ++++++------ .../runnables/test_runnable_events_v2.py | 34 +++++------ libs/core/tests/unit_tests/stubs.py | 8 +-- libs/core/tests/unit_tests/test_messages.py | 26 ++++----- libs/core/tests/unit_tests/test_tools.py | 10 ++-- .../unit_tests/utils/test_function_calling.py | 58 +++++++++---------- .../unit_tests/vectorstores/test_in_memory.py | 16 ++--- .../vectorstores/test_vectorstore.py | 1 + 35 files changed, 266 insertions(+), 200 deletions(-) diff --git a/libs/core/langchain_core/_api/beta_decorator.py b/libs/core/langchain_core/_api/beta_decorator.py index 9ad1cd545f043..10b4e83ec1691 100644 --- a/libs/core/langchain_core/_api/beta_decorator.py +++ b/libs/core/langchain_core/_api/beta_decorator.py @@ -154,7 +154,7 @@ def warn_if_direct_instance( _name = _name or obj.fget.__qualname__ old_doc = obj.__doc__ - class _beta_property(property): + class _BetaProperty(property): """A beta property.""" def __init__(self, fget=None, fset=None, fdel=None, doc=None): @@ -185,7 +185,7 @@ def __set_name__(self, owner, set_name): def finalize(wrapper: Callable[..., Any], new_doc: str) -> Any: """Finalize the property.""" - return _beta_property( + return _BetaProperty( fget=obj.fget, fset=obj.fset, fdel=obj.fdel, doc=new_doc ) diff --git a/libs/core/langchain_core/_api/deprecation.py b/libs/core/langchain_core/_api/deprecation.py index b48215d0d6ca1..05360f108b9c3 100644 --- a/libs/core/langchain_core/_api/deprecation.py +++ b/libs/core/langchain_core/_api/deprecation.py @@ -236,7 +236,7 @@ def finalize(wrapper: Callable[..., Any], new_doc: str) -> T: _name = _name or cast(Union[Type, Callable], obj.fget).__qualname__ old_doc = obj.__doc__ - class _deprecated_property(property): + class _DeprecatedProperty(property): """A deprecated property.""" def __init__(self, fget=None, fset=None, fdel=None, doc=None): # type: ignore[no-untyped-def] @@ -269,7 +269,7 @@ def finalize(wrapper: Callable[..., Any], new_doc: str) -> T: """Finalize the property.""" return cast( T, - _deprecated_property( + _DeprecatedProperty( fget=obj.fget, fset=obj.fset, fdel=obj.fdel, doc=new_doc ), ) diff --git a/libs/core/langchain_core/exceptions.py b/libs/core/langchain_core/exceptions.py index b0d46ffae399f..f8cfde5a2a946 100644 --- a/libs/core/langchain_core/exceptions.py +++ b/libs/core/langchain_core/exceptions.py @@ -3,7 +3,7 @@ from typing import Any, Optional -class LangChainException(Exception): +class LangChainException(Exception): # noqa: N818 """General LangChain exception.""" @@ -11,7 +11,7 @@ class TracerException(LangChainException): """Base class for exceptions in tracers module.""" -class OutputParserException(ValueError, LangChainException): +class OutputParserException(ValueError, LangChainException): # noqa: N818 """Exception that output parsers should raise to signify a parsing error. This exists to differentiate parsing errors from other code or execution errors diff --git a/libs/core/langchain_core/language_models/base.py b/libs/core/langchain_core/language_models/base.py index 262a3b4e54883..d76d06ccff569 100644 --- a/libs/core/langchain_core/language_models/base.py +++ b/libs/core/langchain_core/language_models/base.py @@ -17,7 +17,7 @@ Union, ) -from typing_extensions import TypeAlias +from typing_extensions import TypeAlias, override from langchain_core._api import deprecated from langchain_core.messages import ( @@ -126,6 +126,7 @@ def set_verbose(cls, verbose: Optional[bool]) -> bool: return verbose @property + @override def InputType(self) -> TypeAlias: """Get the input type for this runnable.""" from langchain_core.prompt_values import ( diff --git a/libs/core/langchain_core/language_models/chat_models.py b/libs/core/langchain_core/language_models/chat_models.py index a8615faae3507..abe5d0cdc27a2 100644 --- a/libs/core/langchain_core/language_models/chat_models.py +++ b/libs/core/langchain_core/language_models/chat_models.py @@ -23,7 +23,7 @@ cast, ) -from typing_extensions import TypedDict +from typing_extensions import TypedDict, override from langchain_core._api import deprecated from langchain_core.caches import BaseCache @@ -260,6 +260,7 @@ class Config: # --- Runnable methods --- @property + @override def OutputType(self) -> Any: """Get the output type for this runnable.""" return AnyMessage diff --git a/libs/core/langchain_core/language_models/llms.py b/libs/core/langchain_core/language_models/llms.py index 03ae2be5e2381..f5d322b4ce768 100644 --- a/libs/core/langchain_core/language_models/llms.py +++ b/libs/core/langchain_core/language_models/llms.py @@ -36,6 +36,7 @@ stop_after_attempt, wait_exponential, ) +from typing_extensions import override from langchain_core._api import deprecated from langchain_core.caches import BaseCache @@ -314,6 +315,7 @@ def raise_deprecation(cls, values: Dict) -> Dict: # --- Runnable methods --- @property + @override def OutputType(self) -> Type[str]: """Get the input type for this runnable.""" return str diff --git a/libs/core/langchain_core/output_parsers/base.py b/libs/core/langchain_core/output_parsers/base.py index caac385b6bd87..c6a693d3f69e3 100644 --- a/libs/core/langchain_core/output_parsers/base.py +++ b/libs/core/langchain_core/output_parsers/base.py @@ -13,7 +13,7 @@ Union, ) -from typing_extensions import get_args +from typing_extensions import get_args, override from langchain_core.language_models import LanguageModelOutput from langchain_core.messages import AnyMessage, BaseMessage @@ -68,11 +68,13 @@ class BaseGenerationOutputParser( """Base class to parse the output of an LLM call.""" @property + @override def InputType(self) -> Any: """Return the input type for the parser.""" return Union[str, AnyMessage] @property + @override def OutputType(self) -> Type[T]: """Return the output type for the parser.""" # even though mypy complains this isn't valid, @@ -153,11 +155,13 @@ def _type(self) -> str: """ # noqa: E501 @property + @override def InputType(self) -> Any: """Return the input type for the parser.""" return Union[str, AnyMessage] @property + @override def OutputType(self) -> Type[T]: """Return the output type for the parser. diff --git a/libs/core/langchain_core/output_parsers/pydantic.py b/libs/core/langchain_core/output_parsers/pydantic.py index b48dca9d28a65..8dfcf7ca442e9 100644 --- a/libs/core/langchain_core/output_parsers/pydantic.py +++ b/libs/core/langchain_core/output_parsers/pydantic.py @@ -2,6 +2,7 @@ from typing import Generic, List, Type import pydantic # pydantic: ignore +from typing_extensions import override from langchain_core.exceptions import OutputParserException from langchain_core.output_parsers import JsonOutputParser @@ -101,6 +102,7 @@ def _type(self) -> str: return "pydantic" @property + @override def OutputType(self) -> Type[TBaseModel]: """Return the pydantic model.""" return self.pydantic_object diff --git a/libs/core/langchain_core/output_parsers/xml.py b/libs/core/langchain_core/output_parsers/xml.py index 238c7f6d14b46..d872c96a6cf6d 100644 --- a/libs/core/langchain_core/output_parsers/xml.py +++ b/libs/core/langchain_core/output_parsers/xml.py @@ -1,6 +1,6 @@ import re import xml -import xml.etree.ElementTree as ET +import xml.etree.ElementTree as ET # noqa: N817 from typing import Any, AsyncIterator, Dict, Iterator, List, Literal, Optional, Union from xml.etree.ElementTree import TreeBuilder @@ -45,14 +45,14 @@ def __init__(self, parser: Literal["defusedxml", "xml"]) -> None: """ if parser == "defusedxml": try: - from defusedxml import ElementTree as DET # type: ignore + import defusedxml # type: ignore except ImportError: raise ImportError( "defusedxml is not installed. " "Please install it to use the defusedxml parser." "You can install it with `pip install defusedxml` " ) - _parser = DET.DefusedXMLParser(target=TreeBuilder()) + _parser = defusedxml.ElementTree.DefusedXMLParser(target=TreeBuilder()) else: _parser = None self.pull_parser = ET.XMLPullParser(["start", "end"], _parser=_parser) @@ -188,7 +188,7 @@ def parse(self, text: str) -> Dict[str, Union[str, List[Any]]]: # likely if you're reading this you can move them to the top of the file if self.parser == "defusedxml": try: - from defusedxml import ElementTree as DET # type: ignore + import defusedxml # type: ignore except ImportError: raise ImportError( "defusedxml is not installed. " @@ -196,9 +196,9 @@ def parse(self, text: str) -> Dict[str, Union[str, List[Any]]]: "You can install it with `pip install defusedxml`" "See https://github.com/tiran/defusedxml for more details" ) - _ET = DET # Use the defusedxml parser + _et = defusedxml.ElementTree # Use the defusedxml parser else: - _ET = ET # Use the standard library parser + _et = ET # Use the standard library parser match = re.search(r"```(xml)?(.*)```", text, re.DOTALL) if match is not None: diff --git a/libs/core/langchain_core/prompts/base.py b/libs/core/langchain_core/prompts/base.py index 8ba8e5607116c..de31eac093824 100644 --- a/libs/core/langchain_core/prompts/base.py +++ b/libs/core/langchain_core/prompts/base.py @@ -18,6 +18,7 @@ ) import yaml +from typing_extensions import override from langchain_core.output_parsers.base import BaseOutputParser from langchain_core.prompt_values import ( @@ -103,6 +104,7 @@ class Config: arbitrary_types_allowed = True @property + @override def OutputType(self) -> Any: """Return the output type of the prompt.""" return Union[StringPromptValue, ChatPromptValueConcrete] diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index 619d86368d705..e5bc527cf1438 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -35,7 +35,7 @@ overload, ) -from typing_extensions import Literal, get_args +from typing_extensions import Literal, get_args, override from langchain_core._api import beta_decorator from langchain_core.load.dump import dumpd @@ -253,7 +253,7 @@ def get_name( return name @property - def InputType(self) -> Type[Input]: + def InputType(self) -> Type[Input]: # noqa: N802 """The type of input this Runnable accepts specified as a type annotation.""" for cls in self.__class__.__orig_bases__: # type: ignore[attr-defined] type_args = get_args(cls) @@ -266,7 +266,7 @@ def InputType(self) -> Type[Input]: ) @property - def OutputType(self) -> Type[Output]: + def OutputType(self) -> Type[Output]: # noqa: N802 """The type of output this Runnable produces specified as a type annotation.""" for cls in self.__class__.__orig_bases__: # type: ignore[attr-defined] type_args = get_args(cls) @@ -2669,11 +2669,13 @@ class Config: arbitrary_types_allowed = True @property + @override def InputType(self) -> Type[Input]: """The type of the input to the Runnable.""" return self.first.InputType @property + @override def OutputType(self) -> Type[Output]: """The type of the output of the Runnable.""" return self.last.OutputType @@ -3422,6 +3424,7 @@ def get_name( return super().get_name(suffix, name=name) @property + @override def InputType(self) -> Any: """The type of the input to the Runnable.""" for step in self.steps__.values(): @@ -3915,6 +3918,7 @@ def __init__( pass @property + @override def InputType(self) -> Any: func = getattr(self, "_transform", None) or getattr(self, "_atransform") try: @@ -3928,6 +3932,7 @@ def InputType(self) -> Any: return Any @property + @override def OutputType(self) -> Any: func = getattr(self, "_transform", None) or getattr(self, "_atransform") try: @@ -4151,6 +4156,7 @@ def __init__( pass @property + @override def InputType(self) -> Any: """The type of the input to this Runnable.""" func = getattr(self, "func", None) or getattr(self, "afunc") @@ -4207,6 +4213,7 @@ def get_input_schema( return super().get_input_schema(config) @property + @override def OutputType(self) -> Any: """The type of the output of this Runnable as a type annotation. @@ -4734,6 +4741,7 @@ class Config: arbitrary_types_allowed = True @property + @override def InputType(self) -> Any: return List[self.bound.InputType] # type: ignore[name-defined] @@ -4749,6 +4757,7 @@ def get_input_schema( ) @property + @override def OutputType(self) -> Type[List[Output]]: return List[self.bound.OutputType] # type: ignore[name-defined] @@ -5036,6 +5045,7 @@ def get_name( return self.bound.get_name(suffix, name=name) @property + @override def InputType(self) -> Type[Input]: return ( cast(Type[Input], self.custom_input_type) @@ -5044,6 +5054,7 @@ def InputType(self) -> Type[Input]: ) @property + @override def OutputType(self) -> Type[Output]: return ( cast(Type[Output], self.custom_output_type) diff --git a/libs/core/langchain_core/runnables/configurable.py b/libs/core/langchain_core/runnables/configurable.py index cf7e9ce1dc7bc..97b59599fa15c 100644 --- a/libs/core/langchain_core/runnables/configurable.py +++ b/libs/core/langchain_core/runnables/configurable.py @@ -20,6 +20,8 @@ ) from weakref import WeakValueDictionary +from typing_extensions import override + from langchain_core.pydantic_v1 import BaseModel from langchain_core.runnables.base import Runnable, RunnableSerializable from langchain_core.runnables.config import ( @@ -71,10 +73,12 @@ def get_lc_namespace(cls) -> List[str]: return ["langchain", "schema", "runnable"] @property + @override def InputType(self) -> Type[Input]: return self.default.InputType @property + @override def OutputType(self) -> Type[Output]: return self.default.OutputType diff --git a/libs/core/langchain_core/runnables/fallbacks.py b/libs/core/langchain_core/runnables/fallbacks.py index b3249b47cc42e..3ea7048a56234 100644 --- a/libs/core/langchain_core/runnables/fallbacks.py +++ b/libs/core/langchain_core/runnables/fallbacks.py @@ -18,6 +18,8 @@ cast, ) +from typing_extensions import override + from langchain_core.load.dump import dumpd from langchain_core.pydantic_v1 import BaseModel from langchain_core.runnables.base import Runnable, RunnableSerializable @@ -111,10 +113,12 @@ class Config: arbitrary_types_allowed = True @property + @override def InputType(self) -> Type[Input]: return self.runnable.InputType @property + @override def OutputType(self) -> Type[Output]: return self.runnable.OutputType diff --git a/libs/core/langchain_core/runnables/graph_ascii.py b/libs/core/langchain_core/runnables/graph_ascii.py index 46677213f81c3..7e9225c275d99 100644 --- a/libs/core/langchain_core/runnables/graph_ascii.py +++ b/libs/core/langchain_core/runnables/graph_ascii.py @@ -245,27 +245,27 @@ def draw_ascii(vertices: Mapping[str, str], edges: Sequence[LangEdge]) -> str: # NOTE: coordinates might me negative, so we need to shift # everything to the positive plane before we actually draw it. - Xs = [] - Ys = [] + xlist = [] + ylist = [] sug = _build_sugiyama_layout(vertices, edges) for vertex in sug.g.sV: # NOTE: moving boxes w/2 to the left - Xs.append(vertex.view.xy[0] - vertex.view.w / 2.0) - Xs.append(vertex.view.xy[0] + vertex.view.w / 2.0) - Ys.append(vertex.view.xy[1]) - Ys.append(vertex.view.xy[1] + vertex.view.h) + xlist.append(vertex.view.xy[0] - vertex.view.w / 2.0) + xlist.append(vertex.view.xy[0] + vertex.view.w / 2.0) + ylist.append(vertex.view.xy[1]) + ylist.append(vertex.view.xy[1] + vertex.view.h) for edge in sug.g.sE: for x, y in edge.view._pts: - Xs.append(x) - Ys.append(y) + xlist.append(x) + ylist.append(y) - minx = min(Xs) - miny = min(Ys) - maxx = max(Xs) - maxy = max(Ys) + minx = min(xlist) + miny = min(ylist) + maxx = max(xlist) + maxy = max(ylist) canvas_cols = int(math.ceil(math.ceil(maxx) - math.floor(minx))) + 1 canvas_lines = int(round(maxy - miny)) diff --git a/libs/core/langchain_core/runnables/passthrough.py b/libs/core/langchain_core/runnables/passthrough.py index fe89b1f933362..8f878dcbf5259 100644 --- a/libs/core/langchain_core/runnables/passthrough.py +++ b/libs/core/langchain_core/runnables/passthrough.py @@ -21,6 +21,8 @@ cast, ) +from typing_extensions import override + from langchain_core.pydantic_v1 import BaseModel from langchain_core.runnables.base import ( Other, @@ -198,10 +200,12 @@ def get_lc_namespace(cls) -> List[str]: return ["langchain", "schema", "runnable"] @property + @override def InputType(self) -> Any: return self.input_type or Any @property + @override def OutputType(self) -> Any: return self.input_type or Any diff --git a/libs/core/langchain_core/runnables/utils.py b/libs/core/langchain_core/runnables/utils.py index 1d71f47d1bd14..8aeeeff3466e3 100644 --- a/libs/core/langchain_core/runnables/utils.py +++ b/libs/core/langchain_core/runnables/utils.py @@ -30,7 +30,7 @@ Union, ) -from typing_extensions import TypeGuard +from typing_extensions import TypeGuard, override from langchain_core.pydantic_v1 import BaseConfig, BaseModel from langchain_core.pydantic_v1 import create_model as _create_model_base @@ -136,6 +136,7 @@ def __init__(self, name: str, keys: Set[str]) -> None: self.name = name self.keys = keys + @override def visit_Subscript(self, node: ast.Subscript) -> Any: """Visit a subscript node. @@ -155,6 +156,7 @@ def visit_Subscript(self, node: ast.Subscript) -> Any: # we've found a subscript access on the name we're looking for self.keys.add(node.slice.value) + @override def visit_Call(self, node: ast.Call) -> Any: """Visit a call node. @@ -183,6 +185,7 @@ class IsFunctionArgDict(ast.NodeVisitor): def __init__(self) -> None: self.keys: Set[str] = set() + @override def visit_Lambda(self, node: ast.Lambda) -> Any: """Visit a lambda function. @@ -197,6 +200,7 @@ def visit_Lambda(self, node: ast.Lambda) -> Any: input_arg_name = node.args.args[0].arg IsLocalDict(input_arg_name, self.keys).visit(node.body) + @override def visit_FunctionDef(self, node: ast.FunctionDef) -> Any: """Visit a function definition. @@ -211,6 +215,7 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> Any: input_arg_name = node.args.args[0].arg IsLocalDict(input_arg_name, self.keys).visit(node) + @override def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> Any: """Visit an async function definition. @@ -233,6 +238,7 @@ def __init__(self) -> None: self.loads: Set[str] = set() self.stores: Set[str] = set() + @override def visit_Name(self, node: ast.Name) -> Any: """Visit a name node. @@ -247,6 +253,7 @@ def visit_Name(self, node: ast.Name) -> Any: elif isinstance(node.ctx, ast.Store): self.stores.add(node.id) + @override def visit_Attribute(self, node: ast.Attribute) -> Any: """Visit an attribute node. @@ -273,6 +280,7 @@ class FunctionNonLocals(ast.NodeVisitor): def __init__(self) -> None: self.nonlocals: Set[str] = set() + @override def visit_FunctionDef(self, node: ast.FunctionDef) -> Any: """Visit a function definition. @@ -286,6 +294,7 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> Any: visitor.visit(node) self.nonlocals.update(visitor.loads - visitor.stores) + @override def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> Any: """Visit an async function definition. @@ -299,6 +308,7 @@ def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> Any: visitor.visit(node) self.nonlocals.update(visitor.loads - visitor.stores) + @override def visit_Lambda(self, node: ast.Lambda) -> Any: """Visit a lambda function. @@ -321,6 +331,7 @@ def __init__(self) -> None: self.source: Optional[str] = None self.count = 0 + @override def visit_Lambda(self, node: ast.Lambda) -> Any: """Visit a lambda function. diff --git a/libs/core/langchain_core/tools.py b/libs/core/langchain_core/tools.py index 0dfe3f877e9f8..358b18b8bc384 100644 --- a/libs/core/langchain_core/tools.py +++ b/libs/core/langchain_core/tools.py @@ -263,7 +263,7 @@ def create_schema_from_function( ) -class ToolException(Exception): +class ToolException(Exception): # noqa: N818 """Optional exception that tool throws when execution error occurs. When this exception is thrown, the agent will not stop working, diff --git a/libs/core/langchain_core/tracers/langchain_v1.py b/libs/core/langchain_core/tracers/langchain_v1.py index bf1237d66abbe..ea1c882ea67da 100644 --- a/libs/core/langchain_core/tracers/langchain_v1.py +++ b/libs/core/langchain_core/tracers/langchain_v1.py @@ -9,7 +9,7 @@ def get_headers(*args: Any, **kwargs: Any) -> Any: ) -def LangChainTracerV1(*args: Any, **kwargs: Any) -> Any: +def LangChainTracerV1(*args: Any, **kwargs: Any) -> Any: # noqa: N802 """Throw an error because this has been replaced by LangChainTracer.""" raise RuntimeError( "LangChainTracerV1 is no longer supported. Please use LangChainTracer instead." diff --git a/libs/core/langchain_core/tracers/schemas.py b/libs/core/langchain_core/tracers/schemas.py index dd8982eabd4a6..c169aefeb01ae 100644 --- a/libs/core/langchain_core/tracers/schemas.py +++ b/libs/core/langchain_core/tracers/schemas.py @@ -16,7 +16,7 @@ @deprecated("0.1.0", alternative="Use string instead.", removal="0.3.0") -def RunTypeEnum() -> Type[RunTypeEnumDep]: +def RunTypeEnum() -> Type[RunTypeEnumDep]: # noqa: N802 """RunTypeEnum.""" warnings.warn( "RunTypeEnum is deprecated. Please directly use a string instead" diff --git a/libs/core/langchain_core/utils/aiter.py b/libs/core/langchain_core/utils/aiter.py index 0748e272da84d..e84b38ac02af0 100644 --- a/libs/core/langchain_core/utils/aiter.py +++ b/libs/core/langchain_core/utils/aiter.py @@ -238,7 +238,7 @@ async def aclose(self) -> None: atee = Tee -class aclosing(AbstractAsyncContextManager): +class aclosing(AbstractAsyncContextManager): # noqa: N801 """Async context manager for safely finalizing an asynchronously cleaned-up resource such as an async generator, calling its ``aclose()`` method. diff --git a/libs/core/langchain_core/vectorstores/utils.py b/libs/core/langchain_core/vectorstores/utils.py index 5bcf756747bf0..dc0ba871d2878 100644 --- a/libs/core/langchain_core/vectorstores/utils.py +++ b/libs/core/langchain_core/vectorstores/utils.py @@ -17,12 +17,12 @@ logger = logging.getLogger(__name__) -def _cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray: +def _cosine_similarity(x: Matrix, y: Matrix) -> np.ndarray: """Row-wise cosine similarity between two equal-width matrices. Args: - X: A matrix of shape (n, m). - Y: A matrix of shape (k, m). + x: A matrix of shape (n, m). + y: A matrix of shape (k, m). Returns: A matrix of shape (n, k) where each element (i, j) is the cosine similarity @@ -40,33 +40,33 @@ def _cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray: "Please install numpy with `pip install numpy`." ) - if len(X) == 0 or len(Y) == 0: + if len(x) == 0 or len(y) == 0: return np.array([]) - X = np.array(X) - Y = np.array(Y) - if X.shape[1] != Y.shape[1]: + x = np.array(x) + y = np.array(y) + if x.shape[1] != y.shape[1]: raise ValueError( - f"Number of columns in X and Y must be the same. X has shape {X.shape} " - f"and Y has shape {Y.shape}." + f"Number of columns in X and Y must be the same. X has shape {x.shape} " + f"and Y has shape {y.shape}." ) try: import simsimd as simd # type: ignore - X = np.array(X, dtype=np.float32) - Y = np.array(Y, dtype=np.float32) - Z = 1 - np.array(simd.cdist(X, Y, metric="cosine")) - return Z + x = np.array(x, dtype=np.float32) + y = np.array(y, dtype=np.float32) + z = 1 - np.array(simd.cdist(x, y, metric="cosine")) + return z except ImportError: logger.debug( "Unable to import simsimd, defaulting to NumPy implementation. If you want " "to use simsimd please install with `pip install simsimd`." ) - X_norm = np.linalg.norm(X, axis=1) - Y_norm = np.linalg.norm(Y, axis=1) + x_norm = np.linalg.norm(x, axis=1) + y_norm = np.linalg.norm(y, axis=1) # Ignore divide by zero errors run time warnings as those are handled below. with np.errstate(divide="ignore", invalid="ignore"): - similarity = np.dot(X, Y.T) / np.outer(X_norm, Y_norm) + similarity = np.dot(x, y.T) / np.outer(x_norm, y_norm) similarity[np.isnan(similarity) | np.isinf(similarity)] = 0.0 return similarity diff --git a/libs/core/pyproject.toml b/libs/core/pyproject.toml index e8bb2df1e6e65..0748027a45fb5 100644 --- a/libs/core/pyproject.toml +++ b/libs/core/pyproject.toml @@ -41,7 +41,15 @@ python = ">=3.12.4" [tool.poetry.extras] [tool.ruff.lint] -select = [ "E", "F", "I", "T201",] +select = [ "E", "F", "I", "N", "T201",] + +[tool.ruff.lint.pep8-naming] +classmethod-decorators = [ + "classmethod", + "langchain_core.pydantic_v1.validator", + "langchain_core.pydantic_v1.root_validator", + "langchain_core.utils.pydantic.pre_init" +] [tool.coverage.run] omit = [ "tests/*",] diff --git a/libs/core/tests/unit_tests/fake/test_fake_chat_model.py b/libs/core/tests/unit_tests/fake/test_fake_chat_model.py index feff8a4a7bafa..56e6c19c4a5e4 100644 --- a/libs/core/tests/unit_tests/fake/test_fake_chat_model.py +++ b/libs/core/tests/unit_tests/fake/test_fake_chat_model.py @@ -10,9 +10,9 @@ from langchain_core.outputs import ChatGenerationChunk, GenerationChunk from tests.unit_tests.stubs import ( AnyStr, - _AnyIdAIMessage, - _AnyIdAIMessageChunk, - _AnyIdHumanMessage, + _any_id_ai_message, + _any_id_ai_message_chunk, + _any_id_human_message, ) @@ -21,11 +21,11 @@ def test_generic_fake_chat_model_invoke() -> None: infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")]) model = GenericFakeChatModel(messages=infinite_cycle) response = model.invoke("meow") - assert response == _AnyIdAIMessage(content="hello") + assert response == _any_id_ai_message(content="hello") response = model.invoke("kitty") - assert response == _AnyIdAIMessage(content="goodbye") + assert response == _any_id_ai_message(content="goodbye") response = model.invoke("meow") - assert response == _AnyIdAIMessage(content="hello") + assert response == _any_id_ai_message(content="hello") async def test_generic_fake_chat_model_ainvoke() -> None: @@ -33,11 +33,11 @@ async def test_generic_fake_chat_model_ainvoke() -> None: infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")]) model = GenericFakeChatModel(messages=infinite_cycle) response = await model.ainvoke("meow") - assert response == _AnyIdAIMessage(content="hello") + assert response == _any_id_ai_message(content="hello") response = await model.ainvoke("kitty") - assert response == _AnyIdAIMessage(content="goodbye") + assert response == _any_id_ai_message(content="goodbye") response = await model.ainvoke("meow") - assert response == _AnyIdAIMessage(content="hello") + assert response == _any_id_ai_message(content="hello") async def test_generic_fake_chat_model_stream() -> None: @@ -50,17 +50,17 @@ async def test_generic_fake_chat_model_stream() -> None: model = GenericFakeChatModel(messages=infinite_cycle) chunks = [chunk async for chunk in model.astream("meow")] assert chunks == [ - _AnyIdAIMessageChunk(content="hello"), - _AnyIdAIMessageChunk(content=" "), - _AnyIdAIMessageChunk(content="goodbye"), + _any_id_ai_message_chunk(content="hello"), + _any_id_ai_message_chunk(content=" "), + _any_id_ai_message_chunk(content="goodbye"), ] assert len({chunk.id for chunk in chunks}) == 1 chunks = [chunk for chunk in model.stream("meow")] assert chunks == [ - _AnyIdAIMessageChunk(content="hello"), - _AnyIdAIMessageChunk(content=" "), - _AnyIdAIMessageChunk(content="goodbye"), + _any_id_ai_message_chunk(content="hello"), + _any_id_ai_message_chunk(content=" "), + _any_id_ai_message_chunk(content="goodbye"), ] assert len({chunk.id for chunk in chunks}) == 1 @@ -145,9 +145,9 @@ async def test_generic_fake_chat_model_astream_log() -> None: ] final = log_patches[-1] assert final.state["streamed_output"] == [ - _AnyIdAIMessageChunk(content="hello"), - _AnyIdAIMessageChunk(content=" "), - _AnyIdAIMessageChunk(content="goodbye"), + _any_id_ai_message_chunk(content="hello"), + _any_id_ai_message_chunk(content=" "), + _any_id_ai_message_chunk(content="goodbye"), ] assert len({chunk.id for chunk in final.state["streamed_output"]}) == 1 @@ -196,9 +196,9 @@ async def on_llm_new_token( # New model results = list(model.stream("meow", {"callbacks": [MyCustomAsyncHandler(tokens)]})) assert results == [ - _AnyIdAIMessageChunk(content="hello"), - _AnyIdAIMessageChunk(content=" "), - _AnyIdAIMessageChunk(content="goodbye"), + _any_id_ai_message_chunk(content="hello"), + _any_id_ai_message_chunk(content=" "), + _any_id_ai_message_chunk(content="goodbye"), ] assert tokens == ["hello", " ", "goodbye"] assert len({chunk.id for chunk in results}) == 1 @@ -207,6 +207,8 @@ async def on_llm_new_token( def test_chat_model_inputs() -> None: fake = ParrotFakeChatModel() - assert fake.invoke("hello") == _AnyIdHumanMessage(content="hello") - assert fake.invoke([("ai", "blah")]) == _AnyIdAIMessage(content="blah") - assert fake.invoke([AIMessage(content="blah")]) == _AnyIdAIMessage(content="blah") + assert fake.invoke("hello") == _any_id_human_message(content="hello") + assert fake.invoke([("ai", "blah")]) == _any_id_ai_message(content="blah") + assert fake.invoke([AIMessage(content="blah")]) == _any_id_ai_message( + content="blah" + ) diff --git a/libs/core/tests/unit_tests/language_models/chat_models/test_base.py b/libs/core/tests/unit_tests/language_models/chat_models/test_base.py index e137d96460e4c..fb3fbec6ba12b 100644 --- a/libs/core/tests/unit_tests/language_models/chat_models/test_base.py +++ b/libs/core/tests/unit_tests/language_models/chat_models/test_base.py @@ -25,7 +25,7 @@ FakeAsyncCallbackHandler, FakeCallbackHandler, ) -from tests.unit_tests.stubs import _AnyIdAIMessage, _AnyIdAIMessageChunk +from tests.unit_tests.stubs import _any_id_ai_message, _any_id_ai_message_chunk @pytest.fixture @@ -145,10 +145,10 @@ def _llm_type(self) -> str: model = ModelWithGenerate() chunks = [chunk for chunk in model.stream("anything")] - assert chunks == [_AnyIdAIMessage(content="hello")] + assert chunks == [_any_id_ai_message(content="hello")] chunks = [chunk async for chunk in model.astream("anything")] - assert chunks == [_AnyIdAIMessage(content="hello")] + assert chunks == [_any_id_ai_message(content="hello")] async def test_astream_implementation_fallback_to_stream() -> None: @@ -183,15 +183,15 @@ def _llm_type(self) -> str: model = ModelWithSyncStream() chunks = [chunk for chunk in model.stream("anything")] assert chunks == [ - _AnyIdAIMessageChunk(content="a"), - _AnyIdAIMessageChunk(content="b"), + _any_id_ai_message_chunk(content="a"), + _any_id_ai_message_chunk(content="b"), ] assert len({chunk.id for chunk in chunks}) == 1 assert type(model)._astream == BaseChatModel._astream astream_chunks = [chunk async for chunk in model.astream("anything")] assert astream_chunks == [ - _AnyIdAIMessageChunk(content="a"), - _AnyIdAIMessageChunk(content="b"), + _any_id_ai_message_chunk(content="a"), + _any_id_ai_message_chunk(content="b"), ] assert len({chunk.id for chunk in astream_chunks}) == 1 @@ -228,8 +228,8 @@ def _llm_type(self) -> str: model = ModelWithAsyncStream() chunks = [chunk async for chunk in model.astream("anything")] assert chunks == [ - _AnyIdAIMessageChunk(content="a"), - _AnyIdAIMessageChunk(content="b"), + _any_id_ai_message_chunk(content="a"), + _any_id_ai_message_chunk(content="b"), ] assert len({chunk.id for chunk in chunks}) == 1 diff --git a/libs/core/tests/unit_tests/prompts/test_loading.py b/libs/core/tests/unit_tests/prompts/test_loading.py index d7092aa94c630..76af9259e80b2 100644 --- a/libs/core/tests/unit_tests/prompts/test_loading.py +++ b/libs/core/tests/unit_tests/prompts/test_loading.py @@ -25,7 +25,7 @@ def change_directory(dir: Path) -> Iterator: os.chdir(origin) -def test_loading_from_YAML() -> None: +def test_loading_from_yaml() -> None: """Test loading from yaml file.""" prompt = load_prompt(EXAMPLE_DIR / "simple_prompt.yaml") expected_prompt = PromptTemplate( @@ -36,7 +36,7 @@ def test_loading_from_YAML() -> None: assert prompt == expected_prompt -def test_loading_from_JSON() -> None: +def test_loading_from_json() -> None: """Test loading from json file.""" prompt = load_prompt(EXAMPLE_DIR / "simple_prompt.json") expected_prompt = PromptTemplate( @@ -46,14 +46,14 @@ def test_loading_from_JSON() -> None: assert prompt == expected_prompt -def test_loading_jinja_from_JSON() -> None: +def test_loading_jinja_from_json() -> None: """Test that loading jinja2 format prompts from JSON raises ValueError.""" prompt_path = EXAMPLE_DIR / "jinja_injection_prompt.json" with pytest.raises(ValueError, match=".*can lead to arbitrary code execution.*"): load_prompt(prompt_path) -def test_loading_jinja_from_YAML() -> None: +def test_loading_jinja_from_yaml() -> None: """Test that loading jinja2 format prompts from YAML raises ValueError.""" prompt_path = EXAMPLE_DIR / "jinja_injection_prompt.yaml" with pytest.raises(ValueError, match=".*can lead to arbitrary code execution.*"): diff --git a/libs/core/tests/unit_tests/runnables/test_graph.py b/libs/core/tests/unit_tests/runnables/test_graph.py index 3d18a7f8f94e5..9545fb5e6317f 100644 --- a/libs/core/tests/unit_tests/runnables/test_graph.py +++ b/libs/core/tests/unit_tests/runnables/test_graph.py @@ -1,6 +1,7 @@ from typing import Optional from syrupy import SnapshotAssertion +from typing_extensions import override from langchain_core.language_models import FakeListLLM from langchain_core.output_parsers.list import CommaSeparatedListOutputParser @@ -221,9 +222,11 @@ def test_runnable_get_graph_with_invalid_input_type() -> None: class InvalidInputTypeRunnable(Runnable[int, int]): @property + @override def InputType(self) -> type: raise TypeError() + @override def invoke( self, input: int, @@ -243,9 +246,11 @@ def test_runnable_get_graph_with_invalid_output_type() -> None: class InvalidOutputTypeRunnable(Runnable[int, int]): @property + @override def OutputType(self) -> type: raise TypeError() + @override def invoke( self, input: int, diff --git a/libs/core/tests/unit_tests/runnables/test_runnable.py b/libs/core/tests/unit_tests/runnables/test_runnable.py index 25125be8c10ef..7f1de0f100f98 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable.py @@ -90,7 +90,7 @@ ) from langchain_core.tracers.context import collect_runs from tests.unit_tests.pydantic_utils import _schema -from tests.unit_tests.stubs import AnyStr, _AnyIdAIMessage, _AnyIdAIMessageChunk +from tests.unit_tests.stubs import AnyStr, _any_id_ai_message, _any_id_ai_message_chunk class FakeTracer(BaseTracer): @@ -1831,7 +1831,7 @@ def test_prompt_with_chat_model( tracer = FakeTracer() assert chain.invoke( {"question": "What is your name?"}, dict(callbacks=[tracer]) - ) == _AnyIdAIMessage(content="foo") + ) == _any_id_ai_message(content="foo") assert prompt_spy.call_args.args[1] == {"question": "What is your name?"} assert chat_spy.call_args.args[1] == ChatPromptValue( messages=[ @@ -1856,8 +1856,8 @@ def test_prompt_with_chat_model( ], dict(callbacks=[tracer]), ) == [ - _AnyIdAIMessage(content="foo"), - _AnyIdAIMessage(content="foo"), + _any_id_ai_message(content="foo"), + _any_id_ai_message(content="foo"), ] assert prompt_spy.call_args.args[1] == [ {"question": "What is your name?"}, @@ -1897,9 +1897,9 @@ def test_prompt_with_chat_model( assert [ *chain.stream({"question": "What is your name?"}, dict(callbacks=[tracer])) ] == [ - _AnyIdAIMessageChunk(content="f"), - _AnyIdAIMessageChunk(content="o"), - _AnyIdAIMessageChunk(content="o"), + _any_id_ai_message_chunk(content="f"), + _any_id_ai_message_chunk(content="o"), + _any_id_ai_message_chunk(content="o"), ] assert prompt_spy.call_args.args[1] == {"question": "What is your name?"} assert chat_spy.call_args.args[1] == ChatPromptValue( @@ -1937,7 +1937,7 @@ async def test_prompt_with_chat_model_async( tracer = FakeTracer() assert await chain.ainvoke( {"question": "What is your name?"}, dict(callbacks=[tracer]) - ) == _AnyIdAIMessage(content="foo") + ) == _any_id_ai_message(content="foo") assert prompt_spy.call_args.args[1] == {"question": "What is your name?"} assert chat_spy.call_args.args[1] == ChatPromptValue( messages=[ @@ -1962,8 +1962,8 @@ async def test_prompt_with_chat_model_async( ], dict(callbacks=[tracer]), ) == [ - _AnyIdAIMessage(content="foo"), - _AnyIdAIMessage(content="foo"), + _any_id_ai_message(content="foo"), + _any_id_ai_message(content="foo"), ] assert prompt_spy.call_args.args[1] == [ {"question": "What is your name?"}, @@ -2006,9 +2006,9 @@ async def test_prompt_with_chat_model_async( {"question": "What is your name?"}, dict(callbacks=[tracer]) ) ] == [ - _AnyIdAIMessageChunk(content="f"), - _AnyIdAIMessageChunk(content="o"), - _AnyIdAIMessageChunk(content="o"), + _any_id_ai_message_chunk(content="f"), + _any_id_ai_message_chunk(content="o"), + _any_id_ai_message_chunk(content="o"), ] assert prompt_spy.call_args.args[1] == {"question": "What is your name?"} assert chat_spy.call_args.args[1] == ChatPromptValue( @@ -2675,7 +2675,7 @@ def test_prompt_with_chat_model_and_parser( HumanMessage(content="What is your name?"), ] ) - assert parser_spy.call_args.args[1] == _AnyIdAIMessage(content="foo, bar") + assert parser_spy.call_args.args[1] == _any_id_ai_message(content="foo, bar") assert tracer.runs == snapshot @@ -2810,7 +2810,7 @@ def test_seq_dict_prompt_llm( ), ] ) - assert parser_spy.call_args.args[1] == _AnyIdAIMessage(content="foo, bar") + assert parser_spy.call_args.args[1] == _any_id_ai_message(content="foo, bar") assert len([r for r in tracer.runs if r.parent_run_id is None]) == 1 parent_run = next(r for r in tracer.runs if r.parent_run_id is None) assert len(parent_run.child_runs) == 4 @@ -2856,7 +2856,7 @@ def test_seq_prompt_dict(mocker: MockerFixture, snapshot: SnapshotAssertion) -> assert chain.invoke( {"question": "What is your name?"}, dict(callbacks=[tracer]) ) == { - "chat": _AnyIdAIMessage(content="i'm a chatbot"), + "chat": _any_id_ai_message(content="i'm a chatbot"), "llm": "i'm a textbot", } assert prompt_spy.call_args.args[1] == {"question": "What is your name?"} @@ -3066,7 +3066,7 @@ def test_seq_prompt_map(mocker: MockerFixture, snapshot: SnapshotAssertion) -> N assert chain.invoke( {"question": "What is your name?"}, dict(callbacks=[tracer]) ) == { - "chat": _AnyIdAIMessage(content="i'm a chatbot"), + "chat": _any_id_ai_message(content="i'm a chatbot"), "llm": "i'm a textbot", "passthrough": ChatPromptValue( messages=[ @@ -3275,7 +3275,7 @@ async def test_map_astream() -> None: assert streamed_chunks[0] in [ {"passthrough": prompt.invoke({"question": "What is your name?"})}, {"llm": "i"}, - {"chat": _AnyIdAIMessageChunk(content="i")}, + {"chat": _any_id_ai_message_chunk(content="i")}, ] assert len(streamed_chunks) == len(chat_res) + len(llm_res) + 1 assert all(len(c.keys()) == 1 for c in streamed_chunks) diff --git a/libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py b/libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py index c2230f560a5fa..73d4c0ede78e7 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py @@ -30,7 +30,7 @@ from langchain_core.runnables.history import RunnableWithMessageHistory from langchain_core.runnables.schema import StreamEvent from langchain_core.tools import tool -from tests.unit_tests.stubs import _AnyIdAIMessage, _AnyIdAIMessageChunk +from tests.unit_tests.stubs import _any_id_ai_message, _any_id_ai_message_chunk def _with_nulled_run_id(events: Sequence[StreamEvent]) -> List[StreamEvent]: @@ -501,7 +501,7 @@ async def test_astream_events_from_model() -> None: "tags": ["my_model"], }, { - "data": {"chunk": _AnyIdAIMessageChunk(content="hello")}, + "data": {"chunk": _any_id_ai_message_chunk(content="hello")}, "event": "on_chat_model_stream", "metadata": {"a": "b"}, "name": "my_model", @@ -510,7 +510,7 @@ async def test_astream_events_from_model() -> None: "tags": ["my_model"], }, { - "data": {"chunk": _AnyIdAIMessageChunk(content=" ")}, + "data": {"chunk": _any_id_ai_message_chunk(content=" ")}, "event": "on_chat_model_stream", "metadata": {"a": "b"}, "name": "my_model", @@ -519,7 +519,7 @@ async def test_astream_events_from_model() -> None: "tags": ["my_model"], }, { - "data": {"chunk": _AnyIdAIMessageChunk(content="world!")}, + "data": {"chunk": _any_id_ai_message_chunk(content="world!")}, "event": "on_chat_model_stream", "metadata": {"a": "b"}, "name": "my_model", @@ -528,7 +528,7 @@ async def test_astream_events_from_model() -> None: "tags": ["my_model"], }, { - "data": {"output": _AnyIdAIMessageChunk(content="hello world!")}, + "data": {"output": _any_id_ai_message_chunk(content="hello world!")}, "event": "on_chat_model_end", "metadata": {"a": "b"}, "name": "my_model", @@ -573,7 +573,7 @@ def i_dont_stream(input: Any, config: RunnableConfig) -> Any: "tags": ["my_model"], }, { - "data": {"chunk": _AnyIdAIMessageChunk(content="hello")}, + "data": {"chunk": _any_id_ai_message_chunk(content="hello")}, "event": "on_chat_model_stream", "metadata": { "a": "b", @@ -586,7 +586,7 @@ def i_dont_stream(input: Any, config: RunnableConfig) -> Any: "tags": ["my_model"], }, { - "data": {"chunk": _AnyIdAIMessageChunk(content=" ")}, + "data": {"chunk": _any_id_ai_message_chunk(content=" ")}, "event": "on_chat_model_stream", "metadata": { "a": "b", @@ -599,7 +599,7 @@ def i_dont_stream(input: Any, config: RunnableConfig) -> Any: "tags": ["my_model"], }, { - "data": {"chunk": _AnyIdAIMessageChunk(content="world!")}, + "data": {"chunk": _any_id_ai_message_chunk(content="world!")}, "event": "on_chat_model_stream", "metadata": { "a": "b", @@ -619,7 +619,9 @@ def i_dont_stream(input: Any, config: RunnableConfig) -> Any: [ { "generation_info": None, - "message": _AnyIdAIMessage(content="hello world!"), + "message": _any_id_ai_message( + content="hello world!" + ), "text": "hello world!", "type": "ChatGeneration", } @@ -641,7 +643,7 @@ def i_dont_stream(input: Any, config: RunnableConfig) -> Any: "tags": ["my_model"], }, { - "data": {"chunk": _AnyIdAIMessage(content="hello world!")}, + "data": {"chunk": _any_id_ai_message(content="hello world!")}, "event": "on_chain_stream", "metadata": {}, "name": "i_dont_stream", @@ -650,7 +652,7 @@ def i_dont_stream(input: Any, config: RunnableConfig) -> Any: "tags": [], }, { - "data": {"output": _AnyIdAIMessage(content="hello world!")}, + "data": {"output": _any_id_ai_message(content="hello world!")}, "event": "on_chain_end", "metadata": {}, "name": "i_dont_stream", @@ -695,7 +697,7 @@ async def ai_dont_stream(input: Any, config: RunnableConfig) -> Any: "tags": ["my_model"], }, { - "data": {"chunk": _AnyIdAIMessageChunk(content="hello")}, + "data": {"chunk": _any_id_ai_message_chunk(content="hello")}, "event": "on_chat_model_stream", "metadata": { "a": "b", @@ -708,7 +710,7 @@ async def ai_dont_stream(input: Any, config: RunnableConfig) -> Any: "tags": ["my_model"], }, { - "data": {"chunk": _AnyIdAIMessageChunk(content=" ")}, + "data": {"chunk": _any_id_ai_message_chunk(content=" ")}, "event": "on_chat_model_stream", "metadata": { "a": "b", @@ -721,7 +723,7 @@ async def ai_dont_stream(input: Any, config: RunnableConfig) -> Any: "tags": ["my_model"], }, { - "data": {"chunk": _AnyIdAIMessageChunk(content="world!")}, + "data": {"chunk": _any_id_ai_message_chunk(content="world!")}, "event": "on_chat_model_stream", "metadata": { "a": "b", @@ -741,7 +743,9 @@ async def ai_dont_stream(input: Any, config: RunnableConfig) -> Any: [ { "generation_info": None, - "message": _AnyIdAIMessage(content="hello world!"), + "message": _any_id_ai_message( + content="hello world!" + ), "text": "hello world!", "type": "ChatGeneration", } @@ -763,7 +767,7 @@ async def ai_dont_stream(input: Any, config: RunnableConfig) -> Any: "tags": ["my_model"], }, { - "data": {"chunk": _AnyIdAIMessage(content="hello world!")}, + "data": {"chunk": _any_id_ai_message(content="hello world!")}, "event": "on_chain_stream", "metadata": {}, "name": "ai_dont_stream", @@ -772,7 +776,7 @@ async def ai_dont_stream(input: Any, config: RunnableConfig) -> Any: "tags": [], }, { - "data": {"output": _AnyIdAIMessage(content="hello world!")}, + "data": {"output": _any_id_ai_message(content="hello world!")}, "event": "on_chain_end", "metadata": {}, "name": "ai_dont_stream", diff --git a/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py b/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py index 6132c0efb0d90..95eceb8b85270 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py @@ -53,7 +53,7 @@ from tests.unit_tests.runnables.test_runnable_events_v1 import ( _assert_events_equal_allow_superset_metadata, ) -from tests.unit_tests.stubs import _AnyIdAIMessage, _AnyIdAIMessageChunk +from tests.unit_tests.stubs import _any_id_ai_message, _any_id_ai_message_chunk def _with_nulled_run_id(events: Sequence[StreamEvent]) -> List[StreamEvent]: @@ -539,7 +539,7 @@ async def test_astream_events_from_model() -> None: "tags": ["my_model"], }, { - "data": {"chunk": _AnyIdAIMessageChunk(content="hello")}, + "data": {"chunk": _any_id_ai_message_chunk(content="hello")}, "event": "on_chat_model_stream", "metadata": { "a": "b", @@ -552,7 +552,7 @@ async def test_astream_events_from_model() -> None: "tags": ["my_model"], }, { - "data": {"chunk": _AnyIdAIMessageChunk(content=" ")}, + "data": {"chunk": _any_id_ai_message_chunk(content=" ")}, "event": "on_chat_model_stream", "metadata": { "a": "b", @@ -565,7 +565,7 @@ async def test_astream_events_from_model() -> None: "tags": ["my_model"], }, { - "data": {"chunk": _AnyIdAIMessageChunk(content="world!")}, + "data": {"chunk": _any_id_ai_message_chunk(content="world!")}, "event": "on_chat_model_stream", "metadata": { "a": "b", @@ -579,7 +579,7 @@ async def test_astream_events_from_model() -> None: }, { "data": { - "output": _AnyIdAIMessageChunk(content="hello world!"), + "output": _any_id_ai_message_chunk(content="hello world!"), }, "event": "on_chat_model_end", "metadata": { @@ -646,7 +646,7 @@ def i_dont_stream(input: Any, config: RunnableConfig) -> Any: "tags": ["my_model"], }, { - "data": {"chunk": _AnyIdAIMessageChunk(content="hello")}, + "data": {"chunk": _any_id_ai_message_chunk(content="hello")}, "event": "on_chat_model_stream", "metadata": { "a": "b", @@ -659,7 +659,7 @@ def i_dont_stream(input: Any, config: RunnableConfig) -> Any: "tags": ["my_model"], }, { - "data": {"chunk": _AnyIdAIMessageChunk(content=" ")}, + "data": {"chunk": _any_id_ai_message_chunk(content=" ")}, "event": "on_chat_model_stream", "metadata": { "a": "b", @@ -672,7 +672,7 @@ def i_dont_stream(input: Any, config: RunnableConfig) -> Any: "tags": ["my_model"], }, { - "data": {"chunk": _AnyIdAIMessageChunk(content="world!")}, + "data": {"chunk": _any_id_ai_message_chunk(content="world!")}, "event": "on_chat_model_stream", "metadata": { "a": "b", @@ -687,7 +687,7 @@ def i_dont_stream(input: Any, config: RunnableConfig) -> Any: { "data": { "input": {"messages": [[HumanMessage(content="hello")]]}, - "output": _AnyIdAIMessage(content="hello world!"), + "output": _any_id_ai_message(content="hello world!"), }, "event": "on_chat_model_end", "metadata": { @@ -701,7 +701,7 @@ def i_dont_stream(input: Any, config: RunnableConfig) -> Any: "tags": ["my_model"], }, { - "data": {"chunk": _AnyIdAIMessage(content="hello world!")}, + "data": {"chunk": _any_id_ai_message(content="hello world!")}, "event": "on_chain_stream", "metadata": {}, "name": "i_dont_stream", @@ -710,7 +710,7 @@ def i_dont_stream(input: Any, config: RunnableConfig) -> Any: "tags": [], }, { - "data": {"output": _AnyIdAIMessage(content="hello world!")}, + "data": {"output": _any_id_ai_message(content="hello world!")}, "event": "on_chain_end", "metadata": {}, "name": "i_dont_stream", @@ -755,7 +755,7 @@ async def ai_dont_stream(input: Any, config: RunnableConfig) -> Any: "tags": ["my_model"], }, { - "data": {"chunk": _AnyIdAIMessageChunk(content="hello")}, + "data": {"chunk": _any_id_ai_message_chunk(content="hello")}, "event": "on_chat_model_stream", "metadata": { "a": "b", @@ -768,7 +768,7 @@ async def ai_dont_stream(input: Any, config: RunnableConfig) -> Any: "tags": ["my_model"], }, { - "data": {"chunk": _AnyIdAIMessageChunk(content=" ")}, + "data": {"chunk": _any_id_ai_message_chunk(content=" ")}, "event": "on_chat_model_stream", "metadata": { "a": "b", @@ -781,7 +781,7 @@ async def ai_dont_stream(input: Any, config: RunnableConfig) -> Any: "tags": ["my_model"], }, { - "data": {"chunk": _AnyIdAIMessageChunk(content="world!")}, + "data": {"chunk": _any_id_ai_message_chunk(content="world!")}, "event": "on_chat_model_stream", "metadata": { "a": "b", @@ -796,7 +796,7 @@ async def ai_dont_stream(input: Any, config: RunnableConfig) -> Any: { "data": { "input": {"messages": [[HumanMessage(content="hello")]]}, - "output": _AnyIdAIMessage(content="hello world!"), + "output": _any_id_ai_message(content="hello world!"), }, "event": "on_chat_model_end", "metadata": { @@ -810,7 +810,7 @@ async def ai_dont_stream(input: Any, config: RunnableConfig) -> Any: "tags": ["my_model"], }, { - "data": {"chunk": _AnyIdAIMessage(content="hello world!")}, + "data": {"chunk": _any_id_ai_message(content="hello world!")}, "event": "on_chain_stream", "metadata": {}, "name": "ai_dont_stream", @@ -819,7 +819,7 @@ async def ai_dont_stream(input: Any, config: RunnableConfig) -> Any: "tags": [], }, { - "data": {"output": _AnyIdAIMessage(content="hello world!")}, + "data": {"output": _any_id_ai_message(content="hello world!")}, "event": "on_chain_end", "metadata": {}, "name": "ai_dont_stream", diff --git a/libs/core/tests/unit_tests/stubs.py b/libs/core/tests/unit_tests/stubs.py index b752364e3af5d..95f36b72b0fbf 100644 --- a/libs/core/tests/unit_tests/stubs.py +++ b/libs/core/tests/unit_tests/stubs.py @@ -16,28 +16,28 @@ def __eq__(self, other: Any) -> bool: # subclassed strings. -def _AnyIdDocument(**kwargs: Any) -> Document: +def _any_id_document(**kwargs: Any) -> Document: """Create a document with an id field.""" message = Document(**kwargs) message.id = AnyStr() return message -def _AnyIdAIMessage(**kwargs: Any) -> AIMessage: +def _any_id_ai_message(**kwargs: Any) -> AIMessage: """Create ai message with an any id field.""" message = AIMessage(**kwargs) message.id = AnyStr() return message -def _AnyIdAIMessageChunk(**kwargs: Any) -> AIMessageChunk: +def _any_id_ai_message_chunk(**kwargs: Any) -> AIMessageChunk: """Create ai message with an any id field.""" message = AIMessageChunk(**kwargs) message.id = AnyStr() return message -def _AnyIdHumanMessage(**kwargs: Any) -> HumanMessage: +def _any_id_human_message(**kwargs: Any) -> HumanMessage: """Create a human with an any id field.""" message = HumanMessage(**kwargs) message.id = AnyStr() diff --git a/libs/core/tests/unit_tests/test_messages.py b/libs/core/tests/unit_tests/test_messages.py index 7fdbfc5e66baa..5ca6a5d22adbd 100644 --- a/libs/core/tests/unit_tests/test_messages.py +++ b/libs/core/tests/unit_tests/test_messages.py @@ -768,7 +768,7 @@ def test_convert_to_messages() -> None: @pytest.mark.parametrize( - "MessageClass", + "message_class", [ AIMessage, AIMessageChunk, @@ -777,39 +777,39 @@ def test_convert_to_messages() -> None: SystemMessage, ], ) -def test_message_name(MessageClass: Type) -> None: - msg = MessageClass(content="foo", name="bar") +def test_message_name(message_class: Type) -> None: + msg = message_class(content="foo", name="bar") assert msg.name == "bar" - msg2 = MessageClass(content="foo", name=None) + msg2 = message_class(content="foo", name=None) assert msg2.name is None - msg3 = MessageClass(content="foo") + msg3 = message_class(content="foo") assert msg3.name is None @pytest.mark.parametrize( - "MessageClass", + "message_class", [FunctionMessage, FunctionMessageChunk], ) -def test_message_name_function(MessageClass: Type) -> None: +def test_message_name_function(message_class: Type) -> None: # functionmessage doesn't support name=None - msg = MessageClass(name="foo", content="bar") + msg = message_class(name="foo", content="bar") assert msg.name == "foo" @pytest.mark.parametrize( - "MessageClass", + "message_class", [ChatMessage, ChatMessageChunk], ) -def test_message_name_chat(MessageClass: Type) -> None: - msg = MessageClass(content="foo", role="user", name="bar") +def test_message_name_chat(message_class: Type) -> None: + msg = message_class(content="foo", role="user", name="bar") assert msg.name == "bar" - msg2 = MessageClass(content="foo", role="user", name=None) + msg2 = message_class(content="foo", role="user", name=None) assert msg2.name is None - msg3 = MessageClass(content="foo", role="user") + msg3 = message_class(content="foo", role="user") assert msg3.name is None diff --git a/libs/core/tests/unit_tests/test_tools.py b/libs/core/tests/unit_tests/test_tools.py index 47d13863a10cb..ad3c835ab71f5 100644 --- a/libs/core/tests/unit_tests/test_tools.py +++ b/libs/core/tests/unit_tests/test_tools.py @@ -1359,7 +1359,7 @@ def _run(self, x: int, y: Annotated[str, InjectedToolArg]) -> Any: return y -class fooSchema(BaseModel): +class fooSchema(BaseModel): # noqa: N801 """foo.""" x: int = Field(..., description="abc") @@ -1466,14 +1466,14 @@ def test_tool_injected_arg_with_schema(tool_: BaseTool) -> None: def test_tool_inherited_injected_arg() -> None: - class barSchema(BaseModel): + class BarSchema(BaseModel): """bar.""" y: Annotated[str, "foobar comment", InjectedToolArg()] = Field( ..., description="123" ) - class fooSchema(barSchema): + class FooSchema(BarSchema): """foo.""" x: int = Field(..., description="abc") @@ -1481,14 +1481,14 @@ class fooSchema(barSchema): class InheritedInjectedArgTool(BaseTool): name: str = "foo" description: str = "foo." - args_schema: Type[BaseModel] = fooSchema + args_schema: Type[BaseModel] = FooSchema def _run(self, x: int, y: str) -> Any: return y tool_ = InheritedInjectedArgTool() assert tool_.get_input_schema().schema() == { - "title": "fooSchema", + "title": "FooSchema", "description": "foo.", "type": "object", "properties": { diff --git a/libs/core/tests/unit_tests/utils/test_function_calling.py b/libs/core/tests/unit_tests/utils/test_function_calling.py index daa981d31434b..9bc01b39de6b7 100644 --- a/libs/core/tests/unit_tests/utils/test_function_calling.py +++ b/libs/core/tests/unit_tests/utils/test_function_calling.py @@ -46,7 +46,7 @@ @pytest.fixture() def pydantic() -> Type[BaseModel]: - class dummy_function(BaseModel): + class dummy_function(BaseModel): # noqa: N801 """dummy function""" arg1: int = Field(..., description="foo") @@ -56,7 +56,7 @@ class dummy_function(BaseModel): @pytest.fixture() -def Annotated_function() -> Callable: +def annotated_function() -> Callable: def dummy_function( arg1: ExtensionsAnnotated[int, "foo"], arg2: ExtensionsAnnotated[Literal["bar", "baz"], "one of 'bar', 'baz'"], @@ -112,7 +112,7 @@ def _run(self, *args: Any, **kwargs: Any) -> Any: @pytest.fixture() def dummy_pydantic() -> Type[BaseModel]: - class dummy_function(BaseModel): + class dummy_function(BaseModel): # noqa: N801 """dummy function""" arg1: int = Field(..., description="foo") @@ -123,7 +123,7 @@ class dummy_function(BaseModel): @pytest.fixture() def dummy_pydantic_v2() -> Type[BaseModelV2Maybe]: - class dummy_function(BaseModelV2Maybe): + class dummy_function(BaseModelV2Maybe): # noqa: N801 """dummy function""" arg1: int = FieldV2Maybe(..., description="foo") @@ -136,7 +136,7 @@ class dummy_function(BaseModelV2Maybe): @pytest.fixture() def dummy_typing_typed_dict() -> Type: - class dummy_function(TypingTypedDict): + class dummy_function(TypingTypedDict): # noqa: N801 """dummy function""" arg1: TypingAnnotated[int, ..., "foo"] # noqa: F821 @@ -147,7 +147,7 @@ class dummy_function(TypingTypedDict): @pytest.fixture() def dummy_typing_typed_dict_docstring() -> Type: - class dummy_function(TypingTypedDict): + class dummy_function(TypingTypedDict): # noqa: N801 """dummy function Args: @@ -163,7 +163,7 @@ class dummy_function(TypingTypedDict): @pytest.fixture() def dummy_extensions_typed_dict() -> Type: - class dummy_function(ExtensionsTypedDict): + class dummy_function(ExtensionsTypedDict): # noqa: N801 """dummy function""" arg1: ExtensionsAnnotated[int, ..., "foo"] @@ -174,7 +174,7 @@ class dummy_function(ExtensionsTypedDict): @pytest.fixture() def dummy_extensions_typed_dict_docstring() -> Type: - class dummy_function(ExtensionsTypedDict): + class dummy_function(ExtensionsTypedDict): # noqa: N801 """dummy function Args: @@ -234,7 +234,7 @@ def test_convert_to_openai_function( function: Callable, dummy_tool: BaseTool, json_schema: Dict, - Annotated_function: Callable, + annotated_function: Callable, dummy_pydantic: Type[BaseModel], runnable: Runnable, dummy_typing_typed_dict: Type, @@ -267,7 +267,7 @@ def test_convert_to_openai_function( expected, Dummy.dummy_function, DummyWithClassMethod.dummy_function, - Annotated_function, + annotated_function, dummy_pydantic, dummy_typing_typed_dict, dummy_typing_typed_dict_docstring, @@ -454,20 +454,20 @@ def test__convert_typed_dict_to_openai_function( use_extension_typed_dict: bool, use_extension_annotated: bool ) -> None: if use_extension_typed_dict: - TypedDict = ExtensionsTypedDict + typed_dict = ExtensionsTypedDict else: - TypedDict = TypingTypedDict + typed_dict = TypingTypedDict if use_extension_annotated: - Annotated = TypingAnnotated + annotated = TypingAnnotated else: - Annotated = TypingAnnotated + annotated = TypingAnnotated - class SubTool(TypedDict): + class SubTool(typed_dict): """Subtool docstring""" - args: Annotated[Dict[str, Any], {}, "this does bar"] # noqa: F722 # type: ignore + args: annotated[Dict[str, Any], {}, "this does bar"] # noqa: F722 # type: ignore - class Tool(TypedDict): + class Tool(typed_dict): """Docstring Args: @@ -477,20 +477,20 @@ class Tool(TypedDict): arg1: str arg2: Union[int, str, bool] arg3: Optional[List[SubTool]] - arg4: Annotated[Literal["bar", "baz"], ..., "this does foo"] # noqa: F722 - arg5: Annotated[Optional[float], None] - arg6: Annotated[ + arg4: annotated[Literal["bar", "baz"], ..., "this does foo"] # noqa: F722 + arg5: annotated[Optional[float], None] + arg6: annotated[ Optional[Sequence[Mapping[str, Tuple[Iterable[Any], SubTool]]]], [] ] - arg7: Annotated[List[SubTool], ...] - arg8: Annotated[Tuple[SubTool], ...] - arg9: Annotated[Sequence[SubTool], ...] - arg10: Annotated[Iterable[SubTool], ...] - arg11: Annotated[Set[SubTool], ...] - arg12: Annotated[Dict[str, SubTool], ...] - arg13: Annotated[Mapping[str, SubTool], ...] - arg14: Annotated[MutableMapping[str, SubTool], ...] - arg15: Annotated[bool, False, "flag"] # noqa: F821 # type: ignore + arg7: annotated[List[SubTool], ...] + arg8: annotated[Tuple[SubTool], ...] + arg9: annotated[Sequence[SubTool], ...] + arg10: annotated[Iterable[SubTool], ...] + arg11: annotated[Set[SubTool], ...] + arg12: annotated[Dict[str, SubTool], ...] + arg13: annotated[Mapping[str, SubTool], ...] + arg14: annotated[MutableMapping[str, SubTool], ...] + arg15: annotated[bool, False, "flag"] # noqa: F821 # type: ignore expected = { "name": "Tool", diff --git a/libs/core/tests/unit_tests/vectorstores/test_in_memory.py b/libs/core/tests/unit_tests/vectorstores/test_in_memory.py index 8d30afac3caf4..1b1e898462ffe 100644 --- a/libs/core/tests/unit_tests/vectorstores/test_in_memory.py +++ b/libs/core/tests/unit_tests/vectorstores/test_in_memory.py @@ -10,7 +10,7 @@ from langchain_core.documents import Document from langchain_core.embeddings.fake import DeterministicFakeEmbedding from langchain_core.vectorstores import InMemoryVectorStore -from tests.unit_tests.stubs import AnyStr, _AnyIdDocument +from tests.unit_tests.stubs import AnyStr, _any_id_document class TestInMemoryReadWriteTestSuite(ReadWriteTestSuite): @@ -33,13 +33,13 @@ async def test_inmemory_similarity_search() -> None: # Check sync version output = store.similarity_search("foo", k=1) - assert output == [_AnyIdDocument(page_content="foo")] + assert output == [_any_id_document(page_content="foo")] # Check async version output = await store.asimilarity_search("bar", k=2) assert output == [ - _AnyIdDocument(page_content="bar"), - _AnyIdDocument(page_content="baz"), + _any_id_document(page_content="bar"), + _any_id_document(page_content="baz"), ] @@ -80,16 +80,16 @@ async def test_inmemory_mmr() -> None: # make sure we can k > docstore size output = docsearch.max_marginal_relevance_search("foo", k=10, lambda_mult=0.1) assert len(output) == len(texts) - assert output[0] == _AnyIdDocument(page_content="foo") - assert output[1] == _AnyIdDocument(page_content="foy") + assert output[0] == _any_id_document(page_content="foo") + assert output[1] == _any_id_document(page_content="foy") # Check async version output = await docsearch.amax_marginal_relevance_search( "foo", k=10, lambda_mult=0.1 ) assert len(output) == len(texts) - assert output[0] == _AnyIdDocument(page_content="foo") - assert output[1] == _AnyIdDocument(page_content="foy") + assert output[0] == _any_id_document(page_content="foo") + assert output[1] == _any_id_document(page_content="foy") async def test_inmemory_dump_load(tmp_path: Path) -> None: diff --git a/libs/core/tests/unit_tests/vectorstores/test_vectorstore.py b/libs/core/tests/unit_tests/vectorstores/test_vectorstore.py index 971315752b8c2..52ff3685a97d7 100644 --- a/libs/core/tests/unit_tests/vectorstores/test_vectorstore.py +++ b/libs/core/tests/unit_tests/vectorstores/test_vectorstore.py @@ -49,6 +49,7 @@ def add_texts( def get_by_ids(self, ids: Sequence[str], /) -> List[Document]: return [self.store[id] for id in ids if id in self.store] + @classmethod def from_texts( # type: ignore cls, texts: List[str],