Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mask API key for baidu qianfan #14281

Merged
merged 3 commits into from
Dec 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 18 additions & 13 deletions libs/langchain/langchain/chat_models/baidu_qianfan_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
SystemMessage,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str

from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
Expand Down Expand Up @@ -88,8 +89,8 @@ class QianfanChatEndpoint(BaseChatModel):

client: Any

qianfan_ak: Optional[str] = None
qianfan_sk: Optional[str] = None
qianfan_ak: Optional[SecretStr] = None
qianfan_sk: Optional[SecretStr] = None

streaming: Optional[bool] = False
"""Whether to stream the results or not."""
Expand Down Expand Up @@ -118,19 +119,23 @@ class QianfanChatEndpoint(BaseChatModel):

@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
values["qianfan_ak"] = get_from_dict_or_env(
values,
"qianfan_ak",
"QIANFAN_AK",
values["qianfan_ak"] = convert_to_secret_str(
get_from_dict_or_env(
values,
"qianfan_ak",
"QIANFAN_AK",
)
)
values["qianfan_sk"] = get_from_dict_or_env(
values,
"qianfan_sk",
"QIANFAN_SK",
values["qianfan_sk"] = convert_to_secret_str(
get_from_dict_or_env(
values,
"qianfan_sk",
"QIANFAN_SK",
)
)
params = {
"ak": values["qianfan_ak"],
"sk": values["qianfan_sk"],
"ak": values["qianfan_ak"].get_secret_value(),
"sk": values["qianfan_sk"].get_secret_value(),
"model": values["model"],
"stream": values["streaming"],
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from typing import cast

from langchain_core.pydantic_v1 import SecretStr
from pytest import CaptureFixture, MonkeyPatch

from langchain.chat_models.baidu_qianfan_endpoint import (
QianfanChatEndpoint,
)


def test_qianfan_key_masked_when_passed_from_env(
monkeypatch: MonkeyPatch, capsys: CaptureFixture
) -> None:
"""Test initialization with an API key provided via an env variable"""
monkeypatch.setenv("QIANFAN_AK", "test-api-key")
monkeypatch.setenv("QIANFAN_SK", "test-secret-key")

chat = QianfanChatEndpoint()
print(chat.qianfan_ak, end="")
captured = capsys.readouterr()
assert captured.out == "**********"

print(chat.qianfan_sk, end="")
captured = capsys.readouterr()
assert captured.out == "**********"


def test_qianfan_key_masked_when_passed_via_constructor(
capsys: CaptureFixture,
) -> None:
"""Test initialization with an API key provided via the initializer"""
chat = QianfanChatEndpoint(
qianfan_ak="test-api-key",
qianfan_sk="test-secret-key",
)
print(chat.qianfan_ak, end="")
captured = capsys.readouterr()
assert captured.out == "**********"

print(chat.qianfan_sk, end="")
captured = capsys.readouterr()

assert captured.out == "**********"


def test_uses_actual_secret_value_from_secret_str() -> None:
"""Test that actual secret is retrieved using `.get_secret_value()`."""
chat = QianfanChatEndpoint(
qianfan_ak="test-api-key",
qianfan_sk="test-secret-key",
)
assert cast(SecretStr, chat.qianfan_ak).get_secret_value() == "test-api-key"
assert cast(SecretStr, chat.qianfan_sk).get_secret_value() == "test-secret-key"
Loading