From 798a763ca8acf429e472287680ed288b3302d4ea Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Wed, 18 Sep 2024 14:01:59 +0200 Subject: [PATCH] core: Put Python version as a project requirement so it is considered by ruff --- .../langchain_core/_api/beta_decorator.py | 5 +- libs/core/langchain_core/_api/deprecation.py | 9 +- libs/core/langchain_core/agents.py | 3 +- .../langchain_core/beta/runnables/context.py | 36 +++--- libs/core/langchain_core/caches.py | 3 +- libs/core/langchain_core/callbacks/base.py | 5 +- libs/core/langchain_core/callbacks/manager.py | 8 +- libs/core/langchain_core/chat_history.py | 3 +- libs/core/langchain_core/chat_loaders.py | 4 +- libs/core/langchain_core/chat_sessions.py | 3 +- .../langchain_core/document_loaders/base.py | 3 +- .../document_loaders/blob_loaders.py | 2 +- .../document_loaders/langsmith.py | 3 +- libs/core/langchain_core/documents/base.py | 3 +- .../langchain_core/documents/compressor.py | 3 +- .../langchain_core/documents/transformers.py | 3 +- .../langchain_core/embeddings/embeddings.py | 9 +- libs/core/langchain_core/embeddings/fake.py | 13 +-- .../langchain_core/example_selectors/base.py | 10 +- .../example_selectors/length_based.py | 14 +-- .../langchain_core/graph_vectorstores/base.py | 5 +- .../graph_vectorstores/links.py | 5 +- libs/core/langchain_core/indexing/api.py | 6 +- libs/core/langchain_core/indexing/base.py | 3 +- .../core/langchain_core/indexing/in_memory.py | 11 +- .../langchain_core/language_models/base.py | 10 +- .../language_models/chat_models.py | 12 +- .../langchain_core/language_models/fake.py | 13 ++- .../language_models/fake_chat_models.py | 51 ++++----- .../langchain_core/language_models/llms.py | 34 +++--- libs/core/langchain_core/load/load.py | 22 ++-- libs/core/langchain_core/load/mapping.py | 10 +- libs/core/langchain_core/load/serializable.py | 22 ++-- libs/core/langchain_core/messages/ai.py | 22 ++-- libs/core/langchain_core/messages/base.py | 5 +- libs/core/langchain_core/messages/chat.py | 6 +- libs/core/langchain_core/messages/function.py | 6 +- libs/core/langchain_core/messages/human.py | 8 +- libs/core/langchain_core/messages/modifier.py | 4 +- libs/core/langchain_core/messages/system.py | 8 +- libs/core/langchain_core/messages/tool.py | 18 +-- libs/core/langchain_core/messages/utils.py | 10 +- .../langchain_core/output_parsers/json.py | 3 +- .../langchain_core/output_parsers/list.py | 5 +- .../output_parsers/openai_functions.py | 25 +++-- .../output_parsers/openai_tools.py | 23 ++-- .../langchain_core/output_parsers/pydantic.py | 9 +- .../langchain_core/output_parsers/string.py | 3 +- .../output_parsers/transform.py | 3 +- .../core/langchain_core/output_parsers/xml.py | 15 +-- .../langchain_core/outputs/chat_result.py | 4 +- libs/core/langchain_core/prompt_values.py | 3 +- libs/core/langchain_core/prompts/base.py | 10 +- libs/core/langchain_core/prompts/chat.py | 17 ++- .../prompts/few_shot_with_templates.py | 12 +- libs/core/langchain_core/prompts/image.py | 4 +- libs/core/langchain_core/prompts/loading.py | 6 +- libs/core/langchain_core/prompts/pipeline.py | 10 +- libs/core/langchain_core/prompts/string.py | 4 +- .../core/langchain_core/prompts/structured.py | 23 ++-- libs/core/langchain_core/retrievers.py | 4 +- libs/core/langchain_core/runnables/base.py | 99 ++++++++++------- libs/core/langchain_core/runnables/branch.py | 19 +--- libs/core/langchain_core/runnables/config.py | 9 +- .../langchain_core/runnables/configurable.py | 16 +-- .../langchain_core/runnables/fallbacks.py | 42 +++---- libs/core/langchain_core/runnables/graph.py | 2 +- .../langchain_core/runnables/graph_ascii.py | 3 +- .../langchain_core/runnables/graph_mermaid.py | 10 +- libs/core/langchain_core/runnables/history.py | 12 +- .../langchain_core/runnables/passthrough.py | 12 +- libs/core/langchain_core/runnables/retry.py | 66 ++++++----- libs/core/langchain_core/runnables/router.py | 7 +- libs/core/langchain_core/runnables/schema.py | 3 +- libs/core/langchain_core/runnables/utils.py | 19 ++-- libs/core/langchain_core/stores.py | 25 ++--- libs/core/langchain_core/structured_query.py | 3 +- libs/core/langchain_core/sys_info.py | 4 +- libs/core/langchain_core/tools/base.py | 7 +- libs/core/langchain_core/tools/convert.py | 14 +-- libs/core/langchain_core/tools/render.py | 4 +- libs/core/langchain_core/tools/simple.py | 2 +- libs/core/langchain_core/tools/structured.py | 4 +- .../core/langchain_core/tracers/_streaming.py | 3 +- libs/core/langchain_core/tracers/base.py | 2 +- libs/core/langchain_core/tracers/context.py | 2 +- libs/core/langchain_core/tracers/core.py | 3 +- .../core/langchain_core/tracers/evaluation.py | 5 +- .../langchain_core/tracers/event_stream.py | 9 +- .../core/langchain_core/tracers/log_stream.py | 4 +- .../langchain_core/tracers/memory_stream.py | 3 +- .../langchain_core/tracers/root_listeners.py | 3 +- .../langchain_core/tracers/run_collector.py | 4 +- libs/core/langchain_core/tracers/stdout.py | 4 +- libs/core/langchain_core/utils/aiter.py | 37 +++---- libs/core/langchain_core/utils/formatting.py | 5 +- .../langchain_core/utils/function_calling.py | 28 ++--- libs/core/langchain_core/utils/html.py | 7 +- libs/core/langchain_core/utils/input.py | 6 +- libs/core/langchain_core/utils/iter.py | 25 ++--- libs/core/langchain_core/utils/json_schema.py | 3 +- libs/core/langchain_core/utils/mustache.py | 10 +- libs/core/langchain_core/utils/pydantic.py | 28 +++-- libs/core/langchain_core/utils/strings.py | 4 +- libs/core/langchain_core/utils/utils.py | 15 +-- libs/core/langchain_core/vectorstores/base.py | 4 +- .../langchain_core/vectorstores/in_memory.py | 3 +- .../core/langchain_core/vectorstores/utils.py | 4 +- libs/core/pyproject.toml | 3 + .../unit_tests/_api/test_beta_decorator.py | 4 +- .../tests/unit_tests/_api/test_deprecation.py | 4 +- .../unit_tests/caches/test_in_memory_cache.py | 4 +- .../callbacks/test_dispatch_custom_event.py | 14 +-- .../chat_history/test_chat_history.py | 10 +- libs/core/tests/unit_tests/conftest.py | 4 +- .../unit_tests/document_loaders/test_base.py | 4 +- .../unit_tests/example_selectors/test_base.py | 8 +- .../example_selectors/test_similarity.py | 19 ++-- libs/core/tests/unit_tests/fake/callbacks.py | 8 +- .../unit_tests/fake/test_fake_chat_model.py | 16 +-- .../indexing/test_in_memory_indexer.py | 2 +- .../unit_tests/indexing/test_indexing.py | 5 +- .../language_models/chat_models/test_base.py | 31 +++--- .../language_models/chat_models/test_cache.py | 4 +- .../language_models/llms/test_base.py | 19 ++-- .../language_models/llms/test_cache.py | 6 +- .../unit_tests/load/test_serializable.py | 4 +- .../tests/unit_tests/messages/test_utils.py | 19 ++-- .../output_parsers/test_base_parsers.py | 5 +- .../unit_tests/output_parsers/test_json.py | 5 +- .../output_parsers/test_list_parser.py | 11 +- .../output_parsers/test_openai_functions.py | 4 +- .../output_parsers/test_openai_tools.py | 5 +- .../output_parsers/test_xml_parser.py | 2 +- .../tests/unit_tests/prompts/test_chat.py | 12 +- .../tests/unit_tests/prompts/test_few_shot.py | 25 +++-- .../tests/unit_tests/prompts/test_loading.py | 2 +- .../tests/unit_tests/prompts/test_prompt.py | 4 +- .../unit_tests/prompts/test_structured.py | 10 +- libs/core/tests/unit_tests/pydantic_utils.py | 4 +- .../tests/unit_tests/runnables/test_config.py | 4 +- .../unit_tests/runnables/test_configurable.py | 4 +- .../unit_tests/runnables/test_context.py | 4 +- .../unit_tests/runnables/test_fallbacks.py | 27 ++--- .../unit_tests/runnables/test_history.py | 23 ++-- .../unit_tests/runnables/test_runnable.py | 104 ++++++++---------- .../runnables/test_runnable_events_v1.py | 23 ++-- .../runnables/test_runnable_events_v2.py | 25 ++--- .../runnables/test_tracing_interops.py | 3 +- .../tests/unit_tests/runnables/test_utils.py | 6 +- .../tests/unit_tests/stores/test_in_memory.py | 6 +- libs/core/tests/unit_tests/test_imports.py | 3 +- libs/core/tests/unit_tests/test_messages.py | 10 +- libs/core/tests/unit_tests/test_tools.py | 79 +++++++------ .../unit_tests/tracers/test_langchain.py | 4 +- .../unit_tests/tracers/test_memory_stream.py | 2 +- .../core/tests/unit_tests/utils/test_aiter.py | 6 +- .../unit_tests/utils/test_function_calling.py | 68 +++++------- libs/core/tests/unit_tests/utils/test_iter.py | 4 +- .../tests/unit_tests/utils/test_pydantic.py | 10 +- .../core/tests/unit_tests/utils/test_utils.py | 6 +- .../vectorstores/test_vectorstore.py | 3 +- 162 files changed, 919 insertions(+), 1001 deletions(-) diff --git a/libs/core/langchain_core/_api/beta_decorator.py b/libs/core/langchain_core/_api/beta_decorator.py index d27a27a1c834c..1c421b5b01b46 100644 --- a/libs/core/langchain_core/_api/beta_decorator.py +++ b/libs/core/langchain_core/_api/beta_decorator.py @@ -14,7 +14,8 @@ import functools import inspect import warnings -from typing import Any, Callable, Generator, Type, TypeVar, Union, cast +from collections.abc import Generator +from typing import Any, Callable, TypeVar, Union, cast from langchain_core._api.internal import is_caller_internal @@ -26,7 +27,7 @@ class LangChainBetaWarning(DeprecationWarning): # PUBLIC API -T = TypeVar("T", bound=Union[Callable[..., Any], Type]) +T = TypeVar("T", bound=Union[Callable[..., Any], type]) def beta( diff --git a/libs/core/langchain_core/_api/deprecation.py b/libs/core/langchain_core/_api/deprecation.py index 8df4c1893dbe5..33d59e819e596 100644 --- a/libs/core/langchain_core/_api/deprecation.py +++ b/libs/core/langchain_core/_api/deprecation.py @@ -14,11 +14,10 @@ import functools import inspect import warnings +from collections.abc import Generator from typing import ( Any, Callable, - Generator, - Type, TypeVar, Union, cast, @@ -41,7 +40,7 @@ class LangChainPendingDeprecationWarning(PendingDeprecationWarning): # Last Any should be FieldInfoV1 but this leads to circular imports -T = TypeVar("T", bound=Union[Type, Callable[..., Any], Any]) +T = TypeVar("T", bound=Union[type, Callable[..., Any], Any]) def _validate_deprecation_params( @@ -262,7 +261,7 @@ def finalize(wrapper: Callable[..., Any], new_doc: str) -> T: if not _obj_type: _obj_type = "attribute" wrapped = None - _name = _name or cast(Union[Type, Callable], obj.fget).__qualname__ + _name = _name or cast(Union[type, Callable], obj.fget).__qualname__ old_doc = obj.__doc__ class _deprecated_property(property): @@ -304,7 +303,7 @@ def finalize(wrapper: Callable[..., Any], new_doc: str) -> T: ) else: - _name = _name or cast(Union[Type, Callable], obj).__qualname__ + _name = _name or cast(Union[type, Callable], obj).__qualname__ if not _obj_type: # edge case: when a function is within another function # within a test, this will call it a "method" not a "function" diff --git a/libs/core/langchain_core/agents.py b/libs/core/langchain_core/agents.py index 3e76591f22920..393933174739e 100644 --- a/libs/core/langchain_core/agents.py +++ b/libs/core/langchain_core/agents.py @@ -25,7 +25,8 @@ from __future__ import annotations import json -from typing import Any, Literal, Sequence, Union +from collections.abc import Sequence +from typing import Any, Literal, Union from langchain_core.load.serializable import Serializable from langchain_core.messages import ( diff --git a/libs/core/langchain_core/beta/runnables/context.py b/libs/core/langchain_core/beta/runnables/context.py index 3b95bf2859b4d..70cc1fae324d4 100644 --- a/libs/core/langchain_core/beta/runnables/context.py +++ b/libs/core/langchain_core/beta/runnables/context.py @@ -1,19 +1,13 @@ import asyncio import threading from collections import defaultdict +from collections.abc import Awaitable, Mapping, Sequence from functools import partial from itertools import groupby from typing import ( Any, - Awaitable, Callable, - DefaultDict, - Dict, - List, - Mapping, Optional, - Sequence, - Type, TypeVar, Union, ) @@ -30,7 +24,7 @@ from langchain_core.runnables.utils import ConfigurableFieldSpec, Input, Output T = TypeVar("T") -Values = Dict[Union[asyncio.Event, threading.Event], Any] +Values = dict[Union[asyncio.Event, threading.Event], Any] CONTEXT_CONFIG_PREFIX = "__context__/" CONTEXT_CONFIG_SUFFIX_GET = "/get" CONTEXT_CONFIG_SUFFIX_SET = "/set" @@ -70,10 +64,10 @@ def _key_from_id(id_: str) -> str: def _config_with_context( config: RunnableConfig, - steps: List[Runnable], + steps: list[Runnable], setter: Callable, getter: Callable, - event_cls: Union[Type[threading.Event], Type[asyncio.Event]], + event_cls: Union[type[threading.Event], type[asyncio.Event]], ) -> RunnableConfig: if any(k.startswith(CONTEXT_CONFIG_PREFIX) for k in config.get("configurable", {})): return config @@ -99,10 +93,10 @@ def _config_with_context( } values: Values = {} - events: DefaultDict[str, Union[asyncio.Event, threading.Event]] = defaultdict( + events: defaultdict[str, Union[asyncio.Event, threading.Event]] = defaultdict( event_cls ) - context_funcs: Dict[str, Callable[[], Any]] = {} + context_funcs: dict[str, Callable[[], Any]] = {} for key, group in grouped_by_key.items(): getters = [s for s in group if s[0].id.endswith(CONTEXT_CONFIG_SUFFIX_GET)] setters = [s for s in group if s[0].id.endswith(CONTEXT_CONFIG_SUFFIX_SET)] @@ -129,7 +123,7 @@ def _config_with_context( def aconfig_with_context( config: RunnableConfig, - steps: List[Runnable], + steps: list[Runnable], ) -> RunnableConfig: """Asynchronously patch a runnable config with context getters and setters. @@ -145,7 +139,7 @@ def aconfig_with_context( def config_with_context( config: RunnableConfig, - steps: List[Runnable], + steps: list[Runnable], ) -> RunnableConfig: """Patch a runnable config with context getters and setters. @@ -165,13 +159,13 @@ class ContextGet(RunnableSerializable): prefix: str = "" - key: Union[str, List[str]] + key: Union[str, list[str]] def __str__(self) -> str: return f"ContextGet({_print_keys(self.key)})" @property - def ids(self) -> List[str]: + def ids(self) -> list[str]: prefix = self.prefix + "/" if self.prefix else "" keys = self.key if isinstance(self.key, list) else [self.key] return [ @@ -180,7 +174,7 @@ def ids(self) -> List[str]: ] @property - def config_specs(self) -> List[ConfigurableFieldSpec]: + def config_specs(self) -> list[ConfigurableFieldSpec]: return super().config_specs + [ ConfigurableFieldSpec( id=id_, @@ -256,7 +250,7 @@ def __str__(self) -> str: return f"ContextSet({_print_keys(list(self.keys.keys()))})" @property - def ids(self) -> List[str]: + def ids(self) -> list[str]: prefix = self.prefix + "/" if self.prefix else "" return [ f"{CONTEXT_CONFIG_PREFIX}{prefix}{key}{CONTEXT_CONFIG_SUFFIX_SET}" @@ -264,7 +258,7 @@ def ids(self) -> List[str]: ] @property - def config_specs(self) -> List[ConfigurableFieldSpec]: + def config_specs(self) -> list[ConfigurableFieldSpec]: mapper_config_specs = [ s for mapper in self.keys.values() @@ -364,7 +358,7 @@ def create_scope(scope: str, /) -> "PrefixContext": return PrefixContext(prefix=scope) @staticmethod - def getter(key: Union[str, List[str]], /) -> ContextGet: + def getter(key: Union[str, list[str]], /) -> ContextGet: return ContextGet(key=key) @staticmethod @@ -385,7 +379,7 @@ class PrefixContext: def __init__(self, prefix: str = ""): self.prefix = prefix - def getter(self, key: Union[str, List[str]], /) -> ContextGet: + def getter(self, key: Union[str, list[str]], /) -> ContextGet: return ContextGet(key=key, prefix=self.prefix) def setter( diff --git a/libs/core/langchain_core/caches.py b/libs/core/langchain_core/caches.py index 236d2e4875ec3..e93ce79694003 100644 --- a/libs/core/langchain_core/caches.py +++ b/libs/core/langchain_core/caches.py @@ -23,7 +23,8 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any, Optional, Sequence +from collections.abc import Sequence +from typing import Any, Optional from langchain_core.outputs import Generation from langchain_core.runnables import run_in_executor diff --git a/libs/core/langchain_core/callbacks/base.py b/libs/core/langchain_core/callbacks/base.py index 7137ca287dfda..c6e9090f78963 100644 --- a/libs/core/langchain_core/callbacks/base.py +++ b/libs/core/langchain_core/callbacks/base.py @@ -3,7 +3,8 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Any, List, Optional, Sequence, TypeVar, Union +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union from uuid import UUID from tenacity import RetryCallState @@ -1070,4 +1071,4 @@ def remove_metadata(self, keys: list[str]) -> None: self.inheritable_metadata.pop(key) -Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]] +Callbacks = Optional[Union[list[BaseCallbackHandler], BaseCallbackManager]] diff --git a/libs/core/langchain_core/callbacks/manager.py b/libs/core/langchain_core/callbacks/manager.py index 6155987dcd344..d8685e51d0219 100644 --- a/libs/core/langchain_core/callbacks/manager.py +++ b/libs/core/langchain_core/callbacks/manager.py @@ -5,19 +5,15 @@ import logging import uuid from abc import ABC, abstractmethod +from collections.abc import AsyncGenerator, Coroutine, Generator, Sequence from concurrent.futures import ThreadPoolExecutor from contextlib import asynccontextmanager, contextmanager from contextvars import copy_context from typing import ( TYPE_CHECKING, Any, - AsyncGenerator, Callable, - Coroutine, - Generator, Optional, - Sequence, - Type, TypeVar, Union, cast, @@ -2352,7 +2348,7 @@ def _configure( and handler_class is not None ) if var.get() is not None or create_one: - var_handler = var.get() or cast(Type[BaseCallbackHandler], handler_class)() + var_handler = var.get() or cast(type[BaseCallbackHandler], handler_class)() if handler_class is None: if not any( handler is var_handler # direct pointer comparison diff --git a/libs/core/langchain_core/chat_history.py b/libs/core/langchain_core/chat_history.py index 77138f3d8fbee..9b3eb09b03a05 100644 --- a/libs/core/langchain_core/chat_history.py +++ b/libs/core/langchain_core/chat_history.py @@ -18,7 +18,8 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Sequence, Union +from collections.abc import Sequence +from typing import Union from pydantic import BaseModel, Field diff --git a/libs/core/langchain_core/chat_loaders.py b/libs/core/langchain_core/chat_loaders.py index 93c14a4f8ffdb..51e65133160e2 100644 --- a/libs/core/langchain_core/chat_loaders.py +++ b/libs/core/langchain_core/chat_loaders.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Iterator, List +from collections.abc import Iterator from langchain_core.chat_sessions import ChatSession @@ -15,7 +15,7 @@ def lazy_load(self) -> Iterator[ChatSession]: An iterator of chat sessions. """ - def load(self) -> List[ChatSession]: + def load(self) -> list[ChatSession]: """Eagerly load the chat sessions into memory. Returns: diff --git a/libs/core/langchain_core/chat_sessions.py b/libs/core/langchain_core/chat_sessions.py index d43754729ae4a..ededbc3155e80 100644 --- a/libs/core/langchain_core/chat_sessions.py +++ b/libs/core/langchain_core/chat_sessions.py @@ -1,6 +1,7 @@ """**Chat Sessions** are a collection of messages and function calls.""" -from typing import Sequence, TypedDict +from collections.abc import Sequence +from typing import TypedDict from langchain_core.messages import BaseMessage diff --git a/libs/core/langchain_core/document_loaders/base.py b/libs/core/langchain_core/document_loaders/base.py index 87955ee584095..540ef888c3afe 100644 --- a/libs/core/langchain_core/document_loaders/base.py +++ b/libs/core/langchain_core/document_loaders/base.py @@ -3,7 +3,8 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, AsyncIterator, Iterator, Optional +from collections.abc import AsyncIterator, Iterator +from typing import TYPE_CHECKING, Optional from langchain_core.documents import Document from langchain_core.runnables import run_in_executor diff --git a/libs/core/langchain_core/document_loaders/blob_loaders.py b/libs/core/langchain_core/document_loaders/blob_loaders.py index bb7ed0128774c..3c0d1986f7336 100644 --- a/libs/core/langchain_core/document_loaders/blob_loaders.py +++ b/libs/core/langchain_core/document_loaders/blob_loaders.py @@ -8,7 +8,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Iterable +from collections.abc import Iterable # Re-export Blob and PathLike for backwards compatibility from langchain_core.documents.base import Blob as Blob diff --git a/libs/core/langchain_core/document_loaders/langsmith.py b/libs/core/langchain_core/document_loaders/langsmith.py index 9da48851d0672..39fda02af5792 100644 --- a/libs/core/langchain_core/document_loaders/langsmith.py +++ b/libs/core/langchain_core/document_loaders/langsmith.py @@ -1,7 +1,8 @@ import datetime import json import uuid -from typing import Any, Callable, Iterator, Optional, Sequence, Union +from collections.abc import Iterator, Sequence +from typing import Any, Callable, Optional, Union from langsmith import Client as LangSmithClient diff --git a/libs/core/langchain_core/documents/base.py b/libs/core/langchain_core/documents/base.py index 0cb357d8bb444..27029635c9094 100644 --- a/libs/core/langchain_core/documents/base.py +++ b/libs/core/langchain_core/documents/base.py @@ -2,9 +2,10 @@ import contextlib import mimetypes +from collections.abc import Generator from io import BufferedReader, BytesIO from pathlib import PurePath -from typing import Any, Generator, Literal, Optional, Union, cast +from typing import Any, Literal, Optional, Union, cast from pydantic import ConfigDict, Field, field_validator, model_validator diff --git a/libs/core/langchain_core/documents/compressor.py b/libs/core/langchain_core/documents/compressor.py index 95c8cd9a96394..31ae2901a7bbd 100644 --- a/libs/core/langchain_core/documents/compressor.py +++ b/libs/core/langchain_core/documents/compressor.py @@ -1,7 +1,8 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Optional, Sequence +from collections.abc import Sequence +from typing import Optional from pydantic import BaseModel diff --git a/libs/core/langchain_core/documents/transformers.py b/libs/core/langchain_core/documents/transformers.py index a11e20f4e5078..12167f820f92c 100644 --- a/libs/core/langchain_core/documents/transformers.py +++ b/libs/core/langchain_core/documents/transformers.py @@ -1,7 +1,8 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Sequence +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any from langchain_core.runnables.config import run_in_executor diff --git a/libs/core/langchain_core/embeddings/embeddings.py b/libs/core/langchain_core/embeddings/embeddings.py index f2b0e83f80af4..39c0eb42a8953 100644 --- a/libs/core/langchain_core/embeddings/embeddings.py +++ b/libs/core/langchain_core/embeddings/embeddings.py @@ -1,7 +1,6 @@ """**Embeddings** interface.""" from abc import ABC, abstractmethod -from typing import List from langchain_core.runnables.config import run_in_executor @@ -35,7 +34,7 @@ class Embeddings(ABC): """ @abstractmethod - def embed_documents(self, texts: List[str]) -> List[List[float]]: + def embed_documents(self, texts: list[str]) -> list[list[float]]: """Embed search docs. Args: @@ -46,7 +45,7 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]: """ @abstractmethod - def embed_query(self, text: str) -> List[float]: + def embed_query(self, text: str) -> list[float]: """Embed query text. Args: @@ -56,7 +55,7 @@ def embed_query(self, text: str) -> List[float]: Embedding. """ - async def aembed_documents(self, texts: List[str]) -> List[List[float]]: + async def aembed_documents(self, texts: list[str]) -> list[list[float]]: """Asynchronous Embed search docs. Args: @@ -67,7 +66,7 @@ async def aembed_documents(self, texts: List[str]) -> List[List[float]]: """ return await run_in_executor(None, self.embed_documents, texts) - async def aembed_query(self, text: str) -> List[float]: + async def aembed_query(self, text: str) -> list[float]: """Asynchronous Embed query text. Args: diff --git a/libs/core/langchain_core/embeddings/fake.py b/libs/core/langchain_core/embeddings/fake.py index 2c8221425ccd5..b5286a9a09ce8 100644 --- a/libs/core/langchain_core/embeddings/fake.py +++ b/libs/core/langchain_core/embeddings/fake.py @@ -2,7 +2,6 @@ # Please do not add additional fake embedding model implementations here. import hashlib -from typing import List from pydantic import BaseModel @@ -51,15 +50,15 @@ class FakeEmbeddings(Embeddings, BaseModel): size: int """The size of the embedding vector.""" - def _get_embedding(self) -> List[float]: + def _get_embedding(self) -> list[float]: import numpy as np # type: ignore[import-not-found, import-untyped] return list(np.random.normal(size=self.size)) - def embed_documents(self, texts: List[str]) -> List[List[float]]: + def embed_documents(self, texts: list[str]) -> list[list[float]]: return [self._get_embedding() for _ in texts] - def embed_query(self, text: str) -> List[float]: + def embed_query(self, text: str) -> list[float]: return self._get_embedding() @@ -106,7 +105,7 @@ class DeterministicFakeEmbedding(Embeddings, BaseModel): size: int """The size of the embedding vector.""" - def _get_embedding(self, seed: int) -> List[float]: + def _get_embedding(self, seed: int) -> list[float]: import numpy as np # type: ignore[import-not-found, import-untyped] # set the seed for the random generator @@ -117,8 +116,8 @@ def _get_seed(self, text: str) -> int: """Get a seed for the random generator, using the hash of the text.""" return int(hashlib.sha256(text.encode("utf-8")).hexdigest(), 16) % 10**8 - def embed_documents(self, texts: List[str]) -> List[List[float]]: + def embed_documents(self, texts: list[str]) -> list[list[float]]: return [self._get_embedding(seed=self._get_seed(_)) for _ in texts] - def embed_query(self, text: str) -> List[float]: + def embed_query(self, text: str) -> list[float]: return self._get_embedding(seed=self._get_seed(text)) diff --git a/libs/core/langchain_core/example_selectors/base.py b/libs/core/langchain_core/example_selectors/base.py index 80fde2ae37f00..a70c680e83630 100644 --- a/libs/core/langchain_core/example_selectors/base.py +++ b/libs/core/langchain_core/example_selectors/base.py @@ -1,7 +1,7 @@ """Interface for selecting examples to include in prompts.""" from abc import ABC, abstractmethod -from typing import Any, Dict, List +from typing import Any from langchain_core.runnables import run_in_executor @@ -10,14 +10,14 @@ class BaseExampleSelector(ABC): """Interface for selecting examples to include in prompts.""" @abstractmethod - def add_example(self, example: Dict[str, str]) -> Any: + def add_example(self, example: dict[str, str]) -> Any: """Add new example to store. Args: example: A dictionary with keys as input variables and values as their values.""" - async def aadd_example(self, example: Dict[str, str]) -> Any: + async def aadd_example(self, example: dict[str, str]) -> Any: """Async add new example to store. Args: @@ -27,14 +27,14 @@ async def aadd_example(self, example: Dict[str, str]) -> Any: return await run_in_executor(None, self.add_example, example) @abstractmethod - def select_examples(self, input_variables: Dict[str, str]) -> List[dict]: + def select_examples(self, input_variables: dict[str, str]) -> list[dict]: """Select which examples to use based on the inputs. Args: input_variables: A dictionary with keys as input variables and values as their values.""" - async def aselect_examples(self, input_variables: Dict[str, str]) -> List[dict]: + async def aselect_examples(self, input_variables: dict[str, str]) -> list[dict]: """Async select which examples to use based on the inputs. Args: diff --git a/libs/core/langchain_core/example_selectors/length_based.py b/libs/core/langchain_core/example_selectors/length_based.py index d9e6b6ea7623a..e393e2a1dee62 100644 --- a/libs/core/langchain_core/example_selectors/length_based.py +++ b/libs/core/langchain_core/example_selectors/length_based.py @@ -1,7 +1,7 @@ """Select examples based on length.""" import re -from typing import Callable, Dict, List +from typing import Callable from pydantic import BaseModel, Field, model_validator from typing_extensions import Self @@ -17,7 +17,7 @@ def _get_length_based(text: str) -> int: class LengthBasedExampleSelector(BaseExampleSelector, BaseModel): """Select examples based on length.""" - examples: List[dict] + examples: list[dict] """A list of the examples that the prompt template expects.""" example_prompt: PromptTemplate @@ -29,10 +29,10 @@ class LengthBasedExampleSelector(BaseExampleSelector, BaseModel): max_length: int = 2048 """Max length for the prompt, beyond which examples are cut.""" - example_text_lengths: List[int] = Field(default_factory=list) # :meta private: + example_text_lengths: list[int] = Field(default_factory=list) # :meta private: """Length of each example.""" - def add_example(self, example: Dict[str, str]) -> None: + def add_example(self, example: dict[str, str]) -> None: """Add new example to list. Args: @@ -43,7 +43,7 @@ def add_example(self, example: Dict[str, str]) -> None: string_example = self.example_prompt.format(**example) self.example_text_lengths.append(self.get_text_length(string_example)) - async def aadd_example(self, example: Dict[str, str]) -> None: + async def aadd_example(self, example: dict[str, str]) -> None: """Async add new example to list. Args: @@ -62,7 +62,7 @@ def post_init(self) -> Self: self.example_text_lengths = [self.get_text_length(eg) for eg in string_examples] return self - def select_examples(self, input_variables: Dict[str, str]) -> List[dict]: + def select_examples(self, input_variables: dict[str, str]) -> list[dict]: """Select which examples to use based on the input lengths. Args: @@ -86,7 +86,7 @@ def select_examples(self, input_variables: Dict[str, str]) -> List[dict]: i += 1 return examples - async def aselect_examples(self, input_variables: Dict[str, str]) -> List[dict]: + async def aselect_examples(self, input_variables: dict[str, str]) -> list[dict]: """Async select which examples to use based on the input lengths. Args: diff --git a/libs/core/langchain_core/graph_vectorstores/base.py b/libs/core/langchain_core/graph_vectorstores/base.py index d3e7da93ffb4b..e087bf1dc5f56 100644 --- a/libs/core/langchain_core/graph_vectorstores/base.py +++ b/libs/core/langchain_core/graph_vectorstores/base.py @@ -1,13 +1,10 @@ from __future__ import annotations from abc import abstractmethod +from collections.abc import AsyncIterable, Collection, Iterable, Iterator from typing import ( Any, - AsyncIterable, ClassVar, - Collection, - Iterable, - Iterator, Optional, ) diff --git a/libs/core/langchain_core/graph_vectorstores/links.py b/libs/core/langchain_core/graph_vectorstores/links.py index 7464eb4aa956b..7b18638e43fdc 100644 --- a/libs/core/langchain_core/graph_vectorstores/links.py +++ b/libs/core/langchain_core/graph_vectorstores/links.py @@ -1,5 +1,6 @@ +from collections.abc import Iterable from dataclasses import dataclass -from typing import Iterable, List, Literal, Union +from typing import Literal, Union from langchain_core._api import beta from langchain_core.documents import Document @@ -41,7 +42,7 @@ def bidir(kind: str, tag: str) -> "Link": @beta() -def get_links(doc: Document) -> List[Link]: +def get_links(doc: Document) -> list[Link]: """Get the links from a document. Args: diff --git a/libs/core/langchain_core/indexing/api.py b/libs/core/langchain_core/indexing/api.py index 7a9e612c783f1..388feb5ef02ae 100644 --- a/libs/core/langchain_core/indexing/api.py +++ b/libs/core/langchain_core/indexing/api.py @@ -5,17 +5,13 @@ import hashlib import json import uuid +from collections.abc import AsyncIterable, AsyncIterator, Iterable, Iterator, Sequence from itertools import islice from typing import ( Any, - AsyncIterable, - AsyncIterator, Callable, - Iterable, - Iterator, Literal, Optional, - Sequence, TypedDict, TypeVar, Union, diff --git a/libs/core/langchain_core/indexing/base.py b/libs/core/langchain_core/indexing/base.py index de4445783b1df..5ce5422c3d507 100644 --- a/libs/core/langchain_core/indexing/base.py +++ b/libs/core/langchain_core/indexing/base.py @@ -3,7 +3,8 @@ import abc import time from abc import ABC, abstractmethod -from typing import Any, Optional, Sequence, TypedDict +from collections.abc import Sequence +from typing import Any, Optional, TypedDict from langchain_core._api import beta from langchain_core.documents import Document diff --git a/libs/core/langchain_core/indexing/in_memory.py b/libs/core/langchain_core/indexing/in_memory.py index acc4d3f958451..4983ecfe1bf66 100644 --- a/libs/core/langchain_core/indexing/in_memory.py +++ b/libs/core/langchain_core/indexing/in_memory.py @@ -1,5 +1,6 @@ import uuid -from typing import Any, Dict, List, Optional, Sequence, cast +from collections.abc import Sequence +from typing import Any, Optional, cast from pydantic import Field @@ -22,7 +23,7 @@ class InMemoryDocumentIndex(DocumentIndex): .. versionadded:: 0.2.29 """ - store: Dict[str, Document] = Field(default_factory=dict) + store: dict[str, Document] = Field(default_factory=dict) top_k: int = 4 def upsert(self, items: Sequence[Document], /, **kwargs: Any) -> UpsertResponse: @@ -43,7 +44,7 @@ def upsert(self, items: Sequence[Document], /, **kwargs: Any) -> UpsertResponse: return UpsertResponse(succeeded=ok_ids, failed=[]) - def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> DeleteResponse: + def delete(self, ids: Optional[list[str]] = None, **kwargs: Any) -> DeleteResponse: """Delete by ID.""" if ids is None: raise ValueError("IDs must be provided for deletion") @@ -59,7 +60,7 @@ def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> DeleteRespon succeeded=ok_ids, num_deleted=len(ok_ids), num_failed=0, failed=[] ) - def get(self, ids: Sequence[str], /, **kwargs: Any) -> List[Document]: + def get(self, ids: Sequence[str], /, **kwargs: Any) -> list[Document]: """Get by ids.""" found_documents = [] @@ -71,7 +72,7 @@ def get(self, ids: Sequence[str], /, **kwargs: Any) -> List[Document]: def _get_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun - ) -> List[Document]: + ) -> list[Document]: counts_by_doc = [] for document in self.store.values(): diff --git a/libs/core/langchain_core/language_models/base.py b/libs/core/langchain_core/language_models/base.py index b21e54189ba71..dc1106205f357 100644 --- a/libs/core/langchain_core/language_models/base.py +++ b/libs/core/langchain_core/language_models/base.py @@ -1,16 +1,14 @@ from __future__ import annotations from abc import ABC, abstractmethod -from functools import lru_cache +from collections.abc import Mapping, Sequence +from functools import cache from typing import ( TYPE_CHECKING, Any, Callable, - List, Literal, - Mapping, Optional, - Sequence, TypeVar, Union, ) @@ -52,7 +50,7 @@ class LangSmithParams(TypedDict, total=False): """Stop words for generation.""" -@lru_cache(maxsize=None) # Cache the tokenizer +@cache # Cache the tokenizer def get_tokenizer() -> Any: """Get a GPT-2 tokenizer instance. @@ -158,7 +156,7 @@ def InputType(self) -> TypeAlias: return Union[ str, Union[StringPromptValue, ChatPromptValueConcrete], - List[AnyMessage], + list[AnyMessage], ] @abstractmethod diff --git a/libs/core/langchain_core/language_models/chat_models.py b/libs/core/langchain_core/language_models/chat_models.py index 869c584854fef..9553cb61225aa 100644 --- a/libs/core/langchain_core/language_models/chat_models.py +++ b/libs/core/langchain_core/language_models/chat_models.py @@ -3,21 +3,19 @@ import asyncio import inspect import json +import typing import uuid import warnings from abc import ABC, abstractmethod +from collections.abc import AsyncIterator, Iterator, Sequence from functools import cached_property from operator import itemgetter from typing import ( TYPE_CHECKING, Any, - AsyncIterator, Callable, - Dict, - Iterator, Literal, Optional, - Sequence, Union, cast, ) @@ -1121,18 +1119,18 @@ def dict(self, **kwargs: Any) -> dict: def bind_tools( self, - tools: Sequence[Union[Dict[str, Any], type, Callable, BaseTool]], # noqa: UP006 + tools: Sequence[Union[typing.Dict[str, Any], type, Callable, BaseTool]], # noqa: UP006 **kwargs: Any, ) -> Runnable[LanguageModelInput, BaseMessage]: raise NotImplementedError() def with_structured_output( self, - schema: Union[Dict, type], # noqa: UP006 + schema: Union[typing.Dict, type], # noqa: UP006 *, include_raw: bool = False, **kwargs: Any, - ) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]: # noqa: UP006 + ) -> Runnable[LanguageModelInput, Union[typing.Dict, BaseModel]]: # noqa: UP006 """Model wrapper that returns outputs formatted to match the given schema. Args: diff --git a/libs/core/langchain_core/language_models/fake.py b/libs/core/langchain_core/language_models/fake.py index 92c506ba0991a..9465a4d93f163 100644 --- a/libs/core/langchain_core/language_models/fake.py +++ b/libs/core/langchain_core/language_models/fake.py @@ -1,6 +1,7 @@ import asyncio import time -from typing import Any, AsyncIterator, Iterator, List, Mapping, Optional +from collections.abc import AsyncIterator, Iterator, Mapping +from typing import Any, Optional from langchain_core.callbacks import ( AsyncCallbackManagerForLLMRun, @@ -14,7 +15,7 @@ class FakeListLLM(LLM): """Fake LLM for testing purposes.""" - responses: List[str] + responses: list[str] """List of responses to return in order.""" # This parameter should be removed from FakeListLLM since # it's only used by sub-classes. @@ -37,7 +38,7 @@ def _llm_type(self) -> str: def _call( self, prompt: str, - stop: Optional[List[str]] = None, + stop: Optional[list[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: @@ -52,7 +53,7 @@ def _call( async def _acall( self, prompt: str, - stop: Optional[List[str]] = None, + stop: Optional[list[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: @@ -90,7 +91,7 @@ def stream( input: LanguageModelInput, config: Optional[RunnableConfig] = None, *, - stop: Optional[List[str]] = None, + stop: Optional[list[str]] = None, **kwargs: Any, ) -> Iterator[str]: result = self.invoke(input, config) @@ -110,7 +111,7 @@ async def astream( input: LanguageModelInput, config: Optional[RunnableConfig] = None, *, - stop: Optional[List[str]] = None, + stop: Optional[list[str]] = None, **kwargs: Any, ) -> AsyncIterator[str]: result = await self.ainvoke(input, config) diff --git a/libs/core/langchain_core/language_models/fake_chat_models.py b/libs/core/langchain_core/language_models/fake_chat_models.py index 23f17d88c741f..3c41c1d462f50 100644 --- a/libs/core/langchain_core/language_models/fake_chat_models.py +++ b/libs/core/langchain_core/language_models/fake_chat_models.py @@ -3,7 +3,8 @@ import asyncio import re import time -from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union, cast +from collections.abc import AsyncIterator, Iterator +from typing import Any, Optional, Union, cast from langchain_core.callbacks import ( AsyncCallbackManagerForLLMRun, @@ -17,7 +18,7 @@ class FakeMessagesListChatModel(BaseChatModel): """Fake ChatModel for testing purposes.""" - responses: List[BaseMessage] + responses: list[BaseMessage] """List of responses to **cycle** through in order.""" sleep: Optional[float] = None """Sleep time in seconds between responses.""" @@ -26,8 +27,8 @@ class FakeMessagesListChatModel(BaseChatModel): def _generate( self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, + messages: list[BaseMessage], + stop: Optional[list[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: @@ -51,7 +52,7 @@ class FakeListChatModelError(Exception): class FakeListChatModel(SimpleChatModel): """Fake ChatModel for testing purposes.""" - responses: List[str] + responses: list[str] """List of responses to **cycle** through in order.""" sleep: Optional[float] = None i: int = 0 @@ -65,8 +66,8 @@ def _llm_type(self) -> str: def _call( self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, + messages: list[BaseMessage], + stop: Optional[list[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: @@ -80,8 +81,8 @@ def _call( def _stream( self, - messages: List[BaseMessage], - stop: Union[List[str], None] = None, + messages: list[BaseMessage], + stop: Union[list[str], None] = None, run_manager: Union[CallbackManagerForLLMRun, None] = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: @@ -103,8 +104,8 @@ def _stream( async def _astream( self, - messages: List[BaseMessage], - stop: Union[List[str], None] = None, + messages: list[BaseMessage], + stop: Union[list[str], None] = None, run_manager: Union[AsyncCallbackManagerForLLMRun, None] = None, **kwargs: Any, ) -> AsyncIterator[ChatGenerationChunk]: @@ -124,7 +125,7 @@ async def _astream( yield ChatGenerationChunk(message=AIMessageChunk(content=c)) @property - def _identifying_params(self) -> Dict[str, Any]: + def _identifying_params(self) -> dict[str, Any]: return {"responses": self.responses} @@ -133,8 +134,8 @@ class FakeChatModel(SimpleChatModel): def _call( self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, + messages: list[BaseMessage], + stop: Optional[list[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: @@ -142,8 +143,8 @@ def _call( async def _agenerate( self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, + messages: list[BaseMessage], + stop: Optional[list[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: @@ -157,7 +158,7 @@ def _llm_type(self) -> str: return "fake-chat-model" @property - def _identifying_params(self) -> Dict[str, Any]: + def _identifying_params(self) -> dict[str, Any]: return {"key": "fake"} @@ -186,8 +187,8 @@ class GenericFakeChatModel(BaseChatModel): def _generate( self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, + messages: list[BaseMessage], + stop: Optional[list[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: @@ -202,8 +203,8 @@ def _generate( def _stream( self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, + messages: list[BaseMessage], + stop: Optional[list[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: @@ -231,7 +232,7 @@ def _stream( # Use a regular expression to split on whitespace with a capture group # so that we can preserve the whitespace in the output. assert isinstance(content, str) - content_chunks = cast(List[str], re.split(r"(\s)", content)) + content_chunks = cast(list[str], re.split(r"(\s)", content)) for token in content_chunks: chunk = ChatGenerationChunk( @@ -249,7 +250,7 @@ def _stream( for fkey, fvalue in value.items(): if isinstance(fvalue, str): # Break function call by `,` - fvalue_chunks = cast(List[str], re.split(r"(,)", fvalue)) + fvalue_chunks = cast(list[str], re.split(r"(,)", fvalue)) for fvalue_chunk in fvalue_chunks: chunk = ChatGenerationChunk( message=AIMessageChunk( @@ -306,8 +307,8 @@ class ParrotFakeChatModel(BaseChatModel): def _generate( self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, + messages: list[BaseMessage], + stop: Optional[list[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: diff --git a/libs/core/langchain_core/language_models/llms.py b/libs/core/langchain_core/language_models/llms.py index c1125f6202b4f..7458efd312075 100644 --- a/libs/core/langchain_core/language_models/llms.py +++ b/libs/core/langchain_core/language_models/llms.py @@ -10,16 +10,12 @@ import uuid import warnings from abc import ABC, abstractmethod +from collections.abc import AsyncIterator, Iterator, Sequence from pathlib import Path from typing import ( Any, - AsyncIterator, Callable, - Dict, - Iterator, - List, Optional, - Sequence, Union, cast, ) @@ -448,7 +444,7 @@ def batch( return [g[0].text for g in llm_result.generations] except Exception as e: if return_exceptions: - return cast(List[str], [e for _ in inputs]) + return cast(list[str], [e for _ in inputs]) else: raise e else: @@ -494,7 +490,7 @@ async def abatch( return [g[0].text for g in llm_result.generations] except Exception as e: if return_exceptions: - return cast(List[str], [e for _ in inputs]) + return cast(list[str], [e for _ in inputs]) else: raise e else: @@ -883,13 +879,13 @@ def generate( assert run_name is None or ( isinstance(run_name, list) and len(run_name) == len(prompts) ) - callbacks = cast(List[Callbacks], callbacks) - tags_list = cast(List[Optional[List[str]]], tags or ([None] * len(prompts))) + callbacks = cast(list[Callbacks], callbacks) + tags_list = cast(list[Optional[list[str]]], tags or ([None] * len(prompts))) metadata_list = cast( - List[Optional[Dict[str, Any]]], metadata or ([{}] * len(prompts)) + list[Optional[dict[str, Any]]], metadata or ([{}] * len(prompts)) ) run_name_list = run_name or cast( - List[Optional[str]], ([None] * len(prompts)) + list[Optional[str]], ([None] * len(prompts)) ) callback_managers = [ CallbackManager.configure( @@ -910,9 +906,9 @@ def generate( cast(Callbacks, callbacks), self.callbacks, self.verbose, - cast(List[str], tags), + cast(list[str], tags), self.tags, - cast(Dict[str, Any], metadata), + cast(dict[str, Any], metadata), self.metadata, ) ] * len(prompts) @@ -1116,13 +1112,13 @@ async def agenerate( assert run_name is None or ( isinstance(run_name, list) and len(run_name) == len(prompts) ) - callbacks = cast(List[Callbacks], callbacks) - tags_list = cast(List[Optional[List[str]]], tags or ([None] * len(prompts))) + callbacks = cast(list[Callbacks], callbacks) + tags_list = cast(list[Optional[list[str]]], tags or ([None] * len(prompts))) metadata_list = cast( - List[Optional[Dict[str, Any]]], metadata or ([{}] * len(prompts)) + list[Optional[dict[str, Any]]], metadata or ([{}] * len(prompts)) ) run_name_list = run_name or cast( - List[Optional[str]], ([None] * len(prompts)) + list[Optional[str]], ([None] * len(prompts)) ) callback_managers = [ AsyncCallbackManager.configure( @@ -1143,9 +1139,9 @@ async def agenerate( cast(Callbacks, callbacks), self.callbacks, self.verbose, - cast(List[str], tags), + cast(list[str], tags), self.tags, - cast(Dict[str, Any], metadata), + cast(dict[str, Any], metadata), self.metadata, ) ] * len(prompts) diff --git a/libs/core/langchain_core/load/load.py b/libs/core/langchain_core/load/load.py index 62a286c558c8a..ba150a5a0095b 100644 --- a/libs/core/langchain_core/load/load.py +++ b/libs/core/langchain_core/load/load.py @@ -1,7 +1,7 @@ import importlib import json import os -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Optional from langchain_core._api import beta from langchain_core.load.mapping import ( @@ -34,11 +34,11 @@ class Reviver: def __init__( self, - secrets_map: Optional[Dict[str, str]] = None, - valid_namespaces: Optional[List[str]] = None, + secrets_map: Optional[dict[str, str]] = None, + valid_namespaces: Optional[list[str]] = None, secrets_from_env: bool = True, additional_import_mappings: Optional[ - Dict[Tuple[str, ...], Tuple[str, ...]] + dict[tuple[str, ...], tuple[str, ...]] ] = None, ) -> None: """Initialize the reviver. @@ -73,7 +73,7 @@ def __init__( else ALL_SERIALIZABLE_MAPPINGS ) - def __call__(self, value: Dict[str, Any]) -> Any: + def __call__(self, value: dict[str, Any]) -> Any: if ( value.get("lc", None) == 1 and value.get("type", None) == "secret" @@ -154,10 +154,10 @@ def __call__(self, value: Dict[str, Any]) -> Any: def loads( text: str, *, - secrets_map: Optional[Dict[str, str]] = None, - valid_namespaces: Optional[List[str]] = None, + secrets_map: Optional[dict[str, str]] = None, + valid_namespaces: Optional[list[str]] = None, secrets_from_env: bool = True, - additional_import_mappings: Optional[Dict[Tuple[str, ...], Tuple[str, ...]]] = None, + additional_import_mappings: Optional[dict[tuple[str, ...], tuple[str, ...]]] = None, ) -> Any: """Revive a LangChain class from a JSON string. Equivalent to `load(json.loads(text))`. @@ -190,10 +190,10 @@ def loads( def load( obj: Any, *, - secrets_map: Optional[Dict[str, str]] = None, - valid_namespaces: Optional[List[str]] = None, + secrets_map: Optional[dict[str, str]] = None, + valid_namespaces: Optional[list[str]] = None, secrets_from_env: bool = True, - additional_import_mappings: Optional[Dict[Tuple[str, ...], Tuple[str, ...]]] = None, + additional_import_mappings: Optional[dict[tuple[str, ...], tuple[str, ...]]] = None, ) -> Any: """Revive a LangChain class from a JSON object. Use this if you already have a parsed JSON object, eg. from `json.load` or `orjson.loads`. diff --git a/libs/core/langchain_core/load/mapping.py b/libs/core/langchain_core/load/mapping.py index 0b63bf247d861..17e50947f3ca5 100644 --- a/libs/core/langchain_core/load/mapping.py +++ b/libs/core/langchain_core/load/mapping.py @@ -18,11 +18,9 @@ version of LangChain where the code was in a different location. """ -from typing import Dict, Tuple - # First value is the value that it is serialized as # Second value is the path to load it from -SERIALIZABLE_MAPPING: Dict[Tuple[str, ...], Tuple[str, ...]] = { +SERIALIZABLE_MAPPING: dict[tuple[str, ...], tuple[str, ...]] = { ("langchain", "schema", "messages", "AIMessage"): ( "langchain_core", "messages", @@ -535,7 +533,7 @@ # Needed for backwards compatibility for old versions of LangChain where things # Were in different place -_OG_SERIALIZABLE_MAPPING: Dict[Tuple[str, ...], Tuple[str, ...]] = { +_OG_SERIALIZABLE_MAPPING: dict[tuple[str, ...], tuple[str, ...]] = { ("langchain", "schema", "AIMessage"): ( "langchain_core", "messages", @@ -583,7 +581,7 @@ # Needed for backwards compatibility for a few versions where we serialized # with langchain_core paths. -OLD_CORE_NAMESPACES_MAPPING: Dict[Tuple[str, ...], Tuple[str, ...]] = { +OLD_CORE_NAMESPACES_MAPPING: dict[tuple[str, ...], tuple[str, ...]] = { ("langchain_core", "messages", "ai", "AIMessage"): ( "langchain_core", "messages", @@ -937,7 +935,7 @@ ), } -_JS_SERIALIZABLE_MAPPING: Dict[Tuple[str, ...], Tuple[str, ...]] = { +_JS_SERIALIZABLE_MAPPING: dict[tuple[str, ...], tuple[str, ...]] = { ("langchain_core", "messages", "AIMessage"): ( "langchain_core", "messages", diff --git a/libs/core/langchain_core/load/serializable.py b/libs/core/langchain_core/load/serializable.py index 5dc5d84727a52..f42ca89211644 100644 --- a/libs/core/langchain_core/load/serializable.py +++ b/libs/core/langchain_core/load/serializable.py @@ -1,8 +1,6 @@ from abc import ABC from typing import ( Any, - Dict, - List, Literal, Optional, TypedDict, @@ -25,9 +23,9 @@ class BaseSerialized(TypedDict): """ lc: int - id: List[str] + id: list[str] name: NotRequired[str] - graph: NotRequired[Dict[str, Any]] + graph: NotRequired[dict[str, Any]] class SerializedConstructor(BaseSerialized): @@ -39,7 +37,7 @@ class SerializedConstructor(BaseSerialized): """ type: Literal["constructor"] - kwargs: Dict[str, Any] + kwargs: dict[str, Any] class SerializedSecret(BaseSerialized): @@ -125,7 +123,7 @@ def is_lc_serializable(cls) -> bool: return False @classmethod - def get_lc_namespace(cls) -> List[str]: + def get_lc_namespace(cls) -> list[str]: """Get the namespace of the langchain object. For example, if the class is `langchain.llms.openai.OpenAI`, then the @@ -134,7 +132,7 @@ def get_lc_namespace(cls) -> List[str]: return cls.__module__.split(".") @property - def lc_secrets(self) -> Dict[str, str]: + def lc_secrets(self) -> dict[str, str]: """A map of constructor argument names to secret ids. For example, @@ -143,7 +141,7 @@ def lc_secrets(self) -> Dict[str, str]: return dict() @property - def lc_attributes(self) -> Dict: + def lc_attributes(self) -> dict: """List of attribute names that should be included in the serialized kwargs. These attributes must be accepted by the constructor. @@ -152,7 +150,7 @@ def lc_attributes(self) -> Dict: return {} @classmethod - def lc_id(cls) -> List[str]: + def lc_id(cls) -> list[str]: """A unique identifier for this class for serialization purposes. The unique identifier is a list of strings that describes the path @@ -315,8 +313,8 @@ def _is_field_useful(inst: Serializable, key: str, value: Any) -> bool: def _replace_secrets( - root: Dict[Any, Any], secrets_map: Dict[str, str] -) -> Dict[Any, Any]: + root: dict[Any, Any], secrets_map: dict[str, str] +) -> dict[Any, Any]: result = root.copy() for path, secret_id in secrets_map.items(): [*parts, last] = path.split(".") @@ -344,7 +342,7 @@ def to_json_not_implemented(obj: object) -> SerializedNotImplemented: Returns: SerializedNotImplemented """ - _id: List[str] = [] + _id: list[str] = [] try: if hasattr(obj, "__name__"): _id = [*obj.__module__.split("."), obj.__name__] diff --git a/libs/core/langchain_core/messages/ai.py b/libs/core/langchain_core/messages/ai.py index b7c4cd9d92150..9564a3bc063d3 100644 --- a/libs/core/langchain_core/messages/ai.py +++ b/libs/core/langchain_core/messages/ai.py @@ -1,5 +1,5 @@ import json -from typing import Any, Dict, List, Literal, Optional, Union +from typing import Any, Literal, Optional, Union from pydantic import model_validator from typing_extensions import Self, TypedDict @@ -69,9 +69,9 @@ class AIMessage(BaseMessage): At the moment, this is ignored by most models. Usage is discouraged. """ - tool_calls: List[ToolCall] = [] + tool_calls: list[ToolCall] = [] """If provided, tool calls associated with the message.""" - invalid_tool_calls: List[InvalidToolCall] = [] + invalid_tool_calls: list[InvalidToolCall] = [] """If provided, tool calls with parsing errors associated with the message.""" usage_metadata: Optional[UsageMetadata] = None """If provided, usage metadata for a message, such as token counts. @@ -83,7 +83,7 @@ class AIMessage(BaseMessage): """The type of the message (used for deserialization). Defaults to "ai".""" def __init__( - self, content: Union[str, List[Union[str, Dict]]], **kwargs: Any + self, content: Union[str, list[Union[str, dict]]], **kwargs: Any ) -> None: """Pass in content as positional arg. @@ -94,7 +94,7 @@ def __init__( super().__init__(content=content, **kwargs) @classmethod - def get_lc_namespace(cls) -> List[str]: + def get_lc_namespace(cls) -> list[str]: """Get the namespace of the langchain object. Returns: @@ -104,7 +104,7 @@ def get_lc_namespace(cls) -> List[str]: return ["langchain", "schema", "messages"] @property - def lc_attributes(self) -> Dict: + def lc_attributes(self) -> dict: """Attrs to be serialized even if they are derived from other init args.""" return { "tool_calls": self.tool_calls, @@ -137,7 +137,7 @@ def _backwards_compat_tool_calls(cls, values: dict) -> Any: # Ensure "type" is properly set on all tool call-like dicts. if tool_calls := values.get("tool_calls"): - updated: List = [] + updated: list = [] for tc in tool_calls: updated.append( create_tool_call(**{k: v for k, v in tc.items() if k != "type"}) @@ -178,7 +178,7 @@ def pretty_repr(self, html: bool = False) -> str: base = super().pretty_repr(html=html) lines = [] - def _format_tool_args(tc: Union[ToolCall, InvalidToolCall]) -> List[str]: + def _format_tool_args(tc: Union[ToolCall, InvalidToolCall]) -> list[str]: lines = [ f" {tc.get('name', 'Tool')} ({tc.get('id')})", f" Call ID: {tc.get('id')}", @@ -218,11 +218,11 @@ class AIMessageChunk(AIMessage, BaseMessageChunk): """The type of the message (used for deserialization). Defaults to "AIMessageChunk".""" - tool_call_chunks: List[ToolCallChunk] = [] + tool_call_chunks: list[ToolCallChunk] = [] """If provided, tool call chunks associated with the message.""" @classmethod - def get_lc_namespace(cls) -> List[str]: + def get_lc_namespace(cls) -> list[str]: """Get the namespace of the langchain object. Returns: @@ -232,7 +232,7 @@ def get_lc_namespace(cls) -> List[str]: return ["langchain", "schema", "messages"] @property - def lc_attributes(self) -> Dict: + def lc_attributes(self) -> dict: """Attrs to be serialized even if they are derived from other init args.""" return { "tool_calls": self.tool_calls, diff --git a/libs/core/langchain_core/messages/base.py b/libs/core/langchain_core/messages/base.py index 574ef1c947d08..96fa595fd0e2b 100644 --- a/libs/core/langchain_core/messages/base.py +++ b/libs/core/langchain_core/messages/base.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Union, cast +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any, Optional, Union, cast from pydantic import ConfigDict, Field, field_validator @@ -143,7 +144,7 @@ def merge_content( merged = [merged] + content # type: ignore elif isinstance(content, list): # If both are lists - merged = merge_lists(cast(List, merged), content) # type: ignore + merged = merge_lists(cast(list, merged), content) # type: ignore # If the first content is a list, and the second content is a string else: # If the last element of the first content is a string diff --git a/libs/core/langchain_core/messages/chat.py b/libs/core/langchain_core/messages/chat.py index 7d6205239b18c..e05be83343a4b 100644 --- a/libs/core/langchain_core/messages/chat.py +++ b/libs/core/langchain_core/messages/chat.py @@ -1,4 +1,4 @@ -from typing import Any, List, Literal +from typing import Any, Literal from langchain_core.messages.base import ( BaseMessage, @@ -18,7 +18,7 @@ class ChatMessage(BaseMessage): """The type of the message (used during serialization). Defaults to "chat".""" @classmethod - def get_lc_namespace(cls) -> List[str]: + def get_lc_namespace(cls) -> list[str]: """Get the namespace of the langchain object. Default is ["langchain", "schema", "messages"]. """ @@ -39,7 +39,7 @@ class ChatMessageChunk(ChatMessage, BaseMessageChunk): Defaults to "ChatMessageChunk".""" @classmethod - def get_lc_namespace(cls) -> List[str]: + def get_lc_namespace(cls) -> list[str]: """Get the namespace of the langchain object. Default is ["langchain", "schema", "messages"]. """ diff --git a/libs/core/langchain_core/messages/function.py b/libs/core/langchain_core/messages/function.py index 86b108d9ad822..448a720935ded 100644 --- a/libs/core/langchain_core/messages/function.py +++ b/libs/core/langchain_core/messages/function.py @@ -1,4 +1,4 @@ -from typing import Any, List, Literal +from typing import Any, Literal from langchain_core.messages.base import ( BaseMessage, @@ -26,7 +26,7 @@ class FunctionMessage(BaseMessage): """The type of the message (used for serialization). Defaults to "function".""" @classmethod - def get_lc_namespace(cls) -> List[str]: + def get_lc_namespace(cls) -> list[str]: """Get the namespace of the langchain object. Default is ["langchain", "schema", "messages"].""" return ["langchain", "schema", "messages"] @@ -46,7 +46,7 @@ class FunctionMessageChunk(FunctionMessage, BaseMessageChunk): Defaults to "FunctionMessageChunk".""" @classmethod - def get_lc_namespace(cls) -> List[str]: + def get_lc_namespace(cls) -> list[str]: """Get the namespace of the langchain object. Default is ["langchain", "schema", "messages"].""" return ["langchain", "schema", "messages"] diff --git a/libs/core/langchain_core/messages/human.py b/libs/core/langchain_core/messages/human.py index 631d5052cb32a..96cbeabbe8a69 100644 --- a/libs/core/langchain_core/messages/human.py +++ b/libs/core/langchain_core/messages/human.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Literal, Union +from typing import Any, Literal, Union from langchain_core.messages.base import BaseMessage, BaseMessageChunk @@ -39,13 +39,13 @@ class HumanMessage(BaseMessage): """The type of the message (used for serialization). Defaults to "human".""" @classmethod - def get_lc_namespace(cls) -> List[str]: + def get_lc_namespace(cls) -> list[str]: """Get the namespace of the langchain object. Default is ["langchain", "schema", "messages"].""" return ["langchain", "schema", "messages"] def __init__( - self, content: Union[str, List[Union[str, Dict]]], **kwargs: Any + self, content: Union[str, list[Union[str, dict]]], **kwargs: Any ) -> None: """Pass in content as positional arg. @@ -70,7 +70,7 @@ class HumanMessageChunk(HumanMessage, BaseMessageChunk): Defaults to "HumanMessageChunk".""" @classmethod - def get_lc_namespace(cls) -> List[str]: + def get_lc_namespace(cls) -> list[str]: """Get the namespace of the langchain object. Default is ["langchain", "schema", "messages"].""" return ["langchain", "schema", "messages"] diff --git a/libs/core/langchain_core/messages/modifier.py b/libs/core/langchain_core/messages/modifier.py index 042e23156c7ae..8a5fb6860d868 100644 --- a/libs/core/langchain_core/messages/modifier.py +++ b/libs/core/langchain_core/messages/modifier.py @@ -1,4 +1,4 @@ -from typing import Any, List, Literal +from typing import Any, Literal from langchain_core.messages.base import BaseMessage @@ -25,7 +25,7 @@ def __init__(self, id: str, **kwargs: Any) -> None: return super().__init__("", id=id, **kwargs) @classmethod - def get_lc_namespace(cls) -> List[str]: + def get_lc_namespace(cls) -> list[str]: """Get the namespace of the langchain object. Default is ["langchain", "schema", "messages"].""" return ["langchain", "schema", "messages"] diff --git a/libs/core/langchain_core/messages/system.py b/libs/core/langchain_core/messages/system.py index 7d94776a50cc5..a182198ad0310 100644 --- a/libs/core/langchain_core/messages/system.py +++ b/libs/core/langchain_core/messages/system.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Literal, Union +from typing import Any, Literal, Union from langchain_core.messages.base import BaseMessage, BaseMessageChunk @@ -33,13 +33,13 @@ class SystemMessage(BaseMessage): """The type of the message (used for serialization). Defaults to "system".""" @classmethod - def get_lc_namespace(cls) -> List[str]: + def get_lc_namespace(cls) -> list[str]: """Get the namespace of the langchain object. Default is ["langchain", "schema", "messages"].""" return ["langchain", "schema", "messages"] def __init__( - self, content: Union[str, List[Union[str, Dict]]], **kwargs: Any + self, content: Union[str, list[Union[str, dict]]], **kwargs: Any ) -> None: """Pass in content as positional arg. @@ -64,7 +64,7 @@ class SystemMessageChunk(SystemMessage, BaseMessageChunk): Defaults to "SystemMessageChunk".""" @classmethod - def get_lc_namespace(cls) -> List[str]: + def get_lc_namespace(cls) -> list[str]: """Get the namespace of the langchain object. Default is ["langchain", "schema", "messages"].""" return ["langchain", "schema", "messages"] diff --git a/libs/core/langchain_core/messages/tool.py b/libs/core/langchain_core/messages/tool.py index be1bc675b1a82..ec221083724d2 100644 --- a/libs/core/langchain_core/messages/tool.py +++ b/libs/core/langchain_core/messages/tool.py @@ -1,5 +1,5 @@ import json -from typing import Any, Dict, List, Literal, Optional, Tuple, Union +from typing import Any, Literal, Optional, Union from uuid import UUID from pydantic import Field, model_validator @@ -78,7 +78,7 @@ class ToolMessage(BaseMessage): """Currently inherited from BaseMessage, but not used.""" @classmethod - def get_lc_namespace(cls) -> List[str]: + def get_lc_namespace(cls) -> list[str]: """Get the namespace of the langchain object. Default is ["langchain", "schema", "messages"].""" return ["langchain", "schema", "messages"] @@ -123,7 +123,7 @@ def coerce_args(cls, values: dict) -> dict: return values def __init__( - self, content: Union[str, List[Union[str, Dict]]], **kwargs: Any + self, content: Union[str, list[Union[str, dict]]], **kwargs: Any ) -> None: super().__init__(content=content, **kwargs) @@ -140,7 +140,7 @@ class ToolMessageChunk(ToolMessage, BaseMessageChunk): type: Literal["ToolMessageChunk"] = "ToolMessageChunk" # type: ignore[assignment] @classmethod - def get_lc_namespace(cls) -> List[str]: + def get_lc_namespace(cls) -> list[str]: """Get the namespace of the langchain object.""" return ["langchain", "schema", "messages"] @@ -187,7 +187,7 @@ class ToolCall(TypedDict): name: str """The name of the tool to be called.""" - args: Dict[str, Any] + args: dict[str, Any] """The arguments to the tool call.""" id: Optional[str] """An identifier associated with the tool call. @@ -198,7 +198,7 @@ class ToolCall(TypedDict): type: NotRequired[Literal["tool_call"]] -def tool_call(*, name: str, args: Dict[str, Any], id: Optional[str]) -> ToolCall: +def tool_call(*, name: str, args: dict[str, Any], id: Optional[str]) -> ToolCall: return ToolCall(name=name, args=args, id=id, type="tool_call") @@ -276,8 +276,8 @@ def invalid_tool_call( def default_tool_parser( - raw_tool_calls: List[dict], -) -> Tuple[List[ToolCall], List[InvalidToolCall]]: + raw_tool_calls: list[dict], +) -> tuple[list[ToolCall], list[InvalidToolCall]]: """Best-effort parsing of tools.""" tool_calls = [] invalid_tool_calls = [] @@ -306,7 +306,7 @@ def default_tool_parser( return tool_calls, invalid_tool_calls -def default_tool_chunk_parser(raw_tool_calls: List[dict]) -> List[ToolCallChunk]: +def default_tool_chunk_parser(raw_tool_calls: list[dict]) -> list[ToolCallChunk]: """Best-effort parsing of tool chunks.""" tool_call_chunks = [] for tool_call in raw_tool_calls: diff --git a/libs/core/langchain_core/messages/utils.py b/libs/core/langchain_core/messages/utils.py index ce105f2049c5d..98cdb71f6da82 100644 --- a/libs/core/langchain_core/messages/utils.py +++ b/libs/core/langchain_core/messages/utils.py @@ -11,25 +11,21 @@ import inspect import json +from collections.abc import Iterable, Sequence from functools import partial from typing import ( TYPE_CHECKING, + Annotated, Any, Callable, - Dict, - Iterable, - List, Literal, Optional, - Sequence, - Tuple, Union, cast, overload, ) from pydantic import Discriminator, Field, Tag -from typing_extensions import Annotated from langchain_core.messages.ai import AIMessage, AIMessageChunk from langchain_core.messages.base import BaseMessage, BaseMessageChunk @@ -198,7 +194,7 @@ def message_chunk_to_message(chunk: BaseMessageChunk) -> BaseMessage: MessageLikeRepresentation = Union[ - BaseMessage, List[str], Tuple[str, str], str, Dict[str, Any] + BaseMessage, list[str], tuple[str, str], str, dict[str, Any] ] diff --git a/libs/core/langchain_core/output_parsers/json.py b/libs/core/langchain_core/output_parsers/json.py index 38c0823de23d9..fdb7c38fc2cd2 100644 --- a/libs/core/langchain_core/output_parsers/json.py +++ b/libs/core/langchain_core/output_parsers/json.py @@ -2,12 +2,11 @@ import json from json import JSONDecodeError -from typing import Any, Optional, TypeVar, Union +from typing import Annotated, Any, Optional, TypeVar, Union import jsonpatch # type: ignore[import] import pydantic from pydantic import SkipValidation -from typing_extensions import Annotated from langchain_core.exceptions import OutputParserException from langchain_core.output_parsers.format_instructions import JSON_FORMAT_INSTRUCTIONS diff --git a/libs/core/langchain_core/output_parsers/list.py b/libs/core/langchain_core/output_parsers/list.py index 65ec406ec0a50..858ba86c79fa2 100644 --- a/libs/core/langchain_core/output_parsers/list.py +++ b/libs/core/langchain_core/output_parsers/list.py @@ -3,8 +3,9 @@ import re from abc import abstractmethod from collections import deque -from typing import AsyncIterator, Iterator, List, TypeVar, Union +from collections.abc import AsyncIterator, Iterator from typing import Optional as Optional +from typing import TypeVar, Union from langchain_core.messages import BaseMessage from langchain_core.output_parsers.transform import BaseTransformOutputParser @@ -29,7 +30,7 @@ def droplastn(iter: Iterator[T], n: int) -> Iterator[T]: yield buffer.popleft() -class ListOutputParser(BaseTransformOutputParser[List[str]]): +class ListOutputParser(BaseTransformOutputParser[list[str]]): """Parse the output of an LLM call to a list.""" @property diff --git a/libs/core/langchain_core/output_parsers/openai_functions.py b/libs/core/langchain_core/output_parsers/openai_functions.py index 4324eac47d989..8e29b4075a2d2 100644 --- a/libs/core/langchain_core/output_parsers/openai_functions.py +++ b/libs/core/langchain_core/output_parsers/openai_functions.py @@ -1,6 +1,7 @@ import copy import json -from typing import Any, Dict, List, Optional, Type, Union +from types import GenericAlias +from typing import Any, Optional, Union import jsonpatch # type: ignore[import] from pydantic import BaseModel, model_validator @@ -20,7 +21,7 @@ class OutputFunctionsParser(BaseGenerationOutputParser[Any]): args_only: bool = True """Whether to only return the arguments to the function call.""" - def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: + def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any: """Parse the result of an LLM call to a JSON object. Args: @@ -72,7 +73,7 @@ def _type(self) -> str: def _diff(self, prev: Optional[Any], next: Any) -> Any: return jsonpatch.make_patch(prev, next).patch - def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: + def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any: """Parse the result of an LLM call to a JSON object. Args: @@ -166,7 +167,7 @@ class JsonKeyOutputFunctionsParser(JsonOutputFunctionsParser): key_name: str """The name of the key to return.""" - def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: + def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any: """Parse the result of an LLM call to a JSON object. Args: @@ -223,7 +224,7 @@ class Dog(BaseModel): result = parser.parse_result([chat_generation]) """ - pydantic_schema: Union[Type[BaseModel], Dict[str, Type[BaseModel]]] + pydantic_schema: Union[type[BaseModel], dict[str, type[BaseModel]]] """The pydantic schema to parse the output with. If multiple schemas are provided, then the function name will be used to @@ -232,7 +233,7 @@ class Dog(BaseModel): @model_validator(mode="before") @classmethod - def validate_schema(cls, values: Dict) -> Any: + def validate_schema(cls, values: dict) -> Any: """Validate the pydantic schema. Args: @@ -246,17 +247,19 @@ def validate_schema(cls, values: Dict) -> Any: """ schema = values["pydantic_schema"] if "args_only" not in values: - values["args_only"] = isinstance(schema, type) and issubclass( - schema, BaseModel + values["args_only"] = ( + isinstance(schema, type) + and not isinstance(schema, GenericAlias) + and issubclass(schema, BaseModel) ) - elif values["args_only"] and isinstance(schema, Dict): + elif values["args_only"] and isinstance(schema, dict): raise ValueError( "If multiple pydantic schemas are provided then args_only should be" " False." ) return values - def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: + def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any: """Parse the result of an LLM call to a JSON object. Args: @@ -292,7 +295,7 @@ class PydanticAttrOutputFunctionsParser(PydanticOutputFunctionsParser): attr_name: str """The name of the attribute to return.""" - def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: + def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any: """Parse the result of an LLM call to a JSON object. Args: diff --git a/libs/core/langchain_core/output_parsers/openai_tools.py b/libs/core/langchain_core/output_parsers/openai_tools.py index 9a803c3d8ccd4..5c624f43cb635 100644 --- a/libs/core/langchain_core/output_parsers/openai_tools.py +++ b/libs/core/langchain_core/output_parsers/openai_tools.py @@ -1,10 +1,9 @@ import copy import json from json import JSONDecodeError -from typing import Any, Dict, List, Optional +from typing import Annotated, Any, Optional from pydantic import SkipValidation, ValidationError -from typing_extensions import Annotated from langchain_core.exceptions import OutputParserException from langchain_core.messages import AIMessage, InvalidToolCall @@ -17,12 +16,12 @@ def parse_tool_call( - raw_tool_call: Dict[str, Any], + raw_tool_call: dict[str, Any], *, partial: bool = False, strict: bool = False, return_id: bool = True, -) -> Optional[Dict[str, Any]]: +) -> Optional[dict[str, Any]]: """Parse a single tool call. Args: @@ -69,7 +68,7 @@ def parse_tool_call( def make_invalid_tool_call( - raw_tool_call: Dict[str, Any], + raw_tool_call: dict[str, Any], error_msg: Optional[str], ) -> InvalidToolCall: """Create an InvalidToolCall from a raw tool call. @@ -90,12 +89,12 @@ def make_invalid_tool_call( def parse_tool_calls( - raw_tool_calls: List[dict], + raw_tool_calls: list[dict], *, partial: bool = False, strict: bool = False, return_id: bool = True, -) -> List[Dict[str, Any]]: +) -> list[dict[str, Any]]: """Parse a list of tool calls. Args: @@ -111,7 +110,7 @@ def parse_tool_calls( Raises: OutputParserException: If any of the tool calls are not valid JSON. """ - final_tools: List[Dict[str, Any]] = [] + final_tools: list[dict[str, Any]] = [] exceptions = [] for tool_call in raw_tool_calls: try: @@ -151,7 +150,7 @@ class JsonOutputToolsParser(BaseCumulativeTransformOutputParser[Any]): If no tool calls are found, None will be returned. """ - def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: + def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any: """Parse the result of an LLM call to a list of tool calls. Args: @@ -217,7 +216,7 @@ class JsonOutputKeyToolsParser(JsonOutputToolsParser): key_name: str """The type of tools to return.""" - def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: + def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any: """Parse the result of an LLM call to a list of tool calls. Args: @@ -254,12 +253,12 @@ def parse_result(self, result: List[Generation], *, partial: bool = False) -> An class PydanticToolsParser(JsonOutputToolsParser): """Parse tools from OpenAI response.""" - tools: Annotated[List[TypeBaseModel], SkipValidation()] + tools: Annotated[list[TypeBaseModel], SkipValidation()] """The tools to parse.""" # TODO: Support more granular streaming of objects. Currently only streams once all # Pydantic object fields are present. - def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: + def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any: """Parse the result of an LLM call to a list of Pydantic objects. Args: diff --git a/libs/core/langchain_core/output_parsers/pydantic.py b/libs/core/langchain_core/output_parsers/pydantic.py index 8d48e98b2b349..081dc47094e2e 100644 --- a/libs/core/langchain_core/output_parsers/pydantic.py +++ b/libs/core/langchain_core/output_parsers/pydantic.py @@ -1,9 +1,8 @@ import json -from typing import Generic, List, Optional, Type +from typing import Annotated, Generic, Optional import pydantic from pydantic import SkipValidation -from typing_extensions import Annotated from langchain_core.exceptions import OutputParserException from langchain_core.output_parsers import JsonOutputParser @@ -18,7 +17,7 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]): """Parse an output using a pydantic model.""" - pydantic_object: Annotated[Type[TBaseModel], SkipValidation()] # type: ignore + pydantic_object: Annotated[type[TBaseModel], SkipValidation()] # type: ignore """The pydantic model to parse.""" def _parse_obj(self, obj: dict) -> TBaseModel: @@ -50,7 +49,7 @@ def _parser_exception( return OutputParserException(msg, llm_output=json_string) def parse_result( - self, result: List[Generation], *, partial: bool = False + self, result: list[Generation], *, partial: bool = False ) -> Optional[TBaseModel]: """Parse the result of an LLM call to a pydantic object. @@ -108,7 +107,7 @@ def _type(self) -> str: return "pydantic" @property - def OutputType(self) -> Type[TBaseModel]: + def OutputType(self) -> type[TBaseModel]: """Return the pydantic model.""" return self.pydantic_object diff --git a/libs/core/langchain_core/output_parsers/string.py b/libs/core/langchain_core/output_parsers/string.py index ef231b893c9b7..1d594dfb91986 100644 --- a/libs/core/langchain_core/output_parsers/string.py +++ b/libs/core/langchain_core/output_parsers/string.py @@ -1,4 +1,3 @@ -from typing import List from typing import Optional as Optional from langchain_core.output_parsers.transform import BaseTransformOutputParser @@ -13,7 +12,7 @@ def is_lc_serializable(cls) -> bool: return True @classmethod - def get_lc_namespace(cls) -> List[str]: + def get_lc_namespace(cls) -> list[str]: """Get the namespace of the langchain object.""" return ["langchain", "schema", "output_parser"] diff --git a/libs/core/langchain_core/output_parsers/transform.py b/libs/core/langchain_core/output_parsers/transform.py index b29088ae4f40c..be29273871032 100644 --- a/libs/core/langchain_core/output_parsers/transform.py +++ b/libs/core/langchain_core/output_parsers/transform.py @@ -1,10 +1,9 @@ from __future__ import annotations +from collections.abc import AsyncIterator, Iterator from typing import ( TYPE_CHECKING, Any, - AsyncIterator, - Iterator, Optional, Union, ) diff --git a/libs/core/langchain_core/output_parsers/xml.py b/libs/core/langchain_core/output_parsers/xml.py index b8355d97c3e08..d7249f8c0db60 100644 --- a/libs/core/langchain_core/output_parsers/xml.py +++ b/libs/core/langchain_core/output_parsers/xml.py @@ -1,7 +1,8 @@ import re import xml import xml.etree.ElementTree as ET -from typing import Any, AsyncIterator, Dict, Iterator, List, Literal, Optional, Union +from collections.abc import AsyncIterator, Iterator +from typing import Any, Literal, Optional, Union from xml.etree.ElementTree import TreeBuilder from langchain_core.exceptions import OutputParserException @@ -57,7 +58,7 @@ def __init__(self, parser: Literal["defusedxml", "xml"]) -> None: _parser = None self.pull_parser = ET.XMLPullParser(["start", "end"], _parser=_parser) self.xml_start_re = re.compile(r"<[a-zA-Z:_]") - self.current_path: List[str] = [] + self.current_path: list[str] = [] self.current_path_has_children = False self.buffer = "" self.xml_started = False @@ -140,7 +141,7 @@ def close(self) -> None: class XMLOutputParser(BaseTransformOutputParser): """Parse an output using xml format.""" - tags: Optional[List[str]] = None + tags: Optional[list[str]] = None encoding_matcher: re.Pattern = re.compile( r"<([^>]*encoding[^>]*)>\n(.*)", re.MULTILINE | re.DOTALL ) @@ -169,7 +170,7 @@ def get_format_instructions(self) -> str: """Return the format instructions for the XML output.""" return XML_FORMAT_INSTRUCTIONS.format(tags=self.tags) - def parse(self, text: str) -> Dict[str, Union[str, List[Any]]]: + def parse(self, text: str) -> dict[str, Union[str, list[Any]]]: """Parse the output of an LLM call. Args: @@ -234,13 +235,13 @@ async def _atransform( yield output streaming_parser.close() - def _root_to_dict(self, root: ET.Element) -> Dict[str, Union[str, List[Any]]]: + def _root_to_dict(self, root: ET.Element) -> dict[str, Union[str, list[Any]]]: """Converts xml tree to python dictionary.""" if root.text and bool(re.search(r"\S", root.text)): # If root text contains any non-whitespace character it # returns {root.tag: root.text} return {root.tag: root.text} - result: Dict = {root.tag: []} + result: dict = {root.tag: []} for child in root: if len(child) == 0: result[root.tag].append({child.tag: child.text}) @@ -253,7 +254,7 @@ def _type(self) -> str: return "xml" -def nested_element(path: List[str], elem: ET.Element) -> Any: +def nested_element(path: list[str], elem: ET.Element) -> Any: """Get nested element from path. Args: diff --git a/libs/core/langchain_core/outputs/chat_result.py b/libs/core/langchain_core/outputs/chat_result.py index 511b13bae3c1c..1de553eab3606 100644 --- a/libs/core/langchain_core/outputs/chat_result.py +++ b/libs/core/langchain_core/outputs/chat_result.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Optional from pydantic import BaseModel @@ -18,7 +18,7 @@ class ChatResult(BaseModel): for more information. """ - generations: List[ChatGeneration] + generations: list[ChatGeneration] """List of the chat generations. Generations is a list to allow for multiple candidate generations for a single diff --git a/libs/core/langchain_core/prompt_values.py b/libs/core/langchain_core/prompt_values.py index 6b9f421890f8f..d8b19fcf23d7b 100644 --- a/libs/core/langchain_core/prompt_values.py +++ b/libs/core/langchain_core/prompt_values.py @@ -7,7 +7,8 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Literal, Sequence, cast +from collections.abc import Sequence +from typing import Literal, cast from typing_extensions import TypedDict diff --git a/libs/core/langchain_core/prompts/base.py b/libs/core/langchain_core/prompts/base.py index 531f40a6770cb..e1c43157d594a 100644 --- a/libs/core/langchain_core/prompts/base.py +++ b/libs/core/langchain_core/prompts/base.py @@ -1,16 +1,16 @@ from __future__ import annotations import json +import typing from abc import ABC, abstractmethod +from collections.abc import Mapping from functools import cached_property from pathlib import Path from typing import ( TYPE_CHECKING, Any, Callable, - Dict, Generic, - Mapping, Optional, TypeVar, Union, @@ -39,7 +39,7 @@ class BasePromptTemplate( - RunnableSerializable[Dict, PromptValue], Generic[FormatOutputType], ABC + RunnableSerializable[dict, PromptValue], Generic[FormatOutputType], ABC ): """Base class for all prompt templates, returning a prompt.""" @@ -50,7 +50,7 @@ class BasePromptTemplate( """optional_variables: A list of the names of the variables for placeholder or MessagePlaceholder that are optional. These variables are auto inferred from the prompt and user need not provide them.""" - input_types: Dict[str, Any] = Field(default_factory=dict, exclude=True) # noqa: UP006 + input_types: typing.Dict[str, Any] = Field(default_factory=dict, exclude=True) # noqa: UP006 """A dictionary of the types of the variables the prompt template expects. If not provided, all variables are assumed to be strings.""" output_parser: Optional[BaseOutputParser] = None @@ -60,7 +60,7 @@ class BasePromptTemplate( Partial variables populate the template so that you don't need to pass them in every time you call the prompt.""" - metadata: Optional[Dict[str, Any]] = None # noqa: UP006 + metadata: Optional[typing.Dict[str, Any]] = None # noqa: UP006 """Metadata to be used for tracing.""" tags: Optional[list[str]] = None """Tags to be used for tracing.""" diff --git a/libs/core/langchain_core/prompts/chat.py b/libs/core/langchain_core/prompts/chat.py index c5e5f1b279230..4c5dc492af07d 100644 --- a/libs/core/langchain_core/prompts/chat.py +++ b/libs/core/langchain_core/prompts/chat.py @@ -3,15 +3,13 @@ from __future__ import annotations from abc import ABC, abstractmethod +from collections.abc import Sequence from pathlib import Path from typing import ( + Annotated, Any, - List, Literal, Optional, - Sequence, - Tuple, - Type, TypedDict, TypeVar, Union, @@ -25,7 +23,6 @@ SkipValidation, model_validator, ) -from typing_extensions import Annotated from langchain_core._api import deprecated from langchain_core.load import Serializable @@ -816,9 +813,9 @@ def pretty_print(self) -> None: MessageLikeRepresentation = Union[ MessageLike, - Tuple[ - Union[str, Type], - Union[str, List[dict], List[object]], + tuple[ + Union[str, type], + Union[str, list[dict], list[object]], ], str, ] @@ -1017,7 +1014,7 @@ def __init__( ), **kwargs, } - cast(Type[ChatPromptTemplate], super()).__init__(messages=_messages, **kwargs) + cast(type[ChatPromptTemplate], super()).__init__(messages=_messages, **kwargs) @classmethod def get_lc_namespace(cls) -> list[str]: @@ -1083,7 +1080,7 @@ def validate_input_variables(cls, values: dict) -> Any: values["partial_variables"][message.variable_name] = [] optional_variables.add(message.variable_name) if message.variable_name not in input_types: - input_types[message.variable_name] = List[AnyMessage] + input_types[message.variable_name] = list[AnyMessage] if "partial_variables" in values: input_vars = input_vars - set(values["partial_variables"]) if optional_variables: diff --git a/libs/core/langchain_core/prompts/few_shot_with_templates.py b/libs/core/langchain_core/prompts/few_shot_with_templates.py index 075eaac85182a..65a417046c64c 100644 --- a/libs/core/langchain_core/prompts/few_shot_with_templates.py +++ b/libs/core/langchain_core/prompts/few_shot_with_templates.py @@ -1,7 +1,7 @@ """Prompt template that contains few shot examples.""" from pathlib import Path -from typing import Any, Dict, List, Optional, Union +from typing import Any, Optional, Union from pydantic import ConfigDict, model_validator from typing_extensions import Self @@ -16,7 +16,7 @@ class FewShotPromptWithTemplates(StringPromptTemplate): """Prompt template that contains few shot examples.""" - examples: Optional[List[dict]] = None + examples: Optional[list[dict]] = None """Examples to format into the prompt. Either this or example_selector should be provided.""" @@ -43,13 +43,13 @@ class FewShotPromptWithTemplates(StringPromptTemplate): """Whether or not to try validating the template.""" @classmethod - def get_lc_namespace(cls) -> List[str]: + def get_lc_namespace(cls) -> list[str]: """Get the namespace of the langchain object.""" return ["langchain", "prompts", "few_shot_with_templates"] @model_validator(mode="before") @classmethod - def check_examples_and_selector(cls, values: Dict) -> Any: + def check_examples_and_selector(cls, values: dict) -> Any: """Check that one and only one of examples/example_selector are provided.""" examples = values.get("examples", None) example_selector = values.get("example_selector", None) @@ -93,7 +93,7 @@ def template_is_valid(self) -> Self: extra="forbid", ) - def _get_examples(self, **kwargs: Any) -> List[dict]: + def _get_examples(self, **kwargs: Any) -> list[dict]: if self.examples is not None: return self.examples elif self.example_selector is not None: @@ -101,7 +101,7 @@ def _get_examples(self, **kwargs: Any) -> List[dict]: else: raise ValueError - async def _aget_examples(self, **kwargs: Any) -> List[dict]: + async def _aget_examples(self, **kwargs: Any) -> list[dict]: if self.examples is not None: return self.examples elif self.example_selector is not None: diff --git a/libs/core/langchain_core/prompts/image.py b/libs/core/langchain_core/prompts/image.py index fdad76cc92bde..f28b2ecb1a50f 100644 --- a/libs/core/langchain_core/prompts/image.py +++ b/libs/core/langchain_core/prompts/image.py @@ -1,4 +1,4 @@ -from typing import Any, List +from typing import Any from pydantic import Field @@ -33,7 +33,7 @@ def _prompt_type(self) -> str: return "image-prompt" @classmethod - def get_lc_namespace(cls) -> List[str]: + def get_lc_namespace(cls) -> list[str]: """Get the namespace of the langchain object.""" return ["langchain", "prompts", "image"] diff --git a/libs/core/langchain_core/prompts/loading.py b/libs/core/langchain_core/prompts/loading.py index 046418be6117f..a4a1843786b17 100644 --- a/libs/core/langchain_core/prompts/loading.py +++ b/libs/core/langchain_core/prompts/loading.py @@ -3,7 +3,7 @@ import json import logging from pathlib import Path -from typing import Callable, Dict, Optional, Union +from typing import Callable, Optional, Union import yaml @@ -181,7 +181,7 @@ def _load_prompt_from_file( return load_prompt_from_config(config) -def _load_chat_prompt(config: Dict) -> ChatPromptTemplate: +def _load_chat_prompt(config: dict) -> ChatPromptTemplate: """Load chat prompt from config""" messages = config.pop("messages") @@ -194,7 +194,7 @@ def _load_chat_prompt(config: Dict) -> ChatPromptTemplate: return ChatPromptTemplate.from_template(template=template, **config) -type_to_loader_dict: Dict[str, Callable[[dict], BasePromptTemplate]] = { +type_to_loader_dict: dict[str, Callable[[dict], BasePromptTemplate]] = { "prompt": _load_prompt, "few_shot": _load_few_shot_prompt, "chat": _load_chat_prompt, diff --git a/libs/core/langchain_core/prompts/pipeline.py b/libs/core/langchain_core/prompts/pipeline.py index c1ddc815f32e2..e25a0a7f72461 100644 --- a/libs/core/langchain_core/prompts/pipeline.py +++ b/libs/core/langchain_core/prompts/pipeline.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Tuple +from typing import Any from typing import Optional as Optional from pydantic import model_validator @@ -8,7 +8,7 @@ from langchain_core.prompts.chat import BaseChatPromptTemplate -def _get_inputs(inputs: dict, input_variables: List[str]) -> dict: +def _get_inputs(inputs: dict, input_variables: list[str]) -> dict: return {k: inputs[k] for k in input_variables} @@ -28,17 +28,17 @@ class PipelinePromptTemplate(BasePromptTemplate): final_prompt: BasePromptTemplate """The final prompt that is returned.""" - pipeline_prompts: List[Tuple[str, BasePromptTemplate]] + pipeline_prompts: list[tuple[str, BasePromptTemplate]] """A list of tuples, consisting of a string (`name`) and a Prompt Template.""" @classmethod - def get_lc_namespace(cls) -> List[str]: + def get_lc_namespace(cls) -> list[str]: """Get the namespace of the langchain object.""" return ["langchain", "prompts", "pipeline"] @model_validator(mode="before") @classmethod - def get_input_variables(cls, values: Dict) -> Any: + def get_input_variables(cls, values: dict) -> Any: """Get input variables.""" created_variables = set() all_variables = set() diff --git a/libs/core/langchain_core/prompts/string.py b/libs/core/langchain_core/prompts/string.py index 9aea16cf75407..bc412cf727487 100644 --- a/libs/core/langchain_core/prompts/string.py +++ b/libs/core/langchain_core/prompts/string.py @@ -5,7 +5,7 @@ import warnings from abc import ABC from string import Formatter -from typing import Any, Callable, Dict +from typing import Any, Callable from pydantic import BaseModel, create_model @@ -139,7 +139,7 @@ def mustache_template_vars( return vars -Defs = Dict[str, "Defs"] +Defs = dict[str, "Defs"] def mustache_schema( diff --git a/libs/core/langchain_core/prompts/structured.py b/libs/core/langchain_core/prompts/structured.py index 12114e80338b4..56450c626bae8 100644 --- a/libs/core/langchain_core/prompts/structured.py +++ b/libs/core/langchain_core/prompts/structured.py @@ -1,13 +1,8 @@ +from collections.abc import Iterator, Mapping, Sequence from typing import ( Any, Callable, - Dict, - Iterator, - List, - Mapping, Optional, - Sequence, - Type, Union, ) @@ -32,16 +27,16 @@ class StructuredPrompt(ChatPromptTemplate): """Structured prompt template for a language model.""" - schema_: Union[Dict, Type[BaseModel]] + schema_: Union[dict, type[BaseModel]] """Schema for the structured prompt.""" - structured_output_kwargs: Dict[str, Any] = Field(default_factory=dict) + structured_output_kwargs: dict[str, Any] = Field(default_factory=dict) def __init__( self, messages: Sequence[MessageLikeRepresentation], - schema_: Optional[Union[Dict, Type[BaseModel]]] = None, + schema_: Optional[Union[dict, type[BaseModel]]] = None, *, - structured_output_kwargs: Optional[Dict[str, Any]] = None, + structured_output_kwargs: Optional[dict[str, Any]] = None, **kwargs: Any, ) -> None: schema_ = schema_ or kwargs.pop("schema") @@ -56,7 +51,7 @@ def __init__( ) @classmethod - def get_lc_namespace(cls) -> List[str]: + def get_lc_namespace(cls) -> list[str]: """Get the namespace of the langchain object. For example, if the class is `langchain.llms.openai.OpenAI`, then the @@ -68,7 +63,7 @@ def get_lc_namespace(cls) -> List[str]: def from_messages_and_schema( cls, messages: Sequence[MessageLikeRepresentation], - schema: Union[Dict, Type[BaseModel]], + schema: Union[dict, type[BaseModel]], **kwargs: Any, ) -> ChatPromptTemplate: """Create a chat prompt template from a variety of message formats. @@ -118,7 +113,7 @@ def __or__( Callable[[Iterator[Any]], Iterator[Other]], Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]], ], - ) -> RunnableSerializable[Dict, Other]: + ) -> RunnableSerializable[dict, Other]: return self.pipe(other) def pipe( @@ -130,7 +125,7 @@ def pipe( Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]], ], name: Optional[str] = None, - ) -> RunnableSerializable[Dict, Other]: + ) -> RunnableSerializable[dict, Other]: """Pipe the structured prompt to a language model. Args: diff --git a/libs/core/langchain_core/retrievers.py b/libs/core/langchain_core/retrievers.py index 3d7f2493d7afd..aa977e6233484 100644 --- a/libs/core/langchain_core/retrievers.py +++ b/libs/core/langchain_core/retrievers.py @@ -24,7 +24,7 @@ import warnings from abc import ABC, abstractmethod from inspect import signature -from typing import TYPE_CHECKING, Any, List, Optional +from typing import TYPE_CHECKING, Any, Optional from pydantic import ConfigDict from typing_extensions import TypedDict @@ -47,7 +47,7 @@ ) RetrieverInput = str -RetrieverOutput = List[Document] +RetrieverOutput = list[Document] RetrieverLike = Runnable[RetrieverInput, RetrieverOutput] RetrieverOutputLike = Runnable[Any, RetrieverOutput] diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index 3bcc4b37c5d20..1960979317c28 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -6,36 +6,37 @@ import inspect import threading from abc import ABC, abstractmethod +from collections.abc import ( + AsyncGenerator, + AsyncIterator, + Awaitable, + Coroutine, + Iterator, + Mapping, + Sequence, +) from concurrent.futures import FIRST_COMPLETED, wait from contextvars import copy_context from functools import wraps from itertools import groupby, tee from operator import itemgetter +from types import GenericAlias from typing import ( TYPE_CHECKING, Any, - AsyncGenerator, - AsyncIterator, - Awaitable, Callable, - Coroutine, - Dict, Generic, - Iterator, - List, - Mapping, Optional, Protocol, - Sequence, - Type, TypeVar, Union, cast, + get_type_hints, overload, ) from pydantic import BaseModel, ConfigDict, Field, RootModel -from typing_extensions import Literal, get_args, get_type_hints +from typing_extensions import Literal, get_args from langchain_core._api import beta_decorator from langchain_core.load.serializable import ( @@ -340,7 +341,11 @@ def get_input_schema( """ root_type = self.InputType - if inspect.isclass(root_type) and issubclass(root_type, BaseModel): + if ( + inspect.isclass(root_type) + and not isinstance(root_type, GenericAlias) + and issubclass(root_type, BaseModel) + ): return root_type return create_model_v2( @@ -408,7 +413,11 @@ def get_output_schema( """ root_type = self.OutputType - if inspect.isclass(root_type) and issubclass(root_type, BaseModel): + if ( + inspect.isclass(root_type) + and not isinstance(root_type, GenericAlias) + and issubclass(root_type, BaseModel) + ): return root_type return create_model_v2( @@ -771,10 +780,10 @@ def invoke(input: Input, config: RunnableConfig) -> Union[Output, Exception]: # If there's only one input, don't bother with the executor if len(inputs) == 1: - return cast(List[Output], [invoke(inputs[0], configs[0])]) + return cast(list[Output], [invoke(inputs[0], configs[0])]) with get_executor_for_config(configs[0]) as executor: - return cast(List[Output], list(executor.map(invoke, inputs, configs))) + return cast(list[Output], list(executor.map(invoke, inputs, configs))) @overload def batch_as_completed( @@ -2024,7 +2033,7 @@ def _batch_with_config( for run_manager in run_managers: run_manager.on_chain_error(e) if return_exceptions: - return cast(List[Output], [e for _ in input]) + return cast(list[Output], [e for _ in input]) else: raise else: @@ -2036,7 +2045,7 @@ def _batch_with_config( else: run_manager.on_chain_end(out) if return_exceptions or first_exception is None: - return cast(List[Output], output) + return cast(list[Output], output) else: raise first_exception @@ -2099,7 +2108,7 @@ async def _abatch_with_config( *(run_manager.on_chain_error(e) for run_manager in run_managers) ) if return_exceptions: - return cast(List[Output], [e for _ in input]) + return cast(list[Output], [e for _ in input]) else: raise else: @@ -2113,7 +2122,7 @@ async def _abatch_with_config( coros.append(run_manager.on_chain_end(out)) await asyncio.gather(*coros) if return_exceptions or first_exception is None: - return cast(List[Output], output) + return cast(list[Output], output) else: raise first_exception @@ -3171,7 +3180,7 @@ def batch( for rm in run_managers: rm.on_chain_error(e) if return_exceptions: - return cast(List[Output], [e for _ in inputs]) + return cast(list[Output], [e for _ in inputs]) else: raise else: @@ -3183,7 +3192,7 @@ def batch( else: run_manager.on_chain_end(out) if return_exceptions or first_exception is None: - return cast(List[Output], inputs) + return cast(list[Output], inputs) else: raise first_exception @@ -3298,7 +3307,7 @@ async def abatch( except BaseException as e: await asyncio.gather(*(rm.on_chain_error(e) for rm in run_managers)) if return_exceptions: - return cast(List[Output], [e for _ in inputs]) + return cast(list[Output], [e for _ in inputs]) else: raise else: @@ -3312,7 +3321,7 @@ async def abatch( coros.append(run_manager.on_chain_end(out)) await asyncio.gather(*coros) if return_exceptions or first_exception is None: - return cast(List[Output], inputs) + return cast(list[Output], inputs) else: raise first_exception @@ -3420,7 +3429,7 @@ async def input_aiter() -> AsyncIterator[Input]: yield chunk -class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]): +class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]): """Runnable that runs a mapping of Runnables in parallel, and returns a mapping of their outputs. @@ -4071,7 +4080,11 @@ def get_input_schema( func = getattr(self, "_transform", None) or self._atransform module = getattr(func, "__module__", None) - if inspect.isclass(root_type) and issubclass(root_type, BaseModel): + if ( + inspect.isclass(root_type) + and not isinstance(root_type, GenericAlias) + and issubclass(root_type, BaseModel) + ): return root_type return create_model_v2( @@ -4106,7 +4119,11 @@ def get_output_schema( func = getattr(self, "_transform", None) or self._atransform module = getattr(func, "__module__", None) - if inspect.isclass(root_type) and issubclass(root_type, BaseModel): + if ( + inspect.isclass(root_type) + and not isinstance(root_type, GenericAlias) + and issubclass(root_type, BaseModel) + ): return root_type return create_model_v2( @@ -4369,7 +4386,7 @@ def get_input_schema( module = getattr(func, "__module__", None) return create_model_v2( self.get_name("Input"), - root=List[Any], + root=list[Any], # To create the schema, we need to provide the module # where the underlying function is defined. # This allows pydantic to resolve type annotations appropriately. @@ -4420,7 +4437,11 @@ def get_output_schema( func = getattr(self, "func", None) or self.afunc module = getattr(func, "__module__", None) - if inspect.isclass(root_type) and issubclass(root_type, BaseModel): + if ( + inspect.isclass(root_type) + and not isinstance(root_type, GenericAlias) + and issubclass(root_type, BaseModel) + ): return root_type return create_model_v2( @@ -4921,7 +4942,7 @@ async def input_aiter() -> AsyncIterator[Input]: yield chunk -class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]): +class RunnableEachBase(RunnableSerializable[list[Input], list[Output]]): """Runnable that delegates calls to another Runnable with each element of the input sequence. @@ -4938,7 +4959,7 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]): @property def InputType(self) -> Any: - return List[self.bound.InputType] # type: ignore[name-defined] + return list[self.bound.InputType] # type: ignore[name-defined] def get_input_schema( self, config: Optional[RunnableConfig] = None @@ -4946,7 +4967,7 @@ def get_input_schema( return create_model_v2( self.get_name("Input"), root=( - List[self.bound.get_input_schema(config)], # type: ignore + list[self.bound.get_input_schema(config)], # type: ignore None, ), # create model needs access to appropriate type annotations to be @@ -4961,7 +4982,7 @@ def get_input_schema( @property def OutputType(self) -> type[list[Output]]: - return List[self.bound.OutputType] # type: ignore[name-defined] + return list[self.bound.OutputType] # type: ignore[name-defined] def get_output_schema( self, config: Optional[RunnableConfig] = None @@ -4969,7 +4990,7 @@ def get_output_schema( schema = self.bound.get_output_schema(config) return create_model_v2( self.get_name("Output"), - root=List[schema], # type: ignore[valid-type] + root=list[schema], # type: ignore[valid-type] # create model needs access to appropriate type annotations to be # able to construct the pydantic model. # When we create the model, we pass information about the namespace @@ -5255,7 +5276,7 @@ def get_name( @property def InputType(self) -> type[Input]: return ( - cast(Type[Input], self.custom_input_type) + cast(type[Input], self.custom_input_type) if self.custom_input_type is not None else self.bound.InputType ) @@ -5263,7 +5284,7 @@ def InputType(self) -> type[Input]: @property def OutputType(self) -> type[Output]: return ( - cast(Type[Output], self.custom_output_type) + cast(type[Output], self.custom_output_type) if self.custom_output_type is not None else self.bound.OutputType ) @@ -5336,7 +5357,7 @@ def batch( ) -> list[Output]: if isinstance(config, list): configs = cast( - List[RunnableConfig], + list[RunnableConfig], [self._merge_configs(conf) for conf in config], ) else: @@ -5358,7 +5379,7 @@ async def abatch( ) -> list[Output]: if isinstance(config, list): configs = cast( - List[RunnableConfig], + list[RunnableConfig], [self._merge_configs(conf) for conf in config], ) else: @@ -5400,7 +5421,7 @@ def batch_as_completed( ) -> Iterator[tuple[int, Union[Output, Exception]]]: if isinstance(config, Sequence): configs = cast( - List[RunnableConfig], + list[RunnableConfig], [self._merge_configs(conf) for conf in config], ) else: @@ -5451,7 +5472,7 @@ async def abatch_as_completed( ) -> AsyncIterator[tuple[int, Union[Output, Exception]]]: if isinstance(config, Sequence): configs = cast( - List[RunnableConfig], + list[RunnableConfig], [self._merge_configs(conf) for conf in config], ) else: diff --git a/libs/core/langchain_core/runnables/branch.py b/libs/core/langchain_core/runnables/branch.py index 268fc637f527f..b12fd416e0345 100644 --- a/libs/core/langchain_core/runnables/branch.py +++ b/libs/core/langchain_core/runnables/branch.py @@ -1,15 +1,8 @@ +from collections.abc import AsyncIterator, Awaitable, Iterator, Mapping, Sequence from typing import ( Any, - AsyncIterator, - Awaitable, Callable, - Iterator, - List, - Mapping, Optional, - Sequence, - Tuple, - Type, Union, cast, ) @@ -69,13 +62,13 @@ class RunnableBranch(RunnableSerializable[Input, Output]): branch.invoke(None) # "goodbye" """ - branches: Sequence[Tuple[Runnable[Input, bool], Runnable[Input, Output]]] + branches: Sequence[tuple[Runnable[Input, bool], Runnable[Input, Output]]] default: Runnable[Input, Output] def __init__( self, *branches: Union[ - Tuple[ + tuple[ Union[ Runnable[Input, bool], Callable[[Input], bool], @@ -149,13 +142,13 @@ def is_lc_serializable(cls) -> bool: return True @classmethod - def get_lc_namespace(cls) -> List[str]: + def get_lc_namespace(cls) -> list[str]: """Get the namespace of the langchain object.""" return ["langchain", "schema", "runnable"] def get_input_schema( self, config: Optional[RunnableConfig] = None - ) -> Type[BaseModel]: + ) -> type[BaseModel]: runnables = ( [self.default] + [r for _, r in self.branches] @@ -172,7 +165,7 @@ def get_input_schema( return super().get_input_schema(config) @property - def config_specs(self) -> List[ConfigurableFieldSpec]: + def config_specs(self) -> list[ConfigurableFieldSpec]: from langchain_core.beta.runnables.context import ( CONTEXT_CONFIG_PREFIX, CONTEXT_CONFIG_SUFFIX_SET, diff --git a/libs/core/langchain_core/runnables/config.py b/libs/core/langchain_core/runnables/config.py index 69e58f194ae77..2e7efccc2e674 100644 --- a/libs/core/langchain_core/runnables/config.py +++ b/libs/core/langchain_core/runnables/config.py @@ -3,6 +3,7 @@ import asyncio import uuid import warnings +from collections.abc import Awaitable, Generator, Iterable, Iterator, Sequence from concurrent.futures import Executor, Future, ThreadPoolExecutor from contextlib import contextmanager from contextvars import ContextVar, copy_context @@ -10,14 +11,8 @@ from typing import ( TYPE_CHECKING, Any, - Awaitable, Callable, - Generator, - Iterable, - Iterator, - List, Optional, - Sequence, TypeVar, Union, cast, @@ -43,7 +38,7 @@ else: # Pydantic validates through typed dicts, but # the callbacks need forward refs updated - Callbacks = Optional[Union[List, Any]] + Callbacks = Optional[Union[list, Any]] class EmptyDict(TypedDict, total=False): diff --git a/libs/core/langchain_core/runnables/configurable.py b/libs/core/langchain_core/runnables/configurable.py index 7f0c36687a643..373431c1bdcc1 100644 --- a/libs/core/langchain_core/runnables/configurable.py +++ b/libs/core/langchain_core/runnables/configurable.py @@ -3,20 +3,16 @@ import enum import threading from abc import abstractmethod +from collections.abc import AsyncIterator, Iterator, Sequence +from collections.abc import Mapping as Mapping from functools import wraps from typing import ( Any, - AsyncIterator, Callable, - Iterator, - List, Optional, - Sequence, - Type, Union, cast, ) -from typing import Mapping as Mapping from weakref import WeakValueDictionary from pydantic import BaseModel, ConfigDict @@ -176,10 +172,10 @@ def invoke( # If there's only one input, don't bother with the executor if len(inputs) == 1: - return cast(List[Output], [invoke(prepared[0], inputs[0])]) + return cast(list[Output], [invoke(prepared[0], inputs[0])]) with get_executor_for_config(configs[0]) as executor: - return cast(List[Output], list(executor.map(invoke, prepared, inputs))) + return cast(list[Output], list(executor.map(invoke, prepared, inputs))) async def abatch( self, @@ -562,7 +558,7 @@ def config_specs(self) -> list[ConfigurableFieldSpec]: for v in list(self.alternatives.keys()) + [self.default_key] ), ) - _enums_for_spec[self.which] = cast(Type[StrEnum], which_enum) + _enums_for_spec[self.which] = cast(type[StrEnum], which_enum) return get_unique_config_specs( # which alternative [ @@ -694,7 +690,7 @@ def make_options_spec( spec.name or spec.id, ((v, v) for v in list(spec.options.keys())), ) - _enums_for_spec[spec] = cast(Type[StrEnum], enum) + _enums_for_spec[spec] = cast(type[StrEnum], enum) if isinstance(spec, ConfigurableFieldSingleOption): return ConfigurableFieldSpec( id=spec.id, diff --git a/libs/core/langchain_core/runnables/fallbacks.py b/libs/core/langchain_core/runnables/fallbacks.py index ef97aaf770c39..c97ca9451bc86 100644 --- a/libs/core/langchain_core/runnables/fallbacks.py +++ b/libs/core/langchain_core/runnables/fallbacks.py @@ -1,19 +1,13 @@ import asyncio import inspect import typing +from collections.abc import AsyncIterator, Iterator, Sequence from contextvars import copy_context from functools import wraps from typing import ( TYPE_CHECKING, Any, - AsyncIterator, - Dict, - Iterator, - List, Optional, - Sequence, - Tuple, - Type, Union, cast, ) @@ -96,7 +90,7 @@ def when_all_is_lost(inputs): """The Runnable to run first.""" fallbacks: Sequence[Runnable[Input, Output]] """A sequence of fallbacks to try.""" - exceptions_to_handle: Tuple[Type[BaseException], ...] = (Exception,) + exceptions_to_handle: tuple[type[BaseException], ...] = (Exception,) """The exceptions on which fallbacks should be tried. Any exception that is not a subclass of these exceptions will be raised immediately. @@ -112,25 +106,25 @@ def when_all_is_lost(inputs): ) @property - def InputType(self) -> Type[Input]: + def InputType(self) -> type[Input]: return self.runnable.InputType @property - def OutputType(self) -> Type[Output]: + def OutputType(self) -> type[Output]: return self.runnable.OutputType def get_input_schema( self, config: Optional[RunnableConfig] = None - ) -> Type[BaseModel]: + ) -> type[BaseModel]: return self.runnable.get_input_schema(config) def get_output_schema( self, config: Optional[RunnableConfig] = None - ) -> Type[BaseModel]: + ) -> type[BaseModel]: return self.runnable.get_output_schema(config) @property - def config_specs(self) -> List[ConfigurableFieldSpec]: + def config_specs(self) -> list[ConfigurableFieldSpec]: return get_unique_config_specs( spec for step in [self.runnable, *self.fallbacks] @@ -142,7 +136,7 @@ def is_lc_serializable(cls) -> bool: return True @classmethod - def get_lc_namespace(cls) -> List[str]: + def get_lc_namespace(cls) -> list[str]: """Get the namespace of the langchain object.""" return ["langchain", "schema", "runnable"] @@ -252,12 +246,12 @@ async def ainvoke( def batch( self, - inputs: List[Input], - config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + inputs: list[Input], + config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None, *, return_exceptions: bool = False, **kwargs: Optional[Any], - ) -> List[Output]: + ) -> list[Output]: from langchain_core.callbacks.manager import CallbackManager if self.exception_key is not None and not all( @@ -296,9 +290,9 @@ def batch( for cm, input, config in zip(callback_managers, inputs, configs) ] - to_return: Dict[int, Any] = {} + to_return: dict[int, Any] = {} run_again = {i: input for i, input in enumerate(inputs)} - handled_exceptions: Dict[int, BaseException] = {} + handled_exceptions: dict[int, BaseException] = {} first_to_raise = None for runnable in self.runnables: outputs = runnable.batch( @@ -344,12 +338,12 @@ def batch( async def abatch( self, - inputs: List[Input], - config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + inputs: list[Input], + config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None, *, return_exceptions: bool = False, **kwargs: Optional[Any], - ) -> List[Output]: + ) -> list[Output]: from langchain_core.callbacks.manager import AsyncCallbackManager if self.exception_key is not None and not all( @@ -378,7 +372,7 @@ async def abatch( for config in configs ] # start the root runs, one per input - run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather( + run_managers: list[AsyncCallbackManagerForChainRun] = await asyncio.gather( *( cm.on_chain_start( None, @@ -392,7 +386,7 @@ async def abatch( to_return = {} run_again = {i: input for i, input in enumerate(inputs)} - handled_exceptions: Dict[int, BaseException] = {} + handled_exceptions: dict[int, BaseException] = {} first_to_raise = None for runnable in self.runnables: outputs = await runnable.abatch( diff --git a/libs/core/langchain_core/runnables/graph.py b/libs/core/langchain_core/runnables/graph.py index bcf64c95fc1db..baec005e5ac23 100644 --- a/libs/core/langchain_core/runnables/graph.py +++ b/libs/core/langchain_core/runnables/graph.py @@ -2,6 +2,7 @@ import inspect from collections import Counter +from collections.abc import Sequence from dataclasses import dataclass, field from enum import Enum from typing import ( @@ -11,7 +12,6 @@ NamedTuple, Optional, Protocol, - Sequence, TypedDict, Union, overload, diff --git a/libs/core/langchain_core/runnables/graph_ascii.py b/libs/core/langchain_core/runnables/graph_ascii.py index 46677213f81c3..5aa05c7ec20ae 100644 --- a/libs/core/langchain_core/runnables/graph_ascii.py +++ b/libs/core/langchain_core/runnables/graph_ascii.py @@ -3,7 +3,8 @@ import math import os -from typing import Any, Mapping, Sequence +from collections.abc import Mapping, Sequence +from typing import Any from langchain_core.runnables.graph import Edge as LangEdge diff --git a/libs/core/langchain_core/runnables/graph_mermaid.py b/libs/core/langchain_core/runnables/graph_mermaid.py index 1e05a5af610f1..9eb83b6a537da 100644 --- a/libs/core/langchain_core/runnables/graph_mermaid.py +++ b/libs/core/langchain_core/runnables/graph_mermaid.py @@ -1,7 +1,7 @@ import base64 import re from dataclasses import asdict -from typing import Dict, List, Optional +from typing import Optional from langchain_core.runnables.graph import ( CurveStyle, @@ -15,8 +15,8 @@ def draw_mermaid( - nodes: Dict[str, Node], - edges: List[Edge], + nodes: dict[str, Node], + edges: list[Edge], *, first_node: Optional[str] = None, last_node: Optional[str] = None, @@ -87,7 +87,7 @@ def draw_mermaid( mermaid_graph += f"\t{node_label}\n" # Group edges by their common prefixes - edge_groups: Dict[str, List[Edge]] = {} + edge_groups: dict[str, list[Edge]] = {} for edge in edges: src_parts = edge.source.split(":") tgt_parts = edge.target.split(":") @@ -98,7 +98,7 @@ def draw_mermaid( seen_subgraphs = set() - def add_subgraph(edges: List[Edge], prefix: str) -> None: + def add_subgraph(edges: list[Edge], prefix: str) -> None: nonlocal mermaid_graph self_loop = len(edges) == 1 and edges[0].source == edges[0].target if prefix and not self_loop: diff --git a/libs/core/langchain_core/runnables/history.py b/libs/core/langchain_core/runnables/history.py index 3d405bebefa72..6b60c3b3d127c 100644 --- a/libs/core/langchain_core/runnables/history.py +++ b/libs/core/langchain_core/runnables/history.py @@ -1,13 +1,13 @@ from __future__ import annotations import inspect +from collections.abc import Sequence +from types import GenericAlias from typing import ( TYPE_CHECKING, Any, Callable, - Dict, Optional, - Sequence, Union, ) @@ -31,7 +31,7 @@ from langchain_core.tracers.schemas import Run -MessagesOrDictWithMessages = Union[Sequence["BaseMessage"], Dict[str, Any]] +MessagesOrDictWithMessages = Union[Sequence["BaseMessage"], dict[str, Any]] GetSessionHistoryCallable = Callable[..., BaseChatMessageHistory] @@ -419,7 +419,11 @@ def get_output_schema( """ root_type = self.OutputType - if inspect.isclass(root_type) and issubclass(root_type, BaseModel): + if ( + inspect.isclass(root_type) + and not isinstance(root_type, GenericAlias) + and issubclass(root_type, BaseModel) + ): return root_type return create_model_v2( diff --git a/libs/core/langchain_core/runnables/passthrough.py b/libs/core/langchain_core/runnables/passthrough.py index 3295374f6a2e2..68ca7e1c93d4f 100644 --- a/libs/core/langchain_core/runnables/passthrough.py +++ b/libs/core/langchain_core/runnables/passthrough.py @@ -5,15 +5,11 @@ import asyncio import inspect import threading +from collections.abc import AsyncIterator, Awaitable, Iterator, Mapping from typing import ( TYPE_CHECKING, Any, - AsyncIterator, - Awaitable, Callable, - Dict, - Iterator, - Mapping, Optional, Union, cast, @@ -349,7 +345,7 @@ async def input_aiter() -> AsyncIterator[Other]: _graph_passthrough: RunnablePassthrough = RunnablePassthrough() -class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]): +class RunnableAssign(RunnableSerializable[dict[str, Any], dict[str, Any]]): """Runnable that assigns key-value pairs to Dict[str, Any] inputs. The `RunnableAssign` class takes input dictionaries and, through a @@ -564,7 +560,7 @@ def _transform( if filtered: yield filtered # yield map output - yield cast(Dict[str, Any], first_map_chunk_future.result()) + yield cast(dict[str, Any], first_map_chunk_future.result()) for chunk in map_output: yield chunk @@ -650,7 +646,7 @@ async def input_aiter() -> AsyncIterator[dict[str, Any]]: yield chunk -class RunnablePick(RunnableSerializable[Dict[str, Any], Dict[str, Any]]): +class RunnablePick(RunnableSerializable[dict[str, Any], dict[str, Any]]): """Runnable that picks keys from Dict[str, Any] inputs. RunnablePick class represents a Runnable that selectively picks keys from a diff --git a/libs/core/langchain_core/runnables/retry.py b/libs/core/langchain_core/runnables/retry.py index 4c4e1970ca579..af93fac8753d6 100644 --- a/libs/core/langchain_core/runnables/retry.py +++ b/libs/core/langchain_core/runnables/retry.py @@ -1,11 +1,7 @@ from typing import ( TYPE_CHECKING, Any, - Dict, - List, Optional, - Tuple, - Type, TypeVar, Union, cast, @@ -98,7 +94,7 @@ def foo(input) -> None: retryable_chain = chain.with_retry() """ - retry_exception_types: Tuple[Type[BaseException], ...] = (Exception,) + retry_exception_types: tuple[type[BaseException], ...] = (Exception,) """The exception types to retry on. By default all exceptions are retried. In general you should only retry on exceptions that are likely to be @@ -115,13 +111,13 @@ def foo(input) -> None: """The maximum number of attempts to retry the Runnable.""" @classmethod - def get_lc_namespace(cls) -> List[str]: + def get_lc_namespace(cls) -> list[str]: """Get the namespace of the langchain object.""" return ["langchain", "schema", "runnable"] @property - def _kwargs_retrying(self) -> Dict[str, Any]: - kwargs: Dict[str, Any] = dict() + def _kwargs_retrying(self) -> dict[str, Any]: + kwargs: dict[str, Any] = dict() if self.max_attempt_number: kwargs["stop"] = stop_after_attempt(self.max_attempt_number) @@ -152,10 +148,10 @@ def _patch_config( def _patch_config_list( self, - config: List[RunnableConfig], - run_manager: List["T"], + config: list[RunnableConfig], + run_manager: list["T"], retry_state: RetryCallState, - ) -> List[RunnableConfig]: + ) -> list[RunnableConfig]: return [ self._patch_config(c, rm, retry_state) for c, rm in zip(config, run_manager) ] @@ -208,17 +204,17 @@ async def ainvoke( def _batch( self, - inputs: List[Input], - run_manager: List["CallbackManagerForChainRun"], - config: List[RunnableConfig], + inputs: list[Input], + run_manager: list["CallbackManagerForChainRun"], + config: list[RunnableConfig], **kwargs: Any, - ) -> List[Union[Output, Exception]]: - results_map: Dict[int, Output] = {} + ) -> list[Union[Output, Exception]]: + results_map: dict[int, Output] = {} - def pending(iterable: List[U]) -> List[U]: + def pending(iterable: list[U]) -> list[U]: return [item for idx, item in enumerate(iterable) if idx not in results_map] - not_set: List[Output] = [] + not_set: list[Output] = [] result = not_set try: for attempt in self._sync_retrying(): @@ -250,9 +246,9 @@ def pending(iterable: List[U]) -> List[U]: attempt.retry_state.set_result(result) except RetryError as e: if result is not_set: - result = cast(List[Output], [e] * len(inputs)) + result = cast(list[Output], [e] * len(inputs)) - outputs: List[Union[Output, Exception]] = [] + outputs: list[Union[Output, Exception]] = [] for idx, _ in enumerate(inputs): if idx in results_map: outputs.append(results_map[idx]) @@ -262,29 +258,29 @@ def pending(iterable: List[U]) -> List[U]: def batch( self, - inputs: List[Input], - config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + inputs: list[Input], + config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None, *, return_exceptions: bool = False, **kwargs: Any, - ) -> List[Output]: + ) -> list[Output]: return self._batch_with_config( self._batch, inputs, config, return_exceptions=return_exceptions, **kwargs ) async def _abatch( self, - inputs: List[Input], - run_manager: List["AsyncCallbackManagerForChainRun"], - config: List[RunnableConfig], + inputs: list[Input], + run_manager: list["AsyncCallbackManagerForChainRun"], + config: list[RunnableConfig], **kwargs: Any, - ) -> List[Union[Output, Exception]]: - results_map: Dict[int, Output] = {} + ) -> list[Union[Output, Exception]]: + results_map: dict[int, Output] = {} - def pending(iterable: List[U]) -> List[U]: + def pending(iterable: list[U]) -> list[U]: return [item for idx, item in enumerate(iterable) if idx not in results_map] - not_set: List[Output] = [] + not_set: list[Output] = [] result = not_set try: async for attempt in self._async_retrying(): @@ -316,9 +312,9 @@ def pending(iterable: List[U]) -> List[U]: attempt.retry_state.set_result(result) except RetryError as e: if result is not_set: - result = cast(List[Output], [e] * len(inputs)) + result = cast(list[Output], [e] * len(inputs)) - outputs: List[Union[Output, Exception]] = [] + outputs: list[Union[Output, Exception]] = [] for idx, _ in enumerate(inputs): if idx in results_map: outputs.append(results_map[idx]) @@ -328,12 +324,12 @@ def pending(iterable: List[U]) -> List[U]: async def abatch( self, - inputs: List[Input], - config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + inputs: list[Input], + config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None, *, return_exceptions: bool = False, **kwargs: Any, - ) -> List[Output]: + ) -> list[Output]: return await self._abatch_with_config( self._abatch, inputs, config, return_exceptions=return_exceptions, **kwargs ) diff --git a/libs/core/langchain_core/runnables/router.py b/libs/core/langchain_core/runnables/router.py index 8b0691e9e0c74..43b4761bbd530 100644 --- a/libs/core/langchain_core/runnables/router.py +++ b/libs/core/langchain_core/runnables/router.py @@ -1,12 +1,9 @@ from __future__ import annotations +from collections.abc import AsyncIterator, Iterator, Mapping from typing import ( Any, - AsyncIterator, Callable, - Iterator, - List, - Mapping, Optional, Union, cast, @@ -154,7 +151,7 @@ def invoke( configs = get_config_list(config, len(inputs)) with get_executor_for_config(configs[0]) as executor: return cast( - List[Output], + list[Output], list(executor.map(invoke, runnables, actual_inputs, configs)), ) diff --git a/libs/core/langchain_core/runnables/schema.py b/libs/core/langchain_core/runnables/schema.py index a13d8f6e4db41..228356c8de75e 100644 --- a/libs/core/langchain_core/runnables/schema.py +++ b/libs/core/langchain_core/runnables/schema.py @@ -2,7 +2,8 @@ from __future__ import annotations -from typing import Any, Literal, Sequence, Union +from collections.abc import Sequence +from typing import Any, Literal, Union from typing_extensions import NotRequired, TypedDict diff --git a/libs/core/langchain_core/runnables/utils.py b/libs/core/langchain_core/runnables/utils.py index 4b1763500e889..09a6877fc7580 100644 --- a/libs/core/langchain_core/runnables/utils.py +++ b/libs/core/langchain_core/runnables/utils.py @@ -6,23 +6,24 @@ import asyncio import inspect import textwrap -from functools import lru_cache -from inspect import signature -from itertools import groupby -from typing import ( - Any, +from collections.abc import ( AsyncIterable, AsyncIterator, Awaitable, - Callable, Coroutine, - Dict, Iterable, Mapping, + Sequence, +) +from functools import lru_cache +from inspect import signature +from itertools import groupby +from typing import ( + Any, + Callable, NamedTuple, Optional, Protocol, - Sequence, TypeVar, Union, ) @@ -430,7 +431,7 @@ def indent_lines_after_first(text: str, prefix: str) -> str: return "\n".join([lines[0]] + [spaces + line for line in lines[1:]]) -class AddableDict(Dict[str, Any]): +class AddableDict(dict[str, Any]): """ Dictionary that can be added to another dictionary. """ diff --git a/libs/core/langchain_core/stores.py b/libs/core/langchain_core/stores.py index 50b077b864251..127d3c2e04a15 100644 --- a/libs/core/langchain_core/stores.py +++ b/libs/core/langchain_core/stores.py @@ -7,16 +7,11 @@ """ from abc import ABC, abstractmethod +from collections.abc import AsyncIterator, Iterator, Sequence from typing import ( Any, - AsyncIterator, - Dict, Generic, - Iterator, - List, Optional, - Sequence, - Tuple, TypeVar, Union, ) @@ -84,7 +79,7 @@ def yield_keys(self, prefix=None): """ @abstractmethod - def mget(self, keys: Sequence[K]) -> List[Optional[V]]: + def mget(self, keys: Sequence[K]) -> list[Optional[V]]: """Get the values associated with the given keys. Args: @@ -95,7 +90,7 @@ def mget(self, keys: Sequence[K]) -> List[Optional[V]]: If a key is not found, the corresponding value will be None. """ - async def amget(self, keys: Sequence[K]) -> List[Optional[V]]: + async def amget(self, keys: Sequence[K]) -> list[Optional[V]]: """Async get the values associated with the given keys. Args: @@ -108,14 +103,14 @@ async def amget(self, keys: Sequence[K]) -> List[Optional[V]]: return await run_in_executor(None, self.mget, keys) @abstractmethod - def mset(self, key_value_pairs: Sequence[Tuple[K, V]]) -> None: + def mset(self, key_value_pairs: Sequence[tuple[K, V]]) -> None: """Set the values for the given keys. Args: key_value_pairs (Sequence[Tuple[K, V]]): A sequence of key-value pairs. """ - async def amset(self, key_value_pairs: Sequence[Tuple[K, V]]) -> None: + async def amset(self, key_value_pairs: Sequence[tuple[K, V]]) -> None: """Async set the values for the given keys. Args: @@ -184,9 +179,9 @@ class InMemoryBaseStore(BaseStore[str, V], Generic[V]): def __init__(self) -> None: """Initialize an empty store.""" - self.store: Dict[str, V] = {} + self.store: dict[str, V] = {} - def mget(self, keys: Sequence[str]) -> List[Optional[V]]: + def mget(self, keys: Sequence[str]) -> list[Optional[V]]: """Get the values associated with the given keys. Args: @@ -198,7 +193,7 @@ def mget(self, keys: Sequence[str]) -> List[Optional[V]]: """ return [self.store.get(key) for key in keys] - async def amget(self, keys: Sequence[str]) -> List[Optional[V]]: + async def amget(self, keys: Sequence[str]) -> list[Optional[V]]: """Async get the values associated with the given keys. Args: @@ -210,7 +205,7 @@ async def amget(self, keys: Sequence[str]) -> List[Optional[V]]: """ return self.mget(keys) - def mset(self, key_value_pairs: Sequence[Tuple[str, V]]) -> None: + def mset(self, key_value_pairs: Sequence[tuple[str, V]]) -> None: """Set the values for the given keys. Args: @@ -222,7 +217,7 @@ def mset(self, key_value_pairs: Sequence[Tuple[str, V]]) -> None: for key, value in key_value_pairs: self.store[key] = value - async def amset(self, key_value_pairs: Sequence[Tuple[str, V]]) -> None: + async def amset(self, key_value_pairs: Sequence[tuple[str, V]]) -> None: """Async set the values for the given keys. Args: diff --git a/libs/core/langchain_core/structured_query.py b/libs/core/langchain_core/structured_query.py index 0e58c88dcc4fd..4214d367e40e7 100644 --- a/libs/core/langchain_core/structured_query.py +++ b/libs/core/langchain_core/structured_query.py @@ -3,8 +3,9 @@ from __future__ import annotations from abc import ABC, abstractmethod +from collections.abc import Sequence from enum import Enum -from typing import Any, Optional, Sequence, Union +from typing import Any, Optional, Union from pydantic import BaseModel diff --git a/libs/core/langchain_core/sys_info.py b/libs/core/langchain_core/sys_info.py index d612132ba7247..eacf6730fde17 100644 --- a/libs/core/langchain_core/sys_info.py +++ b/libs/core/langchain_core/sys_info.py @@ -2,10 +2,10 @@ for debugging purposes. """ -from typing import List, Sequence +from collections.abc import Sequence -def _get_sub_deps(packages: Sequence[str]) -> List[str]: +def _get_sub_deps(packages: Sequence[str]) -> list[str]: """Get any specified sub-dependencies.""" from importlib import metadata diff --git a/libs/core/langchain_core/tools/base.py b/libs/core/langchain_core/tools/base.py index 38b54cb7e9adb..a3fe0aba53459 100644 --- a/libs/core/langchain_core/tools/base.py +++ b/libs/core/langchain_core/tools/base.py @@ -7,15 +7,15 @@ import uuid import warnings from abc import ABC, abstractmethod +from collections.abc import Sequence from contextvars import copy_context from inspect import signature from typing import ( + Annotated, Any, Callable, - Dict, Literal, Optional, - Sequence, TypeVar, Union, cast, @@ -36,7 +36,6 @@ ) from pydantic.v1 import BaseModel as BaseModelV1 from pydantic.v1 import validate_arguments as validate_arguments_v1 -from typing_extensions import Annotated from langchain_core._api import deprecated from langchain_core.callbacks import ( @@ -324,7 +323,7 @@ class ToolException(Exception): pass -class BaseTool(RunnableSerializable[Union[str, Dict, ToolCall], Any]): +class BaseTool(RunnableSerializable[Union[str, dict, ToolCall], Any]): """Interface LangChain tools must implement.""" def __init_subclass__(cls, **kwargs: Any) -> None: diff --git a/libs/core/langchain_core/tools/convert.py b/libs/core/langchain_core/tools/convert.py index 2fdab73bb9b02..b419a595559a5 100644 --- a/libs/core/langchain_core/tools/convert.py +++ b/libs/core/langchain_core/tools/convert.py @@ -1,5 +1,5 @@ import inspect -from typing import Any, Callable, Dict, Literal, Optional, Type, Union, get_type_hints +from typing import Any, Callable, Literal, Optional, Union, get_type_hints from pydantic import BaseModel, Field, create_model @@ -13,7 +13,7 @@ def tool( *args: Union[str, Callable, Runnable], return_direct: bool = False, - args_schema: Optional[Type] = None, + args_schema: Optional[type] = None, infer_schema: bool = True, response_format: Literal["content", "content_and_artifact"] = "content", parse_docstring: bool = False, @@ -160,7 +160,7 @@ def invoke_wrapper( coroutine = ainvoke_wrapper func = invoke_wrapper - schema: Optional[Type[BaseModel]] = runnable.input_schema + schema: Optional[type[BaseModel]] = runnable.input_schema description = repr(runnable) elif inspect.iscoroutinefunction(dec_func): coroutine = dec_func @@ -234,8 +234,8 @@ def _get_description_from_runnable(runnable: Runnable) -> str: def _get_schema_from_runnable_and_arg_types( runnable: Runnable, name: str, - arg_types: Optional[Dict[str, Type]] = None, -) -> Type[BaseModel]: + arg_types: Optional[dict[str, type]] = None, +) -> type[BaseModel]: """Infer args_schema for tool.""" if arg_types is None: try: @@ -252,11 +252,11 @@ def _get_schema_from_runnable_and_arg_types( def convert_runnable_to_tool( runnable: Runnable, - args_schema: Optional[Type[BaseModel]] = None, + args_schema: Optional[type[BaseModel]] = None, *, name: Optional[str] = None, description: Optional[str] = None, - arg_types: Optional[Dict[str, Type]] = None, + arg_types: Optional[dict[str, type]] = None, ) -> BaseTool: """Convert a Runnable into a BaseTool. diff --git a/libs/core/langchain_core/tools/render.py b/libs/core/langchain_core/tools/render.py index e8c05a064023b..d59aef42303c7 100644 --- a/libs/core/langchain_core/tools/render.py +++ b/libs/core/langchain_core/tools/render.py @@ -1,11 +1,11 @@ from __future__ import annotations from inspect import signature -from typing import Callable, List +from typing import Callable from langchain_core.tools.base import BaseTool -ToolsRenderer = Callable[[List[BaseTool]], str] +ToolsRenderer = Callable[[list[BaseTool]], str] def render_text_description(tools: list[BaseTool]) -> str: diff --git a/libs/core/langchain_core/tools/simple.py b/libs/core/langchain_core/tools/simple.py index cace7e82ec74d..6f0bdb516fc6c 100644 --- a/libs/core/langchain_core/tools/simple.py +++ b/libs/core/langchain_core/tools/simple.py @@ -1,9 +1,9 @@ from __future__ import annotations +from collections.abc import Awaitable from inspect import signature from typing import ( Any, - Awaitable, Callable, Optional, Union, diff --git a/libs/core/langchain_core/tools/structured.py b/libs/core/langchain_core/tools/structured.py index fb8c69e92bcc2..bf645265b457c 100644 --- a/libs/core/langchain_core/tools/structured.py +++ b/libs/core/langchain_core/tools/structured.py @@ -1,10 +1,11 @@ from __future__ import annotations import textwrap +from collections.abc import Awaitable from inspect import signature from typing import ( + Annotated, Any, - Awaitable, Callable, Literal, Optional, @@ -12,7 +13,6 @@ ) from pydantic import BaseModel, Field, SkipValidation -from typing_extensions import Annotated from langchain_core.callbacks import ( AsyncCallbackManagerForToolRun, diff --git a/libs/core/langchain_core/tracers/_streaming.py b/libs/core/langchain_core/tracers/_streaming.py index 7f9f65395c7d6..ca50213d88a75 100644 --- a/libs/core/langchain_core/tracers/_streaming.py +++ b/libs/core/langchain_core/tracers/_streaming.py @@ -1,7 +1,8 @@ """Internal tracers used for stream_log and astream events implementations.""" import abc -from typing import AsyncIterator, Iterator, TypeVar +from collections.abc import AsyncIterator, Iterator +from typing import TypeVar from uuid import UUID T = TypeVar("T") diff --git a/libs/core/langchain_core/tracers/base.py b/libs/core/langchain_core/tracers/base.py index 11bb8b3f1b827..a31e6d9edee47 100644 --- a/libs/core/langchain_core/tracers/base.py +++ b/libs/core/langchain_core/tracers/base.py @@ -5,11 +5,11 @@ import asyncio import logging from abc import ABC, abstractmethod +from collections.abc import Sequence from typing import ( TYPE_CHECKING, Any, Optional, - Sequence, Union, ) from uuid import UUID diff --git a/libs/core/langchain_core/tracers/context.py b/libs/core/langchain_core/tracers/context.py index c252dad050c98..f2f05849d6265 100644 --- a/libs/core/langchain_core/tracers/context.py +++ b/libs/core/langchain_core/tracers/context.py @@ -1,11 +1,11 @@ from __future__ import annotations +from collections.abc import Generator from contextlib import contextmanager from contextvars import ContextVar from typing import ( TYPE_CHECKING, Any, - Generator, Optional, Union, cast, diff --git a/libs/core/langchain_core/tracers/core.py b/libs/core/langchain_core/tracers/core.py index 53bc09a3e54ad..f4d7bbdccbf55 100644 --- a/libs/core/langchain_core/tracers/core.py +++ b/libs/core/langchain_core/tracers/core.py @@ -6,14 +6,13 @@ import sys import traceback from abc import ABC, abstractmethod +from collections.abc import Coroutine, Sequence from datetime import datetime, timezone from typing import ( TYPE_CHECKING, Any, - Coroutine, Literal, Optional, - Sequence, Union, cast, ) diff --git a/libs/core/langchain_core/tracers/evaluation.py b/libs/core/langchain_core/tracers/evaluation.py index 687e5994ef3f4..c41fa2d7f5816 100644 --- a/libs/core/langchain_core/tracers/evaluation.py +++ b/libs/core/langchain_core/tracers/evaluation.py @@ -5,8 +5,9 @@ import logging import threading import weakref +from collections.abc import Sequence from concurrent.futures import Future, ThreadPoolExecutor, wait -from typing import Any, List, Optional, Sequence, Union, cast +from typing import Any, Optional, Union, cast from uuid import UUID import langsmith @@ -156,7 +157,7 @@ def _select_eval_results( if isinstance(results, EvaluationResult): results_ = [results] elif isinstance(results, dict) and "results" in results: - results_ = cast(List[EvaluationResult], results["results"]) + results_ = cast(list[EvaluationResult], results["results"]) else: raise TypeError( f"Invalid evaluation result type {type(results)}." diff --git a/libs/core/langchain_core/tracers/event_stream.py b/libs/core/langchain_core/tracers/event_stream.py index 23739dcb7313b..3b2f512560459 100644 --- a/libs/core/langchain_core/tracers/event_stream.py +++ b/libs/core/langchain_core/tracers/event_stream.py @@ -4,14 +4,11 @@ import asyncio import logging +from collections.abc import AsyncIterator, Iterator, Sequence from typing import ( TYPE_CHECKING, Any, - AsyncIterator, - Iterator, - List, Optional, - Sequence, TypeVar, Union, cast, @@ -459,7 +456,7 @@ async def on_llm_end( output: Union[dict, BaseMessage] = {} if run_info["run_type"] == "chat_model": - generations = cast(List[List[ChatGenerationChunk]], response.generations) + generations = cast(list[list[ChatGenerationChunk]], response.generations) for gen in generations: if output != {}: break @@ -469,7 +466,7 @@ async def on_llm_end( event = "on_chat_model_end" elif run_info["run_type"] == "llm": - generations = cast(List[List[GenerationChunk]], response.generations) + generations = cast(list[list[GenerationChunk]], response.generations) output = { "generations": [ [ diff --git a/libs/core/langchain_core/tracers/log_stream.py b/libs/core/langchain_core/tracers/log_stream.py index a48f4233594bb..e6422b0682da0 100644 --- a/libs/core/langchain_core/tracers/log_stream.py +++ b/libs/core/langchain_core/tracers/log_stream.py @@ -4,13 +4,11 @@ import copy import threading from collections import defaultdict +from collections.abc import AsyncIterator, Iterator, Sequence from typing import ( Any, - AsyncIterator, - Iterator, Literal, Optional, - Sequence, TypeVar, Union, overload, diff --git a/libs/core/langchain_core/tracers/memory_stream.py b/libs/core/langchain_core/tracers/memory_stream.py index 8aa324737f69b..4a3d09a42b852 100644 --- a/libs/core/langchain_core/tracers/memory_stream.py +++ b/libs/core/langchain_core/tracers/memory_stream.py @@ -11,7 +11,8 @@ import asyncio from asyncio import AbstractEventLoop, Queue -from typing import AsyncIterator, Generic, TypeVar +from collections.abc import AsyncIterator +from typing import Generic, TypeVar T = TypeVar("T") diff --git a/libs/core/langchain_core/tracers/root_listeners.py b/libs/core/langchain_core/tracers/root_listeners.py index bfd73a17131ff..1530598fb7534 100644 --- a/libs/core/langchain_core/tracers/root_listeners.py +++ b/libs/core/langchain_core/tracers/root_listeners.py @@ -1,4 +1,5 @@ -from typing import Awaitable, Callable, Optional, Union +from collections.abc import Awaitable +from typing import Callable, Optional, Union from uuid import UUID from langchain_core.runnables.config import ( diff --git a/libs/core/langchain_core/tracers/run_collector.py b/libs/core/langchain_core/tracers/run_collector.py index 1032c6adf15f8..9001eac38d189 100644 --- a/libs/core/langchain_core/tracers/run_collector.py +++ b/libs/core/langchain_core/tracers/run_collector.py @@ -1,6 +1,6 @@ """A tracer that collects all nested runs in a list.""" -from typing import Any, List, Optional, Union +from typing import Any, Optional, Union from uuid import UUID from langchain_core.tracers.base import BaseTracer @@ -38,7 +38,7 @@ def __init__( self.example_id = ( UUID(example_id) if isinstance(example_id, str) else example_id ) - self.traced_runs: List[Run] = [] + self.traced_runs: list[Run] = [] def _persist_run(self, run: Run) -> None: """ diff --git a/libs/core/langchain_core/tracers/stdout.py b/libs/core/langchain_core/tracers/stdout.py index b9d2043082d16..22ace8bb70f98 100644 --- a/libs/core/langchain_core/tracers/stdout.py +++ b/libs/core/langchain_core/tracers/stdout.py @@ -1,5 +1,5 @@ import json -from typing import Any, Callable, List +from typing import Any, Callable from langchain_core.tracers.base import BaseTracer from langchain_core.tracers.schemas import Run @@ -54,7 +54,7 @@ def __init__(self, function: Callable[[str], None], **kwargs: Any) -> None: def _persist_run(self, run: Run) -> None: pass - def get_parents(self, run: Run) -> List[Run]: + def get_parents(self, run: Run) -> list[Run]: """Get the parents of a run. Args: diff --git a/libs/core/langchain_core/utils/aiter.py b/libs/core/langchain_core/utils/aiter.py index b2bc92699466b..9c42fc70e386a 100644 --- a/libs/core/langchain_core/utils/aiter.py +++ b/libs/core/langchain_core/utils/aiter.py @@ -5,23 +5,20 @@ """ from collections import deque -from contextlib import AbstractAsyncContextManager -from types import TracebackType -from typing import ( - Any, - AsyncContextManager, +from collections.abc import ( AsyncGenerator, AsyncIterable, AsyncIterator, Awaitable, + Iterator, +) +from contextlib import AbstractAsyncContextManager +from types import TracebackType +from typing import ( + Any, Callable, - Deque, Generic, - Iterator, - List, Optional, - Tuple, - Type, TypeVar, Union, cast, @@ -95,10 +92,10 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool: async def tee_peer( iterator: AsyncIterator[T], # the buffer specific to this peer - buffer: Deque[T], + buffer: deque[T], # the buffers of all peers, including our own - peers: List[Deque[T]], - lock: AsyncContextManager[Any], + peers: list[deque[T]], + lock: AbstractAsyncContextManager[Any], ) -> AsyncGenerator[T, None]: """An individual iterator of a :py:func:`~.tee`. @@ -191,10 +188,10 @@ def __init__( iterable: AsyncIterator[T], n: int = 2, *, - lock: Optional[AsyncContextManager[Any]] = None, + lock: Optional[AbstractAsyncContextManager[Any]] = None, ): self._iterator = iterable.__aiter__() # before 3.10 aiter() doesn't exist - self._buffers: List[Deque[T]] = [deque() for _ in range(n)] + self._buffers: list[deque[T]] = [deque() for _ in range(n)] self._children = tuple( tee_peer( iterator=self._iterator, @@ -212,11 +209,11 @@ def __len__(self) -> int: def __getitem__(self, item: int) -> AsyncIterator[T]: ... @overload - def __getitem__(self, item: slice) -> Tuple[AsyncIterator[T], ...]: ... + def __getitem__(self, item: slice) -> tuple[AsyncIterator[T], ...]: ... def __getitem__( self, item: Union[int, slice] - ) -> Union[AsyncIterator[T], Tuple[AsyncIterator[T], ...]]: + ) -> Union[AsyncIterator[T], tuple[AsyncIterator[T], ...]]: return self._children[item] def __iter__(self) -> Iterator[AsyncIterator[T]]: @@ -267,7 +264,7 @@ async def __aenter__(self) -> Union[AsyncGenerator[Any, Any], AsyncIterator[Any] async def __aexit__( self, - exc_type: Optional[Type[BaseException]], + exc_type: Optional[type[BaseException]], exc_value: Optional[BaseException], traceback: Optional[TracebackType], ) -> None: @@ -277,7 +274,7 @@ async def __aexit__( async def abatch_iterate( size: int, iterable: AsyncIterable[T] -) -> AsyncIterator[List[T]]: +) -> AsyncIterator[list[T]]: """Utility batching function for async iterables. Args: @@ -287,7 +284,7 @@ async def abatch_iterate( Returns: An async iterator over the batches. """ - batch: List[T] = [] + batch: list[T] = [] async for element in iterable: if len(batch) < size: batch.append(element) diff --git a/libs/core/langchain_core/utils/formatting.py b/libs/core/langchain_core/utils/formatting.py index b5b880f66b1fc..1e71cdf28ce7a 100644 --- a/libs/core/langchain_core/utils/formatting.py +++ b/libs/core/langchain_core/utils/formatting.py @@ -1,7 +1,8 @@ """Utilities for formatting strings.""" +from collections.abc import Mapping, Sequence from string import Formatter -from typing import Any, List, Mapping, Sequence +from typing import Any class StrictFormatter(Formatter): @@ -31,7 +32,7 @@ def vformat( return super().vformat(format_string, args, kwargs) def validate_input_variables( - self, format_string: str, input_variables: List[str] + self, format_string: str, input_variables: list[str] ) -> None: """Check that all input variables are used in the format string. diff --git a/libs/core/langchain_core/utils/function_calling.py b/libs/core/langchain_core/utils/function_calling.py index 234281966cec6..6694b1f5bd11e 100644 --- a/libs/core/langchain_core/utils/function_calling.py +++ b/libs/core/langchain_core/utils/function_calling.py @@ -10,21 +10,17 @@ import uuid from typing import ( TYPE_CHECKING, + Annotated, Any, Callable, - Dict, - List, Literal, Optional, - Set, - Tuple, - Type, Union, cast, ) from pydantic import BaseModel -from typing_extensions import Annotated, TypedDict, get_args, get_origin, is_typeddict +from typing_extensions import TypedDict, get_args, get_origin, is_typeddict from langchain_core._api import deprecated from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage @@ -201,7 +197,7 @@ def _convert_typed_dict_to_openai_function(typed_dict: type) -> FunctionDescript from pydantic.v1 import BaseModel model = cast( - Type[BaseModel], + type[BaseModel], _convert_any_typed_dicts_to_pydantic(typed_dict, visited=visited), ) return convert_pydantic_to_openai_function(model) # type: ignore @@ -383,15 +379,15 @@ def convert_to_openai_function( "parameters": function, } elif isinstance(function, type) and is_basemodel_subclass(function): - oai_function = cast(Dict, convert_pydantic_to_openai_function(function)) + oai_function = cast(dict, convert_pydantic_to_openai_function(function)) elif is_typeddict(function): oai_function = cast( - Dict, _convert_typed_dict_to_openai_function(cast(Type, function)) + dict, _convert_typed_dict_to_openai_function(cast(type, function)) ) elif isinstance(function, BaseTool): - oai_function = cast(Dict, format_tool_to_openai_function(function)) + oai_function = cast(dict, format_tool_to_openai_function(function)) elif callable(function): - oai_function = cast(Dict, convert_python_function_to_openai_function(function)) + oai_function = cast(dict, convert_python_function_to_openai_function(function)) else: raise ValueError( f"Unsupported function\n\n{function}\n\nFunctions must be passed in" @@ -598,17 +594,17 @@ def _py_38_safe_origin(origin: type) -> type: ) origin_map: dict[type, Any] = { - dict: Dict, - list: List, - tuple: Tuple, - set: Set, + dict: dict, + list: list, + tuple: tuple, + set: set, collections.abc.Iterable: typing.Iterable, collections.abc.Mapping: typing.Mapping, collections.abc.Sequence: typing.Sequence, collections.abc.MutableMapping: typing.MutableMapping, **origin_union_type_map, } - return cast(Type, origin_map.get(origin, origin)) + return cast(type, origin_map.get(origin, origin)) def _recursive_set_additional_properties_false( diff --git a/libs/core/langchain_core/utils/html.py b/libs/core/langchain_core/utils/html.py index 9750515bd1f58..805cae3cc8ed8 100644 --- a/libs/core/langchain_core/utils/html.py +++ b/libs/core/langchain_core/utils/html.py @@ -1,6 +1,7 @@ import logging import re -from typing import List, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Optional, Union from urllib.parse import urljoin, urlparse logger = logging.getLogger(__name__) @@ -33,7 +34,7 @@ def find_all_links( raw_html: str, *, pattern: Union[str, re.Pattern, None] = None -) -> List[str]: +) -> list[str]: """Extract all links from a raw HTML string. Args: @@ -56,7 +57,7 @@ def extract_sub_links( prevent_outside: bool = True, exclude_prefixes: Sequence[str] = (), continue_on_failure: bool = False, -) -> List[str]: +) -> list[str]: """Extract all links from a raw HTML string and convert into absolute paths. Args: diff --git a/libs/core/langchain_core/utils/input.py b/libs/core/langchain_core/utils/input.py index 7edfcd7b6cb17..05e9bbcfcedc3 100644 --- a/libs/core/langchain_core/utils/input.py +++ b/libs/core/langchain_core/utils/input.py @@ -1,6 +1,6 @@ """Handle chained inputs.""" -from typing import Dict, List, Optional, TextIO +from typing import Optional, TextIO _TEXT_COLOR_MAPPING = { "blue": "36;1", @@ -12,8 +12,8 @@ def get_color_mapping( - items: List[str], excluded_colors: Optional[List] = None -) -> Dict[str, str]: + items: list[str], excluded_colors: Optional[list] = None +) -> dict[str, str]: """Get mapping for items to a support color. Args: diff --git a/libs/core/langchain_core/utils/iter.py b/libs/core/langchain_core/utils/iter.py index 9645b75597b18..f76ada9ddd02c 100644 --- a/libs/core/langchain_core/utils/iter.py +++ b/libs/core/langchain_core/utils/iter.py @@ -1,16 +1,11 @@ from collections import deque +from collections.abc import Generator, Iterable, Iterator +from contextlib import AbstractContextManager from itertools import islice from typing import ( Any, - ContextManager, - Deque, - Generator, Generic, - Iterable, - Iterator, - List, Optional, - Tuple, TypeVar, Union, overload, @@ -34,10 +29,10 @@ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> Literal[False]: def tee_peer( iterator: Iterator[T], # the buffer specific to this peer - buffer: Deque[T], + buffer: deque[T], # the buffers of all peers, including our own - peers: List[Deque[T]], - lock: ContextManager[Any], + peers: list[deque[T]], + lock: AbstractContextManager[Any], ) -> Generator[T, None, None]: """An individual iterator of a :py:func:`~.tee`. @@ -130,7 +125,7 @@ def __init__( iterable: Iterator[T], n: int = 2, *, - lock: Optional[ContextManager[Any]] = None, + lock: Optional[AbstractContextManager[Any]] = None, ): """Create a new ``tee``. @@ -141,7 +136,7 @@ def __init__( Defaults to None. """ self._iterator = iter(iterable) - self._buffers: List[Deque[T]] = [deque() for _ in range(n)] + self._buffers: list[deque[T]] = [deque() for _ in range(n)] self._children = tuple( tee_peer( iterator=self._iterator, @@ -159,11 +154,11 @@ def __len__(self) -> int: def __getitem__(self, item: int) -> Iterator[T]: ... @overload - def __getitem__(self, item: slice) -> Tuple[Iterator[T], ...]: ... + def __getitem__(self, item: slice) -> tuple[Iterator[T], ...]: ... def __getitem__( self, item: Union[int, slice] - ) -> Union[Iterator[T], Tuple[Iterator[T], ...]]: + ) -> Union[Iterator[T], tuple[Iterator[T], ...]]: return self._children[item] def __iter__(self) -> Iterator[Iterator[T]]: @@ -185,7 +180,7 @@ def close(self) -> None: safetee = Tee -def batch_iterate(size: Optional[int], iterable: Iterable[T]) -> Iterator[List[T]]: +def batch_iterate(size: Optional[int], iterable: Iterable[T]) -> Iterator[list[T]]: """Utility batching function. Args: diff --git a/libs/core/langchain_core/utils/json_schema.py b/libs/core/langchain_core/utils/json_schema.py index 172a741be66c5..be9642808c1a7 100644 --- a/libs/core/langchain_core/utils/json_schema.py +++ b/libs/core/langchain_core/utils/json_schema.py @@ -1,7 +1,8 @@ from __future__ import annotations +from collections.abc import Sequence from copy import deepcopy -from typing import Any, Optional, Sequence +from typing import Any, Optional def _retrieve_ref(path: str, schema: dict) -> dict: diff --git a/libs/core/langchain_core/utils/mustache.py b/libs/core/langchain_core/utils/mustache.py index 3eb742529ef8a..c517233a3659e 100644 --- a/libs/core/langchain_core/utils/mustache.py +++ b/libs/core/langchain_core/utils/mustache.py @@ -6,16 +6,12 @@ from __future__ import annotations import logging +from collections.abc import Iterator, Mapping, Sequence from types import MappingProxyType from typing import ( Any, - Dict, - Iterator, - List, Literal, - Mapping, Optional, - Sequence, Union, cast, ) @@ -25,7 +21,7 @@ logger = logging.getLogger(__name__) -Scopes: TypeAlias = List[Union[Literal[False, 0], Mapping[str, Any]]] +Scopes: TypeAlias = list[Union[Literal[False, 0], Mapping[str, Any]]] # Globals @@ -381,7 +377,7 @@ def _get_key( # Move into the scope try: # Try subscripting (Normal dictionaries) - scope = cast(Dict[str, Any], scope)[child] + scope = cast(dict[str, Any], scope)[child] except (TypeError, AttributeError): try: scope = getattr(scope, child) diff --git a/libs/core/langchain_core/utils/pydantic.py b/libs/core/langchain_core/utils/pydantic.py index ac44ac1ebccb9..b7d03651667d3 100644 --- a/libs/core/langchain_core/utils/pydantic.py +++ b/libs/core/langchain_core/utils/pydantic.py @@ -7,12 +7,11 @@ import warnings from contextlib import nullcontext from functools import lru_cache, wraps +from types import GenericAlias from typing import ( Any, Callable, - Dict, Optional, - Type, TypeVar, Union, cast, @@ -56,13 +55,13 @@ def get_pydantic_major_version() -> int: from pydantic.fields import FieldInfo as FieldInfoV1 PydanticBaseModel = pydantic.BaseModel - TypeBaseModel = Type[BaseModel] + TypeBaseModel = type[BaseModel] elif PYDANTIC_MAJOR_VERSION == 2: from pydantic.v1.fields import FieldInfo as FieldInfoV1 # type: ignore[assignment] # Union type needs to be last assignment to PydanticBaseModel to make mypy happy. PydanticBaseModel = Union[BaseModel, pydantic.BaseModel] # type: ignore - TypeBaseModel = Union[Type[BaseModel], Type[pydantic.BaseModel]] # type: ignore + TypeBaseModel = Union[type[BaseModel], type[pydantic.BaseModel]] # type: ignore else: raise ValueError(f"Unsupported Pydantic version: {PYDANTIC_MAJOR_VERSION}") @@ -99,7 +98,7 @@ def is_basemodel_subclass(cls: type) -> bool: * pydantic.v1.BaseModel in Pydantic 2.x """ # Before we can use issubclass on the cls we need to check if it is a class - if not inspect.isclass(cls): + if not inspect.isclass(cls) or isinstance(cls, GenericAlias): return False if PYDANTIC_MAJOR_VERSION == 1: @@ -436,12 +435,19 @@ def model_json_schema( if default_ is not NO_DEFAULT: base_class_attributes["root"] = default_ with warnings.catch_warnings(): - if isinstance(type_, type) and issubclass(type_, BaseModelV1): - warnings.filterwarnings( - action="ignore", category=PydanticDeprecationWarning - ) + try: + if ( + isinstance(type_, type) + and not isinstance(type_, GenericAlias) + and issubclass(type_, BaseModelV1) + ): + warnings.filterwarnings( + action="ignore", category=PydanticDeprecationWarning + ) + except TypeError: + pass custom_root_type = type(name, (RootModel,), base_class_attributes) - return cast(Type[BaseModel], custom_root_type) + return cast(type[BaseModel], custom_root_type) @lru_cache(maxsize=256) @@ -565,7 +571,7 @@ def create_model_v2( Returns: Type[BaseModel]: The created model. """ - field_definitions = cast(Dict[str, Any], field_definitions or {}) # type: ignore[no-redef] + field_definitions = cast(dict[str, Any], field_definitions or {}) # type: ignore[no-redef] if root: if field_definitions: diff --git a/libs/core/langchain_core/utils/strings.py b/libs/core/langchain_core/utils/strings.py index f54ddb4fbaaa1..e7a79761f9a01 100644 --- a/libs/core/langchain_core/utils/strings.py +++ b/libs/core/langchain_core/utils/strings.py @@ -1,4 +1,4 @@ -from typing import Any, List +from typing import Any def stringify_value(val: Any) -> str: @@ -35,7 +35,7 @@ def stringify_dict(data: dict) -> str: return text -def comma_list(items: List[Any]) -> str: +def comma_list(items: list[Any]) -> str: """Convert a list to a comma-separated string. Args: diff --git a/libs/core/langchain_core/utils/utils.py b/libs/core/langchain_core/utils/utils.py index d38657cc91cd0..a90eba11b9d5e 100644 --- a/libs/core/langchain_core/utils/utils.py +++ b/libs/core/langchain_core/utils/utils.py @@ -6,8 +6,9 @@ import importlib import os import warnings +from collections.abc import Sequence from importlib.metadata import version -from typing import Any, Callable, Dict, Optional, Sequence, Set, Tuple, Union, overload +from typing import Any, Callable, Optional, Union, overload from packaging.version import parse from pydantic import SecretStr @@ -18,7 +19,7 @@ ) -def xor_args(*arg_groups: Tuple[str, ...]) -> Callable: +def xor_args(*arg_groups: tuple[str, ...]) -> Callable: """Validate specified keyword args are mutually exclusive." Args: @@ -186,7 +187,7 @@ def check_package_version( ) -def get_pydantic_field_names(pydantic_cls: Any) -> Set[str]: +def get_pydantic_field_names(pydantic_cls: Any) -> set[str]: """Get field names, including aliases, for a pydantic class. Args: @@ -210,10 +211,10 @@ def get_pydantic_field_names(pydantic_cls: Any) -> Set[str]: def build_extra_kwargs( - extra_kwargs: Dict[str, Any], - values: Dict[str, Any], - all_required_field_names: Set[str], -) -> Dict[str, Any]: + extra_kwargs: dict[str, Any], + values: dict[str, Any], + all_required_field_names: set[str], +) -> dict[str, Any]: """Build extra kwargs from values and extra_kwargs. Args: diff --git a/libs/core/langchain_core/vectorstores/base.py b/libs/core/langchain_core/vectorstores/base.py index 5fca50cfb767f..aacf897ffb915 100644 --- a/libs/core/langchain_core/vectorstores/base.py +++ b/libs/core/langchain_core/vectorstores/base.py @@ -25,16 +25,14 @@ import math import warnings from abc import ABC, abstractmethod +from collections.abc import Collection, Iterable, Sequence from itertools import cycle from typing import ( TYPE_CHECKING, Any, Callable, ClassVar, - Collection, - Iterable, Optional, - Sequence, TypeVar, ) diff --git a/libs/core/langchain_core/vectorstores/in_memory.py b/libs/core/langchain_core/vectorstores/in_memory.py index d359f619685b6..0b0a7c08653f0 100644 --- a/libs/core/langchain_core/vectorstores/in_memory.py +++ b/libs/core/langchain_core/vectorstores/in_memory.py @@ -2,14 +2,13 @@ import json import uuid +from collections.abc import Iterator, Sequence from pathlib import Path from typing import ( TYPE_CHECKING, Any, Callable, - Iterator, Optional, - Sequence, ) from langchain_core._api import deprecated diff --git a/libs/core/langchain_core/vectorstores/utils.py b/libs/core/langchain_core/vectorstores/utils.py index 89fe0149e8268..c744d2ef621a2 100644 --- a/libs/core/langchain_core/vectorstores/utils.py +++ b/libs/core/langchain_core/vectorstores/utils.py @@ -7,12 +7,12 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, List, Union +from typing import TYPE_CHECKING, Union if TYPE_CHECKING: import numpy as np - Matrix = Union[List[List[float]], List[np.ndarray], np.ndarray] + Matrix = Union[list[list[float]], list[np.ndarray], np.ndarray] logger = logging.getLogger(__name__) diff --git a/libs/core/pyproject.toml b/libs/core/pyproject.toml index 96abaa5db89cc..3fd81376ac707 100644 --- a/libs/core/pyproject.toml +++ b/libs/core/pyproject.toml @@ -2,6 +2,9 @@ requires = [ "poetry-core>=1.0.0",] build-backend = "poetry.core.masonry.api" +[project] +requires-python = ">=3.9,<4.0" + [tool.poetry] name = "langchain-core" version = "0.3.1" diff --git a/libs/core/tests/unit_tests/_api/test_beta_decorator.py b/libs/core/tests/unit_tests/_api/test_beta_decorator.py index 1c9ab67700699..07480b683195d 100644 --- a/libs/core/tests/unit_tests/_api/test_beta_decorator.py +++ b/libs/core/tests/unit_tests/_api/test_beta_decorator.py @@ -1,6 +1,6 @@ import inspect import warnings -from typing import Any, Dict +from typing import Any import pytest from pydantic import BaseModel @@ -41,7 +41,7 @@ ), ], ) -def test_warn_beta(kwargs: Dict[str, Any], expected_message: str) -> None: +def test_warn_beta(kwargs: dict[str, Any], expected_message: str) -> None: """Test warn beta.""" with warnings.catch_warnings(record=True) as warning_list: warnings.simplefilter("always") diff --git a/libs/core/tests/unit_tests/_api/test_deprecation.py b/libs/core/tests/unit_tests/_api/test_deprecation.py index 16407e9d89bbc..726e8f9f9d300 100644 --- a/libs/core/tests/unit_tests/_api/test_deprecation.py +++ b/libs/core/tests/unit_tests/_api/test_deprecation.py @@ -1,6 +1,6 @@ import inspect import warnings -from typing import Any, Dict +from typing import Any import pytest from pydantic import BaseModel @@ -55,7 +55,7 @@ ), ], ) -def test_warn_deprecated(kwargs: Dict[str, Any], expected_message: str) -> None: +def test_warn_deprecated(kwargs: dict[str, Any], expected_message: str) -> None: """Test warn deprecated.""" with warnings.catch_warnings(record=True) as warning_list: warnings.simplefilter("always") diff --git a/libs/core/tests/unit_tests/caches/test_in_memory_cache.py b/libs/core/tests/unit_tests/caches/test_in_memory_cache.py index 67143c0ff9227..2fba0705a57fc 100644 --- a/libs/core/tests/unit_tests/caches/test_in_memory_cache.py +++ b/libs/core/tests/unit_tests/caches/test_in_memory_cache.py @@ -1,5 +1,3 @@ -from typing import Tuple - import pytest from langchain_core.caches import RETURN_VAL_TYPE, InMemoryCache @@ -12,7 +10,7 @@ def cache() -> InMemoryCache: return InMemoryCache() -def cache_item(item_id: int) -> Tuple[str, str, RETURN_VAL_TYPE]: +def cache_item(item_id: int) -> tuple[str, str, RETURN_VAL_TYPE]: """Generate a valid cache item.""" prompt = f"prompt{item_id}" llm_string = f"llm_string{item_id}" diff --git a/libs/core/tests/unit_tests/callbacks/test_dispatch_custom_event.py b/libs/core/tests/unit_tests/callbacks/test_dispatch_custom_event.py index 62b76a7a2d65f..cd13d1d46b26b 100644 --- a/libs/core/tests/unit_tests/callbacks/test_dispatch_custom_event.py +++ b/libs/core/tests/unit_tests/callbacks/test_dispatch_custom_event.py @@ -1,6 +1,6 @@ import sys import uuid -from typing import Any, Dict, List, Optional +from typing import Any, Optional from uuid import UUID import pytest @@ -16,7 +16,7 @@ class AsyncCustomCallbackHandler(AsyncCallbackHandler): def __init__(self) -> None: - self.events: List[Any] = [] + self.events: list[Any] = [] async def on_custom_event( self, @@ -24,8 +24,8 @@ async def on_custom_event( data: Any, *, run_id: UUID, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, + tags: Optional[list[str]] = None, + metadata: Optional[dict[str, Any]] = None, **kwargs: Any, ) -> None: assert kwargs == {} @@ -120,7 +120,7 @@ def test_sync_callback_manager() -> None: class CustomCallbackManager(BaseCallbackHandler): def __init__(self) -> None: - self.events: List[Any] = [] + self.events: list[Any] = [] def on_custom_event( self, @@ -128,8 +128,8 @@ def on_custom_event( data: Any, *, run_id: UUID, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, + tags: Optional[list[str]] = None, + metadata: Optional[dict[str, Any]] = None, **kwargs: Any, ) -> None: assert kwargs == {} diff --git a/libs/core/tests/unit_tests/chat_history/test_chat_history.py b/libs/core/tests/unit_tests/chat_history/test_chat_history.py index e7d2b724f7efa..729f15381c175 100644 --- a/libs/core/tests/unit_tests/chat_history/test_chat_history.py +++ b/libs/core/tests/unit_tests/chat_history/test_chat_history.py @@ -1,4 +1,4 @@ -from typing import List, Sequence +from collections.abc import Sequence from langchain_core.chat_history import BaseChatMessageHistory from langchain_core.messages import BaseMessage, HumanMessage @@ -8,7 +8,7 @@ def test_add_message_implementation_only() -> None: """Test implementation of add_message only.""" class SampleChatHistory(BaseChatMessageHistory): - def __init__(self, *, store: List[BaseMessage]) -> None: + def __init__(self, *, store: list[BaseMessage]) -> None: self.store = store def add_message(self, message: BaseMessage) -> None: @@ -19,7 +19,7 @@ def clear(self) -> None: """Clear the store.""" raise NotImplementedError() - store: List[BaseMessage] = [] + store: list[BaseMessage] = [] chat_history = SampleChatHistory(store=store) chat_history.add_message(HumanMessage(content="Hello")) assert len(store) == 1 @@ -38,10 +38,10 @@ def clear(self) -> None: def test_bulk_message_implementation_only() -> None: """Test that SampleChatHistory works as expected.""" - store: List[BaseMessage] = [] + store: list[BaseMessage] = [] class BulkAddHistory(BaseChatMessageHistory): - def __init__(self, *, store: List[BaseMessage]) -> None: + def __init__(self, *, store: list[BaseMessage]) -> None: self.store = store def add_messages(self, message: Sequence[BaseMessage]) -> None: diff --git a/libs/core/tests/unit_tests/conftest.py b/libs/core/tests/unit_tests/conftest.py index d2f1207557091..53104b12b2afa 100644 --- a/libs/core/tests/unit_tests/conftest.py +++ b/libs/core/tests/unit_tests/conftest.py @@ -1,7 +1,7 @@ """Configuration for unit tests.""" +from collections.abc import Sequence from importlib import util -from typing import Dict, Sequence from uuid import UUID import pytest @@ -41,7 +41,7 @@ def test_something(): """ # Mapping from the name of a package to whether it is installed or not. # Used to avoid repeated calls to `util.find_spec` - required_pkgs_info: Dict[str, bool] = {} + required_pkgs_info: dict[str, bool] = {} only_extended = config.getoption("--only-extended") or False only_core = config.getoption("--only-core") or False diff --git a/libs/core/tests/unit_tests/document_loaders/test_base.py b/libs/core/tests/unit_tests/document_loaders/test_base.py index 484114766a0c4..4297b1654824f 100644 --- a/libs/core/tests/unit_tests/document_loaders/test_base.py +++ b/libs/core/tests/unit_tests/document_loaders/test_base.py @@ -1,6 +1,6 @@ """Test Base Schema of documents.""" -from typing import Iterator, List +from collections.abc import Iterator import pytest @@ -33,7 +33,7 @@ def lazy_parse(self, blob: Blob) -> Iterator[Document]: def test_default_lazy_load() -> None: class FakeLoader(BaseLoader): - def load(self) -> List[Document]: + def load(self) -> list[Document]: return [ Document(page_content="foo"), Document(page_content="bar"), diff --git a/libs/core/tests/unit_tests/example_selectors/test_base.py b/libs/core/tests/unit_tests/example_selectors/test_base.py index 5ab9ed7c2c022..54793627987e0 100644 --- a/libs/core/tests/unit_tests/example_selectors/test_base.py +++ b/libs/core/tests/unit_tests/example_selectors/test_base.py @@ -1,16 +1,16 @@ -from typing import Dict, List, Optional +from typing import Optional from langchain_core.example_selectors import BaseExampleSelector class DummyExampleSelector(BaseExampleSelector): def __init__(self) -> None: - self.example: Optional[Dict[str, str]] = None + self.example: Optional[dict[str, str]] = None - def add_example(self, example: Dict[str, str]) -> None: + def add_example(self, example: dict[str, str]) -> None: self.example = example - def select_examples(self, input_variables: Dict[str, str]) -> List[dict]: + def select_examples(self, input_variables: dict[str, str]) -> list[dict]: return [input_variables] diff --git a/libs/core/tests/unit_tests/example_selectors/test_similarity.py b/libs/core/tests/unit_tests/example_selectors/test_similarity.py index 2cd50ca8dd2e1..5a5f40d197a25 100644 --- a/libs/core/tests/unit_tests/example_selectors/test_similarity.py +++ b/libs/core/tests/unit_tests/example_selectors/test_similarity.py @@ -1,4 +1,5 @@ -from typing import Any, Iterable, List, Optional, cast +from collections.abc import Iterable +from typing import Any, Optional, cast from langchain_core.documents import Document from langchain_core.embeddings import Embeddings, FakeEmbeddings @@ -11,8 +12,8 @@ class DummyVectorStore(VectorStore): def __init__(self, init_arg: Optional[str] = None): - self.texts: List[str] = [] - self.metadatas: List[dict] = [] + self.texts: list[str] = [] + self.metadatas: list[dict] = [] self._embeddings: Optional[Embeddings] = None self.init_arg = init_arg @@ -23,9 +24,9 @@ def embeddings(self) -> Optional[Embeddings]: def add_texts( self, texts: Iterable[str], - metadatas: Optional[List[dict]] = None, + metadatas: Optional[list[dict]] = None, **kwargs: Any, - ) -> List[str]: + ) -> list[str]: self.texts.extend(texts) if metadatas: self.metadatas.extend(metadatas) @@ -33,7 +34,7 @@ def add_texts( def similarity_search( self, query: str, k: int = 4, **kwargs: Any - ) -> List[Document]: + ) -> list[Document]: return [ Document( page_content=query, metadata={"query": query, "k": k, "other": "other"} @@ -47,7 +48,7 @@ def max_marginal_relevance_search( fetch_k: int = 20, lambda_mult: float = 0.5, **kwargs: Any, - ) -> List[Document]: + ) -> list[Document]: return [ Document( page_content=query, @@ -58,9 +59,9 @@ def max_marginal_relevance_search( @classmethod def from_texts( cls, - texts: List[str], + texts: list[str], embedding: Embeddings, - metadatas: Optional[List[dict]] = None, + metadatas: Optional[list[dict]] = None, **kwargs: Any, ) -> "DummyVectorStore": store = DummyVectorStore(**kwargs) diff --git a/libs/core/tests/unit_tests/fake/callbacks.py b/libs/core/tests/unit_tests/fake/callbacks.py index 21efc2c176cd8..51ef0fe6c9d81 100644 --- a/libs/core/tests/unit_tests/fake/callbacks.py +++ b/libs/core/tests/unit_tests/fake/callbacks.py @@ -1,7 +1,7 @@ """A fake callback handler for testing purposes.""" from itertools import chain -from typing import Any, Dict, List, Optional, Union +from typing import Any, Optional, Union from uuid import UUID from pydantic import BaseModel @@ -16,7 +16,7 @@ class BaseFakeCallbackHandler(BaseModel): starts: int = 0 ends: int = 0 errors: int = 0 - errors_args: List[Any] = [] + errors_args: list[Any] = [] text: int = 0 ignore_llm_: bool = False ignore_chain_: bool = False @@ -265,8 +265,8 @@ def __deepcopy__(self, memo: dict) -> "FakeCallbackHandler": # type: ignore class FakeCallbackHandlerWithChatStart(FakeCallbackHandler): def on_chat_model_start( self, - serialized: Dict[str, Any], - messages: List[List[BaseMessage]], + serialized: dict[str, Any], + messages: list[list[BaseMessage]], *, run_id: UUID, parent_run_id: Optional[UUID] = None, diff --git a/libs/core/tests/unit_tests/fake/test_fake_chat_model.py b/libs/core/tests/unit_tests/fake/test_fake_chat_model.py index 873ca5deec3be..4e1e836f0b949 100644 --- a/libs/core/tests/unit_tests/fake/test_fake_chat_model.py +++ b/libs/core/tests/unit_tests/fake/test_fake_chat_model.py @@ -1,7 +1,7 @@ """Tests for verifying that testing utility code works as expected.""" from itertools import cycle -from typing import Any, Dict, List, Optional, Union +from typing import Any, Optional, Union from uuid import UUID from langchain_core.callbacks.base import AsyncCallbackHandler @@ -149,18 +149,18 @@ async def test_callback_handlers() -> None: """Verify that model is implemented correctly with handlers working.""" class MyCustomAsyncHandler(AsyncCallbackHandler): - def __init__(self, store: List[str]) -> None: + def __init__(self, store: list[str]) -> None: self.store = store async def on_chat_model_start( self, - serialized: Dict[str, Any], - messages: List[List[BaseMessage]], + serialized: dict[str, Any], + messages: list[list[BaseMessage]], *, run_id: UUID, parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, + tags: Optional[list[str]] = None, + metadata: Optional[dict[str, Any]] = None, **kwargs: Any, ) -> Any: # Do nothing @@ -174,7 +174,7 @@ async def on_llm_new_token( chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None, run_id: UUID, parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, + tags: Optional[list[str]] = None, **kwargs: Any, ) -> None: self.store.append(token) @@ -185,7 +185,7 @@ async def on_llm_new_token( ] ) model = GenericFakeChatModel(messages=infinite_cycle) - tokens: List[str] = [] + tokens: list[str] = [] # New model results = list(model.stream("meow", {"callbacks": [MyCustomAsyncHandler(tokens)]})) assert results == [ diff --git a/libs/core/tests/unit_tests/indexing/test_in_memory_indexer.py b/libs/core/tests/unit_tests/indexing/test_in_memory_indexer.py index 9e567628138f1..6ddeefc62330b 100644 --- a/libs/core/tests/unit_tests/indexing/test_in_memory_indexer.py +++ b/libs/core/tests/unit_tests/indexing/test_in_memory_indexer.py @@ -1,6 +1,6 @@ """Test in memory indexer""" -from typing import AsyncGenerator, Generator +from collections.abc import AsyncGenerator, Generator import pytest from langchain_standard_tests.integration_tests.indexer import ( diff --git a/libs/core/tests/unit_tests/indexing/test_indexing.py b/libs/core/tests/unit_tests/indexing/test_indexing.py index d4bd4e1ec2db6..5bbbf49e392e1 100644 --- a/libs/core/tests/unit_tests/indexing/test_indexing.py +++ b/libs/core/tests/unit_tests/indexing/test_indexing.py @@ -1,10 +1,7 @@ +from collections.abc import AsyncIterator, Iterable, Iterator, Sequence from datetime import datetime from typing import ( Any, - AsyncIterator, - Iterable, - Iterator, - Sequence, ) from unittest.mock import AsyncMock, MagicMock, patch diff --git a/libs/core/tests/unit_tests/language_models/chat_models/test_base.py b/libs/core/tests/unit_tests/language_models/chat_models/test_base.py index f14c2a1b8d04d..499711d0c9fc5 100644 --- a/libs/core/tests/unit_tests/language_models/chat_models/test_base.py +++ b/libs/core/tests/unit_tests/language_models/chat_models/test_base.py @@ -1,7 +1,8 @@ """Test base chat model.""" import uuid -from typing import Any, AsyncIterator, Iterator, List, Literal, Optional, Union +from collections.abc import AsyncIterator, Iterator +from typing import Any, Literal, Optional, Union import pytest @@ -130,8 +131,8 @@ async def test_astream_fallback_to_ainvoke() -> None: class ModelWithGenerate(BaseChatModel): def _generate( self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, + messages: list[BaseMessage], + stop: Optional[list[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: @@ -158,8 +159,8 @@ async def test_astream_implementation_fallback_to_stream() -> None: class ModelWithSyncStream(BaseChatModel): def _generate( self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, + messages: list[BaseMessage], + stop: Optional[list[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: @@ -168,8 +169,8 @@ def _generate( def _stream( self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, + messages: list[BaseMessage], + stop: Optional[list[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: @@ -203,8 +204,8 @@ async def test_astream_implementation_uses_astream() -> None: class ModelWithAsyncStream(BaseChatModel): def _generate( self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, + messages: list[BaseMessage], + stop: Optional[list[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: @@ -213,8 +214,8 @@ def _generate( async def _astream( # type: ignore self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, + messages: list[BaseMessage], + stop: Optional[list[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> AsyncIterator[ChatGenerationChunk]: @@ -279,8 +280,8 @@ async def test_async_pass_run_id() -> None: class NoStreamingModel(BaseChatModel): def _generate( self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, + messages: list[BaseMessage], + stop: Optional[list[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: @@ -294,8 +295,8 @@ def _llm_type(self) -> str: class StreamingModel(NoStreamingModel): def _stream( self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, + messages: list[BaseMessage], + stop: Optional[list[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: diff --git a/libs/core/tests/unit_tests/language_models/chat_models/test_cache.py b/libs/core/tests/unit_tests/language_models/chat_models/test_cache.py index 99ab278addc23..14c50a854ba47 100644 --- a/libs/core/tests/unit_tests/language_models/chat_models/test_cache.py +++ b/libs/core/tests/unit_tests/language_models/chat_models/test_cache.py @@ -1,6 +1,6 @@ """Module tests interaction of chat model with caching abstraction..""" -from typing import Any, Dict, Optional, Tuple +from typing import Any, Optional import pytest @@ -20,7 +20,7 @@ class InMemoryCache(BaseCache): def __init__(self) -> None: """Initialize with empty cache.""" - self._cache: Dict[Tuple[str, str], RETURN_VAL_TYPE] = {} + self._cache: dict[tuple[str, str], RETURN_VAL_TYPE] = {} def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: """Look up based on prompt and llm_string.""" diff --git a/libs/core/tests/unit_tests/language_models/llms/test_base.py b/libs/core/tests/unit_tests/language_models/llms/test_base.py index 3b61653a80b29..273bf9e0b8b28 100644 --- a/libs/core/tests/unit_tests/language_models/llms/test_base.py +++ b/libs/core/tests/unit_tests/language_models/llms/test_base.py @@ -1,4 +1,5 @@ -from typing import Any, AsyncIterator, Iterator, List, Optional +from collections.abc import AsyncIterator, Iterator +from typing import Any, Optional import pytest @@ -128,8 +129,8 @@ async def test_astream_fallback_to_ainvoke() -> None: class ModelWithGenerate(BaseLLM): def _generate( self, - prompts: List[str], - stop: Optional[List[str]] = None, + prompts: list[str], + stop: Optional[list[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> LLMResult: @@ -154,8 +155,8 @@ async def test_astream_implementation_fallback_to_stream() -> None: class ModelWithSyncStream(BaseLLM): def _generate( self, - prompts: List[str], - stop: Optional[List[str]] = None, + prompts: list[str], + stop: Optional[list[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> LLMResult: @@ -165,7 +166,7 @@ def _generate( def _stream( self, prompt: str, - stop: Optional[List[str]] = None, + stop: Optional[list[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Iterator[GenerationChunk]: @@ -191,8 +192,8 @@ async def test_astream_implementation_uses_astream() -> None: class ModelWithAsyncStream(BaseLLM): def _generate( self, - prompts: List[str], - stop: Optional[List[str]] = None, + prompts: list[str], + stop: Optional[list[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> LLMResult: @@ -202,7 +203,7 @@ def _generate( async def _astream( self, prompt: str, - stop: Optional[List[str]] = None, + stop: Optional[list[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> AsyncIterator[GenerationChunk]: diff --git a/libs/core/tests/unit_tests/language_models/llms/test_cache.py b/libs/core/tests/unit_tests/language_models/llms/test_cache.py index 7e8bf003a97cd..4c5eb04cb9cce 100644 --- a/libs/core/tests/unit_tests/language_models/llms/test_cache.py +++ b/libs/core/tests/unit_tests/language_models/llms/test_cache.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional, Tuple +from typing import Any, Optional from langchain_core.caches import RETURN_VAL_TYPE, BaseCache from langchain_core.globals import set_llm_cache @@ -10,7 +10,7 @@ class InMemoryCache(BaseCache): def __init__(self) -> None: """Initialize with empty cache.""" - self._cache: Dict[Tuple[str, str], RETURN_VAL_TYPE] = {} + self._cache: dict[tuple[str, str], RETURN_VAL_TYPE] = {} def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: """Look up based on prompt and llm_string.""" @@ -62,7 +62,7 @@ class InMemoryCacheBad(BaseCache): def __init__(self) -> None: """Initialize with empty cache.""" - self._cache: Dict[Tuple[str, str], RETURN_VAL_TYPE] = {} + self._cache: dict[tuple[str, str], RETURN_VAL_TYPE] = {} def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: """Look up based on prompt and llm_string.""" diff --git a/libs/core/tests/unit_tests/load/test_serializable.py b/libs/core/tests/unit_tests/load/test_serializable.py index cd33b1d2cbc40..a8b03a801e6d8 100644 --- a/libs/core/tests/unit_tests/load/test_serializable.py +++ b/libs/core/tests/unit_tests/load/test_serializable.py @@ -1,5 +1,3 @@ -from typing import Dict - from pydantic import ConfigDict, Field from langchain_core.load import Serializable, dumpd, load @@ -56,7 +54,7 @@ def is_lc_serializable(cls) -> bool: return True @property - def lc_secrets(self) -> Dict[str, str]: + def lc_secrets(self) -> dict[str, str]: return {"secret": "MASKED_SECRET", "secret_2": "MASKED_SECRET_2"} foo = Foo( diff --git a/libs/core/tests/unit_tests/messages/test_utils.py b/libs/core/tests/unit_tests/messages/test_utils.py index 63f8b3a7b7740..15539a19e3dc5 100644 --- a/libs/core/tests/unit_tests/messages/test_utils.py +++ b/libs/core/tests/unit_tests/messages/test_utils.py @@ -1,5 +1,4 @@ import json -from typing import Dict, List, Type import pytest @@ -21,7 +20,7 @@ @pytest.mark.parametrize("msg_cls", [HumanMessage, AIMessage, SystemMessage]) -def test_merge_message_runs_str(msg_cls: Type[BaseMessage]) -> None: +def test_merge_message_runs_str(msg_cls: type[BaseMessage]) -> None: messages = [msg_cls("foo"), msg_cls("bar"), msg_cls("baz")] messages_model_copy = [m.model_copy(deep=True) for m in messages] expected = [msg_cls("foo\nbar\nbaz")] @@ -32,7 +31,7 @@ def test_merge_message_runs_str(msg_cls: Type[BaseMessage]) -> None: @pytest.mark.parametrize("msg_cls", [HumanMessage, AIMessage, SystemMessage]) def test_merge_message_runs_str_with_specified_separator( - msg_cls: Type[BaseMessage], + msg_cls: type[BaseMessage], ) -> None: messages = [msg_cls("foo"), msg_cls("bar"), msg_cls("baz")] messages_model_copy = [m.model_copy(deep=True) for m in messages] @@ -44,7 +43,7 @@ def test_merge_message_runs_str_with_specified_separator( @pytest.mark.parametrize("msg_cls", [HumanMessage, AIMessage, SystemMessage]) def test_merge_message_runs_str_without_separator( - msg_cls: Type[BaseMessage], + msg_cls: type[BaseMessage], ) -> None: messages = [msg_cls("foo"), msg_cls("bar"), msg_cls("baz")] messages_model_copy = [m.model_copy(deep=True) for m in messages] @@ -127,7 +126,7 @@ def test_merge_messages_tool_messages() -> None: {"include_names": ["blah", "blur"], "exclude_types": [SystemMessage]}, ], ) -def test_filter_message(filters: Dict) -> None: +def test_filter_message(filters: dict) -> None: messages = [ SystemMessage("foo", name="blah", id="1"), HumanMessage("bar", name="blur", id="2"), @@ -306,7 +305,7 @@ def test_trim_messages_allow_partial_text_splitter() -> None: AIMessage("This is a 4 token text.", id="fourth"), ] - def count_words(msgs: List[BaseMessage]) -> int: + def count_words(msgs: list[BaseMessage]) -> int: count = 0 for msg in msgs: if isinstance(msg.content, str): @@ -317,7 +316,7 @@ def count_words(msgs: List[BaseMessage]) -> int: ) return count - def _split_on_space(text: str) -> List[str]: + def _split_on_space(text: str) -> list[str]: splits = text.split(" ") return [s + " " for s in splits[:-1]] + splits[-1:] @@ -356,7 +355,7 @@ def test_trim_messages_bad_token_counter() -> None: trimmer.invoke([HumanMessage("foobar")]) -def dummy_token_counter(messages: List[BaseMessage]) -> int: +def dummy_token_counter(messages: list[BaseMessage]) -> int: # treat each message like it adds 3 default tokens at the beginning # of the message and at the end of the message. 3 + 4 + 3 = 10 tokens # per message. @@ -381,12 +380,12 @@ def dummy_token_counter(messages: List[BaseMessage]) -> int: class FakeTokenCountingModel(FakeChatModel): - def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: + def get_num_tokens_from_messages(self, messages: list[BaseMessage]) -> int: return dummy_token_counter(messages) def test_convert_to_messages() -> None: - message_like: List = [ + message_like: list = [ # BaseMessage SystemMessage("1"), HumanMessage([{"type": "image_url", "image_url": {"url": "2.1"}}], name="2.2"), diff --git a/libs/core/tests/unit_tests/output_parsers/test_base_parsers.py b/libs/core/tests/unit_tests/output_parsers/test_base_parsers.py index 6e9eea83ba8b8..65e4f580862f7 100644 --- a/libs/core/tests/unit_tests/output_parsers/test_base_parsers.py +++ b/libs/core/tests/unit_tests/output_parsers/test_base_parsers.py @@ -1,6 +1,5 @@ """Module to test base parser implementations.""" -from typing import List from typing import Optional as Optional from langchain_core.exceptions import OutputParserException @@ -20,7 +19,7 @@ class StrInvertCase(BaseGenerationOutputParser[str]): """An example parser that inverts the case of the characters in the message.""" def parse_result( - self, result: List[Generation], *, partial: bool = False + self, result: list[Generation], *, partial: bool = False ) -> str: """Parse a list of model Generations into a specific format. @@ -65,7 +64,7 @@ def parse(self, text: str) -> str: raise NotImplementedError() def parse_result( - self, result: List[Generation], *, partial: bool = False + self, result: list[Generation], *, partial: bool = False ) -> str: """Parse a list of model Generations into a specific format. diff --git a/libs/core/tests/unit_tests/output_parsers/test_json.py b/libs/core/tests/unit_tests/output_parsers/test_json.py index 0cc9dd699f218..9753ff98ca385 100644 --- a/libs/core/tests/unit_tests/output_parsers/test_json.py +++ b/libs/core/tests/unit_tests/output_parsers/test_json.py @@ -1,5 +1,6 @@ import json -from typing import Any, AsyncIterator, Iterator, Tuple +from collections.abc import AsyncIterator, Iterator +from typing import Any import pytest from pydantic import BaseModel @@ -245,7 +246,7 @@ def test_parse_json_with_python_dict() -> None: @pytest.mark.parametrize("json_strings", TEST_CASES_PARTIAL) -def test_parse_partial_json(json_strings: Tuple[str, str]) -> None: +def test_parse_partial_json(json_strings: tuple[str, str]) -> None: case, expected = json_strings parsed = parse_partial_json(case) assert parsed == json.loads(expected) diff --git a/libs/core/tests/unit_tests/output_parsers/test_list_parser.py b/libs/core/tests/unit_tests/output_parsers/test_list_parser.py index ef3a00bc5f7ec..3f43edfa2aed8 100644 --- a/libs/core/tests/unit_tests/output_parsers/test_list_parser.py +++ b/libs/core/tests/unit_tests/output_parsers/test_list_parser.py @@ -1,4 +1,5 @@ -from typing import AsyncIterator, Iterable, List, TypeVar, cast +from collections.abc import AsyncIterator, Iterable +from typing import TypeVar, cast from langchain_core.output_parsers.list import ( CommaSeparatedListOutputParser, @@ -79,7 +80,7 @@ def test_numbered_list() -> None: (text2, ["apple", "banana", "cherry"]), (text3, []), ]: - expectedlist = [[a] for a in cast(List[str], expected)] + expectedlist = [[a] for a in cast(list[str], expected)] assert parser.parse(text) == expected assert add(parser.transform(t for t in text)) == (expected or None) assert list(parser.transform(t for t in text)) == expectedlist @@ -114,7 +115,7 @@ def test_markdown_list() -> None: (text2, ["apple", "banana", "cherry"]), (text3, []), ]: - expectedlist = [[a] for a in cast(List[str], expected)] + expectedlist = [[a] for a in cast(list[str], expected)] assert parser.parse(text) == expected assert add(parser.transform(t for t in text)) == (expected or None) assert list(parser.transform(t for t in text)) == expectedlist @@ -217,7 +218,7 @@ async def test_numbered_list_async() -> None: (text2, ["apple", "banana", "cherry"]), (text3, []), ]: - expectedlist = [[a] for a in cast(List[str], expected)] + expectedlist = [[a] for a in cast(list[str], expected)] assert await parser.aparse(text) == expected assert await aadd(parser.atransform(aiter_from_iter(t for t in text))) == ( expected or None @@ -260,7 +261,7 @@ async def test_markdown_list_async() -> None: (text2, ["apple", "banana", "cherry"]), (text3, []), ]: - expectedlist = [[a] for a in cast(List[str], expected)] + expectedlist = [[a] for a in cast(list[str], expected)] assert await parser.aparse(text) == expected assert await aadd(parser.atransform(aiter_from_iter(t for t in text))) == ( expected or None diff --git a/libs/core/tests/unit_tests/output_parsers/test_openai_functions.py b/libs/core/tests/unit_tests/output_parsers/test_openai_functions.py index 0e20295c15f8a..c0959620e420c 100644 --- a/libs/core/tests/unit_tests/output_parsers/test_openai_functions.py +++ b/libs/core/tests/unit_tests/output_parsers/test_openai_functions.py @@ -1,5 +1,5 @@ import json -from typing import Any, Dict +from typing import Any import pytest from pydantic import BaseModel @@ -85,7 +85,7 @@ def test_json_output_function_parser() -> None: }, ], ) -def test_json_output_function_parser_strictness(config: Dict[str, Any]) -> None: +def test_json_output_function_parser_strictness(config: dict[str, Any]) -> None: """Test parsing with JSON strictness on and off.""" args = config["args"] diff --git a/libs/core/tests/unit_tests/output_parsers/test_openai_tools.py b/libs/core/tests/unit_tests/output_parsers/test_openai_tools.py index c961170dea079..63c2d8ea1eee8 100644 --- a/libs/core/tests/unit_tests/output_parsers/test_openai_tools.py +++ b/libs/core/tests/unit_tests/output_parsers/test_openai_tools.py @@ -1,4 +1,5 @@ -from typing import Any, AsyncIterator, Iterator, List +from collections.abc import AsyncIterator, Iterator +from typing import Any import pytest from pydantic import BaseModel, Field @@ -483,7 +484,7 @@ class Person(BaseModel): class NameCollector(BaseModel): """record names of all people mentioned""" - names: List[str] = Field(..., description="all names mentioned") + names: list[str] = Field(..., description="all names mentioned") person: Person = Field(..., description="info about the main subject") diff --git a/libs/core/tests/unit_tests/output_parsers/test_xml_parser.py b/libs/core/tests/unit_tests/output_parsers/test_xml_parser.py index a57a129f1bb8a..60826c439e500 100644 --- a/libs/core/tests/unit_tests/output_parsers/test_xml_parser.py +++ b/libs/core/tests/unit_tests/output_parsers/test_xml_parser.py @@ -1,7 +1,7 @@ """Test XMLOutputParser""" import importlib -from typing import AsyncIterator, Iterable +from collections.abc import AsyncIterator, Iterable import pytest diff --git a/libs/core/tests/unit_tests/prompts/test_chat.py b/libs/core/tests/unit_tests/prompts/test_chat.py index 7fc17d128d610..72056f6c5a22c 100644 --- a/libs/core/tests/unit_tests/prompts/test_chat.py +++ b/libs/core/tests/unit_tests/prompts/test_chat.py @@ -1,7 +1,7 @@ import base64 import tempfile from pathlib import Path -from typing import Any, List, Tuple, Union, cast +from typing import Any, Union, cast import pytest from pydantic import ValidationError @@ -35,7 +35,7 @@ @pytest.fixture -def messages() -> List[BaseMessagePromptTemplate]: +def messages() -> list[BaseMessagePromptTemplate]: """Create messages.""" system_message_prompt = SystemMessagePromptTemplate( prompt=PromptTemplate( @@ -72,7 +72,7 @@ def messages() -> List[BaseMessagePromptTemplate]: @pytest.fixture def chat_prompt_template( - messages: List[BaseMessagePromptTemplate], + messages: list[BaseMessagePromptTemplate], ) -> ChatPromptTemplate: """Create a chat prompt template.""" return ChatPromptTemplate( @@ -227,7 +227,7 @@ async def test_chat_prompt_template(chat_prompt_template: ChatPromptTemplate) -> def test_chat_prompt_template_from_messages( - messages: List[BaseMessagePromptTemplate], + messages: list[BaseMessagePromptTemplate], ) -> None: """Test creating a chat prompt template from messages.""" chat_prompt_template = ChatPromptTemplate.from_messages(messages) @@ -299,7 +299,7 @@ def test_chat_prompt_template_from_messages_mustache() -> None: def test_chat_prompt_template_with_messages( - messages: List[BaseMessagePromptTemplate], + messages: list[BaseMessagePromptTemplate], ) -> None: chat_prompt_template = ChatPromptTemplate.from_messages( messages + [HumanMessage(content="foo")] @@ -828,7 +828,7 @@ async def test_chat_tmpl_serdes(snapshot: SnapshotAssertion) -> None: ("system", [{"text": "You are an AI assistant named {name}."}]), SystemMessagePromptTemplate.from_template("you are {foo}"), cast( - Tuple, + tuple, ( "human", [ diff --git a/libs/core/tests/unit_tests/prompts/test_few_shot.py b/libs/core/tests/unit_tests/prompts/test_few_shot.py index 8c3cc523c23cf..27b0208285fe8 100644 --- a/libs/core/tests/unit_tests/prompts/test_few_shot.py +++ b/libs/core/tests/unit_tests/prompts/test_few_shot.py @@ -1,6 +1,7 @@ """Test few shot prompt template.""" -from typing import Any, Dict, List, Sequence, Tuple +from collections.abc import Sequence +from typing import Any import pytest @@ -25,7 +26,7 @@ @pytest.fixture() @pytest.mark.requires("jinja2") -def example_jinja2_prompt() -> Tuple[PromptTemplate, List[Dict[str, str]]]: +def example_jinja2_prompt() -> tuple[PromptTemplate, list[dict[str, str]]]: example_template = "{{ word }}: {{ antonym }}" examples = [ @@ -227,7 +228,7 @@ def test_partial() -> None: @pytest.mark.requires("jinja2") def test_prompt_jinja2_functionality( - example_jinja2_prompt: Tuple[PromptTemplate, List[Dict[str, str]]], + example_jinja2_prompt: tuple[PromptTemplate, list[dict[str, str]]], ) -> None: prefix = "Starting with {{ foo }}" suffix = "Ending with {{ bar }}" @@ -250,7 +251,7 @@ def test_prompt_jinja2_functionality( @pytest.mark.requires("jinja2") def test_prompt_jinja2_missing_input_variables( - example_jinja2_prompt: Tuple[PromptTemplate, List[Dict[str, str]]], + example_jinja2_prompt: tuple[PromptTemplate, list[dict[str, str]]], ) -> None: """Test error is raised when input variables are not provided.""" prefix = "Starting with {{ foo }}" @@ -297,7 +298,7 @@ def test_prompt_jinja2_missing_input_variables( @pytest.mark.requires("jinja2") def test_prompt_jinja2_extra_input_variables( - example_jinja2_prompt: Tuple[PromptTemplate, List[Dict[str, str]]], + example_jinja2_prompt: tuple[PromptTemplate, list[dict[str, str]]], ) -> None: """Test error is raised when there are too many input variables.""" prefix = "Starting with {{ foo }}" @@ -368,14 +369,14 @@ class AsIsSelector(BaseExampleSelector): This selector returns the examples as-is. """ - def __init__(self, examples: Sequence[Dict[str, str]]) -> None: + def __init__(self, examples: Sequence[dict[str, str]]) -> None: """Initializes the selector.""" self.examples = examples - def add_example(self, example: Dict[str, str]) -> Any: + def add_example(self, example: dict[str, str]) -> Any: raise NotImplementedError - def select_examples(self, input_variables: Dict[str, str]) -> List[dict]: + def select_examples(self, input_variables: dict[str, str]) -> list[dict]: return list(self.examples) @@ -463,17 +464,17 @@ class AsyncAsIsSelector(BaseExampleSelector): This selector returns the examples as-is. """ - def __init__(self, examples: Sequence[Dict[str, str]]) -> None: + def __init__(self, examples: Sequence[dict[str, str]]) -> None: """Initializes the selector.""" self.examples = examples - def add_example(self, example: Dict[str, str]) -> Any: + def add_example(self, example: dict[str, str]) -> Any: raise NotImplementedError - def select_examples(self, input_variables: Dict[str, str]) -> List[dict]: + def select_examples(self, input_variables: dict[str, str]) -> list[dict]: raise NotImplementedError - async def aselect_examples(self, input_variables: Dict[str, str]) -> List[dict]: + async def aselect_examples(self, input_variables: dict[str, str]) -> list[dict]: return list(self.examples) diff --git a/libs/core/tests/unit_tests/prompts/test_loading.py b/libs/core/tests/unit_tests/prompts/test_loading.py index d7092aa94c630..cd9423de5a732 100644 --- a/libs/core/tests/unit_tests/prompts/test_loading.py +++ b/libs/core/tests/unit_tests/prompts/test_loading.py @@ -1,9 +1,9 @@ """Test loading functionality.""" import os +from collections.abc import Iterator from contextlib import contextmanager from pathlib import Path -from typing import Iterator import pytest diff --git a/libs/core/tests/unit_tests/prompts/test_prompt.py b/libs/core/tests/unit_tests/prompts/test_prompt.py index bd73c93773c18..e0fb05f11a589 100644 --- a/libs/core/tests/unit_tests/prompts/test_prompt.py +++ b/libs/core/tests/unit_tests/prompts/test_prompt.py @@ -1,6 +1,6 @@ """Test functionality related to prompts.""" -from typing import Any, Dict, Union +from typing import Any, Union from unittest import mock import pydantic @@ -610,7 +610,7 @@ async def test_prompt_ainvoke_with_metadata() -> None: ) @pytest.mark.parametrize("template_format", ["f-string", "mustache"]) def test_prompt_falsy_vars( - template_format: str, value: Any, expected: Union[str, Dict[str, str]] + template_format: str, value: Any, expected: Union[str, dict[str, str]] ) -> None: # each line is value, f-string, mustache if template_format == "f-string": diff --git a/libs/core/tests/unit_tests/prompts/test_structured.py b/libs/core/tests/unit_tests/prompts/test_structured.py index a405b17f35c4e..8b9d706be500d 100644 --- a/libs/core/tests/unit_tests/prompts/test_structured.py +++ b/libs/core/tests/unit_tests/prompts/test_structured.py @@ -1,6 +1,6 @@ from functools import partial from inspect import isclass -from typing import Any, Dict, Type, Union, cast +from typing import Any, Union, cast from typing import Optional as Optional from pydantic import BaseModel @@ -14,12 +14,12 @@ def _fake_runnable( - input: Any, *, schema: Union[Dict, Type[BaseModel]], value: Any = 42, **_: Any -) -> Union[BaseModel, Dict]: + input: Any, *, schema: Union[dict, type[BaseModel]], value: Any = 42, **_: Any +) -> Union[BaseModel, dict]: if isclass(schema) and is_basemodel_subclass(schema): return schema(name="yo", value=value) else: - params = cast(Dict, schema)["parameters"] + params = cast(dict, schema)["parameters"] return {k: 1 if k != "value" else value for k, v in params.items()} @@ -27,7 +27,7 @@ class FakeStructuredChatModel(FakeListChatModel): """Fake ChatModel for testing purposes.""" def with_structured_output( - self, schema: Union[Dict, Type[BaseModel]], **kwargs: Any + self, schema: Union[dict, type[BaseModel]], **kwargs: Any ) -> Runnable: return RunnableLambda(partial(_fake_runnable, schema=schema, **kwargs)) diff --git a/libs/core/tests/unit_tests/pydantic_utils.py b/libs/core/tests/unit_tests/pydantic_utils.py index a97ca6c0b257f..9f78feb3f5e20 100644 --- a/libs/core/tests/unit_tests/pydantic_utils.py +++ b/libs/core/tests/unit_tests/pydantic_utils.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any from pydantic import BaseModel @@ -99,7 +99,7 @@ def _schema(obj: Any) -> dict: return schema_ -def _normalize_schema(obj: Any) -> Dict[str, Any]: +def _normalize_schema(obj: Any) -> dict[str, Any]: """Generate a schema and normalize it. This will collapse single element allOfs into $ref. diff --git a/libs/core/tests/unit_tests/runnables/test_config.py b/libs/core/tests/unit_tests/runnables/test_config.py index 23eac0cfd73a8..f5243f4e78f9f 100644 --- a/libs/core/tests/unit_tests/runnables/test_config.py +++ b/libs/core/tests/unit_tests/runnables/test_config.py @@ -1,7 +1,7 @@ import json import uuid from contextvars import copy_context -from typing import Any, Dict, cast +from typing import Any, cast import pytest @@ -26,7 +26,7 @@ def test_ensure_config() -> None: run_id = str(uuid.uuid4()) - arg: Dict = { + arg: dict = { "something": "else", "metadata": {"foo": "bar"}, "configurable": {"baz": "qux"}, diff --git a/libs/core/tests/unit_tests/runnables/test_configurable.py b/libs/core/tests/unit_tests/runnables/test_configurable.py index 24f7f7be3d5a2..8bd6e760677c2 100644 --- a/libs/core/tests/unit_tests/runnables/test_configurable.py +++ b/libs/core/tests/unit_tests/runnables/test_configurable.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional +from typing import Any, Optional import pytest from pydantic import ConfigDict, Field, model_validator @@ -21,7 +21,7 @@ class MyRunnable(RunnableSerializable[str, str]): @model_validator(mode="before") @classmethod - def my_error(cls, values: Dict[str, Any]) -> Any: + def my_error(cls, values: dict[str, Any]) -> Any: if "_my_hidden_property" in values: raise ValueError("Cannot set _my_hidden_property") return values diff --git a/libs/core/tests/unit_tests/runnables/test_context.py b/libs/core/tests/unit_tests/runnables/test_context.py index 090b7df51127e..92a53b2c11a61 100644 --- a/libs/core/tests/unit_tests/runnables/test_context.py +++ b/libs/core/tests/unit_tests/runnables/test_context.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, List, NamedTuple, Union +from typing import Any, Callable, NamedTuple, Union import pytest @@ -331,7 +331,7 @@ def seq_naive_rag_scoped() -> Runnable: @pytest.mark.parametrize("runnable, cases", test_cases) async def test_context_runnables( - runnable: Union[Runnable, Callable[[], Runnable]], cases: List[_TestCase] + runnable: Union[Runnable, Callable[[], Runnable]], cases: list[_TestCase] ) -> None: runnable = runnable if isinstance(runnable, Runnable) else runnable() assert runnable.invoke(cases[0].input) == cases[0].output diff --git a/libs/core/tests/unit_tests/runnables/test_fallbacks.py b/libs/core/tests/unit_tests/runnables/test_fallbacks.py index 017122374cdb3..424b61025ab55 100644 --- a/libs/core/tests/unit_tests/runnables/test_fallbacks.py +++ b/libs/core/tests/unit_tests/runnables/test_fallbacks.py @@ -1,14 +1,8 @@ -import sys +from collections.abc import AsyncIterator, Iterator, Sequence from typing import ( Any, - AsyncIterator, Callable, - Dict, - Iterator, - List, Optional, - Sequence, - Type, Union, ) @@ -98,8 +92,7 @@ async def test_fallbacks( assert await runnable.ainvoke("hello") == "bar" assert await runnable.abatch(["hi", "hey", "bye"]) == ["bar"] * 3 assert list(await runnable.ainvoke("hello")) == list("bar") - if sys.version_info >= (3, 9): - assert dumps(runnable, pretty=True) == snapshot + assert dumps(runnable, pretty=True) == snapshot def _runnable(inputs: dict) -> str: @@ -316,8 +309,8 @@ class FakeStructuredOutputModel(BaseChatModel): def _generate( self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, + messages: list[BaseMessage], + stop: Optional[list[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: @@ -326,14 +319,14 @@ def _generate( def bind_tools( self, - tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]], + tools: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]], **kwargs: Any, ) -> Runnable[LanguageModelInput, BaseMessage]: return self.bind(tools=tools) def with_structured_output( - self, schema: Union[Dict, Type[BaseModel]], **kwargs: Any - ) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]: + self, schema: Union[dict, type[BaseModel]], **kwargs: Any + ) -> Runnable[LanguageModelInput, Union[dict, BaseModel]]: return RunnableLambda(lambda x: {"foo": self.foo}) @property @@ -346,8 +339,8 @@ class FakeModel(BaseChatModel): def _generate( self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, + messages: list[BaseMessage], + stop: Optional[list[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: @@ -356,7 +349,7 @@ def _generate( def bind_tools( self, - tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]], + tools: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]], **kwargs: Any, ) -> Runnable[LanguageModelInput, BaseMessage]: return self.bind(tools=tools) diff --git a/libs/core/tests/unit_tests/runnables/test_history.py b/libs/core/tests/unit_tests/runnables/test_history.py index 377e7a72457f1..acc0a8d7792a5 100644 --- a/libs/core/tests/unit_tests/runnables/test_history.py +++ b/libs/core/tests/unit_tests/runnables/test_history.py @@ -1,4 +1,5 @@ -from typing import Any, Callable, Dict, List, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Callable, Optional, Union import pytest from pydantic import BaseModel @@ -30,7 +31,7 @@ def test_interfaces() -> None: def _get_get_session_history( *, - store: Optional[Dict[str, Any]] = None, + store: Optional[dict[str, Any]] = None, ) -> Callable[..., InMemoryChatMessageHistory]: chat_history_store = store if store is not None else {} @@ -49,7 +50,7 @@ def test_input_messages() -> None: lambda messages: "you said: " + "\n".join(str(m.content) for m in messages if isinstance(m, HumanMessage)) ) - store: Dict = {} + store: dict = {} get_session_history = _get_get_session_history(store=store) with_history = RunnableWithMessageHistory(runnable, get_session_history) config: RunnableConfig = {"configurable": {"session_id": "1"}} @@ -78,7 +79,7 @@ async def test_input_messages_async() -> None: lambda messages: "you said: " + "\n".join(str(m.content) for m in messages if isinstance(m, HumanMessage)) ) - store: Dict = {} + store: dict = {} get_session_history = _get_get_session_history(store=store) with_history = RunnableWithMessageHistory(runnable, get_session_history) config = {"session_id": "1_async"} @@ -251,8 +252,8 @@ class LengthChatModel(BaseChatModel): def _generate( self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, + messages: list[BaseMessage], + stop: Optional[list[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: @@ -522,7 +523,7 @@ def test_get_input_schema_input_messages() -> None: def test_using_custom_config_specs() -> None: """Test that we can configure which keys should be passed to the session factory.""" - def _fake_llm(input: Dict[str, Any]) -> List[BaseMessage]: + def _fake_llm(input: dict[str, Any]) -> list[BaseMessage]: messages = input["messages"] return [ AIMessage( @@ -635,7 +636,7 @@ def get_session_history( async def test_using_custom_config_specs_async() -> None: """Test that we can configure which keys should be passed to the session factory.""" - def _fake_llm(input: Dict[str, Any]) -> List[BaseMessage]: + def _fake_llm(input: dict[str, Any]) -> list[BaseMessage]: messages = input["messages"] return [ AIMessage( @@ -748,7 +749,7 @@ def get_session_history( def test_ignore_session_id() -> None: """Test without config.""" - def _fake_llm(input: List[BaseMessage]) -> List[BaseMessage]: + def _fake_llm(input: list[BaseMessage]) -> list[BaseMessage]: return [ AIMessage( content="you said: " @@ -833,7 +834,7 @@ def test_get_output_messages_no_value_error() -> None: lambda messages: "you said: " + "\n".join(str(m.content) for m in messages if isinstance(m, HumanMessage)) ) - store: Dict = {} + store: dict = {} get_session_history = _get_get_session_history(store=store) with_history = RunnableWithMessageHistory(runnable, get_session_history) config: RunnableConfig = { @@ -850,7 +851,7 @@ def test_get_output_messages_no_value_error() -> None: def test_get_output_messages_with_value_error() -> None: illegal_bool_message = False runnable = _RunnableLambdaWithRaiseError(lambda messages: illegal_bool_message) - store: Dict = {} + store: dict = {} get_session_history = _get_get_session_history(store=store) with_history = RunnableWithMessageHistory(runnable, get_session_history) config: RunnableConfig = { diff --git a/libs/core/tests/unit_tests/runnables/test_runnable.py b/libs/core/tests/unit_tests/runnables/test_runnable.py index cb6d16dd1cc9e..e4cdc124b0189 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable.py @@ -1,18 +1,13 @@ import sys import uuid import warnings +from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence from functools import partial from operator import itemgetter from typing import ( Any, - AsyncIterator, - Awaitable, Callable, - Dict, - Iterator, - List, Optional, - Sequence, Union, cast, ) @@ -104,8 +99,8 @@ class FakeTracer(BaseTracer): def __init__(self) -> None: """Initialize the tracer.""" super().__init__() - self.runs: List[Run] = [] - self.uuids_map: Dict[UUID, UUID] = {} + self.runs: list[Run] = [] + self.uuids_map: dict[UUID, UUID] = {} self.uuids_generator = ( UUID(f"00000000-0000-4000-8000-{i:012}", version=4) for i in range(10000) ) @@ -164,7 +159,7 @@ def _persist_run(self, run: Run) -> None: self.runs.append(self._copy_run(run)) - def flattened_runs(self) -> List[Run]: + def flattened_runs(self) -> list[Run]: q = [] + self.runs result = [] while q: @@ -175,7 +170,7 @@ def flattened_runs(self) -> List[Run]: return result @property - def run_ids(self) -> List[Optional[uuid.UUID]]: + def run_ids(self) -> list[Optional[uuid.UUID]]: runs = self.flattened_runs() uuids_map = {v: k for k, v in self.uuids_map.items()} return [uuids_map.get(r.id) for r in runs] @@ -208,10 +203,10 @@ def _get_relevant_documents( query: str, *, callbacks: Callbacks = None, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, + tags: Optional[list[str]] = None, + metadata: Optional[dict[str, Any]] = None, **kwargs: Any, - ) -> List[Document]: + ) -> list[Document]: return [Document(page_content="foo"), Document(page_content="bar")] async def _aget_relevant_documents( @@ -219,10 +214,10 @@ async def _aget_relevant_documents( query: str, *, callbacks: Callbacks = None, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, + tags: Optional[list[str]] = None, + metadata: Optional[dict[str, Any]] = None, **kwargs: Any, - ) -> List[Document]: + ) -> list[Document]: return [Document(page_content="foo"), Document(page_content="bar")] @@ -660,8 +655,8 @@ def foo(x: int) -> None: # Try specifying some RunnableLambda(foo).with_types( - output_type=List[int], # type: ignore[arg-type] - input_type=List[int], # type: ignore[arg-type] + output_type=list[int], # type: ignore[arg-type] + input_type=list[int], # type: ignore[arg-type] ) RunnableLambda(foo).with_types( output_type=Sequence[int], # type: ignore[arg-type] @@ -2577,8 +2572,7 @@ def test_combining_sequences( assert chain.first == prompt assert chain.middle == [chat] assert chain.last == parser - if sys.version_info >= (3, 9): - assert dumps(chain, pretty=True) == snapshot + assert dumps(chain, pretty=True) == snapshot prompt2 = ( SystemMessagePromptTemplate.from_template("You are a nicer assistant.") @@ -2586,7 +2580,7 @@ def test_combining_sequences( ) chat2 = FakeListChatModel(responses=["baz, qux"]) parser2 = CommaSeparatedListOutputParser() - input_formatter: RunnableLambda[List[str], Dict[str, Any]] = RunnableLambda( + input_formatter: RunnableLambda[list[str], dict[str, Any]] = RunnableLambda( lambda x: {"question": x[0] + x[1]} ) @@ -2596,8 +2590,7 @@ def test_combining_sequences( assert chain2.first == input_formatter assert chain2.middle == [prompt2, chat2] assert chain2.last == parser2 - if sys.version_info >= (3, 9): - assert dumps(chain2, pretty=True) == snapshot + assert dumps(chain2, pretty=True) == snapshot combined_chain = cast(RunnableSequence, chain | chain2) @@ -2610,8 +2603,7 @@ def test_combining_sequences( chat2, ] assert combined_chain.last == parser2 - if sys.version_info >= (3, 9): - assert dumps(combined_chain, pretty=True) == snapshot + assert dumps(combined_chain, pretty=True) == snapshot # Test invoke tracer = FakeTracer() @@ -2619,8 +2611,7 @@ def test_combining_sequences( {"question": "What is your name?"}, dict(callbacks=[tracer]) ) == ["baz", "qux"] - if sys.version_info >= (3, 9): - assert tracer.runs == snapshot + assert tracer.runs == snapshot @freeze_time("2023-01-01") @@ -2827,7 +2818,7 @@ async def test_higher_order_lambda_runnable( input={"question": lambda x: x["question"]}, ) - def router(input: Dict[str, Any]) -> Runnable: + def router(input: dict[str, Any]) -> Runnable: if input["key"] == "math": return itemgetter("input") | math_chain elif input["key"] == "english": @@ -2836,8 +2827,7 @@ def router(input: Dict[str, Any]) -> Runnable: raise ValueError(f"Unknown key: {input['key']}") chain: Runnable = input_map | router - if sys.version_info >= (3, 9): - assert dumps(chain, pretty=True) == snapshot + assert dumps(chain, pretty=True) == snapshot result = chain.invoke({"key": "math", "question": "2 + 2"}) assert result == "4" @@ -2877,7 +2867,7 @@ def router(input: Dict[str, Any]) -> Runnable: assert len(math_run.child_runs) == 3 # Test ainvoke - async def arouter(input: Dict[str, Any]) -> Runnable: + async def arouter(input: dict[str, Any]) -> Runnable: if input["key"] == "math": return itemgetter("input") | math_chain elif input["key"] == "english": @@ -3656,7 +3646,7 @@ async def test_runnable_sequence_atransform() -> None: assert "".join(chunks) == "foo-lish" -class FakeSplitIntoListParser(BaseOutputParser[List[str]]): +class FakeSplitIntoListParser(BaseOutputParser[list[str]]): """Parse the output of an LLM call to a comma-separated list.""" @classmethod @@ -3670,7 +3660,7 @@ def get_format_instructions(self) -> str: "eg: `foo, bar, baz`" ) - def parse(self, text: str) -> List[str]: + def parse(self, text: str) -> list[str]: """Parse the output of an LLM call.""" return text.strip().split(", ") @@ -3851,7 +3841,7 @@ def _lambda(x: int) -> Union[int, Runnable]: def test_runnable_lambda_stream() -> None: """Test that stream works for both normal functions & those returning Runnable.""" # Normal output should work - output: List[Any] = [chunk for chunk in RunnableLambda(range).stream(5)] + output: list[Any] = [chunk for chunk in RunnableLambda(range).stream(5)] assert output == [range(5)] # Runnable output should also work @@ -3905,7 +3895,7 @@ async def afunc(*args: Any, **kwargs: Any) -> Any: return afunc # Normal output should work - output: List[Any] = [ + output: list[Any] = [ chunk async for chunk in RunnableLambda( func=id, @@ -3983,9 +3973,9 @@ def invoke(self, input: Any, config: Optional[RunnableConfig] = None) -> Any: def _batch( self, - inputs: List[str], - ) -> List: - outputs: List[Any] = [] + inputs: list[str], + ) -> list: + outputs: list[Any] = [] for input in inputs: if input.startswith(self.fail_starts_with): outputs.append(ValueError()) @@ -3995,12 +3985,12 @@ def _batch( def batch( self, - inputs: List[str], - config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + inputs: list[str], + config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None, *, return_exceptions: bool = False, **kwargs: Any, - ) -> List[str]: + ) -> list[str]: return self._batch_with_config( self._batch, inputs, @@ -4102,9 +4092,9 @@ def invoke(self, input: Any, config: Optional[RunnableConfig] = None) -> Any: async def _abatch( self, - inputs: List[str], - ) -> List: - outputs: List[Any] = [] + inputs: list[str], + ) -> list: + outputs: list[Any] = [] for input in inputs: if input.startswith(self.fail_starts_with): outputs.append(ValueError()) @@ -4114,12 +4104,12 @@ async def _abatch( async def abatch( self, - inputs: List[str], - config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + inputs: list[str], + config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None, *, return_exceptions: bool = False, **kwargs: Any, - ) -> List[str]: + ) -> list[str]: return await self._abatch_with_config( self._abatch, inputs, @@ -5197,13 +5187,13 @@ def test_transform_of_runnable_lambda_with_dicts() -> None: async def test_atransform_of_runnable_lambda_with_dicts() -> None: - async def identity(x: Dict[str, str]) -> Dict[str, str]: + async def identity(x: dict[str, str]) -> dict[str, str]: """Return x.""" return x - runnable = RunnableLambda[Dict[str, str], Dict[str, str]](identity) + runnable = RunnableLambda[dict[str, str], dict[str, str]](identity) - async def chunk_iterator() -> AsyncIterator[Dict[str, str]]: + async def chunk_iterator() -> AsyncIterator[dict[str, str]]: yield {"foo": "a"} yield {"foo": "n"} @@ -5224,7 +5214,7 @@ def invoke( ) -> Output: return cast(Output, input) # type: ignore - runnable = CustomRunnable[Dict[str, str], Dict[str, str]]() + runnable = CustomRunnable[dict[str, str], dict[str, str]]() chunks = iter( [ {"foo": "a"}, @@ -5245,9 +5235,9 @@ def invoke( ) -> Output: return cast(Output, input) - runnable = CustomRunnable[Dict[str, str], Dict[str, str]]() + runnable = CustomRunnable[dict[str, str], dict[str, str]]() - async def chunk_iterator() -> AsyncIterator[Dict[str, str]]: + async def chunk_iterator() -> AsyncIterator[dict[str, str]]: yield {"foo": "a"} yield {"foo": "n"} @@ -5256,7 +5246,7 @@ async def chunk_iterator() -> AsyncIterator[Dict[str, str]]: assert chunks == [{"foo": "n"}] # Test with addable dict - async def chunk_iterator_with_addable() -> AsyncIterator[Dict[str, str]]: + async def chunk_iterator_with_addable() -> AsyncIterator[dict[str, str]]: yield AddableDict({"foo": "a"}) yield AddableDict({"foo": "n"}) @@ -5278,7 +5268,7 @@ async def test_passthrough_atransform_with_dicts() -> None: """Test that default transform works with dicts.""" runnable = RunnablePassthrough(lambda x: x) - async def chunk_iterator() -> AsyncIterator[Dict[str, str]]: + async def chunk_iterator() -> AsyncIterator[dict[str, str]]: yield {"foo": "a"} yield {"foo": "n"} @@ -5364,7 +5354,7 @@ async def on_chain_error( *, run_id: UUID, parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, + tags: Optional[list[str]] = None, **kwargs: Any, ) -> None: """Run when chain errors.""" @@ -5389,7 +5379,7 @@ def test_pydantic_protected_namespaces() -> None: warnings.simplefilter("error") class CustomChatModel(RunnableSerializable): - model_kwargs: Dict[str, Any] = Field(default_factory=dict) + model_kwargs: dict[str, Any] = Field(default_factory=dict) def test_schema_for_prompt_and_chat_model() -> None: diff --git a/libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py b/libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py index f8876c1bfba13..d82cfd6d9ddc9 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py @@ -1,8 +1,9 @@ """Module that contains tests for runnable.astream_events API.""" import sys +from collections.abc import AsyncIterator, Sequence from itertools import cycle -from typing import Any, AsyncIterator, Dict, List, Sequence, cast +from typing import Any, cast from typing import Optional as Optional import pytest @@ -34,22 +35,22 @@ from tests.unit_tests.stubs import _AnyIdAIMessage, _AnyIdAIMessageChunk -def _with_nulled_run_id(events: Sequence[StreamEvent]) -> List[StreamEvent]: +def _with_nulled_run_id(events: Sequence[StreamEvent]) -> list[StreamEvent]: """Removes the run ids from events.""" for event in events: assert "parent_ids" in event, "Parent ids should be present in the event." assert event["parent_ids"] == [], "Parent ids should be empty." - return cast(List[StreamEvent], [{**event, "run_id": ""} for event in events]) + return cast(list[StreamEvent], [{**event, "run_id": ""} for event in events]) -async def _as_async_iterator(iterable: List) -> AsyncIterator: +async def _as_async_iterator(iterable: list) -> AsyncIterator: """Converts an iterable into an async iterator.""" for item in iterable: yield item -async def _collect_events(events: AsyncIterator[StreamEvent]) -> List[StreamEvent]: +async def _collect_events(events: AsyncIterator[StreamEvent]) -> list[StreamEvent]: """Collect the events and remove the run ids.""" materialized_events = [event async for event in events] events_ = _with_nulled_run_id(materialized_events) @@ -58,7 +59,7 @@ async def _collect_events(events: AsyncIterator[StreamEvent]) -> List[StreamEven return events_ -def _assert_events_equal_allow_superset_metadata(events: List, expected: List) -> None: +def _assert_events_equal_allow_superset_metadata(events: list, expected: list) -> None: """Assert that the events are equal.""" assert len(events) == len(expected) for i, (event, expected_event) in enumerate(zip(events, expected)): @@ -86,7 +87,7 @@ def foo(x: int) -> dict: return {"x": 5} @tool - def get_docs(x: int) -> List[Document]: + def get_docs(x: int) -> list[Document]: """Hello Doc""" return [Document(page_content="hello")] @@ -1172,11 +1173,11 @@ def with_parameters_and_callbacks(x: int, y: str, callbacks: Callbacks) -> dict: class HardCodedRetriever(BaseRetriever): - documents: List[Document] + documents: list[Document] def _get_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun - ) -> List[Document]: + ) -> list[Document]: return self.documents @@ -1265,7 +1266,7 @@ async def test_event_stream_with_retriever_and_formatter() -> None: ] ) - def format_docs(docs: List[Document]) -> str: + def format_docs(docs: list[Document]) -> str: """Format the docs.""" return ", ".join([doc.page_content for doc in docs]) @@ -1903,7 +1904,7 @@ def clear(self) -> None: # Here we use a global variable to store the chat message history. # This will make it easier to inspect it to see the underlying results. - store: Dict = {} + store: dict = {} def get_by_session_id(session_id: str) -> BaseChatMessageHistory: """Get a chat message history""" diff --git a/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py b/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py index fba80d291a49a..67383d28cde32 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py @@ -3,17 +3,12 @@ import asyncio import sys import uuid +from collections.abc import AsyncIterator, Iterable, Iterator, Sequence from functools import partial from itertools import cycle from typing import ( Any, - AsyncIterator, - Dict, - Iterable, - Iterator, - List, Optional, - Sequence, cast, ) @@ -55,7 +50,7 @@ from tests.unit_tests.stubs import _AnyIdAIMessage, _AnyIdAIMessageChunk -def _with_nulled_run_id(events: Sequence[StreamEvent]) -> List[StreamEvent]: +def _with_nulled_run_id(events: Sequence[StreamEvent]) -> list[StreamEvent]: """Removes the run ids from events.""" for event in events: assert "run_id" in event, f"Event {event} does not have a run_id." @@ -68,12 +63,12 @@ def _with_nulled_run_id(events: Sequence[StreamEvent]) -> List[StreamEvent]: ), f"Event {event} parent_ids is not a list." return cast( - List[StreamEvent], + list[StreamEvent], [{**event, "run_id": "", "parent_ids": []} for event in events], ) -async def _as_async_iterator(iterable: List) -> AsyncIterator: +async def _as_async_iterator(iterable: list) -> AsyncIterator: """Converts an iterable into an async iterator.""" for item in iterable: yield item @@ -81,7 +76,7 @@ async def _as_async_iterator(iterable: List) -> AsyncIterator: async def _collect_events( events: AsyncIterator[StreamEvent], with_nulled_ids: bool = True -) -> List[StreamEvent]: +) -> list[StreamEvent]: """Collect the events and remove the run ids.""" materialized_events = [event async for event in events] @@ -102,7 +97,7 @@ def foo(x: int) -> dict: return {"x": 5} @tool - def get_docs(x: int) -> List[Document]: + def get_docs(x: int) -> list[Document]: """Hello Doc""" return [Document(page_content="hello")] @@ -1162,11 +1157,11 @@ def with_parameters_and_callbacks(x: int, y: str, callbacks: Callbacks) -> dict: class HardCodedRetriever(BaseRetriever): - documents: List[Document] + documents: list[Document] def _get_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun - ) -> List[Document]: + ) -> list[Document]: return self.documents @@ -1236,7 +1231,7 @@ async def test_event_stream_with_retriever_and_formatter() -> None: ] ) - def format_docs(docs: List[Document]) -> str: + def format_docs(docs: list[Document]) -> str: """Format the docs.""" return ", ".join([doc.page_content for doc in docs]) @@ -1860,7 +1855,7 @@ def clear(self) -> None: # Here we use a global variable to store the chat message history. # This will make it easier to inspect it to see the underlying results. - store: Dict = {} + store: dict = {} def get_by_session_id(session_id: str) -> BaseChatMessageHistory: """Get a chat message history""" diff --git a/libs/core/tests/unit_tests/runnables/test_tracing_interops.py b/libs/core/tests/unit_tests/runnables/test_tracing_interops.py index b958636a91309..bcaec72cf9c61 100644 --- a/libs/core/tests/unit_tests/runnables/test_tracing_interops.py +++ b/libs/core/tests/unit_tests/runnables/test_tracing_interops.py @@ -1,6 +1,7 @@ import json import sys -from typing import Any, AsyncGenerator, Generator +from collections.abc import AsyncGenerator, Generator +from typing import Any from unittest.mock import MagicMock, patch import pytest diff --git a/libs/core/tests/unit_tests/runnables/test_utils.py b/libs/core/tests/unit_tests/runnables/test_utils.py index be6b35672497c..c07a33da3d335 100644 --- a/libs/core/tests/unit_tests/runnables/test_utils.py +++ b/libs/core/tests/unit_tests/runnables/test_utils.py @@ -1,5 +1,5 @@ import sys -from typing import Callable, Dict, Tuple +from typing import Callable import pytest @@ -47,7 +47,7 @@ def test_indent_lines_after_first(text: str, prefix: str, expected_output: str) def test_nonlocals() -> None: agent = RunnableLambda(lambda x: x * 2) - def my_func(input: str, agent: Dict[str, str]) -> str: + def my_func(input: str, agent: dict[str, str]) -> str: return agent.get("agent_name", input) def my_func2(input: str) -> str: @@ -59,7 +59,7 @@ def my_func3(input: str) -> str: def my_func4(input: str) -> str: return global_agent.invoke(input) - def my_func5() -> Tuple[Callable[[str], str], RunnableLambda]: + def my_func5() -> tuple[Callable[[str], str], RunnableLambda]: global_agent = RunnableLambda(lambda x: x * 3) def my_func6(input: str) -> str: diff --git a/libs/core/tests/unit_tests/stores/test_in_memory.py b/libs/core/tests/unit_tests/stores/test_in_memory.py index 6c2346e39341d..3c5f810b1fc6a 100644 --- a/libs/core/tests/unit_tests/stores/test_in_memory.py +++ b/libs/core/tests/unit_tests/stores/test_in_memory.py @@ -1,5 +1,3 @@ -from typing import Tuple - import pytest from langchain_standard_tests.integration_tests.base_store import ( BaseStoreAsyncTests, @@ -16,7 +14,7 @@ def kv_store(self) -> InMemoryStore: return InMemoryStore() @pytest.fixture - def three_values(self) -> Tuple[str, str, str]: # type: ignore + def three_values(self) -> tuple[str, str, str]: # type: ignore return "value1", "value2", "value3" @@ -26,7 +24,7 @@ async def kv_store(self) -> InMemoryStore: return InMemoryStore() @pytest.fixture - def three_values(self) -> Tuple[str, str, str]: # type: ignore + def three_values(self) -> tuple[str, str, str]: # type: ignore return "value1", "value2", "value3" diff --git a/libs/core/tests/unit_tests/test_imports.py b/libs/core/tests/unit_tests/test_imports.py index 887811b5bd00d..ed336df0c3120 100644 --- a/libs/core/tests/unit_tests/test_imports.py +++ b/libs/core/tests/unit_tests/test_imports.py @@ -3,7 +3,6 @@ import importlib import subprocess from pathlib import Path -from typing import Tuple def test_importable_all() -> None: @@ -18,7 +17,7 @@ def test_importable_all() -> None: getattr(module, cls_) -def try_to_import(module_name: str) -> Tuple[int, str]: +def try_to_import(module_name: str) -> tuple[int, str]: """Try to import a module via subprocess.""" module = importlib.import_module("langchain_core." + module_name) all_ = getattr(module, "__all__", []) diff --git a/libs/core/tests/unit_tests/test_messages.py b/libs/core/tests/unit_tests/test_messages.py index c6c47396df08e..cc3e3551efb83 100644 --- a/libs/core/tests/unit_tests/test_messages.py +++ b/libs/core/tests/unit_tests/test_messages.py @@ -1,6 +1,6 @@ import unittest import uuid -from typing import List, Type, Union +from typing import Union import pytest @@ -447,7 +447,7 @@ def test_message_chunk_to_message() -> None: def test_tool_calls_merge() -> None: - chunks: List[dict] = [ + chunks: list[dict] = [ dict(content=""), dict( content="", @@ -790,7 +790,7 @@ def test_convert_to_messages() -> None: SystemMessage, ], ) -def test_message_name(MessageClass: Type) -> None: +def test_message_name(MessageClass: type) -> None: msg = MessageClass(content="foo", name="bar") assert msg.name == "bar" @@ -805,7 +805,7 @@ def test_message_name(MessageClass: Type) -> None: "MessageClass", [FunctionMessage, FunctionMessageChunk], ) -def test_message_name_function(MessageClass: Type) -> None: +def test_message_name_function(MessageClass: type) -> None: # functionmessage doesn't support name=None msg = MessageClass(name="foo", content="bar") assert msg.name == "foo" @@ -815,7 +815,7 @@ def test_message_name_function(MessageClass: Type) -> None: "MessageClass", [ChatMessage, ChatMessageChunk], ) -def test_message_name_chat(MessageClass: Type) -> None: +def test_message_name_chat(MessageClass: type) -> None: msg = MessageClass(content="foo", role="user", name="bar") assert msg.name == "bar" diff --git a/libs/core/tests/unit_tests/test_tools.py b/libs/core/tests/unit_tests/test_tools.py index 83a5b5956d9ee..35ff717dd2e47 100644 --- a/libs/core/tests/unit_tests/test_tools.py +++ b/libs/core/tests/unit_tests/test_tools.py @@ -9,15 +9,12 @@ from enum import Enum from functools import partial from typing import ( + Annotated, Any, Callable, - Dict, Generic, - List, Literal, Optional, - Tuple, - Type, TypeVar, Union, ) @@ -25,7 +22,7 @@ import pytest from pydantic import BaseModel, Field, ValidationError from pydantic.v1 import BaseModel as BaseModelV1 -from typing_extensions import Annotated, TypedDict +from typing_extensions import TypedDict from langchain_core import tools from langchain_core.callbacks import ( @@ -91,7 +88,7 @@ class _MockSchemaV1(BaseModelV1): class _MockStructuredTool(BaseTool): name: str = "structured_api" - args_schema: Type[BaseModel] = _MockSchema + args_schema: type[BaseModel] = _MockSchema description: str = "A Structured Tool" def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str: @@ -135,7 +132,7 @@ def test_forward_ref_annotated_base_tool_accepted() -> None: class _ForwardRefAnnotatedTool(BaseTool): name: str = "structured_api" - args_schema: "Type[BaseModel]" = _MockSchema + args_schema: "type[BaseModel]" = _MockSchema description: str = "A Structured Tool" def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str: @@ -152,7 +149,7 @@ def test_subclass_annotated_base_tool_accepted() -> None: class _ForwardRefAnnotatedTool(BaseTool): name: str = "structured_api" - args_schema: Type[_MockSchema] = _MockSchema + args_schema: type[_MockSchema] = _MockSchema description: str = "A Structured Tool" def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str: @@ -432,7 +429,7 @@ def foo(bar: int, baz: str) -> str: def test_structured_tool_from_function_docstring_complex_args() -> None: """Test that structured tools can be created from functions.""" - def foo(bar: int, baz: List[str]) -> str: + def foo(bar: int, baz: list[str]) -> str: """Docstring Args: @@ -856,8 +853,8 @@ class _RaiseNonValidationErrorTool(BaseTool): def _parse_input( self, - tool_input: Union[str, Dict], - ) -> Union[str, Dict[str, Any]]: + tool_input: Union[str, dict], + ) -> Union[str, dict[str, Any]]: raise NotImplementedError() def _run(self) -> str: @@ -918,8 +915,8 @@ class _RaiseNonValidationErrorTool(BaseTool): def _parse_input( self, - tool_input: Union[str, Dict], - ) -> Union[str, Dict[str, Any]]: + tool_input: Union[str, dict], + ) -> Union[str, dict[str, Any]]: raise NotImplementedError() def _run(self) -> str: @@ -937,7 +934,7 @@ def test_optional_subset_model_rewrite() -> None: class MyModel(BaseModel): a: Optional[str] = None b: str - c: Optional[List[Optional[str]]] = None + c: Optional[list[Optional[str]]] = None model2 = _create_subset_model("model2", MyModel, ["a", "b", "c"]) @@ -1263,20 +1260,20 @@ def test_tool_call_input_tool_message_output() -> None: class _MockStructuredToolWithRawOutput(BaseTool): name: str = "structured_api" - args_schema: Type[BaseModel] = _MockSchema + args_schema: type[BaseModel] = _MockSchema description: str = "A Structured Tool" response_format: Literal["content_and_artifact"] = "content_and_artifact" def _run( self, arg1: int, arg2: bool, arg3: Optional[dict] = None - ) -> Tuple[str, dict]: + ) -> tuple[str, dict]: return f"{arg1} {arg2}", {"arg1": arg1, "arg2": arg2, "arg3": arg3} @tool("structured_api", response_format="content_and_artifact") def _mock_structured_tool_with_artifact( arg1: int, arg2: bool, arg3: Optional[dict] = None -) -> Tuple[str, dict]: +) -> tuple[str, dict]: """A Structured Tool""" return f"{arg1} {arg2}", {"arg1": arg1, "arg2": arg2, "arg3": arg3} @@ -1285,7 +1282,7 @@ def _mock_structured_tool_with_artifact( "tool", [_MockStructuredToolWithRawOutput(), _mock_structured_tool_with_artifact] ) def test_tool_call_input_tool_message_with_artifact(tool: BaseTool) -> None: - tool_call: Dict = { + tool_call: dict = { "name": "structured_api", "args": {"arg1": 1, "arg2": True, "arg3": {"img": "base64string..."}}, "id": "123", @@ -1309,7 +1306,7 @@ def test_convert_from_runnable_dict() -> None: # Test with typed dict input class Args(TypedDict): a: int - b: List[int] + b: list[int] def f(x: Args) -> str: return str(x["a"] * max(x["b"])) @@ -1336,7 +1333,7 @@ def f(x: Args) -> str: assert as_tool.description == "test description" # Dict without typed input-- must supply schema - def g(x: Dict[str, Any]) -> str: + def g(x: dict[str, Any]) -> str: return str(x["a"] * max(x["b"])) # Specify via args_schema: @@ -1344,7 +1341,7 @@ class GSchema(BaseModel): """Apply a function to an integer and list of integers.""" a: int = Field(..., description="Integer") - b: List[int] = Field(..., description="List of ints") + b: list[int] = Field(..., description="List of ints") runnable = RunnableLambda(g) as_tool = runnable.as_tool(GSchema) @@ -1352,18 +1349,18 @@ class GSchema(BaseModel): # Specify via arg_types: runnable = RunnableLambda(g) - as_tool = runnable.as_tool(arg_types={"a": int, "b": List[int]}) + as_tool = runnable.as_tool(arg_types={"a": int, "b": list[int]}) result = as_tool.invoke({"a": 3, "b": [1, 2]}) assert result == "6" # Test with config - def h(x: Dict[str, Any]) -> str: + def h(x: dict[str, Any]) -> str: config = ensure_config() assert config["configurable"]["foo"] == "not-bar" return str(x["a"] * max(x["b"])) runnable = RunnableLambda(h) - as_tool = runnable.as_tool(arg_types={"a": int, "b": List[int]}) + as_tool = runnable.as_tool(arg_types={"a": int, "b": list[int]}) result = as_tool.invoke( {"a": 3, "b": [1, 2]}, config={"configurable": {"foo": "not-bar"}} ) @@ -1436,7 +1433,7 @@ class fooSchema(BaseModel): class InjectedToolWithSchema(BaseTool): name: str = "foo" description: str = "foo." - args_schema: Type[BaseModel] = fooSchema + args_schema: type[BaseModel] = fooSchema def _run(self, x: int, y: str) -> Any: return y @@ -1586,7 +1583,7 @@ class fooSchema(barSchema): class InheritedInjectedArgTool(BaseTool): name: str = "foo" description: str = "foo." - args_schema: Type[BaseModel] = fooSchema + args_schema: type[BaseModel] = fooSchema def _run(self, x: int, y: str) -> Any: return y @@ -1661,7 +1658,7 @@ def test_fn_injected_arg_with_schema(tool_: Callable) -> None: } -def generate_models() -> List[Any]: +def generate_models() -> list[Any]: """Generate a list of base models depending on the pydantic version.""" class FooProper(BaseModel): @@ -1671,7 +1668,7 @@ class FooProper(BaseModel): return [FooProper] -def generate_backwards_compatible_v1() -> List[Any]: +def generate_backwards_compatible_v1() -> list[Any]: """Generate a model with pydantic 2 from the v1 namespace.""" from pydantic.v1 import BaseModel as BaseModelV1 @@ -1692,7 +1689,7 @@ class FooV1Namespace(BaseModelV1): @pytest.mark.parametrize("pydantic_model", TEST_MODELS) def test_args_schema_as_pydantic(pydantic_model: Any) -> None: class SomeTool(BaseTool): - args_schema: Type[pydantic_model] = pydantic_model + args_schema: type[pydantic_model] = pydantic_model def _run(self, *args: Any, **kwargs: Any) -> str: return "foo" @@ -1752,7 +1749,7 @@ class SomeTool(BaseTool): # type ignoring here since we're allowing overriding a type # signature of pydantic.v1.BaseModel with pydantic.BaseModel # for pydantic 2! - args_schema: Type[BaseModel] = Foo # type: ignore[assignment] + args_schema: type[BaseModel] = Foo # type: ignore[assignment] def _run(self, *args: Any, **kwargs: Any) -> str: return "foo" @@ -1893,7 +1890,7 @@ class ModelA(BM2, Generic[A]): # type: ignore[no-redef] model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow") class ModelB(ModelA[str]): - b: Annotated[ModelA[Dict[str, Any]], "foo"] + b: Annotated[ModelA[dict[str, Any]], "foo"] class Mixin: def foo(self) -> str: @@ -1902,11 +1899,11 @@ def foo(self) -> str: class ModelC(Mixin, ModelB): c: dict - expected = {"a": str, "b": Annotated[ModelA[Dict[str, Any]], "foo"], "c": dict} + expected = {"a": str, "b": Annotated[ModelA[dict[str, Any]], "foo"], "c": dict} actual = _get_all_basemodel_annotations(ModelC) assert actual == expected - expected = {"a": str, "b": Annotated[ModelA[Dict[str, Any]], "foo"]} + expected = {"a": str, "b": Annotated[ModelA[dict[str, Any]], "foo"]} actual = _get_all_basemodel_annotations(ModelB) assert actual == expected @@ -1925,7 +1922,7 @@ class ModelD(ModelC, Generic[D]): expected = { "a": str, - "b": Annotated[ModelA[Dict[str, Any]], "foo"], + "b": Annotated[ModelA[dict[str, Any]], "foo"], "c": dict, "d": Union[str, int, None], } @@ -1934,7 +1931,7 @@ class ModelD(ModelC, Generic[D]): expected = { "a": str, - "b": Annotated[ModelA[Dict[str, Any]], "foo"], + "b": Annotated[ModelA[dict[str, Any]], "foo"], "c": dict, "d": Union[int, None], } @@ -1950,7 +1947,7 @@ class ModelA(BaseModel, Generic[A], extra="allow"): a: A class ModelB(ModelA[str]): - b: Annotated[ModelA[Dict[str, Any]], "foo"] + b: Annotated[ModelA[dict[str, Any]], "foo"] class Mixin: def foo(self) -> str: @@ -1959,11 +1956,11 @@ def foo(self) -> str: class ModelC(Mixin, ModelB): c: dict - expected = {"a": str, "b": Annotated[ModelA[Dict[str, Any]], "foo"], "c": dict} + expected = {"a": str, "b": Annotated[ModelA[dict[str, Any]], "foo"], "c": dict} actual = _get_all_basemodel_annotations(ModelC) assert actual == expected - expected = {"a": str, "b": Annotated[ModelA[Dict[str, Any]], "foo"]} + expected = {"a": str, "b": Annotated[ModelA[dict[str, Any]], "foo"]} actual = _get_all_basemodel_annotations(ModelB) assert actual == expected @@ -1982,7 +1979,7 @@ class ModelD(ModelC, Generic[D]): expected = { "a": str, - "b": Annotated[ModelA[Dict[str, Any]], "foo"], + "b": Annotated[ModelA[dict[str, Any]], "foo"], "c": dict, "d": Union[str, int, None], } @@ -1991,7 +1988,7 @@ class ModelD(ModelC, Generic[D]): expected = { "a": str, - "b": Annotated[ModelA[Dict[str, Any]], "foo"], + "b": Annotated[ModelA[dict[str, Any]], "foo"], "c": dict, "d": Union[int, None], } @@ -2026,7 +2023,7 @@ def test_tool_args_schema_pydantic_v2_with_metadata() -> None: from pydantic import ValidationError as ValidationErrorV2 class Foo(BaseModelV2): - x: List[int] = FieldV2( + x: list[int] = FieldV2( description="List of integers", min_length=10, max_length=15 ) diff --git a/libs/core/tests/unit_tests/tracers/test_langchain.py b/libs/core/tests/unit_tests/tracers/test_langchain.py index 1cff651916752..a8c897615ac3b 100644 --- a/libs/core/tests/unit_tests/tracers/test_langchain.py +++ b/libs/core/tests/unit_tests/tracers/test_langchain.py @@ -3,7 +3,7 @@ import unittest import unittest.mock import uuid -from typing import Any, Dict +from typing import Any from uuid import UUID import pytest @@ -102,7 +102,7 @@ class LangChainProjectNameTest(unittest.TestCase): class SetProperTracerProjectTestCase: def __init__( - self, test_name: str, envvars: Dict[str, str], expected_project_name: str + self, test_name: str, envvars: dict[str, str], expected_project_name: str ): self.test_name = test_name self.envvars = envvars diff --git a/libs/core/tests/unit_tests/tracers/test_memory_stream.py b/libs/core/tests/unit_tests/tracers/test_memory_stream.py index cbdd4d125c7b1..451ab35bb678e 100644 --- a/libs/core/tests/unit_tests/tracers/test_memory_stream.py +++ b/libs/core/tests/unit_tests/tracers/test_memory_stream.py @@ -1,8 +1,8 @@ import asyncio import math import time +from collections.abc import AsyncIterator from concurrent.futures import ThreadPoolExecutor -from typing import AsyncIterator from langchain_core.tracers.memory_stream import _MemoryStream diff --git a/libs/core/tests/unit_tests/utils/test_aiter.py b/libs/core/tests/unit_tests/utils/test_aiter.py index 3b035a89277ab..ca17283412d2c 100644 --- a/libs/core/tests/unit_tests/utils/test_aiter.py +++ b/libs/core/tests/unit_tests/utils/test_aiter.py @@ -1,4 +1,4 @@ -from typing import AsyncIterator, List +from collections.abc import AsyncIterator import pytest @@ -15,11 +15,11 @@ ], ) async def test_abatch_iterate( - input_size: int, input_iterable: List[str], expected_output: List[str] + input_size: int, input_iterable: list[str], expected_output: list[str] ) -> None: """Test batching function.""" - async def _to_async_iterable(iterable: List[str]) -> AsyncIterator[str]: + async def _to_async_iterable(iterable: list[str]) -> AsyncIterator[str]: for item in iterable: yield item diff --git a/libs/core/tests/unit_tests/utils/test_function_calling.py b/libs/core/tests/unit_tests/utils/test_function_calling.py index ec6ce542732bb..64d017b5e749e 100644 --- a/libs/core/tests/unit_tests/utils/test_function_calling.py +++ b/libs/core/tests/unit_tests/utils/test_function_calling.py @@ -1,20 +1,13 @@ # mypy: disable-error-code="annotation-unchecked" import sys +import typing +from collections.abc import Iterable, Mapping, MutableMapping, Sequence +from typing import Annotated as ExtensionsAnnotated from typing import ( Any, Callable, - Dict, - Iterable, - List, Literal, - Mapping, - MutableMapping, - MutableSet, Optional, - Sequence, - Set, - Tuple, - Type, Union, ) from typing import TypedDict as TypingTypedDict @@ -22,9 +15,6 @@ import pytest from pydantic import BaseModel as BaseModelV2Maybe # pydantic: ignore from pydantic import Field as FieldV2Maybe # pydantic: ignore -from typing_extensions import ( - Annotated as ExtensionsAnnotated, -) from typing_extensions import ( TypedDict as ExtensionsTypedDict, ) @@ -47,7 +37,7 @@ @pytest.fixture() -def pydantic() -> Type[BaseModel]: +def pydantic() -> type[BaseModel]: class dummy_function(BaseModel): """dummy function""" @@ -102,7 +92,7 @@ class Schema(BaseModel): arg2: Literal["bar", "baz"] = Field(..., description="one of 'bar', 'baz'") class DummyFunction(BaseTool): - args_schema: Type[BaseModel] = Schema + args_schema: type[BaseModel] = Schema name: str = "dummy_function" description: str = "dummy function" @@ -127,7 +117,7 @@ class Schema(BaseModel): @pytest.fixture() -def dummy_pydantic() -> Type[BaseModel]: +def dummy_pydantic() -> type[BaseModel]: class dummy_function(BaseModel): """dummy function""" @@ -138,7 +128,7 @@ class dummy_function(BaseModel): @pytest.fixture() -def dummy_pydantic_v2() -> Type[BaseModelV2Maybe]: +def dummy_pydantic_v2() -> type[BaseModelV2Maybe]: class dummy_function(BaseModelV2Maybe): """dummy function""" @@ -151,7 +141,7 @@ class dummy_function(BaseModelV2Maybe): @pytest.fixture() -def dummy_typing_typed_dict() -> Type: +def dummy_typing_typed_dict() -> type: class dummy_function(TypingTypedDict): """dummy function""" @@ -162,7 +152,7 @@ class dummy_function(TypingTypedDict): @pytest.fixture() -def dummy_typing_typed_dict_docstring() -> Type: +def dummy_typing_typed_dict_docstring() -> type: class dummy_function(TypingTypedDict): """dummy function @@ -178,7 +168,7 @@ class dummy_function(TypingTypedDict): @pytest.fixture() -def dummy_extensions_typed_dict() -> Type: +def dummy_extensions_typed_dict() -> type: class dummy_function(ExtensionsTypedDict): """dummy function""" @@ -189,7 +179,7 @@ class dummy_function(ExtensionsTypedDict): @pytest.fixture() -def dummy_extensions_typed_dict_docstring() -> Type: +def dummy_extensions_typed_dict_docstring() -> type: class dummy_function(ExtensionsTypedDict): """dummy function @@ -205,7 +195,7 @@ class dummy_function(ExtensionsTypedDict): @pytest.fixture() -def json_schema() -> Dict: +def json_schema() -> dict: return { "title": "dummy_function", "description": "dummy function", @@ -246,18 +236,18 @@ def dummy_function(cls, arg1: int, arg2: Literal["bar", "baz"]) -> None: def test_convert_to_openai_function( - pydantic: Type[BaseModel], + pydantic: type[BaseModel], function: Callable, dummy_structured_tool: StructuredTool, dummy_tool: BaseTool, - json_schema: Dict, + json_schema: dict, Annotated_function: Callable, - dummy_pydantic: Type[BaseModel], + dummy_pydantic: type[BaseModel], runnable: Runnable, - dummy_typing_typed_dict: Type, - dummy_typing_typed_dict_docstring: Type, - dummy_extensions_typed_dict: Type, - dummy_extensions_typed_dict_docstring: Type, + dummy_typing_typed_dict: type, + dummy_typing_typed_dict_docstring: type, + dummy_extensions_typed_dict: type, + dummy_extensions_typed_dict_docstring: type, ) -> None: expected = { "name": "dummy_function", @@ -436,7 +426,7 @@ def test_function_optional_param() -> None: def func5( a: Optional[str], b: str, - c: Optional[List[Optional[str]]], + c: Optional[list[Optional[str]]], ) -> None: """A test function""" pass @@ -544,7 +534,7 @@ def test__convert_typed_dict_to_openai_function( class SubTool(TypedDict): """Subtool docstring""" - args: Annotated[Dict[str, Any], {}, "this does bar"] # noqa: F722 # type: ignore + args: Annotated[dict[str, Any], {}, "this does bar"] # noqa: F722 # type: ignore class Tool(TypedDict): """Docstring @@ -555,18 +545,18 @@ class Tool(TypedDict): arg1: str arg2: Union[int, str, bool] - arg3: Optional[List[SubTool]] + arg3: Optional[list[SubTool]] arg4: Annotated[Literal["bar", "baz"], ..., "this does foo"] # noqa: F722 arg5: Annotated[Optional[float], None] arg6: Annotated[ - Optional[Sequence[Mapping[str, Tuple[Iterable[Any], SubTool]]]], [] + Optional[Sequence[Mapping[str, tuple[Iterable[Any], SubTool]]]], [] ] - arg7: Annotated[List[SubTool], ...] - arg8: Annotated[Tuple[SubTool], ...] + arg7: Annotated[list[SubTool], ...] + arg8: Annotated[tuple[SubTool], ...] arg9: Annotated[Sequence[SubTool], ...] arg10: Annotated[Iterable[SubTool], ...] - arg11: Annotated[Set[SubTool], ...] - arg12: Annotated[Dict[str, SubTool], ...] + arg11: Annotated[set[SubTool], ...] + arg12: Annotated[dict[str, SubTool], ...] arg13: Annotated[Mapping[str, SubTool], ...] arg14: Annotated[MutableMapping[str, SubTool], ...] arg15: Annotated[bool, False, "flag"] # noqa: F821 # type: ignore @@ -775,9 +765,9 @@ class Tool(TypedDict): @pytest.mark.parametrize("typed_dict", [ExtensionsTypedDict, TypingTypedDict]) -def test__convert_typed_dict_to_openai_function_fail(typed_dict: Type) -> None: +def test__convert_typed_dict_to_openai_function_fail(typed_dict: type) -> None: class Tool(typed_dict): - arg1: MutableSet # Pydantic 2 supports this, but pydantic v1 does not. + arg1: typing.MutableSet # Pydantic 2 supports this, but pydantic v1 does not. # Error should be raised since we're using v1 code path here with pytest.raises(TypeError): diff --git a/libs/core/tests/unit_tests/utils/test_iter.py b/libs/core/tests/unit_tests/utils/test_iter.py index d0866ea3fc091..2e8d547993aa1 100644 --- a/libs/core/tests/unit_tests/utils/test_iter.py +++ b/libs/core/tests/unit_tests/utils/test_iter.py @@ -1,5 +1,3 @@ -from typing import List - import pytest from langchain_core.utils.iter import batch_iterate @@ -15,7 +13,7 @@ ], ) def test_batch_iterate( - input_size: int, input_iterable: List[str], expected_output: List[str] + input_size: int, input_iterable: list[str], expected_output: list[str] ) -> None: """Test batching function.""" assert list(batch_iterate(input_size, input_iterable)) == expected_output diff --git a/libs/core/tests/unit_tests/utils/test_pydantic.py b/libs/core/tests/unit_tests/utils/test_pydantic.py index cdb8d84fa94fe..8ba4beeab7674 100644 --- a/libs/core/tests/unit_tests/utils/test_pydantic.py +++ b/libs/core/tests/unit_tests/utils/test_pydantic.py @@ -1,6 +1,6 @@ """Test for some custom pydantic decorators.""" -from typing import Any, Dict, List, Optional +from typing import Any, Optional import pytest from pydantic import ConfigDict @@ -24,7 +24,7 @@ class Foo(BaseModel): y: int @pre_init - def validator(cls, v: Dict[str, Any]) -> Dict[str, Any]: + def validator(cls, v: dict[str, Any]) -> dict[str, Any]: v["y"] = v["x"] + 1 return v @@ -45,7 +45,7 @@ class Foo(BaseModel): d: int = Field(default_factory=lambda: 3) @pre_init - def validator(cls, v: Dict[str, Any]) -> Dict[str, Any]: + def validator(cls, v: dict[str, Any]) -> dict[str, Any]: assert v["a"] == 1 assert v["b"] is None assert v["c"] == 2 @@ -69,7 +69,7 @@ class Foo(BaseModel): ) @pre_init - def validator(cls, v: Dict[str, Any]) -> Dict[str, Any]: + def validator(cls, v: dict[str, Any]) -> dict[str, Any]: v["z"] = v["x"] return v @@ -142,7 +142,7 @@ def test_with_field_metadata() -> None: from pydantic import Field as FieldV2 class Foo(BaseModelV2): - x: List[int] = FieldV2( + x: list[int] = FieldV2( description="List of integers", min_length=10, max_length=15 ) diff --git a/libs/core/tests/unit_tests/utils/test_utils.py b/libs/core/tests/unit_tests/utils/test_utils.py index 5c435512c960e..2bd897db6063a 100644 --- a/libs/core/tests/unit_tests/utils/test_utils.py +++ b/libs/core/tests/unit_tests/utils/test_utils.py @@ -2,7 +2,7 @@ import re from contextlib import AbstractContextManager, nullcontext from copy import deepcopy -from typing import Any, Callable, Dict, Optional, Tuple, Type, Union +from typing import Any, Callable, Optional, Union from unittest.mock import patch import pytest @@ -32,9 +32,9 @@ ) def test_check_package_version( package: str, - check_kwargs: Dict[str, Optional[str]], + check_kwargs: dict[str, Optional[str]], actual_version: str, - expected: Optional[Tuple[Type[Exception], str]], + expected: Optional[tuple[type[Exception], str]], ) -> None: with patch("langchain_core.utils.utils.version", return_value=actual_version): if expected is None: diff --git a/libs/core/tests/unit_tests/vectorstores/test_vectorstore.py b/libs/core/tests/unit_tests/vectorstores/test_vectorstore.py index 91983b1588749..a1797d777d1f0 100644 --- a/libs/core/tests/unit_tests/vectorstores/test_vectorstore.py +++ b/libs/core/tests/unit_tests/vectorstores/test_vectorstore.py @@ -7,7 +7,8 @@ from __future__ import annotations import uuid -from typing import Any, Iterable, Optional, Sequence +from collections.abc import Iterable, Sequence +from typing import Any, Optional from langchain_core.documents import Document from langchain_core.embeddings import Embeddings