diff --git a/.github/workflows/cla.yml b/.github/workflows/cla.yml index 20e7eb034..86543585f 100644 --- a/.github/workflows/cla.yml +++ b/.github/workflows/cla.yml @@ -24,6 +24,6 @@ jobs: PERSONAL_ACCESS_TOKEN: ${{ secrets.CLA_TOKEN }} with: path-to-signatures: 'signatures/version1/cla.json' - path-to-document: 'https://github.com/latitude-dev/latitude/blob/main/cla.md' + path-to-document: 'https://github.com/latitude-dev/latitude-llm/blob/main/cla.md' branch: 'signatures' allowlist: geclos,cesr,csansoon,andresgutgon,samulatitude,cesr,ferranrodriguez diff --git a/packages/core/src/constants.ts b/packages/core/src/constants.ts index c4494323b..d4e7a40af 100644 --- a/packages/core/src/constants.ts +++ b/packages/core/src/constants.ts @@ -9,6 +9,7 @@ import { TextStreamPart, } from 'ai' +import { ProviderLog } from './browser' import { Config } from './services/ai' export const LATITUDE_EVENT = 'latitudeEventsChannel' @@ -43,19 +44,27 @@ export type ChainStepTextResponse = { text: string usage: CompletionTokenUsage toolCalls: ToolCall[] + documentLogUuid: string + providerLog: undefined } export type ChainStepObjectResponse = { object: any text: string usage: CompletionTokenUsage + documentLogUuid: string + providerLog: undefined } -export type ChainTextResponse = ChainStepTextResponse & { - documentLogUuid: string +export type ChainTextResponse = Omit & { + providerLog: ProviderLog } -export type ChainObjectResponse = ChainStepObjectResponse & { - documentLogUuid: string +export type ChainObjectResponse = Omit< + ChainStepObjectResponse, + 'providerLog' +> & { + providerLog: ProviderLog } +export type ChainStepResponse = ChainStepTextResponse | ChainStepObjectResponse export type ChainCallResponse = ChainTextResponse | ChainObjectResponse export enum LogSources { @@ -89,7 +98,7 @@ type LatitudeEventData = } | { type: ChainEventTypes.StepComplete - response: ChainCallResponse + response: ChainStepResponse } | { type: ChainEventTypes.Complete diff --git a/packages/core/src/services/ai/index.ts b/packages/core/src/services/ai/index.ts index aa8ecd4ef..d575833b6 100644 --- a/packages/core/src/services/ai/index.ts +++ b/packages/core/src/services/ai/index.ts @@ -13,7 +13,12 @@ import { import { JSONSchema7 } from 'json-schema' import { v4 } from 'uuid' -import { LogSources, ProviderApiKey, Workspace } from '../../browser' +import { + LogSources, + ProviderApiKey, + ProviderLog, + Workspace, +} from '../../browser' import { cache } from '../../cache' import { publisher } from '../../events/publisher' import { createProviderLog } from '../providerLogs/create' @@ -49,7 +54,6 @@ export async function ai({ config, documentLogUuid, source, - onFinish, schema, output, transactionalLogs = false, @@ -64,97 +68,50 @@ export async function ai({ schema?: JSONSchema7 output?: 'object' | 'array' | 'no-schema' transactionalLogs?: boolean - onFinish?: FinishCallback }) { await checkDefaultProviderUsage({ provider: apiProvider, workspace }) const startTime = Date.now() - const { - provider, - token: apiKey, - id: providerId, - provider: providerType, - } = apiProvider + const { provider, token: apiKey } = apiProvider const model = config.model const m = createProvider({ provider, apiKey, config })(model) - const commonOptions = { model: m, prompt, messages: messages as CoreMessage[], } - - const createFinishHandler = (isStructured: boolean) => async (event: any) => { - const commonData = { - uuid: v4(), - source, - generatedAt: new Date(), - documentLogUuid, - providerId, - providerType, - model, - config, - messages, - toolCalls: event.toolCalls?.map((t: any) => ({ - id: t.toolCallId, - name: t.toolName, - arguments: t.args, - })), - usage: event.usage, - duration: Date.now() - startTime, - } - - const payload = { - type: 'aiProviderCallCompleted' as 'aiProviderCallCompleted', - data: { - ...commonData, - responseText: event.text, - responseObject: isStructured ? event.object : undefined, - }, - } - - publisher.publishLater({ - type: payload.type, - data: { - ...payload.data, - workspaceId: apiProvider.workspaceId, - }, - }) - - let providerLogUuid - if (transactionalLogs) { - const providerLog = await createProviderLog(payload.data).then((r) => - r.unwrap(), - ) - providerLogUuid = providerLog.uuid - } else { - const queues = await setupJobs() - queues.defaultQueue.jobs.enqueueCreateProviderLogJob(payload.data) - } - - onFinish?.({ ...event, providerLogUuid }) - } + const { onFinish, providerLog } = createFinishHandler({ + isStructured: !!schema && !!output, + startTime, + apiProvider, + source, + documentLogUuid, + messages, + config, + transactionalLogs, + }) if (schema && output) { const result = await streamObject({ ...commonOptions, schema: jsonSchema(schema), - // @ts-expect-error - output is vale but depending on the type of schema + // @ts-expect-error - output is valid but depending on the type of schema // there might be a mismatch (e.g you pass an object schema but the - // output is "array"). Not really an issue we need to defend atm + // output is "array"). Not really an issue we need to defend atm. output, - onFinish: createFinishHandler(true), + onFinish, }) return { fullStream: result.fullStream, object: result.object, usage: result.usage, + providerLog, } } else { const result = await streamText({ ...commonOptions, - onFinish: createFinishHandler(false), + onFinish, }) return { @@ -162,6 +119,7 @@ export async function ai({ text: result.text, usage: result.usage, toolCalls: result.toolCalls, + providerLog, } } } @@ -185,6 +143,83 @@ const checkDefaultProviderUsage = async ({ } } +const createFinishHandler = ({ + isStructured, + startTime, + apiProvider, + source, + messages, + config, + transactionalLogs, + documentLogUuid, +}: { + isStructured: boolean + startTime: number + apiProvider: ProviderApiKey + source: LogSources + messages: Message[] + config: PartialConfig + transactionalLogs: boolean + documentLogUuid?: string +}) => { + let resolveProviderLog: (value?: ProviderLog) => void + const providerLogPromise = new Promise((resolve) => { + resolveProviderLog = resolve + }) + + return { + providerLog: providerLogPromise, + onFinish: async (event: any) => { + const commonData = { + uuid: v4(), + source, + generatedAt: new Date(), + documentLogUuid, + providerId: apiProvider.id, + providerType: apiProvider.provider, + model: config.model, + config, + messages, + toolCalls: event.toolCalls?.map((t: any) => ({ + id: t.toolCallId, + name: t.toolName, + arguments: t.args, + })), + usage: event.usage, + duration: Date.now() - startTime, + } + + const payload = { + type: 'aiProviderCallCompleted' as 'aiProviderCallCompleted', + data: { + ...commonData, + responseText: event.text, + responseObject: isStructured ? event.object : undefined, + }, + } + + publisher.publishLater({ + type: payload.type, + data: { + ...payload.data, + workspaceId: apiProvider.workspaceId, + }, + }) + + if (transactionalLogs) { + const providerLog = await createProviderLog(payload.data).then((r) => + r.unwrap(), + ) + resolveProviderLog(providerLog) + } else { + const queues = await setupJobs() + queues.defaultQueue.jobs.enqueueCreateProviderLogJob(payload.data) + resolveProviderLog() + } + }, + } +} + export { estimateCost } from './estimateCost' export { validateConfig } from './helpers' export type { PartialConfig, Config } from './helpers' diff --git a/packages/core/src/services/chains/run.test.ts b/packages/core/src/services/chains/run.test.ts index c22de6a25..bee346ec8 100644 --- a/packages/core/src/services/chains/run.test.ts +++ b/packages/core/src/services/chains/run.test.ts @@ -26,6 +26,10 @@ describe('runChain', () => { text: Promise.resolve(text), usage: Promise.resolve({ totalTokens }), toolCalls: Promise.resolve([]), + providerLog: Promise.resolve({ + provider: 'openai', + model: 'gpt-3.5-turbo', + }), fullStream: new ReadableStream({ start(controller) { controller.enqueue({ type: 'text', text }) @@ -50,8 +54,8 @@ describe('runChain', () => { it('runs a chain without schema override', async () => { const mockAiResponse = createMockAiResponse('AI response', 10) - vi.spyOn(aiModule, 'ai').mockResolvedValue(mockAiResponse as any) + vi.spyOn(aiModule, 'ai').mockResolvedValue(mockAiResponse as any) vi.mocked(mockChain.step!).mockResolvedValue({ completed: true, conversation: { @@ -76,12 +80,14 @@ describe('runChain', () => { if (!result.ok) return const response = await result.value.response - expect(response).toEqual({ - documentLogUuid: expect.any(String), - text: 'AI response', - usage: { totalTokens: 10 }, - toolCalls: [], - }) + expect(response).toEqual( + expect.objectContaining({ + documentLogUuid: expect.any(String), + text: 'AI response', + usage: { totalTokens: 10 }, + toolCalls: [], + }), + ) expect(aiModule.ai).toHaveBeenCalledWith( expect.objectContaining({ @@ -104,6 +110,10 @@ describe('runChain', () => { const mockAiResponse = { object: Promise.resolve({ name: 'John', age: 30 }), usage: Promise.resolve({ totalTokens: 15 }), + providerLog: Promise.resolve({ + provider: 'openai', + model: 'gpt-3.5-turbo', + }), fullStream: new ReadableStream({ start(controller) { controller.enqueue({ @@ -145,12 +155,14 @@ describe('runChain', () => { if (!result.ok) return const response = await result.value.response - expect(response).toEqual({ - documentLogUuid: expect.any(String), - object: { name: 'John', age: 30 }, - text: '{"name":"John","age":30}', - usage: { totalTokens: 15 }, - }) + expect(response).toEqual( + expect.objectContaining({ + documentLogUuid: expect.any(String), + object: { name: 'John', age: 30 }, + text: '{"name":"John","age":30}', + usage: { totalTokens: 15 }, + }), + ) expect(aiModule.ai).toHaveBeenCalledWith( expect.objectContaining({ @@ -238,12 +250,14 @@ describe('runChain', () => { if (!result.ok) return const response = await result.value.response - expect(response).toEqual({ - documentLogUuid: expect.any(String), - text: 'AI response 2', - usage: { totalTokens: 15 }, - toolCalls: [], - }) + expect(response).toEqual( + expect.objectContaining({ + documentLogUuid: expect.any(String), + text: 'AI response 2', + usage: { totalTokens: 15 }, + toolCalls: [], + }), + ) expect(aiModule.ai).toHaveBeenCalledTimes(2) }) @@ -280,12 +294,14 @@ describe('runChain', () => { if (!result.ok) return const response = await result.value.response - expect(response).toEqual({ - documentLogUuid: expect.any(String), - text: 'AI response', - usage: { totalTokens: 10 }, - toolCalls: [], - }) + expect(response).toEqual( + expect.objectContaining({ + documentLogUuid: expect.any(String), + text: 'AI response', + usage: { totalTokens: 10 }, + toolCalls: [], + }), + ) expect(aiModule.ai).toHaveBeenCalledWith( expect.objectContaining({ @@ -312,6 +328,10 @@ describe('runChain', () => { const mockAiResponse = { object: Promise.resolve({ name: 'John', age: 30 }), usage: Promise.resolve({ totalTokens: 15 }), + providerLog: Promise.resolve({ + provider: 'openai', + model: 'gpt-3.5-turbo', + }), fullStream: new ReadableStream({ start(controller) { controller.enqueue({ @@ -353,12 +373,14 @@ describe('runChain', () => { if (!result.ok) return const response = await result.value.response - expect(response).toEqual({ - documentLogUuid: expect.any(String), - object: { name: 'John', age: 30 }, - text: '{"name":"John","age":30}', - usage: { totalTokens: 15 }, - }) + expect(response).toEqual( + expect.objectContaining({ + documentLogUuid: expect.any(String), + object: { name: 'John', age: 30 }, + text: '{"name":"John","age":30}', + usage: { totalTokens: 15 }, + }), + ) expect(aiModule.ai).toHaveBeenCalledWith( expect.objectContaining({ @@ -385,6 +407,10 @@ describe('runChain', () => { { name: 'John', age: 30 }, { name: 'Jane', age: 25 }, ]), + providerLog: Promise.resolve({ + provider: 'openai', + model: 'gpt-3.5-turbo', + }), usage: Promise.resolve({ totalTokens: 20 }), fullStream: new ReadableStream({ start(controller) { @@ -430,15 +456,17 @@ describe('runChain', () => { if (!result.ok) return const response = await result.value.response - expect(response).toEqual({ - documentLogUuid: expect.any(String), - object: [ - { name: 'John', age: 30 }, - { name: 'Jane', age: 25 }, - ], - text: '[{"name":"John","age":30},{"name":"Jane","age":25}]', - usage: { totalTokens: 20 }, - }) + expect(response).toEqual( + expect.objectContaining({ + documentLogUuid: expect.any(String), + object: [ + { name: 'John', age: 30 }, + { name: 'Jane', age: 25 }, + ], + text: '[{"name":"John","age":30},{"name":"Jane","age":25}]', + usage: { totalTokens: 20 }, + }), + ) expect(aiModule.ai).toHaveBeenCalledWith( expect.objectContaining({ @@ -482,12 +510,14 @@ describe('runChain', () => { if (!result.ok) return const response = await result.value.response - expect(response).toEqual({ - documentLogUuid: expect.any(String), - text: 'AI response without schema', - usage: { totalTokens: 10 }, - toolCalls: [], - }) + expect(response).toEqual( + expect.objectContaining({ + documentLogUuid: expect.any(String), + text: 'AI response without schema', + usage: { totalTokens: 10 }, + toolCalls: [], + }), + ) expect(aiModule.ai).toHaveBeenCalledWith( expect.objectContaining({ diff --git a/packages/core/src/services/chains/run.ts b/packages/core/src/services/chains/run.ts index ff9577438..e9c894d0b 100644 --- a/packages/core/src/services/chains/run.ts +++ b/packages/core/src/services/chains/run.ts @@ -10,12 +10,18 @@ import { ChainEvent, ChainEventTypes, ChainObjectResponse, + ChainStepResponse, ChainTextResponse, LogSources, ProviderData, StreamEventTypes, } from '../../constants' -import { NotFoundError, Result, UnprocessableEntityError } from '../../lib' +import { + LatitudeError, + NotFoundError, + Result, + UnprocessableEntityError, +} from '../../lib' import { streamToGenerator } from '../../lib/streamToGenerator' import { ai, Config, validateConfig } from '../ai' @@ -126,13 +132,21 @@ async function iterate({ await streamAIResult(controller, aiResult) - const response = await createChainResponse(aiResult, documentLogUuid) + const response = await createChainResponse({ + aiResult, + documentLogUuid, + step, + }) if (step.completed) { - await handleCompletedChain(controller, step, response) - return response + await handleCompletedChain( + controller, + step, + response as ChainCallResponse, + ) + return response as ChainCallResponse } else { - publishStepCompleteEvent(controller, response) + publishStepCompleteEvent(controller, response as ChainStepResponse) return iterate({ workspace, @@ -201,28 +215,46 @@ async function streamAIResult( } } -async function createChainResponse( - result: Awaited>, - documentLogUuid: string, -): Promise { - if (result.object) { +async function createChainResponse({ + step, + aiResult, + documentLogUuid, +}: { + step: { completed: boolean } + aiResult: Awaited> + documentLogUuid: string +}): Promise { + const providerLog = await aiResult.providerLog + if (step.completed && !providerLog) { + throw new LatitudeError( + 'The response completed but the provider log was not created!', + ) + } + + if (aiResult.object) { return { - text: objectToString(await result.object), - object: await result.object, - usage: await result.usage, + text: objectToString(await aiResult.object), + object: await aiResult.object, + usage: await aiResult.usage, documentLogUuid, - } + ...(providerLog ? { providerLog } : {}), + } as typeof providerLog extends undefined + ? ChainStepResponse + : ChainCallResponse } else { return { documentLogUuid, - text: await result.text, - usage: await result.usage, - toolCalls: (await result.toolCalls).map((t) => ({ + ...(providerLog ? { providerLog } : {}), + text: await aiResult.text, + usage: await aiResult.usage, + toolCalls: (await aiResult.toolCalls).map((t) => ({ id: t.toolCallId, name: t.toolName, arguments: t.args, })), - } + } as typeof providerLog extends undefined + ? ChainStepResponse + : ChainCallResponse } } @@ -264,7 +296,7 @@ async function handleCompletedChain( function publishStepCompleteEvent( controller: ReadableStreamDefaultController, - response: ChainCallResponse, + response: ChainStepResponse, ) { enqueueChainEvent(controller, { event: StreamEventTypes.Latitude, diff --git a/packages/core/src/services/commits/runDocumentAtCommit.test.ts b/packages/core/src/services/commits/runDocumentAtCommit.test.ts index 502bd08b9..294cdc416 100644 --- a/packages/core/src/services/commits/runDocumentAtCommit.test.ts +++ b/packages/core/src/services/commits/runDocumentAtCommit.test.ts @@ -163,6 +163,7 @@ This is a test document documentLogUuid: expect.any(String), text: 'Fake AI generated text', toolCalls: [], + providerLog: { uuid: 'fake-provider-log-uuid' }, usage: { promptTokens: 0, completionTokens: 0, totalTokens: 0 }, }, }, @@ -204,6 +205,7 @@ This is a test document text: 'Fake AI generated text', toolCalls: [], usage: { promptTokens: 0, completionTokens: 0, totalTokens: 0 }, + providerLog: { uuid: 'fake-provider-log-uuid' }, documentLogUuid: expect.any(String), }, }, @@ -242,6 +244,7 @@ const mocks = { return { text: Promise.resolve('Fake AI generated text'), + providerLog: Promise.resolve({ uuid: 'fake-provider-log-uuid' }), usage: Promise.resolve({ promptTokens: 0, completionTokens: 0, diff --git a/packages/core/src/services/documentLogs/addMessages/index.ts b/packages/core/src/services/documentLogs/addMessages/index.ts index d310880da..db51005d8 100644 --- a/packages/core/src/services/documentLogs/addMessages/index.ts +++ b/packages/core/src/services/documentLogs/addMessages/index.ts @@ -5,6 +5,8 @@ import { ChainCallResponse, ChainEvent, ChainEventTypes, + ChainObjectResponse, + ChainTextResponse, DocumentLog, LogSources, objectToString, @@ -124,6 +126,7 @@ async function streamMessageResponse({ object: await result.object, text: await result.text, usage: await result.usage, + providerLog: await result.providerLog, toolCalls: result.toolCalls ? (await result.toolCalls).map((t) => ({ id: t.toolCallId, @@ -131,7 +134,8 @@ async function streamMessageResponse({ arguments: t.args, })) : [], - } + } as ChainCallResponse + enqueueChainEvent(controller, { event: StreamEventTypes.Latitude, data: { @@ -140,13 +144,17 @@ async function streamMessageResponse({ messages: [ { role: MessageRole.assistant, - toolCalls: response.toolCalls, - content: response.text || objectToString(response.object), + toolCalls: (response as ChainTextResponse).toolCalls, + content: + response.text || + objectToString((response as ChainObjectResponse).object), }, ], response: { ...response, - text: response.text || objectToString(response.object), + text: + response.text || + objectToString((response as ChainObjectResponse).object), }, }, }) @@ -155,7 +163,9 @@ async function streamMessageResponse({ return { ...response, - text: response.text || objectToString(response.object), + text: + response.text || + objectToString((response as ChainObjectResponse).object), } } catch (e) { const error = e as Error diff --git a/packages/core/src/services/evaluations/run.ts b/packages/core/src/services/evaluations/run.ts index 8551f3222..e454daea3 100644 --- a/packages/core/src/services/evaluations/run.ts +++ b/packages/core/src/services/evaluations/run.ts @@ -113,7 +113,7 @@ export const runEvaluation = async ( data: { evaluationId: evaluation.id, documentLogUuid: documentLog.uuid, - providerLogUuid: lastProviderLog.uuid, + providerLogUuid: response.providerLog.uuid, response, }, }) diff --git a/packages/core/src/services/providerLogs/create.ts b/packages/core/src/services/providerLogs/create.ts index a6856a236..7d608472a 100644 --- a/packages/core/src/services/providerLogs/create.ts +++ b/packages/core/src/services/providerLogs/create.ts @@ -60,6 +60,7 @@ export async function createProviderLog( estimateCost({ provider: providerType, model, usage }) * TO_MILLICENTS_FACTOR, ) + const inserts = await trx .insert(providerLogs) .values({