diff --git a/src/cloudflare/internal/ai-api.ts b/src/cloudflare/internal/ai-api.ts index 3541358c257f..909978743f3f 100644 --- a/src/cloudflare/internal/ai-api.ts +++ b/src/cloudflare/internal/ai-api.ts @@ -18,7 +18,8 @@ export type SessionOptions = { // Deprecated, do not use this } export type AiOptions = { - debug?: boolean; + gatewayId?: string; + prefix?: string; extraHeaders?: object; /* @@ -61,11 +62,12 @@ export class Ai { this.options = options; this.lastRequestId = ""; + // This removes some unwanted options from getting sent in the body + const cleanedOptions = (({ prefix, extraHeaders, sessionOptions, ...object }) => object)(this.options || {}); + const body = JSON.stringify({ inputs, - options: { - debug: this.options?.debug, - }, + options: cleanedOptions }); const fetchOptions = { diff --git a/src/cloudflare/internal/test/ai/ai-api-test.js b/src/cloudflare/internal/test/ai/ai-api-test.js index 9d6bcd37b773..41e9fec4e7e2 100644 --- a/src/cloudflare/internal/test/ai/ai-api-test.js +++ b/src/cloudflare/internal/test/ai/ai-api-test.js @@ -66,5 +66,26 @@ export const tests = { assert.equal(err.name, 'InferenceUpstreamError') assert.equal(err.message, 'Unknown error') } + + { + // Test raw input + const resp = await env.ai.run('rawInputs', {prompt: 'test'}) + + assert.deepStrictEqual(resp, { inputs: {prompt: 'test'}, options: {} }); + } + + { + // Test gateway option + const resp = await env.ai.run('rawInputs', {prompt: 'test'}, {gatewayId: 'my-gateway'}) + + assert.deepStrictEqual(resp, { inputs: {prompt: 'test'}, options: {gatewayId: 'my-gateway'} }); + } + + { + // Test unwanted options not getting sent upstream + const resp = await env.ai.run('rawInputs', {prompt: 'test'}, {extraHeaders: 'test', prefix: 'another', example: 123, gatewayId: 'my-gateway'}) + + assert.deepStrictEqual(resp, { inputs: {prompt: 'test'}, options: {example: 123, gatewayId: 'my-gateway'} }); + } }, } diff --git a/src/cloudflare/internal/test/ai/ai-mock.js b/src/cloudflare/internal/test/ai/ai-mock.js index 7b7e1cee197d..592e0389e29f 100644 --- a/src/cloudflare/internal/test/ai/ai-mock.js +++ b/src/cloudflare/internal/test/ai/ai-mock.js @@ -21,6 +21,12 @@ export default { }) } + if (modelName === 'rawInputs') { + return Response.json(data, { + headers: respHeaders + }) + } + if (modelName === 'inputErrorModel') { return Response.json({ internalCode: 1001, diff --git a/types/defines/ai.d.ts b/types/defines/ai.d.ts index 73d9df516a46..26d242de3e30 100644 --- a/types/defines/ai.d.ts +++ b/types/defines/ai.d.ts @@ -139,6 +139,7 @@ export declare abstract class BaseAiTranslation { postProcessedOutputs: AiTranslationOutput; } export type AiOptions = { + gatewayId?: string; prefix?: string; extraHeaders?: object; };