diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py index 81a02b749..30fc759e0 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py @@ -73,6 +73,7 @@ def __init__( aws_region_name: Optional[Secret] = Secret.from_env_var("AWS_DEFAULT_REGION", strict=False), # noqa: B008 aws_profile_name: Optional[Secret] = Secret.from_env_var("AWS_PROFILE", strict=False), # noqa: B008 max_length: Optional[int] = 100, + truncate: Optional[bool] = True, **kwargs, ): """ @@ -85,6 +86,7 @@ def __init__( :param aws_region_name: The AWS region name. :param aws_profile_name: The AWS profile name. :param max_length: The maximum length of the generated text. + :param truncate: Whether to truncate the prompt or not. :param kwargs: Additional keyword arguments to be passed to the model. :raises ValueError: If the model name is empty or None. :raises AmazonBedrockConfigurationError: If the AWS environment is not configured correctly or the model is @@ -95,11 +97,13 @@ def __init__( raise ValueError(msg) self.model = model self.max_length = max_length + self.truncate = truncate self.aws_access_key_id = aws_access_key_id self.aws_secret_access_key = aws_secret_access_key self.aws_session_token = aws_session_token self.aws_region_name = aws_region_name self.aws_profile_name = aws_profile_name + self.kwargs = kwargs def resolve_secret(secret: Optional[Secret]) -> Optional[str]: return secret.resolve_value() if secret else None @@ -127,6 +131,7 @@ def resolve_secret(secret: Optional[Secret]) -> Optional[str]: # Truncate prompt if prompt tokens > model_max_length-max_length # (max_length is the length of the generated text) # we use GPT2 tokenizer which will likely provide good token count approximation + self.prompt_handler = DefaultPromptHandler( tokenizer="gpt2", model_max_length=model_max_length, @@ -187,6 +192,9 @@ def invoke(self, *args, **kwargs): ) raise ValueError(msg) + if self.truncate: + prompt = self._ensure_token_limit(prompt) + body = self.model_adapter.prepare_body(prompt=prompt, **kwargs) try: if stream: @@ -264,6 +272,8 @@ def to_dict(self) -> Dict[str, Any]: aws_profile_name=self.aws_profile_name.to_dict() if self.aws_profile_name else None, model=self.model, max_length=self.max_length, + truncate=self.truncate, + **self.kwargs, ) @classmethod diff --git a/integrations/amazon_bedrock/tests/test_generator.py b/integrations/amazon_bedrock/tests/test_generator.py index 10fc1eca8..5326d143b 100644 --- a/integrations/amazon_bedrock/tests/test_generator.py +++ b/integrations/amazon_bedrock/tests/test_generator.py @@ -19,10 +19,7 @@ def test_to_dict(mock_boto3_session): """ Test that the to_dict method returns the correct dictionary without aws credentials """ - generator = AmazonBedrockGenerator( - model="anthropic.claude-v2", - max_length=99, - ) + generator = AmazonBedrockGenerator(model="anthropic.claude-v2", max_length=99, truncate=False, temperature=10) expected_dict = { "type": "haystack_integrations.components.generators.amazon_bedrock.generator.AmazonBedrockGenerator", @@ -34,6 +31,8 @@ def test_to_dict(mock_boto3_session): "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False}, "model": "anthropic.claude-v2", "max_length": 99, + "truncate": False, + "temperature": 10, }, } @@ -193,6 +192,46 @@ def test_long_prompt_is_truncated(mock_boto3_session): assert prompt_after_resize == truncated_prompt_text +def test_long_prompt_is_not_truncated_when_truncate_false(mock_boto3_session): + """ + Test that a long prompt is not truncated and _ensure_token_limit is not called when truncate is set to False + """ + long_prompt_text = "I am a tokenized prompt of length eight" + + # Our mock prompt is 8 tokens long, so it exceeds the total limit (8 prompt tokens + 3 generated tokens > 10 tokens) + max_length_generated_text = 3 + total_model_max_length = 10 + + with patch("transformers.AutoTokenizer.from_pretrained", return_value=MagicMock()): + generator = AmazonBedrockGenerator( + model="anthropic.claude-v2", + max_length=max_length_generated_text, + model_max_length=total_model_max_length, + truncate=False, + ) + + # Mock the _ensure_token_limit method to track if it is called + with patch.object( + generator, "_ensure_token_limit", wraps=generator._ensure_token_limit + ) as mock_ensure_token_limit: + # Mock the model adapter to avoid actual invocation + generator.model_adapter.prepare_body = MagicMock(return_value={}) + generator.client = MagicMock() + generator.client.invoke_model = MagicMock( + return_value={"body": MagicMock(read=MagicMock(return_value=b'{"generated_text": "response"}'))} + ) + generator.model_adapter.get_responses = MagicMock(return_value=["response"]) + + # Invoke the generator + generator.invoke(prompt=long_prompt_text) + + # Ensure _ensure_token_limit was not called + mock_ensure_token_limit.assert_not_called(), + + # Check the prompt passed to prepare_body + generator.model_adapter.prepare_body.assert_called_with(prompt=long_prompt_text) + + @pytest.mark.parametrize( "model, expected_model_adapter", [