Skip to content

Commit

Permalink
Improve token limit tests for OpenAI PromptNode layer (#5351)
Browse files Browse the repository at this point in the history
  • Loading branch information
vblagoje authored Jul 17, 2023
1 parent 35b2c99 commit adfabdd
Showing 1 changed file with 75 additions and 20 deletions.
95 changes: 75 additions & 20 deletions test/prompt/invocation_layer/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,28 +6,36 @@
from haystack.nodes.prompt.invocation_layer import OpenAIInvocationLayer


@pytest.fixture
def load_openai_tokenizer():
with patch("haystack.nodes.prompt.invocation_layer.open_ai.load_openai_tokenizer") as mock_load_openai_tokenizer:
yield mock_load_openai_tokenizer


@pytest.fixture()
def mock_open_ai_request():
with patch("haystack.nodes.prompt.invocation_layer.open_ai.openai_request") as mock_openai_request:
yield mock_openai_request


@pytest.mark.unit
@patch("haystack.nodes.prompt.invocation_layer.open_ai.openai_request")
def test_default_api_base(mock_request):
with patch("haystack.nodes.prompt.invocation_layer.open_ai.load_openai_tokenizer"):
invocation_layer = OpenAIInvocationLayer(api_key="fake_api_key")
def test_default_api_base(mock_open_ai_request, load_openai_tokenizer):
invocation_layer = OpenAIInvocationLayer(api_key="fake_api_key")
assert invocation_layer.api_base == "https://api.openai.com/v1"
assert invocation_layer.url == "https://api.openai.com/v1/completions"

invocation_layer.invoke(prompt="dummy_prompt")
assert mock_request.call_args.kwargs["url"] == "https://api.openai.com/v1/completions"
assert mock_open_ai_request.call_args.kwargs["url"] == "https://api.openai.com/v1/completions"


@pytest.mark.unit
@patch("haystack.nodes.prompt.invocation_layer.open_ai.openai_request")
def test_custom_api_base(mock_request):
with patch("haystack.nodes.prompt.invocation_layer.open_ai.load_openai_tokenizer"):
invocation_layer = OpenAIInvocationLayer(api_key="fake_api_key", api_base="https://fake_api_base.com")
def test_custom_api_base(mock_open_ai_request, load_openai_tokenizer):
invocation_layer = OpenAIInvocationLayer(api_key="fake_api_key", api_base="https://fake_api_base.com")
assert invocation_layer.api_base == "https://fake_api_base.com"
assert invocation_layer.url == "https://fake_api_base.com/completions"

invocation_layer.invoke(prompt="dummy_prompt")
assert mock_request.call_args.kwargs["url"] == "https://fake_api_base.com/completions"
assert mock_open_ai_request.call_args.kwargs["url"] == "https://fake_api_base.com/completions"


@pytest.mark.unit
Expand All @@ -42,26 +50,73 @@ def test_openai_token_limit_warning(mock_openai_tokenizer, caplog):


@pytest.mark.unit
@patch("haystack.nodes.prompt.invocation_layer.open_ai.openai_request")
def test_no_openai_organization(mock_request):
with patch("haystack.nodes.prompt.invocation_layer.open_ai.load_openai_tokenizer"):
invocation_layer = OpenAIInvocationLayer(api_key="fake_api_key")
@pytest.mark.parametrize(
"model_name,max_tokens_limit",
[
("text-davinci-003", 4097),
("gpt-3.5-turbo", 4096),
("gpt-3.5-turbo-16k", 16384),
("gpt-4-32k", 32768),
("gpt-4", 8192),
],
)
def test_openai_token_limit_warning_not_triggered(caplog, mock_openai_tokenizer, model_name, max_tokens_limit):
layer = OpenAIInvocationLayer(
model_name_or_path=model_name, api_key="fake_api_key", api_base="https://fake_api_base.com", max_length=256
)

assert layer.max_tokens_limit == max_tokens_limit

# the warning is not triggered because max_length is 256, our prompt is 11 tokens, and we have big context window
_ = layer._ensure_token_limit(prompt="This is a test for a mock openai tokenizer.")
assert not caplog.text


@pytest.mark.unit
@pytest.mark.parametrize(
"model_name,max_tokens_limit",
[
("text-davinci-003", 4097),
("gpt-3.5-turbo", 4096),
("gpt-3.5-turbo-16k", 16384),
("gpt-4-32k", 32768),
("gpt-4", 8192),
],
)
def test_openai_token_limit_warning_is_triggered(caplog, mock_openai_tokenizer, model_name, max_tokens_limit):
layer = OpenAIInvocationLayer(
model_name_or_path=model_name,
api_key="fake_api_key",
api_base="https://fake_api_base.com",
max_length=int(max_tokens_limit) - 1,
)

assert layer.max_tokens_limit == max_tokens_limit

# the warning is triggered because max_length is one token smaller than context window and our prompt has 11 tokens
_ = layer._ensure_token_limit(prompt="This is a test for a mock openai tokenizer.")

# since we are truncating the prompt of 11 tokens, we should see a warning that only 1 token is left
assert "The prompt has been truncated from 11 tokens to 1 tokens" in caplog.text


@pytest.mark.unit
def test_no_openai_organization(mock_open_ai_request, load_openai_tokenizer):
invocation_layer = OpenAIInvocationLayer(api_key="fake_api_key")

assert invocation_layer.openai_organization is None
assert "OpenAI-Organization" not in invocation_layer.headers

invocation_layer.invoke(prompt="dummy_prompt")
assert "OpenAI-Organization" not in mock_request.call_args.kwargs["headers"]
assert "OpenAI-Organization" not in mock_open_ai_request.call_args.kwargs["headers"]


@pytest.mark.unit
@patch("haystack.nodes.prompt.invocation_layer.open_ai.openai_request")
def test_openai_organization(mock_request):
with patch("haystack.nodes.prompt.invocation_layer.open_ai.load_openai_tokenizer"):
invocation_layer = OpenAIInvocationLayer(api_key="fake_api_key", openai_organization="fake_organization")
def test_openai_organization(mock_open_ai_request, load_openai_tokenizer):
invocation_layer = OpenAIInvocationLayer(api_key="fake_api_key", openai_organization="fake_organization")

assert invocation_layer.openai_organization == "fake_organization"
assert invocation_layer.headers["OpenAI-Organization"] == "fake_organization"

invocation_layer.invoke(prompt="dummy_prompt")
assert mock_request.call_args.kwargs["headers"]["OpenAI-Organization"] == "fake_organization"
assert mock_open_ai_request.call_args.kwargs["headers"]["OpenAI-Organization"] == "fake_organization"

0 comments on commit adfabdd

Please sign in to comment.