Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: made truncation optional for BedrockGenerator #833

Merged
merged 4 commits into from
Jul 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def __init__(
aws_region_name: Optional[Secret] = Secret.from_env_var("AWS_DEFAULT_REGION", strict=False), # noqa: B008
aws_profile_name: Optional[Secret] = Secret.from_env_var("AWS_PROFILE", strict=False), # noqa: B008
max_length: Optional[int] = 100,
truncate: Optional[bool] = True,
**kwargs,
):
"""
Expand All @@ -85,6 +86,7 @@ def __init__(
:param aws_region_name: The AWS region name.
:param aws_profile_name: The AWS profile name.
:param max_length: The maximum length of the generated text.
:param truncate: Whether to truncate the prompt or not.
:param kwargs: Additional keyword arguments to be passed to the model.
:raises ValueError: If the model name is empty or None.
:raises AmazonBedrockConfigurationError: If the AWS environment is not configured correctly or the model is
Expand All @@ -95,11 +97,13 @@ def __init__(
raise ValueError(msg)
self.model = model
self.max_length = max_length
self.truncate = truncate
self.aws_access_key_id = aws_access_key_id
self.aws_secret_access_key = aws_secret_access_key
self.aws_session_token = aws_session_token
self.aws_region_name = aws_region_name
self.aws_profile_name = aws_profile_name
self.kwargs = kwargs

def resolve_secret(secret: Optional[Secret]) -> Optional[str]:
return secret.resolve_value() if secret else None
Expand Down Expand Up @@ -127,6 +131,7 @@ def resolve_secret(secret: Optional[Secret]) -> Optional[str]:
# Truncate prompt if prompt tokens > model_max_length-max_length
# (max_length is the length of the generated text)
# we use GPT2 tokenizer which will likely provide good token count approximation

self.prompt_handler = DefaultPromptHandler(
tokenizer="gpt2",
model_max_length=model_max_length,
Expand Down Expand Up @@ -187,6 +192,9 @@ def invoke(self, *args, **kwargs):
)
raise ValueError(msg)

if self.truncate:
prompt = self._ensure_token_limit(prompt)

body = self.model_adapter.prepare_body(prompt=prompt, **kwargs)
try:
if stream:
Expand Down Expand Up @@ -264,6 +272,8 @@ def to_dict(self) -> Dict[str, Any]:
aws_profile_name=self.aws_profile_name.to_dict() if self.aws_profile_name else None,
model=self.model,
max_length=self.max_length,
truncate=self.truncate,
**self.kwargs,
)

@classmethod
Expand Down
47 changes: 43 additions & 4 deletions integrations/amazon_bedrock/tests/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,7 @@ def test_to_dict(mock_boto3_session):
"""
Test that the to_dict method returns the correct dictionary without aws credentials
"""
generator = AmazonBedrockGenerator(
model="anthropic.claude-v2",
max_length=99,
)
generator = AmazonBedrockGenerator(model="anthropic.claude-v2", max_length=99, truncate=False, temperature=10)

expected_dict = {
"type": "haystack_integrations.components.generators.amazon_bedrock.generator.AmazonBedrockGenerator",
Expand All @@ -34,6 +31,8 @@ def test_to_dict(mock_boto3_session):
"aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False},
"model": "anthropic.claude-v2",
"max_length": 99,
"truncate": False,
"temperature": 10,
},
}

Expand Down Expand Up @@ -193,6 +192,46 @@ def test_long_prompt_is_truncated(mock_boto3_session):
assert prompt_after_resize == truncated_prompt_text


def test_long_prompt_is_not_truncated_when_truncate_false(mock_boto3_session):
"""
Test that a long prompt is not truncated and _ensure_token_limit is not called when truncate is set to False
"""
long_prompt_text = "I am a tokenized prompt of length eight"

# Our mock prompt is 8 tokens long, so it exceeds the total limit (8 prompt tokens + 3 generated tokens > 10 tokens)
max_length_generated_text = 3
total_model_max_length = 10

with patch("transformers.AutoTokenizer.from_pretrained", return_value=MagicMock()):
generator = AmazonBedrockGenerator(
model="anthropic.claude-v2",
max_length=max_length_generated_text,
model_max_length=total_model_max_length,
truncate=False,
)

# Mock the _ensure_token_limit method to track if it is called
with patch.object(
generator, "_ensure_token_limit", wraps=generator._ensure_token_limit
) as mock_ensure_token_limit:
# Mock the model adapter to avoid actual invocation
generator.model_adapter.prepare_body = MagicMock(return_value={})
generator.client = MagicMock()
generator.client.invoke_model = MagicMock(
return_value={"body": MagicMock(read=MagicMock(return_value=b'{"generated_text": "response"}'))}
)
generator.model_adapter.get_responses = MagicMock(return_value=["response"])

# Invoke the generator
generator.invoke(prompt=long_prompt_text)

# Ensure _ensure_token_limit was not called
mock_ensure_token_limit.assert_not_called(),

# Check the prompt passed to prepare_body
generator.model_adapter.prepare_body.assert_called_with(prompt=long_prompt_text)


@pytest.mark.parametrize(
"model, expected_model_adapter",
[
Expand Down