Skip to content

Commit

Permalink
bench_fw - fixes & nits for oss (facebookresearch#3102)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: facebookresearch#3102

Reviewed By: pemazare

Differential Revision: D50426528

Pulled By: algoriddle

fbshipit-source-id: 886960b8b522318967fc5ec305666871b496cae8
  • Loading branch information
algoriddle authored and facebook-github-bot committed Oct 20, 2023
1 parent 0a00d81 commit c3b9374
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 95 deletions.
72 changes: 42 additions & 30 deletions benchs/bench_fw/benchmark.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.

from contextlib import contextmanager
import json
import logging
import time
from dataclasses import dataclass
from multiprocessing.pool import ThreadPool
from operator import itemgetter
from statistics import median, mean
from time import perf_counter
from typing import Any, List, Optional
from .descriptors import DatasetDescriptor, IndexDescriptor

Expand All @@ -26,6 +27,15 @@
logger = logging.getLogger(__name__)


@contextmanager
def timer(name) -> float:
logger.info(f"Measuring {name}")
t1 = t2 = perf_counter()
yield lambda: t2 - t1
t2 = perf_counter()
logger.info(f"Time for {name}: {t2 - t1:.3f} seconds")


def refine_distances_knn(
D: np.ndarray, I: np.ndarray, xq: np.ndarray, xb: np.ndarray, metric
):
Expand Down Expand Up @@ -77,7 +87,7 @@ def range_search_pr_curve(
tbl = np.vstack(
[dist_ann, metric_score, cum_score, precision, recall, unique_key]
)
group_by_dist_max_cum_score = np.empty(len(dist_ann), np.bool)
group_by_dist_max_cum_score = np.empty(len(dist_ann), bool)
group_by_dist_max_cum_score[-1] = True
group_by_dist_max_cum_score[:-1] = dist_ann[1:] != dist_ann[:-1]
tbl = tbl[:, group_by_dist_max_cum_score]
Expand Down Expand Up @@ -161,11 +171,13 @@ def optimizer(codec, search, cost_metric, perf_metric):
op.add_operating_point(key, perf, cost)


def distance_ratio_measure(R, D_GT, metric):
def distance_ratio_measure(I, R, D_GT, metric):
sum_of_R = np.sum(np.where(I >= 0, R, 0))
sum_of_D_GT = np.sum(np.where(I >= 0, D_GT, 0))
if metric == faiss.METRIC_INNER_PRODUCT:
return (np.sum(R) / np.sum(D_GT)).item()
return (sum_of_R / sum_of_D_GT).item()
elif metric == faiss.METRIC_L2:
return (np.sum(D_GT) / np.sum(R)).item()
return (sum_of_D_GT / sum_of_R).item()
else:
raise RuntimeError(f"unknown metric {metric}")

Expand All @@ -188,7 +200,7 @@ def get_range_search_metric_function(range_metric, D, R):
assert R is not None
assert D.shape == R.shape
if isinstance(range_metric, list):
aradius, ascore = [], []
aradius, ascore, aradius_from, aradius_to = [], [], [], []
radius_to = 0
for rsd in range_metric:
assert isinstance(rsd, list)
Expand All @@ -212,6 +224,8 @@ def get_range_search_metric_function(range_metric, D, R):
)
aradius.append(real_radius)
ascore.append(score)
aradius_from.append(radius_from)
aradius_to.append(radius_to)

def sigmoid(x, a, b, c):
return a / (1 + np.exp(b * x - c))
Expand All @@ -229,6 +243,7 @@ def sigmoid(x, a, b, c):
cutoff,
lambda x: np.where(x < cutoff, sigmoid(x, *popt), 0),
popt.tolist(),
list(zip(aradius, ascore, aradius_from, aradius_to, strict=True))
)
else:
# Assuming that the range_metric is a float,
Expand All @@ -244,7 +259,7 @@ def sigmoid(x, a, b, c):
f"range_search_metric_function {range_metric=} {real_range=}"
)
assert isinstance(real_range, float)
return real_range * 2, lambda x: np.where(x < real_range, 1, 0), []
return real_range * 2, lambda x: np.where(x < real_range, 1, 0), [], []


@dataclass
Expand Down Expand Up @@ -312,9 +327,9 @@ def range_search_reference(self, index_desc, range_metric):
assert len(range_metric) > 0
ri = len(range_metric[0]) - 1
m_radius = (
max(range_metric, key=itemgetter(ri))[ri]
max(rm[ri] for rm in range_metric)
if self.distance_metric_type == faiss.METRIC_L2
else min(range_metric, key=itemgetter(ri))[ri]
else min(rm[ri] for rm in range_metric)
)
else:
m_radius = range_metric
Expand All @@ -329,13 +344,14 @@ def range_search_reference(self, index_desc, range_metric):
gt_radius,
range_search_metric_function,
coefficients,
coefficients_training_data,
) = get_range_search_metric_function(
range_metric,
D if not flat else None,
R if not flat else None,
)
logger.info("range_search_reference: end")
return gt_radius, range_search_metric_function, coefficients
return gt_radius, range_search_metric_function, coefficients, coefficients_training_data

