Skip to content

Commit

Permalink
small fixes in comments
Browse files Browse the repository at this point in the history
  • Loading branch information
vvchernov committed Aug 20, 2021
1 parent 600a66e commit 9727bf8
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2345,7 +2345,7 @@ def bidir_gru_cell(

def gru_layers(self, input_data, layer_weights_dicts, bidirectional, dropout_p=0.0):
"""
Methods iterates layers for Stacked LSTM
Methods iterates layers for Stacked GRU
"""
layers_num = len(layer_weights_dicts)
# split input sequence to samples set
Expand All @@ -2368,7 +2368,7 @@ def gru_layers(self, input_data, layer_weights_dicts, bidirectional, dropout_p=0
if dropout_p != 0 and i < layers_num - 1:
# for input in input_seqs:
# input = _op.dropout(input, dropout_p)
raise NotImplementedError("Dropout for LSTM has not been supported yet!")
raise NotImplementedError("Dropout for GRU has not been supported yet!")

return _op.stack(input_seqs, 0), _op.stack(output_hiddens, 0)

Expand Down Expand Up @@ -2447,7 +2447,7 @@ def gru(self, inputs, input_types):
names = ["hidden_state", "w_inp", "w_hid", "b_inp", "b_hid"]
if bidirectional:
rsd = len(_weights) % (2 * weights_num)
assert rsd == 0, "got an incorrect number of LSTM weights"
assert rsd == 0, "got an incorrect number of GRU weights"
for i in range(0, len(_weights), 2 * weights_num):
fw_tensors = [layers_h[2 * k], *_weights[i : i + 4]]
fw_weights_dict = dict(zip(names, fw_tensors))
Expand All @@ -2457,7 +2457,7 @@ def gru(self, inputs, input_types):
layer_weights_dicts.append([fw_weights_dict, rev_weights_dict])
k += 1
else:
assert len(_weights) % weights_num == 0, "got an incorrect number of LSTM weights"
assert len(_weights) % weights_num == 0, "got an incorrect number of GRU weights"
for i in range(0, len(_weights), weights_num):
fw_tensors = [layers_h[k], *_weights[i : i + 4]]
fw_weights_dict = dict(zip(names, fw_tensors))
Expand All @@ -2467,7 +2467,7 @@ def gru(self, inputs, input_types):
names = ["hidden_state", "w_inp", "w_hid"]
if bidirectional:
rsd = len(_weights) % (2 * weights_num)
assert rsd == 0, "got an incorrect number of LSTM weights"
assert rsd == 0, "got an incorrect number of GRU weights"
for i in range(0, len(_weights), 2 * weights_num):
fw_tensors = [layers_h[2 * k], *_weights[i : i + 2]]
fw_weights_dict = dict(zip(names, fw_tensors))
Expand All @@ -2477,7 +2477,7 @@ def gru(self, inputs, input_types):
layer_weights_dicts.append([fw_weights_dict, rev_weights_dict])
k += 1
else:
assert len(_weights) % weights_num == 0, "got an incorrect number of LSTM weights"
assert len(_weights) % weights_num == 0, "got an incorrect number of GRU weights"
for i in range(0, len(_weights), weights_num):
fw_tensors = [layers_h[k], *_weights[i : i + 2]]
fw_weights_dict = dict(zip(names, fw_tensors))
Expand Down

0 comments on commit 9727bf8

Please sign in to comment.