Skip to content

Commit

Permalink
GRU cell in ONNX frontend was used from common.py. previous implement…
Browse files Browse the repository at this point in the history
…ation was removed
  • Loading branch information
vvchernov committed Aug 19, 2021
1 parent 36de7b2 commit 600a66e
Showing 1 changed file with 72 additions and 78 deletions.
150 changes: 72 additions & 78 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
infer_value,
new_var,
unbind,
gru_cell,
lstm_cell,
)

Expand Down Expand Up @@ -2349,57 +2350,41 @@ class GRU(RNN):
"""Operator convert for GRU"""

@classmethod
def generate_gru(
cls, X_steps, H_t, W, R, B, linear_before_reset, f_act, g_act, W_dtype, backwards=False
def bidir_gru_cell(
cls,
input_seqs,
weight_dicts,
acts,
):
"""Create an unrolled gru loop.
See https://github.com/onnx/onnx/blob/master/docs/Operators.md for math.
"""
h_list = []
seq_length = len(X_steps)
for i in range(seq_length):
step = X_steps[i] if not backwards else X_steps[seq_length - (i + 1)]
step = _op.squeeze(step, axis=[0])
current = _op.nn.dense(step, W)
cz, cr, ch = _op.split(current, 3, axis=1)
rz, rr, rh = _op.split(R, 3, axis=0)
z = cz + _op.nn.dense(H_t, rz)
r = cr + _op.nn.dense(H_t, rr)
if B is not None:
WB, RB = _op.split(B, 2)
wbz, wbr, wbh = _op.split(WB, 3, axis=-1)
rbz, rbr, rbh = _op.split(RB, 3, axis=-1)
z += wbz + rbz
r += wbr + rbr
r = f_act(r)
if linear_before_reset:
h = ch + (r * (_op.nn.dense(H_t, rh) + rbh)) + wbh
else:
h = ch + _op.nn.dense((r * H_t), rh) + wbh + rbh
else:
r = f_act(r)
if linear_before_reset:
h = ch + (r * (_op.nn.dense(H_t, rh)))
else:
h = ch + _op.nn.dense((r * H_t), rh)

z = f_act(z)
h = g_act(h)

H_t = (H_t - h) * z + h
h_list.append(_op.expand_dims(H_t, axis=0))
Bidirectional GRU cell
"""
seq_len = len(input_seqs)
forward_outputs, fw_H_t = gru_cell(
input_seqs,
**weight_dicts[0],
rz_act=acts[0],
n_act=acts[1],
)

if backwards:
# Canonical view is hidden states from the first token not last
h_list = h_list[::-1]
reverse_outputs, rev_H_t = gru_cell(
input_seqs,
**weight_dicts[1],
rz_act=acts[2],
n_act=acts[3],
backwards=True,
)

# Concatenate outputs and add back in direction axis.
concatenated = _op.concatenate(h_list, 0)
output = _op.expand_dims(concatenated, axis=1)
H_t = _op.expand_dims(H_t, axis=0)
final_outputs = []
for i in range(seq_len):
final_outputs.append(
_op.stack([forward_outputs[i], reverse_outputs[seq_len - 1 - i]], axis=0)
)

return output, H_t
return (
_op.stack(final_outputs, axis=0),
_op.stack([fw_H_t, rev_H_t], axis=0),
)

@classmethod
def _impl_v7(cls, inputs, attr, params):
Expand All @@ -2417,20 +2402,14 @@ def _impl_v7(cls, inputs, attr, params):
W_dtype = infer_type(Wp).checked_type.dtype

if num_directions not in [1, 2]:
raise NotImplementedError(
f"Directions for GRUs should be either 1 or 2 got {num_directions}"
)
raise ValueError("num_directions must be either 1 or 2!")

X_shape = infer_shape(X)
hidden_size = infer_shape(Rp)[-1]
batch_size = X_shape[1]

# Initialize state if not provided.
# Otherwise remove bidirectional axis.
if Hp_0 is None:
Hp_0 = _op.zeros((num_directions, batch_size, hidden_size), W_dtype)
if Bp is None:
Bp = _op.zeros((num_directions, hidden_size * 6), W_dtype)

if "activations" in attr:
activations = attr["activations"]
Expand Down Expand Up @@ -2461,39 +2440,54 @@ def _impl_v7(cls, inputs, attr, params):
else:
acts = [_op.sigmoid, _op.tanh] * 2

result_output = []
result_H = []
# TODO (vvchernov): It can be replaced by _op.split if issue #8412 is resolved
X_steps = unbind(X, axis=0)

X_steps = _op.split(X, indices_or_sections=X_shape[0], axis=0)
H_ts = _op.split(Hp_0, num_directions)
Ws = _op.split(Wp, num_directions)
Rs = _op.split(Rp, num_directions)
Bs = _op.split(Bp, num_directions)

if Bp is not None:
Bs = _op.split(Bp, num_directions)

weights_dicts = []
for i in range(num_directions):
H_t = _op.squeeze(H_ts[i], axis=[0])
W = _op.squeeze(Ws[i], axis=[0])
R = _op.squeeze(Rs[i], axis=[0])
B = _op.squeeze(Bs[i], axis=[0])
f_act, g_act = acts[i * 2 : (i + 1) * 2]
output, H = GRU.generate_gru(
X_steps=X_steps,
H_t=H_t,
W=W,
R=R,
B=B,
linear_before_reset=linear_before_reset,
f_act=f_act,
g_act=g_act,
W_dtype=W_dtype,
backwards=i == 1,
)
weights_dict = {}

weights_dict["hidden_state"] = _op.squeeze(H_ts[i], axis=[0])
weights_dict["linear_before_reset"] = linear_before_reset

# Weights permutation: onnx format i-o-f-c, lstm cell format i-f-c-o
matz, matr, matn = _op.split(_op.squeeze(Ws[i], axis=[0]), 3)
weights_dict["w_inp"] = _op.concatenate([matr, matz, matn], axis=0)
matz, matr, matn = _op.split(_op.squeeze(Rs[i], axis=[0]), 3)
weights_dict["w_hid"] = _op.concatenate([matr, matz, matn], axis=0)
if Bp is not None:
Bi, Bh = _op.split(Bs[i], 2, -1)
matz, matr, matn = _op.split(_op.squeeze(Bi, axis=[0]), 3)
weights_dict["b_inp"] = _op.concatenate([matr, matz, matn], axis=0)
matz, matr, matn = _op.split(_op.squeeze(Bh, axis=[0]), 3)
weights_dict["b_hid"] = _op.concatenate([matr, matz, matn], axis=0)
weights_dicts.append(weights_dict)

result_output.append(output)
result_H.append(H)
if num_directions == 2:
output, H = GRU.bidir_gru_cell(
input_seqs=X_steps,
weight_dicts=weights_dicts,
acts=acts,
)
else:
# outputs shape = [seqs_num, (batch_size, hidden_size)]
outputs, H = gru_cell(
input_seqs=X_steps,
**weights_dicts[0],
rz_act=acts[0],
n_act=acts[1],
)

output = _op.concatenate(result_output, axis=1)
H = _op.concatenate(result_H, axis=0)
# output shape = (seqs_num, num_directions, batch_size, hidden_size)
output = _op.expand_dims(_op.stack(outputs, axis=0), axis=1)
H = _op.expand_dims(H, axis=0)

return _expr.TupleWrapper(_expr.Tuple((output, H)), 2)

Expand Down

0 comments on commit 600a66e

Please sign in to comment.