-
Notifications
You must be signed in to change notification settings - Fork 85
/
ConvRNN.py
121 lines (108 loc) · 4.53 KB
/
ConvRNN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
'''
@File : ConvRNN.py
@Time : 2020/03/09
@Author : jhhuang96
@Mail : [email protected]
@Version : 1.0
@Description: convrnn cell
'''
import torch
import torch.nn as nn
class CGRU_cell(nn.Module):
"""
ConvGRU Cell
"""
def __init__(self, shape, input_channels, filter_size, num_features):
super(CGRU_cell, self).__init__()
self.shape = shape
self.input_channels = input_channels
# kernel_size of input_to_state equals state_to_state
self.filter_size = filter_size
self.num_features = num_features
self.padding = (filter_size - 1) // 2
self.conv1 = nn.Sequential(
nn.Conv2d(self.input_channels + self.num_features,
2 * self.num_features, self.filter_size, 1,
self.padding),
nn.GroupNorm(2 * self.num_features // 32, 2 * self.num_features))
self.conv2 = nn.Sequential(
nn.Conv2d(self.input_channels + self.num_features,
self.num_features, self.filter_size, 1, self.padding),
nn.GroupNorm(self.num_features // 32, self.num_features))
def forward(self, inputs=None, hidden_state=None, seq_len=10):
# seq_len=10 for moving_mnist
if hidden_state is None:
htprev = torch.zeros(inputs.size(1), self.num_features,
self.shape[0], self.shape[1]).cuda()
else:
htprev = hidden_state
output_inner = []
for index in range(seq_len):
if inputs is None:
x = torch.zeros(htprev.size(0), self.input_channels,
self.shape[0], self.shape[1]).cuda()
else:
x = inputs[index, ...]
combined_1 = torch.cat((x, htprev), 1) # X_t + H_t-1
gates = self.conv1(combined_1) # W * (X_t + H_t-1)
zgate, rgate = torch.split(gates, self.num_features, dim=1)
# zgate, rgate = gates.chunk(2, 1)
z = torch.sigmoid(zgate)
r = torch.sigmoid(rgate)
combined_2 = torch.cat((x, r * htprev),
1) # h' = tanh(W*(x+r*H_t-1))
ht = self.conv2(combined_2)
ht = torch.tanh(ht)
htnext = (1 - z) * htprev + z * ht
output_inner.append(htnext)
htprev = htnext
return torch.stack(output_inner), htnext
class CLSTM_cell(nn.Module):
"""ConvLSTMCell
"""
def __init__(self, shape, input_channels, filter_size, num_features):
super(CLSTM_cell, self).__init__()
self.shape = shape # H, W
self.input_channels = input_channels
self.filter_size = filter_size
self.num_features = num_features
# in this way the output has the same size
self.padding = (filter_size - 1) // 2
self.conv = nn.Sequential(
nn.Conv2d(self.input_channels + self.num_features,
4 * self.num_features, self.filter_size, 1,
self.padding),
nn.GroupNorm(4 * self.num_features // 32, 4 * self.num_features))
def forward(self, inputs=None, hidden_state=None, seq_len=10):
# seq_len=10 for moving_mnist
if hidden_state is None:
hx = torch.zeros(inputs.size(1), self.num_features, self.shape[0],
self.shape[1]).cuda()
cx = torch.zeros(inputs.size(1), self.num_features, self.shape[0],
self.shape[1]).cuda()
else:
hx, cx = hidden_state
output_inner = []
for index in range(seq_len):
if inputs is None:
x = torch.zeros(hx.size(0), self.input_channels, self.shape[0],
self.shape[1]).cuda()
else:
x = inputs[index, ...]
combined = torch.cat((x, hx), 1)
gates = self.conv(combined) # gates: S, num_features*4, H, W
# it should return 4 tensors: i,f,g,o
ingate, forgetgate, cellgate, outgate = torch.split(
gates, self.num_features, dim=1)
ingate = torch.sigmoid(ingate)
forgetgate = torch.sigmoid(forgetgate)
cellgate = torch.tanh(cellgate)
outgate = torch.sigmoid(outgate)
cy = (forgetgate * cx) + (ingate * cellgate)
hy = outgate * torch.tanh(cy)
output_inner.append(hy)
hx = hy
cx = cy
return torch.stack(output_inner), (hy, cy)