Skip to content

Commit

Permalink
Support Selector for IDMap (facebookresearch#2848)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookresearch#2848

Add selector support for IDMap wrapped indices.
Caveat: this requires to wrap the IDSelector with another selector. Since the params are const, the const is casted away.

This is a problem if the same params are used from multiple execution threads with different selectors. However, this seems rare enough to take the risk.

Reviewed By: alexanderguzhva

Differential Revision: D45598823

fbshipit-source-id: ec23465c13f1f8273a6a46f9aa869ccae2cdb79c
  • Loading branch information
mdouze authored and Thejas-bhat committed Sep 26, 2023
1 parent c86bc03 commit eabe599
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 30 deletions.
63 changes: 42 additions & 21 deletions faiss/IndexIDMap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

#include <faiss/impl/AuxIndexStructures.h>
#include <faiss/impl/FaissAssert.h>
#include <faiss/impl/IDSelector.h>
#include <faiss/utils/Heap.h>
#include <faiss/utils/WorkerThread.h>

Expand Down Expand Up @@ -71,6 +70,27 @@ void IndexIDMapTemplate<IndexT>::add_with_ids(
this->ntotal = index->ntotal;
}

namespace {

/// RAII object to reset the IDSelector in the params object
struct ScopedSelChange {
SearchParameters* params = nullptr;
IDSelector* old_sel = nullptr;

void set(SearchParameters* params, IDSelector* new_sel) {
this->params = params;
old_sel = params->sel;
params->sel = new_sel;
}
~ScopedSelChange() {
if (params) {
params->sel = old_sel;
}
}
};

} // namespace

template <typename IndexT>
void IndexIDMapTemplate<IndexT>::search(
idx_t n,
Expand All @@ -79,9 +99,26 @@ void IndexIDMapTemplate<IndexT>::search(
typename IndexT::distance_t* distances,
idx_t* labels,
const SearchParameters* params) const {
FAISS_THROW_IF_NOT_MSG(
!params, "search params not supported for this index");
index->search(n, x, k, distances, labels);
IDSelectorTranslated this_idtrans(this->id_map, nullptr);
ScopedSelChange sel_change;

if (params && params->sel) {
auto idtrans = dynamic_cast<const IDSelectorTranslated*>(params->sel);

if (!idtrans) {
/*
FAISS_THROW_IF_NOT_MSG(
idtrans,
"IndexIDMap requires an IDSelectorTranslated on input");
*/
// then make an idtrans and force it into the SearchParameters
// (hence the const_cast)
auto params_non_const = const_cast<SearchParameters*>(params);
this_idtrans.sel = params->sel;
sel_change.set(params_non_const, &this_idtrans);
}
}
index->search(n, x, k, distances, labels, params);
idx_t* li = labels;
#pragma omp parallel for
for (idx_t i = 0; i < n * k; i++) {
Expand All @@ -106,26 +143,10 @@ void IndexIDMapTemplate<IndexT>::range_search(
}
}

namespace {

struct IDTranslatedSelector : IDSelector {
const std::vector<int64_t>& id_map;
const IDSelector& sel;
IDTranslatedSelector(
const std::vector<int64_t>& id_map,
const IDSelector& sel)
: id_map(id_map), sel(sel) {}
bool is_member(idx_t id) const override {
return sel.is_member(id_map[id]);
}
};

} // namespace

template <typename IndexT>
size_t IndexIDMapTemplate<IndexT>::remove_ids(const IDSelector& sel) {
// remove in sub-index first
IDTranslatedSelector sel2(id_map, sel);
IDSelectorTranslated sel2(id_map, &sel);
size_t nremove = index->remove_ids(sel2);

int64_t j = 0;
Expand Down
22 changes: 22 additions & 0 deletions faiss/IndexIDMap.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include <faiss/Index.h>
#include <faiss/IndexBinary.h>
#include <faiss/impl/IDSelector.h>

#include <unordered_map>
#include <vector>
Expand Down Expand Up @@ -102,4 +103,25 @@ struct IndexIDMap2Template : IndexIDMapTemplate<IndexT> {
using IndexIDMap2 = IndexIDMap2Template<Index>;
using IndexBinaryIDMap2 = IndexIDMap2Template<IndexBinary>;

// IDSelector that translates the ids using an IDMap
struct IDSelectorTranslated : IDSelector {
const std::vector<int64_t>& id_map;
const IDSelector* sel;

IDSelectorTranslated(
const std::vector<int64_t>& id_map,
const IDSelector* sel)
: id_map(id_map), sel(sel) {}

IDSelectorTranslated(IndexBinaryIDMap& index_idmap, const IDSelector* sel)
: id_map(index_idmap.id_map), sel(sel) {}

IDSelectorTranslated(IndexIDMap& index_idmap, const IDSelector* sel)
: id_map(index_idmap.id_map), sel(sel) {}

bool is_member(idx_t id) const override {
return sel->is_member(id_map[id]);
}
};

} // namespace faiss
1 change: 1 addition & 0 deletions faiss/python/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ def replacement_function(*args):
add_ref_in_constructor(IDSelectorAnd, slice(2))
add_ref_in_constructor(IDSelectorOr, slice(2))
add_ref_in_constructor(IDSelectorXOr, slice(2))
add_ref_in_constructor(IDSelectorTranslated, slice(2))

# seems really marginal...
# remove_ref_from_method(IndexReplicas, 'removeIndex', 0)
Expand Down
12 changes: 7 additions & 5 deletions faiss/python/swigfaiss.swig
Original file line number Diff line number Diff line change
Expand Up @@ -494,11 +494,6 @@ void gpu_sync_all_devices()
%template(IndexBinaryReplicas) faiss::IndexReplicasTemplate<faiss::IndexBinary>;

%include <faiss/MetaIndexes.h>
%include <faiss/IndexIDMap.h>
%template(IndexIDMap) faiss::IndexIDMapTemplate<faiss::Index>;
%template(IndexBinaryIDMap) faiss::IndexIDMapTemplate<faiss::IndexBinary>;
%template(IndexIDMap2) faiss::IndexIDMap2Template<faiss::Index>;
%template(IndexBinaryIDMap2) faiss::IndexIDMap2Template<faiss::IndexBinary>;

%include <faiss/IndexRowwiseMinMax.h>

Expand All @@ -513,6 +508,13 @@ void gpu_sync_all_devices()
%include <faiss/impl/AuxIndexStructures.h>
%include <faiss/impl/IDSelector.h>

%include <faiss/IndexIDMap.h>
%template(IndexIDMap) faiss::IndexIDMapTemplate<faiss::Index>;
%template(IndexBinaryIDMap) faiss::IndexIDMapTemplate<faiss::IndexBinary>;
%template(IndexIDMap2) faiss::IndexIDMap2Template<faiss::Index>;
%template(IndexBinaryIDMap2) faiss::IndexIDMap2Template<faiss::IndexBinary>;


%include <faiss/utils/approx_topk/mode.h>

#ifdef GPU_WRAPPER
Expand Down
43 changes: 39 additions & 4 deletions tests/test_search_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,17 +101,17 @@ def do_test_id_selector(self, index_key, id_selector_type="batch", mt=faiss.METR
sel = faiss.IDSelectorNot(faiss.IDSelectorBatch(inverse_subset))
elif id_selector_type == "or":
sel = faiss.IDSelectorOr(
faiss.IDSelectorBatch(lhs_subset),
faiss.IDSelectorBatch(lhs_subset),
faiss.IDSelectorBatch(rhs_subset)
)
elif id_selector_type == "and":
sel = faiss.IDSelectorAnd(
faiss.IDSelectorBatch(lhs_subset),
faiss.IDSelectorBatch(lhs_subset),
faiss.IDSelectorBatch(rhs_subset)
)
elif id_selector_type == "xor":
sel = faiss.IDSelectorXOr(
faiss.IDSelectorBatch(lhs_subset),
faiss.IDSelectorBatch(lhs_subset),
faiss.IDSelectorBatch(rhs_subset)
)
else:
Expand Down Expand Up @@ -181,7 +181,7 @@ def test_Flat_id_bitmap(self):

def test_Flat_id_not(self):
self.do_test_id_selector("Flat", id_selector_type="not")

def test_Flat_id_or(self):
self.do_test_id_selector("Flat", id_selector_type="or")

Expand Down Expand Up @@ -220,6 +220,41 @@ def do_test_id_selector_weak(self, index_key):
def test_HSNW(self):
self.do_test_id_selector_weak("HNSW")

def test_idmap(self):
ds = datasets.SyntheticDataset(32, 100, 100, 20)
rs = np.random.RandomState(123)
ids = rs.choice(10000, size=100, replace=False)
mask = ids % 2 == 0
index = faiss.index_factory(ds.d, "IDMap,SQ8")
index.train(ds.get_train())

# ref result
index.add_with_ids(ds.get_database()[mask], ids[mask])
Dref, Iref = index.search(ds.get_queries(), 10)

# with selector
index.reset()
index.add_with_ids(ds.get_database(), ids)

valid_ids = ids[mask]
sel = faiss.IDSelectorTranslated(
index, faiss.IDSelectorBatch(valid_ids))

Dnew, Inew = index.search(
ds.get_queries(), 10,
params=faiss.SearchParameters(sel=sel)
)
np.testing.assert_array_equal(Iref, Inew)
np.testing.assert_array_almost_equal(Dref, Dnew, decimal=5)

# let the IDMap::search add the translation...
Dnew, Inew = index.search(
ds.get_queries(), 10,
params=faiss.SearchParameters(sel=faiss.IDSelectorBatch(valid_ids))
)
np.testing.assert_array_equal(Iref, Inew)
np.testing.assert_array_almost_equal(Dref, Dnew, decimal=5)


class TestSearchParams(unittest.TestCase):

Expand Down

0 comments on commit eabe599

Please sign in to comment.