Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core,integrations[minor]: Dont error on fields in model_kwargs #27110

Merged
merged 11 commits into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions libs/community/langchain_community/chat_models/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
get_pydantic_field_names,
pre_init,
)
from langchain_core.utils.utils import build_extra_kwargs
from langchain_core.utils.utils import _build_model_kwargs
from pydantic import Field, SecretStr, model_validator

SUPPORTED_ROLES: List[str] = [
Expand Down Expand Up @@ -131,10 +131,7 @@ class ChatSnowflakeCortex(BaseChatModel):
def build_extra(cls, values: Dict[str, Any]) -> Any:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls)
extra = values.get("model_kwargs", {})
values["model_kwargs"] = build_extra_kwargs(
extra, values, all_required_field_names
)
values = _build_model_kwargs(values, all_required_field_names)
return values

@pre_init
Expand Down
7 changes: 2 additions & 5 deletions libs/community/langchain_community/llms/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
get_pydantic_field_names,
pre_init,
)
from langchain_core.utils.utils import build_extra_kwargs, convert_to_secret_str
from langchain_core.utils.utils import _build_model_kwargs, convert_to_secret_str
from pydantic import ConfigDict, Field, SecretStr, model_validator


Expand Down Expand Up @@ -69,11 +69,8 @@ class _AnthropicCommon(BaseLanguageModel):
@model_validator(mode="before")
@classmethod
def build_extra(cls, values: Dict) -> Any:
extra = values.get("model_kwargs", {})
all_required_field_names = get_pydantic_field_names(cls)
values["model_kwargs"] = build_extra_kwargs(
extra, values, all_required_field_names
)
values = _build_model_kwargs(values, all_required_field_names)
return values

@pre_init
Expand Down
7 changes: 2 additions & 5 deletions libs/community/langchain_community/llms/llamacpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from langchain_core.language_models.llms import LLM
from langchain_core.outputs import GenerationChunk
from langchain_core.utils import get_pydantic_field_names, pre_init
from langchain_core.utils.utils import build_extra_kwargs
from langchain_core.utils.utils import _build_model_kwargs
from pydantic import Field, model_validator

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -199,10 +199,7 @@ def validate_environment(cls, values: Dict) -> Dict:
def build_model_kwargs(cls, values: Dict[str, Any]) -> Any:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls)
extra = values.get("model_kwargs", {})
values["model_kwargs"] = build_extra_kwargs(
extra, values, all_required_field_names
)
values = _build_model_kwargs(values, all_required_field_names)
return values

@property
Expand Down
7 changes: 2 additions & 5 deletions libs/community/langchain_community/llms/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
pre_init,
)
from langchain_core.utils.pydantic import get_fields
from langchain_core.utils.utils import build_extra_kwargs
from langchain_core.utils.utils import _build_model_kwargs
from pydantic import ConfigDict, Field, model_validator

from langchain_community.utils.openai import is_openai_v1
Expand Down Expand Up @@ -268,10 +268,7 @@ def __new__(cls, **data: Any) -> Union[OpenAIChat, BaseOpenAI]: # type: ignore
def build_extra(cls, values: Dict[str, Any]) -> Any:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls)
extra = values.get("model_kwargs", {})
values["model_kwargs"] = build_extra_kwargs(
extra, values, all_required_field_names
)
values = _build_model_kwargs(values, all_required_field_names)
return values

@pre_init
Expand Down
9 changes: 6 additions & 3 deletions libs/community/tests/unit_tests/chat_models/test_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,12 @@ def test_anthropic_model_kwargs() -> None:


@pytest.mark.requires("anthropic")
def test_anthropic_invalid_model_kwargs() -> None:
with pytest.raises(ValueError):
ChatAnthropic(model_kwargs={"max_tokens_to_sample": 5})
def test_anthropic_fields_in_model_kwargs() -> None:
"""Test that for backwards compatibility fields can be passed in as model_kwargs."""
llm = ChatAnthropic(model_kwargs={"max_tokens_to_sample": 5})
assert llm.max_tokens_to_sample == 5
llm = ChatAnthropic(model_kwargs={"max_tokens": 5})
assert llm.max_tokens_to_sample == 5


@pytest.mark.requires("anthropic")
Expand Down
13 changes: 6 additions & 7 deletions libs/community/tests/unit_tests/llms/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,12 @@ def test_openai_model_kwargs() -> None:


@pytest.mark.requires("openai")
def test_openai_invalid_model_kwargs() -> None:
with pytest.raises(ValueError):
OpenAI(model_kwargs={"model_name": "foo"})

# Test that "model" cannot be specified in kwargs
with pytest.raises(ValueError):
OpenAI(model_kwargs={"model": "gpt-3.5-turbo-instruct"})
def test_openai_fields_model_kwargs() -> None:
"""Test that for backwards compatibility fields can be passed in as model_kwargs."""
llm = OpenAI(model_kwargs={"model_name": "foo"}, api_key="foo")
assert llm.model_name == "foo"
llm = OpenAI(model_kwargs={"model": "foo"}, api_key="foo")
assert llm.model_name == "foo"


