Skip to content

Commit

Permalink
[Obs AI Assistant] Refactor complete/chat
Browse files Browse the repository at this point in the history
  • Loading branch information
dgieselaar committed Apr 21, 2024
1 parent 4744879 commit d005850
Show file tree
Hide file tree
Showing 35 changed files with 1,224 additions and 860 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,17 @@ export type StreamingChatResponseEvent =
| ConversationCreateEvent
| ConversationUpdateEvent
| MessageAddEvent
| ChatCompletionErrorEvent;
| ChatCompletionErrorEvent
| TokenCountEvent;

export type StreamingChatResponseEventWithoutError = Exclude<
StreamingChatResponseEvent,
ChatCompletionErrorEvent
>;

export type ChatEvent = ChatCompletionChunkEvent | TokenCountEvent;
export type MessageOrChatEvent = ChatEvent | MessageAddEvent;

export enum ChatCompletionErrorCode {
InternalError = 'internalError',
NotFoundError = 'notFoundError',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ export function createFunctionRequestMessage({
args,
}: {
name: string;
args: unknown;
args?: Record<string, any>;
}): MessageAddEvent {
return {
id: v4(),
Expand All @@ -28,6 +28,7 @@ export function createFunctionRequestMessage({
trigger: MessageRole.Assistant as const,
},
role: MessageRole.Assistant,
content: '',
},
},
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@ export function createFunctionResponseError({
name: error.name,
message: error.message,
cause: error.cause,
stack: error.stack,
},
message: message || error.message,
},
data: {
stack: error.stack,
},
});
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,22 @@
* 2.0.
*/

