Skip to content

Commit

Permalink
feat: made truncation optional for BedrockGenerator (#833)
Browse files Browse the repository at this point in the history
* Added truncate parameter to init method

* fixed serialization bug for BedrockGenerator

* Add a test to check truncation functionality
  • Loading branch information
Amnah199 authored Jul 1, 2024
1 parent 2d93ea3 commit 268b487
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __init__(
aws_region_name: Optional[Secret] = Secret.from_env_var("AWS_DEFAULT_REGION", strict=False), # noqa: B008
aws_profile_name: Optional[Secret] = Secret.from_env_var("AWS_PROFILE", strict=False), # noqa: B008
max_length: Optional[int] = 100,
truncate: Optional[bool] = True,
**kwargs,
):
"""
Expand All @@ -87,6 +88,7 @@ def __init__(
:param aws_region_name: The AWS region name.
:param aws_profile_name: The AWS profile name.
:param max_length: The maximum length of the generated text.
:param truncate: Whether to truncate the prompt or not.
:param kwargs: Additional keyword arguments to be passed to the model.
:raises ValueError: If the model name is empty or None.
:raises AmazonBedrockConfigurationError: If the AWS environment is not configured correctly or the model is
Expand All @@ -97,11 +99,13 @@ def __init__(
raise ValueError(msg)
self.model = model
self.max_length = max_length
self.truncate = truncate
self.aws_access_key_id = aws_access_key_id
self.aws_secret_access_key = aws_secret_access_key
self.aws_session_token = aws_session_token
self.aws_region_name = aws_region_name
self.aws_profile_name = aws_profile_name
self.kwargs = kwargs

def resolve_secret(secret: Optional[Secret]) -> Optional[str]:
return secret.resolve_value() if secret else None
Expand Down Expand Up @@ -129,6 +133,7 @@ def resolve_secret(secret: Optional[Secret]) -> Optional[str]:
# Truncate prompt if prompt tokens > model_max_length-max_length
# (max_length is the length of the generated text)
# we use GPT2 tokenizer which will likely provide good token count approximation

self.prompt_handler = DefaultPromptHandler(
tokenizer="gpt2",
model_max_length=model_max_length,
Expand Down Expand Up @@ -189,6 +194,9 @@ def invoke(self, *args, **kwargs):
)
raise ValueError(msg)

if self.truncate:
prompt = self._ensure_token_limit(prompt)

body = self.model_adapter.prepare_body(prompt=prompt, **kwargs)
try:
if stream:
Expand Down Expand Up @@ -266,6 +274,8 @@ def to_dict(self) -> Dict[str, Any]:
aws_profile_name=self.aws_profile_name.to_dict() if self.aws_profile_name else None,
model=self.model,
max_length=self.max_length,
truncate=self.truncate,
**self.kwargs,
)

@classmethod
Expand Down
47 changes: 43 additions & 4 deletions integrations/amazon_bedrock/tests/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,7 @@ def test_to_dict(mock_boto3_session):
"""
Test that the to_dict method returns the correct dictionary without aws credentials
"""
generator = AmazonBedrockGenerator(
model="anthropic.claude-v2",
max_length=99,
)
generator = AmazonBedrockGenerator(model="anthropic.claude-v2", max_length=99, truncate=False, temperature=10)

expected_dict = {
"type": "haystack_integrations.components.generators.amazon_bedrock.generator.AmazonBedrockGenerator",
Expand All @@ -35,6 +32,8 @@ def test_to_dict(mock_boto3_session):
"aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False},
"model": "anthropic.claude-v2",
"max_length": 99,
"truncate": False,
"temperature": 10,
},
}

Expand Down Expand Up @@ -194,6 +193,46 @@ def test_long_prompt_is_truncated(mock_boto3_session):
assert prompt_after_resize == truncated_prompt_text


def test_long_prompt_is_not_truncated_when_truncate_false(mock_boto3_session):
"""
Test that a long prompt is not truncated and _ensure_token_limit is not called when truncate is set to False
"""
long_prompt_text = "I am a tokenized prompt of length eight"

# Our mock prompt is 8 tokens long, so it exceeds the total limit (8 prompt tokens + 3 generated tokens > 10 tokens)
max_length_generated_text = 3
total_model_max_length = 10

with patch("transformers.AutoTokenizer.from_pretrained", return_value=MagicMock()):
generator = AmazonBedrockGenerator(
model="anthropic.claude-v2",
max_length=max_length_generated_text,
model_max_length=total_model_max_length,
truncate=False,
)

# Mock the _ensure_token_limit method to track if it is called
with patch.object(
generator, "_ensure_token_limit", wraps=generator._ensure_token_limit
) as mock_ensure_token_limit:
# Mock the model adapter to avoid actual invocation
generator.model_adapter.prepare_body = MagicMock(return_value={})
generator.client = MagicMock()
generator.client.invoke_model = MagicMock(
return_value={"body": MagicMock(read=MagicMock(return_value=b'{"generated_text": "response"}'))}
)
generator.model_adapter.get_responses = MagicMock(return_value=["response"])

# Invoke the generator
generator.invoke(prompt=long_prompt_text)

# Ensure _ensure_token_limit was not called
mock_ensure_token_limit.assert_not_called(),

# Check the prompt passed to prepare_body
generator.model_adapter.prepare_body.assert_called_with(prompt=long_prompt_text)


@pytest.mark.parametrize(
"model, expected_model_adapter",
[
Expand Down

0 comments on commit 268b487

Please sign in to comment.