-
Notifications
You must be signed in to change notification settings - Fork 14
/
plotting.py
145 lines (109 loc) · 4.11 KB
/
plotting.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
__author__ = 'oliver'
import matplotlib
# matplotlib.use('Agg') # Or any other X11 back-end
import matplotlib.pyplot as pyplot
from matplotlib import colors
import argparse
import sys
from numpy import genfromtxt, linspace
from scipy.interpolate import Akima1DInterpolator
import os
import six
xmin = 20000
colors_ = list(six.iteritems(colors.cnames))
# Add the single letter colors.
for name, rgb in six.iteritems(colors.ColorConverter.colors):
hex_ = colors.rgb2hex(rgb)
colors_.append((name, hex_))
# Transform to hex color values.
hex_ = [color[1] for color in colors_]
# shuffle(hex_)
hex_ = hex_[2:-1:2]
def plot(column, metric, smoothing, work_dir):
pretty_colors = ['#FC474C','#8DE047','#FFDD50','#53A3D7']
hex_ = pretty_colors
max_x = 0
max_y = 0
column_num = column #cider_val = -8, blue4_val=11, ..., ROUGE= 10,METEOR=11
files = os.listdir(work_dir)
dirs = []
res_files = [os.path.join(work_dir,file,'plot.txt') for file in files if os.path.exists(os.path.join(work_dir,file,'plot.txt'))]
# res_files = [os.path.join(work_dir,'plot.txt')]
init_val = 0
max_y = -9999
max_x = -9999
data_x_y_enum_name = []
# Do one pass to get max value
for i, filename in enumerate(sorted(res_files)):
# if filename.split('/')[-2].endswith('iters-4') or filename.split('/')[-2].endswith('iters-12'):
# continue
data = genfromtxt(filename, delimiter=' ')
# if len(data) == 22:
# continue
x = data[init_val:, 0]
y = data[init_val:, column_num]
if smoothing:
x_smooth = linspace(x.min(), x.max(), 1000)
akima = Akima1DInterpolator(x, y)
y_smooth = akima(x_smooth)
x = x_smooth
y = y_smooth
if x.max() > max_x:
max_x = x.max()
if y.max() > max_y:
max_y = y.max()
data_x_y_enum_name.append((x, y, i, filename.split('/')[-2]))
# data_x_y_enum_name.append((x, y, i, 'CRNN'))
fig = pyplot.figure(figsize=(6, 6))
axes = pyplot.gca()
pyplot.grid()
BUFFER = 0 #defaul 0.25
bufferx = BUFFER * max_x
buffery = BUFFER * max_y
axes.set_ylim([0, max_y + buffery])
# axes.set_ylim([0,0.01])
axes.set_xlim([1, max_x + bufferx])
# axes.set_xlim([0, 100])
pyplot.xlabel('Iterations')
pyplot.ylabel('{}'.format(metric.upper()))
pyplot.title(metric)
for x, y, enum, name in data_x_y_enum_name:
# Will crash if file only has 1 line.
try:
pyplot.plot(x, y, linewidth=2, label=name, color=hex_[enum])
except IndexError as e:
print("EXCEPTION: " + e.message)
print('Failed to create plot for {}.\nIs there only 1 epoch?'.format(name))
continue
pyplot.legend(loc='upper right', shadow=True, fontsize='medium')
# pyplot.savefig(os.path.join(work_dir, '{}.eps'.format(metric)))
pyplot.savefig(os.path.join(work_dir, '{}.png'.format(metric)))
print("Plotted {} series".format(len(data_x_y_enum_name)))
if __name__=="__main__":
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument('-w', dest='work_dir', type=str)
arg_parser.add_argument('-p', dest='plot_type', type=str)
arg_parser.add_argument('-s', '--smotthing', dest='smoothing', type=int, default=0)
if not len(sys.argv) > 1:
arg_parser.print_help()
sys.exit(0)
args = arg_parser.parse_args()
plot_type = args.plot_type
smoothing = args.smoothing
work_dir = args.work_dir
if plot_type == 'loss':
plot(-4, 'loss', smoothing, work_dir)
elif plot_type == 'WER':
plot(-3, 'WER', smoothing, work_dir)
elif plot_type == 'CER':
plot(-2, 'CER', smoothing, work_dir)
elif plot_type == 'accu':
plot(-1, 'accu', smoothing, work_dir)
elif plot_type == 'all':
plot(-7, 'loss', smoothing, work_dir)
plot(-2, 'WER', smoothing, work_dir)
plot(-1, 'CER', smoothing, work_dir)
plot(-3, 'accu', smoothing, work_dir)
plot(-4, 'CER_train', smoothing, work_dir)
else:
print(plot_type+" metric not supported")