From 268b487a2e8633acecc917e51746eafb2040a9a6 Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Tue, 2 Jul 2024 01:45:31 +0200 Subject: [PATCH] feat: made truncation optional for BedrockGenerator (#833) * Added truncate parameter to init method * fixed serialization bug for BedrockGenerator * Add a test to check truncation functionality --- .../generators/amazon_bedrock/generator.py | 10 ++++ .../amazon_bedrock/tests/test_generator.py | 47 +++++++++++++++++-- 2 files changed, 53 insertions(+), 4 deletions(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py index b93ba1d3f..32d1de629 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py @@ -75,6 +75,7 @@ def __init__( aws_region_name: Optional[Secret] = Secret.from_env_var("AWS_DEFAULT_REGION", strict=False), # noqa: B008 aws_profile_name: Optional[Secret] = Secret.from_env_var("AWS_PROFILE", strict=False), # noqa: B008 max_length: Optional[int] = 100, + truncate: Optional[bool] = True, **kwargs, ): """ @@ -87,6 +88,7 @@ def __init__( :param aws_region_name: The AWS region name. :param aws_profile_name: The AWS profile name. :param max_length: The maximum length of the generated text. + :param truncate: Whether to truncate the prompt or not. :param kwargs: Additional keyword arguments to be passed to the model. :raises ValueError: If the model name is empty or None. :raises AmazonBedrockConfigurationError: If the AWS environment is not configured correctly or the model is @@ -97,11 +99,13 @@ def __init__( raise ValueError(msg) self.model = model self.max_length = max_length + self.truncate = truncate self.aws_access_key_id = aws_access_key_id self.aws_secret_access_key = aws_secret_access_key self.aws_session_token = aws_session_token self.aws_region_name = aws_region_name self.aws_profile_name = aws_profile_name + self.kwargs = kwargs def resolve_secret(secret: Optional[Secret]) -> Optional[str]: return secret.resolve_value() if secret else None @@ -129,6 +133,7 @@ def resolve_secret(secret: Optional[Secret]) -> Optional[str]: # Truncate prompt if prompt tokens > model_max_length-max_length # (max_length is the length of the generated text) # we use GPT2 tokenizer which will likely provide good token count approximation + self.prompt_handler = DefaultPromptHandler( tokenizer="gpt2", model_max_length=model_max_length, @@ -189,6 +194,9 @@ def invoke(self, *args, **kwargs): ) raise ValueError(msg) + if self.truncate: + prompt = self._ensure_token_limit(prompt) + body = self.model_adapter.prepare_body(prompt=prompt, **kwargs) try: if stream: @@ -266,6 +274,8 @@ def to_dict(self) -> Dict[str, Any]: aws_profile_name=self.aws_profile_name.to_dict() if self.aws_profile_name else None, model=self.model, max_length=self.max_length, + truncate=self.truncate, + **self.kwargs, ) @classmethod diff --git a/integrations/amazon_bedrock/tests/test_generator.py b/integrations/amazon_bedrock/tests/test_generator.py index e603c8853..65463caae 100644 --- a/integrations/amazon_bedrock/tests/test_generator.py +++ b/integrations/amazon_bedrock/tests/test_generator.py @@ -20,10 +20,7 @@ def test_to_dict(mock_boto3_session): """ Test that the to_dict method returns the correct dictionary without aws credentials """ - generator = AmazonBedrockGenerator( - model="anthropic.claude-v2", - max_length=99, - ) + generator = AmazonBedrockGenerator(model="anthropic.claude-v2", max_length=99, truncate=False, temperature=10) expected_dict = { "type": "haystack_integrations.components.generators.amazon_bedrock.generator.AmazonBedrockGenerator", @@ -35,6 +32,8 @@ def test_to_dict(mock_boto3_session): "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False}, "model": "anthropic.claude-v2", "max_length": 99, + "truncate": False, + "temperature": 10, }, } @@ -194,6 +193,46 @@ def test_long_prompt_is_truncated(mock_boto3_session): assert prompt_after_resize == truncated_prompt_text +def test_long_prompt_is_not_truncated_when_truncate_false(mock_boto3_session): + """ + Test that a long prompt is not truncated and _ensure_token_limit is not called when truncate is set to False + """ + long_prompt_text = "I am a tokenized prompt of length eight" + + # Our mock prompt is 8 tokens long, so it exceeds the total limit (8 prompt tokens + 3 generated tokens > 10 tokens) + max_length_generated_text = 3 + total_model_max_length = 10 + + with patch("transformers.AutoTokenizer.from_pretrained", return_value=MagicMock()): + generator = AmazonBedrockGenerator( + model="anthropic.claude-v2", + max_length=max_length_generated_text, + model_max_length=total_model_max_length, + truncate=False, + ) + + # Mock the _ensure_token_limit method to track if it is called + with patch.object( + generator, "_ensure_token_limit", wraps=generator._ensure_token_limit + ) as mock_ensure_token_limit: + # Mock the model adapter to avoid actual invocation + generator.model_adapter.prepare_body = MagicMock(return_value={}) + generator.client = MagicMock() + generator.client.invoke_model = MagicMock( + return_value={"body": MagicMock(read=MagicMock(return_value=b'{"generated_text": "response"}'))} + ) + generator.model_adapter.get_responses = MagicMock(return_value=["response"]) + + # Invoke the generator + generator.invoke(prompt=long_prompt_text) + + # Ensure _ensure_token_limit was not called + mock_ensure_token_limit.assert_not_called(), + + # Check the prompt passed to prepare_body + generator.model_adapter.prepare_body.assert_called_with(prompt=long_prompt_text) + + @pytest.mark.parametrize( "model, expected_model_adapter", [