Skip to content

Commit

Permalink
feat: AmazonBedrockChatGenerator - migrate Anthropic chat models to u…
Browse files Browse the repository at this point in the history
…se messaging API (#545)

* Migrate Claude to messaging API
---------

Co-authored-by: Paul Steppacher <[email protected]>
  • Loading branch information
vblagoje and steppi91 authored Mar 11, 2024
1 parent 4f01032 commit f95e4d0
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 100 deletions.
11 changes: 11 additions & 0 deletions .github/workflows/amazon_bedrock.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,14 @@ concurrency:
group: amazon-bedrock-${{ github.head_ref }}
cancel-in-progress: true

permissions:
id-token: write
contents: read

env:
PYTHONUNBUFFERED: "1"
FORCE_COLOR: "1"
AWS_REGION: us-east-1

jobs:
run:
Expand Down Expand Up @@ -56,5 +61,11 @@ jobs:
if: matrix.python-version == '3.9' && runner.os == 'Linux'
run: hatch run docs

- name: AWS authentication
uses: aws-actions/configure-aws-credentials@e3dd6a429d7300a6a4c196c26e071d42e0343502
with:
aws-region: ${{ env.AWS_REGION }}
role-to-assume: ${{ secrets.AWS_CI_ROLE_ARN }}

- name: Run tests
run: hatch run cov
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import logging
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List
from typing import Any, Callable, ClassVar, Dict, List

from botocore.eventstream import EventStream
from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk
Expand Down Expand Up @@ -44,55 +44,62 @@ def get_responses(self, response_body: Dict[str, Any]) -> List[ChatMessage]:
:param response_body: The response body.
:returns: The extracted responses.
"""
return self._extract_messages_from_response(self.response_body_message_key(), response_body)
return self._extract_messages_from_response(response_body)

def get_stream_responses(self, stream: EventStream, stream_handler: Callable[[StreamingChunk], None]) -> List[str]:
def get_stream_responses(
self, stream: EventStream, stream_handler: Callable[[StreamingChunk], None]
) -> List[ChatMessage]:
tokens: List[str] = []
last_decoded_chunk: Dict[str, Any] = {}
for event in stream:
chunk = event.get("chunk")
if chunk:
decoded_chunk = json.loads(chunk["bytes"].decode("utf-8"))
token = self._extract_token_from_stream(decoded_chunk)
# take all the rest key/value pairs from the chunk, add them to the metadata
stream_metadata = {k: v for (k, v) in decoded_chunk.items() if v != token}
stream_chunk = StreamingChunk(content=token, meta=stream_metadata)
# callback the stream handler with StreamingChunk
stream_handler(stream_chunk)
last_decoded_chunk = json.loads(chunk["bytes"].decode("utf-8"))
token = self._extract_token_from_stream(last_decoded_chunk)
stream_chunk = StreamingChunk(content=token) # don't extract meta, we care about tokens only
stream_handler(stream_chunk) # callback the stream handler with StreamingChunk
tokens.append(token)
responses = ["".join(tokens).lstrip()]
return responses
return [ChatMessage.from_assistant(response, meta=last_decoded_chunk) for response in responses]

@staticmethod
def _update_params(target_dict: Dict[str, Any], updates_dict: Dict[str, Any]) -> None:
def _update_params(target_dict: Dict[str, Any], updates_dict: Dict[str, Any], allowed_params: List[str]) -> None:
"""
Updates target_dict with values from updates_dict. Merges lists instead of overriding them.
:param target_dict: The dictionary to update.
:param updates_dict: The dictionary with updates.
:param allowed_params: The list of allowed params to use.
"""
for key, value in updates_dict.items():
if key not in allowed_params:
logger.warning(f"Parameter '{key}' is not allowed and will be ignored.")
continue
if key in target_dict and isinstance(target_dict[key], list) and isinstance(value, list):
# Merge lists and remove duplicates
target_dict[key] = sorted(set(target_dict[key] + value))
else:
# Override the value in target_dict
target_dict[key] = value

def _get_params(self, inference_kwargs: Dict[str, Any], default_params: Dict[str, Any]) -> Dict[str, Any]:
def _get_params(
self, inference_kwargs: Dict[str, Any], default_params: Dict[str, Any], allowed_params: List[str]
) -> Dict[str, Any]:
"""
Merges params from inference_kwargs with the default params and self.generation_kwargs.
Uses a helper function to merge lists or override values as necessary.
:param inference_kwargs: The inference kwargs to merge.
:param default_params: The default params to start with.
:param allowed_params: The list of allowed params to use.
:returns: The merged params.
"""
# Start with a copy of default_params
kwargs = default_params.copy()

# Update the default params with self.generation_kwargs and finally inference_kwargs
self._update_params(kwargs, self.generation_kwargs)
self._update_params(kwargs, inference_kwargs)
self._update_params(kwargs, self.generation_kwargs, allowed_params)
self._update_params(kwargs, inference_kwargs, allowed_params)

return kwargs

Expand Down Expand Up @@ -124,25 +131,14 @@ def check_prompt(self, prompt: str) -> Dict[str, Any]:
:returns: A dictionary containing the resized prompt and additional information.
"""

def _extract_messages_from_response(self, message_tag: str, response_body: Dict[str, Any]) -> List[ChatMessage]:
@abstractmethod
def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]:
"""
Extracts the messages from the response body.
:param message_tag: The key for the message in the response body.
:param response_body: The response body.
:returns: The extracted ChatMessage list.
"""
metadata = {k: v for (k, v) in response_body.items() if k != message_tag}
return [ChatMessage.from_assistant(response_body[message_tag], meta=metadata)]