@pytest.mark.requires("openai")
Expand Down
2 changes: 1 addition & 1 deletion libs/core/langchain_core/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
)

__all__ = [
"build_extra_kwargs",
"StrictFormatter",
"check_package_version",
"convert_to_secret_str",
Expand All @@ -46,7 +47,6 @@
"raise_for_status_with_text",
"xor_args",
"try_load_from_hub",
"build_extra_kwargs",
"image",
"get_from_env",
"get_from_dict_or_env",
Expand Down
45 changes: 45 additions & 0 deletions libs/core/langchain_core/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,51 @@ def get_pydantic_field_names(pydantic_cls: Any) -> set[str]:
return all_required_field_names


def _build_model_kwargs(
values: dict[str, Any],
all_required_field_names: set[str],
) -> dict[str, Any]:
"""Build "model_kwargs" param from Pydanitc constructor values.

Args:
values: All init args passed in by user.
all_required_field_names: All required field names for the pydantic class.

Returns:
Dict[str, Any]: Extra kwargs.

Raises:
ValueError: If a field is specified in both values and extra_kwargs.
ValueError: If a field is specified in model_kwargs.
"""
extra_kwargs = values.get("model_kwargs", {})
for field_name in list(values):
if field_name in extra_kwargs:
raise ValueError(f"Found {field_name} supplied twice.")
if field_name not in all_required_field_names:
warnings.warn(
f"""WARNING! {field_name} is not default parameter.
{field_name} was transferred to model_kwargs.
Please confirm that {field_name} is what you intended.""",
stacklevel=7,
)
extra_kwargs[field_name] = values.pop(field_name)

invalid_model_kwargs = all_required_field_names.intersection(extra_kwargs.keys())
if invalid_model_kwargs:
warnings.warn(
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
f"Instead they were passed in as part of `model_kwargs` parameter.",
stacklevel=7,
)
for k in invalid_model_kwargs:
values[k] = extra_kwargs.pop(k)

values["model_kwargs"] = extra_kwargs
return values


# DON'T USE! Kept for backwards-compatibility but should never have been public.
def build_extra_kwargs(
extra_kwargs: dict[str, Any],
values: dict[str, Any],
Expand Down
2 changes: 1 addition & 1 deletion libs/core/tests/unit_tests/utils/test_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
"raise_for_status_with_text",
"xor_args",
"try_load_from_hub",
"build_extra_kwargs",
"image",
"build_extra_kwargs",
"get_from_dict_or_env",
"get_from_env",
"stringify_dict",
Expand Down
7 changes: 2 additions & 5 deletions libs/partners/anthropic/langchain_anthropic/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,13 @@
)
from langchain_core.tools import BaseTool
from langchain_core.utils import (
build_extra_kwargs,
from_env,
get_pydantic_field_names,
secret_from_env,
)
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_core.utils.pydantic import is_basemodel_subclass
from langchain_core.utils.utils import _build_model_kwargs
from pydantic import (
BaseModel,
ConfigDict,
Expand Down Expand Up @@ -646,11 +646,8 @@ def _get_ls_params(
@model_validator(mode="before")
@classmethod
def build_extra(cls, values: Dict) -> Any:
extra = values.get("model_kwargs", {})
all_required_field_names = get_pydantic_field_names(cls)
values["model_kwargs"] = build_extra_kwargs(
extra, values, all_required_field_names
)
values = _build_model_kwargs(values, all_required_field_names)
return values

@model_validator(mode="after")
Expand Down
7 changes: 2 additions & 5 deletions libs/partners/anthropic/langchain_anthropic/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
get_pydantic_field_names,
)
from langchain_core.utils.utils import (
build_extra_kwargs,
_build_model_kwargs,
from_env,
secret_from_env,
)
Expand Down Expand Up @@ -88,11 +88,8 @@ class _AnthropicCommon(BaseLanguageModel):
@model_validator(mode="before")
@classmethod
def build_extra(cls, values: Dict) -> Any:
extra = values.get("model_kwargs", {})
all_required_field_names = get_pydantic_field_names(cls)
values["model_kwargs"] = build_extra_kwargs(
extra, values, all_required_field_names
)
values = _build_model_kwargs(values, all_required_field_names)
return values

@model_validator(mode="after")
Expand Down
9 changes: 6 additions & 3 deletions libs/partners/anthropic/tests/unit_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,12 @@ def test_anthropic_model_kwargs() -> None:


@pytest.mark.requires("anthropic")
def test_anthropic_invalid_model_kwargs() -> None:
with pytest.raises(ValueError):
ChatAnthropic(model="foo", model_kwargs={"max_tokens_to_sample": 5}) # type: ignore[call-arg]
def test_anthropic_fields_in_model_kwargs() -> None:
"""Test that for backwards compatibility fields can be passed in as model_kwargs."""
llm = ChatAnthropic(model="foo", model_kwargs={"max_tokens_to_sample": 5}) # type: ignore[call-arg]
baskaryan marked this conversation as resolved.
Show resolved Hide resolved
assert llm.max_tokens == 5
llm = ChatAnthropic(model="foo", model_kwargs={"max_tokens": 5}) # type: ignore[call-arg]
assert llm.max_tokens == 5


@pytest.mark.requires("anthropic")
Expand Down
7 changes: 2 additions & 5 deletions libs/partners/fireworks/langchain_fireworks/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
convert_to_openai_tool,
)
from langchain_core.utils.pydantic import is_basemodel_subclass
from langchain_core.utils.utils import build_extra_kwargs, from_env, secret_from_env
from langchain_core.utils.utils import _build_model_kwargs, from_env, secret_from_env
from pydantic import (
BaseModel,
ConfigDict,
Expand Down Expand Up @@ -366,10 +366,7 @@ def is_lc_serializable(cls) -> bool:
def build_extra(cls, values: Dict[str, Any]) -> Any:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls)
extra = values.get("model_kwargs", {})
values["model_kwargs"] = build_extra_kwargs(
extra, values, all_required_field_names
)
values = _build_model_kwargs(values, all_required_field_names)
return values

@model_validator(mode="after")
Expand Down
7 changes: 2 additions & 5 deletions libs/partners/fireworks/langchain_fireworks/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
)
from langchain_core.language_models.llms import LLM
from langchain_core.utils import get_pydantic_field_names
from langchain_core.utils.utils import build_extra_kwargs, secret_from_env
from langchain_core.utils.utils import _build_model_kwargs, secret_from_env
from pydantic import ConfigDict, Field, SecretStr, model_validator

