-
Notifications
You must be signed in to change notification settings - Fork 12
/
experiment.py
126 lines (113 loc) · 4.4 KB
/
experiment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import os
import numpy as np
import scipy.io
from datetime import datetime
import utils
import time
import logging
logger = logging.getLogger()
class Experiment(object):
"""
Experiment. For easy saving and loading.
"""
def __init__(self, name, dump_path, create_logger=True):
"""
Initialize the experiment.
"""
self.name = name
self.start_time = datetime.now()
self.dump_path = dump_path
self.experiment_path = os.path.join(dump_path, name)
self.components = {}
if not os.path.exists(self.experiment_path):
os.makedirs(self.experiment_path)
self.logs_path = os.path.join(self.experiment_path, "experiment.log")
if create_logger:
self.log_formatter = utils.create_logger(self.logs_path)
def reset_time(self):
"""
Reset start time (for logs).
"""
self.log_formatter.start_time = time.time()
def add_component(self, component):
"""
Add a new component to the network experiment.
"""
if component.name in self.components:
raise Exception("%s is already a component of this network!"
% component.name)
self.components[component.name] = component
def dump(self, message, model_name=""):
"""
Write components values.
"""
for name, component in self.components.items():
component_name = "%s_%s.mat" % (model_name, name) if model_name else "%s.mat" % name
component_path = os.path.join(self.experiment_path, component_name)
if not hasattr(component, 'params'):
component_values = {component.name: component.get_value()}
else:
component_values = {
param.name: param.get_value()
for param in component.params
}
scipy.io.savemat(
component_path,
component_values
)
logger.info(message)
def load(self, model_name="", experiment_name="", skip_invalid=False):
"""
Load components values.
"""
experiment_path = os.path.join(self.dump_path, experiment_name) if experiment_name else self.experiment_path
for name, component in self.components.items():
logger.info('Reloading %s...' % name)
component_name = "%s_%s.mat" % (model_name, name) if model_name else "%s.mat" % name
component_path = os.path.join(experiment_path, component_name)
component_values = scipy.io.loadmat(component_path)
if not hasattr(component, 'params'):
param_value = component.get_value()
assert component_values[component.name].size == param_value.size
component.set_value(component_values[component.name].astype(np.float32))
else:
for param in component.params:
param_value = param.get_value()
if component_values[param.name].size != param_value.size:
shape_message = 'Invalid component shape for %s: expected %s, but found %s' % (param, param_value.shape, component_values[param.name].shape)
if skip_invalid:
logger.warning(shape_message + ' - Skipping...')
continue
else:
raise Exception(shape_message)
param.set_value(np.reshape(
component_values[param.name],
param_value.shape
).astype(np.float32))
class Sequential(object):
"""
Create sequential networks.
"""
def __init__(self, *modules):
"""
Initialize a sequential network.
"""
self.modules = [module for module in modules]
@property
def params(self):
"""
Return the parameters of all objects in the sequence.
"""
return sum([module.params for module in self.modules], [])
def add_module(self, module):
"""
Append a module to the sequential network.
"""
self.modules.append(module)
def link(self, input):
"""
Propagate the input through the network and return the output.
"""
for module in self.modules:
input = module.link(input)
return input