-
Notifications
You must be signed in to change notification settings - Fork 7
/
trans_imdb_rank.py
38 lines (30 loc) · 1.01 KB
/
trans_imdb_rank.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import json
import numpy as np
import torch
import os
for filename in ['npy_folder/imdb_bert_512_score.npy', 'npy_folder/imdb_distilbert_512_score.npy']:
if not os.path.exists(filename):
continue
imdb_scores = np.load(filename)
features_and_dataset = torch.load('imdb/cached_train_512_lower')
dataset = features_and_dataset["dataset"]
all_input_mask = dataset.tensors[1]
L = all_input_mask.sum(dim=-1)
assert(len(L)==len(imdb_scores))
rank = []
for i in range(len(L)):
imdb_score = imdb_scores[i]
l = int(L[i])
gs = []
for widx in range(len(imdb_score)):
tmp = imdb_score[widx]
order = np.argsort(-tmp[:l])
guide = np.zeros(len(tmp), dtype=np.float32)
for idx, x in enumerate(order):
guide[x] = 1-idx/l
gs.append(guide)
gs = np.array(gs)
rank.append(gs)
rank = np.array(rank)
o_filename = filename.replace('score' , 'rank')
np.save(o_filename, rank)