Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.17] Removing code to cut search results of hybrid search in the priority queue #881

Merged
merged 1 commit into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -172,11 +172,6 @@ private TopDocs topDocsPerQuery(int start, int howMany, PriorityQueue<ScoreDoc>

int size = howMany - start;
ScoreDoc[] results = new ScoreDoc[size];
// pq's pop() returns the 'least' element in the queue, therefore need
// to discard the first ones, until we reach the requested range.
for (int i = pq.size() - start - size; i > 0; i--) {
pq.pop();
}

// Get the requested results from pq.
populateResults(results, size, pq);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ public boolean searchWith(
validateQuery(searchContext, query);
return super.searchWith(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout);
} else {
// TODO remove this check after following issue https://github.com/opensearch-project/neural-search/issues/280 gets resolved.
if (searchContext.from() != 0) {
throw new IllegalArgumentException("In the current OpenSearch version pagination is not supported with hybrid query");
}
Query hybridQuery = extractHybridQuery(searchContext, query);
QueryPhaseSearcher queryPhaseSearcher = getQueryPhaseSearcher(searchContext);
return queryPhaseSearcher.searchWith(searchContext, searcher, hybridQuery, collectors, hasFilterCollector, hasTimeout);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,8 @@ private void testPostFilterWithSimpleHybridQuery(boolean isSingleShard, boolean
rangeFilterQuery,
null,
false,
null
null,
0
);

assertHitResultsFromQuery(1, searchResponseAsMap);
Expand All @@ -230,7 +231,8 @@ private void testPostFilterWithSimpleHybridQuery(boolean isSingleShard, boolean
null,
null,
false,
null
null,
0
);
assertHitResultsFromQuery(2, searchResponseAsMap);
} else if (!isSingleShard && hasPostFilterQuery) {
Expand All @@ -244,7 +246,8 @@ private void testPostFilterWithSimpleHybridQuery(boolean isSingleShard, boolean
rangeFilterQuery,
null,
false,
null
null,
0
);
assertHitResultsFromQuery(2, searchResponseAsMap);
} else {
Expand All @@ -258,7 +261,8 @@ private void testPostFilterWithSimpleHybridQuery(boolean isSingleShard, boolean
null,
null,
false,
null
null,
0
);
assertHitResultsFromQuery(3, searchResponseAsMap);
}
Expand Down Expand Up @@ -319,7 +323,8 @@ private void testPostFilterWithComplexHybridQuery(boolean isSingleShard, boolean
rangeFilterQuery,
null,
false,
null
null,
0
);

assertHitResultsFromQuery(1, searchResponseAsMap);
Expand All @@ -334,7 +339,8 @@ private void testPostFilterWithComplexHybridQuery(boolean isSingleShard, boolean
null,
null,
false,
null
null,
0
);
assertHitResultsFromQuery(2, searchResponseAsMap);
} else if (!isSingleShard && hasPostFilterQuery) {
Expand All @@ -348,7 +354,8 @@ private void testPostFilterWithComplexHybridQuery(boolean isSingleShard, boolean
rangeFilterQuery,
null,
false,
null
null,
0
);
assertHitResultsFromQuery(4, searchResponseAsMap);
} else {
Expand All @@ -362,7 +369,8 @@ private void testPostFilterWithComplexHybridQuery(boolean isSingleShard, boolean
null,
null,
false,
null
null,
0
);
assertHitResultsFromQuery(3, searchResponseAsMap);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -793,6 +793,46 @@ public void testConcurrentSearchWithMultipleSlices_whenMultipleShardsIndex_thenS
}
}

// TODO remove this test after following issue https://github.com/opensearch-project/neural-search/issues/280 gets resolved.
@SneakyThrows
public void testHybridQuery_whenFromIsSetInSearchRequest_thenFail() {
try {
initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD);
createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE);
MatchQueryBuilder matchQueryBuilder = QueryBuilders.matchQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3);
HybridQueryBuilder hybridQueryBuilderOnlyTerm = new HybridQueryBuilder();
hybridQueryBuilderOnlyTerm.add(matchQueryBuilder);

