From 9659df1a108f50ed0d577dc0bc7b9ba283654f54 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Thu, 2 May 2024 17:10:52 -0700 Subject: [PATCH 1/2] Pass empty QueryCollectorContext in case of hybrid query Signed-off-by: Martin Gaievski --- CHANGELOG.md | 1 + ...earcherWithEmptyQueryCollectorContext.java | 41 ++++++ ...earcherWithEmptyQueryCollectorContext.java | 41 ++++++ .../query/HybridQueryPhaseSearcher.java | 18 ++- ...erWithEmptyQueryCollectorContextTests.java | 128 ++++++++++++++++++ ...erWithEmptyQueryCollectorContextTests.java | 128 ++++++++++++++++++ .../query/HybridQueryPhaseSearcherTests.java | 103 ++++++++------ 7 files changed, 419 insertions(+), 41 deletions(-) create mode 100644 src/main/java/org/opensearch/neuralsearch/search/query/ConcurrentQueryPhaseSearcherWithEmptyQueryCollectorContext.java create mode 100644 src/main/java/org/opensearch/neuralsearch/search/query/DefaultQueryPhaseSearcherWithEmptyQueryCollectorContext.java create mode 100644 src/test/java/org/opensearch/neuralsearch/search/query/ConcurrentQueryPhaseSearcherWithEmptyQueryCollectorContextTests.java create mode 100644 src/test/java/org/opensearch/neuralsearch/search/query/DefaultQueryPhaseSearcherWithEmptyQueryCollectorContextTests.java diff --git a/CHANGELOG.md b/CHANGELOG.md index 36c6be493..479bf1877 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.14...2.x) ### Features ### Enhancements +- Pass empty doc collector instead of top docs collector to improve hybrid query latencies by 20% ([#731](https://github.com/opensearch-project/neural-search/pull/731)) ### Bug Fixes - Fix multi node "no such index" error in text chunking processor ([#713](https://github.com/opensearch-project/neural-search/pull/713)) ### Infrastructure diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/ConcurrentQueryPhaseSearcherWithEmptyQueryCollectorContext.java b/src/main/java/org/opensearch/neuralsearch/search/query/ConcurrentQueryPhaseSearcherWithEmptyQueryCollectorContext.java new file mode 100644 index 000000000..71843d5f2 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/search/query/ConcurrentQueryPhaseSearcherWithEmptyQueryCollectorContext.java @@ -0,0 +1,41 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.search.query; + +import org.apache.lucene.search.Query; +import org.opensearch.search.internal.ContextIndexSearcher; +import org.opensearch.search.internal.SearchContext; +import org.opensearch.search.query.ConcurrentQueryPhaseSearcher; +import org.opensearch.search.query.QueryCollectorContext; + +import java.io.IOException; +import java.util.LinkedList; + +/** + * Class that inherits ConcurrentQueryPhaseSearcher implementation but calls its search with only + * empty query collector context + */ +public class ConcurrentQueryPhaseSearcherWithEmptyQueryCollectorContext extends ConcurrentQueryPhaseSearcher { + + @Override + protected boolean searchWithCollector( + SearchContext searchContext, + ContextIndexSearcher searcher, + Query query, + LinkedList collectors, + boolean hasFilterCollector, + boolean hasTimeout + ) throws IOException { + return searchWithCollector( + searchContext, + searcher, + query, + collectors, + QueryCollectorContext.EMPTY_CONTEXT, + hasFilterCollector, + hasTimeout + ); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/DefaultQueryPhaseSearcherWithEmptyQueryCollectorContext.java b/src/main/java/org/opensearch/neuralsearch/search/query/DefaultQueryPhaseSearcherWithEmptyQueryCollectorContext.java new file mode 100644 index 000000000..179f81e7f --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/search/query/DefaultQueryPhaseSearcherWithEmptyQueryCollectorContext.java @@ -0,0 +1,41 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.search.query; + +import org.apache.lucene.search.Query; +import org.opensearch.search.internal.ContextIndexSearcher; +import org.opensearch.search.internal.SearchContext; +import org.opensearch.search.query.QueryCollectorContext; +import org.opensearch.search.query.QueryPhase; + +import java.io.IOException; +import java.util.LinkedList; + +/** + * Class that inherits DefaultQueryPhaseSearcher implementation but calls its search with only + * empty query collector context + */ +public class DefaultQueryPhaseSearcherWithEmptyQueryCollectorContext extends QueryPhase.DefaultQueryPhaseSearcher { + + @Override + protected boolean searchWithCollector( + SearchContext searchContext, + ContextIndexSearcher searcher, + Query query, + LinkedList collectors, + boolean hasFilterCollector, + boolean hasTimeout + ) throws IOException { + return searchWithCollector( + searchContext, + searcher, + query, + collectors, + QueryCollectorContext.EMPTY_CONTEXT, + hasFilterCollector, + hasTimeout + ); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java index b97134f8f..7b96ebff2 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java @@ -21,6 +21,7 @@ import org.opensearch.search.internal.SearchContext; import org.opensearch.search.query.QueryCollectorContext; import org.opensearch.search.query.QueryPhase; +import org.opensearch.search.query.QueryPhaseSearcher; import org.opensearch.search.query.QueryPhaseSearcherWrapper; import lombok.extern.log4j.Log4j2; @@ -36,6 +37,14 @@ @Log4j2 public class HybridQueryPhaseSearcher extends QueryPhaseSearcherWrapper { + private final QueryPhaseSearcher defaultQueryPhaseSearcherWithEmptyCollectorContext; + private final QueryPhaseSearcher concurrentQueryPhaseSearcherWithEmptyCollectorContext; + + public HybridQueryPhaseSearcher() { + this.defaultQueryPhaseSearcherWithEmptyCollectorContext = new DefaultQueryPhaseSearcherWithEmptyQueryCollectorContext(); + this.concurrentQueryPhaseSearcherWithEmptyCollectorContext = new ConcurrentQueryPhaseSearcherWithEmptyQueryCollectorContext(); + } + public boolean searchWith( final SearchContext searchContext, final ContextIndexSearcher searcher, @@ -49,10 +58,17 @@ public boolean searchWith( return super.searchWith(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout); } else { Query hybridQuery = extractHybridQuery(searchContext, query); - return super.searchWith(searchContext, searcher, hybridQuery, collectors, hasFilterCollector, hasTimeout); + QueryPhaseSearcher queryPhaseSearcher = getQueryPhaseSearcher(searchContext); + return queryPhaseSearcher.searchWith(searchContext, searcher, hybridQuery, collectors, hasFilterCollector, hasTimeout); } } + private QueryPhaseSearcher getQueryPhaseSearcher(final SearchContext searchContext) { + return searchContext.shouldUseConcurrentSearch() + ? concurrentQueryPhaseSearcherWithEmptyCollectorContext + : defaultQueryPhaseSearcherWithEmptyCollectorContext; + } + private static boolean isWrappedHybridQuery(final Query query) { return query instanceof BooleanQuery && ((BooleanQuery) query).clauses().stream().anyMatch(clauseQuery -> clauseQuery.getQuery() instanceof HybridQuery); diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/ConcurrentQueryPhaseSearcherWithEmptyQueryCollectorContextTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/ConcurrentQueryPhaseSearcherWithEmptyQueryCollectorContextTests.java new file mode 100644 index 000000000..5ad641be2 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/search/query/ConcurrentQueryPhaseSearcherWithEmptyQueryCollectorContextTests.java @@ -0,0 +1,128 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.search.query; + +import com.carrotsearch.randomizedtesting.RandomizedTest; +import lombok.SneakyThrows; +import org.apache.lucene.document.FieldType; +import org.apache.lucene.document.TextField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexOptions; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.analysis.MockAnalyzer; +import org.opensearch.action.OriginalIndices; +import org.opensearch.core.index.Index; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.index.mapper.MapperService; +import org.opensearch.index.mapper.TextFieldMapper; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.QueryShardContext; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.index.shard.IndexShard; +import org.opensearch.neuralsearch.query.HybridQueryBuilder; +import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; +import org.opensearch.search.SearchShardTarget; +import org.opensearch.search.internal.ContextIndexSearcher; +import org.opensearch.search.internal.SearchContext; +import org.opensearch.search.query.QueryCollectorContext; +import org.opensearch.search.query.QuerySearchResult; + +import java.io.IOException; +import java.util.LinkedList; + +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.when; + +public class ConcurrentQueryPhaseSearcherWithEmptyQueryCollectorContextTests extends OpenSearchQueryTestCase { + private static final String TEXT_FIELD_NAME = "field"; + private static final String TEST_DOC_TEXT1 = "Hello world"; + private static final String QUERY_TEXT1 = "hello"; + private static final Index dummyIndex = new Index("dummy", "dummy"); + + @SneakyThrows + public void testQueryResult_whenOneSubQueryWithHits_thenHybridResultsAreSet() { + ConcurrentQueryPhaseSearcherWithEmptyQueryCollectorContext queryPhaseSearcher = spy( + new ConcurrentQueryPhaseSearcherWithEmptyQueryCollectorContext() + ); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + MapperService mapperService = createMapperService(); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) mapperService.fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + + Directory directory = newDirectory(); + IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random()))); + FieldType ft = new FieldType(TextField.TYPE_NOT_STORED); + ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); + ft.setOmitNorms(random().nextBoolean()); + ft.freeze(); + int docId1 = RandomizedTest.randomInt(); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId1, TEST_DOC_TEXT1, ft)); + w.commit(); + + IndexReader reader = DirectoryReader.open(w); + SearchContext searchContext = mock(SearchContext.class); + + ContextIndexSearcher contextIndexSearcher = new ContextIndexSearcher( + reader, + IndexSearcher.getDefaultSimilarity(), + IndexSearcher.getDefaultQueryCache(), + IndexSearcher.getDefaultQueryCachingPolicy(), + true, + null, + searchContext + ); + + ShardId shardId = new ShardId(dummyIndex, 1); + SearchShardTarget shardTarget = new SearchShardTarget( + randomAlphaOfLength(10), + shardId, + randomAlphaOfLength(10), + OriginalIndices.NONE + ); + when(searchContext.shardTarget()).thenReturn(shardTarget); + when(searchContext.searcher()).thenReturn(contextIndexSearcher); + when(searchContext.size()).thenReturn(3); + when(searchContext.numberOfShards()).thenReturn(1); + when(searchContext.searcher()).thenReturn(contextIndexSearcher); + IndexShard indexShard = mock(IndexShard.class); + when(indexShard.shardId()).thenReturn(new ShardId("test", "test", 0)); + when(searchContext.indexShard()).thenReturn(indexShard); + QuerySearchResult querySearchResult = new QuerySearchResult(); + when(searchContext.queryResult()).thenReturn(querySearchResult); + when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR); + when(searchContext.mapperService()).thenReturn(mapperService); + + LinkedList collectors = new LinkedList<>(); + boolean hasFilterCollector = randomBoolean(); + boolean hasTimeout = randomBoolean(); + + HybridQueryBuilder queryBuilder = new HybridQueryBuilder(); + + TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1); + queryBuilder.add(termSubQuery); + + Query query = queryBuilder.toQuery(mockQueryShardContext); + when(searchContext.query()).thenReturn(query); + queryPhaseSearcher.aggregationProcessor(searchContext).preProcess(searchContext); + queryPhaseSearcher.searchWith(searchContext, contextIndexSearcher, query, collectors, hasFilterCollector, hasTimeout); + + assertTrue(querySearchResult.hasConsumedTopDocs()); + + releaseResources(directory, w, reader); + } + + private void releaseResources(Directory directory, IndexWriter w, IndexReader reader) throws IOException { + w.close(); + reader.close(); + directory.close(); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/DefaultQueryPhaseSearcherWithEmptyQueryCollectorContextTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/DefaultQueryPhaseSearcherWithEmptyQueryCollectorContextTests.java new file mode 100644 index 000000000..51572000c --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/search/query/DefaultQueryPhaseSearcherWithEmptyQueryCollectorContextTests.java @@ -0,0 +1,128 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.search.query; + +import com.carrotsearch.randomizedtesting.RandomizedTest; +import lombok.SneakyThrows; +import org.apache.lucene.document.FieldType; +import org.apache.lucene.document.TextField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexOptions; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.analysis.MockAnalyzer; +import org.opensearch.action.OriginalIndices; +import org.opensearch.core.index.Index; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.index.mapper.MapperService; +import org.opensearch.index.mapper.TextFieldMapper; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.QueryShardContext; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.index.shard.IndexShard; +import org.opensearch.neuralsearch.query.HybridQueryBuilder; +import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; +import org.opensearch.search.SearchShardTarget; +import org.opensearch.search.internal.ContextIndexSearcher; +import org.opensearch.search.internal.SearchContext; +import org.opensearch.search.query.QueryCollectorContext; +import org.opensearch.search.query.QuerySearchResult; + +import java.io.IOException; +import java.util.LinkedList; + +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.when; + +public class DefaultQueryPhaseSearcherWithEmptyQueryCollectorContextTests extends OpenSearchQueryTestCase { + private static final String TEXT_FIELD_NAME = "field"; + private static final String TEST_DOC_TEXT1 = "Hello world"; + private static final String QUERY_TEXT1 = "hello"; + private static final Index dummyIndex = new Index("dummy", "dummy"); + + @SneakyThrows + public void testQueryResult_whenOneSubQueryWithHits_thenHybridResultsAreSet() { + DefaultQueryPhaseSearcherWithEmptyQueryCollectorContext queryPhaseSearcher = spy( + new DefaultQueryPhaseSearcherWithEmptyQueryCollectorContext() + ); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + MapperService mapperService = createMapperService(); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) mapperService.fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + + Directory directory = newDirectory(); + IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random()))); + FieldType ft = new FieldType(TextField.TYPE_NOT_STORED); + ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); + ft.setOmitNorms(random().nextBoolean()); + ft.freeze(); + int docId1 = RandomizedTest.randomInt(); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId1, TEST_DOC_TEXT1, ft)); + w.commit(); + + IndexReader reader = DirectoryReader.open(w); + SearchContext searchContext = mock(SearchContext.class); + + ContextIndexSearcher contextIndexSearcher = new ContextIndexSearcher( + reader, + IndexSearcher.getDefaultSimilarity(), + IndexSearcher.getDefaultQueryCache(), + IndexSearcher.getDefaultQueryCachingPolicy(), + true, + null, + searchContext + ); + + ShardId shardId = new ShardId(dummyIndex, 1); + SearchShardTarget shardTarget = new SearchShardTarget( + randomAlphaOfLength(10), + shardId, + randomAlphaOfLength(10), + OriginalIndices.NONE + ); + when(searchContext.shardTarget()).thenReturn(shardTarget); + when(searchContext.searcher()).thenReturn(contextIndexSearcher); + when(searchContext.size()).thenReturn(3); + when(searchContext.numberOfShards()).thenReturn(1); + when(searchContext.searcher()).thenReturn(contextIndexSearcher); + IndexShard indexShard = mock(IndexShard.class); + when(indexShard.shardId()).thenReturn(new ShardId("test", "test", 0)); + when(searchContext.indexShard()).thenReturn(indexShard); + QuerySearchResult querySearchResult = new QuerySearchResult(); + when(searchContext.queryResult()).thenReturn(querySearchResult); + when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR); + when(searchContext.mapperService()).thenReturn(mapperService); + + LinkedList collectors = new LinkedList<>(); + boolean hasFilterCollector = randomBoolean(); + boolean hasTimeout = randomBoolean(); + + HybridQueryBuilder queryBuilder = new HybridQueryBuilder(); + + TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1); + queryBuilder.add(termSubQuery); + + Query query = queryBuilder.toQuery(mockQueryShardContext); + when(searchContext.query()).thenReturn(query); + queryPhaseSearcher.aggregationProcessor(searchContext).preProcess(searchContext); + queryPhaseSearcher.searchWith(searchContext, contextIndexSearcher, query, collectors, hasFilterCollector, hasTimeout); + + assertTrue(querySearchResult.hasConsumedTopDocs()); + + releaseResources(directory, w, reader); + } + + private void releaseResources(Directory directory, IndexWriter w, IndexReader reader) throws IOException { + w.close(); + reader.close(); + directory.close(); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java index a938b2111..b606aac3e 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java @@ -20,13 +20,11 @@ import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.isHybridQueryStartStopElement; import java.io.IOException; -import java.util.Arrays; +import java.util.HashMap; import java.util.LinkedList; -import java.util.List; import java.util.Map; import java.util.Set; -import java.util.UUID; -import java.util.stream.Collectors; +import java.util.concurrent.ExecutorService; import org.apache.lucene.document.Document; import org.apache.lucene.document.Field; @@ -39,6 +37,8 @@ import org.apache.lucene.index.IndexWriter; import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; +import org.apache.lucene.search.Collector; +import org.apache.lucene.search.CollectorManager; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; import org.apache.lucene.search.ScoreDoc; @@ -58,8 +58,6 @@ import org.opensearch.index.mapper.MapperService; import org.opensearch.index.mapper.TextFieldMapper; import org.opensearch.index.query.BoolQueryBuilder; -import org.opensearch.index.query.MatchAllQueryBuilder; -import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilders; import org.opensearch.index.query.QueryShardContext; import org.opensearch.index.query.TermQueryBuilder; @@ -78,6 +76,7 @@ import lombok.SneakyThrows; import org.opensearch.search.query.QueryCollectorContext; import org.opensearch.search.query.QuerySearchResult; +import org.opensearch.search.query.ReduceableSearchResult; public class HybridQueryPhaseSearcherTests extends OpenSearchQueryTestCase { private static final String VECTOR_FIELD_NAME = "vectorField"; @@ -88,13 +87,7 @@ public class HybridQueryPhaseSearcherTests extends OpenSearchQueryTestCase { private static final String TEST_DOC_TEXT4 = "This is really nice place to be"; private static final String QUERY_TEXT1 = "hello"; private static final String QUERY_TEXT2 = "randomkeyword"; - private static final String QUERY_TEXT3 = "place"; private static final Index dummyIndex = new Index("dummy", "dummy"); - private static final String MODEL_ID = "mfgfgdsfgfdgsde"; - private static final int K = 10; - private static final QueryBuilder TEST_FILTER = new MatchAllQueryBuilder(); - private static final UUID INDEX_UUID = UUID.randomUUID(); - private static final String TEST_INDEX = "index"; @SneakyThrows public void testQueryType_whenQueryIsHybrid_thenCallHybridDocCollector() { @@ -306,20 +299,22 @@ public void testQueryResult_whenOneSubQueryWithHits_thenHybridResultsAreSet() { Query query = queryBuilder.toQuery(mockQueryShardContext); when(searchContext.query()).thenReturn(query); + CollectorManager collectorManager = HybridCollectorManager + .createHybridCollectorManager(searchContext); + Map, CollectorManager> queryCollectorManagers = new HashMap<>(); + queryCollectorManagers.put(HybridCollectorManager.class, collectorManager); + when(searchContext.queryCollectorManagers()).thenReturn(queryCollectorManagers); + hybridQueryPhaseSearcher.searchWith(searchContext, contextIndexSearcher, query, collectors, hasFilterCollector, hasTimeout); + hybridQueryPhaseSearcher.aggregationProcessor(searchContext).postProcess(searchContext); assertNotNull(querySearchResult.topDocs()); TopDocsAndMaxScore topDocsAndMaxScore = querySearchResult.topDocs(); TopDocs topDocs = topDocsAndMaxScore.topDocs; - assertEquals(1, topDocs.totalHits.value); + assertEquals(0, topDocs.totalHits.value); ScoreDoc[] scoreDocs = topDocs.scoreDocs; assertNotNull(scoreDocs); - assertEquals(1, scoreDocs.length); - ScoreDoc scoreDoc = scoreDocs[0]; - assertNotNull(scoreDoc); - int actualDocId = Integer.parseInt(reader.document(scoreDoc.doc).getField("id").stringValue()); - assertEquals(docId1, actualDocId); - assertTrue(scoreDoc.score > 0.0f); + assertEquals(0, scoreDocs.length); releaseResources(directory, w, reader); } @@ -340,13 +335,7 @@ public void testQueryResult_whenMultipleTextSubQueriesWithSomeHits_thenHybridRes ft.setOmitNorms(random().nextBoolean()); ft.freeze(); int docId1 = RandomizedTest.randomInt(); - int docId2 = RandomizedTest.randomInt(); - int docId3 = RandomizedTest.randomInt(); - int docId4 = RandomizedTest.randomInt(); w.addDocument(getDocument(TEXT_FIELD_NAME, docId1, TEST_DOC_TEXT1, ft)); - w.addDocument(getDocument(TEXT_FIELD_NAME, docId2, TEST_DOC_TEXT2, ft)); - w.addDocument(getDocument(TEXT_FIELD_NAME, docId3, TEST_DOC_TEXT3, ft)); - w.addDocument(getDocument(TEXT_FIELD_NAME, docId4, TEST_DOC_TEXT4, ft)); w.commit(); IndexReader reader = DirectoryReader.open(w); @@ -395,18 +384,22 @@ public void testQueryResult_whenMultipleTextSubQueriesWithSomeHits_thenHybridRes Query query = queryBuilder.toQuery(mockQueryShardContext); when(searchContext.query()).thenReturn(query); + CollectorManager collectorManager = HybridCollectorManager + .createHybridCollectorManager(searchContext); + Map, CollectorManager> queryCollectorManagers = new HashMap<>(); + queryCollectorManagers.put(HybridCollectorManager.class, collectorManager); + when(searchContext.queryCollectorManagers()).thenReturn(queryCollectorManagers); + hybridQueryPhaseSearcher.searchWith(searchContext, contextIndexSearcher, query, collectors, hasFilterCollector, hasTimeout); + hybridQueryPhaseSearcher.aggregationProcessor(searchContext).postProcess(searchContext); assertNotNull(querySearchResult.topDocs()); TopDocsAndMaxScore topDocsAndMaxScore = querySearchResult.topDocs(); TopDocs topDocs = topDocsAndMaxScore.topDocs; - assertEquals(4, topDocs.totalHits.value); + assertEquals(0, topDocs.totalHits.value); ScoreDoc[] scoreDocs = topDocs.scoreDocs; assertNotNull(scoreDocs); - assertEquals(4, scoreDocs.length); - List expectedIds = List.of(0, 1, 2, 3); - List actualDocIds = Arrays.stream(scoreDocs).map(sd -> sd.doc).collect(Collectors.toList()); - assertEquals(expectedIds, actualDocIds); + assertEquals(0, scoreDocs.length); releaseResources(directory, w, reader); } @@ -705,18 +698,22 @@ public void testWrappedHybridQuery_whenHybridWrappedIntoBoolBecauseOfNested_then when(searchContext.query()).thenReturn(query); + CollectorManager collectorManager = HybridCollectorManager + .createHybridCollectorManager(searchContext); + Map, CollectorManager> queryCollectorManagers = new HashMap<>(); + queryCollectorManagers.put(HybridCollectorManager.class, collectorManager); + when(searchContext.queryCollectorManagers()).thenReturn(queryCollectorManagers); + hybridQueryPhaseSearcher.searchWith(searchContext, contextIndexSearcher, query, collectors, hasFilterCollector, hasTimeout); + hybridQueryPhaseSearcher.aggregationProcessor(searchContext).postProcess(searchContext); assertNotNull(querySearchResult.topDocs()); TopDocsAndMaxScore topDocsAndMaxScore = querySearchResult.topDocs(); TopDocs topDocs = topDocsAndMaxScore.topDocs; - assertTrue(topDocs.totalHits.value > 0); + assertEquals(0, topDocs.totalHits.value); ScoreDoc[] scoreDocs = topDocs.scoreDocs; assertNotNull(scoreDocs); - assertEquals(1, scoreDocs.length); - ScoreDoc scoreDoc = scoreDocs[0]; - assertTrue(scoreDoc.score > 0); - assertEquals(0, scoreDoc.doc); + assertEquals(0, scoreDocs.length); releaseResources(directory, w, reader); } @@ -979,18 +976,22 @@ public void testAliasWithFilter_whenHybridWrappedIntoBoolBecauseOfIndexAlias_the when(searchContext.query()).thenReturn(query); when(searchContext.aliasFilter()).thenReturn(termFilter); + CollectorManager collectorManager = HybridCollectorManager + .createHybridCollectorManager(searchContext); + Map, CollectorManager> queryCollectorManagers = new HashMap<>(); + queryCollectorManagers.put(HybridCollectorManager.class, collectorManager); + when(searchContext.queryCollectorManagers()).thenReturn(queryCollectorManagers); + hybridQueryPhaseSearcher.searchWith(searchContext, contextIndexSearcher, query, collectors, hasFilterCollector, hasTimeout); + hybridQueryPhaseSearcher.aggregationProcessor(searchContext).postProcess(searchContext); assertNotNull(querySearchResult.topDocs()); TopDocsAndMaxScore topDocsAndMaxScore = querySearchResult.topDocs(); TopDocs topDocs = topDocsAndMaxScore.topDocs; - assertTrue(topDocs.totalHits.value > 0); + assertEquals(0, topDocs.totalHits.value); ScoreDoc[] scoreDocs = topDocs.scoreDocs; assertNotNull(scoreDocs); - assertEquals(1, scoreDocs.length); - ScoreDoc scoreDoc = scoreDocs[0]; - assertTrue(scoreDoc.score > 0); - assertEquals(0, scoreDoc.doc); + assertEquals(0, scoreDocs.length); releaseResources(directory, w, reader); } @@ -1038,4 +1039,26 @@ private static IndexMetadata getIndexMetadata() { .build(); return indexMetadata; } + + private static ContextIndexSearcher newContextSearcher(IndexReader reader, ExecutorService executor) throws IOException { + SearchContext searchContext = mock(SearchContext.class); + IndexShard indexShard = mock(IndexShard.class); + when(searchContext.indexShard()).thenReturn(indexShard); + when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR); + when(searchContext.shouldUseConcurrentSearch()).thenReturn(executor != null); + if (executor != null) { + when(searchContext.getTargetMaxSliceCount()).thenReturn(randomIntBetween(0, 2)); + } else { + when(searchContext.getTargetMaxSliceCount()).thenThrow(IllegalStateException.class); + } + return new ContextIndexSearcher( + reader, + IndexSearcher.getDefaultSimilarity(), + IndexSearcher.getDefaultQueryCache(), + IndexSearcher.getDefaultQueryCachingPolicy(), + true, + executor, + searchContext + ); + } } From 22577219ff57edcc3c4887735ab072447fe344f5 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Fri, 3 May 2024 17:10:18 -0700 Subject: [PATCH 2/2] Refactor custom qps classes into private static classes Signed-off-by: Martin Gaievski --- ...earcherWithEmptyQueryCollectorContext.java | 41 ------ ...earcherWithEmptyQueryCollectorContext.java | 41 ------ .../query/HybridQueryPhaseSearcher.java | 59 ++++++++ ...erWithEmptyQueryCollectorContextTests.java | 128 ------------------ ...erWithEmptyQueryCollectorContextTests.java | 128 ------------------ .../query/HybridQueryPhaseSearcherTests.java | 22 --- 6 files changed, 59 insertions(+), 360 deletions(-) delete mode 100644 src/main/java/org/opensearch/neuralsearch/search/query/ConcurrentQueryPhaseSearcherWithEmptyQueryCollectorContext.java delete mode 100644 src/main/java/org/opensearch/neuralsearch/search/query/DefaultQueryPhaseSearcherWithEmptyQueryCollectorContext.java delete mode 100644 src/test/java/org/opensearch/neuralsearch/search/query/ConcurrentQueryPhaseSearcherWithEmptyQueryCollectorContextTests.java delete mode 100644 src/test/java/org/opensearch/neuralsearch/search/query/DefaultQueryPhaseSearcherWithEmptyQueryCollectorContextTests.java diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/ConcurrentQueryPhaseSearcherWithEmptyQueryCollectorContext.java b/src/main/java/org/opensearch/neuralsearch/search/query/ConcurrentQueryPhaseSearcherWithEmptyQueryCollectorContext.java deleted file mode 100644 index 71843d5f2..000000000 --- a/src/main/java/org/opensearch/neuralsearch/search/query/ConcurrentQueryPhaseSearcherWithEmptyQueryCollectorContext.java +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ -package org.opensearch.neuralsearch.search.query; - -import org.apache.lucene.search.Query; -import org.opensearch.search.internal.ContextIndexSearcher; -import org.opensearch.search.internal.SearchContext; -import org.opensearch.search.query.ConcurrentQueryPhaseSearcher; -import org.opensearch.search.query.QueryCollectorContext; - -import java.io.IOException; -import java.util.LinkedList; - -/** - * Class that inherits ConcurrentQueryPhaseSearcher implementation but calls its search with only - * empty query collector context - */ -public class ConcurrentQueryPhaseSearcherWithEmptyQueryCollectorContext extends ConcurrentQueryPhaseSearcher { - - @Override - protected boolean searchWithCollector( - SearchContext searchContext, - ContextIndexSearcher searcher, - Query query, - LinkedList collectors, - boolean hasFilterCollector, - boolean hasTimeout - ) throws IOException { - return searchWithCollector( - searchContext, - searcher, - query, - collectors, - QueryCollectorContext.EMPTY_CONTEXT, - hasFilterCollector, - hasTimeout - ); - } -} diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/DefaultQueryPhaseSearcherWithEmptyQueryCollectorContext.java b/src/main/java/org/opensearch/neuralsearch/search/query/DefaultQueryPhaseSearcherWithEmptyQueryCollectorContext.java deleted file mode 100644 index 179f81e7f..000000000 --- a/src/main/java/org/opensearch/neuralsearch/search/query/DefaultQueryPhaseSearcherWithEmptyQueryCollectorContext.java +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ -package org.opensearch.neuralsearch.search.query; - -import org.apache.lucene.search.Query; -import org.opensearch.search.internal.ContextIndexSearcher; -import org.opensearch.search.internal.SearchContext; -import org.opensearch.search.query.QueryCollectorContext; -import org.opensearch.search.query.QueryPhase; - -import java.io.IOException; -import java.util.LinkedList; - -/** - * Class that inherits DefaultQueryPhaseSearcher implementation but calls its search with only - * empty query collector context - */ -public class DefaultQueryPhaseSearcherWithEmptyQueryCollectorContext extends QueryPhase.DefaultQueryPhaseSearcher { - - @Override - protected boolean searchWithCollector( - SearchContext searchContext, - ContextIndexSearcher searcher, - Query query, - LinkedList collectors, - boolean hasFilterCollector, - boolean hasTimeout - ) throws IOException { - return searchWithCollector( - searchContext, - searcher, - query, - collectors, - QueryCollectorContext.EMPTY_CONTEXT, - hasFilterCollector, - hasTimeout - ); - } -} diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java index 7b96ebff2..53248f88c 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java @@ -10,6 +10,8 @@ import java.util.stream.Collectors; import com.google.common.annotations.VisibleForTesting; +import lombok.AccessLevel; +import lombok.NoArgsConstructor; import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.Query; @@ -19,6 +21,7 @@ import org.opensearch.search.aggregations.AggregationProcessor; import org.opensearch.search.internal.ContextIndexSearcher; import org.opensearch.search.internal.SearchContext; +import org.opensearch.search.query.ConcurrentQueryPhaseSearcher; import org.opensearch.search.query.QueryCollectorContext; import org.opensearch.search.query.QueryPhase; import org.opensearch.search.query.QueryPhaseSearcher; @@ -148,4 +151,60 @@ public AggregationProcessor aggregationProcessor(SearchContext searchContext) { AggregationProcessor coreAggProcessor = super.aggregationProcessor(searchContext); return new HybridAggregationProcessor(coreAggProcessor); } + + /** + * Class that inherits ConcurrentQueryPhaseSearcher implementation but calls its search with only + * empty query collector context + */ + @NoArgsConstructor(access = AccessLevel.PRIVATE) + final class ConcurrentQueryPhaseSearcherWithEmptyQueryCollectorContext extends ConcurrentQueryPhaseSearcher { + + @Override + protected boolean searchWithCollector( + SearchContext searchContext, + ContextIndexSearcher searcher, + Query query, + LinkedList collectors, + boolean hasFilterCollector, + boolean hasTimeout + ) throws IOException { + return searchWithCollector( + searchContext, + searcher, + query, + collectors, + QueryCollectorContext.EMPTY_CONTEXT, + hasFilterCollector, + hasTimeout + ); + } + } + + /** + * Class that inherits DefaultQueryPhaseSearcher implementation but calls its search with only + * empty query collector context + */ + @NoArgsConstructor(access = AccessLevel.PACKAGE) + final class DefaultQueryPhaseSearcherWithEmptyQueryCollectorContext extends QueryPhase.DefaultQueryPhaseSearcher { + + @Override + protected boolean searchWithCollector( + SearchContext searchContext, + ContextIndexSearcher searcher, + Query query, + LinkedList collectors, + boolean hasFilterCollector, + boolean hasTimeout + ) throws IOException { + return searchWithCollector( + searchContext, + searcher, + query, + collectors, + QueryCollectorContext.EMPTY_CONTEXT, + hasFilterCollector, + hasTimeout + ); + } + } } diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/ConcurrentQueryPhaseSearcherWithEmptyQueryCollectorContextTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/ConcurrentQueryPhaseSearcherWithEmptyQueryCollectorContextTests.java deleted file mode 100644 index 5ad641be2..000000000 --- a/src/test/java/org/opensearch/neuralsearch/search/query/ConcurrentQueryPhaseSearcherWithEmptyQueryCollectorContextTests.java +++ /dev/null @@ -1,128 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ -package org.opensearch.neuralsearch.search.query; - -import com.carrotsearch.randomizedtesting.RandomizedTest; -import lombok.SneakyThrows; -import org.apache.lucene.document.FieldType; -import org.apache.lucene.document.TextField; -import org.apache.lucene.index.DirectoryReader; -import org.apache.lucene.index.IndexOptions; -import org.apache.lucene.index.IndexReader; -import org.apache.lucene.index.IndexWriter; -import org.apache.lucene.search.IndexSearcher; -import org.apache.lucene.search.Query; -import org.apache.lucene.store.Directory; -import org.apache.lucene.tests.analysis.MockAnalyzer; -import org.opensearch.action.OriginalIndices; -import org.opensearch.core.index.Index; -import org.opensearch.core.index.shard.ShardId; -import org.opensearch.index.mapper.MapperService; -import org.opensearch.index.mapper.TextFieldMapper; -import org.opensearch.index.query.QueryBuilders; -import org.opensearch.index.query.QueryShardContext; -import org.opensearch.index.query.TermQueryBuilder; -import org.opensearch.index.shard.IndexShard; -import org.opensearch.neuralsearch.query.HybridQueryBuilder; -import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; -import org.opensearch.search.SearchShardTarget; -import org.opensearch.search.internal.ContextIndexSearcher; -import org.opensearch.search.internal.SearchContext; -import org.opensearch.search.query.QueryCollectorContext; -import org.opensearch.search.query.QuerySearchResult; - -import java.io.IOException; -import java.util.LinkedList; - -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.when; - -public class ConcurrentQueryPhaseSearcherWithEmptyQueryCollectorContextTests extends OpenSearchQueryTestCase { - private static final String TEXT_FIELD_NAME = "field"; - private static final String TEST_DOC_TEXT1 = "Hello world"; - private static final String QUERY_TEXT1 = "hello"; - private static final Index dummyIndex = new Index("dummy", "dummy"); - - @SneakyThrows - public void testQueryResult_whenOneSubQueryWithHits_thenHybridResultsAreSet() { - ConcurrentQueryPhaseSearcherWithEmptyQueryCollectorContext queryPhaseSearcher = spy( - new ConcurrentQueryPhaseSearcherWithEmptyQueryCollectorContext() - ); - QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); - when(mockQueryShardContext.index()).thenReturn(dummyIndex); - MapperService mapperService = createMapperService(); - TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) mapperService.fieldType(TEXT_FIELD_NAME); - when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); - - Directory directory = newDirectory(); - IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random()))); - FieldType ft = new FieldType(TextField.TYPE_NOT_STORED); - ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); - ft.setOmitNorms(random().nextBoolean()); - ft.freeze(); - int docId1 = RandomizedTest.randomInt(); - w.addDocument(getDocument(TEXT_FIELD_NAME, docId1, TEST_DOC_TEXT1, ft)); - w.commit(); - - IndexReader reader = DirectoryReader.open(w); - SearchContext searchContext = mock(SearchContext.class); - - ContextIndexSearcher contextIndexSearcher = new ContextIndexSearcher( - reader, - IndexSearcher.getDefaultSimilarity(), - IndexSearcher.getDefaultQueryCache(), - IndexSearcher.getDefaultQueryCachingPolicy(), - true, - null, - searchContext - ); - - ShardId shardId = new ShardId(dummyIndex, 1); - SearchShardTarget shardTarget = new SearchShardTarget( - randomAlphaOfLength(10), - shardId, - randomAlphaOfLength(10), - OriginalIndices.NONE - ); - when(searchContext.shardTarget()).thenReturn(shardTarget); - when(searchContext.searcher()).thenReturn(contextIndexSearcher); - when(searchContext.size()).thenReturn(3); - when(searchContext.numberOfShards()).thenReturn(1); - when(searchContext.searcher()).thenReturn(contextIndexSearcher); - IndexShard indexShard = mock(IndexShard.class); - when(indexShard.shardId()).thenReturn(new ShardId("test", "test", 0)); - when(searchContext.indexShard()).thenReturn(indexShard); - QuerySearchResult querySearchResult = new QuerySearchResult(); - when(searchContext.queryResult()).thenReturn(querySearchResult); - when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR); - when(searchContext.mapperService()).thenReturn(mapperService); - - LinkedList collectors = new LinkedList<>(); - boolean hasFilterCollector = randomBoolean(); - boolean hasTimeout = randomBoolean(); - - HybridQueryBuilder queryBuilder = new HybridQueryBuilder(); - - TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1); - queryBuilder.add(termSubQuery); - - Query query = queryBuilder.toQuery(mockQueryShardContext); - when(searchContext.query()).thenReturn(query); - queryPhaseSearcher.aggregationProcessor(searchContext).preProcess(searchContext); - queryPhaseSearcher.searchWith(searchContext, contextIndexSearcher, query, collectors, hasFilterCollector, hasTimeout); - - assertTrue(querySearchResult.hasConsumedTopDocs()); - - releaseResources(directory, w, reader); - } - - private void releaseResources(Directory directory, IndexWriter w, IndexReader reader) throws IOException { - w.close(); - reader.close(); - directory.close(); - } -} diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/DefaultQueryPhaseSearcherWithEmptyQueryCollectorContextTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/DefaultQueryPhaseSearcherWithEmptyQueryCollectorContextTests.java deleted file mode 100644 index 51572000c..000000000 --- a/src/test/java/org/opensearch/neuralsearch/search/query/DefaultQueryPhaseSearcherWithEmptyQueryCollectorContextTests.java +++ /dev/null @@ -1,128 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ -package org.opensearch.neuralsearch.search.query; - -import com.carrotsearch.randomizedtesting.RandomizedTest; -import lombok.SneakyThrows; -import org.apache.lucene.document.FieldType; -import org.apache.lucene.document.TextField; -import org.apache.lucene.index.DirectoryReader; -import org.apache.lucene.index.IndexOptions; -import org.apache.lucene.index.IndexReader; -import org.apache.lucene.index.IndexWriter; -import org.apache.lucene.search.IndexSearcher; -import org.apache.lucene.search.Query; -import org.apache.lucene.store.Directory; -import org.apache.lucene.tests.analysis.MockAnalyzer; -import org.opensearch.action.OriginalIndices; -import org.opensearch.core.index.Index; -import org.opensearch.core.index.shard.ShardId; -import org.opensearch.index.mapper.MapperService; -import org.opensearch.index.mapper.TextFieldMapper; -import org.opensearch.index.query.QueryBuilders; -import org.opensearch.index.query.QueryShardContext; -import org.opensearch.index.query.TermQueryBuilder; -import org.opensearch.index.shard.IndexShard; -import org.opensearch.neuralsearch.query.HybridQueryBuilder; -import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; -import org.opensearch.search.SearchShardTarget; -import org.opensearch.search.internal.ContextIndexSearcher; -import org.opensearch.search.internal.SearchContext; -import org.opensearch.search.query.QueryCollectorContext; -import org.opensearch.search.query.QuerySearchResult; - -import java.io.IOException; -import java.util.LinkedList; - -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.when; - -public class DefaultQueryPhaseSearcherWithEmptyQueryCollectorContextTests extends OpenSearchQueryTestCase { - private static final String TEXT_FIELD_NAME = "field"; - private static final String TEST_DOC_TEXT1 = "Hello world"; - private static final String QUERY_TEXT1 = "hello"; - private static final Index dummyIndex = new Index("dummy", "dummy"); - - @SneakyThrows - public void testQueryResult_whenOneSubQueryWithHits_thenHybridResultsAreSet() { - DefaultQueryPhaseSearcherWithEmptyQueryCollectorContext queryPhaseSearcher = spy( - new DefaultQueryPhaseSearcherWithEmptyQueryCollectorContext() - ); - QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); - when(mockQueryShardContext.index()).thenReturn(dummyIndex); - MapperService mapperService = createMapperService(); - TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) mapperService.fieldType(TEXT_FIELD_NAME); - when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); - - Directory directory = newDirectory(); - IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random()))); - FieldType ft = new FieldType(TextField.TYPE_NOT_STORED); - ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); - ft.setOmitNorms(random().nextBoolean()); - ft.freeze(); - int docId1 = RandomizedTest.randomInt(); - w.addDocument(getDocument(TEXT_FIELD_NAME, docId1, TEST_DOC_TEXT1, ft)); - w.commit(); - - IndexReader reader = DirectoryReader.open(w); - SearchContext searchContext = mock(SearchContext.class); - - ContextIndexSearcher contextIndexSearcher = new ContextIndexSearcher( - reader, - IndexSearcher.getDefaultSimilarity(), - IndexSearcher.getDefaultQueryCache(), - IndexSearcher.getDefaultQueryCachingPolicy(), - true, - null, - searchContext - ); - - ShardId shardId = new ShardId(dummyIndex, 1); - SearchShardTarget shardTarget = new SearchShardTarget( - randomAlphaOfLength(10), - shardId, - randomAlphaOfLength(10), - OriginalIndices.NONE - ); - when(searchContext.shardTarget()).thenReturn(shardTarget); - when(searchContext.searcher()).thenReturn(contextIndexSearcher); - when(searchContext.size()).thenReturn(3); - when(searchContext.numberOfShards()).thenReturn(1); - when(searchContext.searcher()).thenReturn(contextIndexSearcher); - IndexShard indexShard = mock(IndexShard.class); - when(indexShard.shardId()).thenReturn(new ShardId("test", "test", 0)); - when(searchContext.indexShard()).thenReturn(indexShard); - QuerySearchResult querySearchResult = new QuerySearchResult(); - when(searchContext.queryResult()).thenReturn(querySearchResult); - when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR); - when(searchContext.mapperService()).thenReturn(mapperService); - - LinkedList collectors = new LinkedList<>(); - boolean hasFilterCollector = randomBoolean(); - boolean hasTimeout = randomBoolean(); - - HybridQueryBuilder queryBuilder = new HybridQueryBuilder(); - - TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1); - queryBuilder.add(termSubQuery); - - Query query = queryBuilder.toQuery(mockQueryShardContext); - when(searchContext.query()).thenReturn(query); - queryPhaseSearcher.aggregationProcessor(searchContext).preProcess(searchContext); - queryPhaseSearcher.searchWith(searchContext, contextIndexSearcher, query, collectors, hasFilterCollector, hasTimeout); - - assertTrue(querySearchResult.hasConsumedTopDocs()); - - releaseResources(directory, w, reader); - } - - private void releaseResources(Directory directory, IndexWriter w, IndexReader reader) throws IOException { - w.close(); - reader.close(); - directory.close(); - } -} diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java index b606aac3e..e790ffb77 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java @@ -1039,26 +1039,4 @@ private static IndexMetadata getIndexMetadata() { .build(); return indexMetadata; } - - private static ContextIndexSearcher newContextSearcher(IndexReader reader, ExecutorService executor) throws IOException { - SearchContext searchContext = mock(SearchContext.class); - IndexShard indexShard = mock(IndexShard.class); - when(searchContext.indexShard()).thenReturn(indexShard); - when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR); - when(searchContext.shouldUseConcurrentSearch()).thenReturn(executor != null); - if (executor != null) { - when(searchContext.getTargetMaxSliceCount()).thenReturn(randomIntBetween(0, 2)); - } else { - when(searchContext.getTargetMaxSliceCount()).thenThrow(IllegalStateException.class); - } - return new ContextIndexSearcher( - reader, - IndexSearcher.getDefaultSimilarity(), - IndexSearcher.getDefaultQueryCache(), - IndexSearcher.getDefaultQueryCachingPolicy(), - true, - executor, - searchContext - ); - } }