Skip to content

Commit

Permalink
[8.14] [Obs AI Assistant] Refactor ObservabilityAIAssistantClient (#1…
Browse files Browse the repository at this point in the history
…81255) (#182237)

# Backport

This will backport the following commits from `main` to `8.14`:
- [[Obs AI Assistant] Refactor ObservabilityAIAssistantClient
(#181255)](#181255)

<!--- Backport version: 7.3.2 -->

### Questions ?
Please refer to the [Backport tool
documentation](https://github.com/sqren/backport)

<!--BACKPORT {commits} BACKPORT-->
  • Loading branch information
dgieselaar authored May 1, 2024
1 parent 793d051 commit 95e97f1
Show file tree
Hide file tree
Showing 37 changed files with 1,271 additions and 880 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,17 @@ export type StreamingChatResponseEvent =
| ConversationCreateEvent
| ConversationUpdateEvent
| MessageAddEvent
| ChatCompletionErrorEvent;
| ChatCompletionErrorEvent
| TokenCountEvent;

export type StreamingChatResponseEventWithoutError = Exclude<
StreamingChatResponseEvent,
ChatCompletionErrorEvent
>;

export type ChatEvent = ChatCompletionChunkEvent | TokenCountEvent;
export type MessageOrChatEvent = ChatEvent | MessageAddEvent;

export enum ChatCompletionErrorCode {
InternalError = 'internalError',
NotFoundError = 'notFoundError',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ export function createFunctionRequestMessage({
args,
}: {
name: string;
args: unknown;
args?: Record<string, any>;
}): MessageAddEvent {
return {
id: v4(),
Expand All @@ -28,6 +28,7 @@ export function createFunctionRequestMessage({
trigger: MessageRole.Assistant as const,
},
role: MessageRole.Assistant,
content: '',
},
},
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@ export function createFunctionResponseError({
name: error.name,
message: error.message,
cause: error.cause,
stack: error.stack,
},
message: message || error.message,
},
data: {
stack: error.stack,
},
});
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,20 @@
* 2.0.
*/

import { concat, from, last, mergeMap, Observable, shareReplay, withLatestFrom } from 'rxjs';
import {
concat,
from,
last,
mergeMap,
Observable,
OperatorFunction,
shareReplay,
withLatestFrom,
} from 'rxjs';
import { withoutTokenCountEvents } from './without_token_count_events';
import {
ChatCompletionChunkEvent,
ChatEvent,
MessageAddEvent,
StreamingChatResponseEventType,
} from '../conversation_complete';
Expand Down Expand Up @@ -40,20 +51,21 @@ function mergeWithEditedMessage(
);
}

export function emitWithConcatenatedMessage(
export function emitWithConcatenatedMessage<T extends ChatEvent>(
callback?: ConcatenateMessageCallback
): (
source$: Observable<ChatCompletionChunkEvent>
) => Observable<ChatCompletionChunkEvent | MessageAddEvent> {
return (source$: Observable<ChatCompletionChunkEvent>) => {
): OperatorFunction<T, T | MessageAddEvent> {
return (source$) => {
const shared = source$.pipe(shareReplay());

const withoutTokenCount$ = shared.pipe(withoutTokenCountEvents());

const response$ = concat(
shared,
shared.pipe(
withoutTokenCountEvents(),
concatenateChatCompletionChunks(),
last(),
withLatestFrom(source$),
withLatestFrom(withoutTokenCount$),
mergeMap(([message, chunkEvent]) => {
return mergeWithEditedMessage(message, chunkEvent, callback);
})
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import { filter, OperatorFunction } from 'rxjs';
import {
StreamingChatResponseEvent,
StreamingChatResponseEventType,
TokenCountEvent,
} from '../conversation_complete';

export function withoutTokenCountEvents<T extends StreamingChatResponseEvent>(): OperatorFunction<
T,
Exclude<T, TokenCountEvent>
> {
return filter(
(event): event is Exclude<T, TokenCountEvent> =>
event.type !== StreamingChatResponseEventType.TokenCount
);
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ export const mockChatService: ObservabilityAIAssistantChatService = {
'@timestamp': new Date().toISOString(),
message: {
role: MessageRole.System,
content: '',
content: 'System',
},
}),
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ describe('complete', () => {
'@timestamp': expect.any(String),
message: {
content: expect.any(String),
data: expect.any(String),
name: 'my_action',
role: MessageRole.User,
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ import {
StreamingChatResponseEventType,
type StreamingChatResponseEventWithoutError,
type StreamingChatResponseEvent,
TokenCountEvent,
} from '../../common/conversation_complete';
import {
FunctionRegistry,
Expand Down Expand Up @@ -163,13 +162,7 @@ export async function createChatService({

const subscription = toObservable(response)
.pipe(
map(
(line) =>
JSON.parse(line) as
| StreamingChatResponseEvent
| BufferFlushEvent
| TokenCountEvent
),
map((line) => JSON.parse(line) as StreamingChatResponseEvent | BufferFlushEvent),
filter(
(line): line is StreamingChatResponseEvent =>
line.type !== StreamingChatResponseEventType.BufferFlush &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ export const createStorybookChatService = (): ObservabilityAIAssistantChatServic
'@timestamp': new Date().toISOString(),
message: {
role: MessageRole.System,
content: '',
content: 'System',
},
}),
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import { concatenateChatCompletionChunks } from '../../common/utils/concatenate_
import { createFunctionResponseMessage } from '../../common/utils/create_function_response_message';
import { RecallRanking, RecallRankingEventType } from '../analytics/recall_ranking';
import type { ObservabilityAIAssistantClient } from '../service/client';
import { ChatFn } from '../service/types';
import { FunctionCallChatFunction } from '../service/types';
import { parseSuggestionScores } from './parse_suggestion_scores';

const MAX_TOKEN_COUNT_FOR_DATA_ON_SCREEN = 1000;
Expand Down Expand Up @@ -61,7 +61,7 @@ export function registerContextFunction({
required: ['queries', 'categories'],
} as const,
},
async ({ arguments: args, messages, connectorId, screenContexts, chat }, signal) => {
async ({ arguments: args, messages, screenContexts, chat }, signal) => {
const { analytics } = (await resources.context.core).coreStart;

const { queries, categories } = args;
Expand Down Expand Up @@ -118,7 +118,6 @@ export function registerContextFunction({
queries: queriesOrUserPrompt,
messages,
chat,
connectorId,
signal,
logger: resources.logger,
});
Expand Down Expand Up @@ -209,15 +208,13 @@ async function scoreSuggestions({
messages,
queries,
chat,
connectorId,
signal,
logger,
}: {
suggestions: Awaited<ReturnType<typeof retrieveSuggestions>>;
messages: Message[];
queries: string[];
chat: ChatFn;
connectorId: string;
chat: FunctionCallChatFunction;
signal: AbortSignal;
logger: Logger;
}) {
Expand Down Expand Up @@ -274,15 +271,12 @@ async function scoreSuggestions({
};

const response = await lastValueFrom(
(
await chat('score_suggestions', {
connectorId,
messages: [...messages.slice(0, -2), newUserMessage],
functions: [scoreFunction],
functionCall: 'score',
signal,
})
).pipe(concatenateChatCompletionChunks())
chat('score_suggestions', {
messages: [...messages.slice(0, -2), newUserMessage],
functions: [scoreFunction],
functionCall: 'score',
signal,
}).pipe(concatenateChatCompletionChunks())
);

const scoreFunctionRequest = decodeOrThrow(scoreFunctionRequestRt)(response);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
* 2.0.
*/
import datemath from '@elastic/datemath';
import type { DataViewsServerPluginStart } from '@kbn/data-views-plugin/server';
import type { ElasticsearchClient, SavedObjectsClientContract } from '@kbn/core/server';
import type { DataViewsServerPluginStart } from '@kbn/data-views-plugin/server';
import { castArray, chunk, groupBy, uniq } from 'lodash';
import { lastValueFrom, Observable } from 'rxjs';
import type { ObservabilityAIAssistantClient } from '../../service/client';
import { type ChatCompletionChunkEvent, type Message, MessageRole } from '../../../common';
import { lastValueFrom } from 'rxjs';
import { MessageRole, type Message } from '../../../common';
import { concatenateChatCompletionChunks } from '../../../common/utils/concatenate_chat_completion_chunks';
import { FunctionCallChatFunction } from '../../service/types';

export async function getRelevantFieldNames({
index,
Expand All @@ -22,6 +22,7 @@ export async function getRelevantFieldNames({
savedObjectsClient,
chat,
messages,
signal,
}: {
index: string | string[];
start?: string;
Expand All @@ -30,13 +31,8 @@ export async function getRelevantFieldNames({
esClient: ElasticsearchClient;
savedObjectsClient: SavedObjectsClientContract;
messages: Message[];
chat: (
name: string,
{}: Pick<
Parameters<ObservabilityAIAssistantClient['chat']>[1],
'functionCall' | 'functions' | 'messages'
>
) => Promise<Observable<ChatCompletionChunkEvent>>;
chat: FunctionCallChatFunction;
signal: AbortSignal;
}): Promise<{ fields: string[] }> {
const dataViewsService = await dataViews.dataViewsServiceFactory(savedObjectsClient, esClient);

Expand Down Expand Up @@ -79,6 +75,7 @@ export async function getRelevantFieldNames({
chunk(fieldNames, 500).map(async (fieldsInChunk) => {
const chunkResponse$ = (
await chat('get_relevent_dataset_names', {
signal,
messages: [
{
'@timestamp': new Date().toISOString(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ export function registerGetDatasetInfoFunction({
required: ['index'],
} as const,
},
async ({ arguments: { index }, messages, connectorId, chat }, signal) => {
async ({ arguments: { index }, messages, chat }, signal) => {
const coreContext = await resources.context.core;

const esClient = coreContext.elasticsearch.client.asCurrentUser;
Expand Down Expand Up @@ -83,18 +83,8 @@ export function registerGetDatasetInfoFunction({
esClient,
dataViews: await resources.plugins.dataViews.start(),
savedObjectsClient,
chat: (
operationName,
{ messages: nextMessages, functionCall, functions: nextFunctions }
) => {
return chat(operationName, {
messages: nextMessages,
functionCall,
functions: nextFunctions,
connectorId,
signal,
});
},
signal,
chat,
});

return {
Expand Down
Loading

0 comments on commit 95e97f1

Please sign in to comment.