@abstractmethod
def response_body_message_key(self) -> str:
"""
Returns the key for the message in the response body.
Subclasses should override this method to return the correct message key - where the response is located.
:returns: The key for the message in the response body.
"""

@abstractmethod
def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
Expand All @@ -159,8 +155,16 @@ class AnthropicClaudeChatAdapter(BedrockModelChatAdapter):
Model adapter for the Anthropic Claude chat model.
"""

ANTHROPIC_USER_TOKEN = "\n\nHuman:"
ANTHROPIC_ASSISTANT_TOKEN = "\n\nAssistant:"
# https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
ALLOWED_PARAMS: ClassVar[List[str]] = [
"anthropic_version",
"max_tokens",
"stop_sequences",
"temperature",
"top_p",
"top_k",
"system",
]

def __init__(self, generation_kwargs: Dict[str, Any]):
"""
Expand All @@ -183,7 +187,7 @@ def __init__(self, generation_kwargs: Dict[str, Any]):
self.prompt_handler = DefaultPromptHandler(
tokenizer="gpt2",
model_max_length=model_max_length,
max_length=self.generation_kwargs.get("max_tokens_to_sample") or 512,
max_length=self.generation_kwargs.get("max_tokens") or 512,
)

def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[str, Any]:
Expand All @@ -195,46 +199,33 @@ def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[
:returns: The prepared body.
"""
default_params = {
"max_tokens_to_sample": self.generation_kwargs.get("max_tokens_to_sample") or 512,
"stop_sequences": ["\n\nHuman:"],
"anthropic_version": self.generation_kwargs.get("anthropic_version") or "bedrock-2023-05-31",
"max_tokens": self.generation_kwargs.get("max_tokens") or 512, # max_tokens is required
}

# combine stop words with default stop sequences, remove stop_words as Anthropic does not support it
stop_sequences = inference_kwargs.get("stop_sequences", []) + inference_kwargs.pop("stop_words", [])
if stop_sequences:
inference_kwargs["stop_sequences"] = stop_sequences
params = self._get_params(inference_kwargs, default_params)
body = {"prompt": self.prepare_chat_messages(messages=messages), **params}
params = self._get_params(inference_kwargs, default_params, self.ALLOWED_PARAMS)
body = {**self.prepare_chat_messages(messages=messages), **params}
return body

def prepare_chat_messages(self, messages: List[ChatMessage]) -> str:
def prepare_chat_messages(self, messages: List[ChatMessage]) -> Dict[str, Any]:
"""
Prepares the chat messages for the Anthropic Claude request.
:param messages: The chat messages to prepare.
:returns: The prepared chat messages as a string.
"""
conversation = []
for index, message in enumerate(messages):
if message.is_from(ChatRole.USER):
conversation.append(f"{AnthropicClaudeChatAdapter.ANTHROPIC_USER_TOKEN} {message.content.strip()}")
elif message.is_from(ChatRole.ASSISTANT):
conversation.append(f"{AnthropicClaudeChatAdapter.ANTHROPIC_ASSISTANT_TOKEN} {message.content.strip()}")
elif message.is_from(ChatRole.FUNCTION):
error_message = "Anthropic does not support function calls."
raise ValueError(error_message)
elif message.is_from(ChatRole.SYSTEM) and index == 0:
# Until we transition to the new chat message format system messages will be ignored
# see https://docs.anthropic.com/claude/reference/messages_post for more details
logger.warning(
"System messages are not fully supported by the current version of Claude and will be ignored."
)
else:
invalid_role = f"Invalid role {message.role} for message {message.content}"
raise ValueError(invalid_role)

prepared_prompt = "".join(conversation) + AnthropicClaudeChatAdapter.ANTHROPIC_ASSISTANT_TOKEN + " "
return self._ensure_token_limit(prepared_prompt)
body: Dict[str, Any] = {}
system = messages[0].content if messages and messages[0].is_from(ChatRole.SYSTEM) else None
body["messages"] = [
self._to_anthropic_message(m) for m in messages if m.is_from(ChatRole.USER) or m.is_from(ChatRole.ASSISTANT)
]
if system:
body["system"] = system
return body

def check_prompt(self, prompt: str) -> Dict[str, Any]:
"""
Expand All @@ -245,13 +236,20 @@ def check_prompt(self, prompt: str) -> Dict[str, Any]:
"""
return self.prompt_handler(prompt)

def response_body_message_key(self) -> str:
def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]:
"""
Returns the key for the message in the response body for Anthropic Claude i.e. "completion".
Extracts the messages from the response body.
:returns: The key for the message in the response body.
:param response_body: The response body.
:return: The extracted ChatMessage list.
"""
return "completion"
messages: List[ChatMessage] = []
if response_body.get("type") == "message":
for content in response_body["content"]:
if content.get("type") == "text":
meta = {k: v for k, v in response_body.items() if k not in ["type", "content", "role"]}
messages.append(ChatMessage.from_assistant(content["text"], meta=meta))
return messages

def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
"""
Expand All @@ -260,14 +258,27 @@ def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
:param chunk: The streaming chunk.
:returns: The extracted token.
"""
return chunk.get("completion", "")
if chunk.get("type") == "content_block_delta" and chunk.get("delta", {}).get("type") == "text_delta":
return chunk.get("delta", {}).get("text", "")
return ""

def _to_anthropic_message(self, m: ChatMessage) -> Dict[str, Any]:
"""
Convert a ChatMessage to a dictionary with the content and role fields.
:param m: The ChatMessage to convert.
:return: The dictionary with the content and role fields.
"""
return {"content": [{"type": "text", "text": m.content}], "role": m.role.value}


class MetaLlama2ChatAdapter(BedrockModelChatAdapter):
"""
Model adapter for the Meta Llama 2 models.
"""

# https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html
ALLOWED_PARAMS: ClassVar[List[str]] = ["max_gen_len", "temperature", "top_p"]

chat_template = (
"{% if messages[0]['role'] == 'system' %}"
"{% set loop_messages = messages[1:] %}"
Expand Down Expand Up @@ -327,11 +338,8 @@ def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[
"""
default_params = {"max_gen_len": self.generation_kwargs.get("max_gen_len") or 512}

# combine stop words with default stop sequences, remove stop_words as MetaLlama2 does not support it
stop_sequences = inference_kwargs.get("stop_sequences", []) + inference_kwargs.pop("stop_words", [])
if stop_sequences:
inference_kwargs["stop_sequences"] = stop_sequences
params = self._get_params(inference_kwargs, default_params)
# no support for stop words in Meta Llama 2
params = self._get_params(inference_kwargs, default_params, self.ALLOWED_PARAMS)
body = {"prompt": self.prepare_chat_messages(messages=messages), **params}
return body

Expand All @@ -357,13 +365,16 @@ def check_prompt(self, prompt: str) -> Dict[str, Any]:
"""
return self.prompt_handler(prompt)

def response_body_message_key(self) -> str:
def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]:
"""
Returns the key for the message in the response body for Meta Llama 2 i.e. "generation".
Extracts the messages from the response body.
:returns: The key for the message in the response body.
:param response_body: The response body.
:return: The extracted ChatMessage list.
"""
return "generation"
message_tag = "generation"
metadata = {k: v for (k, v) in response_body.items() if k != message_tag}
return [ChatMessage.from_assistant(response_body[message_tag], meta=metadata)]

def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,21 @@ class AmazonBedrockChatGenerator:
"""
`AmazonBedrockChatGenerator` enables text generation via Amazon Bedrock hosted chat LLMs.
For example, to use the Anthropic Claude model, simply initialize the `AmazonBedrockChatGenerator` with the
'anthropic.claude-v2' model name.
For example, to use the Anthropic Claude 3 Sonnet model, simply initialize the `AmazonBedrockChatGenerator` with the
'anthropic.claude-3-sonnet-20240229-v1:0' model name.
```python
from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockChatGenerator
from haystack.dataclasses import ChatMessage
from haystack.components.generators.utils import print_streaming_chunk
messages = [ChatMessage.from_system("\\nYou are a helpful, respectful and honest assistant"),
messages = [ChatMessage.from_system("\\nYou are a helpful, respectful and honest assistant, answer in German only"),
ChatMessage.from_user("What's Natural Language Processing?")]
client = AmazonBedrockChatGenerator(model="anthropic.claude-v2", streaming_callback=print_streaming_chunk)
client.run(messages, generation_kwargs={"max_tokens_to_sample": 512})
client = AmazonBedrockChatGenerator(model="anthropic.claude-3-sonnet-20240229-v1:0",
streaming_callback=print_streaming_chunk)
client.run(messages, generation_kwargs={"max_tokens": 512})
```
Expand Down Expand Up @@ -154,7 +155,7 @@ def invoke(self, *args, **kwargs):
msg = f"The model {self.model} requires a list of ChatMessage objects as a prompt."
raise ValueError(msg)

body = self.model_adapter.prepare_body(messages=messages, stop_words=self.stop_words, **kwargs)
body = self.model_adapter.prepare_body(messages=messages, **{"stop_words": self.stop_words, **kwargs})
try:
if self.streaming_callback:
response = self.client.invoke_model_with_response_stream(
Expand Down
Loading

0 comments on commit f95e4d0

Please sign in to comment.