-
Notifications
You must be signed in to change notification settings - Fork 0
/
find_best.py
32 lines (27 loc) · 946 Bytes
/
find_best.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import sys
import os
import pandas as pd
import numpy as np
nni_id = sys.argv[1]
#file_path = '~/nni/experiments/' + nni_id + '/log/RL_controller.log'
file_path = '/home/yujwang/nni/experiments/okBoeRjw/log/RL_controller.log'
os.environ['file_path'] = str(file_path)
os.system("grep 'reward' $file_path > tmp.txt")
data = pd.read_csv("tmp.txt", header=None, sep='\t')
data.columns = ['arc', 'arc_num', 'reward', 'reward_num']
max_reward = data.ix[data['reward_num'].idxmax()]
print("max_reward", max_reward)
arc_str = max_reward['arc_num']
fw = open("./scripts/arcs.sh", 'w')
fw.write("#!/bin/bash\n")
arc = [int(i) for i in arc_str.split(' ')]
start = 0
for layer_id in range(12):
end = start + 1 + layer_id
end += 1
out_str = "fixed_arc=\"$fixed_arc {0}\"".format(np.reshape(arc[start: end], [-1]))
out_str = out_str.replace("[", "").replace("]", "")
print(out_str)
fw.write(out_str+'\n')
start = end
fw.close()