Skip to content

Commit

Permalink
[semantic_indexing] fix bug of evaluate.py (#7843)
Browse files Browse the repository at this point in the history
  • Loading branch information
ZeyuTeng96 committed Jan 12, 2024
1 parent b44f888 commit 2a5dba9
Showing 1 changed file with 20 additions and 8 deletions.
28 changes: 20 additions & 8 deletions examples/semantic_indexing/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,23 @@
import numpy as np

parser = argparse.ArgumentParser()
parser.add_argument("--similar_text_pair", type=str, default="", help="The full path of similat pair file")
parser.add_argument("--recall_result_file", type=str, default="", help="The full path of recall result file")
parser.add_argument(
"--recall_num", type=int, default=10, help="Most similair number of doc recalled from corpus per query"
"--similar_text_pair",
type=str,
default="",
help="The full path of similat pair file",
)
parser.add_argument(
"--recall_result_file",
type=str,
default="",
help="The full path of recall result file",
)
parser.add_argument(
"--recall_num",
type=int,
default=10,
help="Most similair number of doc recalled from corpus per query",
)
args = parser.parse_args()

Expand Down Expand Up @@ -57,11 +70,6 @@ def recall(rs, N=10):
with open(args.recall_result_file, "r", encoding="utf-8") as f:
relevance_labels = []
for index, line in enumerate(f):

if index % args.recall_num == 0 and index != 0:
rs.append(relevance_labels)
relevance_labels = []

text, recalled_text, cosine_sim = line.rstrip().split("\t")
if text == recalled_text:
continue
Expand All @@ -70,6 +78,10 @@ def recall(rs, N=10):
else:
relevance_labels.append(0)

if (index + 1) % args.recall_num == 0:
rs.append(relevance_labels)
relevance_labels = []

recall_N = []
for topN in (10, 50):
R = round(100 * recall(rs, N=topN), 3)
Expand Down

0 comments on commit 2a5dba9

Please sign in to comment.