From 2a5dba959136c7cb9f23192c3b4cd71d2f23bbfe Mon Sep 17 00:00:00 2001 From: ZeyuTeng96 <96521059+ZeyuTeng96@users.noreply.github.com> Date: Fri, 12 Jan 2024 23:06:47 +0800 Subject: [PATCH] [semantic_indexing] fix bug of evaluate.py (#7843) --- examples/semantic_indexing/evaluate.py | 28 ++++++++++++++++++-------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/examples/semantic_indexing/evaluate.py b/examples/semantic_indexing/evaluate.py index 7107c08961aa..bfa086e521c9 100644 --- a/examples/semantic_indexing/evaluate.py +++ b/examples/semantic_indexing/evaluate.py @@ -17,10 +17,23 @@ import numpy as np parser = argparse.ArgumentParser() -parser.add_argument("--similar_text_pair", type=str, default="", help="The full path of similat pair file") -parser.add_argument("--recall_result_file", type=str, default="", help="The full path of recall result file") parser.add_argument( - "--recall_num", type=int, default=10, help="Most similair number of doc recalled from corpus per query" + "--similar_text_pair", + type=str, + default="", + help="The full path of similat pair file", +) +parser.add_argument( + "--recall_result_file", + type=str, + default="", + help="The full path of recall result file", +) +parser.add_argument( + "--recall_num", + type=int, + default=10, + help="Most similair number of doc recalled from corpus per query", ) args = parser.parse_args() @@ -57,11 +70,6 @@ def recall(rs, N=10): with open(args.recall_result_file, "r", encoding="utf-8") as f: relevance_labels = [] for index, line in enumerate(f): - - if index % args.recall_num == 0 and index != 0: - rs.append(relevance_labels) - relevance_labels = [] - text, recalled_text, cosine_sim = line.rstrip().split("\t") if text == recalled_text: continue @@ -70,6 +78,10 @@ def recall(rs, N=10): else: relevance_labels.append(0) + if (index + 1) % args.recall_num == 0: + rs.append(relevance_labels) + relevance_labels = [] + recall_N = [] for topN in (10, 50): R = round(100 * recall(rs, N=topN), 3)