Skip to content

Commit

Permalink
lint fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
vvchernov committed Jul 13, 2021
1 parent 197f49f commit 883c261
Showing 1 changed file with 23 additions and 14 deletions.
37 changes: 23 additions & 14 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
from . import qnn_torch
from .common import AttrCvt, get_relay_op
from .common import infer_value as _infer_value
from .common import infer_type as _infer_type
from .common import infer_shape as _infer_shape
from .common import infer_value_simulated as _infer_value_simulated
from .common import try_infer_value
Expand Down Expand Up @@ -2359,7 +2358,6 @@ def bidir_lstm_cell(self, input_seq, hidden_pair, weights_pair, has_proj=False):
fw_outputs = self.lstm_cell(input_seq, hidden_pair[0], weights_pair[0], has_proj)

rev_input_seq = []
_op.reverse_sequence
seq_len = len(input_seq)
for i in range(seq_len):
rev_input_seq.append(input_seq[seq_len - 1 - i]) # [seq_num, (batch, hidden_size)]
Expand All @@ -2374,13 +2372,13 @@ def bidir_lstm_cell(self, input_seq, hidden_pair, weights_pair, has_proj=False):
return final_outputs, (fw_outputs[1], rev_outputs[1])

def lstm_layers(
self, input, hiddens, weights, bidirectional, dtype, dropout_p=0.0, has_proj=False
self, input_data, hiddens, weights, bidirectional, dtype, dropout_p=0.0, has_proj=False
):
hidden_layers_num = len(hiddens)
assert len(weights) == hidden_layers_num

# split input sequence to samples set
input_seqs = self.unbind((input, 0), dtype) # [seq_num, (batch, feature_size)]
input_seqs = self.unbind((input_data, 0), dtype) # [seq_num, (batch, feature_size)]
output_hiddens = []
for k in range(hidden_layers_num):
hiddens_input = hiddens[k]
Expand All @@ -2393,10 +2391,13 @@ def lstm_layers(
)

output_hiddens.append(outputs[1])
# input_seqs shape = [seq_num, (batch, feature_size)] or [seq_num, (batch, 2*feature_size)] for bidirectional
# input_seqs shape = [seq_num, (batch, feature_size)] or
# [seq_num, (batch, 2*feature_size)] for bidirectional
input_seqs = outputs[0]

# TODO (vvchernov): in pytorch implementation train is also checked (see https://github.com/pytorch/pytorch/blob/70c8daf43946b53af6493d058899ef952d27d339/aten/src/ATen/native/RNN.cpp#L1054)
# TODO (vvchernov): in pytorch implementation train is also checked
# see https://github.com/pytorch/pytorch/blob/70c8daf43946b53af6493d058899ef952d27d339
# /aten/src/ATen/native/RNN.cpp#L1054
if dropout_p != 0 and k < hidden_layers_num - 1:
# for input in input_seqs:
# input = _op.dropout(input, dropout_p)
Expand All @@ -2412,9 +2413,14 @@ def lstm_layers(
return _op.stack(input_seqs, 0), final_hiddens

def lstm(self, inputs, input_types):
# Description of LSTM in pytorch: https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html
# https://github.com/pytorch/pytorch/blob/70c8daf43946b53af6493d058899ef952d27d339/aten/src/ATen/native/RNN.cpp#L1396 (projection is unsupported) and dependencies were used
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/RNN.cpp#L1483 (projection is supported) and dependencies were used
"""
Description of LSTM in pytorch:https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html
Native implementation for torch version less than 1.8.0 (projection is unsupported):
https://github.com/pytorch/pytorch/blob/70c8daf43946b53af6493d058899ef952d27d339/aten/ \
src/ATen/native/RNN.cpp#L1396
Native implementation for torch version from 1.8.0 and higher (projection is supported):
https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/RNN.cpp#L1483
"""
# TODO (vvchernov): support dropout
assert len(inputs) == 9, "Input of size 9 is expected"
# Unpack inputs, note that if optional and not provided then value will be None.
Expand All @@ -2425,7 +2431,8 @@ def lstm(self, inputs, input_types):
assert len(hidden_states) == 2, "lstm expects two hidden states"
h_0 = hidden_states[0]
c_0 = hidden_states[1]
# H0 shape (hidden_layers_num, batch, proj_size) if projection else (hidden_layers_num, batch, hidden_size)
# H0 shape (hidden_layers_num, batch, proj_size) if projection
# else (hidden_layers_num, batch, hidden_size)
# C0 shape (hidden_layers_num, batch, hidden_size)

_weights = inputs[2]
Expand Down Expand Up @@ -2514,7 +2521,7 @@ def lstm(self, inputs, input_types):
fw_weights = []
rev_weights = []
for j in range(weights_num + 2):
if j == 2 or j == 3:
if j in (2, 3):
fw_weights.append(None)
rev_weights.append(None)
else:
Expand All @@ -2526,7 +2533,7 @@ def lstm(self, inputs, input_types):
for i in range(0, len(_weights), weights_num):
fw_weights = []
for j in range(weights_num + 2):
if j == 2 or j == 3:
if j in (2, 3):
fw_weights.append(None)
else:
fw_weights.append(_weights[i + j])
Expand All @@ -2536,7 +2543,8 @@ def lstm(self, inputs, input_types):
), "For stacked LSTM number of weights tuples should be the same as number of layers!"

X = _op.transpose(_X, (1, 0, 2)) if batch_first else _X
# TODO (vvchernov): Which data type should be used? from input or weights (use _weights[0])? Also _infer_type(X).checked_type.dtype can be used
# TODO (vvchernov): Which data type should be used? from input or weights?
# Instead of it _infer_type(X).checked_type.dtype can be used
X_dtype = input_types[0]
X_shape = _infer_shape(X) # (seq_num, batch, feature_size)

Expand Down Expand Up @@ -2582,7 +2590,8 @@ def lstm(self, inputs, input_types):
has_proj=has_proj,
)

# output shape = (seq_num, batch, hidden_size) or (seq_num, batch, 2*feature_size) for bidirectional
# output shape = (seq_num, batch, hidden_size) or
# (seq_num, batch, 2*feature_size) for bidirectional
output = outputs[0]

hy = []
Expand Down

0 comments on commit 883c261

Please sign in to comment.