Skip to content

Commit

Permalink
CallChatWithFunctions - functions support, incl. OpenAI Implementation
Browse files Browse the repository at this point in the history
May be rough on the edges, but should not create issues.
The implementation is defensive, excessively validates the
return types as the OpenAI API is brittle and can easily misbehave
  • Loading branch information
enricoros committed Jun 28, 2023
1 parent 87d9309 commit 2d4c0e9
Show file tree
Hide file tree
Showing 8 changed files with 193 additions and 74 deletions.
4 changes: 2 additions & 2 deletions pages/api/openai/stream-chat.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { NextRequest, NextResponse } from 'next/server';
import { createParser } from 'eventsource-parser';

import { ChatGenerateSchema, chatGenerateSchema, openAIAccess, openAICompletionRequest } from '~/modules/llms/openai/openai.router';
import { ChatGenerateSchema, chatGenerateSchema, openAIAccess, openAIChatCompletionRequest } from '~/modules/llms/openai/openai.router';
import { OpenAI } from '~/modules/llms/openai/openai.types';


Expand Down Expand Up @@ -31,7 +31,7 @@ async function chatStreamRepeater(access: ChatGenerateSchema['access'], model: C

// prepare request objects
const { headers, url } = openAIAccess(access, '/v1/chat/completions');
const body: OpenAI.Wire.ChatCompletion.Request = openAICompletionRequest(model, history, true);
const body: OpenAI.Wire.ChatCompletion.Request = openAIChatCompletionRequest(model, history, null, true);

// perform the request
upstreamResponse = await fetch(url, { headers, method: 'POST', body: JSON.stringify(body), signal });
Expand Down
50 changes: 41 additions & 9 deletions src/modules/llms/llm.client.ts
Original file line number Diff line number Diff line change
@@ -1,17 +1,49 @@
import { DLLMId } from '~/modules/llms/llm.types';
import { findVendorById } from '~/modules/llms/vendor.registry';
import { useModelsStore } from '~/modules/llms/store-llms';

import { DLLM, DLLMId } from './llm.types';
import { OpenAI } from './openai/openai.types';
import { findVendorById } from './vendor.registry';
import { useModelsStore } from './store-llms';


export type ModelVendorCallChatFn = (llm: DLLM, messages: VChatMessageIn[], maxTokens?: number) => Promise<VChatMessageOut>;
export type ModelVendorCallChatWithFunctionsFn = (llm: DLLM, messages: VChatMessageIn[], functions: VChatFunctionIn[], maxTokens?: number) => Promise<VChatMessageOrFunctionCallOut>;

export interface VChatMessageIn {
role: 'assistant' | 'system' | 'user'; // | 'function';
content: string;
//name?: string; // when role: 'function'
}

export type VChatFunctionIn = OpenAI.Wire.ChatCompletion.RequestFunctionDef;

export interface VChatMessageOut {
role: 'assistant' | 'system' | 'user';
content: string;
finish_reason: 'stop' | 'length' | null;
}

export interface VChatFunctionCallOut {
function_name: string;
function_arguments: object | null;
}

export type VChatMessageOrFunctionCallOut = VChatMessageOut | VChatFunctionCallOut;

export async function callChatGenerate(llmId: DLLMId, messages: OpenAI.Wire.ChatCompletion.RequestMessage[], maxTokens?: number): Promise<OpenAI.API.Chat.Response> {

// get the vendor

export async function callChatGenerate(llmId: DLLMId, messages: VChatMessageIn[], maxTokens?: number): Promise<VChatMessageOut> {
const { llm, vendor } = getLLMAndVendorOrThrow(llmId);
return await vendor.callChat(llm, messages, maxTokens);
}

export async function callChatGenerateWithFunctions(llmId: DLLMId, messages: VChatMessageIn[], functions: VChatFunctionIn[], maxTokens?: number): Promise<VChatMessageOrFunctionCallOut> {
const { llm, vendor } = getLLMAndVendorOrThrow(llmId);
return await vendor.callChatWithFunctions(llm, messages, functions, maxTokens);
}


function getLLMAndVendorOrThrow(llmId: string) {
const llm = useModelsStore.getState().llms.find(llm => llm.id === llmId);
const vendor = findVendorById(llm?._source.vId);
if (!llm || !vendor) throw new Error(`callChat: Vendor not found for LLM ${llmId}`);

// go for it
return await vendor.callChat(llm, messages, maxTokens);
return { llm, vendor };
}
8 changes: 3 additions & 5 deletions src/modules/llms/llm.types.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import type React from 'react';

import type { LLMOptionsOpenAI, SourceSetupOpenAI } from './openai/openai.vendor';
import type { OpenAI } from './openai/openai.types';
import type { ModelVendorCallChatFn, ModelVendorCallChatWithFunctionsFn } from './llm.client';
import type { SourceSetupLocalAI } from './localai/localai.vendor';


export type DLLMId = string;
// export type DLLMTags = 'stream' | 'chat';
export type DLLMOptions = LLMOptionsOpenAI; //DLLMValuesOpenAI | DLLMVaLocalAIDLLMValues;
export type DModelSourceId = string;
export type DModelSourceSetup = SourceSetupOpenAI | SourceSetupLocalAI;
Expand Down Expand Up @@ -60,6 +59,5 @@ export interface ModelVendor {

// functions
callChat: ModelVendorCallChatFn;
}

type ModelVendorCallChatFn = (llm: DLLM, messages: OpenAI.Wire.ChatCompletion.RequestMessage[], maxTokens?: number) => Promise<OpenAI.API.Chat.Response>;
callChatWithFunctions: ModelVendorCallChatWithFunctionsFn;
}
3 changes: 2 additions & 1 deletion src/modules/llms/localai/localai.vendor.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ export const ModelVendorLocalAI: ModelVendor = {
LLMOptionsComponent: () => <>No LocalAI Options</>,

// functions
callChat: () => Promise.reject(new Error('LocalAI is not implemented')),
callChat: () => Promise.reject(new Error('LocalAI chat is not implemented')),
callChatWithFunctions: () => Promise.reject(new Error('LocalAI chatWithFunctions is not implemented')),
};


Expand Down
27 changes: 19 additions & 8 deletions src/modules/llms/openai/openai.client.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { apiAsync } from '~/modules/trpc/trpc.client';

import { DLLM } from '../llm.types';
import { OpenAI } from './openai.types';
import type { DLLM } from '../llm.types';
import type { VChatFunctionIn, VChatMessageIn, VChatMessageOrFunctionCallOut, VChatMessageOut } from '../llm.client';
import { normalizeOAISetup, SourceSetupOpenAI } from './openai.vendor';


Expand All @@ -10,25 +10,36 @@ export const hasServerKeyOpenAI = !!process.env.HAS_SERVER_KEY_OPENAI;
export const isValidOpenAIApiKey = (apiKey?: string) => !!apiKey && apiKey.startsWith('sk-') && apiKey.length > 40;


export const callChat = async (llm: DLLM, messages: VChatMessageIn[], maxTokens?: number) =>
callChatOverloaded<VChatMessageOut>(llm, messages, null, maxTokens);

export const callChatWithFunctions = async (llm: DLLM, messages: VChatMessageIn[], functions: VChatFunctionIn[], maxTokens?: number) =>
callChatOverloaded<VChatMessageOrFunctionCallOut>(llm, messages, functions, maxTokens);


/**
* This function either returns the LLM response, or throws a descriptive error string
* This function either returns the LLM message, or function calls, or throws a descriptive error string
*/
export async function callChat(llm: DLLM, messages: OpenAI.Wire.ChatCompletion.RequestMessage[], maxTokens?: number): Promise<OpenAI.API.Chat.Response> {
async function callChatOverloaded<TOut extends VChatMessageOrFunctionCallOut>(llm: DLLM, messages: VChatMessageIn[], functions: VChatFunctionIn[] | null, maxTokens?: number): Promise<TOut> {
// access params (source)
const partialSetup = llm._source.setup as Partial<SourceSetupOpenAI>;
const sourceSetupOpenAI = normalizeOAISetup(partialSetup);

// model params (llm)
const openaiLlmRef = llm.options.llmRef!;
const modelTemperature = llm.options.llmTemperature || 0.5;
// const maxTokens = llm.options.llmResponseTokens || 1024; // <- note: this would be for chat answers, not programmatic chat calls

try {
return await apiAsync.openai.chatGenerate.mutate({
return await apiAsync.openai.chatGenerateWithFunctions.mutate({
access: sourceSetupOpenAI,
model: { id: openaiLlmRef, temperature: modelTemperature, ...(maxTokens && { maxTokens }) },
model: {
id: openaiLlmRef,
temperature: modelTemperature,
...(maxTokens && { maxTokens }),
},
functions: functions ?? undefined,
history: messages,
});
}) as TOut;
// errorMessage = `issue fetching: ${response.status} · ${response.statusText}${errorPayload ? ' · ' + JSON.stringify(errorPayload) : ''}`;
} catch (error: any) {
const errorMessage = error?.message || error?.toString() || 'OpenAI Chat Fetch Error';
Expand Down
135 changes: 106 additions & 29 deletions src/modules/llms/openai/openai.router.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import { OpenAI } from './openai.types';
// console.warn('OPENAI_API_KEY has not been provided in this deployment environment. Will need client-supplied keys, which is not recommended.');


// Input Schemas

const accessSchema = z.object({
oaiKey: z.string().trim(),
oaiOrg: z.string().trim(),
Expand All @@ -29,7 +31,7 @@ const historySchema = z.array(z.object({
content: z.string(),
}));

/*const functionsSchema = z.array(z.object({
const functionsSchema = z.array(z.object({
name: z.string(),
description: z.string().optional(),
parameters: z.object({
Expand All @@ -41,46 +43,59 @@ const historySchema = z.array(z.object({
})),
required: z.array(z.string()).optional(),
}).optional(),
}));*/
}));

export const chatGenerateSchema = z.object({ access: accessSchema, model: modelSchema, history: historySchema });
export const chatGenerateSchema = z.object({ access: accessSchema, model: modelSchema, history: historySchema, functions: functionsSchema.optional() });
export type ChatGenerateSchema = z.infer<typeof chatGenerateSchema>;

export const chatModerationSchema = z.object({ access: accessSchema, text: z.string() });
const chatModerationSchema = z.object({ access: accessSchema, text: z.string() });


// Output Schemas

const chatGenerateWithFunctionsOutputSchema = z.union([
z.object({
role: z.enum(['assistant', 'system', 'user']),
content: z.string(),
finish_reason: z.union([z.enum(['stop', 'length']), z.null()]),
}),
z.object({
function_name: z.string(),
function_arguments: z.record(z.any()),
}),
]);




export const openAIRouter = createTRPCRouter({

/**
* Chat-based message generation
*/
chatGenerate: publicProcedure
chatGenerateWithFunctions: publicProcedure
.input(chatGenerateSchema)
.mutation(async ({ input }): Promise<OpenAI.API.Chat.Response> => {

const { access, model, history } = input;
const requestBody: OpenAI.Wire.ChatCompletion.Request = openAICompletionRequest(model, history, false);
let wireCompletions: OpenAI.Wire.ChatCompletion.Response;

// try {
wireCompletions = await openaiPOST<OpenAI.Wire.ChatCompletion.Request, OpenAI.Wire.ChatCompletion.Response>(access, requestBody, '/v1/chat/completions');
// } catch (error: any) {
// // NOTE: disabled on 2023-06-19: show all errors, 429 is not that common now, and could explain issues
// // don't log 429 errors on the server-side, they are expected
// if (!error || !(typeof error.startsWith === 'function') || !error.startsWith('Error: 429 · Too Many Requests'))
// console.error('api/openai/chat error:', error);
// throw error;
// }
.output(chatGenerateWithFunctionsOutputSchema)
.mutation(async ({ input }) => {

const { access, model, history, functions } = input;
const isFunctionsCall = !!functions && functions.length > 0;

const wireCompletions = await openaiPOST<OpenAI.Wire.ChatCompletion.Request, OpenAI.Wire.ChatCompletion.Response>(
access,
openAIChatCompletionRequest(model, history, isFunctionsCall ? functions : null, false),
'/v1/chat/completions',
);

// expect a single output
if (wireCompletions?.choices?.length !== 1)
throw new TRPCError({ code: 'INTERNAL_SERVER_ERROR', message: `[OpenAI Issue] Expected 1 completion, got ${wireCompletions?.choices?.length}` });
const { message, finish_reason } = wireCompletions.choices[0];

const singleChoice = wireCompletions.choices[0];
return {
role: singleChoice.message.role,
content: singleChoice.message.content,
finish_reason: singleChoice.finish_reason,
};
// check for a function output
return finish_reason === 'function_call'
? parseChatGenerateFCOutput(isFunctionsCall, message as OpenAI.Wire.ChatCompletion.ResponseFunctionCall)
: parseChatGenerateOutput(message as OpenAI.Wire.ChatCompletion.ResponseMessage, finish_reason);
}),

/**
Expand Down Expand Up @@ -147,6 +162,7 @@ export const openAIRouter = createTRPCRouter({
type AccessSchema = z.infer<typeof accessSchema>;
type ModelSchema = z.infer<typeof modelSchema>;
type HistorySchema = z.infer<typeof historySchema>;
type FunctionsSchema = z.infer<typeof functionsSchema>;

async function openaiGET<TOut>(access: AccessSchema, apiPath: string /*, signal?: AbortSignal*/): Promise<TOut> {
const { headers, url } = openAIAccess(access, apiPath);
Expand All @@ -171,7 +187,11 @@ async function openaiPOST<TBody, TOut>(access: AccessSchema, body: TBody, apiPat
: `[Issue] ${response.statusText}`,
});
}
return await response.json() as TOut;
try {
return await response.json();
} catch (error: any) {
throw new TRPCError({ code: 'INTERNAL_SERVER_ERROR', message: `[OpenAI Issue] ${error?.message || error}` });
}
}

export function openAIAccess(access: AccessSchema, apiPath: string): { headers: HeadersInit, url: string } {
Expand Down Expand Up @@ -203,14 +223,71 @@ export function openAIAccess(access: AccessSchema, apiPath: string): { headers:
};
}

export function openAICompletionRequest(model: ModelSchema, history: HistorySchema, stream: boolean): OpenAI.Wire.ChatCompletion.Request {
export function openAIChatCompletionRequest(model: ModelSchema, history: HistorySchema, functions: FunctionsSchema | null, stream: boolean): OpenAI.Wire.ChatCompletion.Request {
return {
model: model.id,
messages: history,
// ...(functions && { functions: functions, function_call: 'auto', }),
...(functions && { functions: functions, function_call: 'auto' }),
...(model.temperature && { temperature: model.temperature }),
...(model.maxTokens && { max_tokens: model.maxTokens }),
stream,
n: 1,
};
}

function parseChatGenerateFCOutput(isFunctionsCall: boolean, message: OpenAI.Wire.ChatCompletion.ResponseFunctionCall) {
// NOTE: Defensive: we run extensive validation because the API is not well tested and documented at the moment
if (!isFunctionsCall)
throw new TRPCError({
code: 'INTERNAL_SERVER_ERROR',
message: `[OpenAI Issue] Received a function call without a function call request`,
});

// parse the function call
const fcMessage = message as any as OpenAI.Wire.ChatCompletion.ResponseFunctionCall;
if (fcMessage.content !== null)
throw new TRPCError({
code: 'INTERNAL_SERVER_ERROR',
message: `[OpenAI Issue] Expected a function call, got a message`,
});

// got a function call, so parse it
const fc = fcMessage.function_call;
if (!fc || !fc.name || !fc.arguments)
throw new TRPCError({
code: 'INTERNAL_SERVER_ERROR',
message: `[OpenAI Issue] Issue with the function call, missing name or arguments`,
});

// decode the function call
const fcName = fc.name;
let fcArgs: object;
try {
fcArgs = JSON.parse(fc.arguments);
} catch (error: any) {
throw new TRPCError({
code: 'INTERNAL_SERVER_ERROR',
message: `[OpenAI Issue] Issue with the function call, arguments are not valid JSON`,
});
}

return {
function_name: fcName,
function_arguments: fcArgs,
};
}

function parseChatGenerateOutput(message: OpenAI.Wire.ChatCompletion.ResponseMessage, finish_reason: 'stop' | 'length' | null) {
// validate the message
if (message.content === null)
throw new TRPCError({
code: 'INTERNAL_SERVER_ERROR',
message: `[OpenAI Issue] Expected a message, got a null message`,
});

return {
role: message.role,
content: message.content,
finish_reason: finish_reason,
};
}
Loading

1 comment on commit 2d4c0e9

@vercel
Copy link

@vercel vercel bot commented on 2d4c0e9 Jun 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Successfully deployed to the following URLs:

big-agi – ./

big-agi-enricoros.vercel.app
get.big-agi.com
big-agi-git-main-enricoros.vercel.app

Please sign in to comment.