Skip to content

Commit

Permalink
Copy hidden gate in GRU op before applying activation
Browse files Browse the repository at this point in the history
This is a workaround needed because `tanh_in_place` is very slow with
non-contigous inputs. See #192.
  • Loading branch information
robertknight committed May 20, 2024
1 parent f355be2 commit 53d9e41
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions src/ops/rnn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ use rten_tensor::{Tensor, TensorView};
use crate::check_dims;
use crate::gemm::{GemmExecutor, GemmInputA, GemmInputB};
use crate::ops::{
add_in_place, mul_in_place, sigmoid_in_place, tanh_in_place, InputList, IntoOpResult, OpError,
Operator, Output,
add_in_place, mul_in_place, sigmoid_in_place, tanh, tanh_in_place, InputList, IntoOpResult,
OpError, Operator, Output,
};
use crate::tensor_pool::{AutoReturn, TensorPool};

Expand Down Expand Up @@ -264,6 +264,9 @@ pub fn gru(
update_reset_gates.as_dyn_mut(),
hidden_scratch_reset_update_gates.as_dyn(),
);

// nb. This is slower than it should be because it falls back to
// the slow path for non-contiguous tensors.
sigmoid_in_place(update_reset_gates.as_dyn_mut());

// Combine inputs for hidden gate and apply activation.
Expand All @@ -274,11 +277,13 @@ pub fn gru(

let mut hidden_gate = gates.slice_mut::<2, _>((.., gate_range(HIDDEN_GATE)));
add_in_place(hidden_gate.as_dyn_mut(), hidden_gate_recurrent.as_dyn());
tanh_in_place(hidden_gate.as_dyn_mut());

// Copy the hidden gate because `tanh_in_place` is slow with
// non-contiguous tensors.
let hidden_gate = tanh(pool, hidden_gate.as_dyn()).auto_return(pool);

// Compute next hidden state
let mut hidden_item = hidden.slice_mut::<2, _>([dir]);
let hidden_gate = gates.slice::<2, _>((.., gate_range(HIDDEN_GATE)));
let update_gate = gates.slice::<2, _>((.., gate_range(UPDATE_GATE)));

for (hidden, update, hidden_gate) in zip3(
Expand Down

0 comments on commit 53d9e41

Please sign in to comment.