Skip to content

Commit

Permalink
fixes after review. GRU test was implemented for pytorch frontend
Browse files Browse the repository at this point in the history
  • Loading branch information
vvchernov committed Aug 23, 2021
1 parent 9727bf8 commit 17a40ba
Show file tree
Hide file tree
Showing 2 changed files with 257 additions and 58 deletions.
8 changes: 3 additions & 5 deletions python/tvm/relay/frontend/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,7 @@ def gru_cell(
):
"""
Common implementation of GRU cell for all frontends of TVM
TODO(vvchernov): currently it is used by pytorch. Extend for other frontends
TODO(vvchernov): currently it is used by pytorch and ONNX. Extend for other frontends
Parameters
----------
Expand Down Expand Up @@ -709,8 +709,7 @@ def gru_cell(
xwt = _op.nn.dense(x_t, w_inp)
if linear_before_reset:
hwt = _op.nn.dense(hidden_state, w_hid)
# TODO(vvchernov): It is assumed that both bias are or not
if b_inp is not None:
if b_inp is not None and b_hid is not None:
xwt += b_inp
hwt += b_hid
i_r, i_z, i_n = _op.split(xwt, 3, axis=-1)
Expand All @@ -723,8 +722,7 @@ def gru_cell(
w_hr, w_hz, w_hn = _op.split(w_hid, 3, axis=0)
r_gate = i_r + _op.nn.dense(hidden_state, w_hr)
z_gate = i_z + _op.nn.dense(hidden_state, w_hz)
# TODO(vvchernov): It is assumed that both bias are or not
if b_inp is not None:
if b_inp is not None and b_hid is not None:
b_ir, b_iz, b_in = _op.split(b_inp, 3, axis=-1)
b_hr, b_hz, b_hn = _op.split(b_hid, 3, axis=-1)
r_gate += b_ir + b_hr
Expand Down
Loading

0 comments on commit 17a40ba

Please sign in to comment.