Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend, pytorch] Vc/pytorch lstm #8447

Merged
merged 11 commits into from
Jul 20, 2021
294 changes: 294 additions & 0 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from . import qnn_torch
from .common import AttrCvt, get_relay_op
from .common import infer_value as _infer_value
from .common import infer_shape as _infer_shape
from .common import infer_value_simulated as _infer_value_simulated
from .common import try_infer_value
from .pytorch_utils import is_version_greater_than
Expand Down Expand Up @@ -2329,6 +2330,298 @@ def flip(self, inputs, input_types):
axis = inputs[1]
return _op.transform.reverse(data, axis=axis[0])

def lstm_cell(self, input_seqs, hidden, weights, has_proj=False):
if has_proj:
assert len(weights) == 5
else:
assert len(weights) == 4
outputs_list = []
# Default activations types
f_act = _op.sigmoid
g_act = _op.tanh
h_act = _op.tanh

# Input hiddens
H_t = hidden[0] # (batch, hidden_size)
C_t = hidden[1] # (batch, hidden_size)
for x_t in input_seqs:
# x_t shape = (batch, feature size)
# gates shape = (batch, 4 * hidden_size)
gates = _op.nn.dense(x_t, weights[0]) + _op.nn.dense(H_t, weights[1])
# Add biases
if weights[2] is not None:
gates += weights[2]
if weights[3] is not None:
gates += weights[3]
i, f, c, o = _op.split(gates, 4, axis=-1) # (batch, hidden_size)

i = f_act(i)
f = f_act(f)
c = g_act(c)
o = f_act(o)

C = f * C_t + i * c
H = o * h_act(C)

if has_proj:
H = _op.nn.dense(H, weights[4])

H_t = H
C_t = C
outputs_list.append(H) # [seq_num, (batch, hidden_size)]
hidden_outputs = (H_t, C_t)

return (outputs_list, hidden_outputs)

def bidir_lstm_cell(self, input_seq, hidden_pair, weights_pair, has_proj=False):
fw_outputs = self.lstm_cell(input_seq, hidden_pair[0], weights_pair[0], has_proj)

rev_input_seq = []
seq_len = len(input_seq)
for i in range(seq_len):
rev_input_seq.append(input_seq[seq_len - 1 - i]) # [seq_num, (batch, hidden_size)]
rev_outputs = self.lstm_cell(rev_input_seq, hidden_pair[1], weights_pair[1], has_proj)

final_outputs = [] # [seq_num, (batch, 2 * hidden_size)]
for j in range(seq_len):
final_outputs.append(
_op.concatenate([fw_outputs[0][j], rev_outputs[0][seq_len - 1 - j]], -1)
)

return final_outputs, (fw_outputs[1], rev_outputs[1])

def lstm_layers(
self, input_data, hiddens, weights, bidirectional, dtype, dropout_p=0.0, has_proj=False
):
hidden_layers_num = len(hiddens)
assert len(weights) == hidden_layers_num

# split input sequence to samples set
input_seqs = self.unbind((input_data, 0), dtype) # [seq_num, (batch, feature_size)]
output_hiddens = []
for k in range(hidden_layers_num):
hiddens_input = hiddens[k]
weights_input = weights[k]

outputs = (
self.bidir_lstm_cell(input_seqs, hiddens_input, weights_input, has_proj)
if bidirectional
else self.lstm_cell(input_seqs, hiddens_input, weights_input, has_proj)
)

output_hiddens.append(outputs[1])
# input_seqs shape = [seq_num, (batch, feature_size)] or
# [seq_num, (batch, 2*feature_size)] for bidirectional
input_seqs = outputs[0]

# TODO (vvchernov): in pytorch implementation train is also checked
# see https://github.com/pytorch/pytorch/blob/70c8daf43946b53af6493d058899ef952d27d339
# /aten/src/ATen/native/RNN.cpp#L1054
if dropout_p != 0 and k < hidden_layers_num - 1:
# for input in input_seqs:
# input = _op.dropout(input, dropout_p)
raise NotImplementedError("Dropout for LSTM has not been supported yet!")
final_hiddens = []
if bidirectional:
for i in range(hidden_layers_num):
final_hiddens.append(output_hiddens[i][0])
final_hiddens.append(output_hiddens[i][1])
else:
final_hiddens = output_hiddens

