This repository has been archived by the owner on Oct 25, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
get_figures.py
63 lines (53 loc) · 2.53 KB
/
get_figures.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import os
from get_data import *
import matplotlib.pyplot as plt
from matplotlib.pyplot import cm
def plot_results(*methods, true_minimum=None, max_n_calls=np.inf, choice='x_error', x_mark='n', target_time=0,
max_time=1000):
ax = plt.gca()
ax.set_title("Convergence plot")
if x_mark == 'n':
ax.set_xlabel("Number of calls $n$")
else:
ax.set_xlabel("Time Consumption (seconds)")
ax.set_ylabel(choice)
ax.grid()
colors = cm.viridis(np.linspace(0.25, 1.0, len(methods)))
for result, color in zip(methods, colors):
# print(result)
name = result['name']
records = result['result'].records
n_calls = int(np.min([len(records), max_n_calls]))
mins = [records[r['best']][choice] for r in records[:n_calls]]
if x_mark == 'n':
ax.plot(range(1, n_calls + 1), mins, c=color, marker=".", markersize=12, lw=2, label=name)
else:
t0 = records[0]['output_time']
time_consume = []
for r in records[:n_calls]:
time_consume.append(r['input_time'] - t0)
t0 += r['output_time'] - r['input_time'] - target_time
if time_consume[-1] > max_time:
mins = mins[:len(time_consume)]
break
ax.plot(time_consume, mins, c=color, marker=".", markersize=12, lw=2, label=name)
if true_minimum is not None:
ax.axhline(true_minimum, linestyle="--",
color="r", lw=1,
label="True minimum")
ax.legend(loc="best")
return ax
if __name__ == '__main__':
idx = 0
names = sorted(os.listdir('pkl'))
print(names[idx])
with open('pkl/{}'.format(names[idx]), 'rb') as fl:
data = pickle.load(fl)
print(data['setting'])
true_minimum = benchmarks[data['setting']['benchmark']]['y']
plt.show(plot_results(*data['data'], true_minimum=true_minimum, max_n_calls=100, choice='y_output'))
plt.show(plot_results(*data['data'], true_minimum=0., max_n_calls=100, choice='x_error'))
# plt.show(plot_results(*data['data'], true_minimum=true_minimum, choice='y_output', x_mark='time', target_time=2, max_time=1000))
# plt.show(plot_results(*data['data'], choice='x_error', x_mark='time', target_time=5, max_time=1000))
# plt.show(plot_results(data['data'][3], data['data'][0], choice='y_true', x_mark='time', target_time=1, max_time=200))
# plt.show(plot_results(data['data'][3], data['data'][0], choice='x_error', x_mark='time', target_time=1, max_time=200))