From f95e4d07de275fec379117c1b7317f7b83f425df Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Mon, 11 Mar 2024 13:09:05 +0100 Subject: [PATCH] feat: AmazonBedrockChatGenerator - migrate Anthropic chat models to use messaging API (#545) * Migrate Claude to messaging API --------- Co-authored-by: Paul Steppacher --- .github/workflows/amazon_bedrock.yml | 11 ++ .../amazon_bedrock/chat/adapters.py | 153 ++++++++++-------- .../amazon_bedrock/chat/chat_generator.py | 13 +- .../tests/test_chat_generator.py | 94 ++++++++--- 4 files changed, 171 insertions(+), 100 deletions(-) diff --git a/.github/workflows/amazon_bedrock.yml b/.github/workflows/amazon_bedrock.yml index 75f881a50..8b1651764 100644 --- a/.github/workflows/amazon_bedrock.yml +++ b/.github/workflows/amazon_bedrock.yml @@ -18,9 +18,14 @@ concurrency: group: amazon-bedrock-${{ github.head_ref }} cancel-in-progress: true +permissions: + id-token: write + contents: read + env: PYTHONUNBUFFERED: "1" FORCE_COLOR: "1" + AWS_REGION: us-east-1 jobs: run: @@ -56,5 +61,11 @@ jobs: if: matrix.python-version == '3.9' && runner.os == 'Linux' run: hatch run docs + - name: AWS authentication + uses: aws-actions/configure-aws-credentials@e3dd6a429d7300a6a4c196c26e071d42e0343502 + with: + aws-region: ${{ env.AWS_REGION }} + role-to-assume: ${{ secrets.AWS_CI_ROLE_ARN }} + - name: Run tests run: hatch run cov diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py index 196a55743..cdb871f40 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py @@ -1,7 +1,7 @@ import json import logging from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, List +from typing import Any, Callable, ClassVar, Dict, List from botocore.eventstream import EventStream from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk @@ -44,33 +44,37 @@ def get_responses(self, response_body: Dict[str, Any]) -> List[ChatMessage]: :param response_body: The response body. :returns: The extracted responses. """ - return self._extract_messages_from_response(self.response_body_message_key(), response_body) + return self._extract_messages_from_response(response_body) - def get_stream_responses(self, stream: EventStream, stream_handler: Callable[[StreamingChunk], None]) -> List[str]: + def get_stream_responses( + self, stream: EventStream, stream_handler: Callable[[StreamingChunk], None] + ) -> List[ChatMessage]: tokens: List[str] = [] + last_decoded_chunk: Dict[str, Any] = {} for event in stream: chunk = event.get("chunk") if chunk: - decoded_chunk = json.loads(chunk["bytes"].decode("utf-8")) - token = self._extract_token_from_stream(decoded_chunk) - # take all the rest key/value pairs from the chunk, add them to the metadata - stream_metadata = {k: v for (k, v) in decoded_chunk.items() if v != token} - stream_chunk = StreamingChunk(content=token, meta=stream_metadata) - # callback the stream handler with StreamingChunk - stream_handler(stream_chunk) + last_decoded_chunk = json.loads(chunk["bytes"].decode("utf-8")) + token = self._extract_token_from_stream(last_decoded_chunk) + stream_chunk = StreamingChunk(content=token) # don't extract meta, we care about tokens only + stream_handler(stream_chunk) # callback the stream handler with StreamingChunk tokens.append(token) responses = ["".join(tokens).lstrip()] - return responses + return [ChatMessage.from_assistant(response, meta=last_decoded_chunk) for response in responses] @staticmethod - def _update_params(target_dict: Dict[str, Any], updates_dict: Dict[str, Any]) -> None: + def _update_params(target_dict: Dict[str, Any], updates_dict: Dict[str, Any], allowed_params: List[str]) -> None: """ Updates target_dict with values from updates_dict. Merges lists instead of overriding them. :param target_dict: The dictionary to update. :param updates_dict: The dictionary with updates. + :param allowed_params: The list of allowed params to use. """ for key, value in updates_dict.items(): + if key not in allowed_params: + logger.warning(f"Parameter '{key}' is not allowed and will be ignored.") + continue if key in target_dict and isinstance(target_dict[key], list) and isinstance(value, list): # Merge lists and remove duplicates target_dict[key] = sorted(set(target_dict[key] + value)) @@ -78,21 +82,24 @@ def _update_params(target_dict: Dict[str, Any], updates_dict: Dict[str, Any]) -> # Override the value in target_dict target_dict[key] = value - def _get_params(self, inference_kwargs: Dict[str, Any], default_params: Dict[str, Any]) -> Dict[str, Any]: + def _get_params( + self, inference_kwargs: Dict[str, Any], default_params: Dict[str, Any], allowed_params: List[str] + ) -> Dict[str, Any]: """ Merges params from inference_kwargs with the default params and self.generation_kwargs. Uses a helper function to merge lists or override values as necessary. :param inference_kwargs: The inference kwargs to merge. :param default_params: The default params to start with. + :param allowed_params: The list of allowed params to use. :returns: The merged params. """ # Start with a copy of default_params kwargs = default_params.copy() # Update the default params with self.generation_kwargs and finally inference_kwargs - self._update_params(kwargs, self.generation_kwargs) - self._update_params(kwargs, inference_kwargs) + self._update_params(kwargs, self.generation_kwargs, allowed_params) + self._update_params(kwargs, inference_kwargs, allowed_params) return kwargs @@ -124,25 +131,14 @@ def check_prompt(self, prompt: str) -> Dict[str, Any]: :returns: A dictionary containing the resized prompt and additional information. """ - def _extract_messages_from_response(self, message_tag: str, response_body: Dict[str, Any]) -> List[ChatMessage]: + @abstractmethod + def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]: """ Extracts the messages from the response body. - :param message_tag: The key for the message in the response body. :param response_body: The response body. :returns: The extracted ChatMessage list. """ - metadata = {k: v for (k, v) in response_body.items() if k != message_tag} - return [ChatMessage.from_assistant(response_body[message_tag], meta=metadata)] - - @abstractmethod - def response_body_message_key(self) -> str: - """ - Returns the key for the message in the response body. - Subclasses should override this method to return the correct message key - where the response is located. - - :returns: The key for the message in the response body. - """ @abstractmethod def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: @@ -159,8 +155,16 @@ class AnthropicClaudeChatAdapter(BedrockModelChatAdapter): Model adapter for the Anthropic Claude chat model. """ - ANTHROPIC_USER_TOKEN = "\n\nHuman:" - ANTHROPIC_ASSISTANT_TOKEN = "\n\nAssistant:" + # https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html + ALLOWED_PARAMS: ClassVar[List[str]] = [ + "anthropic_version", + "max_tokens", + "stop_sequences", + "temperature", + "top_p", + "top_k", + "system", + ] def __init__(self, generation_kwargs: Dict[str, Any]): """ @@ -183,7 +187,7 @@ def __init__(self, generation_kwargs: Dict[str, Any]): self.prompt_handler = DefaultPromptHandler( tokenizer="gpt2", model_max_length=model_max_length, - max_length=self.generation_kwargs.get("max_tokens_to_sample") or 512, + max_length=self.generation_kwargs.get("max_tokens") or 512, ) def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[str, Any]: @@ -195,46 +199,33 @@ def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[ :returns: The prepared body. """ default_params = { - "max_tokens_to_sample": self.generation_kwargs.get("max_tokens_to_sample") or 512, - "stop_sequences": ["\n\nHuman:"], + "anthropic_version": self.generation_kwargs.get("anthropic_version") or "bedrock-2023-05-31", + "max_tokens": self.generation_kwargs.get("max_tokens") or 512, # max_tokens is required } # combine stop words with default stop sequences, remove stop_words as Anthropic does not support it stop_sequences = inference_kwargs.get("stop_sequences", []) + inference_kwargs.pop("stop_words", []) if stop_sequences: inference_kwargs["stop_sequences"] = stop_sequences - params = self._get_params(inference_kwargs, default_params) - body = {"prompt": self.prepare_chat_messages(messages=messages), **params} + params = self._get_params(inference_kwargs, default_params, self.ALLOWED_PARAMS) + body = {**self.prepare_chat_messages(messages=messages), **params} return body - def prepare_chat_messages(self, messages: List[ChatMessage]) -> str: + def prepare_chat_messages(self, messages: List[ChatMessage]) -> Dict[str, Any]: """ Prepares the chat messages for the Anthropic Claude request. :param messages: The chat messages to prepare. :returns: The prepared chat messages as a string. """ - conversation = [] - for index, message in enumerate(messages): - if message.is_from(ChatRole.USER): - conversation.append(f"{AnthropicClaudeChatAdapter.ANTHROPIC_USER_TOKEN} {message.content.strip()}") - elif message.is_from(ChatRole.ASSISTANT): - conversation.append(f"{AnthropicClaudeChatAdapter.ANTHROPIC_ASSISTANT_TOKEN} {message.content.strip()}") - elif message.is_from(ChatRole.FUNCTION): - error_message = "Anthropic does not support function calls." - raise ValueError(error_message) - elif message.is_from(ChatRole.SYSTEM) and index == 0: - # Until we transition to the new chat message format system messages will be ignored - # see https://docs.anthropic.com/claude/reference/messages_post for more details - logger.warning( - "System messages are not fully supported by the current version of Claude and will be ignored." - ) - else: - invalid_role = f"Invalid role {message.role} for message {message.content}" - raise ValueError(invalid_role) - - prepared_prompt = "".join(conversation) + AnthropicClaudeChatAdapter.ANTHROPIC_ASSISTANT_TOKEN + " " - return self._ensure_token_limit(prepared_prompt) + body: Dict[str, Any] = {} + system = messages[0].content if messages and messages[0].is_from(ChatRole.SYSTEM) else None + body["messages"] = [ + self._to_anthropic_message(m) for m in messages if m.is_from(ChatRole.USER) or m.is_from(ChatRole.ASSISTANT) + ] + if system: + body["system"] = system + return body def check_prompt(self, prompt: str) -> Dict[str, Any]: """ @@ -245,13 +236,20 @@ def check_prompt(self, prompt: str) -> Dict[str, Any]: """ return self.prompt_handler(prompt) - def response_body_message_key(self) -> str: + def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]: """ - Returns the key for the message in the response body for Anthropic Claude i.e. "completion". + Extracts the messages from the response body. - :returns: The key for the message in the response body. + :param response_body: The response body. + :return: The extracted ChatMessage list. """ - return "completion" + messages: List[ChatMessage] = [] + if response_body.get("type") == "message": + for content in response_body["content"]: + if content.get("type") == "text": + meta = {k: v for k, v in response_body.items() if k not in ["type", "content", "role"]} + messages.append(ChatMessage.from_assistant(content["text"], meta=meta)) + return messages def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: """ @@ -260,7 +258,17 @@ def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: :param chunk: The streaming chunk. :returns: The extracted token. """ - return chunk.get("completion", "") + if chunk.get("type") == "content_block_delta" and chunk.get("delta", {}).get("type") == "text_delta": + return chunk.get("delta", {}).get("text", "") + return "" + + def _to_anthropic_message(self, m: ChatMessage) -> Dict[str, Any]: + """ + Convert a ChatMessage to a dictionary with the content and role fields. + :param m: The ChatMessage to convert. + :return: The dictionary with the content and role fields. + """ + return {"content": [{"type": "text", "text": m.content}], "role": m.role.value} class MetaLlama2ChatAdapter(BedrockModelChatAdapter): @@ -268,6 +276,9 @@ class MetaLlama2ChatAdapter(BedrockModelChatAdapter): Model adapter for the Meta Llama 2 models. """ + # https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html + ALLOWED_PARAMS: ClassVar[List[str]] = ["max_gen_len", "temperature", "top_p"] + chat_template = ( "{% if messages[0]['role'] == 'system' %}" "{% set loop_messages = messages[1:] %}" @@ -327,11 +338,8 @@ def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[ """ default_params = {"max_gen_len": self.generation_kwargs.get("max_gen_len") or 512} - # combine stop words with default stop sequences, remove stop_words as MetaLlama2 does not support it - stop_sequences = inference_kwargs.get("stop_sequences", []) + inference_kwargs.pop("stop_words", []) - if stop_sequences: - inference_kwargs["stop_sequences"] = stop_sequences - params = self._get_params(inference_kwargs, default_params) + # no support for stop words in Meta Llama 2 + params = self._get_params(inference_kwargs, default_params, self.ALLOWED_PARAMS) body = {"prompt": self.prepare_chat_messages(messages=messages), **params} return body @@ -357,13 +365,16 @@ def check_prompt(self, prompt: str) -> Dict[str, Any]: """ return self.prompt_handler(prompt) - def response_body_message_key(self) -> str: + def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]: """ - Returns the key for the message in the response body for Meta Llama 2 i.e. "generation". + Extracts the messages from the response body. - :returns: The key for the message in the response body. + :param response_body: The response body. + :return: The extracted ChatMessage list. """ - return "generation" + message_tag = "generation" + metadata = {k: v for (k, v) in response_body.items() if k != message_tag} + return [ChatMessage.from_assistant(response_body[message_tag], meta=metadata)] def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: """ diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py index bea6924f6..5279dc001 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py @@ -25,20 +25,21 @@ class AmazonBedrockChatGenerator: """ `AmazonBedrockChatGenerator` enables text generation via Amazon Bedrock hosted chat LLMs. - For example, to use the Anthropic Claude model, simply initialize the `AmazonBedrockChatGenerator` with the - 'anthropic.claude-v2' model name. + For example, to use the Anthropic Claude 3 Sonnet model, simply initialize the `AmazonBedrockChatGenerator` with the + 'anthropic.claude-3-sonnet-20240229-v1:0' model name. ```python from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockChatGenerator from haystack.dataclasses import ChatMessage from haystack.components.generators.utils import print_streaming_chunk - messages = [ChatMessage.from_system("\\nYou are a helpful, respectful and honest assistant"), + messages = [ChatMessage.from_system("\\nYou are a helpful, respectful and honest assistant, answer in German only"), ChatMessage.from_user("What's Natural Language Processing?")] - client = AmazonBedrockChatGenerator(model="anthropic.claude-v2", streaming_callback=print_streaming_chunk) - client.run(messages, generation_kwargs={"max_tokens_to_sample": 512}) + client = AmazonBedrockChatGenerator(model="anthropic.claude-3-sonnet-20240229-v1:0", + streaming_callback=print_streaming_chunk) + client.run(messages, generation_kwargs={"max_tokens": 512}) ``` @@ -154,7 +155,7 @@ def invoke(self, *args, **kwargs): msg = f"The model {self.model} requires a list of ChatMessage objects as a prompt." raise ValueError(msg) - body = self.model_adapter.prepare_body(messages=messages, stop_words=self.stop_words, **kwargs) + body = self.model_adapter.prepare_body(messages=messages, **{"stop_words": self.stop_words, **kwargs}) try: if self.streaming_callback: response = self.client.invoke_model_with_response_stream( diff --git a/integrations/amazon_bedrock/tests/test_chat_generator.py b/integrations/amazon_bedrock/tests/test_chat_generator.py index 9ba4d5534..6e0356d42 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator.py @@ -2,7 +2,7 @@ import pytest from haystack.components.generators.utils import print_streaming_chunk -from haystack.dataclasses import ChatMessage +from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockChatGenerator from haystack_integrations.components.generators.amazon_bedrock.chat.adapters import ( @@ -11,7 +11,8 @@ MetaLlama2ChatAdapter, ) -clazz = "haystack_integrations.components.generators.amazon_bedrock.chat.chat_generator.AmazonBedrockChatGenerator" +KLASS = "haystack_integrations.components.generators.amazon_bedrock.chat.chat_generator.AmazonBedrockChatGenerator" +MODELS_TO_TEST = ["anthropic.claude-3-sonnet-20240229-v1:0", "anthropic.claude-v2:1", "meta.llama2-13b-chat-v1"] def test_to_dict(mock_boto3_session): @@ -24,7 +25,7 @@ def test_to_dict(mock_boto3_session): streaming_callback=print_streaming_chunk, ) expected_dict = { - "type": clazz, + "type": KLASS, "init_parameters": { "aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False}, "aws_secret_access_key": {"type": "env_var", "env_vars": ["AWS_SECRET_ACCESS_KEY"], "strict": False}, @@ -47,7 +48,7 @@ def test_from_dict(mock_boto3_session): """ generator = AmazonBedrockChatGenerator.from_dict( { - "type": clazz, + "type": KLASS, "init_parameters": { "aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False}, "aws_secret_access_key": {"type": "env_var", "env_vars": ["AWS_SECRET_ACCESS_KEY"], "strict": False}, @@ -146,9 +147,9 @@ def test_prepare_body_with_default_params(self) -> None: layer = AnthropicClaudeChatAdapter(generation_kwargs={}) prompt = "Hello, how are you?" expected_body = { - "prompt": "\n\nHuman: Hello, how are you?\n\nAssistant: ", - "max_tokens_to_sample": 512, - "stop_sequences": ["\n\nHuman:"], + "anthropic_version": "bedrock-2023-05-31", + "max_tokens": 512, + "messages": [{"content": [{"text": "Hello, how are you?", "type": "text"}], "role": "user"}], } body = layer.prepare_body([ChatMessage.from_user(prompt)]) @@ -159,12 +160,13 @@ def test_prepare_body_with_custom_inference_params(self) -> None: layer = AnthropicClaudeChatAdapter(generation_kwargs={"temperature": 0.7, "top_p": 0.8, "top_k": 4}) prompt = "Hello, how are you?" expected_body = { - "prompt": "\n\nHuman: Hello, how are you?\n\nAssistant: ", - "max_tokens_to_sample": 69, - "stop_sequences": ["\n\nHuman:", "CUSTOM_STOP"], + "anthropic_version": "bedrock-2023-05-31", + "max_tokens": 512, + "messages": [{"content": [{"text": "Hello, how are you?", "type": "text"}], "role": "user"}], + "stop_sequences": ["CUSTOM_STOP"], "temperature": 0.7, - "top_p": 0.8, "top_k": 5, + "top_p": 0.8, } body = layer.prepare_body( @@ -173,17 +175,14 @@ def test_prepare_body_with_custom_inference_params(self) -> None: assert body == expected_body - @pytest.mark.integration - def test_get_responses(self) -> None: - adapter = AnthropicClaudeChatAdapter(generation_kwargs={}) - response_body = {"completion": "This is a single response."} - expected_response = "This is a single response." - response_message = adapter.get_responses(response_body) - # assert that the type of each item in the list is a ChatMessage - for message in response_message: - assert isinstance(message, ChatMessage) - assert response_message == [ChatMessage.from_assistant(expected_response)] +@pytest.fixture +def chat_messages(): + messages = [ + ChatMessage.from_system("\\nYou are a helpful assistant, be super brief in your responses."), + ChatMessage.from_user("What's the capital of France?"), + ] + return messages class TestMetaLlama2ChatAdapter: @@ -207,13 +206,13 @@ def test_prepare_body_with_custom_inference_params(self) -> None: generation_kwargs={"temperature": 0.7, "top_p": 0.8, "top_k": 5, "stop_sequences": ["CUSTOM_STOP"]} ) prompt = "Hello, how are you?" + + # expected body is different because stop_sequences and top_k are not supported by MetaLlama2 expected_body = { "prompt": "[INST] Hello, how are you? [/INST]", "max_gen_len": 69, - "stop_sequences": ["CUSTOM_STOP"], "temperature": 0.7, "top_p": 0.8, - "top_k": 5, } body = layer.prepare_body( @@ -238,3 +237,52 @@ def test_get_responses(self) -> None: assert isinstance(message, ChatMessage) assert response_message == [ChatMessage.from_assistant(expected_response)] + + @pytest.mark.parametrize("model_name", MODELS_TO_TEST) + @pytest.mark.integration + def test_default_inference_params(self, model_name, chat_messages): + + client = AmazonBedrockChatGenerator(model=model_name) + response = client.run(chat_messages) + + assert "replies" in response, "Response does not contain 'replies' key" + replies = response["replies"] + assert isinstance(replies, list), "Replies is not a list" + assert len(replies) > 0, "No replies received" + + first_reply = replies[0] + assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance" + assert first_reply.content, "First reply has no content" + assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant" + assert "paris" in first_reply.content.lower(), "First reply does not contain 'paris'" + assert first_reply.meta, "First reply has no metadata" + + @pytest.mark.parametrize("model_name", MODELS_TO_TEST) + @pytest.mark.integration + def test_default_inference_with_streaming(self, model_name, chat_messages): + streaming_callback_called = False + paris_found_in_response = False + + def streaming_callback(chunk: StreamingChunk): + nonlocal streaming_callback_called, paris_found_in_response + streaming_callback_called = True + assert isinstance(chunk, StreamingChunk) + assert chunk.content is not None + if not paris_found_in_response: + paris_found_in_response = "paris" in chunk.content.lower() + + client = AmazonBedrockChatGenerator(model=model_name, streaming_callback=streaming_callback) + response = client.run(chat_messages) + + assert streaming_callback_called, "Streaming callback was not called" + assert paris_found_in_response, "The streaming callback response did not contain 'paris'" + replies = response["replies"] + assert isinstance(replies, list), "Replies is not a list" + assert len(replies) > 0, "No replies received" + + first_reply = replies[0] + assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance" + assert first_reply.content, "First reply has no content" + assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant" + assert "paris" in first_reply.content.lower(), "First reply does not contain 'paris'" + assert first_reply.meta, "First reply has no metadata"