Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ feat: Add Ai21Labs model support in Bedrock #3808

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions src/config/modelProviders/bedrock.ts
Original file line number Diff line number Diff line change
Expand Up @@ -250,15 +250,13 @@ const Bedrock: ModelProviderCard = {
tokens: 4000,
},
*/
/*
// TODO: Not support for now
{
description: 'The latest Foundation Model from AI21 Labs, Jamba-Instruct offers an impressive 256K context window and delivers the best value per price on core text generation, summarization, and question answering tasks for the enterprise.',
displayName: 'Jamba-Instruct',
enabled: true,
id: 'ai21.jamba-instruct-v1:0',
tokens: 256_000,
},
*/
/*
// Cohere Command (Text) and AI21 Labs Jurassic-2 (Text) don't support chat with the Converse API
{
Expand Down
54 changes: 54 additions & 0 deletions src/libs/agent-runtime/bedrock/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import { AgentRuntimeError } from '../utils/createError';
import { debugStream } from '../utils/debugStream';
import { StreamingResponse } from '../utils/response';
import {
AWSBedrockAi21Stream,
AWSBedrockClaudeStream,
AWSBedrockLlamaStream,
createBedrockStream,
Expand Down Expand Up @@ -46,11 +47,64 @@ export class LobeBedrockAI implements LobeRuntimeAI {
}

async chat(payload: ChatStreamPayload, options?: ChatCompetitionOptions) {
if (payload.model.startsWith('ai21')) return this.invokeAi21Model(payload, options);

if (payload.model.startsWith('meta')) return this.invokeLlamaModel(payload, options);

return this.invokeClaudeModel(payload, options);
}

private invokeAi21Model = async (
payload: ChatStreamPayload,
options?: ChatCompetitionOptions,
): Promise<Response> => {
const { frequency_penalty, max_tokens, messages, model, presence_penalty, temperature, top_p } = payload;
const command = new InvokeModelWithResponseStreamCommand({
accept: 'application/json',
body: JSON.stringify({
frequency_penalty: frequency_penalty,
max_tokens: max_tokens || 4096,
messages: messages,
presence_penalty: presence_penalty,
temperature: temperature,
top_p: top_p,
}),
contentType: 'application/json',
modelId: model,
});

try {
// Ask Claude for a streaming chat completion given the prompt
const res = await this.client.send(command);

const stream = createBedrockStream(res);

const [prod, debug] = stream.tee();

if (process.env.DEBUG_BEDROCK_CHAT_COMPLETION === '1') {
debugStream(debug).catch(console.error);
}
// Respond with the stream
return StreamingResponse(AWSBedrockAi21Stream(prod, options?.callback), {
headers: options?.headers,
});
} catch (e) {
const err = e as Error & { $metadata: any };

throw AgentRuntimeError.chat({
error: {
body: err.$metadata,
message: err.message,
region: this.region,
type: err.name,
},
errorType: AgentRuntimeErrorType.ProviderBizError,
provider: ModelProvider.Bedrock,
region: this.region,
});
}
};

private invokeClaudeModel = async (
payload: ChatStreamPayload,
options?: ChatCompetitionOptions,
Expand Down
191 changes: 191 additions & 0 deletions src/libs/agent-runtime/utils/streams/bedrock/ai21.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
import { describe, expect, it, vi } from 'vitest';
import * as uuidModule from '@/utils/uuid';
import { transformAi21Stream, AWSBedrockAi21Stream } from './ai21';

// Define the BedrockAi21StreamChunk type in the test file
interface BedrockAi21StreamChunk {
'amazon-bedrock-invocationMetrics'?: {
inputTokenCount: number;
outputTokenCount: number;
invocationLatency: number;
firstByteLatency: number;
};
id?: string;
choices: {
index?: number;
delta: {
content: string;
};
finish_reason?: string | null;
stop_reason?: string | null;
}[];
usage?: {
prompt_tokens: number;
total_tokens: number;
completion_tokens: number;
};
meta?: {
requestDurationMillis: number;
};
}

describe('AI21 Stream', () => {
describe('transformAi21Stream', () => {
it('should transform text response chunks', () => {
const chunk: BedrockAi21StreamChunk = {
id: "chat-ae86a1e555f04e5cbddb86cc6a98ce5e",
choices: [{
index: 0,
delta: { content: "Hello world!" }
}]
};
const stack = { id: 'chat_test-id' };

const result = transformAi21Stream(chunk, stack);

expect(result).toEqual({
data: "Hello world!",
id: 'chat_test-id',
type: 'text'
});
});

it('should handle stop reason with content', () => {
const chunk: BedrockAi21StreamChunk = {
id: "chat-ae86a1e555f04e5cbddb86cc6a98ce5e",
choices: [{
index: 0,
delta: { content: "Final words." },
finish_reason: "stop",
stop_reason: "<|eom|>"
}]
};
const stack = { id: 'chat_test-id' };

const result = transformAi21Stream(chunk, stack);

expect(result).toEqual({
data: "Final words.",
id: 'chat_test-id',
type: 'text'
});
});

it('should handle empty content', () => {
const chunk: BedrockAi21StreamChunk = {
id: "chat-ae86a1e555f04e5cbddb86cc6a98ce5e",
choices: [{
index: 0,
delta: { content: "" }
}]
};
const stack = { id: 'chat_test-id' };

const result = transformAi21Stream(chunk, stack);

expect(result).toEqual({
data: "",
id: 'chat_test-id',
type: 'text'
});
});

it('should remove amazon-bedrock-invocationMetrics', () => {
const chunk: BedrockAi21StreamChunk = {
id: "chat-ae86a1e555f04e5cbddb86cc6a98ce5e",
choices: [{
index: 0,
delta: { content: "Hello" }
}],
"amazon-bedrock-invocationMetrics": {
inputTokenCount: 63,
outputTokenCount: 263,
invocationLatency: 5330,
firstByteLatency: 122
}
};
const stack = { id: 'chat_test-id' };

const result = transformAi21Stream(chunk, stack);

expect(result).toEqual({
data: "Hello",
id: 'chat_test-id',
type: 'text'
});
expect(chunk['amazon-bedrock-invocationMetrics']).toBeUndefined();
});
});

describe('AWSBedrockAi21Stream', () => {
it('should transform Bedrock AI21 stream to protocol stream', async () => {
vi.spyOn(uuidModule, 'nanoid').mockReturnValueOnce('test-id');
const mockBedrockStream = new ReadableStream({
start(controller) {
controller.enqueue({
id: "chat-ae86a1e555f04e5cbddb86cc6a98ce5e",
choices: [{
index: 0,
delta: { content: "Hello" }
}]
});
controller.enqueue({
id: "chat-ae86a1e555f04e5cbddb86cc6a98ce5e",
choices: [{
index: 0,
delta: { content: " world!" }
}]
});
controller.enqueue({
id: "chat-ae86a1e555f04e5cbddb86cc6a98ce5e",
choices: [{
index: 0,
delta: { content: " Final words." },
finish_reason: "stop",
stop_reason: "<|eom|>"
}]
});
controller.close();
},
});

const onStartMock = vi.fn();
const onTextMock = vi.fn();
const onTokenMock = vi.fn();
const onCompletionMock = vi.fn();

const protocolStream = AWSBedrockAi21Stream(mockBedrockStream, {
onStart: onStartMock,
onText: onTextMock,
onToken: onTokenMock,
onCompletion: onCompletionMock,
});

const decoder = new TextDecoder();
const chunks: string[] = [];

for await (const chunk of protocolStream as unknown as AsyncIterable<Uint8Array>) {
chunks.push(decoder.decode(chunk, { stream: true }));
}

expect(chunks).toEqual([
'id: chat_test-id\n',
'event: text\n',
'data: "Hello"\n\n',
'id: chat_test-id\n',
'event: text\n',
'data: " world!"\n\n',
'id: chat_test-id\n',
'event: text\n',
'data: " Final words."\n\n',
]);

expect(onStartMock).toHaveBeenCalledTimes(1);
expect(onTextMock).toHaveBeenNthCalledWith(1, '"Hello"');
expect(onTextMock).toHaveBeenNthCalledWith(2, '" world!"');
expect(onTextMock).toHaveBeenNthCalledWith(3, '" Final words."');
expect(onTokenMock).toHaveBeenCalledTimes(3);
expect(onCompletionMock).toHaveBeenCalledTimes(1);
});
});
});
83 changes: 83 additions & 0 deletions src/libs/agent-runtime/utils/streams/bedrock/ai21.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import { InvokeModelWithResponseStreamResponse } from '@aws-sdk/client-bedrock-runtime';

import { nanoid } from '@/utils/uuid';

import { ChatStreamCallbacks } from '../../../types';
import {
StreamProtocolChunk,
StreamStack,
createCallbacksTransformer,
createSSEProtocolTransformer,
} from '../protocol';
import { createBedrockStream } from './common';

interface AmazonBedrockInvocationMetrics {
firstByteLatency: number;
inputTokenCount: number;
invocationLatency: number;
outputTokenCount: number;
}

// https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-jamba.html
// ai21_chunk: {"id":"chat-ae86a1e555f04e5cbddb86cc6a98ce5e","choices":[{"index":0,"delta":{"content":"?"},"finish_reason":"stop","stop_reason":"<|eom|>"}],"usage":{"prompt_tokens":144,"total_tokens":158,"completion_tokens":14},"meta":{"requestDurationMillis":146}},"amazon-bedrock-invocationMetrics":{"inputTokenCount":63,"outputTokenCount":263,"invocationLatency":5330,"firstByteLatency":122}}
interface BedrockAi21StreamChunk {
'amazon-bedrock-invocationMetrics'?: AmazonBedrockInvocationMetrics;
'choices': {
'delta': {
'content': string;
};
'finish_reason'?: null | 'stop' | string;
'index'?: number;
'stop_reason'?: null | string;
}[];
'id'?: string;
'meta'?: {
'requestDurationMillis': number;
};
'usage'?: {
'completion_tokens': number;
'prompt_tokens': number;
'total_tokens': number;
};
}

export const transformAi21Stream = (
chunk: BedrockAi21StreamChunk,
stack: StreamStack,
): StreamProtocolChunk => {
// remove 'amazon-bedrock-invocationMetrics' from chunk
delete chunk['amazon-bedrock-invocationMetrics'];

if (!chunk.choices || chunk.choices.length === 0) {
return { data: chunk, id: stack.id, type: 'data' };
}

const item = chunk.choices[0];

if (typeof item.delta?.content === 'string') {
return { data: item.delta.content, id: stack.id, type: 'text' };
}

if (item.finish_reason) {
return { data: item.finish_reason, id: stack.id, type: 'stop' };
}

return {
data: { delta: item.delta, id: stack.id, index: item.index },
id: stack.id,
type: 'data',
};
};

export const AWSBedrockAi21Stream = (
res: InvokeModelWithResponseStreamResponse | ReadableStream,
cb?: ChatStreamCallbacks,
): ReadableStream<string> => {
const streamStack: StreamStack = { id: 'chat_' + nanoid() };

const stream = res instanceof ReadableStream ? res : createBedrockStream(res);

return stream
.pipeThrough(createSSEProtocolTransformer(transformAi21Stream, streamStack))
.pipeThrough(createCallbacksTransformer(cb));
};
1 change: 1 addition & 0 deletions src/libs/agent-runtime/utils/streams/bedrock/index.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
export * from './ai21';
export * from './claude';
export * from './common';
export * from './llama';
Loading