From c1a7696f801181178d51153aa3b623ceb1fc1b96 Mon Sep 17 00:00:00 2001 From: ZanSara Date: Thu, 31 Aug 2023 12:16:43 +0200 Subject: [PATCH] improve tests --- .../generators/openai/test_openai_helpers.py | 23 +++++-------------- test/preview/conftest.py | 13 +++++++++++ 2 files changed, 19 insertions(+), 17 deletions(-) create mode 100644 test/preview/conftest.py diff --git a/test/preview/components/generators/openai/test_openai_helpers.py b/test/preview/components/generators/openai/test_openai_helpers.py index a0003fa323..23a66117d1 100644 --- a/test/preview/components/generators/openai/test_openai_helpers.py +++ b/test/preview/components/generators/openai/test_openai_helpers.py @@ -1,17 +1,12 @@ -from unittest.mock import Mock - import pytest from haystack.preview.components.generators.openai._helpers import enforce_token_limit @pytest.mark.unit -def test_enforce_token_limit_above_limit(caplog): - tokenizer = Mock() - tokenizer.encode = lambda text: text.split() - tokenizer.decode = lambda tokens: " ".join(tokens) - - assert enforce_token_limit("This is a test prompt.", tokenizer=tokenizer, max_tokens_limit=3) == "This is a" +def test_enforce_token_limit_above_limit(caplog, mock_tokenizer): + prompt = enforce_token_limit("This is a test prompt.", tokenizer=mock_tokenizer, max_tokens_limit=3) + assert prompt == "This is a" assert caplog.records[0].message == ( "The prompt has been truncated from 5 tokens to 3 tokens to fit within the max token " "limit. Reduce the length of the prompt to prevent it from being cut off." @@ -19,13 +14,7 @@ def test_enforce_token_limit_above_limit(caplog): @pytest.mark.unit -def test_enforce_token_limit_below_limit(caplog): - tokenizer = Mock() - tokenizer.encode = lambda text: text.split() - tokenizer.decode = lambda tokens: " ".join(tokens) - - assert ( - enforce_token_limit("This is a test prompt.", tokenizer=tokenizer, max_tokens_limit=1000) - == "This is a test prompt." - ) +def test_enforce_token_limit_below_limit(caplog, mock_tokenizer): + prompt = enforce_token_limit("This is a test prompt.", tokenizer=mock_tokenizer, max_tokens_limit=100) + assert prompt == "This is a test prompt." assert not caplog.records diff --git a/test/preview/conftest.py b/test/preview/conftest.py new file mode 100644 index 0000000000..b8abfa41a6 --- /dev/null +++ b/test/preview/conftest.py @@ -0,0 +1,13 @@ +from unittest.mock import Mock +import pytest + + +@pytest.fixture() +def mock_tokenizer(): + """ + Tokenizes the string by splitting on spaces. + """ + tokenizer = Mock() + tokenizer.encode = lambda text: text.split() + tokenizer.decode = lambda tokens: " ".join(tokens) + return tokenizer