Skip to content

Commit

Permalink
Pass empty QueryCollectorContext in case of hybrid query to improve l…
Browse files Browse the repository at this point in the history
…atencies by 20% (#731)

* Pass empty QueryCollectorContext in case of hybrid query

Signed-off-by: Martin Gaievski <[email protected]>
(cherry picked from commit 2c556d2)
  • Loading branch information
martin-gaievski authored and github-actions[bot] committed May 6, 2024
1 parent d7b72a6 commit b7fc313
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 41 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.14...2.x)
### Features
### Enhancements
- Pass empty doc collector instead of top docs collector to improve hybrid query latencies by 20% ([#731](https://github.com/opensearch-project/neural-search/pull/731))
### Bug Fixes
- Fix multi node "no such index" error in text chunking processor ([#713](https://github.com/opensearch-project/neural-search/pull/713))
### Infrastructure
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import java.util.stream.Collectors;

import com.google.common.annotations.VisibleForTesting;
import lombok.AccessLevel;
import lombok.NoArgsConstructor;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.Query;
Expand All @@ -19,8 +21,10 @@
import org.opensearch.search.aggregations.AggregationProcessor;
import org.opensearch.search.internal.ContextIndexSearcher;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.query.ConcurrentQueryPhaseSearcher;
import org.opensearch.search.query.QueryCollectorContext;
import org.opensearch.search.query.QueryPhase;
import org.opensearch.search.query.QueryPhaseSearcher;
import org.opensearch.search.query.QueryPhaseSearcherWrapper;

import lombok.extern.log4j.Log4j2;
Expand All @@ -36,6 +40,14 @@
@Log4j2
public class HybridQueryPhaseSearcher extends QueryPhaseSearcherWrapper {

private final QueryPhaseSearcher defaultQueryPhaseSearcherWithEmptyCollectorContext;
private final QueryPhaseSearcher concurrentQueryPhaseSearcherWithEmptyCollectorContext;

public HybridQueryPhaseSearcher() {
this.defaultQueryPhaseSearcherWithEmptyCollectorContext = new DefaultQueryPhaseSearcherWithEmptyQueryCollectorContext();
this.concurrentQueryPhaseSearcherWithEmptyCollectorContext = new ConcurrentQueryPhaseSearcherWithEmptyQueryCollectorContext();
}

public boolean searchWith(
final SearchContext searchContext,
final ContextIndexSearcher searcher,
Expand All @@ -49,10 +61,17 @@ public boolean searchWith(
return super.searchWith(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout);
} else {
Query hybridQuery = extractHybridQuery(searchContext, query);
return super.searchWith(searchContext, searcher, hybridQuery, collectors, hasFilterCollector, hasTimeout);
QueryPhaseSearcher queryPhaseSearcher = getQueryPhaseSearcher(searchContext);
return queryPhaseSearcher.searchWith(searchContext, searcher, hybridQuery, collectors, hasFilterCollector, hasTimeout);
}
}

private QueryPhaseSearcher getQueryPhaseSearcher(final SearchContext searchContext) {
return searchContext.shouldUseConcurrentSearch()
? concurrentQueryPhaseSearcherWithEmptyCollectorContext
: defaultQueryPhaseSearcherWithEmptyCollectorContext;
}

private static boolean isWrappedHybridQuery(final Query query) {
return query instanceof BooleanQuery
&& ((BooleanQuery) query).clauses().stream().anyMatch(clauseQuery -> clauseQuery.getQuery() instanceof HybridQuery);
Expand Down Expand Up @@ -132,4 +151,60 @@ public AggregationProcessor aggregationProcessor(SearchContext searchContext) {
AggregationProcessor coreAggProcessor = super.aggregationProcessor(searchContext);
return new HybridAggregationProcessor(coreAggProcessor);
}

/**
* Class that inherits ConcurrentQueryPhaseSearcher implementation but calls its search with only
* empty query collector context
*/
@NoArgsConstructor(access = AccessLevel.PRIVATE)
final class ConcurrentQueryPhaseSearcherWithEmptyQueryCollectorContext extends ConcurrentQueryPhaseSearcher {

@Override
protected boolean searchWithCollector(
SearchContext searchContext,
ContextIndexSearcher searcher,
Query query,
LinkedList<QueryCollectorContext> collectors,
boolean hasFilterCollector,
boolean hasTimeout
) throws IOException {
return searchWithCollector(
searchContext,
searcher,
query,
collectors,
QueryCollectorContext.EMPTY_CONTEXT,
hasFilterCollector,
hasTimeout
);
}
}

/**
* Class that inherits DefaultQueryPhaseSearcher implementation but calls its search with only
* empty query collector context
*/
@NoArgsConstructor(access = AccessLevel.PACKAGE)
final class DefaultQueryPhaseSearcherWithEmptyQueryCollectorContext extends QueryPhase.DefaultQueryPhaseSearcher {

@Override
protected boolean searchWithCollector(
SearchContext searchContext,
ContextIndexSearcher searcher,
Query query,
LinkedList<QueryCollectorContext> collectors,
boolean hasFilterCollector,
boolean hasTimeout
) throws IOException {
return searchWithCollector(
searchContext,
searcher,
query,
collectors,
QueryCollectorContext.EMPTY_CONTEXT,
hasFilterCollector,
hasTimeout
);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,11 @@
import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.isHybridQueryStartStopElement;

import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.UUID;
import java.util.stream.Collectors;
import java.util.concurrent.ExecutorService;

import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
Expand All @@ -39,6 +37,8 @@
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.CollectorManager;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc;
Expand All @@ -58,8 +58,6 @@
import org.opensearch.index.mapper.MapperService;
import org.opensearch.index.mapper.TextFieldMapper;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.MatchAllQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.index.query.TermQueryBuilder;
Expand All @@ -78,6 +76,7 @@
import lombok.SneakyThrows;
import org.opensearch.search.query.QueryCollectorContext;
import org.opensearch.search.query.QuerySearchResult;
import org.opensearch.search.query.ReduceableSearchResult;

public class HybridQueryPhaseSearcherTests extends OpenSearchQueryTestCase {
private static final String VECTOR_FIELD_NAME = "vectorField";
Expand All @@ -88,13 +87,7 @@ public class HybridQueryPhaseSearcherTests extends OpenSearchQueryTestCase {
private static final String TEST_DOC_TEXT4 = "This is really nice place to be";
private static final String QUERY_TEXT1 = "hello";
private static final String QUERY_TEXT2 = "randomkeyword";
private static final String QUERY_TEXT3 = "place";
private static final Index dummyIndex = new Index("dummy", "dummy");
private static final String MODEL_ID = "mfgfgdsfgfdgsde";
private static final int K = 10;
private static final QueryBuilder TEST_FILTER = new MatchAllQueryBuilder();
private static final UUID INDEX_UUID = UUID.randomUUID();
private static final String TEST_INDEX = "index";

@SneakyThrows
public void testQueryType_whenQueryIsHybrid_thenCallHybridDocCollector() {
Expand Down Expand Up @@ -306,20 +299,22 @@ public void testQueryResult_whenOneSubQueryWithHits_thenHybridResultsAreSet() {
Query query = queryBuilder.toQuery(mockQueryShardContext);
when(searchContext.query()).thenReturn(query);

CollectorManager<? extends Collector, ReduceableSearchResult> collectorManager = HybridCollectorManager
.createHybridCollectorManager(searchContext);
Map<Class<?>, CollectorManager<? extends Collector, ReduceableSearchResult>> queryCollectorManagers = new HashMap<>();
queryCollectorManagers.put(HybridCollectorManager.class, collectorManager);
when(searchContext.queryCollectorManagers()).thenReturn(queryCollectorManagers);

hybridQueryPhaseSearcher.searchWith(searchContext, contextIndexSearcher, query, collectors, hasFilterCollector, hasTimeout);
hybridQueryPhaseSearcher.aggregationProcessor(searchContext).postProcess(searchContext);

assertNotNull(querySearchResult.topDocs());
TopDocsAndMaxScore topDocsAndMaxScore = querySearchResult.topDocs();
TopDocs topDocs = topDocsAndMaxScore.topDocs;
assertEquals(1, topDocs.totalHits.value);
assertEquals(0, topDocs.totalHits.value);
ScoreDoc[] scoreDocs = topDocs.scoreDocs;
assertNotNull(scoreDocs);
assertEquals(1, scoreDocs.length);
ScoreDoc scoreDoc = scoreDocs[0];
assertNotNull(scoreDoc);
int actualDocId = Integer.parseInt(reader.document(scoreDoc.doc).getField("id").stringValue());
assertEquals(docId1, actualDocId);
assertTrue(scoreDoc.score > 0.0f);
assertEquals(0, scoreDocs.length);

releaseResources(directory, w, reader);
}
Expand All @@ -340,13 +335,7 @@ public void testQueryResult_whenMultipleTextSubQueriesWithSomeHits_thenHybridRes
ft.setOmitNorms(random().nextBoolean());
ft.freeze();
int docId1 = RandomizedTest.randomInt();
int docId2 = RandomizedTest.randomInt();
int docId3 = RandomizedTest.randomInt();
int docId4 = RandomizedTest.randomInt();
w.addDocument(getDocument(TEXT_FIELD_NAME, docId1, TEST_DOC_TEXT1, ft));
w.addDocument(getDocument(TEXT_FIELD_NAME, docId2, TEST_DOC_TEXT2, ft));
w.addDocument(getDocument(TEXT_FIELD_NAME, docId3, TEST_DOC_TEXT3, ft));
w.addDocument(getDocument(TEXT_FIELD_NAME, docId4, TEST_DOC_TEXT4, ft));
w.commit();

IndexReader reader = DirectoryReader.open(w);
Expand Down Expand Up @@ -395,18 +384,22 @@ public void testQueryResult_whenMultipleTextSubQueriesWithSomeHits_thenHybridRes
Query query = queryBuilder.toQuery(mockQueryShardContext);
when(searchContext.query()).thenReturn(query);

CollectorManager<? extends Collector, ReduceableSearchResult> collectorManager = HybridCollectorManager
.createHybridCollectorManager(searchContext);
Map<Class<?>, CollectorManager<? extends Collector, ReduceableSearchResult>> queryCollectorManagers = new HashMap<>();
queryCollectorManagers.put(HybridCollectorManager.class, collectorManager);
when(searchContext.queryCollectorManagers()).thenReturn(queryCollectorManagers);

hybridQueryPhaseSearcher.searchWith(searchContext, contextIndexSearcher, query, collectors, hasFilterCollector, hasTimeout);
hybridQueryPhaseSearcher.aggregationProcessor(searchContext).postProcess(searchContext);

assertNotNull(querySearchResult.topDocs());
TopDocsAndMaxScore topDocsAndMaxScore = querySearchResult.topDocs();
TopDocs topDocs = topDocsAndMaxScore.topDocs;
assertEquals(4, topDocs.totalHits.value);
assertEquals(0, topDocs.totalHits.value);
ScoreDoc[] scoreDocs = topDocs.scoreDocs;
assertNotNull(scoreDocs);
assertEquals(4, scoreDocs.length);
List<Integer> expectedIds = List.of(0, 1, 2, 3);
List<Integer> actualDocIds = Arrays.stream(scoreDocs).map(sd -> sd.doc).collect(Collectors.toList());
assertEquals(expectedIds, actualDocIds);
assertEquals(0, scoreDocs.length);

releaseResources(directory, w, reader);
}
Expand Down Expand Up @@ -705,18 +698,22 @@ public void testWrappedHybridQuery_whenHybridWrappedIntoBoolBecauseOfNested_then

when(searchContext.query()).thenReturn(query);

CollectorManager<? extends Collector, ReduceableSearchResult> collectorManager = HybridCollectorManager
.createHybridCollectorManager(searchContext);
Map<Class<?>, CollectorManager<? extends Collector, ReduceableSearchResult>> queryCollectorManagers = new HashMap<>();
queryCollectorManagers.put(HybridCollectorManager.class, collectorManager);
when(searchContext.queryCollectorManagers()).thenReturn(queryCollectorManagers);

hybridQueryPhaseSearcher.searchWith(searchContext, contextIndexSearcher, query, collectors, hasFilterCollector, hasTimeout);
hybridQueryPhaseSearcher.aggregationProcessor(searchContext).postProcess(searchContext);

assertNotNull(querySearchResult.topDocs());
TopDocsAndMaxScore topDocsAndMaxScore = querySearchResult.topDocs();
TopDocs topDocs = topDocsAndMaxScore.topDocs;
assertTrue(topDocs.totalHits.value > 0);
assertEquals(0, topDocs.totalHits.value);
ScoreDoc[] scoreDocs = topDocs.scoreDocs;
assertNotNull(scoreDocs);
assertEquals(1, scoreDocs.length);
ScoreDoc scoreDoc = scoreDocs[0];
assertTrue(scoreDoc.score > 0);
assertEquals(0, scoreDoc.doc);
assertEquals(0, scoreDocs.length);

releaseResources(directory, w, reader);
}
Expand Down Expand Up @@ -979,18 +976,22 @@ public void testAliasWithFilter_whenHybridWrappedIntoBoolBecauseOfIndexAlias_the
when(searchContext.query()).thenReturn(query);
when(searchContext.aliasFilter()).thenReturn(termFilter);

CollectorManager<? extends Collector, ReduceableSearchResult> collectorManager = HybridCollectorManager
.createHybridCollectorManager(searchContext);
Map<Class<?>, CollectorManager<? extends Collector, ReduceableSearchResult>> queryCollectorManagers = new HashMap<>();
queryCollectorManagers.put(HybridCollectorManager.class, collectorManager);
when(searchContext.queryCollectorManagers()).thenReturn(queryCollectorManagers);

hybridQueryPhaseSearcher.searchWith(searchContext, contextIndexSearcher, query, collectors, hasFilterCollector, hasTimeout);
hybridQueryPhaseSearcher.aggregationProcessor(searchContext).postProcess(searchContext);

assertNotNull(querySearchResult.topDocs());
TopDocsAndMaxScore topDocsAndMaxScore = querySearchResult.topDocs();
TopDocs topDocs = topDocsAndMaxScore.topDocs;
assertTrue(topDocs.totalHits.value > 0);
assertEquals(0, topDocs.totalHits.value);
ScoreDoc[] scoreDocs = topDocs.scoreDocs;
assertNotNull(scoreDocs);
assertEquals(1, scoreDocs.length);
ScoreDoc scoreDoc = scoreDocs[0];
assertTrue(scoreDoc.score > 0);
assertEquals(0, scoreDoc.doc);
assertEquals(0, scoreDocs.length);

releaseResources(directory, w, reader);
}
Expand Down

0 comments on commit b7fc313

Please sign in to comment.