From 0351a865b189d23a8e5c18c5591adc129499d0a5 Mon Sep 17 00:00:00 2001 From: Gabriel Massadas Date: Thu, 16 May 2024 15:42:23 +0100 Subject: [PATCH] Add support for setting gatewayId in AI binding --- src/cloudflare/internal/ai-api.ts | 10 +++++---- .../internal/test/ai/ai-api-test.js | 21 +++++++++++++++++++ src/cloudflare/internal/test/ai/ai-mock.js | 6 ++++++ types/defines/ai.d.ts | 1 + 4 files changed, 34 insertions(+), 4 deletions(-) diff --git a/src/cloudflare/internal/ai-api.ts b/src/cloudflare/internal/ai-api.ts index 3541358c257..a83b3fd72f4 100644 --- a/src/cloudflare/internal/ai-api.ts +++ b/src/cloudflare/internal/ai-api.ts @@ -18,7 +18,8 @@ export type SessionOptions = { // Deprecated, do not use this } export type AiOptions = { - debug?: boolean; + gatewayId?: string; + prefix?: string; extraHeaders?: object; /* @@ -61,11 +62,12 @@ export class Ai { this.options = options; this.lastRequestId = ""; + // This removes some unwanted options from getting sent in the body + const cleanedOptions = (({ prefix, extraHeaders, sessionOptions, ...object }): object => object)(this.options || {}); + const body = JSON.stringify({ inputs, - options: { - debug: this.options?.debug, - }, + options: cleanedOptions }); const fetchOptions = { diff --git a/src/cloudflare/internal/test/ai/ai-api-test.js b/src/cloudflare/internal/test/ai/ai-api-test.js index 9d6bcd37b77..65d4a92b810 100644 --- a/src/cloudflare/internal/test/ai/ai-api-test.js +++ b/src/cloudflare/internal/test/ai/ai-api-test.js @@ -66,5 +66,26 @@ export const tests = { assert.equal(err.name, 'InferenceUpstreamError') assert.equal(err.message, 'Unknown error') } + + { + // Test raw input + const resp = await env.ai.run('rawInputs', {prompt: 'test'}) + + assert.deepStrictEqual(resp, { inputs: {prompt: 'test'}, options: {} }); + } + + { + // Test gateway option + const resp = await env.ai.run('rawInputs', {prompt: 'test'}, {gatewayId: 'my-gateway'}) + + assert.deepStrictEqual(resp, { inputs: {prompt: 'test'}, options: {gatewayId: 'my-gateway'} }); + } + + { + // Test unwanted options not getting sent upstream + const resp = await env.ai.run('rawInputs', {prompt: 'test'}, {extraHeaders: 'test', example: 123, gatewayId: 'my-gateway'}) + + assert.deepStrictEqual(resp, { inputs: {prompt: 'test'}, options: {example: 123, gatewayId: 'my-gateway'} }); + } }, } diff --git a/src/cloudflare/internal/test/ai/ai-mock.js b/src/cloudflare/internal/test/ai/ai-mock.js index 7b7e1cee197..592e0389e29 100644 --- a/src/cloudflare/internal/test/ai/ai-mock.js +++ b/src/cloudflare/internal/test/ai/ai-mock.js @@ -21,6 +21,12 @@ export default { }) } + if (modelName === 'rawInputs') { + return Response.json(data, { + headers: respHeaders + }) + } + if (modelName === 'inputErrorModel') { return Response.json({ internalCode: 1001, diff --git a/types/defines/ai.d.ts b/types/defines/ai.d.ts index 73d9df516a4..26d242de3e3 100644 --- a/types/defines/ai.d.ts +++ b/types/defines/ai.d.ts @@ -139,6 +139,7 @@ export declare abstract class BaseAiTranslation { postProcessedOutputs: AiTranslationOutput; } export type AiOptions = { + gatewayId?: string; prefix?: string; extraHeaders?: object; };