Skip to content

Commit

Permalink
[Frontend] [Torch] [ONNX] GRU layer (apache#8781)
Browse files Browse the repository at this point in the history
* GRU cell was implemented in common.py. GRU was supported on pytorch frontend side

* update GRU in common.py and onnx frontend

* fix issue related to GRU accuracy in pytorch and ONNX frontend

* small fixes and remove excess

* common GRU was additionaly updated. tuned pytorch GRU was strongly accelerated

* GRU cell in ONNX frontend was used from common.py. previous implementation was removed

* small fixes in comments

* fixes after review. GRU test was implemented for pytorch frontend

* tests for RNN layers was unified for pytorch frontend

Co-authored-by: Valery Chernov <[email protected]>
  • Loading branch information
2 people authored and ylc committed Sep 29, 2021
1 parent 4c877c2 commit 4b35bff
Show file tree
Hide file tree
Showing 5 changed files with 774 additions and 441 deletions.
84 changes: 84 additions & 0 deletions python/tvm/relay/frontend/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,6 +658,90 @@ def unbind(data, axis=0):
return _expr.TupleWrapper(_expr.Tuple(ret), selections)


def gru_cell(
input_seqs,
hidden_state,
w_inp,
w_hid,
b_inp=None,
b_hid=None,
rz_act=_op.sigmoid,
n_act=_op.tanh,
backwards=False,
linear_before_reset=True,
):
"""
Common implementation of GRU cell for all frontends of TVM
TODO(vvchernov): currently it is used by pytorch and ONNX. Extend for other frontends
Parameters
----------
input_seqs : List[relay.Expr]
The sequence of input tensors
Input tensor should be 2d while issue #8412 is not resolved
Shape = (batch, feature_size)
hidden_state : relay.Expr
Hidden state. shape = (batch_size, hidden_size)
w_inp, w_hid : relay.Expr
weight matrices. wi shape = (3 * hidden_size, feature_size)
wh shape = (3 * hidden_size, hidden_size)
NOTE: wi = (w_ir|w_iz|w_in) for reset, update and new gates.
The order is important for correct GRU calculation!
b_inp, b_hid : relay.Expr
bias matrices. The same order of internal parts as for weights. shape = (3 * hidden_size)
r_act : relay.op
activation funtion for reset gate. it is sigmoid by default
z_act : relay.op
activation funtion for update gate. it is sigmoid by default
n_act : relay.op
activation funtion for new gate. it is tanh by default
backwards : bool
Flag for reverse pass of GRU
Returns
-------
result : List[relay.Expr], relay.Expr, relay.Expr
The sequence of computed result, final hidden and cell state
"""

outputs_list = []
for x_t in input_seqs if not backwards else reversed(input_seqs):
xwt = _op.nn.dense(x_t, w_inp)
if linear_before_reset:
hwt = _op.nn.dense(hidden_state, w_hid)
if b_inp is not None and b_hid is not None:
xwt += b_inp
hwt += b_hid
i_r, i_z, i_n = _op.split(xwt, 3, axis=-1)
h_r, h_z, h_n = _op.split(hwt, 3, axis=-1)
r_gate = rz_act(i_r + h_r)
z_gate = rz_act(i_z + h_z)
n_gate = n_act(i_n + r_gate * h_n)
else:
i_r, i_z, i_n = _op.split(xwt, 3, axis=1)
w_hr, w_hz, w_hn = _op.split(w_hid, 3, axis=0)
r_gate = i_r + _op.nn.dense(hidden_state, w_hr)
z_gate = i_z + _op.nn.dense(hidden_state, w_hz)
if b_inp is not None and b_hid is not None:
b_ir, b_iz, b_in = _op.split(b_inp, 3, axis=-1)
b_hr, b_hz, b_hn = _op.split(b_hid, 3, axis=-1)
r_gate += b_ir + b_hr
z_gate += b_iz + b_hz
i_n += b_in
h_n = _op.nn.dense((r_gate * hidden_state), w_hn) + b_hn
else:
h_n = _op.nn.dense((r_gate * hidden_state), w_hn)
r_gate = rz_act(r_gate)
z_gate = rz_act(z_gate)
n_gate = n_act(i_n + h_n)

hidden_state = (hidden_state - n_gate) * z_gate + n_gate

outputs_list.append(hidden_state) # [seq_num, (batch, hidden_size)]

return outputs_list, hidden_state


def lstm_cell(
input_seqs,
hidden_state,
Expand Down
149 changes: 72 additions & 77 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
infer_value,
new_var,
unbind,
gru_cell,
lstm_cell,
)

Expand Down Expand Up @@ -2349,56 +2350,41 @@ class GRU(RNN):
"""Operator convert for GRU"""

@classmethod
def generate_gru(
cls, X_steps, H_t, W, R, B, linear_before_reset, f_act, g_act, W_dtype, backwards=False
def bidir_gru_cell(
cls,
input_seqs,
weight_dicts,
acts,
):
"""Create an unrolled gru loop.
See https://github.com/onnx/onnx/blob/master/docs/Operators.md for math.
"""
h_list = []
seq_length = len(X_steps)
for i in range(seq_length):
step = X_steps[i] if not backwards else X_steps[seq_length - (i + 1)]
step = _op.squeeze(step, axis=[0])
current = _op.nn.dense(step, W)
cz, cr, ch = _op.split(current, 3, axis=1)
rz, rr, rh = _op.split(R, 3, axis=0)
z = cz + _op.nn.dense(H_t, rz)
r = cr + _op.nn.dense(H_t, rr)
if B is not None:
WB, RB = _op.split(B, 2)
wbz, wbr, wbh = _op.split(WB, 3, axis=-1)
rbz, rbr, rbh = _op.split(RB, 3, axis=-1)
z += wbz + rbz
r += wbr + rbr
if linear_before_reset:
h = ch + (r * (_op.nn.dense(H_t, rh) + rbh)) + wbh
else:
h = ch + _op.nn.dense((r * H_t), rh) + wbh + rbh
else:
if linear_before_reset:
h = ch + (r * (_op.nn.dense(H_t, rh)))
else:
h = ch + _op.nn.dense((r * H_t), rh)

z = f_act(z)
r = f_act(r)
h = g_act(h)

H_t = ((_expr.const(1, dtype=W_dtype) - z) * h) + (z * H_t)
h_list.append(_op.expand_dims(H_t, axis=0))
Bidirectional GRU cell
"""
seq_len = len(input_seqs)
forward_outputs, fw_H_t = gru_cell(
input_seqs,
**weight_dicts[0],
rz_act=acts[0],
n_act=acts[1],
)

if backwards:
# Canonical view is hidden states from the first token not last
h_list = h_list[::-1]
reverse_outputs, rev_H_t = gru_cell(
input_seqs,
**weight_dicts[1],
rz_act=acts[2],
n_act=acts[3],
backwards=True,
)

# Concatenate outputs and add back in direction axis.
concatenated = _op.concatenate(h_list, 0)
output = _op.expand_dims(concatenated, axis=1)
H_t = _op.expand_dims(H_t, axis=0)
final_outputs = []
for i in range(seq_len):
final_outputs.append(
_op.stack([forward_outputs[i], reverse_outputs[seq_len - 1 - i]], axis=0)
)

return output, H_t
return (
_op.stack(final_outputs, axis=0),
_op.stack([fw_H_t, rev_H_t], axis=0),
)

@classmethod
def _impl_v7(cls, inputs, attr, params):
Expand All @@ -2416,20 +2402,14 @@ def _impl_v7(cls, inputs, attr, params):
W_dtype = infer_type(Wp).checked_type.dtype

if num_directions not in [1, 2]:
raise NotImplementedError(
f"Directions for GRUs should be either 1 or 2 got {num_directions}"
)
raise ValueError("num_directions must be either 1 or 2!")

X_shape = infer_shape(X)
hidden_size = infer_shape(Rp)[-1]
batch_size = X_shape[1]

# Initialize state if not provided.
# Otherwise remove bidirectional axis.
if Hp_0 is None:
Hp_0 = _op.zeros((num_directions, batch_size, hidden_size), W_dtype)
if Bp is None:
Bp = _op.zeros((num_directions, hidden_size * 6), W_dtype)

if "activations" in attr:
activations = attr["activations"]
Expand Down Expand Up @@ -2460,39 +2440,54 @@ def _impl_v7(cls, inputs, attr, params):
else:
acts = [_op.sigmoid, _op.tanh] * 2

result_output = []
result_H = []
# TODO (vvchernov): It can be replaced by _op.split if issue #8412 is resolved
X_steps = unbind(X, axis=0)

X_steps = _op.split(X, indices_or_sections=X_shape[0], axis=0)
H_ts = _op.split(Hp_0, num_directions)
Ws = _op.split(Wp, num_directions)
Rs = _op.split(Rp, num_directions)
Bs = _op.split(Bp, num_directions)

if Bp is not None:
Bs = _op.split(Bp, num_directions)

weights_dicts = []
for i in range(num_directions):
H_t = _op.squeeze(H_ts[i], axis=[0])
W = _op.squeeze(Ws[i], axis=[0])
R = _op.squeeze(Rs[i], axis=[0])
B = _op.squeeze(Bs[i], axis=[0])
f_act, g_act = acts[i * 2 : (i + 1) * 2]
output, H = GRU.generate_gru(
X_steps=X_steps,
H_t=H_t,
W=W,
R=R,
B=B,
linear_before_reset=linear_before_reset,
f_act=f_act,
g_act=g_act,
W_dtype=W_dtype,
backwards=i == 1,
)
weights_dict = {}

weights_dict["hidden_state"] = _op.squeeze(H_ts[i], axis=[0])
weights_dict["linear_before_reset"] = linear_before_reset

# Weights permutation: onnx format i-o-f-c, lstm cell format i-f-c-o
matz, matr, matn = _op.split(_op.squeeze(Ws[i], axis=[0]), 3)
weights_dict["w_inp"] = _op.concatenate([matr, matz, matn], axis=0)
matz, matr, matn = _op.split(_op.squeeze(Rs[i], axis=[0]), 3)
weights_dict["w_hid"] = _op.concatenate([matr, matz, matn], axis=0)
if Bp is not None:
Bi, Bh = _op.split(Bs[i], 2, -1)
matz, matr, matn = _op.split(_op.squeeze(Bi, axis=[0]), 3)
weights_dict["b_inp"] = _op.concatenate([matr, matz, matn], axis=0)
matz, matr, matn = _op.split(_op.squeeze(Bh, axis=[0]), 3)
weights_dict["b_hid"] = _op.concatenate([matr, matz, matn], axis=0)
weights_dicts.append(weights_dict)

result_output.append(output)
result_H.append(H)
if num_directions == 2:
output, H = GRU.bidir_gru_cell(
input_seqs=X_steps,
weight_dicts=weights_dicts,
acts=acts,
)
else:
# outputs shape = [seqs_num, (batch_size, hidden_size)]
outputs, H = gru_cell(
input_seqs=X_steps,
**weights_dicts[0],
rz_act=acts[0],
n_act=acts[1],
)

output = _op.concatenate(result_output, axis=1)
H = _op.concatenate(result_H, axis=0)
# output shape = (seqs_num, num_directions, batch_size, hidden_size)
output = _op.expand_dims(_op.stack(outputs, axis=0), axis=1)
H = _op.expand_dims(H, axis=0)

return _expr.TupleWrapper(_expr.Tuple((output, H)), 2)

Expand Down
Loading

0 comments on commit 4b35bff

Please sign in to comment.