Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add jax frontends #7708

Merged
merged 1 commit into from
Dec 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 47 additions & 2 deletions ivy/functional/frontends/jax/numpy/name_space_functions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import warnings


# local
import ivy
from ivy.functional.frontends.jax.func_wrapper import (
Expand Down Expand Up @@ -72,6 +71,15 @@ def argsort(a, axis=-1, kind="stable", order=None):
return ivy.argsort(a, axis=axis)


@to_ivy_arrays_and_back
def asarray(
a,
dtype=None,
order=None,
):
return ivy.asarray(a, dtype=dtype)


@to_ivy_arrays_and_back
def allclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
return ivy.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
Expand Down Expand Up @@ -510,7 +518,6 @@ def multiply(x1, x2):

alltrue = all


sometrue = any


Expand Down Expand Up @@ -591,3 +598,41 @@ def expand_dims(a, axis):
@to_ivy_arrays_and_back
def eye(N, M=None, k=0, dtype=None):
return ivy.eye(N, M, k=k, dtype=dtype)


@to_ivy_arrays_and_back
def stack(arrays, axis=0, out=None, dtype=None):
if dtype:
return ivy.astype(
ivy.stack(arrays, axis=axis, out=out), ivy.as_ivy_dtype(dtype)
)
return ivy.stack(arrays, axis=axis, out=out)


@to_ivy_arrays_and_back
def take(
a,
indices,
axis=None,
out=None,
mode=None,
unique_indices=False,
indices_are_sorted=False,
fill_value=None,
):
return ivy.take_along_axis(a, indices, axis, out=out)


@to_ivy_arrays_and_back
def zeros_like(a, dtype=None, shape=None):
if shape:
return ivy.zeros(shape, dtype=dtype)
return ivy.zeros_like(a, dtype=dtype)


@to_ivy_arrays_and_back
def negative(
x,
/,
):
return ivy.negative(x)
2 changes: 1 addition & 1 deletion ivy/functional/ivy/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -3619,7 +3619,7 @@ def negative(
*,
out: Optional[ivy.Array] = None,
) -> ivy.Array:
"""Returns a new array with the positive value of each element in ``x``.
"""Returns a new array with the negative value of each element in ``x``.

Parameters
----------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3265,3 +3265,191 @@ def test_jax_numpy_eye(
k=k,
dtype=dtypes[0],
)


# asarray
@handle_frontend_test(
fn_tree="jax.numpy.asarray",
dtype_and_a=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("numeric"),
num_arrays=1,
min_num_dims=0,
max_num_dims=5,
min_dim_size=1,
max_dim_size=5,
),
)
def test_jax_numpy_asarray(
dtype_and_a,
as_variable,
num_positional_args,
native_array,
frontend,
fn_tree,
on_device,
):
dtype, a = dtype_and_a
helpers.test_frontend_function(
input_dtypes=dtype,
as_variable_flags=as_variable,
with_out=False,
num_positional_args=num_positional_args,
native_array_flags=native_array,
frontend=frontend,
fn_tree=fn_tree,
on_device=on_device,
a=a,
dtype=dtype[0],
)


# take
@handle_frontend_test(
fn_tree="jax.numpy.take",
dtype_indices_axis=helpers.array_indices_axis(
array_dtypes=helpers.get_dtypes("numeric"),
indices_dtypes=helpers.get_dtypes("integer"),
min_num_dims=1,
max_num_dims=5,
min_dim_size=1,
max_dim_size=10,
indices_same_dims=True,
),
)
def test_jax_numpy_take(
*,
dtype_indices_axis,
as_variable,
with_out,
num_positional_args,
native_array,
on_device,
fn_tree,
frontend,
):
input_dtypes, value, indices, axis, _ = dtype_indices_axis
helpers.test_frontend_function(
input_dtypes=input_dtypes,
as_variable_flags=as_variable,
with_out=with_out,
num_positional_args=3,
native_array_flags=native_array,
frontend=frontend,
fn_tree=fn_tree,
on_device=on_device,
a=value,
indices=indices,
axis=axis,
)


# zeros_like
@handle_frontend_test(
fn_tree="jax.numpy.zeros_like",
dtype_and_x=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("valid"),
),
shape=helpers.get_shape(
allow_none=True,
min_num_dims=1,
max_num_dims=5,
min_dim_size=1,
max_dim_size=10,
),
dtype=helpers.get_dtypes("valid", full=False),
)
def test_numpy_zeros_like(
dtype_and_x,
dtype,
shape,
as_variable,
num_positional_args,
native_array,
frontend,
fn_tree,
on_device,
):
input_dtype, x = dtype_and_x
helpers.test_frontend_function(
input_dtypes=input_dtype,
as_variable_flags=as_variable,
with_out=False,
num_positional_args=num_positional_args,
native_array_flags=native_array,
frontend=frontend,
fn_tree=fn_tree,
on_device=on_device,
a=x[0],
dtype=dtype[0],
shape=shape,
)


# stack
@handle_frontend_test(
fn_tree="jax.numpy.stack",
dtype_values_axis=helpers.dtype_values_axis(
available_dtypes=helpers.get_dtypes("float"),
num_arrays=st.shared(helpers.ints(min_value=2, max_value=4), key="num_arrays"),
shape=helpers.get_shape(min_num_dims=1),
shared_dtype=True,
valid_axis=True,
allow_neg_axes=True,
force_int_axis=True,
),
dtype=helpers.get_dtypes("valid", full=False),
)
def test_jax_numpy_stack(
dtype_values_axis,
dtype,
as_variable,
with_out,
num_positional_args,
native_array,
on_device,
fn_tree,
frontend,
):
input_dtype, values, axis = dtype_values_axis
helpers.test_frontend_function(
input_dtypes=input_dtype,
as_variable_flags=as_variable,
with_out=with_out,
num_positional_args=num_positional_args,
native_array_flags=native_array,
frontend=frontend,
fn_tree=fn_tree,
on_device=on_device,
arrays=values,
axis=axis,
)


@handle_frontend_test(
fn_tree="jax.numpy.negative",
dtype_and_x=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("numeric"), min_num_dims=1
),
)
def test_jax_numpy_negative(
dtype_and_x,
as_variable,
with_out,
num_positional_args,
native_array,
frontend,
fn_tree,
on_device,
):
input_dtype, x = dtype_and_x
helpers.test_frontend_function(
input_dtypes=input_dtype,
as_variable_flags=as_variable,
with_out=False,
num_positional_args=num_positional_args,
native_array_flags=native_array,
frontend=frontend,
fn_tree=fn_tree,
on_device=on_device,
x=x[0],
)