diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQuery.java b/src/main/java/org/opensearch/knn/index/query/KNNQuery.java index f5c4d3131f..f0974f7e9f 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQuery.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQuery.java @@ -188,7 +188,7 @@ private Weight getFilterWeight(IndexSearcher searcher) throws IOException { @Override public void visit(QueryVisitor visitor) { - + visitor.visitLeaf(this); } @Override diff --git a/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java b/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java index 8b861b4301..a34a0f1eea 100644 --- a/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java +++ b/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java @@ -53,7 +53,7 @@ public class NativeEngineKnnVectorQuery extends Query { @Override public Weight createWeight(IndexSearcher indexSearcher, ScoreMode scoreMode, float boost) throws IOException { final IndexReader reader = indexSearcher.getIndexReader(); - final KNNWeight knnWeight = (KNNWeight) knnQuery.createWeight(indexSearcher, ScoreMode.COMPLETE, 1); + final KNNWeight knnWeight = (KNNWeight) knnQuery.createWeight(indexSearcher, scoreMode, 1); List leafReaderContexts = reader.leaves(); List> perLeafResults; RescoreContext rescoreContext = knnQuery.getRescoreContext(); diff --git a/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java b/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java index 53873e15f6..4577a34d41 100644 --- a/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java +++ b/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java @@ -76,6 +76,8 @@ public class NativeEngineKNNVectorQueryTests extends OpenSearchTestCase { @InjectMocks private NativeEngineKnnVectorQuery objectUnderTest; + private static ScoreMode scoreMode = ScoreMode.TOP_SCORES; + @Override public void setUp() throws Exception { super.setUp(); @@ -85,7 +87,7 @@ public void setUp() throws Exception { when(leaf2.reader()).thenReturn(leafReader2); when(searcher.getIndexReader()).thenReturn(reader); - when(knnQuery.createWeight(searcher, ScoreMode.COMPLETE, 1)).thenReturn(knnWeight); + when(knnQuery.createWeight(searcher, scoreMode, 1)).thenReturn(knnWeight); when(searcher.getTaskExecutor()).thenReturn(taskExecutor); when(taskExecutor.invokeAll(any())).thenAnswer(invocationOnMock -> { @@ -135,7 +137,7 @@ public void testMultiLeaf() { Query expected = new DocAndScoreQuery(4, expectedDocs, expectedScores, findSegments, 1); // When - Weight actual = objectUnderTest.createWeight(searcher, ScoreMode.COMPLETE, 1); + Weight actual = objectUnderTest.createWeight(searcher, scoreMode, 1); // Then assertEquals(expected, actual.getQuery()); @@ -176,7 +178,7 @@ public void testRescoreWhenShardLevelRescoringEnabled() { mockedResultUtil.when(() -> ResultUtil.reduceToTopK(any(), anyInt())).thenCallRealMethod(); // When - Weight actual = objectUnderTest.createWeight(searcher, ScoreMode.COMPLETE, 1); + Weight actual = objectUnderTest.createWeight(searcher, scoreMode, 1); // Then mockedResultUtil.verify(() -> ResultUtil.reduceToTopK(any(), anyInt()), times(2)); @@ -199,7 +201,7 @@ public void testSingleLeaf() { Query expected = new DocAndScoreQuery(4, expectedDocs, expectedScores, findSegments, 1); // When - Weight actual = objectUnderTest.createWeight(searcher, ScoreMode.COMPLETE, 1); + Weight actual = objectUnderTest.createWeight(searcher, scoreMode, 1); // Then assertEquals(expected, actual.getQuery()); @@ -214,7 +216,7 @@ public void testNoMatch() { when(knnQuery.getK()).thenReturn(4); // When - Weight actual = objectUnderTest.createWeight(searcher, ScoreMode.COMPLETE, 1); + Weight actual = objectUnderTest.createWeight(searcher, scoreMode, 1); // Then assertEquals(new MatchNoDocsQuery(), actual.getQuery()); @@ -260,7 +262,7 @@ public void testRescore() { try (MockedStatic mockedStaticNativeKnnVectorQuery = mockStatic(NativeEngineKnnVectorQuery.class)) { mockedStaticNativeKnnVectorQuery.when(() -> NativeEngineKnnVectorQuery.findSegmentStarts(any(), any())) .thenReturn(new int[] { 0, 4, 2 }); - Weight actual = objectUnderTest.createWeight(searcher, ScoreMode.COMPLETE, 1); + Weight actual = objectUnderTest.createWeight(searcher, scoreMode, 1); assertEquals(expected, actual.getQuery()); } }