Skip to content

Commit

Permalink
Remove tool usage for non current turns when looking up message histo…
Browse files Browse the repository at this point in the history
…ry (#2174)

* Remove tool usage for non current turns when looking up message history

* remove that.
  • Loading branch information
sobolk authored Oct 31, 2024
1 parent 37dd87c commit 613bca9
Show file tree
Hide file tree
Showing 3 changed files with 337 additions and 2 deletions.
5 changes: 5 additions & 0 deletions .changeset/long-toys-suffer.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'@aws-amplify/ai-constructs': patch
---

Remove tool usage for non current turns when looking up message history
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,257 @@ void describe('Conversation message history retriever', () => {
},
],
},
{
name: 'Removes tool usage from non-current turns',
mockListResponseMessages: [
{
id: 'someNonCurrentMessageId1',
conversationId: event.conversationId,
role: 'user',
content: [
{
text: 'nonCurrentMessage1',
},
],
},
{
id: 'someNonCurrentMessageId2',
associatedUserMessageId: 'someNonCurrentMessageId1',
conversationId: event.conversationId,
role: 'assistant',
content: [
{
text: 'nonCurrentMessage2',
},
{
toolUse: {
name: 'testToolUse1',
toolUseId: 'testToolUseId1',
input: undefined,
},
},
],
},
{
id: 'someNonCurrentMessageId3',
conversationId: event.conversationId,
role: 'user',
content: [
{
toolResult: {
status: 'success',
toolUseId: 'testToolUseId1',
content: undefined,
},
},
],
},
{
id: 'someNonCurrentMessageId4',
associatedUserMessageId: 'someNonCurrentMessageId3',
conversationId: event.conversationId,
role: 'assistant',
content: [
{
text: 'nonCurrentMessage3',
},
{
toolUse: {
name: 'testToolUse2',
toolUseId: 'testToolUseId2',
input: undefined,
},
},
],
},
{
id: 'someNonCurrentMessageId5',
conversationId: event.conversationId,
role: 'user',
content: [
{
toolResult: {
status: 'success',
toolUseId: 'testToolUseId2',
content: undefined,
},
},
],
},
{
id: 'someNonCurrentMessageId5',
associatedUserMessageId: 'someNonCurrentMessageId5',
conversationId: event.conversationId,
role: 'assistant',
content: [
{
text: 'nonCurrentMessage4',
},
],
},
// Current turn with multiple tool use.
{
id: 'someCurrentMessageId1',
conversationId: event.conversationId,
role: 'user',
content: [
{
text: 'currentMessage1',
},
],
},
{
id: 'someCurrentMessageId2',
associatedUserMessageId: 'someCurrentMessageId1',
conversationId: event.conversationId,
role: 'assistant',
content: [
{
text: 'currentMessage2',
},
{
toolUse: {
name: 'testToolUse3',
toolUseId: 'testToolUseId3',
input: undefined,
},
},
],
},
{
id: 'someCurrentMessageId3',
conversationId: event.conversationId,
role: 'user',
content: [
{
toolResult: {
status: 'success',
toolUseId: 'testToolUseId3',
content: undefined,
},
},
],
},
{
id: 'someCurrentMessageId4',
associatedUserMessageId: 'someCurrentMessageId3',
conversationId: event.conversationId,
role: 'assistant',
content: [
{
text: 'currentMessage3',
},
{
toolUse: {
name: 'testToolUse4',
toolUseId: 'testToolUseId4',
input: undefined,
},
},
],
},
{
id: event.currentMessageId,
conversationId: event.conversationId,
role: 'user',
content: [
{
toolResult: {
status: 'success',
toolUseId: 'testToolUseId2',
content: undefined,
},
},
],
},
],
expectedMessages: [
{
role: 'user',
content: [
{
text: 'nonCurrentMessage1',
},
],
},
{
role: 'assistant',
content: [
{
text: 'nonCurrentMessage2',
},
{
text: 'nonCurrentMessage3',
},
{
text: 'nonCurrentMessage4',
},
],
},
{
role: 'user',
content: [
{
text: 'currentMessage1',
},
],
},
{
role: 'assistant',
content: [
{
text: 'currentMessage2',
},
{
toolUse: {
name: 'testToolUse3',
toolUseId: 'testToolUseId3',
input: undefined,
},
},
],
},
{
role: 'user',
content: [
{
toolResult: {
status: 'success',
toolUseId: 'testToolUseId3',
content: undefined,
},
},
],
},
{
role: 'assistant',
content: [
{
text: 'currentMessage3',
},
{
toolUse: {
name: 'testToolUse4',
toolUseId: 'testToolUseId4',
input: undefined,
},
},
],
},
{
role: 'user',
content: [
{
toolResult: {
status: 'success',
toolUseId: 'testToolUseId2',
content: undefined,
},
},
],
},
],
},
];

