Skip to content

Commit

Permalink
Working Faiss filtering POC with latest Faiss changes picked from PR: f…
Browse files Browse the repository at this point in the history
…acebookresearch/faiss#2848

Signed-off-by: Navneet Verma <[email protected]>
  • Loading branch information
navneet1v committed May 6, 2023
1 parent ed4462e commit d171d26
Show file tree
Hide file tree
Showing 16 changed files with 169 additions and 109 deletions.
2 changes: 1 addition & 1 deletion .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
url = https://github.com/nmslib/nmslib.git
[submodule "jni/external/faiss"]
path = jni/external/faiss
url = https://github.com/facebookresearch/faiss.git
url = https://github.com/navneet1v/faiss.git
4 changes: 2 additions & 2 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -175,9 +175,9 @@ dependencies {
api group: 'com.google.guava', name: 'guava', version:'30.0-jre'
api group: 'commons-lang', name: 'commons-lang', version: '2.6'
testFixturesImplementation "org.opensearch.test:framework:${opensearch_version}"
testImplementation group: 'net.bytebuddy', name: 'byte-buddy', version: '1.12.22'
testImplementation group: 'net.bytebuddy', name: 'byte-buddy', version: '1.14.3'
testImplementation group: 'org.objenesis', name: 'objenesis', version: '3.2'
testImplementation group: 'net.bytebuddy', name: 'byte-buddy-agent', version: '1.12.22'
testImplementation group: 'net.bytebuddy', name: 'byte-buddy-agent', version: '1.14.3'
testFixturesImplementation "org.opensearch:common-utils:${version}"
}

Expand Down
2 changes: 1 addition & 1 deletion gradle/wrapper/gradle-wrapper.properties
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

distributionBase=GRADLE_USER_HOME
distributionPath=wrapper/dists
distributionSha256Sum=c5a643cf80162e665cc228f7b16f343fef868e47d3a4836f62e18b7e17ac018a
distributionSha256Sum=6147605a23b4eff6c334927a86ff3508cb5d6722cd624c97ded4c2e8640f1f87
distributionUrl=https\://services.gradle.org/distributions/gradle-7.6.1-bin.zip
networkTimeout=10000
zipStoreBase=GRADLE_USER_HOME
Expand Down
6 changes: 6 additions & 0 deletions jni/include/faiss_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ namespace knn_jni {
jobjectArray QueryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ,
jfloatArray queryVectorJ, jint kJ);

// Execute a query against the index located in memory at indexPointerJ along with Filters
//
// Return an array of KNNQueryResults
jobjectArray QueryIndex_WithFilter(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ,
jfloatArray queryVectorJ, jint kJ, jintArray filterIdsJ);

// Free the index located in memory at indexPointerJ
void Free(jlong indexPointer);

Expand Down
19 changes: 8 additions & 11 deletions jni/include/org_opensearch_knn_jni_FaissService.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

67 changes: 62 additions & 5 deletions jni/src/faiss_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
#include "faiss/IndexHNSW.h"
#include "faiss/IndexIVFFlat.h"
#include "faiss/MetaIndexes.h"
#include "faiss/Index.h"
#include "faiss/impl/IDSelector.h"

#include <algorithm>
#include <jni.h>
Expand All @@ -33,7 +35,10 @@ void SetExtraParameters(knn_jni::JNIUtilInterface * jniUtil, JNIEnv *env,
const std::unordered_map<std::string, jobject>& parametersCpp, faiss::Index * index);

// Train an index with data provided
void InternalTrainIndex(faiss::Index * index, faiss::Index::idx_t n, const float* x);
void InternalTrainIndex(faiss::Index * index, faiss::idx_t n, const float* x);

// Create the SearchParams based on the Index Type
faiss::SearchParameters* buildSearchParams(const faiss::IndexIDMap *indexReader, faiss::IDSelector *idSelector);

