Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix IVFPQFastScan decode function #3312

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions faiss/IndexIVFPQFastScan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -286,9 +286,28 @@ void IndexIVFPQFastScan::compute_LUT(
}
}

void IndexIVFPQFastScan::sa_decode(idx_t n, const uint8_t* bytes, float* x)
void IndexIVFPQFastScan::sa_decode(idx_t n, const uint8_t* codes, float* x)
const {
pq.decode(bytes, x, n);
size_t coarse_size = coarse_code_size();

#pragma omp parallel
Copy link
Contributor

@alexanderguzhva alexanderguzhva Mar 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

parallel if (n > 1)
or, even better, if n is greater than a certain threshold
please check other sa_decode implementations for the appropriate threshold

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @alexanderguzhva, thanks for the suggestion. I checked other sa_decode. They also don't have the condition. Try to understand it more. #pragma omp parallel if (n > 1) means we only parallel execution when n is greater than 1. I'm wondering how we should choose the threshold in general. Do we have any benchmark to run or other methods?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no benchmarks, this is a magic constant. Typically, things like 100 or 1000 are used.

#pragma omp parallel if (n > 1000)

{
std::vector<float> residual(d);

#pragma omp for
for (idx_t i = 0; i < n; i++) {
const uint8_t* code = codes + i * (code_size + coarse_size);
int64_t list_no = decode_listno(code);
float* xi = x + i * d;
pq.decode(code + coarse_size, xi);
if (by_residual) {
quantizer->reconstruct(list_no, residual.data());
for (size_t j = 0; j < d; j++) {
xi[j] += residual[j];
}
}
}
}
}

} // namespace faiss
58 changes: 43 additions & 15 deletions tests/test_fast_scan_ivf.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,7 @@ def sp(x):
b = btab[0]
dis_new = self.compute_dis_quant(codes, LUTq, biasq, a, b)

# print(a, b, dis_ref.sum())
avg_realtive_error = np.abs(dis_new - dis_ref).sum() / dis_ref.sum()
# print('a=', a, 'avg_relative_error=', avg_realtive_error)
self.assertLess(avg_realtive_error, 0.0005)

def test_no_residual_ip(self):
Expand Down Expand Up @@ -228,8 +226,6 @@ def eval_quant_loss(self, by_residual, metric=faiss.METRIC_L2):

m3 = three_metrics(Da, Ia, Db, Ib)


