Skip to content

Commit

Permalink
Update modeling_persimmon.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Luodian authored Dec 30, 2023
1 parent 488c59a commit 366a09d
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion src/otter_ai/models/fuyu/modeling_persimmon.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,10 @@ def forward(self, hidden_states):
1: recompute gelu_out in the bwd
2: recompute gelu_in and gelu_out in the bwd
"""
hidden_states = fused_mlp_func(hidden_states, self.dense_h_to_4h.weight, self.dense_4h_to_h.weight, self.dense_h_to_4h.bias, self.dense_4h_to_h.bias, "sqrelu", True, False, 0, -1)
if hasattr(self.dense_h_to_4h, "bias") and hasattr(self.dense_4h_to_h, "bias"):
hidden_states = fused_mlp_func(hidden_states, self.dense_h_to_4h.weight, self.dense_4h_to_h.weight, self.dense_h_to_4h.bias, self.dense_4h_to_h.bias, "sqrelu", True, False, 0, -1) # Thanks [Dongfu Jiang](https://jdf-prog.github.io/) for adding this exception!
else:
hidden_states = fused_mlp_func(hidden_states, self.dense_h_to_4h.weight, self.dense_4h_to_h.weight, None, None, "sqrelu", True, False, 0, -1)
return hidden_states


Expand Down

0 comments on commit 366a09d

Please sign in to comment.