diff --git a/examples/semantic_indexing/evaluate.py b/examples/semantic_indexing/evaluate.py index 7107c08961aa..bfa086e521c9 100644 --- a/examples/semantic_indexing/evaluate.py +++ b/examples/semantic_indexing/evaluate.py @@ -17,10 +17,23 @@ import numpy as np parser = argparse.ArgumentParser() -parser.add_argument("--similar_text_pair", type=str, default="", help="The full path of similat pair file") -parser.add_argument("--recall_result_file", type=str, default="", help="The full path of recall result file") parser.add_argument( - "--recall_num", type=int, default=10, help="Most similair number of doc recalled from corpus per query" + "--similar_text_pair", + type=str, + default="", + help="The full path of similat pair file", +) +parser.add_argument( + "--recall_result_file", + type=str, + default="", + help="The full path of recall result file", +) +parser.add_argument( + "--recall_num", + type=int, + default=10, + help="Most similair number of doc recalled from corpus per query", ) args = parser.parse_args() @@ -57,11 +70,6 @@ def recall(rs, N=10): with open(args.recall_result_file, "r", encoding="utf-8") as f: relevance_labels = [] for index, line in enumerate(f): - - if index % args.recall_num == 0 and index != 0: - rs.append(relevance_labels) - relevance_labels = [] - text, recalled_text, cosine_sim = line.rstrip().split("\t") if text == recalled_text: continue @@ -70,6 +78,10 @@ def recall(rs, N=10): else: relevance_labels.append(0) + if (index + 1) % args.recall_num == 0: + rs.append(relevance_labels) + relevance_labels = [] + recall_N = [] for topN in (10, 50): R = round(100 * recall(rs, N=topN), 3)