return _op.stack(input_seqs, 0), final_hiddens

def lstm(self, inputs, input_types):
"""
Description of LSTM in pytorch:https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html
Native implementation for torch version less than 1.8.0 (projection is unsupported):
https://github.com/pytorch/pytorch/blob/70c8daf43946b53af6493d058899ef952d27d339/aten/ \
src/ATen/native/RNN.cpp#L1396
Native implementation for torch version from 1.8.0 and higher (projection is supported):
https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/RNN.cpp#L1483
"""
# TODO (vvchernov): support dropout
assert len(inputs) == 9, "Input of size 9 is expected"
# Unpack inputs, note that if optional and not provided then value will be None.
_X = inputs[0]
# _X shape (seq_num, batch, feature_size) or (batch, seq_num, feature_size)

hidden_states = inputs[1]
assert len(hidden_states) == 2, "lstm expects two hidden states"
h_0 = hidden_states[0]
c_0 = hidden_states[1]
# H0 shape (hidden_layers_num, batch, proj_size) if projection
# else (hidden_layers_num, batch, hidden_size)
# C0 shape (hidden_layers_num, batch, hidden_size)

_weights = inputs[2]
# If no projection
# Wi layer[0] shape (4 * hidden_size, feature_size)
# Wh layer[0] shape (4 * hidden_size, hidden_size)
# Bi layer[0] shape (4 * hidden_size)
# Bh layer[0] shape (4 * hidden_size)

# Wi layer[>0] shape (4 * hidden_size, hidden_size * num_directions)
# Wh layer[>0] shape (4 * hidden_size, hidden_size)
# Bi layer[>0] shape (4 * hidden_size)
# Bh layer[>0] shape (4 * hidden_size)

# If projection
# Wi layer[0] shape (4 * hidden_size, feature_size)
# Wh layer[0] shape (4 * hidden_size, proj_size)
# Bi layer[0] shape (4 * hidden_size)
# Bh layer[0] shape (4 * hidden_size)
# P layer[0] shape (proj_size, hidden_size)

# Wi layer[>0] shape (4 * hidden_size, proj_size * num_directions)
# Wh layer[>0] shape (4 * hidden_size, proj_size)
# Bi layer[>0] shape (4 * hidden_size)
# Bh layer[>0] shape (4 * hidden_size)
# P layer[>0] shape (proj_size, hidden_size)

# Scalar inputs
has_biases = inputs[3]
num_layers = inputs[4]
dropout_p = inputs[5] # dropout probability, if 0.0 it means there is no dropout
# train = inputs[6]
bidirectional = inputs[7]
batch_first = inputs[8]

num_directions = 1
if bidirectional:
num_directions = 2

rsd = len(_weights) % num_layers
assert rsd == 0, "The number of weights must be a multiple of the number of layers!"
rsd = (len(_weights) / num_layers) % num_directions
assert (
rsd == 0
), "The number of weights in layer must be a multiple of the number of directions!"
has_proj = False
proj_size = 0
weights_num = int(len(_weights) / num_layers / num_directions)
if has_biases:
if weights_num == 5:
has_proj = True
proj_size = _infer_shape(_weights[4])[0]
else:
assert weights_num == 4, "The weights number in layer is expected equal to 4"
else:
if weights_num == 3:
has_proj = True
proj_size = _infer_shape(_weights[2])[0]
else:
assert weights_num == 2, "The weights number in layer is expected equal to 2"

