diff --git a/ivy/functional/frontends/jax/array.py b/ivy/functional/frontends/jax/array.py index 5a0c5826d23e7..5dcf3c67e2954 100644 --- a/ivy/functional/frontends/jax/array.py +++ b/ivy/functional/frontends/jax/array.py @@ -45,6 +45,10 @@ def at(self): def T(self): return self.ivy_array.T + @property + def ndim(self): + return self.ivy_array.ndim + # Instance Methods # # ---------------- # @@ -146,6 +150,16 @@ def any(self, *, axis=None, out=None, keepdims=False, where=None): self._ivy_array, axis=axis, keepdims=keepdims, out=out, where=where ) + def reshape(self, *args, order="C"): + if not isinstance(args[0], int): + if len(args) > 1: + raise TypeError( + "Shapes must be 1D sequences of concrete values of integer type," + f" got {args}." + ) + args = args[0] + return jax_frontend.numpy.reshape(self, tuple(args), order) + def __add__(self, other): return jax_frontend.numpy.add(self, other) @@ -264,8 +278,7 @@ def __setitem__(self, idx, val): ) def __iter__(self): - ndim = len(self.shape) - if ndim == 0: + if self.ndim == 0: raise TypeError("iteration over a 0-d Array not supported") for i in range(self.shape[0]): yield self[i] diff --git a/ivy/functional/frontends/jax/numpy/indexing.py b/ivy/functional/frontends/jax/numpy/indexing.py index 492f66c112a34..2471c98f36994 100644 --- a/ivy/functional/frontends/jax/numpy/indexing.py +++ b/ivy/functional/frontends/jax/numpy/indexing.py @@ -1,11 +1,14 @@ # global import inspect +import abc # local import ivy from ivy.functional.frontends.jax.func_wrapper import ( to_ivy_arrays_and_back, ) +from .creation import linspace, arange, array +from .manipulations import transpose, concatenate, expand_dims @to_ivy_arrays_and_back @@ -96,3 +99,101 @@ def indices(dimensions, dtype=int, sparse=False): else: grid = ivy.meshgrid(*[ivy.arange(dim) for dim in dimensions], indexing="ij") return ivy.stack(grid, axis=0).astype(dtype) + + +def _make_1d_grid_from_slice(s): + step = 1 if s.step is None else s.step + start = 0 if s.start is None else s.start + if s.step is not None and ivy.is_complex_dtype(s.step): + newobj = linspace(start, s.stop, int(abs(step))) + else: + newobj = arange(start, s.stop, step) + return newobj + + +class _AxisConcat(abc.ABC): + axis: int + ndmin: int + trans1d: int + + def __getitem__(self, key): + key_tup = key if isinstance(key, tuple) else (key,) + + params = [self.axis, self.ndmin, self.trans1d, -1] + + directive = key_tup[0] + if isinstance(directive, str): + key_tup = key_tup[1:] + # check two special cases: matrix directives + if directive == "r": + params[-1] = 0 + elif directive == "c": + params[-1] = 1 + else: + vec = directive.split(",") + k = len(vec) + if k < 4: + vec += params[k:] + else: + # ignore everything after the first three comma-separated ints + vec = vec[:3] + [params[-1]] + try: + params = list(map(int, vec)) + except ValueError as err: + raise ValueError( + f"could not understand directive {directive!r}" + ) from err + + axis, ndmin, trans1d, matrix = params + + output = [] + for item in key_tup: + if isinstance(item, slice): + newobj = _make_1d_grid_from_slice(item) + item_ndim = 0 + elif isinstance(item, str): + raise ValueError("string directive must be placed at the beginning") + else: + newobj = array(item, copy=False) + item_ndim = newobj.ndim + + newobj = array(newobj, copy=False, ndmin=ndmin) + + if trans1d != -1 and ndmin - item_ndim > 0: + shape_obj = tuple(range(ndmin)) + # Calculate number of left shifts, with overflow protection by mod + num_lshifts = ndmin - abs(ndmin + trans1d + 1) % ndmin + shape_obj = tuple(shape_obj[num_lshifts:] + shape_obj[:num_lshifts]) + + newobj = transpose(newobj, shape_obj) + + output.append(newobj) + + res = concatenate(tuple(output), axis=axis) + + if matrix != -1 and res.ndim == 1: + # insert 2nd dim at axis 0 or 1 + res = expand_dims(res, matrix) + + return res + + def __len__(self) -> int: + return 0 + + +class RClass(_AxisConcat): + axis = 0 + ndmin = 1 + trans1d = -1 + + +r_ = RClass() + + +class CClass(_AxisConcat): + axis = -1 + ndmin = 2 + trans1d = 0 + + +c_ = CClass() diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_array.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_array.py index 27c5d0b6fb153..abe04e7a68884 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_array.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_array.py @@ -8,6 +8,9 @@ from ivy_tests.test_ivy.test_functional.test_core.test_statistical import ( _get_castable_dtype, ) +from ivy_tests.test_ivy.test_frontends.test_jax.test_numpy.test_manipulations import ( + _get_input_and_reshape, +) CLASS_TREE = "ivy.functional.frontends.jax.numpy.ndarray" @@ -55,6 +58,24 @@ def test_jax_array_dtype( assert x.dtype == dtype[0] +@given( + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid", prune_function=False) + ), +) +def test_jax_array_ndim( + dtype_x, + backend_fw, +): + dtype, data = dtype_x + with update_backend(backend_fw) as ivy_backend: + jax_frontend = ivy_backend.utils.dynamic_import.import_module( + "ivy.functional.frontends.jax" + ) + x = jax_frontend.Array(data[0]) + assert x.ndim == data[0].ndim + + @given( dtype_x_shape=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid", prune_function=False), @@ -2341,6 +2362,50 @@ def test_jax_array_searchsorted( ) + +@handle_frontend_method( + class_tree=CLASS_TREE, + init_tree="jax.numpy.array", + method_name="reshape", + dtype_and_x_shape=_get_input_and_reshape(), + order=st.sampled_from(["C", "F"]), + input=st.booleans(), +) +def test_jax_array_reshape( + dtype_and_x_shape, + order, + input, + frontend, + frontend_method_data, + init_flags, + method_flags, + on_device, + backend_fw, +): + input_dtype, x, shape = dtype_and_x_shape + if input: + method_flags.num_positional_args = len(shape) + kwargs = {f"{i}": shape[i] for i in range(len(shape))} + else: + kwargs = {"shape": shape} + method_flags.num_positional_args = 1 + kwargs["order"] = order + helpers.test_frontend_method( + backend_to_test=backend_fw, + init_input_dtypes=input_dtype, + init_all_as_kwargs_np={ + "object": x[0], + }, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np=kwargs, + frontend=frontend, + frontend_method_data=frontend_method_data, + init_flags=init_flags, + method_flags=method_flags, + on_device=on_device, + ) + + # repeat @st.composite def _repeat_helper(draw): @@ -2435,3 +2500,4 @@ def test_jax_repeat( method_flags=method_flags, on_device=on_device, ) + diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_indexing.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_indexing.py index a92c8209e85a3..4f1595d2088a8 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_indexing.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_indexing.py @@ -1,12 +1,17 @@ # global from hypothesis import strategies as st, assume import numpy as np -from jax.numpy import tril, triu +from jax.numpy import tril, triu, r_, c_ # local import ivy_tests.test_ivy.helpers as helpers -from ivy_tests.test_ivy.helpers import handle_frontend_test +from ivy_tests.test_ivy.helpers import handle_frontend_test, update_backend +from ...test_numpy.test_indexing_routines.test_inserting_data_into_arrays import ( + _helper_r_, + _helper_c_, +) +import ivy.functional.frontends.jax.numpy as jnp_frontend # diagonal @@ -459,3 +464,20 @@ def test_jax_numpy_indices( dtype=dtype[0], sparse=sparse, ) + + +@handle_frontend_test(fn_tree="jax.numpy.add", inputs=_helper_r_()) # dummy fn_tree +def test_jax_numpy_r_(inputs, backend_fw): + inputs, *_ = inputs + ret_gt = r_.__getitem__(tuple(inputs)) + with update_backend(backend_fw): + ret = jnp_frontend.r_.__getitem__(tuple(inputs)) + assert np.allclose(ret.ivy_array, ret_gt) + + +@handle_frontend_test(fn_tree="jax.numpy.add", inputs=_helper_c_()) # dummy fn_tree +def test_jax_numpy_c_(inputs, backend_fw): + ret_gt = c_.__getitem__(tuple(inputs)) + with update_backend(backend_fw): + ret = jnp_frontend.c_.__getitem__(tuple(inputs)) + assert np.allclose(ret.ivy_array, ret_gt) diff --git a/ivy_tests/test_ivy/test_frontends/test_numpy/test_indexing_routines/test_inserting_data_into_arrays.py b/ivy_tests/test_ivy/test_frontends/test_numpy/test_indexing_routines/test_inserting_data_into_arrays.py index be9465442fcf2..6f97243e26bee 100644 --- a/ivy_tests/test_ivy/test_frontends/test_numpy/test_indexing_routines/test_inserting_data_into_arrays.py +++ b/ivy_tests/test_ivy/test_frontends/test_numpy/test_indexing_routines/test_inserting_data_into_arrays.py @@ -5,7 +5,7 @@ # local import ivy_tests.test_ivy.helpers as helpers import ivy_tests.test_ivy.test_frontends.test_numpy.helpers as np_frontend_helpers -from ivy_tests.test_ivy.helpers import handle_frontend_test +from ivy_tests.test_ivy.helpers import handle_frontend_test, update_backend import ivy.functional.frontends.numpy as np_frontend @@ -148,10 +148,11 @@ def test_numpy_fill_diagonal( @handle_frontend_test(fn_tree="numpy.add", inputs=_helper_r_()) # dummy fn_tree -def test_numpy_r_(inputs): +def test_numpy_r_(inputs, backend_fw): inputs, elems_in_last_dim, dim = inputs ret_gt = np.r_.__getitem__(tuple(inputs)) - ret = np_frontend.r_.__getitem__(tuple(inputs)) + with update_backend(backend_fw): + ret = np_frontend.r_.__getitem__(tuple(inputs)) if isinstance(inputs[0], str) and inputs[0] in ["r", "c"]: ret = ret._data else: @@ -160,9 +161,10 @@ def test_numpy_r_(inputs): @handle_frontend_test(fn_tree="numpy.add", inputs=_helper_c_()) # dummy fn_tree -def test_numpy_c_(inputs): +def test_numpy_c_(inputs, backend_fw): ret_gt = np.c_.__getitem__(tuple(inputs)) - ret = np_frontend.c_.__getitem__(tuple(inputs)) + with update_backend(backend_fw): + ret = np_frontend.c_.__getitem__(tuple(inputs)) if isinstance(inputs[0], str) and inputs[0] in ["r", "c"]: ret = ret._data else: