diff --git a/ivy/functional/backends/jax/activations.py b/ivy/functional/backends/jax/activations.py index 08f4f42c88ebf..cdb9340436cf6 100644 --- a/ivy/functional/backends/jax/activations.py +++ b/ivy/functional/backends/jax/activations.py @@ -37,7 +37,7 @@ def leaky_relu( def relu( x: JaxArray, /, *, complex_mode="jax", out: Optional[JaxArray] = None ) -> JaxArray: - return jnp.maximum(x, 0) + return jax.nn.relu(x) def sigmoid(