Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MB-60202 - IDMap2 Selector #12

Merged
merged 2 commits into from
Jan 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions c_api/IndexIVF_c_ex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "macros_impl.h"

using faiss::IndexIVF;
using faiss::SearchParametersIVF;

int faiss_IndexIVF_set_direct_map(FaissIndexIVF* index, int direct_map_type) {
try {
Expand All @@ -20,3 +21,14 @@ int faiss_IndexIVF_set_direct_map(FaissIndexIVF* index, int direct_map_type) {
}
CATCH_AND_HANDLE
}

int faiss_SearchParametersIVF_new_with_sel(
FaissSearchParametersIVF** p_sp,
FaissIDSelector* sel) {
try {
SearchParametersIVF* sp = new SearchParametersIVF;
sp->sel = reinterpret_cast<faiss::IDSelector*>(sel);
*p_sp = reinterpret_cast<FaissSearchParametersIVF*>(sp);
}
CATCH_AND_HANDLE
}
4 changes: 4 additions & 0 deletions c_api/IndexIVF_c_ex.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ int faiss_IndexIVF_set_direct_map(
FaissIndexIVF* index,
int direct_map_type);

int faiss_SearchParametersIVF_new_with_sel(
FaissSearchParametersIVF** p_sp,
FaissIDSelector* sel);

#ifdef __cplusplus
}
#endif
Expand Down
63 changes: 42 additions & 21 deletions faiss/IndexIDMap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

#include <faiss/impl/AuxIndexStructures.h>
#include <faiss/impl/FaissAssert.h>
#include <faiss/impl/IDSelector.h>
#include <faiss/utils/Heap.h>
#include <faiss/utils/WorkerThread.h>

Expand Down Expand Up @@ -71,6 +70,27 @@ void IndexIDMapTemplate<IndexT>::add_with_ids(
this->ntotal = index->ntotal;
}

namespace {

/// RAII object to reset the IDSelector in the params object
struct ScopedSelChange {
SearchParameters* params = nullptr;
IDSelector* old_sel = nullptr;

void set(SearchParameters* params, IDSelector* new_sel) {
this->params = params;
old_sel = params->sel;
params->sel = new_sel;
}
~ScopedSelChange() {
if (params) {
params->sel = old_sel;
}
}
};

} // namespace

