From 9a06287a83b64ed63a12bbda23762f345237eb8c Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Wed, 18 Aug 2021 22:51:58 +0300 Subject: [PATCH] GRU cell in ONNX frontend was used from common.py. previous implementation was removed --- python/tvm/relay/frontend/onnx.py | 149 ++++++++++++++---------------- 1 file changed, 71 insertions(+), 78 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index fe3dcb6c07921..7d38dc6463043 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -47,6 +47,7 @@ infer_value, new_var, unbind, + gru_cell, lstm_cell, ) @@ -2349,57 +2350,41 @@ class GRU(RNN): """Operator convert for GRU""" @classmethod - def generate_gru( - cls, X_steps, H_t, W, R, B, linear_before_reset, f_act, g_act, W_dtype, backwards=False + def bidir_gru_cell( + cls, + input_seqs, + weight_dicts, + acts, ): - """Create an unrolled gru loop. - - See https://github.com/onnx/onnx/blob/master/docs/Operators.md for math. """ - h_list = [] - seq_length = len(X_steps) - for i in range(seq_length): - step = X_steps[i] if not backwards else X_steps[seq_length - (i + 1)] - step = _op.squeeze(step, axis=[0]) - current = _op.nn.dense(step, W) - cz, cr, ch = _op.split(current, 3, axis=1) - rz, rr, rh = _op.split(R, 3, axis=0) - z = cz + _op.nn.dense(H_t, rz) - r = cr + _op.nn.dense(H_t, rr) - if B is not None: - WB, RB = _op.split(B, 2) - wbz, wbr, wbh = _op.split(WB, 3, axis=-1) - rbz, rbr, rbh = _op.split(RB, 3, axis=-1) - z += wbz + rbz - r += wbr + rbr - r = f_act(r) - if linear_before_reset: - h = ch + (r * (_op.nn.dense(H_t, rh) + rbh)) + wbh - else: - h = ch + _op.nn.dense((r * H_t), rh) + wbh + rbh - else: - r = f_act(r) - if linear_before_reset: - h = ch + (r * (_op.nn.dense(H_t, rh))) - else: - h = ch + _op.nn.dense((r * H_t), rh) - - z = f_act(z) - h = g_act(h) - - H_t = (H_t - h) * z + h - h_list.append(_op.expand_dims(H_t, axis=0)) + Bidirectional GRU cell + """ + seq_len = len(input_seqs) + forward_outputs, fw_H_t = gru_cell( + input_seqs, + **weight_dicts[0], + rz_act=acts[0], + n_act=acts[1], + ) - if backwards: - # Canonical view is hidden states from the first token not last - h_list = h_list[::-1] + reverse_outputs, rev_H_t = gru_cell( + input_seqs, + **weight_dicts[1], + rz_act=acts[2], + n_act=acts[3], + backwards=True, + ) - # Concatenate outputs and add back in direction axis. - concatenated = _op.concatenate(h_list, 0) - output = _op.expand_dims(concatenated, axis=1) - H_t = _op.expand_dims(H_t, axis=0) + final_outputs = [] + for i in range(seq_len): + final_outputs.append( + _op.stack([forward_outputs[i], reverse_outputs[seq_len - 1 - i]], axis=0) + ) - return output, H_t + return ( + _op.stack(final_outputs, axis=0), + _op.stack([fw_H_t, rev_H_t], axis=0), + ) @classmethod def _impl_v7(cls, inputs, attr, params): @@ -2417,20 +2402,14 @@ def _impl_v7(cls, inputs, attr, params): W_dtype = infer_type(Wp).checked_type.dtype if num_directions not in [1, 2]: - raise NotImplementedError( - f"Directions for GRUs should be either 1 or 2 got {num_directions}" - ) + raise ValueError("num_directions must be either 1 or 2!") X_shape = infer_shape(X) hidden_size = infer_shape(Rp)[-1] batch_size = X_shape[1] - # Initialize state if not provided. - # Otherwise remove bidirectional axis. if Hp_0 is None: Hp_0 = _op.zeros((num_directions, batch_size, hidden_size), W_dtype) - if Bp is None: - Bp = _op.zeros((num_directions, hidden_size * 6), W_dtype) if "activations" in attr: activations = attr["activations"] @@ -2461,39 +2440,53 @@ def _impl_v7(cls, inputs, attr, params): else: acts = [_op.sigmoid, _op.tanh] * 2 - result_output = [] - result_H = [] + # TODO (vvchernov): It can be replaced by _op.split if issue #8412 is resolved + X_steps = unbind(X, axis=0) - X_steps = _op.split(X, indices_or_sections=X_shape[0], axis=0) H_ts = _op.split(Hp_0, num_directions) Ws = _op.split(Wp, num_directions) Rs = _op.split(Rp, num_directions) - Bs = _op.split(Bp, num_directions) + if Bp is not None: + Bs = _op.split(Bp, num_directions) + + weights_dicts = [] for i in range(num_directions): - H_t = _op.squeeze(H_ts[i], axis=[0]) - W = _op.squeeze(Ws[i], axis=[0]) - R = _op.squeeze(Rs[i], axis=[0]) - B = _op.squeeze(Bs[i], axis=[0]) - f_act, g_act = acts[i * 2 : (i + 1) * 2] - output, H = GRU.generate_gru( - X_steps=X_steps, - H_t=H_t, - W=W, - R=R, - B=B, - linear_before_reset=linear_before_reset, - f_act=f_act, - g_act=g_act, - W_dtype=W_dtype, - backwards=i == 1, - ) + weights_dict = {} + + weights_dict["hidden_state"] = _op.squeeze(H_ts[i], axis=[0]) + + # Weights permutation: onnx format i-o-f-c, lstm cell format i-f-c-o + matz, matr, matn = _op.split(_op.squeeze(Ws[i], axis=[0]), 3) + weights_dict["w_inp"] = _op.concatenate([matr, matz, matn], axis=0) + matz, matr, matn = _op.split(_op.squeeze(Rs[i], axis=[0]), 3) + weights_dict["w_hid"] = _op.concatenate([matr, matz, matn], axis=0) + if Bp is not None: + Bi, Bh = _op.split(Bs[i], 2, -1) + matz, matr, matn = _op.split(_op.squeeze(Bi, axis=[0]), 3) + weights_dict["b_inp"] = _op.concatenate([matr, matz, matn], axis=0) + matz, matr, matn = _op.split(_op.squeeze(Bh, axis=[0]), 3) + weights_dict["b_hid"] = _op.concatenate([matr, matz, matn], axis=0) + weights_dicts.append(weights_dict) - result_output.append(output) - result_H.append(H) + if num_directions == 2: + output, H = GRU.bidir_gru_cell( + input_seqs=X_steps, + weight_dicts=weights_dicts, + acts=acts, + ) + else: + # outputs shape = [seqs_num, (batch_size, hidden_size)] + outputs, H = gru_cell( + input_seqs=X_steps, + **weights_dicts[0], + rz_act=acts[0], + n_act=acts[1], + ) - output = _op.concatenate(result_output, axis=1) - H = _op.concatenate(result_H, axis=0) + # output shape = (seqs_num, num_directions, batch_size, hidden_size) + output = _op.expand_dims(_op.stack(outputs, axis=0), axis=1) + H = _op.expand_dims(H, axis=0) return _expr.TupleWrapper(_expr.Tuple((output, H)), 2)