for (const testCase of testCases) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
import { ConversationMessage, ConversationTurnEvent } from './types';
import {
ConversationMessage,
ConversationMessageContentBlock,
ConversationTurnEvent,
} from './types';
import { GraphqlRequestExecutor } from './graphql_request_executor';

export type ConversationHistoryMessageItem = ConversationMessage & {
Expand Down Expand Up @@ -137,7 +141,7 @@ export class ConversationMessageHistoryRetriever {
});

// Reconcile history and inject aiContext
return messages.reduce((acc, current) => {
const orderedMessages = messages.reduce((acc, current) => {
// Bedrock expects that message history is user->assistant->user->assistant->... and so on.
// The chronological order doesn't assure this ordering if there were any concurrent messages sent.
// Therefore, conversation is ordered by user's messages only and corresponding assistant messages are inserted
Expand Down Expand Up @@ -176,6 +180,81 @@ export class ConversationMessageHistoryRetriever {
}
return acc;
}, [] as Array<ConversationMessage>);

// Remove tool usage from non-current turn and squash messages.
return this.squashNonCurrentTurns(orderedMessages);
};

/**
* This function removes tool usage from non-current turns.
* The tool usage and result blocks don't matter after a turn is completed,
* but do cost extra tokens to process.
* The algorithm is as follows:
* 1. Find where current turn begins. I.e. last user message that isn't tool block.
* 2. Remove toolUse and toolResult blocks before current turn.
* 3. Squash continuous sequences of messages that belong to same 'message.role'.
*/
private squashNonCurrentTurns = (messages: Array<ConversationMessage>) => {
const isNonToolBlockPredicate = (
contentBlock: ConversationMessageContentBlock
) => !contentBlock.toolUse && !contentBlock.toolResult;

// find where current turn begins. I.e. last user message that is not related to tools
const lastNonToolUseUserMessageIndex = messages.findLastIndex((message) => {
return (
message.role === 'user' && message.content.find(isNonToolBlockPredicate)
);
});

// No non-current turns, don't transform.
if (lastNonToolUseUserMessageIndex <= 0) {
return messages;
}

const squashedMessages: Array<ConversationMessage> = [];

// Define a "buffer". I.e. a message we keep around and squash content on.
let currentSquashedMessage: ConversationMessage | undefined = undefined;
// Process messages before current turn begins
// Remove tool usage blocks.
// Combine content for consecutive message that have same role.
for (let i = 0; i < lastNonToolUseUserMessageIndex; i++) {
const currentMessage = messages[i];
const currentMessageRole = currentMessage.role;
const currentMessageNonToolContent = currentMessage.content.filter(
isNonToolBlockPredicate
);
if (currentMessageNonToolContent.length === 0) {
// Tool only message. Nothing to squash, skip;
continue;
}

if (!currentSquashedMessage) {
// Nothing squashed yet, initialize the buffer.
currentSquashedMessage = {
role: currentMessageRole,
content: currentMessageNonToolContent,
};
} else if (currentSquashedMessage.role === currentMessageRole) {
// if role is same append content.
currentSquashedMessage.content.push(...currentMessageNonToolContent);
} else {
// if role flips push current squashed message and re-initialize the buffer.
squashedMessages.push(currentSquashedMessage);
currentSquashedMessage = {
role: currentMessageRole,
content: currentMessageNonToolContent,
};
}
}
// flush the last buffer.
if (currentSquashedMessage) {
squashedMessages.push(currentSquashedMessage);
}

// Append current turn as is.
squashedMessages.push(...messages.slice(lastNonToolUseUserMessageIndex));
return squashedMessages;
};

private getCurrentMessage =
Expand Down

0 comments on commit 613bca9

Please sign in to comment.