Skip to content

Commit

Permalink
feat: Added support for Anthropic Claude 3 messages API (#2278)
Browse files Browse the repository at this point in the history
  • Loading branch information
jsumners-nr authored Jun 14, 2024
1 parent 8f96c73 commit 7e3cab9
Show file tree
Hide file tree
Showing 13 changed files with 315 additions and 6 deletions.
4 changes: 4 additions & 0 deletions ai-support.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
{
"title": "Image",
"supported": false
},
{
"title": "Vision",
"supported": false
}
]
},
Expand Down
14 changes: 12 additions & 2 deletions lib/llm-events/aws-bedrock/bedrock-command.js
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class BedrockCommand {
result = this.#body.maxTokens
} else if (this.isClaude() === true) {
result = this.#body.max_tokens_to_sample
} else if (this.isCohere() === true) {
} else if (this.isClaude3() === true || this.isCohere() === true) {
result = this.#body.max_tokens
} else if (this.isLlama2() === true) {
result = this.#body.max_gen_length
Expand Down Expand Up @@ -83,6 +83,11 @@ class BedrockCommand {
this.isLlama2() === true
) {
result = this.#body.prompt
} else if (this.isClaude3() === true) {
result = this.#body?.messages?.reduce((acc, curr) => {
acc += curr?.content ?? ''
return acc
}, '')
}
return result
}
Expand All @@ -96,6 +101,7 @@ class BedrockCommand {
result = this.#body.textGenerationConfig?.temperature
} else if (
this.isClaude() === true ||
this.isClaude3() === true ||
this.isAi21() === true ||
this.isCohere() === true ||
this.isLlama2() === true
Expand All @@ -110,7 +116,11 @@ class BedrockCommand {
}

isClaude() {
return this.#modelId.startsWith('anthropic.claude')
return this.#modelId.startsWith('anthropic.claude-v')
}

isClaude3() {
return this.#modelId.startsWith('anthropic.claude-3')
}

isCohere() {
Expand Down
13 changes: 12 additions & 1 deletion lib/llm-events/aws-bedrock/bedrock-response.js
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,12 @@ class BedrockResponse {
#completions = []
#id

/* eslint-disable sonarjs/cognitive-complexity */
/**
* @param {object} params
* @param {AwsBedrockMiddlewareResponse} params.response
* @param {BedrockCommand} params.bedrockCommand
* @param params.isError
*/
constructor({ response, bedrockCommand, isError = false }) {
this.#innerResponse = isError ? response.$response : response.response
Expand All @@ -57,6 +59,14 @@ class BedrockResponse {
} else if (cmd.isClaude() === true) {
// TODO: can we make this thing give more than one completion?
body.completion && this.#completions.push(body.completion)
} else if (cmd.isClaude3() === true) {
if (body?.type === 'message_stop') {
// Streamed response
this.#completions = body.completions
} else {
this.#completions = body?.content?.map((c) => c.text)
}
this.#id = body.id
} else if (cmd.isCohere() === true) {
this.#completions = body.generations?.map((g) => g.text) ?? []
this.#id = body.id
Expand All @@ -66,6 +76,7 @@ class BedrockResponse {
this.#completions = body.results?.map((r) => r.outputText) ?? []
}
}
/* eslint-enable sonarjs/cognitive-complexity */

/**
* The prompt responses returned by the model.
Expand All @@ -92,7 +103,7 @@ class BedrockResponse {
const cmd = this.#command
if (cmd.isAi21() === true) {
result = this.#parsedBody.completions?.[0]?.finishReason.reason
} else if (cmd.isClaude() === true) {
} else if (cmd.isClaude() === true || cmd.isClaude3() === true) {
result = this.#parsedBody.stop_reason
} else if (cmd.isCohere() === true) {
result = this.#parsedBody.generations?.find((r) => r.finish_reason !== null)?.finish_reason
Expand Down
28 changes: 28 additions & 0 deletions lib/llm-events/aws-bedrock/stream-handler.js
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ class StreamHandler {
if (bedrockCommand.isClaude() === true) {
this.stopReasonKey = 'stop_reason'
this.generator = handleClaude
} else if (bedrockCommand.isClaude3() === true) {
this.stopReasonKey = 'stop_reason'
this.generator = handleClaude3
} else if (bedrockCommand.isCohere() === true) {
this.stopReasonKey = 'generations.0.finish_reason'
this.generator = handleCohere
Expand Down Expand Up @@ -207,6 +210,31 @@ async function* handleClaude() {
}
}

async function* handleClaude3() {
let currentBody = {}
let stopReason
const completions = []

try {
for await (const event of this.stream) {
yield event
const parsed = this.parseEvent(event)
this.updateHeaders(parsed)
currentBody = parsed
if (parsed.type === 'content_block_delta') {
completions.push(parsed.delta.text)
} else if (parsed.type === 'message_delta') {
stopReason = parsed.delta.stop_reason
}
}
} finally {
currentBody.completions = completions
currentBody.stop_reason = stopReason
this.response.output.body = currentBody
this.finish()
}
}

async function* handleCohere() {
let currentBody = {}
const generations = []
Expand Down
9 changes: 8 additions & 1 deletion test/lib/aws-server-stubs/ai-server/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ function handler(req, res) {
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html
const [, model] = /model\/(.+)\/invoke/.exec(req.url)
let response
switch (model) {
switch (decodeURIComponent(model)) {
case 'ai21.j2-mid-v1':
case 'ai21.j2-ultra-v1': {
response = responses.ai21.get(payload.prompt)
Expand All @@ -94,6 +94,13 @@ function handler(req, res) {
break
}

case 'anthropic.claude-3-haiku-20240307-v1:0':
case 'anthropic.claude-3-opus-20240229-v1:0':
case 'anthropic.claude-3-sonnet-20240229-v1:0': {
response = responses.claude3.get(payload?.messages?.[0]?.content)
break
}

case 'cohere.command-text-v14':
case 'cohere.command-light-text-v14': {
response = responses.cohere.get(payload.prompt)
Expand Down
156 changes: 156 additions & 0 deletions test/lib/aws-server-stubs/ai-server/responses/claude3.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
/*
* Copyright 2024 New Relic Corporation. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*/

'use strict'

const responses = new Map()
const { contentType, reqId } = require('./constants')

responses.set('text claude3 ultimate question', {
headers: {
'content-type': contentType,
'x-amzn-requestid': reqId,
'x-amzn-bedrock-invocation-latency': '926',
'x-amzn-bedrock-output-token-count': '36',
'x-amzn-bedrock-input-token-count': '14'
},
statusCode: 200,
body: {
id: 'msg_bdrk_019V7ABaw8ZZZYuRDSTWK7VE',
type: 'message',
role: 'assistant',
model: 'claude-3-haiku-20240307',
stop_sequence: null,
usage: { input_tokens: 30, output_tokens: 265 },
content: [
{
type: 'text',
text: '42'
}
],
stop_reason: 'endoftext'
}
})

responses.set('text claude3 ultimate question streamed', {
headers: {
'content-type': 'application/vnd.amazon.eventstream',
'x-amzn-requestid': reqId,
'x-amzn-bedrock-content-type': contentType
},
statusCode: 200,
// Please do not simplify the set of chunks. This set represents a minimal
// streaming response from the "Messages API". Such a stream is different from
// the other streamed responses, and we need an example of what a Messages API
// stream looks like.
chunks: [
{
body: {
type: 'message_start',
message: {
content: [],
id: 'msg_bdrk_sljfaofk',
model: 'claude-3-sonnet-20240229',
role: 'assistant',
stop_reason: null,
stop_sequence: null,
type: 'message',
usage: {
input_tokens: 30,
output_tokens: 1
}
}
},
headers: {
':event-type': { type: 'string', value: 'chunk' },
':content-type': { type: 'string', value: 'application/json' },
':message-type': { type: 'string', value: 'event' }
}
},
{
body: {
type: 'content_block_start',
index: 0,
content_block: { type: 'text', text: '' }
},
headers: {
':event-type': { type: 'string', value: 'chunk' },
':content-type': { type: 'string', value: 'application/json' },
':message-type': { type: 'string', value: 'event' }
}
},
{
body: {
type: 'content_block_delta',
index: 0,
delta: { type: 'text_delta', text: '42' }
},
headers: {
':event-type': { type: 'string', value: 'chunk' },
':content-type': { type: 'string', value: 'application/json' },
':message-type': { type: 'string', value: 'event' }
}
},
{
body: {
type: 'content_block_stop',
index: 0
},
headers: {
':event-type': { type: 'string', value: 'chunk' },
':content-type': { type: 'string', value: 'application/json' },
':message-type': { type: 'string', value: 'event' }
}
},
{
body: {
type: 'message_delta',
usage: { output_tokens: 1 },
delta: {
// The actual reason from the API will be `max_tokens` if the maximum
// allowed tokens have been reached. But our tests expect "endoftext".
stop_reason: 'endoftext',
stop_sequence: null
}
},
headers: {
':event-type': { type: 'string', value: 'chunk' },
':content-type': { type: 'string', value: 'application/json' },
':message-type': { type: 'string', value: 'event' }
}
},
{
body: {
type: 'message_stop',
['amazon-bedrock-invocationMetrics']: {
inputTokenCount: 8,
outputTokenCount: 4,
invocationLatency: 511,
firstByteLatency: 358
}
},
headers: {
':event-type': { type: 'string', value: 'chunk' },
':content-type': { type: 'string', value: 'application/json' },
':message-type': { type: 'string', value: 'event' }
}
}
]
})

responses.set('text claude3 ultimate question error', {
headers: {
'content-type': contentType,
'x-amzn-requestid': reqId,
'x-amzn-errortype': 'ValidationException:http://internal.amazon.com/coral/com.amazon.bedrock/'
},
statusCode: 400,
body: {
message:
'Malformed input request: 2 schema violations found, please reformat your input and try again.'
}
})

module.exports = responses
2 changes: 2 additions & 0 deletions test/lib/aws-server-stubs/ai-server/responses/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@
const ai21 = require('./ai21')
const amazon = require('./amazon')
const claude = require('./claude')
const claude3 = require('./claude3')
const cohere = require('./cohere')
const llama2 = require('./llama2')

module.exports = {
ai21,
amazon,
claude,
claude3,
cohere,
llama2
}
33 changes: 33 additions & 0 deletions test/unit/llm-events/aws-bedrock/bedrock-command.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@ const claude = {
}
}

const claude3 = {
modelId: 'anthropic.claude-3-haiku-20240307-v1:0',
body: {
messages: [{ content: 'who are you' }]
}
}

const cohere = {
modelId: 'cohere.command-text-v14',
body: {
Expand Down Expand Up @@ -75,6 +82,7 @@ tap.test('non-conforming command is handled gracefully', async (t) => {
for (const model of [
'Ai21',
'Claude',
'Claude3',
'Cohere',
'CohereEmbed',
'Llama2',
Expand Down Expand Up @@ -140,6 +148,31 @@ tap.test('claude complete command works', async (t) => {
t.equal(cmd.temperature, payload.body.temperature)
})

tap.test('claude3 minimal command works', async (t) => {
t.context.updatePayload(structuredClone(claude3))
const cmd = new BedrockCommand(t.context.input)
t.equal(cmd.isClaude3(), true)
t.equal(cmd.maxTokens, undefined)
t.equal(cmd.modelId, claude3.modelId)
t.equal(cmd.modelType, 'completion')
t.equal(cmd.prompt, claude3.body.messages[0].content)
t.equal(cmd.temperature, undefined)
})

tap.test('claude3 complete command works', async (t) => {
const payload = structuredClone(claude3)
payload.body.max_tokens = 25
payload.body.temperature = 0.5
t.context.updatePayload(payload)
const cmd = new BedrockCommand(t.context.input)
t.equal(cmd.isClaude3(), true)
t.equal(cmd.maxTokens, 25)
t.equal(cmd.modelId, payload.modelId)
t.equal(cmd.modelType, 'completion')
t.equal(cmd.prompt, payload.body.messages[0].content)
t.equal(cmd.temperature, payload.body.temperature)
})

tap.test('cohere minimal command works', async (t) => {
t.context.updatePayload(structuredClone(cohere))
const cmd = new BedrockCommand(t.context.input)
Expand Down
3 changes: 3 additions & 0 deletions test/unit/llm-events/aws-bedrock/bedrock-response.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ tap.beforeEach((t) => {
isClaude() {
return false
},
isClaude3() {
return false
},
isCohere() {
return false
},
Expand Down
Loading

0 comments on commit 7e3cab9

Please sign in to comment.