From 7abe4c2b5ac84c0ea398a06934b7066336c8b18b Mon Sep 17 00:00:00 2001 From: Joe McElroy Date: Fri, 5 Apr 2024 22:37:36 +0100 Subject: [PATCH] [Search] [Playground] Switch off condensing question flow for first question (#180222) Two changes where: - condensing question LLM flow is rewriting the question with quotes and breaking the parsing. This escapes quotes. - skipping the condensing question function when theres no chat history --- .../search_playground/server/routes.test.ts | 20 +++ .../search_playground/server/routes.ts | 21 +-- .../server/utils/conversational_chain.test.ts | 149 +++++++++++++++--- .../server/utils/conversational_chain.ts | 16 +- 4 files changed, 165 insertions(+), 41 deletions(-) create mode 100644 x-pack/plugins/search_playground/server/routes.test.ts diff --git a/x-pack/plugins/search_playground/server/routes.test.ts b/x-pack/plugins/search_playground/server/routes.test.ts new file mode 100644 index 00000000000000..21467ab221162a --- /dev/null +++ b/x-pack/plugins/search_playground/server/routes.test.ts @@ -0,0 +1,20 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { createRetriever } from './routes'; + +describe('createRetriever', () => { + test('works when the question has quotes', () => { + const esQuery = '{"query": {"match": {"text": "{query}"}}}'; + const question = 'How can I "do something" with quotes?'; + + const retriever = createRetriever(esQuery); + const result = retriever(question); + + expect(result).toEqual({ match: { text: 'How can I "do something" with quotes?' } }); + }); +}); diff --git a/x-pack/plugins/search_playground/server/routes.ts b/x-pack/plugins/search_playground/server/routes.ts index fe2631065c281c..17025f0101838e 100644 --- a/x-pack/plugins/search_playground/server/routes.ts +++ b/x-pack/plugins/search_playground/server/routes.ts @@ -17,6 +17,17 @@ import { Prompt } from '../common/prompt'; import { errorHandler } from './utils/error_handler'; import { APIRoutes } from './types'; +export function createRetriever(esQuery: string) { + return (question: string) => { + try { + const query = JSON.parse(esQuery.replace(/{query}/g, question.replace(/"/g, '\\"'))); + return query.query; + } catch (e) { + throw Error(e); + } + }; +} + export function defineRoutes({ log, router }: { log: Logger; router: IRouter }) { router.post( { @@ -76,15 +87,7 @@ export function defineRoutes({ log, router }: { log: Logger; router: IRouter }) model, rag: { index: data.indices, - retriever: (question: string) => { - try { - const query = JSON.parse(data.elasticsearchQuery.replace(/{query}/g, question)); - return query.query; - } catch (e) { - log.error('Failed to parse the Elasticsearch query', e); - throw Error(e); - } - }, + retriever: createRetriever(data.elasticsearchQuery), content_field: sourceFields, size: Number(data.docSize), }, diff --git a/x-pack/plugins/search_playground/server/utils/conversational_chain.test.ts b/x-pack/plugins/search_playground/server/utils/conversational_chain.test.ts index 16c7b546130654..c10bda602e67dd 100644 --- a/x-pack/plugins/search_playground/server/utils/conversational_chain.test.ts +++ b/x-pack/plugins/search_playground/server/utils/conversational_chain.test.ts @@ -10,9 +10,16 @@ import { createAssist as Assist } from './assist'; import { ConversationalChain } from './conversational_chain'; import { FakeListLLM } from 'langchain/llms/fake'; import { BaseChatModel } from '@langchain/core/language_models/chat_models'; +import { Message } from 'ai'; describe('conversational chain', () => { - it('should be able to create a conversational chain', async () => { + const createTestChain = async ( + responses: string[], + chat: Message[], + expectedFinalAnswer: string, + expectedDocs: any, + expectedSearchRequest: any + ) => { const searchMock = jest.fn().mockImplementation(() => { return { hits: { @@ -41,7 +48,7 @@ describe('conversational chain', () => { }; const llm = new FakeListLLM({ - responses: ['question rewritten to work from home', 'the final answer'], + responses, }); const aiClient = Assist({ @@ -65,13 +72,7 @@ describe('conversational chain', () => { prompt: 'you are a QA bot', }); - const stream = await conversationalChain.stream(aiClient, [ - { - id: '1', - role: 'user', - content: 'what is the work from home policy?', - }, - ]); + const stream = await conversationalChain.stream(aiClient, chat); const streamToValue: string[] = await new Promise((resolve) => { const reader = stream.getReader(); @@ -94,26 +95,122 @@ describe('conversational chain', () => { const textValue = streamToValue .filter((v) => v[0] === '0') .reduce((acc, v) => acc + v.replace(/0:"(.*)"\n/, '$1'), ''); - expect(textValue).toEqual('the final answer'); + expect(textValue).toEqual(expectedFinalAnswer); const docValue = streamToValue .filter((v) => v[0] === '8') .reduce((acc, v) => acc + v.replace(/8:(.*)\n/, '$1'), ''); - expect(JSON.parse(docValue)).toEqual([ - { - documents: [ - { metadata: { id: '1', index: 'index' }, pageContent: 'value' }, - { metadata: { id: '1', index: 'website' }, pageContent: 'value2' }, - ], - type: 'retrieved_docs', - }, - ]); - expect(searchMock.mock.calls[0]).toEqual([ - { - index: 'index,website', - query: { query: { match: { field: 'question rewritten to work from home' } } }, - size: 3, - }, - ]); + expect(JSON.parse(docValue)).toEqual(expectedDocs); + expect(searchMock.mock.calls[0]).toEqual(expectedSearchRequest); + }; + + it('should be able to create a conversational chain', async () => { + await createTestChain( + ['the final answer'], + [ + { + id: '1', + role: 'user', + content: 'what is the work from home policy?', + }, + ], + 'the final answer', + [ + { + documents: [ + { metadata: { id: '1', index: 'index' }, pageContent: 'value' }, + { metadata: { id: '1', index: 'website' }, pageContent: 'value2' }, + ], + type: 'retrieved_docs', + }, + ], + [ + { + index: 'index,website', + query: { query: { match: { field: 'what is the work from home policy?' } } }, + size: 3, + }, + ] + ); + }); + + it('asking with chat history should re-write the question', async () => { + await createTestChain( + ['rewrite the question', 'the final answer'], + [ + { + id: '1', + role: 'user', + content: 'what is the work from home policy?', + }, + { + id: '2', + role: 'assistant', + content: 'the final answer', + }, + { + id: '3', + role: 'user', + content: 'what is the work from home policy?', + }, + ], + 'the final answer', + [ + { + documents: [ + { metadata: { id: '1', index: 'index' }, pageContent: 'value' }, + { metadata: { id: '1', index: 'website' }, pageContent: 'value2' }, + ], + type: 'retrieved_docs', + }, + ], + [ + { + index: 'index,website', + query: { query: { match: { field: 'rewrite the question' } } }, + size: 3, + }, + ] + ); + }); + + it('should cope with quotes in the query', async () => { + await createTestChain( + ['rewrite "the" question', 'the final answer'], + [ + { + id: '1', + role: 'user', + content: 'what is the work from home policy?', + }, + { + id: '2', + role: 'assistant', + content: 'the final answer', + }, + { + id: '3', + role: 'user', + content: 'what is the work from home policy?', + }, + ], + 'the final answer', + [ + { + documents: [ + { metadata: { id: '1', index: 'index' }, pageContent: 'value' }, + { metadata: { id: '1', index: 'website' }, pageContent: 'value2' }, + ], + type: 'retrieved_docs', + }, + ], + [ + { + index: 'index,website', + query: { query: { match: { field: 'rewrite "the" question' } } }, + size: 3, + }, + ] + ); }); }); diff --git a/x-pack/plugins/search_playground/server/utils/conversational_chain.ts b/x-pack/plugins/search_playground/server/utils/conversational_chain.ts index 9efeaa0d61b0a8..bcf4116ae74a33 100644 --- a/x-pack/plugins/search_playground/server/utils/conversational_chain.ts +++ b/x-pack/plugins/search_playground/server/utils/conversational_chain.ts @@ -75,7 +75,7 @@ class ConversationalChainFn { const question = messages[messages.length - 1]!.content; const retrievedDocs: Document[] = []; - let retrievalChain: Runnable = RunnableLambda.from((input) => ''); + let retrievalChain: Runnable = RunnableLambda.from(() => ''); if (this.options.rag) { const retriever = new ElasticsearchRetriever({ @@ -107,11 +107,15 @@ class ConversationalChainFn { retrievalChain = retriever.pipe(buildContext); } - const standaloneQuestionChain = RunnableSequence.from([ - condenseQuestionPrompt, - this.options.model, - new StringOutputParser(), - ]); + let standaloneQuestionChain: Runnable = RunnableLambda.from((input) => input.question); + + if (previousMessages.length > 0) { + standaloneQuestionChain = RunnableSequence.from([ + condenseQuestionPrompt, + this.options.model, + new StringOutputParser(), + ]); + } const prompt = ChatPromptTemplate.fromTemplate(this.options.prompt);