# print(by_residual, metric, recall_at_1, recall_at_10, intersection_at_10)
ref_results = {
(True, 1): [0.985, 1.0, 9.872],
(True, 0): [ 0.987, 1.0, 9.914],
Expand Down Expand Up @@ -261,36 +257,75 @@ class TestEquivPQ(unittest.TestCase):

def test_equiv_pq(self):
ds = datasets.SyntheticDataset(32, 2000, 200, 4)
xq = ds.get_queries()

index = faiss.index_factory(32, "IVF1,PQ16x4np")
index.by_residual = False
# force coarse quantizer
index.quantizer.add(np.zeros((1, 32), dtype='float32'))
index.train(ds.get_train())
index.add(ds.get_database())
Dref, Iref = index.search(ds.get_queries(), 4)
Dref, Iref = index.search(xq, 4)

index_pq = faiss.index_factory(32, "PQ16x4np")
index_pq.pq = index.pq
index_pq.is_trained = True
index_pq.codes = faiss. downcast_InvertedLists(
index.invlists).codes.at(0)
index_pq.ntotal = index.ntotal
Dnew, Inew = index_pq.search(ds.get_queries(), 4)
Dnew, Inew = index_pq.search(xq, 4)

np.testing.assert_array_equal(Iref, Inew)
np.testing.assert_array_equal(Dref, Dnew)

index_pq2 = faiss.IndexPQFastScan(index_pq)
index_pq2.implem = 12
Dref, Iref = index_pq2.search(ds.get_queries(), 4)
Dref, Iref = index_pq2.search(xq, 4)

index2 = faiss.IndexIVFPQFastScan(index)
index2.implem = 12
Dnew, Inew = index2.search(ds.get_queries(), 4)
Dnew, Inew = index2.search(xq, 4)
np.testing.assert_array_equal(Iref, Inew)
np.testing.assert_array_equal(Dref, Dnew)

# test encode and decode

np.testing.assert_array_equal(
index_pq.sa_encode(xq),
index2.sa_encode(xq)
)

np.testing.assert_array_equal(
((index_pq.sa_decode(index_pq.sa_encode(xq)) - xq) ** 2).sum(1),
((index2.sa_decode(index2.sa_encode(xq)) - xq) ** 2).sum(1)
)

def test_equiv_pq_encode_decode(self):
ds = datasets.SyntheticDataset(32, 1000, 200, 10)
xq = ds.get_queries()

index_ivfpq = faiss.index_factory(ds.d, "IVF10,PQ8x4np")
index_ivfpq.train(ds.get_train())

index_ivfpqfs = faiss.IndexIVFPQFastScan(index_ivfpq)

np.testing.assert_array_equal(
index_ivfpq.sa_encode(xq),
index_ivfpqfs.sa_encode(xq)
)

np.testing.assert_array_equal(
index_ivfpq.sa_decode(index_ivfpq.sa_encode(xq)),
index_ivfpqfs.sa_decode(index_ivfpqfs.sa_encode(xq))
)

np.testing.assert_array_equal(
((index_ivfpq.sa_decode(index_ivfpq.sa_encode(xq)) - xq) ** 2)
.sum(1),
((index_ivfpqfs.sa_decode(index_ivfpqfs.sa_encode(xq)) - xq) ** 2)
.sum(1)
)


class TestIVFImplem12(unittest.TestCase):

Expand Down Expand Up @@ -463,7 +498,6 @@ def do_test(self, by_residual=False, metric=faiss.METRIC_L2, d=32, bbs=32):
Dnew, Inew = index2.search(ds.get_queries(), 10)

m3 = three_metrics(Dref, Iref, Dnew, Inew)
# print((by_residual, metric, d), ":", m3)
ref_m3_tab = {
(True, 1, 32): (0.995, 1.0, 9.91),
(True, 0, 32): (0.99, 1.0, 9.91),
Expand Down Expand Up @@ -554,7 +588,6 @@ def subtest_accuracy(self, aq, st, by_residual, implem, metric_type='L2'):
recall_ref = (Iref == gt).sum() / nq
recall1 = (I1 == gt).sum() / nq

print(aq, st, by_residual, implem, metric_type, recall_ref, recall1)
assert abs(recall_ref - recall1) < 0.051

def xx_test_accuracy(self):
Expand Down Expand Up @@ -599,7 +632,6 @@ def subtest_rescale_accuracy(self, aq, st, by_residual, implem):
recall_ref = (Iref == gt).sum() / nq
recall1 = (I1 == gt).sum() / nq

print(aq, st, by_residual, implem, recall_ref, recall1)
assert abs(recall_ref - recall1) < 0.05

def xx_test_rescale_accuracy(self):
Expand All @@ -624,7 +656,6 @@ def subtest_from_ivfaq(self, implem):
nq = Iref.shape[0]
recall_ref = (Iref == gt).sum() / nq
recall1 = (I1 == gt).sum() / nq
print(recall_ref, recall1)
assert abs(recall_ref - recall1) < 0.02

def test_from_ivfaq(self):
Expand Down Expand Up @@ -763,7 +794,6 @@ def subtest_accuracy(self, paq):
recall_ref = (Iref == gt).sum() / nq
recall1 = (I1 == gt).sum() / nq

print(paq, recall_ref, recall1)
assert abs(recall_ref - recall1) < 0.05

def test_accuracy_PLSQ(self):
Expand Down Expand Up @@ -847,7 +877,6 @@ def do_test(self, metric=faiss.METRIC_L2):
# find a reasonable radius
D, I = index.search(ds.get_queries(), 10)
radius = np.median(D[:, -1])
# print("radius=", radius)
lims1, D1, I1 = index.range_search(ds.get_queries(), radius)

index2 = faiss.IndexIVFPQFastScan(index)
Expand All @@ -860,7 +889,6 @@ def do_test(self, metric=faiss.METRIC_L2):
for i in range(ds.nq):
ref = set(I1[lims1[i]: lims1[i + 1]])
new = set(I2[lims2[i]: lims2[i + 1]])
print(ref, new)
nmiss += len(ref - new)
nextra += len(new - ref)

Expand Down