Skip to content

Commit

Permalink
improve tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ZanSara committed Aug 31, 2023
1 parent 725fabe commit c1a7696
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 17 deletions.
23 changes: 6 additions & 17 deletions test/preview/components/generators/openai/test_openai_helpers.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,20 @@
from unittest.mock import Mock

import pytest

from haystack.preview.components.generators.openai._helpers import enforce_token_limit


@pytest.mark.unit
def test_enforce_token_limit_above_limit(caplog):
tokenizer = Mock()
tokenizer.encode = lambda text: text.split()
tokenizer.decode = lambda tokens: " ".join(tokens)

assert enforce_token_limit("This is a test prompt.", tokenizer=tokenizer, max_tokens_limit=3) == "This is a"
def test_enforce_token_limit_above_limit(caplog, mock_tokenizer):
prompt = enforce_token_limit("This is a test prompt.", tokenizer=mock_tokenizer, max_tokens_limit=3)
assert prompt == "This is a"
assert caplog.records[0].message == (
"The prompt has been truncated from 5 tokens to 3 tokens to fit within the max token "
"limit. Reduce the length of the prompt to prevent it from being cut off."
)


@pytest.mark.unit
def test_enforce_token_limit_below_limit(caplog):
tokenizer = Mock()
tokenizer.encode = lambda text: text.split()
tokenizer.decode = lambda tokens: " ".join(tokens)

assert (
enforce_token_limit("This is a test prompt.", tokenizer=tokenizer, max_tokens_limit=1000)
== "This is a test prompt."
)
def test_enforce_token_limit_below_limit(caplog, mock_tokenizer):
prompt = enforce_token_limit("This is a test prompt.", tokenizer=mock_tokenizer, max_tokens_limit=100)
assert prompt == "This is a test prompt."
assert not caplog.records
13 changes: 13 additions & 0 deletions test/preview/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from unittest.mock import Mock
import pytest


@pytest.fixture()
def mock_tokenizer():
"""
Tokenizes the string by splitting on spaces.
"""
tokenizer = Mock()
tokenizer.encode = lambda text: text.split()
tokenizer.decode = lambda tokens: " ".join(tokens)
return tokenizer

0 comments on commit c1a7696

Please sign in to comment.