-
Notifications
You must be signed in to change notification settings - Fork 2
/
outer_dynamics.py
executable file
·128 lines (88 loc) · 3.76 KB
/
outer_dynamics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import jax.numpy as jnp
import haiku as hk
import json
import sys
# from torch import manual_seed
from jax import grad, local_device_count
from jax import random, jit
from jax.tree_util import tree_map
from utils.data import tfds_mnist_mlp, load_model, get_tfds_loaders
from utils.data import printLog, save_data, one_hot, load_checkpoint
from utils.data import make_result_dir, split, plot_img, par_one_hot, unsplit
from utils.functions import to_complex, L2, to_complex_dict, get_activation
from utils.functions import div_param_dict, dict_zeros_like, cosine_dicts
from utils.functions import distance_dicts
# the first arg is the path to the json containing the hyperparameters
if len(sys.argv)>=2:
with open(sys.argv[1]) as f:
args = json.load(f)
path = make_result_dir(args, sys.argv[0]) if args['save'] else None
if args['model_type']=='xp_mlp':
from models import mlp
# Unpacking argument values
seed = args['seed']
key = random.PRNGKey(seed)
batch_size = args['batch_size']
in_size = args['in_size']
n_targets = args['n_targets']
archi = hk.data_structures.to_immutable_dict(args['archi'])
beta = args['beta']
noise = args['noise']
T1, T2, N = args['T1'], args['T2'], args['N']
activation = get_activation(args)
if args['task']=='mnist':
train_loader, _, _ = tfds_mnist_mlp(batch_size, seed)
batch = next(iter(train_loader))
x, y = batch
elif args['task']=='toy':
x = random.normal(key, (batch_size, in_size))
x = x/x.max()
key, subkey = random.split(key)
y = random.randint(subkey, (batch_size,), 0, n_targets)
print(x)
print(y)
model = mlp.xp_mlp(archi, activation, n_targets, seed, noise=noise)
params = model.init_params(in_size)
if args.get('load_model') is not None:
name = args['load_model']
params = load_model(params, name, path)
if args['plot']:
mlp.plot_params(params, path)
h = model.batched_init_neurons(x)
y = one_hot(y, n_targets)
bptt_grads, aux = grad(model.batched_loss,
has_aux=True)(params, x, h, T1, y)
_, h0 = aux
printLog('Convergence for bptt grads:', path_to_result=path)
printLog(model.batched_eval_conv(params, x, y, h0, 0.0, False).mean(), path_to_result=path)
EP1, key = model.batched_ep_gradN(params, x, h0, N, 1, beta, y, key)
cos_ep1 = cosine_dicts(bptt_grads, EP1)
h = to_complex(h) # start from scratch
cparams = to_complex_dict(params)
x = x.astype(jnp.complex64)
cbeta = beta * jnp.exp(- (2j * jnp.pi)/N)
grads = dict_zeros_like(params)
continual_cosine = []
continual_dist = []
continual_cosine.append(0.0)
continual_dist.append(distance_dicts(bptt_grads, grads))
printLog('Residual convergence after each period:')
N_periods = 40
mean_log_L2s = []
for t in range(N_periods):
h, grads, key = \
model.batched_outer_fwd(cparams, x, h, T2, N, beta, y, grads, key)
L2s = model.batched_eval_conv(params, x, y, h, cbeta, True).mean()
mean_log_L2s.append(jnp.log10(L2s))
printLog(jnp.log10(L2s), path_to_result=path)
norm_grads = div_param_dict(grads, N*(t+1))
continual_cosine.append(cosine_dicts(bptt_grads, norm_grads))
continual_dist.append(distance_dicts(bptt_grads, norm_grads))
printLog('Final cosine with grad:\t', continual_cosine[-1], path_to_result=path)
printLog('Final dist with grad:\t', continual_dist[-1], path_to_result=path)
data = {'cont_cos': continual_cosine,
'cont_dist': continual_dist,
'cos_ep1': cos_ep1,
'mean_log_L2s': mean_log_L2s,
'path': path}
save_data(data, path, 'data')