template <typename IndexT>
void IndexIDMapTemplate<IndexT>::search(
idx_t n,
Expand All @@ -79,9 +99,26 @@ void IndexIDMapTemplate<IndexT>::search(
typename IndexT::distance_t* distances,
idx_t* labels,
const SearchParameters* params) const {
FAISS_THROW_IF_NOT_MSG(
!params, "search params not supported for this index");
index->search(n, x, k, distances, labels);
IDSelectorTranslated this_idtrans(this->id_map, nullptr);
ScopedSelChange sel_change;

if (params && params->sel) {
auto idtrans = dynamic_cast<const IDSelectorTranslated*>(params->sel);

if (!idtrans) {
/*
FAISS_THROW_IF_NOT_MSG(
idtrans,
"IndexIDMap requires an IDSelectorTranslated on input");
*/
metonymic-smokey marked this conversation as resolved.
Show resolved Hide resolved
// then make an idtrans and force it into the SearchParameters
// (hence the const_cast)
auto params_non_const = const_cast<SearchParameters*>(params);
this_idtrans.sel = params->sel;
sel_change.set(params_non_const, &this_idtrans);
}
}
index->search(n, x, k, distances, labels, params);
idx_t* li = labels;
#pragma omp parallel for
for (idx_t i = 0; i < n * k; i++) {
Expand All @@ -106,26 +143,10 @@ void IndexIDMapTemplate<IndexT>::range_search(
}
}

namespace {

struct IDTranslatedSelector : IDSelector {
const std::vector<int64_t>& id_map;
const IDSelector& sel;
IDTranslatedSelector(
const std::vector<int64_t>& id_map,
const IDSelector& sel)
: id_map(id_map), sel(sel) {}
bool is_member(idx_t id) const override {
return sel.is_member(id_map[id]);
}
};

} // namespace

template <typename IndexT>
size_t IndexIDMapTemplate<IndexT>::remove_ids(const IDSelector& sel) {
// remove in sub-index first
IDTranslatedSelector sel2(id_map, sel);
IDSelectorTranslated sel2(id_map, &sel);
size_t nremove = index->remove_ids(sel2);

int64_t j = 0;
Expand Down
22 changes: 22 additions & 0 deletions faiss/IndexIDMap.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include <faiss/Index.h>
#include <faiss/IndexBinary.h>
#include <faiss/impl/IDSelector.h>

#include <unordered_map>
#include <vector>
Expand Down Expand Up @@ -102,4 +103,25 @@ struct IndexIDMap2Template : IndexIDMapTemplate<IndexT> {
using IndexIDMap2 = IndexIDMap2Template<Index>;
using IndexBinaryIDMap2 = IndexIDMap2Template<IndexBinary>;

// IDSelector that translates the ids using an IDMap
struct IDSelectorTranslated : IDSelector {
const std::vector<int64_t>& id_map;
const IDSelector* sel;

IDSelectorTranslated(
const std::vector<int64_t>& id_map,
const IDSelector* sel)
: id_map(id_map), sel(sel) {}

IDSelectorTranslated(IndexBinaryIDMap& index_idmap, const IDSelector* sel)
: id_map(index_idmap.id_map), sel(sel) {}

IDSelectorTranslated(IndexIDMap& index_idmap, const IDSelector* sel)
: id_map(index_idmap.id_map), sel(sel) {}

bool is_member(idx_t id) const override {
return sel->is_member(id_map[id]);
}
};

} // namespace faiss
1 change: 1 addition & 0 deletions faiss/python/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ def replacement_function(*args):
add_ref_in_constructor(IDSelectorAnd, slice(2))
add_ref_in_constructor(IDSelectorOr, slice(2))
add_ref_in_constructor(IDSelectorXOr, slice(2))
add_ref_in_constructor(IDSelectorTranslated, slice(2))

# seems really marginal...
# remove_ref_from_method(IndexReplicas, 'removeIndex', 0)
Expand Down
12 changes: 7 additions & 5 deletions faiss/python/swigfaiss.swig
Original file line number Diff line number Diff line change
Expand Up @@ -494,11 +494,6 @@ void gpu_sync_all_devices()
%template(IndexBinaryReplicas) faiss::IndexReplicasTemplate<faiss::IndexBinary>;

%include <faiss/MetaIndexes.h>
%include <faiss/IndexIDMap.h>
%template(IndexIDMap) faiss::IndexIDMapTemplate<faiss::Index>;
%template(IndexBinaryIDMap) faiss::IndexIDMapTemplate<faiss::IndexBinary>;
%template(IndexIDMap2) faiss::IndexIDMap2Template<faiss::Index>;
%template(IndexBinaryIDMap2) faiss::IndexIDMap2Template<faiss::IndexBinary>;

%include <faiss/IndexRowwiseMinMax.h>

Expand All @@ -513,6 +508,13 @@ void gpu_sync_all_devices()
%include <faiss/impl/AuxIndexStructures.h>
%include <faiss/impl/IDSelector.h>

%include <faiss/IndexIDMap.h>
%template(IndexIDMap) faiss::IndexIDMapTemplate<faiss::Index>;
%template(IndexBinaryIDMap) faiss::IndexIDMapTemplate<faiss::IndexBinary>;
%template(IndexIDMap2) faiss::IndexIDMap2Template<faiss::Index>;
%template(IndexBinaryIDMap2) faiss::IndexIDMap2Template<faiss::IndexBinary>;


%include <faiss/utils/approx_topk/mode.h>

#ifdef GPU_WRAPPER
Expand Down
43 changes: 39 additions & 4 deletions tests/test_search_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,17 +101,17 @@ def do_test_id_selector(self, index_key, id_selector_type="batch", mt=faiss.METR
sel = faiss.IDSelectorNot(faiss.IDSelectorBatch(inverse_subset))
elif id_selector_type == "or":
sel = faiss.IDSelectorOr(
faiss.IDSelectorBatch(lhs_subset),
faiss.IDSelectorBatch(lhs_subset),
faiss.IDSelectorBatch(rhs_subset)
)
elif id_selector_type == "and":
sel = faiss.IDSelectorAnd(
faiss.IDSelectorBatch(lhs_subset),
faiss.IDSelectorBatch(lhs_subset),
faiss.IDSelectorBatch(rhs_subset)
)
elif id_selector_type == "xor":
sel = faiss.IDSelectorXOr(
faiss.IDSelectorBatch(lhs_subset),
faiss.IDSelectorBatch(lhs_subset),
faiss.IDSelectorBatch(rhs_subset)
)
else:
Expand Down Expand Up @@ -181,7 +181,7 @@ def test_Flat_id_bitmap(self):

def test_Flat_id_not(self):
self.do_test_id_selector("Flat", id_selector_type="not")

def test_Flat_id_or(self):
self.do_test_id_selector("Flat", id_selector_type="or")

Expand Down Expand Up @@ -220,6 +220,41 @@ def do_test_id_selector_weak(self, index_key):
def test_HSNW(self):
self.do_test_id_selector_weak("HNSW")

def test_idmap(self):
ds = datasets.SyntheticDataset(32, 100, 100, 20)
rs = np.random.RandomState(123)
ids = rs.choice(10000, size=100, replace=False)
mask = ids % 2 == 0
index = faiss.index_factory(ds.d, "IDMap,SQ8")
index.train(ds.get_train())

# ref result
index.add_with_ids(ds.get_database()[mask], ids[mask])
Dref, Iref = index.search(ds.get_queries(), 10)

# with selector
index.reset()
index.add_with_ids(ds.get_database(), ids)

valid_ids = ids[mask]
sel = faiss.IDSelectorTranslated(
index, faiss.IDSelectorBatch(valid_ids))

Dnew, Inew = index.search(
ds.get_queries(), 10,
params=faiss.SearchParameters(sel=sel)
)
np.testing.assert_array_equal(Iref, Inew)
np.testing.assert_array_almost_equal(Dref, Dnew, decimal=5)

# let the IDMap::search add the translation...
Dnew, Inew = index.search(
ds.get_queries(), 10,
params=faiss.SearchParameters(sel=faiss.IDSelectorBatch(valid_ids))
)
np.testing.assert_array_equal(Iref, Inew)
np.testing.assert_array_almost_equal(Dref, Dnew, decimal=5)


class TestSearchParams(unittest.TestCase):

Expand Down