Skip to content

Commit

Permalink
fix: enable passing max_length for text2text-generation task (#5420)
Browse files Browse the repository at this point in the history
* bug fix

* add unit test

* reformatting

* add release note

* add release note

* Update releasenotes/notes/enable-set-max-length-during-runtime-097d65e537bf800b.yaml

Co-authored-by: bogdankostic <[email protected]>

* Update test/prompt/invocation_layer/test_hugging_face.py

Co-authored-by: bogdankostic <[email protected]>

* Update test/prompt/invocation_layer/test_hugging_face.py

Co-authored-by: bogdankostic <[email protected]>

* Update test/prompt/invocation_layer/test_hugging_face.py

Co-authored-by: bogdankostic <[email protected]>

* Update test/prompt/invocation_layer/test_hugging_face.py

Co-authored-by: bogdankostic <[email protected]>

* bug fix

---------

Co-authored-by: bogdankostic <[email protected]>
  • Loading branch information
faaany and bogdankostic authored Aug 2, 2023
1 parent 40a2e9b commit 73fa796
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 1 deletion.
2 changes: 1 addition & 1 deletion haystack/nodes/prompt/invocation_layer/hugging_face.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def invoke(self, *args, **kwargs):
if is_text_generation:
model_input_kwargs["max_new_tokens"] = model_input_kwargs.pop("max_length", self.max_length)
else:
model_input_kwargs["max_length"] = self.max_length
model_input_kwargs["max_length"] = model_input_kwargs.pop("max_length", self.max_length)

if stream:
stream_handler: TokenStreamingHandler = stream_handler or DefaultTokenStreamingHandler()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
enhancements:
- |
Enable setting the `max_length` value when running PromptNodes using local HF text2text-generation models.
17 changes: 17 additions & 0 deletions test/prompt/invocation_layer/test_hugging_face.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,23 @@ def test_generation_kwargs_from_invoke():
mock_call.assert_called_with(the_question, {}, {"do_sample": True, "top_p": 0.9, "max_length": 100}, {})


@pytest.mark.unit
def test_max_length_from_invoke(mock_auto_tokenizer, mock_pipeline, mock_get_task):
"""
Test that max_length passed to invoke are passed to the underlying HF model
"""
query = "What does 42 mean?"
# test that generation_kwargs are passed to the underlying HF model
layer = HFLocalInvocationLayer()
layer.invoke(prompt=query, generation_kwargs={"max_length": 200})
# find the call to pipeline invocation, and check that the kwargs are correct
assert any((call.kwargs == {"max_length": 200}) and (query in call.args) for call in mock_pipeline.mock_calls)

layer = HFLocalInvocationLayer()
layer.invoke(prompt=query, generation_kwargs=GenerationConfig(max_length=235))
assert any((call.kwargs == {"max_length": 235}) and (query in call.args) for call in mock_pipeline.mock_calls)


@pytest.mark.unit
def test_ensure_token_limit_positive_mock(mock_pipeline, mock_get_task, mock_auto_tokenizer):
# prompt of length 5 + max_length of 3 = 8, which is less than model_max_length of 10, so no resize
Expand Down

0 comments on commit 73fa796

Please sign in to comment.