Skip to content

Commit

Permalink
Add support for setting gatewayId in AI binding
Browse files Browse the repository at this point in the history
  • Loading branch information
G4brym committed May 16, 2024
1 parent ad4fc23 commit 3682634
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 4 deletions.
10 changes: 6 additions & 4 deletions src/cloudflare/internal/ai-api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ export type SessionOptions = { // Deprecated, do not use this
}

export type AiOptions = {
debug?: boolean;
gatewayId?: string;

prefix?: string;
extraHeaders?: object;
/*
Expand Down Expand Up @@ -61,11 +62,12 @@ export class Ai {
this.options = options;
this.lastRequestId = "";

// This removes some unwanted options from getting sent in the body
const cleanedOptions = (({ prefix, extraHeaders, sessionOptions, ...object }) => object)(this.options || {});

const body = JSON.stringify({
inputs,
options: {
debug: this.options?.debug,
},
options: cleanedOptions
});

const fetchOptions = {
Expand Down
21 changes: 21 additions & 0 deletions src/cloudflare/internal/test/ai/ai-api-test.js
Original file line number Diff line number Diff line change
Expand Up @@ -66,5 +66,26 @@ export const tests = {
assert.equal(err.name, 'InferenceUpstreamError')
assert.equal(err.message, 'Unknown error')
}

{
// Test raw input
const resp = await env.ai.run('rawInputs', {prompt: 'test'})

assert.deepStrictEqual(resp, { inputs: {prompt: 'test'}, options: {} });
}

{
// Test gateway option
const resp = await env.ai.run('rawInputs', {prompt: 'test'}, {gatewayId: 'my-gateway'})

assert.deepStrictEqual(resp, { inputs: {prompt: 'test'}, options: {gatewayId: 'my-gateway'} });
}

{
// Test unwanted options not getting sent upstream
const resp = await env.ai.run('rawInputs', {prompt: 'test'}, {extraHeaders: 'test', prefix: 'another', example: 123, gatewayId: 'my-gateway'})

assert.deepStrictEqual(resp, { inputs: {prompt: 'test'}, options: {example: 123, gatewayId: 'my-gateway'} });
}
},
}
6 changes: 6 additions & 0 deletions src/cloudflare/internal/test/ai/ai-mock.js
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ export default {
})
}

if (modelName === 'rawInputs') {
return Response.json(data, {
headers: respHeaders
})
}

if (modelName === 'inputErrorModel') {
return Response.json({
internalCode: 1001,
Expand Down
1 change: 1 addition & 0 deletions types/defines/ai.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ export declare abstract class BaseAiTranslation {
postProcessedOutputs: AiTranslationOutput;
}
export type AiOptions = {
gatewayId?: string;
prefix?: string;
extraHeaders?: object;
};
Expand Down

0 comments on commit 3682634

Please sign in to comment.