Skip to content

Commit

Permalink
Merge pull request #171 from enricoros/llms-rework
Browse files Browse the repository at this point in the history
Llms rework
  • Loading branch information
enricoros authored Oct 12, 2023
2 parents 19361ac + 85e97e9 commit 09d38eb
Show file tree
Hide file tree
Showing 28 changed files with 501 additions and 507 deletions.
6 changes: 4 additions & 2 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@ OPENAI_API_KEY=
OPENAI_API_ORG_ID=
# [Optional] Set the backend host for the OpenAI API, to enable platforms such as Helicone (UI > this > api.openai.com)
OPENAI_API_HOST=

# [Optional, Helicone] Helicone API key: https://www.helicone.ai/keys
# [Optional, Helicone] Helicone API key: https://www.helicone.ai/keys - NOTE: only for OpenAI APIs at the moment
HELICONE_API_KEY=

# [Optional] Azure OpenAI Service credentials for the server-side (if set, both must be set)
Expand All @@ -16,6 +15,9 @@ AZURE_OPENAI_API_KEY=
ANTHROPIC_API_KEY=
ANTHROPIC_API_HOST=

# [Optional] OpenRouter
OPENROUTER_API_KEY=

# [Optional] Enables ElevenLabs credentials on the server side - for optional text-to-speech
ELEVENLABS_API_KEY=
ELEVENLABS_API_HOST=
Expand Down
1 change: 1 addition & 0 deletions next.config.js
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ let nextConfig = {
HAS_SERVER_KEY_AZURE_OPENAI: !!process.env.AZURE_OPENAI_API_KEY && !!process.env.AZURE_OPENAI_API_ENDPOINT,
HAS_SERVER_KEY_ELEVENLABS: !!process.env.ELEVENLABS_API_KEY,
HAS_SERVER_KEY_OPENAI: !!process.env.OPENAI_API_KEY,
HAS_SERVER_KEY_OPENROUTER: !!process.env.OPENROUTER_API_KEY,
HAS_SERVER_KEY_PRODIA: !!process.env.PRODIA_API_KEY,
},
webpack(config, { isServer, dev }) {
Expand Down
49 changes: 27 additions & 22 deletions src/apps/models-modal/ModelsSourceSelector.tsx
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import * as React from 'react';
import { shallow } from 'zustand/shallow';

import { Box, Button, IconButton, ListItemDecorator, MenuItem, Option, Select, Typography } from '@mui/joy';
import { Badge, Box, Button, IconButton, ListItemDecorator, MenuItem, Option, Select, Typography } from '@mui/joy';
import AddIcon from '@mui/icons-material/Add';
import CloudDoneOutlinedIcon from '@mui/icons-material/CloudDoneOutlined';
import CloudOutlinedIcon from '@mui/icons-material/CloudOutlined';
Expand All @@ -10,23 +10,25 @@ import DeleteOutlineIcon from '@mui/icons-material/DeleteOutline';

import { DModelSourceId, useModelsStore } from '~/modules/llms/store-llms';
import { IModelVendor, ModelVendorId } from '~/modules/llms/vendors/IModelVendor';
import { ModelVendorOpenAI } from '~/modules/llms/vendors/openai/openai.vendor';
import { createModelSourceForVendor, findAllVendors, findVendorById } from '~/modules/llms/vendors/vendor.registry';
import { hasServerKeyOpenAI } from '~/modules/llms/vendors/openai/openai.vendor';

import { CloseableMenu } from '~/common/components/CloseableMenu';
import { ConfirmationModal } from '~/common/components/ConfirmationModal';
import { hideOnDesktop, hideOnMobile } from '~/common/theme';


function locationIcon(vendor?: IModelVendor | null) {
if (vendor && vendor.id === 'openai' && hasServerKeyOpenAI)
if (vendor && vendor.id === 'openai' && ModelVendorOpenAI.hasServerKey)
return <CloudDoneOutlinedIcon />;
return !vendor ? null : vendor.location === 'local' ? <ComputerIcon /> : <CloudOutlinedIcon />;
}

function vendorIcon(vendor?: IModelVendor | null) {
function vendorIcon(vendor: IModelVendor | null, greenMark: boolean) {
const Icon = !vendor ? null : vendor.Icon;
return Icon ? <Icon /> : null;
return (greenMark && Icon)
? <Badge color='primary' size='sm' badgeContent=''><Icon /></Badge>
: Icon ? <Icon /> : null;
}


Expand Down Expand Up @@ -73,23 +75,26 @@ export function ModelsSourceSelector(props: {


// vendor list items
const vendorItems = React.useMemo(() => findAllVendors().filter(v => !!v.instanceLimit).map(vendor => {
const sourceCount = modelSources.filter(source => source.vId === vendor.id).length;
const enabled = vendor.instanceLimit > sourceCount;
return {
vendor,
enabled,
sourceCount,
component: (
<MenuItem key={vendor.id} disabled={!enabled} onClick={() => handleAddSourceFromVendor(vendor.id)}>
<ListItemDecorator>
{vendorIcon(vendor)}
</ListItemDecorator>
{vendor.name}{/*{sourceCount > 0 && ` (added)`}*/}
</MenuItem>
),
};
}), [handleAddSourceFromVendor, modelSources]);
const vendorItems = React.useMemo(() => findAllVendors()
.filter(v => !!v.instanceLimit)
.map(vendor => {
const sourceCount = modelSources.filter(source => source.vId === vendor.id).length;
const enabled = vendor.instanceLimit > sourceCount;
return {
vendor,
enabled,
sourceCount,
component: (
<MenuItem key={vendor.id} disabled={!enabled} onClick={() => handleAddSourceFromVendor(vendor.id)}>
<ListItemDecorator>
{vendorIcon(vendor, !!vendor.hasServerKey)}
</ListItemDecorator>
{vendor.name}{/*{sourceCount > 0 && ` (added)`}*/}
</MenuItem>
),
};
},
), [handleAddSourceFromVendor, modelSources]);


// source items
Expand Down
4 changes: 4 additions & 0 deletions src/common/types/env.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ declare namespace NodeJS {
ANTHROPIC_API_KEY: string;
ANTHROPIC_API_HOST: string;

// LLM: OpenRouter
OPENROUTER_API_KEY: string;

// Helicone
HELICONE_API_KEY: string;

Expand All @@ -43,6 +46,7 @@ declare namespace NodeJS {
HAS_SERVER_KEY_AZURE_OPENAI?: boolean;
HAS_SERVER_KEY_ELEVENLABS: boolean;
HAS_SERVER_KEY_OPENAI?: boolean;
HAS_SERVER_KEY_OPENROUTER?: boolean;
HAS_SERVER_KEY_PRODIA: boolean;

}
Expand Down
34 changes: 18 additions & 16 deletions src/modules/llms/store-llms.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import { ModelVendorId } from './vendors/IModelVendor';
/**
* Large Language Model - description and configuration (data object, stored)
*/
export interface DLLM<TLLMOptions = unknown, TModelSource = DModelSource> {
export interface DLLM<TSourceSetup = unknown, TLLMOptions = unknown> {
id: DLLMId;
label: string;
created: number | 0;
Expand All @@ -20,10 +20,10 @@ export interface DLLM<TLLMOptions = unknown, TModelSource = DModelSource> {

// llm -> source
sId: DModelSourceId;
_source: TModelSource;
_source: DModelSource<TSourceSetup>;

// llm-specific
options: Partial<{ llmRef: string } & TLLMOptions>;
options: { llmRef: string } & Partial<TLLMOptions>;
}

export type DLLMId = string;
Expand All @@ -37,15 +37,15 @@ export const LLM_IF_OAI_Complete = 'oai-complete';
/**
* Model Server - configured to be a unique origin of models (data object, stored)
*/
export interface DModelSource<TModelSetup = unknown> {
export interface DModelSource<TSourceSetup = unknown> {
id: DModelSourceId;
label: string;

// source -> vendor
vId: ModelVendorId;

// source-specific
setup: Partial<TModelSetup>;
setup: Partial<TSourceSetup>;
}

export type DModelSourceId = string;
Expand All @@ -65,11 +65,11 @@ interface ModelsActions {
addLLMs: (llms: DLLM[]) => void;
removeLLM: (id: DLLMId) => void;
updateLLM: (id: DLLMId, partial: Partial<DLLM>) => void;
updateLLMOptions: <T>(id: DLLMId, partialOptions: Partial<T>) => void;
updateLLMOptions: <TLLMOptions>(id: DLLMId, partialOptions: Partial<TLLMOptions>) => void;

addSource: (source: DModelSource) => void;
removeSource: (id: DModelSourceId) => void;
updateSourceSetup: <T>(id: DModelSourceId, partialSetup: Partial<T>) => void;
updateSourceSetup: <TSourceSetup>(id: DModelSourceId, partialSetup: Partial<TSourceSetup>) => void;

setChatLLMId: (id: DLLMId | null) => void;
setFastLLMId: (id: DLLMId | null) => void;
Expand Down Expand Up @@ -124,7 +124,7 @@ export const useModelsStore = create<ModelsData & ModelsActions>()(
),
})),

updateLLMOptions: <T>(id: DLLMId, partialOptions: Partial<T>) =>
updateLLMOptions: <TLLMOptions>(id: DLLMId, partialOptions: Partial<TLLMOptions>) =>
set(state => ({
llms: state.llms.map((llm: DLLM): DLLM =>
llm.id === id
Expand All @@ -149,7 +149,7 @@ export const useModelsStore = create<ModelsData & ModelsActions>()(
};
}),

updateSourceSetup: <T>(id: DModelSourceId, partialSetup: Partial<T>) =>
updateSourceSetup: <TSourceSetup>(id: DModelSourceId, partialSetup: Partial<TSourceSetup>) =>
set(state => ({
sources: state.sources.map((source: DModelSource): DModelSource =>
source.id === id
Expand Down Expand Up @@ -191,11 +191,11 @@ const defaultChatSuffixPreference = ['gpt-4-0613', 'gpt-4', 'gpt-4-32k', 'gpt-3.
const defaultFastSuffixPreference = ['gpt-3.5-turbo-0613', 'gpt-3.5-turbo-16k-0613', 'gpt-3.5-turbo-16k', 'gpt-3.5-turbo'];
const defaultFuncSuffixPreference = ['gpt-3.5-turbo-0613', 'gpt-4-0613'];

export function findLLMOrThrow<TLLMOptions>(llmId: DLLMId): DLLM<TLLMOptions> {
export function findLLMOrThrow<TSourceSetup, TLLMOptions>(llmId: DLLMId): DLLM<TSourceSetup, TLLMOptions> {
const llm = useModelsStore.getState().llms.find(llm => llm.id === llmId);
if (!llm) throw new Error(`LLM ${llmId} not found`);
if (!llm._source) throw new Error(`LLM ${llmId} has no source`);
return llm as DLLM<TLLMOptions>;
return llm as DLLM<TSourceSetup, TLLMOptions>;
}

function findLlmIdBySuffix(llms: DLLM[], suffixes: string[], fallbackToFirst: boolean): DLLMId | null {
Expand Down Expand Up @@ -235,19 +235,21 @@ export function useChatLLM() {
/**
* Source-specific read/write - great time saver
*/
export function useSourceSetup<T>(sourceId: DModelSourceId, normalizer: (partialSetup?: Partial<T>) => T) {
export function useSourceSetup<TSourceSetup, TAccess>(sourceId: DModelSourceId, getAccess: (partialSetup?: Partial<TSourceSetup>) => TAccess) {
// invalidate when the setup changes
const { updateSourceSetup, ...rest } = useModelsStore(state => {
const source = state.sources.find(source => source.id === sourceId) ?? null;
const source: DModelSource<TSourceSetup> | null = state.sources.find(source => source.id === sourceId) ?? null;
const sourceLLMs = source ? state.llms.filter(llm => llm._source === source) : [];
return {
source,
sourceLLMs: source ? state.llms.filter(llm => llm._source === source) : [],
normSetup: normalizer(source?.setup as Partial<T> | undefined),
sourceLLMs,
sourceHasLLMs: !!sourceLLMs.length,
access: getAccess(source?.setup),
updateSourceSetup: state.updateSourceSetup,
};
}, shallow);

// convenience function for this source
const updateSetup = (partialSetup: Partial<T>) => updateSourceSetup<T>(sourceId, partialSetup);
const updateSetup = (partialSetup: Partial<TSourceSetup>) => updateSourceSetup<TSourceSetup>(sourceId, partialSetup);
return { ...rest, updateSetup };
}
6 changes: 3 additions & 3 deletions src/modules/llms/transports/chatGenerate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ export interface VChatMessageOrFunctionCallOut extends VChatMessageOut {

export async function callChatGenerate(llmId: DLLMId, messages: VChatMessageIn[], maxTokens?: number): Promise<VChatMessageOut> {
const { llm, vendor } = findVendorForLlmOrThrow(llmId);
return await vendor.callChat(llm, messages, maxTokens);
return await vendor.callChatGenerate(llm, messages, maxTokens);
}

export async function callChatGenerateWithFunctions(llmId: DLLMId, messages: VChatMessageIn[], functions: VChatFunctionIn[], forceFunctionName?: string, maxTokens?: number): Promise<VChatMessageOrFunctionCallOut> {
export async function callChatGenerateWithFunctions(llmId: DLLMId, messages: VChatMessageIn[], functions: VChatFunctionIn[], forceFunctionName: string | null, maxTokens?: number): Promise<VChatMessageOrFunctionCallOut> {
const { llm, vendor } = findVendorForLlmOrThrow(llmId);
return await vendor.callChatWithFunctions(llm, messages, functions, forceFunctionName, maxTokens);
return await vendor.callChatGenerateWF(llm, messages, functions, forceFunctionName, maxTokens);
}
35 changes: 21 additions & 14 deletions src/modules/llms/transports/server/anthropic.router.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,30 @@ import { fetchJsonOrTRPCError } from '~/modules/trpc/trpc.serverutils';

import { LLM_IF_OAI_Chat } from '../../store-llms';

import { chatGenerateOutputSchema, historySchema, modelSchema } from './openai.router';
import { openAIChatGenerateOutputSchema, openAIHistorySchema, openAIModelSchema } from './openai.router';
import { listModelsOutputSchema, ModelDescriptionSchema } from './server.common';

import { AnthropicWire } from './anthropic.wiretypes';


// Input Schemas

const anthropicAccessSchema = z.object({
export const anthropicAccessSchema = z.object({
dialect: z.literal('anthropic'),
anthropicKey: z.string().trim(),
anthropicHost: z.string().trim(),
});
export type AnthropicAccessSchema = z.infer<typeof anthropicAccessSchema>;

const anthropicChatGenerateSchema = z.object({ access: anthropicAccessSchema, model: modelSchema, history: historySchema });

const anthropicListModelsSchema = z.object({ access: anthropicAccessSchema });
const listModelsInputSchema = z.object({
access: anthropicAccessSchema,
});

const chatGenerateInputSchema = z.object({
access: anthropicAccessSchema,
model: openAIModelSchema, history: openAIHistorySchema,
});


export const llmAnthropicRouter = createTRPCRouter({
Expand All @@ -32,14 +40,14 @@ export const llmAnthropicRouter = createTRPCRouter({
* some details on the models, as the API docs are scarce: https://docs.anthropic.com/claude/reference/selecting-a-model
*/
listModels: publicProcedure
.input(anthropicListModelsSchema)
.input(listModelsInputSchema)
.output(listModelsOutputSchema)
.query(() => ({ models: hardcodedAnthropicModels })),

/* Anthropic: Chat generation */
chatGenerate: publicProcedure
.input(anthropicChatGenerateSchema)
.output(chatGenerateOutputSchema)
.input(chatGenerateInputSchema)
.output(openAIChatGenerateOutputSchema)
.mutation(async ({ input }) => {

const { access, model, history } = input;
Expand All @@ -50,7 +58,7 @@ export const llmAnthropicRouter = createTRPCRouter({

const wireCompletions = await anthropicPOST<AnthropicWire.Complete.Response, AnthropicWire.Complete.Request>(
access,
anthropicCompletionRequest(model, history, false),
anthropicChatCompletionPayload(model, history, false),
'/v1/complete',
);

Expand Down Expand Up @@ -118,18 +126,17 @@ const hardcodedAnthropicModels: ModelDescriptionSchema[] = [
},
];

type AccessSchema = z.infer<typeof anthropicAccessSchema>;
type ModelSchema = z.infer<typeof modelSchema>;
type HistorySchema = z.infer<typeof historySchema>;
type ModelSchema = z.infer<typeof openAIModelSchema>;
type HistorySchema = z.infer<typeof openAIHistorySchema>;

async function anthropicPOST<TOut, TPostBody>(access: AccessSchema, body: TPostBody, apiPath: string /*, signal?: AbortSignal*/): Promise<TOut> {
async function anthropicPOST<TOut, TPostBody>(access: AnthropicAccessSchema, body: TPostBody, apiPath: string /*, signal?: AbortSignal*/): Promise<TOut> {
const { headers, url } = anthropicAccess(access, apiPath);
return await fetchJsonOrTRPCError<TOut, TPostBody>(url, 'POST', headers, body, 'Anthropic');
}

const DEFAULT_ANTHROPIC_HOST = 'api.anthropic.com';

export function anthropicAccess(access: AccessSchema, apiPath: string): { headers: HeadersInit, url: string } {
export function anthropicAccess(access: AnthropicAccessSchema, apiPath: string): { headers: HeadersInit, url: string } {
// API version
const apiVersion = '2023-06-01';

Expand Down Expand Up @@ -158,7 +165,7 @@ export function anthropicAccess(access: AccessSchema, apiPath: string): { header
};
}

export function anthropicCompletionRequest(model: ModelSchema, history: HistorySchema, stream: boolean): AnthropicWire.Complete.Request {
export function anthropicChatCompletionPayload(model: ModelSchema, history: HistorySchema, stream: boolean): AnthropicWire.Complete.Request {
// encode the prompt for Claude models
const prompt = history.map(({ role, content }) => {
return role === 'assistant' ? `\n\nAssistant: ${content}` : `\n\nHuman: ${content}`;
Expand Down
Loading

1 comment on commit 09d38eb

@vercel
Copy link

@vercel vercel bot commented on 09d38eb Oct 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Successfully deployed to the following URLs:

big-agi – ./

big-agi-git-main-enricoros.vercel.app
big-agi-enricoros.vercel.app
get.big-agi.com

Please sign in to comment.