diff --git a/haystack/nodes/prompt/invocation_layer/hugging_face.py b/haystack/nodes/prompt/invocation_layer/hugging_face.py index a8100bc89c..1705453e7c 100644 --- a/haystack/nodes/prompt/invocation_layer/hugging_face.py +++ b/haystack/nodes/prompt/invocation_layer/hugging_face.py @@ -28,6 +28,13 @@ class StopWordsCriteria(StoppingCriteria): """ Stops text generation if any one of the stop words is generated. + + Note: When a stop word is encountered, the generation of new text is stopped. + However, if the stop word is in the prompt itself, it can stop generating new text + prematurely after the first token. This is particularly important for LLMs designed + for dialogue generation. For these models, like for example mosaicml/mpt-7b-chat, + the output includes both the new text and the original prompt. Therefore, it's important + to make sure your prompt has no stop words. """ def __init__( @@ -37,10 +44,11 @@ def __init__( device: Union[str, torch.device] = "cpu", ): super().__init__() - self.stop_words = tokenizer(stop_words, add_special_tokens=False, return_tensors="pt").to(device) + encoded_stop_words = tokenizer(stop_words, add_special_tokens=False, padding=True, return_tensors="pt") + self.stop_words = encoded_stop_words.input_ids.to(device) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: - stop_result = torch.isin(self.stop_words["input_ids"], input_ids[-1]) + stop_result = torch.isin(self.stop_words, input_ids[-1]) return any(all(stop_word) for stop_word in stop_result) diff --git a/test/prompt/invocation_layer/test_hugging_face.py b/test/prompt/invocation_layer/test_hugging_face.py index f5d6a6b078..2ff1f52e19 100644 --- a/test/prompt/invocation_layer/test_hugging_face.py +++ b/test/prompt/invocation_layer/test_hugging_face.py @@ -1,3 +1,4 @@ +from typing import List from unittest.mock import MagicMock, patch, Mock import pytest @@ -461,8 +462,8 @@ def test_stop_words_criteria_set(mock_pipeline, mock_get_task): @pytest.mark.integration -@pytest.mark.parametrize("stop_words", [["good"], ["hello", "good"], ["hello", "good", "health"]]) -def test_stop_words_single_token(stop_words): +@pytest.mark.parametrize("stop_words", [["good"], ["hello", "good"]]) +def test_stop_words_single_token(stop_words: List[str]): """ Test that stop words criteria is used and that it works with single token stop words """ @@ -470,9 +471,10 @@ def test_stop_words_single_token(stop_words): # simple test with words not broken down into multiple tokens default_model = "google/flan-t5-base" tokenizer = AutoTokenizer.from_pretrained(default_model) - # each word is broken down into a single token - tokens = tokenizer.tokenize("good health wish") - assert len(tokens) == 3 + for stop_word in stop_words: + # confirm we are dealing with single-token words + tokens = tokenizer.tokenize(stop_word) + assert len(tokens) == 1 layer = HFLocalInvocationLayer(model_name_or_path=default_model) result = layer.invoke(prompt="Generate a sentence `I wish you a good health`", stop_words=stop_words) @@ -483,21 +485,22 @@ def test_stop_words_single_token(stop_words): @pytest.mark.integration -def test_stop_words_multiple_token(): +@pytest.mark.parametrize( + "stop_words", [["unambiguously"], ["unambiguously", "unrelated"], ["unambiguously", "hearted"]] +) +def test_stop_words_multiple_token(stop_words: List[str]): """ Test that stop words criteria is used and that it works for multi-token words """ - # complex test with words broken down into multiple tokens default_model = "google/flan-t5-base" tokenizer = AutoTokenizer.from_pretrained(default_model) - # single word unambiguously is broken down into 3 tokens - tokens = tokenizer.tokenize("unambiguously") - assert len(tokens) == 3 + for stop_word in stop_words: + # confirm we are dealing with multi-token words + tokens = tokenizer.tokenize(stop_word) + assert len(tokens) > 1 layer = HFLocalInvocationLayer(model_name_or_path=default_model) - result = layer.invoke( - prompt="Generate a sentence `I wish you unambiguously good health`", stop_words=["unambiguously"] - ) + result = layer.invoke(prompt="Generate a sentence `I wish you unambiguously good health`", stop_words=stop_words) # yet the stop word is correctly stopped on and removed assert len(result) > 0 assert result[0].startswith("I wish you") @@ -507,10 +510,13 @@ def test_stop_words_multiple_token(): @pytest.mark.integration -def test_stop_words_not_being_found(): - # simple test with words not broken down into multiple tokens +@pytest.mark.parametrize("stop_words", [["Berlin"], ["Berlin", "Brandenburg"], ["Berlin", "Brandenburg", "Germany"]]) +def test_stop_words_not_being_found(stop_words: List[str]): + """ + Test that stop works on tokens that are not found in the generated text, stop words are not found + """ layer = HFLocalInvocationLayer() - result = layer.invoke(prompt="Generate a sentence `I wish you a good health`", stop_words=["Berlin"]) + result = layer.invoke(prompt="Generate a sentence `I wish you a good health`", stop_words=stop_words) assert len(result) > 0 for word in "I wish you a good health".split(): assert word in result[0]