Skip to content

Commit

Permalink
Fix IVFPQFastScan decode function (facebookresearch#3312)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookresearch#3312

as the [#issue3258](facebookresearch#3258) mentioned, the IVFPQFastScan should have same decoding result as IVFPQ. However, current result is not as expected.

In this PR/Diff, we are going to fix the decoding function

Reviewed By: mdouze

Differential Revision: D55264781

fbshipit-source-id: dfdae9eabceadfc5a3ebb851930d71ce3c1c654d
  • Loading branch information
junjieqi authored and abhinavdangeti committed Jul 12, 2024
1 parent 7f337bd commit b9220ff
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 17 deletions.
8 changes: 8 additions & 0 deletions faiss/IndexIVF.h
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,14 @@ struct IndexIVF : Index, IndexIVFInterface {

/* The standalone codec interface (except sa_decode that is specific) */
size_t sa_code_size() const override;

/** encode a set of vectors
* sa_encode will call encode_vector with include_listno=true
* @param n nb of vectors to encode
* @param x the vectors to encode
* @param bytes output array for the codes
* @return nb of bytes written to codes
*/
void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;

IndexIVF();
Expand Down
23 changes: 21 additions & 2 deletions faiss/IndexIVFPQFastScan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -286,9 +286,28 @@ void IndexIVFPQFastScan::compute_LUT(
}
}

void IndexIVFPQFastScan::sa_decode(idx_t n, const uint8_t* bytes, float* x)
void IndexIVFPQFastScan::sa_decode(idx_t n, const uint8_t* codes, float* x)
const {
pq.decode(bytes, x, n);
size_t coarse_size = coarse_code_size();

#pragma omp parallel if (n > 1)
{
std::vector<float> residual(d);

#pragma omp for
for (idx_t i = 0; i < n; i++) {
const uint8_t* code = codes + i * (code_size + coarse_size);
int64_t list_no = decode_listno(code);
float* xi = x + i * d;
pq.decode(code + coarse_size, xi);
if (by_residual) {
quantizer->reconstruct(list_no, residual.data());
for (size_t j = 0; j < d; j++) {
xi[j] += residual[j];
}
}
}
}
}

} // namespace faiss
63 changes: 48 additions & 15 deletions tests/test_fast_scan_ivf.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,7 @@ def sp(x):
b = btab[0]
dis_new = self.compute_dis_quant(codes, LUTq, biasq, a, b)

# print(a, b, dis_ref.sum())
avg_realtive_error = np.abs(dis_new - dis_ref).sum() / dis_ref.sum()
# print('a=', a, 'avg_relative_error=', avg_realtive_error)
self.assertLess(avg_realtive_error, 0.0005)

def test_no_residual_ip(self):
Expand Down Expand Up @@ -228,8 +226,6 @@ def eval_quant_loss(self, by_residual, metric=faiss.METRIC_L2):

m3 = three_metrics(Da, Ia, Db, Ib)


# print(by_residual, metric, recall_at_1, recall_at_10, intersection_at_10)
ref_results = {
(True, 1): [0.985, 1.0, 9.872],
(True, 0): [ 0.987, 1.0, 9.914],
Expand Down Expand Up @@ -261,36 +257,80 @@ class TestEquivPQ(unittest.TestCase):

def test_equiv_pq(self):
ds = datasets.SyntheticDataset(32, 2000, 200, 4)
xq = ds.get_queries()

index = faiss.index_factory(32, "IVF1,PQ16x4np")
index.by_residual = False
# force coarse quantizer
index.quantizer.add(np.zeros((1, 32), dtype='float32'))
index.train(ds.get_train())
index.add(ds.get_database())
Dref, Iref = index.search(ds.get_queries(), 4)
Dref, Iref = index.search(xq, 4)

index_pq = faiss.index_factory(32, "PQ16x4np")
index_pq.pq = index.pq
index_pq.is_trained = True
index_pq.codes = faiss. downcast_InvertedLists(
index.invlists).codes.at(0)
index_pq.ntotal = index.ntotal
Dnew, Inew = index_pq.search(ds.get_queries(), 4)
Dnew, Inew = index_pq.search(xq, 4)

np.testing.assert_array_equal(Iref, Inew)
np.testing.assert_array_equal(Dref, Dnew)

index_pq2 = faiss.IndexPQFastScan(index_pq)
index_pq2.implem = 12
Dref, Iref = index_pq2.search(ds.get_queries(), 4)
Dref, Iref = index_pq2.search(xq, 4)

index2 = faiss.IndexIVFPQFastScan(index)
index2.implem = 12
Dnew, Inew = index2.search(ds.get_queries(), 4)
Dnew, Inew = index2.search(xq, 4)
np.testing.assert_array_equal(Iref, Inew)
np.testing.assert_array_equal(Dref, Dnew)

# test encode and decode

np.testing.assert_array_equal(
index_pq.sa_encode(xq),
index2.sa_encode(xq)
)

np.testing.assert_array_equal(
index_pq.sa_decode(index_pq.sa_encode(xq)),
index2.sa_decode(index2.sa_encode(xq))
)

np.testing.assert_array_equal(
((index_pq.sa_decode(index_pq.sa_encode(xq)) - xq) ** 2).sum(1),
((index2.sa_decode(index2.sa_encode(xq)) - xq) ** 2).sum(1)
)

def test_equiv_pq_encode_decode(self):
ds = datasets.SyntheticDataset(32, 1000, 200, 10)
xq = ds.get_queries()

index_ivfpq = faiss.index_factory(ds.d, "IVF10,PQ8x4np")
index_ivfpq.train(ds.get_train())

index_ivfpqfs = faiss.IndexIVFPQFastScan(index_ivfpq)

np.testing.assert_array_equal(
index_ivfpq.sa_encode(xq),
index_ivfpqfs.sa_encode(xq)
)

np.testing.assert_array_equal(
index_ivfpq.sa_decode(index_ivfpq.sa_encode(xq)),
index_ivfpqfs.sa_decode(index_ivfpqfs.sa_encode(xq))
)

np.testing.assert_array_equal(
((index_ivfpq.sa_decode(index_ivfpq.sa_encode(xq)) - xq) ** 2)
.sum(1),
((index_ivfpqfs.sa_decode(index_ivfpqfs.sa_encode(xq)) - xq) ** 2)
.sum(1)
)


class TestIVFImplem12(unittest.TestCase):

Expand Down Expand Up @@ -463,7 +503,6 @@ def do_test(self, by_residual=False, metric=faiss.METRIC_L2, d=32, bbs=32):
Dnew, Inew = index2.search(ds.get_queries(), 10)

m3 = three_metrics(Dref, Iref, Dnew, Inew)
# print((by_residual, metric, d), ":", m3)
ref_m3_tab = {
(True, 1, 32): (0.995, 1.0, 9.91),
(True, 0, 32): (0.99, 1.0, 9.91),
Expand Down Expand Up @@ -554,7 +593,6 @@ def subtest_accuracy(self, aq, st, by_residual, implem, metric_type='L2'):
recall_ref = (Iref == gt).sum() / nq
recall1 = (I1 == gt).sum() / nq

print(aq, st, by_residual, implem, metric_type, recall_ref, recall1)
assert abs(recall_ref - recall1) < 0.051

def xx_test_accuracy(self):
Expand Down Expand Up @@ -599,7 +637,6 @@ def subtest_rescale_accuracy(self, aq, st, by_residual, implem):
recall_ref = (Iref == gt).sum() / nq
recall1 = (I1 == gt).sum() / nq

print(aq, st, by_residual, implem, recall_ref, recall1)
assert abs(recall_ref - recall1) < 0.05

def xx_test_rescale_accuracy(self):
Expand All @@ -624,7 +661,6 @@ def subtest_from_ivfaq(self, implem):
nq = Iref.shape[0]
recall_ref = (Iref == gt).sum() / nq
recall1 = (I1 == gt).sum() / nq
print(recall_ref, recall1)
assert abs(recall_ref - recall1) < 0.02

def test_from_ivfaq(self):
Expand Down Expand Up @@ -763,7 +799,6 @@ def subtest_accuracy(self, paq):
recall_ref = (Iref == gt).sum() / nq
recall1 = (I1 == gt).sum() / nq

print(paq, recall_ref, recall1)
assert abs(recall_ref - recall1) < 0.05

def test_accuracy_PLSQ(self):
Expand Down Expand Up @@ -847,7 +882,6 @@ def do_test(self, metric=faiss.METRIC_L2):
# find a reasonable radius
D, I = index.search(ds.get_queries(), 10)
radius = np.median(D[:, -1])
# print("radius=", radius)
lims1, D1, I1 = index.range_search(ds.get_queries(), radius)

index2 = faiss.IndexIVFPQFastScan(index)
Expand All @@ -860,7 +894,6 @@ def do_test(self, metric=faiss.METRIC_L2):
for i in range(ds.nq):
ref = set(I1[lims1[i]: lims1[i + 1]])
new = set(I2[lims2[i]: lims2[i + 1]])
print(ref, new)
nmiss += len(ref - new)
nextra += len(new - ref)

Expand Down

0 comments on commit b9220ff

Please sign in to comment.