-
Notifications
You must be signed in to change notification settings - Fork 2
/
planet.py
executable file
·512 lines (442 loc) · 17.8 KB
/
planet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as td
import numpy as np
from typing import Any, List, Tuple, Optional
TensorType = Any
def atanh(x):
return 0.5 * torch.log((1 + x) / (1 - x))
class TanhBijector(torch.distributions.Transform):
def __init__(self):
super().__init__()
self.bijective = True
self.domain = torch.distributions.constraints.real
self.codomain = torch.distributions.constraints.interval(-1.0, 1.0)
@property
def sign(self): return 1.
def _call(self, x): return torch.tanh(x)
def _inverse(self, y: torch.Tensor):
y = torch.where(
(torch.abs(y) <= 1.),
torch.clamp(y, -0.99999997, 0.99999997),
y)
y = atanh(y)
return y
def log_abs_det_jacobian(self, x, y):
return 2. * (np.log(2) - x - F.softplus(-2. * x))
class Reshape(nn.Module):
def __init__(self, shape: List):
super().__init__()
self.shape = shape
def forward(self, x):
return x.view(*self.shape)
class Encoder(nn.Module):
""" As mentioned in the paper, VAE is used
to calculate the state posterior needed
for parameter learning in RSSM. The training objective
here is to create bound on the data. Here losses
are written only for observations as rewards losses
follow them. """
def __init__(self,
depth : int = 32,
input_channels : Optional[int] = 3):
super(Encoder, self).__init__()
"""
Initialize the parameters of the Encoder.
Args
depth (int) : Number of channels in the first convolution layer.
input_channels (int) : Number of channels in the input observation.
"""
self.depth = depth
self.input_channels = input_channels
self.encoder = nn.Sequential(
nn.Conv2d(self.input_channels, self.depth, 4, stride=2),
nn.ReLU(),
nn.Conv2d(self.depth, self.depth * 2, 4, stride=2),
nn.ReLU(),
nn.Conv2d(self.depth * 2, self.depth * 4, 4, stride=2),
nn.ReLU(),
nn.Conv2d(self.depth * 4, self.depth * 8, 4, stride=2),
nn.ReLU(),
)
def forward(self, x):
""" Flatten the input observation [batch, horizon, 3, 64, 64]
into shape [batch * horizon, 3, 64, 64] before feeding it
to the input. """
orig_shape = x.shape
x = x.reshape(-1, *x.shape[-3:])
x = self.encoder(x)
x = x.reshape(*orig_shape[:-3], -1)
return x
class Decoder(nn.Module):
"""
Takes the input from the RSSM model
and then decodes it back to images from
the latent space model. It is mainly used
in calculating losses.
"""
def __init__(self,
input_size : int,
depth: int = 32,
shape: Tuple[int] = (3, 64, 64)):
super(Decoder, self).__init__()
self.depth = depth
self.shape = shape
self.decoder = nn.Sequential(
nn.Linear(input_size, 32 * self.depth),
Reshape([-1, 32 * self.depth, 1, 1]),
nn.ConvTranspose2d(32 * self.depth, 4 * self.depth, 5, stride=2),
nn.ReLU(),
nn.ConvTranspose2d(4 * self.depth, 2 * self.depth, 5, stride=2),
nn.ReLU(),
nn.ConvTranspose2d(2 * self.depth, self.depth, 6, stride=2),
nn.ReLU(),
nn.ConvTranspose2d(self.depth, self.shape[0], 6, stride=2),
)
def forward(self, x):
orig_shape = x.shape
x = self.decoder(x)
reshape_size = orig_shape[:-1] + self.shape
mean = x.view(*reshape_size)
return td.Independent( td.Normal(mean, 1), len(self.shape))
class ActionDecoder(nn.Module):
"""
ActionDecoder is the policy module in Dreamer.
It outputs a distribution parameterized by mean and std, later to be
transformed by a custom TanhBijector.
"""
def __init__(self,
input_size: int,
action_size: int,
layers: int,
units: int,
dist: str = "tanh_normal",
min_std: float = 1e-4,
init_std: float = 5.0,
mean_scale: float = 5.0):
super(ActionDecoder, self).__init__()
self.layrs = layers
self.units = units
self.dist = dist
self.act = nn.ReLU
self.min_std = min_std
self.init_std = init_std
self.mean_scale = mean_scale
self.action_size = action_size
self.layers = []
self.softplus = nn.Softplus()
# MLP Construction
cur_size = input_size
for _ in range(self.layrs):
self.layers.extend([nn.Linear(cur_size, self.units), self.act()])
cur_size = self.units
self.layers.append(nn.Linear(cur_size, 2 * action_size))
self.model = nn.Sequential(*self.layers)
def forward(self, x):
raw_init_std = np.log(np.exp(self.init_std) - 1)
x = self.model(x)
mean, std = torch.chunk(x, 2, dim=-1)
mean = self.mean_scale * torch.tanh(mean / self.mean_scale)
std = self.softplus(std + raw_init_std) + self.min_std
dist = td.Normal(mean, std)
transforms = [TanhBijector()]
dist = td.transformed_distribution.TransformedDistribution(
dist, transforms)
dist = td.Independent(dist, 1)
return dist
class DenseDecoder(nn.Module):
"""
FC network that outputs a distribution for calculating log_prob.
Used later in DreamerLoss.
"""
def __init__(self,
input_size: int,
output_size: int,
layers: int,
units: int,
dist: str = "normal"):
"""Initializes FC network
Args:
input_size (int): Input size to network
output_size (int): Output size to network
layers (int): Number of layers in network
units (int): Size of the hidden layers
dist (str): Output distribution, parameterized by FC output
logits.
act (Any): Activation function
"""
super().__init__()
self.layrs = layers
self.units = units
self.act = nn.ELU
self.dist = dist
self.input_size = input_size
self.output_size = output_size
self.layers = []
cur_size = input_size
for _ in range(self.layrs):
self.layers.extend([nn.Linear(cur_size, self.units), self.act()])
cur_size = units
self.layers.append(nn.Linear(cur_size, output_size))
self.model = nn.Sequential(*self.layers)
def forward(self, x):
x = self.model(x)
if self.output_size == 1:
x = torch.squeeze(x)
if self.dist == "normal":
output_dist = td.Normal(x, 1)
elif self.dist == "binary":
output_dist = td.Bernoulli(logits=x)
else:
raise NotImplementedError("Distribution type not implemented!")
return td.Independent(output_dist, 0)
class RSSM(nn.Module):
"""RSSM is the core recurrent part of the PlaNET module. It consists of
two networks, one (obs) to calculate posterior beliefs and states and
the second (img) to calculate prior beliefs and states. The prior network
takes in the previous state and action, while the posterior network takes
in the previous state, action, and a latent embedding of the most recent
observation.
"""
def __init__(self,
action_size: int,
embed_size: int,
stoch: int = 30,
deter: int = 200,
hidden: int = 200):
"""Initializes RSSM
Args:
action_size (int): Action space size
embed_size (int): Size of ConvEncoder embedding
stoch (int): Size of the distributional hidden state
deter (int): Size of the deterministic hidden state
hidden (int): General size of hidden layers
act (Any): Activation function
"""
super().__init__()
self.stoch_size = stoch
self.deter_size = deter
self.hidden_size = hidden
self.act = nn.ELU
self.obs1 = nn.Linear(embed_size + deter, hidden)
self.obs2 = nn.Linear(hidden, 2 * stoch)
self.cell = nn.GRUCell(self.hidden_size, hidden_size=self.deter_size)
self.img1 = nn.Linear(stoch + action_size, hidden)
self.img2 = nn.Linear(deter, hidden)
self.img3 = nn.Linear(hidden, 2 * stoch)
self.softplus = nn.Softplus
def get_initial_state(self, batch_size: int, device) -> List[TensorType]:
"""Returns the inital state for the RSSM, which consists of mean,
std for the stochastic state, the sampled stochastic hidden state
(from mean, std), and the deterministic hidden state, which is
pushed through the GRUCell.
Args:
batch_size (int): Batch size for initial state
Returns:
List of tensors
"""
return [
torch.zeros(batch_size, self.stoch_size).to(device),
torch.zeros(batch_size, self.stoch_size).to(device),
torch.zeros(batch_size, self.stoch_size).to(device),
torch.zeros(batch_size, self.deter_size).to(device),
]
def observe(self,
embed: TensorType,
action: TensorType,
state: List[TensorType] = None
) -> Tuple[List[TensorType], List[TensorType]]:
"""Returns the corresponding states from the embedding from ConvEncoder
and actions. This is accomplished by rolling out the RNN from the
starting state through eacn index of embed and action, saving all
intermediate states between.
Args:
embed (TensorType): ConvEncoder embedding
action (TensorType): Actions
state (List[TensorType]): Initial state before rollout
Returns:
Posterior states and prior states (both List[TensorType])
"""
if state is None:
state = self.get_initial_state(action.size()[0])
if embed.dim() <= 2:
embed = torch.unsqueeze(embed, 1)
if action.dim() <= 2:
action = torch.unsqueeze(action, 1)
embed = embed.permute(1, 0, 2)
action = action.permute(1, 0, 2)
priors = [[] for i in range(len(state))]
posts = [[] for i in range(len(state))]
last = (state, state)
for index in range(len(action)):
# Tuple of post and prior
last = self.obs_step(last[0], action[index], embed[index])
[o.append(s) for s, o in zip(last[0], posts)]
[o.append(s) for s, o in zip(last[1], priors)]
prior = [torch.stack(x, dim=0) for x in priors]
post = [torch.stack(x, dim=0) for x in posts]
prior = [e.permute(1, 0, 2) for e in prior]
post = [e.permute(1, 0, 2) for e in post]
return post, prior
def imagine(self, action: TensorType,
state: List[TensorType] = None) -> List[TensorType]:
"""Imagines the trajectory starting from state through a list of actions.
Similar to observe(), requires rolling out the RNN for each timestep.
Args:
action (TensorType): Actions
state (List[TensorType]): Starting state before rollout
Returns:
Prior states
"""
if state is None:
state = self.get_initial_state(action.size()[0])
action = action.permute(1, 0, 2)
indices = range(len(action))
priors = [[] for _ in range(len(state))]
last = state
for index in indices:
last = self.img_step(last, action[index])
[o.append(s) for s, o in zip(last, priors)]
prior = [torch.stack(x, dim=0) for x in priors]
prior = [e.permute(1, 0, 2) for e in prior]
return prior
def obs_step(
self, prev_state: TensorType, prev_action: TensorType,
embed: TensorType) -> Tuple[List[TensorType], List[TensorType]]:
"""Runs through the posterior model and returns the posterior state
Args:
prev_state (TensorType): The previous state
prev_action (TensorType): The previous action
embed (TensorType): Embedding from ConvEncoder
Returns:
Post and Prior state
"""
prior = self.img_step(prev_state, prev_action)
x = torch.cat([prior[3], embed], dim=-1)
x = self.obs1(x)
x = self.act()(x)
x = self.obs2(x)
mean, std = torch.chunk(x, 2, dim=-1)
std = self.softplus()(std) + 0.1
stoch = self.get_dist(mean, std).rsample()
post = [mean, std, stoch, prior[3]]
return post, prior
def img_step(self, prev_state: TensorType,
prev_action: TensorType) -> List[TensorType]:
"""Runs through the prior model and returns the prior state
Args:
prev_state (TensorType): The previous state
prev_action (TensorType): The previous action
Returns:
Prior state
"""
x = torch.cat([prev_state[2], prev_action], dim=-1)
x = self.img1(x)
x = self.act()(x)
deter = self.cell(x, prev_state[3])
x = deter
x = self.img2(x)
x = self.act()(x)
x = self.img3(x)
mean, std = torch.chunk(x, 2, dim=-1)
std = self.softplus()(std) + 0.1
stoch = self.get_dist(mean, std).rsample()
return [mean, std, stoch, deter]
def get_feature(self, state: List[TensorType]) -> TensorType:
# Constructs feature for input to reward, decoder, actor, critic
return torch.cat([state[2], state[3]], dim=-1)
def get_dist(self, mean: TensorType, std: TensorType) -> TensorType:
return td.Normal(mean, std)
# Dreamer Model
class PLANet(nn.Module):
def __init__(self, obs_space, action_space, num_outputs, model_config,
name):
super().__init__()
nn.Module.__init__(self)
self.depth = model_config["depth_size"]
self.deter_size = model_config["deter_size"]
self.stoch_size = model_config["stoch_size"]
self.hidden_size = model_config["hidden_size"]
self.action_size = action_space.shape[0]
self.encoder = Encoder(self.depth)
self.decoder = Decoder(
self.stoch_size + self.deter_size, depth=self.depth)
self.reward = DenseDecoder(self.stoch_size + self.deter_size, 1, 2,
self.hidden_size)
self.dynamics = RSSM(
self.action_size,
32 * self.depth,
stoch=self.stoch_size,
deter=self.deter_size)
self.actor = ActionDecoder(self.stoch_size + self.deter_size,
self.action_size, 4, self.hidden_size)
self.value = DenseDecoder(self.stoch_size + self.deter_size, 1, 3,
self.hidden_size)
self.state = None
def policy(self, obs: TensorType, state: List[TensorType], explore=True
) -> Tuple[TensorType, List[float], List[TensorType]]:
"""Returns the action. Runs through the encoder, recurrent model,
and policy to obtain action.
"""
if state is None:
self.state = self.get_initial_state(batch_size=obs.shape[0])
else:
self.state = state
post = self.state[:4]
action = self.state[4]
embed = self.encoder(obs)
post, _ = self.dynamics.obs_step(post, action, embed)
feat = self.dynamics.get_feature(post)
action_dist = self.actor(feat)
if explore:
action = action_dist.sample()
else:
samples = []
for _ in range(1000):
samples.append(action_dist.sample())
action = torch.mean(torch.cat(samples), dim=0)
if action.ndim == 1:
action = action.unsqueeze(0)
logp = action_dist.log_prob(action)
self.state = post + [action]
return action, logp, self.state
def imagine_ahead(self, state: List[TensorType],
horizon: int) -> TensorType:
"""Given a batch of states, rolls out more state of length horizon.
"""
start = []
for s in state:
s = s.contiguous().detach()
shpe = [-1] + list(s.size())[2:]
start.append(s.view(*shpe))
def next_state(state):
feature = self.dynamics.get_feature(state).detach()
action = self.actor(feature).rsample()
next_state = self.dynamics.img_step(state, action)
return next_state
last = start
outputs = [[] for i in range(len(start))]
for _ in range(horizon):
last = next_state(last)
[o.append(s) for s, o in zip(last, outputs)]
outputs = [torch.stack(x, dim=0) for x in outputs]
imag_feat = self.dynamics.get_feature(outputs)
return imag_feat
def get_initial_state(self, device) -> List[TensorType]:
self.state = self.dynamics.get_initial_state(1, device) + [
torch.zeros(1, self.action_size).to(device)
]
return self.state
def value_function(self) -> TensorType:
return None
class FreezeParameters:
def __init__(self, parameters):
self.parameters = parameters
self.param_states = [p.requires_grad for p in self.parameters]
def __enter__(self):
for param in self.parameters:
param.requires_grad = False
def __exit__(self, exc_type, exc_val, exc_tb):
for i, param in enumerate(self.parameters):
param.requires_grad = self.param_states[i]