void knn_jni::faiss_wrapper::CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ,
jobjectArray vectorsJ, jstring indexPathJ, jobject parametersJ) {
Expand Down Expand Up @@ -181,12 +186,17 @@ jlong knn_jni::faiss_wrapper::LoadIndex(knn_jni::JNIUtilInterface * jniUtil, JNI

jobjectArray knn_jni::faiss_wrapper::QueryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ,
jfloatArray queryVectorJ, jint kJ) {
return knn_jni::faiss_wrapper::QueryIndex_WithFilter(jniUtil, env, indexPointerJ, queryVectorJ, kJ, nullptr);
}

jobjectArray knn_jni::faiss_wrapper::QueryIndex_WithFilter(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ,
jfloatArray queryVectorJ, jint kJ, jintArray filterIdsJ) {

if (queryVectorJ == nullptr) {
throw std::runtime_error("Query Vector cannot be null");
}

auto *indexReader = reinterpret_cast<faiss::Index*>(indexPointerJ);
auto *indexReader = reinterpret_cast<faiss::IndexIDMap *>(indexPointerJ);

if (indexReader == nullptr) {
throw std::runtime_error("Invalid pointer to index");
Expand All @@ -195,11 +205,35 @@ jobjectArray knn_jni::faiss_wrapper::QueryIndex(knn_jni::JNIUtilInterface * jniU
// The ids vector will hold the top k ids from the search and the dis vector will hold the top k distances from
// the query point
std::vector<float> dis(kJ);
std::vector<faiss::Index::idx_t> ids(kJ);
std::vector<faiss::idx_t> ids(kJ);
float* rawQueryvector = jniUtil->GetFloatArrayElements(env, queryVectorJ, nullptr);

faiss::SearchParameters *searchParameters = nullptr;
// create the filterSearch params if the filterIdsJ is not a null pointer
if(filterIdsJ != nullptr) {
int *filteredVector = jniUtil->GetIntArrayElements(env, filterIdsJ, nullptr);
int filterIdsLength = env->GetArrayLength(filterIdsJ);

// convert the int array to faiss::idx_t type vector
std::vector<faiss::idx_t> convertedFilterIds(filterIdsLength);
for (int i = 0; i < filterIdsLength; i++) {
convertedFilterIds[i] = filteredVector[i];
}

// now create the array selector
// We should IDSelectorBitmap for better performance and space, as it uses 1 bit per docId.
faiss::IDSelectorArray originalIdSelector(filterIdsLength, convertedFilterIds.data());

// now create the IDSelectorTranslated which is for the IndexIDMap indices.
faiss::IDSelectorTranslated idSelectorTranslated(indexReader->id_map, originalIdSelector);

// create the search params.
searchParameters = buildSearchParams(indexReader, &idSelectorTranslated);
//params.sel = &idSelectorTranslated;
}

try {
indexReader->search(1, rawQueryvector, kJ, dis.data(), ids.data());
indexReader->search(1, rawQueryvector, kJ, dis.data(), ids.data(), searchParameters);
} catch (...) {
jniUtil->ReleaseFloatArrayElements(env, queryVectorJ, rawQueryvector, JNI_ABORT);
throw;
Expand Down Expand Up @@ -227,6 +261,29 @@ jobjectArray knn_jni::faiss_wrapper::QueryIndex(knn_jni::JNIUtilInterface * jniU
return results;
}

/**
* This works only for HNSW and IVF algo
* @param indexReader
* @param idSelector
* @return
*/
faiss::SearchParameters* buildSearchParams(const faiss::IndexIDMap *indexReader, faiss::IDSelector *idSelector) {
auto hnswReader = dynamic_cast<const faiss::IndexHNSW*>(indexReader);
if(hnswReader) {
faiss::SearchParametersHNSW hnswParams;
hnswParams.sel = idSelector;
return &hnswParams;
}

auto ivfReader = dynamic_cast<const faiss::IndexIVF*>(indexReader);
if(ivfReader) {
faiss::SearchParametersIVF ivfParams;
ivfParams.sel = idSelector;
return &ivfParams;
}
return nullptr;
}

void knn_jni::faiss_wrapper::Free(jlong indexPointer) {
auto *indexWrapper = reinterpret_cast<faiss::Index*>(indexPointer);
delete indexWrapper;
Expand Down Expand Up @@ -344,7 +401,7 @@ void SetExtraParameters(knn_jni::JNIUtilInterface * jniUtil, JNIEnv *env,
}
}

void InternalTrainIndex(faiss::Index * index, faiss::Index::idx_t n, const float* x) {
void InternalTrainIndex(faiss::Index * index, faiss::idx_t n, const float* x) {
if (auto * indexIvf = dynamic_cast<faiss::IndexIVF*>(index)) {
if (indexIvf->quantizer_trains_alone == 2) {
InternalTrainIndex(indexIvf->quantizer, n, x);
Expand Down
13 changes: 13 additions & 0 deletions jni/src/org_opensearch_knn_jni_FaissService.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,19 @@ JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryInd
return nullptr;
}

JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryIndex_1WithFilter
(JNIEnv * env, jclass cls, jlong indexPointerJ, jfloatArray queryVectorJ, jint kJ, jintArray filteredIds) {

try {
return knn_jni::faiss_wrapper::QueryIndex_WithFilter(&jniUtil, env, indexPointerJ, queryVectorJ, kJ, filteredIds);

} catch (...) {
jniUtil.CatchCppExceptionAndThrowJava(env);
}
return nullptr;

}

JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_free(JNIEnv * env, jclass cls, jlong indexPointerJ)
{
try {
Expand Down
13 changes: 5 additions & 8 deletions src/main/java/org/opensearch/knn/index/query/KNNQuery.java
Original file line number Diff line number Diff line change
Expand Up @@ -79,20 +79,17 @@ public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float bo
throw new IllegalStateException("KNN plugin is disabled. To enable update knn.plugin.enabled to true");
}
Weight filterWeight = getFilterWeight(searcher);
if(filterQuery != null && filterWeight != null) {
if (filterQuery != null && filterWeight != null) {
return new KNNWeight(this, boost, filterWeight);
}
return new KNNWeight(this, boost);
}

private Weight getFilterWeight(IndexSearcher searcher) throws IOException {
Weight filterWeight = null;
if(this.getFilterQuery() != null) {
BooleanQuery booleanQuery =
new BooleanQuery.Builder()
.add(this.getFilterQuery(), BooleanClause.Occur.FILTER)
.add(new FieldExistsQuery(this.getField()), BooleanClause.Occur.FILTER)
.build();
if (this.getFilterQuery() != null) {
BooleanQuery booleanQuery = new BooleanQuery.Builder().add(this.getFilterQuery(), BooleanClause.Occur.FILTER)
.add(new FieldExistsQuery(this.getField()), BooleanClause.Occur.FILTER)
.build();
Query rewritten = searcher.rewrite(booleanQuery);
return searcher.createWeight(rewritten, ScoreMode.COMPLETE_NO_SCORES, 1f);
}
Expand Down
14 changes: 11 additions & 3 deletions src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,17 @@ public static Query create(CreateQueryRequest createQueryRequest) {
}

private static Query getFilterQuery(CreateQueryRequest createQueryRequest) {
if(createQueryRequest.getFilter().isPresent()) {
final QueryShardContext queryShardContext = createQueryRequest.getContext().orElseThrow(() -> new RuntimeException("Shard context cannot be null"));
log.debug(String.format("Creating k-NN query with filter for index [%s], field [%s] and k [%d]", createQueryRequest.getIndexName(), createQueryRequest.fieldName, createQueryRequest.k));
if (createQueryRequest.getFilter().isPresent()) {
final QueryShardContext queryShardContext = createQueryRequest.getContext()
.orElseThrow(() -> new RuntimeException("Shard context cannot be null"));
log.debug(
String.format(
"Creating k-NN query with filter for index [%s], field [%s] and k [%d]",
createQueryRequest.getIndexName(),
createQueryRequest.fieldName,
createQueryRequest.k
)
);
try {
return createQueryRequest.getFilter().get().toQuery(queryShardContext);
} catch (IOException e) {
Expand Down
13 changes: 0 additions & 13 deletions src/main/java/org/opensearch/knn/index/query/KNNScorer.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,13 @@ public class KNNScorer extends Scorer {
private final Map<Integer, Float> scores;
private final float boost;

// This should be removed.
private BitSet filteredDocsBitSet;

public KNNScorer(Weight weight, DocIdSetIterator docIdsIter, Map<Integer, Float> scores, float boost) {
super(weight);
this.docIdsIter = docIdsIter;
this.scores = scores;
this.boost = boost;
}

public KNNScorer setFilteredDocsBitSet(BitSet filteredDocsBitSet) {
this.filteredDocsBitSet = filteredDocsBitSet;
return this;
}

@Override
public DocIdSetIterator iterator() {
return docIdsIter;
Expand All @@ -60,11 +52,6 @@ public float score() {
assert docID() != DocIdSetIterator.NO_MORE_DOCS;
Float score = scores.get(docID());

log.info("Current DocId is : " + docID());
if(filteredDocsBitSet != null) {
log.info("Doc Id present in filtered bit set: " + filteredDocsBitSet.get(docID()));
}

if (score == null) throw new RuntimeException("Null score for the docID: " + docID());
return score;
}
Expand Down
Loading

0 comments on commit d171d26

Please sign in to comment.