weights = []
if has_biases:
if bidirectional:
rsd = len(_weights) % (2 * weights_num)
assert rsd == 0, "got an incorrect number of LSTM weights"
for i in range(0, len(_weights), 2 * weights_num):
fw_weights = []
rev_weights = []
for j in range(weights_num):
fw_weights.append(_weights[i + j])
rev_weights.append(_weights[i + j + weights_num])
weights.append((fw_weights, rev_weights))
else:
assert len(_weights) % weights_num == 0, "got an incorrect number of LSTM weights"
for i in range(0, len(_weights), weights_num):
fw_weights = []
for j in range(weights_num):
fw_weights.append(_weights[i + j])
weights.append(fw_weights)
else:
if bidirectional:
rsd = len(_weights) % (2 * weights_num)
assert rsd == 0, "got an incorrect number of LSTM weights"
for i in range(0, len(_weights), 2 * weights_num):
fw_weights = []
rev_weights = []
k = i + weights_num
if has_proj:
fw_weights = [_weights[i], _weights[i + 1], None, None, _weights[i + 2]]
rev_weights = [_weights[k], _weights[k + 1], None, None, _weights[k + 2]]
else:
fw_weights = [_weights[i], _weights[i + 1], None, None]
rev_weights = [_weights[k], _weights[k + 1], None, None]
weights.append((fw_weights, rev_weights))
else:
assert len(_weights) % weights_num == 0, "got an incorrect number of LSTM weights"
for i in range(0, len(_weights), weights_num):
if has_proj:
fw_weights = [_weights[i], _weights[i + 1], None, None, _weights[i + 2]]
else:
fw_weights = [_weights[i], _weights[i + 1], None, None]
weights.append(fw_weights)
assert (
len(weights) == num_layers
), "For stacked LSTM number of weights tuples should be the same as number of layers!"

X = _op.transpose(_X, (1, 0, 2)) if batch_first else _X
# TODO (vvchernov): Which data type should be used? from input or weights?
# Instead of it _infer_type(X).checked_type.dtype can be used
X_dtype = input_types[0]
X_shape = _infer_shape(X) # (seq_num, batch, feature_size)

hidden_size = _infer_shape(_weights[0])[0] / 4
batch_size = X_shape[1]

# Initialize hidden states if not provided.
layers_h = []
layers_c = []
hidden_layers_num = num_directions * num_layers
if h_0 is None:
if has_proj:
h_0 = _op.zeros((batch_size, proj_size), X_dtype)
else:
h_0 = _op.zeros((batch_size, hidden_size), X_dtype)
for i in range(hidden_layers_num):
layers_h.append(h_0)
else:
layers_h = self.unbind((h_0, 0), X_dtype)
if c_0 is None:
c_0 = _op.zeros((batch_size, hidden_size), X_dtype)
for i in range(hidden_layers_num):
layers_c.append(c_0)
else:
layers_c = self.unbind((c_0, 0), X_dtype)

hiddens = []
for i in range(num_layers):
if bidirectional:
hiddens.append(
((layers_h[2 * i], layers_c[2 * i]), (layers_h[2 * i + 1], layers_c[2 * i + 1]))
)
else:
hiddens.append((layers_h[i], layers_c[i]))

outputs = self.lstm_layers(
X,
hiddens,
weights,
bidirectional,
dtype=X_dtype,
dropout_p=dropout_p,
has_proj=has_proj,
)

# output shape = (seq_num, batch, hidden_size) or
# (seq_num, batch, 2*feature_size) for bidirectional
output = outputs[0]

hy = []
cy = []
for hidden in outputs[1]:
hy.append(hidden[0])
cy.append(hidden[1])

if batch_first:
output = _op.transpose(output, (1, 0, 2))

return (output, _op.stack(hy, 0), _op.stack(cy, 0))

# Operator mappings
def create_convert_map(self):
self.convert_map = {
Expand Down Expand Up @@ -2545,6 +2838,7 @@ def create_convert_map(self):
"aten::nll_loss": self.nll_loss,
"aten::nll_loss2d": self.nll_loss,
"aten::flip": self.flip,
"aten::lstm": self.lstm,
}

def update_convert_map(self, custom_map):
Expand Down
Loading