def estimate_range(self, index_desc, parameters, range_scoring_radius):
D, I, R, P = self.knn_search(
Expand Down Expand Up @@ -397,16 +413,12 @@ def range_search(
)
# QD = QD[:, :index.nprobe]
# QI = QI[:, :index.nprobe]
logger.info("Timing range_search_preassigned")
faiss.cvar.indexIVF_stats.reset()
t0 = time.time()
lims, D, I = index.range_search_preassigned(xq, radius, QI, QD)
t = time.time() - t0
with timer("range_search_preassigned") as t:
lims, D, I = index.range_search_preassigned(xq, radius, QI, QD)
else:
logger.info("Timing range_search")
t0 = time.time()
lims, D, I = index.range_search(xq, radius)
t = time.time() - t0
with timer("range_search") as t:
lims, D, I = index.range_search(xq, radius)
if flat:
R = D
else:
Expand All @@ -415,7 +427,7 @@ def range_search(
lims, D, I, xq, xb, self.distance_metric_type
)
P = {
"time": t,
"time": t(),
"radius": radius,
"count": lims[-1].item(),
"parameters": parameters,
Expand Down Expand Up @@ -560,16 +572,12 @@ def knn_search(
)
# QD = QD[:, :index.nprobe]
# QI = QI[:, :index.nprobe]
logger.info("Timing knn search_preassigned")
faiss.cvar.indexIVF_stats.reset()
t0 = time.time()
D, I = index.search_preassigned(xq, k, QI, QD)
t = time.time() - t0
with timer("knn search_preassigned") as t:
D, I = index.search_preassigned(xq, k, QI, QD)
else:
logger.info("Timing knn search")
t0 = time.time()
D, I = index.search(xq, k)
t = time.time() - t0
with timer("knn search") as t:
D, I = index.search(xq, k)
if flat or level > 0:
R = D
else:
Expand All @@ -578,7 +586,7 @@ def knn_search(
D, I, xq, xb, self.distance_metric_type
)
P = {
"time": t,
"time": t(),
"parameters": parameters,
"index": index_desc.factory,
"level": level,
Expand Down Expand Up @@ -646,7 +654,7 @@ def experiment(parameters, cost_metric, perf_metric):
I, self.gt_knn_I
),
"distance_ratio": distance_ratio_measure(
R, self.gt_knn_D, self.distance_metric_type
I, R, self.gt_knn_D, self.distance_metric_type
),
}
results["experiments"][key] = metrics
Expand Down Expand Up @@ -691,8 +699,12 @@ def benchmark(self) -> str:
gt_radius,
range_search_metric_function,
coefficients,
coefficients_training_data,
) = self.range_search_reference(index_desc, range_metric)
results["metrics"][metric_key] = coefficients
results["metrics"][metric_key] = {
"coefficients": coefficients,
"training_data": coefficients_training_data,
}
gt_rsm = self.range_ground_truth(
gt_radius, range_search_metric_function
)
Expand Down
14 changes: 7 additions & 7 deletions benchs/bench_fw/benchmark_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def write_file(
def get_dataset(self, dataset):
if dataset not in self.cached_ds:
self.cached_ds[dataset] = self.read_nparray(
os.path.join(self.path, dataset.name)
os.path.join(self.path, dataset.tablename)
)
return self.cached_ds[dataset]

Expand All @@ -207,9 +207,9 @@ def read_nparray(
filename: str,
):
fn = self.download_file_from_blobstore(filename)
logger.info(f"Loading nparray from {fn}\n")
logger.info(f"Loading nparray from {fn}")
nparray = np.load(fn)
logger.info(f"Loaded nparray {nparray.shape} from {fn}\n")
logger.info(f"Loaded nparray {nparray.shape} from {fn}")
return nparray

def write_nparray(
Expand All @@ -218,7 +218,7 @@ def write_nparray(
filename: str,
):
fn = self.get_local_filename(filename)
logger.info(f"Saving nparray {nparray.shape} to {fn}\n")
logger.info(f"Saving nparray {nparray.shape} to {fn}")
np.save(fn, nparray)
self.upload_file_to_blobstore(filename)

Expand All @@ -227,10 +227,10 @@ def read_json(
filename: str,
):
fn = self.download_file_from_blobstore(filename)
logger.info(f"Loading json {fn}\n")
logger.info(f"Loading json {fn}")
with open(fn, "r") as fp:
json_dict = json.load(fp)
logger.info(f"Loaded json {json_dict} from {fn}\n")
logger.info(f"Loaded json {json_dict} from {fn}")
return json_dict

def write_json(
Expand All @@ -240,7 +240,7 @@ def write_json(
overwrite: bool = False,
):
fn = self.get_local_filename(filename)
logger.info(f"Saving json {json_dict} to {fn}\n")
logger.info(f"Saving json {json_dict} to {fn}")
with open(fn, "w") as fp:
json.dump(json_dict, fp)
self.upload_file_to_blobstore(filename, overwrite=overwrite)
58 changes: 0 additions & 58 deletions build.sh

This file was deleted.

0 comments on commit c3b9374

Please sign in to comment.