Skip to content

Commit

Permalink
partner[ai21]: masking of the api key for ai21 models (#20257)
Browse files Browse the repository at this point in the history
**Description:** Masking of the API key for AI21 models
**Issue:** Fixes #12165 for AI21
**Dependencies:** None

Note: This fix came in originally through #12418 but was possibly missed
in the refactor to the AI21 partner package


---------

Co-authored-by: Erick Friis <[email protected]>
  • Loading branch information
2 people authored and hinthornw committed Apr 26, 2024
1 parent a925710 commit 21a4969
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 1 deletion.
38 changes: 37 additions & 1 deletion libs/partners/ai21/tests/unit_tests/test_chat_models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Test chat model integration."""
from typing import List, Optional
from typing import List, Optional, cast
from unittest.mock import Mock, call

import pytest
Expand All @@ -14,6 +14,8 @@
from langchain_core.messages import (
ChatMessage as LangChainChatMessage,
)
from langchain_core.pydantic_v1 import SecretStr
from pytest import CaptureFixture, MonkeyPatch

from langchain_ai21.chat_models import (
ChatAI21,
Expand Down Expand Up @@ -236,3 +238,37 @@ def test_generate(mock_client_with_chat: Mock) -> None:
),
]
)


def test_api_key_is_secret_string() -> None:
llm = ChatAI21(model="j2-ultra", api_key="secret-api-key")
assert isinstance(llm.api_key, SecretStr)


def test_api_key_masked_when_passed_from_env(
monkeypatch: MonkeyPatch, capsys: CaptureFixture
) -> None:
"""Test initialization with an API key provided via an env variable"""
monkeypatch.setenv("AI21_API_KEY", "secret-api-key")
llm = ChatAI21(model="j2-ultra")
print(llm.api_key, end="")
captured = capsys.readouterr()

assert captured.out == "**********"


def test_api_key_masked_when_passed_via_constructor(
capsys: CaptureFixture,
) -> None:
"""Test initialization with an API key provided via the initializer"""
llm = ChatAI21(model="j2-ultra", api_key="secret-api-key")
print(llm.api_key, end="")
captured = capsys.readouterr()

assert captured.out == "**********"


def test_uses_actual_secret_value_from_secretstr() -> None:
"""Test that actual secret is retrieved using `.get_secret_value()`."""
llm = ChatAI21(model="j2-ultra", api_key="secret-api-key")
assert cast(SecretStr, llm.api_key).get_secret_value() == "secret-api-key"
38 changes: 38 additions & 0 deletions libs/partners/ai21/tests/unit_tests/test_llms.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
"""Test AI21 Chat API wrapper."""

from typing import cast
from unittest.mock import Mock, call

import pytest
from ai21 import MissingApiKeyError
from ai21.models import (
Penalty,
)
from langchain_core.pydantic_v1 import SecretStr
from pytest import CaptureFixture, MonkeyPatch

from langchain_ai21 import AI21LLM
from tests.unit_tests.conftest import (
Expand Down Expand Up @@ -106,3 +110,37 @@ def test_generate(mock_client_with_completion: Mock) -> None:
),
]
)


def test_api_key_is_secret_string() -> None:
llm = AI21LLM(model="j2-ultra", api_key="secret-api-key")
assert isinstance(llm.api_key, SecretStr)


def test_api_key_masked_when_passed_from_env(
monkeypatch: MonkeyPatch, capsys: CaptureFixture
) -> None:
"""Test initialization with an API key provided via an env variable"""
monkeypatch.setenv("AI21_API_KEY", "secret-api-key")
llm = AI21LLM(model="j2-ultra")
print(llm.api_key, end="")
captured = capsys.readouterr()

assert captured.out == "**********"


def test_api_key_masked_when_passed_via_constructor(
capsys: CaptureFixture,
) -> None:
"""Test initialization with an API key provided via the initializer"""
llm = AI21LLM(model="j2-ultra", api_key="secret-api-key")
print(llm.api_key, end="")
captured = capsys.readouterr()

assert captured.out == "**********"


def test_uses_actual_secret_value_from_secretstr() -> None:
"""Test that actual secret is retrieved using `.get_secret_value()`."""
llm = AI21LLM(model="j2-ultra", api_key="secret-api-key")
assert cast(SecretStr, llm.api_key).get_secret_value() == "secret-api-key"

0 comments on commit 21a4969

Please sign in to comment.