diff --git a/libs/partners/ai21/tests/unit_tests/test_chat_models.py b/libs/partners/ai21/tests/unit_tests/test_chat_models.py index f95c73db90442..499d21bd6d9ef 100644 --- a/libs/partners/ai21/tests/unit_tests/test_chat_models.py +++ b/libs/partners/ai21/tests/unit_tests/test_chat_models.py @@ -1,5 +1,5 @@ """Test chat model integration.""" -from typing import List, Optional +from typing import List, Optional, cast from unittest.mock import Mock, call import pytest @@ -14,6 +14,8 @@ from langchain_core.messages import ( ChatMessage as LangChainChatMessage, ) +from langchain_core.pydantic_v1 import SecretStr +from pytest import CaptureFixture, MonkeyPatch from langchain_ai21.chat_models import ( ChatAI21, @@ -236,3 +238,37 @@ def test_generate(mock_client_with_chat: Mock) -> None: ), ] ) + + +def test_api_key_is_secret_string() -> None: + llm = ChatAI21(model="j2-ultra", api_key="secret-api-key") + assert isinstance(llm.api_key, SecretStr) + + +def test_api_key_masked_when_passed_from_env( + monkeypatch: MonkeyPatch, capsys: CaptureFixture +) -> None: + """Test initialization with an API key provided via an env variable""" + monkeypatch.setenv("AI21_API_KEY", "secret-api-key") + llm = ChatAI21(model="j2-ultra") + print(llm.api_key, end="") + captured = capsys.readouterr() + + assert captured.out == "**********" + + +def test_api_key_masked_when_passed_via_constructor( + capsys: CaptureFixture, +) -> None: + """Test initialization with an API key provided via the initializer""" + llm = ChatAI21(model="j2-ultra", api_key="secret-api-key") + print(llm.api_key, end="") + captured = capsys.readouterr() + + assert captured.out == "**********" + + +def test_uses_actual_secret_value_from_secretstr() -> None: + """Test that actual secret is retrieved using `.get_secret_value()`.""" + llm = ChatAI21(model="j2-ultra", api_key="secret-api-key") + assert cast(SecretStr, llm.api_key).get_secret_value() == "secret-api-key" diff --git a/libs/partners/ai21/tests/unit_tests/test_llms.py b/libs/partners/ai21/tests/unit_tests/test_llms.py index 2c47ec234acc8..1854df0e77bd2 100644 --- a/libs/partners/ai21/tests/unit_tests/test_llms.py +++ b/libs/partners/ai21/tests/unit_tests/test_llms.py @@ -1,4 +1,6 @@ """Test AI21 Chat API wrapper.""" + +from typing import cast from unittest.mock import Mock, call import pytest @@ -6,6 +8,8 @@ from ai21.models import ( Penalty, ) +from langchain_core.pydantic_v1 import SecretStr +from pytest import CaptureFixture, MonkeyPatch from langchain_ai21 import AI21LLM from tests.unit_tests.conftest import ( @@ -106,3 +110,37 @@ def test_generate(mock_client_with_completion: Mock) -> None: ), ] ) + + +def test_api_key_is_secret_string() -> None: + llm = AI21LLM(model="j2-ultra", api_key="secret-api-key") + assert isinstance(llm.api_key, SecretStr) + + +def test_api_key_masked_when_passed_from_env( + monkeypatch: MonkeyPatch, capsys: CaptureFixture +) -> None: + """Test initialization with an API key provided via an env variable""" + monkeypatch.setenv("AI21_API_KEY", "secret-api-key") + llm = AI21LLM(model="j2-ultra") + print(llm.api_key, end="") + captured = capsys.readouterr() + + assert captured.out == "**********" + + +def test_api_key_masked_when_passed_via_constructor( + capsys: CaptureFixture, +) -> None: + """Test initialization with an API key provided via the initializer""" + llm = AI21LLM(model="j2-ultra", api_key="secret-api-key") + print(llm.api_key, end="") + captured = capsys.readouterr() + + assert captured.out == "**********" + + +def test_uses_actual_secret_value_from_secretstr() -> None: + """Test that actual secret is retrieved using `.get_secret_value()`.""" + llm = AI21LLM(model="j2-ultra", api_key="secret-api-key") + assert cast(SecretStr, llm.api_key).get_secret_value() == "secret-api-key"