from langchain_fireworks.version import __version__
Expand Down Expand Up @@ -93,10 +93,7 @@ class Fireworks(LLM):
def build_extra(cls, values: Dict[str, Any]) -> Any:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls)
extra = values.get("model_kwargs", {})
values["model_kwargs"] = build_extra_kwargs(
extra, values, all_required_field_names
)
values = _build_model_kwargs(values, all_required_field_names)
return values

@property
Expand Down
7 changes: 2 additions & 5 deletions libs/partners/openai/langchain_openai/chat_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@
TypeBaseModel,
is_basemodel_subclass,
)
from langchain_core.utils.utils import build_extra_kwargs, from_env, secret_from_env
from langchain_core.utils.utils import _build_model_kwargs, from_env, secret_from_env
from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
from typing_extensions import Self

Expand Down Expand Up @@ -477,10 +477,7 @@ class BaseChatOpenAI(BaseChatModel):
def build_extra(cls, values: Dict[str, Any]) -> Any:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls)
extra = values.get("model_kwargs", {})
values["model_kwargs"] = build_extra_kwargs(
extra, values, all_required_field_names
)
values = _build_model_kwargs(values, all_required_field_names)
return values

@model_validator(mode="after")
Expand Down
7 changes: 2 additions & 5 deletions libs/partners/openai/langchain_openai/llms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from langchain_core.language_models.llms import BaseLLM
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
from langchain_core.utils import get_pydantic_field_names
from langchain_core.utils.utils import build_extra_kwargs, from_env, secret_from_env
from langchain_core.utils.utils import _build_model_kwargs, from_env, secret_from_env
from pydantic import ConfigDict, Field, SecretStr, model_validator
from typing_extensions import Self

Expand Down Expand Up @@ -160,10 +160,7 @@ class BaseOpenAI(BaseLLM):
def build_extra(cls, values: Dict[str, Any]) -> Any:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls)
extra = values.get("model_kwargs", {})
values["model_kwargs"] = build_extra_kwargs(
extra, values, all_required_field_names
)
values = _build_model_kwargs(values, all_required_field_names)
return values

@model_validator(mode="after")
Expand Down
9 changes: 6 additions & 3 deletions libs/partners/openai/tests/unit_tests/llms/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,12 @@ def test_openai_model_kwargs() -> None:
assert llm.model_kwargs == {"foo": "bar"}


def test_openai_invalid_model_kwargs() -> None:
with pytest.raises(ValueError):
OpenAI(model_kwargs={"model_name": "foo"})
def test_openai_fields_in_model_kwargs() -> None:
"""Test that for backwards compatibility fields can be passed in as model_kwargs."""
llm = OpenAI(model_kwargs={"model_name": "foo"})
assert llm.model_name == "foo"
llm = OpenAI(model_kwargs={"model": "foo"})
assert llm.model_name == "foo"


def test_openai_incorrect_field() -> None:
Expand Down
Loading