ResponseException exceptionNoNestedTypes = expectThrows(
ResponseException.class,
() -> search(
TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD,
hybridQueryBuilderOnlyTerm,
null,
10,
Map.of("search_pipeline", SEARCH_PIPELINE),
null,
null,
null,
false,
null,
10
)

);

org.hamcrest.MatcherAssert.assertThat(
exceptionNoNestedTypes.getMessage(),
allOf(
containsString("In the current OpenSearch version pagination is not supported with hybrid query"),
containsString("illegal_argument_exception")
)
);
} finally {
wipeOfTestResources(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD, null, null, SEARCH_PIPELINE);
}
}

@SneakyThrows
private void initializeIndexIfNotExist(String indexName) throws IOException {
if (TEST_BASIC_INDEX_NAME.equals(indexName) && !indexExists(TEST_BASIC_INDEX_NAME)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,8 @@ private void testPostFilterRangeQuery(String indexName) {
postFilterQuery,
null,
false,
null
null,
0
);
assertHybridQueryResults(searchResponseAsMap, 1, 0, GTE_OF_RANGE_IN_POST_FILTER_QUERY, LTE_OF_RANGE_IN_POST_FILTER_QUERY);
}
Expand Down Expand Up @@ -262,7 +263,8 @@ private void testPostFilterBoolQuery(String indexName) {
postFilterQuery,
null,
false,
null
null,
0
);
assertHybridQueryResults(searchResponseAsMap, 2, 1, GTE_OF_RANGE_IN_POST_FILTER_QUERY, LTE_OF_RANGE_IN_POST_FILTER_QUERY);
// Case 2 A Query with a combination of hybrid query (Match Query, Term Query, Range Query), aggregation (Average stock price
Expand All @@ -278,7 +280,8 @@ private void testPostFilterBoolQuery(String indexName) {
postFilterQuery,
null,
false,
null
null,
0
);
assertHybridQueryResults(searchResponseAsMap, 2, 1, GTE_OF_RANGE_IN_POST_FILTER_QUERY, LTE_OF_RANGE_IN_POST_FILTER_QUERY);
Map<String, Object> aggregations = getAggregations(searchResponseAsMap);
Expand All @@ -303,7 +306,8 @@ private void testPostFilterBoolQuery(String indexName) {
postFilterQuery,
null,
false,
null
null,
0
);
assertHybridQueryResults(searchResponseAsMap, 0, 0, GTE_OF_RANGE_IN_POST_FILTER_QUERY, LTE_OF_RANGE_IN_POST_FILTER_QUERY);
// Case 4 A Query with a combination of hybrid query (Match Query, Range Query) and a post filter query (Bool Query with a should
Expand All @@ -324,7 +328,8 @@ private void testPostFilterBoolQuery(String indexName) {
postFilterQuery,
null,
false,
null
null,
0
);
assertHybridQueryResults(searchResponseAsMap, 0, 0, GTE_OF_RANGE_IN_POST_FILTER_QUERY, LTE_OF_RANGE_IN_POST_FILTER_QUERY);
}
Expand Down Expand Up @@ -382,7 +387,8 @@ private void testPostFilterMatchAllAndMatchNoneQueries(String indexName) {
postFilterQuery,
null,
false,
null
null,
0
);
assertHybridQueryResults(searchResponseAsMap, 4, 3, GTE_OF_RANGE_IN_POST_FILTER_QUERY, LTE_OF_RANGE_IN_POST_FILTER_QUERY);