import { concat, from, last, mergeMap, Observable, shareReplay, withLatestFrom } from 'rxjs';
import {
concat,
from,
last,
mergeMap,
Observable,
OperatorFunction,
shareReplay,
withLatestFrom,
} from 'rxjs';
import { withoutTokenCountEvents } from './without_token_count_events';
import {
ChatCompletionChunkEvent,
ChatEvent,
MessageAddEvent,
MessageOrChatEvent,
StreamingChatResponseEventType,
} from '../conversation_complete';
import {
Expand Down Expand Up @@ -42,18 +54,19 @@ function mergeWithEditedMessage(

export function emitWithConcatenatedMessage(
callback?: ConcatenateMessageCallback
): (
source$: Observable<ChatCompletionChunkEvent>
) => Observable<ChatCompletionChunkEvent | MessageAddEvent> {
return (source$: Observable<ChatCompletionChunkEvent>) => {
): OperatorFunction<ChatEvent, MessageOrChatEvent> {
return (source$) => {
const shared = source$.pipe(shareReplay());

const withoutTokenCount$ = shared.pipe(withoutTokenCountEvents());

const response$ = concat(
shared,
shared.pipe(
withoutTokenCountEvents(),
concatenateChatCompletionChunks(),
last(),
withLatestFrom(source$),
withLatestFrom(withoutTokenCount$),
mergeMap(([message, chunkEvent]) => {
return mergeWithEditedMessage(message, chunkEvent, callback);
})
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import { filter, OperatorFunction } from 'rxjs';
import {
StreamingChatResponseEvent,
StreamingChatResponseEventType,
TokenCountEvent,
} from '../conversation_complete';

export function withoutTokenCountEvents<T extends StreamingChatResponseEvent>(): OperatorFunction<
T,
Exclude<T, TokenCountEvent>
> {
return filter(
(event): event is Exclude<T, TokenCountEvent> =>
event.type !== StreamingChatResponseEventType.TokenCount
);
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import { noop } from 'lodash';
import React from 'react';
import { Observable, of } from 'rxjs';
import type { StreamingChatResponseEventWithoutError } from '../common/conversation_complete';
import { ScreenContextActionDefinition } from '../common/types';
import { MessageRole, ScreenContextActionDefinition } from '../common/types';
import type { ObservabilityAIAssistantAPIClient } from './api';
import type {
ObservabilityAIAssistantChatService,
Expand All @@ -34,6 +34,13 @@ export const mockChatService: ObservabilityAIAssistantChatService = {
),
hasFunction: () => true,
hasRenderFunction: () => true,
getSystemMessage: () => ({
'@timestamp': new Date().toISOString(),
message: {
role: MessageRole.System,
content: 'System',
},
}),
};

export const mockService: ObservabilityAIAssistantService = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ describe('complete', () => {
'@timestamp': expect.any(String),
message: {
content: expect.any(String),
data: expect.any(String),
name: 'my_action',
role: MessageRole.User,
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ import {
StreamingChatResponseEventType,
type StreamingChatResponseEventWithoutError,
type StreamingChatResponseEvent,
TokenCountEvent,
} from '../../common/conversation_complete';
import {
FunctionRegistry,
Expand Down Expand Up @@ -163,13 +162,7 @@ export async function createChatService({

const subscription = toObservable(response)
.pipe(
map(
(line) =>
JSON.parse(line) as
| StreamingChatResponseEvent
| BufferFlushEvent
| TokenCountEvent
),
map((line) => JSON.parse(line) as StreamingChatResponseEvent | BufferFlushEvent),
filter(
(line): line is StreamingChatResponseEvent =>
line.type !== StreamingChatResponseEventType.BufferFlush &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import { i18n } from '@kbn/i18n';
import { noop } from 'lodash';
import React from 'react';
import { Observable, of } from 'rxjs';
import { MessageRole } from '.';
import type { StreamingChatResponseEventWithoutError } from '../common/conversation_complete';
import type { ObservabilityAIAssistantAPIClient } from './api';
import type { ObservabilityAIAssistantChatService, ObservabilityAIAssistantService } from './types';
Expand All @@ -28,6 +29,13 @@ export const createStorybookChatService = (): ObservabilityAIAssistantChatServic
),
hasFunction: () => true,
hasRenderFunction: () => true,
getSystemMessage: () => ({
'@timestamp': new Date().toISOString(),
message: {
role: MessageRole.System,
content: 'System',
},
}),
});

export const createStorybookService = (): ObservabilityAIAssistantService => ({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import {
MessageAddEvent,
StreamingChatResponseEvent,
StreamingChatResponseEventType,
TokenCountEvent,
} from '../../common/conversation_complete';
import { ObservabilityAIAssistantScreenContext } from '../../common/types';
import { concatenateChatCompletionChunks } from '../../common/utils/concatenate_chat_completion_chunks';
Expand Down Expand Up @@ -240,10 +239,7 @@ export class KibanaClient {
.split('\n')
.map((line) => line.trim())
.filter(Boolean)
.map(
(line) =>
JSON.parse(line) as StreamingChatResponseEvent | BufferFlushEvent | TokenCountEvent
)
.map((line) => JSON.parse(line) as StreamingChatResponseEvent | BufferFlushEvent)
),
filter(
(line): line is ChatCompletionChunkEvent | ChatCompletionErrorEvent =>
Expand Down Expand Up @@ -330,13 +326,7 @@ export class KibanaClient {
.split('\n')
.map((line) => line.trim())
.filter(Boolean)
.map(
(line) =>
JSON.parse(line) as
| StreamingChatResponseEvent
| BufferFlushEvent
| TokenCountEvent
)
.map((line) => JSON.parse(line) as StreamingChatResponseEvent | BufferFlushEvent)
),
filter(
(event): event is MessageAddEvent | ConversationCreateEvent =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import { concatenateChatCompletionChunks } from '../../common/utils/concatenate_
import { createFunctionResponseMessage } from '../../common/utils/create_function_response_message';
import { RecallRanking, RecallRankingEventType } from '../analytics/recall_ranking';
import type { ObservabilityAIAssistantClient } from '../service/client';
import { ChatFn } from '../service/types';
import { FunctionCallChatFunction } from '../service/types';
import { parseSuggestionScores } from './parse_suggestion_scores';

const MAX_TOKEN_COUNT_FOR_DATA_ON_SCREEN = 1000;
Expand Down Expand Up @@ -61,7 +61,7 @@ export function registerContextFunction({
required: ['queries', 'categories'],
} as const,
},
async ({ arguments: args, messages, connectorId, screenContexts, chat }, signal) => {
async ({ arguments: args, messages, screenContexts, chat }, signal) => {
const { analytics } = (await resources.context.core).coreStart;

const { queries, categories } = args;
Expand Down Expand Up @@ -118,7 +118,6 @@ export function registerContextFunction({
queries: queriesOrUserPrompt,
messages,
chat,
connectorId,
signal,
logger: resources.logger,
});
Expand Down Expand Up @@ -209,15 +208,13 @@ async function scoreSuggestions({
messages,
queries,
chat,
connectorId,
signal,
logger,
}: {
suggestions: Awaited<ReturnType<typeof retrieveSuggestions>>;
messages: Message[];
queries: string[];
chat: ChatFn;
connectorId: string;
chat: FunctionCallChatFunction;
signal: AbortSignal;
logger: Logger;
}) {
Expand Down Expand Up @@ -274,15 +271,12 @@ async function scoreSuggestions({
};

const response = await lastValueFrom(
(
await chat('score_suggestions', {
connectorId,
messages: [...messages.slice(0, -2), newUserMessage],
functions: [scoreFunction],
functionCall: 'score',
signal,
})
).pipe(concatenateChatCompletionChunks())
chat('score_suggestions', {
messages: [...messages.slice(0, -2), newUserMessage],
functions: [scoreFunction],
functionCall: 'score',
signal,
}).pipe(concatenateChatCompletionChunks())
);

const scoreFunctionRequest = decodeOrThrow(scoreFunctionRequestRt)(response);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
* 2.0.
*/
import datemath from '@elastic/datemath';
import type { DataViewsServerPluginStart } from '@kbn/data-views-plugin/server';
import type { ElasticsearchClient, SavedObjectsClientContract } from '@kbn/core/server';
import type { DataViewsServerPluginStart } from '@kbn/data-views-plugin/server';
import { castArray, chunk, groupBy, uniq } from 'lodash';
import { lastValueFrom, Observable } from 'rxjs';
import type { ObservabilityAIAssistantClient } from '../../service/client';
import { type ChatCompletionChunkEvent, type Message, MessageRole } from '../../../common';
import { lastValueFrom } from 'rxjs';
import { MessageRole, type Message } from '../../../common';
import { concatenateChatCompletionChunks } from '../../../common/utils/concatenate_chat_completion_chunks';
import { FunctionCallChatFunction } from '../../service/types';

export async function getRelevantFieldNames({
index,
Expand All @@ -22,6 +22,7 @@ export async function getRelevantFieldNames({
savedObjectsClient,
chat,
messages,
signal,
}: {
index: string | string[];
start?: string;
Expand All @@ -30,13 +31,8 @@ export async function getRelevantFieldNames({
esClient: ElasticsearchClient;
savedObjectsClient: SavedObjectsClientContract;
messages: Message[];
chat: (
name: string,
{}: Pick<
Parameters<ObservabilityAIAssistantClient['chat']>[1],
'functionCall' | 'functions' | 'messages'
>
) => Promise<Observable<ChatCompletionChunkEvent>>;
chat: FunctionCallChatFunction;
signal: AbortSignal;
}): Promise<{ fields: string[] }> {
const dataViewsService = await dataViews.dataViewsServiceFactory(savedObjectsClient, esClient);

Expand Down Expand Up @@ -79,6 +75,7 @@ export async function getRelevantFieldNames({
chunk(fieldNames, 500).map(async (fieldsInChunk) => {
const chunkResponse$ = (
await chat('get_relevent_dataset_names', {
signal,
messages: [
{
'@timestamp': new Date().toISOString(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ export function registerGetDatasetInfoFunction({
required: ['index'],
} as const,
},
async ({ arguments: { index }, messages, connectorId, chat }, signal) => {
async ({ arguments: { index }, messages, chat }, signal) => {
const coreContext = await resources.context.core;

const esClient = coreContext.elasticsearch.client.asCurrentUser;
Expand Down Expand Up @@ -83,18 +83,8 @@ export function registerGetDatasetInfoFunction({
esClient,
dataViews: await resources.plugins.dataViews.start(),
savedObjectsClient,
chat: (
operationName,
{ messages: nextMessages, functionCall, functions: nextFunctions }
) => {
return chat(operationName, {
messages: nextMessages,
functionCall,
functions: nextFunctions,
connectorId,
signal,
});
},
signal,
chat,
});

return {
Expand Down
Loading

0 comments on commit d005850

Please sign in to comment.