Skip to content

Commit

Permalink
fix: ivy.trapz and jnp.trapezoid frontend
Browse files Browse the repository at this point in the history
  • Loading branch information
Sam-Armstrong committed May 17, 2024
1 parent 6d25d95 commit e3c906a
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 6 deletions.
2 changes: 1 addition & 1 deletion ivy/functional/backends/jax/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,7 @@ def trapz(
axis: int = -1,
out: Optional[JaxArray] = None,
) -> JaxArray:
return jnp.trapz(y, x=x, dx=dx, axis=axis)
return jnp.trapezoid(y, x=x, dx=dx, axis=axis)


@with_unsupported_dtypes(
Expand Down
5 changes: 3 additions & 2 deletions ivy/functional/backends/tensorflow/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@

# local
import ivy
from ivy.func_wrapper import with_unsupported_dtypes, with_supported_dtypes
from ivy import promote_types_of_inputs
from ivy.func_wrapper import with_unsupported_dtypes, with_supported_dtypes
from ivy.utils.exceptions import IvyNotImplementedException
from . import backend_version


Expand Down Expand Up @@ -771,8 +772,8 @@ def trapz(
axis: int = -1,
out: Optional[Union[tf.Tensor, tf.Variable]] = None,
) -> Union[tf.Tensor, tf.Variable]:
pass
# TODO: Implement purely in tensorflow
raise IvyNotImplementedException()


@with_unsupported_dtypes({"2.15.0 and below": ("complex",)}, backend_version)
Expand Down
6 changes: 5 additions & 1 deletion ivy/functional/frontends/jax/numpy/mathematical_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -903,10 +903,14 @@ def trace(a, offset=0, axis1=0, axis2=1, out=None):


@to_ivy_arrays_and_back
def trapz(y, x=None, dx=1.0, axis=-1, out=None):
def trapezoid(y, x=None, dx=1.0, axis=-1, out=None):
return ivy.trapz(y, x=x, dx=dx, axis=axis, out=out)


def trapz(y, x=None, dx=1.0, axis=-1, out=None):
return trapezoid(y, x=x, dx=dx, axis=axis, out=out)


@to_ivy_arrays_and_back
def trunc(x):
return ivy.trunc(x)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3314,11 +3314,11 @@ def test_jax_trace(


@handle_frontend_test(
fn_tree="jax.numpy.trapz",
fn_tree="jax.numpy.trapezoid",
dtype_x_axis_rand_either=_either_x_dx(),
test_with_out=st.just(False),
)
def test_jax_trapz(
def test_jax_trapezoid(
*,
dtype_x_axis_rand_either,
on_device,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1980,6 +1980,7 @@ def test_tanh(*, dtype_and_x, complex_mode, test_flags, backend_fw, fn_name, on_
),
rand_either=_either_x_dx(),
test_gradients=st.just(False),
ground_truth_backend="numpy",
)
def test_trapz(
dtype_values_axis, rand_either, test_flags, backend_fw, fn_name, on_device
Expand Down

0 comments on commit e3c906a

Please sign in to comment.