Expand All @@ -399,7 +405,8 @@ private void testPostFilterMatchAllAndMatchNoneQueries(String indexName) {
postFilterQuery,
null,
false,
null
null,
0
);
assertHybridQueryResults(searchResponseAsMap, 0, 0, GTE_OF_RANGE_IN_POST_FILTER_QUERY, LTE_OF_RANGE_IN_POST_FILTER_QUERY);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,8 @@ private void testSingleFieldSort_whenMultipleSubQueries_thenSuccessful(String in
null,
createSortBuilders(fieldSortOrderMap, false),
false,
null
null,
0
);
List<Map<String, Object>> nestedHits = validateHitsCountAndFetchNestedHits(searchResponseAsMap, 6, 6);
assertStockValueWithSortOrderInHybridQueryResults(nestedHits, SortOrder.DESC, LARGEST_STOCK_VALUE_IN_QUERY_RESULT, true, true);
Expand Down Expand Up @@ -168,7 +169,8 @@ private void testMultipleFieldSort_whenMultipleSubQueries_thenSuccessful(String
null,
createSortBuilders(fieldSortOrderMap, false),
false,
null
null,
0
);
List<Map<String, Object>> nestedHits = validateHitsCountAndFetchNestedHits(searchResponseAsMap, 6, 6);
assertStockValueWithSortOrderInHybridQueryResults(nestedHits, SortOrder.DESC, LARGEST_STOCK_VALUE_IN_QUERY_RESULT, true, false);
Expand Down Expand Up @@ -200,7 +202,8 @@ public void testSingleFieldSort_whenTrackScoresIsEnabled_thenFail() {
null,
createSortBuilders(fieldSortOrderMap, false),
true,
null
null,
0
)
);
} finally {
Expand Down Expand Up @@ -234,7 +237,8 @@ public void testSingleFieldSort_whenSortCriteriaIsByScoreAndField_thenFail() {
null,
createSortBuilders(fieldSortOrderMap, false),
true,
null
null,
0
)
);
} finally {
Expand Down Expand Up @@ -312,7 +316,8 @@ private void testSearchAfter_whenSingleFieldSort_thenSuccessful(String indexName
null,
createSortBuilders(fieldSortOrderMap, false),
false,
searchAfter
searchAfter,
0
);
List<Map<String, Object>> nestedHits = validateHitsCountAndFetchNestedHits(searchResponseAsMap, 3, 6);
assertStockValueWithSortOrderInHybridQueryResults(
Expand Down Expand Up @@ -348,7 +353,8 @@ private void testSearchAfter_whenMultipleFieldSort_thenSuccessful(String indexNa
null,
createSortBuilders(fieldSortOrderMap, false),
false,
searchAfter
searchAfter,
0
);
List<Map<String, Object>> nestedHits = validateHitsCountAndFetchNestedHits(searchResponseAsMap, 5, 6);
assertStockValueWithSortOrderInHybridQueryResults(
Expand Down Expand Up @@ -381,7 +387,8 @@ private void testScoreSort_whenSingleFieldSort_thenSuccessful(String indexName)
null,
createSortBuilders(fieldSortOrderMap, false),
false,
null
null,
0
);
List<Map<String, Object>> nestedHits = validateHitsCountAndFetchNestedHits(searchResponseAsMap, 6, 6);
assertScoreWithSortOrderInHybridQueryResults(nestedHits, SortOrder.DESC, 1.0);
Expand Down Expand Up @@ -415,7 +422,8 @@ public void testSort_whenSortFieldsSizeNotEqualToSearchAfterSize_thenFail() {
null,
createSortBuilders(fieldSortOrderMap, false),
true,
searchAfter
searchAfter,
0
)
);
} finally {
Expand Down Expand Up @@ -450,7 +458,8 @@ public void testSearchAfter_whenAfterFieldIsNotPassed_thenFail() {
null,
createSortBuilders(fieldSortOrderMap, false),
true,
searchAfter
searchAfter,
0
)
);
} finally {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,8 @@ private void testSumAggsAndRangePostFilter() throws IOException {
rangeFilterQuery,
null,
false,
null
null,
0
);

Map<String, Object> aggregations = getAggregations(searchResponseAsMap);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@ protected Map<String, Object> search(
Map<String, String> requestParams,
List<Object> aggs
) {
return search(index, queryBuilder, rescorer, resultSize, requestParams, aggs, null, null, false, null);
return search(index, queryBuilder, rescorer, resultSize, requestParams, aggs, null, null, false, null, 0);
}

@SneakyThrows
Expand All @@ -528,10 +528,11 @@ protected Map<String, Object> search(
QueryBuilder postFilterBuilder,
List<SortBuilder<?>> sortBuilders,
boolean trackScores,
List<Object> searchAfter
List<Object> searchAfter,
int from
) {
XContentBuilder builder = XContentFactory.jsonBuilder().startObject();

builder.field("from", from);
if (queryBuilder != null) {
builder.field("query");
queryBuilder.toXContent(builder, ToXContent.EMPTY_PARAMS);
Expand Down
Loading