Skip to content

Commit

Permalink
fix: a small bug in StopWordsCriteria (#5316)
Browse files Browse the repository at this point in the history
  • Loading branch information
faaany authored Jul 13, 2023
1 parent 237d67d commit 9891bfe
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 18 deletions.
12 changes: 10 additions & 2 deletions haystack/nodes/prompt/invocation_layer/hugging_face.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,13 @@
class StopWordsCriteria(StoppingCriteria):
"""
Stops text generation if any one of the stop words is generated.
Note: When a stop word is encountered, the generation of new text is stopped.
However, if the stop word is in the prompt itself, it can stop generating new text
prematurely after the first token. This is particularly important for LLMs designed
for dialogue generation. For these models, like for example mosaicml/mpt-7b-chat,
the output includes both the new text and the original prompt. Therefore, it's important
to make sure your prompt has no stop words.
"""

def __init__(
Expand All @@ -37,10 +44,11 @@ def __init__(
device: Union[str, torch.device] = "cpu",
):
super().__init__()
self.stop_words = tokenizer(stop_words, add_special_tokens=False, return_tensors="pt").to(device)
encoded_stop_words = tokenizer(stop_words, add_special_tokens=False, padding=True, return_tensors="pt")
self.stop_words = encoded_stop_words.input_ids.to(device)

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
stop_result = torch.isin(self.stop_words["input_ids"], input_ids[-1])
stop_result = torch.isin(self.stop_words, input_ids[-1])
return any(all(stop_word) for stop_word in stop_result)


Expand Down
38 changes: 22 additions & 16 deletions test/prompt/invocation_layer/test_hugging_face.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import List
from unittest.mock import MagicMock, patch, Mock

import pytest
Expand Down Expand Up @@ -461,18 +462,19 @@ def test_stop_words_criteria_set(mock_pipeline, mock_get_task):


@pytest.mark.integration
@pytest.mark.parametrize("stop_words", [["good"], ["hello", "good"], ["hello", "good", "health"]])
def test_stop_words_single_token(stop_words):
@pytest.mark.parametrize("stop_words", [["good"], ["hello", "good"]])
def test_stop_words_single_token(stop_words: List[str]):
"""
Test that stop words criteria is used and that it works with single token stop words
"""

# simple test with words not broken down into multiple tokens
default_model = "google/flan-t5-base"
tokenizer = AutoTokenizer.from_pretrained(default_model)
# each word is broken down into a single token
tokens = tokenizer.tokenize("good health wish")
assert len(tokens) == 3
for stop_word in stop_words:
# confirm we are dealing with single-token words
tokens = tokenizer.tokenize(stop_word)
assert len(tokens) == 1

layer = HFLocalInvocationLayer(model_name_or_path=default_model)
result = layer.invoke(prompt="Generate a sentence `I wish you a good health`", stop_words=stop_words)
Expand All @@ -483,21 +485,22 @@ def test_stop_words_single_token(stop_words):


@pytest.mark.integration
def test_stop_words_multiple_token():
@pytest.mark.parametrize(
"stop_words", [["unambiguously"], ["unambiguously", "unrelated"], ["unambiguously", "hearted"]]
)
def test_stop_words_multiple_token(stop_words: List[str]):
"""
Test that stop words criteria is used and that it works for multi-token words
"""
# complex test with words broken down into multiple tokens
default_model = "google/flan-t5-base"
tokenizer = AutoTokenizer.from_pretrained(default_model)
# single word unambiguously is broken down into 3 tokens
tokens = tokenizer.tokenize("unambiguously")
assert len(tokens) == 3
for stop_word in stop_words:
# confirm we are dealing with multi-token words
tokens = tokenizer.tokenize(stop_word)
assert len(tokens) > 1

layer = HFLocalInvocationLayer(model_name_or_path=default_model)
result = layer.invoke(
prompt="Generate a sentence `I wish you unambiguously good health`", stop_words=["unambiguously"]
)
result = layer.invoke(prompt="Generate a sentence `I wish you unambiguously good health`", stop_words=stop_words)
# yet the stop word is correctly stopped on and removed
assert len(result) > 0
assert result[0].startswith("I wish you")
Expand All @@ -507,10 +510,13 @@ def test_stop_words_multiple_token():


@pytest.mark.integration
def test_stop_words_not_being_found():
# simple test with words not broken down into multiple tokens
@pytest.mark.parametrize("stop_words", [["Berlin"], ["Berlin", "Brandenburg"], ["Berlin", "Brandenburg", "Germany"]])
def test_stop_words_not_being_found(stop_words: List[str]):
"""
Test that stop works on tokens that are not found in the generated text, stop words are not found
"""
layer = HFLocalInvocationLayer()
result = layer.invoke(prompt="Generate a sentence `I wish you a good health`", stop_words=["Berlin"])
result = layer.invoke(prompt="Generate a sentence `I wish you a good health`", stop_words=stop_words)
assert len(result) > 0
for word in "I wish you a good health".split():
assert word in result[0]
Expand Down

0 comments on commit 9891bfe

Please sign in to comment.