Skip to content

Commit

Permalink
[Search] [Playground] Switch off condensing question flow for first q…
Browse files Browse the repository at this point in the history
…uestion (#180222)

Two changes where:
- condensing question LLM flow is rewriting the question with quotes and
breaking the parsing. This escapes quotes.
- skipping the condensing question function when theres no chat history
  • Loading branch information
joemcelroy authored Apr 5, 2024
1 parent c73a83e commit 7abe4c2
Show file tree
Hide file tree
Showing 4 changed files with 165 additions and 41 deletions.
20 changes: 20 additions & 0 deletions x-pack/plugins/search_playground/server/routes.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import { createRetriever } from './routes';

describe('createRetriever', () => {
test('works when the question has quotes', () => {
const esQuery = '{"query": {"match": {"text": "{query}"}}}';
const question = 'How can I "do something" with quotes?';

const retriever = createRetriever(esQuery);
const result = retriever(question);

expect(result).toEqual({ match: { text: 'How can I "do something" with quotes?' } });
});
});
21 changes: 12 additions & 9 deletions x-pack/plugins/search_playground/server/routes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,17 @@ import { Prompt } from '../common/prompt';
import { errorHandler } from './utils/error_handler';
import { APIRoutes } from './types';

export function createRetriever(esQuery: string) {
return (question: string) => {
try {
const query = JSON.parse(esQuery.replace(/{query}/g, question.replace(/"/g, '\\"')));
return query.query;
} catch (e) {
throw Error(e);
}
};
}

export function defineRoutes({ log, router }: { log: Logger; router: IRouter }) {
router.post(
{
Expand Down Expand Up @@ -76,15 +87,7 @@ export function defineRoutes({ log, router }: { log: Logger; router: IRouter })
model,
rag: {
index: data.indices,
retriever: (question: string) => {
try {
const query = JSON.parse(data.elasticsearchQuery.replace(/{query}/g, question));
return query.query;
} catch (e) {
log.error('Failed to parse the Elasticsearch query', e);
throw Error(e);
}
},
retriever: createRetriever(data.elasticsearchQuery),
content_field: sourceFields,
size: Number(data.docSize),
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,16 @@ import { createAssist as Assist } from './assist';
import { ConversationalChain } from './conversational_chain';
import { FakeListLLM } from 'langchain/llms/fake';
import { BaseChatModel } from '@langchain/core/language_models/chat_models';
import { Message } from 'ai';

describe('conversational chain', () => {
it('should be able to create a conversational chain', async () => {
const createTestChain = async (
responses: string[],
chat: Message[],
expectedFinalAnswer: string,
expectedDocs: any,
expectedSearchRequest: any
) => {
const searchMock = jest.fn().mockImplementation(() => {
return {
hits: {
Expand Down Expand Up @@ -41,7 +48,7 @@ describe('conversational chain', () => {
};

const llm = new FakeListLLM({
responses: ['question rewritten to work from home', 'the final answer'],
responses,
});

const aiClient = Assist({
Expand All @@ -65,13 +72,7 @@ describe('conversational chain', () => {
prompt: 'you are a QA bot',
});

const stream = await conversationalChain.stream(aiClient, [
{
id: '1',
role: 'user',
content: 'what is the work from home policy?',
},
]);
const stream = await conversationalChain.stream(aiClient, chat);

const streamToValue: string[] = await new Promise((resolve) => {
const reader = stream.getReader();
Expand All @@ -94,26 +95,122 @@ describe('conversational chain', () => {
const textValue = streamToValue
.filter((v) => v[0] === '0')
.reduce((acc, v) => acc + v.replace(/0:"(.*)"\n/, '$1'), '');
expect(textValue).toEqual('the final answer');
expect(textValue).toEqual(expectedFinalAnswer);

const docValue = streamToValue
.filter((v) => v[0] === '8')
.reduce((acc, v) => acc + v.replace(/8:(.*)\n/, '$1'), '');
expect(JSON.parse(docValue)).toEqual([
{
documents: [
{ metadata: { id: '1', index: 'index' }, pageContent: 'value' },
{ metadata: { id: '1', index: 'website' }, pageContent: 'value2' },
],
type: 'retrieved_docs',
},
]);
expect(searchMock.mock.calls[0]).toEqual([
{
index: 'index,website',
query: { query: { match: { field: 'question rewritten to work from home' } } },
size: 3,
},
]);
expect(JSON.parse(docValue)).toEqual(expectedDocs);
expect(searchMock.mock.calls[0]).toEqual(expectedSearchRequest);
};

it('should be able to create a conversational chain', async () => {
await createTestChain(
['the final answer'],
[
{
id: '1',
role: 'user',
content: 'what is the work from home policy?',
},
],
'the final answer',
[
{
documents: [
{ metadata: { id: '1', index: 'index' }, pageContent: 'value' },
{ metadata: { id: '1', index: 'website' }, pageContent: 'value2' },
],
type: 'retrieved_docs',
},
],
[
{
index: 'index,website',
query: { query: { match: { field: 'what is the work from home policy?' } } },
size: 3,
},
]
);
});

it('asking with chat history should re-write the question', async () => {
await createTestChain(
['rewrite the question', 'the final answer'],
[
{
id: '1',
role: 'user',
content: 'what is the work from home policy?',
},
{
id: '2',
role: 'assistant',
content: 'the final answer',
},
{
id: '3',
role: 'user',
content: 'what is the work from home policy?',
},
],
'the final answer',
[
{
documents: [
{ metadata: { id: '1', index: 'index' }, pageContent: 'value' },
{ metadata: { id: '1', index: 'website' }, pageContent: 'value2' },
],
type: 'retrieved_docs',
},
],
[
{
index: 'index,website',
query: { query: { match: { field: 'rewrite the question' } } },
size: 3,
},
]
);
});

it('should cope with quotes in the query', async () => {
await createTestChain(
['rewrite "the" question', 'the final answer'],
[
{
id: '1',
role: 'user',
content: 'what is the work from home policy?',
},
{
id: '2',
role: 'assistant',
content: 'the final answer',
},
{
id: '3',
role: 'user',
content: 'what is the work from home policy?',
},
],
'the final answer',
[
{
documents: [
{ metadata: { id: '1', index: 'index' }, pageContent: 'value' },
{ metadata: { id: '1', index: 'website' }, pageContent: 'value2' },
],
type: 'retrieved_docs',
},
],
[
{
index: 'index,website',
query: { query: { match: { field: 'rewrite "the" question' } } },
size: 3,
},
]
);
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class ConversationalChainFn {
const question = messages[messages.length - 1]!.content;
const retrievedDocs: Document[] = [];

let retrievalChain: Runnable = RunnableLambda.from((input) => '');
let retrievalChain: Runnable = RunnableLambda.from(() => '');

if (this.options.rag) {
const retriever = new ElasticsearchRetriever({
Expand Down Expand Up @@ -107,11 +107,15 @@ class ConversationalChainFn {
retrievalChain = retriever.pipe(buildContext);
}

const standaloneQuestionChain = RunnableSequence.from([
condenseQuestionPrompt,
this.options.model,
new StringOutputParser(),
]);
let standaloneQuestionChain: Runnable = RunnableLambda.from((input) => input.question);

if (previousMessages.length > 0) {
standaloneQuestionChain = RunnableSequence.from([
condenseQuestionPrompt,
this.options.model,
new StringOutputParser(),
]);
}

const prompt = ChatPromptTemplate.fromTemplate(this.options.prompt);

Expand Down

0 comments on commit 7abe4c2

Please sign in to comment.