diff --git a/release-notes/opensearch-knn.release-notes-2.17.0.0.md b/release-notes/opensearch-knn.release-notes-2.17.0.0.md index fda218566..0e4fb008a 100644 --- a/release-notes/opensearch-knn.release-notes-2.17.0.0.md +++ b/release-notes/opensearch-knn.release-notes-2.17.0.0.md @@ -21,7 +21,6 @@ Compatible with OpenSearch 2.17.0 * Fix memory overflow caused by cache behavior [#2015](https://github.com/opensearch-project/k-NN/pull/2015) * Use correct type for binary vector in ivf training [#2086](https://github.com/opensearch-project/k-NN/pull/2086) * Switch MINGW32 to MINGW64 [#2090](https://github.com/opensearch-project/k-NN/pull/2090) -* Does not create additional KNNVectorValues in NativeEngines990KNNVectorWriter when quantization is not needed [#2133](https://github.com/opensearch-project/k-NN/pull/2133) ### Infrastructure * Parallelize make to reduce build time [#2006] (https://github.com/opensearch-project/k-NN/pull/2006) ### Maintenance diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java index 23cd2a4de..dba0926ff 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java @@ -25,10 +25,11 @@ import org.apache.lucene.util.IOUtils; import org.apache.lucene.util.RamUsageEstimator; import org.opensearch.common.StopWatch; +import org.opensearch.knn.index.quantizationservice.QuantizationService; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.codec.nativeindex.NativeIndexWriter; -import org.opensearch.knn.index.quantizationservice.QuantizationService; import org.opensearch.knn.index.vectorvalues.KNNVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; import org.opensearch.knn.plugin.stats.KNNGraphValue; import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; @@ -36,10 +37,8 @@ import java.io.IOException; import java.util.ArrayList; import java.util.List; -import java.util.function.Supplier; import static org.opensearch.knn.common.FieldInfoExtractor.extractVectorDataType; -import static org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory.getVectorValues; /** * A KNNVectorsWriter class for writing the vector data strcutures and flat vectors for Native Engines. @@ -48,11 +47,15 @@ public class NativeEngines990KnnVectorsWriter extends KnnVectorsWriter { private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(NativeEngines990KnnVectorsWriter.class); + private static final String FLUSH_OPERATION = "flush"; + private static final String MERGE_OPERATION = "merge"; + private final SegmentWriteState segmentWriteState; private final FlatVectorsWriter flatVectorsWriter; private KNN990QuantizationStateWriter quantizationStateWriter; private final List> fields = new ArrayList<>(); private boolean finished; + private final QuantizationService quantizationService = QuantizationService.getInstance(); public NativeEngines990KnnVectorsWriter(SegmentWriteState segmentWriteState, FlatVectorsWriter flatVectorsWriter) { this.segmentWriteState = segmentWriteState; @@ -81,27 +84,14 @@ public void flush(int maxDoc, final Sorter.DocMap sortMap) throws IOException { flatVectorsWriter.flush(maxDoc, sortMap); for (final NativeEngineFieldVectorsWriter field : fields) { - final FieldInfo fieldInfo = field.getFieldInfo(); - final VectorDataType vectorDataType = extractVectorDataType(fieldInfo); - int totalLiveDocs = field.getVectors().size(); - if (totalLiveDocs > 0) { - final Supplier> knnVectorValuesSupplier = () -> getVectorValues( - vectorDataType, - field.getDocsWithField(), - field.getVectors() - ); - final QuantizationState quantizationState = train(field.getFieldInfo(), knnVectorValuesSupplier, totalLiveDocs); - final NativeIndexWriter writer = NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState); - final KNNVectorValues knnVectorValues = knnVectorValuesSupplier.get(); - - StopWatch stopWatch = new StopWatch().start(); - writer.flushIndex(knnVectorValues, totalLiveDocs); - long time_in_millis = stopWatch.stop().totalTime().millis(); - KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS.incrementBy(time_in_millis); - log.debug("Flush took {} ms for vector field [{}]", time_in_millis, fieldInfo.getName()); - } else { - log.debug("[Flush] No live docs for field {}", fieldInfo.getName()); - } + trainAndIndex( + field.getFieldInfo(), + (vectorDataType, fieldInfo, fieldVectorsWriter) -> getKNNVectorValues(vectorDataType, fieldVectorsWriter), + NativeIndexWriter::flushIndex, + field, + KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS, + FLUSH_OPERATION + ); } } @@ -110,29 +100,15 @@ public void mergeOneField(final FieldInfo fieldInfo, final MergeState mergeState // This will ensure that we are merging the FlatIndex during force merge. flatVectorsWriter.mergeOneField(fieldInfo, mergeState); - final VectorDataType vectorDataType = extractVectorDataType(fieldInfo); - final Supplier> knnVectorValuesSupplier = () -> getKNNVectorValuesForMerge( - vectorDataType, + // For merge, pick values from flat vector and reindex again. This will use the flush operation to create graphs + trainAndIndex( fieldInfo, - mergeState + this::getKNNVectorValuesForMerge, + NativeIndexWriter::mergeIndex, + mergeState, + KNNGraphValue.MERGE_TOTAL_TIME_IN_MILLIS, + MERGE_OPERATION ); - int totalLiveDocs = getLiveDocs(knnVectorValuesSupplier.get()); - if (totalLiveDocs == 0) { - log.debug("[Merge] No live docs for field {}", fieldInfo.getName()); - return; - } - - final QuantizationState quantizationState = train(fieldInfo, knnVectorValuesSupplier, totalLiveDocs); - final NativeIndexWriter writer = NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState); - final KNNVectorValues knnVectorValues = knnVectorValuesSupplier.get(); - - StopWatch stopWatch = new StopWatch().start(); - - writer.mergeIndex(knnVectorValues, totalLiveDocs); - - long time_in_millis = stopWatch.stop().totalTime().millis(); - KNNGraphValue.MERGE_TOTAL_TIME_IN_MILLIS.incrementBy(time_in_millis); - log.debug("Merge took {} ms for vector field [{}]", time_in_millis, fieldInfo.getName()); } /** @@ -181,6 +157,18 @@ public long ramBytesUsed() { .sum(); } + /** + * Retrieves the {@link KNNVectorValues} for a specific field based on the vector data type and field writer. + * + * @param vectorDataType The {@link VectorDataType} representing the type of vectors stored. + * @param field The {@link NativeEngineFieldVectorsWriter} representing the field from which to retrieve vectors. + * @param The type of vectors being processed. + * @return The {@link KNNVectorValues} associated with the field. + */ + private KNNVectorValues getKNNVectorValues(final VectorDataType vectorDataType, final NativeEngineFieldVectorsWriter field) { + return (KNNVectorValues) KNNVectorValuesFactory.getVectorValues(vectorDataType, field.getDocsWithField(), field.getVectors()); + } + /** * Retrieves the {@link KNNVectorValues} for a specific field during a merge operation, based on the vector data type. * @@ -195,41 +183,89 @@ private KNNVectorValues getKNNVectorValuesForMerge( final VectorDataType vectorDataType, final FieldInfo fieldInfo, final MergeState mergeState - ) { - try { - switch (fieldInfo.getVectorEncoding()) { - case FLOAT32: - FloatVectorValues mergedFloats = MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState); - return getVectorValues(vectorDataType, mergedFloats); - case BYTE: - ByteVectorValues mergedBytes = MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState); - return getVectorValues(vectorDataType, mergedBytes); - default: - throw new IllegalStateException("Unsupported vector encoding [" + fieldInfo.getVectorEncoding() + "]"); - } - } catch (final IOException e) { - log.error("Unable to merge vectors for field [{}]", fieldInfo.getName(), e); - throw new IllegalStateException("Unable to merge vectors for field [" + fieldInfo.getName() + "]", e); + ) throws IOException { + switch (fieldInfo.getVectorEncoding()) { + case FLOAT32: + FloatVectorValues mergedFloats = MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState); + return (KNNVectorValues) KNNVectorValuesFactory.getVectorValues(vectorDataType, mergedFloats); + case BYTE: + ByteVectorValues mergedBytes = MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState); + return (KNNVectorValues) KNNVectorValuesFactory.getVectorValues(vectorDataType, mergedBytes); + default: + throw new IllegalStateException("Unsupported vector encoding [" + fieldInfo.getVectorEncoding() + "]"); } } - private QuantizationState train( + /** + * Functional interface representing an operation that indexes the provided {@link KNNVectorValues}. + * + * @param The type of vectors being processed. + */ + @FunctionalInterface + private interface IndexOperation { + void buildAndWrite(NativeIndexWriter writer, KNNVectorValues knnVectorValues, int totalLiveDocs) throws IOException; + } + + /** + * Functional interface representing a method that retrieves {@link KNNVectorValues} based on + * the vector data type, field information, and the merge state. + * + * @param The type of the data representing the vector (e.g., {@link VectorDataType}). + * @param The metadata about the field. + * @param The state of the merge operation. + * @param The result of the retrieval, typically {@link KNNVectorValues}. + */ + @FunctionalInterface + private interface VectorValuesRetriever { + Result apply(DataType vectorDataType, FieldInfo fieldInfo, MergeState mergeState) throws IOException; + } + + /** + * Unified method for processing a field during either the indexing or merge operation. This method retrieves vector values + * based on the provided vector data type and applies the specified index operation, potentially including quantization if needed. + * + * @param fieldInfo The {@link FieldInfo} object containing metadata about the field. + * @param vectorValuesRetriever A functional interface that retrieves {@link KNNVectorValues} based on the vector data type, + * field information, and additional context (e.g., merge state or field writer). + * @param indexOperation A functional interface that performs the indexing operation using the retrieved + * {@link KNNVectorValues}. + * @param VectorProcessingContext The additional context required for retrieving the vector values (e.g., {@link MergeState} or {@link NativeEngineFieldVectorsWriter}). + * From Flush we need NativeFieldWriter which contains total number of vectors while from Merge we need merge state which contains vector information + * @param The type of vectors being processed. + * @param The type of the context needed for retrieving the vector values. + * @throws IOException If an I/O error occurs during the processing. + */ + private void trainAndIndex( final FieldInfo fieldInfo, - final Supplier> knnVectorValuesSupplier, - final int totalLiveDocs + final VectorValuesRetriever> vectorValuesRetriever, + final IndexOperation indexOperation, + final C VectorProcessingContext, + final KNNGraphValue graphBuildTime, + final String operationName ) throws IOException { - - final QuantizationService quantizationService = QuantizationService.getInstance(); - final QuantizationParams quantizationParams = quantizationService.getQuantizationParams(fieldInfo); + final VectorDataType vectorDataType = extractVectorDataType(fieldInfo); + KNNVectorValues knnVectorValues = vectorValuesRetriever.apply(vectorDataType, fieldInfo, VectorProcessingContext); + QuantizationParams quantizationParams = quantizationService.getQuantizationParams(fieldInfo); QuantizationState quantizationState = null; + // Count the docIds + int totalLiveDocs = getLiveDocs(vectorValuesRetriever.apply(vectorDataType, fieldInfo, VectorProcessingContext)); if (quantizationParams != null && totalLiveDocs > 0) { initQuantizationStateWriterIfNecessary(); - KNNVectorValues knnVectorValues = knnVectorValuesSupplier.get(); quantizationState = quantizationService.train(quantizationParams, knnVectorValues, totalLiveDocs); quantizationStateWriter.writeState(fieldInfo.getFieldNumber(), quantizationState); } + NativeIndexWriter writer = (quantizationParams != null) + ? NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState) + : NativeIndexWriter.getWriter(fieldInfo, segmentWriteState); + + knnVectorValues = vectorValuesRetriever.apply(vectorDataType, fieldInfo, VectorProcessingContext); - return quantizationState; + StopWatch stopWatch = new StopWatch(); + stopWatch.start(); + indexOperation.buildAndWrite(writer, knnVectorValues, totalLiveDocs); + long time_in_millis = stopWatch.totalTime().millis(); + graphBuildTime.incrementBy(time_in_millis); + log.warn("Graph build took " + time_in_millis + " ms for " + operationName); } /** diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java deleted file mode 100644 index dbb564908..000000000 --- a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java +++ /dev/null @@ -1,299 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.knn.index.codec.KNN990Codec; - -import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; -import lombok.RequiredArgsConstructor; -import lombok.SneakyThrows; -import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; -import org.apache.lucene.index.DocsWithFieldSet; -import org.apache.lucene.index.FieldInfo; -import org.apache.lucene.index.SegmentWriteState; -import org.apache.lucene.index.VectorEncoding; -import org.mockito.Mock; -import org.mockito.MockedConstruction; -import org.mockito.MockedStatic; -import org.mockito.MockitoAnnotations; -import org.opensearch.knn.common.KNNConstants; -import org.opensearch.knn.index.VectorDataType; -import org.opensearch.knn.index.codec.nativeindex.NativeIndexWriter; -import org.opensearch.knn.index.quantizationservice.QuantizationService; -import org.opensearch.knn.index.vectorvalues.KNNVectorValues; -import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; -import org.opensearch.knn.index.vectorvalues.TestVectorValues; -import org.opensearch.knn.plugin.stats.KNNGraphValue; -import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; -import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; -import org.opensearch.test.OpenSearchTestCase; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; -import java.util.List; -import java.util.Map; -import java.util.stream.IntStream; - -import static com.carrotsearch.randomizedtesting.RandomizedTest.$; -import static com.carrotsearch.randomizedtesting.RandomizedTest.$$; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyInt; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.mockConstruction; -import static org.mockito.Mockito.mockStatic; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -@RequiredArgsConstructor -public class NativeEngines990KnnVectorsWriterFlushTests extends OpenSearchTestCase { - - @Mock - private FlatVectorsWriter flatVectorsWriter; - @Mock - private SegmentWriteState segmentWriteState; - @Mock - private QuantizationParams quantizationParams; - @Mock - private QuantizationState quantizationState; - @Mock - private QuantizationService quantizationService; - @Mock - private NativeIndexWriter nativeIndexWriter; - - private NativeEngines990KnnVectorsWriter objectUnderTest; - - private final String description; - private final List> vectorsPerField; - - @Override - public void setUp() throws Exception { - super.setUp(); - MockitoAnnotations.openMocks(this); - objectUnderTest = new NativeEngines990KnnVectorsWriter(segmentWriteState, flatVectorsWriter); - } - - @ParametersFactory - public static Collection data() { - return Arrays.asList( - $$( - $("Single field", List.of(Map.of(0, new float[] { 1, 2, 3 }, 1, new float[] { 2, 3, 4 }, 2, new float[] { 3, 4, 5 }))), - $("Single field, no total live docs", List.of()), - $( - "Multi Field", - List.of( - Map.of(0, new float[] { 1, 2, 3 }, 1, new float[] { 2, 3, 4 }, 2, new float[] { 3, 4, 5 }), - Map.of( - 0, - new float[] { 1, 2, 3, 4 }, - 1, - new float[] { 2, 3, 4, 5 }, - 2, - new float[] { 3, 4, 5, 6 }, - 3, - new float[] { 4, 5, 6, 7 } - ) - ) - ) - ) - ); - } - - @SneakyThrows - public void testFlush() { - // Given - List> expectedVectorValues = new ArrayList<>(); - IntStream.range(0, vectorsPerField.size()).forEach(i -> { - final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( - new ArrayList<>(vectorsPerField.get(i).values()) - ); - final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues( - VectorDataType.FLOAT, - randomVectorValues - ); - expectedVectorValues.add(knnVectorValues); - - }); - - try ( - MockedStatic fieldWriterMockedStatic = mockStatic(NativeEngineFieldVectorsWriter.class); - MockedStatic knnVectorValuesFactoryMockedStatic = mockStatic(KNNVectorValuesFactory.class); - MockedStatic quantizationServiceMockedStatic = mockStatic(QuantizationService.class); - MockedStatic nativeIndexWriterMockedStatic = mockStatic(NativeIndexWriter.class); - MockedConstruction knn990QuantWriterMockedConstruction = mockConstruction( - KNN990QuantizationStateWriter.class - ); - ) { - quantizationServiceMockedStatic.when(() -> QuantizationService.getInstance()).thenReturn(quantizationService); - IntStream.range(0, vectorsPerField.size()).forEach(i -> { - final FieldInfo fieldInfo = fieldInfo( - i, - VectorEncoding.FLOAT32, - Map.of(KNNConstants.VECTOR_DATA_TYPE_FIELD, "float", KNNConstants.KNN_ENGINE, "faiss") - ); - - NativeEngineFieldVectorsWriter field = nativeEngineFieldVectorsWriter(fieldInfo, vectorsPerField.get(i)); - fieldWriterMockedStatic.when(() -> NativeEngineFieldVectorsWriter.create(fieldInfo, segmentWriteState.infoStream)) - .thenReturn(field); - - try { - objectUnderTest.addField(fieldInfo); - } catch (Exception e) { - throw new RuntimeException(e); - } - - DocsWithFieldSet docsWithFieldSet = field.getDocsWithField(); - knnVectorValuesFactoryMockedStatic.when( - () -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i)) - ).thenReturn(expectedVectorValues.get(i)); - - when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(null); - nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, null)) - .thenReturn(nativeIndexWriter); - }); - - doAnswer(answer -> { - Thread.sleep(2); // Need this for KNNGraph value assertion, removing this will fail the assertion - return null; - }).when(nativeIndexWriter).flushIndex(any(), anyInt()); - - // When - objectUnderTest.flush(5, null); - - // Then - verify(flatVectorsWriter).flush(5, null); - if (vectorsPerField.size() > 0) { - assertEquals(0, knn990QuantWriterMockedConstruction.constructed().size()); - assertNotEquals(0L, (long) KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS.getValue()); - } - - IntStream.range(0, vectorsPerField.size()).forEach(i -> { - try { - verify(nativeIndexWriter).flushIndex(expectedVectorValues.get(i), vectorsPerField.get(i).size()); - } catch (Exception e) { - throw new RuntimeException(e); - } - }); - - knnVectorValuesFactoryMockedStatic.verify( - () -> KNNVectorValuesFactory.getVectorValues(any(VectorDataType.class), any(DocsWithFieldSet.class), any()), - times(expectedVectorValues.size()) - ); - } - } - - @SneakyThrows - public void testFlush_WithQuantization() { - // Given - List> expectedVectorValues = new ArrayList<>(); - IntStream.range(0, vectorsPerField.size()).forEach(i -> { - final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( - new ArrayList<>(vectorsPerField.get(i).values()) - ); - final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues( - VectorDataType.FLOAT, - randomVectorValues - ); - expectedVectorValues.add(knnVectorValues); - - }); - - try ( - MockedStatic fieldWriterMockedStatic = mockStatic(NativeEngineFieldVectorsWriter.class); - MockedStatic knnVectorValuesFactoryMockedStatic = mockStatic(KNNVectorValuesFactory.class); - MockedStatic quantizationServiceMockedStatic = mockStatic(QuantizationService.class); - MockedStatic nativeIndexWriterMockedStatic = mockStatic(NativeIndexWriter.class); - MockedConstruction knn990QuantWriterMockedConstruction = mockConstruction( - KNN990QuantizationStateWriter.class - ); - ) { - quantizationServiceMockedStatic.when(() -> QuantizationService.getInstance()).thenReturn(quantizationService); - - IntStream.range(0, vectorsPerField.size()).forEach(i -> { - final FieldInfo fieldInfo = fieldInfo( - i, - VectorEncoding.FLOAT32, - Map.of(KNNConstants.VECTOR_DATA_TYPE_FIELD, "float", KNNConstants.KNN_ENGINE, "faiss") - ); - - NativeEngineFieldVectorsWriter field = nativeEngineFieldVectorsWriter(fieldInfo, vectorsPerField.get(i)); - fieldWriterMockedStatic.when(() -> NativeEngineFieldVectorsWriter.create(fieldInfo, segmentWriteState.infoStream)) - .thenReturn(field); - - try { - objectUnderTest.addField(fieldInfo); - } catch (Exception e) { - throw new RuntimeException(e); - } - - DocsWithFieldSet docsWithFieldSet = field.getDocsWithField(); - knnVectorValuesFactoryMockedStatic.when( - () -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i)) - ).thenReturn(expectedVectorValues.get(i)); - - when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(quantizationParams); - try { - when(quantizationService.train(quantizationParams, expectedVectorValues.get(i), vectorsPerField.get(i).size())) - .thenReturn(quantizationState); - } catch (Exception e) { - throw new RuntimeException(e); - } - - nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState)) - .thenReturn(nativeIndexWriter); - }); - doAnswer(answer -> { - Thread.sleep(2); // Need this for KNNGraph value assertion, removing this will fail the assertion - return null; - }).when(nativeIndexWriter).flushIndex(any(), anyInt()); - - // When - objectUnderTest.flush(5, null); - - // Then - verify(flatVectorsWriter).flush(5, null); - if (vectorsPerField.size() > 0) { - verify(knn990QuantWriterMockedConstruction.constructed().get(0)).writeHeader(segmentWriteState); - assertTrue(KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS.getValue() > 0L); - } else { - assertEquals(0, knn990QuantWriterMockedConstruction.constructed().size()); - } - - IntStream.range(0, vectorsPerField.size()).forEach(i -> { - try { - verify(knn990QuantWriterMockedConstruction.constructed().get(0)).writeState(i, quantizationState); - verify(nativeIndexWriter).flushIndex(expectedVectorValues.get(i), vectorsPerField.get(i).size()); - } catch (Exception e) { - throw new RuntimeException(e); - } - }); - - knnVectorValuesFactoryMockedStatic.verify( - () -> KNNVectorValuesFactory.getVectorValues(any(VectorDataType.class), any(DocsWithFieldSet.class), any()), - times(expectedVectorValues.size() * 2) - ); - } - } - - private FieldInfo fieldInfo(int fieldNumber, VectorEncoding vectorEncoding, Map attributes) { - FieldInfo fieldInfo = mock(FieldInfo.class); - when(fieldInfo.getFieldNumber()).thenReturn(fieldNumber); - when(fieldInfo.getVectorEncoding()).thenReturn(vectorEncoding); - when(fieldInfo.attributes()).thenReturn(attributes); - attributes.forEach((key, value) -> when(fieldInfo.getAttribute(key)).thenReturn(value)); - return fieldInfo; - } - - private NativeEngineFieldVectorsWriter nativeEngineFieldVectorsWriter(FieldInfo fieldInfo, Map vectors) { - NativeEngineFieldVectorsWriter fieldVectorsWriter = mock(NativeEngineFieldVectorsWriter.class); - DocsWithFieldSet docsWithFieldSet = new DocsWithFieldSet(); - vectors.keySet().stream().sorted().forEach(docsWithFieldSet::add); - when(fieldVectorsWriter.getFieldInfo()).thenReturn(fieldInfo); - when(fieldVectorsWriter.getVectors()).thenReturn(vectors); - when(fieldVectorsWriter.getDocsWithField()).thenReturn(docsWithFieldSet); - return fieldVectorsWriter; - } -} diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterMergeTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterMergeTests.java deleted file mode 100644 index 41940c4d4..000000000 --- a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterMergeTests.java +++ /dev/null @@ -1,249 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.knn.index.codec.KNN990Codec; - -import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; -import lombok.RequiredArgsConstructor; -import lombok.SneakyThrows; -import org.apache.lucene.codecs.KnnVectorsWriter; -import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; -import org.apache.lucene.index.DocsWithFieldSet; -import org.apache.lucene.index.FieldInfo; -import org.apache.lucene.index.FloatVectorValues; -import org.apache.lucene.index.MergeState; -import org.apache.lucene.index.SegmentWriteState; -import org.apache.lucene.index.VectorEncoding; -import org.mockito.Mock; -import org.mockito.MockedConstruction; -import org.mockito.MockedStatic; -import org.mockito.MockitoAnnotations; -import org.opensearch.knn.common.KNNConstants; -import org.opensearch.knn.index.VectorDataType; -import org.opensearch.knn.index.codec.nativeindex.NativeIndexWriter; -import org.opensearch.knn.index.quantizationservice.QuantizationService; -import org.opensearch.knn.index.vectorvalues.KNNVectorValues; -import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; -import org.opensearch.knn.index.vectorvalues.TestVectorValues; -import org.opensearch.knn.plugin.stats.KNNGraphValue; -import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; -import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; -import org.opensearch.test.OpenSearchTestCase; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; -import java.util.Map; - -import static com.carrotsearch.randomizedtesting.RandomizedTest.$; -import static com.carrotsearch.randomizedtesting.RandomizedTest.$$; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyInt; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.mockConstruction; -import static org.mockito.Mockito.mockStatic; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.verifyNoInteractions; -import static org.mockito.Mockito.when; - -@RequiredArgsConstructor -public class NativeEngines990KnnVectorsWriterMergeTests extends OpenSearchTestCase { - - @Mock - private FlatVectorsWriter flatVectorsWriter; - @Mock - private SegmentWriteState segmentWriteState; - @Mock - private QuantizationParams quantizationParams; - @Mock - private QuantizationState quantizationState; - @Mock - private QuantizationService quantizationService; - @Mock - private NativeIndexWriter nativeIndexWriter; - @Mock - private FloatVectorValues floatVectorValues; - @Mock - private MergeState mergeState; - - private NativeEngines990KnnVectorsWriter objectUnderTest; - - private final String description; - private final Map mergedVectors; - - @Override - public void setUp() throws Exception { - super.setUp(); - MockitoAnnotations.openMocks(this); - objectUnderTest = new NativeEngines990KnnVectorsWriter(segmentWriteState, flatVectorsWriter); - } - - @ParametersFactory - public static Collection data() { - return Arrays.asList( - $$( - $("Merge one field", Map.of(0, new float[] { 1, 2, 3 }, 1, new float[] { 2, 3, 4 }, 2, new float[] { 3, 4, 5 })), - $("Merge, no live docs", Map.of()) - ) - ); - } - - @SneakyThrows - public void testMerge() { - // Given - final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( - new ArrayList<>(mergedVectors.values()) - ); - final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, randomVectorValues); - - try ( - MockedStatic fieldWriterMockedStatic = mockStatic(NativeEngineFieldVectorsWriter.class); - MockedStatic knnVectorValuesFactoryMockedStatic = mockStatic(KNNVectorValuesFactory.class); - MockedStatic quantizationServiceMockedStatic = mockStatic(QuantizationService.class); - MockedStatic nativeIndexWriterMockedStatic = mockStatic(NativeIndexWriter.class); - MockedStatic mergedVectorValuesMockedStatic = mockStatic( - KnnVectorsWriter.MergedVectorValues.class - ); - MockedConstruction knn990QuantWriterMockedConstruction = mockConstruction( - KNN990QuantizationStateWriter.class - ); - ) { - quantizationServiceMockedStatic.when(() -> QuantizationService.getInstance()).thenReturn(quantizationService); - final FieldInfo fieldInfo = fieldInfo( - 0, - VectorEncoding.FLOAT32, - Map.of(KNNConstants.VECTOR_DATA_TYPE_FIELD, "float", KNNConstants.KNN_ENGINE, "faiss") - ); - - NativeEngineFieldVectorsWriter field = nativeEngineFieldVectorsWriter(fieldInfo, mergedVectors); - fieldWriterMockedStatic.when(() -> NativeEngineFieldVectorsWriter.create(fieldInfo, segmentWriteState.infoStream)) - .thenReturn(field); - - mergedVectorValuesMockedStatic.when(() -> KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState)) - .thenReturn(floatVectorValues); - knnVectorValuesFactoryMockedStatic.when(() -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, floatVectorValues)) - .thenReturn(knnVectorValues); - - when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(null); - nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, null)) - .thenReturn(nativeIndexWriter); - doAnswer(answer -> { - Thread.sleep(2); // Need this for KNNGraph value assertion, removing this will fail the assertion - return null; - }).when(nativeIndexWriter).mergeIndex(any(), anyInt()); - - // When - objectUnderTest.mergeOneField(fieldInfo, mergeState); - - // Then - verify(flatVectorsWriter).mergeOneField(fieldInfo, mergeState); - assertEquals(0, knn990QuantWriterMockedConstruction.constructed().size()); - if (!mergedVectors.isEmpty()) { - verify(nativeIndexWriter).mergeIndex(knnVectorValues, mergedVectors.size()); - assertTrue(KNNGraphValue.MERGE_TOTAL_TIME_IN_MILLIS.getValue() > 0L); - knnVectorValuesFactoryMockedStatic.verify( - () -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, floatVectorValues), - times(2) - ); - } else { - verifyNoInteractions(nativeIndexWriter); - } - } - } - - @SneakyThrows - public void testMerge_WithQuantization() { - // Given - final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( - new ArrayList<>(mergedVectors.values()) - ); - final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, randomVectorValues); - - try ( - MockedStatic fieldWriterMockedStatic = mockStatic(NativeEngineFieldVectorsWriter.class); - MockedStatic knnVectorValuesFactoryMockedStatic = mockStatic(KNNVectorValuesFactory.class); - MockedStatic quantizationServiceMockedStatic = mockStatic(QuantizationService.class); - MockedStatic nativeIndexWriterMockedStatic = mockStatic(NativeIndexWriter.class); - MockedConstruction knn990QuantWriterMockedConstruction = mockConstruction( - KNN990QuantizationStateWriter.class - ); - MockedStatic mergedVectorValuesMockedStatic = mockStatic( - KnnVectorsWriter.MergedVectorValues.class - ); - ) { - quantizationServiceMockedStatic.when(() -> QuantizationService.getInstance()).thenReturn(quantizationService); - - final FieldInfo fieldInfo = fieldInfo( - 0, - VectorEncoding.FLOAT32, - Map.of(KNNConstants.VECTOR_DATA_TYPE_FIELD, "float", KNNConstants.KNN_ENGINE, "faiss") - ); - - NativeEngineFieldVectorsWriter field = nativeEngineFieldVectorsWriter(fieldInfo, mergedVectors); - fieldWriterMockedStatic.when(() -> NativeEngineFieldVectorsWriter.create(fieldInfo, segmentWriteState.infoStream)) - .thenReturn(field); - - mergedVectorValuesMockedStatic.when(() -> KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState)) - .thenReturn(floatVectorValues); - knnVectorValuesFactoryMockedStatic.when(() -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, floatVectorValues)) - .thenReturn(knnVectorValues); - - when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(quantizationParams); - try { - when(quantizationService.train(quantizationParams, knnVectorValues, mergedVectors.size())).thenReturn(quantizationState); - } catch (Exception e) { - throw new RuntimeException(e); - } - - nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState)) - .thenReturn(nativeIndexWriter); - doAnswer(answer -> { - Thread.sleep(2); // Need this for KNNGraph value assertion, removing this will fail the assertion - return null; - }).when(nativeIndexWriter).mergeIndex(any(), anyInt()); - - // When - objectUnderTest.mergeOneField(fieldInfo, mergeState); - - // Then - verify(flatVectorsWriter).mergeOneField(fieldInfo, mergeState); - if (!mergedVectors.isEmpty()) { - verify(knn990QuantWriterMockedConstruction.constructed().get(0)).writeHeader(segmentWriteState); - verify(knn990QuantWriterMockedConstruction.constructed().get(0)).writeState(0, quantizationState); - verify(nativeIndexWriter).mergeIndex(knnVectorValues, mergedVectors.size()); - assertTrue(KNNGraphValue.MERGE_TOTAL_TIME_IN_MILLIS.getValue() > 0L); - knnVectorValuesFactoryMockedStatic.verify( - () -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, floatVectorValues), - times(3) - ); - } else { - assertEquals(0, knn990QuantWriterMockedConstruction.constructed().size()); - verifyNoInteractions(nativeIndexWriter); - } - - } - } - - private FieldInfo fieldInfo(int fieldNumber, VectorEncoding vectorEncoding, Map attributes) { - FieldInfo fieldInfo = mock(FieldInfo.class); - when(fieldInfo.getFieldNumber()).thenReturn(fieldNumber); - when(fieldInfo.getVectorEncoding()).thenReturn(vectorEncoding); - when(fieldInfo.attributes()).thenReturn(attributes); - attributes.forEach((key, value) -> when(fieldInfo.getAttribute(key)).thenReturn(value)); - return fieldInfo; - } - - private NativeEngineFieldVectorsWriter nativeEngineFieldVectorsWriter(FieldInfo fieldInfo, Map vectors) { - NativeEngineFieldVectorsWriter fieldVectorsWriter = mock(NativeEngineFieldVectorsWriter.class); - DocsWithFieldSet docsWithFieldSet = new DocsWithFieldSet(); - vectors.keySet().stream().sorted().forEach(docsWithFieldSet::add); - when(fieldVectorsWriter.getFieldInfo()).thenReturn(fieldInfo); - when(fieldVectorsWriter.getVectors()).thenReturn(vectors); - when(fieldVectorsWriter.getDocsWithField()).thenReturn(docsWithFieldSet); - return fieldVectorsWriter; - } -} diff --git a/src/test/java/org/opensearch/knn/index/vectorvalues/TestVectorValues.java b/src/test/java/org/opensearch/knn/index/vectorvalues/TestVectorValues.java index 0f15d5240..3bf79b004 100644 --- a/src/test/java/org/opensearch/knn/index/vectorvalues/TestVectorValues.java +++ b/src/test/java/org/opensearch/knn/index/vectorvalues/TestVectorValues.java @@ -184,11 +184,7 @@ public static class PreDefinedFloatVectorValues extends FloatVectorValues { public PreDefinedFloatVectorValues(final List vectors) { super(); this.count = vectors.size(); - if (!vectors.isEmpty()) { - this.dimension = vectors.get(0).length; - } else { - this.dimension = 0; - } + this.dimension = vectors.get(0).length; this.vectors = vectors; this.current = -1; vector = new float[dimension];