Skip to content

Commit

Permalink
Introduced KNNVectorValues interface to iterate on different types of…
Browse files Browse the repository at this point in the history
… Vector values during indexing and search (opensearch-project#1897)

Signed-off-by: Navneet Verma <[email protected]>
  • Loading branch information
navneet1v authored Aug 1, 2024
1 parent 27f3168 commit ec6451c
Show file tree
Hide file tree
Showing 19 changed files with 1,214 additions and 181 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,6 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Documentation
### Maintenance
### Refactoring
* Introduce KNNVectorValues interface to iterate on different types of Vector values during indexing and search [#1897](https://github.com/opensearch-project/k-NN/pull/1897)
* Clean up parsing for query [#1824](https://github.com/opensearch-project/k-NN/pull/1824)
* Refactor engine package structure [#1913](https://github.com/opensearch-project/k-NN/pull/1913)
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ static NativeEngineFieldVectorsWriter<?> create(final FieldInfo fieldInfo, final
throw new IllegalStateException("Unsupported Vector encoding : " + fieldInfo.getVectorEncoding());
}

NativeEngineFieldVectorsWriter(final FieldInfo fieldInfo, final InfoStream infoStream) {
private NativeEngineFieldVectorsWriter(final FieldInfo fieldInfo, final InfoStream infoStream) {
this.fieldInfo = fieldInfo;
this.infoStream = infoStream;
vectors = new HashMap<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ public static String buildEngineFileSuffix(String fieldName, String extension) {
return String.format("_%s%s", fieldName, extension);
}

private static long getTotalLiveDocsCount(final BinaryDocValues binaryDocValues) {
public static long getTotalLiveDocsCount(final BinaryDocValues binaryDocValues) {
long totalLiveDocs;
if (binaryDocValues instanceof KNN80BinaryDocValues) {
totalLiveDocs = ((KNN80BinaryDocValues) binaryDocValues).getTotalLiveDocs();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.vectorvalues;

import lombok.ToString;
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.index.ByteVectorValues;

import java.io.IOException;

/**
* Concrete implementation of {@link KNNVectorValues} that returns byte[] as vector where binary vector is stored and
* provides an abstraction over {@link BinaryDocValues}, {@link ByteVectorValues}, {@link KnnFieldVectorsWriter} etc.
*/
@ToString(callSuper = true)
public class KNNBinaryVectorValues extends KNNVectorValues<byte[]> {
KNNBinaryVectorValues(KNNVectorValuesIterator vectorValuesIterator) {
super(vectorValuesIterator);
}

@Override
public byte[] getVector() throws IOException {
final byte[] vector = VectorValueExtractorStrategy.extractBinaryVector(vectorValuesIterator);
this.dimension = vector.length;
return vector;
}

/**
* Binary Vector values gets stored as byte[], hence for dimension of the binary vector we have to multiply the
* byte[] size with {@link Byte#SIZE}
* @return int
*/
@Override
public int dimension() {
return super.dimension() * Byte.SIZE;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.vectorvalues;

import lombok.ToString;
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.index.ByteVectorValues;

import java.io.IOException;

/**
* Concrete implementation of {@link KNNVectorValues} that returns float[] as vector and provides an abstraction over
* {@link BinaryDocValues}, {@link ByteVectorValues}, {@link KnnFieldVectorsWriter} etc.
*/
@ToString(callSuper = true)
public class KNNByteVectorValues extends KNNVectorValues<byte[]> {
KNNByteVectorValues(KNNVectorValuesIterator vectorValuesIterator) {
super(vectorValuesIterator);
}

@Override
public byte[] getVector() throws IOException {
final byte[] vector = VectorValueExtractorStrategy.extractByteVector(vectorValuesIterator);
this.dimension = vector.length;
return vector;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.vectorvalues;

import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.index.FloatVectorValues;

import java.io.IOException;

/**
* Concrete implementation of {@link KNNVectorValues} that returns float[] as vector and provides an abstraction over
* {@link BinaryDocValues}, {@link FloatVectorValues}, {@link KnnFieldVectorsWriter} etc.
*/
public class KNNFloatVectorValues extends KNNVectorValues<float[]> {
KNNFloatVectorValues(final KNNVectorValuesIterator vectorValuesIterator) {
super(vectorValuesIterator);
}

@Override
public float[] getVector() throws IOException {
final float[] vector = VectorValueExtractorStrategy.extractFloatVector(vectorValuesIterator);
this.dimension = vector.length;
return vector;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.vectorvalues;

import lombok.ToString;
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.FloatVectorValues;

import java.io.IOException;

/**
* An abstract class to iterate over KNNVectors, as KNNVectors are stored as different representation like
* {@link BinaryDocValues}, {@link FloatVectorValues}, {@link ByteVectorValues}, {@link KnnFieldVectorsWriter} etc.
* @param <T>
*/
@ToString
public abstract class KNNVectorValues<T> {

protected final KNNVectorValuesIterator vectorValuesIterator;
protected int dimension;

protected KNNVectorValues(final KNNVectorValuesIterator vectorValuesIterator) {
this.vectorValuesIterator = vectorValuesIterator;
}

/**
* Return a vector reference. If you are adding this address in a List/Map ensure that you are copying the vector first.
* This is to ensure that we keep the heap and latency in check by reducing the copies of vectors.
*
* @return T an array of byte[], float[]
* @throws IOException if we are not able to get the vector
*/
public abstract T getVector() throws IOException;

/**
* Dimension of vector is returned. Do call getVector function first before calling this function otherwise you will get 0 value.
* @return int
*/
public int dimension() {
assert docId() != -1 && dimension != 0 : "Cannot get dimension before we retrieve a vector from KNNVectorValues";
return dimension;
}

/**
* Returns the total live docs for KNNVectorValues.
* @return long
*/
public long totalLiveDocs() {
return vectorValuesIterator.liveDocs();
}

/**
* Returns the current docId where the iterator is pointing to.
* @return int
*/
public int docId() {
return vectorValuesIterator.docId();
}

/**
* Advances to a specific docId. Ensure that the passed docId is greater than current docId where Iterator is
* pointing to, otherwise
* {@link IOException} will be thrown
* @return int
* @throws IOException if we are not able to move to the passed docId.
*/
public int advance(int docId) throws IOException {
return vectorValuesIterator.advance(docId);
}

/**
* Move to nextDocId.
* @return int
* @throws IOException if we cannot move to next docId
*/
public int nextDoc() throws IOException {
return vectorValuesIterator.nextDoc();
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.vectorvalues;

import org.apache.lucene.index.DocsWithFieldSet;
import org.apache.lucene.search.DocIdSetIterator;
import org.opensearch.knn.index.VectorDataType;

import java.util.Map;

/**
* A factory class that provides various methods to create the {@link KNNVectorValues}.
*/
public final class KNNVectorValuesFactory {

/**
* Returns a {@link KNNVectorValues} for the given {@link DocIdSetIterator} and {@link VectorDataType}
*
* @param vectorDataType {@link VectorDataType}
* @param docIdSetIterator {@link DocIdSetIterator}
* @return {@link KNNVectorValues} of type float[]
*/
public static <T> KNNVectorValues<T> getVectorValues(final VectorDataType vectorDataType, final DocIdSetIterator docIdSetIterator) {
return getVectorValues(vectorDataType, new KNNVectorValuesIterator.DocIdsIteratorValues(docIdSetIterator));
}

/**
* Returns a {@link KNNVectorValues} for the given {@link DocIdSetIterator} and a Map of docId and vectors.
*
* @param vectorDataType {@link VectorDataType}
* @param docIdWithFieldSet {@link DocsWithFieldSet}
* @return {@link KNNVectorValues} of type float[]
*/
public static <T> KNNVectorValues<T> getVectorValues(
final VectorDataType vectorDataType,
final DocsWithFieldSet docIdWithFieldSet,
final Map<Integer, T> vectors
) {
return getVectorValues(vectorDataType, new KNNVectorValuesIterator.FieldWriterIteratorValues<T>(docIdWithFieldSet, vectors));
}

@SuppressWarnings("unchecked")
private static <T> KNNVectorValues<T> getVectorValues(
final VectorDataType vectorDataType,
final KNNVectorValuesIterator knnVectorValuesIterator
) {
switch (vectorDataType) {
case FLOAT:
return (KNNVectorValues<T>) new KNNFloatVectorValues(knnVectorValuesIterator);
case BYTE:
return (KNNVectorValues<T>) new KNNByteVectorValues(knnVectorValuesIterator);
case BINARY:
return (KNNVectorValues<T>) new KNNBinaryVectorValues(knnVectorValuesIterator);
}
throw new IllegalArgumentException("Invalid Vector data type provided, hence cannot return VectorValues");
}
}
Loading

0 comments on commit ec6451c

Please sign in to comment.