From 140b84cb377b9630909e775090a19c2b7dca3c4a Mon Sep 17 00:00:00 2001 From: Nripesh Niketan <86844847+NripeshN@users.noreply.github.com> Date: Sun, 27 Aug 2023 09:14:14 +0530 Subject: [PATCH 01/55] Fix(install): Fix apple silicon install (#22635) --- install_dependencies.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/install_dependencies.sh b/install_dependencies.sh index 69c234b644ddb..40115a8571594 100755 --- a/install_dependencies.sh +++ b/install_dependencies.sh @@ -1,7 +1,7 @@ pip install -r requirements/requirements.txt if [[ $(arch) == 'arm64' ]]; then - pip install -r requirements/optional_m1_1.txt - pip install -r requirements/optional_m1_2.txt + pip install -r requirements/optional_apple_silicon_1.txt + pip install -r requirements/optional_apple_silicon_2.txt else pip install -r requirements/optional.txt fi \ No newline at end of file From 75557f1fdc217078368686a9638cde60d2c67eed Mon Sep 17 00:00:00 2001 From: Nripesh Niketan <86844847+NripeshN@users.noreply.github.com> Date: Sun, 27 Aug 2023 09:17:20 +0530 Subject: [PATCH 02/55] update(docs): Update docs to new apple silicon requirements --- docs/overview/contributing/setting_up.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/overview/contributing/setting_up.rst b/docs/overview/contributing/setting_up.rst index 5734960449454..15f922ae1c999 100644 --- a/docs/overview/contributing/setting_up.rst +++ b/docs/overview/contributing/setting_up.rst @@ -150,8 +150,8 @@ Using miniconda .. code-block:: none - pip install -r requirements/optional_m1_1.txt - pip install -r requirements/optional_m1_2.txt + pip install -r requirements/optional_apple_silicon_1.txt + pip install -r requirements/optional_apple_silicon_2.txt Using venv ********** From 301fda43ce821121402e4ee464104a6795fb6ad8 Mon Sep 17 00:00:00 2001 From: ivy-branch Date: Sun, 27 Aug 2023 08:06:34 +0000 Subject: [PATCH 03/55] =?UTF-8?q?=F0=9F=A4=96=20Lint=20code?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../frontends/jax/numpy/indexing.py | 210 +++++++++--------- ivy/functional/frontends/sklearn/base.py | 2 - .../frontends/sklearn/datasets/__init__.py | 2 +- .../sklearn/datasets/_samples_generator.py | 38 +++- .../test_frontends/test_sklearn/conftest.py | 2 +- 5 files changed, 136 insertions(+), 118 deletions(-) diff --git a/ivy/functional/frontends/jax/numpy/indexing.py b/ivy/functional/frontends/jax/numpy/indexing.py index 40422234dcbf1..5bda70cfc5d36 100644 --- a/ivy/functional/frontends/jax/numpy/indexing.py +++ b/ivy/functional/frontends/jax/numpy/indexing.py @@ -11,111 +11,6 @@ from .manipulations import transpose, concatenate, expand_dims -@to_ivy_arrays_and_back -def diagonal(a, offset=0, axis1=0, axis2=1): - return ivy.diagonal(a, offset=offset, axis1=axis1, axis2=axis2) - - -@to_ivy_arrays_and_back -def diag(v, k=0): - return ivy.diag(v, k=k) - - -@to_ivy_arrays_and_back -def diag_indices(n, ndim=2): - idx = ivy.arange(n, dtype=int) - return (idx,) * ndim - - -# take_along_axis -@to_ivy_arrays_and_back -def take_along_axis(arr, indices, axis, mode="fill"): - return ivy.take_along_axis(arr, indices, axis, mode=mode) - - -@to_ivy_arrays_and_back -def tril_indices(n, k=0, m=None): - return ivy.tril_indices(n, m, k) - - -@to_ivy_arrays_and_back -def triu_indices(n, k=0, m=None): - return ivy.triu_indices(n, m, k) - - -@to_ivy_arrays_and_back -def triu_indices_from(arr, k=0): - return ivy.triu_indices(arr.shape[-2], arr.shape[-1], k) - - -@to_ivy_arrays_and_back -def tril_indices_from(arr, k=0): - return ivy.tril_indices(arr.shape[-2], arr.shape[-1], k) - - -# unravel_index -@to_ivy_arrays_and_back -def unravel_index(indices, shape): - ret = [x.astype(indices.dtype) for x in ivy.unravel_index(indices, shape)] - return tuple(ret) - - -@to_ivy_arrays_and_back -def mask_indices(n, mask_func, k=0): - mask_func_obj = inspect.unwrap(mask_func) - mask_func_name = mask_func_obj.__name__ - try: - ivy_mask_func_obj = getattr(ivy.functional.frontends.jax.numpy, mask_func_name) - a = ivy.ones((n, n)) - mask = ivy_mask_func_obj(a, k=k) - indices = ivy.argwhere(mask.ivy_array) - return indices[:, 0], indices[:, 1] - except AttributeError as e: - print(f"Attribute error: {e}") - - -@to_ivy_arrays_and_back -def diag_indices_from(arr): - print(arr) - n = arr.shape[0] - ndim = ivy.get_num_dims(arr) - if not all(arr.shape[i] == n for i in range(ndim)): - raise ValueError("All dimensions of input must be of equal length") - idx = ivy.arange(n, dtype=int) - return (idx,) * ndim - - -@to_ivy_arrays_and_back -def indices(dimensions, dtype=int, sparse=False): - if sparse: - return tuple( - ivy.arange(dim) - .expand_dims( - axis=[j for j in range(len(dimensions)) if i != j], - ) - .astype(dtype) - for i, dim in enumerate(dimensions) - ) - else: - grid = ivy.meshgrid(*[ivy.arange(dim) for dim in dimensions], indexing="ij") - return ivy.stack(grid, axis=0).astype(dtype) - - -@to_ivy_arrays_and_back -def choose(arr, choices, out=None, mode="raise"): - return ivy.choose(arr, choices, out=out, mode=mode) - - -def _make_1d_grid_from_slice(s): - step = 1 if s.step is None else s.step - start = 0 if s.start is None else s.start - if s.step is not None and ivy.is_complex_dtype(s.step): - newobj = linspace(start, s.stop, int(abs(step))) - else: - newobj = arange(start, s.stop, step) - return newobj - - class _AxisConcat(abc.ABC): axis: int ndmin: int @@ -202,6 +97,16 @@ class CClass(_AxisConcat): # --------------- # +def _make_1d_grid_from_slice(s): + step = 1 if s.step is None else s.step + start = 0 if s.start is None else s.start + if s.step is not None and ivy.is_complex_dtype(s.step): + newobj = linspace(start, s.stop, int(abs(step))) + else: + newobj = arange(start, s.stop, step) + return newobj + + def _make_1d_grid_from_slice(s): step = 1 if s.step is None else s.step start = 0 if s.start is None else s.start @@ -216,6 +121,16 @@ def _make_1d_grid_from_slice(s): # ------------ # +@to_ivy_arrays_and_back +def choose(arr, choices, out=None, mode="raise"): + return ivy.choose(arr, choices, out=out, mode=mode) + + +@to_ivy_arrays_and_back +def diag(v, k=0): + return ivy.diag(v, k=k) + + @to_ivy_arrays_and_back def diag(v, k=0): return ivy.diag(v, k=k) @@ -227,6 +142,12 @@ def diag_indices(n, ndim=2): return (idx,) * ndim +@to_ivy_arrays_and_back +def diag_indices(n, ndim=2): + idx = ivy.arange(n, dtype=int) + return (idx,) * ndim + + @to_ivy_arrays_and_back def diag_indices_from(arr): print(arr) @@ -238,6 +159,22 @@ def diag_indices_from(arr): return (idx,) * ndim +@to_ivy_arrays_and_back +def diag_indices_from(arr): + print(arr) + n = arr.shape[0] + ndim = ivy.get_num_dims(arr) + if not all(arr.shape[i] == n for i in range(ndim)): + raise ValueError("All dimensions of input must be of equal length") + idx = ivy.arange(n, dtype=int) + return (idx,) * ndim + + +@to_ivy_arrays_and_back +def diagonal(a, offset=0, axis1=0, axis2=1): + return ivy.diagonal(a, offset=offset, axis1=axis1, axis2=axis2) + + @to_ivy_arrays_and_back def diagonal(a, offset=0, axis1=0, axis2=1): return ivy.diagonal(a, offset=offset, axis1=axis1, axis2=axis2) @@ -259,6 +196,36 @@ def indices(dimensions, dtype=int, sparse=False): return ivy.stack(grid, axis=0).astype(dtype) +@to_ivy_arrays_and_back +def indices(dimensions, dtype=int, sparse=False): + if sparse: + return tuple( + ivy.arange(dim) + .expand_dims( + axis=[j for j in range(len(dimensions)) if i != j], + ) + .astype(dtype) + for i, dim in enumerate(dimensions) + ) + else: + grid = ivy.meshgrid(*[ivy.arange(dim) for dim in dimensions], indexing="ij") + return ivy.stack(grid, axis=0).astype(dtype) + + +@to_ivy_arrays_and_back +def mask_indices(n, mask_func, k=0): + mask_func_obj = inspect.unwrap(mask_func) + mask_func_name = mask_func_obj.__name__ + try: + ivy_mask_func_obj = getattr(ivy.functional.frontends.jax.numpy, mask_func_name) + a = ivy.ones((n, n)) + mask = ivy_mask_func_obj(a, k=k) + indices = ivy.argwhere(mask.ivy_array) + return indices[:, 0], indices[:, 1] + except AttributeError as e: + print(f"Attribute error: {e}") + + @to_ivy_arrays_and_back def mask_indices(n, mask_func, k=0): mask_func_obj = inspect.unwrap(mask_func) @@ -279,6 +246,17 @@ def take_along_axis(arr, indices, axis, mode="fill"): return ivy.take_along_axis(arr, indices, axis, mode=mode) +# take_along_axis +@to_ivy_arrays_and_back +def take_along_axis(arr, indices, axis, mode="fill"): + return ivy.take_along_axis(arr, indices, axis, mode=mode) + + +@to_ivy_arrays_and_back +def tril_indices(n, k=0, m=None): + return ivy.tril_indices(n, m, k) + + @to_ivy_arrays_and_back def tril_indices(n, k=0, m=None): return ivy.tril_indices(n, m, k) @@ -289,6 +267,16 @@ def tril_indices_from(arr, k=0): return ivy.tril_indices(arr.shape[-2], arr.shape[-1], k) +@to_ivy_arrays_and_back +def tril_indices_from(arr, k=0): + return ivy.tril_indices(arr.shape[-2], arr.shape[-1], k) + + +@to_ivy_arrays_and_back +def triu_indices(n, k=0, m=None): + return ivy.triu_indices(n, m, k) + + @to_ivy_arrays_and_back def triu_indices(n, k=0, m=None): return ivy.triu_indices(n, m, k) @@ -299,6 +287,18 @@ def triu_indices_from(arr, k=0): return ivy.triu_indices(arr.shape[-2], arr.shape[-1], k) +@to_ivy_arrays_and_back +def triu_indices_from(arr, k=0): + return ivy.triu_indices(arr.shape[-2], arr.shape[-1], k) + + +# unravel_index +@to_ivy_arrays_and_back +def unravel_index(indices, shape): + ret = [x.astype(indices.dtype) for x in ivy.unravel_index(indices, shape)] + return tuple(ret) + + # unravel_index @to_ivy_arrays_and_back def unravel_index(indices, shape): diff --git a/ivy/functional/frontends/sklearn/base.py b/ivy/functional/frontends/sklearn/base.py index 71f55b61441f2..2e8efd6078ee3 100644 --- a/ivy/functional/frontends/sklearn/base.py +++ b/ivy/functional/frontends/sklearn/base.py @@ -7,7 +7,6 @@ def set_params(self, **params): class ClassifierMixin: - def score(self, X, y, sample_weight=None): raise NotImplementedError @@ -19,7 +18,6 @@ def predict(self, X): class RegressorMixin: - def score(self, X, y, sample_weight=None): raise NotImplementedError diff --git a/ivy/functional/frontends/sklearn/datasets/__init__.py b/ivy/functional/frontends/sklearn/datasets/__init__.py index d8a44bb8532d3..be9aa4cb1e0b3 100644 --- a/ivy/functional/frontends/sklearn/datasets/__init__.py +++ b/ivy/functional/frontends/sklearn/datasets/__init__.py @@ -1,2 +1,2 @@ from . import _samples_generator -from ._samples_generator import * \ No newline at end of file +from ._samples_generator import * diff --git a/ivy/functional/frontends/sklearn/datasets/_samples_generator.py b/ivy/functional/frontends/sklearn/datasets/_samples_generator.py index 8bb776a5c0b91..865f097588ef1 100644 --- a/ivy/functional/frontends/sklearn/datasets/_samples_generator.py +++ b/ivy/functional/frontends/sklearn/datasets/_samples_generator.py @@ -2,7 +2,9 @@ import numbers -def make_circles(n_samples=100, *, shuffle=True, noise=None, random_state=None, factor=0.8): +def make_circles( + n_samples=100, *, shuffle=True, noise=None, random_state=None, factor=0.8 +): # numbers.Integral also includes bool if isinstance(n_samples, numbers.Integral): n_samples_out = n_samples // 2 @@ -10,12 +12,30 @@ def make_circles(n_samples=100, *, shuffle=True, noise=None, random_state=None, elif isinstance(n_samples, tuple): n_samples_out, n_samples_in = n_samples - outer_circ_x = ivy.cos(ivy.linspace(0, 2 * ivy.pi, num=n_samples_out, endpoint=False)) - outer_circ_y = ivy.sin(ivy.linspace(0, 2 * ivy.pi, num=n_samples_out, endpoint=False)) - inner_circ_x = ivy.cos(ivy.linspace(0, 2 * ivy.pi, num=n_samples_in, endpoint=False)) * factor - inner_circ_y = ivy.sin(ivy.linspace(0, 2 * ivy.pi, num=n_samples_in, endpoint=False)) * factor - X = ivy.concat([ivy.stack([outer_circ_x, outer_circ_y], axis=1), - ivy.stack([inner_circ_x, inner_circ_y], axis=1)], axis=0) - y = ivy.concat([ivy.zeros(n_samples_out, dtype=ivy.int32), - ivy.ones(n_samples_in, dtype=ivy.int32)], axis=0) + outer_circ_x = ivy.cos( + ivy.linspace(0, 2 * ivy.pi, num=n_samples_out, endpoint=False) + ) + outer_circ_y = ivy.sin( + ivy.linspace(0, 2 * ivy.pi, num=n_samples_out, endpoint=False) + ) + inner_circ_x = ( + ivy.cos(ivy.linspace(0, 2 * ivy.pi, num=n_samples_in, endpoint=False)) * factor + ) + inner_circ_y = ( + ivy.sin(ivy.linspace(0, 2 * ivy.pi, num=n_samples_in, endpoint=False)) * factor + ) + X = ivy.concat( + [ + ivy.stack([outer_circ_x, outer_circ_y], axis=1), + ivy.stack([inner_circ_x, inner_circ_y], axis=1), + ], + axis=0, + ) + y = ivy.concat( + [ + ivy.zeros(n_samples_out, dtype=ivy.int32), + ivy.ones(n_samples_in, dtype=ivy.int32), + ], + axis=0, + ) return X, y diff --git a/ivy_tests/test_ivy/test_frontends/test_sklearn/conftest.py b/ivy_tests/test_ivy/test_frontends/test_sklearn/conftest.py index 2e8a2ccccd45c..7bc3febeaa816 100644 --- a/ivy_tests/test_ivy/test_frontends/test_sklearn/conftest.py +++ b/ivy_tests/test_ivy/test_frontends/test_sklearn/conftest.py @@ -3,4 +3,4 @@ @pytest.fixture(scope="session") def frontend(): - return "sklearn" \ No newline at end of file + return "sklearn" From cb60510fad141f394d4d451875442a1e6e67fb52 Mon Sep 17 00:00:00 2001 From: RickSanchezStoic <57310695+RickSanchezStoic@users.noreply.github.com> Date: Sun, 27 Aug 2023 14:36:17 +0530 Subject: [PATCH 04/55] Stateful changes 2 (#22461) Co-authored-by: Rishabh Kumar --- ivy/data_classes/container/base.py | 35 ++++++++++++++---------------- ivy/stateful/module.py | 6 +---- 2 files changed, 17 insertions(+), 24 deletions(-) diff --git a/ivy/data_classes/container/base.py b/ivy/data_classes/container/base.py index 986bc18c420b7..2d86aa4e9954e 100644 --- a/ivy/data_classes/container/base.py +++ b/ivy/data_classes/container/base.py @@ -69,6 +69,7 @@ def __init__( types_to_iteratively_nest=None, alphabetical_keys=True, dynamic_backend=None, + build_callable=False, **kwargs, ): """ @@ -113,6 +114,9 @@ def __init__( rebuild_child_containers Whether to rebuild container found in dict_in with these constructor params. Default is ``False``, in which case the original container are kept as are. + build_callable + Whether to treat functions encountered at leaf nodes as further instructions + to build the container types_to_iteratively_nest The data types to nest iteratively in the dict structure, each type must be iterable. Default is ``None``. @@ -156,6 +160,7 @@ def __init__( default_key_color=default_key_color, keyword_color_dict=keyword_color_dict, rebuild_child_containers=rebuild_child_containers, + build_callable=build_callable, types_to_iteratively_nest=types_to_iteratively_nest, alphabetical_keys=alphabetical_keys, ) @@ -565,7 +570,15 @@ def cont_diff( all_keys_present = sum(keys_present) == num_containers if all_keys_present: res = ivy.Container.cont_diff( - *[cont[key] for cont in containers], + *[ + ( + cont[key]() + if cont.cont_config["build_callable"] + and callable(cont[key]) + else cont[key] + ) + for cont in containers + ], mode=mode, diff_keys=diff_keys, detect_key_diffs=detect_key_diffs, @@ -809,7 +822,6 @@ def cont_identical( to_apply=True, partial=False, key_chain="", - build_callable=False, assert_and_assign=False, ): """ @@ -841,9 +853,6 @@ def cont_identical( Default is ``False``. key_chain Chain of keys for this dict entry (Default value = '') - build_callable - if true, the leaf nodes which are callables are assumed to be called to - build further nested layers assert_and_assign if true, then the container being compared with is updated with the value in the container being compared to given that the strucutres are congruent @@ -864,9 +873,8 @@ def cont_identical( for key in keys: if not min([key in cont for cont in containers]): return False - if build_callable: - # call the callable if encountered - for cont in containers: + for cont in containers: + if cont.cont_config["build_callable"]: cont[key] = cont[key]() if callable(cont[key]) else cont[key] values = [cont[key] for cont in containers] value_0 = values[0] @@ -905,7 +913,6 @@ def cont_identical( to_apply, partial, this_key_chain, - build_callable=build_callable, assert_and_assign=assert_and_assign, ) if not ret: @@ -983,7 +990,6 @@ def cont_identical_structure( to_apply=True, partial=False, key_chain="", - build_callable=False, assert_and_assign=False, ): """ @@ -1010,9 +1016,6 @@ def cont_identical_structure( Default is ``False``. key_chain Chain of keys for this dict entry (Default value = '') - build_callable - if true, the leaf nodes which are callables are assumed to be called to - build further nested layers assert_and_assign if true, then the container being compared with is updated with the value in the container being compared to given that the strucutres are congruent @@ -1030,7 +1033,6 @@ def cont_identical_structure( to_apply, partial, key_chain, - build_callable=build_callable, assert_and_assign=assert_and_assign, ) @@ -1042,7 +1044,6 @@ def cont_assert_identical_structure( key_chains=None, to_apply=True, partial=False, - build_callable=False, assert_and_assign=False, ): """ @@ -1067,9 +1068,6 @@ def cont_assert_identical_structure( partial Whether to also check for partially complete sub-containers. Default is ``False``. - build_callable - if true, the leaf nodes which are callables are assumed to be called to - build further nested layers assert_and_assign if true, then the container being compared with is updated with the value in the container being compared to given that the strucutres are congruent @@ -1082,7 +1080,6 @@ def cont_assert_identical_structure( key_chains, to_apply, partial, - build_callable=build_callable, assert_and_assign=assert_and_assign, ), "Containers did not have identical structure:\n\n{}".format( diff --git a/ivy/stateful/module.py b/ivy/stateful/module.py index c5408c37f25e0..1ad57a745c022 100644 --- a/ivy/stateful/module.py +++ b/ivy/stateful/module.py @@ -5,10 +5,6 @@ import os import abc import copy -import subprocess -import sys - -subprocess.check_call([sys.executable, "-m", "pip", "install", "dill"]) import dill from typing import Optional, Tuple, Dict @@ -743,6 +739,7 @@ def build( ), dynamic_backend=dynamic_backend, ) + created_n_found.cont_config["build_callable"] = True if ivy.exists(v_from_constructor): if self._with_partial_v: if v_from_constructor: @@ -757,7 +754,6 @@ def build( ivy.Container.cont_assert_identical_structure( [created_n_found, v_from_constructor], - build_callable=True, assert_and_assign=True, ) From ff312df4cafb0301a64e7586ce288f602e6c5ec1 Mon Sep 17 00:00:00 2001 From: Sarthak Tyagi <72240174+dark-developer15@users.noreply.github.com> Date: Sun, 27 Aug 2023 15:31:37 +0530 Subject: [PATCH 05/55] fix: typo in some fft function names Co-authored-by: Mahmoud Ashraf --- .../test_numpy/test_fft/test_discrete_fourier_transform.py | 6 +++--- ivy_tests/test_ivy/test_frontends/test_paddle/test_fft.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/ivy_tests/test_ivy/test_frontends/test_numpy/test_fft/test_discrete_fourier_transform.py b/ivy_tests/test_ivy/test_frontends/test_numpy/test_fft/test_discrete_fourier_transform.py index 588413e661ba7..1e4557bf521d3 100644 --- a/ivy_tests/test_ivy/test_frontends/test_numpy/test_fft/test_discrete_fourier_transform.py +++ b/ivy_tests/test_ivy/test_frontends/test_numpy/test_fft/test_discrete_fourier_transform.py @@ -16,7 +16,7 @@ fn_tree="numpy.fft.ifft", dtype_and_x=x_and_ifft(), ) -def test_numpy_iftt(dtype_and_x, backend_fw, frontend, test_flags, fn_tree, on_device): +def test_numpy_ifft(dtype_and_x, backend_fw, frontend, test_flags, fn_tree, on_device): input_dtype, x, dim, norm, n = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, @@ -39,7 +39,7 @@ def test_numpy_iftt(dtype_and_x, backend_fw, frontend, test_flags, fn_tree, on_d available_dtypes=helpers.get_dtypes("float"), shape=(4,), array_api_dtypes=True ), ) -def test_numpy_ifttshift( +def test_numpy_ifftshift( dtype_and_x, backend_fw, frontend, test_flags, fn_tree, on_device ): input_dtype, arr = dtype_and_x @@ -92,7 +92,7 @@ def test_numpy_fft( available_dtypes=helpers.get_dtypes("float"), shape=(4,), array_api_dtypes=True ), ) -def test_numpy_fttshift( +def test_numpy_fftshift( dtype_and_x, backend_fw, frontend, test_flags, fn_tree, on_device ): input_dtype, arr = dtype_and_x diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_fft.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_fft.py index 98a1b26d91a73..8cc081e98f9c9 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_fft.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_fft.py @@ -57,7 +57,7 @@ def test_paddle_fft( force_int_axis=True, ), ) -def test_paddle_fttshift( +def test_paddle_fftshift( dtype_x_axis, frontend, test_flags, fn_tree, on_device, backend_fw ): input_dtype, x, axes = dtype_x_axis From eb59a762b8561e0df31f36d5301658fa97c9c52a Mon Sep 17 00:00:00 2001 From: Daniel4078 <45633544+Daniel4078@users.noreply.github.com> Date: Sun, 27 Aug 2023 18:23:27 +0800 Subject: [PATCH 06/55] Add testing for precise and non precise modes (#21781) --- .../test_ivy/helpers/function_testing.py | 56 +++++++++++++------ .../test_ivy/helpers/test_parameter_flags.py | 35 +++++++++++- ivy_tests/test_ivy/helpers/testing_helpers.py | 22 ++++++++ 3 files changed, 96 insertions(+), 17 deletions(-) diff --git a/ivy_tests/test_ivy/helpers/function_testing.py b/ivy_tests/test_ivy/helpers/function_testing.py index 4948234eff685..02315785cbf26 100644 --- a/ivy_tests/test_ivy/helpers/function_testing.py +++ b/ivy_tests/test_ivy/helpers/function_testing.py @@ -118,7 +118,7 @@ def test_function( test_flags FunctionTestFlags object that stores all testing flags, including: num_positional_args, with_out, instance_method, as_variable, - native_arrays, container, gradient + native_arrays, container, gradient, precision_mode fw current backend (framework). fn_name @@ -161,11 +161,13 @@ def test_function( >>> native_array_flags = False >>> container_flags = False >>> instance_method = False + >>> precision_mode = False >>> test_flags = FunctionTestFlags(num_positional_args, with_out, instance_method, as_variable, native_arrays, container_flags, + precision_mode, none) >>> fw = "torch" >>> fn_name = "abs" @@ -179,11 +181,13 @@ def test_function( >>> native_array_flags = [True, False] >>> container_flags = [False, False] >>> instance_method = False + >>> precision_mode = False >>> test_flags = FunctionTestFlags(num_positional_args, with_out, instance_method, as_variable, native_arrays, container_flags, + precision_mode, none) >>> fw = "numpy" >>> fn_name = "add" @@ -299,6 +303,7 @@ def test_function( target_fn, *args, test_compile=test_flags.test_compile, + precision_mode=test_flags.precision_mode, **kwargs, ) @@ -324,6 +329,7 @@ def test_function( ) = get_ret_and_flattened_np_array( backend_to_test, instance.__getattribute__(fn_name), + precision_mode=test_flags.precision_mode, *args, **kwargs, out=out, @@ -335,6 +341,7 @@ def test_function( ) = get_ret_and_flattened_np_array( backend_to_test, ivy_backend.__dict__[fn_name], + precision_mode=test_flags.precision_mode, *args, **kwargs, out=out, @@ -407,6 +414,7 @@ def test_function( gt_backend.__dict__[fn_name], *args, test_compile=test_flags.test_compile, + precision_mode=test_flags.precision_mode, **kwargs, ) assert gt_backend.nested_map( @@ -430,6 +438,7 @@ def test_function( gt_backend.__dict__[fn_name], *args, test_compile=test_flags.test_compile, + precision_mode=test_flags.precision_mode, **kwargs, out=out_from_gt, ) @@ -530,6 +539,10 @@ def test_frontend_function( ---------- input_dtypes data types of the input arguments in order. + test_flags + FunctionTestFlags object that stores all testing flags, including: + num_positional_args, with_out, instance_method, as_variable, + native_arrays, container, gradient, precision_mode all_aliases a list of strings containing all aliases for that function in the current frontend with their full namespaces. @@ -655,6 +668,7 @@ def test_frontend_function( create_frontend_array if test_flags.test_compile else None ), as_ivy_arrays=(not test_flags.generate_frontend_arrays), + precision_mode=test_flags.precision_mode, **kwargs_for_test, ) @@ -763,6 +777,7 @@ def test_frontend_function( ret_ = get_frontend_ret( frontend_fn=frontend_fn, backend=backend_to_test, + precision_mode=test_flags.precision_mode, test_compile=test_flags.test_compile, frontend_array_function=( create_frontend_array if test_flags.test_compile else None @@ -781,6 +796,7 @@ def test_frontend_function( ret_ = get_frontend_ret( frontend_fn=frontend_fn, backend=backend_to_test, + precision_mode=test_flags.precision_mode, test_compile=test_flags.test_compile, frontend_array_function=( create_frontend_array if test_flags.test_compile else None @@ -948,12 +964,13 @@ def _grad_fn(all_args): )(*args, **kwargs) return ivy_backend.nested_map(ret, ivy_backend.mean, include_derived=True) - _, grads = ivy_backend.execute_with_gradients( - _grad_fn, - [args, kwargs, 0], - xs_grad_idxs=xs_grad_idxs, - ret_grad_idxs=ret_grad_idxs, - ) + with ivy_backend.PreciseMode(test_flags.precision_mode): + _, grads = ivy_backend.execute_with_gradients( + _grad_fn, + [args, kwargs, 0], + xs_grad_idxs=xs_grad_idxs, + ret_grad_idxs=ret_grad_idxs, + ) grads_np_flat = flatten_and_to_np(backend=backend_to_test, ret=grads) with BackendHandler.update_backend(ground_truth_backend) as gt_backend: @@ -991,12 +1008,13 @@ def _gt_grad_fn(all_args): )(*args, **kwargs) return gt_backend.nested_map(ret, gt_backend.mean, include_derived=True) - _, grads_from_gt = gt_backend.execute_with_gradients( - _gt_grad_fn, - [args, kwargs, 1], - xs_grad_idxs=xs_grad_idxs, - ret_grad_idxs=ret_grad_idxs, - ) + with gt_backend.PreciseMode(test_flags.precision_mode): + _, grads_from_gt = gt_backend.execute_with_gradients( + _gt_grad_fn, + [args, kwargs, 1], + xs_grad_idxs=xs_grad_idxs, + ret_grad_idxs=ret_grad_idxs, + ) grads_np_from_gt_flat = flatten_and_to_np( backend=ground_truth_backend, ret=grads_from_gt ) @@ -1275,6 +1293,7 @@ def test_method( ins.__getattribute__(method_name), *args_method, test_compile=test_compile, + precision_mode=method_flags.precision_mode, **kwargs_method, ) if isinstance(ret, ivy_backend.Array): @@ -1328,6 +1347,7 @@ def test_method( ins_gt.__getattribute__(method_name), *args_gt_method, test_compile=test_compile, + precision_mode=method_flags.precision_mode, **kwargs_gt_method, ) assert gt_backend.nested_map( @@ -1657,6 +1677,7 @@ def test_frontend_method( ret, ret_np_flat = get_ret_and_flattened_np_array( backend_to_test, ins.__getattribute__(frontend_method_data.method_name), + precision_mode=method_flags.precision_mode, *args_method, test_compile=method_flags.test_compile, **kwargs_method, @@ -1953,7 +1974,7 @@ def flatten_frontend_to_np(*, backend: str, ret, frontend_array_fn=None): def get_ret_and_flattened_np_array( - backend_to_test: str, fn, *args, test_compile: bool = False, **kwargs + backend_to_test: str, fn, *args, test_compile=False, precision_mode=False, **kwargs ): """ Run func with args and kwargs. @@ -1964,7 +1985,8 @@ def get_ret_and_flattened_np_array( backend_to_test, fn, test_compile=test_compile, args=args, kwargs=kwargs ) with BackendHandler.update_backend(backend_to_test) as ivy_backend: - ret = fn(*args, **kwargs) + with ivy_backend.PreciseMode(precision_mode): + ret = fn(*args, **kwargs) def map_fn(x): if _is_frontend_array(x): @@ -1983,6 +2005,7 @@ def get_frontend_ret( *args, frontend_array_function=None, as_ivy_arrays=True, + precision_mode=False, test_compile: bool = False, **kwargs, ): @@ -1994,7 +2017,8 @@ def get_frontend_ret( args, kwargs = ivy_backend.nested_map( (args, kwargs), _frontend_array_to_ivy, include_derived={tuple: True} ) - ret = frontend_fn(*args, **kwargs) + with ivy_backend.PreciseMode(precision_mode): + ret = frontend_fn(*args, **kwargs) if test_compile and frontend_array_function is not None: if as_ivy_arrays: ret = ivy_backend.nested_map( diff --git a/ivy_tests/test_ivy/helpers/test_parameter_flags.py b/ivy_tests/test_ivy/helpers/test_parameter_flags.py index 86362f6a22fe2..be0395e883262 100644 --- a/ivy_tests/test_ivy/helpers/test_parameter_flags.py +++ b/ivy_tests/test_ivy/helpers/test_parameter_flags.py @@ -36,6 +36,7 @@ def _as_varaible_strategy(draw): BuiltWithOutStrategy = st.booleans() BuiltCompileStrategy = st.just(False) BuiltFrontendArrayStrategy = st.booleans() +BuiltPrecisionModeStrategy = st.booleans() flags_mapping = { @@ -47,6 +48,7 @@ def _as_varaible_strategy(draw): "with_out": "BuiltWithOutStrategy", "inplace": "BuiltInplace", "test_compile": "BuiltCompileStrategy", + "precision_mode": "BuiltPrecisionModeStrategy", } @@ -80,6 +82,7 @@ def __init__( container, test_gradients, test_compile, + precision_mode, ): self.ground_truth_backend = ground_truth_backend self.num_positional_args = num_positional_args @@ -90,6 +93,7 @@ def __init__( self.as_variable = as_variable self.test_gradients = test_gradients self.test_compile = test_compile + self.precision_mode = precision_mode def apply_flags(self, args_to_iterate, input_dtypes, offset, *, backend, on_device): ret = [] @@ -116,6 +120,7 @@ def __str__(self): f"as_variable={self.as_variable}. " f"test_gradients={self.test_gradients}. " f"test_compile={self.test_compile}. " + f"precision_mode={self.precision_mode}. " ) def __repr__(self): @@ -135,6 +140,7 @@ def function_flags( as_variable, native_arrays, container_flags, + precision_mode, ): return draw( st.builds( @@ -148,6 +154,7 @@ def function_flags( as_variable=as_variable, native_arrays=native_arrays, container=container_flags, + precision_mode=precision_mode, ) ) @@ -162,6 +169,7 @@ def __init__( native_arrays, test_compile, generate_frontend_arrays, + precision_mode, ): self.num_positional_args = num_positional_args self.with_out = with_out @@ -170,6 +178,7 @@ def __init__( self.as_variable = as_variable self.test_compile = test_compile self.generate_frontend_arrays = generate_frontend_arrays + self.precision_mode = precision_mode def apply_flags(self, args_to_iterate, input_dtypes, offset, *, backend, on_device): ret = [] @@ -192,6 +201,7 @@ def __str__(self): f"as_variable={self.as_variable}. " f"test_compile={self.test_compile}. " f"generate_frontend_arrays={self.generate_frontend_arrays}. " + f"precision_mode={self.precision_mode}. " ) def __repr__(self): @@ -209,6 +219,7 @@ def frontend_function_flags( native_arrays, test_compile, generate_frontend_arrays, + precision_mode, ): return draw( st.builds( @@ -220,6 +231,7 @@ def frontend_function_flags( native_arrays=native_arrays, test_compile=test_compile, generate_frontend_arrays=generate_frontend_arrays, + precision_mode=precision_mode, ) ) @@ -230,10 +242,12 @@ def __init__( num_positional_args, as_variable, native_arrays, + precision_mode, ): self.num_positional_args = num_positional_args self.native_arrays = native_arrays self.as_variable = as_variable + self.precision_mode = precision_mode def apply_flags(self, args_to_iterate, input_dtypes, offset, *, backend, on_device): ret = [] @@ -252,6 +266,7 @@ def __str__(self): f"num_positional_args={self.num_positional_args}. " f"native_arrays={self.native_arrays}. " f"as_variable={self.as_variable}. " + f"precision_mode={self.precision_mode}. " ) def __repr__(self): @@ -265,6 +280,7 @@ def init_method_flags( num_positional_args, as_variable, native_arrays, + precision_mode, ): return draw( st.builds( @@ -272,6 +288,7 @@ def init_method_flags( num_positional_args=num_positional_args, as_variable=as_variable, native_arrays=native_arrays, + precision_mode=precision_mode, ) ) @@ -283,11 +300,13 @@ def __init__( as_variable, native_arrays, container_flags, + precision_mode, ): self.num_positional_args = num_positional_args self.native_arrays = native_arrays self.as_variable = as_variable self.container = container_flags + self.precision_mode = precision_mode def apply_flags(self, args_to_iterate, input_dtypes, offset, *, backend, on_device): ret = [] @@ -309,6 +328,7 @@ def __str__(self): f"native_arrays={self.native_arrays}. " f"as_variable={self.as_variable}. " f"container_flags={self.container}. " + f"precision_mode={self.precision_mode}. " ) def __repr__(self): @@ -323,6 +343,7 @@ def method_flags( as_variable, native_arrays, container_flags, + precision_mode, ): return draw( st.builds( @@ -331,15 +352,24 @@ def method_flags( as_variable=as_variable, native_arrays=native_arrays, container_flags=container_flags, + precision_mode=precision_mode, ) ) class FrontendMethodTestFlags(TestFlags): - def __init__(self, num_positional_args, as_variable, native_arrays, test_compile): + def __init__( + self, + num_positional_args, + as_variable, + native_arrays, + precision_mode, + test_compile, + ): self.num_positional_args = num_positional_args self.native_arrays = native_arrays self.as_variable = as_variable + self.precision_mode = precision_mode self.test_compile = test_compile def apply_flags(self, args_to_iterate, input_dtypes, offset, *, backend, on_device): @@ -359,6 +389,7 @@ def __str__(self): f"num_positional_args={self.num_positional_args}. " f"native_arrays={self.native_arrays}. " f"as_variable={self.as_variable}. " + f"precision_mode={self.precision_mode}. " f"test_compile={self.test_compile}." ) @@ -373,6 +404,7 @@ def frontend_method_flags( num_positional_args, as_variable, native_arrays, + precision_mode, test_compile, ): return draw( @@ -381,6 +413,7 @@ def frontend_method_flags( num_positional_args=num_positional_args, as_variable=as_variable, native_arrays=native_arrays, + precision_mode=precision_mode, test_compile=test_compile, ) ) diff --git a/ivy_tests/test_ivy/helpers/testing_helpers.py b/ivy_tests/test_ivy/helpers/testing_helpers.py index 6698e2de59d2f..670f6106ac880 100644 --- a/ivy_tests/test_ivy/helpers/testing_helpers.py +++ b/ivy_tests/test_ivy/helpers/testing_helpers.py @@ -24,6 +24,7 @@ BuiltInplaceStrategy, BuiltCompileStrategy, BuiltFrontendArrayStrategy, + BuiltPrecisionModeStrategy, ) from ivy_tests.test_ivy.helpers.structs import FrontendMethodData from ivy_tests.test_ivy.helpers.available_frameworks import available_frameworks @@ -38,6 +39,7 @@ "instance_method", "test_gradients", "test_compile", + "precision_mode", ) cmd_line_args_lists = ( "as_variable", @@ -286,6 +288,7 @@ def handle_test( test_with_out=BuiltWithOutStrategy, test_gradients=BuiltGradientStrategy, test_compile=BuiltCompileStrategy, + precision_mode=BuiltPrecisionModeStrategy, as_variable_flags=BuiltAsVariableStrategy, native_array_flags=BuiltNativeArrayStrategy, container_flags=BuiltContainerStrategy, @@ -323,6 +326,10 @@ def handle_test( A search strategy that generates a boolean to graph compile and test the function + precision_mode + A search strategy that generates a boolean to switch between two different + precision modes supported by numpy and (torch, jax) and test the function + as_variable_flags A search strategy that generates a list of boolean flags for array inputs to be passed as a Variable array @@ -356,6 +363,7 @@ def handle_test( as_variable=as_variable_flags, native_arrays=native_array_flags, container_flags=container_flags, + precision_mode=precision_mode, ) def test_wrapper(test_fn): @@ -418,6 +426,7 @@ def handle_frontend_test( native_array_flags=BuiltNativeArrayStrategy, test_compile=BuiltCompileStrategy, generate_frontend_arrays=BuiltFrontendArrayStrategy, + precision_mode=BuiltPrecisionModeStrategy, **_given_kwargs, ): """ @@ -444,6 +453,10 @@ def handle_frontend_test( A search strategy that generates a boolean to test the function with an `out` parameter + precision_mode + A search strategy that generates a boolean to switch between two different + precision modes supported by numpy and (torch, jax) and test the function + as_variable_flags A search strategy that generates a list of boolean flags for array inputs to be passed as a Variable array @@ -479,6 +492,7 @@ def handle_frontend_test( native_arrays=native_array_flags, test_compile=test_compile, generate_frontend_arrays=generate_frontend_arrays, + precision_mode=precision_mode, ) def test_wrapper(test_fn): @@ -550,6 +564,7 @@ def handle_method( ground_truth_backend: str = "tensorflow", test_gradients=BuiltGradientStrategy, test_compile=BuiltCompileStrategy, + precision_mode=BuiltPrecisionModeStrategy, init_num_positional_args=None, init_native_arrays=BuiltNativeArrayStrategy, init_as_variable_flags=BuiltAsVariableStrategy, @@ -572,6 +587,7 @@ def handle_method( ground_truth_backend The framework to assert test results are equal to """ + # need to fill up the docstring is_method_tree_provided = method_tree is not None if is_method_tree_provided: method_tree = "ivy." + method_tree @@ -580,6 +596,7 @@ def handle_method( "ground_truth_backend": st.just(ground_truth_backend), "test_gradients": test_gradients, "test_compile": test_compile, + "precision_mode": precision_mode, } if is_hypothesis_test and is_method_tree_provided: @@ -594,6 +611,7 @@ def handle_method( num_positional_args=init_num_positional_args, as_variable=init_as_variable_flags, native_arrays=init_native_arrays, + precision_mode=precision_mode, ) if method_num_positional_args is None: @@ -606,6 +624,7 @@ def handle_method( as_variable=method_as_variable_flags, native_arrays=method_native_arrays, container_flags=method_container_flags, + precision_mode=precision_mode, ) def test_wrapper(test_fn): @@ -666,6 +685,7 @@ def handle_frontend_method( init_native_arrays=BuiltNativeArrayStrategy, init_as_variable_flags=BuiltAsVariableStrategy, test_compile=BuiltCompileStrategy, + precision_mode=BuiltPrecisionModeStrategy, method_num_positional_args=None, method_native_arrays=BuiltNativeArrayStrategy, method_as_variable_flags=BuiltAsVariableStrategy, @@ -725,6 +745,7 @@ def test_wrapper(test_fn): as_variable=init_as_variable_flags, native_arrays=init_native_arrays, test_compile=test_compile, + precision_mode=precision_mode, ) method_flags = pf.frontend_method_flags( @@ -732,6 +753,7 @@ def test_wrapper(test_fn): as_variable=method_as_variable_flags, native_arrays=method_native_arrays, test_compile=test_compile, + precision_mode=precision_mode, ) ivy_init_modules = str(ivy_init_module) framework_init_modules = str(framework_init_module) From 8e4eb629db5c440d344da337aeb7a58058a65292 Mon Sep 17 00:00:00 2001 From: Bhushan Srivastava <59949692+he11owthere@users.noreply.github.com> Date: Sun, 27 Aug 2023 18:05:21 +0530 Subject: [PATCH 07/55] Added remainder_ method to the Paddle frontend --- .../frontends/paddle/tensor/tensor.py | 7 ++++ .../test_paddle/test_tensor/test_tensor.py | 39 +++++++++++++++++++ 2 files changed, 46 insertions(+) diff --git a/ivy/functional/frontends/paddle/tensor/tensor.py b/ivy/functional/frontends/paddle/tensor/tensor.py index 65dc025fb6ecf..579f8d4bdc08e 100644 --- a/ivy/functional/frontends/paddle/tensor/tensor.py +++ b/ivy/functional/frontends/paddle/tensor/tensor.py @@ -234,6 +234,13 @@ def isinf(self, name=None): def square(self, name=None): return paddle_frontend.Tensor(ivy.square(self._ivy_array)) + @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") + def remainder_(self, y, name=None): + self.ivy_array = paddle_frontend.Tensor( + ivy.remainder(self._ivy_array, _to_ivy_array(y)) + ).ivy_array + return self + @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") def cholesky(self, upper=False, name=None): return paddle_frontend.Tensor(ivy.cholesky(self._ivy_array, upper=upper)) diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_tensor.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_tensor.py index 726a8c053462b..4758934edb971 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_tensor.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_tensor.py @@ -1033,6 +1033,45 @@ def test_paddle_tensor_square( ) +# remainder_ +@handle_frontend_method( + class_tree=CLASS_TREE, + init_tree="paddle.to_tensor", + method_name="remainder_", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, + shared_dtype=True, + ), +) +def test_paddle_tensor_remainder_( + dtype_and_x, + frontend_method_data, + init_flags, + method_flags, + frontend, + on_device, + backend_fw, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_method( + init_input_dtypes=input_dtype, + backend_to_test=backend_fw, + init_all_as_kwargs_np={ + "value": x[0], + }, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={ + "y": x[1], + }, + frontend_method_data=frontend_method_data, + init_flags=init_flags, + method_flags=method_flags, + frontend=frontend, + on_device=on_device, + ) + + # cholesky @handle_frontend_method( class_tree=CLASS_TREE, From 570e1a93d319b308ed9ce4cb72103045cdc6ae11 Mon Sep 17 00:00:00 2001 From: Ishtiaq Hussain <53497039+Ishticode@users.noreply.github.com> Date: Sun, 27 Aug 2023 15:51:37 +0100 Subject: [PATCH 08/55] add CrossValidator base class and KFold class with essential method to be implemented --- .../sklearn/model_selection/__init__.py | 4 +++ .../sklearn/model_selection/_split.py | 31 +++++++++++++++++++ 2 files changed, 35 insertions(+) create mode 100644 ivy/functional/frontends/sklearn/model_selection/__init__.py create mode 100644 ivy/functional/frontends/sklearn/model_selection/_split.py diff --git a/ivy/functional/frontends/sklearn/model_selection/__init__.py b/ivy/functional/frontends/sklearn/model_selection/__init__.py new file mode 100644 index 0000000000000..e35652aca2942 --- /dev/null +++ b/ivy/functional/frontends/sklearn/model_selection/__init__.py @@ -0,0 +1,4 @@ +from ._split import ( + BaseCrossValidator, + KFold, +) \ No newline at end of file diff --git a/ivy/functional/frontends/sklearn/model_selection/_split.py b/ivy/functional/frontends/sklearn/model_selection/_split.py new file mode 100644 index 0000000000000..d0948efa72634 --- /dev/null +++ b/ivy/functional/frontends/sklearn/model_selection/_split.py @@ -0,0 +1,31 @@ +class BaseCrossValidator: + def split(self, X, y=None, groups=None): + raise NotImplementedError + + def _iter_test_masks(self, X=None, y=None, groups=None): + raise NotImplementedError + + def _iter_test_indices(self, X=None, y=None, groups=None): + raise NotImplementedError + + +class KFold(BaseCrossValidator): + def __init__( + self, + n_splits=5, + *, + shuffle=False, + random_state=None, + ): + self.n_splits = n_splits + self.shuffle = shuffle + self.random_state = random_state + + def split(self, X, y=None, groups=None): + raise NotImplementedError + + def _iter_test_masks(self, X=None, y=None, groups=None): + raise NotImplementedError + + def _iter_test_indices(self, X=None, y=None, groups=None): + raise NotImplementedError From 610413d04faeb601f4b7953f6375bb89f379f73d Mon Sep 17 00:00:00 2001 From: Ishtiaq Hussain <53497039+Ishticode@users.noreply.github.com> Date: Sun, 27 Aug 2023 21:06:52 +0100 Subject: [PATCH 09/55] feat(frontend): add `test_train_split` function with the test stratify not implemented yet --- .../sklearn/model_selection/__init__.py | 6 +-- .../sklearn/model_selection/_split.py | 50 +++++++++++++++++++ .../test_model_selection/__init__.py | 0 .../test_model_selection/test_split.py | 42 ++++++++++++++++ 4 files changed, 94 insertions(+), 4 deletions(-) create mode 100644 ivy_tests/test_ivy/test_frontends/test_sklearn/test_model_selection/__init__.py create mode 100644 ivy_tests/test_ivy/test_frontends/test_sklearn/test_model_selection/test_split.py diff --git a/ivy/functional/frontends/sklearn/model_selection/__init__.py b/ivy/functional/frontends/sklearn/model_selection/__init__.py index e35652aca2942..e16bc6cd46684 100644 --- a/ivy/functional/frontends/sklearn/model_selection/__init__.py +++ b/ivy/functional/frontends/sklearn/model_selection/__init__.py @@ -1,4 +1,2 @@ -from ._split import ( - BaseCrossValidator, - KFold, -) \ No newline at end of file +from . import _split +from ._split import * diff --git a/ivy/functional/frontends/sklearn/model_selection/_split.py b/ivy/functional/frontends/sklearn/model_selection/_split.py index d0948efa72634..b331da8ab40b2 100644 --- a/ivy/functional/frontends/sklearn/model_selection/_split.py +++ b/ivy/functional/frontends/sklearn/model_selection/_split.py @@ -1,3 +1,6 @@ +import ivy +from ivy.functional.frontends.numpy.func_wrapper import to_ivy_arrays_and_back + class BaseCrossValidator: def split(self, X, y=None, groups=None): raise NotImplementedError @@ -29,3 +32,50 @@ def _iter_test_masks(self, X=None, y=None, groups=None): def _iter_test_indices(self, X=None, y=None, groups=None): raise NotImplementedError + + +@to_ivy_arrays_and_back +def train_test_split(*arrays, test_size=None, train_size=None, random_state=None, shuffle=True, stratify=None): + # TODO: Make it concise + # TODO: implement stratify + if stratify is not None: + raise NotImplementedError + n_arrays = len(arrays) + if n_arrays == 0: + raise ValueError("At least one array required as input") + if test_size is None and train_size is None: + test_size = 0.25 + n_samples = arrays[0].shape[0] + test_size_type, train_size_type = type(test_size), type(train_size) + if "f" in str(test_size_type): + n_test = ivy.ceil(test_size * n_samples) + elif "i" in str(test_size_type): + n_test = float(test_size) + else: + n_test = 0 + + if "f" in str(train_size_type): + n_train = ivy.floor(train_size * n_samples) + elif "i" in str(train_size_type): + n_train = float(train_size) + else: + n_train = 0 + + if train_size is None: + n_train = n_samples - n_test + elif test_size is None: + n_test = n_samples - n_train + + n_train, n_test = int(n_train), int(n_test) + indices = ivy.arange(0, n_train + n_test) + if shuffle: + if random_state is not None: + ivy.seed(random_state) + indices = ivy.shuffle(indices) + train_indices = indices[:n_train] + test_indices = indices[n_train:] + output = [] + for array in arrays: + output.append(ivy.gather(array, train_indices, axis=0)) + output.append(ivy.gather(array, test_indices, axis=0)) + return tuple(output) diff --git a/ivy_tests/test_ivy/test_frontends/test_sklearn/test_model_selection/__init__.py b/ivy_tests/test_ivy/test_frontends/test_sklearn/test_model_selection/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/ivy_tests/test_ivy/test_frontends/test_sklearn/test_model_selection/test_split.py b/ivy_tests/test_ivy/test_frontends/test_sklearn/test_model_selection/test_split.py new file mode 100644 index 0000000000000..80778d6830621 --- /dev/null +++ b/ivy_tests/test_ivy/test_frontends/test_sklearn/test_model_selection/test_split.py @@ -0,0 +1,42 @@ +from hypothesis import strategies as st + +import ivy_tests.test_ivy.helpers as helpers +from ivy_tests.test_ivy.helpers import handle_frontend_test + + +@handle_frontend_test( + fn_tree="sklearn.model_selection.train_test_split", + arrays_and_dtypes=helpers.dtype_and_values( + num_arrays=helpers.ints(min_value=2, max_value=4), + shape=helpers.lists( + x=helpers.ints(min_value=2, max_value=5), + min_size=2, + max_size=3, + )), + shuffle=st.booleans(), +) +def test_sklearn_test_train_split( + arrays_and_dtypes, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, + shuffle, +): + dtypes, values = arrays_and_dtypes + kw = {} + for i, x_ in enumerate(values): + kw["x{}".format(i)] = x_ + test_flags.num_positional_args = len(values) + helpers.test_frontend_function( + input_dtypes=dtypes, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + frontend=frontend, + on_device=on_device, + test_values=False, + **kw, + shuffle=shuffle, + ) From a88c50df69cad78f595bd73ee4bb6839d7e7d243 Mon Sep 17 00:00:00 2001 From: Ishtiaq Hussain <53497039+Ishticode@users.noreply.github.com> Date: Sun, 27 Aug 2023 21:19:25 +0100 Subject: [PATCH 10/55] refactor(frontend): simplify train_size and test_size derivation in test_train_split function --- .../sklearn/model_selection/_split.py | 22 +++++-------------- 1 file changed, 5 insertions(+), 17 deletions(-) diff --git a/ivy/functional/frontends/sklearn/model_selection/_split.py b/ivy/functional/frontends/sklearn/model_selection/_split.py index b331da8ab40b2..9738f472a6f97 100644 --- a/ivy/functional/frontends/sklearn/model_selection/_split.py +++ b/ivy/functional/frontends/sklearn/model_selection/_split.py @@ -40,27 +40,15 @@ def train_test_split(*arrays, test_size=None, train_size=None, random_state=None # TODO: implement stratify if stratify is not None: raise NotImplementedError - n_arrays = len(arrays) - if n_arrays == 0: + if len(arrays) == 0: raise ValueError("At least one array required as input") if test_size is None and train_size is None: test_size = 0.25 n_samples = arrays[0].shape[0] - test_size_type, train_size_type = type(test_size), type(train_size) - if "f" in str(test_size_type): - n_test = ivy.ceil(test_size * n_samples) - elif "i" in str(test_size_type): - n_test = float(test_size) - else: - n_test = 0 - - if "f" in str(train_size_type): - n_train = ivy.floor(train_size * n_samples) - elif "i" in str(train_size_type): - n_train = float(train_size) - else: - n_train = 0 - + n_train = ivy.floor(train_size * n_samples) if isinstance(train_size, float) \ + else float(train_size) if isinstance(train_size, int) else None + n_test = ivy.ceil(test_size * n_samples) if isinstance(test_size, float) \ + else float(test_size) if isinstance(test_size, int) else None if train_size is None: n_train = n_samples - n_test elif test_size is None: From 9d5f27eb00a56dc3e2226cc20b7de3e3c858cd02 Mon Sep 17 00:00:00 2001 From: Vaishnavi Mudaliar Date: Mon, 28 Aug 2023 02:59:27 +0530 Subject: [PATCH 11/55] Added Categorical to jax frontend (#22146) Co-authored-by: Saeed Ashraf --- ivy/functional/frontends/jax/random.py | 37 ++++++++++++- .../test_frontends/test_jax/test_random.py | 55 +++++++++++++++++++ 2 files changed, 90 insertions(+), 2 deletions(-) diff --git a/ivy/functional/frontends/jax/random.py b/ivy/functional/frontends/jax/random.py index fec1047b83ec5..4bd88d7b02aa1 100644 --- a/ivy/functional/frontends/jax/random.py +++ b/ivy/functional/frontends/jax/random.py @@ -27,6 +27,9 @@ def _get_seed(key): def PRNGKey(seed): return ivy.array([0, seed % 4294967295 - (seed // 4294967295)], dtype=ivy.int64) +def _remove_axis(shape, axis): + return shape[:axis] + shape[axis + 1 :] + @handle_jax_dtype @to_ivy_arrays_and_back @@ -371,8 +374,6 @@ def uniform(key, shape=(), dtype=None, minval=0.0, maxval=1.0): low=minval, high=maxval, shape=shape, dtype=dtype, seed=ivy.to_scalar(key[1]) ) - -@handle_jax_dtype @to_ivy_arrays_and_back @with_unsupported_dtypes( { @@ -383,9 +384,41 @@ def uniform(key, shape=(), dtype=None, minval=0.0, maxval=1.0): }, "jax", ) + +def categorical(key, logits, axis, shape=None): + _get_seed(key) + logits_arr = ivy.asarray(logits) + + if axis >= 0: + axis -= len(logits_arr.shape) + batch_shape = tuple(_remove_axis(logits_arr.shape, axis)) + + if shape is None: + shape = batch_shape + else: + shape = tuple(shape) + if shape != batch_shape: + raise ValueError( ++ f"Shape {shape} is not compatible with reference shape {batch_shape}" + ) + + shape_prefix = shape[: len(shape) - len(batch_shape)] + logits_shape = list(shape[len(shape) - len(batch_shape) :]) + logits_shape.insert(axis % len(logits_arr.shape), logits_arr.shape[axis]) + + gumbel_noise = gumbel(key, ivy.array(logits_shape), logits_arr.dtype) + expanded_logits = ivy.expand_dims(logits_arr, axis=axis) + noisy_logits = gumbel_noise + expanded_logits + + # Use Ivy's argmax to get indices + indices = ivy.argmax(noisy_logits, axis=axis) + + return indices + def weibull_min(key, scale, concentration, shape=(), dtype="float64"): seed = _get_seed(key) uniform_x = ivy.random_uniform(seed=seed, shape=shape, dtype=dtype) x = 1 - uniform_x weibull = x ** (concentration - 1) * -ivy.log(x / scale) return weibull + diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_random.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_random.py index aa63058977a9b..bec5396f8dbab 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_random.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_random.py @@ -1518,3 +1518,58 @@ def call(): for u, v in zip(ret_np, ret_from_np): assert u.dtype == v.dtype assert u.shape == v.shape + + +@pytest.mark.xfail +@handle_frontend_test( + fn_tree="jax.random.categorical", + dtype_key=helpers.dtype_and_values( + available_dtypes=["uint32"], + min_value=0, + max_value=2000, + min_num_dims=1, + max_num_dims=1, + min_dim_size=2, + max_dim_size=2, + ), + shape=helpers.get_shape( + min_dim_size=1, max_num_dims=6, max_dim_size=6, min_num_dims=1, allow_none=False + ), + dtype=helpers.get_dtypes("float", full=False), +) +def test_jax_categorical( + *, + dtype_key, + shape, + dtype, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + + input_dtype,key = dtype_key + + def call(): + return helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + test_flags=test_flags, + backend_to_test=backend_fw, + fn_tree=fn_tree, + on_device=on_device, + test_values=False, + key=key[0], + shape=shape, + dtype=dtype[0], + ) + ret = call() + if not ivy.exists(ret): + return + ret_np, ret_from_np = ret + ret_np = helpers.flatten_and_to_np(ret=ret_np, backend=backend_fw) + ret_from_np = helpers.flatten_and_to_np(ret=ret_from_np, backend=backend_fw) + for u, v in zip(ret_np, ret_from_np): + assert u.dtype == v.dtype + assert u.shape == v.shape From d58265d02bfc495fe82962e499e16b1b91741df2 Mon Sep 17 00:00:00 2001 From: Humza Tareen Date: Mon, 28 Aug 2023 06:23:55 +0500 Subject: [PATCH 12/55] updated func wrapper imports for scipy frontend (#22622) --- ivy/functional/frontends/scipy/fft/fft.py | 5 ++--- ivy/functional/frontends/scipy/linalg/linalg.py | 4 +--- ivy/functional/frontends/scipy/spatial/distance.py | 4 +--- 3 files changed, 4 insertions(+), 9 deletions(-) diff --git a/ivy/functional/frontends/scipy/fft/fft.py b/ivy/functional/frontends/scipy/fft/fft.py index eda0b67063221..3ae72a2439351 100644 --- a/ivy/functional/frontends/scipy/fft/fft.py +++ b/ivy/functional/frontends/scipy/fft/fft.py @@ -1,8 +1,7 @@ # global +from ivy.functional.frontends.numpy.func_wrapper import to_ivy_arrays_and_back + import ivy -from ivy.functional.frontends.scipy.func_wrapper import ( - to_ivy_arrays_and_back, -) # dct diff --git a/ivy/functional/frontends/scipy/linalg/linalg.py b/ivy/functional/frontends/scipy/linalg/linalg.py index 587d7e23febde..f81f214043312 100644 --- a/ivy/functional/frontends/scipy/linalg/linalg.py +++ b/ivy/functional/frontends/scipy/linalg/linalg.py @@ -1,8 +1,6 @@ # global import ivy -from ivy.functional.frontends.scipy.func_wrapper import ( - to_ivy_arrays_and_back, -) +from ivy.functional.frontends.numpy.func_wrapper import to_ivy_arrays_and_back # --- Helpers --- # diff --git a/ivy/functional/frontends/scipy/spatial/distance.py b/ivy/functional/frontends/scipy/spatial/distance.py index 282665ed13246..eb38e843c167f 100644 --- a/ivy/functional/frontends/scipy/spatial/distance.py +++ b/ivy/functional/frontends/scipy/spatial/distance.py @@ -1,8 +1,6 @@ # global import ivy -from ivy.functional.frontends.scipy.func_wrapper import ( - to_ivy_arrays_and_back, -) +from ivy.functional.frontends.numpy.func_wrapper import to_ivy_arrays_and_back import ivy.functional.frontends.scipy as sc_frontend From 645ec2099efc1d96b12b423872554708104a82c8 Mon Sep 17 00:00:00 2001 From: Vera <121622878+VeraChristina@users.noreply.github.com> Date: Mon, 28 Aug 2023 02:55:22 +0100 Subject: [PATCH 13/55] added instance method isfinite to pytorch tensor (pytorch frontend) (#22010) --- .../backends/tensorflow/elementwise.py | 1 + ivy/functional/frontends/torch/tensor.py | 3 ++ .../test_frontends/test_torch/test_tensor.py | 33 +++++++++++++++++++ 3 files changed, 37 insertions(+) diff --git a/ivy/functional/backends/tensorflow/elementwise.py b/ivy/functional/backends/tensorflow/elementwise.py index 81d1850f626a9..6d2875755729a 100644 --- a/ivy/functional/backends/tensorflow/elementwise.py +++ b/ivy/functional/backends/tensorflow/elementwise.py @@ -337,6 +337,7 @@ def greater_equal( return tf.math.greater_equal(x1, x2) +@with_unsupported_dtypes({"2.13.0 and below": ("bool",)}, backend_version) def isfinite( x: Union[tf.Tensor, tf.Variable], /, diff --git a/ivy/functional/frontends/torch/tensor.py b/ivy/functional/frontends/torch/tensor.py index 09ec33d2d5fd3..2819ffba27b74 100644 --- a/ivy/functional/frontends/torch/tensor.py +++ b/ivy/functional/frontends/torch/tensor.py @@ -1127,6 +1127,9 @@ def fix_(self): self.ivy_array = self.fix().ivy_array return self + def isfinite(self): + return torch_frontend.isfinite(self._ivy_array) + def isinf(self): return torch_frontend.isinf(self._ivy_array) diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py index e45dfb5acf803..ddadd153bf3d7 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py @@ -10019,6 +10019,39 @@ def test_torch_tensor_norm( ) +# isfinite +@handle_frontend_method( + class_tree=CLASS_TREE, + init_tree="torch.tensor", + method_name="isfinite", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + ), +) +def test_torch_tensor_isfinite( + dtype_and_x, + frontend_method_data, + init_flags, + method_flags, + frontend, + on_device, + backend_fw, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_method( + init_input_dtypes=input_dtype, + backend_to_test=backend_fw, + init_all_as_kwargs_np={"data": x[0]}, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={}, + frontend_method_data=frontend_method_data, + init_flags=init_flags, + method_flags=method_flags, + frontend=frontend, + on_device=on_device, + ) + + # isinf @handle_frontend_method( class_tree=CLASS_TREE, From fd1192230f466e1d8f3761ec77bf1b100b335299 Mon Sep 17 00:00:00 2001 From: umairjavaid Date: Mon, 28 Aug 2023 07:09:27 +0500 Subject: [PATCH 14/55] Revert "added instance method isfinite to pytorch tensor (pytorch frontend)" (#22673) --- .../backends/tensorflow/elementwise.py | 1 - ivy/functional/frontends/torch/tensor.py | 3 -- .../test_frontends/test_torch/test_tensor.py | 33 ------------------- 3 files changed, 37 deletions(-) diff --git a/ivy/functional/backends/tensorflow/elementwise.py b/ivy/functional/backends/tensorflow/elementwise.py index 6d2875755729a..81d1850f626a9 100644 --- a/ivy/functional/backends/tensorflow/elementwise.py +++ b/ivy/functional/backends/tensorflow/elementwise.py @@ -337,7 +337,6 @@ def greater_equal( return tf.math.greater_equal(x1, x2) -@with_unsupported_dtypes({"2.13.0 and below": ("bool",)}, backend_version) def isfinite( x: Union[tf.Tensor, tf.Variable], /, diff --git a/ivy/functional/frontends/torch/tensor.py b/ivy/functional/frontends/torch/tensor.py index 2819ffba27b74..09ec33d2d5fd3 100644 --- a/ivy/functional/frontends/torch/tensor.py +++ b/ivy/functional/frontends/torch/tensor.py @@ -1127,9 +1127,6 @@ def fix_(self): self.ivy_array = self.fix().ivy_array return self - def isfinite(self): - return torch_frontend.isfinite(self._ivy_array) - def isinf(self): return torch_frontend.isinf(self._ivy_array) diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py index ddadd153bf3d7..e45dfb5acf803 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py @@ -10019,39 +10019,6 @@ def test_torch_tensor_norm( ) -# isfinite -@handle_frontend_method( - class_tree=CLASS_TREE, - init_tree="torch.tensor", - method_name="isfinite", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - ), -) -def test_torch_tensor_isfinite( - dtype_and_x, - frontend_method_data, - init_flags, - method_flags, - frontend, - on_device, - backend_fw, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_method( - init_input_dtypes=input_dtype, - backend_to_test=backend_fw, - init_all_as_kwargs_np={"data": x[0]}, - method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, - frontend_method_data=frontend_method_data, - init_flags=init_flags, - method_flags=method_flags, - frontend=frontend, - on_device=on_device, - ) - - # isinf @handle_frontend_method( class_tree=CLASS_TREE, From 5e2a3288f5a29da00c6bc3b2df31c63f08563c51 Mon Sep 17 00:00:00 2001 From: Ishtiaq Hussain <53497039+Ishticode@users.noreply.github.com> Date: Mon, 28 Aug 2023 03:49:26 +0100 Subject: [PATCH 15/55] feat(frontend): setup sklearn metrics submodule and implement rudimentary accuracy_score with test (yet to fix) --- .../frontends/sklearn/metrics/__init__.py | 2 + .../sklearn/metrics/_classification.py | 11 +++++ .../test_sklearn/test_metrics/__init__.py | 0 .../test_metrics/test_classification.py | 40 +++++++++++++++++++ 4 files changed, 53 insertions(+) create mode 100644 ivy/functional/frontends/sklearn/metrics/__init__.py create mode 100644 ivy/functional/frontends/sklearn/metrics/_classification.py create mode 100644 ivy_tests/test_ivy/test_frontends/test_sklearn/test_metrics/__init__.py create mode 100644 ivy_tests/test_ivy/test_frontends/test_sklearn/test_metrics/test_classification.py diff --git a/ivy/functional/frontends/sklearn/metrics/__init__.py b/ivy/functional/frontends/sklearn/metrics/__init__.py new file mode 100644 index 0000000000000..b13c572d7a449 --- /dev/null +++ b/ivy/functional/frontends/sklearn/metrics/__init__.py @@ -0,0 +1,2 @@ +from . import _classification +from ._classification import * diff --git a/ivy/functional/frontends/sklearn/metrics/_classification.py b/ivy/functional/frontends/sklearn/metrics/_classification.py new file mode 100644 index 0000000000000..3fa7d7d4ec6f9 --- /dev/null +++ b/ivy/functional/frontends/sklearn/metrics/_classification.py @@ -0,0 +1,11 @@ +import ivy + + +def accuracy_score(y_true, y_pred, normalize=True, sample_weight=None): + # TODO: implement sample_weight + # TODO: multi-class + ret = ivy.equal(y_true, y_pred) + ret = ret.sum() + if normalize: + ret = ret / y_true.shape[0] + return ret diff --git a/ivy_tests/test_ivy/test_frontends/test_sklearn/test_metrics/__init__.py b/ivy_tests/test_ivy/test_frontends/test_sklearn/test_metrics/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/ivy_tests/test_ivy/test_frontends/test_sklearn/test_metrics/test_classification.py b/ivy_tests/test_ivy/test_frontends/test_sklearn/test_metrics/test_classification.py new file mode 100644 index 0000000000000..85a278a632d3a --- /dev/null +++ b/ivy_tests/test_ivy/test_frontends/test_sklearn/test_metrics/test_classification.py @@ -0,0 +1,40 @@ +import ivy_tests.test_ivy.helpers as helpers +from ivy_tests.test_ivy.helpers import handle_frontend_test + + +@handle_frontend_test( + fn_tree="sklearn.metrics.accuracy_score", + arrays_and_dtypes=helpers.dtype_and_values( + num_arrays=2, + min_value=-2, + max_value=2, + shared_dtype=True, + shape=helpers.lists( + x=helpers.ints(min_value=2, max_value=5), + min_size=2, + max_size=3, + )), + normalize=st.booleans(), +) +def test_sklearn_accuracy_score( + arrays_and_dtypes, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, + normalize, +): + dtypes, values = arrays_and_dtypes + helpers.test_frontend_function( + input_dtypes=dtypes, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + frontend=frontend, + on_device=on_device, + y_true=values[0], + y_pred=values[1], + normalize=normalize, + sample_weight=None, + ) From 8bddfcfb66eca54104e6666f4a7995053c9f1658 Mon Sep 17 00:00:00 2001 From: Ishtiaq Hussain <53497039+Ishticode@users.noreply.github.com> Date: Mon, 28 Aug 2023 03:51:45 +0100 Subject: [PATCH 16/55] fix(frontend): add hypothesis strategies import --- .../test_sklearn/test_metrics/test_classification.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ivy_tests/test_ivy/test_frontends/test_sklearn/test_metrics/test_classification.py b/ivy_tests/test_ivy/test_frontends/test_sklearn/test_metrics/test_classification.py index 85a278a632d3a..00b1d112444df 100644 --- a/ivy_tests/test_ivy/test_frontends/test_sklearn/test_metrics/test_classification.py +++ b/ivy_tests/test_ivy/test_frontends/test_sklearn/test_metrics/test_classification.py @@ -1,3 +1,5 @@ +from hypothesis import strategies as st + import ivy_tests.test_ivy.helpers as helpers from ivy_tests.test_ivy.helpers import handle_frontend_test From 17ea87fd09082869297c5c56816168f4b7688036 Mon Sep 17 00:00:00 2001 From: Ishtiaq Hussain <53497039+Ishticode@users.noreply.github.com> Date: Mon, 28 Aug 2023 04:14:07 +0100 Subject: [PATCH 17/55] test(frontend): limit data generation in test_sklearn_accuracy_score to 1D, handle dtype mismatches and a todo note to specify arrays generated only for classification and not regression bcz sklearn classification metrics can't handle a mix of binary and continuous targets --- .../frontends/sklearn/metrics/_classification.py | 10 ++++++---- .../test_sklearn/test_metrics/test_classification.py | 8 +++----- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/ivy/functional/frontends/sklearn/metrics/_classification.py b/ivy/functional/frontends/sklearn/metrics/_classification.py index 3fa7d7d4ec6f9..5a39a326ca28f 100644 --- a/ivy/functional/frontends/sklearn/metrics/_classification.py +++ b/ivy/functional/frontends/sklearn/metrics/_classification.py @@ -1,11 +1,13 @@ import ivy +from ivy.functional.frontends.numpy.func_wrapper import to_ivy_arrays_and_back -def accuracy_score(y_true, y_pred, normalize=True, sample_weight=None): +@to_ivy_arrays_and_back +def accuracy_score(y_true, y_pred, *, normalize=True, sample_weight=None): # TODO: implement sample_weight - # TODO: multi-class - ret = ivy.equal(y_true, y_pred) - ret = ret.sum() + ret = ivy.equal(y_true, y_pred).astype('int64') + ret = ret.sum().astype('int64') if normalize: ret = ret / y_true.shape[0] + ret = ret.astype('float64') return ret diff --git a/ivy_tests/test_ivy/test_frontends/test_sklearn/test_metrics/test_classification.py b/ivy_tests/test_ivy/test_frontends/test_sklearn/test_metrics/test_classification.py index 00b1d112444df..b63ee606e8c99 100644 --- a/ivy_tests/test_ivy/test_frontends/test_sklearn/test_metrics/test_classification.py +++ b/ivy_tests/test_ivy/test_frontends/test_sklearn/test_metrics/test_classification.py @@ -7,15 +7,12 @@ @handle_frontend_test( fn_tree="sklearn.metrics.accuracy_score", arrays_and_dtypes=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float_and_integer"), num_arrays=2, min_value=-2, max_value=2, shared_dtype=True, - shape=helpers.lists( - x=helpers.ints(min_value=2, max_value=5), - min_size=2, - max_size=3, - )), + shape=(helpers.ints(min_value=2, max_value=5))), normalize=st.booleans(), ) def test_sklearn_accuracy_score( @@ -27,6 +24,7 @@ def test_sklearn_accuracy_score( backend_fw, normalize, ): + # todo: limit array generation to classification instead of regression (contrinuous values) dtypes, values = arrays_and_dtypes helpers.test_frontend_function( input_dtypes=dtypes, From 3855f090d91f6f48e163dc2df48cb236025b55ea Mon Sep 17 00:00:00 2001 From: akshatvishu <33392262+akshatvishu@users.noreply.github.com> Date: Mon, 28 Aug 2023 08:50:44 +0530 Subject: [PATCH 18/55] Remove incorrect 'tnp' import from `creation.py` at tensorflow backend (#22371) From 6861ad62ce23ebcff508250101d6c89076d19844 Mon Sep 17 00:00:00 2001 From: Eddy Oyieko <67474838+mobley-trent@users.noreply.github.com> Date: Mon, 28 Aug 2023 06:36:10 +0300 Subject: [PATCH 19/55] `ivy.experimental.sparsify_tensor` (#22028) Co-authored-by: Anwaar Khalid --- .../array/experimental/elementwise.py | 44 +++++++++ .../container/experimental/elementwise.py | 93 +++++++++++++++++++ .../ivy/experimental/elementwise.py | 53 +++++++++++ .../test_core/test_elementwise.py | 46 +++++++++ 4 files changed, 236 insertions(+) diff --git a/ivy/data_classes/array/experimental/elementwise.py b/ivy/data_classes/array/experimental/elementwise.py index 24c6e51cc6a37..403f1f601325e 100644 --- a/ivy/data_classes/array/experimental/elementwise.py +++ b/ivy/data_classes/array/experimental/elementwise.py @@ -1020,3 +1020,47 @@ def digamma( ivy.array([-0.7549271 0.92278427 0.9988394]) """ return ivy.digamma(self._data, out=out) + + def sparsify_tensor( + self: ivy.Array, + card: int, + /, + *, + out: Optional[ivy.Array] = None, + ) -> ivy.Array: + """ + ivy.Array class method variant of ivy.sparsify_tensor. This method simply wraps + the function, and so the docstring for ivy.sparsify_tensor also applies to this + method with minimal changes. + + Parameters + ---------- + self : array + The tensor to sparsify. + card : int + The number of values to keep. + out : array, optional + Optional output array, for writing the result to. + + Returns + ------- + ret : array + The sparsified tensor. + + Examples + -------- + >>> x = ivy.arange(100) + >>> x = ivy.reshape(x, (10, 10)) + >>> x.sparsify_tensor(10) + ivy.array([[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [90, 91, 92, 93, 94, 95, 96, 97, 98, 99]]) + """ + return ivy.sparsify_tensor(self._data, card, out=out) diff --git a/ivy/data_classes/container/experimental/elementwise.py b/ivy/data_classes/container/experimental/elementwise.py index c883069e72b7b..f11d85f81d7a6 100644 --- a/ivy/data_classes/container/experimental/elementwise.py +++ b/ivy/data_classes/container/experimental/elementwise.py @@ -2975,3 +2975,96 @@ def digamma( map_sequences=map_sequences, out=out, ) + + @staticmethod + def static_sparsify_tensor( + x: Union[ivy.Container, ivy.Array, ivy.NativeArray], + card: Union[int, ivy.Container], + /, + *, + key_chains: Optional[Union[List[str], Dict[str, str]]] = None, + to_apply: Union[bool, ivy.Container] = True, + prune_unapplied: Union[bool, ivy.Container] = False, + map_sequences: Union[bool, ivy.Container] = False, + out: Optional[ivy.Container] = None, + ) -> ivy.Container: + """ + ivy.Container static method variant of ivy.sparsify_tensor. This method simply + wraps the function, and so the docstring for ivy.sparsify_tensor also applies to + this method with minimal changes. + + Parameters + ---------- + x + Input container containing input arrays. + card + The number of values to keep in each tensor. + key_chains + The key-chains to apply or not apply the method to. Default is ``None``. + to_apply + If True, the method will be applied to key_chains, otherwise key_chains + will be skipped. Default is ``True``. + prune_unapplied + Whether to prune key_chains for which the function was not applied. + Default is ``False``. + map_sequences + Whether to also map method to sequences (lists, tuples). + Default is ``False``. + out + Alternate output container in which to place the result. + The default is None. + + Returns + ------- + ret + container including the sparsified tensor computed element-wise + Examples + -------- + >>> x = ivy.Container( + a=ivy.reshape(ivy.arange(100), (10, 10)), + b=ivy.reshape(ivy.arange(100), (10, 10)), + ) + >>> ivy.Container.static_sparsify_tensor(x, 10) + { + a: ( shape=[10, 10]), + b: ( shape=[10, 10]) + } + """ + return ContainerBase.cont_multi_map_in_function( + "sparsify_tensor", + x, + card, + key_chains=key_chains, + to_apply=to_apply, + prune_unapplied=prune_unapplied, + map_sequences=map_sequences, + out=out, + ) + + def sparsify_tensor( + self: Union[ivy.Container, ivy.Array, ivy.NativeArray], + card: Union[int, ivy.Container], + /, + *, + key_chains: Optional[Union[List[str], Dict[str, str]]] = None, + to_apply: Union[bool, ivy.Container] = True, + prune_unapplied: Union[bool, ivy.Container] = False, + map_sequences: Union[bool, ivy.Container] = False, + out: Optional[ivy.Container] = None, + ) -> ivy.Container: + """ + ivy.Container instance method variant of ivy.sparsify_tensor. + + This method simply wraps the function, and so the docstring for + ivy.sparsify_tensor also applies to this method with minimal + changes. + """ + return self.static_sparsify_tensor( + self, + card, + key_chains=key_chains, + to_apply=to_apply, + prune_unapplied=prune_unapplied, + map_sequences=map_sequences, + out=out, + ) diff --git a/ivy/functional/ivy/experimental/elementwise.py b/ivy/functional/ivy/experimental/elementwise.py index 925ac7eeaa706..613fe70acc6b2 100644 --- a/ivy/functional/ivy/experimental/elementwise.py +++ b/ivy/functional/ivy/experimental/elementwise.py @@ -1335,3 +1335,56 @@ def digamma( ivy.array([-0.7549271 0.92278427 0.9988394]) """ return ivy.current_backend(x).digamma(x, out=out) + + +@handle_exceptions +@handle_nestable +@inputs_to_ivy_arrays +@handle_array_function +def sparsify_tensor( + x: Union[ivy.Array, ivy.NativeArray], + card: int, + /, + *, + out: Optional[ivy.Array] = None, +) -> ivy.Array: + """ + Zeros out all elements in the tensor except `card` elements with maximum absolute + values. + + Parameters + ---------- + x + Tensor to be sparsified + card + Desired number of non-zero elements in the tensor + out + Optional output array for writing the result to. + + Returns + ------- + ivy.array of shape tensor.shape + + Examples + -------- + >>> x = ivy.arange(100) + >>> x = ivy.reshape(x, (10, 10)) + >>> sparsify_tensor(x, 10) + ivy.array([[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [90, 91, 92, 93, 94, 95, 96, 97, 98, 99]]) + """ + if card >= ivy.prod(ivy.array(x.shape)): + return ivy.inplace_update(out, x) if ivy.exists(out) else x + _shape = ivy.shape(x) + x = ivy.reshape(ivy.sort(ivy.abs(x)), (-1,)) + tensor = ivy.concat([ivy.zeros(len(x) - card, dtype=x.dtype), x[-card:]], axis=0) + + return ivy.reshape(tensor, _shape, out=out) diff --git a/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_elementwise.py b/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_elementwise.py index 67c8ad603ccc6..79224648053ec 100644 --- a/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_elementwise.py +++ b/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_elementwise.py @@ -822,3 +822,49 @@ def test_digamma( on_device=on_device, x=x[0], ) + + +@st.composite +def _sparsify_tensor_stg(draw): + dtype, tensor, shape = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + ret_shape=True, + min_num_dims=1, + min_dim_size=1, + min_value=10, + ) + ) + + size = 1 + for dim in shape: + size *= dim + + card = draw(st.integers(min_value=1, max_value=size)) + + return dtype, tensor[0], card + + +# sparsify_tensor +@handle_test( + fn_tree="functional.ivy.experimental.sparsify_tensor", + tensor_data=_sparsify_tensor_stg(), +) +def test_sparsify_tensor( + tensor_data, + test_flags, + on_device, + fn_name, + backend_fw, +): + dtype, tensor, card = tensor_data + + helpers.test_function( + backend_to_test=backend_fw, + test_flags=test_flags, + on_device=on_device, + fn_name=fn_name, + input_dtypes=dtype, + tensor=tensor, + card=card, + ) From 6cbc2c6071f3ba3dbea27d7e4a18fa7d2e64ac8d Mon Sep 17 00:00:00 2001 From: Abdullah Sabry Date: Mon, 28 Aug 2023 07:03:26 +0300 Subject: [PATCH 20/55] Added argmin as instance methods to JAX fontend (#22662) --- ivy/functional/frontends/jax/array.py | 18 ++++++++ .../test_frontends/test_jax/test_array.py | 42 +++++++++++++++++++ 2 files changed, 60 insertions(+) diff --git a/ivy/functional/frontends/jax/array.py b/ivy/functional/frontends/jax/array.py index 9d57a7993edd5..2ac9f3dc96b84 100644 --- a/ivy/functional/frontends/jax/array.py +++ b/ivy/functional/frontends/jax/array.py @@ -3,6 +3,7 @@ # local import ivy import ivy.functional.frontends.jax as jax_frontend +from ivy.func_wrapper import with_unsupported_dtypes class Array: @@ -73,6 +74,7 @@ def astype(self, dtype): f"Dtype {self.dtype} is not castable to {dtype}" ) + @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") def argmax( self, /, @@ -88,6 +90,22 @@ def argmax( keepdims=keepdims, ) + @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") + def argmin( + self, + /, + *, + axis=None, + out=None, + keepdims=False, + ): + return jax_frontend.numpy.argmin( + self, + axis=axis, + out=out, + keepdims=keepdims, + ) + def conj(self, /): return jax_frontend.numpy.conj(self._ivy_array) diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_array.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_array.py index 81093a9cc92a5..c9477a03b4f12 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_array.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_array.py @@ -347,6 +347,48 @@ def test_jax_array_argmax( ) +@handle_frontend_method( + class_tree=CLASS_TREE, + init_tree="jax.numpy.array", + method_name="argmin", + dtype_and_x=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("float"), + force_int_axis=True, + min_num_dims=1, + valid_axis=True, + ), + keepdims=st.booleans(), +) +def test_jax_array_argmin( + dtype_and_x, + keepdims, + on_device, + frontend, + frontend_method_data, + init_flags, + method_flags, + backend_fw, +): + input_dtype, x, axis = dtype_and_x + helpers.test_frontend_method( + init_input_dtypes=input_dtype, + backend_to_test=backend_fw, + init_all_as_kwargs_np={ + "object": x[0], + }, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={ + "axis": axis, + "keepdims": keepdims, + }, + frontend=frontend, + frontend_method_data=frontend_method_data, + init_flags=init_flags, + method_flags=method_flags, + on_device=on_device, + ) + + @handle_frontend_method( class_tree=CLASS_TREE, init_tree="jax.numpy.array", From 1600ad2f246cc8653e8f1f0602c42a5a45d5c63e Mon Sep 17 00:00:00 2001 From: theRealBird <75845929+theRealBird@users.noreply.github.com> Date: Mon, 28 Aug 2023 13:13:18 +0530 Subject: [PATCH 21/55] add missing _to_device import for set_item JAX (#22677) --- ivy/functional/backends/jax/general.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ivy/functional/backends/jax/general.py b/ivy/functional/backends/jax/general.py index 18bec4cb2e2f9..31090c6843608 100644 --- a/ivy/functional/backends/jax/general.py +++ b/ivy/functional/backends/jax/general.py @@ -15,7 +15,7 @@ # local import ivy from ivy.func_wrapper import with_unsupported_dtypes -from ivy.functional.backends.jax.device import _to_array +from ivy.functional.backends.jax.device import _to_array, _to_device from ivy.functional.ivy.general import _broadcast_to from ivy.functional.backends.jax import JaxArray, NativeArray from . import backend_version From 8d0185356aae091b7f0ee6410ecfe1e806d29a04 Mon Sep 17 00:00:00 2001 From: Bhushan Srivastava <59949692+he11owthere@users.noreply.github.com> Date: Mon, 28 Aug 2023 13:15:01 +0530 Subject: [PATCH 22/55] Added subtract_ method to the Paddle frontend (#22467) Co-authored-by: Samsam Lee <106169847+jieunboy0516@users.noreply.github.com> --- .../frontends/paddle/tensor/tensor.py | 1291 +++++++++-------- .../test_paddle/test_tensor/test_tensor.py | 33 + 2 files changed, 682 insertions(+), 642 deletions(-) diff --git a/ivy/functional/frontends/paddle/tensor/tensor.py b/ivy/functional/frontends/paddle/tensor/tensor.py index 579f8d4bdc08e..81c7dd423e829 100644 --- a/ivy/functional/frontends/paddle/tensor/tensor.py +++ b/ivy/functional/frontends/paddle/tensor/tensor.py @@ -1,642 +1,649 @@ -# local -import ivy -import ivy.functional.frontends.paddle as paddle_frontend -from ivy.func_wrapper import ( - with_supported_dtypes, - with_unsupported_dtypes, -) -from ivy.functional.frontends.paddle.func_wrapper import _to_ivy_array - - -class Tensor: - def __init__(self, array, dtype=None, place="cpu", stop_gradient=True): - self._ivy_array = ( - ivy.array(array, dtype=dtype, device=place) - if not isinstance(array, ivy.Array) - else array - ) - self._dtype = dtype - self._place = place - self._stop_gradient = stop_gradient - - def __repr__(self): - return ( - str(self._ivy_array.__repr__()) - .replace("ivy.array", "ivy.frontends.paddle.Tensor") - .replace("dev", "place") - ) - - # Properties # - # ---------- # - - @property - def ivy_array(self): - return self._ivy_array - - @property - def place(self): - return self.ivy_array.device - - @property - def dtype(self): - return self._ivy_array.dtype - - @property - def shape(self): - return self._ivy_array.shape - - @property - def ndim(self): - return self.dim() - - # Setters # - # --------# - - @ivy_array.setter - def ivy_array(self, array): - self._ivy_array = ( - ivy.array(array) if not isinstance(array, ivy.Array) else array - ) - - # Special Methods # - # -------------------# - - def __getitem__(self, item): - ivy_args = ivy.nested_map([self, item], _to_ivy_array) - ret = ivy.get_item(*ivy_args) - return paddle_frontend.Tensor(ret) - - def __setitem__(self, item, value): - raise ivy.utils.exceptions.IvyException( - "ivy.functional.frontends.paddle.Tensor object doesn't support assignment" - ) - - def __iter__(self): - if self.ndim == 0: - raise TypeError("iteration over a 0-d tensor not supported") - for i in range(self.shape[0]): - yield self[i] - - # Instance Methods # - # ---------------- # - - def reshape(self, *args, shape=None): - if args and shape: - raise TypeError("reshape() got multiple values for argument 'shape'") - if shape is not None: - return paddle_frontend.reshape(self._ivy_array, shape) - if args: - if isinstance(args[0], (tuple, list)): - shape = args[0] - return paddle_frontend.reshape(self._ivy_array, shape) - else: - return paddle_frontend.reshape(self._ivy_array, args) - return paddle_frontend.reshape(self._ivy_array) - - def dim(self): - return self.ivy_array.ndim - - @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") - def abs(self): - return paddle_frontend.abs(self) - - @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") - def acosh(self, name=None): - return paddle_frontend.Tensor(ivy.acosh(self._ivy_array)) - - @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") - def ceil(self): - return paddle_frontend.ceil(self) - - @with_unsupported_dtypes({"2.5.1 and below": ("complex", "int8")}, "paddle") - def numel(self): - return paddle_frontend.numel(self) - - @with_unsupported_dtypes({"2.5.1 and below": ("float16",)}, "paddle") - def asinh(self, name=None): - return paddle_frontend.Tensor(ivy.asinh(self._ivy_array)) - - @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") - def asin(self, name=None): - return paddle_frontend.Tensor(ivy.asin(self._ivy_array)) - - @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") - def cosh(self, name=None): - return paddle_frontend.Tensor(ivy.cosh(self._ivy_array)) - - @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") - def log(self, name=None): - return paddle_frontend.Tensor(ivy.log(self._ivy_array)) - - @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") - def sin(self, name=None): - return paddle_frontend.Tensor(ivy.sin(self._ivy_array)) - - @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") - def sinh(self, name=None): - return paddle_frontend.Tensor(ivy.sinh(self._ivy_array)) - - @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") - def argmax(self, axis=None, keepdim=False, dtype=None, name=None): - return paddle_frontend.Tensor( - ivy.argmax(self._ivy_array, axis=axis, keepdims=keepdim, dtype=dtype) - ) - - @with_unsupported_dtypes({"2.5.1 and below": ("float16", "uint16")}, "paddle") - def unsqueeze(self, axis=None, name=None): - return paddle_frontend.Tensor(ivy.expand_dims(self._ivy_array, axis=axis)) - - @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") - def sqrt(self, name=None): - return paddle_frontend.Tensor(ivy.sqrt(self._ivy_array)) - - @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") - def sqrt_(self, name=None): - self.ivy_array = self.sqrt().ivy_array - return self - - @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") - def cos(self, name=None): - return paddle_frontend.Tensor(ivy.cos(self._ivy_array)) - - @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") - def exp(self, name=None): - return paddle_frontend.Tensor(ivy.exp(self._ivy_array)) - - @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") - def exp_(self, name=None): - self.ivy_array = self.exp().ivy_array - return self - - @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") - def erf(self, name=None): - return paddle_frontend.Tensor(ivy.erf(self._ivy_array)) - - @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") - def subtract(self, y, name=None): - return paddle_frontend.Tensor(ivy.subtract(self._ivy_array, _to_ivy_array(y))) - - @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") - def log10(self, name=None): - return paddle_frontend.Tensor(ivy.log10(self._ivy_array)) - - @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") - def argsort(self, axis=-1, descending=False, name=None): - return paddle_frontend.Tensor( - ivy.argsort(self._ivy_array, axis=axis, descending=descending) - ) - - @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") - def floor(self, name=None): - return paddle_frontend.Tensor(ivy.floor(self._ivy_array)) - - @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") - def floor_(self): - self.ivy_array = self.floor().ivy_array - return self - - @with_supported_dtypes( - {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, "paddle" - ) - def clip(self, min=None, max=None, name=None): - ivy.utils.assertions.check_all_or_any_fn( - min, - max, - fn=ivy.exists, - type="any", - limit=[1, 2], - message="at most one of min or max can be None", - ) - if min is None: - ret = ivy.minimum(self._ivy_array, max) - elif max is None: - ret = ivy.maximum(self._ivy_array, min) - else: - ret = ivy.clip(self._ivy_array, min, max) - return paddle_frontend.Tensor(ret) - - @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") - def tanh(self, name=None): - return paddle_frontend.Tensor(ivy.tanh(self._ivy_array)) - - @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") - def add_(self, name=None): - return paddle_frontend.Tensor(ivy.add(self._ivy_array)) - - @with_supported_dtypes( - {"2.5.1 and below": ("float16", "float32", "float64", "int32", "int64")}, - "paddle", - ) - def isinf(self, name=None): - return paddle_frontend.Tensor(ivy.isinf(self._ivy_array)) - - @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") - def square(self, name=None): - return paddle_frontend.Tensor(ivy.square(self._ivy_array)) - - @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") - def remainder_(self, y, name=None): - self.ivy_array = paddle_frontend.Tensor( - ivy.remainder(self._ivy_array, _to_ivy_array(y)) - ).ivy_array - return self - - @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") - def cholesky(self, upper=False, name=None): - return paddle_frontend.Tensor(ivy.cholesky(self._ivy_array, upper=upper)) - - @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") - def multiply(self, y, name=None): - return paddle_frontend.multiply(self, y) - - @with_supported_dtypes( - {"2.5.1 and below": ("float16", "float32", "float64", "int32", "int64")}, - "paddle", - ) - def isfinite(self, name=None): - return paddle_frontend.Tensor(ivy.isfinite(self._ivy_array)) - - @with_supported_dtypes({"2.4.2 and below": ("float16", "bfloat16")}, "paddle") - def all(self, axis=None, keepdim=False, dtype=None, name=None): - return paddle_frontend.Tensor( - ivy.all(self.ivy_array, axis=axis, keepdims=keepdim, dtype=dtype) - ) - - @with_supported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") - def allclose(self, other, rtol=1e-05, atol=1e-08, equal_nan=False, name=None): - return paddle_frontend.Tensor( - ivy.allclose( - self._ivy_array, other, rtol=rtol, atol=atol, equal_nan=equal_nan - ) - ) - - @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") - def sort(self, axis=-1, descending=False, name=None): - return paddle_frontend.Tensor( - ivy.sort(self._ivy_array, axis=axis, descending=descending) - ) - - @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") - def log1p(self, name=None): - return ivy.log1p(self._ivy_array) - - @with_supported_dtypes( - { - "2.4.2 and below": ( - "bool", - "uint8", - "int8", - "int16", - "int32", - "int64", - ) - }, - "paddle", - ) - def bitwise_and(self, y, out=None, name=None): - return paddle_frontend.bitwise_and(self, y) - - @with_supported_dtypes( - { - "2.5.1 and below": ( - "bool", - "int8", - "int16", - "int32", - "int64", - "float32", - "float64", - ) - }, - "paddle", - ) - def logical_or(self, y, out=None, name=None): - return paddle_frontend.logical_or(self, y, out=out) - - @with_supported_dtypes( - {"2.5.1 and below": ("bool", "uint8", "int8", "int16", "int32", "int64")}, - "paddle", - ) - def bitwise_xor(self, y, out=None, name=None): - return paddle_frontend.bitwise_xor(self, y) - - @with_supported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") - def any(self, axis=None, keepdim=False, name=None): - return paddle_frontend.Tensor( - ivy.any(self._ivy_array, axis=axis, keepdims=keepdim) - ) - - @with_unsupported_dtypes({"2.5.1 and below": "bfloat16"}, "paddle") - def astype(self, dtype): - return paddle_frontend.Tensor(ivy.astype(self._ivy_array, dtype)) - - @with_supported_dtypes( - {"2.5.1 and below": ("bool", "uint8", "int8", "int16", "int32", "int64")}, - "paddle", - ) - def bitwise_not(self, out=None, name=None): - return paddle_frontend.Tensor(ivy.bitwise_invert(self._ivy_array, out=out)) - - @with_supported_dtypes( - { - "2.5.1 and below": ( - "bool", - "int8", - "int16", - "int32", - "int64", - ) - }, - "paddle", - ) - def bitwise_or(self, y, out=None, name=None): - return paddle_frontend.bitwise_or(self, y, out=out) - - @with_supported_dtypes( - { - "2.5.1 and below": ( - "bool", - "int8", - "int16", - "int32", - "int64", - "float32", - "float64", - ) - }, - "paddle", - ) - def logical_xor(self, y, out=None, name=None): - return paddle_frontend.logical_xor(self, y, out=out) - - @with_supported_dtypes( - {"2.5.1 and below": ("float16", "float32", "float64", "int32", "int64")}, - "paddle", - ) - def isnan(self, name=None): - return paddle_frontend.isnan(self) - - @with_unsupported_dtypes( - { - "2.5.1 and below": ( - "bool", - "uint8", - "int8", - "int16", - "complex64", - "complex128", - ) - }, - "paddle", - ) - def greater_than(self, y, name=None): - return paddle_frontend.greater_than(self, y) - - @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") - def rsqrt(self, name=None): - return paddle_frontend.Tensor(ivy.reciprocal(ivy.sqrt(self._ivy_array))) - - @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") - def reciprocal(self, name=None): - return paddle_frontend.reciprocal(self) - - @with_supported_dtypes( - { - "2.5.1 and below": ( - "bool", - "int8", - "int16", - "int32", - "int64", - "float32", - "float64", - ) - }, - "paddle", - ) - def logical_and(self, y, out=None, name=None): - return paddle_frontend.logical_and(self, y, out=out) - - @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") - def divide(self, y, name=None): - return paddle_frontend.divide(self, y) - - @with_unsupported_dtypes( - { - "2.5.1 and below": ( - "bool", - "uint8", - "int8", - "int16", - "complex64", - "complex128", - ) - }, - "paddle", - ) - def less_than(self, y, name=None): - return paddle_frontend.less_than(self, y) - - @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") - def cumprod(self, dim=None, dtype=None, name=None): - return paddle_frontend.Tensor( - ivy.cumprod(self._ivy_array, axis=dim, dtype=dtype) - ) - - @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") - def cumsum(self, axis=None, dtype=None, name=None): - return paddle_frontend.Tensor( - ivy.cumsum(self._ivy_array, axis=axis, dtype=dtype) - ) - - @with_supported_dtypes( - {"2.5.1 and below": ("complex64", "complex128", "float32", "float64")}, - "paddle", - ) - def angle(self, name=None): - return paddle_frontend.Tensor(ivy.angle(self._ivy_array)) - - @with_unsupported_dtypes( - { - "2.5.1 and below": ( - "uint8", - "int8", - "int16", - "complex64", - "complex128", - ) - }, - "paddle", - ) - def equal(self, y, name=None): - return paddle_frontend.equal(self, y) - - @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") - def rad2deg(self, name=None): - return paddle_frontend.Tensor(ivy.rad2deg(self._ivy_array)) - - @with_unsupported_dtypes( - { - "2.5.1 and below": ( - "uint8", - "int8", - "int16", - "float16", - "complex64", - "complex128", - ) - }, - "paddle", - ) - def equal_all(self, y, name=None): - return paddle_frontend.Tensor( - ivy.array_equal(self._ivy_array, _to_ivy_array(y)) - ) - - @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") - def maximum(self, other, name=None): - return ivy.maximum(self._ivy_array, other) - - @with_unsupported_dtypes({"2.5.1 and below": "bfloat16"}, "paddle") - def fmax(self, y, name=None): - return paddle_frontend.Tensor(ivy.fmax(self._ivy_array, _to_ivy_array(y))) - - @with_unsupported_dtypes({"2.5.1 and below": "bfloat16"}, "paddle") - def fmin(self, y, name=None): - return paddle_frontend.Tensor(ivy.fmin(self._ivy_array, _to_ivy_array(y))) - - @with_supported_dtypes( - {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, "paddle" - ) - def minimum(self, y, name=None): - return paddle_frontend.Tensor(ivy.minimum(self._ivy_array, _to_ivy_array(y))) - - @with_supported_dtypes( - {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, "paddle" - ) - def max(self, axis=None, keepdim=False, name=None): - return paddle_frontend.Tensor( - ivy.max(self._ivy_array, axis=axis, keepdims=keepdim) - ) - - @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") - def deg2rad(self, name=None): - return paddle_frontend.Tensor(ivy.deg2rad(self._ivy_array)) - - @with_supported_dtypes( - {"2.5.1 and below": ("float32", "float64", "int32", "int64", "bool")}, "paddle" - ) - def rot90(self, k=1, axes=(0, 1), name=None): - return paddle_frontend.Tensor(ivy.rot90(self._ivy_array, k=k, axes=axes)) - - @with_supported_dtypes( - {"2.5.1 and below": ("complex64", "complex128")}, - "paddle", - ) - def imag(self, name=None): - return paddle_frontend.imag(self) - - def is_tensor(self): - return paddle_frontend.is_tensor(self._ivy_array) - - @with_supported_dtypes( - { - "2.5.1 and below": ( - "float32", - "float64", - ) - }, - "paddle", - ) - def isclose(self, y, rtol=1e-05, atol=1e-08, equal_nan=False, name=None): - return paddle_frontend.isclose( - self, y, rtol=rtol, atol=atol, equal_nan=equal_nan - ) - - @with_supported_dtypes({"2.5.1 and below": ("int32", "int64")}, "paddle") - def floor_divide(self, y, name=None): - return paddle_frontend.Tensor( - ivy.floor_divide(self._ivy_array, _to_ivy_array(y)) - ) - - # cond - @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") - def cond(self, p=None, name=None): - return paddle_frontend.cond(self, p=p, name=name) - - @with_unsupported_dtypes({"2.4.2 and below": ("int16", "float16")}, "paddle") - def conj(self, name=None): - return paddle_frontend.Tensor(ivy.conj(self._ivy_array)) - - @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") - def log2(self, name=None): - return paddle_frontend.Tensor(ivy.log2(self._ivy_array)) - - @with_unsupported_dtypes( - {"2.4.2 and below": ("float32", "float64", "int32", "int64")}, "paddle" - ) - def neg(self, name=None): - return paddle_frontend.neg(self) - - @with_supported_dtypes( - { - "2.5.1 and below": ( - "bool", - "int8", - "int16", - "int32", - "int64", - "float32", - "float64", - ) - }, - "paddle", - ) - def logical_not(self, out=None, name=None): - return paddle_frontend.Tensor(ivy.logical_not(self.ivy_array)) - - @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") - def sign(self, name=None): - return ivy.sign(self._ivy_array) - - @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") - def var(self, axis=None, unbiased=True, keepdim=False, name=None): - return paddle_frontend.Tensor( - ivy.var( - self._ivy_array, axis=axis, correction=int(unbiased), keepdims=keepdim - ) - ) - - @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") - def sgn(self, name=None): - return paddle_frontend.Tensor(ivy.sign(self._ivy_array, np_variant=True)) - - def tolist(self): - return paddle_frontend.Tensor(ivy.to_list(self._ivy_array)) - - @with_supported_dtypes( - {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, - "paddle", - ) - def min(self, axis=None, keepdim=False, name=None): - return ivy.min(self._ivy_array, axis=axis, keepdims=keepdim) - - @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") - def atan(self, name=None): - return ivy.atan(self._ivy_array) - - @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") - def atanh(self, name=None): - return ivy.atanh(self._ivy_array) - - @with_unsupported_dtypes({"2.4.2 and below": ("float32", "float64")}, "paddle") - def std(self, axis=None, unbiased=True, keepdim=False, name=None): - return paddle_frontend.Tensor( - ivy.std(self._ivy_array, axis=axis, keepdims=keepdim) - ) - - @with_supported_dtypes( - {"2.5.1 and below": ("int32", "int64", "float32", "float64")}, "paddle" - ) - def trunc(self, name=None): - return paddle_frontend.Tensor(ivy.trunc(self._ivy_array)) +# local +import ivy +import ivy.functional.frontends.paddle as paddle_frontend +from ivy.func_wrapper import ( + with_supported_dtypes, + with_unsupported_dtypes, +) +from ivy.functional.frontends.paddle.func_wrapper import _to_ivy_array + + +class Tensor: + def __init__(self, array, dtype=None, place="cpu", stop_gradient=True): + self._ivy_array = ( + ivy.array(array, dtype=dtype, device=place) + if not isinstance(array, ivy.Array) + else array + ) + self._dtype = dtype + self._place = place + self._stop_gradient = stop_gradient + + def __repr__(self): + return ( + str(self._ivy_array.__repr__()) + .replace("ivy.array", "ivy.frontends.paddle.Tensor") + .replace("dev", "place") + ) + + # Properties # + # ---------- # + + @property + def ivy_array(self): + return self._ivy_array + + @property + def place(self): + return self.ivy_array.device + + @property + def dtype(self): + return self._ivy_array.dtype + + @property + def shape(self): + return self._ivy_array.shape + + @property + def ndim(self): + return self.dim() + + # Setters # + # --------# + + @ivy_array.setter + def ivy_array(self, array): + self._ivy_array = ( + ivy.array(array) if not isinstance(array, ivy.Array) else array + ) + + # Special Methods # + # -------------------# + + def __getitem__(self, item): + ivy_args = ivy.nested_map([self, item], _to_ivy_array) + ret = ivy.get_item(*ivy_args) + return paddle_frontend.Tensor(ret) + + def __setitem__(self, item, value): + raise ivy.utils.exceptions.IvyException( + "ivy.functional.frontends.paddle.Tensor object doesn't support assignment" + ) + + def __iter__(self): + if self.ndim == 0: + raise TypeError("iteration over a 0-d tensor not supported") + for i in range(self.shape[0]): + yield self[i] + + # Instance Methods # + # ---------------- # + + def reshape(self, *args, shape=None): + if args and shape: + raise TypeError("reshape() got multiple values for argument 'shape'") + if shape is not None: + return paddle_frontend.reshape(self._ivy_array, shape) + if args: + if isinstance(args[0], (tuple, list)): + shape = args[0] + return paddle_frontend.reshape(self._ivy_array, shape) + else: + return paddle_frontend.reshape(self._ivy_array, args) + return paddle_frontend.reshape(self._ivy_array) + + def dim(self): + return self.ivy_array.ndim + + @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") + def abs(self): + return paddle_frontend.abs(self) + + @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") + def acosh(self, name=None): + return paddle_frontend.Tensor(ivy.acosh(self._ivy_array)) + + @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") + def ceil(self): + return paddle_frontend.ceil(self) + + @with_unsupported_dtypes({"2.5.1 and below": ("complex", "int8")}, "paddle") + def numel(self): + return paddle_frontend.numel(self) + + @with_unsupported_dtypes({"2.5.1 and below": ("float16",)}, "paddle") + def asinh(self, name=None): + return paddle_frontend.Tensor(ivy.asinh(self._ivy_array)) + + @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") + def asin(self, name=None): + return paddle_frontend.Tensor(ivy.asin(self._ivy_array)) + + @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") + def cosh(self, name=None): + return paddle_frontend.Tensor(ivy.cosh(self._ivy_array)) + + @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") + def log(self, name=None): + return paddle_frontend.Tensor(ivy.log(self._ivy_array)) + + @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") + def sin(self, name=None): + return paddle_frontend.Tensor(ivy.sin(self._ivy_array)) + + @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") + def sinh(self, name=None): + return paddle_frontend.Tensor(ivy.sinh(self._ivy_array)) + + @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") + def argmax(self, axis=None, keepdim=False, dtype=None, name=None): + return paddle_frontend.Tensor( + ivy.argmax(self._ivy_array, axis=axis, keepdims=keepdim, dtype=dtype) + ) + + @with_unsupported_dtypes({"2.5.1 and below": ("float16", "uint16")}, "paddle") + def unsqueeze(self, axis=None, name=None): + return paddle_frontend.Tensor(ivy.expand_dims(self._ivy_array, axis=axis)) + + @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") + def sqrt(self, name=None): + return paddle_frontend.Tensor(ivy.sqrt(self._ivy_array)) + + @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") + def sqrt_(self, name=None): + self.ivy_array = self.sqrt().ivy_array + return self + + @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") + def cos(self, name=None): + return paddle_frontend.Tensor(ivy.cos(self._ivy_array)) + + @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") + def exp(self, name=None): + return paddle_frontend.Tensor(ivy.exp(self._ivy_array)) + + @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") + def exp_(self, name=None): + self.ivy_array = self.exp().ivy_array + return self + + @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") + def erf(self, name=None): + return paddle_frontend.Tensor(ivy.erf(self._ivy_array)) + + @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") + def subtract(self, y, name=None): + return paddle_frontend.Tensor(ivy.subtract(self._ivy_array, _to_ivy_array(y))) + + @with_unsupported_dtypes( + {"2.5.1 and below": ("float16", "uint8", "int8", "bool")}, "paddle" + ) + def subtract_(self, y, name=None): + self.ivy_array = self.subtract(y).ivy_array + return self + + @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") + def log10(self, name=None): + return paddle_frontend.Tensor(ivy.log10(self._ivy_array)) + + @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") + def argsort(self, axis=-1, descending=False, name=None): + return paddle_frontend.Tensor( + ivy.argsort(self._ivy_array, axis=axis, descending=descending) + ) + + @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") + def floor(self, name=None): + return paddle_frontend.Tensor(ivy.floor(self._ivy_array)) + + @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") + def floor_(self): + self.ivy_array = self.floor().ivy_array + return self + + @with_supported_dtypes( + {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, "paddle" + ) + def clip(self, min=None, max=None, name=None): + ivy.utils.assertions.check_all_or_any_fn( + min, + max, + fn=ivy.exists, + type="any", + limit=[1, 2], + message="at most one of min or max can be None", + ) + if min is None: + ret = ivy.minimum(self._ivy_array, max) + elif max is None: + ret = ivy.maximum(self._ivy_array, min) + else: + ret = ivy.clip(self._ivy_array, min, max) + return paddle_frontend.Tensor(ret) + + @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") + def tanh(self, name=None): + return paddle_frontend.Tensor(ivy.tanh(self._ivy_array)) + + @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") + def add_(self, name=None): + return paddle_frontend.Tensor(ivy.add(self._ivy_array)) + + @with_supported_dtypes( + {"2.5.1 and below": ("float16", "float32", "float64", "int32", "int64")}, + "paddle", + ) + def isinf(self, name=None): + return paddle_frontend.Tensor(ivy.isinf(self._ivy_array)) + + @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") + def square(self, name=None): + return paddle_frontend.Tensor(ivy.square(self._ivy_array)) + + @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") + def remainder_(self, y, name=None): + self.ivy_array = paddle_frontend.Tensor( + ivy.remainder(self._ivy_array, _to_ivy_array(y)) + ).ivy_array + return self + + @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") + def cholesky(self, upper=False, name=None): + return paddle_frontend.Tensor(ivy.cholesky(self._ivy_array, upper=upper)) + + @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") + def multiply(self, y, name=None): + return paddle_frontend.multiply(self, y) + + @with_supported_dtypes( + {"2.5.1 and below": ("float16", "float32", "float64", "int32", "int64")}, + "paddle", + ) + def isfinite(self, name=None): + return paddle_frontend.Tensor(ivy.isfinite(self._ivy_array)) + + @with_supported_dtypes({"2.4.2 and below": ("float16", "bfloat16")}, "paddle") + def all(self, axis=None, keepdim=False, dtype=None, name=None): + return paddle_frontend.Tensor( + ivy.all(self.ivy_array, axis=axis, keepdims=keepdim, dtype=dtype) + ) + + @with_supported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") + def allclose(self, other, rtol=1e-05, atol=1e-08, equal_nan=False, name=None): + return paddle_frontend.Tensor( + ivy.allclose( + self._ivy_array, other, rtol=rtol, atol=atol, equal_nan=equal_nan + ) + ) + + @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") + def sort(self, axis=-1, descending=False, name=None): + return paddle_frontend.Tensor( + ivy.sort(self._ivy_array, axis=axis, descending=descending) + ) + + @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") + def log1p(self, name=None): + return ivy.log1p(self._ivy_array) + + @with_supported_dtypes( + { + "2.4.2 and below": ( + "bool", + "uint8", + "int8", + "int16", + "int32", + "int64", + ) + }, + "paddle", + ) + def bitwise_and(self, y, out=None, name=None): + return paddle_frontend.bitwise_and(self, y) + + @with_supported_dtypes( + { + "2.5.1 and below": ( + "bool", + "int8", + "int16", + "int32", + "int64", + "float32", + "float64", + ) + }, + "paddle", + ) + def logical_or(self, y, out=None, name=None): + return paddle_frontend.logical_or(self, y, out=out) + + @with_supported_dtypes( + {"2.5.1 and below": ("bool", "uint8", "int8", "int16", "int32", "int64")}, + "paddle", + ) + def bitwise_xor(self, y, out=None, name=None): + return paddle_frontend.bitwise_xor(self, y) + + @with_supported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") + def any(self, axis=None, keepdim=False, name=None): + return paddle_frontend.Tensor( + ivy.any(self._ivy_array, axis=axis, keepdims=keepdim) + ) + + @with_unsupported_dtypes({"2.5.1 and below": "bfloat16"}, "paddle") + def astype(self, dtype): + return paddle_frontend.Tensor(ivy.astype(self._ivy_array, dtype)) + + @with_supported_dtypes( + {"2.5.1 and below": ("bool", "uint8", "int8", "int16", "int32", "int64")}, + "paddle", + ) + def bitwise_not(self, out=None, name=None): + return paddle_frontend.Tensor(ivy.bitwise_invert(self._ivy_array, out=out)) + + @with_supported_dtypes( + { + "2.5.1 and below": ( + "bool", + "int8", + "int16", + "int32", + "int64", + ) + }, + "paddle", + ) + def bitwise_or(self, y, out=None, name=None): + return paddle_frontend.bitwise_or(self, y, out=out) + + @with_supported_dtypes( + { + "2.5.1 and below": ( + "bool", + "int8", + "int16", + "int32", + "int64", + "float32", + "float64", + ) + }, + "paddle", + ) + def logical_xor(self, y, out=None, name=None): + return paddle_frontend.logical_xor(self, y, out=out) + + @with_supported_dtypes( + {"2.5.1 and below": ("float16", "float32", "float64", "int32", "int64")}, + "paddle", + ) + def isnan(self, name=None): + return paddle_frontend.isnan(self) + + @with_unsupported_dtypes( + { + "2.5.1 and below": ( + "bool", + "uint8", + "int8", + "int16", + "complex64", + "complex128", + ) + }, + "paddle", + ) + def greater_than(self, y, name=None): + return paddle_frontend.greater_than(self, y) + + @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") + def rsqrt(self, name=None): + return paddle_frontend.Tensor(ivy.reciprocal(ivy.sqrt(self._ivy_array))) + + @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") + def reciprocal(self, name=None): + return paddle_frontend.reciprocal(self) + + @with_supported_dtypes( + { + "2.5.1 and below": ( + "bool", + "int8", + "int16", + "int32", + "int64", + "float32", + "float64", + ) + }, + "paddle", + ) + def logical_and(self, y, out=None, name=None): + return paddle_frontend.logical_and(self, y, out=out) + + @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") + def divide(self, y, name=None): + return paddle_frontend.divide(self, y) + + @with_unsupported_dtypes( + { + "2.5.1 and below": ( + "bool", + "uint8", + "int8", + "int16", + "complex64", + "complex128", + ) + }, + "paddle", + ) + def less_than(self, y, name=None): + return paddle_frontend.less_than(self, y) + + @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") + def cumprod(self, dim=None, dtype=None, name=None): + return paddle_frontend.Tensor( + ivy.cumprod(self._ivy_array, axis=dim, dtype=dtype) + ) + + @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") + def cumsum(self, axis=None, dtype=None, name=None): + return paddle_frontend.Tensor( + ivy.cumsum(self._ivy_array, axis=axis, dtype=dtype) + ) + + @with_supported_dtypes( + {"2.5.1 and below": ("complex64", "complex128", "float32", "float64")}, + "paddle", + ) + def angle(self, name=None): + return paddle_frontend.Tensor(ivy.angle(self._ivy_array)) + + @with_unsupported_dtypes( + { + "2.5.1 and below": ( + "uint8", + "int8", + "int16", + "complex64", + "complex128", + ) + }, + "paddle", + ) + def equal(self, y, name=None): + return paddle_frontend.equal(self, y) + + @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") + def rad2deg(self, name=None): + return paddle_frontend.Tensor(ivy.rad2deg(self._ivy_array)) + + @with_unsupported_dtypes( + { + "2.5.1 and below": ( + "uint8", + "int8", + "int16", + "float16", + "complex64", + "complex128", + ) + }, + "paddle", + ) + def equal_all(self, y, name=None): + return paddle_frontend.Tensor( + ivy.array_equal(self._ivy_array, _to_ivy_array(y)) + ) + + @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") + def maximum(self, other, name=None): + return ivy.maximum(self._ivy_array, other) + + @with_unsupported_dtypes({"2.5.1 and below": "bfloat16"}, "paddle") + def fmax(self, y, name=None): + return paddle_frontend.Tensor(ivy.fmax(self._ivy_array, _to_ivy_array(y))) + + @with_unsupported_dtypes({"2.5.1 and below": "bfloat16"}, "paddle") + def fmin(self, y, name=None): + return paddle_frontend.Tensor(ivy.fmin(self._ivy_array, _to_ivy_array(y))) + + @with_supported_dtypes( + {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, "paddle" + ) + def minimum(self, y, name=None): + return paddle_frontend.Tensor(ivy.minimum(self._ivy_array, _to_ivy_array(y))) + + @with_supported_dtypes( + {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, "paddle" + ) + def max(self, axis=None, keepdim=False, name=None): + return paddle_frontend.Tensor( + ivy.max(self._ivy_array, axis=axis, keepdims=keepdim) + ) + + @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") + def deg2rad(self, name=None): + return paddle_frontend.Tensor(ivy.deg2rad(self._ivy_array)) + + @with_supported_dtypes( + {"2.5.1 and below": ("float32", "float64", "int32", "int64", "bool")}, "paddle" + ) + def rot90(self, k=1, axes=(0, 1), name=None): + return paddle_frontend.Tensor(ivy.rot90(self._ivy_array, k=k, axes=axes)) + + @with_supported_dtypes( + {"2.5.1 and below": ("complex64", "complex128")}, + "paddle", + ) + def imag(self, name=None): + return paddle_frontend.imag(self) + + def is_tensor(self): + return paddle_frontend.is_tensor(self._ivy_array) + + @with_supported_dtypes( + { + "2.5.1 and below": ( + "float32", + "float64", + ) + }, + "paddle", + ) + def isclose(self, y, rtol=1e-05, atol=1e-08, equal_nan=False, name=None): + return paddle_frontend.isclose( + self, y, rtol=rtol, atol=atol, equal_nan=equal_nan + ) + + @with_supported_dtypes({"2.5.1 and below": ("int32", "int64")}, "paddle") + def floor_divide(self, y, name=None): + return paddle_frontend.Tensor( + ivy.floor_divide(self._ivy_array, _to_ivy_array(y)) + ) + + # cond + @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") + def cond(self, p=None, name=None): + return paddle_frontend.cond(self, p=p, name=name) + + @with_unsupported_dtypes({"2.4.2 and below": ("int16", "float16")}, "paddle") + def conj(self, name=None): + return paddle_frontend.Tensor(ivy.conj(self._ivy_array)) + + @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") + def log2(self, name=None): + return paddle_frontend.Tensor(ivy.log2(self._ivy_array)) + + @with_unsupported_dtypes( + {"2.4.2 and below": ("float32", "float64", "int32", "int64")}, "paddle" + ) + def neg(self, name=None): + return paddle_frontend.neg(self) + + @with_supported_dtypes( + { + "2.5.1 and below": ( + "bool", + "int8", + "int16", + "int32", + "int64", + "float32", + "float64", + ) + }, + "paddle", + ) + def logical_not(self, out=None, name=None): + return paddle_frontend.Tensor(ivy.logical_not(self.ivy_array)) + + @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") + def sign(self, name=None): + return ivy.sign(self._ivy_array) + + @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") + def var(self, axis=None, unbiased=True, keepdim=False, name=None): + return paddle_frontend.Tensor( + ivy.var( + self._ivy_array, axis=axis, correction=int(unbiased), keepdims=keepdim + ) + ) + + @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") + def sgn(self, name=None): + return paddle_frontend.Tensor(ivy.sign(self._ivy_array, np_variant=True)) + + def tolist(self): + return paddle_frontend.Tensor(ivy.to_list(self._ivy_array)) + + @with_supported_dtypes( + {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, + "paddle", + ) + def min(self, axis=None, keepdim=False, name=None): + return ivy.min(self._ivy_array, axis=axis, keepdims=keepdim) + + @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") + def atan(self, name=None): + return ivy.atan(self._ivy_array) + + @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") + def atanh(self, name=None): + return ivy.atanh(self._ivy_array) + + @with_unsupported_dtypes({"2.4.2 and below": ("float32", "float64")}, "paddle") + def std(self, axis=None, unbiased=True, keepdim=False, name=None): + return paddle_frontend.Tensor( + ivy.std(self._ivy_array, axis=axis, keepdims=keepdim) + ) + + @with_supported_dtypes( + {"2.5.1 and below": ("int32", "int64", "float32", "float64")}, "paddle" + ) + def trunc(self, name=None): + return paddle_frontend.Tensor(ivy.trunc(self._ivy_array)) diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_tensor.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_tensor.py index 4758934edb971..be9b7139c9745 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_tensor.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_tensor.py @@ -1506,6 +1506,39 @@ def test_paddle_tensor_subtract( ) +# subtract_ +@handle_frontend_method( + class_tree=CLASS_TREE, + init_tree="paddle.to_tensor", + method_name="subtract_", + dtypes_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), num_arrays=2, shared_dtype=True + ), +) +def test_paddle_tensor_subtract_( + dtypes_and_x, + frontend_method_data, + init_flags, + method_flags, + frontend, + on_device, + backend_fw, +): + input_dtype, x = dtypes_and_x + helpers.test_frontend_method( + init_input_dtypes=input_dtype, + backend_to_test=backend_fw, + init_all_as_kwargs_np={"data": x[0]}, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={"y": x[1]}, + frontend_method_data=frontend_method_data, + init_flags=init_flags, + method_flags=method_flags, + frontend=frontend, + on_device=on_device, + ) + + # bitwise_xor @handle_frontend_method( class_tree=CLASS_TREE, From 60839be68c82f2da65d91ce3a1634235a7bc4468 Mon Sep 17 00:00:00 2001 From: ivy-branch Date: Mon, 28 Aug 2023 08:06:52 +0000 Subject: [PATCH 23/55] =?UTF-8?q?=F0=9F=A4=96=20Lint=20code?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ivy/functional/frontends/jax/random.py | 89 +- .../frontends/paddle/tensor/tensor.py | 1298 ++++++++--------- .../sklearn/metrics/_classification.py | 6 +- .../sklearn/model_selection/_split.py | 22 +- .../test_frontends/test_jax/test_random.py | 6 +- .../test_metrics/test_classification.py | 3 +- .../test_model_selection/test_split.py | 3 +- 7 files changed, 722 insertions(+), 705 deletions(-) diff --git a/ivy/functional/frontends/jax/random.py b/ivy/functional/frontends/jax/random.py index 4bd88d7b02aa1..60586fa572d99 100644 --- a/ivy/functional/frontends/jax/random.py +++ b/ivy/functional/frontends/jax/random.py @@ -19,6 +19,10 @@ def _get_seed(key): return ivy.to_scalar(int("".join(map(str, [key1, key2])))) +def _remove_axis(shape, axis): + return shape[:axis] + shape[axis + 1 :] + + # --- Main --- # # ------------ # @@ -27,9 +31,6 @@ def _get_seed(key): def PRNGKey(seed): return ivy.array([0, seed % 4294967295 - (seed // 4294967295)], dtype=ivy.int64) -def _remove_axis(shape, axis): - return shape[:axis] + shape[axis + 1 :] - @handle_jax_dtype @to_ivy_arrays_and_back @@ -79,6 +80,47 @@ def beta(key, a, b, shape=None, dtype=None): return ivy.beta(a, b, shape=shape, dtype=dtype, seed=seed) +@to_ivy_arrays_and_back +@with_unsupported_dtypes( + { + "0.4.14 and below": ( + "float16", + "bfloat16", + ) + }, + "jax", +) +def categorical(key, logits, axis, shape=None): + _get_seed(key) + logits_arr = ivy.asarray(logits) + + if axis >= 0: + axis -= len(logits_arr.shape) + batch_shape = tuple(_remove_axis(logits_arr.shape, axis)) + + if shape is None: + shape = batch_shape + else: + shape = tuple(shape) + if shape != batch_shape: + raise ValueError( + +f"Shape {shape} is not compatible with reference shape {batch_shape}" + ) + + shape_prefix = shape[: len(shape) - len(batch_shape)] + logits_shape = list(shape[len(shape) - len(batch_shape) :]) + logits_shape.insert(axis % len(logits_arr.shape), logits_arr.shape[axis]) + + gumbel_noise = gumbel(key, ivy.array(logits_shape), logits_arr.dtype) + expanded_logits = ivy.expand_dims(logits_arr, axis=axis) + noisy_logits = gumbel_noise + expanded_logits + + # Use Ivy's argmax to get indices + indices = ivy.argmax(noisy_logits, axis=axis) + + return indices + + @handle_jax_dtype @to_ivy_arrays_and_back def cauchy(key, shape=(), dtype="float64"): @@ -374,46 +416,6 @@ def uniform(key, shape=(), dtype=None, minval=0.0, maxval=1.0): low=minval, high=maxval, shape=shape, dtype=dtype, seed=ivy.to_scalar(key[1]) ) -@to_ivy_arrays_and_back -@with_unsupported_dtypes( - { - "0.4.14 and below": ( - "float16", - "bfloat16", - ) - }, - "jax", -) - -def categorical(key, logits, axis, shape=None): - _get_seed(key) - logits_arr = ivy.asarray(logits) - - if axis >= 0: - axis -= len(logits_arr.shape) - batch_shape = tuple(_remove_axis(logits_arr.shape, axis)) - - if shape is None: - shape = batch_shape - else: - shape = tuple(shape) - if shape != batch_shape: - raise ValueError( -+ f"Shape {shape} is not compatible with reference shape {batch_shape}" - ) - - shape_prefix = shape[: len(shape) - len(batch_shape)] - logits_shape = list(shape[len(shape) - len(batch_shape) :]) - logits_shape.insert(axis % len(logits_arr.shape), logits_arr.shape[axis]) - - gumbel_noise = gumbel(key, ivy.array(logits_shape), logits_arr.dtype) - expanded_logits = ivy.expand_dims(logits_arr, axis=axis) - noisy_logits = gumbel_noise + expanded_logits - - # Use Ivy's argmax to get indices - indices = ivy.argmax(noisy_logits, axis=axis) - - return indices def weibull_min(key, scale, concentration, shape=(), dtype="float64"): seed = _get_seed(key) @@ -421,4 +423,3 @@ def weibull_min(key, scale, concentration, shape=(), dtype="float64"): x = 1 - uniform_x weibull = x ** (concentration - 1) * -ivy.log(x / scale) return weibull - diff --git a/ivy/functional/frontends/paddle/tensor/tensor.py b/ivy/functional/frontends/paddle/tensor/tensor.py index 81c7dd423e829..3037eacf755ab 100644 --- a/ivy/functional/frontends/paddle/tensor/tensor.py +++ b/ivy/functional/frontends/paddle/tensor/tensor.py @@ -1,649 +1,649 @@ -# local -import ivy -import ivy.functional.frontends.paddle as paddle_frontend -from ivy.func_wrapper import ( - with_supported_dtypes, - with_unsupported_dtypes, -) -from ivy.functional.frontends.paddle.func_wrapper import _to_ivy_array - - -class Tensor: - def __init__(self, array, dtype=None, place="cpu", stop_gradient=True): - self._ivy_array = ( - ivy.array(array, dtype=dtype, device=place) - if not isinstance(array, ivy.Array) - else array - ) - self._dtype = dtype - self._place = place - self._stop_gradient = stop_gradient - - def __repr__(self): - return ( - str(self._ivy_array.__repr__()) - .replace("ivy.array", "ivy.frontends.paddle.Tensor") - .replace("dev", "place") - ) - - # Properties # - # ---------- # - - @property - def ivy_array(self): - return self._ivy_array - - @property - def place(self): - return self.ivy_array.device - - @property - def dtype(self): - return self._ivy_array.dtype - - @property - def shape(self): - return self._ivy_array.shape - - @property - def ndim(self): - return self.dim() - - # Setters # - # --------# - - @ivy_array.setter - def ivy_array(self, array): - self._ivy_array = ( - ivy.array(array) if not isinstance(array, ivy.Array) else array - ) - - # Special Methods # - # -------------------# - - def __getitem__(self, item): - ivy_args = ivy.nested_map([self, item], _to_ivy_array) - ret = ivy.get_item(*ivy_args) - return paddle_frontend.Tensor(ret) - - def __setitem__(self, item, value): - raise ivy.utils.exceptions.IvyException( - "ivy.functional.frontends.paddle.Tensor object doesn't support assignment" - ) - - def __iter__(self): - if self.ndim == 0: - raise TypeError("iteration over a 0-d tensor not supported") - for i in range(self.shape[0]): - yield self[i] - - # Instance Methods # - # ---------------- # - - def reshape(self, *args, shape=None): - if args and shape: - raise TypeError("reshape() got multiple values for argument 'shape'") - if shape is not None: - return paddle_frontend.reshape(self._ivy_array, shape) - if args: - if isinstance(args[0], (tuple, list)): - shape = args[0] - return paddle_frontend.reshape(self._ivy_array, shape) - else: - return paddle_frontend.reshape(self._ivy_array, args) - return paddle_frontend.reshape(self._ivy_array) - - def dim(self): - return self.ivy_array.ndim - - @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") - def abs(self): - return paddle_frontend.abs(self) - - @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") - def acosh(self, name=None): - return paddle_frontend.Tensor(ivy.acosh(self._ivy_array)) - - @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") - def ceil(self): - return paddle_frontend.ceil(self) - - @with_unsupported_dtypes({"2.5.1 and below": ("complex", "int8")}, "paddle") - def numel(self): - return paddle_frontend.numel(self) - - @with_unsupported_dtypes({"2.5.1 and below": ("float16",)}, "paddle") - def asinh(self, name=None): - return paddle_frontend.Tensor(ivy.asinh(self._ivy_array)) - - @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") - def asin(self, name=None): - return paddle_frontend.Tensor(ivy.asin(self._ivy_array)) - - @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") - def cosh(self, name=None): - return paddle_frontend.Tensor(ivy.cosh(self._ivy_array)) - - @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") - def log(self, name=None): - return paddle_frontend.Tensor(ivy.log(self._ivy_array)) - - @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") - def sin(self, name=None): - return paddle_frontend.Tensor(ivy.sin(self._ivy_array)) - - @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") - def sinh(self, name=None): - return paddle_frontend.Tensor(ivy.sinh(self._ivy_array)) - - @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") - def argmax(self, axis=None, keepdim=False, dtype=None, name=None): - return paddle_frontend.Tensor( - ivy.argmax(self._ivy_array, axis=axis, keepdims=keepdim, dtype=dtype) - ) - - @with_unsupported_dtypes({"2.5.1 and below": ("float16", "uint16")}, "paddle") - def unsqueeze(self, axis=None, name=None): - return paddle_frontend.Tensor(ivy.expand_dims(self._ivy_array, axis=axis)) - - @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") - def sqrt(self, name=None): - return paddle_frontend.Tensor(ivy.sqrt(self._ivy_array)) - - @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") - def sqrt_(self, name=None): - self.ivy_array = self.sqrt().ivy_array - return self - - @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") - def cos(self, name=None): - return paddle_frontend.Tensor(ivy.cos(self._ivy_array)) - - @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") - def exp(self, name=None): - return paddle_frontend.Tensor(ivy.exp(self._ivy_array)) - - @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") - def exp_(self, name=None): - self.ivy_array = self.exp().ivy_array - return self - - @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") - def erf(self, name=None): - return paddle_frontend.Tensor(ivy.erf(self._ivy_array)) - - @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") - def subtract(self, y, name=None): - return paddle_frontend.Tensor(ivy.subtract(self._ivy_array, _to_ivy_array(y))) - - @with_unsupported_dtypes( - {"2.5.1 and below": ("float16", "uint8", "int8", "bool")}, "paddle" - ) - def subtract_(self, y, name=None): - self.ivy_array = self.subtract(y).ivy_array - return self - - @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") - def log10(self, name=None): - return paddle_frontend.Tensor(ivy.log10(self._ivy_array)) - - @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") - def argsort(self, axis=-1, descending=False, name=None): - return paddle_frontend.Tensor( - ivy.argsort(self._ivy_array, axis=axis, descending=descending) - ) - - @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") - def floor(self, name=None): - return paddle_frontend.Tensor(ivy.floor(self._ivy_array)) - - @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") - def floor_(self): - self.ivy_array = self.floor().ivy_array - return self - - @with_supported_dtypes( - {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, "paddle" - ) - def clip(self, min=None, max=None, name=None): - ivy.utils.assertions.check_all_or_any_fn( - min, - max, - fn=ivy.exists, - type="any", - limit=[1, 2], - message="at most one of min or max can be None", - ) - if min is None: - ret = ivy.minimum(self._ivy_array, max) - elif max is None: - ret = ivy.maximum(self._ivy_array, min) - else: - ret = ivy.clip(self._ivy_array, min, max) - return paddle_frontend.Tensor(ret) - - @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") - def tanh(self, name=None): - return paddle_frontend.Tensor(ivy.tanh(self._ivy_array)) - - @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") - def add_(self, name=None): - return paddle_frontend.Tensor(ivy.add(self._ivy_array)) - - @with_supported_dtypes( - {"2.5.1 and below": ("float16", "float32", "float64", "int32", "int64")}, - "paddle", - ) - def isinf(self, name=None): - return paddle_frontend.Tensor(ivy.isinf(self._ivy_array)) - - @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") - def square(self, name=None): - return paddle_frontend.Tensor(ivy.square(self._ivy_array)) - - @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") - def remainder_(self, y, name=None): - self.ivy_array = paddle_frontend.Tensor( - ivy.remainder(self._ivy_array, _to_ivy_array(y)) - ).ivy_array - return self - - @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") - def cholesky(self, upper=False, name=None): - return paddle_frontend.Tensor(ivy.cholesky(self._ivy_array, upper=upper)) - - @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") - def multiply(self, y, name=None): - return paddle_frontend.multiply(self, y) - - @with_supported_dtypes( - {"2.5.1 and below": ("float16", "float32", "float64", "int32", "int64")}, - "paddle", - ) - def isfinite(self, name=None): - return paddle_frontend.Tensor(ivy.isfinite(self._ivy_array)) - - @with_supported_dtypes({"2.4.2 and below": ("float16", "bfloat16")}, "paddle") - def all(self, axis=None, keepdim=False, dtype=None, name=None): - return paddle_frontend.Tensor( - ivy.all(self.ivy_array, axis=axis, keepdims=keepdim, dtype=dtype) - ) - - @with_supported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") - def allclose(self, other, rtol=1e-05, atol=1e-08, equal_nan=False, name=None): - return paddle_frontend.Tensor( - ivy.allclose( - self._ivy_array, other, rtol=rtol, atol=atol, equal_nan=equal_nan - ) - ) - - @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") - def sort(self, axis=-1, descending=False, name=None): - return paddle_frontend.Tensor( - ivy.sort(self._ivy_array, axis=axis, descending=descending) - ) - - @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") - def log1p(self, name=None): - return ivy.log1p(self._ivy_array) - - @with_supported_dtypes( - { - "2.4.2 and below": ( - "bool", - "uint8", - "int8", - "int16", - "int32", - "int64", - ) - }, - "paddle", - ) - def bitwise_and(self, y, out=None, name=None): - return paddle_frontend.bitwise_and(self, y) - - @with_supported_dtypes( - { - "2.5.1 and below": ( - "bool", - "int8", - "int16", - "int32", - "int64", - "float32", - "float64", - ) - }, - "paddle", - ) - def logical_or(self, y, out=None, name=None): - return paddle_frontend.logical_or(self, y, out=out) - - @with_supported_dtypes( - {"2.5.1 and below": ("bool", "uint8", "int8", "int16", "int32", "int64")}, - "paddle", - ) - def bitwise_xor(self, y, out=None, name=None): - return paddle_frontend.bitwise_xor(self, y) - - @with_supported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") - def any(self, axis=None, keepdim=False, name=None): - return paddle_frontend.Tensor( - ivy.any(self._ivy_array, axis=axis, keepdims=keepdim) - ) - - @with_unsupported_dtypes({"2.5.1 and below": "bfloat16"}, "paddle") - def astype(self, dtype): - return paddle_frontend.Tensor(ivy.astype(self._ivy_array, dtype)) - - @with_supported_dtypes( - {"2.5.1 and below": ("bool", "uint8", "int8", "int16", "int32", "int64")}, - "paddle", - ) - def bitwise_not(self, out=None, name=None): - return paddle_frontend.Tensor(ivy.bitwise_invert(self._ivy_array, out=out)) - - @with_supported_dtypes( - { - "2.5.1 and below": ( - "bool", - "int8", - "int16", - "int32", - "int64", - ) - }, - "paddle", - ) - def bitwise_or(self, y, out=None, name=None): - return paddle_frontend.bitwise_or(self, y, out=out) - - @with_supported_dtypes( - { - "2.5.1 and below": ( - "bool", - "int8", - "int16", - "int32", - "int64", - "float32", - "float64", - ) - }, - "paddle", - ) - def logical_xor(self, y, out=None, name=None): - return paddle_frontend.logical_xor(self, y, out=out) - - @with_supported_dtypes( - {"2.5.1 and below": ("float16", "float32", "float64", "int32", "int64")}, - "paddle", - ) - def isnan(self, name=None): - return paddle_frontend.isnan(self) - - @with_unsupported_dtypes( - { - "2.5.1 and below": ( - "bool", - "uint8", - "int8", - "int16", - "complex64", - "complex128", - ) - }, - "paddle", - ) - def greater_than(self, y, name=None): - return paddle_frontend.greater_than(self, y) - - @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") - def rsqrt(self, name=None): - return paddle_frontend.Tensor(ivy.reciprocal(ivy.sqrt(self._ivy_array))) - - @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") - def reciprocal(self, name=None): - return paddle_frontend.reciprocal(self) - - @with_supported_dtypes( - { - "2.5.1 and below": ( - "bool", - "int8", - "int16", - "int32", - "int64", - "float32", - "float64", - ) - }, - "paddle", - ) - def logical_and(self, y, out=None, name=None): - return paddle_frontend.logical_and(self, y, out=out) - - @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") - def divide(self, y, name=None): - return paddle_frontend.divide(self, y) - - @with_unsupported_dtypes( - { - "2.5.1 and below": ( - "bool", - "uint8", - "int8", - "int16", - "complex64", - "complex128", - ) - }, - "paddle", - ) - def less_than(self, y, name=None): - return paddle_frontend.less_than(self, y) - - @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") - def cumprod(self, dim=None, dtype=None, name=None): - return paddle_frontend.Tensor( - ivy.cumprod(self._ivy_array, axis=dim, dtype=dtype) - ) - - @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") - def cumsum(self, axis=None, dtype=None, name=None): - return paddle_frontend.Tensor( - ivy.cumsum(self._ivy_array, axis=axis, dtype=dtype) - ) - - @with_supported_dtypes( - {"2.5.1 and below": ("complex64", "complex128", "float32", "float64")}, - "paddle", - ) - def angle(self, name=None): - return paddle_frontend.Tensor(ivy.angle(self._ivy_array)) - - @with_unsupported_dtypes( - { - "2.5.1 and below": ( - "uint8", - "int8", - "int16", - "complex64", - "complex128", - ) - }, - "paddle", - ) - def equal(self, y, name=None): - return paddle_frontend.equal(self, y) - - @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") - def rad2deg(self, name=None): - return paddle_frontend.Tensor(ivy.rad2deg(self._ivy_array)) - - @with_unsupported_dtypes( - { - "2.5.1 and below": ( - "uint8", - "int8", - "int16", - "float16", - "complex64", - "complex128", - ) - }, - "paddle", - ) - def equal_all(self, y, name=None): - return paddle_frontend.Tensor( - ivy.array_equal(self._ivy_array, _to_ivy_array(y)) - ) - - @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") - def maximum(self, other, name=None): - return ivy.maximum(self._ivy_array, other) - - @with_unsupported_dtypes({"2.5.1 and below": "bfloat16"}, "paddle") - def fmax(self, y, name=None): - return paddle_frontend.Tensor(ivy.fmax(self._ivy_array, _to_ivy_array(y))) - - @with_unsupported_dtypes({"2.5.1 and below": "bfloat16"}, "paddle") - def fmin(self, y, name=None): - return paddle_frontend.Tensor(ivy.fmin(self._ivy_array, _to_ivy_array(y))) - - @with_supported_dtypes( - {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, "paddle" - ) - def minimum(self, y, name=None): - return paddle_frontend.Tensor(ivy.minimum(self._ivy_array, _to_ivy_array(y))) - - @with_supported_dtypes( - {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, "paddle" - ) - def max(self, axis=None, keepdim=False, name=None): - return paddle_frontend.Tensor( - ivy.max(self._ivy_array, axis=axis, keepdims=keepdim) - ) - - @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") - def deg2rad(self, name=None): - return paddle_frontend.Tensor(ivy.deg2rad(self._ivy_array)) - - @with_supported_dtypes( - {"2.5.1 and below": ("float32", "float64", "int32", "int64", "bool")}, "paddle" - ) - def rot90(self, k=1, axes=(0, 1), name=None): - return paddle_frontend.Tensor(ivy.rot90(self._ivy_array, k=k, axes=axes)) - - @with_supported_dtypes( - {"2.5.1 and below": ("complex64", "complex128")}, - "paddle", - ) - def imag(self, name=None): - return paddle_frontend.imag(self) - - def is_tensor(self): - return paddle_frontend.is_tensor(self._ivy_array) - - @with_supported_dtypes( - { - "2.5.1 and below": ( - "float32", - "float64", - ) - }, - "paddle", - ) - def isclose(self, y, rtol=1e-05, atol=1e-08, equal_nan=False, name=None): - return paddle_frontend.isclose( - self, y, rtol=rtol, atol=atol, equal_nan=equal_nan - ) - - @with_supported_dtypes({"2.5.1 and below": ("int32", "int64")}, "paddle") - def floor_divide(self, y, name=None): - return paddle_frontend.Tensor( - ivy.floor_divide(self._ivy_array, _to_ivy_array(y)) - ) - - # cond - @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") - def cond(self, p=None, name=None): - return paddle_frontend.cond(self, p=p, name=name) - - @with_unsupported_dtypes({"2.4.2 and below": ("int16", "float16")}, "paddle") - def conj(self, name=None): - return paddle_frontend.Tensor(ivy.conj(self._ivy_array)) - - @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") - def log2(self, name=None): - return paddle_frontend.Tensor(ivy.log2(self._ivy_array)) - - @with_unsupported_dtypes( - {"2.4.2 and below": ("float32", "float64", "int32", "int64")}, "paddle" - ) - def neg(self, name=None): - return paddle_frontend.neg(self) - - @with_supported_dtypes( - { - "2.5.1 and below": ( - "bool", - "int8", - "int16", - "int32", - "int64", - "float32", - "float64", - ) - }, - "paddle", - ) - def logical_not(self, out=None, name=None): - return paddle_frontend.Tensor(ivy.logical_not(self.ivy_array)) - - @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") - def sign(self, name=None): - return ivy.sign(self._ivy_array) - - @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") - def var(self, axis=None, unbiased=True, keepdim=False, name=None): - return paddle_frontend.Tensor( - ivy.var( - self._ivy_array, axis=axis, correction=int(unbiased), keepdims=keepdim - ) - ) - - @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") - def sgn(self, name=None): - return paddle_frontend.Tensor(ivy.sign(self._ivy_array, np_variant=True)) - - def tolist(self): - return paddle_frontend.Tensor(ivy.to_list(self._ivy_array)) - - @with_supported_dtypes( - {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, - "paddle", - ) - def min(self, axis=None, keepdim=False, name=None): - return ivy.min(self._ivy_array, axis=axis, keepdims=keepdim) - - @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") - def atan(self, name=None): - return ivy.atan(self._ivy_array) - - @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") - def atanh(self, name=None): - return ivy.atanh(self._ivy_array) - - @with_unsupported_dtypes({"2.4.2 and below": ("float32", "float64")}, "paddle") - def std(self, axis=None, unbiased=True, keepdim=False, name=None): - return paddle_frontend.Tensor( - ivy.std(self._ivy_array, axis=axis, keepdims=keepdim) - ) - - @with_supported_dtypes( - {"2.5.1 and below": ("int32", "int64", "float32", "float64")}, "paddle" - ) - def trunc(self, name=None): - return paddle_frontend.Tensor(ivy.trunc(self._ivy_array)) +# local +import ivy +import ivy.functional.frontends.paddle as paddle_frontend +from ivy.func_wrapper import ( + with_supported_dtypes, + with_unsupported_dtypes, +) +from ivy.functional.frontends.paddle.func_wrapper import _to_ivy_array + + +class Tensor: + def __init__(self, array, dtype=None, place="cpu", stop_gradient=True): + self._ivy_array = ( + ivy.array(array, dtype=dtype, device=place) + if not isinstance(array, ivy.Array) + else array + ) + self._dtype = dtype + self._place = place + self._stop_gradient = stop_gradient + + def __repr__(self): + return ( + str(self._ivy_array.__repr__()) + .replace("ivy.array", "ivy.frontends.paddle.Tensor") + .replace("dev", "place") + ) + + # Properties # + # ---------- # + + @property + def ivy_array(self): + return self._ivy_array + + @property + def place(self): + return self.ivy_array.device + + @property + def dtype(self): + return self._ivy_array.dtype + + @property + def shape(self): + return self._ivy_array.shape + + @property + def ndim(self): + return self.dim() + + # Setters # + # --------# + + @ivy_array.setter + def ivy_array(self, array): + self._ivy_array = ( + ivy.array(array) if not isinstance(array, ivy.Array) else array + ) + + # Special Methods # + # -------------------# + + def __getitem__(self, item): + ivy_args = ivy.nested_map([self, item], _to_ivy_array) + ret = ivy.get_item(*ivy_args) + return paddle_frontend.Tensor(ret) + + def __setitem__(self, item, value): + raise ivy.utils.exceptions.IvyException( + "ivy.functional.frontends.paddle.Tensor object doesn't support assignment" + ) + + def __iter__(self): + if self.ndim == 0: + raise TypeError("iteration over a 0-d tensor not supported") + for i in range(self.shape[0]): + yield self[i] + + # Instance Methods # + # ---------------- # + + def reshape(self, *args, shape=None): + if args and shape: + raise TypeError("reshape() got multiple values for argument 'shape'") + if shape is not None: + return paddle_frontend.reshape(self._ivy_array, shape) + if args: + if isinstance(args[0], (tuple, list)): + shape = args[0] + return paddle_frontend.reshape(self._ivy_array, shape) + else: + return paddle_frontend.reshape(self._ivy_array, args) + return paddle_frontend.reshape(self._ivy_array) + + def dim(self): + return self.ivy_array.ndim + + @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") + def abs(self): + return paddle_frontend.abs(self) + + @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") + def acosh(self, name=None): + return paddle_frontend.Tensor(ivy.acosh(self._ivy_array)) + + @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") + def ceil(self): + return paddle_frontend.ceil(self) + + @with_unsupported_dtypes({"2.5.1 and below": ("complex", "int8")}, "paddle") + def numel(self): + return paddle_frontend.numel(self) + + @with_unsupported_dtypes({"2.5.1 and below": ("float16",)}, "paddle") + def asinh(self, name=None): + return paddle_frontend.Tensor(ivy.asinh(self._ivy_array)) + + @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") + def asin(self, name=None): + return paddle_frontend.Tensor(ivy.asin(self._ivy_array)) + + @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") + def cosh(self, name=None): + return paddle_frontend.Tensor(ivy.cosh(self._ivy_array)) + + @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") + def log(self, name=None): + return paddle_frontend.Tensor(ivy.log(self._ivy_array)) + + @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") + def sin(self, name=None): + return paddle_frontend.Tensor(ivy.sin(self._ivy_array)) + + @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") + def sinh(self, name=None): + return paddle_frontend.Tensor(ivy.sinh(self._ivy_array)) + + @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") + def argmax(self, axis=None, keepdim=False, dtype=None, name=None): + return paddle_frontend.Tensor( + ivy.argmax(self._ivy_array, axis=axis, keepdims=keepdim, dtype=dtype) + ) + + @with_unsupported_dtypes({"2.5.1 and below": ("float16", "uint16")}, "paddle") + def unsqueeze(self, axis=None, name=None): + return paddle_frontend.Tensor(ivy.expand_dims(self._ivy_array, axis=axis)) + + @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") + def sqrt(self, name=None): + return paddle_frontend.Tensor(ivy.sqrt(self._ivy_array)) + + @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") + def sqrt_(self, name=None): + self.ivy_array = self.sqrt().ivy_array + return self + + @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") + def cos(self, name=None): + return paddle_frontend.Tensor(ivy.cos(self._ivy_array)) + + @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") + def exp(self, name=None): + return paddle_frontend.Tensor(ivy.exp(self._ivy_array)) + + @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") + def exp_(self, name=None): + self.ivy_array = self.exp().ivy_array + return self + + @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") + def erf(self, name=None): + return paddle_frontend.Tensor(ivy.erf(self._ivy_array)) + + @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") + def subtract(self, y, name=None): + return paddle_frontend.Tensor(ivy.subtract(self._ivy_array, _to_ivy_array(y))) + + @with_unsupported_dtypes( + {"2.5.1 and below": ("float16", "uint8", "int8", "bool")}, "paddle" + ) + def subtract_(self, y, name=None): + self.ivy_array = self.subtract(y).ivy_array + return self + + @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") + def log10(self, name=None): + return paddle_frontend.Tensor(ivy.log10(self._ivy_array)) + + @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") + def argsort(self, axis=-1, descending=False, name=None): + return paddle_frontend.Tensor( + ivy.argsort(self._ivy_array, axis=axis, descending=descending) + ) + + @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") + def floor(self, name=None): + return paddle_frontend.Tensor(ivy.floor(self._ivy_array)) + + @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") + def floor_(self): + self.ivy_array = self.floor().ivy_array + return self + + @with_supported_dtypes( + {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, "paddle" + ) + def clip(self, min=None, max=None, name=None): + ivy.utils.assertions.check_all_or_any_fn( + min, + max, + fn=ivy.exists, + type="any", + limit=[1, 2], + message="at most one of min or max can be None", + ) + if min is None: + ret = ivy.minimum(self._ivy_array, max) + elif max is None: + ret = ivy.maximum(self._ivy_array, min) + else: + ret = ivy.clip(self._ivy_array, min, max) + return paddle_frontend.Tensor(ret) + + @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") + def tanh(self, name=None): + return paddle_frontend.Tensor(ivy.tanh(self._ivy_array)) + + @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") + def add_(self, name=None): + return paddle_frontend.Tensor(ivy.add(self._ivy_array)) + + @with_supported_dtypes( + {"2.5.1 and below": ("float16", "float32", "float64", "int32", "int64")}, + "paddle", + ) + def isinf(self, name=None): + return paddle_frontend.Tensor(ivy.isinf(self._ivy_array)) + + @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") + def square(self, name=None): + return paddle_frontend.Tensor(ivy.square(self._ivy_array)) + + @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") + def remainder_(self, y, name=None): + self.ivy_array = paddle_frontend.Tensor( + ivy.remainder(self._ivy_array, _to_ivy_array(y)) + ).ivy_array + return self + + @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") + def cholesky(self, upper=False, name=None): + return paddle_frontend.Tensor(ivy.cholesky(self._ivy_array, upper=upper)) + + @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") + def multiply(self, y, name=None): + return paddle_frontend.multiply(self, y) + + @with_supported_dtypes( + {"2.5.1 and below": ("float16", "float32", "float64", "int32", "int64")}, + "paddle", + ) + def isfinite(self, name=None): + return paddle_frontend.Tensor(ivy.isfinite(self._ivy_array)) + + @with_supported_dtypes({"2.4.2 and below": ("float16", "bfloat16")}, "paddle") + def all(self, axis=None, keepdim=False, dtype=None, name=None): + return paddle_frontend.Tensor( + ivy.all(self.ivy_array, axis=axis, keepdims=keepdim, dtype=dtype) + ) + + @with_supported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") + def allclose(self, other, rtol=1e-05, atol=1e-08, equal_nan=False, name=None): + return paddle_frontend.Tensor( + ivy.allclose( + self._ivy_array, other, rtol=rtol, atol=atol, equal_nan=equal_nan + ) + ) + + @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") + def sort(self, axis=-1, descending=False, name=None): + return paddle_frontend.Tensor( + ivy.sort(self._ivy_array, axis=axis, descending=descending) + ) + + @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") + def log1p(self, name=None): + return ivy.log1p(self._ivy_array) + + @with_supported_dtypes( + { + "2.4.2 and below": ( + "bool", + "uint8", + "int8", + "int16", + "int32", + "int64", + ) + }, + "paddle", + ) + def bitwise_and(self, y, out=None, name=None): + return paddle_frontend.bitwise_and(self, y) + + @with_supported_dtypes( + { + "2.5.1 and below": ( + "bool", + "int8", + "int16", + "int32", + "int64", + "float32", + "float64", + ) + }, + "paddle", + ) + def logical_or(self, y, out=None, name=None): + return paddle_frontend.logical_or(self, y, out=out) + + @with_supported_dtypes( + {"2.5.1 and below": ("bool", "uint8", "int8", "int16", "int32", "int64")}, + "paddle", + ) + def bitwise_xor(self, y, out=None, name=None): + return paddle_frontend.bitwise_xor(self, y) + + @with_supported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") + def any(self, axis=None, keepdim=False, name=None): + return paddle_frontend.Tensor( + ivy.any(self._ivy_array, axis=axis, keepdims=keepdim) + ) + + @with_unsupported_dtypes({"2.5.1 and below": "bfloat16"}, "paddle") + def astype(self, dtype): + return paddle_frontend.Tensor(ivy.astype(self._ivy_array, dtype)) + + @with_supported_dtypes( + {"2.5.1 and below": ("bool", "uint8", "int8", "int16", "int32", "int64")}, + "paddle", + ) + def bitwise_not(self, out=None, name=None): + return paddle_frontend.Tensor(ivy.bitwise_invert(self._ivy_array, out=out)) + + @with_supported_dtypes( + { + "2.5.1 and below": ( + "bool", + "int8", + "int16", + "int32", + "int64", + ) + }, + "paddle", + ) + def bitwise_or(self, y, out=None, name=None): + return paddle_frontend.bitwise_or(self, y, out=out) + + @with_supported_dtypes( + { + "2.5.1 and below": ( + "bool", + "int8", + "int16", + "int32", + "int64", + "float32", + "float64", + ) + }, + "paddle", + ) + def logical_xor(self, y, out=None, name=None): + return paddle_frontend.logical_xor(self, y, out=out) + + @with_supported_dtypes( + {"2.5.1 and below": ("float16", "float32", "float64", "int32", "int64")}, + "paddle", + ) + def isnan(self, name=None): + return paddle_frontend.isnan(self) + + @with_unsupported_dtypes( + { + "2.5.1 and below": ( + "bool", + "uint8", + "int8", + "int16", + "complex64", + "complex128", + ) + }, + "paddle", + ) + def greater_than(self, y, name=None): + return paddle_frontend.greater_than(self, y) + + @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") + def rsqrt(self, name=None): + return paddle_frontend.Tensor(ivy.reciprocal(ivy.sqrt(self._ivy_array))) + + @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") + def reciprocal(self, name=None): + return paddle_frontend.reciprocal(self) + + @with_supported_dtypes( + { + "2.5.1 and below": ( + "bool", + "int8", + "int16", + "int32", + "int64", + "float32", + "float64", + ) + }, + "paddle", + ) + def logical_and(self, y, out=None, name=None): + return paddle_frontend.logical_and(self, y, out=out) + + @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") + def divide(self, y, name=None): + return paddle_frontend.divide(self, y) + + @with_unsupported_dtypes( + { + "2.5.1 and below": ( + "bool", + "uint8", + "int8", + "int16", + "complex64", + "complex128", + ) + }, + "paddle", + ) + def less_than(self, y, name=None): + return paddle_frontend.less_than(self, y) + + @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") + def cumprod(self, dim=None, dtype=None, name=None): + return paddle_frontend.Tensor( + ivy.cumprod(self._ivy_array, axis=dim, dtype=dtype) + ) + + @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") + def cumsum(self, axis=None, dtype=None, name=None): + return paddle_frontend.Tensor( + ivy.cumsum(self._ivy_array, axis=axis, dtype=dtype) + ) + + @with_supported_dtypes( + {"2.5.1 and below": ("complex64", "complex128", "float32", "float64")}, + "paddle", + ) + def angle(self, name=None): + return paddle_frontend.Tensor(ivy.angle(self._ivy_array)) + + @with_unsupported_dtypes( + { + "2.5.1 and below": ( + "uint8", + "int8", + "int16", + "complex64", + "complex128", + ) + }, + "paddle", + ) + def equal(self, y, name=None): + return paddle_frontend.equal(self, y) + + @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") + def rad2deg(self, name=None): + return paddle_frontend.Tensor(ivy.rad2deg(self._ivy_array)) + + @with_unsupported_dtypes( + { + "2.5.1 and below": ( + "uint8", + "int8", + "int16", + "float16", + "complex64", + "complex128", + ) + }, + "paddle", + ) + def equal_all(self, y, name=None): + return paddle_frontend.Tensor( + ivy.array_equal(self._ivy_array, _to_ivy_array(y)) + ) + + @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") + def maximum(self, other, name=None): + return ivy.maximum(self._ivy_array, other) + + @with_unsupported_dtypes({"2.5.1 and below": "bfloat16"}, "paddle") + def fmax(self, y, name=None): + return paddle_frontend.Tensor(ivy.fmax(self._ivy_array, _to_ivy_array(y))) + + @with_unsupported_dtypes({"2.5.1 and below": "bfloat16"}, "paddle") + def fmin(self, y, name=None): + return paddle_frontend.Tensor(ivy.fmin(self._ivy_array, _to_ivy_array(y))) + + @with_supported_dtypes( + {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, "paddle" + ) + def minimum(self, y, name=None): + return paddle_frontend.Tensor(ivy.minimum(self._ivy_array, _to_ivy_array(y))) + + @with_supported_dtypes( + {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, "paddle" + ) + def max(self, axis=None, keepdim=False, name=None): + return paddle_frontend.Tensor( + ivy.max(self._ivy_array, axis=axis, keepdims=keepdim) + ) + + @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") + def deg2rad(self, name=None): + return paddle_frontend.Tensor(ivy.deg2rad(self._ivy_array)) + + @with_supported_dtypes( + {"2.5.1 and below": ("float32", "float64", "int32", "int64", "bool")}, "paddle" + ) + def rot90(self, k=1, axes=(0, 1), name=None): + return paddle_frontend.Tensor(ivy.rot90(self._ivy_array, k=k, axes=axes)) + + @with_supported_dtypes( + {"2.5.1 and below": ("complex64", "complex128")}, + "paddle", + ) + def imag(self, name=None): + return paddle_frontend.imag(self) + + def is_tensor(self): + return paddle_frontend.is_tensor(self._ivy_array) + + @with_supported_dtypes( + { + "2.5.1 and below": ( + "float32", + "float64", + ) + }, + "paddle", + ) + def isclose(self, y, rtol=1e-05, atol=1e-08, equal_nan=False, name=None): + return paddle_frontend.isclose( + self, y, rtol=rtol, atol=atol, equal_nan=equal_nan + ) + + @with_supported_dtypes({"2.5.1 and below": ("int32", "int64")}, "paddle") + def floor_divide(self, y, name=None): + return paddle_frontend.Tensor( + ivy.floor_divide(self._ivy_array, _to_ivy_array(y)) + ) + + # cond + @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") + def cond(self, p=None, name=None): + return paddle_frontend.cond(self, p=p, name=name) + + @with_unsupported_dtypes({"2.4.2 and below": ("int16", "float16")}, "paddle") + def conj(self, name=None): + return paddle_frontend.Tensor(ivy.conj(self._ivy_array)) + + @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") + def log2(self, name=None): + return paddle_frontend.Tensor(ivy.log2(self._ivy_array)) + + @with_unsupported_dtypes( + {"2.4.2 and below": ("float32", "float64", "int32", "int64")}, "paddle" + ) + def neg(self, name=None): + return paddle_frontend.neg(self) + + @with_supported_dtypes( + { + "2.5.1 and below": ( + "bool", + "int8", + "int16", + "int32", + "int64", + "float32", + "float64", + ) + }, + "paddle", + ) + def logical_not(self, out=None, name=None): + return paddle_frontend.Tensor(ivy.logical_not(self.ivy_array)) + + @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") + def sign(self, name=None): + return ivy.sign(self._ivy_array) + + @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") + def var(self, axis=None, unbiased=True, keepdim=False, name=None): + return paddle_frontend.Tensor( + ivy.var( + self._ivy_array, axis=axis, correction=int(unbiased), keepdims=keepdim + ) + ) + + @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") + def sgn(self, name=None): + return paddle_frontend.Tensor(ivy.sign(self._ivy_array, np_variant=True)) + + def tolist(self): + return paddle_frontend.Tensor(ivy.to_list(self._ivy_array)) + + @with_supported_dtypes( + {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, + "paddle", + ) + def min(self, axis=None, keepdim=False, name=None): + return ivy.min(self._ivy_array, axis=axis, keepdims=keepdim) + + @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") + def atan(self, name=None): + return ivy.atan(self._ivy_array) + + @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") + def atanh(self, name=None): + return ivy.atanh(self._ivy_array) + + @with_unsupported_dtypes({"2.4.2 and below": ("float32", "float64")}, "paddle") + def std(self, axis=None, unbiased=True, keepdim=False, name=None): + return paddle_frontend.Tensor( + ivy.std(self._ivy_array, axis=axis, keepdims=keepdim) + ) + + @with_supported_dtypes( + {"2.5.1 and below": ("int32", "int64", "float32", "float64")}, "paddle" + ) + def trunc(self, name=None): + return paddle_frontend.Tensor(ivy.trunc(self._ivy_array)) diff --git a/ivy/functional/frontends/sklearn/metrics/_classification.py b/ivy/functional/frontends/sklearn/metrics/_classification.py index 5a39a326ca28f..4221f711711d3 100644 --- a/ivy/functional/frontends/sklearn/metrics/_classification.py +++ b/ivy/functional/frontends/sklearn/metrics/_classification.py @@ -5,9 +5,9 @@ @to_ivy_arrays_and_back def accuracy_score(y_true, y_pred, *, normalize=True, sample_weight=None): # TODO: implement sample_weight - ret = ivy.equal(y_true, y_pred).astype('int64') - ret = ret.sum().astype('int64') + ret = ivy.equal(y_true, y_pred).astype("int64") + ret = ret.sum().astype("int64") if normalize: ret = ret / y_true.shape[0] - ret = ret.astype('float64') + ret = ret.astype("float64") return ret diff --git a/ivy/functional/frontends/sklearn/model_selection/_split.py b/ivy/functional/frontends/sklearn/model_selection/_split.py index 9738f472a6f97..777451d83ef8f 100644 --- a/ivy/functional/frontends/sklearn/model_selection/_split.py +++ b/ivy/functional/frontends/sklearn/model_selection/_split.py @@ -1,6 +1,7 @@ import ivy from ivy.functional.frontends.numpy.func_wrapper import to_ivy_arrays_and_back + class BaseCrossValidator: def split(self, X, y=None, groups=None): raise NotImplementedError @@ -35,7 +36,14 @@ def _iter_test_indices(self, X=None, y=None, groups=None): @to_ivy_arrays_and_back -def train_test_split(*arrays, test_size=None, train_size=None, random_state=None, shuffle=True, stratify=None): +def train_test_split( + *arrays, + test_size=None, + train_size=None, + random_state=None, + shuffle=True, + stratify=None, +): # TODO: Make it concise # TODO: implement stratify if stratify is not None: @@ -45,17 +53,23 @@ def train_test_split(*arrays, test_size=None, train_size=None, random_state=None if test_size is None and train_size is None: test_size = 0.25 n_samples = arrays[0].shape[0] - n_train = ivy.floor(train_size * n_samples) if isinstance(train_size, float) \ + n_train = ( + ivy.floor(train_size * n_samples) + if isinstance(train_size, float) else float(train_size) if isinstance(train_size, int) else None - n_test = ivy.ceil(test_size * n_samples) if isinstance(test_size, float) \ + ) + n_test = ( + ivy.ceil(test_size * n_samples) + if isinstance(test_size, float) else float(test_size) if isinstance(test_size, int) else None + ) if train_size is None: n_train = n_samples - n_test elif test_size is None: n_test = n_samples - n_train n_train, n_test = int(n_train), int(n_test) - indices = ivy.arange(0, n_train + n_test) + indices = ivy.arange(0, n_train + n_test) if shuffle: if random_state is not None: ivy.seed(random_state) diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_random.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_random.py index bec5396f8dbab..6d80a44cbbab3 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_random.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_random.py @@ -1548,9 +1548,8 @@ def test_jax_categorical( test_flags, backend_fw, ): - - input_dtype,key = dtype_key - + input_dtype, key = dtype_key + def call(): return helpers.test_frontend_function( input_dtypes=input_dtype, @@ -1564,6 +1563,7 @@ def call(): shape=shape, dtype=dtype[0], ) + ret = call() if not ivy.exists(ret): return diff --git a/ivy_tests/test_ivy/test_frontends/test_sklearn/test_metrics/test_classification.py b/ivy_tests/test_ivy/test_frontends/test_sklearn/test_metrics/test_classification.py index b63ee606e8c99..dce0938654d28 100644 --- a/ivy_tests/test_ivy/test_frontends/test_sklearn/test_metrics/test_classification.py +++ b/ivy_tests/test_ivy/test_frontends/test_sklearn/test_metrics/test_classification.py @@ -12,7 +12,8 @@ min_value=-2, max_value=2, shared_dtype=True, - shape=(helpers.ints(min_value=2, max_value=5))), + shape=(helpers.ints(min_value=2, max_value=5)), + ), normalize=st.booleans(), ) def test_sklearn_accuracy_score( diff --git a/ivy_tests/test_ivy/test_frontends/test_sklearn/test_model_selection/test_split.py b/ivy_tests/test_ivy/test_frontends/test_sklearn/test_model_selection/test_split.py index 80778d6830621..ac7963359b9b9 100644 --- a/ivy_tests/test_ivy/test_frontends/test_sklearn/test_model_selection/test_split.py +++ b/ivy_tests/test_ivy/test_frontends/test_sklearn/test_model_selection/test_split.py @@ -12,7 +12,8 @@ x=helpers.ints(min_value=2, max_value=5), min_size=2, max_size=3, - )), + ), + ), shuffle=st.booleans(), ) def test_sklearn_test_train_split( From 99d62305b9974ba2b326163b14cd6da8802626b3 Mon Sep 17 00:00:00 2001 From: AnnaTz <111577222+AnnaTz@users.noreply.github.com> Date: Mon, 28 Aug 2023 10:13:48 +0100 Subject: [PATCH 24/55] fix(ivy): handles torch.as_tensor error for np arrays with negative strides --- ivy/functional/backends/torch/creation.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/ivy/functional/backends/torch/creation.py b/ivy/functional/backends/torch/creation.py index b04111dfe3e3a..7c5a41482e44c 100644 --- a/ivy/functional/backends/torch/creation.py +++ b/ivy/functional/backends/torch/creation.py @@ -127,7 +127,13 @@ def asarray( if contain_tensor: ret = _stack_tensors(obj, dtype).to(device) return ret.clone().detach() if copy else ret - ret = torch.as_tensor(obj, dtype=dtype, device=device) + try: + ret = torch.as_tensor(obj, dtype=dtype, device=device) + except ValueError as e: + if "At least one stride in the given numpy array is negative" in str(e): + ret = torch.as_tensor(obj.copy(), dtype=dtype, device=device) + else: + raise return ret.clone().detach() if copy else ret From f755b809539c32be2431d6e1f1a04cb05bf4a850 Mon Sep 17 00:00:00 2001 From: Anwaar Khalid Date: Mon, 28 Aug 2023 14:49:48 +0530 Subject: [PATCH 25/55] Update FactorizedTensor (#22680) - replaced scipy.optimize.brentq with _bisection_root_finder to avoid a scipy dependency - rename FactorizedTensor to factorized_tensor --- ivy/__init__.py | 2 +- ivy/data_classes/__init__.py | 2 +- .../__init__.py | 0 .../base.py | 0 .../cp_tensor.py | 0 .../tucker_tensor.py | 24 ++++++++++++++++--- 6 files changed, 23 insertions(+), 5 deletions(-) rename ivy/data_classes/{FactorizedTensor => factorized_tensor}/__init__.py (100%) rename ivy/data_classes/{FactorizedTensor => factorized_tensor}/base.py (100%) rename ivy/data_classes/{FactorizedTensor => factorized_tensor}/cp_tensor.py (100%) rename ivy/data_classes/{FactorizedTensor => factorized_tensor}/tucker_tensor.py (94%) diff --git a/ivy/__init__.py b/ivy/__init__.py index 18392580d5ffd..7af6d26de224b 100644 --- a/ivy/__init__.py +++ b/ivy/__init__.py @@ -758,7 +758,7 @@ class Node(str): add_ivy_container_instance_methods, ) from .data_classes.nested_array import NestedArray -from .data_classes.FactorizedTensor import TuckerTensor, CPTensor +from .data_classes.factorized_tensor import TuckerTensor, CPTensor from ivy.utils.backend import ( current_backend, compiled_backends, diff --git a/ivy/data_classes/__init__.py b/ivy/data_classes/__init__.py index f08707536d46a..b79d5b4fca897 100644 --- a/ivy/data_classes/__init__.py +++ b/ivy/data_classes/__init__.py @@ -1,4 +1,4 @@ from . import array from . import container from . import nested_array -from . import FactorizedTensor +from . import factorized_tensor diff --git a/ivy/data_classes/FactorizedTensor/__init__.py b/ivy/data_classes/factorized_tensor/__init__.py similarity index 100% rename from ivy/data_classes/FactorizedTensor/__init__.py rename to ivy/data_classes/factorized_tensor/__init__.py diff --git a/ivy/data_classes/FactorizedTensor/base.py b/ivy/data_classes/factorized_tensor/base.py similarity index 100% rename from ivy/data_classes/FactorizedTensor/base.py rename to ivy/data_classes/factorized_tensor/base.py diff --git a/ivy/data_classes/FactorizedTensor/cp_tensor.py b/ivy/data_classes/factorized_tensor/cp_tensor.py similarity index 100% rename from ivy/data_classes/FactorizedTensor/cp_tensor.py rename to ivy/data_classes/factorized_tensor/cp_tensor.py diff --git a/ivy/data_classes/FactorizedTensor/tucker_tensor.py b/ivy/data_classes/factorized_tensor/tucker_tensor.py similarity index 94% rename from ivy/data_classes/FactorizedTensor/tucker_tensor.py rename to ivy/data_classes/factorized_tensor/tucker_tensor.py index f0f212a937482..51f155ae58b88 100644 --- a/ivy/data_classes/FactorizedTensor/tucker_tensor.py +++ b/ivy/data_classes/factorized_tensor/tucker_tensor.py @@ -7,6 +7,25 @@ import warnings +def _bisection_root_finder(fun, a, b, tol=1e-6, max_iter=100): + if fun(a) * fun(b) >= 0: + raise ValueError( + "Function values at the interval endpoints must have opposite signs" + ) + + for _ in range(max_iter): + c = (a + b) / 2 + if fun(c) == 0 or (b - a) / 2 < tol: + return c + + if fun(c) * fun(a) < 0: + b = c + else: + a = c + + raise RuntimeError("Bisection algorithm did not converge") + + class TuckerTensor(FactorizedTensor): def __init__(self, tucker_tensor): super().__init__() @@ -235,8 +254,6 @@ def tucker_mode_dot( def validate_tucker_rank( tensor_shape, rank="same", rounding="round", fixed_modes=None ): - from scipy.optimize import brentq - if rounding == "ceil": rounding_fun = ivy.ceil elif rounding == "floor": @@ -288,7 +305,8 @@ def validate_tucker_rank( + n_fixed_params * x - rank * n_param_tensor ) - fraction_param = brentq(fun, 0.0, max(rank, 1.0)) + # fraction_param = brentq(fun, 0.0, max(rank, 1.0)) + fraction_param = _bisection_root_finder(fun, 0.0, max(rank, 1.0)) rank = [max(int(rounding_fun(s * fraction_param)), 1) for s in tensor_shape] if fixed_modes is not None: From 5439856f919618626b3a59a273abf81a068405c8 Mon Sep 17 00:00:00 2001 From: RickSanchezStoic <57310695+RickSanchezStoic@users.noreply.github.com> Date: Mon, 28 Aug 2023 15:05:00 +0530 Subject: [PATCH 26/55] adds eval mode to stateful module (#22469) adds eval and training methods to Module, removes training as an argument to function calls throughout stateful, adds test_train_eval --- ivy/stateful/layers.py | 12 ++++++------ ivy/stateful/module.py | 17 +++++++++++++++++ ivy/stateful/norms.py | 15 +++++---------- .../test_ivy/test_stateful/test_modules.py | 17 +++++++++++++++++ 4 files changed, 45 insertions(+), 16 deletions(-) diff --git a/ivy/stateful/layers.py b/ivy/stateful/layers.py index 65b445b102ff9..5efc92331e6d5 100644 --- a/ivy/stateful/layers.py +++ b/ivy/stateful/layers.py @@ -147,8 +147,7 @@ def __init__( """ self._prob = prob self._scale = scale - self.training = training - Module.__init__(self, device=None, v=None, dtype=dtype) + Module.__init__(self, device=None, v=None, dtype=dtype, training=training) def _create_variables(self, device, dtype=None): """ @@ -210,6 +209,7 @@ def __init__( v=None, build_mode="on_init", dtype=None, + training=True, ): """ Multi Head Attention layer. @@ -258,6 +258,8 @@ def __init__( dtype the desired data type of the internal variables to be created if not provided. Default is ``None``. + training + If True, dropout is used, otherwise dropout is not activated. """ # proj @@ -285,6 +287,7 @@ def __init__( build_mode=build_mode, with_partial_v=True, dtype=dtype, + training=training, ) def _create_variables(self, device, dtype=None): @@ -371,7 +374,6 @@ def _forward( is_causal=False, return_attention_weights=False, average_attention_weights=True, - training=False, ): """ Perform forward pass of the MultiHeadAttention layer. @@ -396,8 +398,6 @@ def _forward( If true, indicates that the returned ``attention_weights`` should be averaged across heads. Otherwise, ``attention_weights`` are provided separately per head. Note that this flag only has an effect when ``return_attention_weights=True``. Default: ``True`` (i.e. average weights across heads) - training - If True, dropout is used, otherwise dropout is not activated. Returns ------- @@ -432,7 +432,7 @@ def _forward( return_attention_weights=return_attention_weights, average_attention_weights=average_attention_weights, dropout=self._dropout_rate, - training=training, + training=self.training, ) diff --git a/ivy/stateful/module.py b/ivy/stateful/module.py index 1ad57a745c022..dc27a344b0a06 100644 --- a/ivy/stateful/module.py +++ b/ivy/stateful/module.py @@ -61,6 +61,7 @@ def __init__( devices=None, dtype=None, dynamic_backend=None, + training=True, **kwargs, ): """ @@ -98,6 +99,9 @@ def __init__( is raised during the compiled forward pass. Default is ``True``. with_partial_v Whether to allow partial specification of variables. Default is ``False``. + training + specifies whether the module is in training or evaluation mode. Default is + ``True``. devices devices on which to distribute the module's variables 'cuda:0', 'cuda:1', 'cpu' etc. (Default value = None) @@ -145,6 +149,7 @@ def __init__( self._target = None self._lazy_compiled = False self._dynamic_backend = dynamic_backend + self.training = training if build_mode != "on_init": return if hasattr(Module, "_init_var"): @@ -813,6 +818,18 @@ def register_buffer(self, var_name, value): """Set the buffer at any place within the class.""" self._set_buffers({var_name: value}) + def eval(self): + # disables training mode for child modules + self.train(mode=False) + + def train(self, mode: bool = True): + # enables/disables training mode + self.training = mode + for module in self.v: + module = getattr(self, module, None) + if isinstance(module, ivy.Module): + module.train(mode=mode) + def __repr__(self): return object.__repr__(self) diff --git a/ivy/stateful/norms.py b/ivy/stateful/norms.py index 790a5b60aaef1..65560176ddedb 100644 --- a/ivy/stateful/norms.py +++ b/ivy/stateful/norms.py @@ -103,6 +103,7 @@ def __init__( device=None, v=None, dtype=None, + training=True, ): """ Class for applying Layer Normalization over a mini-batch of inputs. @@ -151,7 +152,7 @@ def __init__( self._bias_init = Zeros() self._running_mean_init = Zeros() self._running_var_init = Ones() - Module.__init__(self, device=device, v=v, dtype=dtype) + Module.__init__(self, device=device, v=v, dtype=dtype, training=training) def _create_variables(self, device, dtype=None): """Create internal variables for the layer.""" @@ -172,11 +173,7 @@ def _create_variables(self, device, dtype=None): } return {} - def _forward( - self, - inputs, - training: bool = False, - ): + def _forward(self, inputs): """ Perform forward pass of the BatchNorm layer. @@ -184,8 +181,6 @@ def _forward( ---------- inputs Inputs to process of shape N,C,*. - training - Determine the current phase (training/inference) Returns ------- @@ -199,11 +194,11 @@ def _forward( eps=self._epsilon, momentum=self._momentum, data_format=self.data_format, - training=training, + training=self.training, scale=self.v.w if self._affine else None, offset=self.v.b if self._affine else None, ) - if self._track_running_stats and training: + if self._track_running_stats and self.training: self.v.running_mean = running_mean self.v.running_var = running_var diff --git a/ivy_tests/test_ivy/test_stateful/test_modules.py b/ivy_tests/test_ivy/test_stateful/test_modules.py index c530042f37c1a..7ce71eacd3879 100644 --- a/ivy_tests/test_ivy/test_stateful/test_modules.py +++ b/ivy_tests/test_ivy/test_stateful/test_modules.py @@ -1089,3 +1089,20 @@ def test_get_buffers(buffer): module.register_buffer(key, item[key]) assert module.buffers == buffers + + +class ModuleWithTrainEval(ivy.Module): + def __init__(self): + super().__init__() + + def _forward(): + pass + + +@given(mode=st.booleans()) +def test_train_eval(mode): + cls = ModuleWithTrainEval() + cls.train(mode) + assert mode == cls.training + cls.eval() + assert False == cls.training From 9f88a7dfb61252c8ba0fb0c0cc6bff560fc66357 Mon Sep 17 00:00:00 2001 From: Ved Patwardhan <54766411+vedpatwardhan@users.noreply.github.com> Date: Mon, 28 Aug 2023 16:12:36 +0530 Subject: [PATCH 27/55] fix(docs) : replaced optional_m1 by optional_apple_silicon in the docs (#22683) --- docs/overview/contributing/setting_up.rst | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/overview/contributing/setting_up.rst b/docs/overview/contributing/setting_up.rst index 15f922ae1c999..6987e5e046ee1 100644 --- a/docs/overview/contributing/setting_up.rst +++ b/docs/overview/contributing/setting_up.rst @@ -146,7 +146,7 @@ Using miniconda pip install -r requirements/optional.txt - b. On M1 Mac, you will need to use the optional_m1_1 and optional_m1_2 requirements files. To install dependencies. + b. On M1 Mac, you will need to use the optional_apple_silicon_1 and optional_apple_silicon_2 requirements files. To install dependencies. .. code-block:: none @@ -224,12 +224,12 @@ This is a builtin package and doesn't require explicit installation. PS: If the link gets expired at some point in the future, check http://nz2.archive.ubuntu.com/ubuntu/pool/main/o/openssl/?C=M;O=D for a valid one. - b. On M1 Mac, you will need to use the optional_m1_1 and optional_m1_2 requirements files. To install dependencies. + b. On M1 Mac, you will need to use the optional_apple_silicon_1 and optional_apple_silicon_2 requirements files. To install dependencies. .. code-block:: none - pip install -r requirements/optional_m1_1.txt - pip install -r requirements/optional_m1_2.txt + pip install -r requirements/optional_apple_silicon_1.txt + pip install -r requirements/optional_apple_silicon_2.txt #. Installing array API testing dependencies. @@ -349,7 +349,7 @@ If Docker's latest version causes an error, try using an earlier version by visi **Important Note** -When setting up on an M1 Mac, you would have to update the Dockerfile to install libraries from :code:`requirements/optional_m1_1.txt` and :code:`requirements/optional_m1_2.txt` instead of :code:`requirements/optional.txt`. +When setting up on an M1 Mac, you would have to update the Dockerfile to install libraries from :code:`requirements/optional_apple_silicon_1.txt` and :code:`requirements/optional_apple_silicon_2.txt` instead of :code:`requirements/optional.txt`. **Video** From 5596a74e7339ef426b37b574902e29546131940c Mon Sep 17 00:00:00 2001 From: AnnaTz <111577222+AnnaTz@users.noreply.github.com> Date: Mon, 28 Aug 2023 12:14:57 +0100 Subject: [PATCH 28/55] Fixes the paddle backend of ivy.unique_all (#22686) --- ivy/functional/backends/paddle/set.py | 56 +++++++++++++++++---------- 1 file changed, 36 insertions(+), 20 deletions(-) diff --git a/ivy/functional/backends/paddle/set.py b/ivy/functional/backends/paddle/set.py index 133a092b7f696..64e8ffdae8c15 100644 --- a/ivy/functional/backends/paddle/set.py +++ b/ivy/functional/backends/paddle/set.py @@ -3,15 +3,13 @@ from typing import Tuple, Optional from collections import namedtuple import ivy.functional.backends.paddle as paddle_backend -from ivy.func_wrapper import with_unsupported_device_and_dtypes +from ivy.func_wrapper import with_unsupported_device_and_dtypes, with_unsupported_dtypes # local from . import backend_version -@with_unsupported_device_and_dtypes( - {"2.5.1 and below": {"cpu": ("complex",)}}, backend_version -) +@with_unsupported_dtypes({"2.5.1 and below": ("complex",)}, backend_version) def unique_all( x: paddle.Tensor, /, @@ -24,32 +22,45 @@ def unique_all( ["values", "indices", "inverse_indices", "counts"], ) - if x.dtype not in [paddle.int32, paddle.int64, paddle.float32, paddle.float64]: + if x.dtype in [ + paddle.int8, + paddle.int16, + paddle.uint8, + paddle.float16, + paddle.bool, + ]: x, x_dtype = x.cast("float32"), x.dtype else: x_dtype = x.dtype if axis is not None: axis = axis % x.ndim - values, indices, inverse_indices, counts = paddle.unique( + values, inverse_indices, counts = paddle.unique( x, - return_index=True, + return_index=False, # which occurences of the unique values are picked is inconsistent in some cases, so calculate the indices manually below return_counts=True, return_inverse=True, axis=axis, ) - nan_count = paddle.sum(paddle.isnan(x)) - if nan_count.item() > 0: - nan = paddle.to_tensor([float("nan")] * nan_count.item(), dtype=values.dtype) - values = paddle.concat((values, nan)) - nan_idx = paddle.nonzero(paddle.isnan(x).astype(float).flatten()).flatten() - indices = paddle.concat((indices, nan_idx)) - inverse_indices = paddle.put_along_axis( - arr=inverse_indices, indices=nan_idx, values=values.shape, axis=0 - ) - counts = paddle.concat( - (counts, paddle.ones(shape=nan_count, dtype=counts.dtype)) - ) + unique_nan = paddle.isnan(values) + idx_dtype = inverse_indices.dtype + if paddle.any(unique_nan): + nan_index = paddle.where(paddle.isnan(x)) + non_nan_index = [ + x.tolist().index(val) for val in values if not paddle.isnan(val) + ] + indices = values.clone().to(idx_dtype) + indices[unique_nan] = nan_index[0] + inverse_indices[paddle.isnan(x)] = paddle.where(unique_nan)[0][0] + counts[unique_nan] = 1 + indices[~unique_nan] = paddle.to_tensor(non_nan_index, dtype=idx_dtype) + else: + decimals = paddle.arange(inverse_indices.numel()) / inverse_indices.numel() + inv_sorted = (inverse_indices.astype(decimals.dtype) + decimals).argsort() + tot_counts = paddle.concat( + (paddle.zeros((1,), dtype=counts.dtype), counts.cumsum(axis=0)) + )[:-1] + indices = inv_sorted[tot_counts].astype(idx_dtype) if not by_value: sort_idx = paddle.argsort(indices) @@ -59,7 +70,12 @@ def unique_all( values_ = paddle.moveaxis(values, axis, 0) values_ = paddle.reshape(values_, (values_.shape[0], -1)) sort_idx = paddle.to_tensor( - [i[0] for i in sorted(list(enumerate(values_)), key=lambda x: tuple(x[1]))] + [ + i[0] + for i in sorted( + list(enumerate(values_.numpy().tolist())), key=lambda x: tuple(x[1]) + ) + ] ) values = paddle.gather(values, sort_idx, axis=axis) counts = paddle.gather(counts, sort_idx) From c68d91849c6eee682ce64bfe2a688814b2a892e6 Mon Sep 17 00:00:00 2001 From: Shreyansh Bardia <104841983+ShreyanshBardia@users.noreply.github.com> Date: Mon, 28 Aug 2023 17:29:10 +0530 Subject: [PATCH 29/55] added __array__ method to torch and tensorflow frontends (#22602) Co-authored-by: @AnnaTz --- .../frontends/numpy/ndarray/ndarray.py | 9 +-- ivy/functional/frontends/tensorflow/tensor.py | 6 +- ivy/functional/frontends/torch/tensor.py | 11 +++ .../test_numpy/test_ndarray/test_ndarray.py | 5 +- .../test_tensorflow/test_tensor.py | 42 +++++------- .../test_frontends/test_torch/test_tensor.py | 67 ++++++++++++++++++- 6 files changed, 103 insertions(+), 37 deletions(-) diff --git a/ivy/functional/frontends/numpy/ndarray/ndarray.py b/ivy/functional/frontends/numpy/ndarray/ndarray.py index 15486d5b14f94..db99f9bf688f0 100644 --- a/ivy/functional/frontends/numpy/ndarray/ndarray.py +++ b/ivy/functional/frontends/numpy/ndarray/ndarray.py @@ -571,14 +571,11 @@ def __abs__(self): def __array__(self, dtype=None, /): if not dtype: - return self - return np_frontend.array(self, dtype=dtype) + return ivy.to_numpy(self.ivy_array) + return ivy.to_numpy(self.ivy_array).astype(dtype) def __array_wrap__(self, array, context=None, /): - if context is None: - return np_frontend.array(array) - else: - return np_frontend.asarray(self) + return np_frontend.array(array) def __getitem__(self, key, /): ivy_args = ivy.nested_map([self, key], _to_ivy_array) diff --git a/ivy/functional/frontends/tensorflow/tensor.py b/ivy/functional/frontends/tensorflow/tensor.py index 8def8c906e831..9cd980aa35503 100644 --- a/ivy/functional/frontends/tensorflow/tensor.py +++ b/ivy/functional/frontends/tensorflow/tensor.py @@ -4,7 +4,7 @@ import ivy from ivy import with_unsupported_dtypes import ivy.functional.frontends.tensorflow as tf_frontend -from ivy.functional.frontends.tensorflow.func_wrapper import to_ivy_dtype, _to_ivy_array +from ivy.functional.frontends.tensorflow.func_wrapper import _to_ivy_array from ivy.functional.frontends.numpy.creation_routines.from_existing_data import array @@ -86,7 +86,9 @@ def __and__(self, y, name="and"): return self.__rand__(y) def __array__(self, dtype=None, name="array"): - return array(ivy.asarray(self.ivy_array, dtype=to_ivy_dtype(dtype))) + if not dtype: + return ivy.to_numpy(self.ivy_array) + return ivy.to_numpy(self.ivy_array).astype(dtype) def __bool__(self, name="bool"): temp = ivy.squeeze(self.ivy_array, axis=None) diff --git a/ivy/functional/frontends/torch/tensor.py b/ivy/functional/frontends/torch/tensor.py index 09ec33d2d5fd3..53b313b347ab8 100644 --- a/ivy/functional/frontends/torch/tensor.py +++ b/ivy/functional/frontends/torch/tensor.py @@ -1313,6 +1313,17 @@ def __invert__(self): def __and__(self, other): return torch_frontend.bitwise_and(self, other) + def __array__(self, dtype=None): + if dtype is None: + return ivy.to_numpy(self.ivy_array) + else: + return ivy.to_numpy(self.ivy_array).astype(dtype, copy=False) + + def __array_wrap__(self, array): + if array.dtype == bool: + array = array.astype("uint8") + return torch_frontend.tensor(array) + # Method aliases absolute, absolute_ = abs, abs_ clip, clip_ = clamp, clamp_ diff --git a/ivy_tests/test_ivy/test_frontends/test_numpy/test_ndarray/test_ndarray.py b/ivy_tests/test_ivy/test_frontends/test_numpy/test_ndarray/test_ndarray.py index 30be2eb149a9d..482eae0ee5897 100644 --- a/ivy_tests/test_ivy/test_frontends/test_numpy/test_ndarray/test_ndarray.py +++ b/ivy_tests/test_ivy/test_frontends/test_numpy/test_ndarray/test_ndarray.py @@ -3075,7 +3075,7 @@ def test_numpy___array__( }, method_input_dtypes=input_dtypes, method_all_as_kwargs_np={ - "dtype": input_dtypes[0], + "dtype": np.dtype(input_dtypes[0]), }, init_flags=init_flags, method_flags=method_flags, @@ -3091,8 +3091,7 @@ def test_numpy___array__( init_tree="numpy.array", method_name="__array_wrap__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - min_num_dims=1, + available_dtypes=helpers.get_dtypes("valid"), num_arrays=2, ), ) diff --git a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_tensor.py b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_tensor.py index 5357e0f131a6a..7f4da701acf2a 100644 --- a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_tensor.py +++ b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_tensor.py @@ -1,18 +1,18 @@ # global from hypothesis import strategies as st, given, assume import numpy as np +import tensorflow as tf # local import ivy import ivy_tests.test_ivy.helpers as helpers from ivy.functional.backends.tensorflow.general import _check_query -from ivy_tests.test_ivy.helpers import handle_frontend_method +from ivy_tests.test_ivy.helpers import handle_frontend_method, BackendHandler from ivy_tests.test_ivy.test_frontends.test_tensorflow.test_raw_ops import ( _pow_helper_shared_dtype, ) from ivy.functional.frontends.tensorflow import EagerTensor - CLASS_TREE = "ivy.functional.frontends.tensorflow.EagerTensor" @@ -1253,37 +1253,29 @@ def test_tensorflow__rmatmul__( class_tree=CLASS_TREE, init_tree="tensorflow.constant", method_name="__array__", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric") - ), - dtype=helpers.get_dtypes("float", full=False), + dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("valid")), + dtype=helpers.get_dtypes("valid", full=False), ) def test_tensorflow__array__( dtype_and_x, dtype, frontend, - frontend_method_data, - init_flags, - method_flags, backend_fw, - on_device, ): input_dtype, x = dtype_and_x - helpers.test_frontend_method( - init_input_dtypes=input_dtype, - backend_to_test=backend_fw, - init_all_as_kwargs_np={ - "value": x[0], - }, - method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "dtype": np.dtype(dtype[0]), - }, - frontend=frontend, - frontend_method_data=frontend_method_data, - init_flags=init_flags, - method_flags=method_flags, - on_device=on_device, + dtype[0] = np.dtype(dtype[0]) + ret_gt = tf.constant(x[0]).__array__(dtype[0]) + with BackendHandler.update_backend(backend_fw) as ivy_backend: + local_importer = ivy_backend.utils.dynamic_import + function_module = local_importer.import_module( + "ivy.functional.frontends.tensorflow" + ) + ret = function_module.constant(x[0]).__array__(dtype[0]) + helpers.value_test( + ret_np_flat=ret.ravel(), + ret_np_from_gt_flat=ret_gt.ravel(), + ground_truth_backend="tensorflow", + backend=backend_fw, ) diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py index e45dfb5acf803..ba83b4dea4975 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py @@ -32,7 +32,7 @@ _get_dtype_input_and_mat_vec, ) from ivy.functional.frontends.torch import Tensor -from ivy_tests.test_ivy.helpers import handle_frontend_method +from ivy_tests.test_ivy.helpers import handle_frontend_method, BackendHandler from ivy_tests.test_ivy.test_functional.test_core.test_searching import ( _broadcastable_trio, ) @@ -12437,3 +12437,68 @@ def test_torch_triu_( frontend=frontend, on_device=on_device, ) + + +@handle_frontend_method( + class_tree=CLASS_TREE, + init_tree="torch.tensor", + method_name="__array__", + dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("valid")), + dtype=helpers.get_dtypes("valid", full=False), +) +def test_torch__array__( + dtype_and_x, + dtype, + frontend, + backend_fw, +): + input_dtype, x = dtype_and_x + if x[0].dtype == "bfloat16": + return + dtype[0] = np.dtype(dtype[0]) + ret_gt = torch.tensor(x[0]).__array__(dtype[0]) + with BackendHandler.update_backend(backend_fw) as ivy_backend: + local_importer = ivy_backend.utils.dynamic_import + function_module = local_importer.import_module("ivy.functional.frontends.torch") + ret = function_module.tensor(x[0]).__array__(dtype[0]) + + helpers.value_test( + ret_np_flat=ret.ravel(), + ret_np_from_gt_flat=ret_gt.ravel(), + ground_truth_backend="torch", + backend=backend_fw, + ) + + +@handle_frontend_method( + class_tree=CLASS_TREE, + init_tree="torch.tensor", + method_name="__array_wrap__", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + num_arrays=2, + ), +) +def test_torch___array_wrap__( + dtype_and_x, + backend_fw, + frontend, +): + input_dtypes, x = dtype_and_x + if x[1].dtype == "bfloat16": + return + if x[0].dtype == "bfloat16": + ret_gt = torch.tensor(x[0].tolist(), dtype=torch.bfloat16).__array_wrap__(x[1]) + else: + ret_gt = torch.tensor(x[0]).__array_wrap__(x[1]) + with BackendHandler.update_backend(backend_fw) as ivy_backend: + local_importer = ivy_backend.utils.dynamic_import + function_module = local_importer.import_module("ivy.functional.frontends.torch") + ret = function_module.tensor(x[0]).__array_wrap__(x[1]) + assert isinstance(ret, function_module.Tensor) + helpers.value_test( + ret_np_flat=np.array(ret.ivy_array).ravel(), + ret_np_from_gt_flat=ret_gt.numpy().ravel(), + ground_truth_backend="torch", + backend=backend_fw, + ) From b11a9d12745bf33fe20fce8dce7ca8fdf2380d25 Mon Sep 17 00:00:00 2001 From: Nripesh Niketan <86844847+NripeshN@users.noreply.github.com> Date: Mon, 28 Aug 2023 17:35:38 +0530 Subject: [PATCH 30/55] Add "_" before composite functions tests (#22569) --- .../test_experimental/test_nn/test_layers.py | 30 +++++++-------- .../test_functional/test_nn/test_layers.py | 38 +++++++++---------- .../test_ivy/test_stateful/test_layers.py | 4 +- 3 files changed, 36 insertions(+), 36 deletions(-) diff --git a/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_layers.py b/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_layers.py index b70dacaaa6670..e83868b1639c0 100644 --- a/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_layers.py +++ b/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_layers.py @@ -325,7 +325,7 @@ def test_avg_pool3d( @st.composite -def valid_dct(draw): +def _valid_dct(draw): dtype, x = draw( helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), @@ -354,7 +354,7 @@ def valid_dct(draw): @handle_test( fn_tree="dct", - dtype_x_and_args=valid_dct(), + dtype_x_and_args=_valid_dct(), test_gradients=st.just(False), ) def test_dct(*, dtype_x_and_args, test_flags, backend_fw, fn_name, on_device): @@ -377,7 +377,7 @@ def test_dct(*, dtype_x_and_args, test_flags, backend_fw, fn_name, on_device): @handle_test( fn_tree="idct", - dtype_x_and_args=valid_dct(), + dtype_x_and_args=_valid_dct(), test_gradients=st.just(False), ) def test_idct(dtype_x_and_args, test_flags, backend_fw, fn_name, on_device): @@ -585,7 +585,7 @@ def test_interpolate( @st.composite -def x_and_fft(draw): +def _x_and_fft(draw): min_fft_points = 2 dtype = draw(helpers.get_dtypes("valid", full=False)) x_dim = draw( @@ -613,7 +613,7 @@ def x_and_fft(draw): @handle_test( fn_tree="functional.ivy.experimental.fft", - d_x_d_n_n=x_and_fft(), + d_x_d_n_n=_x_and_fft(), ground_truth_backend="jax", test_gradients=st.just(False), ) @@ -793,7 +793,7 @@ def test_dropout3d( @st.composite -def x_and_ifft(draw): +def _x_and_ifft(draw): min_fft_points = 2 dtype = draw(helpers.get_dtypes("complex")) x_dim = draw( @@ -817,7 +817,7 @@ def x_and_ifft(draw): @handle_test( fn_tree="functional.ivy.experimental.ifft", - d_x_d_n_n=x_and_ifft(), + d_x_d_n_n=_x_and_ifft(), test_gradients=st.just(False), ) def test_ifft(*, d_x_d_n_n, test_flags, backend_fw, fn_name): @@ -865,8 +865,8 @@ def test_embedding( @handle_test( fn_tree="dft", - d_xfft_axis_n_length=x_and_fft(), - d_xifft_axis_n_length=x_and_ifft(), + d_xfft_axis_n_length=_x_and_fft(), + d_xifft_axis_n_length=_x_and_ifft(), inverse=st.booleans(), onesided=st.booleans(), ) @@ -1089,7 +1089,7 @@ def test_reduce_window(*, all_args, test_flags, backend_fw, fn_name, on_device): @st.composite -def x_and_fft2(draw): +def _x_and_fft2(draw): min_fft2_points = 2 dtype = draw(helpers.get_dtypes("float_and_complex", full=False)) x_dim = draw( @@ -1120,7 +1120,7 @@ def x_and_fft2(draw): @handle_test( fn_tree="functional.ivy.experimental.fft2", - d_x_d_s_n=x_and_fft2(), + d_x_d_s_n=_x_and_fft2(), ground_truth_backend="numpy", container_flags=st.just([False]), test_gradients=st.just(False), @@ -1143,7 +1143,7 @@ def test_fft2(*, d_x_d_s_n, test_flags, backend_fw, fn_name, on_device): @st.composite -def x_and_ifftn(draw): +def _x_and_ifftn(draw): min_fft_points = 2 dtype = draw(helpers.get_dtypes("complex")) x_dim = draw( @@ -1180,7 +1180,7 @@ def x_and_ifftn(draw): @handle_test( fn_tree="functional.ivy.experimental.ifftn", - d_x_d_s_n=x_and_ifftn(), + d_x_d_s_n=_x_and_ifftn(), ground_truth_backend="numpy", test_gradients=st.just(False), ) @@ -1207,7 +1207,7 @@ def test_ifftn( @st.composite -def x_and_rfftn(draw): +def _x_and_rfftn(draw): min_rfftn_points = 2 dtype = draw(helpers.get_dtypes("float")) x_dim = draw( @@ -1242,7 +1242,7 @@ def x_and_rfftn(draw): @handle_test( fn_tree="functional.ivy.experimental.rfftn", - d_x_d_s_n=x_and_rfftn(), + d_x_d_s_n=_x_and_rfftn(), ground_truth_backend="numpy", test_gradients=st.just(False), ) diff --git a/ivy_tests/test_ivy/test_functional/test_nn/test_layers.py b/ivy_tests/test_ivy/test_functional/test_nn/test_layers.py index c5b9f4391b066..e0bbfa29038bb 100644 --- a/ivy_tests/test_ivy/test_functional/test_nn/test_layers.py +++ b/ivy_tests/test_ivy/test_functional/test_nn/test_layers.py @@ -15,7 +15,7 @@ # Linear # # -------# @st.composite -def x_and_linear(draw): +def _x_and_linear(draw): mixed_fn_compos = draw(st.booleans()) is_torch_backend = ivy.current_backend_str() == "torch" dtype = draw( @@ -64,7 +64,7 @@ def x_and_linear(draw): # linear @handle_test( fn_tree="functional.ivy.linear", - dtype_x_weight_bias=x_and_linear(), + dtype_x_weight_bias=_x_and_linear(), ) def test_linear(*, dtype_x_weight_bias, test_flags, backend_fw, fn_name, on_device): dtype, x, weight, bias = dtype_x_weight_bias @@ -166,7 +166,7 @@ def test_dropout( @st.composite -def x_and_scaled_attention(draw, dtypes): +def _x_and_scaled_attention(draw, dtypes): dtype = draw(dtypes) num_queries = draw(helpers.ints(min_value=2, max_value=4)) num_keys = draw(helpers.ints(min_value=2, max_value=4)) @@ -223,7 +223,7 @@ def x_and_scaled_attention(draw, dtypes): # scaled_dot_product_attention @handle_test( fn_tree="functional.ivy.scaled_dot_product_attention", - dtype_q_k_v_mask=x_and_scaled_attention( + dtype_q_k_v_mask=_x_and_scaled_attention( dtypes=helpers.get_dtypes("float", full=False), ), scale=st.floats(min_value=0.1, max_value=1), @@ -511,7 +511,7 @@ def test_multi_head_attention( @st.composite -def x_and_filters( +def _x_and_filters( draw, dim: int = 2, transpose: bool = False, @@ -686,7 +686,7 @@ def _assume_tf_dilation_gt_1(backend_fw, on_device, dilations): # conv1d @handle_test( fn_tree="functional.ivy.conv1d", - x_f_d_df=x_and_filters( + x_f_d_df=_x_and_filters( dim=1, bias=True, filter_format=st.sampled_from(["channel_last", "channel_first"]), @@ -731,7 +731,7 @@ def test_conv1d(*, x_f_d_df, test_flags, backend_fw, fn_name, on_device): # conv1d_transpose @handle_test( fn_tree="functional.ivy.conv1d_transpose", - x_f_d_df=x_and_filters( + x_f_d_df=_x_and_filters( dim=1, transpose=True, bias=True, @@ -775,7 +775,7 @@ def test_conv1d_transpose(*, x_f_d_df, test_flags, backend_fw, fn_name, on_devic # conv2d @handle_test( fn_tree="functional.ivy.conv2d", - x_f_d_df=x_and_filters( + x_f_d_df=_x_and_filters( dim=2, bias=True, filter_format=st.sampled_from(["channel_last", "channel_first"]), @@ -820,7 +820,7 @@ def test_conv2d(*, x_f_d_df, test_flags, backend_fw, fn_name, on_device): # conv2d_transpose @handle_test( fn_tree="functional.ivy.conv2d_transpose", - x_f_d_df=x_and_filters( + x_f_d_df=_x_and_filters( dim=2, transpose=True, bias=True, @@ -865,7 +865,7 @@ def test_conv2d_transpose(*, x_f_d_df, test_flags, backend_fw, fn_name, on_devic # depthwise_conv2d @handle_test( fn_tree="functional.ivy.depthwise_conv2d", - x_f_d_df=x_and_filters( + x_f_d_df=_x_and_filters( dim=2, depthwise=True, ), @@ -898,7 +898,7 @@ def test_depthwise_conv2d(*, x_f_d_df, test_flags, backend_fw, fn_name, on_devic # conv3d @handle_test( fn_tree="functional.ivy.conv3d", - x_f_d_df=x_and_filters( + x_f_d_df=_x_and_filters( dim=3, bias=True, filter_format=st.sampled_from(["channel_last", "channel_first"]), @@ -942,7 +942,7 @@ def test_conv3d(*, x_f_d_df, test_flags, backend_fw, fn_name, on_device): # conv3d_transpose @handle_test( fn_tree="functional.ivy.conv3d_transpose", - x_f_d_df=x_and_filters( + x_f_d_df=_x_and_filters( dim=3, transpose=True, bias=True, @@ -986,7 +986,7 @@ def test_conv3d_transpose(*, x_f_d_df, test_flags, backend_fw, fn_name, on_devic @handle_test( fn_tree="functional.ivy.conv_general_dilated", dims=st.shared(st.integers(1, 3), key="dims"), - x_f_d_df=x_and_filters( + x_f_d_df=_x_and_filters( dim=st.shared(st.integers(1, 3), key="dims"), general=True, bias=True, @@ -1035,7 +1035,7 @@ def test_conv_general_dilated( @handle_test( fn_tree="functional.ivy.conv_general_transpose", dims=st.shared(st.integers(1, 3), key="dims"), - x_f_d_df=x_and_filters( + x_f_d_df=_x_and_filters( dim=st.shared(st.integers(1, 3), key="dims"), general=True, transpose=True, @@ -1083,7 +1083,7 @@ def test_conv_general_transpose( # filter_format not in conv_general_transpose # output_shape not in conv_general_dilated @st.composite -def x_and_filters_and_transpose( +def _x_and_filters_and_transpose( draw, dim: int = 2, general=False, @@ -1094,7 +1094,7 @@ def x_and_filters_and_transpose( if not transpose: filter_format = st.sampled_from(["channel_last", "channel_first"]) all_args = draw( - x_and_filters( + _x_and_filters( dim=dim, general=general, bias=bias, @@ -1150,7 +1150,7 @@ def x_and_filters_and_transpose( @handle_test( fn_tree="functional.ivy.conv", dims=st.shared(st.integers(1, 3), key="dims"), - x_f_d_df_tr=x_and_filters_and_transpose( + x_f_d_df_tr=_x_and_filters_and_transpose( dim=st.shared(st.integers(1, 3), key="dims"), general=True, bias=True, @@ -1209,7 +1209,7 @@ def test_conv(*, dims, x_f_d_df_tr, test_flags, backend_fw, fn_name, on_device): @st.composite -def x_and_lstm(draw, dtypes): +def _x_and_lstm(draw, dtypes): dtype = draw(dtypes) batch_shape = (1,) @@ -1273,7 +1273,7 @@ def x_and_lstm(draw, dtypes): # lstm @handle_test( fn_tree="functional.ivy.lstm_update", - dtype_lstm=x_and_lstm( + dtype_lstm=_x_and_lstm( dtypes=helpers.get_dtypes("numeric"), ), test_with_out=st.just(False), diff --git a/ivy_tests/test_ivy/test_stateful/test_layers.py b/ivy_tests/test_ivy/test_stateful/test_layers.py index 327b948a51f26..91c6334671b57 100644 --- a/ivy_tests/test_ivy/test_stateful/test_layers.py +++ b/ivy_tests/test_ivy/test_stateful/test_layers.py @@ -183,7 +183,7 @@ def test_dropout_layer( # Attention # # ----------# @st.composite -def x_and_mha(draw): +def _x_and_mha(draw): dtype = draw( helpers.get_dtypes("float", full=False).filter(lambda x: x != ["float16"]) ) @@ -254,7 +254,7 @@ def x_and_mha(draw): # multi_head_attention @handle_method( method_tree="MultiHeadAttention.__call__", - dtype_mha=x_and_mha(), + dtype_mha=_x_and_mha(), init_with_v=st.booleans(), method_with_v=st.booleans(), method_num_positional_args=helpers.num_positional_args( From 4efbf3a26f3a24f9890db58d8dbad9c9c3394d80 Mon Sep 17 00:00:00 2001 From: Zaeem Ansari <99063526+zaeemansari70@users.noreply.github.com> Date: Mon, 28 Aug 2023 19:08:30 +0500 Subject: [PATCH 31/55] Error Handling (docs) : Added a new section in the docs for the common errors for the developers and the contributers while working with the functional or the experimental API --- docs/overview/contributing/error_handling.rst | 112 ++++++++++++++++++ 1 file changed, 112 insertions(+) create mode 100644 docs/overview/contributing/error_handling.rst diff --git a/docs/overview/contributing/error_handling.rst b/docs/overview/contributing/error_handling.rst new file mode 100644 index 0000000000000..44f4df41fee1b --- /dev/null +++ b/docs/overview/contributing/error_handling.rst @@ -0,0 +1,112 @@ +Error Handling +============== + +.. _`discord`: https://discord.gg/sXyFF8tDtm +.. _`pycharm channel`: https://discord.com/channels/799879767196958751/942114831039856730 +.. _`docker channel`: https://discord.com/channels/799879767196958751/942114744691740772 +.. _`pre-commit channel`: https://discord.com/channels/799879767196958751/982725464110034944 +.. _`pip packages channel`: https://discord.com/channels/799879767196958751/942114789642080317 +.. _`ivy tests channel`: https://discord.com/channels/799879767196958751/982738436383445073 + +This section, "Error Handling" aims to assist you in navigating through some common errors you might encounter while working with the Ivy's Functional API. We'll go through some common errors which you might encounter while working as a contributor or a developer. + +#. This is the case where we pass in a dtype to `torch` which is not actually supported by the torch's native framework itself. The function which was + + .. code-block:: python + + E RuntimeError: "logaddexp2_cpu" not implemented for 'Half' + E Falsifying example: test_logaddexp2( + E backend_fw='torch', + E on_device='cpu', + E dtype_and_x=(['float16', 'float16'], + E [array([-1.], dtype=float16), array([-1.], dtype=float16)]), + E test_flags=FunctionTestFlags( + E ground_truth_backend='tensorflow', + E num_positional_args=2, + E with_out=False, + E instance_method=False, + E test_gradients=False, + E test_compile=None, + E as_variable=[False], + E native_arrays=[False], + E container=[False], + E ), + E fn_name='logaddexp2', + E ) + E + E You can reproduce this example by temporarily adding @reproduce_failure('6.82.4', b'AXicY2BkAAMoBaaR2WAAAACVAAY=') as a decorator on your test case + +#. This is the case where the value from the ground-truth backend(tensorflow) does not match the value of the backend(jax) we are testing for this case. + + .. code-block:: python + + E AssertionError: the results from backend jax and ground truth framework tensorflow do not match + E 0.25830078125!=0.258544921875 + E + E + E Falsifying example: test_acosh( + E backend_fw='jax', + E on_device='cpu', + E dtype_and_x=(['float16'], [array(4., dtype=float16)]), + E test_flags=FunctionTestFlags( + E ground_truth_backend='tensorflow', + E num_positional_args=1, + E with_out=False, + E instance_method=False, + E test_gradients=True, + E test_compile=None, + E as_variable=[False], + E native_arrays=[False], + E container=[False], + E ), + E fn_name='acosh', + E ) + E + E You can reproduce this example by temporarily adding @reproduce_failure('6.82.4', b'AXicY2BAABYQwQgiAABDAAY=') as a decorator on your test case + +#. This is a similar assertion as stated in point 2 but with torch and ground-truth tensorflow not matching but the matrices are quite different so there should be an issue in the backends rather than a numerical instability here: + + .. code-block:: python + + E AssertionError: the results from backend torch and ground truth framework tensorflow do not match + E [[1.41421356 1.41421356 1.41421356] + E [1.41421356 1.41421356 1.41421356] + E [1.41421356 inf 1.41421356]]!=[[1.41421356e+000 1.41421356e+000 1.41421356e+000] + E [1.41421356e+000 1.41421356e+000 1.41421356e+000] + E [1.41421356e+000 1.34078079e+154 1.41421356e+000]] + E + E + E Falsifying example: test_abs( + E backend_fw='torch', + E on_device='cpu', + E dtype_and_x=(['complex128'], + E [array([[-1.-1.00000000e+000j, -1.-1.00000000e+000j, -1.-1.00000000e+000j], + E [-1.-1.00000000e+000j, -1.-1.00000000e+000j, -1.-1.00000000e+000j], + E [-1.-1.00000000e+000j, -1.-1.34078079e+154j, -1.-1.00000000e+000j]])]), + E fn_name='abs', + E test_flags=FunctionTestFlags( + E ground_truth_backend='tensorflow', + E num_positional_args=1, + E with_out=False, + E instance_method=False, + E test_gradients=False, + E test_compile=None, + E as_variable=[False], + E native_arrays=[False], + E container=[False], + E ), + E ) + E + E You can reproduce this example by temporarily adding @reproduce_failure('6.82.4', b'AXicY2ZkYAIiBiBgZIAAxqHEXsAAB7jUQAAAMtEAzQ==') as a decorator on your test case + + +**Note** + +This section is specifically targeted towards dealing with the Ivy Functional API and the Ivy Experimental API. + +**Round Up** + +This should have hopefully given you an understanding of how to deal with common errors while working with the the functional API. + +If you have any questions, please feel free to reach out on `discord`_ in the `ivy tests channel`_, `pycharm channel`_, `docker channel`_, `pre-commit channel`_, `pip packages channel`_ depending on the question! + From abc4cdde235eaada1f3af660d02930ce3345324a Mon Sep 17 00:00:00 2001 From: Zaeem Ansari <99063526+zaeemansari70@users.noreply.github.com> Date: Mon, 28 Aug 2023 19:19:47 +0500 Subject: [PATCH 32/55] Contributing (docs) : Added a reference to the link for the new section error_handling, so that it can render in the docs --- docs/overview/contributing.rst | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/overview/contributing.rst b/docs/overview/contributing.rst index 74992ad5a4165..c5b3b1aca787c 100644 --- a/docs/overview/contributing.rst +++ b/docs/overview/contributing.rst @@ -35,6 +35,9 @@ The contributor guide is split into the sections below, it's best to go from sta | | (g) :ref:`Helpful Resources` | Resources you would find useful when learning Ivy 📖 +| +| (g) :ref:`Error Handling` +| Common errors you will be facing contributing to Ivy ❌ .. toctree:: :hidden: @@ -48,6 +51,7 @@ The contributor guide is split into the sections below, it's best to go from sta contributing/open_tasks.rst contributing/applied_libraries.rst contributing/helpful_resources.rst + contributing/error_handling.rst **Video** From 06e97a5d7b628e97c98cfd60078a1f1aa5225937 Mon Sep 17 00:00:00 2001 From: SUSHMANTH REDDY <73489688+sushmanthreddy@users.noreply.github.com> Date: Mon, 28 Aug 2023 20:14:49 +0530 Subject: [PATCH 33/55] huber_loss (#22375) --- ivy/data_classes/array/experimental/losses.py | 48 +++++ .../container/experimental/losses.py | 166 ++++++++++++++++++ .../backends/jax/experimental/losses.py | 16 ++ .../backends/numpy/experimental/losses.py | 23 +++ .../backends/paddle/experimental/losses.py | 28 +++ .../tensorflow/experimental/losses.py | 22 +++ .../backends/torch/experimental/losses.py | 17 ++ ivy/functional/ivy/experimental/losses.py | 65 +++++++ .../test_experimental/test_nn/test_losses.py | 49 ++++++ 9 files changed, 434 insertions(+) diff --git a/ivy/data_classes/array/experimental/losses.py b/ivy/data_classes/array/experimental/losses.py index c3f7be946cc8a..173f1510f5a36 100644 --- a/ivy/data_classes/array/experimental/losses.py +++ b/ivy/data_classes/array/experimental/losses.py @@ -49,6 +49,54 @@ def l1_loss( """ return ivy.l1_loss(self._data, target, reduction=reduction, out=out) + def huber_loss( + self: ivy.Array, + pred: Union[ivy.Array, ivy.NativeArray], + /, + *, + reduction: Optional[str] = "mean", + delta: Optional[float] = 1.0, + out: Optional[ivy.Array] = None, + ) -> ivy.Array: + """ + ivy.Array instance method variant of huber_loss. This method simply wraps the + function, and so the docstring for huber_loss also applies to this method with + minimal changes. + + Parameters + ---------- + self + The true (ground truth) values. + pred + The predicted values by the model. + reduction : str, optional + The type of reduction to apply to the loss. + Possible values are "mean" (default) + and "sum". + delta + The threshold parameter that determines the point where the loss transitions + from squared error to absolute error. Default is 1.0. + out + Optional output array, for writing the result to. It must have a shape + that the inputs broadcast to. + + Returns + ------- + ret + The Huber loss between the true and predicted values. + + Examples + -------- + >>> true = ivy.array([2, 4, 7, 1]) + >>> pred = ivy.array([2.5, 3.5, 8, 0.8]) + >>> loss = true.huber_loss(pred, delta=1.0) + >>> print(loss) + ivy.array([0.125, 0.125, 0.5 , 0.125]) + """ + return ivy.huber_loss( + self._data, pred, reduction=reduction, delta=delta, out=out + ) + def smooth_l1_loss( self: ivy.Array, target: Union[ivy.Array, ivy.NativeArray], diff --git a/ivy/data_classes/container/experimental/losses.py b/ivy/data_classes/container/experimental/losses.py index 3d9aa2510579a..048ef9d01a3ba 100644 --- a/ivy/data_classes/container/experimental/losses.py +++ b/ivy/data_classes/container/experimental/losses.py @@ -328,3 +328,169 @@ def smooth_l1_loss( map_sequences=map_sequences, out=out, ) + + @staticmethod + def _static_huber_loss( + true: Union[ivy.Container, ivy.Array, ivy.NativeArray], + pred: Union[ivy.Container, ivy.Array, ivy.NativeArray], + /, + *, + delta: Optional[Union[float, ivy.Container]] = 1.0, + reduction: Optional[Union[str, ivy.Container]] = "mean", + key_chains: Optional[Union[List[str], Dict[str, str], ivy.Container]] = None, + to_apply: Union[bool, ivy.Container] = True, + prune_unapplied: Union[bool, ivy.Container] = False, + map_sequences: Union[bool, ivy.Container] = False, + out: Optional[ivy.Container] = None, + ) -> ivy.Container: + """ + ivy.Container static method variant of huber_loss. This method simply wraps the + function, and so the docstring for huber_loss also applies to this method with + minimal changes. + + Parameters + ---------- + true + true array or container containing true labels. + pred + true array or container containing the predicted labels. + delta + The threshold parameter that determines the point where the loss transitions + from squared error to absolute error. Default is 1.0. + reduction : str, optional + The type of reduction to apply to the loss. + Possible values are "mean" (default) + and "sum". + key_chains + The key-chains to apply or not apply the method to. Default is ``None``. + to_apply + If true, the method will be applied to key_chains, otherwise key_chains + will be skipped. Default is ``true``. + prune_unapplied + Whether to prune key_chains for which the function was not applied. + Default is ``False``. + map_sequences + Whether to also map method to sequences (lists, tuples). + Default is ``False``. + out + optional output container, for writing the result to. It must have a shape + that the true broadcast to. + + Returns + ------- + ret + The Huber loss between the true and predicted values. + + Examples + -------- + With :class:`ivy.Container` trues: + + >>> x = ivy.Container(a=ivy.array([1, 0, 3]), b=ivy.array([0, 0, 2])) + >>> y = ivy.Container(a=ivy.array([1.5, 0.2, 2.8]), b=ivy.array([0.5, 0.2, 1.9]) + ) + >>> z = ivy.Container.static_huber_loss(x, y, delta=1.0) + >>> print(z) + { + a: ivy.array(0.0575), + b: ivy.array(0.005) + } + + With a mix of :class:`ivy.Array` and :class:`ivy.Container` trues: + + >>> x = ivy.array([1, 0, 3]) + >>> y = ivy.Container(a=ivy.array([1.5, 0.2, 2.8]), b=ivy.array([0.5, 0.2, 1.9]) + ) + >>> z = ivy.Container.static_huber_loss(x, y, delta=1.0) + >>> print(z) + { + a: ivy.array(0.0575), + b: ivy.array(0.005) + } + """ + return ContainerBase.cont_multi_map_in_function( + "huber_loss", + true, + pred, + delta=delta, + reduction=reduction, + key_chains=key_chains, + to_apply=to_apply, + prune_unapplied=prune_unapplied, + map_sequences=map_sequences, + out=out, + ) + + def huber_loss( + self: ivy.Container, + pred: Union[ivy.Container, ivy.Array, ivy.NativeArray], + /, + *, + delta: Optional[Union[float, ivy.Container]] = 1.0, + reduction: Optional[Union[str, ivy.Container]] = "mean", + key_chains: Optional[Union[List[str], Dict[str, str], ivy.Container]] = None, + to_apply: Union[bool, ivy.Container] = True, + prune_unapplied: Union[bool, ivy.Container] = False, + map_sequences: Union[bool, ivy.Container] = False, + out: Optional[ivy.Container] = None, + ) -> ivy.Container: + """ + ivy.Container instance method variant of huber_loss. This method simply wraps + the function, and so the docstring for huber_loss also applies to this method + with minimal changes. + + Parameters + ---------- + self + true container containing true labels. + pred + true array or container containing the predicted labels. + delta + The threshold parameter that determines the point where the loss transitions + from squared error to absolute error. Default is 1.0. + reduction : str, optional + The type of reduction to apply to the loss. + Possible values are "mean" (default) + and "sum". + key_chains + The key-chains to apply or not apply the method to. Default is ``None``. + to_apply + If true, the method will be applied to key_chains, otherwise key_chains + will be skipped. Default is ``true``. + prune_unapplied + Whether to prune key_chains for which the function was not applied. + Default is ``False``. + map_sequences + Whether to also map method to sequences (lists, tuples). + Default is ``False``. + out + optional output container, for writing the result to. It must have a shape + that the trues broadcast to. + + Returns + ------- + ret + The Huber loss between the true and predicted values. + + Examples + -------- + >>> x = ivy.Container(a=ivy.array([1, 0, 3]), b=ivy.array([0, 0, 2])) + >>> y = ivy.Container(a=ivy.array([1.5, 0.2, 2.8]), b=ivy.array([0.5, 0.2, 1.9]) + ) + >>> z = x.huber_loss(y, delta=1.0) + >>> print(z) + { + a: ivy.array(0.0575), + b: ivy.array(0.005) + } + """ + return self._static_huber_loss( + self, + pred, + delta=delta, + reduction=reduction, + key_chains=key_chains, + to_apply=to_apply, + prune_unapplied=prune_unapplied, + map_sequences=map_sequences, + out=out, + ) diff --git a/ivy/functional/backends/jax/experimental/losses.py b/ivy/functional/backends/jax/experimental/losses.py index 2709e81d9306d..160bd6760b78b 100644 --- a/ivy/functional/backends/jax/experimental/losses.py +++ b/ivy/functional/backends/jax/experimental/losses.py @@ -3,6 +3,22 @@ from ivy.functional.backends.jax import JaxArray +def huber_loss( + input: JaxArray, target: JaxArray, /, *, delta: float = 1.0, reduction: str = "mean" +) -> JaxArray: + residual = jnp.abs(input - target) + quadratic_loss = 0.5 * (residual**2) + linear_loss = delta * residual - 0.5 * (delta**2) + loss = jnp.where(residual < delta, quadratic_loss, linear_loss) + + if reduction == "mean": + loss = jnp.mean(loss) + elif reduction == "sum": + loss = jnp.sum(loss) + + return loss + + def smooth_l1_loss( input: JaxArray, target: JaxArray, diff --git a/ivy/functional/backends/numpy/experimental/losses.py b/ivy/functional/backends/numpy/experimental/losses.py index fae08dc5edd63..439406e93c6fc 100644 --- a/ivy/functional/backends/numpy/experimental/losses.py +++ b/ivy/functional/backends/numpy/experimental/losses.py @@ -5,6 +5,29 @@ from . import backend_version +@with_unsupported_dtypes({"1.25.2 and below": ("bool",)}, backend_version) +@_scalar_output_to_0d_array +def huber_loss( + input: np.ndarray, + target: np.ndarray, + /, + *, + delta: Optional[float] = 1.0, + reduction: Optional[str] = "mean", +) -> np.ndarray: + abs_diff = np.abs(input - target) + quadratic_loss = 0.5 * (abs_diff**2) + linear_loss = delta * (abs_diff - 0.5 * delta) + loss = np.where(abs_diff <= delta, quadratic_loss, linear_loss) + + if reduction == "sum": + return np.sum(loss) + elif reduction == "mean": + return np.mean(loss) + else: + return loss + + # Implementation of smooth_l1_loss in the given format @with_unsupported_dtypes({"1.25.2 and below": ("bool",)}, backend_version) @_scalar_output_to_0d_array diff --git a/ivy/functional/backends/paddle/experimental/losses.py b/ivy/functional/backends/paddle/experimental/losses.py index 843de2f24d250..70441eb4655f3 100644 --- a/ivy/functional/backends/paddle/experimental/losses.py +++ b/ivy/functional/backends/paddle/experimental/losses.py @@ -64,3 +64,31 @@ def smooth_l1_loss( return paddle.nn.functional.smooth_l1_loss( input, target, reduction=reduction, beta=beta ) + + +@with_unsupported_device_and_dtypes( + { + "2.5.1 and below": { + "cpu": ( + "float16", + "int8", + "int16", + "int32", + "int64", + "uint8", + "complex64", + "complex128", + "bool", + ) + } + }, + backend_version, +) +def huber_loss( + input: paddle.Tensor, + target: paddle.Tensor, + /, + *, + delta: Optional[float] = 1.0, +) -> paddle.Tensor: + return paddle.fluid.layers.huber_loss(input, target, delta=delta) diff --git a/ivy/functional/backends/tensorflow/experimental/losses.py b/ivy/functional/backends/tensorflow/experimental/losses.py index db165249ef644..19a6d4a54fa7b 100644 --- a/ivy/functional/backends/tensorflow/experimental/losses.py +++ b/ivy/functional/backends/tensorflow/experimental/losses.py @@ -4,6 +4,28 @@ from . import backend_version +@with_unsupported_dtypes({"2.13.0 and below": "bool"}, backend_version) +def huber_loss( + input: tf.Tensor, + target: tf.Tensor, + /, + *, + delta: Optional[float] = 1.0, + reduction: Optional[str] = "mean", +) -> tf.Tensor: + abs_diff = tf.abs(input - target) + quadratic_loss = 0.5 * (abs_diff**2) + linear_loss = delta * (abs_diff - 0.5 * delta) + loss = tf.where(abs_diff <= delta, quadratic_loss, linear_loss) + + if reduction == "sum": + return tf.sum(loss) + elif reduction == "mean": + return tf.mean(loss) + else: + return loss + + @with_unsupported_dtypes({"2.13.0 and below": "bool"}, backend_version) def smooth_l1_loss( input: tf.Tensor, diff --git a/ivy/functional/backends/torch/experimental/losses.py b/ivy/functional/backends/torch/experimental/losses.py index 77099de513a4c..db3442685a678 100644 --- a/ivy/functional/backends/torch/experimental/losses.py +++ b/ivy/functional/backends/torch/experimental/losses.py @@ -52,3 +52,20 @@ def smooth_l1_loss( beta=beta, reduction=reduction, ) + + +@with_unsupported_dtypes( + {"2.0.1 and below": ("uint8", "int8", "int16", "int32", "int64", "bool")}, + backend_version, +) +def huber_loss( + input: torch.Tensor, + target: torch.Tensor, + /, + *, + reduction: Optional[str] = "mean", + delta: Optional[float] = 1.0, +) -> torch.Tensor: + return torch.nn.functional.huber_loss( + input, target, reduction=reduction, delta=delta + ) diff --git a/ivy/functional/ivy/experimental/losses.py b/ivy/functional/ivy/experimental/losses.py index bcbb5128116e8..7b9f55d010c5a 100644 --- a/ivy/functional/ivy/experimental/losses.py +++ b/ivy/functional/ivy/experimental/losses.py @@ -154,6 +154,71 @@ def l1_loss( return ivy.inplace_update(out, loss) if out is not None else loss +@handle_exceptions +@handle_nestable +@handle_array_like_without_promotion +@inputs_to_ivy_arrays +@handle_array_function +def huber_loss( + true: Union[ivy.Array, ivy.NativeArray], + pred: Union[ivy.Array, ivy.NativeArray], + /, + *, + delta: Optional[float] = 1.0, + reduction: Optional[str] = "mean", + out: Optional[ivy.Array] = None, +) -> ivy.Array: + """ + Compute the Huber loss (smooth L1 loss) between true and predicted values. + + Parameters + ---------- + true: array_like + The true (ground truth) values. + pred : array_like + The predicted values by the model. + delta : float, optional + The threshold parameter that determines the point where the loss transitions fro + -m + squared error to absolute error. Default is 1.0. + reduction : str, optional + The type of reduction to apply to the loss. Possible values are "mean" (default) + and "sum". + out : array_like, optional + Optional output array, for writing the result to. It must have a shape + that the inputs broadcast to. + + Returns + ------- + ret : array_like + The Huber loss between the true and predicted values. + + Examples + -------- + >>> true = ivy.array([2, 4, 7, 1]) + >>> pred = ivy.array([2.5, 3.5, 8, 0.8]) + >>> huber_loss(true, pred, delta=1.0) + ivy.array([0.125, 0.125, 0.5 , 0.125]) + + >>> huber_loss(true, pred, delta=2.0) + ivy.array([0.125, 0.125, 0.5 , 0.2 ]) + + >>> huber_loss(true, pred, delta=0.5) + ivy.array([0.25 , 0.25 , 0. , 0.125]) + """ + abs_diff = ivy.abs(true - pred) + quadratic_loss = 0.5 * (abs_diff**2) + linear_loss = delta * (abs_diff - 0.5 * delta) + loss = ivy.where(abs_diff <= delta, quadratic_loss, linear_loss) + + if reduction == "sum": + return ivy.sum(loss, out=out) + elif reduction == "mean": + return ivy.mean(loss, out=out) + else: + return ivy.inplace_update(out, loss) if out is not None else loss + + @handle_exceptions @handle_nestable @handle_array_like_without_promotion diff --git a/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_losses.py b/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_losses.py index a59fbf12cf704..77520704697f3 100644 --- a/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_losses.py +++ b/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_losses.py @@ -151,3 +151,52 @@ def test_smooth_l1_loss( beta=beta, reduction=reduction, ) + + +# huber_loss +@handle_test( + fn_tree="functional.ivy.experimental.huber_loss", + dtype_and_true=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + min_value=-10, + max_value=10, + allow_inf=False, + min_num_dims=1, + max_num_dims=3, + min_dim_size=3, + ), + dtype_and_pred=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + min_value=-10, + max_value=10, + allow_inf=False, + min_num_dims=1, + max_num_dims=3, + min_dim_size=3, + ), + reduction=st.sampled_from(["none", "sum", "mean"]), + delta=helpers.floats(min_value=0.01, max_value=2.0), +) +def test_huber_loss( + dtype_and_true, + dtype_and_pred, + reduction, + delta, + test_flags, + backend_fw, + fn_name, + on_device, +): + true_dtype, true = dtype_and_true + pred_dtype, pred = dtype_and_pred + helpers.test_function( + input_dtypes=true_dtype + pred_dtype, + test_flags=test_flags, + backend_to_test=backend_fw, + fn_name=fn_name, + on_device=on_device, + true=true[0], + pred=pred[0], + reduction=reduction, + delta=delta, + ) From 84722d6a610ac218ca5e9f4913db976c4b63dffb Mon Sep 17 00:00:00 2001 From: seif mohamed <60402868+Seif-Mohamed1@users.noreply.github.com> Date: Mon, 28 Aug 2023 18:03:24 +0300 Subject: [PATCH 34/55] Added amin function to Paddle Mathematical Functions (#22632) Co-authored-by: Mahmoud Ashraf --- .../frontends/paddle/tensor/math.py | 7 ++++ .../test_paddle/test_tensor/test_math.py | 33 +++++++++++++++++++ 2 files changed, 40 insertions(+) diff --git a/ivy/functional/frontends/paddle/tensor/math.py b/ivy/functional/frontends/paddle/tensor/math.py index 75242ca8e62a4..4df4724f52019 100644 --- a/ivy/functional/frontends/paddle/tensor/math.py +++ b/ivy/functional/frontends/paddle/tensor/math.py @@ -502,3 +502,10 @@ def tanh(x, name=None): @to_ivy_arrays_and_back def trunc(x, name=None): return ivy.trunc(x) + +@with_supported_dtypes( + {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, "paddle" +) +@to_ivy_arrays_and_back +def amin(x, axis=None, keepdim=False, name=None): + return ivy.min(x, axis=axis, keepdims=keepdim) diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_math.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_math.py index 50b4e48de810d..2bac85f674262 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_math.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_math.py @@ -2137,3 +2137,36 @@ def test_paddle_stanh( scale_a=scale_a, scale_b=scale_b, ) + + +# amin +@handle_frontend_test( + fn_tree="paddle.tensor.math.amin", + dtype_and_x=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("valid"), + valid_axis=True, + ), + keepdim=st.booleans(), +) +def test_paddle_amin( + *, + dtype_and_x, + keepdim, + on_device, + fn_tree, + backend_fw, + frontend, + test_flags, +): + input_dtype, x, axis = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + backend_to_test=backend_fw, + fn_tree=fn_tree, + test_flags=test_flags, + on_device=on_device, + x=x[0], + axis = axis, + keepdim = keepdim, + ) From a2a4e7e7ac393a6599eac403d26d36fa8c7e61aa Mon Sep 17 00:00:00 2001 From: Nripesh Niketan <86844847+NripeshN@users.noreply.github.com> Date: Tue, 29 Aug 2023 01:09:08 +0530 Subject: [PATCH 35/55] Add "_" before composite functions tests (#22710) --- .../test_jax/test__src/test_tree_util.py | 6 ++-- .../test_jax/test_lax/test_operators.py | 8 ++--- .../test_function/test_mindspore_nn_func.py | 8 ++--- .../test_scipy/test_fft/test_fft.py | 30 +++++++++---------- .../test_tensorflow/test_signal.py | 6 ++-- .../test_convolution_functions.py | 14 ++++----- .../test_functional/test_linear_functions.py | 4 +-- .../test_non_linear_activation_functions.py | 4 +-- .../test_functional/test_core/test_general.py | 4 +-- .../test_misc/test_backend_handler.py | 3 +- 10 files changed, 43 insertions(+), 44 deletions(-) diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test__src/test_tree_util.py b/ivy_tests/test_ivy/test_frontends/test_jax/test__src/test_tree_util.py index 022122dd69332..a72c24c973ca1 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test__src/test_tree_util.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test__src/test_tree_util.py @@ -32,14 +32,14 @@ def tree_strategy(max_depth=2): @st.composite -def tree_dict_strategy(draw): +def _tree_dict_strategy(draw): return draw(tree_strategy()) # tree_leaves @handle_frontend_test( fn_tree="jax._src.tree_util.tree_leaves", - tree=tree_dict_strategy(), + tree=_tree_dict_strategy(), ) def test_jax_tree_leaves( *, @@ -65,7 +65,7 @@ def test_jax_tree_leaves( # tree_map @handle_frontend_test( fn_tree="jax._src.tree_util.tree_map", - tree=tree_dict_strategy(), + tree=_tree_dict_strategy(), ) def test_jax_tree_map( *, diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_lax/test_operators.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_lax/test_operators.py index 4ed4f7e81bd96..455f0e5d26de3 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_lax/test_operators.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_lax/test_operators.py @@ -2032,7 +2032,7 @@ def test_jax_dot_general( @st.composite -def x_and_filters(draw, dim=2, transpose=False, general=False): +def _x_and_filters(draw, dim=2, transpose=False, general=False): if not isinstance(dim, int): dim = draw(dim) batch_size = draw(st.integers(1, 5)) @@ -2156,7 +2156,7 @@ def x_and_filters(draw, dim=2, transpose=False, general=False): @handle_frontend_test( fn_tree="jax.lax.conv", - x_f_d_other=x_and_filters(), + x_f_d_other=_x_and_filters(), test_with_out=st.just(False), ) def test_jax_conv( @@ -2188,7 +2188,7 @@ def test_jax_conv( @handle_frontend_test( fn_tree="jax.lax.conv_transpose", - x_f_d_other=x_and_filters(general=True, transpose=True), + x_f_d_other=_x_and_filters(general=True, transpose=True), test_with_out=st.just(False), ) def test_jax_conv_transpose( @@ -2223,7 +2223,7 @@ def test_jax_conv_transpose( @handle_frontend_test( fn_tree="jax.lax.conv_general_dilated", - x_f_d_other=x_and_filters(general=True), + x_f_d_other=_x_and_filters(general=True), test_with_out=st.just(False), ) def test_jax_conv_general_dilated( diff --git a/ivy_tests/test_ivy/test_frontends/test_mindspore/test_ops/test_function/test_mindspore_nn_func.py b/ivy_tests/test_ivy/test_frontends/test_mindspore/test_ops/test_function/test_mindspore_nn_func.py index aa152fd6c5131..1eb8bded0d086 100644 --- a/ivy_tests/test_ivy/test_frontends/test_mindspore/test_ops/test_function/test_mindspore_nn_func.py +++ b/ivy_tests/test_ivy/test_frontends/test_mindspore/test_ops/test_function/test_mindspore_nn_func.py @@ -614,7 +614,7 @@ def test_mindspore_flatten( @st.composite -def x_and_filters(draw, dim: int = 2): +def _x_and_filters(draw, dim: int = 2): if not isinstance(dim, int): dim = draw(dim) strides = draw( @@ -705,7 +705,7 @@ def x_and_filters(draw, dim: int = 2): @pytest.mark.skip("Testing pipeline not yet implemented") @handle_frontend_test( fn_tree="mindspore.ops.function.nn_func.Conv1d", - dtype_vals=x_and_filters(dim=1), + dtype_vals=_x_and_filters(dim=1), ) def test_mindspore_conv1d( *, @@ -738,7 +738,7 @@ def test_mindspore_conv1d( @pytest.mark.skip("Testing pipeline not yet implemented") @handle_frontend_test( fn_tree="mindspore.ops.function.nn_func.Conv2d", - dtype_vals=x_and_filters(dim=2), + dtype_vals=_x_and_filters(dim=2), ) def test_mindspore_conv2d( *, @@ -771,7 +771,7 @@ def test_mindspore_conv2d( @pytest.mark.skip("Testing pipeline not yet implemented") @handle_frontend_test( fn_tree="mindspore.ops.function.nn_func.Conv3d", - dtype_vals=x_and_filters(dim=3), + dtype_vals=_x_and_filters(dim=3), ) def test_mindspore_conv3d( *, diff --git a/ivy_tests/test_ivy/test_frontends/test_scipy/test_fft/test_fft.py b/ivy_tests/test_ivy/test_frontends/test_scipy/test_fft/test_fft.py index 8384db2426fd7..c13d8e761a0fc 100644 --- a/ivy_tests/test_ivy/test_frontends/test_scipy/test_fft/test_fft.py +++ b/ivy_tests/test_ivy/test_frontends/test_scipy/test_fft/test_fft.py @@ -10,7 +10,7 @@ @st.composite -def x_and_fft(draw, dtypes): +def _x_and_fft(draw, dtypes): min_fft_points = 2 dtype = draw(dtypes) x_dim = draw( @@ -33,7 +33,7 @@ def x_and_fft(draw, dtypes): @st.composite -def x_and_ifft(draw): +def _x_and_ifft(draw): min_fft_points = 2 dtype = draw(helpers.get_dtypes("complex")) x_dim = draw( @@ -56,7 +56,7 @@ def x_and_ifft(draw): @st.composite -def valid_dct(draw): +def _valid_dct(draw): dtype, x = draw( helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), @@ -84,7 +84,7 @@ def valid_dct(draw): @st.composite -def valid_idct(draw): +def _valid_idct(draw): dtype, x = draw( helpers.dtype_and_values( available_dtypes=["float32", "float64"], @@ -107,7 +107,7 @@ def valid_idct(draw): @st.composite -def x_and_fft2(draw): +def _x_and_fft2(draw): min_fft2_points = 2 dtype = draw(helpers.get_dtypes("float_and_complex", full=False)) x, dim = draw( @@ -128,10 +128,10 @@ def x_and_fft2(draw): @st.composite -def x_and_ifftn(draw): - x_and_ifftn = draw(x_and_fft2()) +def _x_and_ifftn(draw): + _x_and_ifftn = draw(_x_and_fft2()) workers = draw(st.integers(1, 4)) - return x_and_ifftn + (workers,) + return _x_and_ifftn + (workers,) # Tests @@ -140,7 +140,7 @@ def x_and_ifftn(draw): # fft @handle_frontend_test( fn_tree="scipy.fft.fft", - d_x_d_n_n=x_and_fft(helpers.get_dtypes("complex")), + d_x_d_n_n=_x_and_fft(helpers.get_dtypes("complex")), test_with_out=st.just(False), ) def test_scipy_fft( @@ -167,7 +167,7 @@ def test_scipy_fft( # ifft @handle_frontend_test( fn_tree="scipy.fft.ifft", - d_x_d_n_n=x_and_ifft(), + d_x_d_n_n=_x_and_ifft(), test_with_out=st.just(False), ) def test_scipy_ifft( @@ -194,7 +194,7 @@ def test_scipy_ifft( # dct @handle_frontend_test( fn_tree="scipy.fft.dct", - dtype_x_and_args=valid_dct(), + dtype_x_and_args=_valid_dct(), test_with_out=st.just(False), ) def test_scipy_dct( @@ -224,7 +224,7 @@ def test_scipy_dct( # idct @handle_frontend_test( fn_tree="scipy.fft.idct", - dtype_x_and_args=valid_idct(), + dtype_x_and_args=_valid_idct(), test_with_out=st.just(False), ) def test_scipy_idct( @@ -254,7 +254,7 @@ def test_scipy_idct( # fft2 @handle_frontend_test( fn_tree="scipy.fft.fft2", - d_x_d_s_n=x_and_fft2(), + d_x_d_s_n=_x_and_fft2(), test_with_out=st.just(False), ) def test_scipy_fft2( @@ -281,7 +281,7 @@ def test_scipy_fft2( # ifftn @handle_frontend_test( fn_tree="scipy.fft.ifftn", - d_x_d_s_n_workers=x_and_ifftn(), + d_x_d_s_n_workers=_x_and_ifftn(), test_with_out=st.just(False), ) def test_scipy_ifftn( @@ -309,7 +309,7 @@ def test_scipy_ifftn( # rfftn @handle_frontend_test( fn_tree="scipy.fft.rfftn", - d_x_d_s_n_workers=x_and_ifftn(), + d_x_d_s_n_workers=_x_and_ifftn(), test_with_out=st.just(False), ) def test_scipy_rfftn( diff --git a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_signal.py b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_signal.py index 1b9504baeaf68..894e7d26989b7 100644 --- a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_signal.py +++ b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_signal.py @@ -45,7 +45,7 @@ def test_tensorflow_kaiser_window( @st.composite -def valid_idct(draw): +def _valid_idct(draw): dtype, x = draw( helpers.dtype_and_values( available_dtypes=["float32", "float64"], @@ -68,7 +68,7 @@ def valid_idct(draw): # idct @handle_frontend_test( fn_tree="tensorflow.signal.idct", - dtype_x_and_args=valid_idct(), + dtype_x_and_args=_valid_idct(), test_with_out=st.just(False), ) def test_tensorflow_idct( @@ -111,7 +111,7 @@ def test_tensorflow_idct( n=helpers.ints(min_value=1, max_value=3), norm=st.sampled_from([None, "ortho"]), type=helpers.ints(min_value=1, max_value=4), - # dtype_x_and_args=valid_idct(), + # dtype_x_and_args=_valid_idct(), test_with_out=st.just(False), ) def test_tensorflow_dct( diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_convolution_functions.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_convolution_functions.py index 248251bee51c7..be32c522fb53f 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_convolution_functions.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_convolution_functions.py @@ -12,7 +12,7 @@ @st.composite -def x_and_filters(draw, dim: int = 2, transpose: bool = False): +def _x_and_filters(draw, dim: int = 2, transpose: bool = False): if not isinstance(dim, int): dim = draw(dim) strides = draw( @@ -139,7 +139,7 @@ def x_and_filters(draw, dim: int = 2, transpose: bool = False): @handle_frontend_test( fn_tree="torch.nn.functional.conv1d", - dtype_vals=x_and_filters(dim=1), + dtype_vals=_x_and_filters(dim=1), ) def test_torch_conv1d( *, @@ -170,7 +170,7 @@ def test_torch_conv1d( @handle_frontend_test( fn_tree="torch.nn.functional.conv2d", - dtype_vals=x_and_filters(dim=2), + dtype_vals=_x_and_filters(dim=2), ) def test_torch_conv2d( *, @@ -201,7 +201,7 @@ def test_torch_conv2d( @handle_frontend_test( fn_tree="torch.nn.functional.conv3d", - dtype_vals=x_and_filters(dim=3), + dtype_vals=_x_and_filters(dim=3), ) def test_torch_conv3d( *, @@ -251,7 +251,7 @@ def _output_shape( @handle_frontend_test( fn_tree="torch.nn.functional.conv_transpose1d", - dtype_vals=x_and_filters(dim=1, transpose=True), + dtype_vals=_x_and_filters(dim=1, transpose=True), ) def test_torch_conv_tranpose1d( *, @@ -292,7 +292,7 @@ def test_torch_conv_tranpose1d( @handle_frontend_test( fn_tree="torch.nn.functional.conv_transpose2d", - dtype_vals=x_and_filters(dim=2, transpose=True), + dtype_vals=_x_and_filters(dim=2, transpose=True), ) def test_torch_conv_tranpose2d( *, @@ -333,7 +333,7 @@ def test_torch_conv_tranpose2d( @handle_frontend_test( fn_tree="torch.nn.functional.conv_transpose3d", - dtype_vals=x_and_filters(dim=3, transpose=True), + dtype_vals=_x_and_filters(dim=3, transpose=True), ) def test_torch_conv_tranpose3d( *, diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_linear_functions.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_linear_functions.py index 0bcb364137fb0..2a2a3dd0f59fe 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_linear_functions.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_linear_functions.py @@ -7,7 +7,7 @@ @st.composite -def x_and_linear(draw, dtypes): +def _x_and_linear(draw, dtypes): dtype = draw(dtypes) in_features = draw(helpers.ints(min_value=1, max_value=2)) out_features = draw(helpers.ints(min_value=1, max_value=2)) @@ -37,7 +37,7 @@ def x_and_linear(draw, dtypes): # linear @handle_frontend_test( fn_tree="torch.nn.functional.linear", - dtype_x_weight_bias=x_and_linear( + dtype_x_weight_bias=_x_and_linear( dtypes=helpers.get_dtypes("float", full=False), ), ) diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_non_linear_activation_functions.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_non_linear_activation_functions.py index e07bcd8b68fe7..dd423523c004b 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_non_linear_activation_functions.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_non_linear_activation_functions.py @@ -1492,7 +1492,7 @@ def test_torch_multi_head_attention_forward( @st.composite -def x_and_scaled_attention(draw, dtypes): +def _x_and_scaled_attention(draw, dtypes): dtype = draw(dtypes) num_queries = draw(helpers.ints(min_value=2, max_value=4)) num_keys = draw(helpers.ints(min_value=2, max_value=4)) @@ -1549,7 +1549,7 @@ def x_and_scaled_attention(draw, dtypes): # scaled_dot_product_attention @handle_frontend_test( fn_tree="torch.nn.functional.scaled_dot_product_attention", - dtype_q_k_v_mask=x_and_scaled_attention( + dtype_q_k_v_mask=_x_and_scaled_attention( dtypes=helpers.get_dtypes("float"), ), dropout_p=st.floats(min_value=0, max_value=0.99), diff --git a/ivy_tests/test_ivy/test_functional/test_core/test_general.py b/ivy_tests/test_ivy/test_functional/test_core/test_general.py index 4203755e54956..30bb022a7f4cf 100644 --- a/ivy_tests/test_ivy/test_functional/test_core/test_general.py +++ b/ivy_tests/test_ivy/test_functional/test_core/test_general.py @@ -473,7 +473,7 @@ def test_clip_vector_norm( @st.composite -def values_and_ndindices( +def _values_and_ndindices( draw, *, array_dtypes, @@ -606,7 +606,7 @@ def test_scatter_flat(x, reduction, test_flags, backend_fw, fn_name, on_device): # scatter_nd @handle_test( fn_tree="functional.ivy.scatter_nd", - x=values_and_ndindices( + x=_values_and_ndindices( # ToDo: needs support for boolean arrays array_dtypes=helpers.get_dtypes("numeric"), indices_dtypes=["int32", "int64"], diff --git a/ivy_tests/test_ivy/test_misc/test_backend_handler.py b/ivy_tests/test_ivy/test_misc/test_backend_handler.py index 888546ce4869d..0185ad0857c22 100644 --- a/ivy_tests/test_ivy/test_misc/test_backend_handler.py +++ b/ivy_tests/test_ivy/test_misc/test_backend_handler.py @@ -186,10 +186,9 @@ def test_choose_random_backend(excluded): # Dynamic Backend backends = list(_backend_dict.keys()) -backend_combinations = [(a, b) for a in backends for b in backends if a != b] -@pytest.mark.parametrize("middle_backend,end_backend", backend_combinations) +@pytest.mark.parametrize("middle_backend,end_backend", [(a, b) for a in backends for b in backends if a != b]) def test_dynamic_backend_all_combos(middle_backend, end_backend): # create an ivy array, container and native container a = ivy.array([1, 2, 3]) From c2d84cd93ca19f390a8581b63f4ea3ab3a3bf444 Mon Sep 17 00:00:00 2001 From: umairjavaid Date: Tue, 29 Aug 2023 06:30:53 +0500 Subject: [PATCH 36/55] Sklearn decisiontree (#22722) --- ivy/functional/frontends/sklearn/base.py | 5 + .../frontends/sklearn/tree/_classes.py | 150 ++++++++++++++++++ 2 files changed, 155 insertions(+) create mode 100644 ivy/functional/frontends/sklearn/tree/_classes.py diff --git a/ivy/functional/frontends/sklearn/base.py b/ivy/functional/frontends/sklearn/base.py index 2e8efd6078ee3..1db8a5caf63e5 100644 --- a/ivy/functional/frontends/sklearn/base.py +++ b/ivy/functional/frontends/sklearn/base.py @@ -26,3 +26,8 @@ def fit(self, X, y, **kwargs): def predict(self, X): raise NotImplementedError + + +class MultiOutputMixin: + def _more_tags(self): + return {"multioutput": True} \ No newline at end of file diff --git a/ivy/functional/frontends/sklearn/tree/_classes.py b/ivy/functional/frontends/sklearn/tree/_classes.py new file mode 100644 index 0000000000000..c8352f80b8250 --- /dev/null +++ b/ivy/functional/frontends/sklearn/tree/_classes.py @@ -0,0 +1,150 @@ +from abc import ABCMeta, abstractmethod +from ..base import ( + BaseEstimator, + ClassifierMixin, + MultiOutputMixin, +) + +class BaseDecisionTree(MultiOutputMixin, BaseEstimator, metaclass=ABCMeta): + + @abstractmethod + def __init__( + self, + *, + criterion, + splitter, + max_depth, + min_samples_split, + min_samples_leaf, + min_weight_fraction_leaf, + max_features, + max_leaf_nodes, + random_state, + min_impurity_decrease, + class_weight=None, + ccp_alpha=0.0, + ): + self.criterion = criterion + self.splitter = splitter + self.max_depth = max_depth + self.min_samples_split = min_samples_split + self.min_samples_leaf = min_samples_leaf + self.min_weight_fraction_leaf = min_weight_fraction_leaf + self.max_features = max_features + self.max_leaf_nodes = max_leaf_nodes + self.random_state = random_state + self.min_impurity_decrease = min_impurity_decrease + self.class_weight = class_weight + self.ccp_alpha = ccp_alpha + + def get_depth(self): + raise NotImplementedError + + + def get_n_leaves(self): + raise NotImplementedError + + + def _support_missing_values(self, X): + raise NotImplementedError + + + def _compute_missing_values_in_feature_mask(self, X): + raise NotImplementedError + + + def _fit( + self, + X, + y, + sample_weight=None, + check_input=True, + missing_values_in_feature_mask=None, + ): + raise NotImplementedError + + + def _validate_X_predict(self, X, check_input): + raise NotImplementedError + + + def predict(self, X, check_input=True): + raise NotImplementedError + + def apply(self, X, check_input=True): + raise NotImplementedError + + + def decision_path(self, X, check_input=True): + raise NotImplementedError + + + def _prune_tree(self): + raise NotImplementedError + + + def cost_complexity_pruning_path(self, X, y, sample_weight=None): + raise NotImplementedError + + + @property + def feature_importances_(self): + raise NotImplementedError + + +class DecisionTreeClassifier(ClassifierMixin, BaseDecisionTree): + + def __init__( + self, + *, + criterion="gini", + splitter="best", + max_depth=None, + min_samples_split=2, + min_samples_leaf=1, + min_weight_fraction_leaf=0.0, + max_features=None, + random_state=None, + max_leaf_nodes=None, + min_impurity_decrease=0.0, + class_weight=None, + ccp_alpha=0.0, + ): + super().__init__( + criterion=criterion, + splitter=splitter, + max_depth=max_depth, + min_samples_split=min_samples_split, + min_samples_leaf=min_samples_leaf, + min_weight_fraction_leaf=min_weight_fraction_leaf, + max_features=max_features, + max_leaf_nodes=max_leaf_nodes, + class_weight=class_weight, + random_state=random_state, + min_impurity_decrease=min_impurity_decrease, + ccp_alpha=ccp_alpha, + ) + + def fit(self, X, y, sample_weight=None, check_input=True): + + super()._fit( + X, + y, + sample_weight=sample_weight, + check_input=check_input, + ) + return self + + def predict_proba(self, X, check_input=True): + raise NotImplementedError + + def predict_log_proba(self, X): + raise NotImplementedError + + def _more_tags(self): + allow_nan = self.splitter == "best" and self.criterion in { + "gini", + "log_loss", + "entropy", + } + return {"multilabel": True, "allow_nan": allow_nan} \ No newline at end of file From 7646fe53b672c6100b1cb74b4099a56907da8346 Mon Sep 17 00:00:00 2001 From: sherry30 <65318415+sherry30@users.noreply.github.com> Date: Tue, 29 Aug 2023 08:45:36 +0500 Subject: [PATCH 37/55] Updated `devices.rst` to also include handling of the `device` argument --- docs/overview/deep_dive/devices.rst | 30 ++++++++++++++++++----------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/docs/overview/deep_dive/devices.rst b/docs/overview/deep_dive/devices.rst index bb96a0af07170..091571a72ded3 100644 --- a/docs/overview/deep_dive/devices.rst +++ b/docs/overview/deep_dive/devices.rst @@ -158,7 +158,7 @@ doesn't care about this, it moves all the tensors to the same device before perf In Ivy, users can control the device on which the operation is to be executed using `ivy.set_soft_device_mode`_ flag. There are two cases for this, either the soft device mode is set to :code:`True` or :code:`False`. -1. When :code:`ivy.set_soft_device_mode(True)`: +**When ivy.set_soft_device_mode(True)**: a. All the input arrays are moved to :code:`ivy.default_device()` while performing an operation. If the array is already present in the default device, no device shifting is done. @@ -174,7 +174,14 @@ are moved to :code:`ivy.default_device()` while performing :code:`ivy.add` opera y = ivy.array([34], device="gpu:0") ivy.add(x, y) -2. When :code:`ivy.set_soft_device_mode(False)`: +The priority of device shifting is following in this mode: + +#. The ``device`` argument. +#. device the arrays are on. +#. :code:`default_device` + + +**When ivy.set_soft_device_mode(False)**: a. If any of the input arrays are on a different device, a device exception is raised. @@ -226,18 +233,16 @@ The code to handle all these cases are present inside `@handle_device_shifting`_ all the functions that accept at least one array as input(except mixed and compositional functions) in `ivy.functional.ivy`_ submodule. The decorator calls :code:`ivy.handle_soft_device_variable` function under the hood to handle device shifting for each backend. -**Soft Device Handling Function** +The priority of device shifting is following in this mode: -There is a backend specific implementation of :code:`ivy.handle_soft_device_variable` function for numpy and tensorflow. The reason being, for numpy there -is no need for device shifting as it only support 'cpu' device, whereas, tensorflow automatically moves the inputs to 'gpu' if one is available and there is no way to turn this -off globally. +#. The ``device`` argument. +#. :code:`default_device` -The `numpy soft device handling function`_ just returns the inputs of the operation as it is without making any changes. -Whereas the `tensorflow soft device handling function`_ move the input arrays to :code:`ivy.default_device()` using -`tf.device`_ context manager. +**Soft Device Handling Function** + +This is a function which plays a crucial role in the :code:`handle_device_shifting` decorator. The purpose of this function is to ensure that the function :code:`fn` passed to it is executed on the device passed in :code:`device_shifting_dev` argument. If it is passed as :code:`None`, then the function will be executed on the default device. -For the rest of the frameworks, the `ivy implementation`_ of soft device handling function is used, which loops through -the inputs of the function and move the arrays to :code:`ivy.default_device()`, if not already on that device. +Most of the backend implementations are very similar, first they move all the arrays to the desired device using :code:`ivy.nested_map` and then execute the function inside the device handling context manager from that native framework. The prupose of executing the function inside the context manager is to handle the functions that do not accept any arrays, the only way in that case to let the native framework know on which device we want the function to be executed on is through the context manager. This approach is used in most backend implementations with the exceptions being tensorflow, where we dont have to move all the tensors to the desired device because just using its context manager is enough, it moves all the tensors itself internally, and numpy, since it only accepts `cpu` as device. **Forcing Operations on User Specified Device** @@ -258,6 +263,9 @@ context manager. So from now on, all the operations will be executed on 'cpu' de On exiting the context manager(`__exit__`_ method), the default device and soft device mode is reset to the previous state using `ivy.unset_default_device()`_ and `ivy.unset_soft_device_mode()`_ respectively, to move back to the previous state. +There are some functions(mostly creation function) which accept a :code:`device` argument. This is for specifying on which device the function is executed on and the device of the returned array. :code:`handle_device_shifting` deals with this argument by first checking if it exists and then setting :code:`device_shifting_dev` to that which is then passed to the :code:`handle_soft_device_variable` function depending on the :code:`soft_device` mode. + + **Round Up** This should have hopefully given you a good feel for devices, and how these are handled in Ivy. From 9be7dfba05eee56009a11c24e6f2be0a4db73a01 Mon Sep 17 00:00:00 2001 From: Sarvesh Kesharwani Date: Tue, 29 Aug 2023 09:44:16 +0530 Subject: [PATCH 38/55] fromfunction (#21432) --- ivy/functional/frontends/jax/numpy/logic.py | 28 +++++++++- .../test_jax/test_numpy/test_logic.py | 54 +++++++++++++++++++ 2 files changed, 81 insertions(+), 1 deletion(-) diff --git a/ivy/functional/frontends/jax/numpy/logic.py b/ivy/functional/frontends/jax/numpy/logic.py index d93a852b28423..8435587202cfd 100644 --- a/ivy/functional/frontends/jax/numpy/logic.py +++ b/ivy/functional/frontends/jax/numpy/logic.py @@ -260,7 +260,9 @@ def right_shift(x1, x2, /): @to_ivy_arrays_and_back -@with_unsupported_dtypes({"0.4.14 and below": ("bfloat16", "bool")}, "jax") +@with_unsupported_dtypes( + {"0.4.14 and below": ("bfloat16", 'bool')}, "jax" +) def setxor1d(ar1, ar2, assume_unique=False): common_dtype = ivy.promote_types(ivy.dtype(ar1), ivy.dtype(ar2)) ar1 = ivy.asarray(ar1, dtype=common_dtype) @@ -287,3 +289,27 @@ def setxor1d(ar1, ar2, assume_unique=False): alltrue = all sometrue = any + +@to_ivy_arrays_and_back +@with_unsupported_dtypes({"0.4.14 and below": ("bfloat16",)}, "jax") +def fromfunction(function, shape, *, dtype=float, **kwargs): + def canonicalize_shape(shape, context="shape argument"): + if isinstance(shape, int): + return (shape,) + elif isinstance(shape, list): + return tuple(shape) + elif isinstance(shape, tuple): + return shape + else: + msg = "{} must be an int, list, or tuple, but got {}." + raise TypeError(msg.format(context, type(shape))) + + arr = ivy.zeros(shape, dtype=dtype) + shape = canonicalize_shape(shape) + # Iterate over the indices of the array + for indices in ivy.ndindex(shape): + f_indices = indices + ivy.set_nest_at_index( + arr, f_indices, ivy.asarray(function(*indices, **kwargs), dtype=dtype) + ) + return arr diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_logic.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_logic.py index 81bdc4ef7f786..430c2adf2c988 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_logic.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_logic.py @@ -1148,3 +1148,57 @@ def test_jax_numpy_packbits( bitorder=bitorder, backend_to_test=backend_fw, ) + + +@st.composite +def _func_and_shape_dtype_helper(draw): + # here assumption is that the input func will take the len(shape) no of parameters + def add_numbers(*args): + total = 0 + for num in args: + total += num + return total + + shape = draw( + helpers.get_shape( + allow_none=False, + min_num_dims=1, + max_num_dims=3, + min_dim_size=1, + max_dim_size=3, + ) + ) + + dtype = draw(helpers.get_dtypes("valid")) + + return add_numbers, shape, dtype[0] + + +# fromfunction +@handle_frontend_test( + fn_tree="jax.numpy.fromfunction", + input_dtype=helpers.get_dtypes("valid"), + function_and_shape_and_dtype=_func_and_shape_dtype_helper(), + test_with_out=st.just(False), +) +def test_jax_numpy_fromfunction( + input_dtype, + function_and_shape_and_dtype, + backend_fw, + frontend, + on_device, + fn_tree, + test_flags, +): + function, shape, dtype = function_and_shape_and_dtype + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + function=function, + shape=shape, + dtype=dtype, + ) From 8871d2a4ea8b4daafe0564398cb743f52cd3c4d3 Mon Sep 17 00:00:00 2001 From: sherry30 <65318415+sherry30@users.noreply.github.com> Date: Tue, 29 Aug 2023 09:25:39 +0500 Subject: [PATCH 39/55] Update `ivy_frontends_tests.rst` and added description of a few arguments in the `test_frontend_method` --- docs/overview/deep_dive/ivy_frontends_tests.rst | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/docs/overview/deep_dive/ivy_frontends_tests.rst b/docs/overview/deep_dive/ivy_frontends_tests.rst index f602525617e0e..137889ba50ce0 100644 --- a/docs/overview/deep_dive/ivy_frontends_tests.rst +++ b/docs/overview/deep_dive/ivy_frontends_tests.rst @@ -629,7 +629,11 @@ for example, :code:`ndarray.__add__` would expect an array as input, despite the - :code:`init_tree` A full path to initialization function. - :code:`method_name` The name of the method to test. -:func:`helpers.test_frontend_method` is used to test frontend instance methods. It is used in the same way as :func:`helpers.test_frontend_function`. +:func:`helpers.test_frontend_method` is used to test frontend instance methods. It is used in the same way as :func:`helpers.test_frontend_function`. A few important arguments for this function are following: + - :code:`init_input_dtypes` Input dtypes of the arguments on which we are initializing the array on. + - :code:`init_all_as_kwargs_np` The data to be passed when intializing, this will be a dictionary in which the numpy array which will contain the data will be passed in the :code:`data` key. + - :code:`method_input_dtypes` The input dtypes of the arguemnt which are to be passed to the instance method after the intialization of the array. + - :code:`method_all_as_kwargs_np` All the arguments which are to be passed to instance method. Frontend Instance Method Test Examples @@ -822,4 +826,4 @@ If you have any questions, please feel free to reach out on `discord`_ in the `i \ No newline at end of file + From 5e60112465c1056ed1d1bb06cf4ca532e1a83b39 Mon Sep 17 00:00:00 2001 From: Adithya Palle <35441394+4di03@users.noreply.github.com> Date: Tue, 29 Aug 2023 01:17:39 -0400 Subject: [PATCH 40/55] (feat: frontends) added torch.chain_matmul to pytorch frontend (#21013) Co-authored-by: Yusha Arif<101613943+YushaArif99@users.noreply.github.com> --- .../frontends/torch/blas_and_lapack_ops.py | 5 ++ .../test_torch/test_blas_and_lapack_ops.py | 72 ++++++++++++++++++- 2 files changed, 76 insertions(+), 1 deletion(-) diff --git a/ivy/functional/frontends/torch/blas_and_lapack_ops.py b/ivy/functional/frontends/torch/blas_and_lapack_ops.py index 808c58d1176bc..9a13858480d3c 100644 --- a/ivy/functional/frontends/torch/blas_and_lapack_ops.py +++ b/ivy/functional/frontends/torch/blas_and_lapack_ops.py @@ -76,6 +76,11 @@ def baddbmm(input, batch1, batch2, *, beta=1, alpha=1, out=None): return ivy.add(beta_input, ret, out=out) +@to_ivy_arrays_and_back +def chain_matmul(*matrices, out=None): + return ivy.multi_dot(matrices, out=out) + + @to_ivy_arrays_and_back def bmm(input, mat2, *, out=None): if len(ivy.shape(input)) != 3 or len(ivy.shape(mat2)) != 3: diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_blas_and_lapack_ops.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_blas_and_lapack_ops.py index 12ad52d6aff57..701c2b85c36e5 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_blas_and_lapack_ops.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_blas_and_lapack_ops.py @@ -392,6 +392,76 @@ def test_torch_baddbmm( ) +@st.composite +def _generate_chain_matmul_dtype_and_arrays(draw): + dtype = draw(helpers.get_dtypes("float", full=True)) + input_dtype = [ + draw(st.sampled_from(tuple(set(dtype).difference({"bfloat16", "float16"})))) + ] + matrices_dims = draw( + st.lists(st.integers(min_value=2, max_value=10), min_size=4, max_size=4) + ) + shape_1 = (matrices_dims[0], matrices_dims[1]) + shape_2 = (matrices_dims[1], matrices_dims[2]) + shape_3 = (matrices_dims[2], matrices_dims[3]) + + matrix_1 = draw( + helpers.dtype_and_values( + shape=shape_1, + dtype=input_dtype, + min_value=-10, + max_value=10, + ) + ) + matrix_2 = draw( + helpers.dtype_and_values( + shape=shape_2, + dtype=input_dtype, + min_value=-10, + max_value=10, + ) + ) + matrix_3 = draw( + helpers.dtype_and_values( + shape=shape_3, + dtype=input_dtype, + min_value=-10, + max_value=10, + ) + ) + + return input_dtype, [matrix_1[1][0], matrix_2[1][0], matrix_3[1][0]] + + +# chain_matmul +@handle_frontend_test( + fn_tree="torch.chain_matmul", + dtype_and_matrices=_generate_chain_matmul_dtype_and_arrays(), +) +def test_torch_chain_matmul( + *, + dtype_and_matrices, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + dtype, matrices = dtype_and_matrices + args = {f"x{i}": matrix for i, matrix in enumerate(matrices)} + test_flags.num_positional_args = len(matrices) + helpers.test_frontend_function( + input_dtypes=dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + rtol=1e-03, + **args, + ) + + # bmm @handle_frontend_test( fn_tree="torch.bmm", @@ -894,5 +964,5 @@ def test_torch_trapezoid( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - **kwargs + **kwargs, ) From 47359a574231bd476f54fea63ee4fd8105d68991 Mon Sep 17 00:00:00 2001 From: Iggy Galang <80394307+Kalachuchis@users.noreply.github.com> Date: Tue, 29 Aug 2023 15:04:54 +0800 Subject: [PATCH 41/55] feat(paddle frontend) add paddle creation clone (#22695) Co-authored-by: iggy --- .../frontends/paddle/tensor/creation.py | 8 ++++++ .../test_paddle/test_tensor/test_creation.py | 26 +++++++++++++++++++ 2 files changed, 34 insertions(+) diff --git a/ivy/functional/frontends/paddle/tensor/creation.py b/ivy/functional/frontends/paddle/tensor/creation.py index 9ea67c4ce749f..7ed0e10847fac 100644 --- a/ivy/functional/frontends/paddle/tensor/creation.py +++ b/ivy/functional/frontends/paddle/tensor/creation.py @@ -29,6 +29,14 @@ def assign(x, output=None): return ret +@with_unsupported_dtypes( + {"2.5.1 and below": ("bfloat16", "uint16", "uint32", "uint64")}, "paddle" +) +@to_ivy_arrays_and_back +def clone(x): + return ivy.copy_array(x) + + @with_supported_dtypes( {"2.5.1 and below": ("float32", "float64")}, "paddle", diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_creation.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_creation.py index bd9f251b06f43..208f3061b95c5 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_creation.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_creation.py @@ -760,3 +760,29 @@ def test_paddle_linspace( num=num, dtype=dtype[0], ) + + +# clone +@handle_frontend_test( + fn_tree="paddle.clone", + dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("valid")), +) +def test_paddle_clone( + *, + dtype_and_x, + test_flags, + frontend, + backend_fw, + fn_tree, + on_device, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + ) From 39ac558d107ba2f2af99c7cca5025a0740e270b8 Mon Sep 17 00:00:00 2001 From: ivy-branch Date: Tue, 29 Aug 2023 08:06:57 +0000 Subject: [PATCH 42/55] =?UTF-8?q?=F0=9F=A4=96=20Lint=20code?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ivy/functional/frontends/jax/numpy/logic.py | 53 +++++++++---------- .../frontends/paddle/tensor/math.py | 15 +++--- ivy/functional/frontends/sklearn/base.py | 2 +- .../frontends/sklearn/tree/_classes.py | 18 ++----- .../frontends/torch/blas_and_lapack_ops.py | 10 ++-- .../test_paddle/test_tensor/test_math.py | 4 +- .../test_misc/test_backend_handler.py | 4 +- 7 files changed, 48 insertions(+), 58 deletions(-) diff --git a/ivy/functional/frontends/jax/numpy/logic.py b/ivy/functional/frontends/jax/numpy/logic.py index 8435587202cfd..24bf9716677f0 100644 --- a/ivy/functional/frontends/jax/numpy/logic.py +++ b/ivy/functional/frontends/jax/numpy/logic.py @@ -100,6 +100,31 @@ def equal(x1, x2, /): return ivy.equal(x1, x2) +@to_ivy_arrays_and_back +@with_unsupported_dtypes({"0.4.14 and below": ("bfloat16",)}, "jax") +def fromfunction(function, shape, *, dtype=float, **kwargs): + def canonicalize_shape(shape, context="shape argument"): + if isinstance(shape, int): + return (shape,) + elif isinstance(shape, list): + return tuple(shape) + elif isinstance(shape, tuple): + return shape + else: + msg = "{} must be an int, list, or tuple, but got {}." + raise TypeError(msg.format(context, type(shape))) + + arr = ivy.zeros(shape, dtype=dtype) + shape = canonicalize_shape(shape) + # Iterate over the indices of the array + for indices in ivy.ndindex(shape): + f_indices = indices + ivy.set_nest_at_index( + arr, f_indices, ivy.asarray(function(*indices, **kwargs), dtype=dtype) + ) + return arr + + @to_ivy_arrays_and_back def greater(x1, x2, /): x1, x2 = promote_jax_arrays(x1, x2) @@ -260,9 +285,7 @@ def right_shift(x1, x2, /): @to_ivy_arrays_and_back -@with_unsupported_dtypes( - {"0.4.14 and below": ("bfloat16", 'bool')}, "jax" -) +@with_unsupported_dtypes({"0.4.14 and below": ("bfloat16", "bool")}, "jax") def setxor1d(ar1, ar2, assume_unique=False): common_dtype = ivy.promote_types(ivy.dtype(ar1), ivy.dtype(ar2)) ar1 = ivy.asarray(ar1, dtype=common_dtype) @@ -289,27 +312,3 @@ def setxor1d(ar1, ar2, assume_unique=False): alltrue = all sometrue = any - -@to_ivy_arrays_and_back -@with_unsupported_dtypes({"0.4.14 and below": ("bfloat16",)}, "jax") -def fromfunction(function, shape, *, dtype=float, **kwargs): - def canonicalize_shape(shape, context="shape argument"): - if isinstance(shape, int): - return (shape,) - elif isinstance(shape, list): - return tuple(shape) - elif isinstance(shape, tuple): - return shape - else: - msg = "{} must be an int, list, or tuple, but got {}." - raise TypeError(msg.format(context, type(shape))) - - arr = ivy.zeros(shape, dtype=dtype) - shape = canonicalize_shape(shape) - # Iterate over the indices of the array - for indices in ivy.ndindex(shape): - f_indices = indices - ivy.set_nest_at_index( - arr, f_indices, ivy.asarray(function(*indices, **kwargs), dtype=dtype) - ) - return arr diff --git a/ivy/functional/frontends/paddle/tensor/math.py b/ivy/functional/frontends/paddle/tensor/math.py index 4df4724f52019..00f91ca3103ea 100644 --- a/ivy/functional/frontends/paddle/tensor/math.py +++ b/ivy/functional/frontends/paddle/tensor/math.py @@ -57,6 +57,14 @@ def amax(x, axis=None, keepdims=False): return ivy.max(x, axis=axis, keepdims=keepdims) +@with_supported_dtypes( + {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, "paddle" +) +@to_ivy_arrays_and_back +def amin(x, axis=None, keepdim=False, name=None): + return ivy.min(x, axis=axis, keepdims=keepdim) + + @with_supported_dtypes( {"2.5.1 and below": ("complex64", "complex128", "float32", "float64")}, "paddle", @@ -502,10 +510,3 @@ def tanh(x, name=None): @to_ivy_arrays_and_back def trunc(x, name=None): return ivy.trunc(x) - -@with_supported_dtypes( - {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, "paddle" -) -@to_ivy_arrays_and_back -def amin(x, axis=None, keepdim=False, name=None): - return ivy.min(x, axis=axis, keepdims=keepdim) diff --git a/ivy/functional/frontends/sklearn/base.py b/ivy/functional/frontends/sklearn/base.py index 1db8a5caf63e5..410c0fd4fe082 100644 --- a/ivy/functional/frontends/sklearn/base.py +++ b/ivy/functional/frontends/sklearn/base.py @@ -30,4 +30,4 @@ def predict(self, X): class MultiOutputMixin: def _more_tags(self): - return {"multioutput": True} \ No newline at end of file + return {"multioutput": True} diff --git a/ivy/functional/frontends/sklearn/tree/_classes.py b/ivy/functional/frontends/sklearn/tree/_classes.py index c8352f80b8250..d2e8dc8b17c30 100644 --- a/ivy/functional/frontends/sklearn/tree/_classes.py +++ b/ivy/functional/frontends/sklearn/tree/_classes.py @@ -5,8 +5,8 @@ MultiOutputMixin, ) -class BaseDecisionTree(MultiOutputMixin, BaseEstimator, metaclass=ABCMeta): +class BaseDecisionTree(MultiOutputMixin, BaseEstimator, metaclass=ABCMeta): @abstractmethod def __init__( self, @@ -39,20 +39,16 @@ def __init__( def get_depth(self): raise NotImplementedError - def get_n_leaves(self): raise NotImplementedError - def _support_missing_values(self, X): raise NotImplementedError - def _compute_missing_values_in_feature_mask(self, X): raise NotImplementedError - def _fit( self, X, @@ -63,37 +59,30 @@ def _fit( ): raise NotImplementedError - def _validate_X_predict(self, X, check_input): raise NotImplementedError - def predict(self, X, check_input=True): raise NotImplementedError - + def apply(self, X, check_input=True): raise NotImplementedError - def decision_path(self, X, check_input=True): raise NotImplementedError - def _prune_tree(self): raise NotImplementedError - def cost_complexity_pruning_path(self, X, y, sample_weight=None): raise NotImplementedError - @property def feature_importances_(self): raise NotImplementedError class DecisionTreeClassifier(ClassifierMixin, BaseDecisionTree): - def __init__( self, *, @@ -126,7 +115,6 @@ def __init__( ) def fit(self, X, y, sample_weight=None, check_input=True): - super()._fit( X, y, @@ -147,4 +135,4 @@ def _more_tags(self): "log_loss", "entropy", } - return {"multilabel": True, "allow_nan": allow_nan} \ No newline at end of file + return {"multilabel": True, "allow_nan": allow_nan} diff --git a/ivy/functional/frontends/torch/blas_and_lapack_ops.py b/ivy/functional/frontends/torch/blas_and_lapack_ops.py index 9a13858480d3c..172b3e41272a7 100644 --- a/ivy/functional/frontends/torch/blas_and_lapack_ops.py +++ b/ivy/functional/frontends/torch/blas_and_lapack_ops.py @@ -76,11 +76,6 @@ def baddbmm(input, batch1, batch2, *, beta=1, alpha=1, out=None): return ivy.add(beta_input, ret, out=out) -@to_ivy_arrays_and_back -def chain_matmul(*matrices, out=None): - return ivy.multi_dot(matrices, out=out) - - @to_ivy_arrays_and_back def bmm(input, mat2, *, out=None): if len(ivy.shape(input)) != 3 or len(ivy.shape(mat2)) != 3: @@ -89,6 +84,11 @@ def bmm(input, mat2, *, out=None): return ivy.matmul(input, mat2, out=out) +@to_ivy_arrays_and_back +def chain_matmul(*matrices, out=None): + return ivy.multi_dot(matrices, out=out) + + @to_ivy_arrays_and_back def cholesky(input, upper=False, *, out=None): return ivy.cholesky(input, upper=upper, out=out) diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_math.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_math.py index 2bac85f674262..de856f6f19fcd 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_math.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_math.py @@ -2167,6 +2167,6 @@ def test_paddle_amin( test_flags=test_flags, on_device=on_device, x=x[0], - axis = axis, - keepdim = keepdim, + axis=axis, + keepdim=keepdim, ) diff --git a/ivy_tests/test_ivy/test_misc/test_backend_handler.py b/ivy_tests/test_ivy/test_misc/test_backend_handler.py index 0185ad0857c22..1bb09f050e2ce 100644 --- a/ivy_tests/test_ivy/test_misc/test_backend_handler.py +++ b/ivy_tests/test_ivy/test_misc/test_backend_handler.py @@ -188,7 +188,9 @@ def test_choose_random_backend(excluded): backends = list(_backend_dict.keys()) -@pytest.mark.parametrize("middle_backend,end_backend", [(a, b) for a in backends for b in backends if a != b]) +@pytest.mark.parametrize( + "middle_backend,end_backend", [(a, b) for a in backends for b in backends if a != b] +) def test_dynamic_backend_all_combos(middle_backend, end_backend): # create an ivy array, container and native container a = ivy.array([1, 2, 3]) From c14c90debd9b80af86f8fc4f43b1dff08204eba9 Mon Sep 17 00:00:00 2001 From: Muhammad Kashif Date: Tue, 29 Aug 2023 14:44:14 +0500 Subject: [PATCH 43/55] Paddle Frontend tensor.min implementation (#22599) --- .../frontends/paddle/tensor/math.py | 8 +++++ .../test_paddle/test_tensor/test_math.py | 34 +++++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/ivy/functional/frontends/paddle/tensor/math.py b/ivy/functional/frontends/paddle/tensor/math.py index 00f91ca3103ea..4a270dcfe71dd 100644 --- a/ivy/functional/frontends/paddle/tensor/math.py +++ b/ivy/functional/frontends/paddle/tensor/math.py @@ -322,6 +322,14 @@ def maximum(x, y, name=None): return ivy.maximum(x, y) +@with_supported_dtypes( + {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, "paddle" +) +@to_ivy_arrays_and_back +def min(x, axis=None, keepdim=False, name=None): + return ivy.min(x, axis=axis, keepdims=keepdim) + + @with_supported_dtypes( {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, "paddle" ) diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_math.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_math.py index de856f6f19fcd..fe6e126e5759a 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_math.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_math.py @@ -2139,6 +2139,40 @@ def test_paddle_stanh( ) +# min +@handle_frontend_test( + fn_tree="paddle.tensor.math.min", + dtype_and_x=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("valid"), + min_axis=-1, + max_axis=0, + min_num_dims=1, + force_int_axis=False, + ), +) +def test_paddle_min( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + backend_fw, + test_flags, +): + input_dtype, x, axis = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + axis=axis, + keepdim=False, + ) + + # amin @handle_frontend_test( fn_tree="paddle.tensor.math.amin", From 2b5e04fad12666b8d63e89477b55546ea335f542 Mon Sep 17 00:00:00 2001 From: Arsalan ali Date: Tue, 29 Aug 2023 14:53:57 +0500 Subject: [PATCH 44/55] refactored the rsqrt inplace function of the paddle frontend (#22656) Co-authored-by: @AnnaTz --- ivy/functional/frontends/paddle/tensor/math.py | 2 +- .../test_frontends/test_paddle/test_tensor/test_math.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/ivy/functional/frontends/paddle/tensor/math.py b/ivy/functional/frontends/paddle/tensor/math.py index 4a270dcfe71dd..dd4564b44302f 100644 --- a/ivy/functional/frontends/paddle/tensor/math.py +++ b/ivy/functional/frontends/paddle/tensor/math.py @@ -419,7 +419,7 @@ def rsqrt(x, name=None): @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") def rsqrt_(x, name=None): - return ivy.inplace_update(x, ivy.reciprocal(ivy.inplace_update(x, ivy.sqrt(x)))) + return ivy.inplace_update(x, reciprocal(sqrt(x))) @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_math.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_math.py index fe6e126e5759a..96344d22b0c1e 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_math.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_math.py @@ -1784,6 +1784,7 @@ def test_paddle_rsqrt( available_dtypes=helpers.get_dtypes("valid"), ), ) + def test_paddle_rsqrt_( *, dtype_and_x, From 7bd529e919f8f2d43f037f641a8dac3f27800792 Mon Sep 17 00:00:00 2001 From: AnnaTz <111577222+AnnaTz@users.noreply.github.com> Date: Tue, 29 Aug 2023 11:39:02 +0100 Subject: [PATCH 45/55] Refactored _parse_query of get_item/set_item to fix transpilation issues (#22691) --- ivy/functional/ivy/general.py | 277 ++++------------------------------ 1 file changed, 29 insertions(+), 248 deletions(-) diff --git a/ivy/functional/ivy/general.py b/ivy/functional/ivy/general.py index 7cea3d0bef235..7752630d6cf97 100644 --- a/ivy/functional/ivy/general.py +++ b/ivy/functional/ivy/general.py @@ -2790,17 +2790,11 @@ def get_item( query = ivy.nonzero(query, as_tuple=False) ret = ivy.gather_nd(x, query) else: - query, target_shape, vector_inds = _parse_query(query, x.shape) - if vector_inds is not None: - x = ivy.permute_dims( - x, - axes=[ - *vector_inds, - *[i for i in range(len(x.shape)) if i not in vector_inds], - ], - ) - ret = ivy.gather_nd(x, query) - ret = ivy.reshape(ret, target_shape) if target_shape != list(ret.shape) else ret + indices, target_shape = _parse_query(query, x.shape) + if indices is None: + return ivy.empty(target_shape, dtype=x.dtype) + ret = ivy.gather_nd(x, indices) + ret = ivy.reshape(ret, target_shape) return ret @@ -2867,27 +2861,21 @@ def set_item( """ if copy: x = ivy.copy_array(x) + if not ivy.is_array(val): + val = ivy.array(val) if 0 in x.shape or 0 in val.shape: return x - inv_perm = None if ivy.is_array(query) and ivy.is_bool_dtype(query): if not len(query.shape): query = ivy.tile(query, (x.shape[0],)) target_shape = ivy.get_item(x, query).shape - query = ivy.nonzero(query, as_tuple=False) + indices = ivy.nonzero(query, as_tuple=False) else: - query, target_shape, vector_inds = _parse_query(query, x.shape, scatter=True) - if vector_inds is not None: - perm = [ - *vector_inds, - *[i for i in range(len(x.shape)) if i not in vector_inds], - ] - x = ivy.permute_dims(x, axes=perm) - inv_perm = ivy.invert_permutation(perm).to_list() + indices, target_shape = _parse_query(query, x.shape) + if indices is None: + return x val = _broadcast_to(val, target_shape).astype(x.dtype) - ret = ivy.scatter_nd(query, val, reduction="replace", out=x) - if inv_perm is not None: - return ivy.permute_dims(x, axes=inv_perm) + ret = ivy.scatter_nd(indices, val, reduction="replace", out=x) return ret @@ -2901,240 +2889,33 @@ def set_item( } -def _parse_query(query, x_shape, scatter=False): - query = (query,) if not isinstance(query, tuple) else query - - # sequence and integer queries are dealt with as array queries - query = [ivy.array(q) if isinstance(q, (tuple, list, int)) else q for q in query] - - # check if non-slice queries are in consecutive positions - # if so, they have to be moved to the front - # https://numpy.org/neps/nep-0021-advanced-indexing.html#mixed-indexing - non_slice_q_idxs = [i for i, q in enumerate(query) if ivy.is_array(q)] - to_front = len(non_slice_q_idxs) > 1 and any(ivy.diff(non_slice_q_idxs) != 1) - - # extract newaxis queries - if not scatter: - new_axes = [i for i, q in enumerate(query) if q is None] - query = [q for q in query if q is not None] - query = [Ellipsis] if query == [] else query - - # parse ellipsis - ellipsis_inds = None - if any(q is Ellipsis for q in query): - query, ellipsis_inds = _parse_ellipsis(query, len(x_shape)) - - # broadcast array queries - array_inds = [i for i, v in enumerate(query) if ivy.is_array(v)] - if array_inds: - array_queries = ivy.broadcast_arrays( - *[v for i, v in enumerate(query) if i in array_inds] - ) - array_queries = [ - ivy.where(arr < 0, arr + x_shape[i], arr).astype(ivy.int64) - for arr, i in zip(array_queries, array_inds) - ] - for idx, arr in zip(array_inds, array_queries): - query[idx] = arr - - # convert slices to range arrays - query = [ - _parse_slice(q, x_shape[i]).astype(ivy.int64) if isinstance(q, slice) else q - for i, q in enumerate(query) - ] - - # fill in missing queries - if len(query) < len(x_shape): - query += [ivy.arange(0, s, 1).astype(ivy.int64) for s in x_shape[len(query) :]] - - # calculate target_shape, i.e. the shape the gathered/scattered values should have - if len(array_inds) and to_front: - target_shape = ( - [list(array_queries[0].shape)] - + [list(query[i].shape) for i in range(len(query)) if i not in array_inds] - + [[] for _ in range(len(array_inds) - 1)] - ) - elif len(array_inds): - target_shape = ( - [list(query[i].shape) for i in range(0, array_inds[0])] - + [list(array_queries[0].shape)] - + [[] for _ in range(len(array_inds) - 1)] - + [list(query[i].shape) for i in range(array_inds[-1] + 1, len(query))] - ) - else: - target_shape = [list(q.shape) for q in query] - if ellipsis_inds is not None: - target_shape = ( - target_shape[: ellipsis_inds[0]] - + [target_shape[ellipsis_inds[0] : ellipsis_inds[1]]] - + target_shape[ellipsis_inds[1] :] - ) - if not scatter: - for ax in new_axes: - if len(array_inds) and to_front: - ax -= sum(1 for x in array_inds if x < ax) - 1 - target_shape = [*target_shape[:ax], 1, *target_shape[ax:]] - target_shape = _deep_flatten(target_shape) - - # calculate the indices mesh (indices in gather_nd/scatter_nd format) - query = [ivy.expand_dims(q) if not len(q.shape) else q for q in query] - if len(array_inds): - array_queries = [ - ( - arr.reshape((-1,)) - if len(arr.shape) > 1 - else ivy.expand_dims(arr) if not len(arr.shape) else arr - ) - for arr in array_queries - ] - array_queries = ivy.stack(array_queries, axis=1) - if len(array_inds) == len(query): # advanced indexing - indices = array_queries.reshape((*target_shape, len(x_shape))) - elif len(array_inds) == 0: # basic indexing - indices = ivy.stack(ivy.meshgrid(*query, indexing="ij"), axis=-1).reshape( - (*target_shape, len(x_shape)) - ) - else: # mixed indexing - if to_front: - post_array_queries = ( - ivy.stack( - ivy.meshgrid( - *[v for i, v in enumerate(query) if i not in array_inds], - indexing="ij", - ), - axis=-1, - ).reshape((-1, len(query) - len(array_inds))) - if len(array_inds) < len(query) - else ivy.empty((1, 0)) - ) - indices = ivy.array( - [ - (*arr, *post) - for arr, post in itertools.product( - array_queries, post_array_queries - ) - ] - ).reshape((*target_shape, len(x_shape))) - else: - pre_array_queries = ( - ivy.stack( - ivy.meshgrid( - *[v for i, v in enumerate(query) if i < array_inds[0]], - indexing="ij", - ), - axis=-1, - ).reshape((-1, array_inds[0])) - if array_inds[0] > 0 - else ivy.empty((1, 0)) - ) - post_array_queries = ( - ivy.stack( - ivy.meshgrid( - *[v for i, v in enumerate(query) if i > array_inds[-1]], - indexing="ij", - ), - axis=-1, - ).reshape((-1, len(query) - 1 - array_inds[-1])) - if array_inds[-1] < len(query) - 1 - else ivy.empty((1, 0)) - ) - indices = ivy.array( - [ - (*pre, *arr, *post) - for pre, arr, post in itertools.product( - pre_array_queries, array_queries, post_array_queries - ) - ] - ).reshape((*target_shape, len(x_shape))) - - return ( - indices.astype(ivy.int64), - target_shape, - array_inds if len(array_inds) and to_front else None, - ) +def _parse_query(query, x_shape): + query = (query,) if not isinstance(query, tuple) else query + query_ = tuple([q.to_numpy() if ivy.is_array(q) else q for q in query]) -def _parse_ellipsis(so, ndims): - pre = list() - for s in so: - if s is Ellipsis: - break - pre.append(s) - post = list() - for s in reversed(so): - if s is Ellipsis: - break - post.append(s) - ret = list( - pre - + [slice(None, None, None) for _ in range(ndims - len(pre) - len(post))] - + list(reversed(post)) - ) - return ret, (len(pre), ndims - len(post)) + # array containing all of x's flat indices + x_ = ivy.arange(0, _numel(x_shape)).reshape(x_shape) + # use numpy's __getitem__ to get the queried indices + x_idxs = ivy.array(x_.to_numpy()[query_]) + target_shape = x_idxs.shape -def _parse_slice(idx, s): - step = 1 if idx.step is None else idx.step - if step > 0: - start = 0 if idx.start is None else idx.start - if start >= s: - stop = start - else: - if start <= -s: - start = 0 - elif start < 0: - start = start + s - stop = s if idx.stop is None else idx.stop - if stop > s: - stop = s - elif start <= -s: - stop = 0 - elif stop < 0: - stop = stop + s - else: - start = s - 1 if idx.start is None else idx.start - if start < -s: - stop = start - else: - if start >= s: - start = s - 1 - elif start < 0: - start = start + s - if idx.stop is None: - stop = -1 - else: - stop = idx.stop - if stop > s: - stop = s - elif stop < -s: - stop = -1 - elif stop == -s: - stop = 0 - elif stop < 0: - stop = stop + s - q_i = ivy.arange(start, stop, step).to_list() - q_i = [q for q in q_i if 0 <= q < s] - q_i = ( - ivy.array(q_i) - if len(q_i) or start == stop or idx.stop is not None - else ivy.arange(0, s, 1) - ) - return q_i + if 0 in x_idxs.shape or 0 in x_shape: + return None, target_shape + # convert the flat indices to multi-D indices + x_idxs = ivy.unravel_index(x_idxs, x_shape) -def _deep_flatten(iterable): - def _flatten_gen(iterable): - for item in iterable: - if isinstance(item, list): - yield from _flatten_gen(item) - else: - yield item + # stack the multi-D indices to bring them to gather_nd/scatter_nd format + x_idxs = ivy.stack(x_idxs, axis=-1).astype(ivy.int64) - return list(_flatten_gen(iterable)) + return x_idxs, target_shape def _numel(shape): - return math.prod(shape) if shape != () else 1 + shape = tuple(shape) + return ivy.prod(shape).to_scalar() if shape != () else 1 def _broadcast_to(input, target_shape): From 2aa51ff29f3389f13dde8d0c629b1bc8e55bb89c Mon Sep 17 00:00:00 2001 From: Nripesh Niketan <86844847+NripeshN@users.noreply.github.com> Date: Tue, 29 Aug 2023 16:17:34 +0530 Subject: [PATCH 46/55] refactor(tests): Added new pre-commit hook for ordering functions (#22600) --- .pre-commit-config.yaml | 2 +- .../numpy/fft/discrete_fourier_transform.py | 16 +- .../frontends/numpy/ma/MaskedArray.py | 9 +- .../frontends/tensorflow/raw_ops.py | 8 +- .../frontends/torch/comparison_ops.py | 4 +- ivy/functional/ivy/general.py | 2 - .../test_jax/test__src/test_tree_util.py | 54 +- .../test_frontends/test_jax/test_array.py | 1858 ++-- .../test_jax/test_func_wrapper.py | 8 + .../test_jax/test_general_functions.py | 142 +- .../test_lax/test_control_flow_operators.py | 132 +- .../test_jax/test_lax/test_linalg.py | 144 +- .../test_jax/test_lax/test_operators.py | 2904 +++--- .../test_nn/test_non_linear_activations.py | 458 +- .../test_jax/test_numpy/test_creation.py | 980 +- .../test_jax/test_numpy/test_dtype.py | 70 +- .../test_jax/test_numpy/test_indexing.py | 482 +- .../test_jax/test_numpy/test_linalg.py | 818 +- .../test_jax/test_numpy/test_logic.py | 856 +- .../test_jax/test_numpy/test_manipulations.py | 1854 ++-- .../test_numpy/test_mathematical_functions.py | 2342 ++--- .../test_numpy/test_searching_sorting.py | 350 +- .../test_jax/test_numpy/test_statistical.py | 1464 +-- .../test_frontends/test_jax/test_random.py | 706 +- .../test_function/test_mindspore_nn_func.py | 1024 +- .../test_frontends/test_numpy/helpers.py | 244 +- .../test_numpy/test_broadcast/test_methods.py | 56 +- .../test_building_matrices.py | 184 +- .../test_from_shape_or_value.py | 320 +- .../test_numerical_ranges.py | 98 +- .../test_data_type_routines/test_general.py | 48 +- .../test_discrete_fourier_transform.py | 149 +- .../test_numpy/test_func_wrapper.py | 152 +- .../test_generating_index_arrays.py | 100 +- .../test_indexing_like_operations.py | 82 +- .../test_inserting_data_into_arrays.py | 82 +- .../test_matrix_and_vector_products.py | 326 +- .../test_linalg/test_matrix_eigenvalues.py | 60 +- .../test_norms_and_other_numbers.py | 52 +- ...olving_equations_and_inverting_matrices.py | 218 +- .../test_logic/test_array_contents.py | 64 +- .../test_numpy/test_logic/test_comparison.py | 120 +- .../test_logic/test_logical_operations.py | 28 +- .../test_logic/test_truth_value_testing.py | 74 +- .../test_numpy/test_ma/test_MaskedArray.py | 126 +- .../test_adding_and_removing_elements.py | 90 +- .../test_basic_operations.py | 4 + .../test_changing_array_shape.py | 204 +- .../test_changing_number_of_dimensions.py | 164 +- .../test_joining_arrays.py | 58 +- .../test_padding_arrays.py | 52 +- .../test_rearranging_elements.py | 174 +- .../test_splitting_arrays.py | 72 +- .../test_tiling_arrays.py | 64 +- .../test_transpose_like_operations.py | 60 +- .../test_arithmetic_operations.py | 340 +- .../test_exponents_and_logarithms.py | 268 +- .../test_extrema_finding.py | 274 +- .../test_handling_complex_numbers.py | 62 +- .../test_hyperbolic_functions.py | 60 +- .../test_miscellaneous.py | 490 +- .../test_rounding.py | 130 +- .../test_sums_products_differences.py | 274 +- .../test_trigonometric_functions.py | 90 +- .../test_numpy/test_matrix/test_methods.py | 98 +- .../test_numpy/test_ndarray/test_ndarray.py | 2736 ++--- .../test_numpy/test_random/test_functions.py | 842 +- .../test_searching.py | 322 +- .../test_sorting.py | 78 +- .../test_averages_and_variances.py | 448 +- .../test_statistics/test_correlating.py | 180 +- .../test_statistics/test_order_statistics.py | 174 +- .../test_numpy/test_ufunc/test_methods.py | 28 +- .../test_onnx/test_elementwise.py | 202 +- .../test_frontends/test_onnx/test_tensor.py | 46 +- .../test_frontends/test_paddle/test_fft.py | 50 +- .../test_functional/test_activation.py | 470 +- .../test_nn/test_functional/test_common.py | 484 +- .../test_nn/test_functional/test_conv.py | 66 +- .../test_nn/test_functional/test_distance.py | 2 - .../test_nn/test_functional/test_extension.py | 2 - .../test_nn/test_functional/test_input.py | 2 - .../test_nn/test_functional/test_loss.py | 394 +- .../test_nn/test_functional/test_pooling.py | 238 +- .../test_nn/test_functional/test_vision.py | 118 +- .../test_frontends/test_paddle/test_signal.py | 2 - .../test_paddle/test_tensor/test_attribute.py | 36 +- .../test_paddle/test_tensor/test_creation.py | 1580 +-- .../test_paddle/test_tensor/test_einsum.py | 2 - .../test_paddle/test_tensor/test_linalg.py | 917 +- .../test_paddle/test_tensor/test_logic.py | 306 +- .../test_tensor/test_manipulation.py | 738 +- .../test_paddle/test_tensor/test_math.py | 1269 +-- .../test_paddle/test_tensor/test_random.py | 234 +- .../test_paddle/test_tensor/test_search.py | 100 +- .../test_paddle/test_tensor/test_stat.py | 100 +- .../test_paddle/test_tensor/test_tensor.py | 1614 +-- .../test_vision/test_transforms.py | 96 +- .../test_pandas/test_dataframe.py | 42 +- .../test_frontends/test_pandas/test_series.py | 74 +- .../test_scipy/test_fft/test_fft.py | 174 +- .../test_scipy/test_linalg/test_linalg.py | 230 +- .../test_compat/test_v1/test_nn.py | 106 +- .../test_tensorflow/test_func_wrapper.py | 76 +- .../test_tensorflow/test_general_functions.py | 2563 +++-- .../test_image/test_cropping.py | 8 + .../test_keras/test_activations.py | 1046 +- .../test_keras/test_metrics.py | 508 +- .../test_tensorflow/test_linalg.py | 1250 +-- .../test_tensorflow/test_math.py | 1982 ++-- .../test_frontends/test_tensorflow/test_nn.py | 1310 +-- .../test_tensorflow/test_random.py | 300 +- .../test_tensorflow/test_raw_ops.py | 3534 +++---- .../test_tensorflow/test_signal.py | 140 +- .../test_tensorflow/test_tensor.py | 810 +- .../test_torch/test_blas_and_lapack_ops.py | 378 +- .../test_torch/test_comparison_ops.py | 392 +- .../test_torch/test_creation_ops.py | 706 +- .../test_torch/test_func_wrapper.py | 64 +- ...t_indexing_slicing_joining_mutating_ops.py | 3200 +++--- .../test_frontends/test_torch/test_linalg.py | 1178 +-- .../test_torch/test_miscellaneous_ops.py | 3572 +++---- .../test_convolution_functions.py | 234 +- .../test_functional/test_dropout_functions.py | 8 +- .../test_functional/test_linear_functions.py | 8 + .../test_functional/test_loss_functions.py | 656 +- .../test_non_linear_activation_functions.py | 1534 +-- .../test_nn/test_functional/test_norms.py | 264 +- .../test_functional/test_pooling_functions.py | 382 +- .../test_functional/test_sparse_functions.py | 40 +- .../test_functional/test_vision_functions.py | 288 +- .../test_torch/test_pointwise_ops.py | 1684 ++-- .../test_torch/test_random_sampling.py | 336 +- .../test_torch/test_reduction_ops.py | 2072 ++-- .../test_frontends/test_torch/test_tensor.py | 8894 +++++++++-------- .../test_torch/test_tensor_functions.py | 58 +- .../test_torch/test_utilities.py | 58 +- .../test_core/test_creation.py | 818 +- .../test_functional/test_core/test_device.py | 841 +- .../test_functional/test_core/test_dtype.py | 1234 +-- .../test_core/test_elementwise.py | 1000 +- .../test_functional/test_core/test_general.py | 2643 ++--- .../test_core/test_gradients.py | 494 +- .../test_functional/test_core/test_linalg.py | 1160 +-- .../test_core/test_manipulation.py | 735 +- .../test_functional/test_core/test_meta.py | 378 +- .../test_functional/test_core/test_nest.py | 567 +- .../test_functional/test_core/test_random.py | 294 +- .../test_core/test_searching.py | 76 +- .../test_functional/test_core/test_set.py | 48 +- .../test_functional/test_core/test_sorting.py | 156 +- .../test_core/test_statistical.py | 318 +- .../test_core/test_creation.py | 654 +- .../test_core/test_elementwise.py | 868 +- .../test_core/test_general.py | 8 + .../test_core/test_linalg.py | 1870 ++-- .../test_core/test_manipulation.py | 1688 ++-- .../test_core/test_random.py | 162 +- .../test_core/test_searching.py | 4 + .../test_core/test_sorting.py | 8 + .../test_core/test_sparse_array.py | 229 +- .../test_core/test_statistical.py | 516 +- .../test_nn/test_activations.py | 96 +- .../test_experimental/test_nn/test_layers.py | 1708 ++-- .../test_experimental/test_nn/test_losses.py | 152 +- .../test_experimental/test_nn/test_norms.py | 192 +- .../test_nn/test_activations.py | 172 +- .../test_functional/test_nn/test_layers.py | 1282 +-- .../test_functional/test_nn/test_losses.py | 100 +- .../test_functional/test_nn/test_norms.py | 8 + ivy_tests/test_ivy/test_misc/test_array.py | 1730 ++-- .../test_ivy/test_misc/test_assertions.py | 608 +- .../test_misc/test_backend_handler.py | 262 +- .../test_ivy/test_misc/test_container.py | 5177 +++++----- .../test_ivy/test_misc/test_cp_tensor.py | 281 +- .../test_ivy/test_misc/test_exceptions.py | 38 +- .../test_ivy/test_misc/test_func_wrapper.py | 210 +- .../test_ivy/test_misc/test_inspection.py | 8 + .../test_ivy/test_misc/test_ivy_demos.py | 38 +- ivy_tests/test_ivy/test_misc/test_logging.py | 10 +- ivy_tests/test_ivy/test_misc/test_pickling.py | 38 +- ivy_tests/test_ivy/test_misc/test_shape.py | 392 +- .../test_ivy/test_misc/test_tucker_tensor.py | 174 +- .../test_ivy/test_misc/test_with_backend.py | 22 +- .../test_stateful/test_activations.py | 246 +- .../test_ivy/test_stateful/test_converters.py | 48 +- .../test_stateful/test_initializers.py | 220 +- .../test_ivy/test_stateful/test_layers.py | 1475 +-- .../test_ivy/test_stateful/test_losses.py | 246 +- .../test_ivy/test_stateful/test_modules.py | 1470 +-- .../test_ivy/test_stateful/test_norms.py | 94 +- .../test_ivy/test_stateful/test_optimizers.py | 123 +- .../test_ivy/test_stateful/test_sequential.py | 38 +- 193 files changed, 55220 insertions(+), 54540 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0b922480bf477..a395e41b926f8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -26,6 +26,6 @@ repos: # Exclude everything in frontends except __init__.py, and func_wrapper.py exclude: 'ivy/functional/(frontends|backends)/(?!.*/func_wrapper\.py$).*(?!__init__\.py$)' - repo: https://github.com/unifyai/lint-hook - rev: b9a103a9f7991fec0ed636a2bcd4497691761e78 + rev: a90659d806c6d65f20ec41095a2da8e8920cc96f hooks: - id: ivy-lint diff --git a/ivy/functional/frontends/numpy/fft/discrete_fourier_transform.py b/ivy/functional/frontends/numpy/fft/discrete_fourier_transform.py index e9a777baeec07..9d0b890841772 100644 --- a/ivy/functional/frontends/numpy/fft/discrete_fourier_transform.py +++ b/ivy/functional/frontends/numpy/fft/discrete_fourier_transform.py @@ -3,14 +3,6 @@ from ivy.func_wrapper import with_unsupported_dtypes -_SWAP_DIRECTION_MAP = { - None: "forward", - "backward": "forward", - "ortho": "ortho", - "forward": "backward", -} - - # --- Helpers --- # # --------------- # @@ -148,3 +140,11 @@ def rfftfreq(n, d=1.0): def rfftn(a, s=None, axes=None, norm=None): a = ivy.asarray(a, dtype=ivy.complex128) return ivy.rfftn(a, s=s, axes=axes, norm=norm) + + +_SWAP_DIRECTION_MAP = { + None: "forward", + "backward": "forward", + "ortho": "ortho", + "forward": "backward", +} diff --git a/ivy/functional/frontends/numpy/ma/MaskedArray.py b/ivy/functional/frontends/numpy/ma/MaskedArray.py index d59c1633f6604..34ad474a5b58a 100644 --- a/ivy/functional/frontends/numpy/ma/MaskedArray.py +++ b/ivy/functional/frontends/numpy/ma/MaskedArray.py @@ -2,7 +2,6 @@ import ivy.functional.frontends.numpy as np_frontend import numpy as np -masked = True masked_print_options = "--" nomask = False @@ -194,10 +193,12 @@ def _array_in_str(self): def _is_masked_array(x): return isinstance(x, (np.ma.MaskedArray, np_frontend.ma.MaskedArray)) - # Instance Methods # - # ---------------- # - # TODO +masked = True +# Instance Methods # +# ---------------- # + +# TODO # masked_array (alias) diff --git a/ivy/functional/frontends/tensorflow/raw_ops.py b/ivy/functional/frontends/tensorflow/raw_ops.py index 5c46ccc7fa1e5..ce38aa3b4f19f 100644 --- a/ivy/functional/frontends/tensorflow/raw_ops.py +++ b/ivy/functional/frontends/tensorflow/raw_ops.py @@ -14,7 +14,6 @@ Acos = to_ivy_arrays_and_back(map_raw_ops_alias(tf_frontend.math.acos)) Acosh = to_ivy_arrays_and_back(map_raw_ops_alias(tf_frontend.math.acosh)) -Add = to_ivy_arrays_and_back(map_raw_ops_alias(tf_frontend.math.add)) AddN = to_ivy_arrays_and_back(map_raw_ops_alias(tf_frontend.math.add_n)) AddV2 = to_ivy_arrays_and_back(map_raw_ops_alias(tf_frontend.math.add)) ArgMax = to_ivy_arrays_and_back( @@ -220,7 +219,6 @@ ) Sin = to_ivy_arrays_and_back(map_raw_ops_alias(tf_frontend.math.sin)) Size = to_ivy_arrays_and_back(map_raw_ops_alias(tf_frontend.general_functions.size)) -Slice = to_ivy_arrays_and_back(map_raw_ops_alias(tf_frontend.slice)) Softmax = to_ivy_arrays_and_back( with_unsupported_dtypes( { @@ -253,7 +251,6 @@ Squeeze = to_ivy_arrays_and_back( map_raw_ops_alias(tf_frontend.general_functions.squeeze) ) -Sub = to_ivy_arrays_and_back(map_raw_ops_alias(tf_frontend.math.subtract)) Tan = to_ivy_arrays_and_back(map_raw_ops_alias(tf_frontend.math.tan)) Tanh = to_ivy_arrays_and_back(map_raw_ops_alias(tf_frontend.math.tanh)) Xlogy = to_ivy_arrays_and_back(map_raw_ops_alias(tf_frontend.math.xlogy)) @@ -831,3 +828,8 @@ def Xlog1py(*, x, y, name="Xlog1py"): @to_ivy_arrays_and_back def ZerosLike(*, x, name="ZerosLike"): return ivy.zeros_like(x) + + +Add = to_ivy_arrays_and_back(map_raw_ops_alias(tf_frontend.math.add)) +Slice = to_ivy_arrays_and_back(map_raw_ops_alias(tf_frontend.slice)) +Sub = to_ivy_arrays_and_back(map_raw_ops_alias(tf_frontend.math.subtract)) diff --git a/ivy/functional/frontends/torch/comparison_ops.py b/ivy/functional/frontends/torch/comparison_ops.py index 5eb04bad6d727..eeaadadfbf864 100644 --- a/ivy/functional/frontends/torch/comparison_ops.py +++ b/ivy/functional/frontends/torch/comparison_ops.py @@ -289,8 +289,8 @@ def topk(input, k, dim=None, largest=True, sorted=True, *, out=None): return ivy.top_k(input, k, axis=dim, largest=largest, sorted=sorted, out=out) -ge = greater_equal gt = greater +ne = not_equal +ge = greater_equal le = less_equal lt = less -ne = not_equal diff --git a/ivy/functional/ivy/general.py b/ivy/functional/ivy/general.py index 7752630d6cf97..b2413ad0de071 100644 --- a/ivy/functional/ivy/general.py +++ b/ivy/functional/ivy/general.py @@ -3,7 +3,6 @@ # global import gc import inspect -import itertools import math from functools import wraps from numbers import Number @@ -2890,7 +2889,6 @@ def set_item( def _parse_query(query, x_shape): - query = (query,) if not isinstance(query, tuple) else query query_ = tuple([q.to_numpy() if ivy.is_array(q) else q for q in query]) diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test__src/test_tree_util.py b/ivy_tests/test_ivy/test_frontends/test_jax/test__src/test_tree_util.py index a72c24c973ca1..171cf37f2d825 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test__src/test_tree_util.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test__src/test_tree_util.py @@ -6,29 +6,8 @@ import hypothesis.strategies as st -def leaf_strategy(): - return st.lists(st.integers(1, 10)).map(ivy.array) - - -def tree_strategy(max_depth=2): - if max_depth == 0: - return leaf_strategy() - else: - return st.dictionaries( - keys=st.one_of( - *[ - st.text( - alphabet=st.characters(min_codepoint=97, max_codepoint=122), - min_size=1, - max_size=1, - ).filter(lambda x: x not in used_keys) - for used_keys in [set()] - ] - ), - values=st.one_of(leaf_strategy(), tree_strategy(max_depth - 1)), - min_size=1, - max_size=10, - ) +# --- Helpers --- # +# --------------- # @st.composite @@ -36,6 +15,14 @@ def _tree_dict_strategy(draw): return draw(tree_strategy()) +# --- Main --- # +# ------------ # + + +def leaf_strategy(): + return st.lists(st.integers(1, 10)).map(ivy.array) + + # tree_leaves @handle_frontend_test( fn_tree="jax._src.tree_util.tree_leaves", @@ -93,3 +80,24 @@ def square(x): assert ivy.equal(ivy.Container(result), expected) ivy.previous_backend() + + +def tree_strategy(max_depth=2): + if max_depth == 0: + return leaf_strategy() + else: + return st.dictionaries( + keys=st.one_of( + *[ + st.text( + alphabet=st.characters(min_codepoint=97, max_codepoint=122), + min_size=1, + max_size=1, + ).filter(lambda x: x not in used_keys) + for used_keys in [set()] + ] + ), + values=st.one_of(leaf_strategy(), tree_strategy(max_depth - 1)), + min_size=1, + max_size=10, + ) diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_array.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_array.py index c9477a03b4f12..474165d43ddd2 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_array.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_array.py @@ -15,84 +15,141 @@ CLASS_TREE = "ivy.functional.frontends.jax.numpy.ndarray" -@given( - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid", prune_function=False) - ), -) -def test_jax_ivy_array( - dtype_x, - backend_fw, -): - _, data = dtype_x - with BackendHandler.update_backend(backend_fw) as ivy_backend: - jax_frontend = ivy_backend.utils.dynamic_import.import_module( - "ivy.functional.frontends.jax" +# --- Helpers --- # +# --------------- # + + +@st.composite +def _at_helper(draw): + _, data, shape = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid", prune_function=False), + num_arrays=2, + shared_dtype=True, + min_num_dims=1, + ret_shape=True, ) - x = jax_frontend.Array(data[0]) - ret = helpers.flatten_and_to_np(ret=x.ivy_array.data, backend=backend_fw) - ret_gt = helpers.flatten_and_to_np(ret=data[0], backend=backend_fw) - helpers.value_test( - ret_np_flat=ret, - ret_np_from_gt_flat=ret_gt, - backend=backend_fw, - ground_truth_backend="jax", + ) + axis = draw(helpers.get_axis(shape=shape, force_tuple=True)) + index = () + for a in axis: + index = index + (draw(st.integers(min_value=0, max_value=shape[a] - 1)),) + return data, index + + +@st.composite +def _get_dtype_input_and_vectors(draw): + dim_size = draw(helpers.ints(min_value=2, max_value=5)) + dtype = draw(helpers.get_dtypes("numeric", index=1, full=False)) + vec1 = draw( + helpers.array_values( + dtype=dtype[0], shape=(dim_size, dim_size), min_value=2, max_value=5 + ) + ) + vec2 = draw( + helpers.array_values( + dtype=dtype[0], shape=(dim_size, dim_size), min_value=2, max_value=5 ) + ) + return dtype, [vec1, vec2] -@given( - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid", prune_function=False) - ), -) -def test_jax_array_dtype( - dtype_x, - backend_fw, -): - dtype, data = dtype_x - with BackendHandler.update_backend(backend_fw) as ivy_backend: - jax_frontend = ivy_backend.utils.dynamic_import.import_module( - "ivy.functional.frontends.jax" +@st.composite +def _get_dtype_x_and_int(draw, *, dtype="numeric"): + x_dtype, x = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes(dtype), + large_abs_safety_factor=2, + small_abs_safety_factor=2, + safety_factor_scale="log", ) - x = jax_frontend.Array(data[0]) - assert x.dtype == dtype[0] + ) + pow_dtype, x_int = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("integer"), + min_value=0, + max_value=10, + max_num_dims=0, + max_dim_size=1, + small_abs_safety_factor=2, + large_abs_safety_factor=2, + safety_factor_scale="log", + ) + ) + x_dtype = x_dtype + pow_dtype + return x_dtype, x, x_int -@given( - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid", prune_function=False) - ), -) -def test_jax_array_ndim( - dtype_x, - backend_fw, -): - dtype, data = dtype_x - with BackendHandler.update_backend(backend_fw) as ivy_backend: - jax_frontend = ivy_backend.utils.dynamic_import.import_module( - "ivy.functional.frontends.jax" +# shifting helper +@st.composite +def _get_dtype_x_and_int_shift(draw, dtype): + x_dtype, x = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes(dtype), + num_arrays=2, + shared_dtype=True, ) - x = jax_frontend.Array(data[0]) - assert x.ndim == data[0].ndim + ) + x_dtype = x_dtype + x[1] = np.asarray(np.clip(x[0], 0, np.iinfo(x_dtype[0]).bits - 1), dtype=x_dtype[0]) + return x_dtype, x[0], x[1] -@given( - dtype_x_shape=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid", prune_function=False), - ret_shape=True, - ), -) -def test_jax_array_shape( - dtype_x_shape, - backend_fw, -): - _, data, shape = dtype_x_shape - with BackendHandler.update_backend(backend_fw) as ivy_backend: - jax_frontend = ivy_backend.utils.dynamic_import.import_module( - "ivy.functional.frontends.jax" +# repeat +@st.composite +def _repeat_helper(draw): + shape = draw(st.shared(helpers.get_shape(min_num_dims=1), key="value_shape")) + axis = draw( + st.shared( + st.one_of(st.none(), helpers.get_axis(shape=shape, max_size=1)), key="axis" ) - x = jax_frontend.Array(data[0]) - assert x.shape == shape + ) + + if not isinstance(axis, int) and axis is not None: + axis = axis[0] + + repeat_shape = ( + (draw(st.one_of(st.just(1), st.just(shape[axis]))),) + if axis is not None + else (1,) + ) + repeat = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("integer"), + shape=repeat_shape, + min_value=0, + max_value=10, + ) + ) + return repeat + + +# searchsorted +@st.composite +def _searchsorted(draw): + dtype_x, x = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes( + "numeric", full=False, key="searchsorted" + ), + shape=(draw(st.integers(min_value=1, max_value=10)),), + ), + ) + dtype_v, v = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes( + "numeric", full=False, key="searchsorted" + ), + min_num_dims=1, + ) + ) + + input_dtypes = dtype_x + dtype_v + xs = x + v + side = draw(st.sampled_from(["left", "right"])) + sorter = None + xs[0] = np.sort(xs[0], axis=-1) + return input_dtypes, xs, side, sorter @st.composite @@ -112,80 +169,39 @@ def _transpose_helper(draw): return x, xT -@given(x_transpose=_transpose_helper()) -def test_jax_array_property_T(x_transpose, backend_fw): - with BackendHandler.update_backend(backend_fw) as ivy_backend: - x, xT = x_transpose - jax_frontend = ivy_backend.utils.dynamic_import.import_module( - "ivy.functional.frontends.jax" - ) - x = jax_frontend.Array(x) - assert np.array_equal(x.T, xT) - - -@st.composite -def _at_helper(draw): - _, data, shape = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid", prune_function=False), - num_arrays=2, - shared_dtype=True, - min_num_dims=1, - ret_shape=True, - ) - ) - axis = draw(helpers.get_axis(shape=shape, force_tuple=True)) - index = () - for a in axis: - index = index + (draw(st.integers(min_value=0, max_value=shape[a] - 1)),) - return data, index - - -@given( - x_y_index=_at_helper(), -) -def test_jax_array_at(x_y_index, backend_fw): - with BackendHandler.update_backend(backend_fw) as ivy_backend: - jax_frontend = ivy_backend.utils.dynamic_import.import_module( - "ivy.functional.frontends.jax" - ) - xy, idx = x_y_index - x = jax_frontend.Array(xy[0]) - y = jax_frontend.Array(xy[1]) - idx = idx[0] - x_set = x.at[idx].set(y[idx]) - assert x_set[idx] == y[idx] - assert x.at[idx].get() == x[idx] +# --- Main --- # +# ------------ # @handle_frontend_method( class_tree=CLASS_TREE, init_tree="jax.numpy.array", - method_name="copy", + method_name="__add__", dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - min_num_dims=1, + available_dtypes=helpers.get_dtypes("numeric"), + shared_dtype=True, + num_arrays=2, ), ) -def test_jax_array_copy( +def test_jax___add__( dtype_x, - on_device, frontend, frontend_method_data, - backend_fw, init_flags, method_flags, + backend_fw, + on_device, ): input_dtype, x = dtype_x helpers.test_frontend_method( init_input_dtypes=input_dtype, + backend_to_test=backend_fw, init_all_as_kwargs_np={ "object": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, + method_all_as_kwargs_np={"other": x[1]}, frontend=frontend, - backend_to_test=backend_fw, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -196,30 +212,32 @@ def test_jax_array_copy( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="jax.numpy.array", - method_name="diagonal", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - min_num_dims=2, + method_name="__div__", + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + shared_dtype=True, + num_arrays=2, ), ) -def test_jax_array_diagonal( - dtype_and_x, - on_device, +def test_jax___div__( + dtype_x, frontend, - backend_fw, frontend_method_data, init_flags, method_flags, + backend_fw, + on_device, ): - input_dtype, x = dtype_and_x + input_dtype, x = dtype_x + assume(not np.any(np.isclose(x[1], 0))) helpers.test_frontend_method( - backend_to_test=backend_fw, init_input_dtypes=input_dtype, + backend_to_test=backend_fw, init_all_as_kwargs_np={ "object": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, + method_all_as_kwargs_np={"other": x[1]}, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -228,29 +246,57 @@ def test_jax_array_diagonal( ) +# __getitem__ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="jax.numpy.array", - method_name="all", - dtype_x_axis=helpers.dtype_values_axis( + method_name="__getitem__", + dtype_x_index=helpers.dtype_array_query( available_dtypes=helpers.get_dtypes("valid"), - force_int_axis=True, - valid_axis=True, - min_num_dims=1, - ), - keepdims=st.booleans(), + ).filter(lambda x: not (isinstance(x[-1], np.ndarray) and x[-1].dtype == np.bool_)), ) -def test_jax_array_all( - dtype_x_axis, - keepdims, +def test_jax___getitem__( + dtype_x_index, + frontend, + frontend_method_data, + init_flags, + method_flags, + backend_fw, on_device, +): + input_dtype, x, index = dtype_x_index + helpers.test_frontend_method( + init_input_dtypes=[input_dtype[0]], + backend_to_test=backend_fw, + init_all_as_kwargs_np={"object": x}, + method_input_dtypes=[*input_dtype[1:]], + method_all_as_kwargs_np={"idx": index}, + frontend=frontend, + frontend_method_data=frontend_method_data, + init_flags=init_flags, + method_flags=method_flags, + on_device=on_device, + ) + + +@handle_frontend_method( + class_tree=CLASS_TREE, + init_tree="jax.numpy.array", + method_name="__invert__", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=st.one_of(st.just(("bool",)), helpers.get_dtypes("integer")), + ), +) +def test_jax___invert__( + dtype_and_x, frontend, frontend_method_data, init_flags, method_flags, backend_fw, + on_device, ): - input_dtype, x, axis = dtype_x_axis + input_dtype, x = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -258,10 +304,7 @@ def test_jax_array_all( "object": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "axis": axis, - "keepdims": keepdims, - }, + method_all_as_kwargs_np={}, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -273,30 +316,27 @@ def test_jax_array_all( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="jax.numpy.array", - method_name="astype", - dtype_and_x=_get_castable_dtype(), + method_name="__lshift__", + dtype_x_shift=_get_dtype_x_and_int_shift(dtype="signed_integer"), ) -def test_jax_array_astype( - dtype_and_x, - on_device, +def test_jax___lshift__( + dtype_x_shift, frontend, - backend_fw, frontend_method_data, init_flags, method_flags, + backend_fw, + on_device, ): - input_dtype, x, _, castable_dtype = dtype_and_x - + input_dtype, x, shift = dtype_x_shift helpers.test_frontend_method( + init_input_dtypes=input_dtype, backend_to_test=backend_fw, - init_input_dtypes=[input_dtype], init_all_as_kwargs_np={ - "object": x[0], - }, - method_input_dtypes=[input_dtype], - method_all_as_kwargs_np={ - "dtype": castable_dtype, + "object": x, }, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={"other": shift}, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -308,26 +348,19 @@ def test_jax_array_astype( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="jax.numpy.array", - method_name="argmax", - dtype_and_x=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("float"), - force_int_axis=True, - min_num_dims=1, - valid_axis=True, - ), - keepdims=st.booleans(), + method_name="__matmul__", + dtype_x=_get_dtype_input_and_vectors(), ) -def test_jax_array_argmax( - dtype_and_x, - keepdims, - on_device, +def test_jax___matmul__( + dtype_x, frontend, frontend_method_data, init_flags, method_flags, backend_fw, + on_device, ): - input_dtype, x, axis = dtype_and_x + input_dtype, x = dtype_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -335,10 +368,7 @@ def test_jax_array_argmax( "object": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "axis": axis, - "keepdims": keepdims, - }, + method_all_as_kwargs_np={"other": x[1]}, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -350,26 +380,24 @@ def test_jax_array_argmax( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="jax.numpy.array", - method_name="argmin", - dtype_and_x=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("float"), - force_int_axis=True, - min_num_dims=1, - valid_axis=True, + method_name="__mod__", + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + shared_dtype=True, + num_arrays=2, ), - keepdims=st.booleans(), ) -def test_jax_array_argmin( - dtype_and_x, - keepdims, - on_device, +def test_jax___mod__( + dtype_x, frontend, frontend_method_data, init_flags, method_flags, backend_fw, + on_device, ): - input_dtype, x, axis = dtype_and_x + input_dtype, x = dtype_x + assume(not np.any(np.isclose(x[1], 0))) helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -377,10 +405,7 @@ def test_jax_array_argmin( "object": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "axis": axis, - "keepdims": keepdims, - }, + method_all_as_kwargs_np={"other": x[1]}, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -392,21 +417,23 @@ def test_jax_array_argmin( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="jax.numpy.array", - method_name="conj", - dtype_and_x=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("real_and_complex"), + method_name="__mul__", + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + shared_dtype=True, + num_arrays=2, ), ) -def test_jax_array_conj( - dtype_and_x, - on_device, +def test_jax___mul__( + dtype_x, frontend, frontend_method_data, init_flags, method_flags, backend_fw, + on_device, ): - input_dtype, x, axis = dtype_and_x + input_dtype, x = dtype_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -414,7 +441,7 @@ def test_jax_array_conj( "object": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, + method_all_as_kwargs_np={"other": x[1]}, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -426,31 +453,33 @@ def test_jax_array_conj( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="jax.numpy.array", - method_name="conjugate", - dtype_and_x=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("real_and_complex"), + method_name="__radd__", + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + shared_dtype=True, + num_arrays=2, ), ) -def test_jax_array_conjugate( - dtype_and_x, - on_device, +def test_jax___radd__( + dtype_x, frontend, frontend_method_data, - backend_fw, init_flags, method_flags, + backend_fw, + on_device, ): - input_dtype, x, axis = dtype_and_x + input_dtype, x = dtype_x helpers.test_frontend_method( init_input_dtypes=input_dtype, + backend_to_test=backend_fw, init_all_as_kwargs_np={ "object": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, + method_all_as_kwargs_np={"other": x[1]}, frontend=frontend, frontend_method_data=frontend_method_data, - backend_to_test=backend_fw, init_flags=init_flags, method_flags=method_flags, on_device=on_device, @@ -460,29 +489,24 @@ def test_jax_array_conjugate( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="jax.numpy.array", - method_name="mean", - dtype_and_x=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("float"), - large_abs_safety_factor=24, - small_abs_safety_factor=24, - safety_factor_scale="log", - force_int_axis=True, - min_num_dims=1, - valid_axis=True, + method_name="__rdiv__", + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + shared_dtype=True, + num_arrays=2, ), - keepdims=st.booleans(), ) -def test_jax_array_mean( - dtype_and_x, - keepdims, - on_device, +def test_jax___rdiv__( + dtype_x, frontend, frontend_method_data, init_flags, method_flags, backend_fw, + on_device, ): - input_dtype, x, axis = dtype_and_x + input_dtype, x = dtype_x + assume(not np.any(np.isclose(x[0], 0))) helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -490,56 +514,39 @@ def test_jax_array_mean( "object": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "axis": axis, - "keepdims": keepdims, - }, + method_all_as_kwargs_np={"other": x[1]}, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, on_device=on_device, - rtol_=1e-3, - atol_=1e-3, ) @handle_frontend_method( class_tree=CLASS_TREE, init_tree="jax.numpy.array", - method_name="cumprod", - dtype_and_x=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("numeric"), - min_num_dims=1, - max_num_dims=5, - min_value=-100, - max_value=100, - valid_axis=True, - allow_neg_axes=False, - max_axes_size=1, - force_int_axis=True, - ), + method_name="__rlshift__", + dtype_x_shift=_get_dtype_x_and_int_shift(dtype="signed_integer"), ) -def test_jax_array_cumprod( - dtype_and_x, - on_device, +def test_jax___rlshift__( + dtype_x_shift, frontend, frontend_method_data, init_flags, method_flags, backend_fw, + on_device, ): - input_dtype, x, axis = dtype_and_x + input_dtype, x, shift = dtype_x_shift helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ - "object": x[0], + "object": shift, }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "axis": axis, - }, + method_all_as_kwargs_np={"other": x}, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -551,31 +558,28 @@ def test_jax_array_cumprod( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="jax.numpy.array", - method_name="cumsum", - dtype_and_x=_get_castable_dtype(), + method_name="__rmatmul__", + dtype_x=_get_dtype_input_and_vectors(), ) -def test_jax_array_cumsum( - dtype_and_x, - on_device, +def test_jax___rmatmul__( + dtype_x, frontend, frontend_method_data, - backend_fw, init_flags, method_flags, + backend_fw, + on_device, ): - input_dtype, x, axis, dtype = dtype_and_x + input_dtype, x = dtype_x helpers.test_frontend_method( - init_input_dtypes=[input_dtype], + init_input_dtypes=input_dtype, + backend_to_test=backend_fw, init_all_as_kwargs_np={ "object": x[0], }, - method_input_dtypes=[input_dtype], - method_all_as_kwargs_np={ - "axis": axis, - "dtype": dtype, - }, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={"other": x[1]}, frontend=frontend, - backend_to_test=backend_fw, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -586,22 +590,24 @@ def test_jax_array_cumsum( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="jax.numpy.array", - method_name="nonzero", - dtype_and_x=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("valid"), - min_num_dims=1, + method_name="__rmod__", + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + shared_dtype=True, + num_arrays=2, ), ) -def test_jax_array_nonzero( - dtype_and_x, - on_device, +def test_jax___rmod__( + dtype_x, frontend, frontend_method_data, init_flags, method_flags, backend_fw, + on_device, ): - input_dtype, x, axis = dtype_and_x + input_dtype, x = dtype_x + assume(not np.any(np.isclose(x[0], 0))) helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -609,7 +615,7 @@ def test_jax_array_nonzero( "object": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, + method_all_as_kwargs_np={"other": x[1]}, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -621,26 +627,23 @@ def test_jax_array_nonzero( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="jax.numpy.array", - method_name="prod", - dtype_x_axis=helpers.dtype_values_axis( + method_name="__rmul__", + dtype_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), - force_int_axis=True, - valid_axis=True, - min_dim_size=2, - max_dim_size=10, - min_num_dims=2, + shared_dtype=True, + num_arrays=2, ), ) -def test_jax_prod( - dtype_x_axis, - on_device, +def test_jax___rmul__( + dtype_x, frontend, frontend_method_data, init_flags, method_flags, backend_fw, + on_device, ): - input_dtype, x, axis = dtype_x_axis + input_dtype, x = dtype_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -648,55 +651,39 @@ def test_jax_prod( "object": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "axis": axis, - }, + method_all_as_kwargs_np={"other": x[1]}, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, on_device=on_device, - atol_=1e-04, ) @handle_frontend_method( class_tree=CLASS_TREE, init_tree="jax.numpy.array", - method_name="ravel", - dtype_and_x=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("valid"), - min_num_dims=1, - max_num_dims=5, - min_dim_size=2, - max_dim_size=10, - shape=helpers.get_shape( - min_num_dims=2, max_num_dims=5, min_dim_size=2, max_dim_size=10 - ), - ), - order=st.sampled_from(["C", "F"]), + method_name="__rrshift__", + dtype_x_shift=_get_dtype_x_and_int_shift(dtype="signed_integer"), ) -def test_jax_array_ravel( - dtype_and_x, - order, - on_device, +def test_jax___rrshift__( + dtype_x_shift, frontend, frontend_method_data, init_flags, method_flags, backend_fw, + on_device, ): - input_dtype, x, axis = dtype_and_x + input_dtype, x, shift = dtype_x_shift helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ - "object": x[0], + "object": shift, }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "order": order, - }, + method_all_as_kwargs_np={"other": x}, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -708,37 +695,27 @@ def test_jax_array_ravel( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="jax.numpy.array", - method_name="sort", - dtype_x_axis=helpers.dtype_values_axis( - available_dtypes=["int64"], - force_int_axis=True, - min_axis=-1, - max_axis=-1, - min_dim_size=2, - max_dim_size=100, - min_num_dims=2, - ), + method_name="__rshift__", + dtype_x_shift=_get_dtype_x_and_int_shift(dtype="signed_integer"), ) -def test_jax_array_sort( - dtype_x_axis, - on_device, +def test_jax___rshift__( + dtype_x_shift, frontend, frontend_method_data, - backend_fw, init_flags, method_flags, + backend_fw, + on_device, ): - input_dtype, x, axis = dtype_x_axis + input_dtype, x, shift = dtype_x_shift helpers.test_frontend_method( - backend_to_test=backend_fw, init_input_dtypes=input_dtype, + backend_to_test=backend_fw, init_all_as_kwargs_np={ - "object": x[0], + "object": x, }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "axis": axis, - }, + method_all_as_kwargs_np={"other": shift}, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -750,79 +727,69 @@ def test_jax_array_sort( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="jax.numpy.array", - method_name="sum", - dtype_and_x=helpers.dtype_values_axis( + method_name="__rsub__", + dtype_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), - min_num_dims=1, - max_num_dims=5, - min_dim_size=2, - max_dim_size=10, - valid_axis=True, - force_int_axis=True, + shared_dtype=True, + num_arrays=2, ), ) -def test_jax_sum( - dtype_and_x, - on_device, +def test_jax___rsub__( + dtype_x, frontend, frontend_method_data, - backend_fw, init_flags, method_flags, + backend_fw, + on_device, ): - input_dtype, x, axis = dtype_and_x + input_dtype, x = dtype_x helpers.test_frontend_method( - backend_to_test=backend_fw, init_input_dtypes=input_dtype, + backend_to_test=backend_fw, init_all_as_kwargs_np={ "object": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "axis": axis, - }, + method_all_as_kwargs_np={"other": x[1]}, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, on_device=on_device, - atol_=1e-04, ) @handle_frontend_method( class_tree=CLASS_TREE, init_tree="jax.numpy.array", - method_name="argsort", - dtype_x_axis=helpers.dtype_values_axis( + method_name="__rtruediv__", + dtype_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), - min_axis=-1, - max_axis=0, - min_num_dims=1, - force_int_axis=True, + shared_dtype=True, + num_arrays=2, ), ) -def test_jax_array_argsort( - dtype_x_axis, - on_device, +def test_jax___rtruediv__( + dtype_x, frontend, - backend_fw, frontend_method_data, init_flags, method_flags, + backend_fw, + on_device, ): - input_dtype, x, axis = dtype_x_axis + input_dtype, x = dtype_x + assume(not np.any(np.isclose(x[0], 0))) helpers.test_frontend_method( init_input_dtypes=input_dtype, + backend_to_test=backend_fw, init_all_as_kwargs_np={ "object": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "axis": axis, - }, + method_all_as_kwargs_np={"other": x[1]}, frontend=frontend, - backend_to_test=backend_fw, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -833,38 +800,32 @@ def test_jax_array_argsort( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="jax.numpy.array", - method_name="any", - dtype_x_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("valid"), - force_int_axis=True, - valid_axis=True, - min_num_dims=1, + method_name="__sub__", + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + shared_dtype=True, + num_arrays=2, ), - keepdims=st.booleans(), ) -def test_jax_array_any( - dtype_x_axis, - keepdims, - on_device, +def test_jax___sub__( + dtype_x, frontend, - backend_fw, frontend_method_data, init_flags, method_flags, + backend_fw, + on_device, ): - input_dtype, x, axis = dtype_x_axis + input_dtype, x = dtype_x helpers.test_frontend_method( init_input_dtypes=input_dtype, + backend_to_test=backend_fw, init_all_as_kwargs_np={ "object": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "axis": axis, - "keepdims": keepdims, - }, + method_all_as_kwargs_np={"other": x[1]}, frontend=frontend, - backend_to_test=backend_fw, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -875,13 +836,18 @@ def test_jax_array_any( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="jax.numpy.array", - method_name="__pos__", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), + method_name="__truediv__", + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + shared_dtype=True, + num_arrays=2, + large_abs_safety_factor=2, + small_abs_safety_factor=2, + safety_factor_scale="log", ), ) -def test_jax__pos_( - dtype_and_x, +def test_jax___truediv__( + dtype_x, frontend, frontend_method_data, init_flags, @@ -889,32 +855,33 @@ def test_jax__pos_( backend_fw, on_device, ): - input_dtype, x = dtype_and_x + input_dtype, x = dtype_x + assume(not np.any(np.isclose(x[1], 0))) helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, - on_device=on_device, init_all_as_kwargs_np={ "object": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, + method_all_as_kwargs_np={"other": x[1]}, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, + on_device=on_device, ) @handle_frontend_method( class_tree=CLASS_TREE, init_tree="jax.numpy.array", - method_name="__neg__", + method_name="__abs__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("signed_integer"), + available_dtypes=helpers.get_dtypes("numeric"), ), ) -def test_jax__neg_( +def test_jax__abs_( dtype_and_x, frontend, frontend_method_data, @@ -943,13 +910,14 @@ def test_jax__neg_( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="jax.numpy.array", - method_name="__eq__", + method_name="__and__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), + available_dtypes=st.one_of(st.just(("bool",)), helpers.get_dtypes("integer")), num_arrays=2, + shared_dtype=True, ), ) -def test_jax__eq_( +def test_jax__and_( dtype_and_x, frontend, frontend_method_data, @@ -959,7 +927,6 @@ def test_jax__eq_( on_device, ): input_dtype, x = dtype_and_x - assume("bfloat16" not in input_dtype) helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -967,9 +934,7 @@ def test_jax__eq_( "object": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "other": x[1], - }, + method_all_as_kwargs_np={"other": x[1]}, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -981,13 +946,13 @@ def test_jax__eq_( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="jax.numpy.array", - method_name="__ne__", + method_name="__eq__", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2, ), ) -def test_jax__ne_( +def test_jax__eq_( dtype_and_x, frontend, frontend_method_data, @@ -1019,14 +984,13 @@ def test_jax__ne_( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="jax.numpy.array", - method_name="__lt__", + method_name="__ge__", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float_and_integer"), num_arrays=2, - shared_dtype=True, ), ) -def test_jax__lt_( +def test_jax__ge_( dtype_and_x, frontend, frontend_method_data, @@ -1036,6 +1000,7 @@ def test_jax__lt_( on_device, ): input_dtype, x = dtype_and_x + assume("bfloat16" not in input_dtype) helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -1057,13 +1022,13 @@ def test_jax__lt_( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="jax.numpy.array", - method_name="__le__", + method_name="__gt__", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float_and_integer"), num_arrays=2, ), ) -def test_jax__le_( +def test_jax__gt_( dtype_and_x, frontend, frontend_method_data, @@ -1095,13 +1060,13 @@ def test_jax__le_( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="jax.numpy.array", - method_name="__gt__", + method_name="__le__", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float_and_integer"), num_arrays=2, ), ) -def test_jax__gt_( +def test_jax__le_( dtype_and_x, frontend, frontend_method_data, @@ -1133,13 +1098,14 @@ def test_jax__gt_( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="jax.numpy.array", - method_name="__ge__", + method_name="__lt__", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float_and_integer"), num_arrays=2, + shared_dtype=True, ), ) -def test_jax__ge_( +def test_jax__lt_( dtype_and_x, frontend, frontend_method_data, @@ -1149,7 +1115,6 @@ def test_jax__ge_( on_device, ): input_dtype, x = dtype_and_x - assume("bfloat16" not in input_dtype) helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -1171,12 +1136,13 @@ def test_jax__ge_( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="jax.numpy.array", - method_name="__abs__", + method_name="__ne__", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, ), ) -def test_jax__abs_( +def test_jax__ne_( dtype_and_x, frontend, frontend_method_data, @@ -1186,6 +1152,7 @@ def test_jax__abs_( on_device, ): input_dtype, x = dtype_and_x + assume("bfloat16" not in input_dtype) helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -1193,7 +1160,9 @@ def test_jax__abs_( "object": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, + method_all_as_kwargs_np={ + "other": x[1], + }, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -1202,44 +1171,16 @@ def test_jax__abs_( ) -@st.composite -def _get_dtype_x_and_int(draw, *, dtype="numeric"): - x_dtype, x = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes(dtype), - large_abs_safety_factor=2, - small_abs_safety_factor=2, - safety_factor_scale="log", - ) - ) - pow_dtype, x_int = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("integer"), - min_value=0, - max_value=10, - max_num_dims=0, - max_dim_size=1, - small_abs_safety_factor=2, - large_abs_safety_factor=2, - safety_factor_scale="log", - ) - ) - x_dtype = x_dtype + pow_dtype - return x_dtype, x, x_int - - @handle_frontend_method( class_tree=CLASS_TREE, init_tree="jax.numpy.array", - method_name="__pow__", - dtype_x_pow=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - shared_dtype=True, + method_name="__neg__", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("signed_integer"), ), ) -def test_jax__pow_( - dtype_x_pow, +def test_jax__neg_( + dtype_and_x, frontend, frontend_method_data, init_flags, @@ -1247,7 +1188,7 @@ def test_jax__pow_( backend_fw, on_device, ): - input_dtype, x = dtype_x_pow + input_dtype, x = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -1255,9 +1196,7 @@ def test_jax__pow_( "object": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "other": x[1], - }, + method_all_as_kwargs_np={}, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -1269,11 +1208,15 @@ def test_jax__pow_( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="jax.numpy.array", - method_name="__rpow__", - dtype_x_pow=_get_dtype_x_and_int(), + method_name="__or__", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=st.one_of(st.just(("bool",)), helpers.get_dtypes("integer")), + num_arrays=2, + shared_dtype=True, + ), ) -def test_jax__rpow_( - dtype_x_pow, +def test_jax__or_( + dtype_and_x, frontend, frontend_method_data, init_flags, @@ -1281,17 +1224,15 @@ def test_jax__rpow_( backend_fw, on_device, ): - input_dtype, x, pow = dtype_x_pow + input_dtype, x = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ - "object": pow[0], + "object": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "other": x[0], - }, + method_all_as_kwargs_np={"other": x[1]}, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -1303,14 +1244,12 @@ def test_jax__rpow_( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="jax.numpy.array", - method_name="__and__", + method_name="__pos__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=st.one_of(st.just(("bool",)), helpers.get_dtypes("integer")), - num_arrays=2, - shared_dtype=True, + available_dtypes=helpers.get_dtypes("valid"), ), ) -def test_jax__and_( +def test_jax__pos_( dtype_and_x, frontend, frontend_method_data, @@ -1323,31 +1262,31 @@ def test_jax__and_( helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, + on_device=on_device, init_all_as_kwargs_np={ "object": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={"other": x[1]}, + method_all_as_kwargs_np={}, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, - on_device=on_device, ) @handle_frontend_method( class_tree=CLASS_TREE, init_tree="jax.numpy.array", - method_name="__rand__", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=st.one_of(st.just(("bool",)), helpers.get_dtypes("integer")), + method_name="__pow__", + dtype_x_pow=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), num_arrays=2, shared_dtype=True, ), ) -def test_jax__rand_( - dtype_and_x, +def test_jax__pow_( + dtype_x_pow, frontend, frontend_method_data, init_flags, @@ -1355,7 +1294,7 @@ def test_jax__rand_( backend_fw, on_device, ): - input_dtype, x = dtype_and_x + input_dtype, x = dtype_x_pow helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -1363,7 +1302,9 @@ def test_jax__rand_( "object": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={"other": x[1]}, + method_all_as_kwargs_np={ + "other": x[1], + }, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -1375,14 +1316,14 @@ def test_jax__rand_( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="jax.numpy.array", - method_name="__or__", + method_name="__rand__", dtype_and_x=helpers.dtype_and_values( available_dtypes=st.one_of(st.just(("bool",)), helpers.get_dtypes("integer")), num_arrays=2, shared_dtype=True, ), ) -def test_jax__or_( +def test_jax__rand_( dtype_and_x, frontend, frontend_method_data, @@ -1447,15 +1388,11 @@ def test_jax__ror_( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="jax.numpy.array", - method_name="__xor__", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=st.one_of(st.just(("bool",)), helpers.get_dtypes("integer")), - num_arrays=2, - shared_dtype=True, - ), + method_name="__rpow__", + dtype_x_pow=_get_dtype_x_and_int(), ) -def test_jax__xor_( - dtype_and_x, +def test_jax__rpow_( + dtype_x_pow, frontend, frontend_method_data, init_flags, @@ -1463,15 +1400,17 @@ def test_jax__xor_( backend_fw, on_device, ): - input_dtype, x = dtype_and_x + input_dtype, x, pow = dtype_x_pow helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ - "object": x[0], + "object": pow[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={"other": x[1]}, + method_all_as_kwargs_np={ + "other": x[0], + }, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -1519,12 +1458,14 @@ def test_jax__rxor_( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="jax.numpy.array", - method_name="__invert__", + method_name="__xor__", dtype_and_x=helpers.dtype_and_values( available_dtypes=st.one_of(st.just(("bool",)), helpers.get_dtypes("integer")), + num_arrays=2, + shared_dtype=True, ), ) -def test_jax___invert__( +def test_jax__xor_( dtype_and_x, frontend, frontend_method_data, @@ -1541,7 +1482,7 @@ def test_jax___invert__( "object": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, + method_all_as_kwargs_np={"other": x[1]}, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -1550,45 +1491,40 @@ def test_jax___invert__( ) -# shifting helper -@st.composite -def _get_dtype_x_and_int_shift(draw, dtype): - x_dtype, x = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes(dtype), - num_arrays=2, - shared_dtype=True, - ) - ) - x_dtype = x_dtype - x[1] = np.asarray(np.clip(x[0], 0, np.iinfo(x_dtype[0]).bits - 1), dtype=x_dtype[0]) - return x_dtype, x[0], x[1] - - @handle_frontend_method( class_tree=CLASS_TREE, init_tree="jax.numpy.array", - method_name="__lshift__", - dtype_x_shift=_get_dtype_x_and_int_shift(dtype="signed_integer"), + method_name="all", + dtype_x_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("valid"), + force_int_axis=True, + valid_axis=True, + min_num_dims=1, + ), + keepdims=st.booleans(), ) -def test_jax___lshift__( - dtype_x_shift, +def test_jax_array_all( + dtype_x_axis, + keepdims, + on_device, frontend, frontend_method_data, init_flags, method_flags, backend_fw, - on_device, ): - input_dtype, x, shift = dtype_x_shift + input_dtype, x, axis = dtype_x_axis helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ - "object": x, + "object": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={"other": shift}, + method_all_as_kwargs_np={ + "axis": axis, + "keepdims": keepdims, + }, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -1600,28 +1536,38 @@ def test_jax___lshift__( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="jax.numpy.array", - method_name="__rlshift__", - dtype_x_shift=_get_dtype_x_and_int_shift(dtype="signed_integer"), + method_name="any", + dtype_x_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("valid"), + force_int_axis=True, + valid_axis=True, + min_num_dims=1, + ), + keepdims=st.booleans(), ) -def test_jax___rlshift__( - dtype_x_shift, +def test_jax_array_any( + dtype_x_axis, + keepdims, + on_device, frontend, + backend_fw, frontend_method_data, init_flags, method_flags, - backend_fw, - on_device, ): - input_dtype, x, shift = dtype_x_shift + input_dtype, x, axis = dtype_x_axis helpers.test_frontend_method( init_input_dtypes=input_dtype, - backend_to_test=backend_fw, init_all_as_kwargs_np={ - "object": shift, + "object": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={"other": x}, + method_all_as_kwargs_np={ + "axis": axis, + "keepdims": keepdims, + }, frontend=frontend, + backend_to_test=backend_fw, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -1632,27 +1578,37 @@ def test_jax___rlshift__( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="jax.numpy.array", - method_name="__rshift__", - dtype_x_shift=_get_dtype_x_and_int_shift(dtype="signed_integer"), + method_name="argmax", + dtype_and_x=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("float"), + force_int_axis=True, + min_num_dims=1, + valid_axis=True, + ), + keepdims=st.booleans(), ) -def test_jax___rshift__( - dtype_x_shift, +def test_jax_array_argmax( + dtype_and_x, + keepdims, + on_device, frontend, frontend_method_data, init_flags, method_flags, backend_fw, - on_device, ): - input_dtype, x, shift = dtype_x_shift + input_dtype, x, axis = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ - "object": x, + "object": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={"other": shift}, + method_all_as_kwargs_np={ + "axis": axis, + "keepdims": keepdims, + }, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -1664,27 +1620,37 @@ def test_jax___rshift__( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="jax.numpy.array", - method_name="__rrshift__", - dtype_x_shift=_get_dtype_x_and_int_shift(dtype="signed_integer"), + method_name="argmin", + dtype_and_x=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("float"), + force_int_axis=True, + min_num_dims=1, + valid_axis=True, + ), + keepdims=st.booleans(), ) -def test_jax___rrshift__( - dtype_x_shift, +def test_jax_array_argmin( + dtype_and_x, + keepdims, + on_device, frontend, frontend_method_data, init_flags, method_flags, backend_fw, - on_device, ): - input_dtype, x, shift = dtype_x_shift + input_dtype, x, axis = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ - "object": shift, + "object": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={"other": x}, + method_all_as_kwargs_np={ + "axis": axis, + "keepdims": keepdims, + }, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -1696,32 +1662,36 @@ def test_jax___rrshift__( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="jax.numpy.array", - method_name="__add__", - dtype_x=helpers.dtype_and_values( + method_name="argsort", + dtype_x_axis=helpers.dtype_values_axis( available_dtypes=helpers.get_dtypes("numeric"), - shared_dtype=True, - num_arrays=2, + min_axis=-1, + max_axis=0, + min_num_dims=1, + force_int_axis=True, ), ) -def test_jax___add__( - dtype_x, +def test_jax_array_argsort( + dtype_x_axis, + on_device, frontend, + backend_fw, frontend_method_data, init_flags, method_flags, - backend_fw, - on_device, ): - input_dtype, x = dtype_x + input_dtype, x, axis = dtype_x_axis helpers.test_frontend_method( init_input_dtypes=input_dtype, - backend_to_test=backend_fw, init_all_as_kwargs_np={ "object": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={"other": x[1]}, + method_all_as_kwargs_np={ + "axis": axis, + }, frontend=frontend, + backend_to_test=backend_fw, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -1732,31 +1702,30 @@ def test_jax___add__( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="jax.numpy.array", - method_name="__radd__", - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - shared_dtype=True, - num_arrays=2, - ), + method_name="astype", + dtype_and_x=_get_castable_dtype(), ) -def test_jax___radd__( - dtype_x, +def test_jax_array_astype( + dtype_and_x, + on_device, frontend, + backend_fw, frontend_method_data, init_flags, method_flags, - backend_fw, - on_device, ): - input_dtype, x = dtype_x + input_dtype, x, _, castable_dtype = dtype_and_x + helpers.test_frontend_method( - init_input_dtypes=input_dtype, backend_to_test=backend_fw, + init_input_dtypes=[input_dtype], init_all_as_kwargs_np={ "object": x[0], }, - method_input_dtypes=input_dtype, - method_all_as_kwargs_np={"other": x[1]}, + method_input_dtypes=[input_dtype], + method_all_as_kwargs_np={ + "dtype": castable_dtype, + }, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -1765,26 +1734,41 @@ def test_jax___radd__( ) +@given( + x_y_index=_at_helper(), +) +def test_jax_array_at(x_y_index, backend_fw): + with BackendHandler.update_backend(backend_fw) as ivy_backend: + jax_frontend = ivy_backend.utils.dynamic_import.import_module( + "ivy.functional.frontends.jax" + ) + xy, idx = x_y_index + x = jax_frontend.Array(xy[0]) + y = jax_frontend.Array(xy[1]) + idx = idx[0] + x_set = x.at[idx].set(y[idx]) + assert x_set[idx] == y[idx] + assert x.at[idx].get() == x[idx] + + @handle_frontend_method( class_tree=CLASS_TREE, init_tree="jax.numpy.array", - method_name="__sub__", - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - shared_dtype=True, - num_arrays=2, + method_name="conj", + dtype_and_x=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("real_and_complex"), ), ) -def test_jax___sub__( - dtype_x, +def test_jax_array_conj( + dtype_and_x, + on_device, frontend, frontend_method_data, init_flags, method_flags, backend_fw, - on_device, ): - input_dtype, x = dtype_x + input_dtype, x, axis = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -1792,7 +1776,7 @@ def test_jax___sub__( "object": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={"other": x[1]}, + method_all_as_kwargs_np={}, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -1804,33 +1788,31 @@ def test_jax___sub__( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="jax.numpy.array", - method_name="__rsub__", - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - shared_dtype=True, - num_arrays=2, + method_name="conjugate", + dtype_and_x=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("real_and_complex"), ), ) -def test_jax___rsub__( - dtype_x, +def test_jax_array_conjugate( + dtype_and_x, + on_device, frontend, frontend_method_data, + backend_fw, init_flags, method_flags, - backend_fw, - on_device, ): - input_dtype, x = dtype_x + input_dtype, x, axis = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, - backend_to_test=backend_fw, init_all_as_kwargs_np={ "object": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={"other": x[1]}, + method_all_as_kwargs_np={}, frontend=frontend, frontend_method_data=frontend_method_data, + backend_to_test=backend_fw, init_flags=init_flags, method_flags=method_flags, on_device=on_device, @@ -1840,32 +1822,31 @@ def test_jax___rsub__( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="jax.numpy.array", - method_name="__mul__", + method_name="copy", dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - shared_dtype=True, - num_arrays=2, + available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=1, ), ) -def test_jax___mul__( +def test_jax_array_copy( dtype_x, + on_device, frontend, frontend_method_data, + backend_fw, init_flags, method_flags, - backend_fw, - on_device, ): input_dtype, x = dtype_x helpers.test_frontend_method( init_input_dtypes=input_dtype, - backend_to_test=backend_fw, init_all_as_kwargs_np={ "object": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={"other": x[1]}, + method_all_as_kwargs_np={}, frontend=frontend, + backend_to_test=backend_fw, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -1876,23 +1857,29 @@ def test_jax___mul__( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="jax.numpy.array", - method_name="__rmul__", - dtype_x=helpers.dtype_and_values( + method_name="cumprod", + dtype_and_x=helpers.dtype_values_axis( available_dtypes=helpers.get_dtypes("numeric"), - shared_dtype=True, - num_arrays=2, + min_num_dims=1, + max_num_dims=5, + min_value=-100, + max_value=100, + valid_axis=True, + allow_neg_axes=False, + max_axes_size=1, + force_int_axis=True, ), ) -def test_jax___rmul__( - dtype_x, +def test_jax_array_cumprod( + dtype_and_x, + on_device, frontend, frontend_method_data, init_flags, method_flags, backend_fw, - on_device, ): - input_dtype, x = dtype_x + input_dtype, x, axis = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -1900,7 +1887,9 @@ def test_jax___rmul__( "object": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={"other": x[1]}, + method_all_as_kwargs_np={ + "axis": axis, + }, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -1912,33 +1901,31 @@ def test_jax___rmul__( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="jax.numpy.array", - method_name="__div__", - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - shared_dtype=True, - num_arrays=2, - ), + method_name="cumsum", + dtype_and_x=_get_castable_dtype(), ) -def test_jax___div__( - dtype_x, +def test_jax_array_cumsum( + dtype_and_x, + on_device, frontend, frontend_method_data, + backend_fw, init_flags, method_flags, - backend_fw, - on_device, ): - input_dtype, x = dtype_x - assume(not np.any(np.isclose(x[1], 0))) + input_dtype, x, axis, dtype = dtype_and_x helpers.test_frontend_method( - init_input_dtypes=input_dtype, - backend_to_test=backend_fw, + init_input_dtypes=[input_dtype], init_all_as_kwargs_np={ "object": x[0], }, - method_input_dtypes=input_dtype, - method_all_as_kwargs_np={"other": x[1]}, + method_input_dtypes=[input_dtype], + method_all_as_kwargs_np={ + "axis": axis, + "dtype": dtype, + }, frontend=frontend, + backend_to_test=backend_fw, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -1949,32 +1936,30 @@ def test_jax___div__( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="jax.numpy.array", - method_name="__rdiv__", - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - shared_dtype=True, - num_arrays=2, + method_name="diagonal", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=2, ), ) -def test_jax___rdiv__( - dtype_x, +def test_jax_array_diagonal( + dtype_and_x, + on_device, frontend, + backend_fw, frontend_method_data, init_flags, method_flags, - backend_fw, - on_device, ): - input_dtype, x = dtype_x - assume(not np.any(np.isclose(x[0], 0))) + input_dtype, x = dtype_and_x helpers.test_frontend_method( - init_input_dtypes=input_dtype, backend_to_test=backend_fw, + init_input_dtypes=input_dtype, init_all_as_kwargs_np={ "object": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={"other": x[1]}, + method_all_as_kwargs_np={}, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -1983,30 +1968,48 @@ def test_jax___rdiv__( ) +@given( + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid", prune_function=False) + ), +) +def test_jax_array_dtype( + dtype_x, + backend_fw, +): + dtype, data = dtype_x + with BackendHandler.update_backend(backend_fw) as ivy_backend: + jax_frontend = ivy_backend.utils.dynamic_import.import_module( + "ivy.functional.frontends.jax" + ) + x = jax_frontend.Array(data[0]) + assert x.dtype == dtype[0] + + +# max @handle_frontend_method( class_tree=CLASS_TREE, init_tree="jax.numpy.array", - method_name="__truediv__", - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - shared_dtype=True, - num_arrays=2, - large_abs_safety_factor=2, - small_abs_safety_factor=2, - safety_factor_scale="log", + method_name="max", + dtype_and_x=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("float"), + force_int_axis=True, + min_num_dims=1, + valid_axis=True, ), + keepdims=st.booleans(), ) -def test_jax___truediv__( - dtype_x, +def test_jax_array_max( + dtype_and_x, + keepdims, + on_device, frontend, + backend_fw, frontend_method_data, init_flags, method_flags, - backend_fw, - on_device, ): - input_dtype, x = dtype_x - assume(not np.any(np.isclose(x[1], 0))) + input_dtype, x, axis = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -2014,7 +2017,10 @@ def test_jax___truediv__( "object": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={"other": x[1]}, + method_all_as_kwargs_np={ + "axis": axis, + "keepdims": keepdims, + }, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -2026,24 +2032,29 @@ def test_jax___truediv__( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="jax.numpy.array", - method_name="__rtruediv__", - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - shared_dtype=True, - num_arrays=2, + method_name="mean", + dtype_and_x=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("float"), + large_abs_safety_factor=24, + small_abs_safety_factor=24, + safety_factor_scale="log", + force_int_axis=True, + min_num_dims=1, + valid_axis=True, ), + keepdims=st.booleans(), ) -def test_jax___rtruediv__( - dtype_x, +def test_jax_array_mean( + dtype_and_x, + keepdims, + on_device, frontend, frontend_method_data, init_flags, method_flags, backend_fw, - on_device, ): - input_dtype, x = dtype_x - assume(not np.any(np.isclose(x[0], 0))) + input_dtype, x, axis = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -2051,36 +2062,44 @@ def test_jax___rtruediv__( "object": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={"other": x[1]}, + method_all_as_kwargs_np={ + "axis": axis, + "keepdims": keepdims, + }, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, on_device=on_device, + rtol_=1e-3, + atol_=1e-3, ) +# min @handle_frontend_method( class_tree=CLASS_TREE, init_tree="jax.numpy.array", - method_name="__mod__", - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - shared_dtype=True, - num_arrays=2, + method_name="min", + dtype_and_x=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("float"), + force_int_axis=True, + min_num_dims=1, + valid_axis=True, ), + keepdims=st.booleans(), ) -def test_jax___mod__( - dtype_x, +def test_jax_array_min( + dtype_and_x, + keepdims, + on_device, frontend, + backend_fw, frontend_method_data, init_flags, method_flags, - backend_fw, - on_device, ): - input_dtype, x = dtype_x - assume(not np.any(np.isclose(x[1], 0))) + input_dtype, x, axis = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -2088,7 +2107,10 @@ def test_jax___mod__( "object": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={"other": x[1]}, + method_all_as_kwargs_np={ + "axis": axis, + "keepdims": keepdims, + }, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -2097,27 +2119,43 @@ def test_jax___mod__( ) +@given( + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid", prune_function=False) + ), +) +def test_jax_array_ndim( + dtype_x, + backend_fw, +): + dtype, data = dtype_x + with BackendHandler.update_backend(backend_fw) as ivy_backend: + jax_frontend = ivy_backend.utils.dynamic_import.import_module( + "ivy.functional.frontends.jax" + ) + x = jax_frontend.Array(data[0]) + assert x.ndim == data[0].ndim + + @handle_frontend_method( class_tree=CLASS_TREE, init_tree="jax.numpy.array", - method_name="__rmod__", - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - shared_dtype=True, - num_arrays=2, + method_name="nonzero", + dtype_and_x=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=1, ), ) -def test_jax___rmod__( - dtype_x, +def test_jax_array_nonzero( + dtype_and_x, + on_device, frontend, frontend_method_data, init_flags, method_flags, backend_fw, - on_device, ): - input_dtype, x = dtype_x - assume(not np.any(np.isclose(x[0], 0))) + input_dtype, x, axis = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -2125,7 +2163,7 @@ def test_jax___rmod__( "object": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={"other": x[1]}, + method_all_as_kwargs_np={}, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -2134,48 +2172,58 @@ def test_jax___rmod__( ) -@st.composite -def _get_dtype_input_and_vectors(draw): - dim_size = draw(helpers.ints(min_value=2, max_value=5)) - dtype = draw(helpers.get_dtypes("numeric", index=1, full=False)) - vec1 = draw( - helpers.array_values( - dtype=dtype[0], shape=(dim_size, dim_size), min_value=2, max_value=5 - ) - ) - vec2 = draw( - helpers.array_values( - dtype=dtype[0], shape=(dim_size, dim_size), min_value=2, max_value=5 +@given(x_transpose=_transpose_helper()) +def test_jax_array_property_T(x_transpose, backend_fw): + with BackendHandler.update_backend(backend_fw) as ivy_backend: + x, xT = x_transpose + jax_frontend = ivy_backend.utils.dynamic_import.import_module( + "ivy.functional.frontends.jax" ) - ) - return dtype, [vec1, vec2] + x = jax_frontend.Array(x) + assert np.array_equal(x.T, xT) +# ptp @handle_frontend_method( class_tree=CLASS_TREE, init_tree="jax.numpy.array", - method_name="__matmul__", - dtype_x=_get_dtype_input_and_vectors(), + method_name="ptp", + dtype_and_x_axis_dtype=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("float"), + allow_inf=False, + num_arrays=1, + large_abs_safety_factor=24, + small_abs_safety_factor=24, + safety_factor_scale="log", + min_num_dims=1, + valid_axis=True, + ), + keep_dims=st.booleans(), ) -def test_jax___matmul__( - dtype_x, +def test_jax_array_ptp( + dtype_and_x_axis_dtype, + keep_dims, frontend, frontend_method_data, + backend_fw, init_flags, method_flags, - backend_fw, on_device, ): - input_dtype, x = dtype_x + input_dtypes, x, axis = dtype_and_x_axis_dtype helpers.test_frontend_method( - init_input_dtypes=input_dtype, - backend_to_test=backend_fw, + init_input_dtypes=input_dtypes, init_all_as_kwargs_np={ "object": x[0], }, - method_input_dtypes=input_dtype, - method_all_as_kwargs_np={"other": x[1]}, + method_input_dtypes=input_dtypes, + method_all_as_kwargs_np={ + "axis": axis, + "out": None, + "keepdims": keep_dims, + }, frontend=frontend, + backend_to_test=backend_fw, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -2186,19 +2234,30 @@ def test_jax___matmul__( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="jax.numpy.array", - method_name="__rmatmul__", - dtype_x=_get_dtype_input_and_vectors(), + method_name="ravel", + dtype_and_x=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=1, + max_num_dims=5, + min_dim_size=2, + max_dim_size=10, + shape=helpers.get_shape( + min_num_dims=2, max_num_dims=5, min_dim_size=2, max_dim_size=10 + ), + ), + order=st.sampled_from(["C", "F"]), ) -def test_jax___rmatmul__( - dtype_x, +def test_jax_array_ravel( + dtype_and_x, + order, + on_device, frontend, frontend_method_data, init_flags, method_flags, backend_fw, - on_device, ): - input_dtype, x = dtype_x + input_dtype, x, axis = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -2206,7 +2265,9 @@ def test_jax___rmatmul__( "object": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={"other": x[1]}, + method_all_as_kwargs_np={ + "order": order, + }, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -2215,31 +2276,41 @@ def test_jax___rmatmul__( ) -# __getitem__ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="jax.numpy.array", - method_name="__getitem__", - dtype_x_index=helpers.dtype_array_query( - available_dtypes=helpers.get_dtypes("valid"), - ).filter(lambda x: not (isinstance(x[-1], np.ndarray) and x[-1].dtype == np.bool_)), + method_name="reshape", + dtype_and_x_shape=_get_input_and_reshape(), + order=st.sampled_from(["C", "F"]), + input=st.booleans(), ) -def test_jax___getitem__( - dtype_x_index, +def test_jax_array_reshape( + dtype_and_x_shape, + order, + input, frontend, frontend_method_data, init_flags, method_flags, - backend_fw, on_device, + backend_fw, ): - input_dtype, x, index = dtype_x_index + input_dtype, x, shape = dtype_and_x_shape + if input: + method_flags.num_positional_args = len(shape) + kwargs = {f"{i}": shape[i] for i in range(len(shape))} + else: + kwargs = {"shape": shape} + method_flags.num_positional_args = 1 + kwargs["order"] = order helpers.test_frontend_method( - init_input_dtypes=[input_dtype[0]], backend_to_test=backend_fw, - init_all_as_kwargs_np={"object": x}, - method_input_dtypes=[*input_dtype[1:]], - method_all_as_kwargs_np={"idx": index}, + init_input_dtypes=input_dtype, + init_all_as_kwargs_np={ + "object": x[0], + }, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np=kwargs, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -2287,92 +2358,30 @@ def test_jax_array_round( ) -# var @handle_frontend_method( class_tree=CLASS_TREE, init_tree="jax.numpy.array", - method_name="var", - dtype_and_x=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("float"), - large_abs_safety_factor=24, - small_abs_safety_factor=24, - safety_factor_scale="log", - force_int_axis=True, - min_num_dims=1, - valid_axis=True, - ), - ddof=st.booleans(), - keepdims=st.booleans(), + method_name="searchsorted", + dtype_x_v_side_sorter=_searchsorted(), ) -def test_jax_array_var( - dtype_and_x, - keepdims, - on_device, +def test_jax_array_searchsorted( + dtype_x_v_side_sorter, frontend, - ddof, - backend_fw, frontend_method_data, init_flags, method_flags, -): - input_dtype, x, axis = dtype_and_x - helpers.test_frontend_method( - init_input_dtypes=input_dtype, - backend_to_test=backend_fw, - init_all_as_kwargs_np={ - "object": x[0], - }, - method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "axis": axis, - "ddof": ddof, # You can adjust the ddof value as needed - "keepdims": keepdims, - }, - frontend=frontend, - frontend_method_data=frontend_method_data, - init_flags=init_flags, - method_flags=method_flags, - on_device=on_device, - rtol_=1e-3, - atol_=1e-3, - ) - - -# min -@handle_frontend_method( - class_tree=CLASS_TREE, - init_tree="jax.numpy.array", - method_name="min", - dtype_and_x=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("float"), - force_int_axis=True, - min_num_dims=1, - valid_axis=True, - ), - keepdims=st.booleans(), -) -def test_jax_array_min( - dtype_and_x, - keepdims, on_device, - frontend, backend_fw, - frontend_method_data, - init_flags, - method_flags, ): - input_dtype, x, axis = dtype_and_x + input_dtypes, xs, side, sorter = dtype_x_v_side_sorter helpers.test_frontend_method( - init_input_dtypes=input_dtype, + init_input_dtypes=input_dtypes, backend_to_test=backend_fw, init_all_as_kwargs_np={ - "object": x[0], - }, - method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "axis": axis, - "keepdims": keepdims, + "object": xs[0], }, + method_input_dtypes=input_dtypes, + method_all_as_kwargs_np={"v": xs[0], "side": side, "sorter": sorter}, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -2381,47 +2390,60 @@ def test_jax_array_min( ) -# ptp +@given( + dtype_x_shape=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid", prune_function=False), + ret_shape=True, + ), +) +def test_jax_array_shape( + dtype_x_shape, + backend_fw, +): + _, data, shape = dtype_x_shape + with BackendHandler.update_backend(backend_fw) as ivy_backend: + jax_frontend = ivy_backend.utils.dynamic_import.import_module( + "ivy.functional.frontends.jax" + ) + x = jax_frontend.Array(data[0]) + assert x.shape == shape + + @handle_frontend_method( class_tree=CLASS_TREE, init_tree="jax.numpy.array", - method_name="ptp", - dtype_and_x_axis_dtype=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("float"), - allow_inf=False, - num_arrays=1, - large_abs_safety_factor=24, - small_abs_safety_factor=24, - safety_factor_scale="log", - min_num_dims=1, - valid_axis=True, + method_name="sort", + dtype_x_axis=helpers.dtype_values_axis( + available_dtypes=["int64"], + force_int_axis=True, + min_axis=-1, + max_axis=-1, + min_dim_size=2, + max_dim_size=100, + min_num_dims=2, ), - keep_dims=st.booleans(), ) -def test_jax_array_ptp( - dtype_and_x_axis_dtype, - keep_dims, +def test_jax_array_sort( + dtype_x_axis, + on_device, frontend, frontend_method_data, backend_fw, init_flags, method_flags, - on_device, ): - input_dtypes, x, axis = dtype_and_x_axis_dtype + input_dtype, x, axis = dtype_x_axis helpers.test_frontend_method( - init_input_dtypes=input_dtypes, + backend_to_test=backend_fw, + init_input_dtypes=input_dtype, init_all_as_kwargs_np={ "object": x[0], }, - method_input_dtypes=input_dtypes, + method_input_dtypes=input_dtype, method_all_as_kwargs_np={ "axis": axis, - "out": None, - "keepdims": keep_dims, }, frontend=frontend, - backend_to_test=backend_fw, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -2429,24 +2451,29 @@ def test_jax_array_ptp( ) -# max +# var @handle_frontend_method( class_tree=CLASS_TREE, init_tree="jax.numpy.array", - method_name="max", + method_name="var", dtype_and_x=helpers.dtype_values_axis( available_dtypes=helpers.get_dtypes("float"), + large_abs_safety_factor=24, + small_abs_safety_factor=24, + safety_factor_scale="log", force_int_axis=True, min_num_dims=1, valid_axis=True, ), + ddof=st.booleans(), keepdims=st.booleans(), ) -def test_jax_array_max( +def test_jax_array_var( dtype_and_x, keepdims, on_device, frontend, + ddof, backend_fw, frontend_method_data, init_flags, @@ -2462,6 +2489,7 @@ def test_jax_array_max( method_input_dtypes=input_dtype, method_all_as_kwargs_np={ "axis": axis, + "ddof": ddof, # You can adjust the ddof value as needed "keepdims": keepdims, }, frontend=frontend, @@ -2469,141 +2497,78 @@ def test_jax_array_max( init_flags=init_flags, method_flags=method_flags, on_device=on_device, + rtol_=1e-3, + atol_=1e-3, ) -# searchsorted -@st.composite -def _searchsorted(draw): - dtype_x, x = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes( - "numeric", full=False, key="searchsorted" - ), - shape=(draw(st.integers(min_value=1, max_value=10)),), - ), - ) - dtype_v, v = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes( - "numeric", full=False, key="searchsorted" - ), - min_num_dims=1, - ) - ) - - input_dtypes = dtype_x + dtype_v - xs = x + v - side = draw(st.sampled_from(["left", "right"])) - sorter = None - xs[0] = np.sort(xs[0], axis=-1) - return input_dtypes, xs, side, sorter - - -@handle_frontend_method( - class_tree=CLASS_TREE, - init_tree="jax.numpy.array", - method_name="searchsorted", - dtype_x_v_side_sorter=_searchsorted(), +@given( + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid", prune_function=False) + ), ) -def test_jax_array_searchsorted( - dtype_x_v_side_sorter, - frontend, - frontend_method_data, - init_flags, - method_flags, - on_device, +def test_jax_ivy_array( + dtype_x, backend_fw, ): - input_dtypes, xs, side, sorter = dtype_x_v_side_sorter - helpers.test_frontend_method( - init_input_dtypes=input_dtypes, - backend_to_test=backend_fw, - init_all_as_kwargs_np={ - "object": xs[0], - }, - method_input_dtypes=input_dtypes, - method_all_as_kwargs_np={"v": xs[0], "side": side, "sorter": sorter}, - frontend=frontend, - frontend_method_data=frontend_method_data, - init_flags=init_flags, - method_flags=method_flags, - on_device=on_device, - ) + _, data = dtype_x + with BackendHandler.update_backend(backend_fw) as ivy_backend: + jax_frontend = ivy_backend.utils.dynamic_import.import_module( + "ivy.functional.frontends.jax" + ) + x = jax_frontend.Array(data[0]) + ret = helpers.flatten_and_to_np(ret=x.ivy_array.data, backend=backend_fw) + ret_gt = helpers.flatten_and_to_np(ret=data[0], backend=backend_fw) + helpers.value_test( + ret_np_flat=ret, + ret_np_from_gt_flat=ret_gt, + backend=backend_fw, + ground_truth_backend="jax", + ) @handle_frontend_method( class_tree=CLASS_TREE, init_tree="jax.numpy.array", - method_name="reshape", - dtype_and_x_shape=_get_input_and_reshape(), - order=st.sampled_from(["C", "F"]), - input=st.booleans(), + method_name="prod", + dtype_x_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("numeric"), + force_int_axis=True, + valid_axis=True, + min_dim_size=2, + max_dim_size=10, + min_num_dims=2, + ), ) -def test_jax_array_reshape( - dtype_and_x_shape, - order, - input, +def test_jax_prod( + dtype_x_axis, + on_device, frontend, frontend_method_data, init_flags, method_flags, - on_device, backend_fw, ): - input_dtype, x, shape = dtype_and_x_shape - if input: - method_flags.num_positional_args = len(shape) - kwargs = {f"{i}": shape[i] for i in range(len(shape))} - else: - kwargs = {"shape": shape} - method_flags.num_positional_args = 1 - kwargs["order"] = order + input_dtype, x, axis = dtype_x_axis helpers.test_frontend_method( - backend_to_test=backend_fw, init_input_dtypes=input_dtype, + backend_to_test=backend_fw, init_all_as_kwargs_np={ "object": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np=kwargs, + method_all_as_kwargs_np={ + "axis": axis, + }, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, on_device=on_device, + atol_=1e-04, ) -# repeat -@st.composite -def _repeat_helper(draw): - shape = draw(st.shared(helpers.get_shape(min_num_dims=1), key="value_shape")) - axis = draw( - st.shared( - st.one_of(st.none(), helpers.get_axis(shape=shape, max_size=1)), key="axis" - ) - ) - - if not isinstance(axis, int) and axis is not None: - axis = axis[0] - - repeat_shape = ( - (draw(st.one_of(st.just(1), st.just(shape[axis]))),) - if axis is not None - else (1,) - ) - repeat = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("integer"), - shape=repeat_shape, - min_value=0, - max_value=10, - ) - ) - return repeat - - @handle_frontend_method( class_tree=CLASS_TREE, init_tree="jax.numpy.array", @@ -2669,3 +2634,46 @@ def test_jax_repeat( method_flags=method_flags, on_device=on_device, ) + + +@handle_frontend_method( + class_tree=CLASS_TREE, + init_tree="jax.numpy.array", + method_name="sum", + dtype_and_x=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("numeric"), + min_num_dims=1, + max_num_dims=5, + min_dim_size=2, + max_dim_size=10, + valid_axis=True, + force_int_axis=True, + ), +) +def test_jax_sum( + dtype_and_x, + on_device, + frontend, + frontend_method_data, + backend_fw, + init_flags, + method_flags, +): + input_dtype, x, axis = dtype_and_x + helpers.test_frontend_method( + backend_to_test=backend_fw, + init_input_dtypes=input_dtype, + init_all_as_kwargs_np={ + "object": x[0], + }, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={ + "axis": axis, + }, + frontend=frontend, + frontend_method_data=frontend_method_data, + init_flags=init_flags, + method_flags=method_flags, + on_device=on_device, + atol_=1e-04, + ) diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_func_wrapper.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_func_wrapper.py index d638b1e500798..8ad6a133c9cec 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_func_wrapper.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_func_wrapper.py @@ -13,6 +13,10 @@ import ivy.functional.frontends.jax as jax_frontend +# --- Helpers --- # +# --------------- # + + def _fn(x, check_default=False): if check_default and jax_frontend.config.jax_enable_x64: ivy.utils.assertions.check_equal( @@ -24,6 +28,10 @@ def _fn(x, check_default=False): return x +# --- Main --- # +# ------------ # + + @given( dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid", prune_function=False) diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_general_functions.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_general_functions.py index d193a8b41d439..2ac4e1b44523e 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_general_functions.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_general_functions.py @@ -7,6 +7,10 @@ import jax +# --- Helpers --- # +# --------------- # + + def _fn1(x, y): return ivy.matmul(x, y) @@ -19,6 +23,77 @@ def _fn3(x, y): return ivy.add(x, y) +# --- Main --- # +# ------------ # + + +# device_get +@handle_frontend_test( + fn_tree="jax.general_functions.device_get", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + ), +) +def test_jax_device_get( + *, + dtype_and_x, + test_flags, + fn_tree, + frontend, + backend_fw, + on_device, +): + with BackendHandler.update_backend(backend_fw) as ivy_backend: + dtype, x = dtype_and_x + dtype = dtype[0] + x = x[0] + + x = ivy_backend.asarray(x) + if test_flags.as_variable and ivy_backend.is_float_dtype(dtype): + x = ivy_backend.functional.ivy.gradients._variable(x) + + x_on_dev = ivy_backend.functional.frontends.jax.device_get(x).ivy_array + dev_from_new_x = ivy_backend.dev(x_on_dev) + + # value test + assert dev_from_new_x == "cpu" + + +# device_put +@handle_frontend_test( + fn_tree="jax.general_functions.device_put", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + ), +) +def test_jax_device_put( + *, + dtype_and_x, + test_flags, + fn_tree, + frontend, + backend_fw, + on_device, +): + with BackendHandler.update_backend(backend_fw) as ivy_backend: + dtype, x = dtype_and_x + dtype = dtype[0] + x = x[0] + + x = ivy_backend.asarray(x) + if test_flags.as_variable and ivy_backend.is_float_dtype(dtype): + x = ivy_backend.functional.ivy.gradients._variable(x) + + device = ivy_backend.dev(x) + x_on_dev = ivy_backend.functional.frontends.jax.device_put( + x, on_device + ).ivy_array + dev_from_new_x = ivy_backend.dev(x_on_dev) + + # value test + assert dev_from_new_x == device + + # vmap @handle_frontend_test( fn_tree="jax.general_functions.vmap", @@ -93,70 +168,3 @@ def test_jax_vmap( pass else: assert False, "One of the results is None while other isn't" - - -# device_put -@handle_frontend_test( - fn_tree="jax.general_functions.device_put", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - ), -) -def test_jax_device_put( - *, - dtype_and_x, - test_flags, - fn_tree, - frontend, - backend_fw, - on_device, -): - with BackendHandler.update_backend(backend_fw) as ivy_backend: - dtype, x = dtype_and_x - dtype = dtype[0] - x = x[0] - - x = ivy_backend.asarray(x) - if test_flags.as_variable and ivy_backend.is_float_dtype(dtype): - x = ivy_backend.functional.ivy.gradients._variable(x) - - device = ivy_backend.dev(x) - x_on_dev = ivy_backend.functional.frontends.jax.device_put( - x, on_device - ).ivy_array - dev_from_new_x = ivy_backend.dev(x_on_dev) - - # value test - assert dev_from_new_x == device - - -# device_get -@handle_frontend_test( - fn_tree="jax.general_functions.device_get", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - ), -) -def test_jax_device_get( - *, - dtype_and_x, - test_flags, - fn_tree, - frontend, - backend_fw, - on_device, -): - with BackendHandler.update_backend(backend_fw) as ivy_backend: - dtype, x = dtype_and_x - dtype = dtype[0] - x = x[0] - - x = ivy_backend.asarray(x) - if test_flags.as_variable and ivy_backend.is_float_dtype(dtype): - x = ivy_backend.functional.ivy.gradients._variable(x) - - x_on_dev = ivy_backend.functional.frontends.jax.device_get(x).ivy_array - dev_from_new_x = ivy_backend.dev(x_on_dev) - - # value test - assert dev_from_new_x == "cpu" diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_lax/test_control_flow_operators.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_lax/test_control_flow_operators.py index 722c5a283446b..0c506060704db 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_lax/test_control_flow_operators.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_lax/test_control_flow_operators.py @@ -50,25 +50,31 @@ def _test_false_fn(x): @handle_frontend_test( - fn_tree="jax.lax.map", + fn_tree="jax.lax.fori_loop", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), + min_value=-1000, + max_value=1000, min_num_dims=1, min_dim_size=1, ), + lower=st.integers(min_value=-10, max_value=10), + upper=st.integers(min_value=-10, max_value=10), test_with_out=st.just(False), ) -def test_jax_map( +def test_jax_fori_loop( *, dtype_and_x, + lower, + upper, test_flags, on_device, fn_tree, frontend, backend_fw, ): - def _test_map_fn(x): - return x + x + def _test_body_fn(x, y): + return x + y input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -78,37 +84,34 @@ def _test_map_fn(x): frontend=frontend, fn_tree=fn_tree, on_device=on_device, - f=_test_map_fn, - xs=x[0], + lower=lower, + upper=upper, + body_fun=_test_body_fn, + init_val=x[0], ) @handle_frontend_test( - fn_tree="jax.lax.switch", + fn_tree="jax.lax.map", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), min_num_dims=1, min_dim_size=1, ), - index=helpers.ints(min_value=-10, max_value=10), test_with_out=st.just(False), ) -def test_jax_switch( +def test_jax_map( *, dtype_and_x, - index, test_flags, on_device, fn_tree, frontend, backend_fw, ): - def _test_branch_1(x): + def _test_map_fn(x): return x + x - def _test_branch_2(x): - return x * x - input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, @@ -117,14 +120,13 @@ def _test_branch_2(x): frontend=frontend, fn_tree=fn_tree, on_device=on_device, - index=index, - branches=[_test_branch_1, _test_branch_2], - operand=x[0], + f=_test_map_fn, + xs=x[0], ) @handle_frontend_test( - fn_tree="jax.lax.fori_loop", + fn_tree="jax.lax.scan", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), min_value=-1000, @@ -132,23 +134,26 @@ def _test_branch_2(x): min_num_dims=1, min_dim_size=1, ), - lower=st.integers(min_value=-10, max_value=10), - upper=st.integers(min_value=-10, max_value=10), + length=st.integers(min_value=-10, max_value=10), + init=st.integers(min_value=-10, max_value=10), test_with_out=st.just(False), ) -def test_jax_fori_loop( +def test_jax_scan( *, dtype_and_x, - lower, - upper, + length, + init, test_flags, on_device, fn_tree, frontend, backend_fw, ): - def _test_body_fn(x, y): - return x + y + if length == 0 or length != len(dtype_and_x[1][0]): + return + + def _test_scan_fn(carry, x): + return carry + x, x * 2 input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -158,48 +163,38 @@ def _test_body_fn(x, y): frontend=frontend, fn_tree=fn_tree, on_device=on_device, - lower=lower, - upper=upper, - body_fun=_test_body_fn, - init_val=x[0], + f=_test_scan_fn, + init=init, + xs=x[0], + length=length, ) @handle_frontend_test( - fn_tree="jax.lax.while_loop", + fn_tree="jax.lax.switch", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), - min_value=-1000, - max_value=1000, min_num_dims=1, min_dim_size=1, ), + index=helpers.ints(min_value=-10, max_value=10), + test_with_out=st.just(False), ) -def test_jax_while_loop( +def test_jax_switch( *, dtype_and_x, + index, test_flags, on_device, fn_tree, frontend, backend_fw, ): - def _test_cond_fn(x): - def any_negative_real(arr): - for elem in arr: - if isinstance(elem, (int, float)) and elem < 0: - return True - elif isinstance(elem, complex): - return False - elif isinstance(elem, (list, tuple)): - if any_negative_real(elem): - return True - return False - - return any_negative_real(x) + def _test_branch_1(x): + return x + x - def _test_body_fn(x): - return x + 1 + def _test_branch_2(x): + return x * x input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -209,14 +204,14 @@ def _test_body_fn(x): frontend=frontend, fn_tree=fn_tree, on_device=on_device, - cond_fun=_test_cond_fn, - body_fun=_test_body_fn, - init_val=x[0], + index=index, + branches=[_test_branch_1, _test_branch_2], + operand=x[0], ) @handle_frontend_test( - fn_tree="jax.lax.scan", + fn_tree="jax.lax.while_loop", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), min_value=-1000, @@ -224,26 +219,32 @@ def _test_body_fn(x): min_num_dims=1, min_dim_size=1, ), - length=st.integers(min_value=-10, max_value=10), - init=st.integers(min_value=-10, max_value=10), - test_with_out=st.just(False), ) -def test_jax_scan( +def test_jax_while_loop( *, dtype_and_x, - length, - init, test_flags, on_device, fn_tree, frontend, backend_fw, ): - if length == 0 or length != len(dtype_and_x[1][0]): - return + def _test_cond_fn(x): + def any_negative_real(arr): + for elem in arr: + if isinstance(elem, (int, float)) and elem < 0: + return True + elif isinstance(elem, complex): + return False + elif isinstance(elem, (list, tuple)): + if any_negative_real(elem): + return True + return False - def _test_scan_fn(carry, x): - return carry + x, x * 2 + return any_negative_real(x) + + def _test_body_fn(x): + return x + 1 input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -253,8 +254,7 @@ def _test_scan_fn(carry, x): frontend=frontend, fn_tree=fn_tree, on_device=on_device, - f=_test_scan_fn, - init=init, - xs=x[0], - length=length, + cond_fun=_test_cond_fn, + body_fun=_test_body_fn, + init_val=x[0], ) diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_lax/test_linalg.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_lax/test_linalg.py index 0c72cea0b4005..73194a41bd519 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_lax/test_linalg.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_lax/test_linalg.py @@ -9,9 +9,9 @@ from ivy_tests.test_ivy.helpers import handle_frontend_test, BackendHandler -# svd +# cholesky @handle_frontend_test( - fn_tree="jax.lax.linalg.svd", + fn_tree="jax.lax.linalg.cholesky", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), min_value=0, @@ -23,15 +23,13 @@ and np.linalg.cond(x[1][0]) < 1 / sys.float_info.epsilon and np.linalg.det(np.asarray(x[1][0])) != 0 ), - full_matrices=st.booleans(), - compute_uv=st.booleans(), + symmetrize_input=st.booleans(), test_with_out=st.just(False), ) -def test_jax_svd( +def test_jax_cholesky( *, dtype_and_x, - full_matrices, - compute_uv, + symmetrize_input, on_device, fn_tree, frontend, @@ -42,52 +40,26 @@ def test_jax_svd( x = np.asarray(x[0], dtype=dtype[0]) # make symmetric positive-definite beforehand x = np.matmul(x.T, x) + np.identity(x.shape[0]) * 1e-3 - - ret, frontend_ret = helpers.test_frontend_function( + fw_ret, gt_ret = helpers.test_frontend_function( input_dtypes=dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - test_values=False, + rtol=1e-02, x=x, - full_matrices=full_matrices, - compute_uv=compute_uv, + symmetrize_input=symmetrize_input, + test_values=False, ) - - if compute_uv: - with BackendHandler.update_backend(backend_fw) as ivy_backend: - ret = [ivy_backend.to_numpy(x) for x in ret] - frontend_ret = [np.asarray(x) for x in frontend_ret] - - u, s, vh = ret - frontend_u, frontend_s, frontend_vh = frontend_ret - - assert_all_close( - ret_np=u @ np.diag(s) @ vh, - ret_from_gt_np=frontend_u @ np.diag(frontend_s) @ frontend_vh, - rtol=1e-2, - atol=1e-2, - backend=backend_fw, - ground_truth_backend=frontend, - ) - else: - with BackendHandler.update_backend(backend_fw) as ivy_backend: - ret = ivy_backend.to_numpy(ret) - assert_all_close( - ret_np=ret, - ret_from_gt_np=np.asarray(frontend_ret[0]), - rtol=1e-2, - atol=1e-2, - backend=backend_fw, - ground_truth_backend=frontend, - ) + # ToDo: turn value test on when jax cholesky is fixed in issue + # https: // github.com / google / jax / issues / 16185 + helpers.assertions.assert_same_type_and_shape([fw_ret, gt_ret]) -# cholesky +# eigh @handle_frontend_test( - fn_tree="jax.lax.linalg.cholesky", + fn_tree="jax.lax.linalg.eigh", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), min_value=0, @@ -99,12 +71,14 @@ def test_jax_svd( and np.linalg.cond(x[1][0]) < 1 / sys.float_info.epsilon and np.linalg.det(np.asarray(x[1][0])) != 0 ), + lower=st.booleans(), symmetrize_input=st.booleans(), test_with_out=st.just(False), ) -def test_jax_cholesky( +def test_jax_eigh( *, dtype_and_x, + lower, symmetrize_input, on_device, fn_tree, @@ -113,29 +87,41 @@ def test_jax_cholesky( backend_fw, ): dtype, x = dtype_and_x - x = np.asarray(x[0], dtype=dtype[0]) + x = np.array(x[0], dtype=dtype[0]) # make symmetric positive-definite beforehand x = np.matmul(x.T, x) + np.identity(x.shape[0]) * 1e-3 - fw_ret, gt_ret = helpers.test_frontend_function( + + ret, frontend_ret = helpers.test_frontend_function( input_dtypes=dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - rtol=1e-02, + test_values=False, x=x, + lower=lower, symmetrize_input=symmetrize_input, - test_values=False, ) - # ToDo: turn value test on when jax cholesky is fixed in issue - # https: // github.com / google / jax / issues / 16185 - helpers.assertions.assert_same_type_and_shape([fw_ret, gt_ret]) + with BackendHandler.update_backend(backend_fw) as ivy_backend: + ret = [ivy_backend.to_numpy(x) for x in ret] + frontend_ret = [np.asarray(x) for x in frontend_ret] + L, Q = ret + frontend_Q, frontend_L = frontend_ret -# eigh + assert_all_close( + ret_np=Q @ np.diag(L) @ Q.T, + ret_from_gt_np=frontend_Q @ np.diag(frontend_L) @ frontend_Q.T, + atol=1e-2, + backend=backend_fw, + ground_truth_backend=frontend, + ) + + +# svd @handle_frontend_test( - fn_tree="jax.lax.linalg.eigh", + fn_tree="jax.lax.linalg.svd", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), min_value=0, @@ -147,15 +133,15 @@ def test_jax_cholesky( and np.linalg.cond(x[1][0]) < 1 / sys.float_info.epsilon and np.linalg.det(np.asarray(x[1][0])) != 0 ), - lower=st.booleans(), - symmetrize_input=st.booleans(), + full_matrices=st.booleans(), + compute_uv=st.booleans(), test_with_out=st.just(False), ) -def test_jax_eigh( +def test_jax_svd( *, dtype_and_x, - lower, - symmetrize_input, + full_matrices, + compute_uv, on_device, fn_tree, frontend, @@ -163,7 +149,7 @@ def test_jax_eigh( backend_fw, ): dtype, x = dtype_and_x - x = np.array(x[0], dtype=dtype[0]) + x = np.asarray(x[0], dtype=dtype[0]) # make symmetric positive-definite beforehand x = np.matmul(x.T, x) + np.identity(x.shape[0]) * 1e-3 @@ -176,20 +162,34 @@ def test_jax_eigh( on_device=on_device, test_values=False, x=x, - lower=lower, - symmetrize_input=symmetrize_input, + full_matrices=full_matrices, + compute_uv=compute_uv, ) - with BackendHandler.update_backend(backend_fw) as ivy_backend: - ret = [ivy_backend.to_numpy(x) for x in ret] - frontend_ret = [np.asarray(x) for x in frontend_ret] - L, Q = ret - frontend_Q, frontend_L = frontend_ret + if compute_uv: + with BackendHandler.update_backend(backend_fw) as ivy_backend: + ret = [ivy_backend.to_numpy(x) for x in ret] + frontend_ret = [np.asarray(x) for x in frontend_ret] - assert_all_close( - ret_np=Q @ np.diag(L) @ Q.T, - ret_from_gt_np=frontend_Q @ np.diag(frontend_L) @ frontend_Q.T, - atol=1e-2, - backend=backend_fw, - ground_truth_backend=frontend, - ) + u, s, vh = ret + frontend_u, frontend_s, frontend_vh = frontend_ret + + assert_all_close( + ret_np=u @ np.diag(s) @ vh, + ret_from_gt_np=frontend_u @ np.diag(frontend_s) @ frontend_vh, + rtol=1e-2, + atol=1e-2, + backend=backend_fw, + ground_truth_backend=frontend, + ) + else: + with BackendHandler.update_backend(backend_fw) as ivy_backend: + ret = ivy_backend.to_numpy(ret) + assert_all_close( + ret_np=ret, + ret_from_gt_np=np.asarray(frontend_ret[0]), + rtol=1e-2, + atol=1e-2, + backend=backend_fw, + ground_truth_backend=frontend, + ) diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_lax/test_operators.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_lax/test_operators.py index 455f0e5d26de3..b071645a3e7c0 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_lax/test_operators.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_lax/test_operators.py @@ -24,125 +24,8 @@ ) -# imag -@handle_frontend_test( - fn_tree="jax.lax.imag", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("complex") - ), -) -def test_jax_imag( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - test_values=True, - x=x[0], - ) - - -# add -@handle_frontend_test( - fn_tree="jax.lax.add", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, - shared_dtype=True, - ), - test_with_out=st.just(False), -) -def test_jax_add( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - y=x[1], - ) - - -# tan -@handle_frontend_test( - fn_tree="jax.lax.tan", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), - test_with_out=st.just(False), -) -def test_jax_tan( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - ) - - -# max -@handle_frontend_test( - fn_tree="jax.lax.max", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, - shared_dtype=True, - ), - test_with_out=st.just(False), -) -def test_jax_max( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, x = dtype_and_x - - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - y=x[1], - ) +# --- Helpers --- # +# --------------- # # noinspection DuplicatedCode @@ -183,32 +66,57 @@ def _arrays_idx_n_dtypes(draw): return xs, input_dtypes, unique_idx -# concat -@handle_frontend_test( - fn_tree="jax.lax.concatenate", - xs_n_input_dtypes_n_unique_idx=_arrays_idx_n_dtypes(), - test_with_out=st.just(False), -) -def test_jax_concat( - *, - xs_n_input_dtypes_n_unique_idx, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - xs, input_dtypes, unique_idx = xs_n_input_dtypes_n_unique_idx - helpers.test_frontend_function( - input_dtypes=input_dtypes, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - operands=xs, - dimension=unique_idx, +@st.composite +def _div_dtypes_and_xs(draw): + dtype, dividend, shape = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), ret_shape=True + ) + ) + divisor = draw( + helpers.array_values(dtype=dtype[0], min_value=-20, max_value=20, shape=shape) + ) + return dtype, [dividend[0], divisor] + + +# select +@st.composite +def _dtype_pred_ontrue_on_false(draw): + shape = draw(helpers.get_shape(min_num_dims=1, min_dim_size=1)) + pred = draw(helpers.array_values(dtype="bool", shape=shape)) + dtypes, on_true_on_false = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, + shape=shape, + large_abs_safety_factor=16, + small_abs_safety_factor=16, + safety_factor_scale="log", + shared_dtype=True, + ) + ) + return dtypes, pred, on_true_on_false + + +@st.composite +def _dtype_values_dims(draw): + dtype, values, shape = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + min_num_dims=1, + ret_shape=True, + ) + ) + size = len(shape) + permutations = draw( + st.lists( + st.integers(min_value=0, max_value=len(shape) - 1), + min_size=size, + max_size=size, + unique=True, + ) ) + return dtype, values, tuple(permutations) @st.composite @@ -222,59 +130,434 @@ def _fill_value(draw): return draw(helpers.floats(min_value=-5, max_value=5)) -@handle_frontend_test( - fn_tree="jax.lax.full", - shape=helpers.get_shape( - allow_none=False, - min_num_dims=1, - max_num_dims=5, - min_dim_size=1, - max_dim_size=10, - ), - fill_value=_fill_value(), - dtypes=helpers.get_dtypes("numeric", full=False, key="dtype"), -) -def test_jax_full( - *, - shape, - fill_value, - dtypes, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - helpers.test_frontend_function( - input_dtypes=dtypes, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - shape=shape, - fill_value=fill_value, - dtype=dtypes[0], +@st.composite +def _general_dot_helper(draw): + input_dtype, lhs, lshape = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + min_value=-1e04, + max_value=1e04, + min_num_dims=2, + ret_shape=True, + ) ) - - -# abs -@handle_frontend_test( - fn_tree="jax.lax.abs", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("signed_integer"), - ), - test_with_out=st.just(False), -) -def test_jax_abs( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): + ndims = len(lshape) + perm_id = random.sample(list(range(ndims)), ndims) + rshape = [lshape[i] for i in perm_id] + input_dtype, rhs = draw( + helpers.dtype_and_values( + dtype=input_dtype, + min_value=-1e04, + max_value=1e04, + shape=rshape, + ) + ) + ind_list = list(range(ndims)) + batch_n = draw(st.integers(min_value=1, max_value=len(lshape) - 1)) + lhs_batch = random.sample(ind_list, batch_n) + rhs_batch = [perm_id.index(i) for i in lhs_batch] + lhs_contracting = [i for i in ind_list if i not in lhs_batch] + rhs_contracting = [perm_id.index(i) for i in lhs_contracting] + is_pref = draw(st.booleans()) + pref_dtype = None + if is_pref: + uint_cast_st = helpers.get_castable_dtype( + draw(helpers.get_dtypes("unsigned")), + input_dtype[0], + ) + int_cast_st = helpers.get_castable_dtype( + draw(helpers.get_dtypes("signed_integer")), + input_dtype[0], + ) + float_cast_st = helpers.get_castable_dtype( + draw(helpers.get_dtypes("float")), + input_dtype[0], + ) + complex_cast_st = helpers.get_castable_dtype( + draw(helpers.get_dtypes("complex")), + input_dtype[0], + ) + if "uint" in input_dtype[0]: + pref_dtype = draw(st.one_of(uint_cast_st, float_cast_st))[-1] + elif "int" in input_dtype[0]: + pref_dtype = draw(st.one_of(int_cast_st, float_cast_st))[-1] + elif "float" in input_dtype[0]: + pref_dtype = draw(float_cast_st)[-1] + elif "complex" in input_dtype[0]: + pref_dtype = draw(complex_cast_st)[-1] + else: + raise ivy.exceptions.IvyException("unsupported dtype") + return ( + input_dtype * 2, + (lhs[0], rhs[0]), + ((lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch)), + pref_dtype, + ) + + +@st.composite +def _get_clamp_inputs(draw): + shape = draw( + helpers.get_shape( + min_num_dims=1, max_num_dims=5, min_dim_size=2, max_dim_size=10 + ) + ) + x_dtype, x = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + shape=shape, + ) + ) + + min = draw( + helpers.array_values(dtype=x_dtype[0], shape=shape, min_value=-25, max_value=0) + ) + max = draw( + helpers.array_values(dtype=x_dtype[0], shape=shape, min_value=1, max_value=25) + ) + return x_dtype, x, min, max + + +@st.composite +def _get_dtype_inputs_for_batch_matmul(draw): + dtype, lhs = draw( + helpers.dtype_and_values( + min_num_dims=2, + max_num_dims=6, + min_value=2, + max_value=5, + ) + ) + lhs_shape = lhs[0].shape + rhs_shape = list(lhs_shape) + rhs_shape[-1], rhs_shape[-2] = rhs_shape[-2], rhs_shape[-1] + rhs_shape = tuple(rhs_shape) + rhs = draw( + helpers.array_values( + dtype=dtype[0], + shape=rhs_shape, + min_value=2, + max_value=5, + ) + ) + + return dtype, lhs[0], rhs + + +@st.composite +def _get_dtype_inputs_for_dot(draw): + dim_size = draw(helpers.ints(min_value=1, max_value=5)) + dtype = draw(helpers.get_dtypes("numeric", index=1, full=False)) + if dim_size == 1: + lhs = draw( + helpers.array_values( + dtype=dtype[0], shape=(dim_size,), min_value=2, max_value=5 + ) + ) + rhs = draw( + helpers.array_values( + dtype=dtype[0], shape=(dim_size,), min_value=2, max_value=5 + ) + ) + else: + lhs = draw( + helpers.array_values( + dtype=dtype[0], shape=(dim_size, dim_size), min_value=2, max_value=5 + ) + ) + rhs = draw( + helpers.array_values( + dtype=dtype[0], shape=(dim_size, dim_size), min_value=2, max_value=5 + ) + ) + is_pref = draw(st.booleans()) + if is_pref: + dtype, values, pref = draw( + helpers.get_castable_dtype( + draw(helpers.get_dtypes("numeric")), dtype[0], [lhs, rhs] + ) + ) + assume(can_cast(dtype, pref)) + return [dtype], pref, values[0], values[1] + else: + return dtype, None, lhs, rhs + + +def _get_reduce_func(dtype): + if dtype[0] == "bool": + return st.sampled_from([jnp.logical_and, jnp.logical_or]) + else: + return st.sampled_from([jlax.add, jlax.max, jlax.min, jlax.mul, jnp.multiply]) + + +@st.composite +def _pad_helper(draw): + dtype, x, shape = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("bool"), + ret_shape=True, + min_num_dims=1, + min_dim_size=2, + min_value=-100, + max_value=100, + ).filter(lambda _x: _x[0][0] not in ["float16", "bfloat16"]) + ) + ndim = len(shape) + min_dim = min(shape) + padding_config = draw( + st.lists( + st.tuples( + st.integers(min_value=-(min_dim - 1), max_value=min_dim - 1), + st.integers(min_value=-(min_dim - 1), max_value=min_dim - 1), + st.integers(min_value=0, max_value=min_dim - 1), + ), + min_size=ndim, + max_size=ndim, + ) + ) + padding_value = draw(st.booleans()) + return dtype, x[0], padding_value, padding_config + + +@st.composite +def _reshape_helper(draw): + # generate a shape s.t len(shape) > 0 + shape = draw( + helpers.get_shape( + allow_none=False, + min_num_dims=1, + max_num_dims=3, + min_dim_size=1, + max_dim_size=3, + ) + ) + + reshape_shape = draw(helpers.reshape_shapes(shape=shape)) + + dtypes, x = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + shape=shape, + ) + ) + is_dim = draw(st.booleans()) + if is_dim: + dims = [x for x in range(len(shape))] + permut = draw(st.permutations(dims)) + return x, dtypes, reshape_shape, permut + else: + return x, dtypes, reshape_shape, None + + +@st.composite +def _slice_helper(draw): + dtype, x, shape = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=1, + ret_shape=True, + ), + ) + start_indices, limit_indices, strides = [], [], [] + for i in shape: + start_indices += [draw(st.integers(min_value=0, max_value=i - 1))] + limit_indices += [ + draw( + st.integers(min_value=0, max_value=i - 1).filter( + lambda _x: _x > start_indices[-1] + ) + ) + ] + strides += [draw(st.integers(min_value=1, max_value=i))] + return dtype, x, start_indices, limit_indices, strides + + +@st.composite +def _slice_in_dim_helper(draw): + dtype, x, axis = draw( + helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=1, + force_int_axis=True, + valid_axis=True, + ), + ) + operand = x[0] + start_index = draw( + st.integers(min_value=-abs(operand.shape[axis]), max_value=operand.shape[axis]) + ) + if start_index < 0: + limit_index = draw( + st.integers( + min_value=start_index + operand.shape[axis], + max_value=operand.shape[axis], + ) + ) + else: + limit_index = draw( + st.integers( + min_value=-abs(operand.shape[axis]), max_value=operand.shape[axis] + ).filter(lambda _x: _x >= start_index) + ) + stride = draw(st.integers(min_value=1, max_value=abs(limit_index + 1))) + return dtype, x, start_index, limit_index, stride, axis + + +# squeeze +@st.composite +def _squeeze_helper(draw): + shape = draw(st.shared(helpers.get_shape(), key="value_shape")) + valid_axes = [] + for index, axis in enumerate(shape): + if axis == 1: + valid_axes.append(index) + return valid_axes + + +@st.composite +def _x_and_filters(draw, dim=2, transpose=False, general=False): + if not isinstance(dim, int): + dim = draw(dim) + batch_size = draw(st.integers(1, 5)) + filter_shape = draw( + helpers.get_shape( + min_num_dims=dim, max_num_dims=dim, min_dim_size=1, max_dim_size=5 + ) + ) + dtype = draw(helpers.get_dtypes("float", full=False)) + padding = draw( + st.one_of( + st.lists( + st.tuples( + st.integers(min_value=0, max_value=3), + st.integers(min_value=0, max_value=3), + ), + min_size=dim, + max_size=dim, + ), + st.sampled_from(["SAME", "VALID"]), + ) + ) + input_channels = draw(st.integers(1, 3)) + output_channels = draw(st.integers(1, 3)) + group_list = [i for i in range(1, 6)] + if not transpose: + group_list = list(filter(lambda x: (input_channels % x == 0), group_list)) + else: + group_list = list(filter(lambda x: (output_channels % x == 0), group_list)) + fc = draw(st.sampled_from(group_list)) if general else 1 + strides = draw(st.lists(st.integers(1, 3), min_size=dim, max_size=dim)) + dilations = draw(st.lists(st.integers(1, 3), min_size=dim, max_size=dim)) + if general: + if dim == 2: + dim_num_st1 = st.sampled_from(["NCHW", "NHWC"]) + dim_num_st2 = st.sampled_from(["OIHW", "HWIO"]) + elif dim == 1: + dim_num_st1 = st.sampled_from(["NWC", "NCW"]) + dim_num_st2 = st.sampled_from(["OIW", "WIO"]) + else: + dim_num_st1 = st.sampled_from(["NDHWC", "NCDHW"]) + dim_num_st2 = st.sampled_from(["OIDHW", "DHWIO"]) + dim_seq = [*range(0, dim + 2)] + dimension_numbers = draw( + st.sampled_from( + [ + None, + (draw(dim_num_st1), draw(dim_num_st2), draw(dim_num_st1)), + ConvDimensionNumbers( + *map( + tuple, + draw( + st.lists( + st.permutations(dim_seq), min_size=3, max_size=3 + ) + ), + ) + ), + ] + ) + ) + else: + dimension_numbers = ( + ("NCH", "OIH", "NCH") + if dim == 1 + else ("NCHW", "OIHW", "NCHW") if dim == 2 else ("NCDHW", "OIDHW", "NCDHW") + ) + dim_nums = _dimension_numbers(dimension_numbers, dim + 2, transp=transpose) + if not transpose: + output_channels = output_channels * fc + channel_shape = (output_channels, input_channels // fc) + else: + input_channels = input_channels * fc + channel_shape = (output_channels // fc, input_channels) + x_dim = [] + for i in range(dim): + min_x = filter_shape[i] + (filter_shape[i] - 1) * (dilations[i] - 1) + x_dim.append(draw(st.integers(min_x, min_x + 1))) + x_shape = (batch_size, input_channels, *x_dim) + filter_shape = channel_shape + filter_shape + vals = draw( + helpers.array_values( + dtype=dtype[0], + shape=x_shape, + min_value=0.0, + max_value=1.0, + ) + ) + vals = ivy.permute_dims(vals, axes=_argsort_tuple(dim_nums[0])) + filters = draw( + helpers.array_values( + dtype=dtype[0], + shape=filter_shape, + min_value=0.0, + max_value=1.0, + ) + ) + filters = ivy.permute_dims(filters, axes=_argsort_tuple(dim_nums[1])) + if general and not transpose: + x_dilation = draw(st.lists(st.integers(1, 3), min_size=dim, max_size=dim)) + dilations = (dilations, x_dilation) + if draw(st.booleans()): + p_dtype, pref = draw( + helpers.get_castable_dtype(draw(helpers.get_dtypes("float")), dtype[0]) + ) + assume(can_cast(p_dtype, pref)) + else: + pref = None + return ( + dtype, + vals, + filters, + dilations, + dimension_numbers, + strides, + padding, + fc, + pref, + ) + + +# --- Main --- # +# ------------ # + + +# abs +@handle_frontend_test( + fn_tree="jax.lax.abs", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("signed_integer"), + ), + test_with_out=st.just(False), +) +def test_jax_abs( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, @@ -287,13 +570,13 @@ def test_jax_abs( ) -# sqrt +# acos @handle_frontend_test( - fn_tree="jax.lax.sqrt", + fn_tree="jax.lax.acos", dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), test_with_out=st.just(False), ) -def test_jax_sqrt( +def test_jax_acos( *, dtype_and_x, on_device, @@ -314,13 +597,118 @@ def test_jax_sqrt( ) -# acos +# add @handle_frontend_test( - fn_tree="jax.lax.acos", + fn_tree="jax.lax.add", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, + shared_dtype=True, + ), + test_with_out=st.just(False), +) +def test_jax_add( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + y=x[1], + ) + + +@handle_frontend_test( + fn_tree="jax.lax.argmax", + dtype_x_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("numeric"), + min_num_dims=1, + min_dim_size=1, + valid_axis=True, + force_int_axis=True, + allow_neg_axes=False, + ), + index_dtype=helpers.get_dtypes("integer", full=False), + test_with_out=st.just(False), +) +def test_jax_argmax( + *, + dtype_x_axis, + index_dtype, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x, axis = dtype_x_axis + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + operand=x[0], + axis=axis, + index_dtype=index_dtype[0], + ) + + +@handle_frontend_test( + fn_tree="jax.lax.argmin", + dtype_x_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("numeric"), + min_num_dims=1, + min_dim_size=1, + valid_axis=True, + force_int_axis=True, + allow_neg_axes=False, + ), + index_dtype=helpers.get_dtypes("integer", full=False), + test_with_out=st.just(False), +) +def test_jax_argmin( + *, + dtype_x_axis, + index_dtype, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x, axis = dtype_x_axis + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + operand=x[0], + axis=axis, + index_dtype=index_dtype[0], + ) + + +# asin +@handle_frontend_test( + fn_tree="jax.lax.asin", dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), test_with_out=st.just(False), ) -def test_jax_acos( +def test_jax_asin( *, dtype_and_x, on_device, @@ -341,13 +729,13 @@ def test_jax_acos( ) -# sin +# asinh @handle_frontend_test( - fn_tree="jax.lax.sin", + fn_tree="jax.lax.asinh", dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), test_with_out=st.just(False), ) -def test_jax_sin( +def test_jax_asinh( *, dtype_and_x, on_device, @@ -368,15 +756,71 @@ def test_jax_sin( ) -# sign +# atan @handle_frontend_test( - fn_tree="jax.lax.sign", + fn_tree="jax.lax.atan", + dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), + test_with_out=st.just(False), +) +def test_jax_atan( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + ) + + +# atan2 +@handle_frontend_test( + fn_tree="jax.lax.atan2", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric") + available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, + shared_dtype=True, ), +) +def test_jax_atan2( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + y=x[1], + ) + + +# atanh +@handle_frontend_test( + fn_tree="jax.lax.atanh", + dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), test_with_out=st.just(False), ) -def test_jax_sign( +def test_jax_atanh( *, dtype_and_x, on_device, @@ -397,40 +841,47 @@ def test_jax_sign( ) -# asin @handle_frontend_test( - fn_tree="jax.lax.asin", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), + fn_tree="jax.lax.batch_matmul", + dtypes_and_xs=_get_dtype_inputs_for_batch_matmul(), test_with_out=st.just(False), ) -def test_jax_asin( +def test_jax_batch_matmul( *, - dtype_and_x, + dtypes_and_xs, on_device, fn_tree, frontend, test_flags, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtypes, lhs, rhs = dtypes_and_xs helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, + input_dtypes=input_dtypes, frontend=frontend, + backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], + rtol=1e-2, + atol=1e-2, + lhs=lhs, + rhs=rhs, + precision=None, ) -# sinh +# bitwise_and @handle_frontend_test( - fn_tree="jax.lax.sinh", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), + fn_tree="jax.lax.bitwise_and", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("integer"), + num_arrays=2, + shared_dtype=True, + ), test_with_out=st.just(False), ) -def test_jax_sinh( +def test_jax_bitwise_and( *, dtype_and_x, on_device, @@ -440,7 +891,6 @@ def test_jax_sinh( backend_fw, ): input_dtype, x = dtype_and_x - helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -449,19 +899,20 @@ def test_jax_sinh( fn_tree=fn_tree, on_device=on_device, x=x[0], + y=x[1], ) -# atan2 +# bitwise_not @handle_frontend_test( - fn_tree="jax.lax.atan2", + fn_tree="jax.lax.bitwise_not", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - shared_dtype=True, + available_dtypes=helpers.get_dtypes("integer"), + num_arrays=1, ), + test_with_out=st.just(False), ) -def test_jax_atan2( +def test_jax_bitwise_not( *, dtype_and_x, on_device, @@ -471,6 +922,7 @@ def test_jax_atan2( backend_fw, ): input_dtype, x = dtype_and_x + helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -479,113 +931,113 @@ def test_jax_atan2( fn_tree=fn_tree, on_device=on_device, x=x[0], - y=x[1], ) +# bitwise_or @handle_frontend_test( - fn_tree="jax.lax.min", - dtypes_and_xs=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), + fn_tree="jax.lax.bitwise_or", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("integer"), num_arrays=2, shared_dtype=True, ), test_with_out=st.just(False), ) -def test_jax_min( +def test_jax_bitwise_or( *, - dtypes_and_xs, + dtype_and_x, on_device, fn_tree, frontend, test_flags, backend_fw, ): - input_dtypes, xs = dtypes_and_xs + input_dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=input_dtypes, + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=xs[0], - y=xs[1], + x=x[0], + y=x[1], ) +# bitwise_xor @handle_frontend_test( - fn_tree="jax.lax.eq", - dtypes_and_xs=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), + fn_tree="jax.lax.bitwise_xor", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("integer"), num_arrays=2, shared_dtype=True, ), test_with_out=st.just(False), ) -def test_jax_eq( +def test_jax_bitwise_xor( *, - dtypes_and_xs, + dtype_and_x, on_device, fn_tree, frontend, test_flags, backend_fw, ): - input_dtypes, xs = dtypes_and_xs + input_dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=input_dtypes, + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=xs[0], - y=xs[1], + x=x[0], + y=x[1], ) @handle_frontend_test( - fn_tree="jax.lax.mul", - dtypes_and_xs=helpers.dtype_and_values( + fn_tree="jax.lax.broadcast", + dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, - shared_dtype=True, - small_abs_safety_factor=2, - large_abs_safety_factor=2, - safety_factor_scale="log", ), + sizes=helpers.get_shape(min_num_dims=1), test_with_out=st.just(False), ) -def test_jax_mul( +def test_jax_broadcast( *, - dtypes_and_xs, + dtype_and_x, + sizes, on_device, fn_tree, frontend, test_flags, backend_fw, ): - input_dtypes, xs = dtypes_and_xs + input_dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=input_dtypes, + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=xs[0], - y=xs[1], + operand=x[0], + sizes=sizes, ) -# atan +# cbrt @handle_frontend_test( - fn_tree="jax.lax.atan", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), + fn_tree="jax.lax.cbrt", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), min_value=0.0 + ), test_with_out=st.just(False), ) -def test_jax_atan( +def test_jax_cbrt( *, dtype_and_x, on_device, @@ -634,95 +1086,84 @@ def test_jax_ceil( ) -# bitwise_and @handle_frontend_test( - fn_tree="jax.lax.bitwise_and", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("integer"), - num_arrays=2, - shared_dtype=True, - ), + fn_tree="jax.lax.clamp", + dtype_x_min_max=_get_clamp_inputs(), test_with_out=st.just(False), ) -def test_jax_bitwise_and( +def test_jax_clamp( *, - dtype_and_x, + dtype_x_min_max, on_device, fn_tree, frontend, test_flags, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtypes, x, min_vals, max_vals = dtype_x_min_max helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=input_dtypes, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, + min=min_vals, x=x[0], - y=x[1], + max=max_vals, ) -# bitwise_or +# concat @handle_frontend_test( - fn_tree="jax.lax.bitwise_or", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("integer"), - num_arrays=2, - shared_dtype=True, - ), + fn_tree="jax.lax.concatenate", + xs_n_input_dtypes_n_unique_idx=_arrays_idx_n_dtypes(), test_with_out=st.just(False), ) -def test_jax_bitwise_or( +def test_jax_concat( *, - dtype_and_x, + xs_n_input_dtypes_n_unique_idx, on_device, fn_tree, frontend, test_flags, backend_fw, ): - input_dtype, x = dtype_and_x + xs, input_dtypes, unique_idx = xs_n_input_dtypes_n_unique_idx helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=input_dtypes, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], - y=x[1], + operands=xs, + dimension=unique_idx, ) -# bitwise_not +# conj @handle_frontend_test( - fn_tree="jax.lax.bitwise_not", + fn_tree="jax.lax.conj", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("integer"), - num_arrays=1, + available_dtypes=["complex64"], ), - test_with_out=st.just(False), ) -def test_jax_bitwise_not( +def test_jax_conj( *, dtype_and_x, + test_flags, on_device, fn_tree, frontend, - test_flags, backend_fw, ): input_dtype, x = dtype_and_x - helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, - frontend=frontend, test_flags=test_flags, + frontend=frontend, fn_tree=fn_tree, on_device=on_device, x=x[0], @@ -730,119 +1171,124 @@ def test_jax_bitwise_not( @handle_frontend_test( - fn_tree="jax.lax.neg", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("signed_integer"), - ), + fn_tree="jax.lax.conv", + x_f_d_other=_x_and_filters(), test_with_out=st.just(False), ) -def test_jax_neg( +def test_jax_conv( *, - dtype_and_x, + x_f_d_other, on_device, fn_tree, frontend, test_flags, backend_fw, ): - input_dtype, x = dtype_and_x + dtype, x, filters, dilation, dim_num, stride, pad, fc, pref = x_f_d_other + _assume_tf_dilation_gt_1(backend_fw, on_device, dilation) helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], + lhs=x, + rhs=filters, + window_strides=stride, + padding=pad, + precision=None, + preferred_element_type=pref, ) @handle_frontend_test( - fn_tree="jax.lax.argmax", - dtype_x_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("numeric"), - min_num_dims=1, - min_dim_size=1, - valid_axis=True, - force_int_axis=True, - allow_neg_axes=False, - ), - index_dtype=helpers.get_dtypes("integer", full=False), + fn_tree="jax.lax.conv_general_dilated", + x_f_d_other=_x_and_filters(general=True), test_with_out=st.just(False), ) -def test_jax_argmax( +def test_jax_conv_general_dilated( *, - dtype_x_axis, - index_dtype, + x_f_d_other, on_device, fn_tree, frontend, test_flags, backend_fw, ): - input_dtype, x, axis = dtype_x_axis + dtype, x, filters, dilations, dim_num, stride, pad, fc, pref = x_f_d_other + _assume_tf_dilation_gt_1(ivy.current_backend_str(), on_device, dilations[0]) + assume( + not (isinstance(pad, str) and not len(dilations[1]) == dilations[1].count(1)) + ) helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, - operand=x[0], - axis=axis, - index_dtype=index_dtype[0], - ) - - -@handle_frontend_test( - fn_tree="jax.lax.argmin", - dtype_x_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("numeric"), - min_num_dims=1, - min_dim_size=1, - valid_axis=True, - force_int_axis=True, - allow_neg_axes=False, - ), - index_dtype=helpers.get_dtypes("integer", full=False), + on_device=on_device, + lhs=x, + rhs=filters, + window_strides=stride, + padding=pad, + lhs_dilation=dilations[1], + rhs_dilation=dilations[0], + dimension_numbers=dim_num, + feature_group_count=fc, + batch_group_count=1, + precision=None, + preferred_element_type=pref, + ) + + +@handle_frontend_test( + fn_tree="jax.lax.conv_transpose", + x_f_d_other=_x_and_filters(general=True, transpose=True), test_with_out=st.just(False), ) -def test_jax_argmin( +def test_jax_conv_transpose( *, - dtype_x_axis, - index_dtype, + x_f_d_other, on_device, fn_tree, frontend, test_flags, backend_fw, ): - input_dtype, x, axis = dtype_x_axis + dtype, x, filters, dilation, dim_num, stride, pad, fc, pref = x_f_d_other + _assume_tf_dilation_gt_1(ivy.current_backend_str(), on_device, dilation) helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=dtype, backend_to_test=backend_fw, - frontend=frontend, test_flags=test_flags, + frontend=frontend, fn_tree=fn_tree, on_device=on_device, - operand=x[0], - axis=axis, - index_dtype=index_dtype[0], + lhs=x, + rhs=filters, + strides=stride, + padding=pad, + rhs_dilation=dilation, + dimension_numbers=dim_num, + transpose_kernel=False, + precision=None, + preferred_element_type=pref, ) -# bitwise_xor @handle_frontend_test( - fn_tree="jax.lax.bitwise_xor", + fn_tree="jax.lax.convert_element_type", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("integer"), - num_arrays=2, - shared_dtype=True, + available_dtypes=helpers.get_dtypes("valid"), ), + new_dtype=helpers.get_dtypes("valid", full=False), test_with_out=st.just(False), ) -def test_jax_bitwise_xor( +def test_jax_convert_element_type( *, dtype_and_x, + new_dtype, on_device, fn_tree, frontend, @@ -850,34 +1296,28 @@ def test_jax_bitwise_xor( backend_fw, ): input_dtype, x = dtype_and_x + assume(can_cast(input_dtype[0], new_dtype[0])) helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, + input_dtypes=input_dtype + new_dtype, frontend=frontend, + backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], - y=x[1], + operand=x[0], + new_dtype=new_dtype[0], ) +# cos @handle_frontend_test( - fn_tree="jax.lax.full_like", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric", full=False, key="dtype") - ), - fill_val=_fill_value(), - shape=st.one_of(helpers.get_shape() | st.none()), - dtype=st.shared(helpers.get_dtypes("numeric", full=False), key="dtype"), + fn_tree="jax.lax.cos", + dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), test_with_out=st.just(False), ) -def test_jax_full_like( +def test_jax_cos( *, dtype_and_x, - fill_val, - shape, - dtype, on_device, fn_tree, frontend, @@ -885,7 +1325,6 @@ def test_jax_full_like( backend_fw, ): input_dtype, x = dtype_and_x - fill_val = fill_val helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -894,20 +1333,16 @@ def test_jax_full_like( fn_tree=fn_tree, on_device=on_device, x=x[0], - fill_value=fill_val, - dtype=dtype, - shape=shape, ) +# cosh @handle_frontend_test( - fn_tree="jax.lax.exp", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - ), + fn_tree="jax.lax.cosh", + dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), test_with_out=st.just(False), ) -def test_jax_exp( +def test_jax_cosh( *, dtype_and_x, on_device, @@ -928,35 +1363,44 @@ def test_jax_exp( ) +# cummin @handle_frontend_test( - fn_tree="jax.lax.convert_element_type", - dtype_and_x=helpers.dtype_and_values( + fn_tree="jax.lax.cummin", + dtype_x_axis=helpers.dtype_values_axis( available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=1, + max_num_dims=5, + valid_axis=True, + allow_neg_axes=False, + max_axes_size=1, + force_int_axis=True, ), - new_dtype=helpers.get_dtypes("valid", full=False), + reverse=st.booleans(), test_with_out=st.just(False), ) -def test_jax_convert_element_type( +def test_jax_cummin( *, - dtype_and_x, - new_dtype, + dtype_x_axis, + reverse, on_device, fn_tree, frontend, test_flags, backend_fw, ): - input_dtype, x = dtype_and_x - assume(can_cast(input_dtype[0], new_dtype[0])) + input_dtype, x, axis = dtype_x_axis helpers.test_frontend_function( - input_dtypes=input_dtype + new_dtype, - frontend=frontend, + input_dtypes=input_dtype, backend_to_test=backend_fw, + frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, + rtol=1e-2, + atol=1e-2, operand=x[0], - new_dtype=new_dtype[0], + axis=axis, + reverse=reverse, ) @@ -1041,15 +1485,11 @@ def test_jax_cumsum( @handle_frontend_test( - fn_tree="jax.lax.ge", - dtypes_and_xs=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, - shared_dtype=True, - ), + fn_tree="jax.lax.div", + dtypes_and_xs=_div_dtypes_and_xs(), test_with_out=st.just(False), ) -def test_jax_ge( +def test_jax_div( *, dtypes_and_xs, on_device, @@ -1059,6 +1499,7 @@ def test_jax_ge( backend_fw, ): input_dtypes, xs = dtypes_and_xs + assume(not np.any(np.isclose(xs[1], 0))) helpers.test_frontend_function( input_dtypes=input_dtypes, backend_to_test=backend_fw, @@ -1071,191 +1512,69 @@ def test_jax_ge( ) -@st.composite -def _reshape_helper(draw): - # generate a shape s.t len(shape) > 0 - shape = draw( - helpers.get_shape( - allow_none=False, - min_num_dims=1, - max_num_dims=3, - min_dim_size=1, - max_dim_size=3, - ) - ) - - reshape_shape = draw(helpers.reshape_shapes(shape=shape)) - - dtypes, x = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - shape=shape, - ) - ) - is_dim = draw(st.booleans()) - if is_dim: - dims = [x for x in range(len(shape))] - permut = draw(st.permutations(dims)) - return x, dtypes, reshape_shape, permut - else: - return x, dtypes, reshape_shape, None - - -@handle_frontend_test( - fn_tree="jax.lax.reshape", - x_reshape_permut=_reshape_helper(), - test_with_out=st.just(False), -) -def test_jax_reshape( - *, - x_reshape_permut, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - x, dtype, shape, dimensions = x_reshape_permut - helpers.test_frontend_function( - input_dtypes=dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - operand=x[0], - new_sizes=shape, - dimensions=dimensions, - ) - - -@handle_frontend_test( - fn_tree="jax.lax.broadcast", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - ), - sizes=helpers.get_shape(min_num_dims=1), - test_with_out=st.just(False), -) -def test_jax_broadcast( - *, - dtype_and_x, - sizes, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - operand=x[0], - sizes=sizes, - ) - - -@handle_frontend_test( - fn_tree="jax.lax.reciprocal", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - ), - test_with_out=st.just(False), -) -def test_jax_reciprocal( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - ) - - @handle_frontend_test( - fn_tree="jax.lax.sort", - dtype_x_bounded_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("numeric"), - min_num_dims=1, - valid_axis=True, - force_int_axis=True, - ), - is_stable=st.booleans(), + fn_tree="jax.lax.dot", + dtypes_and_xs=_get_dtype_inputs_for_dot(), test_with_out=st.just(False), ) -def test_jax_sort( +def test_jax_dot( *, - dtype_x_bounded_axis, - is_stable, + dtypes_and_xs, on_device, fn_tree, frontend, test_flags, backend_fw, ): - input_dtype, x, axis = dtype_x_bounded_axis + input_dtypes, dtype, lhs, rhs = dtypes_and_xs helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=input_dtypes, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - operand=x[0], - dimension=axis, - is_stable=is_stable, + rtol=1e-2, + atol=1e-2, + lhs=lhs, + rhs=rhs, + precision=None, + preferred_element_type=dtype, ) @handle_frontend_test( - fn_tree="jax.lax.le", - dtypes_and_xs=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, - shared_dtype=True, - ), + fn_tree="jax.lax.dot_general", + dtypes_lr_dims=_general_dot_helper(), test_with_out=st.just(False), ) -def test_jax_le( +def test_jax_dot_general( *, - dtypes_and_xs, + dtypes_lr_dims, on_device, fn_tree, frontend, test_flags, backend_fw, ): - input_dtypes, xs = dtypes_and_xs + dtypes, lr, dims, dtype = dtypes_lr_dims helpers.test_frontend_function( - input_dtypes=input_dtypes, + input_dtypes=dtypes, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=xs[0], - y=xs[1], + lhs=lr[0], + rhs=lr[1], + dimension_numbers=dims, + precision=None, + preferred_element_type=dtype, ) @handle_frontend_test( - fn_tree="jax.lax.ne", + fn_tree="jax.lax.eq", dtypes_and_xs=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2, @@ -1263,7 +1582,7 @@ def test_jax_le( ), test_with_out=st.just(False), ) -def test_jax_ne( +def test_jax_eq( *, dtypes_and_xs, on_device, @@ -1285,13 +1604,12 @@ def test_jax_ne( ) -# cosh @handle_frontend_test( - fn_tree="jax.lax.cosh", + fn_tree="jax.lax.erf", dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), test_with_out=st.just(False), ) -def test_jax_cosh( +def test_jax_erf( *, dtype_and_x, on_device, @@ -1308,52 +1626,51 @@ def test_jax_cosh( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, + rtol=1e-2, + atol=1e-2, x=x[0], ) +# erfc @handle_frontend_test( - fn_tree="jax.lax.add", - dtypes_and_xs=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, - shared_dtype=True, - ), + fn_tree="jax.lax.erfc", + dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("valid")), test_with_out=st.just(False), ) -def test_jax_lt( +def test_jax_erfc( *, - dtypes_and_xs, + dtype_and_x, on_device, fn_tree, frontend, test_flags, backend_fw, ): - input_dtypes, xs = dtypes_and_xs + input_dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=input_dtypes, + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=xs[0], - y=xs[1], + rtol=1e-2, + atol=1e-2, + x=x[0], ) -# round @handle_frontend_test( - fn_tree="jax.lax.round", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), - rounding_method=st.sampled_from([0, 1]), + fn_tree="jax.lax.exp", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + ), test_with_out=st.just(False), ) -def test_jax_round( +def test_jax_exp( *, dtype_and_x, - rounding_method, on_device, fn_tree, frontend, @@ -1369,138 +1686,125 @@ def test_jax_round( fn_tree=fn_tree, on_device=on_device, x=x[0], - rounding_method=rounding_method, ) +# expand_dims @handle_frontend_test( - fn_tree="jax.lax.pow", - dtypes_and_values=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - shared_dtype=True, + fn_tree="jax.lax.expand_dims", + dtype_x_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=1, + max_num_dims=5, + min_dim_size=2, + max_dim_size=10, + force_int_axis=True, + valid_axis=True, ), test_with_out=st.just(False), ) -def test_jax_pow( +def test_jax_expand_dims( *, - dtypes_and_values, + dtype_x_axis, on_device, fn_tree, frontend, test_flags, backend_fw, ): - input_dtypes, xs = dtypes_and_values + x_dtype, x, axis = dtype_x_axis helpers.test_frontend_function( - input_dtypes=input_dtypes, - backend_to_test=backend_fw, + input_dtypes=x_dtype, frontend=frontend, + bakcend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=xs[0], - y=xs[1], - ) - - -@st.composite -def _pad_helper(draw): - dtype, x, shape = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("bool"), - ret_shape=True, - min_num_dims=1, - min_dim_size=2, - min_value=-100, - max_value=100, - ).filter(lambda _x: _x[0][0] not in ["float16", "bfloat16"]) - ) - ndim = len(shape) - min_dim = min(shape) - padding_config = draw( - st.lists( - st.tuples( - st.integers(min_value=-(min_dim - 1), max_value=min_dim - 1), - st.integers(min_value=-(min_dim - 1), max_value=min_dim - 1), - st.integers(min_value=0, max_value=min_dim - 1), - ), - min_size=ndim, - max_size=ndim, - ) + array=x[0], + dimensions=(axis,), ) - padding_value = draw(st.booleans()) - return dtype, x[0], padding_value, padding_config @handle_frontend_test( - fn_tree="jax.lax.pad", - dtype_x_params=_pad_helper(), + fn_tree="jax.lax.expm1", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + ), test_with_out=st.just(False), ) -def test_jax_pad( +def test_jax_expm1( *, - dtype_x_params, + dtype_and_x, on_device, fn_tree, frontend, test_flags, backend_fw, ): - dtype, operand, padding_value, padding_config = dtype_x_params + input_dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=dtype, + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - operand=operand, - padding_value=padding_value, - padding_config=padding_config, + x=x[0], ) @handle_frontend_test( - fn_tree="jax.lax.gt", - dtypes_and_xs=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, - shared_dtype=True, + fn_tree="jax.lax.full", + shape=helpers.get_shape( + allow_none=False, + min_num_dims=1, + max_num_dims=5, + min_dim_size=1, + max_dim_size=10, ), - test_with_out=st.just(False), + fill_value=_fill_value(), + dtypes=helpers.get_dtypes("numeric", full=False, key="dtype"), ) -def test_jax_gt( +def test_jax_full( *, - dtypes_and_xs, + shape, + fill_value, + dtypes, on_device, fn_tree, frontend, test_flags, backend_fw, ): - input_dtypes, xs = dtypes_and_xs helpers.test_frontend_function( - input_dtypes=input_dtypes, + input_dtypes=dtypes, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=xs[0], - y=xs[1], + shape=shape, + fill_value=fill_value, + dtype=dtypes[0], ) -# cos @handle_frontend_test( - fn_tree="jax.lax.cos", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), + fn_tree="jax.lax.full_like", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric", full=False, key="dtype") + ), + fill_val=_fill_value(), + shape=st.one_of(helpers.get_shape() | st.none()), + dtype=st.shared(helpers.get_dtypes("numeric", full=False), key="dtype"), test_with_out=st.just(False), ) -def test_jax_cos( +def test_jax_full_like( *, dtype_and_x, + fill_val, + shape, + dtype, on_device, fn_tree, frontend, @@ -1508,6 +1812,7 @@ def test_jax_cos( backend_fw, ): input_dtype, x = dtype_and_x + fill_val = fill_val helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -1516,47 +1821,62 @@ def test_jax_cos( fn_tree=fn_tree, on_device=on_device, x=x[0], + fill_value=fill_val, + dtype=dtype, + shape=shape, ) -@st.composite -def _get_clamp_inputs(draw): - shape = draw( - helpers.get_shape( - min_num_dims=1, max_num_dims=5, min_dim_size=2, max_dim_size=10 - ) - ) - x_dtype, x = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - shape=shape, - ) - ) - - min = draw( - helpers.array_values(dtype=x_dtype[0], shape=shape, min_value=-25, max_value=0) - ) - max = draw( - helpers.array_values(dtype=x_dtype[0], shape=shape, min_value=1, max_value=25) +@handle_frontend_test( + fn_tree="jax.lax.ge", + dtypes_and_xs=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, + shared_dtype=True, + ), + test_with_out=st.just(False), +) +def test_jax_ge( + *, + dtypes_and_xs, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtypes, xs = dtypes_and_xs + helpers.test_frontend_function( + input_dtypes=input_dtypes, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=xs[0], + y=xs[1], ) - return x_dtype, x, min, max @handle_frontend_test( - fn_tree="jax.lax.clamp", - dtype_x_min_max=_get_clamp_inputs(), + fn_tree="jax.lax.gt", + dtypes_and_xs=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, + shared_dtype=True, + ), test_with_out=st.just(False), ) -def test_jax_clamp( +def test_jax_gt( *, - dtype_x_min_max, + dtypes_and_xs, on_device, fn_tree, frontend, test_flags, backend_fw, ): - input_dtypes, x, min_vals, max_vals = dtype_x_min_max + input_dtypes, xs = dtypes_and_xs helpers.test_frontend_function( input_dtypes=input_dtypes, backend_to_test=backend_fw, @@ -1564,21 +1884,19 @@ def test_jax_clamp( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - min=min_vals, - x=x[0], - max=max_vals, + x=xs[0], + y=xs[1], ) +# imag @handle_frontend_test( - fn_tree="jax.lax.log", + fn_tree="jax.lax.imag", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_value=1, + available_dtypes=helpers.get_dtypes("complex") ), - test_with_out=st.just(False), ) -def test_jax_log( +def test_jax_imag( *, dtype_and_x, on_device, @@ -1595,122 +1913,107 @@ def test_jax_log( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, + test_values=True, x=x[0], ) +# iota @handle_frontend_test( - fn_tree="jax.lax.rev", - dtype_x_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("numeric"), - min_num_dims=1, - min_dim_size=2, - valid_axis=True, - force_int_axis=True, - allow_neg_axes=False, - ), + fn_tree="jax.lax.iota", + dtypes=helpers.get_dtypes("valid", full=False), + size=helpers.ints(min_value=0, max_value=10), test_with_out=st.just(False), ) -def test_jax_rev( +def test_jax_iota( *, - dtype_x_axis, + dtypes, + size, on_device, fn_tree, frontend, test_flags, backend_fw, ): - input_dtype, x, axis = dtype_x_axis helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=dtypes, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - operand=x[0], - dimensions=(axis,), - ) - - -@st.composite -def _div_dtypes_and_xs(draw): - dtype, dividend, shape = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), ret_shape=True - ) - ) - divisor = draw( - helpers.array_values(dtype=dtype[0], min_value=-20, max_value=20, shape=shape) + dtype=dtypes[0], + size=size, ) - return dtype, [dividend[0], divisor] +# is_finite @handle_frontend_test( - fn_tree="jax.lax.div", - dtypes_and_xs=_div_dtypes_and_xs(), + fn_tree="jax.lax.is_finite", + dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), test_with_out=st.just(False), ) -def test_jax_div( +def test_jax_is_finite( *, - dtypes_and_xs, + dtype_and_x, on_device, fn_tree, frontend, test_flags, backend_fw, ): - input_dtypes, xs = dtypes_and_xs - assume(not np.any(np.isclose(xs[1], 0))) + input_dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=input_dtypes, + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=xs[0], - y=xs[1], + x=x[0], ) @handle_frontend_test( - fn_tree="jax.lax.rsqrt", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + fn_tree="jax.lax.le", + dtypes_and_xs=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, + shared_dtype=True, ), test_with_out=st.just(False), ) -def test_jax_rsqrt( +def test_jax_le( *, - dtype_and_x, + dtypes_and_xs, on_device, fn_tree, frontend, test_flags, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtypes, xs = dtypes_and_xs helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=input_dtypes, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - rtol=1e-02, - x=x[0], + x=xs[0], + y=xs[1], ) @handle_frontend_test( - fn_tree="jax.lax.expm1", + fn_tree="jax.lax.log", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), + min_value=1, ), test_with_out=st.just(False), ) -def test_jax_expm1( +def test_jax_log( *, dtype_and_x, on_device, @@ -1762,42 +2065,58 @@ def test_jax_log1p( ) -@st.composite -def _dtype_values_dims(draw): - dtype, values, shape = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - min_num_dims=1, - ret_shape=True, - ) - ) - size = len(shape) - permutations = draw( - st.lists( - st.integers(min_value=0, max_value=len(shape) - 1), - min_size=size, - max_size=size, - unique=True, - ) +@handle_frontend_test( + fn_tree="jax.lax.add", + dtypes_and_xs=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, + shared_dtype=True, + ), + test_with_out=st.just(False), +) +def test_jax_lt( + *, + dtypes_and_xs, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtypes, xs = dtypes_and_xs + helpers.test_frontend_function( + input_dtypes=input_dtypes, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=xs[0], + y=xs[1], ) - return dtype, values, tuple(permutations) +# max @handle_frontend_test( - fn_tree="jax.lax.transpose", - dtype_x_dims=_dtype_values_dims(), + fn_tree="jax.lax.max", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, + shared_dtype=True, + ), test_with_out=st.just(False), ) -def test_jax_transpose( +def test_jax_max( *, - dtype_x_dims, + dtype_and_x, on_device, fn_tree, frontend, test_flags, backend_fw, ): - input_dtype, x, dims = dtype_x_dims + input_dtype, x = dtype_and_x + helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -1805,56 +2124,21 @@ def test_jax_transpose( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - operand=x[0], - permutation=dims, + x=x[0], + y=x[1], ) -@st.composite -def _get_dtype_inputs_for_dot(draw): - dim_size = draw(helpers.ints(min_value=1, max_value=5)) - dtype = draw(helpers.get_dtypes("numeric", index=1, full=False)) - if dim_size == 1: - lhs = draw( - helpers.array_values( - dtype=dtype[0], shape=(dim_size,), min_value=2, max_value=5 - ) - ) - rhs = draw( - helpers.array_values( - dtype=dtype[0], shape=(dim_size,), min_value=2, max_value=5 - ) - ) - else: - lhs = draw( - helpers.array_values( - dtype=dtype[0], shape=(dim_size, dim_size), min_value=2, max_value=5 - ) - ) - rhs = draw( - helpers.array_values( - dtype=dtype[0], shape=(dim_size, dim_size), min_value=2, max_value=5 - ) - ) - is_pref = draw(st.booleans()) - if is_pref: - dtype, values, pref = draw( - helpers.get_castable_dtype( - draw(helpers.get_dtypes("numeric")), dtype[0], [lhs, rhs] - ) - ) - assume(can_cast(dtype, pref)) - return [dtype], pref, values[0], values[1] - else: - return dtype, None, lhs, rhs - - @handle_frontend_test( - fn_tree="jax.lax.dot", - dtypes_and_xs=_get_dtype_inputs_for_dot(), + fn_tree="jax.lax.min", + dtypes_and_xs=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, + shared_dtype=True, + ), test_with_out=st.just(False), ) -def test_jax_dot( +def test_jax_min( *, dtypes_and_xs, on_device, @@ -1863,7 +2147,7 @@ def test_jax_dot( test_flags, backend_fw, ): - input_dtypes, dtype, lhs, rhs = dtypes_and_xs + input_dtypes, xs = dtypes_and_xs helpers.test_frontend_function( input_dtypes=input_dtypes, backend_to_test=backend_fw, @@ -1871,47 +2155,24 @@ def test_jax_dot( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - rtol=1e-2, - atol=1e-2, - lhs=lhs, - rhs=rhs, - precision=None, - preferred_element_type=dtype, - ) - - -@st.composite -def _get_dtype_inputs_for_batch_matmul(draw): - dtype, lhs = draw( - helpers.dtype_and_values( - min_num_dims=2, - max_num_dims=6, - min_value=2, - max_value=5, - ) - ) - lhs_shape = lhs[0].shape - rhs_shape = list(lhs_shape) - rhs_shape[-1], rhs_shape[-2] = rhs_shape[-2], rhs_shape[-1] - rhs_shape = tuple(rhs_shape) - rhs = draw( - helpers.array_values( - dtype=dtype[0], - shape=rhs_shape, - min_value=2, - max_value=5, - ) + x=xs[0], + y=xs[1], ) - return dtype, lhs[0], rhs - @handle_frontend_test( - fn_tree="jax.lax.batch_matmul", - dtypes_and_xs=_get_dtype_inputs_for_batch_matmul(), + fn_tree="jax.lax.mul", + dtypes_and_xs=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, + shared_dtype=True, + small_abs_safety_factor=2, + large_abs_safety_factor=2, + safety_factor_scale="log", + ), test_with_out=st.just(False), ) -def test_jax_batch_matmul( +def test_jax_mul( *, dtypes_and_xs, on_device, @@ -1920,357 +2181,212 @@ def test_jax_batch_matmul( test_flags, backend_fw, ): - input_dtypes, lhs, rhs = dtypes_and_xs + input_dtypes, xs = dtypes_and_xs helpers.test_frontend_function( input_dtypes=input_dtypes, - frontend=frontend, backend_to_test=backend_fw, + frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - rtol=1e-2, - atol=1e-2, - lhs=lhs, - rhs=rhs, - precision=None, - ) - - -@st.composite -def _general_dot_helper(draw): - input_dtype, lhs, lshape = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - min_value=-1e04, - max_value=1e04, - min_num_dims=2, - ret_shape=True, - ) - ) - ndims = len(lshape) - perm_id = random.sample(list(range(ndims)), ndims) - rshape = [lshape[i] for i in perm_id] - input_dtype, rhs = draw( - helpers.dtype_and_values( - dtype=input_dtype, - min_value=-1e04, - max_value=1e04, - shape=rshape, - ) - ) - ind_list = list(range(ndims)) - batch_n = draw(st.integers(min_value=1, max_value=len(lshape) - 1)) - lhs_batch = random.sample(ind_list, batch_n) - rhs_batch = [perm_id.index(i) for i in lhs_batch] - lhs_contracting = [i for i in ind_list if i not in lhs_batch] - rhs_contracting = [perm_id.index(i) for i in lhs_contracting] - is_pref = draw(st.booleans()) - pref_dtype = None - if is_pref: - uint_cast_st = helpers.get_castable_dtype( - draw(helpers.get_dtypes("unsigned")), - input_dtype[0], - ) - int_cast_st = helpers.get_castable_dtype( - draw(helpers.get_dtypes("signed_integer")), - input_dtype[0], - ) - float_cast_st = helpers.get_castable_dtype( - draw(helpers.get_dtypes("float")), - input_dtype[0], - ) - complex_cast_st = helpers.get_castable_dtype( - draw(helpers.get_dtypes("complex")), - input_dtype[0], - ) - if "uint" in input_dtype[0]: - pref_dtype = draw(st.one_of(uint_cast_st, float_cast_st))[-1] - elif "int" in input_dtype[0]: - pref_dtype = draw(st.one_of(int_cast_st, float_cast_st))[-1] - elif "float" in input_dtype[0]: - pref_dtype = draw(float_cast_st)[-1] - elif "complex" in input_dtype[0]: - pref_dtype = draw(complex_cast_st)[-1] - else: - raise ivy.exceptions.IvyException("unsupported dtype") - return ( - input_dtype * 2, - (lhs[0], rhs[0]), - ((lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch)), - pref_dtype, + x=xs[0], + y=xs[1], ) @handle_frontend_test( - fn_tree="jax.lax.dot_general", - dtypes_lr_dims=_general_dot_helper(), + fn_tree="jax.lax.ne", + dtypes_and_xs=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, + shared_dtype=True, + ), test_with_out=st.just(False), ) -def test_jax_dot_general( +def test_jax_ne( *, - dtypes_lr_dims, + dtypes_and_xs, on_device, fn_tree, frontend, test_flags, backend_fw, ): - dtypes, lr, dims, dtype = dtypes_lr_dims + input_dtypes, xs = dtypes_and_xs helpers.test_frontend_function( - input_dtypes=dtypes, + input_dtypes=input_dtypes, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - lhs=lr[0], - rhs=lr[1], - dimension_numbers=dims, - precision=None, - preferred_element_type=dtype, + x=xs[0], + y=xs[1], ) -@st.composite -def _x_and_filters(draw, dim=2, transpose=False, general=False): - if not isinstance(dim, int): - dim = draw(dim) - batch_size = draw(st.integers(1, 5)) - filter_shape = draw( - helpers.get_shape( - min_num_dims=dim, max_num_dims=dim, min_dim_size=1, max_dim_size=5 - ) - ) - dtype = draw(helpers.get_dtypes("float", full=False)) - padding = draw( - st.one_of( - st.lists( - st.tuples( - st.integers(min_value=0, max_value=3), - st.integers(min_value=0, max_value=3), - ), - min_size=dim, - max_size=dim, - ), - st.sampled_from(["SAME", "VALID"]), - ) - ) - input_channels = draw(st.integers(1, 3)) - output_channels = draw(st.integers(1, 3)) - group_list = [i for i in range(1, 6)] - if not transpose: - group_list = list(filter(lambda x: (input_channels % x == 0), group_list)) - else: - group_list = list(filter(lambda x: (output_channels % x == 0), group_list)) - fc = draw(st.sampled_from(group_list)) if general else 1 - strides = draw(st.lists(st.integers(1, 3), min_size=dim, max_size=dim)) - dilations = draw(st.lists(st.integers(1, 3), min_size=dim, max_size=dim)) - if general: - if dim == 2: - dim_num_st1 = st.sampled_from(["NCHW", "NHWC"]) - dim_num_st2 = st.sampled_from(["OIHW", "HWIO"]) - elif dim == 1: - dim_num_st1 = st.sampled_from(["NWC", "NCW"]) - dim_num_st2 = st.sampled_from(["OIW", "WIO"]) - else: - dim_num_st1 = st.sampled_from(["NDHWC", "NCDHW"]) - dim_num_st2 = st.sampled_from(["OIDHW", "DHWIO"]) - dim_seq = [*range(0, dim + 2)] - dimension_numbers = draw( - st.sampled_from( - [ - None, - (draw(dim_num_st1), draw(dim_num_st2), draw(dim_num_st1)), - ConvDimensionNumbers( - *map( - tuple, - draw( - st.lists( - st.permutations(dim_seq), min_size=3, max_size=3 - ) - ), - ) - ), - ] - ) - ) - else: - dimension_numbers = ( - ("NCH", "OIH", "NCH") - if dim == 1 - else ("NCHW", "OIHW", "NCHW") if dim == 2 else ("NCDHW", "OIDHW", "NCDHW") - ) - dim_nums = _dimension_numbers(dimension_numbers, dim + 2, transp=transpose) - if not transpose: - output_channels = output_channels * fc - channel_shape = (output_channels, input_channels // fc) - else: - input_channels = input_channels * fc - channel_shape = (output_channels // fc, input_channels) - x_dim = [] - for i in range(dim): - min_x = filter_shape[i] + (filter_shape[i] - 1) * (dilations[i] - 1) - x_dim.append(draw(st.integers(min_x, min_x + 1))) - x_shape = (batch_size, input_channels, *x_dim) - filter_shape = channel_shape + filter_shape - vals = draw( - helpers.array_values( - dtype=dtype[0], - shape=x_shape, - min_value=0.0, - max_value=1.0, - ) - ) - vals = ivy.permute_dims(vals, axes=_argsort_tuple(dim_nums[0])) - filters = draw( - helpers.array_values( - dtype=dtype[0], - shape=filter_shape, - min_value=0.0, - max_value=1.0, - ) - ) - filters = ivy.permute_dims(filters, axes=_argsort_tuple(dim_nums[1])) - if general and not transpose: - x_dilation = draw(st.lists(st.integers(1, 3), min_size=dim, max_size=dim)) - dilations = (dilations, x_dilation) - if draw(st.booleans()): - p_dtype, pref = draw( - helpers.get_castable_dtype(draw(helpers.get_dtypes("float")), dtype[0]) - ) - assume(can_cast(p_dtype, pref)) - else: - pref = None - return ( - dtype, - vals, - filters, - dilations, - dimension_numbers, - strides, - padding, - fc, - pref, +@handle_frontend_test( + fn_tree="jax.lax.neg", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("signed_integer"), + ), + test_with_out=st.just(False), +) +def test_jax_neg( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], ) +# nextafter @handle_frontend_test( - fn_tree="jax.lax.conv", - x_f_d_other=_x_and_filters(), + fn_tree="jax.lax.nextafter", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=["float32", "float64"], + min_value=-100, + max_value=100, + min_num_dims=1, + max_num_dims=3, + min_dim_size=1, + max_dim_size=3, + num_arrays=2, + shared_dtype=True, + ), test_with_out=st.just(False), ) -def test_jax_conv( +def test_jax_nextafter( *, - x_f_d_other, + dtype_and_x, on_device, fn_tree, frontend, test_flags, backend_fw, ): - dtype, x, filters, dilation, dim_num, stride, pad, fc, pref = x_f_d_other - _assume_tf_dilation_gt_1(backend_fw, on_device, dilation) + input_dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=dtype, + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - lhs=x, - rhs=filters, - window_strides=stride, - padding=pad, - precision=None, - preferred_element_type=pref, + x1=x[0], + x2=x[0], ) @handle_frontend_test( - fn_tree="jax.lax.conv_transpose", - x_f_d_other=_x_and_filters(general=True, transpose=True), + fn_tree="jax.lax.pad", + dtype_x_params=_pad_helper(), test_with_out=st.just(False), ) -def test_jax_conv_transpose( +def test_jax_pad( *, - x_f_d_other, + dtype_x_params, on_device, fn_tree, frontend, test_flags, backend_fw, ): - dtype, x, filters, dilation, dim_num, stride, pad, fc, pref = x_f_d_other - _assume_tf_dilation_gt_1(ivy.current_backend_str(), on_device, dilation) + dtype, operand, padding_value, padding_config = dtype_x_params helpers.test_frontend_function( input_dtypes=dtype, backend_to_test=backend_fw, - test_flags=test_flags, frontend=frontend, + test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - lhs=x, - rhs=filters, - strides=stride, - padding=pad, - rhs_dilation=dilation, - dimension_numbers=dim_num, - transpose_kernel=False, - precision=None, - preferred_element_type=pref, + operand=operand, + padding_value=padding_value, + padding_config=padding_config, ) @handle_frontend_test( - fn_tree="jax.lax.conv_general_dilated", - x_f_d_other=_x_and_filters(general=True), + fn_tree="jax.lax.pow", + dtypes_and_values=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, + shared_dtype=True, + ), test_with_out=st.just(False), ) -def test_jax_conv_general_dilated( +def test_jax_pow( *, - x_f_d_other, + dtypes_and_values, on_device, fn_tree, frontend, test_flags, backend_fw, ): - dtype, x, filters, dilations, dim_num, stride, pad, fc, pref = x_f_d_other - _assume_tf_dilation_gt_1(ivy.current_backend_str(), on_device, dilations[0]) - assume( - not (isinstance(pad, str) and not len(dilations[1]) == dilations[1].count(1)) + input_dtypes, xs = dtypes_and_values + helpers.test_frontend_function( + input_dtypes=input_dtypes, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=xs[0], + y=xs[1], ) + + +# real +@handle_frontend_test( + fn_tree="jax.lax.real", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("complex") + ), +) +def test_jax_real( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=dtype, + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - lhs=x, - rhs=filters, - window_strides=stride, - padding=pad, - lhs_dilation=dilations[1], - rhs_dilation=dilations[0], - dimension_numbers=dim_num, - feature_group_count=fc, - batch_group_count=1, - precision=None, - preferred_element_type=pref, + test_values=True, + x=x[0], ) @handle_frontend_test( - fn_tree="jax.lax.sub", + fn_tree="jax.lax.reciprocal", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, - shared_dtype=True, + available_dtypes=helpers.get_dtypes("float"), ), test_with_out=st.just(False), ) -def test_jax_sub( +def test_jax_reciprocal( *, dtype_and_x, on_device, @@ -2288,7 +2404,39 @@ def test_jax_sub( fn_tree=fn_tree, on_device=on_device, x=x[0], - y=x[1], + ) + + +@handle_frontend_test( + fn_tree="jax.lax.reduce_window", + all_args=_reduce_window_helper(_get_reduce_func), + test_with_out=st.just(False), +) +def test_jax_reduce_window( + *, + all_args, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + dtypes, operand, init_value, computation, others, padding = all_args + helpers.test_frontend_function( + input_dtypes=dtypes, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + operand=operand[0], + init_value=init_value[0], + computation=computation, + window_dimensions=others[0], + window_strides=others[1], + padding=padding, + base_dilation=others[2], + window_dilation=None, ) @@ -2327,16 +2475,105 @@ def test_jax_rem( @handle_frontend_test( - fn_tree="jax.lax.square", - dtype_and_x=helpers.dtype_and_values( + fn_tree="jax.lax.reshape", + x_reshape_permut=_reshape_helper(), + test_with_out=st.just(False), +) +def test_jax_reshape( + *, + x_reshape_permut, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + x, dtype, shape, dimensions = x_reshape_permut + helpers.test_frontend_function( + input_dtypes=dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + operand=x[0], + new_sizes=shape, + dimensions=dimensions, + ) + + +@handle_frontend_test( + fn_tree="jax.lax.rev", + dtype_x_axis=helpers.dtype_values_axis( available_dtypes=helpers.get_dtypes("numeric"), - small_abs_safety_factor=2, - large_abs_safety_factor=2, - safety_factor_scale="log", + min_num_dims=1, + min_dim_size=2, + valid_axis=True, + force_int_axis=True, + allow_neg_axes=False, + ), + test_with_out=st.just(False), +) +def test_jax_rev( + *, + dtype_x_axis, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x, axis = dtype_x_axis + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + operand=x[0], + dimensions=(axis,), + ) + + +# round +@handle_frontend_test( + fn_tree="jax.lax.round", + dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), + rounding_method=st.sampled_from([0, 1]), + test_with_out=st.just(False), +) +def test_jax_round( + *, + dtype_and_x, + rounding_method, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + rounding_method=rounding_method, + ) + + +@handle_frontend_test( + fn_tree="jax.lax.rsqrt", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), ), test_with_out=st.just(False), ) -def test_jax_square( +def test_jax_rsqrt( *, dtype_and_x, on_device, @@ -2353,35 +2590,36 @@ def test_jax_square( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, + rtol=1e-02, x=x[0], ) @handle_frontend_test( - fn_tree="jax.lax.erf", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), + fn_tree="jax.lax.select", + dtype_pred_ontrue_on_false=_dtype_pred_ontrue_on_false(), test_with_out=st.just(False), ) -def test_jax_erf( +def test_jax_select( *, - dtype_and_x, + dtype_pred_ontrue_on_false, on_device, fn_tree, frontend, test_flags, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, pred, on_true_on_false = dtype_pred_ontrue_on_false helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, + input_dtypes=["bool"] + input_dtype, frontend=frontend, + backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - rtol=1e-2, - atol=1e-2, - x=x[0], + pred=pred, + on_true=on_true_on_false[0], + on_false=on_true_on_false[0], ) @@ -2460,162 +2698,15 @@ def test_jax_shift_right_logical( ) -@st.composite -def _slice_helper(draw): - dtype, x, shape = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - min_num_dims=1, - ret_shape=True, - ), - ) - start_indices, limit_indices, strides = [], [], [] - for i in shape: - start_indices += [draw(st.integers(min_value=0, max_value=i - 1))] - limit_indices += [ - draw( - st.integers(min_value=0, max_value=i - 1).filter( - lambda _x: _x > start_indices[-1] - ) - ) - ] - strides += [draw(st.integers(min_value=1, max_value=i))] - return dtype, x, start_indices, limit_indices, strides - - -@handle_frontend_test( - fn_tree="jax.lax.slice", - dtype_x_params=_slice_helper(), - test_with_out=st.just(False), -) -def test_jax_slice( - *, - dtype_x_params, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - dtype, x, start_indices, limit_indices, strides = dtype_x_params - helpers.test_frontend_function( - input_dtypes=dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - operand=x[0], - start_indices=start_indices, - limit_indices=limit_indices, - strides=strides, - ) - - -@st.composite -def _slice_in_dim_helper(draw): - dtype, x, axis = draw( - helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("valid"), - min_num_dims=1, - force_int_axis=True, - valid_axis=True, - ), - ) - operand = x[0] - start_index = draw( - st.integers(min_value=-abs(operand.shape[axis]), max_value=operand.shape[axis]) - ) - if start_index < 0: - limit_index = draw( - st.integers( - min_value=start_index + operand.shape[axis], - max_value=operand.shape[axis], - ) - ) - else: - limit_index = draw( - st.integers( - min_value=-abs(operand.shape[axis]), max_value=operand.shape[axis] - ).filter(lambda _x: _x >= start_index) - ) - stride = draw(st.integers(min_value=1, max_value=abs(limit_index + 1))) - return dtype, x, start_index, limit_index, stride, axis - - -@handle_frontend_test( - fn_tree="jax.lax.slice_in_dim", - dtype_x_params=_slice_in_dim_helper(), - test_with_out=st.just(False), -) -def test_jax_slice_in_dim( - *, - dtype_x_params, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - dtype, x, start_index, limit_index, stride, axis = dtype_x_params - helpers.test_frontend_function( - input_dtypes=dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - operand=x[0], - start_index=start_index, - limit_index=limit_index, - stride=stride, - axis=axis, - ) - - -# expand_dims +# sign @handle_frontend_test( - fn_tree="jax.lax.expand_dims", - dtype_x_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("valid"), - min_num_dims=1, - max_num_dims=5, - min_dim_size=2, - max_dim_size=10, - force_int_axis=True, - valid_axis=True, + fn_tree="jax.lax.sign", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric") ), test_with_out=st.just(False), ) -def test_jax_expand_dims( - *, - dtype_x_axis, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - x_dtype, x, axis = dtype_x_axis - helpers.test_frontend_function( - input_dtypes=x_dtype, - frontend=frontend, - bakcend_to_test=backend_fw, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - array=x[0], - dimensions=(axis,), - ) - - -# asinh -@handle_frontend_test( - fn_tree="jax.lax.asinh", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), - test_with_out=st.just(False), -) -def test_jax_asinh( +def test_jax_sign( *, dtype_and_x, on_device, @@ -2636,13 +2727,13 @@ def test_jax_asinh( ) -# atanh +# sin @handle_frontend_test( - fn_tree="jax.lax.atanh", + fn_tree="jax.lax.sin", dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), test_with_out=st.just(False), ) -def test_jax_atanh( +def test_jax_sin( *, dtype_and_x, on_device, @@ -2663,200 +2754,115 @@ def test_jax_atanh( ) -# select -@st.composite -def _dtype_pred_ontrue_on_false(draw): - shape = draw(helpers.get_shape(min_num_dims=1, min_dim_size=1)) - pred = draw(helpers.array_values(dtype="bool", shape=shape)) - dtypes, on_true_on_false = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, - shape=shape, - large_abs_safety_factor=16, - small_abs_safety_factor=16, - safety_factor_scale="log", - shared_dtype=True, - ) - ) - return dtypes, pred, on_true_on_false - - -@handle_frontend_test( - fn_tree="jax.lax.select", - dtype_pred_ontrue_on_false=_dtype_pred_ontrue_on_false(), - test_with_out=st.just(False), -) -def test_jax_select( - *, - dtype_pred_ontrue_on_false, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, pred, on_true_on_false = dtype_pred_ontrue_on_false - helpers.test_frontend_function( - input_dtypes=["bool"] + input_dtype, - frontend=frontend, - backend_to_test=backend_fw, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - pred=pred, - on_true=on_true_on_false[0], - on_false=on_true_on_false[0], - ) - - -# top_k +# sinh @handle_frontend_test( - fn_tree="jax.lax.top_k", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_num_dims=1, - large_abs_safety_factor=8, - small_abs_safety_factor=8, - safety_factor_scale="log", - min_dim_size=4, - max_dim_size=10, - ), - k=helpers.ints(min_value=1, max_value=4), + fn_tree="jax.lax.sinh", + dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), test_with_out=st.just(False), ) -def test_jax_top_k( +def test_jax_sinh( *, dtype_and_x, - k, on_device, fn_tree, frontend, test_flags, backend_fw, ): - dtype, x = dtype_and_x + input_dtype, x = dtype_and_x + helpers.test_frontend_function( - input_dtypes=dtype, + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - operand=x[0], - k=k, - # test_values=False, + x=x[0], ) -def _get_reduce_func(dtype): - if dtype[0] == "bool": - return st.sampled_from([jnp.logical_and, jnp.logical_or]) - else: - return st.sampled_from([jlax.add, jlax.max, jlax.min, jlax.mul, jnp.multiply]) - - @handle_frontend_test( - fn_tree="jax.lax.reduce_window", - all_args=_reduce_window_helper(_get_reduce_func), + fn_tree="jax.lax.slice", + dtype_x_params=_slice_helper(), test_with_out=st.just(False), ) -def test_jax_reduce_window( +def test_jax_slice( *, - all_args, + dtype_x_params, on_device, fn_tree, frontend, test_flags, backend_fw, ): - dtypes, operand, init_value, computation, others, padding = all_args + dtype, x, start_indices, limit_indices, strides = dtype_x_params helpers.test_frontend_function( - input_dtypes=dtypes, + input_dtypes=dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - operand=operand[0], - init_value=init_value[0], - computation=computation, - window_dimensions=others[0], - window_strides=others[1], - padding=padding, - base_dilation=others[2], - window_dilation=None, + operand=x[0], + start_indices=start_indices, + limit_indices=limit_indices, + strides=strides, ) -# real @handle_frontend_test( - fn_tree="jax.lax.real", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("complex") - ), + fn_tree="jax.lax.slice_in_dim", + dtype_x_params=_slice_in_dim_helper(), + test_with_out=st.just(False), ) -def test_jax_real( +def test_jax_slice_in_dim( *, - dtype_and_x, + dtype_x_params, on_device, fn_tree, frontend, test_flags, backend_fw, ): - input_dtype, x = dtype_and_x + dtype, x, start_index, limit_index, stride, axis = dtype_x_params helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - test_values=True, - x=x[0], + operand=x[0], + start_index=start_index, + limit_index=limit_index, + stride=stride, + axis=axis, ) -# squeeze -@st.composite -def _squeeze_helper(draw): - shape = draw(st.shared(helpers.get_shape(), key="value_shape")) - valid_axes = [] - for index, axis in enumerate(shape): - if axis == 1: - valid_axes.append(index) - return valid_axes - - @handle_frontend_test( - fn_tree="jax.lax.squeeze", - dtype_and_values=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - shape=st.shared( - helpers.get_shape( - allow_none=False, - min_num_dims=1, - max_num_dims=10, - min_dim_size=1, - max_dim_size=5, - ), - key="value_shape", - ), + fn_tree="jax.lax.sort", + dtype_x_bounded_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("numeric"), + min_num_dims=1, + valid_axis=True, + force_int_axis=True, ), - dim=_squeeze_helper(), + is_stable=st.booleans(), + test_with_out=st.just(False), ) -def test_jax_squeeze( +def test_jax_sort( *, - dtype_and_values, - dim, + dtype_x_bounded_axis, + is_stable, on_device, fn_tree, frontend, test_flags, backend_fw, ): - input_dtype, value = dtype_and_values + input_dtype, x, axis = dtype_x_bounded_axis helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -2864,28 +2870,19 @@ def test_jax_squeeze( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - array=value[0], - dimensions=dim, + operand=x[0], + dimension=axis, + is_stable=is_stable, ) -# nextafter +# sqrt @handle_frontend_test( - fn_tree="jax.lax.nextafter", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=["float32", "float64"], - min_value=-100, - max_value=100, - min_num_dims=1, - max_num_dims=3, - min_dim_size=1, - max_dim_size=3, - num_arrays=2, - shared_dtype=True, - ), + fn_tree="jax.lax.sqrt", + dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), test_with_out=st.just(False), ) -def test_jax_nextafter( +def test_jax_sqrt( *, dtype_and_x, on_device, @@ -2902,55 +2899,69 @@ def test_jax_nextafter( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x1=x[0], - x2=x[0], + x=x[0], ) -# conj @handle_frontend_test( - fn_tree="jax.lax.conj", + fn_tree="jax.lax.square", dtype_and_x=helpers.dtype_and_values( - available_dtypes=["complex64"], + available_dtypes=helpers.get_dtypes("numeric"), + small_abs_safety_factor=2, + large_abs_safety_factor=2, + safety_factor_scale="log", ), + test_with_out=st.just(False), ) -def test_jax_conj( +def test_jax_square( *, dtype_and_x, - test_flags, on_device, fn_tree, frontend, + test_flags, backend_fw, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, - test_flags=test_flags, frontend=frontend, + test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, x=x[0], ) -# is_finite @handle_frontend_test( - fn_tree="jax.lax.is_finite", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), - test_with_out=st.just(False), + fn_tree="jax.lax.squeeze", + dtype_and_values=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + shape=st.shared( + helpers.get_shape( + allow_none=False, + min_num_dims=1, + max_num_dims=10, + min_dim_size=1, + max_dim_size=5, + ), + key="value_shape", + ), + ), + dim=_squeeze_helper(), ) -def test_jax_is_finite( +def test_jax_squeeze( *, - dtype_and_x, + dtype_and_values, + dim, on_device, fn_tree, frontend, test_flags, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, value = dtype_and_values helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -2958,19 +2969,21 @@ def test_jax_is_finite( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], + array=value[0], + dimensions=dim, ) -# cbrt @handle_frontend_test( - fn_tree="jax.lax.cbrt", + fn_tree="jax.lax.sub", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), min_value=0.0 + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, + shared_dtype=True, ), test_with_out=st.just(False), ) -def test_jax_cbrt( +def test_jax_sub( *, dtype_and_x, on_device, @@ -2988,35 +3001,26 @@ def test_jax_cbrt( fn_tree=fn_tree, on_device=on_device, x=x[0], + y=x[1], ) -# cummin +# tan @handle_frontend_test( - fn_tree="jax.lax.cummin", - dtype_x_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("valid"), - min_num_dims=1, - max_num_dims=5, - valid_axis=True, - allow_neg_axes=False, - max_axes_size=1, - force_int_axis=True, - ), - reverse=st.booleans(), + fn_tree="jax.lax.tan", + dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), test_with_out=st.just(False), ) -def test_jax_cummin( +def test_jax_tan( *, - dtype_x_axis, - reverse, + dtype_and_x, on_device, fn_tree, frontend, test_flags, backend_fw, ): - input_dtype, x, axis = dtype_x_axis + input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -3024,11 +3028,7 @@ def test_jax_cummin( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - rtol=1e-2, - atol=1e-2, - operand=x[0], - axis=axis, - reverse=reverse, + x=x[0], ) @@ -3064,59 +3064,67 @@ def test_jax_tie_in( ) -# erfc +# top_k @handle_frontend_test( - fn_tree="jax.lax.erfc", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("valid")), + fn_tree="jax.lax.top_k", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_num_dims=1, + large_abs_safety_factor=8, + small_abs_safety_factor=8, + safety_factor_scale="log", + min_dim_size=4, + max_dim_size=10, + ), + k=helpers.ints(min_value=1, max_value=4), test_with_out=st.just(False), ) -def test_jax_erfc( +def test_jax_top_k( *, dtype_and_x, + k, on_device, fn_tree, frontend, test_flags, backend_fw, ): - input_dtype, x = dtype_and_x + dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - rtol=1e-2, - atol=1e-2, - x=x[0], + operand=x[0], + k=k, + # test_values=False, ) -# iota @handle_frontend_test( - fn_tree="jax.lax.iota", - dtypes=helpers.get_dtypes("valid", full=False), - size=helpers.ints(min_value=0, max_value=10), + fn_tree="jax.lax.transpose", + dtype_x_dims=_dtype_values_dims(), test_with_out=st.just(False), ) -def test_jax_iota( +def test_jax_transpose( *, - dtypes, - size, + dtype_x_dims, on_device, fn_tree, frontend, test_flags, backend_fw, ): + input_dtype, x, dims = dtype_x_dims helpers.test_frontend_function( - input_dtypes=dtypes, + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - dtype=dtypes[0], - size=size, + operand=x[0], + permutation=dims, ) diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py index c7d69200a4be9..f5984c3d9083d 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py @@ -5,49 +5,80 @@ from ivy_tests.test_ivy.helpers import handle_frontend_test -# Todo : turn on complex dtype activation tests once supported in all backends +# --- Helpers --- # +# --------------- # + + +# one_hot +@st.composite +def _dtype_indices_classes_axis(draw): + classes = draw(helpers.ints(min_value=2, max_value=100)) + dtype, indices, shape = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("integer"), + min_value=0, + max_value=classes - 1, + small_abs_safety_factor=4, + ret_shape=True, + ) + ) + + axis = draw(st.integers(min_value=-1, max_value=len(shape) - 1)) + return dtype, indices, classes, axis + + +# --- Main --- # +# ------------ # + + @handle_frontend_test( - fn_tree="jax.nn.relu", + fn_tree="jax.nn.celu", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - large_abs_safety_factor=3, - small_abs_safety_factor=3, + available_dtypes=helpers.get_dtypes("float_and_integer"), + min_value=-5, + max_value=5, safety_factor_scale="linear", ), + alpha=helpers.floats(min_value=0.01, max_value=1), test_with_out=st.just(False), ) -def test_jax_relu( +def test_jax_celu( *, dtype_and_x, + alpha, test_flags, on_device, fn_tree, frontend, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtypes, xs = dtype_and_x helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=input_dtypes, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], + x=xs[0], + alpha=alpha, ) +# elu @handle_frontend_test( - fn_tree="jax.nn.relu6", + fn_tree="jax.nn.elu", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float_and_integer"), - large_abs_safety_factor=2, - small_abs_safety_factor=2, + min_value=-5, + max_value=5, safety_factor_scale="linear", + num_arrays=2, + shared_dtype=True, ), test_with_out=st.just(False), ) -def test_jax_relu6( +def test_jax_elu( *, dtype_and_x, test_flags, @@ -56,31 +87,38 @@ def test_jax_relu6( frontend, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtypes, xs = dtype_and_x helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=input_dtypes, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], + x=xs[0], + alpha=xs[1], + rtol=1e-03, + atol=1e-03, ) @handle_frontend_test( - fn_tree="jax.nn.soft_sign", + fn_tree="jax.nn.gelu", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float_and_integer"), - large_abs_safety_factor=2, - small_abs_safety_factor=2, + available_dtypes=helpers.get_dtypes("float_and_complex"), + large_abs_safety_factor=1, + small_abs_safety_factor=1, safety_factor_scale="linear", + min_value=-1e4, + max_value=1e4, ), + approximate=st.booleans(), test_with_out=st.just(False), ) -def test_jax_soft_sign( +def test_jax_gelu( *, dtype_and_x, + approximate, test_flags, on_device, fn_tree, @@ -88,6 +126,9 @@ def test_jax_soft_sign( backend_fw, ): input_dtype, x = dtype_and_x + # As erf function doesn't support complex dtype + if "complex" in str(x[0].dtype): + approximate = True helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -95,23 +136,33 @@ def test_jax_soft_sign( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, + rtol=1e-02, + atol=1e-02, x=x[0], + approximate=approximate, ) +# glu @handle_frontend_test( - fn_tree="jax.nn.silu", + fn_tree="jax.nn.glu", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float_and_integer"), + available_dtypes=helpers.get_dtypes("float"), large_abs_safety_factor=2, small_abs_safety_factor=2, safety_factor_scale="linear", + min_value=-2, + min_num_dims=1, + min_dim_size=4, + max_dim_size=4, ), + axis=helpers.ints(min_value=-1, max_value=0), test_with_out=st.just(False), ) -def test_jax_silu( +def test_jax_glu( *, dtype_and_x, + axis, test_flags, on_device, fn_tree, @@ -126,98 +177,84 @@ def test_jax_silu( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, + rtol=1e-01, + atol=1e-01, x=x[0], + axis=axis, ) @handle_frontend_test( - fn_tree="jax.nn.leaky_relu", + fn_tree="jax.nn.hard_sigmoid", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), + available_dtypes=helpers.get_dtypes("float_and_integer"), large_abs_safety_factor=2, small_abs_safety_factor=2, - safety_factor_scale="linear", ), - negative_slope=helpers.floats(min_value=0.0, max_value=1.0), test_with_out=st.just(False), ) -def test_jax_leaky_relu( +def test_jax_hard_sigmoid( *, dtype_and_x, - negative_slope, test_flags, on_device, fn_tree, frontend, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtypes, xs = dtype_and_x helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=input_dtypes, backend_to_test=backend_fw, - rtol=1e-01, - atol=1e-01, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], - negative_slope=negative_slope, + x=xs[0], ) @handle_frontend_test( - fn_tree="jax.nn.gelu", + fn_tree="jax.nn.hard_silu", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float_and_complex"), - large_abs_safety_factor=1, - small_abs_safety_factor=1, - safety_factor_scale="linear", - min_value=-1e4, - max_value=1e4, + available_dtypes=helpers.get_dtypes("float_and_integer"), + large_abs_safety_factor=2, + small_abs_safety_factor=2, ), - approximate=st.booleans(), test_with_out=st.just(False), ) -def test_jax_gelu( +def test_jax_hard_silu( *, dtype_and_x, - approximate, test_flags, on_device, fn_tree, frontend, backend_fw, ): - input_dtype, x = dtype_and_x - # As erf function doesn't support complex dtype - if "complex" in str(x[0].dtype): - approximate = True + input_dtypes, xs = dtype_and_x helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=input_dtypes, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - rtol=1e-02, - atol=1e-02, - x=x[0], - approximate=approximate, + x=xs[0], ) @handle_frontend_test( - fn_tree="jax.nn.sigmoid", + fn_tree="jax.nn.hard_swish", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - large_abs_safety_factor=2, - small_abs_safety_factor=2, + min_value=-10, + max_value=10, safety_factor_scale="linear", ), test_with_out=st.just(False), ) -def test_jax_sigmoid( +def test_jax_hard_swish( *, dtype_and_x, test_flags, @@ -230,6 +267,8 @@ def test_jax_sigmoid( helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, + rtol=1e-02, + atol=1e-02, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, @@ -238,41 +277,27 @@ def test_jax_sigmoid( ) -# one_hot -@st.composite -def _dtype_indices_classes_axis(draw): - classes = draw(helpers.ints(min_value=2, max_value=100)) - dtype, indices, shape = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("integer"), - min_value=0, - max_value=classes - 1, - small_abs_safety_factor=4, - ret_shape=True, - ) - ) - - axis = draw(st.integers(min_value=-1, max_value=len(shape) - 1)) - return dtype, indices, classes, axis - - @handle_frontend_test( - fn_tree="jax.nn.one_hot", - dtype_indices_classes_axis=_dtype_indices_classes_axis(), - dtype=helpers.get_dtypes("float", full=False), + fn_tree="jax.nn.hard_tanh", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float_and_integer"), + large_abs_safety_factor=2, + small_abs_safety_factor=2, + safety_factor_scale="linear", + ), test_with_out=st.just(False), ) -def test_jax_one_hot( +def test_jax_hard_tanh( *, - dtype_indices_classes_axis, - dtype, + dtype_and_x, test_flags, on_device, fn_tree, frontend, backend_fw, ): - input_dtype, indices, num_classes, axis = dtype_indices_classes_axis + # TODO: enable this test for all valid dtypes as jax.nn.hard_tanh supports + input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -280,63 +305,25 @@ def test_jax_one_hot( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - rtol=1e-02, - x=indices[0], - num_classes=num_classes, - dtype=dtype[0], - axis=axis, - ) - - -@handle_frontend_test( - fn_tree="jax.nn.softmax", - dtype_x_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("float"), - min_num_dims=2, - max_axes_size=1, - force_int_axis=True, - valid_axis=True, - ), - test_with_out=st.just(False), -) -def test_jax_softmax( - *, - dtype_x_axis, - test_flags, - on_device, - fn_tree, - frontend, - backend_fw, -): - x_dtype, x, axis = dtype_x_axis - - helpers.test_frontend_function( - input_dtypes=x_dtype, - frontend=frontend, - backend_to_test=backend_fw, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - rtol=1e-02, - atol=1e-02, x=x[0], - axis=axis, ) @handle_frontend_test( - fn_tree="jax.nn.softplus", + fn_tree="jax.nn.leaky_relu", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), - large_abs_safety_factor=4, - small_abs_safety_factor=4, - safety_factor_scale="log", + large_abs_safety_factor=2, + small_abs_safety_factor=2, + safety_factor_scale="linear", ), + negative_slope=helpers.floats(min_value=0.0, max_value=1.0), test_with_out=st.just(False), ) -def test_jax_softplus( +def test_jax_leaky_relu( *, dtype_and_x, + negative_slope, test_flags, on_device, fn_tree, @@ -347,11 +334,14 @@ def test_jax_softplus( helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, + rtol=1e-01, + atol=1e-01, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, x=x[0], + negative_slope=negative_slope, ) @@ -428,44 +418,46 @@ def test_jax_log_softmax( ) -# glu @handle_frontend_test( - fn_tree="jax.nn.glu", + fn_tree="jax.nn.logsumexp", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("float_and_integer"), large_abs_safety_factor=2, small_abs_safety_factor=2, safety_factor_scale="linear", - min_value=-2, - min_num_dims=1, - min_dim_size=4, - max_dim_size=4, + num_arrays=2, + shared_dtype=True, ), - axis=helpers.ints(min_value=-1, max_value=0), + axis=st.just(None), + keepdims=st.booleans(), + return_sign=st.booleans(), test_with_out=st.just(False), ) -def test_jax_glu( +def test_jax_logsumexp( *, dtype_and_x, axis, + keepdims, + return_sign, test_flags, on_device, fn_tree, frontend, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtypes, xs = dtype_and_x helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=input_dtypes, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - rtol=1e-01, - atol=1e-01, - x=x[0], + a=xs[0], axis=axis, + b=xs[1], + keepdims=keepdims, + return_sign=return_sign, ) @@ -520,26 +512,22 @@ def test_jax_normalize( @handle_frontend_test( - fn_tree="jax.nn.hard_tanh", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float_and_integer"), - large_abs_safety_factor=2, - small_abs_safety_factor=2, - safety_factor_scale="linear", - ), + fn_tree="jax.nn.one_hot", + dtype_indices_classes_axis=_dtype_indices_classes_axis(), + dtype=helpers.get_dtypes("float", full=False), test_with_out=st.just(False), ) -def test_jax_hard_tanh( +def test_jax_one_hot( *, - dtype_and_x, + dtype_indices_classes_axis, + dtype, test_flags, on_device, fn_tree, frontend, backend_fw, ): - # TODO: enable this test for all valid dtypes as jax.nn.hard_tanh supports - input_dtype, x = dtype_and_x + input_dtype, indices, num_classes, axis = dtype_indices_classes_axis helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -547,58 +535,57 @@ def test_jax_hard_tanh( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], + rtol=1e-02, + x=indices[0], + num_classes=num_classes, + dtype=dtype[0], + axis=axis, ) +# Todo : turn on complex dtype activation tests once supported in all backends @handle_frontend_test( - fn_tree="jax.nn.celu", + fn_tree="jax.nn.relu", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float_and_integer"), - min_value=-5, - max_value=5, + available_dtypes=helpers.get_dtypes("numeric"), + large_abs_safety_factor=3, + small_abs_safety_factor=3, safety_factor_scale="linear", ), - alpha=helpers.floats(min_value=0.01, max_value=1), test_with_out=st.just(False), ) -def test_jax_celu( +def test_jax_relu( *, dtype_and_x, - alpha, test_flags, on_device, fn_tree, frontend, backend_fw, ): - input_dtypes, xs = dtype_and_x + input_dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=input_dtypes, + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=xs[0], - alpha=alpha, + x=x[0], ) -# elu @handle_frontend_test( - fn_tree="jax.nn.elu", + fn_tree="jax.nn.relu6", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float_and_integer"), - min_value=-5, - max_value=5, + large_abs_safety_factor=2, + small_abs_safety_factor=2, safety_factor_scale="linear", - num_arrays=2, - shared_dtype=True, ), test_with_out=st.just(False), ) -def test_jax_elu( +def test_jax_relu6( *, dtype_and_x, test_flags, @@ -607,66 +594,53 @@ def test_jax_elu( frontend, backend_fw, ): - input_dtypes, xs = dtype_and_x + input_dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=input_dtypes, + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=xs[0], - alpha=xs[1], - rtol=1e-03, - atol=1e-03, + x=x[0], ) @handle_frontend_test( - fn_tree="jax.nn.logsumexp", + fn_tree="jax.nn.selu", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float_and_integer"), large_abs_safety_factor=2, small_abs_safety_factor=2, - safety_factor_scale="linear", - num_arrays=2, - shared_dtype=True, + safety_factor_scale="log", ), - axis=st.just(None), - keepdims=st.booleans(), - return_sign=st.booleans(), test_with_out=st.just(False), ) -def test_jax_logsumexp( +def test_jax_selu( *, dtype_and_x, - axis, - keepdims, - return_sign, test_flags, on_device, fn_tree, frontend, backend_fw, ): - input_dtypes, xs = dtype_and_x + input_dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=input_dtypes, + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - a=xs[0], - axis=axis, - b=xs[1], - keepdims=keepdims, - return_sign=return_sign, + rtol=1e-02, + atol=1e-02, + x=x[0], ) @handle_frontend_test( - fn_tree="jax.nn.swish", + fn_tree="jax.nn.sigmoid", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), large_abs_safety_factor=2, @@ -675,7 +649,7 @@ def test_jax_logsumexp( ), test_with_out=st.just(False), ) -def test_jax_swish( +def test_jax_sigmoid( *, dtype_and_x, test_flags, @@ -697,16 +671,16 @@ def test_jax_swish( @handle_frontend_test( - fn_tree="jax.nn.hard_swish", + fn_tree="jax.nn.silu", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_value=-10, - max_value=10, + available_dtypes=helpers.get_dtypes("float_and_integer"), + large_abs_safety_factor=2, + small_abs_safety_factor=2, safety_factor_scale="linear", ), test_with_out=st.just(False), ) -def test_jax_hard_swish( +def test_jax_silu( *, dtype_and_x, test_flags, @@ -719,8 +693,6 @@ def test_jax_hard_swish( helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, - rtol=1e-02, - atol=1e-02, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, @@ -730,15 +702,16 @@ def test_jax_hard_swish( @handle_frontend_test( - fn_tree="jax.nn.hard_silu", + fn_tree="jax.nn.soft_sign", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float_and_integer"), large_abs_safety_factor=2, small_abs_safety_factor=2, + safety_factor_scale="linear", ), test_with_out=st.just(False), ) -def test_jax_hard_silu( +def test_jax_soft_sign( *, dtype_and_x, test_flags, @@ -747,28 +720,65 @@ def test_jax_hard_silu( frontend, backend_fw, ): - input_dtypes, xs = dtype_and_x + input_dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=input_dtypes, + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=xs[0], + x=x[0], ) @handle_frontend_test( - fn_tree="jax.nn.hard_sigmoid", + fn_tree="jax.nn.softmax", + dtype_x_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("float"), + min_num_dims=2, + max_axes_size=1, + force_int_axis=True, + valid_axis=True, + ), + test_with_out=st.just(False), +) +def test_jax_softmax( + *, + dtype_x_axis, + test_flags, + on_device, + fn_tree, + frontend, + backend_fw, +): + x_dtype, x, axis = dtype_x_axis + + helpers.test_frontend_function( + input_dtypes=x_dtype, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + rtol=1e-02, + atol=1e-02, + x=x[0], + axis=axis, + ) + + +@handle_frontend_test( + fn_tree="jax.nn.softplus", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float_and_integer"), - large_abs_safety_factor=2, - small_abs_safety_factor=2, + available_dtypes=helpers.get_dtypes("numeric"), + large_abs_safety_factor=4, + small_abs_safety_factor=4, + safety_factor_scale="log", ), test_with_out=st.just(False), ) -def test_jax_hard_sigmoid( +def test_jax_softplus( *, dtype_and_x, test_flags, @@ -777,29 +787,29 @@ def test_jax_hard_sigmoid( frontend, backend_fw, ): - input_dtypes, xs = dtype_and_x + input_dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=input_dtypes, + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=xs[0], + x=x[0], ) @handle_frontend_test( - fn_tree="jax.nn.selu", + fn_tree="jax.nn.swish", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float_and_integer"), + available_dtypes=helpers.get_dtypes("float"), large_abs_safety_factor=2, small_abs_safety_factor=2, - safety_factor_scale="log", + safety_factor_scale="linear", ), test_with_out=st.just(False), ) -def test_jax_selu( +def test_jax_swish( *, dtype_and_x, test_flags, @@ -816,7 +826,5 @@ def test_jax_selu( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - rtol=1e-02, - atol=1e-02, x=x[0], ) diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_creation.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_creation.py index e7fb44c3caee4..ccf37438b1bf0 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_creation.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_creation.py @@ -12,6 +12,62 @@ from ivy_tests.test_ivy.helpers import handle_frontend_test +# --- Helpers --- # +# --------------- # + + +@st.composite +def _get_dtype_and_range(draw): + dim = draw(helpers.ints(min_value=2, max_value=5)) + dtype = draw(helpers.get_dtypes("float", index=1, full=False)) + start = draw( + helpers.array_values(dtype=dtype[0], shape=(dim,), min_value=-50, max_value=0) + ) + stop = draw( + helpers.array_values(dtype=dtype[0], shape=(dim,), min_value=1, max_value=50) + ) + return dtype * 2, start, stop + + +# --- Main --- # +# ------------ # + + +# arange +@handle_frontend_test( + fn_tree="jax.numpy.arange", + start=st.integers(min_value=-100, max_value=100), + stop=st.integers(min_value=-100, max_value=100) | st.none(), + step=st.integers(min_value=-100, max_value=100).filter(lambda x: x != 0), + dtype=helpers.get_dtypes("numeric", full=False), + test_with_out=st.just(False), +) +def test_jax_arange( + *, + start, + stop, + step, + dtype, + on_device, + fn_tree, + test_flags, + frontend, + backend_fw, +): + helpers.test_frontend_function( + input_dtypes=dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + start=start, + stop=stop, + step=step, + dtype=dtype[0], + ) + + @handle_frontend_test( fn_tree="jax.numpy.array", dtype_and_x=helpers.dtype_and_values( @@ -66,67 +122,54 @@ def test_jax_array( ) -# zeros_like +# asarray @handle_frontend_test( - fn_tree="jax.numpy.zeros_like", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - ), - shape=helpers.get_shape( - allow_none=True, - min_num_dims=1, + fn_tree="jax.numpy.asarray", + dtype_and_a=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=1, + min_num_dims=0, max_num_dims=5, min_dim_size=1, - max_dim_size=10, + max_dim_size=5, ), - dtype=helpers.get_dtypes("valid", full=False), test_with_out=st.just(False), ) -def test_jax_zeros_like( - dtype_and_x, - dtype, - shape, +def test_jax_asarray( + dtype_and_a, test_flags, frontend, backend_fw, fn_tree, on_device, ): - input_dtype, x = dtype_and_x + dtype, a = dtype_and_a helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - a=x[0], + a=a, dtype=dtype[0], - shape=shape, ) -# arange +# bool_ @handle_frontend_test( - fn_tree="jax.numpy.arange", - start=st.integers(min_value=-100, max_value=100), - stop=st.integers(min_value=-100, max_value=100) | st.none(), - step=st.integers(min_value=-100, max_value=100).filter(lambda x: x != 0), - dtype=helpers.get_dtypes("numeric", full=False), - test_with_out=st.just(False), + fn_tree="jax.numpy.bool_", + dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("bool")), ) -def test_jax_arange( - *, - start, - stop, - step, - dtype, - on_device, - fn_tree, +def test_jax_bool_( + dtype_and_x, test_flags, frontend, backend_fw, + fn_tree, + on_device, ): + dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=dtype, backend_to_test=backend_fw, @@ -134,70 +177,64 @@ def test_jax_arange( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - start=start, - stop=stop, - step=step, - dtype=dtype[0], + x=x[0], ) -# zeros @handle_frontend_test( - fn_tree="jax.numpy.zeros", - shape=helpers.get_shape( - allow_none=False, - min_num_dims=1, - max_num_dims=5, - min_dim_size=1, - max_dim_size=10, + fn_tree="jax.numpy.cdouble", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("complex") ), - dtypes=helpers.get_dtypes("numeric", full=False), - test_with_out=st.just(False), ) -def test_jax_zeros( - *, - dtypes, - shape, - on_device, - fn_tree, +def test_jax_cdouble( + dtype_and_x, test_flags, frontend, backend_fw, + fn_tree, + on_device, ): + dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=dtypes, + input_dtypes=dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - shape=shape, - dtype=dtypes[0], + x=x[0], ) -# ones @handle_frontend_test( - fn_tree="jax.numpy.ones", - shape=helpers.get_shape( - allow_none=False, + fn_tree="jax.numpy.compress", + dtype_arr_ax=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("valid"), min_num_dims=1, max_num_dims=5, - min_dim_size=1, - max_dim_size=10, + min_dim_size=10, + max_dim_size=100, + valid_axis=True, + force_int_axis=True, + ), + condition=helpers.array_values( + dtype=helpers.get_dtypes("bool"), + shape=helpers.get_shape( + min_num_dims=1, max_num_dims=1, min_dim_size=1, max_dim_size=5 + ), ), - dtype=helpers.get_dtypes("valid", full=False), - test_with_out=st.just(False), ) -def test_jax_ones( - shape, - dtype, - test_flags, +def test_jax_compress( + dtype_arr_ax, + condition, frontend, - backend_fw, + test_flags, fn_tree, + backend_fw, on_device, ): + dtype, arr, ax = dtype_arr_ax helpers.test_frontend_function( input_dtypes=dtype, backend_to_test=backend_fw, @@ -205,73 +242,61 @@ def test_jax_ones( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - shape=shape, - dtype=dtype[0], + condition=condition, + a=arr[0], + axis=ax, ) -# ones_like +# copy @handle_frontend_test( - fn_tree="jax.numpy.ones_like", - dtype_and_x=helpers.dtype_and_values( + fn_tree="jax.numpy.copy", + dtype_and_a=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), - ), - shape=helpers.get_shape( - allow_none=True, - min_num_dims=1, + num_arrays=1, + min_num_dims=0, max_num_dims=5, min_dim_size=1, - max_dim_size=10, + max_dim_size=5, ), - dtype=helpers.get_dtypes("valid", full=False), test_with_out=st.just(False), ) -def test_jax_ones_like( - dtype_and_x, - shape, - dtype, - test_flags, +def test_jax_copy( + dtype_and_a, frontend, - backend_fw, + test_flags, fn_tree, + backend_fw, on_device, ): - input_dtype, x = dtype_and_x + dtype, a = dtype_and_a helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - a=x[0], - dtype=dtype[0], - shape=shape, + a=a[0], ) -# asarray @handle_frontend_test( - fn_tree="jax.numpy.asarray", - dtype_and_a=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=1, - min_num_dims=0, - max_num_dims=5, - min_dim_size=1, - max_dim_size=5, + fn_tree="jax.numpy.csingle", + aliases=["jax.numpy.complex64"], + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric") ), - test_with_out=st.just(False), ) -def test_jax_asarray( - dtype_and_a, +def test_jax_csingle( + dtype_and_x, test_flags, frontend, backend_fw, fn_tree, on_device, ): - dtype, a = dtype_and_a + dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=dtype, backend_to_test=backend_fw, @@ -279,110 +304,32 @@ def test_jax_asarray( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - a=a, - dtype=dtype[0], + x=x[0], ) -# hstack +# double @handle_frontend_test( - fn_tree="jax.numpy.hstack", - dtype_and_tup=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - shared_dtype=True, - num_arrays=st.integers(min_value=2, max_value=2), - shape=helpers.get_shape( - min_num_dims=1, max_num_dims=3, min_dim_size=1, max_dim_size=5 - ), - ), - test_with_out=st.just(False), + fn_tree="jax.numpy.double", + dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), ) -def test_jax_hstack( - dtype_and_tup, +def test_jax_double( + dtype_and_x, test_flags, frontend, backend_fw, fn_tree, + on_device, ): - input_dtype, x = dtype_and_tup + dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - tup=x, - ) - - -# eye -@handle_frontend_test( - fn_tree="jax.numpy.eye", - n=helpers.ints(min_value=3, max_value=10), - m=st.none() | helpers.ints(min_value=3, max_value=10), - k=helpers.ints(min_value=-2, max_value=2), - dtypes=helpers.get_dtypes("valid", full=False), - test_with_out=st.just(False), -) -def test_jax_eye( - *, - n, - m, - k, - dtypes, - on_device, - fn_tree, - test_flags, - frontend, - backend_fw, -): - helpers.test_frontend_function( - input_dtypes=dtypes, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - N=n, - M=m, - k=k, - dtype=dtypes[0], - ) - - -# triu -@handle_frontend_test( - fn_tree="jax.numpy.triu", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=1, - min_num_dims=2, - max_num_dims=5, - min_dim_size=1, - max_dim_size=5, - ), - k=helpers.ints(min_value=-10, max_value=10), - test_with_out=st.just(False), -) -def test_jax_triu( - dtype_and_x, - k, - test_flags, - frontend, - backend_fw, - fn_tree, - on_device, -): - dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=dtype, + input_dtypes=dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - m=x[0], - k=k, + x=x[0], ) @@ -421,58 +368,33 @@ def test_jax_empty( ) -# vander +# empty_like @handle_frontend_test( - fn_tree="jax.numpy.vander", + fn_tree="jax.numpy.empty_like", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - shape=st.tuples( - st.integers(min_value=1, max_value=5), - ), + available_dtypes=helpers.get_dtypes("valid"), ), - N=st.integers(min_value=0, max_value=5), - increasing=st.booleans(), + shape=helpers.get_shape( + allow_none=True, + min_num_dims=1, + max_num_dims=5, + min_dim_size=1, + max_dim_size=10, + ), + dtype=helpers.get_dtypes("valid", full=False), + test_with_out=st.just(False), ) -def test_jax_vander( - *, +def test_jax_empty_like( dtype_and_x, - N, - increasing, + shape, + dtype, test_flags, - on_device, - fn_tree, frontend, backend_fw, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - test_flags=test_flags, - frontend=frontend, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - N=N, - increasing=increasing, - ) - - -# full_like -@handle_frontend_test( - fn_tree="jax.numpy.full_like", - input_fill_dtype=_input_fill_and_dtype(), - test_with_out=st.just(False), -) -def test_jax_full_like( - input_fill_dtype, - frontend, - test_flags, fn_tree, - backend_fw, on_device, ): - input_dtype, x, fill_value, dtype = input_fill_dtype + input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -480,22 +402,26 @@ def test_jax_full_like( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - a=x[0], - fill_value=fill_value, - dtype=dtype, + prototype=x[0], + dtype=dtype[0], + shape=shape, ) -# identity +# eye @handle_frontend_test( - fn_tree="jax.numpy.identity", + fn_tree="jax.numpy.eye", n=helpers.ints(min_value=3, max_value=10), + m=st.none() | helpers.ints(min_value=3, max_value=10), + k=helpers.ints(min_value=-2, max_value=2), dtypes=helpers.get_dtypes("valid", full=False), test_with_out=st.just(False), ) -def test_jax_identity( +def test_jax_eye( *, n, + m, + k, dtypes, on_device, fn_tree, @@ -510,63 +436,35 @@ def test_jax_identity( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - n=n, + N=n, + M=m, + k=k, dtype=dtypes[0], ) -# ndim -@handle_frontend_test( - fn_tree="jax.numpy.ndim", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("valid")), -) -def test_jax_ndim( - dtype_and_x, - test_flags, - frontend, - backend_fw, - fn_tree, - on_device, -): - dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - a=x[0], - ) - - -# empty_like +# full @handle_frontend_test( - fn_tree="jax.numpy.empty_like", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - ), + fn_tree="jax.numpy.full", shape=helpers.get_shape( - allow_none=True, min_num_dims=1, max_num_dims=5, min_dim_size=1, max_dim_size=10, ), - dtype=helpers.get_dtypes("valid", full=False), + input_fill_dtype=_input_fill_and_dtype(), test_with_out=st.just(False), ) -def test_jax_empty_like( - dtype_and_x, +def test_jax_full( shape, - dtype, - test_flags, + input_fill_dtype, frontend, - backend_fw, + test_flags, fn_tree, + backend_fw, on_device, ): - input_dtype, x = dtype_and_x + input_dtype, _, fill_value, dtype = input_fill_dtype helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -574,26 +472,19 @@ def test_jax_empty_like( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - prototype=x[0], - dtype=dtype[0], shape=shape, + fill_value=fill_value, + dtype=dtype, ) -# full +# full_like @handle_frontend_test( - fn_tree="jax.numpy.full", - shape=helpers.get_shape( - min_num_dims=1, - max_num_dims=5, - min_dim_size=1, - max_dim_size=10, - ), + fn_tree="jax.numpy.full_like", input_fill_dtype=_input_fill_and_dtype(), test_with_out=st.just(False), ) -def test_jax_full( - shape, +def test_jax_full_like( input_fill_dtype, frontend, test_flags, @@ -601,7 +492,7 @@ def test_jax_full( backend_fw, on_device, ): - input_dtype, _, fill_value, dtype = input_fill_dtype + input_dtype, x, fill_value, dtype = input_fill_dtype helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -609,39 +500,23 @@ def test_jax_full( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - shape=shape, + a=x[0], fill_value=fill_value, dtype=dtype, ) -@st.composite -def _get_dtype_and_range(draw): - dim = draw(helpers.ints(min_value=2, max_value=5)) - dtype = draw(helpers.get_dtypes("float", index=1, full=False)) - start = draw( - helpers.array_values(dtype=dtype[0], shape=(dim,), min_value=-50, max_value=0) - ) - stop = draw( - helpers.array_values(dtype=dtype[0], shape=(dim,), min_value=1, max_value=50) - ) - return dtype * 2, start, stop - - -# logspace @handle_frontend_test( - fn_tree="jax.numpy.logspace", + fn_tree="jax.numpy.geomspace", dtype_start_stop=_get_dtype_and_range(), num=helpers.ints(min_value=5, max_value=50), - base=helpers.ints(min_value=2, max_value=10), - axis=helpers.ints(min_value=-1, max_value=0), + endpoint=st.booleans(), test_with_out=st.just(False), ) -def test_jax_logspace( +def test_jax_geomspace( dtype_start_stop, num, - base, - axis, + endpoint, frontend, test_flags, fn_tree, @@ -656,48 +531,90 @@ def test_jax_logspace( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - rtol=1e-01, + rtol=1e-1, start=start, stop=stop, num=num, - endpoint=True, - base=base, + endpoint=endpoint, dtype=input_dtypes[0], - axis=axis, ) -# meshgrid +# hstack @handle_frontend_test( - fn_tree="jax.numpy.meshgrid", - dtype_and_arrays=helpers.dtype_and_values( + fn_tree="jax.numpy.hstack", + dtype_and_tup=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=st.integers(min_value=1, max_value=5), - min_num_dims=1, - max_num_dims=1, shared_dtype=True, - ), - sparse=st.booleans(), - indexing=st.sampled_from(["xy", "ij"]), + num_arrays=st.integers(min_value=2, max_value=2), + shape=helpers.get_shape( + min_num_dims=1, max_num_dims=3, min_dim_size=1, max_dim_size=5 + ), + ), test_with_out=st.just(False), ) -def test_jax_meshgrid( - dtype_and_arrays, - sparse, - indexing, +def test_jax_hstack( + dtype_and_tup, test_flags, frontend, backend_fw, fn_tree, +): + input_dtype, x = dtype_and_tup + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + tup=x, + ) + + +# identity +@handle_frontend_test( + fn_tree="jax.numpy.identity", + n=helpers.ints(min_value=3, max_value=10), + dtypes=helpers.get_dtypes("valid", full=False), + test_with_out=st.just(False), +) +def test_jax_identity( + *, + n, + dtypes, on_device, + fn_tree, + test_flags, + frontend, + backend_fw, ): - dtype, arrays = dtype_and_arrays - kw = {} - i = 0 - for x_ in arrays: - kw["x{}".format(i)] = x_ - i += 1 - test_flags.num_positional_args = len(arrays) + helpers.test_frontend_function( + input_dtypes=dtypes, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + n=n, + dtype=dtypes[0], + ) + + +@handle_frontend_test( + fn_tree="jax.numpy.iterable", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + ), +) +def test_jax_iterable( + dtype_and_x, + test_flags, + frontend, + backend_fw, + fn_tree, + on_device, +): + dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=dtype, backend_to_test=backend_fw, @@ -705,9 +622,7 @@ def test_jax_meshgrid( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - **kw, - sparse=sparse, - indexing=indexing, + y=x[0], ) @@ -749,53 +664,76 @@ def test_jax_linspace( ) -# copy +# logspace @handle_frontend_test( - fn_tree="jax.numpy.copy", - dtype_and_a=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - num_arrays=1, - min_num_dims=0, - max_num_dims=5, - min_dim_size=1, - max_dim_size=5, - ), + fn_tree="jax.numpy.logspace", + dtype_start_stop=_get_dtype_and_range(), + num=helpers.ints(min_value=5, max_value=50), + base=helpers.ints(min_value=2, max_value=10), + axis=helpers.ints(min_value=-1, max_value=0), test_with_out=st.just(False), ) -def test_jax_copy( - dtype_and_a, +def test_jax_logspace( + dtype_start_stop, + num, + base, + axis, frontend, test_flags, fn_tree, backend_fw, on_device, ): - dtype, a = dtype_and_a + input_dtypes, start, stop = dtype_start_stop helpers.test_frontend_function( - input_dtypes=dtype, + input_dtypes=input_dtypes, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - a=a[0], + rtol=1e-01, + start=start, + stop=stop, + num=num, + endpoint=True, + base=base, + dtype=input_dtypes[0], + axis=axis, ) -# single +# meshgrid @handle_frontend_test( - fn_tree="jax.numpy.single", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), + fn_tree="jax.numpy.meshgrid", + dtype_and_arrays=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=st.integers(min_value=1, max_value=5), + min_num_dims=1, + max_num_dims=1, + shared_dtype=True, + ), + sparse=st.booleans(), + indexing=st.sampled_from(["xy", "ij"]), + test_with_out=st.just(False), ) -def test_jax_single( - dtype_and_x, +def test_jax_meshgrid( + dtype_and_arrays, + sparse, + indexing, test_flags, frontend, backend_fw, fn_tree, on_device, ): - dtype, x = dtype_and_x + dtype, arrays = dtype_and_arrays + kw = {} + i = 0 + for x_ in arrays: + kw["x{}".format(i)] = x_ + i += 1 + test_flags.num_positional_args = len(arrays) helpers.test_frontend_function( input_dtypes=dtype, backend_to_test=backend_fw, @@ -803,16 +741,18 @@ def test_jax_single( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], + **kw, + sparse=sparse, + indexing=indexing, ) -# double +# ndim @handle_frontend_test( - fn_tree="jax.numpy.double", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), + fn_tree="jax.numpy.ndim", + dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("valid")), ) -def test_jax_double( +def test_jax_ndim( dtype_and_x, test_flags, frontend, @@ -828,103 +768,155 @@ def test_jax_double( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], + a=x[0], ) -# bool_ @handle_frontend_test( - fn_tree="jax.numpy.bool_", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("bool")), + fn_tree="jax.numpy.frombuffer", + dtype_buffer_count_offset=_get_dtype_buffer_count_offset(), + test_with_out=st.just(False), ) -def test_jax_bool_( - dtype_and_x, - test_flags, +def test_jax_numpy_frombuffer( + *, + dtype_buffer_count_offset, + on_device, + fn_tree, frontend, backend_fw, - fn_tree, - on_device, + test_flags, ): - dtype, x = dtype_and_x + input_dtype, buffer, count, offset = dtype_buffer_count_offset helpers.test_frontend_function( - input_dtypes=dtype, + input_dtypes=input_dtype, + frontend=frontend, + test_flags=test_flags, backend_to_test=backend_fw, + fn_tree=fn_tree, + on_device=on_device, + buffer=buffer, + dtype=input_dtype[0], + count=count, + offset=offset, + ) + + +@handle_frontend_test( + fn_tree="jax.numpy.in1d", + dtype_and_a=helpers.dtype_and_values(min_num_dims=1, max_num_dims=1), + dtype_and_b=helpers.dtype_and_values(min_num_dims=1, max_num_dims=1), + assume_unique=st.booleans(), + invert=st.booleans(), +) +def test_jax_numpy_in1d( + *, + dtype_and_a, + dtype_and_b, + assume_unique, + invert, + on_device, + fn_tree, + frontend, + backend_fw, + test_flags, +): + input_dtype_a, a = dtype_and_a + input_dtype_b, b = dtype_and_b + + helpers.test_frontend_function( + input_dtypes=input_dtype_a + input_dtype_b, frontend=frontend, test_flags=test_flags, + backend_to_test=backend_fw, fn_tree=fn_tree, on_device=on_device, - x=x[0], + ar1=a[0], + ar2=b[0], + assume_unique=assume_unique, + invert=invert, ) +# ones @handle_frontend_test( - fn_tree="jax.numpy.geomspace", - dtype_start_stop=_get_dtype_and_range(), - num=helpers.ints(min_value=5, max_value=50), - endpoint=st.booleans(), + fn_tree="jax.numpy.ones", + shape=helpers.get_shape( + allow_none=False, + min_num_dims=1, + max_num_dims=5, + min_dim_size=1, + max_dim_size=10, + ), + dtype=helpers.get_dtypes("valid", full=False), test_with_out=st.just(False), ) -def test_jax_geomspace( - dtype_start_stop, - num, - endpoint, - frontend, +def test_jax_ones( + shape, + dtype, test_flags, - fn_tree, + frontend, backend_fw, + fn_tree, on_device, ): - input_dtypes, start, stop = dtype_start_stop helpers.test_frontend_function( - input_dtypes=input_dtypes, + input_dtypes=dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - rtol=1e-1, - start=start, - stop=stop, - num=num, - endpoint=endpoint, - dtype=input_dtypes[0], + shape=shape, + dtype=dtype[0], ) +# ones_like @handle_frontend_test( - fn_tree="jax.numpy.csingle", - aliases=["jax.numpy.complex64"], + fn_tree="jax.numpy.ones_like", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric") + available_dtypes=helpers.get_dtypes("valid"), + ), + shape=helpers.get_shape( + allow_none=True, + min_num_dims=1, + max_num_dims=5, + min_dim_size=1, + max_dim_size=10, ), + dtype=helpers.get_dtypes("valid", full=False), + test_with_out=st.just(False), ) -def test_jax_csingle( +def test_jax_ones_like( dtype_and_x, + shape, + dtype, test_flags, frontend, backend_fw, fn_tree, on_device, ): - dtype, x = dtype_and_x + input_dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=dtype, + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], + a=x[0], + dtype=dtype[0], + shape=shape, ) +# single @handle_frontend_test( - fn_tree="jax.numpy.cdouble", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("complex") - ), + fn_tree="jax.numpy.single", + dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), ) -def test_jax_cdouble( +def test_jax_single( dtype_and_x, test_flags, frontend, @@ -945,33 +937,30 @@ def test_jax_cdouble( @handle_frontend_test( - fn_tree="jax.numpy.compress", - dtype_arr_ax=helpers.dtype_values_axis( + fn_tree="jax.numpy.size", + dtype_x_axis=helpers.dtype_values_axis( available_dtypes=helpers.get_dtypes("valid"), + abs_smallest_val=0, + num_arrays=1, min_num_dims=1, max_num_dims=5, - min_dim_size=10, + min_dim_size=1, max_dim_size=100, valid_axis=True, + allow_neg_axes=True, force_int_axis=True, ), - condition=helpers.array_values( - dtype=helpers.get_dtypes("bool"), - shape=helpers.get_shape( - min_num_dims=1, max_num_dims=1, min_dim_size=1, max_dim_size=5 - ), - ), ) -def test_jax_compress( - dtype_arr_ax, - condition, - frontend, +def test_jax_size( + dtype_x_axis, test_flags, - fn_tree, + frontend, backend_fw, + fn_tree, on_device, ): - dtype, arr, ax = dtype_arr_ax + dtype, x, axis = dtype_x_axis + helpers.test_frontend_function( input_dtypes=dtype, backend_to_test=backend_fw, @@ -979,20 +968,28 @@ def test_jax_compress( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - condition=condition, - a=arr[0], - axis=ax, + a=x[0], + axis=axis, ) +# triu @handle_frontend_test( - fn_tree="jax.numpy.iterable", + fn_tree="jax.numpy.triu", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=1, + min_num_dims=2, + max_num_dims=5, + min_dim_size=1, + max_dim_size=5, ), + k=helpers.ints(min_value=-10, max_value=10), + test_with_out=st.just(False), ) -def test_jax_iterable( +def test_jax_triu( dtype_and_x, + k, test_flags, frontend, backend_fw, @@ -1007,107 +1004,118 @@ def test_jax_iterable( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - y=x[0], + m=x[0], + k=k, ) +# vander @handle_frontend_test( - fn_tree="jax.numpy.size", - dtype_x_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("valid"), - abs_smallest_val=0, - num_arrays=1, - min_num_dims=1, - max_num_dims=5, - min_dim_size=1, - max_dim_size=100, - valid_axis=True, - allow_neg_axes=True, - force_int_axis=True, + fn_tree="jax.numpy.vander", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + shape=st.tuples( + st.integers(min_value=1, max_value=5), + ), ), + N=st.integers(min_value=0, max_value=5), + increasing=st.booleans(), ) -def test_jax_size( - dtype_x_axis, +def test_jax_vander( + *, + dtype_and_x, + N, + increasing, test_flags, + on_device, + fn_tree, frontend, backend_fw, - fn_tree, - on_device, ): - dtype, x, axis = dtype_x_axis - + input_dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=dtype, + input_dtypes=input_dtype, backend_to_test=backend_fw, - frontend=frontend, test_flags=test_flags, + frontend=frontend, fn_tree=fn_tree, on_device=on_device, - a=x[0], - axis=axis, + x=x[0], + N=N, + increasing=increasing, ) +# zeros @handle_frontend_test( - fn_tree="jax.numpy.frombuffer", - dtype_buffer_count_offset=_get_dtype_buffer_count_offset(), + fn_tree="jax.numpy.zeros", + shape=helpers.get_shape( + allow_none=False, + min_num_dims=1, + max_num_dims=5, + min_dim_size=1, + max_dim_size=10, + ), + dtypes=helpers.get_dtypes("numeric", full=False), test_with_out=st.just(False), ) -def test_jax_numpy_frombuffer( +def test_jax_zeros( *, - dtype_buffer_count_offset, + dtypes, + shape, on_device, fn_tree, + test_flags, frontend, backend_fw, - test_flags, ): - input_dtype, buffer, count, offset = dtype_buffer_count_offset helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=dtypes, + backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, - backend_to_test=backend_fw, fn_tree=fn_tree, on_device=on_device, - buffer=buffer, - dtype=input_dtype[0], - count=count, - offset=offset, + shape=shape, + dtype=dtypes[0], ) +# zeros_like @handle_frontend_test( - fn_tree="jax.numpy.in1d", - dtype_and_a=helpers.dtype_and_values(min_num_dims=1, max_num_dims=1), - dtype_and_b=helpers.dtype_and_values(min_num_dims=1, max_num_dims=1), - assume_unique=st.booleans(), - invert=st.booleans(), + fn_tree="jax.numpy.zeros_like", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + ), + shape=helpers.get_shape( + allow_none=True, + min_num_dims=1, + max_num_dims=5, + min_dim_size=1, + max_dim_size=10, + ), + dtype=helpers.get_dtypes("valid", full=False), + test_with_out=st.just(False), ) -def test_jax_numpy_in1d( - *, - dtype_and_a, - dtype_and_b, - assume_unique, - invert, - on_device, - fn_tree, +def test_jax_zeros_like( + dtype_and_x, + dtype, + shape, + test_flags, frontend, backend_fw, - test_flags, + fn_tree, + on_device, ): - input_dtype_a, a = dtype_and_a - input_dtype_b, b = dtype_and_b - + input_dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=input_dtype_a + input_dtype_b, + input_dtypes=input_dtype, + backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, - backend_to_test=backend_fw, fn_tree=fn_tree, on_device=on_device, - ar1=a[0], - ar2=b[0], - assume_unique=assume_unique, - invert=invert, + a=x[0], + dtype=dtype[0], + shape=shape, ) diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_dtype.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_dtype.py index 50d366dc025f1..1370f200e972f 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_dtype.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_dtype.py @@ -40,6 +40,41 @@ def test_jax_can_cast( ) +@handle_frontend_test( + fn_tree="jax.numpy.finfo", + dtype=helpers.get_dtypes("numeric", full=False), + test_with_out=st.just(False), +) +def test_jax_finfo(*, dtype, test_flags, on_device, fn_tree, frontend, backend_fw): + helpers.test_frontend_function( + input_dtypes=[], + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + dtype=dtype[0], + backend_to_test=backend_fw, + ) + + +@handle_frontend_test( + fn_tree="jax.numpy.iinfo", + dtype=helpers.get_dtypes("numeric", full=False), + test_with_out=st.just(False), +) +@settings(max_examples=200) +def test_jax_iinfo(*, dtype, test_flags, on_device, fn_tree, frontend, backend_fw): + helpers.test_frontend_function( + input_dtypes=[], + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + int_type=dtype[0], + backend_to_test=backend_fw, + ) + + # promote_types @handle_frontend_test( fn_tree="jax.numpy.promote_types", @@ -101,38 +136,3 @@ def test_jax_result_type( test_values=False, **kw, ) - - -@handle_frontend_test( - fn_tree="jax.numpy.iinfo", - dtype=helpers.get_dtypes("numeric", full=False), - test_with_out=st.just(False), -) -@settings(max_examples=200) -def test_jax_iinfo(*, dtype, test_flags, on_device, fn_tree, frontend, backend_fw): - helpers.test_frontend_function( - input_dtypes=[], - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - int_type=dtype[0], - backend_to_test=backend_fw, - ) - - -@handle_frontend_test( - fn_tree="jax.numpy.finfo", - dtype=helpers.get_dtypes("numeric", full=False), - test_with_out=st.just(False), -) -def test_jax_finfo(*, dtype, test_flags, on_device, fn_tree, frontend, backend_fw): - helpers.test_frontend_function( - input_dtypes=[], - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - dtype=dtype[0], - backend_to_test=backend_fw, - ) diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_indexing.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_indexing.py index eadfd3da43c6f..3e7842c3fd9e1 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_indexing.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_indexing.py @@ -15,6 +15,44 @@ import ivy.functional.frontends.jax.numpy as jnp_frontend +# --- Helpers --- # +# --------------- # + + +# diag +@st.composite +def _diag_helper(draw): + dtype, x = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + small_abs_safety_factor=2, + large_abs_safety_factor=2, + safety_factor_scale="log", + min_num_dims=1, + max_num_dims=2, + min_dim_size=1, + max_dim_size=50, + ) + ) + shape = x[0].shape + if len(shape) == 2: + k = draw(helpers.ints(min_value=-shape[0] + 1, max_value=shape[1] - 1)) + else: + k = draw(helpers.ints(min_value=0, max_value=shape[0])) + return dtype, x, k + + +@st.composite +def _get_dtype_square_x(draw): + dim_size = draw(helpers.ints(min_value=2, max_value=5)) + dtype_x = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), shape=(dim_size, dim_size) + ) + ) + return dtype_x + + # diagonal @st.composite def dims_and_offset(draw, shape): @@ -27,70 +65,69 @@ def dims_and_offset(draw, shape): return dim1, dim2, offset +# unravel_index +@st.composite +def max_value_as_shape_prod(draw): + shape = draw( + helpers.get_shape( + min_num_dims=1, + max_num_dims=5, + min_dim_size=1, + max_dim_size=5, + ) + ) + dtype_and_x = draw( + helpers.dtype_values_axis( + available_dtypes=["int32", "int64"], + min_value=0, + max_value=np.prod(shape) - 1, + ) + ) + return dtype_and_x, shape + + +# --- Main --- # +# ------------ # + + +# choose @handle_frontend_test( - fn_tree="jax.numpy.diagonal", - dtype_and_values=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - shape=st.shared(helpers.get_shape(min_num_dims=2), key="shape"), - ), - dims_and_offset=dims_and_offset( - shape=st.shared(helpers.get_shape(min_num_dims=2), key="shape") + fn_tree="jax.numpy.choose", + dtype_x_indices_axis=helpers.array_indices_axis( + array_dtypes=helpers.get_dtypes("numeric"), + indices_dtypes=["int32", "int64"], ), + out=st.none(), + mode=st.sampled_from(["wrap", "clip", "raise"]), + test_with_out=st.just(False), ) -def test_jax_diagonal( +def test_jax_choose( *, - dtype_and_values, - dims_and_offset, + dtype_x_indices_axis, + out, + mode, test_flags, - on_device, - fn_tree, frontend, backend_fw, + fn_tree, + on_device, ): - input_dtype, value = dtype_and_values - axis1, axis2, offset = dims_and_offset - a = value[0] - num_of_dims = len(np.shape(a)) - assume(axis1 != axis2) - if axis1 < 0: - assume(axis1 + num_of_dims != axis2) - if axis2 < 0: - assume(axis1 != axis2 + num_of_dims) + dtypes, x, indices, axis, _ = dtype_x_indices_axis + choices = ivy.array( + [np.random.randint(0, 10, size=x.shape) for _ in range(len(dtypes))] + ) helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=dtypes, backend_to_test=backend_fw, test_flags=test_flags, frontend=frontend, fn_tree=fn_tree, on_device=on_device, - a=a, - offset=offset, - axis1=axis1, - axis2=axis2, - ) - - -# diag -@st.composite -def _diag_helper(draw): - dtype, x = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - small_abs_safety_factor=2, - large_abs_safety_factor=2, - safety_factor_scale="log", - min_num_dims=1, - max_num_dims=2, - min_dim_size=1, - max_dim_size=50, - ) + arr=x, + choices=choices, + out=out, + mode=mode, ) - shape = x[0].shape - if len(shape) == 2: - k = draw(helpers.ints(min_value=-shape[0] + 1, max_value=shape[1] - 1)) - else: - k = draw(helpers.ints(min_value=0, max_value=shape[0])) - return dtype, x, k @handle_frontend_test( @@ -149,89 +186,88 @@ def test_jax_diag_indices( ) -# take_along_axis @handle_frontend_test( - fn_tree="jax.numpy.take_along_axis", - dtype_x_indices_axis=helpers.array_indices_axis( - array_dtypes=helpers.get_dtypes("numeric"), - indices_dtypes=["int32", "int64"], - min_num_dims=1, - max_num_dims=5, - min_dim_size=1, - max_dim_size=10, - indices_same_dims=True, - valid_bounds=False, - ), - mode=st.sampled_from(["clip", "fill", "drop"]), + dtype_x=_get_dtype_square_x(), + fn_tree="jax.numpy.diag_indices_from", test_with_out=st.just(False), ) -def test_jax_take_along_axis( - *, - dtype_x_indices_axis, - mode, +def test_jax_diag_indices_from( + dtype_x, test_flags, frontend, backend_fw, fn_tree, on_device, ): - dtypes, x, indices, axis, _ = dtype_x_indices_axis + dtype, x = dtype_x helpers.test_frontend_function( - input_dtypes=dtypes, + input_dtypes=dtype, backend_to_test=backend_fw, test_flags=test_flags, frontend=frontend, fn_tree=fn_tree, on_device=on_device, - arr=x, - indices=indices, - axis=axis, - mode=mode, + arr=x[0], ) -# Tril_indices @handle_frontend_test( - fn_tree="jax.numpy.tril_indices", - n_rows=helpers.ints(min_value=1, max_value=10), - k=helpers.ints(min_value=2, max_value=10), - dtype=helpers.get_dtypes("valid", full=False), - test_with_out=st.just(False), + fn_tree="jax.numpy.diagonal", + dtype_and_values=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + shape=st.shared(helpers.get_shape(min_num_dims=2), key="shape"), + ), + dims_and_offset=dims_and_offset( + shape=st.shared(helpers.get_shape(min_num_dims=2), key="shape") + ), ) -def test_jax_tril_indices( - n_rows, - k, - dtype, +def test_jax_diagonal( + *, + dtype_and_values, + dims_and_offset, test_flags, + on_device, + fn_tree, frontend, backend_fw, - fn_tree, - on_device, ): + input_dtype, value = dtype_and_values + axis1, axis2, offset = dims_and_offset + a = value[0] + num_of_dims = len(np.shape(a)) + assume(axis1 != axis2) + if axis1 < 0: + assume(axis1 + num_of_dims != axis2) + if axis2 < 0: + assume(axis1 != axis2 + num_of_dims) helpers.test_frontend_function( - input_dtypes=dtype, + input_dtypes=input_dtype, backend_to_test=backend_fw, test_flags=test_flags, frontend=frontend, fn_tree=fn_tree, on_device=on_device, - n=n_rows, - k=k, + a=a, + offset=offset, + axis1=axis1, + axis2=axis2, ) -# triu_indices @handle_frontend_test( - fn_tree="jax.numpy.triu_indices", - n=helpers.ints(min_value=2, max_value=10), - k=helpers.ints(min_value=-10, max_value=10), - input_dtypes=helpers.get_dtypes("valid", full=False), + fn_tree="jax.numpy.mask_indices", + n=helpers.ints(min_value=3, max_value=10), + mask_func=st.sampled_from([triu, tril]), + k=helpers.ints(min_value=-5, max_value=5), + input_dtype=helpers.get_dtypes("numeric"), test_with_out=st.just(False), + number_positional_args=st.just(2), ) -def test_jax_triu_indices( +def test_jax_mask_indices( n, + mask_func, k, - input_dtypes, + input_dtype, test_flags, frontend, backend_fw, @@ -239,39 +275,44 @@ def test_jax_triu_indices( on_device, ): helpers.test_frontend_function( - n=n, - k=k, - input_dtypes=input_dtypes, + input_dtypes=input_dtype, backend_to_test=backend_fw, test_flags=test_flags, frontend=frontend, fn_tree=fn_tree, on_device=on_device, + n=n, + mask_func=mask_func, + k=k, ) -# triu_indices_from +@handle_frontend_test(fn_tree="jax.numpy.add", inputs=_helper_c_()) # dummy fn_tree +def test_jax_numpy_c_(inputs, backend_fw): + ret_gt = c_.__getitem__(tuple(inputs)) + with BackendHandler.update_backend(backend_fw): + ret = jnp_frontend.c_.__getitem__(tuple(inputs)) + assert np.allclose(ret.ivy_array, ret_gt) + + @handle_frontend_test( - fn_tree="jax.numpy.triu_indices_from", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=1, - min_num_dims=2, - max_num_dims=5, - ), - k=helpers.ints(min_value=-5, max_value=5), + fn_tree="jax.numpy.indices", + dimensions=helpers.get_shape(min_num_dims=1), + dtype=helpers.get_dtypes("numeric"), + sparse=st.booleans(), test_with_out=st.just(False), ) -def test_jax_triu_indices_from( - dtype_and_x, - k, +def test_jax_numpy_indices( + *, + dimensions, + dtype, + sparse, test_flags, frontend, backend_fw, fn_tree, on_device, ): - dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=dtype, backend_to_test=backend_fw, @@ -279,181 +320,178 @@ def test_jax_triu_indices_from( frontend=frontend, fn_tree=fn_tree, on_device=on_device, - arr=x[0], - k=k, + dimensions=dimensions, + dtype=dtype[0], + sparse=sparse, ) -# tril_indices_from +@handle_frontend_test(fn_tree="jax.numpy.add", inputs=_helper_r_()) # dummy fn_tree +def test_jax_numpy_r_(inputs, backend_fw): + inputs, *_ = inputs + ret_gt = r_.__getitem__(tuple(inputs)) + with BackendHandler.update_backend(backend_fw): + ret = jnp_frontend.r_.__getitem__(tuple(inputs)) + assert np.allclose(ret.ivy_array, ret_gt) + + +# take_along_axis @handle_frontend_test( - fn_tree="jax.numpy.tril_indices_from", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=1, - min_num_dims=2, + fn_tree="jax.numpy.take_along_axis", + dtype_x_indices_axis=helpers.array_indices_axis( + array_dtypes=helpers.get_dtypes("numeric"), + indices_dtypes=["int32", "int64"], + min_num_dims=1, max_num_dims=5, + min_dim_size=1, + max_dim_size=10, + indices_same_dims=True, + valid_bounds=False, ), - k=helpers.ints(min_value=-5, max_value=5), + mode=st.sampled_from(["clip", "fill", "drop"]), test_with_out=st.just(False), ) -def test_jax_tril_indices_from( - dtype_and_x, - k, +def test_jax_take_along_axis( + *, + dtype_x_indices_axis, + mode, test_flags, frontend, backend_fw, fn_tree, on_device, ): - dtype, x = dtype_and_x + dtypes, x, indices, axis, _ = dtype_x_indices_axis helpers.test_frontend_function( - input_dtypes=dtype, + input_dtypes=dtypes, backend_to_test=backend_fw, test_flags=test_flags, frontend=frontend, fn_tree=fn_tree, on_device=on_device, - arr=x[0], - k=k, - ) - - -# unravel_index -@st.composite -def max_value_as_shape_prod(draw): - shape = draw( - helpers.get_shape( - min_num_dims=1, - max_num_dims=5, - min_dim_size=1, - max_dim_size=5, - ) - ) - dtype_and_x = draw( - helpers.dtype_values_axis( - available_dtypes=["int32", "int64"], - min_value=0, - max_value=np.prod(shape) - 1, - ) + arr=x, + indices=indices, + axis=axis, + mode=mode, ) - return dtype_and_x, shape +# Tril_indices @handle_frontend_test( - fn_tree="jax.numpy.unravel_index", - dtype_x_shape=max_value_as_shape_prod(), + fn_tree="jax.numpy.tril_indices", + n_rows=helpers.ints(min_value=1, max_value=10), + k=helpers.ints(min_value=2, max_value=10), + dtype=helpers.get_dtypes("valid", full=False), test_with_out=st.just(False), ) -def test_jax_unravel_index( - *, - dtype_x_shape, +def test_jax_tril_indices( + n_rows, + k, + dtype, test_flags, frontend, backend_fw, fn_tree, on_device, ): - dtype_and_x, shape = dtype_x_shape - input_dtype, x = dtype_and_x[0], dtype_and_x[1] helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=dtype, backend_to_test=backend_fw, test_flags=test_flags, frontend=frontend, fn_tree=fn_tree, on_device=on_device, - indices=x[0], - shape=shape, + n=n_rows, + k=k, ) +# tril_indices_from @handle_frontend_test( - fn_tree="jax.numpy.mask_indices", - n=helpers.ints(min_value=3, max_value=10), - mask_func=st.sampled_from([triu, tril]), + fn_tree="jax.numpy.tril_indices_from", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=1, + min_num_dims=2, + max_num_dims=5, + ), k=helpers.ints(min_value=-5, max_value=5), - input_dtype=helpers.get_dtypes("numeric"), test_with_out=st.just(False), - number_positional_args=st.just(2), ) -def test_jax_mask_indices( - n, - mask_func, +def test_jax_tril_indices_from( + dtype_and_x, k, - input_dtype, test_flags, frontend, backend_fw, fn_tree, on_device, ): + dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=dtype, backend_to_test=backend_fw, test_flags=test_flags, frontend=frontend, fn_tree=fn_tree, on_device=on_device, - n=n, - mask_func=mask_func, + arr=x[0], k=k, ) -@st.composite -def _get_dtype_square_x(draw): - dim_size = draw(helpers.ints(min_value=2, max_value=5)) - dtype_x = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), shape=(dim_size, dim_size) - ) - ) - return dtype_x - - +# triu_indices @handle_frontend_test( - dtype_x=_get_dtype_square_x(), - fn_tree="jax.numpy.diag_indices_from", + fn_tree="jax.numpy.triu_indices", + n=helpers.ints(min_value=2, max_value=10), + k=helpers.ints(min_value=-10, max_value=10), + input_dtypes=helpers.get_dtypes("valid", full=False), test_with_out=st.just(False), ) -def test_jax_diag_indices_from( - dtype_x, +def test_jax_triu_indices( + n, + k, + input_dtypes, test_flags, frontend, backend_fw, fn_tree, on_device, ): - dtype, x = dtype_x helpers.test_frontend_function( - input_dtypes=dtype, + n=n, + k=k, + input_dtypes=input_dtypes, backend_to_test=backend_fw, test_flags=test_flags, frontend=frontend, fn_tree=fn_tree, on_device=on_device, - arr=x[0], ) +# triu_indices_from @handle_frontend_test( - fn_tree="jax.numpy.indices", - dimensions=helpers.get_shape(min_num_dims=1), - dtype=helpers.get_dtypes("numeric"), - sparse=st.booleans(), + fn_tree="jax.numpy.triu_indices_from", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=1, + min_num_dims=2, + max_num_dims=5, + ), + k=helpers.ints(min_value=-5, max_value=5), test_with_out=st.just(False), ) -def test_jax_numpy_indices( - *, - dimensions, - dtype, - sparse, +def test_jax_triu_indices_from( + dtype_and_x, + k, test_flags, frontend, backend_fw, fn_tree, on_device, ): + dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=dtype, backend_to_test=backend_fw, @@ -461,64 +499,34 @@ def test_jax_numpy_indices( frontend=frontend, fn_tree=fn_tree, on_device=on_device, - dimensions=dimensions, - dtype=dtype[0], - sparse=sparse, + arr=x[0], + k=k, ) -@handle_frontend_test(fn_tree="jax.numpy.add", inputs=_helper_r_()) # dummy fn_tree -def test_jax_numpy_r_(inputs, backend_fw): - inputs, *_ = inputs - ret_gt = r_.__getitem__(tuple(inputs)) - with BackendHandler.update_backend(backend_fw): - ret = jnp_frontend.r_.__getitem__(tuple(inputs)) - assert np.allclose(ret.ivy_array, ret_gt) - - -@handle_frontend_test(fn_tree="jax.numpy.add", inputs=_helper_c_()) # dummy fn_tree -def test_jax_numpy_c_(inputs, backend_fw): - ret_gt = c_.__getitem__(tuple(inputs)) - with BackendHandler.update_backend(backend_fw): - ret = jnp_frontend.c_.__getitem__(tuple(inputs)) - assert np.allclose(ret.ivy_array, ret_gt) - - -# choose @handle_frontend_test( - fn_tree="jax.numpy.choose", - dtype_x_indices_axis=helpers.array_indices_axis( - array_dtypes=helpers.get_dtypes("numeric"), - indices_dtypes=["int32", "int64"], - ), - out=st.none(), - mode=st.sampled_from(["wrap", "clip", "raise"]), + fn_tree="jax.numpy.unravel_index", + dtype_x_shape=max_value_as_shape_prod(), test_with_out=st.just(False), ) -def test_jax_choose( +def test_jax_unravel_index( *, - dtype_x_indices_axis, - out, - mode, + dtype_x_shape, test_flags, frontend, backend_fw, fn_tree, on_device, ): - dtypes, x, indices, axis, _ = dtype_x_indices_axis - choices = ivy.array( - [np.random.randint(0, 10, size=x.shape) for _ in range(len(dtypes))] - ) + dtype_and_x, shape = dtype_x_shape + input_dtype, x = dtype_and_x[0], dtype_and_x[1] helpers.test_frontend_function( - input_dtypes=dtypes, + input_dtypes=input_dtype, backend_to_test=backend_fw, test_flags=test_flags, frontend=frontend, fn_tree=fn_tree, on_device=on_device, - arr=x, - choices=choices, - out=out, - mode=mode, + indices=x[0], + shape=shape, ) diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_linalg.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_linalg.py index 50494c24e3e28..9317b6956dfaa 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_linalg.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_linalg.py @@ -20,9 +20,130 @@ ) -# svd +# --- Helpers --- # +# --------------- # + + +# tensorinv +@st.composite +def _get_inv_square_matrices(draw): + dim_size = draw(helpers.ints(min_value=1, max_value=9)) + + batch_shape = draw(st.sampled_from([2, 4, 6, 8])) + + generated_shape = (dim_size,) * batch_shape + generated_ind = int(np.floor(len(generated_shape) / 2)) + + handpicked_shape, handpicked_ind = draw( + st.sampled_from([[(24, 6, 4), 1], [(8, 3, 6, 4), 2], [(6, 7, 8, 16, 21), 3]]) + ) + + shape, ind = draw( + st.sampled_from( + [(generated_shape, generated_ind), (handpicked_shape, handpicked_ind)] + ) + ) + + input_dtype = draw( + helpers.get_dtypes("float", index=1, full=False).filter( + lambda x: x not in ["float16", "bfloat16"] + ) + ) + invertible = False + while not invertible: + a = draw( + helpers.array_values( + dtype=input_dtype[0], + shape=shape, + large_abs_safety_factor=24, + small_abs_safety_factor=24, + safety_factor_scale="log", + ).filter(lambda x: helpers.matrix_is_stable(x)) + ) + try: + np.linalg.tensorinv(a, ind) + invertible = True + except np.linalg.LinAlgError: + pass + + return input_dtype, a, ind + + +# tensorsolve +@st.composite +def _get_solve_matrices(draw): + # batch_shape, random_size, shared + + # float16 causes a crash when filtering out matrices + # for which `np.linalg.cond` is large. + input_dtype_strategy = st.shared( + st.sampled_from(draw(helpers.get_dtypes("float"))).filter( + lambda x: "float16" not in x + ), + key="shared_dtype", + ) + input_dtype = draw(input_dtype_strategy) + + dim = draw(helpers.ints(min_value=2, max_value=5)) + + first_matrix = draw( + helpers.array_values( + dtype=input_dtype, + shape=(dim, dim, dim, dim), + min_value=1.2, + max_value=5, + ).filter(lambda x: np.linalg.det(x.reshape((dim**2, dim**2))) != 0) + ) + + second_matrix = draw( + helpers.array_values( + dtype=input_dtype, + shape=(dim, dim), + min_value=1.2, + max_value=3, + ) + ) + + return input_dtype, first_matrix, second_matrix + + +@st.composite +def norm_helper(draw): + dtype, x = draw( + helpers.dtype_and_values( + shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), + available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=1, + safety_factor_scale="log", + large_abs_safety_factor=2, + ) + ) + axis = draw( + helpers.get_axis( + shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), + ) + ) + if type(axis) in [tuple, list]: + if len(axis) == 2: + ord_param = draw( + st.sampled_from(["fro", "nuc", 1, 2, -1, -2, np.inf, -np.inf]) + ) + else: + axis = axis[0] + ord_param = draw(st.sampled_from([0, 1, 2, -1, -2, np.inf, -np.inf])) + else: + ord_param = draw(st.sampled_from([0, 1, 2, -1, -2, np.inf, -np.inf])) + keepdims = draw(st.booleans()) + return dtype, x, ord_param, axis, keepdims + + +# --- Main --- # +# ------------ # + + +# cholesky @handle_frontend_test( - fn_tree="jax.numpy.linalg.svd", + fn_tree="jax.numpy.linalg.cholesky", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), min_value=0, @@ -32,68 +153,63 @@ lambda x: "float16" not in x[0] and "bfloat16" not in x[0] and np.linalg.cond(x[1][0]) < 1 / sys.float_info.epsilon - and np.linalg.det(np.asarray(x[1][0])) != 0 + and np.linalg.det(x[1][0]) != 0 ), - full_matrices=st.booleans(), - compute_uv=st.booleans(), test_with_out=st.just(False), ) -def test_jax_svd( +def test_jax_cholesky( *, dtype_and_x, - full_matrices, - compute_uv, on_device, fn_tree, frontend, - backend_fw, test_flags, + backend_fw, ): dtype, x = dtype_and_x x = np.asarray(x[0], dtype=dtype[0]) - # make symmetric positive-definite beforehand + # make symmetric positive-definite x = np.matmul(x.T, x) + np.identity(x.shape[0]) * 1e-3 - ret, frontend_ret = helpers.test_frontend_function( + helpers.test_frontend_function( input_dtypes=dtype, frontend=frontend, backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - test_values=False, - x=x, - full_matrices=full_matrices, - compute_uv=compute_uv, + rtol=1e-02, + a=x, ) - if compute_uv: - with BackendHandler.update_backend(backend_fw) as ivy_backend: - ret = [ivy_backend.to_numpy(x) for x in ret] - frontend_ret = [np.asarray(x) for x in frontend_ret] - - u, s, vh = ret - frontend_u, frontend_s, frontend_vh = frontend_ret - assert_all_close( - ret_np=u @ np.diag(s) @ vh, - ret_from_gt_np=frontend_u @ np.diag(frontend_s) @ frontend_vh, - rtol=1e-2, - atol=1e-2, - backend=backend_fw, - ground_truth_backend=frontend, - ) - else: - with BackendHandler.update_backend(backend_fw) as ivy_backend: - ret = ivy_backend.to_numpy(ret) - assert_all_close( - ret_np=ret, - ret_from_gt_np=np.asarray(frontend_ret[0]), - rtol=1e-2, - atol=1e-2, - backend=backend_fw, - ground_truth_backend=frontend, - ) +@handle_frontend_test( + fn_tree="jax.numpy.linalg.cond", + dtype_x_p=helpers.cond_data_gen_helper(), + test_with_out=st.just(False), +) +def test_jax_cond( + *, + dtype_x_p, + test_flags, + on_device, + fn_tree, + frontend, + backend_fw, +): + dtype, x = dtype_x_p + helpers.test_frontend_function( + input_dtypes=dtype, + test_flags=test_flags, + rtol=1e-01, + atol=1e-01, + frontend=frontend, + backend_to_test=backend_fw, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + p=x[1], + ) # det @@ -245,23 +361,22 @@ def test_jax_eigh( ) -# inv +# eigvals @handle_frontend_test( - fn_tree="jax.numpy.linalg.inv", + fn_tree="jax.numpy.linalg.eigvals", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - min_value=-100, - max_value=100, - shape=helpers.ints(min_value=1, max_value=10).map(lambda x: tuple([x, x])), + min_value=0, + max_value=10, + shape=helpers.ints(min_value=2, max_value=5).map(lambda x: tuple([x, x])), ).filter( lambda x: "float16" not in x[0] and "bfloat16" not in x[0] and np.linalg.cond(x[1][0]) < 1 / sys.float_info.epsilon and np.linalg.det(np.asarray(x[1][0])) != 0 ), - test_with_out=st.just(False), ) -def test_jax_inv( +def test_jax_eigvals( *, dtype_and_x, on_device, @@ -270,19 +385,36 @@ def test_jax_inv( test_flags, backend_fw, ): - dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=dtype, + dtypes, x = dtype_and_x + x = np.array(x[0], dtype=dtypes[0]) + # make symmetric positive-definite beforehand + x = np.matmul(x.T, x) + np.identity(x.shape[0]) * 1e-3 + + ret, frontend_ret = helpers.test_frontend_function( + input_dtypes=dtypes, backend_to_test=backend_fw, - rtol=1e-01, - atol=1e-01, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - a=x[0], + test_values=False, + a=x, ) + with BackendHandler.update_backend(backend_fw) as ivy_backend: + # Calculate the magnitude of the complex numbers then sort them for testing + ret = np.sort(np.abs(ivy_backend.to_numpy(ret))).astype(np.float64) + frontend_ret = np.sort(np.abs(frontend_ret)).astype(np.float64) + + assert_all_close( + ret_np=ret, + ret_from_gt_np=frontend_ret, + backend=backend_fw, + ground_truth_backend=frontend, + atol=1e-2, + rtol=1e-2, + ) + # eigvalsh @handle_frontend_test( @@ -329,25 +461,25 @@ def test_jax_eigvalsh( ) -# qr +# inv @handle_frontend_test( - fn_tree="jax.numpy.linalg.qr", + fn_tree="jax.numpy.linalg.inv", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - min_num_dims=3, - max_num_dims=5, - min_dim_size=2, - max_dim_size=5, - min_value=2, - max_value=5, + min_value=-100, + max_value=100, + shape=helpers.ints(min_value=1, max_value=10).map(lambda x: tuple([x, x])), + ).filter( + lambda x: "float16" not in x[0] + and "bfloat16" not in x[0] + and np.linalg.cond(x[1][0]) < 1 / sys.float_info.epsilon + and np.linalg.det(np.asarray(x[1][0])) != 0 ), - mode=st.sampled_from(("reduced", "complete")), test_with_out=st.just(False), ) -def test_jax_qr( +def test_jax_inv( *, dtype_and_x, - mode, on_device, fn_tree, frontend, @@ -355,107 +487,40 @@ def test_jax_qr( backend_fw, ): dtype, x = dtype_and_x - ret, frontend_ret = helpers.test_frontend_function( + helpers.test_frontend_function( input_dtypes=dtype, - test_values=False, backend_to_test=backend_fw, + rtol=1e-01, + atol=1e-01, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - a=np.asarray(x[0], dtype[0]), - mode=mode, - ) - with BackendHandler.update_backend(backend_fw) as ivy_backend: - ret = [ivy_backend.to_numpy(x).astype(np.float64) for x in ret] - frontend_ret = [x.astype(np.float64) for x in frontend_ret] - - Q, R = ret - frontend_Q, frontend_R = frontend_ret - - assert_all_close( - ret_np=Q @ R, - ret_from_gt_np=frontend_Q @ frontend_R, - atol=1e-02, - backend=backend_fw, - ground_truth_backend=frontend, + a=x[0], ) -# eigvals +# matrix_power @handle_frontend_test( - fn_tree="jax.numpy.linalg.eigvals", + fn_tree="jax.numpy.linalg.matrix_power", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - min_value=0, - max_value=10, - shape=helpers.ints(min_value=2, max_value=5).map(lambda x: tuple([x, x])), + min_value=-100, + max_value=100, + shape=helpers.ints(min_value=1, max_value=10).map(lambda x: tuple([x, x])), ).filter( lambda x: "float16" not in x[0] and "bfloat16" not in x[0] and np.linalg.cond(x[1][0]) < 1 / sys.float_info.epsilon and np.linalg.det(np.asarray(x[1][0])) != 0 ), -) -def test_jax_eigvals( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - dtypes, x = dtype_and_x - x = np.array(x[0], dtype=dtypes[0]) - # make symmetric positive-definite beforehand - x = np.matmul(x.T, x) + np.identity(x.shape[0]) * 1e-3 - - ret, frontend_ret = helpers.test_frontend_function( - input_dtypes=dtypes, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - test_values=False, - a=x, - ) - - with BackendHandler.update_backend(backend_fw) as ivy_backend: - # Calculate the magnitude of the complex numbers then sort them for testing - ret = np.sort(np.abs(ivy_backend.to_numpy(ret))).astype(np.float64) - frontend_ret = np.sort(np.abs(frontend_ret)).astype(np.float64) - - assert_all_close( - ret_np=ret, - ret_from_gt_np=frontend_ret, - backend=backend_fw, - ground_truth_backend=frontend, - atol=1e-2, - rtol=1e-2, - ) - - -# cholesky -@handle_frontend_test( - fn_tree="jax.numpy.linalg.cholesky", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_value=0, - max_value=10, - shape=helpers.ints(min_value=2, max_value=5).map(lambda x: tuple([x, x])), - ).filter( - lambda x: "float16" not in x[0] - and "bfloat16" not in x[0] - and np.linalg.cond(x[1][0]) < 1 / sys.float_info.epsilon - and np.linalg.det(x[1][0]) != 0 - ), + n=helpers.ints(min_value=1, max_value=8), test_with_out=st.just(False), ) -def test_jax_cholesky( +def test_jax_matrix_power( *, dtype_and_x, + n, on_device, fn_tree, frontend, @@ -463,57 +528,17 @@ def test_jax_cholesky( backend_fw, ): dtype, x = dtype_and_x - x = np.asarray(x[0], dtype=dtype[0]) - # make symmetric positive-definite - x = np.matmul(x.T, x) + np.identity(x.shape[0]) * 1e-3 - helpers.test_frontend_function( input_dtypes=dtype, + rtol=1e-01, + atol=1e-01, frontend=frontend, - backend_to_test=backend_fw, test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - rtol=1e-02, - a=x, - ) - - -# slogdet -@handle_frontend_test( - fn_tree="jax.numpy.linalg.slogdet", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - max_value=100, - min_value=-100, - shape=st.tuples( - st.shared(st.integers(1, 5), key="sq"), - st.shared(st.integers(1, 5), key="sq"), - ), - num_arrays=1, - ), - test_with_out=st.just(False), -) -def test_jax_slogdet( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - frontend=frontend, backend_to_test=backend_fw, - test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - atol=1e-4, - rtol=1e-4, - a=x[0], + a=np.asarray(x[0], dtype=dtype[0]), + n=n, ) @@ -546,67 +571,37 @@ def test_jax_matrix_rank( ) -# solve +# multi_dot @handle_frontend_test( - fn_tree="jax.numpy.linalg.solve", - x=helpers.get_first_solve_matrix(adjoint=False), - y=helpers.get_second_solve_matrix(), + fn_tree="jax.numpy.linalg.multi_dot", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_value=0, + max_value=10, + shape=helpers.ints(min_value=2, max_value=5).map(lambda x: tuple([x, x])), + num_arrays=2, + shared_dtype=True, + ).filter( + lambda x: "float16" not in x[0] + and "bfloat16" not in x[0] + and np.linalg.cond(x[1][0]) < 1 / sys.float_info.epsilon + and np.linalg.det(np.asarray(x[1][0])) != 0 + ), test_with_out=st.just(False), ) -def test_jax_solve( - *, - x, - y, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, +def test_jax_multi_dot( + *, dtype_and_x, on_device, fn_tree, frontend, test_flags, backend_fw ): - input_dtype1, x1, _ = x - input_dtype2, x2 = y + dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=[input_dtype1, input_dtype2], + input_dtypes=dtype, frontend=frontend, test_flags=test_flags, - backend_to_test=backend_fw, fn_tree=fn_tree, on_device=on_device, - rtol=1e-1, - atol=1e-1, - a=x1, - b=x2, - ) - - -@st.composite -def norm_helper(draw): - dtype, x = draw( - helpers.dtype_and_values( - shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), - available_dtypes=helpers.get_dtypes("valid"), - min_num_dims=1, - safety_factor_scale="log", - large_abs_safety_factor=2, - ) - ) - axis = draw( - helpers.get_axis( - shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), - ) + arrays=(x[0], x[1]), + backend_to_test=backend_fw, ) - if type(axis) in [tuple, list]: - if len(axis) == 2: - ord_param = draw( - st.sampled_from(["fro", "nuc", 1, 2, -1, -2, np.inf, -np.inf]) - ) - else: - axis = axis[0] - ord_param = draw(st.sampled_from([0, 1, 2, -1, -2, np.inf, -np.inf])) - else: - ord_param = draw(st.sampled_from([0, 1, 2, -1, -2, np.inf, -np.inf])) - keepdims = draw(st.booleans()) - return dtype, x, ord_param, axis, keepdims # norm @@ -659,196 +654,242 @@ def test_jax_norm( ) -# matrix_power +# pinv @handle_frontend_test( - fn_tree="jax.numpy.linalg.matrix_power", + fn_tree="jax.numpy.linalg.pinv", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - min_value=-100, - max_value=100, - shape=helpers.ints(min_value=1, max_value=10).map(lambda x: tuple([x, x])), - ).filter( - lambda x: "float16" not in x[0] - and "bfloat16" not in x[0] - and np.linalg.cond(x[1][0]) < 1 / sys.float_info.epsilon - and np.linalg.det(np.asarray(x[1][0])) != 0 + min_num_dims=2, + max_num_dims=5, + min_dim_size=2, + max_dim_size=5, + large_abs_safety_factor=4, + small_abs_safety_factor=4, + safety_factor_scale="log", ), - n=helpers.ints(min_value=1, max_value=8), test_with_out=st.just(False), + rcond=st.floats(1e-5, 1e-3), ) -def test_jax_matrix_power( - *, +def test_jax_pinv( dtype_and_x, - n, - on_device, - fn_tree, frontend, + fn_tree, test_flags, backend_fw, + rcond, ): dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=dtype, - rtol=1e-01, - atol=1e-01, frontend=frontend, - test_flags=test_flags, backend_to_test=backend_fw, + test_flags=test_flags, fn_tree=fn_tree, - on_device=on_device, - a=np.asarray(x[0], dtype=dtype[0]), - n=n, + a=x[0], + rcond=rcond, + atol=1e-1, + rtol=1e-1, ) -# tensorsolve -@st.composite -def _get_solve_matrices(draw): - # batch_shape, random_size, shared - - # float16 causes a crash when filtering out matrices - # for which `np.linalg.cond` is large. - input_dtype_strategy = st.shared( - st.sampled_from(draw(helpers.get_dtypes("float"))).filter( - lambda x: "float16" not in x - ), - key="shared_dtype", +# qr +@handle_frontend_test( + fn_tree="jax.numpy.linalg.qr", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_num_dims=3, + max_num_dims=5, + min_dim_size=2, + max_dim_size=5, + min_value=2, + max_value=5, + ), + mode=st.sampled_from(("reduced", "complete")), + test_with_out=st.just(False), +) +def test_jax_qr( + *, + dtype_and_x, + mode, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + dtype, x = dtype_and_x + ret, frontend_ret = helpers.test_frontend_function( + input_dtypes=dtype, + test_values=False, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + a=np.asarray(x[0], dtype[0]), + mode=mode, ) - input_dtype = draw(input_dtype_strategy) + with BackendHandler.update_backend(backend_fw) as ivy_backend: + ret = [ivy_backend.to_numpy(x).astype(np.float64) for x in ret] + frontend_ret = [x.astype(np.float64) for x in frontend_ret] - dim = draw(helpers.ints(min_value=2, max_value=5)) + Q, R = ret + frontend_Q, frontend_R = frontend_ret - first_matrix = draw( - helpers.array_values( - dtype=input_dtype, - shape=(dim, dim, dim, dim), - min_value=1.2, - max_value=5, - ).filter(lambda x: np.linalg.det(x.reshape((dim**2, dim**2))) != 0) + assert_all_close( + ret_np=Q @ R, + ret_from_gt_np=frontend_Q @ frontend_R, + atol=1e-02, + backend=backend_fw, + ground_truth_backend=frontend, ) - second_matrix = draw( - helpers.array_values( - dtype=input_dtype, - shape=(dim, dim), - min_value=1.2, - max_value=3, - ) - ) - return input_dtype, first_matrix, second_matrix +# slogdet +@handle_frontend_test( + fn_tree="jax.numpy.linalg.slogdet", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + max_value=100, + min_value=-100, + shape=st.tuples( + st.shared(st.integers(1, 5), key="sq"), + st.shared(st.integers(1, 5), key="sq"), + ), + num_arrays=1, + ), + test_with_out=st.just(False), +) +def test_jax_slogdet( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + atol=1e-4, + rtol=1e-4, + a=x[0], + ) +# solve @handle_frontend_test( - fn_tree="jax.numpy.linalg.tensorsolve", - a_and_b=_get_solve_matrices(), + fn_tree="jax.numpy.linalg.solve", + x=helpers.get_first_solve_matrix(adjoint=False), + y=helpers.get_second_solve_matrix(), test_with_out=st.just(False), ) -def test_jax_tensorsolve( +def test_jax_solve( *, - a_and_b, + x, + y, on_device, fn_tree, frontend, test_flags, backend_fw, ): - input_dtype, x, y = a_and_b + input_dtype1, x1, _ = x + input_dtype2, x2 = y helpers.test_frontend_function( - input_dtypes=[input_dtype], + input_dtypes=[input_dtype1, input_dtype2], frontend=frontend, test_flags=test_flags, backend_to_test=backend_fw, fn_tree=fn_tree, on_device=on_device, - a=x, - b=y, - atol=1e-2, - rtol=1e-2, + rtol=1e-1, + atol=1e-1, + a=x1, + b=x2, ) -# pinv +# svd @handle_frontend_test( - fn_tree="jax.numpy.linalg.pinv", + fn_tree="jax.numpy.linalg.svd", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - min_num_dims=2, - max_num_dims=5, - min_dim_size=2, - max_dim_size=5, - large_abs_safety_factor=4, - small_abs_safety_factor=4, - safety_factor_scale="log", + min_value=0, + max_value=10, + shape=helpers.ints(min_value=2, max_value=5).map(lambda x: tuple([x, x])), + ).filter( + lambda x: "float16" not in x[0] + and "bfloat16" not in x[0] + and np.linalg.cond(x[1][0]) < 1 / sys.float_info.epsilon + and np.linalg.det(np.asarray(x[1][0])) != 0 ), + full_matrices=st.booleans(), + compute_uv=st.booleans(), test_with_out=st.just(False), - rcond=st.floats(1e-5, 1e-3), ) -def test_jax_pinv( +def test_jax_svd( + *, dtype_and_x, - frontend, + full_matrices, + compute_uv, + on_device, fn_tree, - test_flags, + frontend, backend_fw, - rcond, + test_flags, ): dtype, x = dtype_and_x - helpers.test_frontend_function( + x = np.asarray(x[0], dtype=dtype[0]) + # make symmetric positive-definite beforehand + x = np.matmul(x.T, x) + np.identity(x.shape[0]) * 1e-3 + + ret, frontend_ret = helpers.test_frontend_function( input_dtypes=dtype, frontend=frontend, backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, - a=x[0], - rcond=rcond, - atol=1e-1, - rtol=1e-1, + on_device=on_device, + test_values=False, + x=x, + full_matrices=full_matrices, + compute_uv=compute_uv, ) + if compute_uv: + with BackendHandler.update_backend(backend_fw) as ivy_backend: + ret = [ivy_backend.to_numpy(x) for x in ret] + frontend_ret = [np.asarray(x) for x in frontend_ret] -# tensorinv -@st.composite -def _get_inv_square_matrices(draw): - dim_size = draw(helpers.ints(min_value=1, max_value=9)) - - batch_shape = draw(st.sampled_from([2, 4, 6, 8])) - - generated_shape = (dim_size,) * batch_shape - generated_ind = int(np.floor(len(generated_shape) / 2)) - - handpicked_shape, handpicked_ind = draw( - st.sampled_from([[(24, 6, 4), 1], [(8, 3, 6, 4), 2], [(6, 7, 8, 16, 21), 3]]) - ) - - shape, ind = draw( - st.sampled_from( - [(generated_shape, generated_ind), (handpicked_shape, handpicked_ind)] - ) - ) + u, s, vh = ret + frontend_u, frontend_s, frontend_vh = frontend_ret - input_dtype = draw( - helpers.get_dtypes("float", index=1, full=False).filter( - lambda x: x not in ["float16", "bfloat16"] + assert_all_close( + ret_np=u @ np.diag(s) @ vh, + ret_from_gt_np=frontend_u @ np.diag(frontend_s) @ frontend_vh, + rtol=1e-2, + atol=1e-2, + backend=backend_fw, + ground_truth_backend=frontend, ) - ) - invertible = False - while not invertible: - a = draw( - helpers.array_values( - dtype=input_dtype[0], - shape=shape, - large_abs_safety_factor=24, - small_abs_safety_factor=24, - safety_factor_scale="log", - ).filter(lambda x: helpers.matrix_is_stable(x)) + else: + with BackendHandler.update_backend(backend_fw) as ivy_backend: + ret = ivy_backend.to_numpy(ret) + assert_all_close( + ret_np=ret, + ret_from_gt_np=np.asarray(frontend_ret[0]), + rtol=1e-2, + atol=1e-2, + backend=backend_fw, + ground_truth_backend=frontend, ) - try: - np.linalg.tensorinv(a, ind) - invertible = True - except np.linalg.LinAlgError: - pass - - return input_dtype, a, ind @handle_frontend_test( @@ -879,62 +920,29 @@ def test_jax_tensorinv( @handle_frontend_test( - fn_tree="jax.numpy.linalg.cond", - dtype_x_p=helpers.cond_data_gen_helper(), + fn_tree="jax.numpy.linalg.tensorsolve", + a_and_b=_get_solve_matrices(), test_with_out=st.just(False), ) -def test_jax_cond( +def test_jax_tensorsolve( *, - dtype_x_p, - test_flags, + a_and_b, on_device, fn_tree, frontend, + test_flags, backend_fw, ): - dtype, x = dtype_x_p - helpers.test_frontend_function( - input_dtypes=dtype, - test_flags=test_flags, - rtol=1e-01, - atol=1e-01, - frontend=frontend, - backend_to_test=backend_fw, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - p=x[1], - ) - - -# multi_dot -@handle_frontend_test( - fn_tree="jax.numpy.linalg.multi_dot", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_value=0, - max_value=10, - shape=helpers.ints(min_value=2, max_value=5).map(lambda x: tuple([x, x])), - num_arrays=2, - shared_dtype=True, - ).filter( - lambda x: "float16" not in x[0] - and "bfloat16" not in x[0] - and np.linalg.cond(x[1][0]) < 1 / sys.float_info.epsilon - and np.linalg.det(np.asarray(x[1][0])) != 0 - ), - test_with_out=st.just(False), -) -def test_jax_multi_dot( - *, dtype_and_x, on_device, fn_tree, frontend, test_flags, backend_fw -): - dtype, x = dtype_and_x + input_dtype, x, y = a_and_b helpers.test_frontend_function( - input_dtypes=dtype, + input_dtypes=[input_dtype], frontend=frontend, test_flags=test_flags, + backend_to_test=backend_fw, fn_tree=fn_tree, on_device=on_device, - arrays=(x[0], x[1]), - backend_to_test=backend_fw, + a=x, + b=y, + atol=1e-2, + rtol=1e-2, ) diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_logic.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_logic.py index 430c2adf2c988..d0580c45643bb 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_logic.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_logic.py @@ -8,6 +8,80 @@ import ivy_tests.test_ivy.test_frontends.test_numpy.helpers as np_helpers +# --- Helpers --- # +# --------------- # + + +@st.composite +def _func_and_shape_dtype_helper(draw): + # here assumption is that the input func will take the len(shape) no of parameters + def add_numbers(*args): + total = 0 + for num in args: + total += num + return total + + shape = draw( + helpers.get_shape( + allow_none=False, + min_num_dims=1, + max_num_dims=3, + min_dim_size=1, + max_dim_size=3, + ) + ) + + dtype = draw(helpers.get_dtypes("valid")) + + return add_numbers, shape, dtype[0] + + +# isin +@st.composite +def _isin_data_generation_helper(draw): + dtype_and_x = helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + num_arrays=2, + shared_dtype=True, + ) + return draw(dtype_and_x) + + +# --- Main --- # +# ------------ # + + +# all +@handle_frontend_test( + fn_tree="jax.numpy.all", + # aliases=["jax.numpy.alltrue"], deprecated since 0.4.12. + # uncomment with multi-version testing pipeline + dtype_and_x=helpers.dtype_and_values( + available_dtypes=st.one_of(st.just(("bool",)), helpers.get_dtypes("integer")), + ), + test_with_out=st.just(False), +) +def test_jax_all( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + a=x[0], + ) + + # allclose @handle_frontend_test( fn_tree="jax.numpy.allclose", @@ -43,6 +117,55 @@ def test_jax_allclose( ) +# any +@handle_frontend_test( + fn_tree="jax.numpy.any", + # aliases=["jax.numpy.sometrue"], deprecated since 0.4.12. + # uncomment with multi-version testing pipeline + dtype_x_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("valid"), + valid_axis=True, + max_axes_size=1, + force_int_axis=True, + ), + keepdims=st.booleans(), + where=np_helpers.where(), + test_with_out=st.just(False), +) +def test_jax_any( + *, + dtype_x_axis, + keepdims, + where, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtypes, x, axis = dtype_x_axis + if isinstance(axis, tuple): + axis = axis[0] + where, input_dtypes, test_flags = np_helpers.handle_where_and_array_bools( + where=where, + input_dtype=input_dtypes, + test_flags=test_flags, + ) + np_helpers.test_frontend_function( + input_dtypes=input_dtypes, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + a=x[0], + axis=axis, + out=None, + keepdims=keepdims, + where=where, + ) + + @handle_frontend_test( fn_tree="jax.numpy.array_equal", dtype_and_x=helpers.dtype_and_values( @@ -110,22 +233,44 @@ def test_jax_array_equiv( ) -# isneginf +# bitwise_and +# TODO: add testing for other dtypes @handle_frontend_test( - fn_tree="jax.numpy.isneginf", + fn_tree="jax.numpy.bitwise_and", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_value=-np.inf, - max_value=np.inf, - min_num_dims=1, - max_num_dims=3, - min_dim_size=1, - max_dim_size=3, - allow_inf=True, + available_dtypes=helpers.get_dtypes("bool"), num_arrays=2 ), test_with_out=st.just(False), ) -def test_jax_isneginf( +def test_jax_bitwise_and( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x1=x[0], + x2=x[1], + ) + + +# bitwise_not +@handle_frontend_test( + fn_tree="jax.numpy.bitwise_not", + dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("bool")), + test_with_out=st.just(False), +) +def test_jax_bitwise_not( *, dtype_and_x, on_device, @@ -146,21 +291,16 @@ def test_jax_isneginf( ) -# isposinf +# bitwise_or +# TODO: add testing for other dtypes @handle_frontend_test( - fn_tree="jax.numpy.isposinf", + fn_tree="jax.numpy.bitwise_or", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_value=-np.inf, - max_value=np.inf, - min_num_dims=1, - max_num_dims=3, - min_dim_size=1, - max_dim_size=3, - allow_inf=True, + available_dtypes=helpers.get_dtypes("bool"), num_arrays=2 ), + test_with_out=st.just(False), ) -def test_jax_isposinf( +def test_jax_bitwise_or( *, dtype_and_x, on_device, @@ -177,25 +317,27 @@ def test_jax_isposinf( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], + x1=x[0], + x2=x[1], ) -# less +# bitwise_xor +# TODO: add testing for other dtypes @handle_frontend_test( - fn_tree="jax.numpy.less", + fn_tree="jax.numpy.bitwise_xor", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, - shared_dtype=True, + available_dtypes=helpers.get_dtypes("bool"), num_arrays=2 ), test_with_out=st.just(False), ) -def test_jax_less( +def test_jax_bitwise_xor( + *, dtype_and_x, + on_device, + fn_tree, frontend, test_flags, - fn_tree, backend_fw, ): input_dtype, x = dtype_and_x @@ -205,14 +347,15 @@ def test_jax_less( frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, + on_device=on_device, x1=x[0], x2=x[1], ) -# less_equal +# equal @handle_frontend_test( - fn_tree="jax.numpy.less_equal", + fn_tree="jax.numpy.equal", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2, @@ -220,7 +363,7 @@ def test_jax_less( ), test_with_out=st.just(False), ) -def test_jax_less_equal( +def test_jax_equal( dtype_and_x, frontend, test_flags, @@ -297,34 +440,26 @@ def test_jax_greater_equal( ) -# isnan +# invert @handle_frontend_test( - fn_tree="jax.numpy.isnan", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_value=-np.inf, - max_value=np.inf, - min_num_dims=1, - max_num_dims=3, - min_dim_size=1, - max_dim_size=3, - allow_inf=True, + fn_tree="jax.numpy.invert", + dtypes_values=helpers.dtype_and_values( + available_dtypes=st.one_of(st.just(("bool",)), helpers.get_dtypes("integer")), ), ) -def test_jax_isnan( - *, - dtype_and_x, +def test_jax_invert( + dtypes_values, on_device, fn_tree, frontend, test_flags, backend_fw, ): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, + x_dtypes, x = dtypes_values + np_helpers.test_frontend_function( + input_dtypes=x_dtypes, frontend=frontend, + backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, @@ -332,50 +467,55 @@ def test_jax_isnan( ) -# equal +# isclose @handle_frontend_test( - fn_tree="jax.numpy.equal", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), + fn_tree="jax.numpy.isclose", + dtype_and_input=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), num_arrays=2, - shared_dtype=True, ), + equal_nan=st.booleans(), test_with_out=st.just(False), ) -def test_jax_equal( - dtype_and_x, +def test_jax_isclose( + *, + dtype_and_input, + equal_nan, + on_device, + fn_tree, frontend, test_flags, - fn_tree, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, input = dtype_and_input helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, - x1=x[0], - x2=x[1], + on_device=on_device, + a=input[0], + b=input[1], + equal_nan=equal_nan, ) -# not_equal +# iscomplex @handle_frontend_test( - fn_tree="jax.numpy.not_equal", + fn_tree="jax.numpy.iscomplex", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, - shared_dtype=True, + available_dtypes=helpers.get_dtypes("real_and_complex"), min_num_dims=1 ), test_with_out=st.just(False), ) -def test_jax_not_equal( +def test_jax_iscomplex( dtype_and_x, frontend, - test_flags, + on_device, + *, fn_tree, + test_flags, backend_fw, ): input_dtype, x = dtype_and_x @@ -385,27 +525,25 @@ def test_jax_not_equal( frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, - x1=x[0], - x2=x[1], + on_device=on_device, + x=x[0], ) -# all +# iscomplexobj @handle_frontend_test( - fn_tree="jax.numpy.all", - # aliases=["jax.numpy.alltrue"], deprecated since 0.4.12. - # uncomment with multi-version testing pipeline + fn_tree="jax.numpy.iscomplexobj", dtype_and_x=helpers.dtype_and_values( - available_dtypes=st.one_of(st.just(("bool",)), helpers.get_dtypes("integer")), + available_dtypes=helpers.get_dtypes("valid"), ), test_with_out=st.just(False), ) -def test_jax_all( - *, +def test_jax_iscomplexobj( dtype_and_x, + frontend, on_device, + *, fn_tree, - frontend, test_flags, backend_fw, ): @@ -417,20 +555,19 @@ def test_jax_all( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - a=x[0], + x=x[0], ) -# bitwise_and -# TODO: add testing for other dtypes +# isfinite @handle_frontend_test( - fn_tree="jax.numpy.bitwise_and", + fn_tree="jax.numpy.isfinite", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("bool"), num_arrays=2 + available_dtypes=helpers.get_dtypes("numeric"), allow_nan=True ), test_with_out=st.just(False), ) -def test_jax_bitwise_and( +def test_jax_isfinite( *, dtype_and_x, on_device, @@ -447,48 +584,51 @@ def test_jax_bitwise_and( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x1=x[0], - x2=x[1], + x=x[0], ) -# bitwise_not @handle_frontend_test( - fn_tree="jax.numpy.bitwise_not", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("bool")), + fn_tree="jax.numpy.isin", + assume_unique_and_dtype_and_x=_isin_data_generation_helper(), + invert=st.booleans(), test_with_out=st.just(False), ) -def test_jax_bitwise_not( +def test_jax_isin( *, - dtype_and_x, + assume_unique_and_dtype_and_x, + invert, on_device, fn_tree, frontend, test_flags, backend_fw, ): - input_dtype, x = dtype_and_x + x_and_dtype = assume_unique_and_dtype_and_x + dtypes, values = x_and_dtype + elements, test_elements = values helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, + input_dtypes=dtypes, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], + element=elements, + test_elements=test_elements, + invert=invert, + backend_to_test=backend_fw, ) -# bitwise_or -# TODO: add testing for other dtypes +# isinf @handle_frontend_test( - fn_tree="jax.numpy.bitwise_or", + fn_tree="jax.numpy.isinf", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("bool"), num_arrays=2 + available_dtypes=helpers.get_dtypes("numeric"), allow_inf=True ), test_with_out=st.just(False), ) -def test_jax_bitwise_or( +def test_jax_isinf( *, dtype_and_x, on_device, @@ -505,21 +645,25 @@ def test_jax_bitwise_or( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x1=x[0], - x2=x[1], + x=x[0], ) -# bitwise_xor -# TODO: add testing for other dtypes +# isnan @handle_frontend_test( - fn_tree="jax.numpy.bitwise_xor", + fn_tree="jax.numpy.isnan", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("bool"), num_arrays=2 + available_dtypes=helpers.get_dtypes("float"), + min_value=-np.inf, + max_value=np.inf, + min_num_dims=1, + max_num_dims=3, + min_dim_size=1, + max_dim_size=3, + allow_inf=True, ), - test_with_out=st.just(False), ) -def test_jax_bitwise_xor( +def test_jax_isnan( *, dtype_and_x, on_device, @@ -536,109 +680,109 @@ def test_jax_bitwise_xor( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x1=x[0], - x2=x[1], + x=x[0], ) -# any +# isneginf @handle_frontend_test( - fn_tree="jax.numpy.any", - # aliases=["jax.numpy.sometrue"], deprecated since 0.4.12. - # uncomment with multi-version testing pipeline - dtype_x_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("valid"), - valid_axis=True, - max_axes_size=1, - force_int_axis=True, + fn_tree="jax.numpy.isneginf", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_value=-np.inf, + max_value=np.inf, + min_num_dims=1, + max_num_dims=3, + min_dim_size=1, + max_dim_size=3, + allow_inf=True, ), - keepdims=st.booleans(), - where=np_helpers.where(), test_with_out=st.just(False), ) -def test_jax_any( +def test_jax_isneginf( *, - dtype_x_axis, - keepdims, - where, + dtype_and_x, on_device, fn_tree, frontend, test_flags, backend_fw, ): - input_dtypes, x, axis = dtype_x_axis - if isinstance(axis, tuple): - axis = axis[0] - where, input_dtypes, test_flags = np_helpers.handle_where_and_array_bools( - where=where, - input_dtype=input_dtypes, - test_flags=test_flags, - ) - np_helpers.test_frontend_function( - input_dtypes=input_dtypes, + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - a=x[0], - axis=axis, - out=None, - keepdims=keepdims, - where=where, + x=x[0], ) -# logical_and +# isposinf @handle_frontend_test( - fn_tree="jax.numpy.logical_and", - dtypes_values=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("bool"), - num_arrays=2, + fn_tree="jax.numpy.isposinf", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_value=-np.inf, + max_value=np.inf, + min_num_dims=1, + max_num_dims=3, + min_dim_size=1, + max_dim_size=3, + allow_inf=True, ), ) -def test_jax_logical_and( - dtypes_values, +def test_jax_isposinf( + *, + dtype_and_x, on_device, fn_tree, frontend, test_flags, backend_fw, ): - x_dtypes, x = dtypes_values - np_helpers.test_frontend_function( - input_dtypes=x_dtypes, - frontend=frontend, + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, backend_to_test=backend_fw, + frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x1=x[0], - x2=x[1], + x=x[0], ) -# invert +# isreal @handle_frontend_test( - fn_tree="jax.numpy.invert", - dtypes_values=helpers.dtype_and_values( - available_dtypes=st.one_of(st.just(("bool",)), helpers.get_dtypes("integer")), + fn_tree="jax.numpy.isreal", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_value=-np.inf, + max_value=np.inf, + min_num_dims=1, + max_num_dims=3, + min_dim_size=1, + max_dim_size=3, + allow_inf=True, ), ) -def test_jax_invert( - dtypes_values, +def test_jax_isreal( + *, + dtype_and_x, on_device, fn_tree, frontend, test_flags, backend_fw, ): - x_dtypes, x = dtypes_values - np_helpers.test_frontend_function( - input_dtypes=x_dtypes, - frontend=frontend, + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, backend_to_test=backend_fw, + frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, @@ -646,20 +790,19 @@ def test_jax_invert( ) -# isfinite @handle_frontend_test( - fn_tree="jax.numpy.isfinite", + fn_tree="jax.numpy.isrealobj", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), allow_nan=True + available_dtypes=helpers.get_dtypes("valid"), min_num_dims=1 ), test_with_out=st.just(False), ) -def test_jax_isfinite( - *, +def test_jax_isrealobj( dtype_and_x, + frontend, on_device, + *, fn_tree, - frontend, test_flags, backend_fw, ): @@ -675,58 +818,43 @@ def test_jax_isfinite( ) -# isin -@st.composite -def _isin_data_generation_helper(draw): - dtype_and_x = helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - num_arrays=2, - shared_dtype=True, - ) - return draw(dtype_and_x) - - +# isscalar @handle_frontend_test( - fn_tree="jax.numpy.isin", - assume_unique_and_dtype_and_x=_isin_data_generation_helper(), - invert=st.booleans(), - test_with_out=st.just(False), + fn_tree="jax.numpy.isscalar", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric") + ), ) -def test_jax_isin( +def test_jax_isscalar( *, - assume_unique_and_dtype_and_x, - invert, + dtype_and_x, on_device, fn_tree, frontend, - test_flags, backend_fw, + test_flags, ): - x_and_dtype = assume_unique_and_dtype_and_x - dtypes, values = x_and_dtype - elements, test_elements = values - helpers.test_frontend_function( - input_dtypes=dtypes, - frontend=frontend, + x_dtypes, x = dtype_and_x + np_helpers.test_frontend_function( + input_dtypes=x_dtypes, + frontend=frontend, + backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - element=elements, - test_elements=test_elements, - invert=invert, - backend_to_test=backend_fw, + x=x[0], ) -# isinf +# left_shift @handle_frontend_test( - fn_tree="jax.numpy.isinf", + fn_tree="jax.numpy.left_shift", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), allow_inf=True + available_dtypes=helpers.get_dtypes("integer"), + num_arrays=2, ), - test_with_out=st.just(False), ) -def test_jax_isinf( +def test_jax_left_shift( *, dtype_and_x, on_device, @@ -743,81 +871,78 @@ def test_jax_isinf( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], + x1=x[0], + x2=x[1], ) -# isclose +# less @handle_frontend_test( - fn_tree="jax.numpy.isclose", - dtype_and_input=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + fn_tree="jax.numpy.less", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2, + shared_dtype=True, ), - equal_nan=st.booleans(), test_with_out=st.just(False), ) -def test_jax_isclose( - *, - dtype_and_input, - equal_nan, - on_device, - fn_tree, +def test_jax_less( + dtype_and_x, frontend, test_flags, + fn_tree, backend_fw, ): - input_dtype, input = dtype_and_input + input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, - on_device=on_device, - a=input[0], - b=input[1], - equal_nan=equal_nan, + x1=x[0], + x2=x[1], ) -# logical_not +# less_equal @handle_frontend_test( - fn_tree="jax.numpy.logical_not", - dtypes_values=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("bool"), - num_arrays=1, + fn_tree="jax.numpy.less_equal", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, + shared_dtype=True, ), + test_with_out=st.just(False), ) -def test_jax_logical_not( - dtypes_values, - on_device, - fn_tree, +def test_jax_less_equal( + dtype_and_x, frontend, test_flags, + fn_tree, backend_fw, ): - x_dtypes, x = dtypes_values - np_helpers.test_frontend_function( - input_dtypes=x_dtypes, - frontend=frontend, + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, backend_to_test=backend_fw, + frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, - on_device=on_device, - x=x[0], + x1=x[0], + x2=x[1], ) -# logical_or +# logical_and @handle_frontend_test( - fn_tree="jax.numpy.logical_or", + fn_tree="jax.numpy.logical_and", dtypes_values=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("bool"), num_arrays=2, ), ) -def test_jax_logical_or( +def test_jax_logical_and( dtypes_values, on_device, fn_tree, @@ -829,32 +954,32 @@ def test_jax_logical_or( np_helpers.test_frontend_function( input_dtypes=x_dtypes, frontend=frontend, + backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, x1=x[0], x2=x[1], - backend_to_test=backend_fw, ) -# isscalar +# logical_not @handle_frontend_test( - fn_tree="jax.numpy.isscalar", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric") + fn_tree="jax.numpy.logical_not", + dtypes_values=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("bool"), + num_arrays=1, ), ) -def test_jax_isscalar( - *, - dtype_and_x, +def test_jax_logical_not( + dtypes_values, on_device, fn_tree, frontend, - backend_fw, test_flags, + backend_fw, ): - x_dtypes, x = dtype_and_x + x_dtypes, x = dtypes_values np_helpers.test_frontend_function( input_dtypes=x_dtypes, frontend=frontend, @@ -866,68 +991,32 @@ def test_jax_isscalar( ) -# left_shift +# logical_or @handle_frontend_test( - fn_tree="jax.numpy.left_shift", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("integer"), + fn_tree="jax.numpy.logical_or", + dtypes_values=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("bool"), num_arrays=2, ), ) -def test_jax_left_shift( - *, - dtype_and_x, +def test_jax_logical_or( + dtypes_values, on_device, fn_tree, frontend, test_flags, backend_fw, ): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, + x_dtypes, x = dtypes_values + np_helpers.test_frontend_function( + input_dtypes=x_dtypes, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, x1=x[0], x2=x[1], - ) - - -# isreal -@handle_frontend_test( - fn_tree="jax.numpy.isreal", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_value=-np.inf, - max_value=np.inf, - min_num_dims=1, - max_num_dims=3, - min_dim_size=1, - max_dim_size=3, - allow_inf=True, - ), -) -def test_jax_isreal( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], ) @@ -960,56 +1049,22 @@ def test_jax_logical_xor( ) +# not_equal @handle_frontend_test( - fn_tree="jax.numpy.right_shift", + fn_tree="jax.numpy.not_equal", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("integer"), + available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2, shared_dtype=True, ), test_with_out=st.just(False), ) -def test_jax_right_shift( - *, +def test_jax_not_equal( dtype_and_x, frontend, test_flags, fn_tree, backend_fw, - on_device, -): - dtype, xs = dtype_and_x - - xs[1] = np.asarray(np.clip(xs[1], 0, np.iinfo(dtype[1]).bits - 1), dtype=dtype[1]) - - helpers.test_frontend_function( - input_dtypes=dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x1=xs[0], - x2=xs[1], - ) - - -# iscomplex -@handle_frontend_test( - fn_tree="jax.numpy.iscomplex", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("real_and_complex"), min_num_dims=1 - ), - test_with_out=st.just(False), -) -def test_jax_iscomplex( - dtype_and_x, - frontend, - on_device, - *, - fn_tree, - test_flags, - backend_fw, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -1018,28 +1073,28 @@ def test_jax_iscomplex( frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, - on_device=on_device, - x=x[0], + x1=x[0], + x2=x[1], ) +# fromfunction @handle_frontend_test( - fn_tree="jax.numpy.isrealobj", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), min_num_dims=1 - ), + fn_tree="jax.numpy.fromfunction", + input_dtype=helpers.get_dtypes("valid"), + function_and_shape_and_dtype=_func_and_shape_dtype_helper(), test_with_out=st.just(False), ) -def test_jax_isrealobj( - dtype_and_x, +def test_jax_numpy_fromfunction( + input_dtype, + function_and_shape_and_dtype, + backend_fw, frontend, on_device, - *, fn_tree, test_flags, - backend_fw, ): - input_dtype, x = dtype_and_x + function, shape, dtype = function_and_shape_and_dtype helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -1047,20 +1102,29 @@ def test_jax_isrealobj( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], + function=function, + shape=shape, + dtype=dtype, ) -# iscomplexobj +# packbits @handle_frontend_test( - fn_tree="jax.numpy.iscomplexobj", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), + fn_tree="jax.numpy.packbits", + dtype_x_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("integer"), + min_num_dims=1, + min_dim_size=1, + valid_axis=True, + max_axes_size=1, + force_int_axis=True, ), test_with_out=st.just(False), + bitorder=st.sampled_from(["big", "little"]), ) -def test_jax_iscomplexobj( - dtype_and_x, +def test_jax_numpy_packbits( + dtype_x_axis, + bitorder, frontend, on_device, *, @@ -1068,15 +1132,17 @@ def test_jax_iscomplexobj( test_flags, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, x, axis = dtype_x_axis helpers.test_frontend_function( input_dtypes=input_dtype, - backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, x=x[0], + axis=axis, + bitorder=bitorder, + backend_to_test=backend_fw, ) @@ -1112,93 +1178,35 @@ def test_jax_numpy_setxor1d( ) -# packbits @handle_frontend_test( - fn_tree="jax.numpy.packbits", - dtype_x_axis=helpers.dtype_values_axis( + fn_tree="jax.numpy.right_shift", + dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("integer"), - min_num_dims=1, - min_dim_size=1, - valid_axis=True, - max_axes_size=1, - force_int_axis=True, + num_arrays=2, + shared_dtype=True, ), test_with_out=st.just(False), - bitorder=st.sampled_from(["big", "little"]), ) -def test_jax_numpy_packbits( - dtype_x_axis, - bitorder, - frontend, - on_device, +def test_jax_right_shift( *, - fn_tree, + dtype_and_x, + frontend, test_flags, + fn_tree, backend_fw, + on_device, ): - input_dtype, x, axis = dtype_x_axis - helpers.test_frontend_function( - input_dtypes=input_dtype, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - axis=axis, - bitorder=bitorder, - backend_to_test=backend_fw, - ) - - -@st.composite -def _func_and_shape_dtype_helper(draw): - # here assumption is that the input func will take the len(shape) no of parameters - def add_numbers(*args): - total = 0 - for num in args: - total += num - return total - - shape = draw( - helpers.get_shape( - allow_none=False, - min_num_dims=1, - max_num_dims=3, - min_dim_size=1, - max_dim_size=3, - ) - ) - - dtype = draw(helpers.get_dtypes("valid")) - - return add_numbers, shape, dtype[0] + dtype, xs = dtype_and_x + xs[1] = np.asarray(np.clip(xs[1], 0, np.iinfo(dtype[1]).bits - 1), dtype=dtype[1]) -# fromfunction -@handle_frontend_test( - fn_tree="jax.numpy.fromfunction", - input_dtype=helpers.get_dtypes("valid"), - function_and_shape_and_dtype=_func_and_shape_dtype_helper(), - test_with_out=st.just(False), -) -def test_jax_numpy_fromfunction( - input_dtype, - function_and_shape_and_dtype, - backend_fw, - frontend, - on_device, - fn_tree, - test_flags, -): - function, shape, dtype = function_and_shape_and_dtype helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - function=function, - shape=shape, - dtype=dtype, + x1=xs[0], + x2=xs[1], ) diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_manipulations.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_manipulations.py index ad5823f5640cc..b4a95c76efacb 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_manipulations.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_manipulations.py @@ -18,6 +18,45 @@ ) +# --- Helpers --- # +# --------------- # + + +# concatenate +@st.composite +def _arrays_idx_n_dtypes(draw): + num_dims = draw(st.shared(helpers.ints(min_value=1, max_value=4), key="num_dims")) + num_arrays = draw( + st.shared(helpers.ints(min_value=2, max_value=4), key="num_arrays") + ) + common_shape = draw( + helpers.list_of_size( + x=helpers.ints(min_value=2, max_value=3), + size=num_dims - 1, + ) + ) + unique_idx = draw(helpers.ints(min_value=0, max_value=num_dims - 1)) + unique_dims = draw( + helpers.list_of_size( + x=helpers.ints(min_value=2, max_value=3), + size=num_arrays, + ) + ) + xs = list() + input_dtypes = draw( + helpers.array_dtypes(available_dtypes=draw(helpers.get_dtypes("valid"))) + ) + for ud, dt in zip(unique_dims, input_dtypes): + x = draw( + helpers.array_values( + shape=common_shape[:unique_idx] + [ud] + common_shape[unique_idx:], + dtype=dt, + ) + ) + xs.append(x) + return xs, input_dtypes, unique_idx + + @st.composite def _get_clip_inputs(draw): shape = draw( @@ -60,148 +99,82 @@ def _get_clip_inputs(draw): return x_dtype, x, min, max -# clip -@handle_frontend_test( - fn_tree="jax.numpy.clip", - input_and_ranges=_get_clip_inputs(), -) -def test_jax_clip( - *, - input_and_ranges, - on_device, - fn_tree, - frontend, - backend_fw, - test_flags, -): - x_dtype, x, min, max = input_and_ranges - helpers.test_frontend_function( - input_dtypes=x_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - a=x[0], - a_min=min, - a_max=max, +# block +@st.composite +def _get_input_and_block(draw): + shapes = draw( + st.lists( + helpers.get_shape( + min_num_dims=1, max_num_dims=5, min_dim_size=2, max_dim_size=10 + ), + min_size=2, + max_size=10, + ) + ) + x_dtypes, xs = zip( + *[ + draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=1, + max_num_dims=5, + min_dim_size=2, + max_dim_size=10, + shape=shape, + ) + ) + for shape in shapes + ] ) + return x_dtypes, xs -# concatenate +# broadcast_to @st.composite -def _arrays_idx_n_dtypes(draw): - num_dims = draw(st.shared(helpers.ints(min_value=1, max_value=4), key="num_dims")) - num_arrays = draw( - st.shared(helpers.ints(min_value=2, max_value=4), key="num_arrays") - ) - common_shape = draw( - helpers.list_of_size( - x=helpers.ints(min_value=2, max_value=3), - size=num_dims - 1, +def _get_input_and_broadcast_shape(draw): + dim1 = draw(helpers.ints(min_value=2, max_value=5)) + x_dtype, x = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=1, + max_num_dims=5, + min_dim_size=2, + max_dim_size=10, + shape=(dim1,), ) ) - unique_idx = draw(helpers.ints(min_value=0, max_value=num_dims - 1)) - unique_dims = draw( - helpers.list_of_size( - x=helpers.ints(min_value=2, max_value=3), - size=num_arrays, + broadcast_dim = draw(helpers.ints(min_value=1, max_value=3)) + shape = () + for _ in range(broadcast_dim): + shape += (draw(helpers.ints(min_value=1, max_value=dim1)),) + shape += (dim1,) + return x_dtype, x, shape + + +# resize +@st.composite +def _get_input_and_new_shape(draw): + shape = draw( + helpers.get_shape( + min_num_dims=2, max_num_dims=5, min_dim_size=2, max_dim_size=10 ) ) - xs = list() - input_dtypes = draw( - helpers.array_dtypes(available_dtypes=draw(helpers.get_dtypes("valid"))) + new_shape = draw( + helpers.get_shape( + min_num_dims=2, max_num_dims=5, min_dim_size=2, max_dim_size=10 + ).filter(lambda x: np.prod(x) == np.prod(shape)) ) - for ud, dt in zip(unique_dims, input_dtypes): - x = draw( - helpers.array_values( - shape=common_shape[:unique_idx] + [ud] + common_shape[unique_idx:], - dtype=dt, - ) + x_dtype, x = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=2, + max_num_dims=5, + min_dim_size=2, + max_dim_size=10, + shape=shape, ) - xs.append(x) - return xs, input_dtypes, unique_idx - - -@handle_frontend_test( - fn_tree="jax.numpy.concatenate", - xs_n_input_dtypes_n_unique_idx=_arrays_idx_n_dtypes(), - test_with_out=st.just(False), -) -def test_jax_concat( - *, - xs_n_input_dtypes_n_unique_idx, - on_device, - fn_tree, - frontend, - backend_fw, - test_flags, -): - xs, input_dtypes, unique_idx = xs_n_input_dtypes_n_unique_idx - helpers.test_frontend_function( - input_dtypes=input_dtypes, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - arrays=xs, - axis=unique_idx, - ) - - -# repeat -@handle_frontend_test( - fn_tree="jax.numpy.repeat", - dtype_value=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - shape=st.shared(helpers.get_shape(min_num_dims=1), key="value_shape"), - ), - axis=st.shared( - st.one_of( - st.none(), - helpers.get_axis( - shape=st.shared(helpers.get_shape(min_num_dims=1), key="value_shape"), - max_size=1, - ), - ), - key="axis", - ), - repeat=st.one_of(st.integers(1, 10), _repeat_helper()), - test_with_out=st.just(False), -) -def test_jax_repeat( - *, - dtype_value, - axis, - repeat, - on_device, - fn_tree, - frontend, - backend_fw, - test_flags, -): - value_dtype, value = dtype_value - - if not isinstance(repeat, int): - repeat_dtype, repeat_list = repeat - repeat = repeat_list[0] - value_dtype += repeat_dtype - - if not isinstance(axis, int) and axis is not None: - axis = axis[0] - - helpers.test_frontend_function( - input_dtypes=value_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - a=value[0], - repeats=repeat, - axis=axis, ) + return x_dtype, x, new_shape # reshape @@ -226,226 +199,145 @@ def _get_input_and_reshape(draw): return x_dtype, x, new_shape -@handle_frontend_test( - fn_tree="jax.numpy.reshape", - input_x_shape=_get_input_and_reshape(), - order=st.sampled_from(["C", "F"]), - test_with_out=st.just(False), -) -def test_jax_reshape( - *, - input_x_shape, - order, - on_device, - fn_tree, - frontend, - backend_fw, - test_flags, -): - x_dtype, x, shape = input_x_shape - helpers.test_frontend_function( - input_dtypes=x_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - a=x[0], - newshape=shape, - order=order, +# swapaxes +@st.composite +def _get_input_and_two_swapabble_axes(draw): + x_dtype, x, x_shape = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + ret_shape=True, + min_num_dims=1, + max_num_dims=10, + ) ) - -# ravel -@handle_frontend_test( - fn_tree="jax.numpy.ravel", - dtype_and_values=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - min_num_dims=1, - max_num_dims=5, - min_dim_size=2, - max_dim_size=10, - shape=helpers.get_shape( - min_num_dims=2, max_num_dims=5, min_dim_size=2, max_dim_size=10 - ), - ), - order=st.sampled_from(["C", "F"]), - test_with_out=st.just(False), -) -def test_jax_ravel( - *, - dtype_and_values, - order, - on_device, - backend_fw, - fn_tree, - frontend, - test_flags, -): - input_dtypes, x = dtype_and_values - helpers.test_frontend_function( - input_dtypes=input_dtypes, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - a=x[0], - order=order, + axis1 = draw( + helpers.ints( + min_value=-1 * len(x_shape), + max_value=len(x_shape) - 1, + ) ) + axis2 = draw( + helpers.ints( + min_value=-1 * len(x_shape), + max_value=len(x_shape) - 1, + ) + ) + return x_dtype, x, axis1, axis2 -# resize +# pad @st.composite -def _get_input_and_new_shape(draw): - shape = draw( - helpers.get_shape( - min_num_dims=2, max_num_dims=5, min_dim_size=2, max_dim_size=10 +def _pad_helper(draw): + mode = draw( + st.sampled_from( + [ + "constant", + "edge", + "linear_ramp", + "maximum", + "mean", + "median", + "minimum", + "reflect", + "symmetric", + "wrap", + ] ) ) - new_shape = draw( - helpers.get_shape( - min_num_dims=2, max_num_dims=5, min_dim_size=2, max_dim_size=10 - ).filter(lambda x: np.prod(x) == np.prod(shape)) - ) - x_dtype, x = draw( + if mode == "median": + dtypes = "float" + else: + dtypes = "numeric" + dtype, input, shape = draw( helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - min_num_dims=2, - max_num_dims=5, - min_dim_size=2, - max_dim_size=10, - shape=shape, - ) + available_dtypes=helpers.get_dtypes(dtypes), + ret_shape=True, + min_num_dims=1, + min_value=-100, + max_value=100, + ).filter( + lambda x: x[0][0] not in ["float16", "bfloat16", "complex64", "complex128"] + ), ) - return x_dtype, x, new_shape + ndim = len(shape) + pad_width = draw(_st_tuples_or_int(ndim, min_val=0)) + kwargs = {} + if mode == "reflect" or mode == "symmetric": + kwargs["reflect_type"] = draw(st.sampled_from(["even", "odd"])) + if mode in ["maximum", "mean", "median", "minimum"]: + kwargs["stat_length"] = draw(_st_tuples_or_int(ndim, min_val=2)) + if mode in ["linear_ramp"]: + kwargs["end_values"] = draw(_st_tuples_or_int(ndim)) + if mode == "constant": + kwargs["constant_values"] = draw(_st_tuples_or_int(ndim)) + return dtype, input[0], pad_width, kwargs, mode -@handle_frontend_test( - fn_tree="jax.numpy.resize", - input_x_shape=_get_input_and_new_shape(), - test_with_out=st.just(True), -) -def test_jax_resize( - *, - input_x_shape, - on_device, - fn_tree, - frontend, - backend_fw, - test_flags, -): - x_dtype, x, new_shape = input_x_shape - helpers.test_frontend_function( - input_dtypes=x_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - a=x[0], - new_shape=new_shape, - ) +# TODO: uncomment when block is reimplemented +# @handle_frontend_test( +# fn_tree="jax.numpy.block", +# input_x_shape=_get_input_and_block(), +# test_with_out=st.just(False), +# ) +# def test_jax_block( +# *, +# input_x_shape, +# on_device, +# fn_tree, +# frontend, +# test_flags, +# ): +# x_dtypes, xs = input_x_shape +# helpers.test_frontend_function( +# input_dtypes=x_dtypes, +# frontend=frontend, +# test_flags=test_flags, +# fn_tree=fn_tree, +# on_device=on_device, +# arrays=xs, +# ) -# moveaxis -@handle_frontend_test( - fn_tree="jax.numpy.moveaxis", - dtype_and_a=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - min_value=-100, - max_value=100, - shape=st.shared( - helpers.get_shape( - min_num_dims=1, - max_num_dims=3, - min_dim_size=1, - max_dim_size=3, - ), - key="a_s_d", - ), - ), - source=helpers.get_axis( - allow_none=False, - unique=True, - shape=st.shared( - helpers.get_shape( - min_num_dims=1, - max_num_dims=3, - min_dim_size=1, - max_dim_size=3, - ), - key="a_s_d", - ), - min_size=1, - force_int=True, - ), - destination=helpers.get_axis( - allow_none=False, - unique=True, - shape=st.shared( - helpers.get_shape( - min_num_dims=1, - max_num_dims=3, - min_dim_size=1, - max_dim_size=3, - ), - key="a_s_d", - ), - min_size=1, - force_int=True, - ), - test_with_out=st.just(False), -) -def test_jax_moveaxis( - *, - dtype_and_a, - source, - destination, - on_device, - fn_tree, - frontend, - backend_fw, - test_flags, -): - input_dtype, a = dtype_and_a - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - a=a[0], - source=source, - destination=destination, - ) +@st.composite +def _squeeze_helper(draw): + shape = draw(st.shared(helpers.get_shape(), key="shape")) + valid_axes = [idx for idx in range(len(shape)) if shape[idx] == 1] + [None] + return draw(st.sampled_from(valid_axes)) -# flipud +# --- Main --- # +# ------------ # + + +# append @handle_frontend_test( - fn_tree="jax.numpy.flipud", - dtype_and_m=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - min_value=-100, - max_value=100, - min_num_dims=1, - max_num_dims=3, - min_dim_size=1, - max_dim_size=3, + fn_tree="jax.numpy.append", + dtype_values_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("valid"), + num_arrays=2, + shape=helpers.get_shape( + min_num_dims=1, + max_num_dims=5, + min_dim_size=1, + max_dim_size=5, + ), + shared_dtype=True, + valid_axis=True, + allow_neg_axes=True, + force_int_axis=True, ), test_with_out=st.just(False), ) -def test_jax_flipud( - *, - dtype_and_m, +def test_jax_append( + dtype_values_axis, on_device, fn_tree, frontend, backend_fw, test_flags, ): - input_dtype, m = dtype_and_m + input_dtype, values, axis = dtype_values_axis helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -453,61 +345,35 @@ def test_jax_flipud( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - m=m[0], - ) - - -# transpose -@handle_frontend_test( - fn_tree="jax.numpy.transpose", - array_and_axes=np_frontend_helpers._array_and_axes_permute_helper( - min_num_dims=0, - max_num_dims=5, - min_dim_size=0, - max_dim_size=10, - ), - test_with_out=st.just(False), -) -def test_jax_transpose( - *, - array_and_axes, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - array, dtype, axes = array_and_axes - helpers.test_frontend_function( - input_dtypes=dtype, - frontend=frontend, - backend_to_test=backend_fw, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - a=array, - axes=axes, + arr=values[0], + values=values[1], + axis=axis, ) -# flip +# array_split @handle_frontend_test( - fn_tree="jax.numpy.flip", + fn_tree="jax.numpy.array_split", dtype_value=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), + available_dtypes=helpers.get_dtypes("integer"), shape=st.shared(helpers.get_shape(min_num_dims=1), key="value_shape"), ), - axis=helpers.get_axis( - shape=st.shared(helpers.get_shape(min_num_dims=1), key="value_shape"), - min_size=1, - max_size=1, - force_int=True, + indices_or_sections=_get_splits( + min_num_dims=1, allow_none=False, is_mod_split=True + ), + axis=st.shared( + helpers.get_axis( + shape=st.shared(helpers.get_shape(min_num_dims=1), key="value_shape"), + force_int=True, + ), + key="target_axis", ), test_with_out=st.just(False), ) -def test_jax_flip( +def test_jax_array_split( *, dtype_value, + indices_or_sections, axis, on_device, fn_tree, @@ -515,38 +381,44 @@ def test_jax_flip( backend_fw, test_flags, ): - dtype, value = dtype_value + input_dtype, value = dtype_value + assume(isinstance(indices_or_sections, int)) helpers.test_frontend_function( - input_dtypes=dtype, + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - m=value[0], + ary=value[0], + indices_or_sections=indices_or_sections, axis=axis, ) -# fliplr +# atleast_1d @handle_frontend_test( - fn_tree="jax.numpy.fliplr", - dtype_and_m=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - min_num_dims=2, + fn_tree="jax.numpy.atleast_1d", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + num_arrays=helpers.ints(min_value=1, max_value=10), ), test_with_out=st.just(False), ) -def test_jax_fliplr( +def test_jax_atleast_1d( *, - dtype_and_m, + dtype_and_x, on_device, fn_tree, frontend, backend_fw, test_flags, ): - input_dtype, m = dtype_and_m + input_dtype, arrays = dtype_and_x + arys = {} + for i, (array, idtype) in enumerate(zip(arrays, input_dtype)): + arys["arrs{}".format(i)] = np.asarray(array, dtype=idtype) + test_flags.num_positional_args = len(arys) helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -554,33 +426,33 @@ def test_jax_fliplr( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - m=m[0], + **arys, ) -# expand_dims +# atleast_2d @handle_frontend_test( - fn_tree="jax.numpy.expand_dims", - dtype_x_axis=helpers.dtype_values_axis( + fn_tree="jax.numpy.atleast_2d", + dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), - min_num_dims=1, - max_num_dims=5, - min_dim_size=2, - max_dim_size=10, - force_int_axis=True, - valid_axis=True, + num_arrays=helpers.ints(min_value=1, max_value=10), ), + test_with_out=st.just(False), ) -def test_jax_expand_dims( +def test_jax_atleast_2d( *, - dtype_x_axis, + dtype_and_x, on_device, fn_tree, frontend, backend_fw, test_flags, ): - input_dtype, x, axis = dtype_x_axis + input_dtype, arrays = dtype_and_x + arys = {} + for i, (array, idtype) in enumerate(zip(arrays, input_dtype)): + arys["arrs{}".format(i)] = np.asarray(array, dtype=idtype) + test_flags.num_positional_args = len(arys) helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -588,35 +460,33 @@ def test_jax_expand_dims( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - a=x[0], - axis=axis, + **arys, ) -# stack +# atleast_3d @handle_frontend_test( - fn_tree="jax.numpy.stack", - dtype_values_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=st.shared(helpers.ints(min_value=2, max_value=4), key="num_arrays"), - shape=helpers.get_shape(min_num_dims=1), - shared_dtype=True, - valid_axis=True, - allow_neg_axes=True, - force_int_axis=True, + fn_tree="jax.numpy.atleast_3d", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + num_arrays=helpers.ints(min_value=1, max_value=10), ), - dtype=helpers.get_dtypes("valid", full=False), + test_with_out=st.just(False), ) -def test_jax_stack( - dtype_values_axis, - dtype, +def test_jax_atleast_3d( + *, + dtype_and_x, on_device, fn_tree, frontend, backend_fw, test_flags, ): - input_dtype, values, axis = dtype_values_axis + input_dtype, arrays = dtype_and_x + arys = {} + for i, (array, idtype) in enumerate(zip(arrays, input_dtype)): + arys["arrs{}".format(i)] = np.asarray(array, dtype=idtype) + test_flags.num_positional_args = len(arys) helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -624,44 +494,31 @@ def test_jax_stack( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - arrays=values, - axis=axis, + **arys, ) -# take +# blackman @handle_frontend_test( - fn_tree="jax.numpy.take", - dtype_indices_axis=helpers.array_indices_axis( - array_dtypes=helpers.get_dtypes("numeric"), - indices_dtypes=["int32", "int64"], - min_num_dims=1, - max_num_dims=5, - min_dim_size=1, - max_dim_size=10, - indices_same_dims=True, - ), + fn_tree="jax.numpy.blackman", + m=helpers.ints(min_value=0, max_value=20), ) -def test_jax_take( - *, - dtype_indices_axis, - on_device, - fn_tree, +def test_jax_blackman( + m, frontend, backend_fw, test_flags, + fn_tree, + on_device, ): - input_dtypes, value, indices, axis, _ = dtype_indices_axis helpers.test_frontend_function( - input_dtypes=input_dtypes, - backend_to_test=backend_fw, + input_dtypes=["int64"], frontend=frontend, + backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - a=value, - indices=indices, - axis=axis, + M=m, ) @@ -733,28 +590,6 @@ def test_jax_broadcast_shapes( assert ret == frontend_ret -# broadcast_to -@st.composite -def _get_input_and_broadcast_shape(draw): - dim1 = draw(helpers.ints(min_value=2, max_value=5)) - x_dtype, x = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - min_num_dims=1, - max_num_dims=5, - min_dim_size=2, - max_dim_size=10, - shape=(dim1,), - ) - ) - broadcast_dim = draw(helpers.ints(min_value=1, max_value=3)) - shape = () - for _ in range(broadcast_dim): - shape += (draw(helpers.ints(min_value=1, max_value=dim1)),) - shape += (dim1,) - return x_dtype, x, shape - - @handle_frontend_test( fn_tree="jax.numpy.broadcast_to", input_x_broadcast=_get_input_and_broadcast_shape(), @@ -782,125 +617,117 @@ def test_jax_broadcast_to( ) -# append +# clip @handle_frontend_test( - fn_tree="jax.numpy.append", - dtype_values_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("valid"), - num_arrays=2, - shape=helpers.get_shape( - min_num_dims=1, - max_num_dims=5, - min_dim_size=1, - max_dim_size=5, - ), - shared_dtype=True, - valid_axis=True, - allow_neg_axes=True, - force_int_axis=True, - ), - test_with_out=st.just(False), + fn_tree="jax.numpy.clip", + input_and_ranges=_get_clip_inputs(), ) -def test_jax_append( - dtype_values_axis, +def test_jax_clip( + *, + input_and_ranges, on_device, fn_tree, frontend, backend_fw, test_flags, ): - input_dtype, values, axis = dtype_values_axis + x_dtype, x, min, max = input_and_ranges helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=x_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - arr=values[0], - values=values[1], - axis=axis, + a=x[0], + a_min=min, + a_max=max, ) -# swapaxes -@st.composite -def _get_input_and_two_swapabble_axes(draw): - x_dtype, x, x_shape = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - ret_shape=True, - min_num_dims=1, - max_num_dims=10, - ) - ) - - axis1 = draw( - helpers.ints( - min_value=-1 * len(x_shape), - max_value=len(x_shape) - 1, - ) - ) - axis2 = draw( - helpers.ints( - min_value=-1 * len(x_shape), - max_value=len(x_shape) - 1, - ) +# column_stack +@handle_frontend_test( + fn_tree="jax.numpy.column_stack", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_num_dims=1, + ), + factor=helpers.ints(min_value=2, max_value=6), +) +def test_jax_column_stack( + dtype_and_x, + factor, + frontend, + backend_fw, + test_flags, + fn_tree, + on_device, +): + dtype, x = dtype_and_x + ys = [x[0]] + for i in range(factor): + ys += [x[0]] + helpers.test_frontend_function( + input_dtypes=[dtype[0]] * (factor + 1), + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + tup=ys, ) - return x_dtype, x, axis1, axis2 @handle_frontend_test( - fn_tree="jax.numpy.swapaxes", - input_x_axis1_axis2=_get_input_and_two_swapabble_axes(), + fn_tree="jax.numpy.concatenate", + xs_n_input_dtypes_n_unique_idx=_arrays_idx_n_dtypes(), test_with_out=st.just(False), ) -def test_jax_swapaxes( +def test_jax_concat( *, - input_x_axis1_axis2, - test_flags, + xs_n_input_dtypes_n_unique_idx, on_device, fn_tree, frontend, backend_fw, + test_flags, ): - x_dtype, x, axis1, axis2 = input_x_axis1_axis2 + xs, input_dtypes, unique_idx = xs_n_input_dtypes_n_unique_idx helpers.test_frontend_function( - input_dtypes=x_dtype, + input_dtypes=input_dtypes, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - a=x[0], - axis1=axis1, - axis2=axis2, + arrays=xs, + axis=unique_idx, ) -# atleast_3d +# dsplit @handle_frontend_test( - fn_tree="jax.numpy.atleast_3d", - dtype_and_x=helpers.dtype_and_values( + fn_tree="jax.numpy.dsplit", + dtype_value=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), - num_arrays=helpers.ints(min_value=1, max_value=10), + shape=st.shared(helpers.get_shape(min_num_dims=3), key="value_shape"), + ), + indices_or_sections=_get_splits( + min_num_dims=3, axis=2, allow_none=False, is_mod_split=True ), test_with_out=st.just(False), ) -def test_jax_atleast_3d( +def test_jax_dsplit( *, - dtype_and_x, + dtype_value, + indices_or_sections, on_device, fn_tree, frontend, backend_fw, test_flags, ): - input_dtype, arrays = dtype_and_x - arys = {} - for i, (array, idtype) in enumerate(zip(arrays, input_dtype)): - arys["arrs{}".format(i)] = np.asarray(array, dtype=idtype) - test_flags.num_positional_args = len(arys) + input_dtype, value = dtype_value helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -908,20 +735,25 @@ def test_jax_atleast_3d( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - **arys, + ary=value[0], + indices_or_sections=indices_or_sections, ) -# atleast_2d +# dstack @handle_frontend_test( - fn_tree="jax.numpy.atleast_2d", + fn_tree="jax.numpy.dstack", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), + shared_dtype=True, num_arrays=helpers.ints(min_value=1, max_value=10), + shape=helpers.get_shape( + min_num_dims=1, + ), ), test_with_out=st.just(False), ) -def test_jax_atleast_2d( +def test_jax_dstack( *, dtype_and_x, on_device, @@ -930,11 +762,7 @@ def test_jax_atleast_2d( backend_fw, test_flags, ): - input_dtype, arrays = dtype_and_x - arys = {} - for i, (array, idtype) in enumerate(zip(arrays, input_dtype)): - arys["arrs{}".format(i)] = np.asarray(array, dtype=idtype) - test_flags.num_positional_args = len(arys) + input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -942,33 +770,33 @@ def test_jax_atleast_2d( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - **arys, + tup=x, ) -# atleast_1d +# expand_dims @handle_frontend_test( - fn_tree="jax.numpy.atleast_1d", - dtype_and_x=helpers.dtype_and_values( + fn_tree="jax.numpy.expand_dims", + dtype_x_axis=helpers.dtype_values_axis( available_dtypes=helpers.get_dtypes("valid"), - num_arrays=helpers.ints(min_value=1, max_value=10), + min_num_dims=1, + max_num_dims=5, + min_dim_size=2, + max_dim_size=10, + force_int_axis=True, + valid_axis=True, ), - test_with_out=st.just(False), ) -def test_jax_atleast_1d( +def test_jax_expand_dims( *, - dtype_and_x, + dtype_x_axis, on_device, fn_tree, frontend, backend_fw, test_flags, ): - input_dtype, arrays = dtype_and_x - arys = {} - for i, (array, idtype) in enumerate(zip(arrays, input_dtype)): - arys["arrs{}".format(i)] = np.asarray(array, dtype=idtype) - test_flags.num_positional_args = len(arys) + input_dtype, x, axis = dtype_x_axis helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -976,34 +804,37 @@ def test_jax_atleast_1d( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - **arys, + a=x[0], + axis=axis, ) -# tril +# flip @handle_frontend_test( - fn_tree="jax.numpy.tril", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - min_num_dims=2, - max_num_dims=5, - min_dim_size=1, - max_dim_size=5, + fn_tree="jax.numpy.flip", + dtype_value=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + shape=st.shared(helpers.get_shape(min_num_dims=1), key="value_shape"), + ), + axis=helpers.get_axis( + shape=st.shared(helpers.get_shape(min_num_dims=1), key="value_shape"), + min_size=1, + max_size=1, + force_int=True, ), - k=helpers.ints(min_value=-10, max_value=10), test_with_out=st.just(False), ) -def test_jax_tril( +def test_jax_flip( *, - dtype_and_x, - k, + dtype_value, + axis, on_device, fn_tree, frontend, backend_fw, test_flags, ): - dtype, x = dtype_and_x + dtype, value = dtype_value helpers.test_frontend_function( input_dtypes=dtype, backend_to_test=backend_fw, @@ -1011,196 +842,140 @@ def test_jax_tril( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - m=x[0], - k=k, + m=value[0], + axis=axis, ) -# trim_zeros +# fliplr @handle_frontend_test( - fn_tree="jax.numpy.trim_zeros", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), min_num_dims=1, max_num_dims=1 + fn_tree="jax.numpy.fliplr", + dtype_and_m=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + min_num_dims=2, ), - trim=st.sampled_from(["f", "b", "fb"]), + test_with_out=st.just(False), ) -def test_jax_numpy_trim_zeros( - frontend, - on_device, +def test_jax_fliplr( *, - dtype_and_x, - backend_fw, - trim, + dtype_and_m, + on_device, fn_tree, + frontend, + backend_fw, test_flags, ): - dtype, x = dtype_and_x + input_dtype, m = dtype_and_m helpers.test_frontend_function( - input_dtypes=dtype, + input_dtypes=input_dtype, + backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - backend_to_test=backend_fw, - filt=x[0], - trim=trim, - ) - - -# block -@st.composite -def _get_input_and_block(draw): - shapes = draw( - st.lists( - helpers.get_shape( - min_num_dims=1, max_num_dims=5, min_dim_size=2, max_dim_size=10 - ), - min_size=2, - max_size=10, - ) - ) - x_dtypes, xs = zip( - *[ - draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - min_num_dims=1, - max_num_dims=5, - min_dim_size=2, - max_dim_size=10, - shape=shape, - ) - ) - for shape in shapes - ] + m=m[0], ) - return x_dtypes, xs - - -# TODO: uncomment when block is reimplemented -# @handle_frontend_test( -# fn_tree="jax.numpy.block", -# input_x_shape=_get_input_and_block(), -# test_with_out=st.just(False), -# ) -# def test_jax_block( -# *, -# input_x_shape, -# on_device, -# fn_tree, -# frontend, -# test_flags, -# ): -# x_dtypes, xs = input_x_shape -# helpers.test_frontend_function( -# input_dtypes=x_dtypes, -# frontend=frontend, -# test_flags=test_flags, -# fn_tree=fn_tree, -# on_device=on_device, -# arrays=xs, -# ) - - -@st.composite -def _squeeze_helper(draw): - shape = draw(st.shared(helpers.get_shape(), key="shape")) - valid_axes = [idx for idx in range(len(shape)) if shape[idx] == 1] + [None] - return draw(st.sampled_from(valid_axes)) -# squeeze +# flipud @handle_frontend_test( - fn_tree="jax.numpy.squeeze", - dtype_and_values=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - shape=st.shared(helpers.get_shape(), key="shape"), + fn_tree="jax.numpy.flipud", + dtype_and_m=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + min_value=-100, + max_value=100, + min_num_dims=1, + max_num_dims=3, + min_dim_size=1, + max_dim_size=3, ), - axis=_squeeze_helper(), test_with_out=st.just(False), ) -def test_jax_squeeze( +def test_jax_flipud( *, - dtype_and_values, - axis, + dtype_and_m, on_device, fn_tree, frontend, backend_fw, test_flags, ): - input_dtype, values = dtype_and_values + input_dtype, m = dtype_and_m + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + m=m[0], + ) + + +# hamming +@handle_frontend_test( + fn_tree="jax.numpy.hamming", + m=helpers.ints(min_value=0, max_value=20), +) +def test_jax_hamming( + m, + frontend, + backend_fw, + test_flags, + fn_tree, + on_device, +): helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=["int64"], backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - a=values[0], - axis=axis, + M=m, ) -# rot90 +# hanning @handle_frontend_test( - fn_tree="jax.numpy.rot90", - dtype_m_k_axes=_get_dtype_values_k_axes_for_rot90( - available_dtypes=helpers.get_dtypes("numeric"), - min_num_dims=1, - max_num_dims=5, - min_dim_size=1, - max_dim_size=10, - ), - test_with_out=st.just(False), + fn_tree="jax.numpy.hanning", + m=helpers.ints(min_value=0, max_value=20), ) -def test_jax_rot90( - *, - dtype_m_k_axes, - on_device, - fn_tree, +def test_jax_hanning( + m, frontend, backend_fw, test_flags, + fn_tree, + on_device, ): - input_dtype, m, k, axes = dtype_m_k_axes helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, + input_dtypes=["int64"], frontend=frontend, + backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - m=m, - k=k, - axes=tuple(axes), + M=m, ) -# split +# hsplit @handle_frontend_test( - fn_tree="jax.numpy.split", + fn_tree="jax.numpy.hsplit", dtype_value=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), - shape=st.shared(helpers.get_shape(min_num_dims=1), key="value_shape"), + shape=st.shared(helpers.get_shape(min_num_dims=2), key="value_shape"), ), indices_or_sections=_get_splits( - min_num_dims=1, allow_none=False, is_mod_split=True - ), - axis=st.shared( - helpers.get_axis( - shape=st.shared(helpers.get_shape(min_num_dims=1), key="value_shape"), - force_int=True, - ), - key="target_axis", + min_num_dims=2, axis=1, allow_none=False, is_mod_split=True ), test_with_out=st.just(False), ) -def test_jax_split( +def test_jax_hsplit( *, dtype_value, indices_or_sections, - axis, on_device, fn_tree, frontend, @@ -1217,42 +992,97 @@ def test_jax_split( on_device=on_device, ary=value[0], indices_or_sections=indices_or_sections, - axis=axis, ) -# array_split +# kaiser @handle_frontend_test( - fn_tree="jax.numpy.array_split", - dtype_value=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("integer"), - shape=st.shared(helpers.get_shape(min_num_dims=1), key="value_shape"), + fn_tree="jax.numpy.kaiser", + m=helpers.ints(min_value=0, max_value=100), + beta=helpers.floats(min_value=-10, max_value=10), +) +def test_jax_kaiser( + m, + beta, + frontend, + backend_fw, + test_flags, + fn_tree, + on_device, +): + helpers.test_frontend_function( + input_dtypes=["int64", "float64"], + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + M=m, + beta=beta, + ) + + +# moveaxis +@handle_frontend_test( + fn_tree="jax.numpy.moveaxis", + dtype_and_a=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + min_value=-100, + max_value=100, + shape=st.shared( + helpers.get_shape( + min_num_dims=1, + max_num_dims=3, + min_dim_size=1, + max_dim_size=3, + ), + key="a_s_d", + ), ), - indices_or_sections=_get_splits( - min_num_dims=1, allow_none=False, is_mod_split=True + source=helpers.get_axis( + allow_none=False, + unique=True, + shape=st.shared( + helpers.get_shape( + min_num_dims=1, + max_num_dims=3, + min_dim_size=1, + max_dim_size=3, + ), + key="a_s_d", + ), + min_size=1, + force_int=True, ), - axis=st.shared( - helpers.get_axis( - shape=st.shared(helpers.get_shape(min_num_dims=1), key="value_shape"), - force_int=True, + destination=helpers.get_axis( + allow_none=False, + unique=True, + shape=st.shared( + helpers.get_shape( + min_num_dims=1, + max_num_dims=3, + min_dim_size=1, + max_dim_size=3, + ), + key="a_s_d", ), - key="target_axis", + min_size=1, + force_int=True, ), test_with_out=st.just(False), ) -def test_jax_array_split( +def test_jax_moveaxis( *, - dtype_value, - indices_or_sections, - axis, + dtype_and_a, + source, + destination, on_device, fn_tree, frontend, backend_fw, test_flags, ): - input_dtype, value = dtype_value - assume(isinstance(indices_or_sections, int)) + input_dtype, a = dtype_and_a helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -1260,189 +1090,233 @@ def test_jax_array_split( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - ary=value[0], - indices_or_sections=indices_or_sections, - axis=axis, + a=a[0], + source=source, + destination=destination, ) -# dsplit +# trim_zeros @handle_frontend_test( - fn_tree="jax.numpy.dsplit", - dtype_value=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - shape=st.shared(helpers.get_shape(min_num_dims=3), key="value_shape"), - ), - indices_or_sections=_get_splits( - min_num_dims=3, axis=2, allow_none=False, is_mod_split=True + fn_tree="jax.numpy.trim_zeros", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), min_num_dims=1, max_num_dims=1 ), - test_with_out=st.just(False), + trim=st.sampled_from(["f", "b", "fb"]), ) -def test_jax_dsplit( - *, - dtype_value, - indices_or_sections, +def test_jax_numpy_trim_zeros( + frontend, on_device, + *, + dtype_and_x, + backend_fw, + trim, fn_tree, + test_flags, +): + dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=dtype, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + backend_to_test=backend_fw, + filt=x[0], + trim=trim, + ) + + +@handle_frontend_test( + fn_tree="jax.numpy.pad", + dtype_and_input_and_other=_pad_helper(), + test_with_out=st.just(False), +) +def test_jax_pad( + *, + dtype_and_input_and_other, frontend, backend_fw, test_flags, + fn_tree, + on_device, ): - input_dtype, value = dtype_value + ( + dtype, + input, + pad_width, + kwargs, + mode, + ) = dtype_and_input_and_other + + if isinstance(pad_width, int): + pad_width = ((pad_width, pad_width),) * input.ndim + else: + pad_width = tuple( + tuple(pair) if isinstance(pair, list) else pair for pair in pad_width + ) + helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - ary=value[0], - indices_or_sections=indices_or_sections, + array=input, + pad_width=pad_width, + mode=mode, + **kwargs, ) -# tile +# ravel @handle_frontend_test( - fn_tree="jax.numpy.tile", - dtype_value=helpers.dtype_and_values( + fn_tree="jax.numpy.ravel", + dtype_and_values=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), - shape=st.shared(helpers.get_shape(min_num_dims=1), key="value_shape"), - ), - repeat=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("signed_integer"), - shape=st.shared(helpers.get_shape(min_num_dims=1), key="value_shape").map( - lambda rep: (len(rep),) + min_num_dims=1, + max_num_dims=5, + min_dim_size=2, + max_dim_size=10, + shape=helpers.get_shape( + min_num_dims=2, max_num_dims=5, min_dim_size=2, max_dim_size=10 ), - min_value=0, - max_value=10, ), + order=st.sampled_from(["C", "F"]), test_with_out=st.just(False), ) -def test_jax_tile( +def test_jax_ravel( *, - dtype_value, - repeat, + dtype_and_values, + order, on_device, + backend_fw, fn_tree, frontend, - backend_fw, test_flags, ): - dtype, value = dtype_value - repeat_dtype, repeat_list = repeat + input_dtypes, x = dtype_and_values helpers.test_frontend_function( - input_dtypes=dtype + repeat_dtype, + input_dtypes=input_dtypes, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - A=value[0], - reps=repeat_list[0], + a=x[0], + order=order, ) -# dstack +# repeat @handle_frontend_test( - fn_tree="jax.numpy.dstack", - dtype_and_x=helpers.dtype_and_values( + fn_tree="jax.numpy.repeat", + dtype_value=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), - shared_dtype=True, - num_arrays=helpers.ints(min_value=1, max_value=10), - shape=helpers.get_shape( - min_num_dims=1, + shape=st.shared(helpers.get_shape(min_num_dims=1), key="value_shape"), + ), + axis=st.shared( + st.one_of( + st.none(), + helpers.get_axis( + shape=st.shared(helpers.get_shape(min_num_dims=1), key="value_shape"), + max_size=1, + ), ), + key="axis", ), + repeat=st.one_of(st.integers(1, 10), _repeat_helper()), test_with_out=st.just(False), ) -def test_jax_dstack( +def test_jax_repeat( *, - dtype_and_x, + dtype_value, + axis, + repeat, on_device, fn_tree, frontend, backend_fw, test_flags, ): - input_dtype, x = dtype_and_x + value_dtype, value = dtype_value + + if not isinstance(repeat, int): + repeat_dtype, repeat_list = repeat + repeat = repeat_list[0] + value_dtype += repeat_dtype + + if not isinstance(axis, int) and axis is not None: + axis = axis[0] + helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=value_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - tup=x, + a=value[0], + repeats=repeat, + axis=axis, ) -# vsplit @handle_frontend_test( - fn_tree="jax.numpy.vsplit", - dtype_value=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - shape=st.shared(helpers.get_shape(min_num_dims=2), key="value_shape"), - ), - indices_or_sections=_get_splits( - min_num_dims=2, axis=0, allow_none=False, is_mod_split=True - ), + fn_tree="jax.numpy.reshape", + input_x_shape=_get_input_and_reshape(), + order=st.sampled_from(["C", "F"]), test_with_out=st.just(False), ) -def test_jax_vsplit( +def test_jax_reshape( *, - dtype_value, - indices_or_sections, + input_x_shape, + order, on_device, fn_tree, frontend, backend_fw, test_flags, ): - input_dtype, value = dtype_value + x_dtype, x, shape = input_x_shape helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=x_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - ary=value[0], - indices_or_sections=indices_or_sections, + a=x[0], + newshape=shape, + order=order, ) -# hsplit @handle_frontend_test( - fn_tree="jax.numpy.hsplit", - dtype_value=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - shape=st.shared(helpers.get_shape(min_num_dims=2), key="value_shape"), - ), - indices_or_sections=_get_splits( - min_num_dims=2, axis=1, allow_none=False, is_mod_split=True - ), - test_with_out=st.just(False), + fn_tree="jax.numpy.resize", + input_x_shape=_get_input_and_new_shape(), + test_with_out=st.just(True), ) -def test_jax_hsplit( +def test_jax_resize( *, - dtype_value, - indices_or_sections, + input_x_shape, on_device, fn_tree, frontend, backend_fw, test_flags, ): - input_dtype, value = dtype_value + x_dtype, x, new_shape = input_x_shape helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=x_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - ary=value[0], - indices_or_sections=indices_or_sections, + a=x[0], + new_shape=new_shape, ) @@ -1517,6 +1391,41 @@ def test_jax_roll( ) +# rot90 +@handle_frontend_test( + fn_tree="jax.numpy.rot90", + dtype_m_k_axes=_get_dtype_values_k_axes_for_rot90( + available_dtypes=helpers.get_dtypes("numeric"), + min_num_dims=1, + max_num_dims=5, + min_dim_size=1, + max_dim_size=10, + ), + test_with_out=st.just(False), +) +def test_jax_rot90( + *, + dtype_m_k_axes, + on_device, + fn_tree, + frontend, + backend_fw, + test_flags, +): + input_dtype, m, k, axes = dtype_m_k_axes + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + m=m, + k=k, + axes=tuple(axes), + ) + + # row_stack @handle_frontend_test( fn_tree="jax.numpy.row_stack", @@ -1550,202 +1459,254 @@ def test_jax_row_stack( ) -# column_stack +# split @handle_frontend_test( - fn_tree="jax.numpy.column_stack", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_num_dims=1, + fn_tree="jax.numpy.split", + dtype_value=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + shape=st.shared(helpers.get_shape(min_num_dims=1), key="value_shape"), ), - factor=helpers.ints(min_value=2, max_value=6), + indices_or_sections=_get_splits( + min_num_dims=1, allow_none=False, is_mod_split=True + ), + axis=st.shared( + helpers.get_axis( + shape=st.shared(helpers.get_shape(min_num_dims=1), key="value_shape"), + force_int=True, + ), + key="target_axis", + ), + test_with_out=st.just(False), ) -def test_jax_column_stack( - dtype_and_x, - factor, +def test_jax_split( + *, + dtype_value, + indices_or_sections, + axis, + on_device, + fn_tree, frontend, backend_fw, test_flags, - fn_tree, +): + input_dtype, value = dtype_value + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + ary=value[0], + indices_or_sections=indices_or_sections, + axis=axis, + ) + + +# squeeze +@handle_frontend_test( + fn_tree="jax.numpy.squeeze", + dtype_and_values=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + shape=st.shared(helpers.get_shape(), key="shape"), + ), + axis=_squeeze_helper(), + test_with_out=st.just(False), +) +def test_jax_squeeze( + *, + dtype_and_values, + axis, on_device, + fn_tree, + frontend, + backend_fw, + test_flags, ): - dtype, x = dtype_and_x - ys = [x[0]] - for i in range(factor): - ys += [x[0]] + input_dtype, values = dtype_and_values helpers.test_frontend_function( - input_dtypes=[dtype[0]] * (factor + 1), + input_dtypes=input_dtype, + backend_to_test=backend_fw, frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + a=values[0], + axis=axis, + ) + + +# stack +@handle_frontend_test( + fn_tree="jax.numpy.stack", + dtype_values_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("float"), + num_arrays=st.shared(helpers.ints(min_value=2, max_value=4), key="num_arrays"), + shape=helpers.get_shape(min_num_dims=1), + shared_dtype=True, + valid_axis=True, + allow_neg_axes=True, + force_int_axis=True, + ), + dtype=helpers.get_dtypes("valid", full=False), +) +def test_jax_stack( + dtype_values_axis, + dtype, + on_device, + fn_tree, + frontend, + backend_fw, + test_flags, +): + input_dtype, values, axis = dtype_values_axis + helpers.test_frontend_function( + input_dtypes=input_dtype, backend_to_test=backend_fw, + frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - tup=ys, - ) - - -# pad -@st.composite -def _pad_helper(draw): - mode = draw( - st.sampled_from( - [ - "constant", - "edge", - "linear_ramp", - "maximum", - "mean", - "median", - "minimum", - "reflect", - "symmetric", - "wrap", - ] - ) - ) - if mode == "median": - dtypes = "float" - else: - dtypes = "numeric" - dtype, input, shape = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes(dtypes), - ret_shape=True, - min_num_dims=1, - min_value=-100, - max_value=100, - ).filter( - lambda x: x[0][0] not in ["float16", "bfloat16", "complex64", "complex128"] - ), + arrays=values, + axis=axis, ) - ndim = len(shape) - pad_width = draw(_st_tuples_or_int(ndim, min_val=0)) - kwargs = {} - if mode == "reflect" or mode == "symmetric": - kwargs["reflect_type"] = draw(st.sampled_from(["even", "odd"])) - if mode in ["maximum", "mean", "median", "minimum"]: - kwargs["stat_length"] = draw(_st_tuples_or_int(ndim, min_val=2)) - if mode in ["linear_ramp"]: - kwargs["end_values"] = draw(_st_tuples_or_int(ndim)) - if mode == "constant": - kwargs["constant_values"] = draw(_st_tuples_or_int(ndim)) - return dtype, input[0], pad_width, kwargs, mode @handle_frontend_test( - fn_tree="jax.numpy.pad", - dtype_and_input_and_other=_pad_helper(), + fn_tree="jax.numpy.swapaxes", + input_x_axis1_axis2=_get_input_and_two_swapabble_axes(), test_with_out=st.just(False), ) -def test_jax_pad( +def test_jax_swapaxes( *, - dtype_and_input_and_other, - frontend, - backend_fw, + input_x_axis1_axis2, test_flags, - fn_tree, on_device, + fn_tree, + frontend, + backend_fw, ): - ( - dtype, - input, - pad_width, - kwargs, - mode, - ) = dtype_and_input_and_other - - if isinstance(pad_width, int): - pad_width = ((pad_width, pad_width),) * input.ndim - else: - pad_width = tuple( - tuple(pair) if isinstance(pair, list) else pair for pair in pad_width - ) - + x_dtype, x, axis1, axis2 = input_x_axis1_axis2 helpers.test_frontend_function( - input_dtypes=dtype, + input_dtypes=x_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - array=input, - pad_width=pad_width, - mode=mode, - **kwargs, + a=x[0], + axis1=axis1, + axis2=axis2, ) -# hamming +# take @handle_frontend_test( - fn_tree="jax.numpy.hamming", - m=helpers.ints(min_value=0, max_value=20), + fn_tree="jax.numpy.take", + dtype_indices_axis=helpers.array_indices_axis( + array_dtypes=helpers.get_dtypes("numeric"), + indices_dtypes=["int32", "int64"], + min_num_dims=1, + max_num_dims=5, + min_dim_size=1, + max_dim_size=10, + indices_same_dims=True, + ), ) -def test_jax_hamming( - m, +def test_jax_take( + *, + dtype_indices_axis, + on_device, + fn_tree, frontend, backend_fw, test_flags, - fn_tree, - on_device, ): + input_dtypes, value, indices, axis, _ = dtype_indices_axis helpers.test_frontend_function( - input_dtypes=["int64"], + input_dtypes=input_dtypes, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - M=m, + a=value, + indices=indices, + axis=axis, ) -# hanning +# tile @handle_frontend_test( - fn_tree="jax.numpy.hanning", - m=helpers.ints(min_value=0, max_value=20), + fn_tree="jax.numpy.tile", + dtype_value=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + shape=st.shared(helpers.get_shape(min_num_dims=1), key="value_shape"), + ), + repeat=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("signed_integer"), + shape=st.shared(helpers.get_shape(min_num_dims=1), key="value_shape").map( + lambda rep: (len(rep),) + ), + min_value=0, + max_value=10, + ), + test_with_out=st.just(False), ) -def test_jax_hanning( - m, +def test_jax_tile( + *, + dtype_value, + repeat, + on_device, + fn_tree, frontend, backend_fw, test_flags, - fn_tree, - on_device, ): + dtype, value = dtype_value + repeat_dtype, repeat_list = repeat helpers.test_frontend_function( - input_dtypes=["int64"], - frontend=frontend, + input_dtypes=dtype + repeat_dtype, backend_to_test=backend_fw, + frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - M=m, + A=value[0], + reps=repeat_list[0], ) -# kaiser +# transpose @handle_frontend_test( - fn_tree="jax.numpy.kaiser", - m=helpers.ints(min_value=0, max_value=100), - beta=helpers.floats(min_value=-10, max_value=10), + fn_tree="jax.numpy.transpose", + array_and_axes=np_frontend_helpers._array_and_axes_permute_helper( + min_num_dims=0, + max_num_dims=5, + min_dim_size=0, + max_dim_size=10, + ), + test_with_out=st.just(False), ) -def test_jax_kaiser( - m, - beta, +def test_jax_transpose( + *, + array_and_axes, + on_device, + fn_tree, frontend, - backend_fw, test_flags, - fn_tree, - on_device, + backend_fw, ): + array, dtype, axes = array_and_axes helpers.test_frontend_function( - input_dtypes=["int64", "float64"], + input_dtypes=dtype, frontend=frontend, backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - M=m, - beta=beta, + a=array, + axes=axes, ) @@ -1783,25 +1744,72 @@ def test_jax_tri( ) -# blackman +# tril @handle_frontend_test( - fn_tree="jax.numpy.blackman", - m=helpers.ints(min_value=0, max_value=20), + fn_tree="jax.numpy.tril", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + min_num_dims=2, + max_num_dims=5, + min_dim_size=1, + max_dim_size=5, + ), + k=helpers.ints(min_value=-10, max_value=10), + test_with_out=st.just(False), ) -def test_jax_blackman( - m, +def test_jax_tril( + *, + dtype_and_x, + k, + on_device, + fn_tree, frontend, backend_fw, test_flags, - fn_tree, - on_device, ): + dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=["int64"], + input_dtypes=dtype, + backend_to_test=backend_fw, frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + m=x[0], + k=k, + ) + + +# vsplit +@handle_frontend_test( + fn_tree="jax.numpy.vsplit", + dtype_value=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + shape=st.shared(helpers.get_shape(min_num_dims=2), key="value_shape"), + ), + indices_or_sections=_get_splits( + min_num_dims=2, axis=0, allow_none=False, is_mod_split=True + ), + test_with_out=st.just(False), +) +def test_jax_vsplit( + *, + dtype_value, + indices_or_sections, + on_device, + fn_tree, + frontend, + backend_fw, + test_flags, +): + input_dtype, value = dtype_value + helpers.test_frontend_function( + input_dtypes=input_dtype, backend_to_test=backend_fw, + frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - M=m, + ary=value[0], + indices_or_sections=indices_or_sections, ) diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_mathematical_functions.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_mathematical_functions.py index 0aa8cb07fd555..852293a989fdb 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_mathematical_functions.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_mathematical_functions.py @@ -17,6 +17,72 @@ ) +# --- Helpers --- # +# --------------- # + + +# trapz +@st.composite +def _either_x_dx(draw): + dtype_values_axis = draw( + helpers.dtype_values_axis( + available_dtypes=st.shared(helpers.get_dtypes("float"), key="trapz_dtype"), + min_value=-100, + max_value=100, + min_num_dims=1, + max_num_dims=3, + min_dim_size=1, + max_dim_size=3, + allow_neg_axes=True, + valid_axis=True, + force_int_axis=True, + ), + ) + rand = (draw(st.integers(min_value=0, max_value=1)),) + if rand == 0: + either_x_dx = draw( + helpers.dtype_and_x( + avaliable_dtypes=st.shared( + helpers.get_dtypes("float"), key="trapz_dtype" + ), + min_value=-100, + max_value=100, + min_num_dims=1, + max_num_dims=3, + min_dim_size=1, + max_dim_size=3, + ) + ) + return dtype_values_axis, rand, either_x_dx + else: + either_x_dx = draw( + st.floats(min_value=-10, max_value=10), + ) + return dtype_values_axis, rand, either_x_dx + + +# polyint +@st.composite +def _get_array_values_m_and_k(draw): + dtype_and_x = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + num_arrays=1, + min_num_dims=1, + max_num_dims=1, + min_dim_size=1, + ) + ) + dtype, x = dtype_and_x + m = draw(st.integers(min_value=0, max_value=10)) + max_bound = m - 1 + if max_bound <= m: + k = None + else: + k = draw(st.integers(min_value=0, max_value=max_bound)) + return dtype, x, m, k + + @st.composite def _get_castable_dtypes_values(draw, *, allow_nan=False, use_where=False): available_dtypes = helpers.get_dtypes("numeric") @@ -42,33 +108,48 @@ def _get_castable_dtypes_values(draw, *, allow_nan=False, use_where=False): return [dtype1], [values], axis, dtype2 -# sign -@handle_frontend_test( - fn_tree="jax.numpy.sign", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), min_num_dims=1 - ), - test_with_out=st.just(False), -) -def test_jax_sign( - *, - dtype_and_x, - test_flags, - on_device, - fn_tree, - frontend, - backend_fw, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - test_flags=test_flags, - frontend=frontend, - backend_to_test=backend_fw, - fn_tree=fn_tree, - on_device=on_device, - a=x[0], - ) +# diff +@st.composite +def _get_dtype_input_and_vector(draw): + size1 = draw(helpers.ints(min_value=1, max_value=5)) + size2 = draw(helpers.ints(min_value=1, max_value=5)) + dtype = draw(helpers.get_dtypes("integer")) + vec1 = draw(helpers.array_values(dtype=dtype[0], shape=(size1, size2))) + return dtype, vec1 + + +# dot +@st.composite +def _get_dtype_input_and_vectors(draw): + dim_size = draw(helpers.ints(min_value=1, max_value=5)) + dtype = draw(helpers.get_dtypes("float", index=1, full=False)) + if dim_size == 1: + vec1 = draw( + helpers.array_values( + dtype=dtype[0], shape=(dim_size,), min_value=2, max_value=5 + ) + ) + vec2 = draw( + helpers.array_values( + dtype=dtype[0], shape=(dim_size,), min_value=2, max_value=5 + ) + ) + else: + vec1 = draw( + helpers.array_values( + dtype=dtype[0], shape=(dim_size, dim_size), min_value=2, max_value=5 + ) + ) + vec2 = draw( + helpers.array_values( + dtype=dtype[0], shape=(dim_size, dim_size), min_value=2, max_value=5 + ) + ) + return dtype, vec1, vec2 + + +# --- Main --- # +# ------------ # # absolute @@ -173,73 +254,103 @@ def test_jax_angle( ) -# diff -@st.composite -def _get_dtype_input_and_vector(draw): - size1 = draw(helpers.ints(min_value=1, max_value=5)) - size2 = draw(helpers.ints(min_value=1, max_value=5)) - dtype = draw(helpers.get_dtypes("integer")) - vec1 = draw(helpers.array_values(dtype=dtype[0], shape=(size1, size2))) - return dtype, vec1 +# arccos +@handle_frontend_test( + fn_tree="jax.numpy.arccos", + dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), + test_with_out=st.just(False), +) +def test_jax_arccos( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + backend_fw, + test_flags, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + ) +# arccosh @handle_frontend_test( - fn_tree="jax.numpy.diff", - dtype_and_x=_get_dtype_input_and_vector(), - n=helpers.ints( - min_value=0, - max_value=10, - ), - axis=helpers.ints( - min_value=-1, - max_value=10, - ), + fn_tree="jax.numpy.arccosh", + dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), + test_with_out=st.just(False), ) -def test_jax_diff( +def test_jax_arccosh( *, dtype_and_x, - test_flags, on_device, fn_tree, frontend, backend_fw, - n, - axis, + test_flags, ): input_dtype, x = dtype_and_x - if axis > (x[0].ndim - 1): - axis = x[0].ndim - 1 helpers.test_frontend_function( input_dtypes=input_dtype, - test_flags=test_flags, frontend=frontend, backend_to_test=backend_fw, + test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - a=x[0], - n=n, - axis=axis, - prepend=None, - append=None, + x=x[0], ) -# ediff1d +# arcsin @handle_frontend_test( - fn_tree="jax.numpy.ediff1d", + fn_tree="jax.numpy.arcsin", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), min_num_dims=1, max_num_dims=1 - ), - to_end=helpers.ints( - min_value=-1, - max_value=10, + available_dtypes=helpers.get_dtypes("float"), + large_abs_safety_factor=4, + small_abs_safety_factor=4, ), - to_begin=helpers.ints( - min_value=-1, - max_value=10, +) +def test_jax_arcsin( + dtype_and_x, + frontend, + backend_fw, + test_flags, + fn_tree, + on_device, +): + input_dtype, x = dtype_and_x + + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + rtol=1e-2, + atol=1e-2, + ) + + +# arcsinh +@handle_frontend_test( + fn_tree="jax.numpy.arcsinh", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + large_abs_safety_factor=4, + small_abs_safety_factor=4, ), + test_with_out=st.just(False), ) -def test_jax_ediff1d( +def test_jax_arcsinh( *, dtype_and_x, on_device, @@ -247,20 +358,16 @@ def test_jax_ediff1d( frontend, backend_fw, test_flags, - to_end, - to_begin, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, frontend=frontend, backend_to_test=backend_fw, + test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - test_flags=test_flags, - ary=x[0], - to_end=to_end, - to_begin=to_begin, + x=x[0], ) @@ -320,55 +427,45 @@ def test_jax_arctan2( ) -# convolve +# arctanh @handle_frontend_test( - fn_tree="jax.numpy.convolve", + fn_tree="jax.numpy.arctanh", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - min_num_dims=1, - max_num_dims=1, - min_value=-1e04, - max_value=1e04, - shared_dtype=True, + min_num_dims=0, ), - mode=st.sampled_from(["valid", "same", "full"]), + test_with_out=st.just(False), ) -def test_jax_convolve( - *, +def test_jax_arctanh( dtype_and_x, - on_device, - fn_tree, frontend, backend_fw, test_flags, - mode, + fn_tree, ): input_dtype, x = dtype_and_x - assume("float16" not in input_dtype) helpers.test_frontend_function( input_dtypes=input_dtype, frontend=frontend, backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, - rtol=1e-2, - atol=1e-2, - on_device=on_device, - a=x[0], - v=x[1], - mode=mode, - precision=None, + x=x[0], ) +# around @handle_frontend_test( - fn_tree="jax.numpy.cos", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), + fn_tree="jax.numpy.around", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + ), + decimals=st.integers(min_value=0, max_value=5), ) -def test_jax_cos( +def test_jax_around( *, dtype_and_x, + decimals, on_device, fn_tree, frontend, @@ -383,17 +480,20 @@ def test_jax_cos( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], + a=x[0], + decimals=decimals, ) -# cosh +# cbrt @handle_frontend_test( - fn_tree="jax.numpy.cosh", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), + fn_tree="jax.numpy.cbrt", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + ), test_with_out=st.just(False), ) -def test_jax_cosh( +def test_jax_cbrt( *, dtype_and_x, on_device, @@ -410,17 +510,19 @@ def test_jax_cosh( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, + rtol=1e-2, + atol=1e-2, x=x[0], ) -# tanh +# ceil @handle_frontend_test( - fn_tree="jax.numpy.tanh", + fn_tree="jax.numpy.ceil", dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), test_with_out=st.just(False), ) -def test_jax_tanh( +def test_jax_ceil( *, dtype_and_x, on_device, @@ -441,70 +543,120 @@ def test_jax_tanh( ) -# sinh +# conj @handle_frontend_test( - fn_tree="jax.numpy.sinh", + fn_tree="jax.numpy.conj", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - large_abs_safety_factor=4, - small_abs_safety_factor=4, + available_dtypes=helpers.get_dtypes("valid"), ), - test_with_out=st.just(False), ) -def test_jax_sinh( +def test_jax_conj( *, dtype_and_x, + test_flags, on_device, fn_tree, frontend, backend_fw, - test_flags, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, + test_flags=test_flags, frontend=frontend, backend_to_test=backend_fw, - test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, x=x[0], ) +# TODO: uncomment with multiversion pipeline (deprecated since 0.4.12) +# @handle_frontend_test( +# fn_tree="jax.numpy.product", +# dtype_x_axis_dtype_where=_get_castable_dtypes_values(use_where=True), +# keepdims=st.booleans(), +# initial=st.one_of(st.floats(min_value=-100, max_value=100)), +# promote_integers=st.booleans(), +# ) +# def test_jax_product( +# dtype_x_axis_dtype_where, +# keepdims, +# initial, +# promote_integers, +# frontend, +# test_flags, +# fn_tree, +# on_device, +# ): +# input_dtypes, x, axis, dtype, where = dtype_x_axis_dtype_where +# if ivy.current_backend_str() == "torch": +# assume(not test_flags.as_variable[0]) +# where, input_dtypes, test_flags = np_frontend_helpers. +# handle_where_and_array_bools( +# where=where, +# input_dtype=input_dtypes, +# test_flags=test_flags, +# ) +# helpers.test_frontend_function( +# input_dtypes=input_dtypes, +# frontend=frontend, +# test_flags=test_flags, +# fn_tree=fn_tree, +# on_device=on_device, +# a=x[0], +# axis=axis, +# dtype=dtype, +# keepdims=keepdims, +# initial=initial, +# where=where, +# promote_integers=promote_integers, +# ) + + +# conjugate @handle_frontend_test( - fn_tree="jax.numpy.sin", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), - test_with_out=st.just(False), + fn_tree="jax.numpy.conjugate", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + ), ) -def test_jax_sin( +def test_jax_conjugate( *, dtype_and_x, + test_flags, on_device, fn_tree, frontend, backend_fw, - test_flags, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, + test_flags=test_flags, frontend=frontend, backend_to_test=backend_fw, - test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, x=x[0], ) -# floor +# convolve @handle_frontend_test( - fn_tree="jax.numpy.floor", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), - test_with_out=st.just(False), + fn_tree="jax.numpy.convolve", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, + min_num_dims=1, + max_num_dims=1, + min_value=-1e04, + max_value=1e04, + shared_dtype=True, + ), + mode=st.sampled_from(["valid", "same", "full"]), ) -def test_jax_floor( +def test_jax_convolve( *, dtype_and_x, on_device, @@ -512,101 +664,97 @@ def test_jax_floor( frontend, backend_fw, test_flags, + mode, ): input_dtype, x = dtype_and_x + assume("float16" not in input_dtype) helpers.test_frontend_function( input_dtypes=input_dtype, frontend=frontend, backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, + rtol=1e-2, + atol=1e-2, on_device=on_device, - x=x[0], + a=x[0], + v=x[1], + mode=mode, + precision=None, ) -# tensordot +# copysign @handle_frontend_test( - fn_tree="jax.numpy.tensordot", - dtype_values_and_axes=_get_dtype_value1_value2_axis_for_tensordot( - helpers.get_dtypes(kind="numeric") + fn_tree="jax.numpy.copysign", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_value=-100, + max_value=100, + min_num_dims=1, + max_num_dims=3, + min_dim_size=1, + max_dim_size=3, + num_arrays=2, + shared_dtype=True, ), test_with_out=st.just(False), ) -def test_jax_tensordot( - dtype_values_and_axes, +def test_jax_copysign( + *, + dtype_and_x, + on_device, + fn_tree, frontend, backend_fw, test_flags, - fn_tree, ): - dtype, a, b, axes = dtype_values_and_axes - if ivy.current_backend_str() == "torch": - atol = 1e-3 - else: - atol = 1e-6 + input_dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=dtype, + input_dtypes=input_dtype, frontend=frontend, backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, - a=a, - b=b, - atol=atol, - axes=axes, + on_device=on_device, + x1=x[0], + x2=x[0], ) -# divide @handle_frontend_test( - fn_tree="jax.numpy.divide", - aliases=["jax.numpy.true_divide"], - dtype_values=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, - allow_inf=False, - large_abs_safety_factor=2, - small_abs_safety_factor=2, - safety_factor_scale="log", - shared_dtype=True, - ), - test_with_out=st.just(False), + fn_tree="jax.numpy.cos", + dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), ) -def test_jax_divide( +def test_jax_cos( *, - dtype_values, + dtype_and_x, + on_device, + fn_tree, frontend, backend_fw, test_flags, - fn_tree, ): - input_dtype, x = dtype_values - assume(not np.any(np.isclose(x[1], 0))) - if ivy.current_backend_str() == "paddle": - atol, rtol = 1e-2, 1e-2 - else: - atol, rtol = 1e-5, 1e-5 + input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, frontend=frontend, backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, - a=x[0], - b=x[1], - atol=atol, - rtol=rtol, + on_device=on_device, + x=x[0], ) -# exp +# cosh @handle_frontend_test( - fn_tree="jax.numpy.exp", + fn_tree="jax.numpy.cosh", dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), test_with_out=st.just(False), ) -def test_jax_exp( +def test_jax_cosh( + *, dtype_and_x, on_device, fn_tree, @@ -614,9 +762,9 @@ def test_jax_exp( backend_fw, test_flags, ): - dtype, x = dtype_and_x + input_dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=dtype, + input_dtypes=input_dtype, frontend=frontend, backend_to_test=backend_fw, test_flags=test_flags, @@ -626,79 +774,44 @@ def test_jax_exp( ) -# dot -@st.composite -def _get_dtype_input_and_vectors(draw): - dim_size = draw(helpers.ints(min_value=1, max_value=5)) - dtype = draw(helpers.get_dtypes("float", index=1, full=False)) - if dim_size == 1: - vec1 = draw( - helpers.array_values( - dtype=dtype[0], shape=(dim_size,), min_value=2, max_value=5 - ) - ) - vec2 = draw( - helpers.array_values( - dtype=dtype[0], shape=(dim_size,), min_value=2, max_value=5 - ) - ) - else: - vec1 = draw( - helpers.array_values( - dtype=dtype[0], shape=(dim_size, dim_size), min_value=2, max_value=5 - ) - ) - vec2 = draw( - helpers.array_values( - dtype=dtype[0], shape=(dim_size, dim_size), min_value=2, max_value=5 - ) - ) - return dtype, vec1, vec2 - - +# deg2rad @handle_frontend_test( - fn_tree="jax.numpy.dot", - dtype_x_y=_get_dtype_input_and_vectors(), + fn_tree="jax.numpy.deg2rad", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + ), test_with_out=st.just(False), ) -def test_jax_dot( +def test_jax_deg2rad( *, - dtype_x_y, + dtype_and_x, on_device, fn_tree, frontend, backend_fw, test_flags, ): - input_dtype, x, y = dtype_x_y + input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, - rtol=1e-01, - atol=1e-01, frontend=frontend, backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - a=x, - b=y, - precision=None, + x=x[0], ) -# mod +# degrees @handle_frontend_test( - fn_tree="jax.numpy.mod", + fn_tree="jax.numpy.degrees", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, - large_abs_safety_factor=2, - small_abs_safety_factor=2, - safety_factor_scale="log", + available_dtypes=helpers.get_dtypes("float"), ), test_with_out=st.just(False), ) -def test_jax_mod( +def test_jax_degrees( *, dtype_and_x, on_device, @@ -708,7 +821,6 @@ def test_jax_mod( test_flags, ): input_dtype, x = dtype_and_x - assume(not np.any(np.isclose(x[1], 0)) and "bfloat16" not in input_dtype) helpers.test_frontend_function( input_dtypes=input_dtype, frontend=frontend, @@ -716,39 +828,90 @@ def test_jax_mod( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x1=x[0], - x2=x[1], + x=x[0], ) -# modf @handle_frontend_test( - fn_tree="jax.numpy.modf", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float_and_integer"), + fn_tree="jax.numpy.diff", + dtype_and_x=_get_dtype_input_and_vector(), + n=helpers.ints( + min_value=0, + max_value=10, + ), + axis=helpers.ints( + min_value=-1, + max_value=10, + ), +) +def test_jax_diff( + *, + dtype_and_x, + test_flags, + on_device, + fn_tree, + frontend, + backend_fw, + n, + axis, +): + input_dtype, x = dtype_and_x + if axis > (x[0].ndim - 1): + axis = x[0].ndim - 1 + helpers.test_frontend_function( + input_dtypes=input_dtype, + test_flags=test_flags, + frontend=frontend, + backend_to_test=backend_fw, + fn_tree=fn_tree, + on_device=on_device, + a=x[0], + n=n, + axis=axis, + prepend=None, + append=None, + ) + + +# divide +@handle_frontend_test( + fn_tree="jax.numpy.divide", + aliases=["jax.numpy.true_divide"], + dtype_values=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, + allow_inf=False, large_abs_safety_factor=2, small_abs_safety_factor=2, safety_factor_scale="log", + shared_dtype=True, ), test_with_out=st.just(False), ) -def test_jax_modf( - dtype_and_x, +def test_jax_divide( + *, + dtype_values, frontend, backend_fw, test_flags, fn_tree, - on_device, ): - input_dtype, x = dtype_and_x + input_dtype, x = dtype_values + assume(not np.any(np.isclose(x[1], 0))) + if ivy.current_backend_str() == "paddle": + atol, rtol = 1e-2, 1e-2 + else: + atol, rtol = 1e-5, 1e-5 helpers.test_frontend_function( input_dtypes=input_dtype, frontend=frontend, backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, - on_device=on_device, - x=x[0], + a=x[0], + b=x[1], + atol=atol, + rtol=rtol, ) @@ -789,40 +952,52 @@ def test_jax_divmod( ) -# tan @handle_frontend_test( - fn_tree="jax.numpy.tan", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), + fn_tree="jax.numpy.dot", + dtype_x_y=_get_dtype_input_and_vectors(), test_with_out=st.just(False), ) -def test_jax_tan( +def test_jax_dot( *, - dtype_and_x, + dtype_x_y, on_device, fn_tree, frontend, backend_fw, test_flags, ): - input_dtype, x = dtype_and_x + input_dtype, x, y = dtype_x_y helpers.test_frontend_function( input_dtypes=input_dtype, + rtol=1e-01, + atol=1e-01, frontend=frontend, backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], + a=x, + b=y, + precision=None, ) -# arccos +# ediff1d @handle_frontend_test( - fn_tree="jax.numpy.arccos", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), - test_with_out=st.just(False), + fn_tree="jax.numpy.ediff1d", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), min_num_dims=1, max_num_dims=1 + ), + to_end=helpers.ints( + min_value=-1, + max_value=10, + ), + to_begin=helpers.ints( + min_value=-1, + max_value=10, + ), ) -def test_jax_arccos( +def test_jax_ediff1d( *, dtype_and_x, on_device, @@ -830,27 +1005,30 @@ def test_jax_arccos( frontend, backend_fw, test_flags, + to_end, + to_begin, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, frontend=frontend, backend_to_test=backend_fw, - test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], + test_flags=test_flags, + ary=x[0], + to_end=to_end, + to_begin=to_begin, ) -# arccosh +# exp @handle_frontend_test( - fn_tree="jax.numpy.arccosh", + fn_tree="jax.numpy.exp", dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), test_with_out=st.just(False), ) -def test_jax_arccosh( - *, +def test_jax_exp( dtype_and_x, on_device, fn_tree, @@ -858,9 +1036,9 @@ def test_jax_arccosh( backend_fw, test_flags, ): - input_dtype, x = dtype_and_x + dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=dtype, frontend=frontend, backend_to_test=backend_fw, test_flags=test_flags, @@ -870,25 +1048,30 @@ def test_jax_arccosh( ) -# arcsin +# exp2 @handle_frontend_test( - fn_tree="jax.numpy.arcsin", + fn_tree="jax.numpy.exp2", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - large_abs_safety_factor=4, - small_abs_safety_factor=4, + min_value=-10, + max_value=10, + min_num_dims=1, + max_num_dims=3, + min_dim_size=1, + max_dim_size=3, ), + test_with_out=st.just(False), ) -def test_jax_arcsin( +def test_jax_exp2( + *, dtype_and_x, + on_device, + fn_tree, frontend, backend_fw, test_flags, - fn_tree, - on_device, ): input_dtype, x = dtype_and_x - helpers.test_frontend_function( input_dtypes=input_dtype, frontend=frontend, @@ -897,31 +1080,29 @@ def test_jax_arcsin( fn_tree=fn_tree, on_device=on_device, x=x[0], - rtol=1e-2, - atol=1e-2, + rtol=1e-01, + atol=1e-02, ) -# log1p +# expm1 @handle_frontend_test( - fn_tree="jax.numpy.log1p", + fn_tree="jax.numpy.expm1", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - large_abs_safety_factor=2, - small_abs_safety_factor=2, - safety_factor_scale="log", ), + test_with_out=st.just(False), ) -def test_jax_log1p( +def test_jax_expm1( + *, dtype_and_x, + on_device, + fn_tree, frontend, backend_fw, test_flags, - fn_tree, - on_device, ): input_dtype, x = dtype_and_x - helpers.test_frontend_function( input_dtypes=input_dtype, frontend=frontend, @@ -933,17 +1114,14 @@ def test_jax_log1p( ) -# arcsinh +# fabs @handle_frontend_test( - fn_tree="jax.numpy.arcsinh", + fn_tree="jax.numpy.fabs", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - large_abs_safety_factor=4, - small_abs_safety_factor=4, ), - test_with_out=st.just(False), ) -def test_jax_arcsinh( +def test_jax_fabs( *, dtype_and_x, on_device, @@ -964,16 +1142,18 @@ def test_jax_arcsinh( ) -# power +# fix @handle_frontend_test( - fn_tree="jax.numpy.power", + fn_tree="jax.numpy.fix", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, + available_dtypes=helpers.get_dtypes("float", index=2), + min_num_dims=1, + max_num_dims=3, + min_dim_size=1, + max_dim_size=3, ), - test_with_out=st.just(False), ) -def test_jax_power( +def test_jax_fix( *, dtype_and_x, on_device, @@ -990,23 +1170,32 @@ def test_jax_power( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x1=x[0], - x2=x[1], + test_values=False, + x=x[0], ) -# trunc +# float_power @handle_frontend_test( - fn_tree="jax.numpy.trunc", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), - test_with_out=st.just(False), -) -def test_jax_trunc( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, + fn_tree="jax.numpy.float_power", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_value=-10, + max_value=10, + num_arrays=2, + min_num_dims=1, + max_num_dims=3, + min_dim_size=1, + max_dim_size=3, + ), + test_with_out=st.just(False), +) +def test_jax_float_power( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, backend_fw, test_flags, ): @@ -1018,17 +1207,18 @@ def test_jax_trunc( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], + x1=x[0], + x2=x[1], ) -# ceil +# floor @handle_frontend_test( - fn_tree="jax.numpy.ceil", + fn_tree="jax.numpy.floor", dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), test_with_out=st.just(False), ) -def test_jax_ceil( +def test_jax_floor( *, dtype_and_x, on_device, @@ -1049,61 +1239,68 @@ def test_jax_ceil( ) -# float_power +# floor_divide @handle_frontend_test( - fn_tree="jax.numpy.float_power", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_value=-10, - max_value=10, + fn_tree="jax.numpy.floor_divide", + dtype_values=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2, - min_num_dims=1, - max_num_dims=3, - min_dim_size=1, - max_dim_size=3, + shared_dtype=True, + min_value=-10.0, + max_value=10.0, + large_abs_safety_factor=2, + small_abs_safety_factor=2, + safety_factor_scale="linear", ), - test_with_out=st.just(False), ) -def test_jax_float_power( +def test_jax_floor_divide( *, - dtype_and_x, - on_device, - fn_tree, + dtype_values, frontend, backend_fw, + fn_tree, + on_device, test_flags, ): - input_dtype, x = dtype_and_x + input_dtype, x = dtype_values + # Making sure division by zero doesn't occur + assume(not np.any(np.isclose(x[1], 0))) + # Absolute tolerance is 1, + # due to flooring can cause absolute error of 1 due to precision helpers.test_frontend_function( input_dtypes=input_dtype, + on_device=on_device, + test_flags=test_flags, frontend=frontend, backend_to_test=backend_fw, - test_flags=test_flags, fn_tree=fn_tree, - on_device=on_device, x1=x[0], x2=x[1], + atol=1, ) -# deg2rad +# fmax @handle_frontend_test( - fn_tree="jax.numpy.deg2rad", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + fn_tree="jax.numpy.fmax", + dtype_and_inputs=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, + min_value=-np.inf, + max_value=np.inf, ), test_with_out=st.just(False), ) -def test_jax_deg2rad( +def test_jax_fmax( *, - dtype_and_x, + dtype_and_inputs, on_device, fn_tree, frontend, backend_fw, test_flags, ): - input_dtype, x = dtype_and_x + input_dtype, inputs = dtype_and_inputs helpers.test_frontend_function( input_dtypes=input_dtype, frontend=frontend, @@ -1111,28 +1308,31 @@ def test_jax_deg2rad( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], + x1=inputs[0], + x2=inputs[1], ) -# radians +# fmin @handle_frontend_test( - fn_tree="jax.numpy.radians", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + fn_tree="jax.numpy.fmin", + dtype_and_inputs=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, + min_value=-np.inf, + max_value=np.inf, ), - test_with_out=st.just(False), ) -def test_jax_radians( +def test_jax_fmin( *, - dtype_and_x, + dtype_and_inputs, on_device, fn_tree, frontend, backend_fw, test_flags, ): - input_dtype, x = dtype_and_x + input_dtype, inputs = dtype_and_inputs helpers.test_frontend_function( input_dtypes=input_dtype, frontend=frontend, @@ -1140,34 +1340,33 @@ def test_jax_radians( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], + x1=inputs[0], + x2=inputs[1], ) -# exp2 +# fmod @handle_frontend_test( - fn_tree="jax.numpy.exp2", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_value=-10, - max_value=10, - min_num_dims=1, - max_num_dims=3, - min_dim_size=1, - max_dim_size=3, + fn_tree="jax.numpy.fmod", + dtype_and_inputs=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, + large_abs_safety_factor=1.5, + safety_factor_scale="log", ), test_with_out=st.just(False), ) -def test_jax_exp2( +def test_jax_fmod( *, - dtype_and_x, + dtype_and_inputs, on_device, fn_tree, frontend, backend_fw, test_flags, ): - input_dtype, x = dtype_and_x + input_dtype, x = dtype_and_inputs + assume(not np.any(np.isclose(x[1], 0))) helpers.test_frontend_function( input_dtypes=input_dtype, frontend=frontend, @@ -1175,35 +1374,35 @@ def test_jax_exp2( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], - rtol=1e-01, - atol=1e-02, + x1=x[0], + x2=x[1], ) -# expm1 +# frexp @handle_frontend_test( - fn_tree="jax.numpy.expm1", + fn_tree="jax.numpy.frexp", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), + min_value=1, + max_value=100, ), - test_with_out=st.just(False), ) -def test_jax_expm1( +def test_jax_frexp( *, dtype_and_x, + test_flags, on_device, fn_tree, frontend, backend_fw, - test_flags, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, + test_flags=test_flags, frontend=frontend, backend_to_test=backend_fw, - test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, x=x[0], @@ -1247,21 +1446,22 @@ def test_jax_gcd( ) -# i0 +# heaviside @handle_frontend_test( - fn_tree="jax.numpy.i0", + fn_tree="jax.numpy.heaviside", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - min_value=-10, - max_value=10, + min_value=-100, + max_value=100, min_num_dims=1, max_num_dims=3, min_dim_size=1, max_dim_size=3, + num_arrays=2, ), test_with_out=st.just(False), ) -def test_jax_i0( +def test_jax_heaviside( *, dtype_and_x, on_device, @@ -1278,63 +1478,62 @@ def test_jax_i0( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], + x1=x[0], + x2=x[0], ) -# kron +# hypot @handle_frontend_test( - fn_tree="jax.numpy.kron", - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - min_num_dims=2, - max_num_dims=2, - min_dim_size=1, - max_dim_size=3, + fn_tree="jax.numpy.hypot", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), num_arrays=2, + shared_dtype=True, + min_value=-100, + max_value=100, + min_num_dims=1, + max_num_dims=3, ), - test_with_out=st.just(False), ) -def test_jax_kron( +def test_jax_hypot( *, - dtype_x, + dtype_and_x, on_device, fn_tree, frontend, backend_fw, test_flags, ): - input_dtype, x = dtype_x + input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, frontend=frontend, - backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - a=x[0], - b=x[1], + atol=1e-2, + x1=x[0], + x2=x[1], + backend_to_test=backend_fw, ) -# lcm +# i0 @handle_frontend_test( - fn_tree="jax.numpy.lcm", + fn_tree="jax.numpy.i0", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("integer"), - num_arrays=2, + available_dtypes=helpers.get_dtypes("float"), + min_value=-10, + max_value=10, min_num_dims=1, max_num_dims=3, - min_value=-100, - max_value=100, - allow_nan=False, - small_abs_safety_factor=2, - large_abs_safety_factor=2, - safety_factor_scale="log", + min_dim_size=1, + max_dim_size=3, ), test_with_out=st.just(False), ) -def test_jax_lcm( +def test_jax_i0( *, dtype_and_x, on_device, @@ -1344,13 +1543,6 @@ def test_jax_lcm( test_flags, ): input_dtype, x = dtype_and_x - value_test = True - # Skip Tensorflow backend value test for lcm - # https://github.com/tensorflow/tensorflow/issues/58955 - if ivy.current_backend_str() == "tensorflow": - value_test = False - if ivy.current_backend_str() in ("jax", "numpy"): - assume(input_dtype[0] != "uint64" and input_dtype[1] != "uint64") helpers.test_frontend_function( input_dtypes=input_dtype, frontend=frontend, @@ -1358,72 +1550,101 @@ def test_jax_lcm( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x1=x[0], - x2=x[1], - test_values=value_test, + x=x[0], ) -# logaddexp2 +# imag @handle_frontend_test( - fn_tree="jax.numpy.logaddexp2", + fn_tree="jax.numpy.imag", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - shared_dtype=True, - min_num_dims=1, - max_num_dims=3, - min_value=-100, - max_value=100, - allow_nan=False, - large_abs_safety_factor=2, - small_abs_safety_factor=2, - safety_factor_scale="log", + available_dtypes=helpers.get_dtypes("float_and_complex"), + min_value=-20, + max_value=20, ), test_with_out=st.just(False), ) -def test_jax_logaddexp2( +def test_jax_imag( *, dtype_and_x, + test_flags, on_device, fn_tree, frontend, backend_fw, - test_flags, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, + test_flags=test_flags, frontend=frontend, backend_to_test=backend_fw, - test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - rtol=1e-01, - atol=1e-02, - x1=x[0], - x2=x[1], - ) - + rtol=1e-5, + atol=1e-5, + val=x[0], + ) -# matmul + +# inner @handle_frontend_test( - fn_tree="jax.numpy.matmul", - dtypes_values_casting=np_frontend_helpers.dtypes_values_casting_dtype( - arr_func=[_get_first_matrix_and_dtype, _get_second_matrix_and_dtype], + fn_tree="jax.numpy.inner", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + min_value=-10, + max_value=10, + num_arrays=2, + shared_dtype=True, ), ) -def test_jax_matmul( - dtypes_values_casting, - frontend, - backend_fw, +def test_jax_inner( + *, + dtype_and_x, test_flags, + on_device, fn_tree, + frontend, + backend_fw, +): + input_dtypes, xs = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtypes, + test_flags=test_flags, + frontend=frontend, + backend_to_test=backend_fw, + fn_tree=fn_tree, + on_device=on_device, + a=xs[0], + b=xs[1], + ) + + +# kron +@handle_frontend_test( + fn_tree="jax.numpy.kron", + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + min_num_dims=2, + max_num_dims=2, + min_dim_size=1, + max_dim_size=3, + num_arrays=2, + ), + test_with_out=st.just(False), +) +def test_jax_kron( + *, + dtype_x, on_device, + fn_tree, + frontend, + backend_fw, + test_flags, ): - dtypes, x, casting, dtype = dtypes_values_casting + input_dtype, x = dtype_x helpers.test_frontend_function( - input_dtypes=dtypes, + input_dtypes=input_dtype, frontend=frontend, backend_to_test=backend_fw, test_flags=test_flags, @@ -1431,73 +1652,43 @@ def test_jax_matmul( on_device=on_device, a=x[0], b=x[1], - precision=None, - ) - - -# trapz -@st.composite -def _either_x_dx(draw): - dtype_values_axis = draw( - helpers.dtype_values_axis( - available_dtypes=st.shared(helpers.get_dtypes("float"), key="trapz_dtype"), - min_value=-100, - max_value=100, - min_num_dims=1, - max_num_dims=3, - min_dim_size=1, - max_dim_size=3, - allow_neg_axes=True, - valid_axis=True, - force_int_axis=True, - ), ) - rand = (draw(st.integers(min_value=0, max_value=1)),) - if rand == 0: - either_x_dx = draw( - helpers.dtype_and_x( - avaliable_dtypes=st.shared( - helpers.get_dtypes("float"), key="trapz_dtype" - ), - min_value=-100, - max_value=100, - min_num_dims=1, - max_num_dims=3, - min_dim_size=1, - max_dim_size=3, - ) - ) - return dtype_values_axis, rand, either_x_dx - else: - either_x_dx = draw( - st.floats(min_value=-10, max_value=10), - ) - return dtype_values_axis, rand, either_x_dx +# lcm @handle_frontend_test( - fn_tree="jax.numpy.trapz", - dtype_x_axis_rand_either=_either_x_dx(), + fn_tree="jax.numpy.lcm", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("integer"), + num_arrays=2, + min_num_dims=1, + max_num_dims=3, + min_value=-100, + max_value=100, + allow_nan=False, + small_abs_safety_factor=2, + large_abs_safety_factor=2, + safety_factor_scale="log", + ), test_with_out=st.just(False), ) -def test_jax_trapz( +def test_jax_lcm( *, - dtype_x_axis_rand_either, + dtype_and_x, on_device, fn_tree, frontend, backend_fw, test_flags, ): - dtype_values_axis, rand, either_x_dx = dtype_x_axis_rand_either - input_dtype, y, axis = dtype_values_axis - if rand == 0: - dtype_x, x = either_x_dx - x = np.asarray(x, dtype=dtype_x) - dx = None - else: - x = None - dx = either_x_dx + input_dtype, x = dtype_and_x + value_test = True + # Skip Tensorflow backend value test for lcm + # https://github.com/tensorflow/tensorflow/issues/58955 + if ivy.current_backend_str() == "tensorflow": + value_test = False + if ivy.current_backend_str() in ("jax", "numpy"): + assume(input_dtype[0] != "uint64" and input_dtype[1] != "uint64") helpers.test_frontend_function( input_dtypes=input_dtype, frontend=frontend, @@ -1505,49 +1696,52 @@ def test_jax_trapz( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - rtol=1e-2, - atol=1e-2, - y=y[0], - x=x, - dx=dx, - axis=axis, + x1=x[0], + x2=x[1], + test_values=value_test, ) -# sqrt +# ldexp @handle_frontend_test( - fn_tree="jax.numpy.sqrt", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), - test_with_out=st.just(False), + fn_tree="jax.numpy.ldexp", + dtype_and_x=ldexp_args(), ) -def test_jax_sqrt( +def test_jax_ldexp( *, dtype_and_x, + test_flags, on_device, fn_tree, frontend, backend_fw, - test_flags, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, + test_flags=test_flags, frontend=frontend, backend_to_test=backend_fw, - test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], + x1=x[0], + x2=x[1], ) -# square +# log @handle_frontend_test( - fn_tree="jax.numpy.square", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), + fn_tree="jax.numpy.log", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_num_dims=1, + max_num_dims=3, + min_value=-100, + max_value=100, + ), test_with_out=st.just(False), ) -def test_jax_square( +def test_jax_log( *, dtype_and_x, on_device, @@ -1564,25 +1758,32 @@ def test_jax_square( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, + rtol=1e-01, + atol=1e-02, x=x[0], ) -# arctanh +# log10 @handle_frontend_test( - fn_tree="jax.numpy.arctanh", + fn_tree="jax.numpy.log10", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - min_num_dims=0, + min_num_dims=1, + max_num_dims=3, + min_value=-100, + max_value=100, ), test_with_out=st.just(False), ) -def test_jax_arctanh( +def test_jax_log10( + *, dtype_and_x, + on_device, + fn_tree, frontend, backend_fw, test_flags, - fn_tree, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -1591,51 +1792,53 @@ def test_jax_arctanh( backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, + on_device=on_device, + rtol=1e-01, + atol=1e-02, x=x[0], ) -# multiply +# log1p @handle_frontend_test( - fn_tree="jax.numpy.multiply", + fn_tree="jax.numpy.log1p", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, + available_dtypes=helpers.get_dtypes("float"), + large_abs_safety_factor=2, + small_abs_safety_factor=2, + safety_factor_scale="log", ), - test_with_out=st.just(False), ) -def test_jax_multiply( +def test_jax_log1p( dtype_and_x, frontend, backend_fw, test_flags, fn_tree, + on_device, ): input_dtype, x = dtype_and_x + helpers.test_frontend_function( input_dtypes=input_dtype, frontend=frontend, backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, - x1=x[0], - x2=x[1], + on_device=on_device, + x=x[0], ) -# log10 +# log2 @handle_frontend_test( - fn_tree="jax.numpy.log10", + fn_tree="jax.numpy.log2", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - min_num_dims=1, - max_num_dims=3, - min_value=-100, - max_value=100, ), test_with_out=st.just(False), ) -def test_jax_log10( +def test_jax_log2( *, dtype_and_x, on_device, @@ -1645,6 +1848,7 @@ def test_jax_log10( test_flags, ): input_dtype, x = dtype_and_x + assume(not np.any(np.isclose(x[0], 0))) helpers.test_frontend_function( input_dtypes=input_dtype, frontend=frontend, @@ -1652,8 +1856,7 @@ def test_jax_log10( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - rtol=1e-01, - atol=1e-02, + rtol=1e-2, x=x[0], ) @@ -1699,15 +1902,25 @@ def test_jax_logaddexp( ) -# degrees +# logaddexp2 @handle_frontend_test( - fn_tree="jax.numpy.degrees", + fn_tree="jax.numpy.logaddexp2", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, + shared_dtype=True, + min_num_dims=1, + max_num_dims=3, + min_value=-100, + max_value=100, + allow_nan=False, + large_abs_safety_factor=2, + small_abs_safety_factor=2, + safety_factor_scale="log", ), test_with_out=st.just(False), ) -def test_jax_degrees( +def test_jax_logaddexp2( *, dtype_and_x, on_device, @@ -1724,53 +1937,57 @@ def test_jax_degrees( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], + rtol=1e-01, + atol=1e-02, + x1=x[0], + x2=x[1], ) -# negative +# matmul @handle_frontend_test( - fn_tree="jax.numpy.negative", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), min_num_dims=1 + fn_tree="jax.numpy.matmul", + dtypes_values_casting=np_frontend_helpers.dtypes_values_casting_dtype( + arr_func=[_get_first_matrix_and_dtype, _get_second_matrix_and_dtype], ), - test_with_out=st.just(False), ) -def test_jax_negative( - dtype_and_x, +def test_jax_matmul( + dtypes_values_casting, frontend, backend_fw, test_flags, fn_tree, on_device, ): - input_dtype, x = dtype_and_x + dtypes, x, casting, dtype = dtypes_values_casting helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=dtypes, frontend=frontend, backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], + a=x[0], + b=x[1], + precision=None, ) -# positive +# maximum @handle_frontend_test( - fn_tree="jax.numpy.positive", + fn_tree="jax.numpy.maximum", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), min_num_dims=1 + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, ), test_with_out=st.just(False), ) -def test_jax_positive( +def test_jax_maximum( dtype_and_x, frontend, backend_fw, test_flags, fn_tree, - on_device, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -1779,26 +1996,26 @@ def test_jax_positive( backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, - on_device=on_device, - x=x[0], + x1=x[0], + x2=x[1], ) -# rad2deg +# minimum @handle_frontend_test( - fn_tree="jax.numpy.rad2deg", + fn_tree="jax.numpy.minimum", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), min_num_dims=1 + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, ), test_with_out=st.just(False), ) -def test_jax_rad2deg( +def test_jax_minimum( dtype_and_x, frontend, backend_fw, test_flags, fn_tree, - on_device, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -1807,64 +2024,34 @@ def test_jax_rad2deg( backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, - on_device=on_device, - x=x[0], + x1=x[0], + x2=x[1], ) -# fmax +# mod @handle_frontend_test( - fn_tree="jax.numpy.fmax", - dtype_and_inputs=helpers.dtype_and_values( + fn_tree="jax.numpy.mod", + dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2, - min_value=-np.inf, - max_value=np.inf, + large_abs_safety_factor=2, + small_abs_safety_factor=2, + safety_factor_scale="log", ), test_with_out=st.just(False), ) -def test_jax_fmax( - *, - dtype_and_inputs, - on_device, - fn_tree, - frontend, - backend_fw, - test_flags, -): - input_dtype, inputs = dtype_and_inputs - helpers.test_frontend_function( - input_dtypes=input_dtype, - frontend=frontend, - backend_to_test=backend_fw, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x1=inputs[0], - x2=inputs[1], - ) - - -# fmin -@handle_frontend_test( - fn_tree="jax.numpy.fmin", - dtype_and_inputs=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, - min_value=-np.inf, - max_value=np.inf, - ), -) -def test_jax_fmin( +def test_jax_mod( *, - dtype_and_inputs, + dtype_and_x, on_device, fn_tree, frontend, backend_fw, test_flags, ): - input_dtype, inputs = dtype_and_inputs + input_dtype, x = dtype_and_x + assume(not np.any(np.isclose(x[1], 0)) and "bfloat16" not in input_dtype) helpers.test_frontend_function( input_dtypes=input_dtype, frontend=frontend, @@ -1872,26 +2059,29 @@ def test_jax_fmin( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x1=inputs[0], - x2=inputs[1], + x1=x[0], + x2=x[1], ) -# fabs +# modf @handle_frontend_test( - fn_tree="jax.numpy.fabs", + fn_tree="jax.numpy.modf", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("float_and_integer"), + large_abs_safety_factor=2, + small_abs_safety_factor=2, + safety_factor_scale="log", ), + test_with_out=st.just(False), ) -def test_jax_fabs( - *, +def test_jax_modf( dtype_and_x, - on_device, - fn_tree, frontend, backend_fw, test_flags, + fn_tree, + on_device, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -1905,55 +2095,64 @@ def test_jax_fabs( ) -# fmod +# multiply @handle_frontend_test( - fn_tree="jax.numpy.fmod", - dtype_and_inputs=helpers.dtype_and_values( + fn_tree="jax.numpy.multiply", + dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2, - large_abs_safety_factor=1.5, - safety_factor_scale="log", ), test_with_out=st.just(False), ) -def test_jax_fmod( - *, - dtype_and_inputs, - on_device, - fn_tree, +def test_jax_multiply( + dtype_and_x, frontend, backend_fw, test_flags, + fn_tree, ): - input_dtype, x = dtype_and_inputs - assume(not np.any(np.isclose(x[1], 0))) + input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, frontend=frontend, backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, - on_device=on_device, x1=x[0], x2=x[1], ) -# maximum +# nan_to_num @handle_frontend_test( - fn_tree="jax.numpy.maximum", + fn_tree="jax.numpy.nan_to_num", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, + available_dtypes=helpers.get_dtypes("float"), + min_num_dims=1, + max_num_dims=3, + min_value=-100, + max_value=100, + allow_nan=True, + allow_inf=True, ), + copy=st.booleans(), + nan=st.floats(min_value=0.0, max_value=100), + posinf=st.floats(min_value=5e100, max_value=5e100), + neginf=st.floats(min_value=-5e100, max_value=-5e100), test_with_out=st.just(False), ) -def test_jax_maximum( +def test_jax_nan_to_num( + *, dtype_and_x, + copy, + nan, + posinf, + neginf, + on_device, + fn_tree, frontend, backend_fw, test_flags, - fn_tree, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -1962,26 +2161,30 @@ def test_jax_maximum( backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, - x1=x[0], - x2=x[1], + on_device=on_device, + x=x[0], + copy=copy, + nan=nan, + posinf=posinf, + neginf=neginf, ) -# minimum +# negative @handle_frontend_test( - fn_tree="jax.numpy.minimum", + fn_tree="jax.numpy.negative", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, + available_dtypes=helpers.get_dtypes("numeric"), min_num_dims=1 ), test_with_out=st.just(False), ) -def test_jax_minimum( +def test_jax_negative( dtype_and_x, frontend, backend_fw, test_flags, fn_tree, + on_device, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -1990,14 +2193,14 @@ def test_jax_minimum( backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, - x1=x[0], - x2=x[1], + on_device=on_device, + x=x[0], ) -# heaviside +# nextafter @handle_frontend_test( - fn_tree="jax.numpy.heaviside", + fn_tree="jax.numpy.nextafter", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), min_value=-100, @@ -2007,10 +2210,11 @@ def test_jax_minimum( min_dim_size=1, max_dim_size=3, num_arrays=2, + shared_dtype=True, ), test_with_out=st.just(False), ) -def test_jax_heaviside( +def test_jax_nextafter( *, dtype_and_x, on_device, @@ -2032,316 +2236,305 @@ def test_jax_heaviside( ) -# log +# outer @handle_frontend_test( - fn_tree="jax.numpy.log", + fn_tree="jax.numpy.outer", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, + min_value=-10, + max_value=10, min_num_dims=1, - max_num_dims=3, - min_value=-100, - max_value=100, + max_num_dims=1, + shared_dtype=True, ), - test_with_out=st.just(False), ) -def test_jax_log( +def test_jax_outer( *, dtype_and_x, + test_flags, on_device, fn_tree, frontend, backend_fw, - test_flags, ): - input_dtype, x = dtype_and_x + input_dtypes, xs = dtype_and_x helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=input_dtypes, + test_flags=test_flags, frontend=frontend, backend_to_test=backend_fw, - test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - rtol=1e-01, - atol=1e-02, - x=x[0], + a=xs[0], + b=xs[1], ) -# copysign +# poly @handle_frontend_test( - fn_tree="jax.numpy.copysign", + fn_tree="jax.numpy.poly", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - min_value=-100, - max_value=100, + num_arrays=1, min_num_dims=1, - max_num_dims=3, - min_dim_size=1, - max_dim_size=3, - num_arrays=2, - shared_dtype=True, + max_num_dims=1, + min_value=-1e04, + max_value=1e04, ), - test_with_out=st.just(False), ) -def test_jax_copysign( +def test_jax_poly( *, dtype_and_x, + test_flags, on_device, fn_tree, frontend, backend_fw, - test_flags, ): input_dtype, x = dtype_and_x + assume("float16" not in input_dtype) helpers.test_frontend_function( input_dtypes=input_dtype, + test_flags=test_flags, frontend=frontend, backend_to_test=backend_fw, - test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x1=x[0], - x2=x[0], + seq_of_zeros=x[0], + atol=1e-05, + rtol=1e-03, ) -# sinc +# polyadd @handle_frontend_test( - fn_tree="jax.numpy.sinc", + fn_tree="jax.numpy.polyadd", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, min_num_dims=1, - max_num_dims=3, - min_value=-100, - max_value=100, + max_num_dims=1, + min_dim_size=2, ), ) -def test_jax_sinc( +def test_jax_polyadd( *, dtype_and_x, + test_flags, on_device, fn_tree, frontend, backend_fw, - test_flags, ): input_dtype, x = dtype_and_x + assume("float16" not in input_dtype) helpers.test_frontend_function( input_dtypes=input_dtype, + test_flags=test_flags, frontend=frontend, backend_to_test=backend_fw, - test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - rtol=1e-01, - atol=1e-02, - x=x[0], + a1=x[0], + a2=x[1], ) -# nextafter +# polyder @handle_frontend_test( - fn_tree="jax.numpy.nextafter", + fn_tree="jax.numpy.polyder", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - min_value=-100, - max_value=100, + num_arrays=1, min_num_dims=1, - max_num_dims=3, + max_num_dims=1, min_dim_size=1, - max_dim_size=3, - num_arrays=2, - shared_dtype=True, ), - test_with_out=st.just(False), + m=st.integers(min_value=0, max_value=10), ) -def test_jax_nextafter( +def test_jax_polyder( *, dtype_and_x, + m, + test_flags, on_device, fn_tree, frontend, backend_fw, - test_flags, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, + test_flags=test_flags, frontend=frontend, backend_to_test=backend_fw, - test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x1=x[0], - x2=x[0], + p=x[0], + m=m, ) -# remainder +# polydiv @handle_frontend_test( - fn_tree="jax.numpy.remainder", + fn_tree="jax.numpy.polydiv", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), + available_dtypes=helpers.get_dtypes("float"), num_arrays=2, - large_abs_safety_factor=6, - small_abs_safety_factor=6, - safety_factor_scale="log", + min_num_dims=1, + min_dim_size=1, + max_num_dims=1, + min_value=-1e04, + max_value=1e04, ), - test_with_out=st.just(False), ) -def test_jax_remainder( +def test_jax_polydiv( *, dtype_and_x, + test_flags, on_device, fn_tree, frontend, backend_fw, - test_flags, ): input_dtype, x = dtype_and_x - - assume(not np.any(np.isclose(x[1], 0))) - + assume("float16" not in input_dtype) + # TODO: remove asumme when the decorator works helpers.test_frontend_function( input_dtypes=input_dtype, + test_flags=test_flags, frontend=frontend, backend_to_test=backend_fw, - test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x1=x[0], - x2=x[1], - rtol=1e-2, - atol=1e-2, + u=x[0], + v=x[1], + rtol=1e-01, + atol=1e-02, ) -# trace @handle_frontend_test( - fn_tree="jax.numpy.trace", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_num_dims=2, - max_num_dims=2, - min_dim_size=1, - max_dim_size=10, - large_abs_safety_factor=24, - small_abs_safety_factor=24, - safety_factor_scale="log", - ), - offset=st.integers(min_value=0, max_value=0), - axis1=st.integers(min_value=0, max_value=0), - axis2=st.integers(min_value=1, max_value=1), - test_with_out=st.just(False), + fn_tree="jax.numpy.polyint", + dtype_and_x_and_k=_get_array_values_m_and_k(), ) -def test_jax_trace( +def test_jax_polyint( *, - dtype_and_x, - offset, - axis1, - axis2, + dtype_and_x_and_k, + test_flags, on_device, fn_tree, frontend, backend_fw, - test_flags, ): - input_dtype, x = dtype_and_x + input_dtype, x, m, k = dtype_and_x_and_k helpers.test_frontend_function( input_dtypes=input_dtype, + test_flags=test_flags, frontend=frontend, backend_to_test=backend_fw, - test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - rtol=1e-1, - atol=1e-1, - a=x[0], - offset=offset, - axis1=axis1, - axis2=axis2, + p=x[0], + m=m, + k=k, ) -# log2 +# polymul @handle_frontend_test( - fn_tree="jax.numpy.log2", + fn_tree="jax.numpy.polymul", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, + min_num_dims=1, + max_num_dims=1, + min_dim_size=2, + large_abs_safety_factor=2, + small_abs_safety_factor=2, + safety_factor_scale="log", ), - test_with_out=st.just(False), + trim=st.booleans(), ) -def test_jax_log2( +def test_jax_polymul( *, dtype_and_x, + trim, + test_flags, on_device, fn_tree, frontend, backend_fw, - test_flags, ): input_dtype, x = dtype_and_x - assume(not np.any(np.isclose(x[0], 0))) + assume("float16" not in input_dtype) helpers.test_frontend_function( input_dtypes=input_dtype, + test_flags=test_flags, frontend=frontend, backend_to_test=backend_fw, - test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - rtol=1e-2, - x=x[0], + a1=x[0], + a2=x[1], + trim_leading_zeros=trim, + atol=1e-01, + rtol=1e-01, ) -# vdot +# polysub @handle_frontend_test( - fn_tree="jax.numpy.vdot", + fn_tree="jax.numpy.polysub", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), + available_dtypes=helpers.get_dtypes("float"), num_arrays=2, + min_num_dims=1, + max_num_dims=1, + min_dim_size=2, + min_value=-1e04, + max_value=1e04, ), - test_with_out=st.just(False), ) -def test_jax_vdot( +def test_jax_polysub( *, dtype_and_x, + test_flags, on_device, fn_tree, frontend, backend_fw, - test_flags, ): input_dtype, x = dtype_and_x + assume("float16" not in input_dtype) helpers.test_frontend_function( input_dtypes=input_dtype, + test_flags=test_flags, frontend=frontend, backend_to_test=backend_fw, - test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - test_values=False, - a=x[0], - b=x[1], + a1=x[0], + a2=x[1], ) -# cbrt +# positive @handle_frontend_test( - fn_tree="jax.numpy.cbrt", + fn_tree="jax.numpy.positive", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("numeric"), min_num_dims=1 ), test_with_out=st.just(False), ) -def test_jax_cbrt( - *, +def test_jax_positive( dtype_and_x, - on_device, - fn_tree, frontend, backend_fw, test_flags, + fn_tree, + on_device, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -2351,37 +2544,22 @@ def test_jax_cbrt( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - rtol=1e-2, - atol=1e-2, x=x[0], ) -# nan_to_num +# power @handle_frontend_test( - fn_tree="jax.numpy.nan_to_num", + fn_tree="jax.numpy.power", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - min_num_dims=1, - max_num_dims=3, - min_value=-100, - max_value=100, - allow_nan=True, - allow_inf=True, + num_arrays=2, ), - copy=st.booleans(), - nan=st.floats(min_value=0.0, max_value=100), - posinf=st.floats(min_value=5e100, max_value=5e100), - neginf=st.floats(min_value=-5e100, max_value=-5e100), test_with_out=st.just(False), ) -def test_jax_nan_to_num( +def test_jax_power( *, dtype_and_x, - copy, - nan, - posinf, - neginf, on_device, fn_tree, frontend, @@ -2396,33 +2574,26 @@ def test_jax_nan_to_num( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], - copy=copy, - nan=nan, - posinf=posinf, - neginf=neginf, + x1=x[0], + x2=x[1], ) -# fix +# rad2deg @handle_frontend_test( - fn_tree="jax.numpy.fix", + fn_tree="jax.numpy.rad2deg", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float", index=2), - min_num_dims=1, - max_num_dims=3, - min_dim_size=1, - max_dim_size=3, + available_dtypes=helpers.get_dtypes("float"), min_num_dims=1 ), + test_with_out=st.just(False), ) -def test_jax_fix( - *, +def test_jax_rad2deg( dtype_and_x, - on_device, - fn_tree, frontend, backend_fw, test_flags, + fn_tree, + on_device, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -2432,25 +2603,19 @@ def test_jax_fix( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - test_values=False, x=x[0], ) -# hypot +# radians @handle_frontend_test( - fn_tree="jax.numpy.hypot", + fn_tree="jax.numpy.radians", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - shared_dtype=True, - min_value=-100, - max_value=100, - min_num_dims=1, - max_num_dims=3, ), + test_with_out=st.just(False), ) -def test_jax_hypot( +def test_jax_radians( *, dtype_and_x, on_device, @@ -2463,58 +2628,15 @@ def test_jax_hypot( helpers.test_frontend_function( input_dtypes=input_dtype, frontend=frontend, + backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - atol=1e-2, - x1=x[0], - x2=x[1], - backend_to_test=backend_fw, + x=x[0], ) -# floor_divide -@handle_frontend_test( - fn_tree="jax.numpy.floor_divide", - dtype_values=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, - shared_dtype=True, - min_value=-10.0, - max_value=10.0, - large_abs_safety_factor=2, - small_abs_safety_factor=2, - safety_factor_scale="linear", - ), -) -def test_jax_floor_divide( - *, - dtype_values, - frontend, - backend_fw, - fn_tree, - on_device, - test_flags, -): - input_dtype, x = dtype_values - # Making sure division by zero doesn't occur - assume(not np.any(np.isclose(x[1], 0))) - # Absolute tolerance is 1, - # due to flooring can cause absolute error of 1 due to precision - helpers.test_frontend_function( - input_dtypes=input_dtype, - on_device=on_device, - test_flags=test_flags, - frontend=frontend, - backend_to_test=backend_fw, - fn_tree=fn_tree, - x1=x[0], - x2=x[1], - atol=1, - ) - - -# real +# real @handle_frontend_test( fn_tree="jax.numpy.real", dtype_and_x=helpers.dtype_and_values( @@ -2543,18 +2665,18 @@ def test_jax_real( ) -# inner +# reciprocal @handle_frontend_test( - fn_tree="jax.numpy.inner", + fn_tree="jax.numpy.reciprocal", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - min_value=-10, - max_value=10, - num_arrays=2, - shared_dtype=True, + available_dtypes=helpers.get_dtypes("float"), + small_abs_safety_factor=4, + large_abs_safety_factor=4, + safety_factor_scale="log", + num_arrays=1, ), ) -def test_jax_inner( +def test_jax_reciprocal( *, dtype_and_x, test_flags, @@ -2563,94 +2685,97 @@ def test_jax_inner( frontend, backend_fw, ): - input_dtypes, xs = dtype_and_x + input_dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=input_dtypes, + input_dtypes=input_dtype, test_flags=test_flags, frontend=frontend, backend_to_test=backend_fw, fn_tree=fn_tree, on_device=on_device, - a=xs[0], - b=xs[1], + x=x[0], ) -# outer +# remainder @handle_frontend_test( - fn_tree="jax.numpy.outer", + fn_tree="jax.numpy.remainder", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2, - min_value=-10, - max_value=10, - min_num_dims=1, - max_num_dims=1, - shared_dtype=True, + large_abs_safety_factor=6, + small_abs_safety_factor=6, + safety_factor_scale="log", ), + test_with_out=st.just(False), ) -def test_jax_outer( +def test_jax_remainder( *, dtype_and_x, - test_flags, on_device, fn_tree, frontend, backend_fw, + test_flags, ): - input_dtypes, xs = dtype_and_x + input_dtype, x = dtype_and_x + + assume(not np.any(np.isclose(x[1], 0))) + helpers.test_frontend_function( - input_dtypes=input_dtypes, - test_flags=test_flags, + input_dtypes=input_dtype, frontend=frontend, backend_to_test=backend_fw, + test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - a=xs[0], - b=xs[1], + x1=x[0], + x2=x[1], + rtol=1e-2, + atol=1e-2, ) -# reciprocal +# round @handle_frontend_test( - fn_tree="jax.numpy.reciprocal", + fn_tree="jax.numpy.round", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - small_abs_safety_factor=4, - large_abs_safety_factor=4, - safety_factor_scale="log", - num_arrays=1, ), + decimals=st.integers(min_value=0, max_value=5), ) -def test_jax_reciprocal( +def test_jax_round( *, dtype_and_x, - test_flags, + decimals, on_device, fn_tree, frontend, backend_fw, + test_flags, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, - test_flags=test_flags, frontend=frontend, backend_to_test=backend_fw, + test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], + a=x[0], + decimals=decimals, ) -# conj +# sign @handle_frontend_test( - fn_tree="jax.numpy.conj", + fn_tree="jax.numpy.sign", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), + available_dtypes=helpers.get_dtypes("numeric"), min_num_dims=1 ), + test_with_out=st.just(False), ) -def test_jax_conj( +def test_jax_sign( *, dtype_and_x, test_flags, @@ -2667,54 +2792,44 @@ def test_jax_conj( backend_to_test=backend_fw, fn_tree=fn_tree, on_device=on_device, - x=x[0], + a=x[0], ) -# imag +# signbit @handle_frontend_test( - fn_tree="jax.numpy.imag", + fn_tree="jax.numpy.signbit", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float_and_complex"), - min_value=-20, - max_value=20, + available_dtypes=helpers.get_dtypes("float"), ), - test_with_out=st.just(False), ) -def test_jax_imag( +def test_jax_signbit( *, dtype_and_x, - test_flags, on_device, fn_tree, frontend, backend_fw, + test_flags, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, - test_flags=test_flags, frontend=frontend, backend_to_test=backend_fw, + test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - rtol=1e-5, - atol=1e-5, - val=x[0], + x=x[0], ) -# subtract @handle_frontend_test( - fn_tree="jax.numpy.subtract", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, - shared_dtype=True, - ), + fn_tree="jax.numpy.sin", + dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), test_with_out=st.just(False), ) -def test_jax_subtract( +def test_jax_sin( *, dtype_and_x, on_device, @@ -2731,23 +2846,24 @@ def test_jax_subtract( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x1=x[0], - x2=x[0], + x=x[0], ) -# around +# sinc @handle_frontend_test( - fn_tree="jax.numpy.around", + fn_tree="jax.numpy.sinc", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), + available_dtypes=helpers.get_dtypes("float"), + min_num_dims=1, + max_num_dims=3, + min_value=-100, + max_value=100, ), - decimals=st.integers(min_value=0, max_value=5), ) -def test_jax_around( +def test_jax_sinc( *, dtype_and_x, - decimals, on_device, fn_tree, frontend, @@ -2762,23 +2878,25 @@ def test_jax_around( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - a=x[0], - decimals=decimals, + rtol=1e-01, + atol=1e-02, + x=x[0], ) -# round +# sinh @handle_frontend_test( - fn_tree="jax.numpy.round", + fn_tree="jax.numpy.sinh", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), + large_abs_safety_factor=4, + small_abs_safety_factor=4, ), - decimals=st.integers(min_value=0, max_value=5), + test_with_out=st.just(False), ) -def test_jax_round( +def test_jax_sinh( *, dtype_and_x, - decimals, on_device, fn_tree, frontend, @@ -2793,347 +2911,276 @@ def test_jax_round( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - a=x[0], - decimals=decimals, + x=x[0], ) -# frexp +# sqrt @handle_frontend_test( - fn_tree="jax.numpy.frexp", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_value=1, - max_value=100, - ), + fn_tree="jax.numpy.sqrt", + dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), + test_with_out=st.just(False), ) -def test_jax_frexp( +def test_jax_sqrt( *, dtype_and_x, - test_flags, on_device, fn_tree, frontend, backend_fw, + test_flags, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, - test_flags=test_flags, frontend=frontend, backend_to_test=backend_fw, + test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, x=x[0], ) -# ldexp +# square @handle_frontend_test( - fn_tree="jax.numpy.ldexp", - dtype_and_x=ldexp_args(), + fn_tree="jax.numpy.square", + dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), + test_with_out=st.just(False), ) -def test_jax_ldexp( +def test_jax_square( *, dtype_and_x, - test_flags, on_device, fn_tree, frontend, backend_fw, + test_flags, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, - test_flags=test_flags, frontend=frontend, backend_to_test=backend_fw, + test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x1=x[0], - x2=x[1], + x=x[0], ) -# poly +# subtract @handle_frontend_test( - fn_tree="jax.numpy.poly", + fn_tree="jax.numpy.subtract", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=1, - min_num_dims=1, - max_num_dims=1, - min_value=-1e04, - max_value=1e04, + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, + shared_dtype=True, ), + test_with_out=st.just(False), ) -def test_jax_poly( +def test_jax_subtract( *, dtype_and_x, - test_flags, on_device, fn_tree, frontend, backend_fw, + test_flags, ): input_dtype, x = dtype_and_x - assume("float16" not in input_dtype) helpers.test_frontend_function( input_dtypes=input_dtype, - test_flags=test_flags, frontend=frontend, backend_to_test=backend_fw, + test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - seq_of_zeros=x[0], - atol=1e-05, - rtol=1e-03, + x1=x[0], + x2=x[0], ) -# polyadd +# tan @handle_frontend_test( - fn_tree="jax.numpy.polyadd", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - min_num_dims=1, - max_num_dims=1, - min_dim_size=2, - ), + fn_tree="jax.numpy.tan", + dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), + test_with_out=st.just(False), ) -def test_jax_polyadd( +def test_jax_tan( *, dtype_and_x, - test_flags, on_device, fn_tree, frontend, backend_fw, + test_flags, ): input_dtype, x = dtype_and_x - assume("float16" not in input_dtype) helpers.test_frontend_function( input_dtypes=input_dtype, - test_flags=test_flags, frontend=frontend, backend_to_test=backend_fw, + test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - a1=x[0], - a2=x[1], + x=x[0], ) -# polyder +# tanh @handle_frontend_test( - fn_tree="jax.numpy.polyder", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=1, - min_num_dims=1, - max_num_dims=1, - min_dim_size=1, - ), - m=st.integers(min_value=0, max_value=10), + fn_tree="jax.numpy.tanh", + dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), + test_with_out=st.just(False), ) -def test_jax_polyder( +def test_jax_tanh( *, dtype_and_x, - m, - test_flags, on_device, fn_tree, frontend, backend_fw, + test_flags, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, - test_flags=test_flags, frontend=frontend, backend_to_test=backend_fw, - fn_tree=fn_tree, - on_device=on_device, - p=x[0], - m=m, - ) - - -# polyint -@st.composite -def _get_array_values_m_and_k(draw): - dtype_and_x = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=1, - min_num_dims=1, - max_num_dims=1, - min_dim_size=1, - ) - ) - dtype, x = dtype_and_x - m = draw(st.integers(min_value=0, max_value=10)) - max_bound = m - 1 - if max_bound <= m: - k = None - else: - k = draw(st.integers(min_value=0, max_value=max_bound)) - return dtype, x, m, k - - -@handle_frontend_test( - fn_tree="jax.numpy.polyint", - dtype_and_x_and_k=_get_array_values_m_and_k(), -) -def test_jax_polyint( - *, - dtype_and_x_and_k, - test_flags, - on_device, - fn_tree, - frontend, - backend_fw, -): - input_dtype, x, m, k = dtype_and_x_and_k - helpers.test_frontend_function( - input_dtypes=input_dtype, test_flags=test_flags, - frontend=frontend, - backend_to_test=backend_fw, fn_tree=fn_tree, on_device=on_device, - p=x[0], - m=m, - k=k, + x=x[0], ) -# polydiv +# tensordot @handle_frontend_test( - fn_tree="jax.numpy.polydiv", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - min_num_dims=1, - min_dim_size=1, - max_num_dims=1, - min_value=-1e04, - max_value=1e04, + fn_tree="jax.numpy.tensordot", + dtype_values_and_axes=_get_dtype_value1_value2_axis_for_tensordot( + helpers.get_dtypes(kind="numeric") ), + test_with_out=st.just(False), ) -def test_jax_polydiv( - *, - dtype_and_x, - test_flags, - on_device, - fn_tree, +def test_jax_tensordot( + dtype_values_and_axes, frontend, backend_fw, + test_flags, + fn_tree, ): - input_dtype, x = dtype_and_x - assume("float16" not in input_dtype) - # TODO: remove asumme when the decorator works + dtype, a, b, axes = dtype_values_and_axes + if ivy.current_backend_str() == "torch": + atol = 1e-3 + else: + atol = 1e-6 helpers.test_frontend_function( - input_dtypes=input_dtype, - test_flags=test_flags, + input_dtypes=dtype, frontend=frontend, backend_to_test=backend_fw, + test_flags=test_flags, fn_tree=fn_tree, - on_device=on_device, - u=x[0], - v=x[1], - rtol=1e-01, - atol=1e-02, + a=a, + b=b, + atol=atol, + axes=axes, ) -# polysub +# trace @handle_frontend_test( - fn_tree="jax.numpy.polysub", + fn_tree="jax.numpy.trace", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - min_num_dims=1, - max_num_dims=1, - min_dim_size=2, - min_value=-1e04, - max_value=1e04, + min_num_dims=2, + max_num_dims=2, + min_dim_size=1, + max_dim_size=10, + large_abs_safety_factor=24, + small_abs_safety_factor=24, + safety_factor_scale="log", ), + offset=st.integers(min_value=0, max_value=0), + axis1=st.integers(min_value=0, max_value=0), + axis2=st.integers(min_value=1, max_value=1), + test_with_out=st.just(False), ) -def test_jax_polysub( +def test_jax_trace( *, dtype_and_x, - test_flags, + offset, + axis1, + axis2, on_device, fn_tree, frontend, backend_fw, + test_flags, ): input_dtype, x = dtype_and_x - assume("float16" not in input_dtype) helpers.test_frontend_function( input_dtypes=input_dtype, - test_flags=test_flags, frontend=frontend, backend_to_test=backend_fw, + test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - a1=x[0], - a2=x[1], + rtol=1e-1, + atol=1e-1, + a=x[0], + offset=offset, + axis1=axis1, + axis2=axis2, ) -# polymul @handle_frontend_test( - fn_tree="jax.numpy.polymul", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - min_num_dims=1, - max_num_dims=1, - min_dim_size=2, - large_abs_safety_factor=2, - small_abs_safety_factor=2, - safety_factor_scale="log", - ), - trim=st.booleans(), + fn_tree="jax.numpy.trapz", + dtype_x_axis_rand_either=_either_x_dx(), + test_with_out=st.just(False), ) -def test_jax_polymul( +def test_jax_trapz( *, - dtype_and_x, - trim, - test_flags, + dtype_x_axis_rand_either, on_device, fn_tree, frontend, backend_fw, + test_flags, ): - input_dtype, x = dtype_and_x - assume("float16" not in input_dtype) + dtype_values_axis, rand, either_x_dx = dtype_x_axis_rand_either + input_dtype, y, axis = dtype_values_axis + if rand == 0: + dtype_x, x = either_x_dx + x = np.asarray(x, dtype=dtype_x) + dx = None + else: + x = None + dx = either_x_dx helpers.test_frontend_function( input_dtypes=input_dtype, - test_flags=test_flags, frontend=frontend, backend_to_test=backend_fw, + test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - a1=x[0], - a2=x[1], - trim_leading_zeros=trim, - atol=1e-01, - rtol=1e-01, + rtol=1e-2, + atol=1e-2, + y=y[0], + x=x, + dx=dx, + axis=axis, ) -# signbit +# trunc @handle_frontend_test( - fn_tree="jax.numpy.signbit", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - ), + fn_tree="jax.numpy.trunc", + dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), + test_with_out=st.just(False), ) -def test_jax_signbit( +def test_jax_trunc( *, dtype_and_x, on_device, @@ -3154,72 +3201,33 @@ def test_jax_signbit( ) -# TODO: uncomment with multiversion pipeline (deprecated since 0.4.12) -# @handle_frontend_test( -# fn_tree="jax.numpy.product", -# dtype_x_axis_dtype_where=_get_castable_dtypes_values(use_where=True), -# keepdims=st.booleans(), -# initial=st.one_of(st.floats(min_value=-100, max_value=100)), -# promote_integers=st.booleans(), -# ) -# def test_jax_product( -# dtype_x_axis_dtype_where, -# keepdims, -# initial, -# promote_integers, -# frontend, -# test_flags, -# fn_tree, -# on_device, -# ): -# input_dtypes, x, axis, dtype, where = dtype_x_axis_dtype_where -# if ivy.current_backend_str() == "torch": -# assume(not test_flags.as_variable[0]) -# where, input_dtypes, test_flags = np_frontend_helpers. -# handle_where_and_array_bools( -# where=where, -# input_dtype=input_dtypes, -# test_flags=test_flags, -# ) -# helpers.test_frontend_function( -# input_dtypes=input_dtypes, -# frontend=frontend, -# test_flags=test_flags, -# fn_tree=fn_tree, -# on_device=on_device, -# a=x[0], -# axis=axis, -# dtype=dtype, -# keepdims=keepdims, -# initial=initial, -# where=where, -# promote_integers=promote_integers, -# ) - - -# conjugate +# vdot @handle_frontend_test( - fn_tree="jax.numpy.conjugate", + fn_tree="jax.numpy.vdot", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, ), + test_with_out=st.just(False), ) -def test_jax_conjugate( +def test_jax_vdot( *, dtype_and_x, - test_flags, on_device, fn_tree, frontend, backend_fw, + test_flags, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, - test_flags=test_flags, frontend=frontend, backend_to_test=backend_fw, + test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], + test_values=False, + a=x[0], + b=x[1], ) diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_searching_sorting.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_searching_sorting.py index c5af497fb0df4..1e43433c4e00e 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_searching_sorting.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_searching_sorting.py @@ -13,6 +13,62 @@ ) +# --- Helpers --- # +# --------------- # + + +# searchsorted +@st.composite +def _searchsorted(draw): + dtype_x, x = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes( + "numeric", full=False, key="searchsorted" + ), + shape=(draw(st.integers(min_value=1, max_value=10)),), + ), + ) + dtype_v, v = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes( + "numeric", full=False, key="searchsorted" + ), + min_num_dims=1, + ) + ) + + input_dtypes = dtype_x + dtype_v + xs = x + v + side = draw(st.sampled_from(["left", "right"])) + sorter = None + xs[0] = np.sort(xs[0], axis=-1) + return input_dtypes, xs, side, sorter + + +# unique +@st.composite +def _unique_helper(draw): + arr_dtype, arr, shape = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes( + "numeric", full=False, key="searchsorted" + ), + min_num_dims=1, + min_dim_size=2, + ret_shape=True, + ) + ) + axis = draw(st.sampled_from(list(range(len(shape))) + [None])) + return_index = draw(st.booleans()) + return_inverse = draw(st.booleans()) + return_counts = draw(st.booleans()) + return arr_dtype, arr, return_index, return_inverse, return_counts, axis + + +# --- Main --- # +# ------------ # + + # argmax @handle_frontend_test( fn_tree="jax.numpy.argmax", @@ -49,6 +105,40 @@ def test_jax_argmax( ) +# argsort +@handle_frontend_test( + fn_tree="jax.numpy.argsort", + dtype_x_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("numeric"), + min_axis=-1, + max_axis=0, + min_num_dims=1, + force_int_axis=True, + ), + test_with_out=st.just(False), +) +def test_jax_argsort( + *, + dtype_x_axis, + frontend, + backend_fw, + test_flags, + fn_tree, + on_device, +): + input_dtype, x, axis = dtype_x_axis + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + a=x[0], + axis=axis, + ) + + # argwhere @handle_frontend_test( fn_tree="jax.numpy.argwhere", @@ -79,85 +169,49 @@ def test_jax_argwhere( ) -# argsort +# extract @handle_frontend_test( - fn_tree="jax.numpy.argsort", - dtype_x_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("numeric"), - min_axis=-1, - max_axis=0, - min_num_dims=1, - force_int_axis=True, - ), - test_with_out=st.just(False), + fn_tree="jax.numpy.extract", + broadcastables=_broadcastable_trio(), ) -def test_jax_argsort( - *, - dtype_x_axis, +def test_jax_extract( + broadcastables, frontend, backend_fw, test_flags, fn_tree, on_device, ): - input_dtype, x, axis = dtype_x_axis + cond, xs, dtype = broadcastables helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=dtype, frontend=frontend, backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - a=x[0], - axis=axis, + condition=cond, + arr=xs[0], ) -# msort -# @handle_frontend_test( -# fn_tree="jax.numpy.msort", -# dtype_and_x=helpers.dtype_and_values( -# available_dtypes=helpers.get_dtypes("numeric"), -# min_num_dims=2, -# min_dim_size=2, -# ), -# test_with_out=st.just(False), -# ) -# def test_jax_msort( -# dtype_and_x, -# frontend, -# test_flags, -# fn_tree, -# ): -# input_dtype, x = dtype_and_x -# helpers.test_frontend_function( -# input_dtypes=input_dtype, -# frontend=frontend, -# test_flags=test_flags, -# fn_tree=fn_tree, -# a=x[0], -# ) -# TODO : deprecated since jax 0.4.1. \ -# Uncomment with multiversion testing pipeline enabled. - - -# nonzero +# flatnonzero @handle_frontend_test( - fn_tree="jax.numpy.nonzero", - dtype_and_a=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), + fn_tree="jax.numpy.flatnonzero", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), ), test_with_out=st.just(False), ) -def test_jax_nonzero( - dtype_and_a, +def test_jax_flatnonzero( + dtype_and_x, frontend, backend_fw, test_flags, fn_tree, on_device, ): - dtype, a = dtype_and_a + dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=dtype, frontend=frontend, @@ -165,7 +219,7 @@ def test_jax_nonzero( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - a=a[0], + a=x[0], ) @@ -241,20 +295,51 @@ def test_jax_nanargmin( ) -# extract +# msort +# @handle_frontend_test( +# fn_tree="jax.numpy.msort", +# dtype_and_x=helpers.dtype_and_values( +# available_dtypes=helpers.get_dtypes("numeric"), +# min_num_dims=2, +# min_dim_size=2, +# ), +# test_with_out=st.just(False), +# ) +# def test_jax_msort( +# dtype_and_x, +# frontend, +# test_flags, +# fn_tree, +# ): +# input_dtype, x = dtype_and_x +# helpers.test_frontend_function( +# input_dtypes=input_dtype, +# frontend=frontend, +# test_flags=test_flags, +# fn_tree=fn_tree, +# a=x[0], +# ) +# TODO : deprecated since jax 0.4.1. \ +# Uncomment with multiversion testing pipeline enabled. + + +# nonzero @handle_frontend_test( - fn_tree="jax.numpy.extract", - broadcastables=_broadcastable_trio(), + fn_tree="jax.numpy.nonzero", + dtype_and_a=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + ), + test_with_out=st.just(False), ) -def test_jax_extract( - broadcastables, +def test_jax_nonzero( + dtype_and_a, frontend, backend_fw, test_flags, fn_tree, on_device, ): - cond, xs, dtype = broadcastables + dtype, a = dtype_and_a helpers.test_frontend_function( input_dtypes=dtype, frontend=frontend, @@ -262,70 +347,69 @@ def test_jax_extract( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - condition=cond, - arr=xs[0], + a=a[0], ) -# sort @handle_frontend_test( - fn_tree="jax.numpy.sort", - dtype_x_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("numeric"), - min_axis=-1, - max_axis=0, - min_num_dims=1, - force_int_axis=True, - ), + fn_tree="jax.numpy.searchsorted", + dtype_x_v_side_sorter=_searchsorted(), test_with_out=st.just(False), ) -def test_jax_sort( - *, - dtype_x_axis, +def test_jax_searchsorted( + dtype_x_v_side_sorter, frontend, backend_fw, + test_flags, fn_tree, on_device, - test_flags, ): - input_dtype, x, axis = dtype_x_axis + input_dtypes, xs, side, sorter = dtype_x_v_side_sorter helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=input_dtypes, frontend=frontend, backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - a=x[0], - axis=axis, + a=xs[0], + v=xs[1], + side=side, + sorter=sorter, ) -# flatnonzero +# sort @handle_frontend_test( - fn_tree="jax.numpy.flatnonzero", - dtype_and_x=helpers.dtype_and_values( + fn_tree="jax.numpy.sort", + dtype_x_axis=helpers.dtype_values_axis( available_dtypes=helpers.get_dtypes("numeric"), + min_axis=-1, + max_axis=0, + min_num_dims=1, + force_int_axis=True, ), test_with_out=st.just(False), ) -def test_jax_flatnonzero( - dtype_and_x, +def test_jax_sort( + *, + dtype_x_axis, frontend, backend_fw, - test_flags, fn_tree, on_device, + test_flags, ): - dtype, x = dtype_and_x + input_dtype, x, axis = dtype_x_axis helpers.test_frontend_function( - input_dtypes=dtype, + input_dtypes=input_dtype, frontend=frontend, backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, a=x[0], + axis=axis, ) @@ -364,59 +448,23 @@ def test_jax_sort_complex( ) -# searchsorted -@st.composite -def _searchsorted(draw): - dtype_x, x = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes( - "numeric", full=False, key="searchsorted" - ), - shape=(draw(st.integers(min_value=1, max_value=10)),), - ), - ) - dtype_v, v = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes( - "numeric", full=False, key="searchsorted" - ), - min_num_dims=1, - ) - ) - - input_dtypes = dtype_x + dtype_v - xs = x + v - side = draw(st.sampled_from(["left", "right"])) - sorter = None - xs[0] = np.sort(xs[0], axis=-1) - return input_dtypes, xs, side, sorter - - @handle_frontend_test( - fn_tree="jax.numpy.searchsorted", - dtype_x_v_side_sorter=_searchsorted(), - test_with_out=st.just(False), + fn_tree="jax.numpy.unique", fn_inputs=_unique_helper(), test_with_out=st.just(False) ) -def test_jax_searchsorted( - dtype_x_v_side_sorter, - frontend, - backend_fw, - test_flags, - fn_tree, - on_device, -): - input_dtypes, xs, side, sorter = dtype_x_v_side_sorter +def test_jax_unique(fn_inputs, backend_fw, frontend, test_flags, fn_tree, on_device): + arr_dtype, arr, return_index, return_inverse, return_counts, axis = fn_inputs helpers.test_frontend_function( - input_dtypes=input_dtypes, + input_dtypes=arr_dtype, frontend=frontend, backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - a=xs[0], - v=xs[1], - side=side, - sorter=sorter, + ar=arr[0], + return_index=return_index, + return_inverse=return_inverse, + return_counts=return_counts, + axis=axis, ) @@ -458,43 +506,3 @@ def test_jax_where( size=size, fill_value=fill_value, ) - - -# unique -@st.composite -def _unique_helper(draw): - arr_dtype, arr, shape = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes( - "numeric", full=False, key="searchsorted" - ), - min_num_dims=1, - min_dim_size=2, - ret_shape=True, - ) - ) - axis = draw(st.sampled_from(list(range(len(shape))) + [None])) - return_index = draw(st.booleans()) - return_inverse = draw(st.booleans()) - return_counts = draw(st.booleans()) - return arr_dtype, arr, return_index, return_inverse, return_counts, axis - - -@handle_frontend_test( - fn_tree="jax.numpy.unique", fn_inputs=_unique_helper(), test_with_out=st.just(False) -) -def test_jax_unique(fn_inputs, backend_fw, frontend, test_flags, fn_tree, on_device): - arr_dtype, arr, return_index, return_inverse, return_counts, axis = fn_inputs - helpers.test_frontend_function( - input_dtypes=arr_dtype, - frontend=frontend, - backend_to_test=backend_fw, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - ar=arr[0], - return_index=return_index, - return_inverse=return_inverse, - return_counts=return_counts, - axis=axis, - ) diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_statistical.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_statistical.py index b076d278fc809..7395504a913c7 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_statistical.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_statistical.py @@ -16,137 +16,210 @@ from ivy import inf -# einsum -@handle_frontend_test( - fn_tree="jax.numpy.einsum", - eq_n_op=st.sampled_from( - [ - ( - "ii", - np.arange(25).reshape(5, 5), - ), - ( - "ii->i", - np.arange(25).reshape(5, 5), - ), - ("ij,j", np.arange(25).reshape(5, 5), np.arange(5)), - ] - ), - dtype=helpers.get_dtypes("float", full=False), -) -def test_jax_einsum( - *, - eq_n_op, - dtype, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - kw = {} - i = 0 - for x_ in eq_n_op: - kw["x{}".format(i)] = x_ - i += 1 - test_flags.num_positional_args = i - helpers.test_frontend_function( - input_dtypes=dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - **kw, - out=None, - optimize="optimal", - precision=None, - _use_xeinsum=False, +# --- Helpers --- # +# --------------- # + + +@st.composite +def _get_array_axes_probs(draw): + array_dtypes = draw(helpers.get_dtypes(kind="float")) + array_dtype, array, axes = draw( + helpers.dtype_values_axis( + available_dtypes=array_dtypes, + small_abs_safety_factor=5, + large_abs_safety_factor=5, + min_num_dims=1, + max_num_dims=5, + max_dim_size=7, + max_axes_size=5, + valid_axis=True, + force_int_axis=True, + min_value=1, + max_value=300, + ) + ) + q = np.round( + np.array( + draw( + helpers.lists( + x=helpers.floats( + min_value=0, + max_value=1, + small_abs_safety_factor=50, + large_abs_safety_factor=50, + safety_factor_scale="log", + abs_smallest_val=1e-1, + mixed_fn_compos=False, + ), + min_size=1, + max_size=10, + ) + ) + ), + decimals=3, ) + return array_dtype, array, axes, q -# mean -@handle_frontend_test( - fn_tree="jax.numpy.mean", - dtype_x_axis=_statistical_dtype_values(function="mean"), - dtype=helpers.get_dtypes("float", full=False, none=True), - where=np_helpers.where(), - keepdims=st.booleans(), -) -def test_jax_mean( - *, - dtype_x_axis, - dtype, - keepdims, - where, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, + +# nanmean +@st.composite +def _get_castable_dtype_with_nan(draw): + available_dtypes = helpers.get_dtypes("float") + shape = draw(helpers.get_shape(min_num_dims=1, max_num_dims=4, max_dim_size=6)) + dtype, values = draw( + helpers.dtype_and_values( + available_dtypes=available_dtypes, + num_arrays=1, + large_abs_safety_factor=6, + small_abs_safety_factor=24, + safety_factor_scale="log", + shape=shape, + allow_nan=True, + allow_inf=True, + ) + ) + axis = draw(helpers.get_axis(shape=shape, force_int=True)) + dtype1, values, dtype2 = draw( + helpers.get_castable_dtype(draw(available_dtypes), dtype[0], values[0]) + ) + return dtype1, [values], axis, dtype2 + + +@st.composite +def _get_castable_dtypes_values(draw, *, allow_nan=False, use_where=False): + available_dtypes = helpers.get_dtypes("numeric") + shape = draw(helpers.get_shape(min_num_dims=1, max_num_dims=4, max_dim_size=6)) + dtype, values = draw( + helpers.dtype_and_values( + available_dtypes=available_dtypes, + num_arrays=1, + large_abs_safety_factor=24, + small_abs_safety_factor=24, + safety_factor_scale="log", + shape=shape, + allow_nan=allow_nan, + ) + ) + axis = draw(helpers.get_axis(shape=shape, force_int=True)) + dtype1, values, dtype2 = draw( + helpers.get_castable_dtype(draw(available_dtypes), dtype[0], values[0]) + ) + if use_where: + where = draw(np_frontend_helpers.where(shape=shape)) + return [dtype1], [values], axis, dtype2, where + return [dtype1], [values], axis, dtype2 + + +# cov +@st.composite +def _get_dtype_value1_value2_cov( + draw, + available_dtypes, + min_num_dims=1, + max_num_dims=2, + min_dim_size=2, + max_dim_size=3, + abs_smallest_val=None, + min_value=None, + max_value=None, + allow_inf=False, + exclude_min=False, + exclude_max=False, + large_abs_safety_factor=50, + small_abs_safety_factor=50, + safety_factor_scale="log", ): - input_dtypes, x, axis = dtype_x_axis - if isinstance(axis, tuple): - axis = axis[0] - where, input_dtypes, test_flags = np_helpers.handle_where_and_array_bools( - where=where, - input_dtype=input_dtypes, - test_flags=test_flags, + shape = draw( + helpers.get_shape( + allow_none=False, + min_num_dims=min_num_dims, + max_num_dims=max_num_dims, + min_dim_size=min_dim_size, + max_dim_size=max_dim_size, + ) ) - np_helpers.test_frontend_function( - input_dtypes=input_dtypes, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - atol=1e-1, - rtol=1e-1, - a=x[0], - axis=axis, - dtype=dtype[0], - out=None, - keepdims=keepdims, - where=where, + dtype = draw(st.sampled_from(available_dtypes)) + + values = [] + for i in range(2): + values.append( + draw( + helpers.array_values( + dtype=dtype, + shape=shape, + abs_smallest_val=abs_smallest_val, + min_value=min_value, + max_value=max_value, + allow_inf=allow_inf, + exclude_min=exclude_min, + exclude_max=exclude_max, + large_abs_safety_factor=large_abs_safety_factor, + small_abs_safety_factor=small_abs_safety_factor, + safety_factor_scale=safety_factor_scale, + ) + ) + ) + + value1, value2 = values[0], values[1] + + # modifiers: rowVar, bias, ddof + rowVar = draw(st.booleans()) + bias = draw(st.booleans()) + ddof = draw(helpers.ints(min_value=0, max_value=1)) + + numVals = None + if rowVar is False: + numVals = -1 if numVals == 0 else 0 + else: + numVals = 0 if len(shape) == 1 else -1 + + fweights = None + + aweights = draw( + helpers.array_values( + dtype=dtype, + shape=shape[numVals], + abs_smallest_val=1, + min_value=1, + max_value=10, + allow_inf=False, + small_abs_safety_factor=1, + ) ) + return [dtype], value1, value2, rowVar, bias, ddof, fweights, aweights -# var + +# --- Main --- # +# ------------ # + + +# argmin @handle_frontend_test( - fn_tree="jax.numpy.var", - dtype_x_axis=_statistical_dtype_values(function="var").filter( - lambda x: x[0][0] != "bfloat16" - ), - dtype=helpers.get_dtypes("float", full=False, none=True).filter( - lambda x: x != "bfloat16" + fn_tree="jax.numpy.argmin", + dtype_and_x=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("numeric"), + force_int_axis=True, + min_num_dims=1, + valid_axis=True, ), - where=np_helpers.where(), keepdims=st.booleans(), ) -def test_jax_var( +def test_jax_argmin( *, - dtype_x_axis, - dtype, + dtype_and_x, keepdims, - where, on_device, fn_tree, frontend, test_flags, backend_fw, ): - input_dtypes, x, axis, ddof = dtype_x_axis - if isinstance(axis, tuple): - axis = axis[0] - where, input_dtypes, test_flags = np_helpers.handle_where_and_array_bools( - where=where, - input_dtype=input_dtypes, - test_flags=test_flags, - ) - - np_helpers.test_frontend_function( - input_dtypes=input_dtypes, + input_dtype, x, axis = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, @@ -154,49 +227,57 @@ def test_jax_var( on_device=on_device, a=x[0], axis=axis, - dtype=dtype[0], out=None, - ddof=ddof, keepdims=keepdims, - where=where, - atol=1e-3, - rtol=1e-3, ) -# argmin +# average @handle_frontend_test( - fn_tree="jax.numpy.argmin", - dtype_and_x=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("numeric"), - force_int_axis=True, + fn_tree="jax.numpy.average", + dtype_x_axis=helpers.dtype_values_axis( + num_arrays=2, + available_dtypes=helpers.get_dtypes("float"), + large_abs_safety_factor=24, + small_abs_safety_factor=24, + safety_factor_scale="log", min_num_dims=1, + max_num_dims=5, + min_dim_size=2, valid_axis=True, + allow_neg_axes=False, + min_axes_size=1, ), - keepdims=st.booleans(), + returned=st.booleans(), ) -def test_jax_argmin( +def test_jax_average( *, - dtype_and_x, - keepdims, + dtype_x_axis, + returned, on_device, fn_tree, frontend, test_flags, backend_fw, ): - input_dtype, x, axis = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, + x_dtype, x, axis = dtype_x_axis + + if isinstance(axis, tuple): + axis = axis[0] + + np_helpers.test_frontend_function( + input_dtypes=x_dtype, frontend=frontend, test_flags=test_flags, + backend_to_test=backend_fw, fn_tree=fn_tree, on_device=on_device, + atol=2e-2, + rtol=2e-2, a=x[0], axis=axis, - out=None, - keepdims=keepdims, + weights=x[1], + returned=returned, ) @@ -241,417 +322,280 @@ def test_jax_bincount( ) -# cumprod @handle_frontend_test( - fn_tree="jax.numpy.cumprod", - # aliases=["jax.numpy.cumproduct"], deprecated since 0.4.12 - dtype_x_axis=_get_castable_dtype(), - test_with_out=st.just(False), + fn_tree="jax.numpy.corrcoef", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=["float32", "float64"], + num_arrays=2, + shared_dtype=True, + abs_smallest_val=1e-5, + min_num_dims=2, + max_num_dims=2, + min_dim_size=3, + max_dim_size=3, + min_value=-100, + max_value=100, + ), + rowvar=st.booleans(), ) -def test_jax_cumprod( - *, - dtype_x_axis, - on_device, - fn_tree, +def test_jax_corrcoef( + dtype_and_x, + rowvar, frontend, test_flags, - backend_fw, -): - input_dtype, x, axis, dtype = dtype_x_axis - helpers.test_frontend_function( - backend_to_test=backend_fw, - input_dtypes=[input_dtype], - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - rtol=1e-2, - a=x[0], - axis=axis, - dtype=dtype, - ) - - -# cumsum -@handle_frontend_test( - fn_tree="jax.numpy.cumsum", - dtype_x_axis=_get_castable_dtype(), - test_with_out=st.just(False), -) -def test_jax_cumsum( - *, - dtype_x_axis, - on_device, fn_tree, - frontend, - test_flags, backend_fw, -): - input_dtype, x, axis, dtype = dtype_x_axis - helpers.test_frontend_function( - backend_to_test=backend_fw, - input_dtypes=[input_dtype], - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - a=x[0], - axis=axis, - dtype=dtype, - ) - - -# sum -@handle_frontend_test( - fn_tree="jax.numpy.sum", - dtype_x_axis_castable=_get_castable_dtype(), - initial=st.none() | st.floats(-10.0, 10.0), - where=np_helpers.where(), - keepdims=st.booleans(), -) -def test_jax_sum( - *, - dtype_x_axis_castable, - initial, - where, - keepdims, on_device, - fn_tree, - frontend, - test_flags, - backend_fw, ): - input_dtypes, x, axis, castable_dtype = dtype_x_axis_castable - - if isinstance(axis, tuple): - axis = axis[0] - where, input_dtypes, test_flags = np_helpers.handle_where_and_array_bools( - where=where, - input_dtype=[input_dtypes], - test_flags=test_flags, - ) - + input_dtypes, x = dtype_and_x np_helpers.test_frontend_function( input_dtypes=input_dtypes, + backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - rtol=1e-1, - atol=1e-2, - a=x[0], - axis=axis, - dtype=castable_dtype, - out=None, - keepdims=keepdims, - initial=initial, - where=where, - backend_to_test=backend_fw, + x=x[0], + y=x[1], + rowvar=rowvar, ) -# min +# correlate @handle_frontend_test( - fn_tree="jax.numpy.min", - aliases=["jax.numpy.amin"], - dtype_x_axis=_statistical_dtype_values(function="min"), - where=np_helpers.where(), - keepdims=st.booleans(), + fn_tree="jax.numpy.correlate", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, + min_num_dims=1, + max_num_dims=1, + min_value=-1e04, + max_value=1e04, + shared_dtype=True, + ), + mode=st.sampled_from(["valid", "same", "full"]), ) -def test_jax_min( +def test_jax_correlate( *, - dtype_x_axis, - keepdims, - where, + dtype_and_x, on_device, fn_tree, frontend, test_flags, backend_fw, + mode, ): - input_dtypes, x, axis = dtype_x_axis - if isinstance(axis, tuple): - axis = axis[0] - where, input_dtypes, test_flags = np_helpers.handle_where_and_array_bools( - where=where, - input_dtype=input_dtypes, - test_flags=test_flags, - ) - - np_helpers.test_frontend_function( - input_dtypes=input_dtypes, + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, + rtol=1e-4, + atol=1e-4, on_device=on_device, a=x[0], - axis=axis, - out=None, - keepdims=keepdims, - where=where, + v=x[1], + mode=mode, ) -# max @handle_frontend_test( - fn_tree="jax.numpy.max", - aliases=["jax.numpy.amax"], - dtype_x_axis=_statistical_dtype_values(function="max"), - where=np_helpers.where(), - keepdims=st.booleans(), + fn_tree="jax.numpy.cov", + dtypes_args=_get_dtype_value1_value2_cov(available_dtypes=["float64"]), + test_with_out=st.just(False), ) -def test_jax_max( +def test_jax_cov( *, - dtype_x_axis, - keepdims, - where, + dtypes_args, on_device, fn_tree, frontend, - test_flags, backend_fw, + test_flags, ): - input_dtypes, x, axis = dtype_x_axis - if isinstance(axis, tuple): - axis = axis[0] - where, input_dtypes, test_flags = np_helpers.handle_where_and_array_bools( - where=where, - input_dtype=input_dtypes, - test_flags=test_flags, - ) - - np_helpers.test_frontend_function( - input_dtypes=input_dtypes, - backend_to_test=backend_fw, + dtype, value1, value2, rowvar, bias, ddof, fweights, aweights = dtypes_args + helpers.test_frontend_function( + input_dtypes=dtype, frontend=frontend, + backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, + rtol=1e-3, + atol=1e-3, on_device=on_device, - a=x[0], - axis=axis, - out=None, - keepdims=keepdims, - where=where, + m=value1, + y=value2, + rowvar=rowvar, + bias=bias, + ddof=ddof, + fweights=fweights, + aweights=aweights, ) -# average +# cumprod @handle_frontend_test( - fn_tree="jax.numpy.average", - dtype_x_axis=helpers.dtype_values_axis( - num_arrays=2, - available_dtypes=helpers.get_dtypes("float"), - large_abs_safety_factor=24, - small_abs_safety_factor=24, - safety_factor_scale="log", - min_num_dims=1, - max_num_dims=5, - min_dim_size=2, - valid_axis=True, - allow_neg_axes=False, - min_axes_size=1, - ), - returned=st.booleans(), + fn_tree="jax.numpy.cumprod", + # aliases=["jax.numpy.cumproduct"], deprecated since 0.4.12 + dtype_x_axis=_get_castable_dtype(), + test_with_out=st.just(False), ) -def test_jax_average( +def test_jax_cumprod( *, dtype_x_axis, - returned, on_device, fn_tree, frontend, test_flags, backend_fw, ): - x_dtype, x, axis = dtype_x_axis - - if isinstance(axis, tuple): - axis = axis[0] - - np_helpers.test_frontend_function( - input_dtypes=x_dtype, + input_dtype, x, axis, dtype = dtype_x_axis + helpers.test_frontend_function( + backend_to_test=backend_fw, + input_dtypes=[input_dtype], frontend=frontend, test_flags=test_flags, - backend_to_test=backend_fw, fn_tree=fn_tree, on_device=on_device, - atol=2e-2, - rtol=2e-2, + rtol=1e-2, a=x[0], axis=axis, - weights=x[1], - returned=returned, + dtype=dtype, ) -# nanmax +# cumsum @handle_frontend_test( - fn_tree="jax.numpy.nanmax", - aliases=["jax.numpy.nanmax"], - dtype_x_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("float"), - min_num_dims=1, - valid_axis=True, - force_int_axis=True, - large_abs_safety_factor=2, - safety_factor_scale="log", - allow_nan=True, - allow_inf=True, - ), - initial=st.one_of(st.floats(min_value=-1000, max_value=1000), st.none()), - keepdims=st.booleans(), - where=np_helpers.where(), + fn_tree="jax.numpy.cumsum", + dtype_x_axis=_get_castable_dtype(), + test_with_out=st.just(False), ) -def test_jax_nanmax( +def test_jax_cumsum( + *, dtype_x_axis, + on_device, + fn_tree, frontend, test_flags, - fn_tree, backend_fw, - on_device, - where, - initial, - keepdims, ): - if initial is None and np.all(where) is not True: - assume(initial is -inf) - - input_dtypes, x, axis = dtype_x_axis - where, input_dtypes, test_flags = np_helpers.handle_where_and_array_bools( - where=where, - input_dtype=input_dtypes, - test_flags=test_flags, - ) - np_helpers.test_frontend_function( - input_dtypes=input_dtypes, + input_dtype, x, axis, dtype = dtype_x_axis + helpers.test_frontend_function( backend_to_test=backend_fw, + input_dtypes=[input_dtype], frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, a=x[0], axis=axis, - out=None, - keepdims=keepdims, - initial=initial, - where=where, + dtype=dtype, ) -# nanmin +# einsum @handle_frontend_test( - fn_tree="jax.numpy.nanmin", - dtype_x_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("float", full=False), - min_num_dims=1, - valid_axis=True, - force_int_axis=True, - large_abs_safety_factor=2, - safety_factor_scale="log", - allow_nan=True, - allow_inf=True, + fn_tree="jax.numpy.einsum", + eq_n_op=st.sampled_from( + [ + ( + "ii", + np.arange(25).reshape(5, 5), + ), + ( + "ii->i", + np.arange(25).reshape(5, 5), + ), + ("ij,j", np.arange(25).reshape(5, 5), np.arange(5)), + ] ), - initial=st.one_of(st.floats(min_value=-1000, max_value=1000), st.none()), - keepdims=st.booleans(), - where=np_helpers.where(), + dtype=helpers.get_dtypes("float", full=False), ) -def test_jax_nanmin( - dtype_x_axis, +def test_jax_einsum( + *, + eq_n_op, + dtype, + on_device, + fn_tree, frontend, test_flags, - fn_tree, backend_fw, - on_device, - where, - initial, - keepdims, ): - if initial is None and np.all(where) is not True: - assume(initial is inf) - - input_dtypes, x, axis = dtype_x_axis - where, input_dtypes, test_flags = np_helpers.handle_where_and_array_bools( - where=where, - input_dtype=input_dtypes, - test_flags=test_flags, - ) - np_helpers.test_frontend_function( - input_dtypes=input_dtypes, + kw = {} + i = 0 + for x_ in eq_n_op: + kw["x{}".format(i)] = x_ + i += 1 + test_flags.num_positional_args = i + helpers.test_frontend_function( + input_dtypes=dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - a=x[0], - axis=axis, + **kw, out=None, - keepdims=keepdims, - initial=initial, - where=where, + optimize="optimal", + precision=None, + _use_xeinsum=False, ) -# nanstd +# max @handle_frontend_test( - fn_tree="jax.numpy.nanstd", - dtype_and_a=_statistical_dtype_values(function="nanstd"), - dtype=helpers.get_dtypes("float", full=False, none=True), - where=np_frontend_helpers.where(), - keep_dims=st.booleans(), + fn_tree="jax.numpy.max", + aliases=["jax.numpy.amax"], + dtype_x_axis=_statistical_dtype_values(function="max"), + where=np_helpers.where(), + keepdims=st.booleans(), ) -def test_jax_nanstd( - dtype_and_a, - dtype, +def test_jax_max( + *, + dtype_x_axis, + keepdims, where, + on_device, + fn_tree, frontend, test_flags, - fn_tree, backend_fw, - on_device, - keep_dims, ): - input_dtypes, a, axis, correction = dtype_and_a + input_dtypes, x, axis = dtype_x_axis if isinstance(axis, tuple): axis = axis[0] - where, input_dtypes, test_flags = np_frontend_helpers.handle_where_and_array_bools( + where, input_dtypes, test_flags = np_helpers.handle_where_and_array_bools( where=where, input_dtype=input_dtypes, test_flags=test_flags, ) - assume(np.dtype(dtype[0]) >= np.dtype(input_dtypes[0])) - np_frontend_helpers.test_frontend_function( + + np_helpers.test_frontend_function( input_dtypes=input_dtypes, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - a=a[0], + a=x[0], axis=axis, - dtype=dtype[0], out=None, - ddof=correction, - keepdims=keep_dims, + keepdims=keepdims, where=where, - atol=1e-2, - rtol=1e-2, ) -# nanvar +# mean @handle_frontend_test( - fn_tree="jax.numpy.nanvar", - dtype_x_axis=_statistical_dtype_values(function="nanvar").filter( - lambda x: x[0][0] != "bfloat16" - ), - dtype=helpers.get_dtypes("float", full=False, none=True).filter( - lambda x: x != "bfloat16" - ), + fn_tree="jax.numpy.mean", + dtype_x_axis=_statistical_dtype_values(function="mean"), + dtype=helpers.get_dtypes("float", full=False, none=True), where=np_helpers.where(), keepdims=st.booleans(), ) -def test_jax_nanvar( +def test_jax_mean( *, dtype_x_axis, dtype, @@ -663,7 +607,7 @@ def test_jax_nanvar( test_flags, backend_fw, ): - input_dtypes, x, axis, ddof = dtype_x_axis + input_dtypes, x, axis = dtype_x_axis if isinstance(axis, tuple): axis = axis[0] where, input_dtypes, test_flags = np_helpers.handle_where_and_array_bools( @@ -671,7 +615,7 @@ def test_jax_nanvar( input_dtype=input_dtypes, test_flags=test_flags, ) - assume(np.dtype(dtype[0]) >= np.dtype(input_dtypes[0])) + np_helpers.test_frontend_function( input_dtypes=input_dtypes, backend_to_test=backend_fw, @@ -679,90 +623,42 @@ def test_jax_nanvar( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, + atol=1e-1, + rtol=1e-1, a=x[0], axis=axis, dtype=dtype[0], out=None, - ddof=ddof, keepdims=keepdims, where=where, - atol=1e-3, - rtol=1e-3, - ) - - -@st.composite -def _get_castable_dtypes_values(draw, *, allow_nan=False, use_where=False): - available_dtypes = helpers.get_dtypes("numeric") - shape = draw(helpers.get_shape(min_num_dims=1, max_num_dims=4, max_dim_size=6)) - dtype, values = draw( - helpers.dtype_and_values( - available_dtypes=available_dtypes, - num_arrays=1, - large_abs_safety_factor=24, - small_abs_safety_factor=24, - safety_factor_scale="log", - shape=shape, - allow_nan=allow_nan, - ) - ) - axis = draw(helpers.get_axis(shape=shape, force_int=True)) - dtype1, values, dtype2 = draw( - helpers.get_castable_dtype(draw(available_dtypes), dtype[0], values[0]) ) - if use_where: - where = draw(np_frontend_helpers.where(shape=shape)) - return [dtype1], [values], axis, dtype2, where - return [dtype1], [values], axis, dtype2 -# nancumprod +# median @handle_frontend_test( - fn_tree="jax.numpy.nancumprod", - dtype_and_x_axis_dtype=_get_castable_dtypes_values(allow_nan=True), + fn_tree="jax.numpy.median", + dtype_x_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=1, + min_value=-(2**10), + max_value=2**10, + valid_axis=True, + ), + keepdims=st.booleans(), ) -def test_jax_nancumprod( - dtype_and_x_axis_dtype, - frontend, - test_flags, - fn_tree, - backend_fw, +def test_jax_median( + *, + dtype_x_axis, + keepdims, on_device, -): - input_dtypes, x, axis, dtype = dtype_and_x_axis_dtype - if ivy.current_backend_str() == "torch": - assume(not test_flags.as_variable[0]) - helpers.test_frontend_function( - input_dtypes=input_dtypes, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - a=x[0], - axis=axis, - dtype=dtype, - ) - - -# nancumsum -@handle_frontend_test( - fn_tree="jax.numpy.nancumsum", - dtype_and_x_axis_dtype=_get_castable_dtypes_values(allow_nan=True), -) -def test_jax_nancumsum( - dtype_and_x_axis_dtype, + fn_tree, frontend, test_flags, - fn_tree, backend_fw, - on_device, ): - input_dtypes, x, axis, dtype = dtype_and_x_axis_dtype - if ivy.current_backend_str() == "torch": - assume(not test_flags.as_variable[0]) - np_frontend_helpers.test_frontend_function( - input_dtypes=input_dtypes, + input_dtype, x, axis = dtype_x_axis + helpers.test_frontend_function( + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, @@ -770,22 +666,25 @@ def test_jax_nancumsum( on_device=on_device, a=x[0], axis=axis, - dtype=dtype, + out=None, + overwrite_input=False, + keepdims=keepdims, + atol=1e-3, + rtol=1e-3, ) -# std +# min @handle_frontend_test( - fn_tree="jax.numpy.std", - dtype_x_axis=_statistical_dtype_values(function="std"), - dtype=helpers.get_dtypes("float", full=False, none=True), + fn_tree="jax.numpy.min", + aliases=["jax.numpy.amin"], + dtype_x_axis=_statistical_dtype_values(function="min"), where=np_helpers.where(), keepdims=st.booleans(), ) -def test_jax_std( +def test_jax_min( *, dtype_x_axis, - dtype, keepdims, where, on_device, @@ -794,7 +693,7 @@ def test_jax_std( test_flags, backend_fw, ): - input_dtypes, x, axis, ddof = dtype_x_axis + input_dtypes, x, axis = dtype_x_axis if isinstance(axis, tuple): axis = axis[0] where, input_dtypes, test_flags = np_helpers.handle_where_and_array_bools( @@ -812,80 +711,59 @@ def test_jax_std( on_device=on_device, a=x[0], axis=axis, - dtype=dtype[0], out=None, - ddof=ddof, keepdims=keepdims, where=where, - atol=1e-3, - rtol=1e-3, ) +# nancumprod @handle_frontend_test( - fn_tree="jax.numpy.corrcoef", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=["float32", "float64"], - num_arrays=2, - shared_dtype=True, - abs_smallest_val=1e-5, - min_num_dims=2, - max_num_dims=2, - min_dim_size=3, - max_dim_size=3, - min_value=-100, - max_value=100, - ), - rowvar=st.booleans(), + fn_tree="jax.numpy.nancumprod", + dtype_and_x_axis_dtype=_get_castable_dtypes_values(allow_nan=True), ) -def test_jax_corrcoef( - dtype_and_x, - rowvar, +def test_jax_nancumprod( + dtype_and_x_axis_dtype, frontend, test_flags, fn_tree, backend_fw, on_device, ): - input_dtypes, x = dtype_and_x - np_helpers.test_frontend_function( + input_dtypes, x, axis, dtype = dtype_and_x_axis_dtype + if ivy.current_backend_str() == "torch": + assume(not test_flags.as_variable[0]) + helpers.test_frontend_function( input_dtypes=input_dtypes, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], - y=x[1], - rowvar=rowvar, + a=x[0], + axis=axis, + dtype=dtype, ) -# median +# nancumsum @handle_frontend_test( - fn_tree="jax.numpy.median", - dtype_x_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("valid"), - min_num_dims=1, - min_value=-(2**10), - max_value=2**10, - valid_axis=True, - ), - keepdims=st.booleans(), + fn_tree="jax.numpy.nancumsum", + dtype_and_x_axis_dtype=_get_castable_dtypes_values(allow_nan=True), ) -def test_jax_median( - *, - dtype_x_axis, - keepdims, - on_device, - fn_tree, +def test_jax_nancumsum( + dtype_and_x_axis_dtype, frontend, test_flags, + fn_tree, backend_fw, + on_device, ): - input_dtype, x, axis = dtype_x_axis - helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes, x, axis, dtype = dtype_and_x_axis_dtype + if ivy.current_backend_str() == "torch": + assume(not test_flags.as_variable[0]) + np_frontend_helpers.test_frontend_function( + input_dtypes=input_dtypes, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, @@ -893,31 +771,49 @@ def test_jax_median( on_device=on_device, a=x[0], axis=axis, - out=None, - overwrite_input=False, - keepdims=keepdims, - atol=1e-3, - rtol=1e-3, + dtype=dtype, ) -# ptp +# nanmax @handle_frontend_test( - fn_tree="jax.numpy.ptp", - dtype_and_x_axis_dtype=_get_castable_dtypes_values(), - keep_dims=st.booleans(), + fn_tree="jax.numpy.nanmax", + aliases=["jax.numpy.nanmax"], + dtype_x_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("float"), + min_num_dims=1, + valid_axis=True, + force_int_axis=True, + large_abs_safety_factor=2, + safety_factor_scale="log", + allow_nan=True, + allow_inf=True, + ), + initial=st.one_of(st.floats(min_value=-1000, max_value=1000), st.none()), + keepdims=st.booleans(), + where=np_helpers.where(), ) -def test_jax_ptp( - dtype_and_x_axis_dtype, +def test_jax_nanmax( + dtype_x_axis, frontend, test_flags, fn_tree, backend_fw, on_device, - keep_dims, + where, + initial, + keepdims, ): - input_dtypes, x, axis, dtype = dtype_and_x_axis_dtype - np_frontend_helpers.test_frontend_function( + if initial is None and np.all(where) is not True: + assume(initial is -inf) + + input_dtypes, x, axis = dtype_x_axis + where, input_dtypes, test_flags = np_helpers.handle_where_and_array_bools( + where=where, + input_dtype=input_dtypes, + test_flags=test_flags, + ) + np_helpers.test_frontend_function( input_dtypes=input_dtypes, backend_to_test=backend_fw, frontend=frontend, @@ -927,32 +823,10 @@ def test_jax_ptp( a=x[0], axis=axis, out=None, - keepdims=keep_dims, - ) - - -# nanmean -@st.composite -def _get_castable_dtype_with_nan(draw): - available_dtypes = helpers.get_dtypes("float") - shape = draw(helpers.get_shape(min_num_dims=1, max_num_dims=4, max_dim_size=6)) - dtype, values = draw( - helpers.dtype_and_values( - available_dtypes=available_dtypes, - num_arrays=1, - large_abs_safety_factor=6, - small_abs_safety_factor=24, - safety_factor_scale="log", - shape=shape, - allow_nan=True, - allow_inf=True, - ) - ) - axis = draw(helpers.get_axis(shape=shape, force_int=True)) - dtype1, values, dtype2 = draw( - helpers.get_castable_dtype(draw(available_dtypes), dtype[0], values[0]) + keepdims=keepdims, + initial=initial, + where=where, ) - return dtype1, [values], axis, dtype2 @handle_frontend_test( @@ -1033,202 +907,187 @@ def test_jax_nanmedian( ) -# correlate +# nanmin @handle_frontend_test( - fn_tree="jax.numpy.correlate", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, + fn_tree="jax.numpy.nanmin", + dtype_x_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("float", full=False), min_num_dims=1, - max_num_dims=1, - min_value=-1e04, - max_value=1e04, - shared_dtype=True, + valid_axis=True, + force_int_axis=True, + large_abs_safety_factor=2, + safety_factor_scale="log", + allow_nan=True, + allow_inf=True, ), - mode=st.sampled_from(["valid", "same", "full"]), + initial=st.one_of(st.floats(min_value=-1000, max_value=1000), st.none()), + keepdims=st.booleans(), + where=np_helpers.where(), ) -def test_jax_correlate( - *, - dtype_and_x, - on_device, - fn_tree, +def test_jax_nanmin( + dtype_x_axis, frontend, test_flags, + fn_tree, backend_fw, - mode, + on_device, + where, + initial, + keepdims, ): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, + if initial is None and np.all(where) is not True: + assume(initial is inf) + + input_dtypes, x, axis = dtype_x_axis + where, input_dtypes, test_flags = np_helpers.handle_where_and_array_bools( + where=where, + input_dtype=input_dtypes, + test_flags=test_flags, + ) + np_helpers.test_frontend_function( + input_dtypes=input_dtypes, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, - rtol=1e-4, - atol=1e-4, on_device=on_device, a=x[0], - v=x[1], - mode=mode, + axis=axis, + out=None, + keepdims=keepdims, + initial=initial, + where=where, ) -# cov -@st.composite -def _get_dtype_value1_value2_cov( - draw, - available_dtypes, - min_num_dims=1, - max_num_dims=2, - min_dim_size=2, - max_dim_size=3, - abs_smallest_val=None, - min_value=None, - max_value=None, - allow_inf=False, - exclude_min=False, - exclude_max=False, - large_abs_safety_factor=50, - small_abs_safety_factor=50, - safety_factor_scale="log", +# nanstd +@handle_frontend_test( + fn_tree="jax.numpy.nanstd", + dtype_and_a=_statistical_dtype_values(function="nanstd"), + dtype=helpers.get_dtypes("float", full=False, none=True), + where=np_frontend_helpers.where(), + keep_dims=st.booleans(), +) +def test_jax_nanstd( + dtype_and_a, + dtype, + where, + frontend, + test_flags, + fn_tree, + backend_fw, + on_device, + keep_dims, ): - shape = draw( - helpers.get_shape( - allow_none=False, - min_num_dims=min_num_dims, - max_num_dims=max_num_dims, - min_dim_size=min_dim_size, - max_dim_size=max_dim_size, - ) + input_dtypes, a, axis, correction = dtype_and_a + if isinstance(axis, tuple): + axis = axis[0] + where, input_dtypes, test_flags = np_frontend_helpers.handle_where_and_array_bools( + where=where, + input_dtype=input_dtypes, + test_flags=test_flags, ) - - dtype = draw(st.sampled_from(available_dtypes)) - - values = [] - for i in range(2): - values.append( - draw( - helpers.array_values( - dtype=dtype, - shape=shape, - abs_smallest_val=abs_smallest_val, - min_value=min_value, - max_value=max_value, - allow_inf=allow_inf, - exclude_min=exclude_min, - exclude_max=exclude_max, - large_abs_safety_factor=large_abs_safety_factor, - small_abs_safety_factor=small_abs_safety_factor, - safety_factor_scale=safety_factor_scale, - ) - ) - ) - - value1, value2 = values[0], values[1] - - # modifiers: rowVar, bias, ddof - rowVar = draw(st.booleans()) - bias = draw(st.booleans()) - ddof = draw(helpers.ints(min_value=0, max_value=1)) - - numVals = None - if rowVar is False: - numVals = -1 if numVals == 0 else 0 - else: - numVals = 0 if len(shape) == 1 else -1 - - fweights = None - - aweights = draw( - helpers.array_values( - dtype=dtype, - shape=shape[numVals], - abs_smallest_val=1, - min_value=1, - max_value=10, - allow_inf=False, - small_abs_safety_factor=1, - ) + assume(np.dtype(dtype[0]) >= np.dtype(input_dtypes[0])) + np_frontend_helpers.test_frontend_function( + input_dtypes=input_dtypes, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + a=a[0], + axis=axis, + dtype=dtype[0], + out=None, + ddof=correction, + keepdims=keep_dims, + where=where, + atol=1e-2, + rtol=1e-2, ) - return [dtype], value1, value2, rowVar, bias, ddof, fweights, aweights - +# nanvar @handle_frontend_test( - fn_tree="jax.numpy.cov", - dtypes_args=_get_dtype_value1_value2_cov(available_dtypes=["float64"]), - test_with_out=st.just(False), + fn_tree="jax.numpy.nanvar", + dtype_x_axis=_statistical_dtype_values(function="nanvar").filter( + lambda x: x[0][0] != "bfloat16" + ), + dtype=helpers.get_dtypes("float", full=False, none=True).filter( + lambda x: x != "bfloat16" + ), + where=np_helpers.where(), + keepdims=st.booleans(), ) -def test_jax_cov( +def test_jax_nanvar( *, - dtypes_args, + dtype_x_axis, + dtype, + keepdims, + where, on_device, fn_tree, frontend, - backend_fw, test_flags, + backend_fw, ): - dtype, value1, value2, rowvar, bias, ddof, fweights, aweights = dtypes_args - helpers.test_frontend_function( - input_dtypes=dtype, - frontend=frontend, + input_dtypes, x, axis, ddof = dtype_x_axis + if isinstance(axis, tuple): + axis = axis[0] + where, input_dtypes, test_flags = np_helpers.handle_where_and_array_bools( + where=where, + input_dtype=input_dtypes, + test_flags=test_flags, + ) + assume(np.dtype(dtype[0]) >= np.dtype(input_dtypes[0])) + np_helpers.test_frontend_function( + input_dtypes=input_dtypes, backend_to_test=backend_fw, + frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, - rtol=1e-3, - atol=1e-3, on_device=on_device, - m=value1, - y=value2, - rowvar=rowvar, - bias=bias, + a=x[0], + axis=axis, + dtype=dtype[0], + out=None, ddof=ddof, - fweights=fweights, - aweights=aweights, + keepdims=keepdims, + where=where, + atol=1e-3, + rtol=1e-3, ) -@st.composite -def _get_array_axes_probs(draw): - array_dtypes = draw(helpers.get_dtypes(kind="float")) - array_dtype, array, axes = draw( - helpers.dtype_values_axis( - available_dtypes=array_dtypes, - small_abs_safety_factor=5, - large_abs_safety_factor=5, - min_num_dims=1, - max_num_dims=5, - max_dim_size=7, - max_axes_size=5, - valid_axis=True, - force_int_axis=True, - min_value=1, - max_value=300, - ) - ) - q = np.round( - np.array( - draw( - helpers.lists( - x=helpers.floats( - min_value=0, - max_value=1, - small_abs_safety_factor=50, - large_abs_safety_factor=50, - safety_factor_scale="log", - abs_smallest_val=1e-1, - mixed_fn_compos=False, - ), - min_size=1, - max_size=10, - ) - ) - ), - decimals=3, +# ptp +@handle_frontend_test( + fn_tree="jax.numpy.ptp", + dtype_and_x_axis_dtype=_get_castable_dtypes_values(), + keep_dims=st.booleans(), +) +def test_jax_ptp( + dtype_and_x_axis_dtype, + frontend, + test_flags, + fn_tree, + backend_fw, + on_device, + keep_dims, +): + input_dtypes, x, axis, dtype = dtype_and_x_axis_dtype + np_frontend_helpers.test_frontend_function( + input_dtypes=input_dtypes, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + a=x[0], + axis=axis, + out=None, + keepdims=keep_dims, ) - return array_dtype, array, axes, q - @handle_frontend_test( fn_tree="jax.numpy.quantile", @@ -1264,3 +1123,152 @@ def test_jax_quantile( method=method, keepdims=keepdims, ) + + +# std +@handle_frontend_test( + fn_tree="jax.numpy.std", + dtype_x_axis=_statistical_dtype_values(function="std"), + dtype=helpers.get_dtypes("float", full=False, none=True), + where=np_helpers.where(), + keepdims=st.booleans(), +) +def test_jax_std( + *, + dtype_x_axis, + dtype, + keepdims, + where, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtypes, x, axis, ddof = dtype_x_axis + if isinstance(axis, tuple): + axis = axis[0] + where, input_dtypes, test_flags = np_helpers.handle_where_and_array_bools( + where=where, + input_dtype=input_dtypes, + test_flags=test_flags, + ) + + np_helpers.test_frontend_function( + input_dtypes=input_dtypes, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + a=x[0], + axis=axis, + dtype=dtype[0], + out=None, + ddof=ddof, + keepdims=keepdims, + where=where, + atol=1e-3, + rtol=1e-3, + ) + + +# sum +@handle_frontend_test( + fn_tree="jax.numpy.sum", + dtype_x_axis_castable=_get_castable_dtype(), + initial=st.none() | st.floats(-10.0, 10.0), + where=np_helpers.where(), + keepdims=st.booleans(), +) +def test_jax_sum( + *, + dtype_x_axis_castable, + initial, + where, + keepdims, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtypes, x, axis, castable_dtype = dtype_x_axis_castable + + if isinstance(axis, tuple): + axis = axis[0] + where, input_dtypes, test_flags = np_helpers.handle_where_and_array_bools( + where=where, + input_dtype=[input_dtypes], + test_flags=test_flags, + ) + + np_helpers.test_frontend_function( + input_dtypes=input_dtypes, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + rtol=1e-1, + atol=1e-2, + a=x[0], + axis=axis, + dtype=castable_dtype, + out=None, + keepdims=keepdims, + initial=initial, + where=where, + backend_to_test=backend_fw, + ) + + +# var +@handle_frontend_test( + fn_tree="jax.numpy.var", + dtype_x_axis=_statistical_dtype_values(function="var").filter( + lambda x: x[0][0] != "bfloat16" + ), + dtype=helpers.get_dtypes("float", full=False, none=True).filter( + lambda x: x != "bfloat16" + ), + where=np_helpers.where(), + keepdims=st.booleans(), +) +def test_jax_var( + *, + dtype_x_axis, + dtype, + keepdims, + where, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtypes, x, axis, ddof = dtype_x_axis + if isinstance(axis, tuple): + axis = axis[0] + where, input_dtypes, test_flags = np_helpers.handle_where_and_array_bools( + where=where, + input_dtype=input_dtypes, + test_flags=test_flags, + ) + + np_helpers.test_frontend_function( + input_dtypes=input_dtypes, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + a=x[0], + axis=axis, + dtype=dtype[0], + out=None, + ddof=ddof, + keepdims=keepdims, + where=where, + atol=1e-3, + rtol=1e-3, + ) diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_random.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_random.py index 6d80a44cbbab3..10ddbd13320db 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_random.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_random.py @@ -10,6 +10,35 @@ from ivy_tests.test_ivy.helpers import handle_frontend_test +# --- Helpers --- # +# --------------- # + + +@st.composite +def _all_gamma_params(draw): + shape = draw( + helpers.get_shape( + min_dim_size=1, max_dim_size=5, min_num_dims=2, max_num_dims=2 + ) + | st.just(None) + ) + if shape is None: + a = draw( + helpers.array_values( + min_value=0.0, + max_value=100.0, + dtype=helpers.get_dtypes("float", full=False), + exclude_min=True, + shape=helpers.get_shape( + min_dim_size=1, max_dim_size=5, min_num_dims=1, max_num_dims=2 + ), + ) + ) + return a[0], shape + a = draw(st.floats(min_value=0, max_value=5, exclude_min=True)) + return a, shape + + # ToDo: Find solution around torch and paddle not running with uints 32, # 64 and remove xfail fixture @@ -22,9 +51,73 @@ def _get_minval_maxval(draw): return minval, maxval +@st.composite +def dtype_p_shape(draw): + dtype = draw(helpers.array_dtypes(available_dtypes=("float32", "float64"))) + shape = draw(helpers.get_shape(allow_none=False, min_num_dims=1, max_num_dims=3)) + + dtype_and_probs = draw( + helpers.dtype_and_values( + available_dtypes=dtype, min_value=0, max_value=1, shape=shape + ) + ) + return dtype_and_probs, shape + + +@st.composite +def get_mean_cov_vector(draw): + input_dtype = draw( + st.shared( + st.sampled_from(draw(helpers.get_dtypes("float"))), + key="shared_dtype", + ) + ) + shared_size = draw( + st.shared(helpers.ints(min_value=2, max_value=4), key="shared_size") + ) + + # Generate shape for mean vector (..., n) + dtype_mean = draw( + helpers.array_values( + dtype=input_dtype, + shape=tuple([shared_size]), + min_value=2, + max_value=5, + ) + ) + + # Generate shape for covariance matrix (..., n, n) + dtype_cov = draw( + helpers.array_values( + dtype=input_dtype, + shape=tuple([shared_size, shared_size]), + min_value=2, + max_value=5, + ).filter(lambda x: np.linalg.cond(x.tolist()) < 1 / sys.float_info.epsilon) + ) + + batch_shape = dtype_cov.shape[:-2] + + return input_dtype, dtype_mean, dtype_cov, batch_shape + + +@st.composite +def get_shape_and_arrays(draw): + b_shapes = draw( + helpers.array_and_broadcastable_shape(dtype=helpers.get_dtypes("float")) + ) + b, shapes = b_shapes + shapes = draw(st.sampled_from([None, shapes])) + return b, shapes + + +# --- Main --- # +# ------------ # + + @pytest.mark.xfail @handle_frontend_test( - fn_tree="jax.random.uniform", + fn_tree="jax.random.ball", dtype_key=helpers.dtype_and_values( available_dtypes=["uint32"], min_value=0, @@ -34,24 +127,28 @@ def _get_minval_maxval(draw): min_dim_size=2, max_dim_size=2, ), - shape=helpers.get_shape(), + shape=helpers.get_shape( + min_num_dims=1, max_num_dims=6, min_dim_size=1, max_dim_size=6 + ), dtype=helpers.get_dtypes("float", full=False), - dtype_minval_maxval=_get_minval_maxval(), + d=st.integers(min_value=1, max_value=100), + p=st.floats(min_value=1e-5, max_value=100, exclude_min=True), + test_with_out=st.just(False), ) -def test_jax_uniform( +def test_jax_ball( *, dtype_key, + d, + p, shape, dtype, - dtype_minval_maxval, on_device, fn_tree, frontend, - test_flags, backend_fw, + test_flags, ): input_dtype, key = dtype_key - minval, maxval = dtype_minval_maxval def call(): return helpers.test_frontend_function( @@ -63,10 +160,10 @@ def call(): on_device=on_device, test_values=False, key=key[0], + d=d, + p=p, shape=shape, dtype=dtype[0], - minval=minval, - maxval=maxval, ) ret = call() @@ -84,26 +181,22 @@ def call(): @pytest.mark.xfail @handle_frontend_test( - fn_tree="jax.random.orthogonal", + fn_tree="jax.random.bernoulli", dtype_key=helpers.dtype_and_values( - available_dtypes=["float32", "float64"], + available_dtypes=helpers.get_dtypes("integer", full=False), min_value=0, max_value=2000, min_num_dims=1, - max_num_dims=3, + max_num_dims=1, min_dim_size=2, - max_dim_size=5, + max_dim_size=2, ), - n=helpers.get_shape(), - shape=helpers.get_shape(), - dtype=helpers.get_dtypes("float", full=False), + dtype_p_shape_=dtype_p_shape(), ) -def test_jax_orthogonal( +def test_jax_bernoulli( *, dtype_key, - n, - shape, - dtype, + dtype_p_shape_, on_device, fn_tree, frontend, @@ -111,20 +204,21 @@ def test_jax_orthogonal( backend_fw, ): input_dtype, key = dtype_key + dtype_p, shape = dtype_p_shape_ + dtype, p = dtype_p def call(): return helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=input_dtype + dtype, frontend=frontend, test_flags=test_flags, - backend_to_test=backend_fw, fn_tree=fn_tree, + backend_to_test=backend_fw, on_device=on_device, test_values=False, key=key[0], - n=n, + p=p[0], shape=shape, - dtype=dtype[0], ) ret = call() @@ -138,13 +232,11 @@ def call(): for u, v in zip(ret_np, ret_from_np): assert u.dtype == v.dtype assert u.shape == v.shape - # Check if the output matrices are orthogonal - assert ivy.allclose(ivy.eye(n), ivy.matmul(u.T, u)) @pytest.mark.xfail @handle_frontend_test( - fn_tree="jax.random.normal", + fn_tree="jax.random.beta", dtype_key=helpers.dtype_and_values( available_dtypes=["uint32"], min_value=0, @@ -154,12 +246,19 @@ def call(): min_dim_size=2, max_dim_size=2, ), - shape=helpers.get_shape(), + alpha=st.floats(min_value=0, max_value=5, exclude_min=True), + beta=st.floats(min_value=0, max_value=5, exclude_min=True), + shape=helpers.get_shape( + min_num_dims=2, max_num_dims=2, min_dim_size=1, max_dim_size=5 + ), dtype=helpers.get_dtypes("float", full=False), + test_with_out=st.just(False), ) -def test_jax_normal( +def test_jax_beta( *, dtype_key, + alpha, + beta, shape, dtype, on_device, @@ -180,6 +279,8 @@ def call(): on_device=on_device, test_values=False, key=key[0], + a=alpha, + b=beta, shape=shape, dtype=dtype[0], ) @@ -199,7 +300,7 @@ def call(): @pytest.mark.xfail @handle_frontend_test( - fn_tree="jax.random.beta", + fn_tree="jax.random.categorical", dtype_key=helpers.dtype_and_values( available_dtypes=["uint32"], min_value=0, @@ -209,19 +310,14 @@ def call(): min_dim_size=2, max_dim_size=2, ), - alpha=st.floats(min_value=0, max_value=5, exclude_min=True), - beta=st.floats(min_value=0, max_value=5, exclude_min=True), shape=helpers.get_shape( - min_num_dims=2, max_num_dims=2, min_dim_size=1, max_dim_size=5 + min_dim_size=1, max_num_dims=6, max_dim_size=6, min_num_dims=1, allow_none=False ), dtype=helpers.get_dtypes("float", full=False), - test_with_out=st.just(False), ) -def test_jax_beta( +def test_jax_categorical( *, dtype_key, - alpha, - beta, shape, dtype, on_device, @@ -242,17 +338,13 @@ def call(): on_device=on_device, test_values=False, key=key[0], - a=alpha, - b=beta, shape=shape, dtype=dtype[0], ) ret = call() - if not ivy.exists(ret): return - ret_np, ret_from_np = ret ret_np = helpers.flatten_and_to_np(ret=ret_np, backend=backend_fw) ret_from_np = helpers.flatten_and_to_np(ret=ret_from_np, backend=backend_fw) @@ -263,7 +355,7 @@ def call(): @pytest.mark.xfail @handle_frontend_test( - fn_tree="jax.random.dirichlet", + fn_tree="jax.random.cauchy", dtype_key=helpers.dtype_and_values( available_dtypes=["uint32"], min_value=0, @@ -273,25 +365,12 @@ def call(): min_dim_size=2, max_dim_size=2, ), - dtype_alpha=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float", full=False), - shape=st.tuples( - st.integers(min_value=2, max_value=5), - ), - min_value=1.1, - max_value=100.0, - exclude_min=True, - ), - shape=helpers.get_shape( - min_num_dims=2, max_num_dims=2, min_dim_size=2, max_dim_size=5 - ), + shape=helpers.get_shape(), dtype=helpers.get_dtypes("float", full=False), - test_with_out=st.just(False), ) -def test_jax_dirichlet( +def test_jax_cauchy( *, dtype_key, - dtype_alpha, shape, dtype, on_device, @@ -301,7 +380,6 @@ def test_jax_dirichlet( backend_fw, ): input_dtype, key = dtype_key - _, alpha = dtype_alpha def call(): return helpers.test_frontend_function( @@ -313,7 +391,6 @@ def call(): on_device=on_device, test_values=False, key=key[0], - alpha=alpha[0], shape=shape, dtype=dtype[0], ) @@ -333,7 +410,7 @@ def call(): @pytest.mark.xfail @handle_frontend_test( - fn_tree="jax.random.cauchy", + fn_tree="jax.random.dirichlet", dtype_key=helpers.dtype_and_values( available_dtypes=["uint32"], min_value=0, @@ -343,12 +420,25 @@ def call(): min_dim_size=2, max_dim_size=2, ), - shape=helpers.get_shape(), + dtype_alpha=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float", full=False), + shape=st.tuples( + st.integers(min_value=2, max_value=5), + ), + min_value=1.1, + max_value=100.0, + exclude_min=True, + ), + shape=helpers.get_shape( + min_num_dims=2, max_num_dims=2, min_dim_size=2, max_dim_size=5 + ), dtype=helpers.get_dtypes("float", full=False), + test_with_out=st.just(False), ) -def test_jax_cauchy( +def test_jax_dirichlet( *, dtype_key, + dtype_alpha, shape, dtype, on_device, @@ -358,6 +448,7 @@ def test_jax_cauchy( backend_fw, ): input_dtype, key = dtype_key + _, alpha = dtype_alpha def call(): return helpers.test_frontend_function( @@ -369,6 +460,7 @@ def call(): on_device=on_device, test_values=False, key=key[0], + alpha=alpha[0], shape=shape, dtype=dtype[0], ) @@ -388,7 +480,7 @@ def call(): @pytest.mark.xfail @handle_frontend_test( - fn_tree="jax.random.poisson", + fn_tree="jax.random.exponential", dtype_key=helpers.dtype_and_values( available_dtypes=["uint32"], min_value=0, @@ -398,17 +490,12 @@ def call(): min_dim_size=2, max_dim_size=2, ), - lam=st.floats(min_value=0, max_value=5, exclude_min=True), - shape=helpers.get_shape( - min_num_dims=2, max_num_dims=2, min_dim_size=1, max_dim_size=5 - ), - dtype=helpers.get_dtypes("integer", full=False), - test_with_out=st.just(False), + shape=helpers.get_shape(allow_none=False, min_num_dims=1, min_dim_size=1), + dtype=helpers.get_dtypes("float", full=False), ) -def test_jax_poisson( +def test_jax_exponential( *, dtype_key, - lam, shape, dtype, on_device, @@ -422,14 +509,13 @@ def test_jax_poisson( def call(): return helpers.test_frontend_function( input_dtypes=input_dtype, - backend_to_test=backend_fw, frontend=frontend, + backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, test_values=False, key=key[0], - lam=lam, shape=shape, dtype=dtype[0], ) @@ -440,47 +526,16 @@ def call(): return ret_np, ret_from_np = ret - ret_np = helpers.flatten_and_to_np( - ret=ret_np, - backend=backend_fw, - ) - ret_from_np = helpers.flatten_and_to_np( - ret=ret_from_np, - backend=backend_fw, - ) + ret_np = helpers.flatten_and_to_np(ret=ret_np, backend=backend_fw) + ret_from_np = helpers.flatten_and_to_np(ret=ret_from_np, backend=backend_fw) for u, v in zip(ret_np, ret_from_np): assert u.dtype == v.dtype assert u.shape == v.shape -@st.composite -def _all_gamma_params(draw): - shape = draw( - helpers.get_shape( - min_dim_size=1, max_dim_size=5, min_num_dims=2, max_num_dims=2 - ) - | st.just(None) - ) - if shape is None: - a = draw( - helpers.array_values( - min_value=0.0, - max_value=100.0, - dtype=helpers.get_dtypes("float", full=False), - exclude_min=True, - shape=helpers.get_shape( - min_dim_size=1, max_dim_size=5, min_num_dims=1, max_num_dims=2 - ), - ) - ) - return a[0], shape - a = draw(st.floats(min_value=0, max_value=5, exclude_min=True)) - return a, shape - - @pytest.mark.xfail @handle_frontend_test( - fn_tree="jax.random.gamma", + fn_tree="jax.random.fold_in", dtype_key=helpers.dtype_and_values( available_dtypes=["uint32"], min_value=0, @@ -490,15 +545,12 @@ def _all_gamma_params(draw): min_dim_size=2, max_dim_size=2, ), - a_shape=_all_gamma_params(), - dtype=helpers.get_dtypes("float", full=False), - test_with_out=st.just(False), + data=helpers.ints(), ) -def test_jax_gamma( +def test_jax_fold_in( *, dtype_key, - a_shape, - dtype, + data, on_device, fn_tree, frontend, @@ -506,21 +558,18 @@ def test_jax_gamma( backend_fw, ): input_dtype, key = dtype_key - a, shape = a_shape def call(): return helpers.test_frontend_function( input_dtypes=input_dtype, frontend=frontend, + backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, - backend_to_test=backend_fw, on_device=on_device, test_values=False, key=key[0], - a=a, - shape=shape, - dtype=dtype[0], + data=data, ) ret = call() @@ -538,7 +587,7 @@ def call(): @pytest.mark.xfail @handle_frontend_test( - fn_tree="jax.random.gumbel", + fn_tree="jax.random.gamma", dtype_key=helpers.dtype_and_values( available_dtypes=["uint32"], min_value=0, @@ -548,15 +597,14 @@ def call(): min_dim_size=2, max_dim_size=2, ), - shape=helpers.get_shape( - min_dim_size=1, max_num_dims=6, max_dim_size=6, min_num_dims=1, allow_none=False - ), + a_shape=_all_gamma_params(), dtype=helpers.get_dtypes("float", full=False), + test_with_out=st.just(False), ) -def test_jax_gumbel( +def test_jax_gamma( *, dtype_key, - shape, + a_shape, dtype, on_device, fn_tree, @@ -565,17 +613,19 @@ def test_jax_gumbel( backend_fw, ): input_dtype, key = dtype_key + a, shape = a_shape def call(): return helpers.test_frontend_function( input_dtypes=input_dtype, frontend=frontend, test_flags=test_flags, - backend_to_test=backend_fw, fn_tree=fn_tree, + backend_to_test=backend_fw, on_device=on_device, test_values=False, key=key[0], + a=a, shape=shape, dtype=dtype[0], ) @@ -595,7 +645,7 @@ def call(): @pytest.mark.xfail @handle_frontend_test( - fn_tree="jax.random.t", + fn_tree="jax.random.generalized_normal", dtype_key=helpers.dtype_and_values( available_dtypes=["uint32"], min_value=0, @@ -605,17 +655,17 @@ def call(): min_dim_size=2, max_dim_size=2, ), - df=st.floats(min_value=0, max_value=5, exclude_min=True), + p=st.floats(min_value=1e-5, max_value=100, exclude_min=True), shape=helpers.get_shape( min_num_dims=1, max_num_dims=6, min_dim_size=1, max_dim_size=6 ), dtype=helpers.get_dtypes("float", full=False), test_with_out=st.just(False), ) -def test_jax_t( +def test_jax_generalized_normal( *, dtype_key, - df, + p, shape, dtype, on_device, @@ -630,13 +680,13 @@ def call(): return helpers.test_frontend_function( input_dtypes=input_dtype, frontend=frontend, - test_flags=test_flags, backend_to_test=backend_fw, + test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, test_values=False, key=key[0], - df=df, + p=p, shape=shape, dtype=dtype[0], ) @@ -656,7 +706,7 @@ def call(): @pytest.mark.xfail @handle_frontend_test( - fn_tree="jax.random.generalized_normal", + fn_tree="jax.random.gumbel", dtype_key=helpers.dtype_and_values( available_dtypes=["uint32"], min_value=0, @@ -666,17 +716,14 @@ def call(): min_dim_size=2, max_dim_size=2, ), - p=st.floats(min_value=1e-5, max_value=100, exclude_min=True), shape=helpers.get_shape( - min_num_dims=1, max_num_dims=6, min_dim_size=1, max_dim_size=6 + min_dim_size=1, max_num_dims=6, max_dim_size=6, min_num_dims=1, allow_none=False ), dtype=helpers.get_dtypes("float", full=False), - test_with_out=st.just(False), ) -def test_jax_generalized_normal( +def test_jax_gumbel( *, dtype_key, - p, shape, dtype, on_device, @@ -691,13 +738,12 @@ def call(): return helpers.test_frontend_function( input_dtypes=input_dtype, frontend=frontend, - backend_to_test=backend_fw, test_flags=test_flags, + backend_to_test=backend_fw, fn_tree=fn_tree, on_device=on_device, test_values=False, key=key[0], - p=p, shape=shape, dtype=dtype[0], ) @@ -715,9 +761,10 @@ def call(): assert u.shape == v.shape +# loggamma @pytest.mark.xfail @handle_frontend_test( - fn_tree="jax.random.rademacher", + fn_tree="jax.random.loggamma", dtype_key=helpers.dtype_and_values( available_dtypes=["uint32"], min_value=0, @@ -727,13 +774,14 @@ def call(): min_dim_size=2, max_dim_size=2, ), - shape=helpers.get_shape(allow_none=False, min_num_dims=1, min_dim_size=1), - dtype=helpers.get_dtypes("integer", full=False), + a_shape=_all_gamma_params(), + dtype=helpers.get_dtypes("float", full=False), + test_with_out=st.just(False), ) -def test_jax_rademacher( +def test_jax_loggamma( *, dtype_key, - shape, + a_shape, dtype, on_device, fn_tree, @@ -742,6 +790,7 @@ def test_jax_rademacher( backend_fw, ): input_dtype, key = dtype_key + a, shape = a_shape def call(): return helpers.test_frontend_function( @@ -753,6 +802,7 @@ def call(): on_device=on_device, test_values=False, key=key[0], + a=a, shape=shape, dtype=dtype[0], ) @@ -772,7 +822,7 @@ def call(): @pytest.mark.xfail @handle_frontend_test( - fn_tree="jax.random.randint", + fn_tree="jax.random.maxwell", dtype_key=helpers.dtype_and_values( available_dtypes=["uint32"], min_value=0, @@ -782,38 +832,33 @@ def call(): min_dim_size=2, max_dim_size=2, ), - shape=helpers.get_shape(allow_none=False, min_num_dims=1, min_dim_size=1), - dtype=helpers.get_dtypes("integer", full=False), - min_max=helpers.general_helpers.get_bounds(dtype="int16"), + shape=helpers.get_shape(), + dtype=helpers.get_dtypes("float", full=False), ) -def test_jax_randint( +def test_jax_maxwell( *, dtype_key, shape, dtype, - min_max, on_device, fn_tree, frontend, - test_flags, backend_fw, + test_flags, ): input_dtype, key = dtype_key - minval, maxval = min_max def call(): return helpers.test_frontend_function( input_dtypes=input_dtype, frontend=frontend, - test_flags=test_flags, backend_to_test=backend_fw, + test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, test_values=False, key=key[0], shape=shape, - minval=minval, - maxval=maxval, dtype=dtype[0], ) @@ -830,24 +875,11 @@ def call(): assert u.shape == v.shape -@st.composite -def dtype_p_shape(draw): - dtype = draw(helpers.array_dtypes(available_dtypes=("float32", "float64"))) - shape = draw(helpers.get_shape(allow_none=False, min_num_dims=1, max_num_dims=3)) - - dtype_and_probs = draw( - helpers.dtype_and_values( - available_dtypes=dtype, min_value=0, max_value=1, shape=shape - ) - ) - return dtype_and_probs, shape - - @pytest.mark.xfail @handle_frontend_test( - fn_tree="jax.random.bernoulli", + fn_tree="jax.random.multivariate_normal", dtype_key=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("integer", full=False), + available_dtypes=["uint32"], min_value=0, max_value=2000, min_num_dims=1, @@ -855,34 +887,41 @@ def dtype_p_shape(draw): min_dim_size=2, max_dim_size=2, ), - dtype_p_shape_=dtype_p_shape(), + dtype=helpers.get_dtypes("float", full=False), + mean_cov_vector=get_mean_cov_vector(), + method=st.sampled_from(["cholesky", "eigh", "svd"]), + test_with_out=st.just(False), ) -def test_jax_bernoulli( +def test_jax_multivariate_normal( *, dtype_key, - dtype_p_shape_, - on_device, - fn_tree, + mean_cov_vector, + dtype, + method, frontend, - test_flags, backend_fw, + test_flags, + fn_tree, ): input_dtype, key = dtype_key - dtype_p, shape = dtype_p_shape_ - dtype, p = dtype_p + shared_dtype, mean, cov, shape = mean_cov_vector + + spd = np.matmul(cov.T, cov) + np.identity(cov.shape[0]) def call(): - return helpers.test_frontend_function( - input_dtypes=input_dtype + dtype, + helpers.test_frontend_function( + input_dtypes=input_dtype + [shared_dtype], frontend=frontend, test_flags=test_flags, - fn_tree=fn_tree, backend_to_test=backend_fw, - on_device=on_device, + fn_tree=fn_tree, test_values=False, key=key[0], - p=p[0], + mean=mean, + cov=spd, shape=shape, + dtype=dtype[0], + method=method, ) ret = call() @@ -900,7 +939,7 @@ def call(): @pytest.mark.xfail @handle_frontend_test( - fn_tree="jax.random.fold_in", + fn_tree="jax.random.normal", dtype_key=helpers.dtype_and_values( available_dtypes=["uint32"], min_value=0, @@ -910,12 +949,14 @@ def call(): min_dim_size=2, max_dim_size=2, ), - data=helpers.ints(), + shape=helpers.get_shape(), + dtype=helpers.get_dtypes("float", full=False), ) -def test_jax_fold_in( +def test_jax_normal( *, dtype_key, - data, + shape, + dtype, on_device, fn_tree, frontend, @@ -928,13 +969,14 @@ def call(): return helpers.test_frontend_function( input_dtypes=input_dtype, frontend=frontend, - backend_to_test=backend_fw, test_flags=test_flags, + backend_to_test=backend_fw, fn_tree=fn_tree, on_device=on_device, test_values=False, key=key[0], - data=data, + shape=shape, + dtype=dtype[0], ) ret = call() @@ -952,24 +994,26 @@ def call(): @pytest.mark.xfail @handle_frontend_test( - fn_tree="jax.random.permutation", + fn_tree="jax.random.orthogonal", dtype_key=helpers.dtype_and_values( - available_dtypes=["uint32"], + available_dtypes=["float32", "float64"], min_value=0, max_value=2000, min_num_dims=1, - max_num_dims=1, + max_num_dims=3, min_dim_size=2, - max_dim_size=2, + max_dim_size=5, ), - x=st.integers(min_value=0, max_value=10), - axis=st.integers(min_value=0, max_value=0), + n=helpers.get_shape(), + shape=helpers.get_shape(), + dtype=helpers.get_dtypes("float", full=False), ) -def test_jax_permutation( +def test_jax_orthogonal( *, dtype_key, - x, - axis, + n, + shape, + dtype, on_device, fn_tree, frontend, @@ -988,8 +1032,9 @@ def call(): on_device=on_device, test_values=False, key=key[0], - x=x, - axis=axis, + n=n, + shape=shape, + dtype=dtype[0], ) ret = call() @@ -1003,12 +1048,13 @@ def call(): for u, v in zip(ret_np, ret_from_np): assert u.dtype == v.dtype assert u.shape == v.shape + # Check if the output matrices are orthogonal + assert ivy.allclose(ivy.eye(n), ivy.matmul(u.T, u)) -# loggamma @pytest.mark.xfail @handle_frontend_test( - fn_tree="jax.random.loggamma", + fn_tree="jax.random.pareto", dtype_key=helpers.dtype_and_values( available_dtypes=["uint32"], min_value=0, @@ -1018,23 +1064,22 @@ def call(): min_dim_size=2, max_dim_size=2, ), - a_shape=_all_gamma_params(), + b_shapes=get_shape_and_arrays(), dtype=helpers.get_dtypes("float", full=False), - test_with_out=st.just(False), ) -def test_jax_loggamma( +def test_jax_pareto( *, dtype_key, - a_shape, + b_shapes, dtype, on_device, fn_tree, frontend, - test_flags, backend_fw, + test_flags, ): input_dtype, key = dtype_key - a, shape = a_shape + b, shape = b_shapes def call(): return helpers.test_frontend_function( @@ -1046,7 +1091,7 @@ def call(): on_device=on_device, test_values=False, key=key[0], - a=a, + b=b, shape=shape, dtype=dtype[0], ) @@ -1055,18 +1100,16 @@ def call(): if not ivy.exists(ret): return - ret_np, ret_from_np = ret - ret_np = helpers.flatten_and_to_np(ret=ret_np, backend=backend_fw) - ret_from_np = helpers.flatten_and_to_np(ret=ret_from_np, backend=backend_fw) - for u, v in zip(ret_np, ret_from_np): - assert u.dtype == v.dtype - assert u.shape == v.shape + if shape is not None: + assert ret_np.shape == shape + else: + assert ret_np.shape == b.shape @pytest.mark.xfail @handle_frontend_test( - fn_tree="jax.random.shuffle", + fn_tree="jax.random.permutation", dtype_key=helpers.dtype_and_values( available_dtypes=["uint32"], min_value=0, @@ -1076,37 +1119,33 @@ def call(): min_dim_size=2, max_dim_size=2, ), - dtype_x_axis=helpers.dtype_values_axis( - available_dtypes=("float32", "float64"), - valid_axis=True, - max_axes_size=1, - force_int_axis=True, - ), + x=st.integers(min_value=0, max_value=10), + axis=st.integers(min_value=0, max_value=0), ) -def test_jax_shuffle( +def test_jax_permutation( *, dtype_key, - dtype_x_axis, + x, + axis, on_device, fn_tree, frontend, test_flags, backend_fw, ): - key_dtype, key = dtype_key - x_dtypes, x, axis = dtype_x_axis + input_dtype, key = dtype_key def call(): return helpers.test_frontend_function( - input_dtypes=key_dtype + x_dtypes, + input_dtypes=input_dtype, frontend=frontend, - backend_to_test=backend_fw, test_flags=test_flags, + backend_to_test=backend_fw, fn_tree=fn_tree, on_device=on_device, test_values=False, key=key[0], - x=x[0], + x=x, axis=axis, ) @@ -1125,7 +1164,7 @@ def call(): @pytest.mark.xfail @handle_frontend_test( - fn_tree="jax.random.exponential", + fn_tree="jax.random.poisson", dtype_key=helpers.dtype_and_values( available_dtypes=["uint32"], min_value=0, @@ -1135,12 +1174,17 @@ def call(): min_dim_size=2, max_dim_size=2, ), - shape=helpers.get_shape(allow_none=False, min_num_dims=1, min_dim_size=1), - dtype=helpers.get_dtypes("float", full=False), + lam=st.floats(min_value=0, max_value=5, exclude_min=True), + shape=helpers.get_shape( + min_num_dims=2, max_num_dims=2, min_dim_size=1, max_dim_size=5 + ), + dtype=helpers.get_dtypes("integer", full=False), + test_with_out=st.just(False), ) -def test_jax_exponential( +def test_jax_poisson( *, dtype_key, + lam, shape, dtype, on_device, @@ -1154,13 +1198,14 @@ def test_jax_exponential( def call(): return helpers.test_frontend_function( input_dtypes=input_dtype, - frontend=frontend, backend_to_test=backend_fw, + frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, test_values=False, key=key[0], + lam=lam, shape=shape, dtype=dtype[0], ) @@ -1171,8 +1216,14 @@ def call(): return ret_np, ret_from_np = ret - ret_np = helpers.flatten_and_to_np(ret=ret_np, backend=backend_fw) - ret_from_np = helpers.flatten_and_to_np(ret=ret_from_np, backend=backend_fw) + ret_np = helpers.flatten_and_to_np( + ret=ret_np, + backend=backend_fw, + ) + ret_from_np = helpers.flatten_and_to_np( + ret=ret_from_np, + backend=backend_fw, + ) for u, v in zip(ret_np, ret_from_np): assert u.dtype == v.dtype assert u.shape == v.shape @@ -1180,7 +1231,7 @@ def call(): @pytest.mark.xfail @handle_frontend_test( - fn_tree="jax.random.weibull_min", + fn_tree="jax.random.rademacher", dtype_key=helpers.dtype_and_values( available_dtypes=["uint32"], min_value=0, @@ -1191,20 +1242,18 @@ def call(): max_dim_size=2, ), shape=helpers.get_shape(allow_none=False, min_num_dims=1, min_dim_size=1), - dtype=helpers.get_dtypes("float", full=False), + dtype=helpers.get_dtypes("integer", full=False), ) -def test_jax_weibull_min( +def test_jax_rademacher( *, dtype_key, shape, - scale, - concentration, dtype, on_device, fn_tree, frontend, - backend_fw, test_flags, + backend_fw, ): input_dtype, key = dtype_key @@ -1212,15 +1261,13 @@ def call(): return helpers.test_frontend_function( input_dtypes=input_dtype, frontend=frontend, + backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, - backend_to_test=backend_fw, on_device=on_device, test_values=False, key=key[0], shape=shape, - scale=scale, - concentration=concentration, dtype=dtype[0], ) @@ -1237,19 +1284,9 @@ def call(): assert u.shape == v.shape -@st.composite -def get_shape_and_arrays(draw): - b_shapes = draw( - helpers.array_and_broadcastable_shape(dtype=helpers.get_dtypes("float")) - ) - b, shapes = b_shapes - shapes = draw(st.sampled_from([None, shapes])) - return b, shapes - - @pytest.mark.xfail @handle_frontend_test( - fn_tree="jax.random.pareto", + fn_tree="jax.random.randint", dtype_key=helpers.dtype_and_values( available_dtypes=["uint32"], min_value=0, @@ -1259,35 +1296,38 @@ def get_shape_and_arrays(draw): min_dim_size=2, max_dim_size=2, ), - b_shapes=get_shape_and_arrays(), - dtype=helpers.get_dtypes("float", full=False), + shape=helpers.get_shape(allow_none=False, min_num_dims=1, min_dim_size=1), + dtype=helpers.get_dtypes("integer", full=False), + min_max=helpers.general_helpers.get_bounds(dtype="int16"), ) -def test_jax_pareto( +def test_jax_randint( *, dtype_key, - b_shapes, + shape, dtype, + min_max, on_device, fn_tree, frontend, - backend_fw, test_flags, + backend_fw, ): input_dtype, key = dtype_key - b, shape = b_shapes + minval, maxval = min_max def call(): return helpers.test_frontend_function( input_dtypes=input_dtype, frontend=frontend, - backend_to_test=backend_fw, test_flags=test_flags, + backend_to_test=backend_fw, fn_tree=fn_tree, on_device=on_device, test_values=False, key=key[0], - b=b, shape=shape, + minval=minval, + maxval=maxval, dtype=dtype[0], ) @@ -1295,16 +1335,18 @@ def call(): if not ivy.exists(ret): return + ret_np, ret_from_np = ret - if shape is not None: - assert ret_np.shape == shape - else: - assert ret_np.shape == b.shape + ret_np = helpers.flatten_and_to_np(ret=ret_np, backend=backend_fw) + ret_from_np = helpers.flatten_and_to_np(ret=ret_from_np, backend=backend_fw) + for u, v in zip(ret_np, ret_from_np): + assert u.dtype == v.dtype + assert u.shape == v.shape @pytest.mark.xfail @handle_frontend_test( - fn_tree="jax.random.maxwell", + fn_tree="jax.random.shuffle", dtype_key=helpers.dtype_and_values( available_dtypes=["uint32"], min_value=0, @@ -1314,25 +1356,29 @@ def call(): min_dim_size=2, max_dim_size=2, ), - shape=helpers.get_shape(), - dtype=helpers.get_dtypes("float", full=False), + dtype_x_axis=helpers.dtype_values_axis( + available_dtypes=("float32", "float64"), + valid_axis=True, + max_axes_size=1, + force_int_axis=True, + ), ) -def test_jax_maxwell( +def test_jax_shuffle( *, dtype_key, - shape, - dtype, + dtype_x_axis, on_device, fn_tree, frontend, - backend_fw, test_flags, + backend_fw, ): - input_dtype, key = dtype_key + key_dtype, key = dtype_key + x_dtypes, x, axis = dtype_x_axis def call(): return helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=key_dtype + x_dtypes, frontend=frontend, backend_to_test=backend_fw, test_flags=test_flags, @@ -1340,8 +1386,8 @@ def call(): on_device=on_device, test_values=False, key=key[0], - shape=shape, - dtype=dtype[0], + x=x[0], + axis=axis, ) ret = call() @@ -1359,7 +1405,7 @@ def call(): @pytest.mark.xfail @handle_frontend_test( - fn_tree="jax.random.ball", + fn_tree="jax.random.t", dtype_key=helpers.dtype_and_values( available_dtypes=["uint32"], min_value=0, @@ -1369,26 +1415,24 @@ def call(): min_dim_size=2, max_dim_size=2, ), + df=st.floats(min_value=0, max_value=5, exclude_min=True), shape=helpers.get_shape( min_num_dims=1, max_num_dims=6, min_dim_size=1, max_dim_size=6 ), dtype=helpers.get_dtypes("float", full=False), - d=st.integers(min_value=1, max_value=100), - p=st.floats(min_value=1e-5, max_value=100, exclude_min=True), test_with_out=st.just(False), ) -def test_jax_ball( +def test_jax_t( *, dtype_key, - d, - p, + df, shape, dtype, on_device, fn_tree, frontend, - backend_fw, test_flags, + backend_fw, ): input_dtype, key = dtype_key @@ -1402,8 +1446,7 @@ def call(): on_device=on_device, test_values=False, key=key[0], - d=d, - p=p, + df=df, shape=shape, dtype=dtype[0], ) @@ -1421,46 +1464,9 @@ def call(): assert u.shape == v.shape -@st.composite -def get_mean_cov_vector(draw): - input_dtype = draw( - st.shared( - st.sampled_from(draw(helpers.get_dtypes("float"))), - key="shared_dtype", - ) - ) - shared_size = draw( - st.shared(helpers.ints(min_value=2, max_value=4), key="shared_size") - ) - - # Generate shape for mean vector (..., n) - dtype_mean = draw( - helpers.array_values( - dtype=input_dtype, - shape=tuple([shared_size]), - min_value=2, - max_value=5, - ) - ) - - # Generate shape for covariance matrix (..., n, n) - dtype_cov = draw( - helpers.array_values( - dtype=input_dtype, - shape=tuple([shared_size, shared_size]), - min_value=2, - max_value=5, - ).filter(lambda x: np.linalg.cond(x.tolist()) < 1 / sys.float_info.epsilon) - ) - - batch_shape = dtype_cov.shape[:-2] - - return input_dtype, dtype_mean, dtype_cov, batch_shape - - @pytest.mark.xfail @handle_frontend_test( - fn_tree="jax.random.multivariate_normal", + fn_tree="jax.random.uniform", dtype_key=helpers.dtype_and_values( available_dtypes=["uint32"], min_value=0, @@ -1470,41 +1476,39 @@ def get_mean_cov_vector(draw): min_dim_size=2, max_dim_size=2, ), + shape=helpers.get_shape(), dtype=helpers.get_dtypes("float", full=False), - mean_cov_vector=get_mean_cov_vector(), - method=st.sampled_from(["cholesky", "eigh", "svd"]), - test_with_out=st.just(False), + dtype_minval_maxval=_get_minval_maxval(), ) -def test_jax_multivariate_normal( +def test_jax_uniform( *, dtype_key, - mean_cov_vector, + shape, dtype, - method, + dtype_minval_maxval, + on_device, + fn_tree, frontend, - backend_fw, test_flags, - fn_tree, + backend_fw, ): input_dtype, key = dtype_key - shared_dtype, mean, cov, shape = mean_cov_vector - - spd = np.matmul(cov.T, cov) + np.identity(cov.shape[0]) + minval, maxval = dtype_minval_maxval def call(): - helpers.test_frontend_function( - input_dtypes=input_dtype + [shared_dtype], + return helpers.test_frontend_function( + input_dtypes=input_dtype, frontend=frontend, test_flags=test_flags, backend_to_test=backend_fw, fn_tree=fn_tree, + on_device=on_device, test_values=False, key=key[0], - mean=mean, - cov=spd, shape=shape, dtype=dtype[0], - method=method, + minval=minval, + maxval=maxval, ) ret = call() @@ -1522,7 +1526,7 @@ def call(): @pytest.mark.xfail @handle_frontend_test( - fn_tree="jax.random.categorical", + fn_tree="jax.random.weibull_min", dtype_key=helpers.dtype_and_values( available_dtypes=["uint32"], min_value=0, @@ -1532,21 +1536,21 @@ def call(): min_dim_size=2, max_dim_size=2, ), - shape=helpers.get_shape( - min_dim_size=1, max_num_dims=6, max_dim_size=6, min_num_dims=1, allow_none=False - ), + shape=helpers.get_shape(allow_none=False, min_num_dims=1, min_dim_size=1), dtype=helpers.get_dtypes("float", full=False), ) -def test_jax_categorical( +def test_jax_weibull_min( *, dtype_key, shape, + scale, + concentration, dtype, on_device, fn_tree, frontend, - test_flags, backend_fw, + test_flags, ): input_dtype, key = dtype_key @@ -1555,18 +1559,22 @@ def call(): input_dtypes=input_dtype, frontend=frontend, test_flags=test_flags, - backend_to_test=backend_fw, fn_tree=fn_tree, + backend_to_test=backend_fw, on_device=on_device, test_values=False, key=key[0], shape=shape, + scale=scale, + concentration=concentration, dtype=dtype[0], ) ret = call() + if not ivy.exists(ret): return + ret_np, ret_from_np = ret ret_np = helpers.flatten_and_to_np(ret=ret_np, backend=backend_fw) ret_from_np = helpers.flatten_and_to_np(ret=ret_from_np, backend=backend_fw) diff --git a/ivy_tests/test_ivy/test_frontends/test_mindspore/test_ops/test_function/test_mindspore_nn_func.py b/ivy_tests/test_ivy/test_frontends/test_mindspore/test_ops/test_function/test_mindspore_nn_func.py index 1eb8bded0d086..fe825eda7098f 100644 --- a/ivy_tests/test_ivy/test_frontends/test_mindspore/test_ops/test_function/test_mindspore_nn_func.py +++ b/ivy_tests/test_ivy/test_frontends/test_mindspore/test_ops/test_function/test_mindspore_nn_func.py @@ -12,231 +12,386 @@ import ivy -# dropout2d +# --- Helpers --- # +# --------------- # + + +def _calculate_same_padding(kernel_size, stride, shape): + padding = tuple( + [ + max( + 0, + math.ceil(((shape[i] - 1) * stride[i] + kernel_size[i] - shape[i]) / 2), + ) + for i in range(len(kernel_size)) + ] + ) + if all([kernel_size[i] / 2 >= padding[i] for i in range(len(kernel_size))]): + if _is_same_padding(padding, stride, kernel_size, shape): + return padding + return (0, 0) + + +def _is_same_padding(padding, stride, kernel_size, input_shape): + output_shape = tuple( + [ + (input_shape[i] + 2 * padding[i] - kernel_size[i]) // stride[i] + 1 + for i in range(len(padding)) + ] + ) + return all( + [ + output_shape[i] == math.ceil(input_shape[i] / stride[i]) + for i in range(len(padding)) + ] + ) + + +def _scale_factor_strategy(): + return st.one_of( + st.floats(min_value=0.1, max_value=2.0), + st.tuples(st.floats(min_value=0.1, max_value=2.0)), + st.lists(st.floats(min_value=0.1, max_value=2.0), min_size=3, max_size=3), + ) + + +def _size_and_scale_factor_strategy(): + return st.one_of( + st.tuples(_size_strategy(), st.just(None)), + st.tuples(st.just(None), _scale_factor_strategy()), + st.tuples(_size_strategy(), _scale_factor_strategy()), + ) + + +def _size_strategy(): + return st.one_of( + st.integers(min_value=1, max_value=10), + st.tuples(st.integers(min_value=1, max_value=10)), + st.lists(st.integers(min_value=1, max_value=10), min_size=3, max_size=3), + ) + + +@st.composite +def _x_and_filters(draw, dim: int = 2): + if not isinstance(dim, int): + dim = draw(dim) + strides = draw( + st.one_of( + st.lists( + st.integers(min_value=1, max_value=3), + min_size=dim, + max_size=dim, + ), + st.integers(min_value=1, max_value=3), + ) + ) + + pad_mode = draw(st.sampled_from(["valid", "same", "pad"])) + + padding = draw( + st.one_of( + st.integers(min_value=1, max_value=3), + st.lists(st.integers(min_value=1, max_value=2), min_size=dim, max_size=dim), + ) + ) + + batch_size = draw(st.integers(1, 5)) + filter_shape = draw( + helpers.get_shape( + min_num_dims=dim, max_num_dims=dim, min_dim_size=1, max_dim_size=5 + ) + ) + dtype = draw(helpers.get_dtypes("float", full=False)) + input_channels = draw(st.integers(1, 3)) + output_channels = draw(st.integers(1, 3)) + group_list = [i for i in range(1, 3)] + + group_list = list(filter(lambda x: (input_channels % x == 0), group_list)) + + fc = draw(st.sampled_from(group_list)) + dilations = draw( + st.one_of( + st.lists( + st.integers(min_value=1, max_value=3), + min_size=dim, + max_size=dim, + ), + st.integers(min_value=1, max_value=3), + ) + ) + full_dilations = [dilations] * dim if isinstance(dilations, int) else dilations + + x_dim = [] + for i in range(dim): + min_x = filter_shape[i] + (filter_shape[i] - 1) * (full_dilations[i] - 1) + x_dim.append(draw(st.integers(min_x, 15))) + x_dim = tuple(x_dim) + + output_channels = output_channels * fc + filter_shape = (output_channels, input_channels // fc) + filter_shape + + x_shape = (batch_size, input_channels) + x_dim + vals = draw( + helpers.array_values( + dtype=dtype[0], + shape=x_shape, + min_value=0.0, + max_value=1.0, + ) + ) + filters = draw( + helpers.array_values( + dtype=dtype[0], + shape=filter_shape, + min_value=0.0, + max_value=1.0, + ) + ) + bias = draw( + helpers.array_values( + dtype=dtype[0], + shape=(output_channels,), + min_value=0.0, + max_value=1.0, + ) + ) + + return dtype, vals, filters, bias, dilations, strides, padding, fc, pad_mode + + +# --- Main --- # +# ------------ # + + +# adaptive_avg_pool2d @pytest.mark.skip("Testing pipeline not yet implemented") @handle_frontend_test( - fn_tree="mindspore.ops.function.nn_func.dropout2d", - d_type_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - num_arrays=1, - shared_dtype=True, - min_value=2, - max_value=5, - min_dim_size=4, - shape=( - st.integers(min_value=2, max_value=10), - 4, - st.integers(min_value=12, max_value=64), - st.integers(min_value=12, max_value=64), + fn_tree="mindspore.ops.function.nn_func.adaptive_avg_pool2d", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_num_dims=4, + max_num_dims=4, + min_dim_size=1, + max_value=100, + min_value=-100, + ), + output_size=st.one_of( + st.tuples( + helpers.ints(min_value=1, max_value=5), + helpers.ints(min_value=1, max_value=5), ), + helpers.ints(min_value=1, max_value=5), ), - p=st.floats(min_value=0.0, max_value=1.0), - training=st.booleans(), ) -def test_mindspore_dropout2d( +def test_mindspore_adaptive_avg_pool2d( *, - d_type_and_x, - p, - training, + dtype_and_x, + output_size, + test_flags, + frontend, on_device, fn_tree, - frontend, - test_flags, ): - dtype, x = d_type_and_x + input_dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=dtype, + input_dtypes=input_dtype, frontend=frontend, test_flags=test_flags, + on_device=on_device, + fn_tree=fn_tree, + x=x[0], + output_size=output_size, + ) + + +# avg_pool2d +@pytest.mark.skip("Testing pipeline not yet implemented") +@handle_frontend_test( + fn_tree="mindspore.ops.function.nn_func.avg_pool2d", + dtype_x_k_s=helpers.arrays_for_pooling( + min_dims=4, + max_dims=4, + min_side=1, + max_side=4, + ), + pad_mode=st.booleans(), + count_include_pad=st.booleans(), + test_with_out=st.just(False), +) +def test_mindspore_avg_pool2d( + dtype_x_k_s, + count_include_pad, + pad_mode, + *, + test_flags, + frontend, + backend_fw, + fn_tree, + on_device, +): + input_dtype, x, kernel_size, stride, pad_name = dtype_x_k_s + + if len(stride) == 1: + stride = (stride[0], stride[0]) + + if pad_name == "SAME": + padding = _calculate_same_padding(kernel_size, stride, x[0].shape[2:]) + else: + padding = (0, 0) + + x[0] = x[0].reshape((x[0].shape[0], x[0].shape[-1], *x[0].shape[1:-1])) + + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + test_flags=test_flags, + frontend=frontend, fn_tree=fn_tree, on_device=on_device, input=x[0], - p=p, - training=training, + kernel_size=kernel_size, + stride=stride, + padding=padding, + pad_mode=pad_mode, + count_include_pad=count_include_pad, + divisor_override=None, ) -# selu +# conv1d @pytest.mark.skip("Testing pipeline not yet implemented") @handle_frontend_test( - fn_tree="mindspore.ops.function.nn_func.selu", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - safety_factor_scale="log", - small_abs_safety_factor=20, - ), + fn_tree="mindspore.ops.function.nn_func.Conv1d", + dtype_vals=_x_and_filters(dim=1), ) -def test_mindspore_selu( +def test_mindspore_conv1d( *, - dtype_and_x, + dtype_vals, on_device, fn_tree, frontend, test_flags, + backend_fw, ): - input_dtype, x = dtype_and_x + dtype, vals, weight, bias, dilations, strides, padding, fc, pad_mode = dtype_vals helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=dtype, + backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], + input=vals, + weight=weight, + bias=bias, + stride=strides, + padding=padding, + dilation=dilations, + groups=fc, + pad_mode=pad_mode, ) -# kl_div @pytest.mark.skip("Testing pipeline not yet implemented") @handle_frontend_test( - fn_tree="mindspore.ops.function.nn_func.kl_div", - p=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - num_arrays=1, - shared_dtype=True, - min_value=2, - max_value=5, - min_dim_size=4, - ), - q=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - num_arrays=1, - shared_dtype=True, - min_value=2, - max_value=5, - min_dim_size=4, - ), - reduction=st.sampled_from(["none", "sum", "mean"]), + fn_tree="mindspore.ops.function.nn_func.Conv2d", + dtype_vals=_x_and_filters(dim=2), ) -def test_mindspore_kl_div( +def test_mindspore_conv2d( *, - p, - q, - reduction, + dtype_vals, on_device, fn_tree, frontend, test_flags, + backend_fw, ): + dtype, vals, weight, bias, dilations, strides, padding, fc, pad_mode = dtype_vals helpers.test_frontend_function( - input_dtypes=p[0], + input_dtypes=dtype, + backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - p=p[1], - q=q[1], - reduction=reduction, + input=vals, + weight=weight, + bias=bias, + stride=strides, + padding=padding, + dilation=dilations, + groups=fc, + pad_mode=pad_mode, ) -# dropout3d @pytest.mark.skip("Testing pipeline not yet implemented") @handle_frontend_test( - fn_tree="mindspore.ops.function.nn_func.dropout3d", - d_type_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - num_arrays=1, - shared_dtype=True, - min_value=2, - max_value=5, - min_dim_size=5, - shape=( - st.integers(min_value=2, max_value=10), - st.integers(min_value=12, max_value=64), - st.integers(min_value=12, max_value=64), - st.integers(min_value=12, max_value=64), - ), - ), - p=st.floats(min_value=0.0, max_value=1.0), - training=st.booleans(), + fn_tree="mindspore.ops.function.nn_func.Conv3d", + dtype_vals=_x_and_filters(dim=3), ) -def test_mindspore_dropout3d( +def test_mindspore_conv3d( *, - d_type_and_x, - p, - training, + dtype_vals, on_device, fn_tree, frontend, test_flags, + backend_fw, ): - dtype, x = d_type_and_x + dtype, vals, weight, bias, dilations, strides, padding, fc, pad_mode = dtype_vals + # ToDo: Enable gradient tests for dilations > 1 when tensorflow supports it. + _assume_tf_dilation_gt_1(ivy.current_backend_str(), on_device, dilations) helpers.test_frontend_function( input_dtypes=dtype, + backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=x[0], - p=p, - training=training, - ) - - -def _size_strategy(): - return st.one_of( - st.integers(min_value=1, max_value=10), - st.tuples(st.integers(min_value=1, max_value=10)), - st.lists(st.integers(min_value=1, max_value=10), min_size=3, max_size=3), - ) - - -def _scale_factor_strategy(): - return st.one_of( - st.floats(min_value=0.1, max_value=2.0), - st.tuples(st.floats(min_value=0.1, max_value=2.0)), - st.lists(st.floats(min_value=0.1, max_value=2.0), min_size=3, max_size=3), - ) - - -def _size_and_scale_factor_strategy(): - return st.one_of( - st.tuples(_size_strategy(), st.just(None)), - st.tuples(st.just(None), _scale_factor_strategy()), - st.tuples(_size_strategy(), _scale_factor_strategy()), + input=vals, + weight=weight, + bias=bias, + stride=strides, + padding=padding, + dilation=dilations, + groups=fc, + pad_mode=pad_mode, ) +# dropout2d @pytest.mark.skip("Testing pipeline not yet implemented") @handle_frontend_test( - fn_tree="mindspore.ops.function.nn_func.interpolate", - dtype_and_x=helpers.dtype_and_values( + fn_tree="mindspore.ops.function.nn_func.dropout2d", + d_type_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), num_arrays=1, shared_dtype=True, + min_value=2, + max_value=5, + min_dim_size=4, + shape=( + st.integers(min_value=2, max_value=10), + 4, + st.integers(min_value=12, max_value=64), + st.integers(min_value=12, max_value=64), + ), ), - mode=st.sampled_from( - [ - "nearest", - "linear", - "bilinear", - "bicubic", - "trilinear", - "area", - "nearest-exact", - ] - ), - align_corners=st.booleans(), - recompute_scale_factor=st.booleans(), - size_and_scale_factor=_size_and_scale_factor_strategy(), + p=st.floats(min_value=0.0, max_value=1.0), + training=st.booleans(), ) -def test_mindspore_interpolate( +def test_mindspore_dropout2d( *, - dtype_and_x, - size, - scale_factor, - mode, - align_corners, - recompute_scale_factor, + d_type_and_x, + p, + training, on_device, fn_tree, frontend, test_flags, - size_and_scale_factor, ): - dtype, x = dtype_and_x - size, scale_factor = size_and_scale_factor - + dtype, x = d_type_and_x helpers.test_frontend_function( input_dtypes=dtype, frontend=frontend, @@ -244,560 +399,413 @@ def test_mindspore_interpolate( fn_tree=fn_tree, on_device=on_device, input=x[0], - size=size, - scale_factor=scale_factor, - mode=mode, - align_corners=align_corners, - recompute_scale_factor=recompute_scale_factor, + p=p, + training=training, ) -# hardswish -# @handle_frontend_test( -# fn_tree="mindspore.ops.function.nn_func.hardswish", -# dtype_and_x=helpers.dtype_and_values( -# available_dtypes=helpers.get_dtypes("valid"), -# ), -# ) -# def test_mindspore_hardswish( -# *, -# dtype_and_x, -# on_device, -# fn_tree, -# frontend, -# test_flags, -# ): -# input_dtype, x = dtype_and_x -# helpers.test_frontend_function( -# input_dtypes=input_dtype, -# frontend=frontend, -# test_flags=test_flags, -# fn_tree=fn_tree, -# on_device=on_device, -# x=x[0], -# ) - - -# pad +# dropout3d @pytest.mark.skip("Testing pipeline not yet implemented") @handle_frontend_test( - fn_tree="pad", - input=helpers.dtype_and_values( + fn_tree="mindspore.ops.function.nn_func.dropout3d", + d_type_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), num_arrays=1, shared_dtype=True, min_value=2, max_value=5, - min_dim_size=4, - ), - pad_width=st.lists( - st.tuples( - st.integers(min_value=0, max_value=5), st.integers(min_value=0, max_value=5) - ) + min_dim_size=5, + shape=( + st.integers(min_value=2, max_value=10), + st.integers(min_value=12, max_value=64), + st.integers(min_value=12, max_value=64), + st.integers(min_value=12, max_value=64), + ), ), - mode=st.sampled_from(["constant", "reflect", "replicate", "circular"]), - constant_values=st.floats(min_value=0.0, max_value=1.0), + p=st.floats(min_value=0.0, max_value=1.0), + training=st.booleans(), ) -def test_mindspore_pad( +def test_mindspore_dropout3d( *, - input, - pad_width, - mode, - constant_values, + d_type_and_x, + p, + training, on_device, fn_tree, frontend, test_flags, ): + dtype, x = d_type_and_x helpers.test_frontend_function( - input_dtypes=input[0], + input_dtypes=dtype, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=input[1], - pad_width=pad_width, - mode=mode, - constant_values=constant_values, + input=x[0], + p=p, + training=training, ) -# adaptive_avg_pool2d +# FastGelu @pytest.mark.skip("Testing pipeline not yet implemented") @handle_frontend_test( - fn_tree="mindspore.ops.function.nn_func.adaptive_avg_pool2d", + fn_tree="mindspore.ops.function.nn_func.fast_gelu", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - min_num_dims=4, - max_num_dims=4, - min_dim_size=1, - max_value=100, - min_value=-100, - ), - output_size=st.one_of( - st.tuples( - helpers.ints(min_value=1, max_value=5), - helpers.ints(min_value=1, max_value=5), - ), - helpers.ints(min_value=1, max_value=5), ), ) -def test_mindspore_adaptive_avg_pool2d( - *, +def test_mindspore_fast_gelu( dtype_and_x, - output_size, + *, test_flags, frontend, on_device, fn_tree, ): input_dtype, x = dtype_and_x + helpers.test_frontend_function( input_dtypes=input_dtype, frontend=frontend, test_flags=test_flags, - on_device=on_device, fn_tree=fn_tree, + on_device=on_device, x=x[0], - output_size=output_size, + input=x[0], ) -# log_softmax +# flatten @pytest.mark.skip("Testing pipeline not yet implemented") @handle_frontend_test( - fn_tree="mindspore.ops.function.nn_func.log_softmax", - dtype_and_x=helpers.dtype_and_values( + fn_tree="mindspore.ops.function.nn_func.flatten", + dtype_input_axes=helpers.dtype_values_axis( available_dtypes=helpers.get_dtypes("valid"), - safety_factor_scale="log", - small_abs_safety_factor=20, + valid_axis=True, + min_num_dims=1, + min_axes_size=2, + max_axes_size=2, ), ) -def test_mindspore_log_softmax( +def test_mindspore_flatten( *, - dtype_and_x, + dtype_input_axes, on_device, fn_tree, frontend, test_flags, -): - input_dtype, x = dtype_and_x - - -def _is_same_padding(padding, stride, kernel_size, input_shape): - output_shape = tuple( - [ - (input_shape[i] + 2 * padding[i] - kernel_size[i]) // stride[i] + 1 - for i in range(len(padding)) - ] - ) - return all( - [ - output_shape[i] == math.ceil(input_shape[i] / stride[i]) - for i in range(len(padding)) - ] - ) - - -def _calculate_same_padding(kernel_size, stride, shape): - padding = tuple( - [ - max( - 0, - math.ceil(((shape[i] - 1) * stride[i] + kernel_size[i] - shape[i]) / 2), - ) - for i in range(len(kernel_size)) - ] - ) - if all([kernel_size[i] / 2 >= padding[i] for i in range(len(kernel_size))]): - if _is_same_padding(padding, stride, kernel_size, shape): - return padding - return (0, 0) - - -# avg_pool2d -@pytest.mark.skip("Testing pipeline not yet implemented") -@handle_frontend_test( - fn_tree="mindspore.ops.function.nn_func.avg_pool2d", - dtype_x_k_s=helpers.arrays_for_pooling( - min_dims=4, - max_dims=4, - min_side=1, - max_side=4, - ), - pad_mode=st.booleans(), - count_include_pad=st.booleans(), - test_with_out=st.just(False), -) -def test_mindspore_avg_pool2d( - dtype_x_k_s, - count_include_pad, - pad_mode, - *, - test_flags, - frontend, backend_fw, - fn_tree, - on_device, ): - input_dtype, x, kernel_size, stride, pad_name = dtype_x_k_s - - if len(stride) == 1: - stride = (stride[0], stride[0]) - - if pad_name == "SAME": - padding = _calculate_same_padding(kernel_size, stride, x[0].shape[2:]) + dtype, input, axes = dtype_input_axes + if isinstance(axes, int): + start_dim = axes + end_dim = -1 else: - padding = (0, 0) - - x[0] = x[0].reshape((x[0].shape[0], x[0].shape[-1], *x[0].shape[1:-1])) - + start_dim = axes[0] + end_dim = axes[1] helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=dtype, backend_to_test=backend_fw, - test_flags=test_flags, frontend=frontend, + test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=x[0], - kernel_size=kernel_size, - stride=stride, - padding=padding, - pad_mode=pad_mode, - count_include_pad=count_include_pad, - divisor_override=None, + input=input[0], + order="C", + start_dim=start_dim, + end_dim=end_dim, ) -# softshrink @pytest.mark.skip("Testing pipeline not yet implemented") @handle_frontend_test( - fn_tree="mindspore.ops.function.nn_func.softshrink", - dtype_and_input=helpers.dtype_and_values( + fn_tree="mindspore.ops.function.nn_func.interpolate", + dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), + num_arrays=1, + shared_dtype=True, ), - lambd=helpers.floats(min_value=0, max_value=1, exclude_min=True), + mode=st.sampled_from( + [ + "nearest", + "linear", + "bilinear", + "bicubic", + "trilinear", + "area", + "nearest-exact", + ] + ), + align_corners=st.booleans(), + recompute_scale_factor=st.booleans(), + size_and_scale_factor=_size_and_scale_factor_strategy(), ) -def test_mindspore_softshrink( +def test_mindspore_interpolate( *, - dtype_and_input, - lambd, + dtype_and_x, + size, + scale_factor, + mode, + align_corners, + recompute_scale_factor, on_device, fn_tree, frontend, test_flags, - backend_fw, + size_and_scale_factor, ): - input_dtype, x = dtype_and_input + dtype, x = dtype_and_x + size, scale_factor = size_and_scale_factor + helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, + input_dtypes=dtype, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], - lambd=lambd, + input=x[0], + size=size, + scale_factor=scale_factor, + mode=mode, + align_corners=align_corners, + recompute_scale_factor=recompute_scale_factor, ) -# gumbel_softmax +# kl_div @pytest.mark.skip("Testing pipeline not yet implemented") @handle_frontend_test( - fn_tree="mindspore.ops.function.nn_func.gumbel_softmax", - dtype_and_x=helpers.dtype_and_values( + fn_tree="mindspore.ops.function.nn_func.kl_div", + p=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + num_arrays=1, + shared_dtype=True, + min_value=2, + max_value=5, + min_dim_size=4, + ), + q=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), + num_arrays=1, + shared_dtype=True, + min_value=2, + max_value=5, + min_dim_size=4, ), - tau=st.floats(min_value=0), - hard=st.booleans(), - dim=st.integers(), - test_with_out=st.just(False), - test_inplace=st.booleans(), + reduction=st.sampled_from(["none", "sum", "mean"]), ) -def test_torch_gumbel_softmax( +def test_mindspore_kl_div( *, - dtype_and_x, - tau, - hard, - dim, + p, + q, + reduction, on_device, fn_tree, frontend, test_flags, - backend_fw, ): - input_dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, + input_dtypes=p[0], frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - test_values=False, - logits=x[0], - tau=tau, - hard=hard, - dim=dim, + p=p[1], + q=q[1], + reduction=reduction, ) -# FastGelu +# log_softmax @pytest.mark.skip("Testing pipeline not yet implemented") @handle_frontend_test( - fn_tree="mindspore.ops.function.nn_func.fast_gelu", + fn_tree="mindspore.ops.function.nn_func.log_softmax", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("valid"), + safety_factor_scale="log", + small_abs_safety_factor=20, ), ) -def test_mindspore_fast_gelu( - dtype_and_x, +def test_mindspore_log_softmax( *, - test_flags, - frontend, + dtype_and_x, on_device, fn_tree, + frontend, + test_flags, ): input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - input=x[0], - ) - -# flatten +# hardswish +# @handle_frontend_test( +# fn_tree="mindspore.ops.function.nn_func.hardswish", +# dtype_and_x=helpers.dtype_and_values( +# available_dtypes=helpers.get_dtypes("valid"), +# ), +# ) +# def test_mindspore_hardswish( +# *, +# dtype_and_x, +# on_device, +# fn_tree, +# frontend, +# test_flags, +# ): +# input_dtype, x = dtype_and_x +# helpers.test_frontend_function( +# input_dtypes=input_dtype, +# frontend=frontend, +# test_flags=test_flags, +# fn_tree=fn_tree, +# on_device=on_device, +# x=x[0], +# ) + + +# pad @pytest.mark.skip("Testing pipeline not yet implemented") @handle_frontend_test( - fn_tree="mindspore.ops.function.nn_func.flatten", - dtype_input_axes=helpers.dtype_values_axis( + fn_tree="pad", + input=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), - valid_axis=True, - min_num_dims=1, - min_axes_size=2, - max_axes_size=2, + num_arrays=1, + shared_dtype=True, + min_value=2, + max_value=5, + min_dim_size=4, + ), + pad_width=st.lists( + st.tuples( + st.integers(min_value=0, max_value=5), st.integers(min_value=0, max_value=5) + ) ), + mode=st.sampled_from(["constant", "reflect", "replicate", "circular"]), + constant_values=st.floats(min_value=0.0, max_value=1.0), ) -def test_mindspore_flatten( +def test_mindspore_pad( *, - dtype_input_axes, + input, + pad_width, + mode, + constant_values, on_device, fn_tree, frontend, test_flags, - backend_fw, ): - dtype, input, axes = dtype_input_axes - if isinstance(axes, int): - start_dim = axes - end_dim = -1 - else: - start_dim = axes[0] - end_dim = axes[1] helpers.test_frontend_function( - input_dtypes=dtype, - backend_to_test=backend_fw, + input_dtypes=input[0], frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=input[0], - order="C", - start_dim=start_dim, - end_dim=end_dim, - ) - - -@st.composite -def _x_and_filters(draw, dim: int = 2): - if not isinstance(dim, int): - dim = draw(dim) - strides = draw( - st.one_of( - st.lists( - st.integers(min_value=1, max_value=3), - min_size=dim, - max_size=dim, - ), - st.integers(min_value=1, max_value=3), - ) - ) - - pad_mode = draw(st.sampled_from(["valid", "same", "pad"])) - - padding = draw( - st.one_of( - st.integers(min_value=1, max_value=3), - st.lists(st.integers(min_value=1, max_value=2), min_size=dim, max_size=dim), - ) - ) - - batch_size = draw(st.integers(1, 5)) - filter_shape = draw( - helpers.get_shape( - min_num_dims=dim, max_num_dims=dim, min_dim_size=1, max_dim_size=5 - ) - ) - dtype = draw(helpers.get_dtypes("float", full=False)) - input_channels = draw(st.integers(1, 3)) - output_channels = draw(st.integers(1, 3)) - group_list = [i for i in range(1, 3)] - - group_list = list(filter(lambda x: (input_channels % x == 0), group_list)) - - fc = draw(st.sampled_from(group_list)) - dilations = draw( - st.one_of( - st.lists( - st.integers(min_value=1, max_value=3), - min_size=dim, - max_size=dim, - ), - st.integers(min_value=1, max_value=3), - ) - ) - full_dilations = [dilations] * dim if isinstance(dilations, int) else dilations - - x_dim = [] - for i in range(dim): - min_x = filter_shape[i] + (filter_shape[i] - 1) * (full_dilations[i] - 1) - x_dim.append(draw(st.integers(min_x, 15))) - x_dim = tuple(x_dim) - - output_channels = output_channels * fc - filter_shape = (output_channels, input_channels // fc) + filter_shape - - x_shape = (batch_size, input_channels) + x_dim - vals = draw( - helpers.array_values( - dtype=dtype[0], - shape=x_shape, - min_value=0.0, - max_value=1.0, - ) - ) - filters = draw( - helpers.array_values( - dtype=dtype[0], - shape=filter_shape, - min_value=0.0, - max_value=1.0, - ) - ) - bias = draw( - helpers.array_values( - dtype=dtype[0], - shape=(output_channels,), - min_value=0.0, - max_value=1.0, - ) + input=input[1], + pad_width=pad_width, + mode=mode, + constant_values=constant_values, ) - return dtype, vals, filters, bias, dilations, strides, padding, fc, pad_mode - -# conv1d +# selu @pytest.mark.skip("Testing pipeline not yet implemented") @handle_frontend_test( - fn_tree="mindspore.ops.function.nn_func.Conv1d", - dtype_vals=_x_and_filters(dim=1), + fn_tree="mindspore.ops.function.nn_func.selu", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + safety_factor_scale="log", + small_abs_safety_factor=20, + ), ) -def test_mindspore_conv1d( +def test_mindspore_selu( *, - dtype_vals, + dtype_and_x, on_device, fn_tree, frontend, test_flags, - backend_fw, ): - dtype, vals, weight, bias, dilations, strides, padding, fc, pad_mode = dtype_vals + input_dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=dtype, - backend_to_test=backend_fw, + input_dtypes=input_dtype, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=vals, - weight=weight, - bias=bias, - stride=strides, - padding=padding, - dilation=dilations, - groups=fc, - pad_mode=pad_mode, + x=x[0], ) +# softshrink @pytest.mark.skip("Testing pipeline not yet implemented") @handle_frontend_test( - fn_tree="mindspore.ops.function.nn_func.Conv2d", - dtype_vals=_x_and_filters(dim=2), + fn_tree="mindspore.ops.function.nn_func.softshrink", + dtype_and_input=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + ), + lambd=helpers.floats(min_value=0, max_value=1, exclude_min=True), ) -def test_mindspore_conv2d( +def test_mindspore_softshrink( *, - dtype_vals, + dtype_and_input, + lambd, on_device, fn_tree, frontend, test_flags, backend_fw, ): - dtype, vals, weight, bias, dilations, strides, padding, fc, pad_mode = dtype_vals + input_dtype, x = dtype_and_input helpers.test_frontend_function( - input_dtypes=dtype, + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=vals, - weight=weight, - bias=bias, - stride=strides, - padding=padding, - dilation=dilations, - groups=fc, - pad_mode=pad_mode, + x=x[0], + lambd=lambd, ) +# gumbel_softmax @pytest.mark.skip("Testing pipeline not yet implemented") @handle_frontend_test( - fn_tree="mindspore.ops.function.nn_func.Conv3d", - dtype_vals=_x_and_filters(dim=3), + fn_tree="mindspore.ops.function.nn_func.gumbel_softmax", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + ), + tau=st.floats(min_value=0), + hard=st.booleans(), + dim=st.integers(), + test_with_out=st.just(False), + test_inplace=st.booleans(), ) -def test_mindspore_conv3d( +def test_torch_gumbel_softmax( *, - dtype_vals, + dtype_and_x, + tau, + hard, + dim, on_device, fn_tree, frontend, test_flags, backend_fw, ): - dtype, vals, weight, bias, dilations, strides, padding, fc, pad_mode = dtype_vals - # ToDo: Enable gradient tests for dilations > 1 when tensorflow supports it. - _assume_tf_dilation_gt_1(ivy.current_backend_str(), on_device, dilations) + input_dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=dtype, + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=vals, - weight=weight, - bias=bias, - stride=strides, - padding=padding, - dilation=dilations, - groups=fc, - pad_mode=pad_mode, + test_values=False, + logits=x[0], + tau=tau, + hard=hard, + dim=dim, ) diff --git a/ivy_tests/test_ivy/test_frontends/test_numpy/helpers.py b/ivy_tests/test_ivy/test_frontends/test_numpy/helpers.py index 3a0d63f2706a9..8f5fbbf38c50e 100644 --- a/ivy_tests/test_ivy/test_frontends/test_numpy/helpers.py +++ b/ivy_tests/test_ivy/test_frontends/test_numpy/helpers.py @@ -13,6 +13,10 @@ import ivy_tests.test_ivy.helpers.globals as test_globals +# --- Helpers --- # +# --------------- # + + @st.composite def _array_and_axes_permute_helper( draw, @@ -73,13 +77,97 @@ def _array_and_axes_permute_helper( return array, dtype, axes +def _flatten_frontend_return(*, ret, backend): + """Flattening the returned frontend value to a list of numpy arrays.""" + with BackendHandler.update_backend(backend) as ivy_backend: + if not isinstance(ret, tuple): + if not ivy_backend.is_ivy_array(ret): + ret_np_flat = helpers.flatten_frontend_to_np(backend=backend, ret=ret) + else: + ret_np_flat = _flatten_fw_return(ret=ret, backend=backend) + else: + if any([not ivy_backend.is_ivy_array(x) for x in ret]): + ret_np_flat = helpers.flatten_frontend_to_np(backend=backend, ret=ret) + else: + ret_np_flat = _flatten_fw_return(ret=ret, backend=backend) + return ret_np_flat + + +def _flatten_fw_return(ret, backend): + with BackendHandler.update_backend(backend) as ivy_backend: + if not isinstance(ret, tuple): + ret = (ret,) + ret_idxs = ivy_backend.nested_argwhere( + ret, lambda x: ivy_backend.is_ivy_array(x) or ivy_backend.is_native_array(x) + ) + if len(ret_idxs) == 0: + ret_idxs = ivy_backend.nested_argwhere(ret, ivy_backend.isscalar) + ret_flat = ivy_backend.multi_index_nest(ret, ret_idxs) + ret_flat = [ + ivy_backend.asarray( + x, dtype=ivy_backend.Dtype(str(np.asarray(x).dtype)) + ) + for x in ret_flat + ] + else: + ret_flat = ivy_backend.multi_index_nest(ret, ret_idxs) + + # convert the return to NumPy + ret_np_flat = [ivy_backend.to_numpy(x) for x in ret_flat] + return ret_np_flat + + @st.composite -def where(draw, *, shape=None): - if shape is None: - _, values = draw(helpers.dtype_and_values(dtype=["bool"])) +def _get_dtype_input_and_vectors(draw): + dim_size = draw(helpers.ints(min_value=1, max_value=5)) + dtype = draw(helpers.get_dtypes("float", index=1, full=False)) + if dim_size == 1: + vec1 = draw( + helpers.array_values( + dtype=dtype[0], shape=(dim_size,), min_value=2, max_value=5 + ) + ) + vec2 = draw( + helpers.array_values( + dtype=dtype[0], shape=(dim_size,), min_value=2, max_value=5 + ) + ) else: - _, values = draw(helpers.dtype_and_values(dtype=["bool"], shape=shape)) - return draw(st.just(values) | st.just(True)) + vec1 = draw( + helpers.array_values( + dtype=dtype[0], shape=(dim_size, dim_size), min_value=2, max_value=5 + ) + ) + vec2 = draw( + helpers.array_values( + dtype=dtype[0], shape=(dim_size, dim_size), min_value=2, max_value=5 + ) + ) + return dtype, vec1, vec2 + + +# Casting helper +@st.composite +def _get_safe_casting_dtype(draw, *, dtypes): + target_dtype = dtypes[0] + for dtype in dtypes[1:]: + if np_frontend.can_cast(target_dtype, dtype, casting="safe"): + target_dtype = dtype + with BackendHandler.update_backend(test_globals.CURRENT_BACKEND) as ivy_backend: + if ivy_backend.is_float_dtype(target_dtype): + dtype = draw(st.sampled_from(["float64", None])) + elif ivy_backend.is_uint_dtype(target_dtype): + dtype = draw(st.sampled_from(["uint64", None])) + elif ivy_backend.is_int_dtype(target_dtype): + dtype = draw(st.sampled_from(["int64", None])) + elif ivy_backend.is_complex_dtype(target_dtype): + dtype = draw(st.sampled_from(["complex128", None])) + else: + dtype = draw(st.sampled_from(["bool", None])) + # filter uint64 as not supported by torch backend + if dtype == "uint64": + dtype = None + return dtype # noinspection PyShadowingNames @@ -153,94 +241,6 @@ def _test_frontend_function_ignoring_uninitialized(*args, **kwargs): ) -def _flatten_fw_return(ret, backend): - with BackendHandler.update_backend(backend) as ivy_backend: - if not isinstance(ret, tuple): - ret = (ret,) - ret_idxs = ivy_backend.nested_argwhere( - ret, lambda x: ivy_backend.is_ivy_array(x) or ivy_backend.is_native_array(x) - ) - if len(ret_idxs) == 0: - ret_idxs = ivy_backend.nested_argwhere(ret, ivy_backend.isscalar) - ret_flat = ivy_backend.multi_index_nest(ret, ret_idxs) - ret_flat = [ - ivy_backend.asarray( - x, dtype=ivy_backend.Dtype(str(np.asarray(x).dtype)) - ) - for x in ret_flat - ] - else: - ret_flat = ivy_backend.multi_index_nest(ret, ret_idxs) - - # convert the return to NumPy - ret_np_flat = [ivy_backend.to_numpy(x) for x in ret_flat] - return ret_np_flat - - -def _flatten_frontend_return(*, ret, backend): - """Flattening the returned frontend value to a list of numpy arrays.""" - with BackendHandler.update_backend(backend) as ivy_backend: - if not isinstance(ret, tuple): - if not ivy_backend.is_ivy_array(ret): - ret_np_flat = helpers.flatten_frontend_to_np(backend=backend, ret=ret) - else: - ret_np_flat = _flatten_fw_return(ret=ret, backend=backend) - else: - if any([not ivy_backend.is_ivy_array(x) for x in ret]): - ret_np_flat = helpers.flatten_frontend_to_np(backend=backend, ret=ret) - else: - ret_np_flat = _flatten_fw_return(ret=ret, backend=backend) - return ret_np_flat - - -# noinspection PyShadowingNames -def test_frontend_function(*args, where=None, **kwargs): - if not ivy.exists(where): - helpers.test_frontend_function(*args, **kwargs) - else: - kwargs["where"] = where - if "out" in kwargs and kwargs["out"] is None: - _test_frontend_function_ignoring_uninitialized(*args, **kwargs) - return - else: - helpers.test_frontend_function(*args, **kwargs) - - -# noinspection PyShadowingNames -def handle_where_and_array_bools(where, input_dtype, test_flags): - if isinstance(where, list) or isinstance(where, tuple): - where = where[0] - test_flags.as_variable += [False] - test_flags.native_arrays += [False] - input_dtype += ["bool"] - return where, input_dtype, test_flags - return where, input_dtype, test_flags - - -# Casting helper -@st.composite -def _get_safe_casting_dtype(draw, *, dtypes): - target_dtype = dtypes[0] - for dtype in dtypes[1:]: - if np_frontend.can_cast(target_dtype, dtype, casting="safe"): - target_dtype = dtype - with BackendHandler.update_backend(test_globals.CURRENT_BACKEND) as ivy_backend: - if ivy_backend.is_float_dtype(target_dtype): - dtype = draw(st.sampled_from(["float64", None])) - elif ivy_backend.is_uint_dtype(target_dtype): - dtype = draw(st.sampled_from(["uint64", None])) - elif ivy_backend.is_int_dtype(target_dtype): - dtype = draw(st.sampled_from(["int64", None])) - elif ivy_backend.is_complex_dtype(target_dtype): - dtype = draw(st.sampled_from(["complex128", None])) - else: - dtype = draw(st.sampled_from(["bool", None])) - # filter uint64 as not supported by torch backend - if dtype == "uint64": - dtype = None - return dtype - - @st.composite def dtypes_values_casting_dtype( draw, @@ -298,29 +298,37 @@ def get_num_positional_args_ufunc(draw, *, fn_name=None): @st.composite -def _get_dtype_input_and_vectors(draw): - dim_size = draw(helpers.ints(min_value=1, max_value=5)) - dtype = draw(helpers.get_dtypes("float", index=1, full=False)) - if dim_size == 1: - vec1 = draw( - helpers.array_values( - dtype=dtype[0], shape=(dim_size,), min_value=2, max_value=5 - ) - ) - vec2 = draw( - helpers.array_values( - dtype=dtype[0], shape=(dim_size,), min_value=2, max_value=5 - ) - ) +def where(draw, *, shape=None): + if shape is None: + _, values = draw(helpers.dtype_and_values(dtype=["bool"])) else: - vec1 = draw( - helpers.array_values( - dtype=dtype[0], shape=(dim_size, dim_size), min_value=2, max_value=5 - ) - ) - vec2 = draw( - helpers.array_values( - dtype=dtype[0], shape=(dim_size, dim_size), min_value=2, max_value=5 - ) - ) - return dtype, vec1, vec2 + _, values = draw(helpers.dtype_and_values(dtype=["bool"], shape=shape)) + return draw(st.just(values) | st.just(True)) + + +# --- Main --- # +# ------------ # + + +# noinspection PyShadowingNames +def handle_where_and_array_bools(where, input_dtype, test_flags): + if isinstance(where, list) or isinstance(where, tuple): + where = where[0] + test_flags.as_variable += [False] + test_flags.native_arrays += [False] + input_dtype += ["bool"] + return where, input_dtype, test_flags + return where, input_dtype, test_flags + + +# noinspection PyShadowingNames +def test_frontend_function(*args, where=None, **kwargs): + if not ivy.exists(where): + helpers.test_frontend_function(*args, **kwargs) + else: + kwargs["where"] = where + if "out" in kwargs and kwargs["out"] is None: + _test_frontend_function_ignoring_uninitialized(*args, **kwargs) + return + else: + helpers.test_frontend_function(*args, **kwargs) diff --git a/ivy_tests/test_ivy/test_frontends/test_numpy/test_broadcast/test_methods.py b/ivy_tests/test_ivy/test_frontends/test_numpy/test_broadcast/test_methods.py index 594fef50f9f26..029e889d2b059 100644 --- a/ivy_tests/test_ivy/test_frontends/test_numpy/test_broadcast/test_methods.py +++ b/ivy_tests/test_ivy/test_frontends/test_numpy/test_broadcast/test_methods.py @@ -9,6 +9,10 @@ from ivy_tests.test_ivy.helpers import handle_frontend_test +# --- Helpers --- # +# --------------- # + + @st.composite def _broadcastable_arrays(draw): num_of_array = draw(st.integers(1, 3)) @@ -23,24 +27,44 @@ def _broadcastable_arrays(draw): return xs +# --- Main --- # +# ------------ # + + @handle_frontend_test( fn_tree="numpy.add", # dummy fn_tree args=_broadcastable_arrays(), ) -def test_numpy_broadcast_property_shape(args): +def test_numpy_broadcast_method_reset(args): ret = broadcast(*args) ret_gt = np.broadcast(*args) - assert ret.shape == ret_gt.shape + for _ in zip(ret, ret_gt): + pass + ret.reset() + ret_gt.reset() + assert ret.index == ret_gt.index @handle_frontend_test( fn_tree="numpy.add", # dummy fn_tree args=_broadcastable_arrays(), ) -def test_numpy_broadcast_property_size(args): +def test_numpy_broadcast_property_index(args): ret = broadcast(*args) ret_gt = np.broadcast(*args) - assert ret.size == ret_gt.size + assert ret.index == ret_gt.index + for _ in zip(ret, ret_gt): + assert ret.index == ret_gt.index + + +@handle_frontend_test( + fn_tree="numpy.add", # dummy fn_tree + args=_broadcastable_arrays(), +) +def test_numpy_broadcast_property_iters(args): + ret = list(map(list, broadcast(*args).iters)) + ret_gt = np.array(list(map(list, np.broadcast(*args).iters))) + assert ivy.all(ret == ret_gt) @handle_frontend_test( @@ -77,33 +101,17 @@ def test_numpy_broadcast_property_numiter(args): fn_tree="numpy.add", # dummy fn_tree args=_broadcastable_arrays(), ) -def test_numpy_broadcast_property_index(args): +def test_numpy_broadcast_property_shape(args): ret = broadcast(*args) ret_gt = np.broadcast(*args) - assert ret.index == ret_gt.index - for _ in zip(ret, ret_gt): - assert ret.index == ret_gt.index - - -@handle_frontend_test( - fn_tree="numpy.add", # dummy fn_tree - args=_broadcastable_arrays(), -) -def test_numpy_broadcast_property_iters(args): - ret = list(map(list, broadcast(*args).iters)) - ret_gt = np.array(list(map(list, np.broadcast(*args).iters))) - assert ivy.all(ret == ret_gt) + assert ret.shape == ret_gt.shape @handle_frontend_test( fn_tree="numpy.add", # dummy fn_tree args=_broadcastable_arrays(), ) -def test_numpy_broadcast_method_reset(args): +def test_numpy_broadcast_property_size(args): ret = broadcast(*args) ret_gt = np.broadcast(*args) - for _ in zip(ret, ret_gt): - pass - ret.reset() - ret_gt.reset() - assert ret.index == ret_gt.index + assert ret.size == ret_gt.size diff --git a/ivy_tests/test_ivy/test_frontends/test_numpy/test_creation_routines/test_building_matrices.py b/ivy_tests/test_ivy/test_frontends/test_numpy/test_creation_routines/test_building_matrices.py index e8f7251e40c2f..901bc700a7ee7 100644 --- a/ivy_tests/test_ivy/test_frontends/test_numpy/test_creation_routines/test_building_matrices.py +++ b/ivy_tests/test_ivy/test_frontends/test_numpy/test_creation_routines/test_building_matrices.py @@ -8,30 +8,49 @@ from ivy_tests.test_ivy.test_functional.test_core.test_linalg import _diag_helper -# tril +# --- Helpers --- # +# --------------- # + + +@st.composite +def _diag_flat_helper(draw): + x_shape = draw( + helpers.get_shape( + min_num_dims=1, max_num_dims=2, min_dim_size=1, max_dim_size=10 + ) + ) + dtype_and_x = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + shape=x_shape, + small_abs_safety_factor=2, + large_abs_safety_factor=2, + safety_factor_scale="log", + ) + ) + k = draw(helpers.ints(min_value=-5, max_value=5)) + + return dtype_and_x[0], dtype_and_x[1], k + + +# --- Main --- # +# ------------ # + + +# diag @handle_frontend_test( - fn_tree="numpy.tril", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=1, - min_num_dims=2, - max_num_dims=5, - min_dim_size=1, - max_dim_size=5, - ), - k=helpers.ints(min_value=-10, max_value=10), - test_with_out=st.just(False), + fn_tree="numpy.diag", + dtype_and_x_k=_diag_helper(), ) -def test_numpy_tril( - dtype_and_x, - k, +def test_numpy_diag( + dtype_and_x_k, frontend, test_flags, fn_tree, backend_fw, on_device, ): - dtype, x = dtype_and_x + dtype, x, k = dtype_and_x_k helpers.test_frontend_function( input_dtypes=dtype, backend_to_test=backend_fw, @@ -39,35 +58,25 @@ def test_numpy_tril( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - m=x[0], + v=x[0], k=k, ) -# triu +# diagflat @handle_frontend_test( - fn_tree="numpy.triu", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=1, - min_num_dims=2, - max_num_dims=5, - min_dim_size=1, - max_dim_size=5, - ), - k=helpers.ints(min_value=-10, max_value=10), - test_with_out=st.just(False), + fn_tree="numpy.diagflat", + dtype_and_x_k=_diag_flat_helper(), ) -def test_numpy_triu( - dtype_and_x, - k, +def test_numpy_diagflat( + dtype_and_x_k, frontend, test_flags, fn_tree, backend_fw, on_device, ): - dtype, x = dtype_and_x + dtype, x, k = dtype_and_x_k helpers.test_frontend_function( input_dtypes=dtype, backend_to_test=backend_fw, @@ -75,7 +84,7 @@ def test_numpy_triu( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - m=x[0], + v=x[0], k=k, ) @@ -114,20 +123,30 @@ def test_numpy_tri( ) -# diag +# tril @handle_frontend_test( - fn_tree="numpy.diag", - dtype_and_x_k=_diag_helper(), + fn_tree="numpy.tril", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=1, + min_num_dims=2, + max_num_dims=5, + min_dim_size=1, + max_dim_size=5, + ), + k=helpers.ints(min_value=-10, max_value=10), + test_with_out=st.just(False), ) -def test_numpy_diag( - dtype_and_x_k, +def test_numpy_tril( + dtype_and_x, + k, frontend, test_flags, fn_tree, backend_fw, on_device, ): - dtype, x, k = dtype_and_x_k + dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=dtype, backend_to_test=backend_fw, @@ -135,7 +154,43 @@ def test_numpy_diag( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - v=x[0], + m=x[0], + k=k, + ) + + +# triu +@handle_frontend_test( + fn_tree="numpy.triu", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=1, + min_num_dims=2, + max_num_dims=5, + min_dim_size=1, + max_dim_size=5, + ), + k=helpers.ints(min_value=-10, max_value=10), + test_with_out=st.just(False), +) +def test_numpy_triu( + dtype_and_x, + k, + frontend, + test_flags, + fn_tree, + backend_fw, + on_device, +): + dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + m=x[0], k=k, ) @@ -170,50 +225,3 @@ def test_numpy_vander( N=N, increasing=increasing, ) - - -@st.composite -def _diag_flat_helper(draw): - x_shape = draw( - helpers.get_shape( - min_num_dims=1, max_num_dims=2, min_dim_size=1, max_dim_size=10 - ) - ) - dtype_and_x = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - shape=x_shape, - small_abs_safety_factor=2, - large_abs_safety_factor=2, - safety_factor_scale="log", - ) - ) - k = draw(helpers.ints(min_value=-5, max_value=5)) - - return dtype_and_x[0], dtype_and_x[1], k - - -# diagflat -@handle_frontend_test( - fn_tree="numpy.diagflat", - dtype_and_x_k=_diag_flat_helper(), -) -def test_numpy_diagflat( - dtype_and_x_k, - frontend, - test_flags, - fn_tree, - backend_fw, - on_device, -): - dtype, x, k = dtype_and_x_k - helpers.test_frontend_function( - input_dtypes=dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - v=x[0], - k=k, - ) diff --git a/ivy_tests/test_ivy/test_frontends/test_numpy/test_creation_routines/test_from_shape_or_value.py b/ivy_tests/test_ivy/test_frontends/test_numpy/test_creation_routines/test_from_shape_or_value.py index 555277d13fe6c..134d0c1a8576b 100644 --- a/ivy_tests/test_ivy/test_frontends/test_numpy/test_creation_routines/test_from_shape_or_value.py +++ b/ivy_tests/test_ivy/test_frontends/test_numpy/test_creation_routines/test_from_shape_or_value.py @@ -7,6 +7,84 @@ from ivy_tests.test_ivy.helpers import handle_frontend_test, BackendHandler +# --- Helpers --- # +# --------------- # + + +# full and full_like helper +@st.composite +def _input_fill_and_dtype(draw): + dtype = draw(helpers.get_dtypes("float", full=False)) + dtype_and_input = draw(helpers.dtype_and_values(dtype=dtype)) + with BackendHandler.update_backend(test_globals.CURRENT_BACKEND) as ivy_backend: + if ivy_backend.is_uint_dtype(dtype[0]): + fill_values = draw(st.integers(min_value=0, max_value=5)) + elif ivy_backend.is_int_dtype(dtype[0]): + fill_values = draw(st.integers(min_value=-5, max_value=5)) + else: + fill_values = draw( + helpers.floats( + min_value=-5, + max_value=5, + large_abs_safety_factor=10, + small_abs_safety_factor=10, + safety_factor_scale="log", + ) + ) + dtype_to_cast = draw(helpers.get_dtypes("float", full=False)) + return dtype, dtype_and_input[1], fill_values, dtype_to_cast[0] + + +@st.composite +def _shape_and_function( + draw, + *, + allow_none=False, + min_num_dims=1, + max_num_dims=5, + min_dim_size=1, + max_dim_size=10, +): + shape = draw( + helpers.get_shape( + allow_none=allow_none, + min_num_dims=min_num_dims, + max_num_dims=max_num_dims, + min_dim_size=min_dim_size, + max_dim_size=max_dim_size, + ) + ) + + VARS = "abcdefghijklmnopqrstuvw" + args = "" + out = "" + for i in range(len(shape)): + args += f"{VARS[i]}," + out += f"{VARS[i]}+" + fn_str = f"lambda {args[:-1]}: {out[:-1]}" + + def fn2(*args): + return args[0] + + if len(shape) > 1: + + def fn3(*args): + return args[0] == args[1] + + else: + + def fn3(*args): + return args[0] > 10 + + function = draw(st.sampled_from([eval(fn_str), fn2, fn3])) + + return shape, function + + +# --- Main --- # +# ------------ # + + # empty @handle_frontend_test( fn_tree="numpy.empty", @@ -119,15 +197,21 @@ def test_numpy_eye( ) -# identity @handle_frontend_test( - fn_tree="numpy.identity", - n=helpers.ints(min_value=1, max_value=10), - dtype=helpers.get_dtypes("valid", full=False), + fn_tree="numpy.fromfunction", + shape_and_function=_shape_and_function( + allow_none=False, + min_num_dims=1, + max_num_dims=5, + min_dim_size=1, + max_dim_size=5, + ), + # not using valid as bool a is problematic dtype + dtype=helpers.get_dtypes("numeric", full=False), test_with_out=st.just(False), ) -def test_numpy_identity( - n, +def test_numpy_fromfunction( + shape_and_function, dtype, frontend, test_flags, @@ -135,6 +219,7 @@ def test_numpy_identity( backend_fw, on_device, ): + shape, function = shape_and_function helpers.test_frontend_function( input_dtypes=dtype, backend_to_test=backend_fw, @@ -142,14 +227,16 @@ def test_numpy_identity( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - n=n, + test_values=False, + function=function, + shape=shape, dtype=dtype[0], ) -# ones +# full @handle_frontend_test( - fn_tree="numpy.ones", + fn_tree="numpy.full", shape=helpers.get_shape( allow_none=False, min_num_dims=1, @@ -157,36 +244,36 @@ def test_numpy_identity( min_dim_size=1, max_dim_size=10, ), - dtype=helpers.get_dtypes("valid", full=False), + input_fill_dtype=_input_fill_and_dtype(), test_with_out=st.just(False), ) -def test_numpy_ones( +def test_numpy_full( shape, - dtype, + input_fill_dtype, frontend, test_flags, fn_tree, backend_fw, on_device, ): + input_dtype, x, fill, dtype_to_cast = input_fill_dtype helpers.test_frontend_function( - input_dtypes=dtype, + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, shape=shape, - dtype=dtype[0], + fill_value=fill, + dtype=dtype_to_cast, ) -# ones_like +# full_like @handle_frontend_test( - fn_tree="numpy.ones_like", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - ), + fn_tree="numpy.full_like", + input_fill_dtype=_input_fill_and_dtype(), shape=helpers.get_shape( allow_none=True, min_num_dims=1, @@ -194,20 +281,18 @@ def test_numpy_ones( min_dim_size=1, max_dim_size=10, ), - dtype=helpers.get_dtypes("valid", full=False), test_with_out=st.just(False), ) -def test_numpy_ones_like( - dtype_and_x, +def test_numpy_full_like( + input_fill_dtype, shape, - dtype, frontend, test_flags, fn_tree, backend_fw, on_device, ): - input_dtype, x = dtype_and_x + input_dtype, x, fill, dtype_to_cast = input_fill_dtype helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -216,16 +301,45 @@ def test_numpy_ones_like( fn_tree=fn_tree, on_device=on_device, a=x[0], - dtype=dtype[0], + fill_value=fill, + dtype=dtype_to_cast, order="K", subok=True, shape=shape, ) -# zeros +# identity @handle_frontend_test( - fn_tree="numpy.zeros", + fn_tree="numpy.identity", + n=helpers.ints(min_value=1, max_value=10), + dtype=helpers.get_dtypes("valid", full=False), + test_with_out=st.just(False), +) +def test_numpy_identity( + n, + dtype, + frontend, + test_flags, + fn_tree, + backend_fw, + on_device, +): + helpers.test_frontend_function( + input_dtypes=dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + n=n, + dtype=dtype[0], + ) + + +# ones +@handle_frontend_test( + fn_tree="numpy.ones", shape=helpers.get_shape( allow_none=False, min_num_dims=1, @@ -236,7 +350,7 @@ def test_numpy_ones_like( dtype=helpers.get_dtypes("valid", full=False), test_with_out=st.just(False), ) -def test_numpy_zeros( +def test_numpy_ones( shape, dtype, frontend, @@ -257,9 +371,9 @@ def test_numpy_zeros( ) -# zeros_like +# ones_like @handle_frontend_test( - fn_tree="numpy.zeros_like", + fn_tree="numpy.ones_like", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), ), @@ -273,10 +387,10 @@ def test_numpy_zeros( dtype=helpers.get_dtypes("valid", full=False), test_with_out=st.just(False), ) -def test_numpy_zeros_like( +def test_numpy_ones_like( dtype_and_x, - dtype, shape, + dtype, frontend, test_flags, fn_tree, @@ -299,33 +413,9 @@ def test_numpy_zeros_like( ) -# full and full_like helper -@st.composite -def _input_fill_and_dtype(draw): - dtype = draw(helpers.get_dtypes("float", full=False)) - dtype_and_input = draw(helpers.dtype_and_values(dtype=dtype)) - with BackendHandler.update_backend(test_globals.CURRENT_BACKEND) as ivy_backend: - if ivy_backend.is_uint_dtype(dtype[0]): - fill_values = draw(st.integers(min_value=0, max_value=5)) - elif ivy_backend.is_int_dtype(dtype[0]): - fill_values = draw(st.integers(min_value=-5, max_value=5)) - else: - fill_values = draw( - helpers.floats( - min_value=-5, - max_value=5, - large_abs_safety_factor=10, - small_abs_safety_factor=10, - safety_factor_scale="log", - ) - ) - dtype_to_cast = draw(helpers.get_dtypes("float", full=False)) - return dtype, dtype_and_input[1], fill_values, dtype_to_cast[0] - - -# full +# zeros @handle_frontend_test( - fn_tree="numpy.full", + fn_tree="numpy.zeros", shape=helpers.get_shape( allow_none=False, min_num_dims=1, @@ -333,36 +423,36 @@ def _input_fill_and_dtype(draw): min_dim_size=1, max_dim_size=10, ), - input_fill_dtype=_input_fill_and_dtype(), + dtype=helpers.get_dtypes("valid", full=False), test_with_out=st.just(False), ) -def test_numpy_full( +def test_numpy_zeros( shape, - input_fill_dtype, + dtype, frontend, test_flags, fn_tree, backend_fw, on_device, ): - input_dtype, x, fill, dtype_to_cast = input_fill_dtype helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, shape=shape, - fill_value=fill, - dtype=dtype_to_cast, + dtype=dtype[0], ) -# full_like +# zeros_like @handle_frontend_test( - fn_tree="numpy.full_like", - input_fill_dtype=_input_fill_and_dtype(), + fn_tree="numpy.zeros_like", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + ), shape=helpers.get_shape( allow_none=True, min_num_dims=1, @@ -370,10 +460,12 @@ def test_numpy_full( min_dim_size=1, max_dim_size=10, ), + dtype=helpers.get_dtypes("valid", full=False), test_with_out=st.just(False), ) -def test_numpy_full_like( - input_fill_dtype, +def test_numpy_zeros_like( + dtype_and_x, + dtype, shape, frontend, test_flags, @@ -381,7 +473,7 @@ def test_numpy_full_like( backend_fw, on_device, ): - input_dtype, x, fill, dtype_to_cast = input_fill_dtype + input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -390,92 +482,8 @@ def test_numpy_full_like( fn_tree=fn_tree, on_device=on_device, a=x[0], - fill_value=fill, - dtype=dtype_to_cast, + dtype=dtype[0], order="K", subok=True, shape=shape, ) - - -@st.composite -def _shape_and_function( - draw, - *, - allow_none=False, - min_num_dims=1, - max_num_dims=5, - min_dim_size=1, - max_dim_size=10, -): - shape = draw( - helpers.get_shape( - allow_none=allow_none, - min_num_dims=min_num_dims, - max_num_dims=max_num_dims, - min_dim_size=min_dim_size, - max_dim_size=max_dim_size, - ) - ) - - VARS = "abcdefghijklmnopqrstuvw" - args = "" - out = "" - for i in range(len(shape)): - args += f"{VARS[i]}," - out += f"{VARS[i]}+" - fn_str = f"lambda {args[:-1]}: {out[:-1]}" - - def fn2(*args): - return args[0] - - if len(shape) > 1: - - def fn3(*args): - return args[0] == args[1] - - else: - - def fn3(*args): - return args[0] > 10 - - function = draw(st.sampled_from([eval(fn_str), fn2, fn3])) - - return shape, function - - -@handle_frontend_test( - fn_tree="numpy.fromfunction", - shape_and_function=_shape_and_function( - allow_none=False, - min_num_dims=1, - max_num_dims=5, - min_dim_size=1, - max_dim_size=5, - ), - # not using valid as bool a is problematic dtype - dtype=helpers.get_dtypes("numeric", full=False), - test_with_out=st.just(False), -) -def test_numpy_fromfunction( - shape_and_function, - dtype, - frontend, - test_flags, - fn_tree, - backend_fw, - on_device, -): - shape, function = shape_and_function - helpers.test_frontend_function( - input_dtypes=dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - test_values=False, - function=function, - shape=shape, - dtype=dtype[0], - ) diff --git a/ivy_tests/test_ivy/test_frontends/test_numpy/test_creation_routines/test_numerical_ranges.py b/ivy_tests/test_ivy/test_frontends/test_numpy/test_creation_routines/test_numerical_ranges.py index 131e48e73ee3f..0bea5cb7070c8 100644 --- a/ivy_tests/test_ivy/test_frontends/test_numpy/test_creation_routines/test_numerical_ranges.py +++ b/ivy_tests/test_ivy/test_frontends/test_numpy/test_creation_routines/test_numerical_ranges.py @@ -10,6 +10,23 @@ from ivy_tests.test_ivy.helpers import handle_frontend_test, handle_frontend_method +# --- Helpers --- # +# --------------- # + + +@st.composite +def _get_dtype_and_range(draw): + dim = draw(helpers.ints(min_value=2, max_value=5)) + dtype = draw(helpers.get_dtypes("float", index=1, full=False)) + start = draw( + helpers.array_values(dtype=dtype[0], shape=(dim,), min_value=-50, max_value=0) + ) + stop = draw( + helpers.array_values(dtype=dtype[0], shape=(dim,), min_value=1, max_value=50) + ) + return dtype * 2, start, stop + + # helpers @st.composite def _get_range_for_grid(draw): @@ -31,17 +48,8 @@ def _get_range_for_grid(draw): return start, stop, None -@st.composite -def _get_dtype_and_range(draw): - dim = draw(helpers.ints(min_value=2, max_value=5)) - dtype = draw(helpers.get_dtypes("float", index=1, full=False)) - start = draw( - helpers.array_values(dtype=dtype[0], shape=(dim,), min_value=-50, max_value=0) - ) - stop = draw( - helpers.array_values(dtype=dtype[0], shape=(dim,), min_value=1, max_value=50) - ) - return dtype * 2, start, stop +# --- Main --- # +# ------------ # # arange @@ -78,6 +86,40 @@ def test_numpy_arange( ) +@handle_frontend_test( + fn_tree="numpy.geomspace", + dtype_start_stop=_get_dtype_and_range(), + num=helpers.ints(min_value=5, max_value=50), + endpoint=st.booleans(), + test_with_out=st.just(False), +) +def test_numpy_geomspace( + dtype_start_stop, + num, + endpoint, + frontend, + test_flags, + fn_tree, + backend_fw, + on_device, +): + input_dtypes, start, stop = dtype_start_stop + helpers.test_frontend_function( + input_dtypes=input_dtypes, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + rtol=1e-1, + start=start, + stop=stop, + num=num, + endpoint=endpoint, + dtype=input_dtypes[0], + ) + + # linspace @handle_frontend_test( fn_tree="numpy.linspace", @@ -283,37 +325,3 @@ def test_numpy_ogrid(range, class_, method_name, backend_fw, frontend): backend=backend_fw, ground_truth_backend=frontend, ) - - -@handle_frontend_test( - fn_tree="numpy.geomspace", - dtype_start_stop=_get_dtype_and_range(), - num=helpers.ints(min_value=5, max_value=50), - endpoint=st.booleans(), - test_with_out=st.just(False), -) -def test_numpy_geomspace( - dtype_start_stop, - num, - endpoint, - frontend, - test_flags, - fn_tree, - backend_fw, - on_device, -): - input_dtypes, start, stop = dtype_start_stop - helpers.test_frontend_function( - input_dtypes=input_dtypes, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - rtol=1e-1, - start=start, - stop=stop, - num=num, - endpoint=endpoint, - dtype=input_dtypes[0], - ) diff --git a/ivy_tests/test_ivy/test_frontends/test_numpy/test_data_type_routines/test_general.py b/ivy_tests/test_ivy/test_frontends/test_numpy/test_data_type_routines/test_general.py index eb7f4c8263859..e07a4274ebafa 100644 --- a/ivy_tests/test_ivy/test_frontends/test_numpy/test_data_type_routines/test_general.py +++ b/ivy_tests/test_ivy/test_frontends/test_numpy/test_data_type_routines/test_general.py @@ -41,25 +41,26 @@ def test_numpy_can_cast( ) -# promote_types @handle_frontend_test( - fn_tree="numpy.promote_types", - type1=helpers.get_dtypes("valid", full=False), - type2=helpers.get_dtypes("valid", full=False), - test_with_out=st.just(False), + fn_tree="numpy.min_scalar_type", + x=st.one_of( + helpers.ints(min_value=-256, max_value=256), + st.booleans(), + helpers.floats(min_value=-256, max_value=256), + ), ) -# there are 100 combinations of dtypes, so run 200 examples to make sure all are tested @settings(max_examples=200) -def test_numpy_promote_types( +def test_numpy_min_scalar_type( *, - type1, - type2, + x, on_device, fn_tree, frontend, test_flags, backend_fw, -): +): # skip torch backend uint + if ivy.current_backend_str() == "torch": + assume(not isinstance(x, int)) ret, frontend_ret = helpers.test_frontend_function( input_dtypes=[], backend_to_test=backend_fw, @@ -67,33 +68,31 @@ def test_numpy_promote_types( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - type1=type1[0], - type2=type2[0], + a=x, test_values=False, ) assert ret._ivy_dtype == frontend_ret[0].name +# promote_types @handle_frontend_test( - fn_tree="numpy.min_scalar_type", - x=st.one_of( - helpers.ints(min_value=-256, max_value=256), - st.booleans(), - helpers.floats(min_value=-256, max_value=256), - ), + fn_tree="numpy.promote_types", + type1=helpers.get_dtypes("valid", full=False), + type2=helpers.get_dtypes("valid", full=False), + test_with_out=st.just(False), ) +# there are 100 combinations of dtypes, so run 200 examples to make sure all are tested @settings(max_examples=200) -def test_numpy_min_scalar_type( +def test_numpy_promote_types( *, - x, + type1, + type2, on_device, fn_tree, frontend, test_flags, backend_fw, -): # skip torch backend uint - if ivy.current_backend_str() == "torch": - assume(not isinstance(x, int)) +): ret, frontend_ret = helpers.test_frontend_function( input_dtypes=[], backend_to_test=backend_fw, @@ -101,7 +100,8 @@ def test_numpy_min_scalar_type( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - a=x, + type1=type1[0], + type2=type2[0], test_values=False, ) assert ret._ivy_dtype == frontend_ret[0].name diff --git a/ivy_tests/test_ivy/test_frontends/test_numpy/test_fft/test_discrete_fourier_transform.py b/ivy_tests/test_ivy/test_frontends/test_numpy/test_fft/test_discrete_fourier_transform.py index 1e4557bf521d3..cf81f0b7bdf41 100644 --- a/ivy_tests/test_ivy/test_frontends/test_numpy/test_fft/test_discrete_fourier_transform.py +++ b/ivy_tests/test_ivy/test_frontends/test_numpy/test_fft/test_discrete_fourier_transform.py @@ -9,15 +9,22 @@ x_and_rfftn, ) -# ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_layers.py - @handle_frontend_test( - fn_tree="numpy.fft.ifft", - dtype_and_x=x_and_ifft(), + fn_tree="numpy.fft.fft", + dtype_input_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("float_and_complex"), + shape=(2,), + min_axis=-1, + force_int_axis=True, + ), + norm=st.sampled_from(["backward", "ortho", "forward"]), + n=st.integers(min_value=2, max_value=10), ) -def test_numpy_ifft(dtype_and_x, backend_fw, frontend, test_flags, fn_tree, on_device): - input_dtype, x, dim, norm, n = dtype_and_x +def test_numpy_fft( + dtype_input_axis, norm, n, backend_fw, frontend, test_flags, fn_tree, on_device +): + input_dtype, x, axis = dtype_input_axis helpers.test_frontend_function( input_dtypes=input_dtype, frontend=frontend, @@ -26,20 +33,42 @@ def test_numpy_ifft(dtype_and_x, backend_fw, frontend, test_flags, fn_tree, on_d fn_tree=fn_tree, on_device=on_device, test_values=True, - a=x, + a=x[0], n=n, - axis=dim, + axis=axis, norm=norm, ) @handle_frontend_test( - fn_tree="numpy.fft.ifftshift", + fn_tree="numpy.fft.fftfreq", + n=st.integers(min_value=10, max_value=100), + sample_rate=st.integers(min_value=1, max_value=10), +) +def test_numpy_fftfreq( + n, sample_rate, backend_fw, frontend, test_flags, fn_tree, on_device +): + d = 1 / sample_rate + helpers.test_frontend_function( + input_dtypes=[int], + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + test_values=True, + n=n, + d=d, + ) + + +@handle_frontend_test( + fn_tree="numpy.fft.fftshift", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), shape=(4,), array_api_dtypes=True ), ) -def test_numpy_ifftshift( +def test_numpy_fftshift( dtype_and_x, backend_fw, frontend, test_flags, fn_tree, on_device ): input_dtype, arr = dtype_and_x @@ -56,21 +85,15 @@ def test_numpy_ifftshift( ) +# ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_layers.py + + @handle_frontend_test( - fn_tree="numpy.fft.fft", - dtype_input_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("float_and_complex"), - shape=(2,), - min_axis=-1, - force_int_axis=True, - ), - norm=st.sampled_from(["backward", "ortho", "forward"]), - n=st.integers(min_value=2, max_value=10), + fn_tree="numpy.fft.ifft", + dtype_and_x=x_and_ifft(), ) -def test_numpy_fft( - dtype_input_axis, norm, n, backend_fw, frontend, test_flags, fn_tree, on_device -): - input_dtype, x, axis = dtype_input_axis +def test_numpy_ifft(dtype_and_x, backend_fw, frontend, test_flags, fn_tree, on_device): + input_dtype, x, dim, norm, n = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, frontend=frontend, @@ -79,20 +102,41 @@ def test_numpy_fft( fn_tree=fn_tree, on_device=on_device, test_values=True, - a=x[0], + a=x, n=n, - axis=axis, + axis=dim, norm=norm, ) @handle_frontend_test( - fn_tree="numpy.fft.fftshift", + fn_tree="numpy.fft.ifftn", + dtype_and_x=x_and_ifft(), +) +def test_numpy_ifftn(dtype_and_x, backend_fw, frontend, test_flags, fn_tree, on_device): + input_dtype, x, dim, norm, n = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + test_values=True, + a=x, + s=None, + axes=None, + norm=norm, + ) + + +@handle_frontend_test( + fn_tree="numpy.fft.ifftshift", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), shape=(4,), array_api_dtypes=True ), ) -def test_numpy_fftshift( +def test_numpy_ifftshift( dtype_and_x, backend_fw, frontend, test_flags, fn_tree, on_device ): input_dtype, arr = dtype_and_x @@ -110,7 +154,7 @@ def test_numpy_fftshift( @handle_frontend_test( - fn_tree="numpy.fft.rfft", + fn_tree="numpy.fft.ihfft", dtype_input_axis=helpers.dtype_values_axis( available_dtypes=helpers.get_dtypes("float_and_complex"), shape=(2,), @@ -120,7 +164,7 @@ def test_numpy_fftshift( norm=st.sampled_from(["backward", "ortho", "forward"]), n=st.integers(min_value=2, max_value=5), ) -def test_numpy_rfft( +def test_numpy_ihfft( dtype_input_axis, norm, n, backend_fw, frontend, test_flags, fn_tree, on_device ): input_dtype, x, axis = dtype_input_axis @@ -140,7 +184,7 @@ def test_numpy_rfft( @handle_frontend_test( - fn_tree="numpy.fft.ihfft", + fn_tree="numpy.fft.rfft", dtype_input_axis=helpers.dtype_values_axis( available_dtypes=helpers.get_dtypes("float_and_complex"), shape=(2,), @@ -150,7 +194,7 @@ def test_numpy_rfft( norm=st.sampled_from(["backward", "ortho", "forward"]), n=st.integers(min_value=2, max_value=5), ) -def test_numpy_ihfft( +def test_numpy_rfft( dtype_input_axis, norm, n, backend_fw, frontend, test_flags, fn_tree, on_device ): input_dtype, x, axis = dtype_input_axis @@ -169,28 +213,6 @@ def test_numpy_ihfft( ) -@handle_frontend_test( - fn_tree="numpy.fft.fftfreq", - n=st.integers(min_value=10, max_value=100), - sample_rate=st.integers(min_value=1, max_value=10), -) -def test_numpy_fftfreq( - n, sample_rate, backend_fw, frontend, test_flags, fn_tree, on_device -): - d = 1 / sample_rate - helpers.test_frontend_function( - input_dtypes=[int], - frontend=frontend, - backend_to_test=backend_fw, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - test_values=True, - n=n, - d=d, - ) - - @handle_frontend_test( fn_tree="numpy.fft.rfftfreq", n=st.integers(min_value=10, max_value=100), @@ -213,27 +235,6 @@ def test_numpy_rfftfreq( ) -@handle_frontend_test( - fn_tree="numpy.fft.ifftn", - dtype_and_x=x_and_ifft(), -) -def test_numpy_ifftn(dtype_and_x, backend_fw, frontend, test_flags, fn_tree, on_device): - input_dtype, x, dim, norm, n = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - frontend=frontend, - backend_to_test=backend_fw, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - test_values=True, - a=x, - s=None, - axes=None, - norm=norm, - ) - - @handle_frontend_test( fn_tree="numpy.fft.rfftn", dtype_and_x=x_and_rfftn(), diff --git a/ivy_tests/test_ivy/test_frontends/test_numpy/test_func_wrapper.py b/ivy_tests/test_ivy/test_frontends/test_numpy/test_func_wrapper.py index fab8a56e01a50..0df77d4cac224 100644 --- a/ivy_tests/test_ivy/test_frontends/test_numpy/test_func_wrapper.py +++ b/ivy_tests/test_ivy/test_frontends/test_numpy/test_func_wrapper.py @@ -16,6 +16,33 @@ import ivy.functional.frontends.numpy as np_frontend +# --- Helpers --- # +# --------------- # + + +@st.composite +def _dtype_helper(draw): + return draw( + st.sampled_from( + [ + draw(st.sampled_from([int, float, bool])), + ivy.as_native_dtype( + draw(helpers.get_dtypes("valid", full=False, prune_function=False))[ + 0 + ] + ), + np_frontend.dtype( + draw(helpers.get_dtypes("valid", full=False, prune_function=False))[ + 0 + ] + ), + draw(st.sampled_from(list(np_frontend.numpy_scalar_to_dtype.keys()))), + draw(st.sampled_from(list(np_frontend.numpy_str_to_type_table.keys()))), + ] + ) + ) + + def _fn(*args, check_default=False, dtype=None): if ( check_default @@ -38,6 +65,59 @@ def _fn(*args, check_default=False, dtype=None): return args[0] +def _zero_dim_to_scalar_checks(x, ret_x): + if len(x.shape) > 0: + assert ivy.all(ivy.array(ret_x) == ivy.array(x)) + else: + assert issubclass(type(ret_x), np_frontend.generic) + assert ret_x.ivy_array == ivy.array(x) + + +@st.composite +def _zero_dim_to_scalar_helper(draw): + dtype = draw( + helpers.get_dtypes("valid", prune_function=False, full=False).filter( + lambda x: "bfloat16" not in x + ) + )[0] + shape = draw(helpers.get_shape()) + return draw( + st.one_of( + helpers.array_values(shape=shape, dtype=dtype), + st.lists(helpers.array_values(shape=shape, dtype=dtype), min_size=1).map( + tuple + ), + ) + ) + + +# --- Main --- # +# ------------ # + + +@given( + dtype=_dtype_helper(), +) +def test_handle_numpy_dtype(dtype, backend_fw): + ivy.set_backend(backend_fw) + ret_dtype = handle_numpy_dtype(_fn)(None, dtype=dtype) + assert isinstance(ret_dtype, ivy.Dtype) + ivy.previous_backend() + + +@given(x=_zero_dim_to_scalar_helper()) +def test_numpy_from_zero_dim_arrays_to_scalar(x, backend_fw): + ivy.set_backend(backend_fw) + ret_x = from_zero_dim_arrays_to_scalar(_fn)(x) + if isinstance(x, tuple): + assert isinstance(ret_x, tuple) + for x_i, ret_x_i in zip(x, ret_x): + _zero_dim_to_scalar_checks(x_i, ret_x_i) + else: + _zero_dim_to_scalar_checks(x, ret_x) + ivy.previous_backend() + + @given( dtype_x_shape=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid", prune_function=False), @@ -162,75 +242,3 @@ def test_numpy_to_ivy_arrays_and_back(dtype_x_shape, dtype, backend_fw): assert ivy.default_float_dtype_stack == ivy.default_int_dtype_stack == [] ivy.previous_backend() - - -@st.composite -def _zero_dim_to_scalar_helper(draw): - dtype = draw( - helpers.get_dtypes("valid", prune_function=False, full=False).filter( - lambda x: "bfloat16" not in x - ) - )[0] - shape = draw(helpers.get_shape()) - return draw( - st.one_of( - helpers.array_values(shape=shape, dtype=dtype), - st.lists(helpers.array_values(shape=shape, dtype=dtype), min_size=1).map( - tuple - ), - ) - ) - - -def _zero_dim_to_scalar_checks(x, ret_x): - if len(x.shape) > 0: - assert ivy.all(ivy.array(ret_x) == ivy.array(x)) - else: - assert issubclass(type(ret_x), np_frontend.generic) - assert ret_x.ivy_array == ivy.array(x) - - -@given(x=_zero_dim_to_scalar_helper()) -def test_numpy_from_zero_dim_arrays_to_scalar(x, backend_fw): - ivy.set_backend(backend_fw) - ret_x = from_zero_dim_arrays_to_scalar(_fn)(x) - if isinstance(x, tuple): - assert isinstance(ret_x, tuple) - for x_i, ret_x_i in zip(x, ret_x): - _zero_dim_to_scalar_checks(x_i, ret_x_i) - else: - _zero_dim_to_scalar_checks(x, ret_x) - ivy.previous_backend() - - -@st.composite -def _dtype_helper(draw): - return draw( - st.sampled_from( - [ - draw(st.sampled_from([int, float, bool])), - ivy.as_native_dtype( - draw(helpers.get_dtypes("valid", full=False, prune_function=False))[ - 0 - ] - ), - np_frontend.dtype( - draw(helpers.get_dtypes("valid", full=False, prune_function=False))[ - 0 - ] - ), - draw(st.sampled_from(list(np_frontend.numpy_scalar_to_dtype.keys()))), - draw(st.sampled_from(list(np_frontend.numpy_str_to_type_table.keys()))), - ] - ) - ) - - -@given( - dtype=_dtype_helper(), -) -def test_handle_numpy_dtype(dtype, backend_fw): - ivy.set_backend(backend_fw) - ret_dtype = handle_numpy_dtype(_fn)(None, dtype=dtype) - assert isinstance(ret_dtype, ivy.Dtype) - ivy.previous_backend() diff --git a/ivy_tests/test_ivy/test_frontends/test_numpy/test_indexing_routines/test_generating_index_arrays.py b/ivy_tests/test_ivy/test_frontends/test_numpy/test_indexing_routines/test_generating_index_arrays.py index ca62aa3a46b1d..612bf8ba82d04 100644 --- a/ivy_tests/test_ivy/test_frontends/test_numpy/test_indexing_routines/test_generating_index_arrays.py +++ b/ivy_tests/test_ivy/test_frontends/test_numpy/test_indexing_routines/test_generating_index_arrays.py @@ -7,35 +7,8 @@ from ivy_tests.test_ivy.helpers import handle_frontend_test -@handle_frontend_test( - fn_tree="numpy.indices", - dimensions=helpers.get_shape(min_num_dims=1), - dtype=helpers.get_dtypes(kind="float", full=False), - sparse=st.booleans(), - test_with_out=st.just(False), -) -def test_numpy_indices( - *, - dimensions, - dtype, - sparse, - test_flags, - frontend, - backend_fw, - fn_tree, - on_device, -): - helpers.test_frontend_function( - input_dtypes=dtype, - backend_to_test=backend_fw, - test_flags=test_flags, - frontend=frontend, - fn_tree=fn_tree, - on_device=on_device, - dimensions=dimensions, - dtype=dtype[0], - sparse=sparse, - ) +# --- Helpers --- # +# --------------- # # unravel_index @@ -60,44 +33,46 @@ def max_value_as_shape_prod(draw): @handle_frontend_test( - fn_tree="numpy.unravel_index", - dtype_x_shape=max_value_as_shape_prod(), + fn_tree="numpy.diag_indices", + n=helpers.ints(min_value=1, max_value=10), + ndim=helpers.ints(min_value=2, max_value=10), + dtype=helpers.get_dtypes("valid", full=False), test_with_out=st.just(False), ) -def test_numpy_unravel_index( - *, - dtype_x_shape, +def test_numpy_diag_indices( + n, + ndim, + dtype, test_flags, frontend, backend_fw, fn_tree, on_device, ): - dtype_and_x, shape = dtype_x_shape - input_dtype, x = dtype_and_x[0], dtype_and_x[1] helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=dtype, backend_to_test=backend_fw, test_flags=test_flags, frontend=frontend, fn_tree=fn_tree, on_device=on_device, - indices=x[0], - shape=shape, + n=n, + ndim=ndim, ) @handle_frontend_test( - fn_tree="numpy.diag_indices", - n=helpers.ints(min_value=1, max_value=10), - ndim=helpers.ints(min_value=2, max_value=10), - dtype=helpers.get_dtypes("valid", full=False), + fn_tree="numpy.indices", + dimensions=helpers.get_shape(min_num_dims=1), + dtype=helpers.get_dtypes(kind="float", full=False), + sparse=st.booleans(), test_with_out=st.just(False), ) -def test_numpy_diag_indices( - n, - ndim, +def test_numpy_indices( + *, + dimensions, dtype, + sparse, test_flags, frontend, backend_fw, @@ -111,8 +86,9 @@ def test_numpy_diag_indices( frontend=frontend, fn_tree=fn_tree, on_device=on_device, - n=n, - ndim=ndim, + dimensions=dimensions, + dtype=dtype[0], + sparse=sparse, ) @@ -145,3 +121,31 @@ def test_numpy_tril_indices( k=k, m=m, ) + + +@handle_frontend_test( + fn_tree="numpy.unravel_index", + dtype_x_shape=max_value_as_shape_prod(), + test_with_out=st.just(False), +) +def test_numpy_unravel_index( + *, + dtype_x_shape, + test_flags, + frontend, + backend_fw, + fn_tree, + on_device, +): + dtype_and_x, shape = dtype_x_shape + input_dtype, x = dtype_and_x[0], dtype_and_x[1] + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + test_flags=test_flags, + frontend=frontend, + fn_tree=fn_tree, + on_device=on_device, + indices=x[0], + shape=shape, + ) diff --git a/ivy_tests/test_ivy/test_frontends/test_numpy/test_indexing_routines/test_indexing_like_operations.py b/ivy_tests/test_ivy/test_frontends/test_numpy/test_indexing_routines/test_indexing_like_operations.py index c2e486be78ac3..bb974d204c426 100644 --- a/ivy_tests/test_ivy/test_frontends/test_numpy/test_indexing_routines/test_indexing_like_operations.py +++ b/ivy_tests/test_ivy/test_frontends/test_numpy/test_indexing_routines/test_indexing_like_operations.py @@ -9,38 +9,43 @@ @handle_frontend_test( - fn_tree="numpy.take_along_axis", - dtype_x_indices_axis=helpers.array_indices_axis( - array_dtypes=helpers.get_dtypes("numeric"), - indices_dtypes=["int32", "int64"], + fn_tree="numpy.compress", + dtype_arr_ax=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("valid"), min_num_dims=1, max_num_dims=5, - min_dim_size=1, - max_dim_size=10, - indices_same_dims=True, + min_dim_size=10, + max_dim_size=100, + valid_axis=True, + force_int_axis=True, + ), + condition=helpers.array_values( + dtype=helpers.get_dtypes("bool"), + shape=helpers.get_shape( + min_num_dims=1, max_num_dims=1, min_dim_size=1, max_dim_size=5 + ), ), - test_with_out=st.just(False), ) -def test_numpy_take_along_axis( - *, - dtype_x_indices_axis, - test_flags, +def test_numpy_compress( + dtype_arr_ax, + condition, frontend, + test_flags, backend_fw, fn_tree, on_device, ): - dtypes, x, indices, axis, _ = dtype_x_indices_axis + dtype, arr, ax = dtype_arr_ax helpers.test_frontend_function( - input_dtypes=dtypes, + input_dtypes=dtype, + frontend=frontend, backend_to_test=backend_fw, test_flags=test_flags, - frontend=frontend, fn_tree=fn_tree, on_device=on_device, - arr=x, - indices=indices, - axis=axis, + condition=condition, + a=arr[0], + axis=ax, ) @@ -150,41 +155,36 @@ def test_numpy_put_along_axis( @handle_frontend_test( - fn_tree="numpy.compress", - dtype_arr_ax=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("valid"), + fn_tree="numpy.take_along_axis", + dtype_x_indices_axis=helpers.array_indices_axis( + array_dtypes=helpers.get_dtypes("numeric"), + indices_dtypes=["int32", "int64"], min_num_dims=1, max_num_dims=5, - min_dim_size=10, - max_dim_size=100, - valid_axis=True, - force_int_axis=True, - ), - condition=helpers.array_values( - dtype=helpers.get_dtypes("bool"), - shape=helpers.get_shape( - min_num_dims=1, max_num_dims=1, min_dim_size=1, max_dim_size=5 - ), + min_dim_size=1, + max_dim_size=10, + indices_same_dims=True, ), + test_with_out=st.just(False), ) -def test_numpy_compress( - dtype_arr_ax, - condition, - frontend, +def test_numpy_take_along_axis( + *, + dtype_x_indices_axis, test_flags, + frontend, backend_fw, fn_tree, on_device, ): - dtype, arr, ax = dtype_arr_ax + dtypes, x, indices, axis, _ = dtype_x_indices_axis helpers.test_frontend_function( - input_dtypes=dtype, - frontend=frontend, + input_dtypes=dtypes, backend_to_test=backend_fw, test_flags=test_flags, + frontend=frontend, fn_tree=fn_tree, on_device=on_device, - condition=condition, - a=arr[0], - axis=ax, + arr=x, + indices=indices, + axis=axis, ) diff --git a/ivy_tests/test_ivy/test_frontends/test_numpy/test_indexing_routines/test_inserting_data_into_arrays.py b/ivy_tests/test_ivy/test_frontends/test_numpy/test_indexing_routines/test_inserting_data_into_arrays.py index 63cc4ef1f79d2..186794403dc41 100644 --- a/ivy_tests/test_ivy/test_frontends/test_numpy/test_indexing_routines/test_inserting_data_into_arrays.py +++ b/ivy_tests/test_ivy/test_frontends/test_numpy/test_indexing_routines/test_inserting_data_into_arrays.py @@ -9,6 +9,37 @@ import ivy.functional.frontends.numpy as np_frontend +# --- Helpers --- # +# --------------- # + + +@st.composite +def _helper_c_(draw): + dim = draw(st.integers(1, 3)) + num_of_elems = draw(st.integers(1, 5)) + elem_shape = draw(helpers.get_shape(min_num_dims=dim, max_num_dims=dim)) + ret = [] + if dim == 1: + start = draw(st.integers(min_value=-100, max_value=100)) + step = draw(st.integers(1, 3)) + stop = start + 1 + (tuple(elem_shape)[0] - 1) * step + elem = slice(start, stop, step) + ret.append(elem) + input_dtypes, x, casting, dtype = draw( + np_frontend_helpers.dtypes_values_casting_dtype( + arr_func=[ + lambda: helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + shape=elem_shape, + num_arrays=num_of_elems, + shared_dtype=True, + ) + ], + ), + ) + return x + ret + + @st.composite def _helper_r_(draw): elems_in_last_dim = draw(st.integers(min_value=2, max_value=8)) @@ -84,31 +115,20 @@ def _helper_r_(draw): return ret, elems_in_last_dim, dim -@st.composite -def _helper_c_(draw): - dim = draw(st.integers(1, 3)) - num_of_elems = draw(st.integers(1, 5)) - elem_shape = draw(helpers.get_shape(min_num_dims=dim, max_num_dims=dim)) - ret = [] - if dim == 1: - start = draw(st.integers(min_value=-100, max_value=100)) - step = draw(st.integers(1, 3)) - stop = start + 1 + (tuple(elem_shape)[0] - 1) * step - elem = slice(start, stop, step) - ret.append(elem) - input_dtypes, x, casting, dtype = draw( - np_frontend_helpers.dtypes_values_casting_dtype( - arr_func=[ - lambda: helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - shape=elem_shape, - num_arrays=num_of_elems, - shared_dtype=True, - ) - ], - ), - ) - return x + ret +# --- Main --- # +# ------------ # + + +@handle_frontend_test(fn_tree="numpy.add", inputs=_helper_c_()) # dummy fn_tree +def test_numpy_c_(inputs, backend_fw): + ret_gt = np.c_.__getitem__(tuple(inputs)) + with BackendHandler.update_backend(backend_fw): + ret = np_frontend.c_.__getitem__(tuple(inputs)) + if isinstance(inputs[0], str) and inputs[0] in ["r", "c"]: + ret = ret._data + else: + ret = ret.ivy_array + assert np.allclose(ret, ret_gt) @handle_frontend_test( @@ -158,15 +178,3 @@ def test_numpy_r_(inputs, backend_fw): else: ret = ret.ivy_array assert np.allclose(ret, ret_gt) - - -@handle_frontend_test(fn_tree="numpy.add", inputs=_helper_c_()) # dummy fn_tree -def test_numpy_c_(inputs, backend_fw): - ret_gt = np.c_.__getitem__(tuple(inputs)) - with BackendHandler.update_backend(backend_fw): - ret = np_frontend.c_.__getitem__(tuple(inputs)) - if isinstance(inputs[0], str) and inputs[0] in ["r", "c"]: - ret = ret._data - else: - ret = ret.ivy_array - assert np.allclose(ret, ret_gt) diff --git a/ivy_tests/test_ivy/test_frontends/test_numpy/test_linalg/test_matrix_and_vector_products.py b/ivy_tests/test_ivy/test_frontends/test_numpy/test_linalg/test_matrix_and_vector_products.py index 6b7560078a5b4..015c1bf04c36e 100644 --- a/ivy_tests/test_ivy/test_frontends/test_numpy/test_linalg/test_matrix_and_vector_products.py +++ b/ivy_tests/test_ivy/test_frontends/test_numpy/test_linalg/test_matrix_and_vector_products.py @@ -16,71 +16,8 @@ ) -# outer -@handle_frontend_test( - fn_tree="numpy.outer", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - min_value=-10, - max_value=10, - num_arrays=2, - min_num_dims=1, - max_num_dims=1, - shared_dtype=True, - ), -) -def test_numpy_outer( - dtype_and_x, - frontend, - test_flags, - fn_tree, - backend_fw, - on_device, -): - input_dtypes, xs = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtypes, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - a=xs[0], - b=xs[1], - ) - - -# inner -@handle_frontend_test( - fn_tree="numpy.inner", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - min_value=-10, - max_value=10, - num_arrays=2, - shared_dtype=True, - ), - test_with_out=st.just(False), -) -def test_numpy_inner( - dtype_and_x, - frontend, - test_flags, - fn_tree, - backend_fw, - on_device, -): - input_dtypes, xs = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtypes, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - a=xs[0], - b=xs[1], - ) +# --- Helpers --- # +# --------------- # # cross @@ -186,102 +123,112 @@ def test_numpy_cross( ) -# matmul +# dot @handle_frontend_test( - fn_tree="numpy.matmul", - dtypes_values_casting=np_frontend_helpers.dtypes_values_casting_dtype( - arr_func=[_get_first_matrix_and_dtype, _get_second_matrix_and_dtype], - ), - number_positional_args=np_frontend_helpers.get_num_positional_args_ufunc( - fn_name="matmul" - ), + fn_tree="numpy.dot", + dtype_a_b=np_frontend_helpers._get_dtype_input_and_vectors(), ) -def test_numpy_matmul( - dtypes_values_casting, +def test_numpy_dot( + dtype_a_b, frontend, + backend_fw, test_flags, fn_tree, - backend_fw, on_device, ): - dtypes, x, casting, dtype = dtypes_values_casting + dtype, a, b = dtype_a_b helpers.test_frontend_function( - input_dtypes=dtypes, - backend_to_test=backend_fw, + input_dtypes=dtype, frontend=frontend, - test_flags=test_flags, + backend_to_test=backend_fw, fn_tree=fn_tree, on_device=on_device, - x1=x[0], - x2=x[1], - out=None, - casting=casting, - order="K", - dtype=dtype, - # The arguments below are currently unused. - # subok=True, + test_flags=test_flags, + rtol=1e-01, + atol=1e-01, + a=a, + b=b, ) -# matrix_power +# einsum @handle_frontend_test( - fn_tree="numpy.linalg.matrix_power", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_value=0, - max_value=50, - shape=helpers.ints(min_value=2, max_value=8).map(lambda x: tuple([x, x])), + fn_tree="numpy.einsum", + args=st.sampled_from( + [ + ( + "ii", + np.arange(25).reshape(5, 5), + ), + ( + "ii->i", + np.arange(25).reshape(5, 5), + ), + ("ij,j", np.arange(25).reshape(5, 5), np.arange(5)), + ] ), - n=helpers.ints(min_value=1, max_value=8), - test_with_out=st.just(False), + dtype=helpers.get_dtypes("float", full=False), ) -def test_numpy_matrix_power( - dtype_and_x, - n, +def test_numpy_einsum( + *, + args, + dtype, frontend, test_flags, fn_tree, backend_fw, on_device, ): - dtype, x = dtype_and_x + kw = {} + i = 0 + for arg in args: + kw[f"x{i}"] = arg + i += 1 + test_flags.num_positional_args = i helpers.test_frontend_function( input_dtypes=dtype, backend_to_test=backend_fw, frontend=frontend, - test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - a=x[0], - n=n, + test_flags=test_flags, + **kw, + optimize=False, + order="K", + casting="safe", ) -# tensordot +# inner @handle_frontend_test( - fn_tree="numpy.tensordot", - dtype_values_and_axes=_get_dtype_value1_value2_axis_for_tensordot( - helpers.get_dtypes(kind="numeric") + fn_tree="numpy.inner", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + min_value=-10, + max_value=10, + num_arrays=2, + shared_dtype=True, ), test_with_out=st.just(False), ) -def test_numpy_tensordot( - dtype_values_and_axes, +def test_numpy_inner( + dtype_and_x, frontend, test_flags, fn_tree, backend_fw, + on_device, ): - dtype, a, b, axes = dtype_values_and_axes + input_dtypes, xs = dtype_and_x helpers.test_frontend_function( - input_dtypes=dtype, + input_dtypes=input_dtypes, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, - a=a, - b=b, - axes=axes, + on_device=on_device, + a=xs[0], + b=xs[1], ) @@ -317,6 +264,77 @@ def test_numpy_kron( ) +# matmul +@handle_frontend_test( + fn_tree="numpy.matmul", + dtypes_values_casting=np_frontend_helpers.dtypes_values_casting_dtype( + arr_func=[_get_first_matrix_and_dtype, _get_second_matrix_and_dtype], + ), + number_positional_args=np_frontend_helpers.get_num_positional_args_ufunc( + fn_name="matmul" + ), +) +def test_numpy_matmul( + dtypes_values_casting, + frontend, + test_flags, + fn_tree, + backend_fw, + on_device, +): + dtypes, x, casting, dtype = dtypes_values_casting + helpers.test_frontend_function( + input_dtypes=dtypes, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x1=x[0], + x2=x[1], + out=None, + casting=casting, + order="K", + dtype=dtype, + # The arguments below are currently unused. + # subok=True, + ) + + +# matrix_power +@handle_frontend_test( + fn_tree="numpy.linalg.matrix_power", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_value=0, + max_value=50, + shape=helpers.ints(min_value=2, max_value=8).map(lambda x: tuple([x, x])), + ), + n=helpers.ints(min_value=1, max_value=8), + test_with_out=st.just(False), +) +def test_numpy_matrix_power( + dtype_and_x, + n, + frontend, + test_flags, + fn_tree, + backend_fw, + on_device, +): + dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + a=x[0], + n=n, + ) + + # multi_dot @handle_frontend_test( fn_tree="numpy.linalg.multi_dot", @@ -344,77 +362,63 @@ def test_numpy_multi_dot( ) -# dot +# outer @handle_frontend_test( - fn_tree="numpy.dot", - dtype_a_b=np_frontend_helpers._get_dtype_input_and_vectors(), + fn_tree="numpy.outer", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + min_value=-10, + max_value=10, + num_arrays=2, + min_num_dims=1, + max_num_dims=1, + shared_dtype=True, + ), ) -def test_numpy_dot( - dtype_a_b, +def test_numpy_outer( + dtype_and_x, frontend, - backend_fw, test_flags, fn_tree, + backend_fw, on_device, ): - dtype, a, b = dtype_a_b + input_dtypes, xs = dtype_and_x helpers.test_frontend_function( - input_dtypes=dtype, - frontend=frontend, + input_dtypes=input_dtypes, backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - test_flags=test_flags, - rtol=1e-01, - atol=1e-01, - a=a, - b=b, + a=xs[0], + b=xs[1], ) -# einsum +# tensordot @handle_frontend_test( - fn_tree="numpy.einsum", - args=st.sampled_from( - [ - ( - "ii", - np.arange(25).reshape(5, 5), - ), - ( - "ii->i", - np.arange(25).reshape(5, 5), - ), - ("ij,j", np.arange(25).reshape(5, 5), np.arange(5)), - ] + fn_tree="numpy.tensordot", + dtype_values_and_axes=_get_dtype_value1_value2_axis_for_tensordot( + helpers.get_dtypes(kind="numeric") ), - dtype=helpers.get_dtypes("float", full=False), + test_with_out=st.just(False), ) -def test_numpy_einsum( - *, - args, - dtype, +def test_numpy_tensordot( + dtype_values_and_axes, frontend, test_flags, fn_tree, backend_fw, - on_device, ): - kw = {} - i = 0 - for arg in args: - kw[f"x{i}"] = arg - i += 1 - test_flags.num_positional_args = i + dtype, a, b, axes = dtype_values_and_axes helpers.test_frontend_function( input_dtypes=dtype, backend_to_test=backend_fw, frontend=frontend, - fn_tree=fn_tree, - on_device=on_device, test_flags=test_flags, - **kw, - optimize=False, - order="K", - casting="safe", + fn_tree=fn_tree, + a=a, + b=b, + axes=axes, ) diff --git a/ivy_tests/test_ivy/test_frontends/test_numpy/test_linalg/test_matrix_eigenvalues.py b/ivy_tests/test_ivy/test_frontends/test_numpy/test_linalg/test_matrix_eigenvalues.py index 1571e95b7a9a3..d240cef1dc48b 100644 --- a/ivy_tests/test_ivy/test_frontends/test_numpy/test_linalg/test_matrix_eigenvalues.py +++ b/ivy_tests/test_ivy/test_frontends/test_numpy/test_linalg/test_matrix_eigenvalues.py @@ -15,36 +15,6 @@ ) -# eigvalsh -@handle_frontend_test( - fn_tree="numpy.linalg.eigvalsh", - x=_get_dtype_and_matrix(symmetric=True), - UPLO=st.sampled_from(["L", "U"]), -) -def test_numpy_eigvalsh( - x, - UPLO, - frontend, - test_flags, - fn_tree, - backend_fw, - on_device, -): - input_dtypes, xs = x - helpers.test_frontend_function( - input_dtypes=input_dtypes, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - rtol=1e-2, - atol=1e-2, - a=xs, - UPLO=UPLO, - ) - - # eig @handle_frontend_test( fn_tree="numpy.linalg.eig", @@ -158,3 +128,33 @@ def test_numpy_eigh( backend=backend_fw, ground_truth_backend=frontend, ) + + +# eigvalsh +@handle_frontend_test( + fn_tree="numpy.linalg.eigvalsh", + x=_get_dtype_and_matrix(symmetric=True), + UPLO=st.sampled_from(["L", "U"]), +) +def test_numpy_eigvalsh( + x, + UPLO, + frontend, + test_flags, + fn_tree, + backend_fw, + on_device, +): + input_dtypes, xs = x + helpers.test_frontend_function( + input_dtypes=input_dtypes, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + rtol=1e-2, + atol=1e-2, + a=xs, + UPLO=UPLO, + ) diff --git a/ivy_tests/test_ivy/test_frontends/test_numpy/test_linalg/test_norms_and_other_numbers.py b/ivy_tests/test_ivy/test_frontends/test_numpy/test_linalg/test_norms_and_other_numbers.py index 87656505375db..90ee682e46531 100644 --- a/ivy_tests/test_ivy/test_frontends/test_numpy/test_linalg/test_norms_and_other_numbers.py +++ b/ivy_tests/test_ivy/test_frontends/test_numpy/test_linalg/test_norms_and_other_numbers.py @@ -13,6 +13,10 @@ ) +# --- Helpers --- # +# --------------- # + + # norm @st.composite def _norm_helper(draw): @@ -66,24 +70,25 @@ def _vector_norm_example(): return _matrix_norm_example() +# --- Main --- # +# ------------ # + + +# det @handle_frontend_test( - fn_tree="numpy.linalg.norm", - norm_values=_norm_helper(), - keepdims=st.booleans(), + fn_tree="numpy.linalg.det", + dtype_and_x=_get_dtype_and_matrix(), test_with_out=st.just(False), ) -def test_numpy_norm( - norm_values, - keepdims, +def test_numpy_det( + dtype_and_x, frontend, test_flags, fn_tree, backend_fw, on_device, ): - dtype, x, axis, ord, check_stable = norm_values - if check_stable: - assume(matrix_is_stable(x[0], cond_limit=10)) + dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=dtype, backend_to_test=backend_fw, @@ -91,10 +96,9 @@ def test_numpy_norm( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], - ord=ord, - axis=axis, - keepdims=keepdims, + rtol=1e-2, + atol=1e-2, + a=x[0], ) @@ -127,21 +131,24 @@ def test_numpy_matrix_rank( ) -# det @handle_frontend_test( - fn_tree="numpy.linalg.det", - dtype_and_x=_get_dtype_and_matrix(), + fn_tree="numpy.linalg.norm", + norm_values=_norm_helper(), + keepdims=st.booleans(), test_with_out=st.just(False), ) -def test_numpy_det( - dtype_and_x, +def test_numpy_norm( + norm_values, + keepdims, frontend, test_flags, fn_tree, backend_fw, on_device, ): - dtype, x = dtype_and_x + dtype, x, axis, ord, check_stable = norm_values + if check_stable: + assume(matrix_is_stable(x[0], cond_limit=10)) helpers.test_frontend_function( input_dtypes=dtype, backend_to_test=backend_fw, @@ -149,9 +156,10 @@ def test_numpy_det( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - rtol=1e-2, - atol=1e-2, - a=x[0], + x=x[0], + ord=ord, + axis=axis, + keepdims=keepdims, ) diff --git a/ivy_tests/test_ivy/test_frontends/test_numpy/test_linalg/test_solving_equations_and_inverting_matrices.py b/ivy_tests/test_ivy/test_frontends/test_numpy/test_linalg/test_solving_equations_and_inverting_matrices.py index ba7c6f7af175a..2f991e2c6e47b 100644 --- a/ivy_tests/test_ivy/test_frontends/test_numpy/test_linalg/test_solving_equations_and_inverting_matrices.py +++ b/ivy_tests/test_ivy/test_frontends/test_numpy/test_linalg/test_solving_equations_and_inverting_matrices.py @@ -8,95 +8,8 @@ from ivy_tests.test_ivy.helpers import handle_frontend_test -# solve -@handle_frontend_test( - fn_tree="numpy.linalg.solve", - x=helpers.get_first_solve_matrix(adjoint=True), - y=helpers.get_second_solve_matrix(), - test_with_out=st.just(False), -) -def test_numpy_solve( - x, - y, - frontend, - test_flags, - fn_tree, - backend_fw, - on_device, -): - dtype1, x1, _ = x - dtype2, x2 = y - helpers.test_frontend_function( - input_dtypes=[dtype1, dtype2], - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - a=x1, - b=x2, - ) - - -# inv -@handle_frontend_test( - fn_tree="numpy.linalg.inv", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - small_abs_safety_factor=2, - safety_factor_scale="log", - shape=helpers.ints(min_value=2, max_value=20).map(lambda x: tuple([x, x])), - ).filter(lambda x: np.linalg.cond(x[1][0].tolist()) < 1 / sys.float_info.epsilon), - test_with_out=st.just(False), -) -def test_numpy_inv( - dtype_and_x, - frontend, - test_flags, - fn_tree, - backend_fw, - on_device, -): - dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - a=x[0], - ) - - -# pinv -@handle_frontend_test( - fn_tree="numpy.linalg.pinv", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_num_dims=2, - max_num_dims=2, - ), - test_with_out=st.just(False), -) -def test_numpy_pinv( - dtype_and_x, - frontend, - test_flags, - fn_tree, - backend_fw, - on_device, -): - dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - a=x[0], - ) +# --- Helpers --- # +# --------------- # # tensorinv @@ -143,35 +56,38 @@ def _get_inv_square_matrices(draw): return input_dtype, a, ind +# --- Main --- # +# ------------ # + + +# inv @handle_frontend_test( - fn_tree="numpy.linalg.tensorinv", - params=_get_inv_square_matrices(), + fn_tree="numpy.linalg.inv", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + small_abs_safety_factor=2, + safety_factor_scale="log", + shape=helpers.ints(min_value=2, max_value=20).map(lambda x: tuple([x, x])), + ).filter(lambda x: np.linalg.cond(x[1][0].tolist()) < 1 / sys.float_info.epsilon), test_with_out=st.just(False), ) -def test_numpy_tensorinv( - *, - params, +def test_numpy_inv( + dtype_and_x, + frontend, test_flags, - on_device, fn_tree, - frontend, backend_fw, + on_device, ): - dtype, x, ind = params - if backend_fw == "paddle": - # Paddle only supports ndim from 0 to 9 - assume(x.ndim <= 9) + dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=dtype, backend_to_test=backend_fw, - test_flags=test_flags, - rtol=1e-01, - atol=1e-01, frontend=frontend, + test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - a=x, - ind=ind, + a=x[0], ) @@ -217,3 +133,95 @@ def test_numpy_lstsq( # ground_truth_backend="numpy", # ) return + + +# pinv +@handle_frontend_test( + fn_tree="numpy.linalg.pinv", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_num_dims=2, + max_num_dims=2, + ), + test_with_out=st.just(False), +) +def test_numpy_pinv( + dtype_and_x, + frontend, + test_flags, + fn_tree, + backend_fw, + on_device, +): + dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + a=x[0], + ) + + +# solve +@handle_frontend_test( + fn_tree="numpy.linalg.solve", + x=helpers.get_first_solve_matrix(adjoint=True), + y=helpers.get_second_solve_matrix(), + test_with_out=st.just(False), +) +def test_numpy_solve( + x, + y, + frontend, + test_flags, + fn_tree, + backend_fw, + on_device, +): + dtype1, x1, _ = x + dtype2, x2 = y + helpers.test_frontend_function( + input_dtypes=[dtype1, dtype2], + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + a=x1, + b=x2, + ) + + +@handle_frontend_test( + fn_tree="numpy.linalg.tensorinv", + params=_get_inv_square_matrices(), + test_with_out=st.just(False), +) +def test_numpy_tensorinv( + *, + params, + test_flags, + on_device, + fn_tree, + frontend, + backend_fw, +): + dtype, x, ind = params + if backend_fw == "paddle": + # Paddle only supports ndim from 0 to 9 + assume(x.ndim <= 9) + helpers.test_frontend_function( + input_dtypes=dtype, + backend_to_test=backend_fw, + test_flags=test_flags, + rtol=1e-01, + atol=1e-01, + frontend=frontend, + fn_tree=fn_tree, + on_device=on_device, + a=x, + ind=ind, + ) diff --git a/ivy_tests/test_ivy/test_frontends/test_numpy/test_logic/test_array_contents.py b/ivy_tests/test_ivy/test_frontends/test_numpy/test_logic/test_array_contents.py index 7cfbbb7014e55..56447b9024f47 100644 --- a/ivy_tests/test_ivy/test_frontends/test_numpy/test_logic/test_array_contents.py +++ b/ivy_tests/test_ivy/test_frontends/test_numpy/test_logic/test_array_contents.py @@ -8,17 +8,19 @@ @handle_frontend_test( - fn_tree="numpy.isneginf", + fn_tree="numpy.allclose", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float_and_integer"), - min_value=-np.inf, - max_value=np.inf, + available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, + shared_dtype=True, ), + equal_nan=st.booleans(), test_with_out=st.just(False), ) -def test_numpy_isneginf( +def test_numpy_allclose( *, dtype_and_x, + equal_nan, on_device, fn_tree, frontend, @@ -33,22 +35,26 @@ def test_numpy_isneginf( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], + a=x[0], + b=x[1], + equal_nan=equal_nan, ) @handle_frontend_test( - fn_tree="numpy.isposinf", + fn_tree="numpy.isclose", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float_and_integer"), - min_value=-np.inf, - max_value=np.inf, + available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, + shared_dtype=True, ), + equal_nan=st.booleans(), test_with_out=st.just(False), ) -def test_numpy_isposinf( +def test_numpy_isclose( *, dtype_and_x, + equal_nan, on_device, fn_tree, frontend, @@ -63,24 +69,24 @@ def test_numpy_isposinf( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], + a=x[0], + b=x[1], + equal_nan=equal_nan, ) @handle_frontend_test( - fn_tree="numpy.allclose", + fn_tree="numpy.isneginf", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - shared_dtype=True, + available_dtypes=helpers.get_dtypes("float_and_integer"), + min_value=-np.inf, + max_value=np.inf, ), - equal_nan=st.booleans(), test_with_out=st.just(False), ) -def test_numpy_allclose( +def test_numpy_isneginf( *, dtype_and_x, - equal_nan, on_device, fn_tree, frontend, @@ -95,26 +101,22 @@ def test_numpy_allclose( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - a=x[0], - b=x[1], - equal_nan=equal_nan, + x=x[0], ) @handle_frontend_test( - fn_tree="numpy.isclose", + fn_tree="numpy.isposinf", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - shared_dtype=True, + available_dtypes=helpers.get_dtypes("float_and_integer"), + min_value=-np.inf, + max_value=np.inf, ), - equal_nan=st.booleans(), test_with_out=st.just(False), ) -def test_numpy_isclose( +def test_numpy_isposinf( *, dtype_and_x, - equal_nan, on_device, fn_tree, frontend, @@ -129,7 +131,5 @@ def test_numpy_isclose( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - a=x[0], - b=x[1], - equal_nan=equal_nan, + x=x[0], ) diff --git a/ivy_tests/test_ivy/test_frontends/test_numpy/test_logic/test_comparison.py b/ivy_tests/test_ivy/test_frontends/test_numpy/test_logic/test_comparison.py index 4325312d5e597..4ec4993f14786 100644 --- a/ivy_tests/test_ivy/test_frontends/test_numpy/test_logic/test_comparison.py +++ b/ivy_tests/test_ivy/test_frontends/test_numpy/test_logic/test_comparison.py @@ -7,6 +7,66 @@ from ivy_tests.test_ivy.helpers import handle_frontend_test +@handle_frontend_test( + fn_tree="numpy.array_equal", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), num_arrays=2, shared_dtype=True + ), + equal_nan=st.booleans(), +) +def test_numpy_array_equal( + *, + dtype_and_x, + equal_nan, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + a1=x[0], + a2=x[1], + equal_nan=equal_nan, + ) + + +@handle_frontend_test( + fn_tree="numpy.array_equiv", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2, shared_dtype=True + ), + test_with_out=st.just(False), +) +def test_numpy_array_equiv( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + a1=x[0], + a2=x[1], + ) + + # equal @handle_frontend_test( fn_tree="numpy.equal", @@ -57,37 +117,6 @@ def test_numpy_equal( ) -@handle_frontend_test( - fn_tree="numpy.array_equal", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), num_arrays=2, shared_dtype=True - ), - equal_nan=st.booleans(), -) -def test_numpy_array_equal( - *, - dtype_and_x, - equal_nan, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - a1=x[0], - a2=x[1], - equal_nan=equal_nan, - ) - - @handle_frontend_test( fn_tree="numpy.greater", dtypes_values_casting=np_frontend_helpers.dtypes_values_casting_dtype( @@ -333,32 +362,3 @@ def test_numpy_not_equal( dtype=None, subok=True, ) - - -@handle_frontend_test( - fn_tree="numpy.array_equiv", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2, shared_dtype=True - ), - test_with_out=st.just(False), -) -def test_numpy_array_equiv( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - a1=x[0], - a2=x[1], - ) diff --git a/ivy_tests/test_ivy/test_frontends/test_numpy/test_logic/test_logical_operations.py b/ivy_tests/test_ivy/test_frontends/test_numpy/test_logic/test_logical_operations.py index eb50edf203c0d..6f1aaabe01357 100644 --- a/ivy_tests/test_ivy/test_frontends/test_numpy/test_logic/test_logical_operations.py +++ b/ivy_tests/test_ivy/test_frontends/test_numpy/test_logic/test_logical_operations.py @@ -55,24 +55,23 @@ def test_numpy_logical_and( ) -# logical_or +# logical_not @handle_frontend_test( - fn_tree="numpy.logical_or", + fn_tree="numpy.logical_not", dtypes_values_casting=np_frontend_helpers.dtypes_values_casting_dtype( arr_func=[ lambda: helpers.dtype_and_values( available_dtypes=("bool",), - num_arrays=2, ) ], special=True, ), where=np_frontend_helpers.where(), number_positional_args=np_frontend_helpers.get_num_positional_args_ufunc( - fn_name="logical_or" + fn_name="logical_not" ), ) -def test_numpy_logical_or( +def test_numpy_logical_not( dtypes_values_casting, where, on_device, @@ -81,7 +80,7 @@ def test_numpy_logical_or( test_flags, backend_fw, ): - input_dtypes, x, casting, dtype = dtypes_values_casting + input_dtypes, x, casting, _ = dtypes_values_casting where, input_dtypes, test_flags = np_frontend_helpers.handle_where_and_array_bools( where=where, input_dtype=input_dtypes, @@ -94,8 +93,7 @@ def test_numpy_logical_or( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x1=x[0], - x2=x[1], + x=x[0], out=None, where=where, casting=casting, @@ -105,23 +103,24 @@ def test_numpy_logical_or( ) -# logical_not +# logical_or @handle_frontend_test( - fn_tree="numpy.logical_not", + fn_tree="numpy.logical_or", dtypes_values_casting=np_frontend_helpers.dtypes_values_casting_dtype( arr_func=[ lambda: helpers.dtype_and_values( available_dtypes=("bool",), + num_arrays=2, ) ], special=True, ), where=np_frontend_helpers.where(), number_positional_args=np_frontend_helpers.get_num_positional_args_ufunc( - fn_name="logical_not" + fn_name="logical_or" ), ) -def test_numpy_logical_not( +def test_numpy_logical_or( dtypes_values_casting, where, on_device, @@ -130,7 +129,7 @@ def test_numpy_logical_not( test_flags, backend_fw, ): - input_dtypes, x, casting, _ = dtypes_values_casting + input_dtypes, x, casting, dtype = dtypes_values_casting where, input_dtypes, test_flags = np_frontend_helpers.handle_where_and_array_bools( where=where, input_dtype=input_dtypes, @@ -143,7 +142,8 @@ def test_numpy_logical_not( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], + x1=x[0], + x2=x[1], out=None, where=where, casting=casting, diff --git a/ivy_tests/test_ivy/test_frontends/test_numpy/test_logic/test_truth_value_testing.py b/ivy_tests/test_ivy/test_frontends/test_numpy/test_logic/test_truth_value_testing.py index a259488e65b52..cf7cfa8231110 100644 --- a/ivy_tests/test_ivy/test_frontends/test_numpy/test_logic/test_truth_value_testing.py +++ b/ivy_tests/test_ivy/test_frontends/test_numpy/test_logic/test_truth_value_testing.py @@ -99,38 +99,41 @@ def test_numpy_any( @handle_frontend_test( - fn_tree="numpy.isscalar", - element=st.booleans() | st.floats() | st.integers() | st.complex_numbers(), + fn_tree="numpy.iscomplex", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("real_and_complex"), min_num_dims=1 + ), test_with_out=st.just(False), ) -def test_numpy_isscalar( +def test_numpy_iscomplex( *, - element, + dtype_and_x, + frontend, on_device, fn_tree, - frontend, test_flags, backend_fw, ): + input_dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=ivy.all_dtypes, + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - element=element, + x=x[0], ) @handle_frontend_test( - fn_tree="numpy.isfortran", + fn_tree="numpy.iscomplexobj", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), min_num_dims=1 + available_dtypes=helpers.get_dtypes("real_and_complex"), ), test_with_out=st.just(False), ) -def test_numpy_isfortran( +def test_numpy_iscomplexobj( *, dtype_and_x, frontend, @@ -139,9 +142,10 @@ def test_numpy_isfortran( test_flags, backend_fw, ): - if ivy.current_backend() != "numpy": - return input_dtype, x = dtype_and_x + if ivy.current_backend_str() == "paddle": + # mostly paddle doesn't support unsigned int + assume(input_dtype[0] not in ["int8", "uint8", "int16"]) helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -149,17 +153,18 @@ def test_numpy_isfortran( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - a=x[0], + x=x[0], ) @handle_frontend_test( - fn_tree="numpy.isreal", + fn_tree="numpy.isfortran", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float_and_complex") + available_dtypes=helpers.get_dtypes("float"), min_num_dims=1 ), + test_with_out=st.just(False), ) -def test_numpy_isreal( +def test_numpy_isfortran( *, dtype_and_x, frontend, @@ -168,6 +173,8 @@ def test_numpy_isreal( test_flags, backend_fw, ): + if ivy.current_backend() != "numpy": + return input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, @@ -176,18 +183,17 @@ def test_numpy_isreal( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], + a=x[0], ) @handle_frontend_test( - fn_tree="numpy.isrealobj", + fn_tree="numpy.isreal", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("real_and_complex"), min_num_dims=1 + available_dtypes=helpers.get_dtypes("float_and_complex") ), - test_with_out=st.just(False), ) -def test_numpy_isrealobj( +def test_numpy_isreal( *, dtype_and_x, frontend, @@ -209,13 +215,13 @@ def test_numpy_isrealobj( @handle_frontend_test( - fn_tree="numpy.iscomplexobj", + fn_tree="numpy.isrealobj", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("real_and_complex"), + available_dtypes=helpers.get_dtypes("real_and_complex"), min_num_dims=1 ), test_with_out=st.just(False), ) -def test_numpy_iscomplexobj( +def test_numpy_isrealobj( *, dtype_and_x, frontend, @@ -225,9 +231,6 @@ def test_numpy_iscomplexobj( backend_fw, ): input_dtype, x = dtype_and_x - if ivy.current_backend_str() == "paddle": - # mostly paddle doesn't support unsigned int - assume(input_dtype[0] not in ["int8", "uint8", "int16"]) helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -240,28 +243,25 @@ def test_numpy_iscomplexobj( @handle_frontend_test( - fn_tree="numpy.iscomplex", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("real_and_complex"), min_num_dims=1 - ), + fn_tree="numpy.isscalar", + element=st.booleans() | st.floats() | st.integers() | st.complex_numbers(), test_with_out=st.just(False), ) -def test_numpy_iscomplex( +def test_numpy_isscalar( *, - dtype_and_x, - frontend, + element, on_device, fn_tree, + frontend, test_flags, backend_fw, ): - input_dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=ivy.all_dtypes, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], + element=element, ) diff --git a/ivy_tests/test_ivy/test_frontends/test_numpy/test_ma/test_MaskedArray.py b/ivy_tests/test_ivy/test_frontends/test_numpy/test_ma/test_MaskedArray.py index b17df5f7408f7..9da26eafecd0d 100644 --- a/ivy_tests/test_ivy/test_frontends/test_numpy/test_ma/test_MaskedArray.py +++ b/ivy_tests/test_ivy/test_frontends/test_numpy/test_ma/test_MaskedArray.py @@ -9,6 +9,10 @@ from ivy_tests.test_ivy.helpers import handle_frontend_test +# --- Helpers --- # +# --------------- # + + @st.composite def _array_mask(draw): dtype = draw(helpers.get_dtypes("valid", prune_function=False, full=False)) @@ -21,6 +25,10 @@ def _array_mask(draw): return dtype[0], x_mask +# --- Main --- # +# ------------ # + + # data @handle_frontend_test( fn_tree="numpy.add", # dummy fn_tree @@ -34,15 +42,12 @@ def test_numpy_data( assert ivy.all(x.data == ivy.array(data[0])) -# mask -@handle_frontend_test( - fn_tree="numpy.add", # dummy fn_tree - args=_array_mask(), -) -def test_numpy_mask(args): - dtype, data = args - x = MaskedArray(data[0], mask=ivy.array(data[1]), dtype=dtype, shrink=False) - assert ivy.all(x.mask == ivy.array(data[1])) +# dtype +@handle_frontend_test(fn_tree="numpy.add", dtype_x_mask=_array_mask()) # dummy fn_tree +def test_numpy_dtype(dtype_x_mask): + dtype, data = dtype_x_mask + x = MaskedArray(data[0], mask=data[1], dtype=dtype) + assert x.dtype == dtype # fill_value @@ -72,97 +77,12 @@ def test_numpy_hardmask(dtype_x_mask, hard): assert x.hardmask == hard -# dtype -@handle_frontend_test(fn_tree="numpy.add", dtype_x_mask=_array_mask()) # dummy fn_tree -def test_numpy_dtype(dtype_x_mask): - dtype, data = dtype_x_mask - x = MaskedArray(data[0], mask=data[1], dtype=dtype) - assert x.dtype == dtype - - -# @st.composite -# def _getitem_helper(draw): -# dtype_x_index = draw( -# helpers.array_indices_axis( -# array_dtypes=helpers.get_dtypes("numeric"), -# indices_dtypes=ivy_torch.valid_int_dtypes, -# indices_same_dims=True, -# ) -# ) -# dtype, x, index = dtype_x_index[:3] -# mask = draw( -# helpers.dtype_and_values( -# available_dtypes=helpers.get_dtypes("bool"), -# shape=x.shape, -# ) -# ) -# return dtype[0], x, mask[1][0], index - - -# # __getitem__ -# @handle_frontend_test( -# fn_tree="numpy.add", # dummy fn_tree -# args=_getitem_helper(), -# ) -# def test_numpy___getitem__( -# args, -# ): -# dtype, x, mask, index = args -# data = MaskedArray(x, mask=mask, dtype=dtype) -# ret = data.__getitem__(index) -# data_gt = np.ma.MaskedArray(x, mask=mask, dtype=dtype) -# ret_gt = data_gt.__getitem__(index) -# ret = ivy.to_numpy(ivy.flatten(ret.data)) -# ret_gt = np.array(np.ravel(ret_gt)) -# helpers.value_test( -# ret_np_flat=ret, -# ret_np_from_gt_flat=ret_gt, -# ground_truth_backend="numpy", -# ) - - -# @st.composite -# def _setitem_helper(draw): -# dtype_x_index = draw( -# helpers.array_indices_axis( -# array_dtypes=st.shared(helpers.get_dtypes("numeric"), key="dtype"), -# indices_dtypes=ivy_torch.valid_int_dtypes, -# indices_same_dims=True, -# ) -# ) -# dtype, x, index = dtype_x_index[:3] -# mask = draw( -# helpers.dtype_and_values( -# available_dtypes=helpers.get_dtypes("bool"), -# shape=x.shape, -# ) -# ) -# value = draw( -# helpers.dtype_and_values( -# available_dtypes=st.shared(helpers.get_dtypes("numeric"), key="dtype"), -# shape=index.shape, -# ) -# ) -# return dtype[0], x, mask[1][0], index, value[1][0] - - -# # __setitem__ -# @handle_frontend_test( -# fn_tree="numpy.add", # dummy fn_tree -# args=_setitem_helper(), -# ) -# def test_numpy___setitem__( -# args, -# ): -# dtype, x, mask, index, value = args -# data = MaskedArray(x, mask=mask, dtype=dtype) -# data_gt = np.ma.MaskedArray(x, mask=mask, dtype=dtype) -# data = data.__setitem__(index, value) -# data_gt.__setitem__(index, value) -# ret = ivy.to_numpy(ivy.flatten(data.data)) -# ret_gt = np.array(np.ravel(data_gt)) -# helpers.value_test( -# ret_np_flat=ret, -# ret_np_from_gt_flat=ret_gt, -# ground_truth_backend="numpy", -# ) +# mask +@handle_frontend_test( + fn_tree="numpy.add", # dummy fn_tree + args=_array_mask(), +) +def test_numpy_mask(args): + dtype, data = args + x = MaskedArray(data[0], mask=ivy.array(data[1]), dtype=dtype, shrink=False) + assert ivy.all(x.mask == ivy.array(data[1])) diff --git a/ivy_tests/test_ivy/test_frontends/test_numpy/test_manipulation_routines/test_adding_and_removing_elements.py b/ivy_tests/test_ivy/test_frontends/test_numpy/test_manipulation_routines/test_adding_and_removing_elements.py index 81e739efb9c60..04a10cba43996 100644 --- a/ivy_tests/test_ivy/test_frontends/test_numpy/test_manipulation_routines/test_adding_and_removing_elements.py +++ b/ivy_tests/test_ivy/test_frontends/test_numpy/test_manipulation_routines/test_adding_and_removing_elements.py @@ -8,51 +8,6 @@ from ivy_tests.test_ivy.helpers import handle_frontend_test -# unique -@handle_frontend_test( - fn_tree="numpy.unique", - dtype_x_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("float"), - force_int_axis=True, - valid_axis=True, - ), - return_index=st.booleans(), - return_inverse=st.booleans(), - return_counts=st.booleans(), - none_axis=st.booleans(), - test_with_out=st.just(False), -) -def test_numpy_unique( - *, - dtype_x_axis, - return_index, - return_inverse, - return_counts, - none_axis, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtypes, xs, axis = dtype_x_axis - if none_axis: - axis = None - helpers.test_frontend_function( - input_dtypes=input_dtypes, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - array=xs[0], - return_index=return_index, - return_inverse=return_inverse, - return_counts=return_counts, - axis=axis, - ) - - # append @handle_frontend_test( fn_tree="numpy.append", @@ -124,3 +79,48 @@ def test_numpy_trim_zeros( filt=x[0], trim=trim, ) + + +# unique +@handle_frontend_test( + fn_tree="numpy.unique", + dtype_x_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("float"), + force_int_axis=True, + valid_axis=True, + ), + return_index=st.booleans(), + return_inverse=st.booleans(), + return_counts=st.booleans(), + none_axis=st.booleans(), + test_with_out=st.just(False), +) +def test_numpy_unique( + *, + dtype_x_axis, + return_index, + return_inverse, + return_counts, + none_axis, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtypes, xs, axis = dtype_x_axis + if none_axis: + axis = None + helpers.test_frontend_function( + input_dtypes=input_dtypes, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + array=xs[0], + return_index=return_index, + return_inverse=return_inverse, + return_counts=return_counts, + axis=axis, + ) diff --git a/ivy_tests/test_ivy/test_frontends/test_numpy/test_manipulation_routines/test_basic_operations.py b/ivy_tests/test_ivy/test_frontends/test_numpy/test_manipulation_routines/test_basic_operations.py index c000ce64cfc6c..822ccfe4b46ab 100644 --- a/ivy_tests/test_ivy/test_frontends/test_numpy/test_manipulation_routines/test_basic_operations.py +++ b/ivy_tests/test_ivy/test_frontends/test_numpy/test_manipulation_routines/test_basic_operations.py @@ -8,6 +8,10 @@ from ivy_tests.test_ivy.helpers import handle_frontend_test, BackendHandler +# --- Helpers --- # +# --------------- # + + @st.composite def generate_copyto_args(draw): input_dtypes, xs, casting, _ = draw( diff --git a/ivy_tests/test_ivy/test_frontends/test_numpy/test_manipulation_routines/test_changing_array_shape.py b/ivy_tests/test_ivy/test_frontends/test_numpy/test_manipulation_routines/test_changing_array_shape.py index 0938f14c0372c..6b830f3d2cf09 100644 --- a/ivy_tests/test_ivy/test_frontends/test_numpy/test_manipulation_routines/test_changing_array_shape.py +++ b/ivy_tests/test_ivy/test_frontends/test_numpy/test_manipulation_routines/test_changing_array_shape.py @@ -7,6 +7,37 @@ from ivy_tests.test_ivy.helpers import handle_frontend_test +# --- Helpers --- # +# --------------- # + + +# resize +@st.composite +def dtype_and_resize(draw): + dtype, x = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + shape=helpers.get_shape( + allow_none=False, + min_num_dims=1, + max_num_dims=5, + min_dim_size=2, + max_dim_size=10, + ), + ) + ) + new_shape = draw( + helpers.get_shape( + allow_none=False, + min_num_dims=2, + max_num_dims=5, + min_dim_size=2, + max_dim_size=10, + ), + ) + return dtype, x, new_shape + + @st.composite def dtypes_x_reshape(draw): dtypes, x = draw( @@ -25,56 +56,48 @@ def dtypes_x_reshape(draw): return dtypes, x, shape -# reshape +# asarray_chkfinite @handle_frontend_test( - fn_tree="numpy.reshape", - dtypes_x_shape=dtypes_x_reshape(), - order=st.sampled_from(["C", "F", "A"]), + fn_tree="numpy.asarray_chkfinite", + dtype_and_a=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), + test_with_out=st.just(False), ) -def test_numpy_reshape( +def test_numpy_asarray_chkfinite( *, - dtypes_x_shape, - order, + dtype_and_a, on_device, fn_tree, frontend, test_flags, backend_fw, ): - dtypes, x, shape = dtypes_x_shape + dtype, a = dtype_and_a helpers.test_frontend_function( - input_dtypes=dtypes, + input_dtypes=dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], - newshape=shape, - order=order, + a=a[0], ) +# asfarray @handle_frontend_test( - fn_tree="numpy.broadcast_to", - dtype_x_shape=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), ret_shape=True - ), - factor=helpers.ints(min_value=1, max_value=5), - test_with_out=st.just(False), + fn_tree="numpy.asfarray", + dtype_and_a=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), ) -def test_numpy_broadcast_to( +def test_numpy_asfarray( *, - dtype_x_shape, - factor, + dtype_and_a, on_device, fn_tree, frontend, test_flags, backend_fw, ): - dtype, x, shape = dtype_x_shape - broadcast_shape = (factor,) + shape + dtype, a = dtype_and_a helpers.test_frontend_function( input_dtypes=dtype, backend_to_test=backend_fw, @@ -82,30 +105,30 @@ def test_numpy_broadcast_to( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - array=x[0], - shape=broadcast_shape, + a=a[0], ) @handle_frontend_test( - fn_tree="numpy.ravel", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), + fn_tree="numpy.broadcast_to", + dtype_x_shape=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), ret_shape=True ), - order=st.sampled_from(["C", "F", "A", "K"]), + factor=helpers.ints(min_value=1, max_value=5), test_with_out=st.just(False), ) -def test_numpy_ravel( +def test_numpy_broadcast_to( *, - dtype_and_x, - order, + dtype_x_shape, + factor, on_device, fn_tree, frontend, test_flags, backend_fw, ): - dtype, x = dtype_and_x + dtype, x, shape = dtype_x_shape + broadcast_shape = (factor,) + shape helpers.test_frontend_function( input_dtypes=dtype, backend_to_test=backend_fw, @@ -113,8 +136,8 @@ def test_numpy_ravel( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - a=x[0], - order=order, + array=x[0], + shape=broadcast_shape, ) @@ -192,47 +215,25 @@ def test_numpy_moveaxis( ) -# resize -@st.composite -def dtype_and_resize(draw): - dtype, x = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - shape=helpers.get_shape( - allow_none=False, - min_num_dims=1, - max_num_dims=5, - min_dim_size=2, - max_dim_size=10, - ), - ) - ) - new_shape = draw( - helpers.get_shape( - allow_none=False, - min_num_dims=2, - max_num_dims=5, - min_dim_size=2, - max_dim_size=10, - ), - ) - return dtype, x, new_shape - - @handle_frontend_test( - fn_tree="numpy.resize", - dtypes_x_shape=dtype_and_resize(), + fn_tree="numpy.ravel", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + ), + order=st.sampled_from(["C", "F", "A", "K"]), + test_with_out=st.just(False), ) -def test_numpy_resize( +def test_numpy_ravel( *, - dtypes_x_shape, + dtype_and_x, + order, on_device, fn_tree, frontend, test_flags, backend_fw, ): - dtype, x, new_shape = dtypes_x_shape + dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=dtype, backend_to_test=backend_fw, @@ -240,93 +241,96 @@ def test_numpy_resize( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], - newshape=new_shape, + a=x[0], + order=order, ) -# asfarray +# require @handle_frontend_test( - fn_tree="numpy.asfarray", + fn_tree="numpy.require", dtype_and_a=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), + requirements=st.sampled_from(["C", "F", "A", "O", "W", "E"]), + like=st.just(None), + test_with_out=st.just(False), ) -def test_numpy_asfarray( +def test_numpy_require( *, dtype_and_a, + requirements, + like, on_device, fn_tree, frontend, - test_flags, backend_fw, + test_flags, ): dtype, a = dtype_and_a helpers.test_frontend_function( input_dtypes=dtype, - backend_to_test=backend_fw, frontend=frontend, + backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, a=a[0], + dtype=np.dtype(dtype[0]), + requirements=requirements, + like=like, ) -# asarray_chkfinite +# reshape @handle_frontend_test( - fn_tree="numpy.asarray_chkfinite", - dtype_and_a=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), - test_with_out=st.just(False), + fn_tree="numpy.reshape", + dtypes_x_shape=dtypes_x_reshape(), + order=st.sampled_from(["C", "F", "A"]), ) -def test_numpy_asarray_chkfinite( +def test_numpy_reshape( *, - dtype_and_a, + dtypes_x_shape, + order, on_device, fn_tree, frontend, test_flags, backend_fw, ): - dtype, a = dtype_and_a + dtypes, x, shape = dtypes_x_shape helpers.test_frontend_function( - input_dtypes=dtype, + input_dtypes=dtypes, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - a=a[0], + x=x[0], + newshape=shape, + order=order, ) -# require @handle_frontend_test( - fn_tree="numpy.require", - dtype_and_a=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), - requirements=st.sampled_from(["C", "F", "A", "O", "W", "E"]), - like=st.just(None), - test_with_out=st.just(False), + fn_tree="numpy.resize", + dtypes_x_shape=dtype_and_resize(), ) -def test_numpy_require( +def test_numpy_resize( *, - dtype_and_a, - requirements, - like, + dtypes_x_shape, on_device, fn_tree, frontend, - backend_fw, test_flags, + backend_fw, ): - dtype, a = dtype_and_a + dtype, x, new_shape = dtypes_x_shape helpers.test_frontend_function( input_dtypes=dtype, - frontend=frontend, backend_to_test=backend_fw, + frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - a=a[0], - dtype=np.dtype(dtype[0]), - requirements=requirements, - like=like, + x=x[0], + newshape=new_shape, ) diff --git a/ivy_tests/test_ivy/test_frontends/test_numpy/test_manipulation_routines/test_changing_number_of_dimensions.py b/ivy_tests/test_ivy/test_frontends/test_numpy/test_manipulation_routines/test_changing_number_of_dimensions.py index 37b2a2366d6f3..28a1f00d3f087 100644 --- a/ivy_tests/test_ivy/test_frontends/test_numpy/test_manipulation_routines/test_changing_number_of_dimensions.py +++ b/ivy_tests/test_ivy/test_frontends/test_numpy/test_manipulation_routines/test_changing_number_of_dimensions.py @@ -7,6 +7,10 @@ from ivy_tests.test_ivy.test_functional.test_core.test_dtype import dtypes_shared +# --- Helpers --- # +# --------------- # + + # squeeze @st.composite def _squeeze_helper(draw): @@ -20,25 +24,46 @@ def _squeeze_helper(draw): return draw(st.sampled_from(valid_axes)) +# broadcast_arrays +@st.composite +def broadcastable_arrays(draw, dtypes): + num_arrays = st.shared(helpers.ints(min_value=2, max_value=5), key="num_arrays") + shapes = draw(num_arrays.flatmap(helpers.mutually_broadcastable_shapes)) + dtypes = draw(dtypes) + arrays = [] + for c, (shape, dtype) in enumerate(zip(shapes, dtypes), 1): + x = draw(helpers.array_values(dtype=dtype, shape=shape), label=f"x{c}").tolist() + arrays.append(x) + return arrays + + +# --- Main --- # +# ------------ # + + +# atleast_1d @handle_frontend_test( - fn_tree="numpy.squeeze", + fn_tree="numpy.atleast_1d", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), - shape=st.shared(helpers.get_shape(), key="value_shape"), + num_arrays=helpers.ints(min_value=1, max_value=10), ), - axis=_squeeze_helper(), + test_with_out=st.just(False), ) -def test_numpy_squeeze( +def test_numpy_atleast_1d( *, dtype_and_x, - axis, on_device, fn_tree, frontend, test_flags, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, arrays = dtype_and_x + arys = {} + for i, (array, idtype) in enumerate(zip(arrays, input_dtype)): + arys["arrs{}".format(i)] = np.asarray(array, dtype=idtype) + test_flags.num_positional_args = len(arys) helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -46,35 +71,33 @@ def test_numpy_squeeze( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - a=x[0], - axis=axis, + **arys, ) +# atleast_2d @handle_frontend_test( - fn_tree="numpy.expand_dims", + fn_tree="numpy.atleast_2d", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), - shape=st.shared(helpers.get_shape(), key="value_shape"), - ), - axis=helpers.get_axis( - shape=st.shared(helpers.get_shape(), key="value_shape"), - min_size=1, - max_size=1, - force_int=True, + num_arrays=helpers.ints(min_value=1, max_value=10), ), + test_with_out=st.just(False), ) -def test_numpy_expand_dims( +def test_numpy_atleast_2d( *, dtype_and_x, - axis, on_device, fn_tree, frontend, test_flags, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, arrays = dtype_and_x + arys = {} + for i, (array, idtype) in enumerate(zip(arrays, input_dtype)): + arys["arrs{}".format(i)] = np.asarray(array, dtype=idtype) + test_flags.num_positional_args = len(arys) helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -82,21 +105,20 @@ def test_numpy_expand_dims( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - a=x[0], - axis=axis, + **arys, ) -# atleast_2d +# atleast_3d @handle_frontend_test( - fn_tree="numpy.atleast_2d", + fn_tree="numpy.atleast_3d", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), num_arrays=helpers.ints(min_value=1, max_value=10), ), test_with_out=st.just(False), ) -def test_numpy_atleast_2d( +def test_numpy_atleast_3d( *, dtype_and_x, on_device, @@ -121,63 +143,61 @@ def test_numpy_atleast_2d( ) -# atleast_3d @handle_frontend_test( - fn_tree="numpy.atleast_3d", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - num_arrays=helpers.ints(min_value=1, max_value=10), - ), + fn_tree="numpy.broadcast_arrays", + arrays=broadcastable_arrays(dtypes_shared("num_arrays")), + input_dtypes=dtypes_shared("num_arrays"), test_with_out=st.just(False), ) -def test_numpy_atleast_3d( +def test_numpy_broadcast_arrays( *, - dtype_and_x, + arrays, + input_dtypes, on_device, fn_tree, frontend, test_flags, backend_fw, ): - input_dtype, arrays = dtype_and_x - arys = {} - for i, (array, idtype) in enumerate(zip(arrays, input_dtype)): - arys["arrs{}".format(i)] = np.asarray(array, dtype=idtype) - test_flags.num_positional_args = len(arys) + args = {} + for i, (array, dtype) in enumerate(zip(arrays, input_dtypes)): + args["x{}".format(i)] = np.asarray(array, dtype=dtype) + test_flags.num_positional_args = len(args) helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=input_dtypes, backend_to_test=backend_fw, - frontend=frontend, test_flags=test_flags, + frontend=frontend, fn_tree=fn_tree, on_device=on_device, - **arys, + **args, ) -# atleast_1d @handle_frontend_test( - fn_tree="numpy.atleast_1d", + fn_tree="numpy.expand_dims", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), - num_arrays=helpers.ints(min_value=1, max_value=10), + shape=st.shared(helpers.get_shape(), key="value_shape"), + ), + axis=helpers.get_axis( + shape=st.shared(helpers.get_shape(), key="value_shape"), + min_size=1, + max_size=1, + force_int=True, ), - test_with_out=st.just(False), ) -def test_numpy_atleast_1d( +def test_numpy_expand_dims( *, dtype_and_x, + axis, on_device, fn_tree, frontend, test_flags, backend_fw, ): - input_dtype, arrays = dtype_and_x - arys = {} - for i, (array, idtype) in enumerate(zip(arrays, input_dtype)): - arys["arrs{}".format(i)] = np.asarray(array, dtype=idtype) - test_flags.num_positional_args = len(arys) + input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -185,49 +205,37 @@ def test_numpy_atleast_1d( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - **arys, + a=x[0], + axis=axis, ) -# broadcast_arrays -@st.composite -def broadcastable_arrays(draw, dtypes): - num_arrays = st.shared(helpers.ints(min_value=2, max_value=5), key="num_arrays") - shapes = draw(num_arrays.flatmap(helpers.mutually_broadcastable_shapes)) - dtypes = draw(dtypes) - arrays = [] - for c, (shape, dtype) in enumerate(zip(shapes, dtypes), 1): - x = draw(helpers.array_values(dtype=dtype, shape=shape), label=f"x{c}").tolist() - arrays.append(x) - return arrays - - @handle_frontend_test( - fn_tree="numpy.broadcast_arrays", - arrays=broadcastable_arrays(dtypes_shared("num_arrays")), - input_dtypes=dtypes_shared("num_arrays"), - test_with_out=st.just(False), + fn_tree="numpy.squeeze", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + shape=st.shared(helpers.get_shape(), key="value_shape"), + ), + axis=_squeeze_helper(), ) -def test_numpy_broadcast_arrays( +def test_numpy_squeeze( *, - arrays, - input_dtypes, + dtype_and_x, + axis, on_device, fn_tree, frontend, test_flags, backend_fw, ): - args = {} - for i, (array, dtype) in enumerate(zip(arrays, input_dtypes)): - args["x{}".format(i)] = np.asarray(array, dtype=dtype) - test_flags.num_positional_args = len(args) + input_dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=input_dtypes, + input_dtypes=input_dtype, backend_to_test=backend_fw, - test_flags=test_flags, frontend=frontend, + test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - **args, + a=x[0], + axis=axis, ) diff --git a/ivy_tests/test_ivy/test_frontends/test_numpy/test_manipulation_routines/test_joining_arrays.py b/ivy_tests/test_ivy/test_frontends/test_numpy/test_manipulation_routines/test_joining_arrays.py index 794bcc1b71a49..3102a5f40bc25 100644 --- a/ivy_tests/test_ivy/test_frontends/test_numpy/test_manipulation_routines/test_joining_arrays.py +++ b/ivy_tests/test_ivy/test_frontends/test_numpy/test_manipulation_routines/test_joining_arrays.py @@ -7,6 +7,10 @@ import ivy_tests.test_ivy.test_frontends.test_numpy.helpers as np_frontend_helpers +# --- Helpers --- # +# --------------- # + + @st.composite def _arrays_idx_n_dtypes(draw): num_arrays = draw( @@ -33,6 +37,10 @@ def _arrays_idx_n_dtypes(draw): return x, input_dtypes, axis, casting, dtype +# --- Main --- # +# ------------ # + + # concat @handle_frontend_test( fn_tree="numpy.concatenate", @@ -64,12 +72,19 @@ def test_numpy_concatenate( ) -# stack +# hstack @handle_frontend_test( - fn_tree="numpy.stack", - dtype_and_x=_arrays_idx_n_dtypes(), + fn_tree="numpy.hstack", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + shared_dtype=True, + num_arrays=helpers.ints(min_value=2, max_value=10), + shape=helpers.get_shape( + min_num_dims=1, + ), + ), ) -def test_numpy_stack( +def test_numpy_hstack( dtype_and_x, frontend, test_flags, @@ -77,32 +92,24 @@ def test_numpy_stack( backend_fw, on_device, ): - xs, input_dtypes, unique_idx, _, _ = dtype_and_x + input_dtype, xs = dtype_and_x helpers.test_frontend_function( - input_dtypes=input_dtypes, + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - arrays=xs, - axis=unique_idx, + tup=xs, ) -# vstack +# stack @handle_frontend_test( - fn_tree="numpy.vstack", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - shared_dtype=True, - num_arrays=helpers.ints(min_value=2, max_value=10), - shape=helpers.get_shape( - min_num_dims=1, - ), - ), + fn_tree="numpy.stack", + dtype_and_x=_arrays_idx_n_dtypes(), ) -def test_numpy_vstack( +def test_numpy_stack( dtype_and_x, frontend, test_flags, @@ -110,21 +117,22 @@ def test_numpy_vstack( backend_fw, on_device, ): - input_dtype, xs = dtype_and_x + xs, input_dtypes, unique_idx, _, _ = dtype_and_x helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=input_dtypes, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - tup=xs, + arrays=xs, + axis=unique_idx, ) -# hstack +# vstack @handle_frontend_test( - fn_tree="numpy.hstack", + fn_tree="numpy.vstack", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), shared_dtype=True, @@ -134,7 +142,7 @@ def test_numpy_vstack( ), ), ) -def test_numpy_hstack( +def test_numpy_vstack( dtype_and_x, frontend, test_flags, diff --git a/ivy_tests/test_ivy/test_frontends/test_numpy/test_manipulation_routines/test_padding_arrays.py b/ivy_tests/test_ivy/test_frontends/test_numpy/test_manipulation_routines/test_padding_arrays.py index cb41299539350..2fa46f5db1cf2 100644 --- a/ivy_tests/test_ivy/test_frontends/test_numpy/test_manipulation_routines/test_padding_arrays.py +++ b/ivy_tests/test_ivy/test_frontends/test_numpy/test_manipulation_routines/test_padding_arrays.py @@ -6,28 +6,8 @@ from ivy_tests.test_ivy.helpers import handle_frontend_test -def st_tuples(elements, *, min_size=0, max_size=None, unique_by=None, unique=False): - return st.lists( - elements, - min_size=min_size, - max_size=max_size, - unique_by=unique_by, - unique=unique, - ).map(tuple) - - -def _st_tuples_or_int(n_pairs, min_val=0): - return st.one_of( - st_tuples( - st.tuples( - st.integers(min_value=min_val, max_value=4), - st.integers(min_value=min_val, max_value=4), - ), - min_size=n_pairs, - max_size=n_pairs, - ), - helpers.ints(min_value=min_val, max_value=4), - ) +# --- Helpers --- # +# --------------- # @st.composite @@ -77,6 +57,34 @@ def _pad_helper(draw): return dtype, input[0], pad_width, kwargs, mode +def _st_tuples_or_int(n_pairs, min_val=0): + return st.one_of( + st_tuples( + st.tuples( + st.integers(min_value=min_val, max_value=4), + st.integers(min_value=min_val, max_value=4), + ), + min_size=n_pairs, + max_size=n_pairs, + ), + helpers.ints(min_value=min_val, max_value=4), + ) + + +# --- Main --- # +# ------------ # + + +def st_tuples(elements, *, min_size=0, max_size=None, unique_by=None, unique=False): + return st.lists( + elements, + min_size=min_size, + max_size=max_size, + unique_by=unique_by, + unique=unique, + ).map(tuple) + + # pad @handle_frontend_test( fn_tree="numpy.pad", diff --git a/ivy_tests/test_ivy/test_frontends/test_numpy/test_manipulation_routines/test_rearranging_elements.py b/ivy_tests/test_ivy/test_frontends/test_numpy/test_manipulation_routines/test_rearranging_elements.py index 8c507648cb3b7..f1030b3807416 100644 --- a/ivy_tests/test_ivy/test_frontends/test_numpy/test_manipulation_routines/test_rearranging_elements.py +++ b/ivy_tests/test_ivy/test_frontends/test_numpy/test_manipulation_routines/test_rearranging_elements.py @@ -5,41 +5,8 @@ from ivy_tests.test_ivy.helpers import handle_frontend_test -# roll -@handle_frontend_test( - fn_tree="numpy.roll", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - min_num_dims=2, - min_dim_size=2, - ), - shift=helpers.ints(min_value=1, max_value=10), - axis=helpers.ints(min_value=-1, max_value=1), - test_with_out=st.just(False), -) -def test_numpy_roll( - *, - dtype_and_x, - shift, - axis, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - a=x[0], - shift=shift, - axis=axis, - ) +# --- Helpers --- # +# --------------- # @st.composite @@ -49,6 +16,62 @@ def _dtype_x_bounded_axis(draw, **kwargs): return dtype, x, axis +@st.composite +def _get_dtype_values_k_axes_for_rot90( + draw, + available_dtypes, + min_value=None, + max_value=None, + allow_inf=False, + exclude_min=False, + exclude_max=False, + min_num_dims=1, + max_num_dims=10, + min_dim_size=1, + max_dim_size=10, +): + shape = draw( + helpers.get_shape( + allow_none=False, + min_num_dims=min_num_dims, + max_num_dims=max_num_dims, + min_dim_size=min_dim_size, + max_dim_size=max_dim_size, + ) + ) + k = draw(helpers.ints(min_value=-4, max_value=4)) + axes = tuple( + draw( + st.lists( + helpers.ints(min_value=-(len(shape) - 1), max_value=len(shape) - 2), + min_size=2, + max_size=2, + unique=True, + ).filter(lambda axes: abs(axes[0] - axes[1]) != len(shape) - 1) + ) + ) + dtype = draw(st.sampled_from(draw(available_dtypes))) + values = draw( + helpers.array_values( + dtype=dtype, + shape=shape, + min_value=min_value, + max_value=max_value, + allow_inf=allow_inf, + exclude_min=exclude_min, + exclude_max=exclude_max, + large_abs_safety_factor=72, + small_abs_safety_factor=72, + safety_factor_scale="log", + ) + ) + return [dtype], values, k, axes + + +# --- Main --- # +# ------------ # + + @handle_frontend_test( fn_tree="numpy.flip", dtype_x_axis=_dtype_x_bounded_axis( @@ -145,56 +168,41 @@ def test_numpy_flipud( ) -@st.composite -def _get_dtype_values_k_axes_for_rot90( - draw, - available_dtypes, - min_value=None, - max_value=None, - allow_inf=False, - exclude_min=False, - exclude_max=False, - min_num_dims=1, - max_num_dims=10, - min_dim_size=1, - max_dim_size=10, +# roll +@handle_frontend_test( + fn_tree="numpy.roll", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + min_num_dims=2, + min_dim_size=2, + ), + shift=helpers.ints(min_value=1, max_value=10), + axis=helpers.ints(min_value=-1, max_value=1), + test_with_out=st.just(False), +) +def test_numpy_roll( + *, + dtype_and_x, + shift, + axis, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, ): - shape = draw( - helpers.get_shape( - allow_none=False, - min_num_dims=min_num_dims, - max_num_dims=max_num_dims, - min_dim_size=min_dim_size, - max_dim_size=max_dim_size, - ) - ) - k = draw(helpers.ints(min_value=-4, max_value=4)) - axes = tuple( - draw( - st.lists( - helpers.ints(min_value=-(len(shape) - 1), max_value=len(shape) - 2), - min_size=2, - max_size=2, - unique=True, - ).filter(lambda axes: abs(axes[0] - axes[1]) != len(shape) - 1) - ) - ) - dtype = draw(st.sampled_from(draw(available_dtypes))) - values = draw( - helpers.array_values( - dtype=dtype, - shape=shape, - min_value=min_value, - max_value=max_value, - allow_inf=allow_inf, - exclude_min=exclude_min, - exclude_max=exclude_max, - large_abs_safety_factor=72, - small_abs_safety_factor=72, - safety_factor_scale="log", - ) + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + a=x[0], + shift=shift, + axis=axis, ) - return [dtype], values, k, axes # rot90 diff --git a/ivy_tests/test_ivy/test_frontends/test_numpy/test_manipulation_routines/test_splitting_arrays.py b/ivy_tests/test_ivy/test_frontends/test_numpy/test_manipulation_routines/test_splitting_arrays.py index 063ec1d40a0aa..15d3d92c8d0c5 100644 --- a/ivy_tests/test_ivy/test_frontends/test_numpy/test_manipulation_routines/test_splitting_arrays.py +++ b/ivy_tests/test_ivy/test_frontends/test_numpy/test_manipulation_routines/test_splitting_arrays.py @@ -10,9 +10,9 @@ ) -# split +# array_split @handle_frontend_test( - fn_tree="numpy.split", + fn_tree="numpy.array_split", dtype_value=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("integer"), shape=st.shared(helpers.get_shape(min_num_dims=1), key="value_shape"), @@ -29,7 +29,7 @@ ), test_with_out=st.just(False), ) -def test_numpy_split( +def test_numpy_array_split( *, dtype_value, indices_or_sections, @@ -41,6 +41,7 @@ def test_numpy_split( backend_fw, ): input_dtype, value = dtype_value + assume(isinstance(indices_or_sections, int)) helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -54,30 +55,22 @@ def test_numpy_split( ) -# array_split +# dsplit @handle_frontend_test( - fn_tree="numpy.array_split", + fn_tree="numpy.dsplit", dtype_value=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("integer"), - shape=st.shared(helpers.get_shape(min_num_dims=1), key="value_shape"), + available_dtypes=helpers.get_dtypes("valid"), + shape=st.shared(helpers.get_shape(min_num_dims=3), key="value_shape"), ), indices_or_sections=_get_splits( - min_num_dims=1, allow_none=False, is_mod_split=True - ), - axis=st.shared( - helpers.get_axis( - shape=st.shared(helpers.get_shape(min_num_dims=1), key="value_shape"), - force_int=True, - ), - key="target_axis", + min_num_dims=3, axis=2, allow_none=False, is_mod_split=True ), test_with_out=st.just(False), ) -def test_numpy_array_split( +def test_numpy_dsplit( *, dtype_value, indices_or_sections, - axis, on_device, fn_tree, frontend, @@ -85,7 +78,8 @@ def test_numpy_array_split( backend_fw, ): input_dtype, value = dtype_value - assume(isinstance(indices_or_sections, int)) + if isinstance(indices_or_sections, np.ndarray): + assume(indices_or_sections.ndim == 0) helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -95,23 +89,22 @@ def test_numpy_array_split( on_device=on_device, ary=value[0], indices_or_sections=indices_or_sections, - axis=axis, ) -# dsplit +# hsplit @handle_frontend_test( - fn_tree="numpy.dsplit", + fn_tree="numpy.hsplit", dtype_value=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), - shape=st.shared(helpers.get_shape(min_num_dims=3), key="value_shape"), + shape=st.shared(helpers.get_shape(min_num_dims=2), key="value_shape"), ), indices_or_sections=_get_splits( - min_num_dims=3, axis=2, allow_none=False, is_mod_split=True + min_num_dims=2, axis=1, allow_none=False, is_mod_split=True ), test_with_out=st.just(False), ) -def test_numpy_dsplit( +def test_numpy_hsplit( *, dtype_value, indices_or_sections, @@ -136,22 +129,30 @@ def test_numpy_dsplit( ) -# vsplit +# split @handle_frontend_test( - fn_tree="numpy.vsplit", + fn_tree="numpy.split", dtype_value=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - shape=st.shared(helpers.get_shape(min_num_dims=2), key="value_shape"), + available_dtypes=helpers.get_dtypes("integer"), + shape=st.shared(helpers.get_shape(min_num_dims=1), key="value_shape"), ), indices_or_sections=_get_splits( - min_num_dims=2, axis=0, allow_none=False, is_mod_split=True + min_num_dims=1, allow_none=False, is_mod_split=True + ), + axis=st.shared( + helpers.get_axis( + shape=st.shared(helpers.get_shape(min_num_dims=1), key="value_shape"), + force_int=True, + ), + key="target_axis", ), test_with_out=st.just(False), ) -def test_numpy_vsplit( +def test_numpy_split( *, dtype_value, indices_or_sections, + axis, on_device, fn_tree, frontend, @@ -159,8 +160,6 @@ def test_numpy_vsplit( backend_fw, ): input_dtype, value = dtype_value - if isinstance(indices_or_sections, np.ndarray): - assume(indices_or_sections.ndim == 0) helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -170,22 +169,23 @@ def test_numpy_vsplit( on_device=on_device, ary=value[0], indices_or_sections=indices_or_sections, + axis=axis, ) -# hsplit +# vsplit @handle_frontend_test( - fn_tree="numpy.hsplit", + fn_tree="numpy.vsplit", dtype_value=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), shape=st.shared(helpers.get_shape(min_num_dims=2), key="value_shape"), ), indices_or_sections=_get_splits( - min_num_dims=2, axis=1, allow_none=False, is_mod_split=True + min_num_dims=2, axis=0, allow_none=False, is_mod_split=True ), test_with_out=st.just(False), ) -def test_numpy_hsplit( +def test_numpy_vsplit( *, dtype_value, indices_or_sections, diff --git a/ivy_tests/test_ivy/test_frontends/test_numpy/test_manipulation_routines/test_tiling_arrays.py b/ivy_tests/test_ivy/test_frontends/test_numpy/test_manipulation_routines/test_tiling_arrays.py index 1288a418e06d6..ba37b511864b3 100644 --- a/ivy_tests/test_ivy/test_frontends/test_numpy/test_manipulation_routines/test_tiling_arrays.py +++ b/ivy_tests/test_ivy/test_frontends/test_numpy/test_manipulation_routines/test_tiling_arrays.py @@ -6,27 +6,23 @@ from ivy_tests.test_ivy.helpers import handle_frontend_test -# tile +# repeat @handle_frontend_test( - fn_tree="numpy.tile", + fn_tree="numpy.repeat", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - shape=st.shared(helpers.get_shape(min_num_dims=1), key="value_shape"), - ), - dtype_and_repeats=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("signed_integer"), - shape=st.shared(helpers.get_shape(min_num_dims=1), key="value_shape").map( - lambda rep: (len(rep),) - ), - min_value=0, - max_value=10, + available_dtypes=helpers.get_dtypes("numeric"), + min_num_dims=2, + min_dim_size=2, ), + repeats=helpers.ints(min_value=2, max_value=5), + axis=helpers.ints(min_value=-1, max_value=1), test_with_out=st.just(False), ) -def test_numpy_tile( +def test_numpy_repeat( *, dtype_and_x, - dtype_and_repeats, + repeats, + axis, on_device, fn_tree, frontend, @@ -34,36 +30,40 @@ def test_numpy_tile( backend_fw, ): input_dtype, x = dtype_and_x - repeats_dtype, repeats = dtype_and_repeats helpers.test_frontend_function( - input_dtypes=input_dtype + repeats_dtype, + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - A=x[0], - reps=repeats[0], + a=x[0], + repeats=repeats, + axis=axis, ) -# repeat +# tile @handle_frontend_test( - fn_tree="numpy.repeat", + fn_tree="numpy.tile", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - min_num_dims=2, - min_dim_size=2, + available_dtypes=helpers.get_dtypes("valid"), + shape=st.shared(helpers.get_shape(min_num_dims=1), key="value_shape"), + ), + dtype_and_repeats=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("signed_integer"), + shape=st.shared(helpers.get_shape(min_num_dims=1), key="value_shape").map( + lambda rep: (len(rep),) + ), + min_value=0, + max_value=10, ), - repeats=helpers.ints(min_value=2, max_value=5), - axis=helpers.ints(min_value=-1, max_value=1), test_with_out=st.just(False), ) -def test_numpy_repeat( +def test_numpy_tile( *, dtype_and_x, - repeats, - axis, + dtype_and_repeats, on_device, fn_tree, frontend, @@ -71,14 +71,14 @@ def test_numpy_repeat( backend_fw, ): input_dtype, x = dtype_and_x + repeats_dtype, repeats = dtype_and_repeats helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=input_dtype + repeats_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - a=x[0], - repeats=repeats, - axis=axis, + A=x[0], + reps=repeats[0], ) diff --git a/ivy_tests/test_ivy/test_frontends/test_numpy/test_manipulation_routines/test_transpose_like_operations.py b/ivy_tests/test_ivy/test_frontends/test_numpy/test_manipulation_routines/test_transpose_like_operations.py index 17fc05e0df885..edf781f22c5be 100644 --- a/ivy_tests/test_ivy/test_frontends/test_numpy/test_manipulation_routines/test_transpose_like_operations.py +++ b/ivy_tests/test_ivy/test_frontends/test_numpy/test_manipulation_routines/test_transpose_like_operations.py @@ -7,36 +7,40 @@ import ivy_tests.test_ivy.test_frontends.test_numpy.helpers as np_frontend_helpers -# transpose +# rollaxis @handle_frontend_test( - fn_tree="numpy.transpose", - array_and_axes=np_frontend_helpers._array_and_axes_permute_helper( - min_num_dims=0, - max_num_dims=5, - min_dim_size=0, - max_dim_size=10, + fn_tree="numpy.rollaxis", + dtype_and_a=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + min_num_dims=3, + min_dim_size=2, ), + axis=helpers.ints(min_value=-2, max_value=2), + start=helpers.ints(min_value=-2, max_value=2), test_with_out=st.just(False), ) -def test_numpy_transpose( +def test_numpy_rollaxis( *, - array_and_axes, + dtype_and_a, + axis, + start, on_device, fn_tree, frontend, test_flags, backend_fw, ): - array, dtype, axes = array_and_axes + input_dtype, a = dtype_and_a helpers.test_frontend_function( - input_dtypes=dtype, + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - array=array, - axes=axes, + a=a[0], + axis=axis, + start=start, ) @@ -81,38 +85,34 @@ def test_numpy_swapaxes( ) -# rollaxis +# transpose @handle_frontend_test( - fn_tree="numpy.rollaxis", - dtype_and_a=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - min_num_dims=3, - min_dim_size=2, + fn_tree="numpy.transpose", + array_and_axes=np_frontend_helpers._array_and_axes_permute_helper( + min_num_dims=0, + max_num_dims=5, + min_dim_size=0, + max_dim_size=10, ), - axis=helpers.ints(min_value=-2, max_value=2), - start=helpers.ints(min_value=-2, max_value=2), test_with_out=st.just(False), ) -def test_numpy_rollaxis( +def test_numpy_transpose( *, - dtype_and_a, - axis, - start, + array_and_axes, on_device, fn_tree, frontend, test_flags, backend_fw, ): - input_dtype, a = dtype_and_a + array, dtype, axes = array_and_axes helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - a=a[0], - axis=axis, - start=start, + array=array, + axes=axes, ) diff --git a/ivy_tests/test_ivy/test_frontends/test_numpy/test_mathematical_functions/test_arithmetic_operations.py b/ivy_tests/test_ivy/test_frontends/test_numpy/test_mathematical_functions/test_arithmetic_operations.py index 86372aab70c91..ca9af362f1222 100644 --- a/ivy_tests/test_ivy/test_frontends/test_numpy/test_mathematical_functions/test_arithmetic_operations.py +++ b/ivy_tests/test_ivy/test_frontends/test_numpy/test_mathematical_functions/test_arithmetic_operations.py @@ -59,13 +59,14 @@ def test_numpy_add( ) -# subtract +# divide @handle_frontend_test( - fn_tree="numpy.subtract", + fn_tree="numpy.divide", + aliases=["numpy.true_divide"], dtypes_values_casting=np_frontend_helpers.dtypes_values_casting_dtype( arr_func=[ lambda: helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), + available_dtypes=helpers.get_dtypes("float"), num_arrays=2, shared_dtype=True, ) @@ -73,10 +74,10 @@ def test_numpy_add( ), where=np_frontend_helpers.where(), number_positional_args=np_frontend_helpers.get_num_positional_args_ufunc( - fn_name="subtract" + fn_name="divide" ), ) -def test_numpy_subtract( +def test_numpy_divide( dtypes_values_casting, where, frontend, @@ -86,6 +87,7 @@ def test_numpy_subtract( on_device, ): input_dtypes, xs, casting, dtype = dtypes_values_casting + assume(not np.any(np.isclose(xs[1], 0.0))) where, input_dtypes, test_flags = np_frontend_helpers.handle_where_and_array_bools( where=where, input_dtype=input_dtypes, @@ -98,6 +100,8 @@ def test_numpy_subtract( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, + atol=1e-03, + rtol=1e-03, x1=xs[0], x2=xs[1], out=None, @@ -109,79 +113,97 @@ def test_numpy_subtract( ) -# vdot +# divmod @handle_frontend_test( - fn_tree="numpy.vdot", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2 + fn_tree="numpy.divmod", + dtypes_values_casting=np_frontend_helpers.dtypes_values_casting_dtype( + arr_func=[ + lambda: helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, + allow_inf=False, + large_abs_safety_factor=6, + safety_factor_scale="linear", + shared_dtype=True, + ) + ], ), + where=np_frontend_helpers.where(), test_with_out=st.just(False), + number_positional_args=np_frontend_helpers.get_num_positional_args_ufunc( + fn_name="divmod" + ), ) -def test_numpy_vdot( - dtype_and_x, +def test_numpy_divmod( + dtypes_values_casting, + where, frontend, test_flags, fn_tree, backend_fw, on_device, ): - input_dtypes, xs = dtype_and_x - helpers.test_frontend_function( + input_dtypes, xs, casting, dtype = dtypes_values_casting + assume(not np.any(np.isclose(xs[1], 0))) + if dtype: + assume(np.dtype(dtype) >= np.dtype(input_dtypes[0])) + assume(np.dtype(dtype) >= np.dtype(input_dtypes[1])) + assume(not np.any(np.isclose(xs[1].astype(dtype), 0))) + + assume("uint" not in input_dtypes[0] and "uint" not in input_dtypes[1]) + where, input_dtypes, test_flags = np_frontend_helpers.handle_where_and_array_bools( + where=where, + input_dtype=input_dtypes, + test_flags=test_flags, + ) + np_frontend_helpers.test_frontend_function( input_dtypes=input_dtypes, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - test_values=False, - a=xs[0], - b=xs[1], + x1=xs[0], + x2=xs[1], ) -# divide +# float_power @handle_frontend_test( - fn_tree="numpy.divide", - aliases=["numpy.true_divide"], + fn_tree="numpy.float_power", dtypes_values_casting=np_frontend_helpers.dtypes_values_casting_dtype( - arr_func=[ - lambda: helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - shared_dtype=True, - ) - ], + arr_func=[lambda: _float_power_helper()], ), where=np_frontend_helpers.where(), number_positional_args=np_frontend_helpers.get_num_positional_args_ufunc( - fn_name="divide" + fn_name="float_power" ), ) -def test_numpy_divide( +def test_numpy_float_power( dtypes_values_casting, where, frontend, test_flags, fn_tree, backend_fw, - on_device, ): input_dtypes, xs, casting, dtype = dtypes_values_casting - assume(not np.any(np.isclose(xs[1], 0.0))) + xs = list(xs[0]) + input_dtypes = list(input_dtypes[0]) where, input_dtypes, test_flags = np_frontend_helpers.handle_where_and_array_bools( where=where, input_dtype=input_dtypes, test_flags=test_flags, ) + # removing casting options as they raise errors for this function + assume(casting == "same_kind") + assume(dtype != "bool") np_frontend_helpers.test_frontend_function( input_dtypes=input_dtypes, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, - on_device=on_device, - atol=1e-03, - rtol=1e-03, x1=xs[0], x2=xs[1], out=None, @@ -193,24 +215,26 @@ def test_numpy_divide( ) -# multiply +# floor_divide @handle_frontend_test( - fn_tree="numpy.multiply", + fn_tree="numpy.floor_divide", dtypes_values_casting=np_frontend_helpers.dtypes_values_casting_dtype( arr_func=[ lambda: helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2, + large_abs_safety_factor=4, shared_dtype=True, + safety_factor_scale="linear", ) ], ), where=np_frontend_helpers.where(), number_positional_args=np_frontend_helpers.get_num_positional_args_ufunc( - fn_name="multiply" + fn_name="floor_divide" ), ) -def test_numpy_multiply( +def test_numpy_floor_divide( dtypes_values_casting, where, frontend, @@ -219,7 +243,11 @@ def test_numpy_multiply( backend_fw, on_device, ): - input_dtypes, xs, casting, dtype = dtypes_values_casting + input_dtypes, x, casting, dtype = dtypes_values_casting + assume(not np.any(np.isclose(x[1], 0, rtol=1e-1, atol=1e-1))) + assume(not np.any(np.isclose(x[0], 0, rtol=1e-1, atol=1e-1))) + if dtype: + assume(np.dtype(dtype) >= np.dtype(input_dtypes[0])) where, input_dtypes, test_flags = np_frontend_helpers.handle_where_and_array_bools( where=where, input_dtype=input_dtypes, @@ -232,37 +260,39 @@ def test_numpy_multiply( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x1=xs[0], - x2=xs[1], + x1=x[0], + x2=x[1], out=None, where=where, casting=casting, order="K", dtype=dtype, subok=True, + atol=1e-2, + rtol=1e-2, ) -# power @handle_frontend_test( - fn_tree="numpy.power", + fn_tree="numpy.fmod", dtypes_values_casting=np_frontend_helpers.dtypes_values_casting_dtype( arr_func=[ lambda: helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("integer"), + available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2, - min_value=0, - max_value=7, shared_dtype=True, + large_abs_safety_factor=6, + small_abs_safety_factor=6, + safety_factor_scale="log", ) ], ), where=np_frontend_helpers.where(), number_positional_args=np_frontend_helpers.get_num_positional_args_ufunc( - fn_name="power" + fn_name="fmod" ), ) -def test_numpy_power( +def test_numpy_fmod( dtypes_values_casting, where, frontend, @@ -272,6 +302,12 @@ def test_numpy_power( on_device, ): input_dtypes, xs, casting, dtype = dtypes_values_casting + assume(not np.any(np.isclose(xs[1], 0.0))) + assume(not np.any(np.isclose(xs[0], 0.0))) + if dtype: + assume(not np.any(np.isclose(xs[1].astype(dtype), 0.0))) + assume(not np.any(np.isclose(xs[0].astype(dtype), 0.0))) + assume("uint" not in input_dtypes[0] and "uint" not in input_dtypes[1]) where, input_dtypes, test_flags = np_frontend_helpers.handle_where_and_array_bools( where=where, input_dtype=input_dtypes, @@ -295,42 +331,47 @@ def test_numpy_power( ) -# float_power +# mod @handle_frontend_test( - fn_tree="numpy.float_power", + fn_tree="numpy.mod", dtypes_values_casting=np_frontend_helpers.dtypes_values_casting_dtype( - arr_func=[lambda: _float_power_helper()], + arr_func=[ + lambda: helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, + min_value=0, + exclude_min=True, + shared_dtype=True, + ) + ], ), where=np_frontend_helpers.where(), number_positional_args=np_frontend_helpers.get_num_positional_args_ufunc( - fn_name="float_power" + fn_name="mod" ), ) -def test_numpy_float_power( +def test_numpy_mod( dtypes_values_casting, where, frontend, test_flags, fn_tree, backend_fw, + on_device, ): input_dtypes, xs, casting, dtype = dtypes_values_casting - xs = list(xs[0]) - input_dtypes = list(input_dtypes[0]) where, input_dtypes, test_flags = np_frontend_helpers.handle_where_and_array_bools( where=where, input_dtype=input_dtypes, test_flags=test_flags, ) - # removing casting options as they raise errors for this function - assume(casting == "same_kind") - assume(dtype != "bool") np_frontend_helpers.test_frontend_function( input_dtypes=input_dtypes, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, + on_device=on_device, x1=xs[0], x2=xs[1], out=None, @@ -339,25 +380,31 @@ def test_numpy_float_power( order="K", dtype=dtype, subok=True, + rtol=1e-5, + atol=1e-5, ) -# positive +# modf @handle_frontend_test( - fn_tree="numpy.positive", + fn_tree="numpy.modf", dtypes_values_casting=np_frontend_helpers.dtypes_values_casting_dtype( arr_func=[ lambda: helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=1, + min_value=0, + exclude_min=True, ) ], ), where=np_frontend_helpers.where(), + test_with_out=st.just(False), number_positional_args=np_frontend_helpers.get_num_positional_args_ufunc( - fn_name="positive" + fn_name="modf" ), ) -def test_numpy_positive( +def test_numpy_modf( dtypes_values_casting, where, frontend, @@ -389,22 +436,24 @@ def test_numpy_positive( ) -# negative +# multiply @handle_frontend_test( - fn_tree="numpy.negative", + fn_tree="numpy.multiply", dtypes_values_casting=np_frontend_helpers.dtypes_values_casting_dtype( arr_func=[ lambda: helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, + shared_dtype=True, ) ], ), where=np_frontend_helpers.where(), number_positional_args=np_frontend_helpers.get_num_positional_args_ufunc( - fn_name="negative" + fn_name="multiply" ), ) -def test_numpy_negative( +def test_numpy_multiply( dtypes_values_casting, where, frontend, @@ -413,7 +462,7 @@ def test_numpy_negative( backend_fw, on_device, ): - input_dtypes, x, casting, dtype = dtypes_values_casting + input_dtypes, xs, casting, dtype = dtypes_values_casting where, input_dtypes, test_flags = np_frontend_helpers.handle_where_and_array_bools( where=where, input_dtype=input_dtypes, @@ -426,7 +475,8 @@ def test_numpy_negative( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], + x1=xs[0], + x2=xs[1], out=None, where=where, casting=casting, @@ -436,26 +486,22 @@ def test_numpy_negative( ) -# floor_divide +# negative @handle_frontend_test( - fn_tree="numpy.floor_divide", + fn_tree="numpy.negative", dtypes_values_casting=np_frontend_helpers.dtypes_values_casting_dtype( arr_func=[ lambda: helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, - large_abs_safety_factor=4, - shared_dtype=True, - safety_factor_scale="linear", ) ], ), where=np_frontend_helpers.where(), number_positional_args=np_frontend_helpers.get_num_positional_args_ufunc( - fn_name="floor_divide" + fn_name="negative" ), ) -def test_numpy_floor_divide( +def test_numpy_negative( dtypes_values_casting, where, frontend, @@ -465,10 +511,6 @@ def test_numpy_floor_divide( on_device, ): input_dtypes, x, casting, dtype = dtypes_values_casting - assume(not np.any(np.isclose(x[1], 0, rtol=1e-1, atol=1e-1))) - assume(not np.any(np.isclose(x[0], 0, rtol=1e-1, atol=1e-1))) - if dtype: - assume(np.dtype(dtype) >= np.dtype(input_dtypes[0])) where, input_dtypes, test_flags = np_frontend_helpers.handle_where_and_array_bools( where=where, input_dtype=input_dtypes, @@ -481,39 +523,32 @@ def test_numpy_floor_divide( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x1=x[0], - x2=x[1], + x=x[0], out=None, where=where, casting=casting, order="K", dtype=dtype, subok=True, - atol=1e-2, - rtol=1e-2, ) -# mod +# positive @handle_frontend_test( - fn_tree="numpy.mod", + fn_tree="numpy.positive", dtypes_values_casting=np_frontend_helpers.dtypes_values_casting_dtype( arr_func=[ lambda: helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, - min_value=0, - exclude_min=True, - shared_dtype=True, ) ], ), where=np_frontend_helpers.where(), number_positional_args=np_frontend_helpers.get_num_positional_args_ufunc( - fn_name="mod" + fn_name="positive" ), ) -def test_numpy_mod( +def test_numpy_positive( dtypes_values_casting, where, frontend, @@ -522,7 +557,7 @@ def test_numpy_mod( backend_fw, on_device, ): - input_dtypes, xs, casting, dtype = dtypes_values_casting + input_dtypes, x, casting, dtype = dtypes_values_casting where, input_dtypes, test_flags = np_frontend_helpers.handle_where_and_array_bools( where=where, input_dtype=input_dtypes, @@ -535,39 +570,36 @@ def test_numpy_mod( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x1=xs[0], - x2=xs[1], + x=x[0], out=None, where=where, casting=casting, order="K", dtype=dtype, subok=True, - rtol=1e-5, - atol=1e-5, ) -# modf +# power @handle_frontend_test( - fn_tree="numpy.modf", + fn_tree="numpy.power", dtypes_values_casting=np_frontend_helpers.dtypes_values_casting_dtype( arr_func=[ lambda: helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=1, + available_dtypes=helpers.get_dtypes("integer"), + num_arrays=2, min_value=0, - exclude_min=True, + max_value=7, + shared_dtype=True, ) ], ), where=np_frontend_helpers.where(), - test_with_out=st.just(False), number_positional_args=np_frontend_helpers.get_num_positional_args_ufunc( - fn_name="modf" + fn_name="power" ), ) -def test_numpy_modf( +def test_numpy_power( dtypes_values_casting, where, frontend, @@ -576,7 +608,7 @@ def test_numpy_modf( backend_fw, on_device, ): - input_dtypes, x, casting, dtype = dtypes_values_casting + input_dtypes, xs, casting, dtype = dtypes_values_casting where, input_dtypes, test_flags = np_frontend_helpers.handle_where_and_array_bools( where=where, input_dtype=input_dtypes, @@ -589,7 +621,8 @@ def test_numpy_modf( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], + x1=xs[0], + x2=xs[1], out=None, where=where, casting=casting, @@ -650,51 +683,44 @@ def test_numpy_reciprocal( ) +# remainder @handle_frontend_test( - fn_tree="numpy.fmod", + fn_tree="numpy.remainder", dtypes_values_casting=np_frontend_helpers.dtypes_values_casting_dtype( arr_func=[ lambda: helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2, shared_dtype=True, - large_abs_safety_factor=6, - small_abs_safety_factor=6, - safety_factor_scale="log", ) ], ), where=np_frontend_helpers.where(), number_positional_args=np_frontend_helpers.get_num_positional_args_ufunc( - fn_name="fmod" + fn_name="remainder" ), ) -def test_numpy_fmod( +def test_numpy_remainder( dtypes_values_casting, where, frontend, test_flags, - fn_tree, backend_fw, + fn_tree, on_device, ): input_dtypes, xs, casting, dtype = dtypes_values_casting - assume(not np.any(np.isclose(xs[1], 0.0))) - assume(not np.any(np.isclose(xs[0], 0.0))) - if dtype: - assume(not np.any(np.isclose(xs[1].astype(dtype), 0.0))) - assume(not np.any(np.isclose(xs[0].astype(dtype), 0.0))) - assume("uint" not in input_dtypes[0] and "uint" not in input_dtypes[1]) where, input_dtypes, test_flags = np_frontend_helpers.handle_where_and_array_bools( where=where, input_dtype=input_dtypes, test_flags=test_flags, ) + assume(not np.any(np.isclose(xs[1], 0.0))) np_frontend_helpers.test_frontend_function( input_dtypes=input_dtypes, - backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, + backend_to_test=backend_fw, fn_tree=fn_tree, on_device=on_device, x1=xs[0], @@ -708,28 +734,24 @@ def test_numpy_fmod( ) -# divmod +# subtract @handle_frontend_test( - fn_tree="numpy.divmod", + fn_tree="numpy.subtract", dtypes_values_casting=np_frontend_helpers.dtypes_values_casting_dtype( arr_func=[ lambda: helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2, - allow_inf=False, - large_abs_safety_factor=6, - safety_factor_scale="linear", shared_dtype=True, ) ], ), where=np_frontend_helpers.where(), - test_with_out=st.just(False), number_positional_args=np_frontend_helpers.get_num_positional_args_ufunc( - fn_name="divmod" + fn_name="subtract" ), ) -def test_numpy_divmod( +def test_numpy_subtract( dtypes_values_casting, where, frontend, @@ -739,13 +761,6 @@ def test_numpy_divmod( on_device, ): input_dtypes, xs, casting, dtype = dtypes_values_casting - assume(not np.any(np.isclose(xs[1], 0))) - if dtype: - assume(np.dtype(dtype) >= np.dtype(input_dtypes[0])) - assume(np.dtype(dtype) >= np.dtype(input_dtypes[1])) - assume(not np.any(np.isclose(xs[1].astype(dtype), 0))) - - assume("uint" not in input_dtypes[0] and "uint" not in input_dtypes[1]) where, input_dtypes, test_flags = np_frontend_helpers.handle_where_and_array_bools( where=where, input_dtype=input_dtypes, @@ -760,55 +775,40 @@ def test_numpy_divmod( on_device=on_device, x1=xs[0], x2=xs[1], + out=None, + where=where, + casting=casting, + order="K", + dtype=dtype, + subok=True, ) -# remainder +# vdot @handle_frontend_test( - fn_tree="numpy.remainder", - dtypes_values_casting=np_frontend_helpers.dtypes_values_casting_dtype( - arr_func=[ - lambda: helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, - shared_dtype=True, - ) - ], - ), - where=np_frontend_helpers.where(), - number_positional_args=np_frontend_helpers.get_num_positional_args_ufunc( - fn_name="remainder" + fn_tree="numpy.vdot", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2 ), + test_with_out=st.just(False), ) -def test_numpy_remainder( - dtypes_values_casting, - where, +def test_numpy_vdot( + dtype_and_x, frontend, test_flags, - backend_fw, fn_tree, + backend_fw, on_device, ): - input_dtypes, xs, casting, dtype = dtypes_values_casting - where, input_dtypes, test_flags = np_frontend_helpers.handle_where_and_array_bools( - where=where, - input_dtype=input_dtypes, - test_flags=test_flags, - ) - assume(not np.any(np.isclose(xs[1], 0.0))) - np_frontend_helpers.test_frontend_function( + input_dtypes, xs = dtype_and_x + helpers.test_frontend_function( input_dtypes=input_dtypes, + backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, - backend_to_test=backend_fw, fn_tree=fn_tree, on_device=on_device, - x1=xs[0], - x2=xs[1], - out=None, - where=where, - casting=casting, - order="K", - dtype=dtype, - subok=True, + test_values=False, + a=xs[0], + b=xs[1], ) diff --git a/ivy_tests/test_ivy/test_frontends/test_numpy/test_mathematical_functions/test_exponents_and_logarithms.py b/ivy_tests/test_ivy/test_frontends/test_numpy/test_mathematical_functions/test_exponents_and_logarithms.py index e50c2ae06b392..0ac629ea8dde4 100644 --- a/ivy_tests/test_ivy/test_frontends/test_numpy/test_mathematical_functions/test_exponents_and_logarithms.py +++ b/ivy_tests/test_ivy/test_frontends/test_numpy/test_mathematical_functions/test_exponents_and_logarithms.py @@ -59,9 +59,9 @@ def test_numpy_exp( ) -# expm1 +# exp2 @handle_frontend_test( - fn_tree="numpy.expm1", + fn_tree="numpy.exp2", dtypes_values_casting=np_frontend_helpers.dtypes_values_casting_dtype( arr_func=[ lambda: helpers.dtype_and_values( @@ -71,10 +71,10 @@ def test_numpy_exp( ), where=np_frontend_helpers.where(), number_positional_args=np_frontend_helpers.get_num_positional_args_ufunc( - fn_name="expm1" + fn_name="exp2" ), ) -def test_numpy_expm1( +def test_numpy_exp2( dtypes_values_casting, where, frontend, @@ -97,7 +97,6 @@ def test_numpy_expm1( fn_tree=fn_tree, on_device=on_device, rtol=1e-02, - atol=1e-02, x=x[0], out=None, where=where, @@ -108,9 +107,9 @@ def test_numpy_expm1( ) -# exp2 +# expm1 @handle_frontend_test( - fn_tree="numpy.exp2", + fn_tree="numpy.expm1", dtypes_values_casting=np_frontend_helpers.dtypes_values_casting_dtype( arr_func=[ lambda: helpers.dtype_and_values( @@ -120,10 +119,10 @@ def test_numpy_expm1( ), where=np_frontend_helpers.where(), number_positional_args=np_frontend_helpers.get_num_positional_args_ufunc( - fn_name="exp2" + fn_name="expm1" ), ) -def test_numpy_exp2( +def test_numpy_expm1( dtypes_values_casting, where, frontend, @@ -146,6 +145,7 @@ def test_numpy_exp2( fn_tree=fn_tree, on_device=on_device, rtol=1e-02, + atol=1e-02, x=x[0], out=None, where=where, @@ -156,22 +156,125 @@ def test_numpy_exp2( ) -# log10 +# frexp @handle_frontend_test( - fn_tree="numpy.log10", + fn_tree="numpy.frexp", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=["float32", "float64"], + num_arrays=1, + shared_dtype=True, + min_value=-100, + max_value=100, + min_num_dims=1, + max_num_dims=3, + ), + number_positional_args=np_frontend_helpers.get_num_positional_args_ufunc( + fn_name="frexp" + ), +) +def test_numpy_frexp( + *, + dtype_and_x, + test_flags, + on_device, + fn_tree, + frontend, + backend_fw, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + test_flags=test_flags, + frontend=frontend, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + ) + + +# i0 +@handle_frontend_test( + fn_tree="numpy.i0", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_value=-10, + max_value=10, + min_num_dims=1, + max_num_dims=3, + min_dim_size=1, + max_dim_size=3, + ), +) +def test_numpy_i0( + *, + dtype_and_x, + test_flags, + on_device, + fn_tree, + frontend, + backend_fw, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + test_flags=test_flags, + frontend=frontend, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + ) + + +# ldexp +@handle_frontend_test( + fn_tree="numpy.ldexp", + dtype_and_x=ldexp_args(), + number_positional_args=np_frontend_helpers.get_num_positional_args_ufunc( + fn_name="ldexp" + ), +) +def test_numpy_ldexp( + *, + dtype_and_x, + test_flags, + on_device, + fn_tree, + frontend, + backend_fw, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + test_flags=test_flags, + frontend=frontend, + fn_tree=fn_tree, + on_device=on_device, + x1=x[0], + x2=x[1], + ) + + +# log +@handle_frontend_test( + fn_tree="numpy.log", dtypes_values_casting=np_frontend_helpers.dtypes_values_casting_dtype( arr_func=[ lambda: helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), + small_abs_safety_factor=2, + safety_factor_scale="log", ) ], ), where=np_frontend_helpers.where(), number_positional_args=np_frontend_helpers.get_num_positional_args_ufunc( - fn_name="log10" + fn_name="log" ), ) -def test_numpy_log10( +def test_numpy_log( dtypes_values_casting, where, frontend, @@ -193,8 +296,7 @@ def test_numpy_log10( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - rtol=1e-2, - atol=1e-2, + rtol=1e-03, x=x[0], out=None, where=where, @@ -205,24 +307,22 @@ def test_numpy_log10( ) -# log +# log10 @handle_frontend_test( - fn_tree="numpy.log", + fn_tree="numpy.log10", dtypes_values_casting=np_frontend_helpers.dtypes_values_casting_dtype( arr_func=[ lambda: helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - small_abs_safety_factor=2, - safety_factor_scale="log", ) ], ), where=np_frontend_helpers.where(), number_positional_args=np_frontend_helpers.get_num_positional_args_ufunc( - fn_name="log" + fn_name="log10" ), ) -def test_numpy_log( +def test_numpy_log10( dtypes_values_casting, where, frontend, @@ -244,7 +344,8 @@ def test_numpy_log( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - rtol=1e-03, + rtol=1e-2, + atol=1e-2, x=x[0], out=None, where=where, @@ -255,24 +356,24 @@ def test_numpy_log( ) -# log2 +# log1p @handle_frontend_test( - fn_tree="numpy.log2", + fn_tree="numpy.log1p", dtypes_values_casting=np_frontend_helpers.dtypes_values_casting_dtype( arr_func=[ lambda: helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), small_abs_safety_factor=2, - safety_factor_scale="linear", + safety_factor_scale="log", ) ], ), where=np_frontend_helpers.where(), number_positional_args=np_frontend_helpers.get_num_positional_args_ufunc( - fn_name="log2" + fn_name="log1p" ), ) -def test_numpy_log2( +def test_numpy_log1p( dtypes_values_casting, where, frontend, @@ -295,7 +396,6 @@ def test_numpy_log2( fn_tree=fn_tree, on_device=on_device, rtol=1e-3, - atol=1e-3, x=x[0], out=None, where=where, @@ -306,24 +406,24 @@ def test_numpy_log2( ) -# log1p +# log2 @handle_frontend_test( - fn_tree="numpy.log1p", + fn_tree="numpy.log2", dtypes_values_casting=np_frontend_helpers.dtypes_values_casting_dtype( arr_func=[ lambda: helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), small_abs_safety_factor=2, - safety_factor_scale="log", + safety_factor_scale="linear", ) ], ), where=np_frontend_helpers.where(), number_positional_args=np_frontend_helpers.get_num_positional_args_ufunc( - fn_name="log1p" + fn_name="log2" ), ) -def test_numpy_log1p( +def test_numpy_log2( dtypes_values_casting, where, frontend, @@ -346,6 +446,7 @@ def test_numpy_log1p( fn_tree=fn_tree, on_device=on_device, rtol=1e-3, + atol=1e-3, x=x[0], out=None, where=where, @@ -460,104 +561,3 @@ def test_numpy_logaddexp2( dtype=dtype, subok=True, ) - - -# i0 -@handle_frontend_test( - fn_tree="numpy.i0", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_value=-10, - max_value=10, - min_num_dims=1, - max_num_dims=3, - min_dim_size=1, - max_dim_size=3, - ), -) -def test_numpy_i0( - *, - dtype_and_x, - test_flags, - on_device, - fn_tree, - frontend, - backend_fw, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - test_flags=test_flags, - frontend=frontend, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - ) - - -# frexp -@handle_frontend_test( - fn_tree="numpy.frexp", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=["float32", "float64"], - num_arrays=1, - shared_dtype=True, - min_value=-100, - max_value=100, - min_num_dims=1, - max_num_dims=3, - ), - number_positional_args=np_frontend_helpers.get_num_positional_args_ufunc( - fn_name="frexp" - ), -) -def test_numpy_frexp( - *, - dtype_and_x, - test_flags, - on_device, - fn_tree, - frontend, - backend_fw, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - test_flags=test_flags, - frontend=frontend, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - ) - - -# ldexp -@handle_frontend_test( - fn_tree="numpy.ldexp", - dtype_and_x=ldexp_args(), - number_positional_args=np_frontend_helpers.get_num_positional_args_ufunc( - fn_name="ldexp" - ), -) -def test_numpy_ldexp( - *, - dtype_and_x, - test_flags, - on_device, - fn_tree, - frontend, - backend_fw, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - test_flags=test_flags, - frontend=frontend, - fn_tree=fn_tree, - on_device=on_device, - x1=x[0], - x2=x[1], - ) diff --git a/ivy_tests/test_ivy/test_frontends/test_numpy/test_mathematical_functions/test_extrema_finding.py b/ivy_tests/test_ivy/test_frontends/test_numpy/test_mathematical_functions/test_extrema_finding.py index 1c8df0ce44aaa..2d5f59ff7b45e 100644 --- a/ivy_tests/test_ivy/test_frontends/test_numpy/test_mathematical_functions/test_extrema_finding.py +++ b/ivy_tests/test_ivy/test_frontends/test_numpy/test_mathematical_functions/test_extrema_finding.py @@ -9,33 +9,35 @@ from ivy import inf -# minimum +# amax @handle_frontend_test( - fn_tree="numpy.minimum", - dtypes_values_casting=np_frontend_helpers.dtypes_values_casting_dtype( - arr_func=[ - lambda: helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, - shared_dtype=True, - ) - ], + fn_tree="numpy.amax", + dtype_x_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("float"), + min_num_dims=1, + valid_axis=True, + force_int_axis=True, + large_abs_safety_factor=2, + safety_factor_scale="log", ), + initial=st.one_of(st.floats(min_value=-1000, max_value=1000), st.none()), + keepdims=st.booleans(), where=np_frontend_helpers.where(), - number_positional_args=np_frontend_helpers.get_num_positional_args_ufunc( - fn_name="minimum" - ), ) -def test_numpy_minimum( - dtypes_values_casting, - where, +def test_numpy_amax( + dtype_x_axis, frontend, test_flags, fn_tree, backend_fw, on_device, + where, + initial, + keepdims, ): - input_dtypes, xs, casting, dtype = dtypes_values_casting + if initial is None and np.all(where) is not True: + assume(initial is +inf) + input_dtypes, x, axis = dtype_x_axis where, input_dtypes, test_flags = np_frontend_helpers.handle_where_and_array_bools( where=where, input_dtype=input_dtypes, @@ -48,14 +50,11 @@ def test_numpy_minimum( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x1=xs[0], - x2=xs[1], - out=None, + a=x[0], + axis=axis, + keepdims=keepdims, + initial=initial, where=where, - casting=casting, - order="K", - dtype=dtype, - subok=True, ) @@ -109,35 +108,31 @@ def test_numpy_amin( ) -# amax +# fmax @handle_frontend_test( - fn_tree="numpy.amax", - dtype_x_axis=helpers.dtype_values_axis( + fn_tree="numpy.fmax", + dtype_and_inputs=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - min_num_dims=1, - valid_axis=True, - force_int_axis=True, - large_abs_safety_factor=2, - safety_factor_scale="log", + num_arrays=2, + min_value=-np.inf, + max_value=np.inf, + shared_dtype=True, ), - initial=st.one_of(st.floats(min_value=-1000, max_value=1000), st.none()), - keepdims=st.booleans(), where=np_frontend_helpers.where(), + number_positional_args=np_frontend_helpers.get_num_positional_args_ufunc( + fn_name="fmax" + ), ) -def test_numpy_amax( - dtype_x_axis, +def test_numpy_fmax( + dtype_and_inputs, + where, frontend, test_flags, fn_tree, backend_fw, on_device, - where, - initial, - keepdims, ): - if initial is None and np.all(where) is not True: - assume(initial is +inf) - input_dtypes, x, axis = dtype_x_axis + input_dtypes, xs = dtype_and_inputs where, input_dtypes, test_flags = np_frontend_helpers.handle_where_and_array_bools( where=where, input_dtype=input_dtypes, @@ -150,46 +145,40 @@ def test_numpy_amax( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - a=x[0], - axis=axis, - keepdims=keepdims, - initial=initial, + x1=xs[0], + x2=xs[1], + out=None, where=where, ) -# nanmin +# fmin @handle_frontend_test( - fn_tree="numpy.nanmin", - dtype_x_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("float"), - min_num_dims=1, - valid_axis=True, - force_int_axis=True, - large_abs_safety_factor=2, - safety_factor_scale="log", - allow_nan=True, - allow_inf=True, + fn_tree="numpy.fmin", + dtypes_values_casting=np_frontend_helpers.dtypes_values_casting_dtype( + arr_func=[ + lambda: helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, + shared_dtype=True, + ) + ], ), - initial=st.one_of(st.floats(min_value=-1000, max_value=1000), st.none()), - keepdims=st.booleans(), where=np_frontend_helpers.where(), + number_positional_args=np_frontend_helpers.get_num_positional_args_ufunc( + fn_name="fmin" + ), ) -def test_numpy_nanmin( - dtype_x_axis, +def test_numpy_fmin( + dtypes_values_casting, + where, frontend, test_flags, fn_tree, backend_fw, on_device, - where, - initial, - keepdims, ): - if initial is None and np.all(where) is not True: - assume(initial is inf) - - input_dtypes, x, axis = dtype_x_axis + input_dtypes, xs, casting, dtype = dtypes_values_casting where, input_dtypes, test_flags = np_frontend_helpers.handle_where_and_array_bools( where=where, input_dtype=input_dtypes, @@ -202,12 +191,14 @@ def test_numpy_nanmin( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - a=x[0], - axis=axis, + x1=xs[0], + x2=xs[1], out=None, - keepdims=keepdims, - initial=initial, where=where, + casting=casting, + order="K", + dtype=dtype, + subok=True, ) @@ -261,38 +252,33 @@ def test_numpy_maximum( ) -# nanmax +# minimum @handle_frontend_test( - fn_tree="numpy.nanmax", - dtype_x_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("float"), - min_num_dims=1, - valid_axis=True, - force_int_axis=True, - large_abs_safety_factor=2, - safety_factor_scale="log", - allow_nan=True, - allow_inf=True, + fn_tree="numpy.minimum", + dtypes_values_casting=np_frontend_helpers.dtypes_values_casting_dtype( + arr_func=[ + lambda: helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, + shared_dtype=True, + ) + ], ), - initial=st.one_of(st.floats(min_value=-1000, max_value=1000), st.none()), - keepdims=st.booleans(), where=np_frontend_helpers.where(), + number_positional_args=np_frontend_helpers.get_num_positional_args_ufunc( + fn_name="minimum" + ), ) -def test_numpy_nanmax( - dtype_x_axis, +def test_numpy_minimum( + dtypes_values_casting, + where, frontend, test_flags, fn_tree, backend_fw, on_device, - where, - initial, - keepdims, ): - if initial is None and np.all(where) is not True: - assume(initial is -inf) - - input_dtypes, x, axis = dtype_x_axis + input_dtypes, xs, casting, dtype = dtypes_values_casting where, input_dtypes, test_flags = np_frontend_helpers.handle_where_and_array_bools( where=where, input_dtype=input_dtypes, @@ -305,40 +291,49 @@ def test_numpy_nanmax( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - a=x[0], - axis=axis, + x1=xs[0], + x2=xs[1], out=None, - keepdims=keepdims, - initial=initial, where=where, + casting=casting, + order="K", + dtype=dtype, + subok=True, ) -# fmax +# nanmax @handle_frontend_test( - fn_tree="numpy.fmax", - dtype_and_inputs=helpers.dtype_and_values( + fn_tree="numpy.nanmax", + dtype_x_axis=helpers.dtype_values_axis( available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - min_value=-np.inf, - max_value=np.inf, - shared_dtype=True, + min_num_dims=1, + valid_axis=True, + force_int_axis=True, + large_abs_safety_factor=2, + safety_factor_scale="log", + allow_nan=True, + allow_inf=True, ), + initial=st.one_of(st.floats(min_value=-1000, max_value=1000), st.none()), + keepdims=st.booleans(), where=np_frontend_helpers.where(), - number_positional_args=np_frontend_helpers.get_num_positional_args_ufunc( - fn_name="fmax" - ), ) -def test_numpy_fmax( - dtype_and_inputs, - where, +def test_numpy_nanmax( + dtype_x_axis, frontend, test_flags, fn_tree, backend_fw, on_device, + where, + initial, + keepdims, ): - input_dtypes, xs = dtype_and_inputs + if initial is None and np.all(where) is not True: + assume(initial is -inf) + + input_dtypes, x, axis = dtype_x_axis where, input_dtypes, test_flags = np_frontend_helpers.handle_where_and_array_bools( where=where, input_dtype=input_dtypes, @@ -351,40 +346,47 @@ def test_numpy_fmax( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x1=xs[0], - x2=xs[1], + a=x[0], + axis=axis, out=None, + keepdims=keepdims, + initial=initial, where=where, ) -# fmin +# nanmin @handle_frontend_test( - fn_tree="numpy.fmin", - dtypes_values_casting=np_frontend_helpers.dtypes_values_casting_dtype( - arr_func=[ - lambda: helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, - shared_dtype=True, - ) - ], + fn_tree="numpy.nanmin", + dtype_x_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("float"), + min_num_dims=1, + valid_axis=True, + force_int_axis=True, + large_abs_safety_factor=2, + safety_factor_scale="log", + allow_nan=True, + allow_inf=True, ), + initial=st.one_of(st.floats(min_value=-1000, max_value=1000), st.none()), + keepdims=st.booleans(), where=np_frontend_helpers.where(), - number_positional_args=np_frontend_helpers.get_num_positional_args_ufunc( - fn_name="fmin" - ), ) -def test_numpy_fmin( - dtypes_values_casting, - where, +def test_numpy_nanmin( + dtype_x_axis, frontend, test_flags, fn_tree, backend_fw, on_device, + where, + initial, + keepdims, ): - input_dtypes, xs, casting, dtype = dtypes_values_casting + if initial is None and np.all(where) is not True: + assume(initial is inf) + + input_dtypes, x, axis = dtype_x_axis where, input_dtypes, test_flags = np_frontend_helpers.handle_where_and_array_bools( where=where, input_dtype=input_dtypes, @@ -397,12 +399,10 @@ def test_numpy_fmin( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x1=xs[0], - x2=xs[1], + a=x[0], + axis=axis, out=None, + keepdims=keepdims, + initial=initial, where=where, - casting=casting, - order="K", - dtype=dtype, - subok=True, ) diff --git a/ivy_tests/test_ivy/test_frontends/test_numpy/test_mathematical_functions/test_handling_complex_numbers.py b/ivy_tests/test_ivy/test_frontends/test_numpy/test_mathematical_functions/test_handling_complex_numbers.py index 1ff05270aae3d..8fad9822b3f27 100644 --- a/ivy_tests/test_ivy/test_frontends/test_numpy/test_mathematical_functions/test_handling_complex_numbers.py +++ b/ivy_tests/test_ivy/test_frontends/test_numpy/test_mathematical_functions/test_handling_complex_numbers.py @@ -37,39 +37,45 @@ def test_numpy_angle( ) -# imag +# conj @handle_frontend_test( - fn_tree="numpy.imag", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("valid")), - test_with_out=st.just(False), + fn_tree="numpy.conj", + aliases=["numpy.conjugate"], + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float_and_complex"), + ), + number_positional_args=np_frontend_helpers.get_num_positional_args_ufunc( + fn_name="conj" + ), ) -def test_numpy_imag( - dtype_and_x, +def test_numpy_conj( + on_device, frontend, - test_flags, + *, + dtype_and_x, fn_tree, + test_flags, backend_fw, - on_device, ): - input_dtypes, x = dtype_and_x + input_dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=input_dtypes, + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - val=x[0], + x=x[0], ) -# real +# imag @handle_frontend_test( - fn_tree="numpy.real", + fn_tree="numpy.imag", dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("valid")), test_with_out=st.just(False), ) -def test_numpy_real( +def test_numpy_imag( dtype_and_x, frontend, test_flags, @@ -89,33 +95,27 @@ def test_numpy_real( ) -# conj +# real @handle_frontend_test( - fn_tree="numpy.conj", - aliases=["numpy.conjugate"], - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float_and_complex"), - ), - number_positional_args=np_frontend_helpers.get_num_positional_args_ufunc( - fn_name="conj" - ), + fn_tree="numpy.real", + dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("valid")), + test_with_out=st.just(False), ) -def test_numpy_conj( - on_device, - frontend, - *, +def test_numpy_real( dtype_and_x, - fn_tree, + frontend, test_flags, + fn_tree, backend_fw, + on_device, ): - input_dtype, x = dtype_and_x + input_dtypes, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=input_dtypes, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], + val=x[0], ) diff --git a/ivy_tests/test_ivy/test_frontends/test_numpy/test_mathematical_functions/test_hyperbolic_functions.py b/ivy_tests/test_ivy/test_frontends/test_numpy/test_mathematical_functions/test_hyperbolic_functions.py index d676761631de2..bd01f0c675fdc 100644 --- a/ivy_tests/test_ivy/test_frontends/test_numpy/test_mathematical_functions/test_hyperbolic_functions.py +++ b/ivy_tests/test_ivy/test_frontends/test_numpy/test_mathematical_functions/test_hyperbolic_functions.py @@ -6,9 +6,9 @@ from ivy_tests.test_ivy.helpers import handle_frontend_test -# sinh +# arccosh @handle_frontend_test( - fn_tree="numpy.sinh", + fn_tree="numpy.arccosh", dtypes_values_casting=np_frontend_helpers.dtypes_values_casting_dtype( arr_func=[ lambda: helpers.dtype_and_values( @@ -20,10 +20,10 @@ ), where=np_frontend_helpers.where(), number_positional_args=np_frontend_helpers.get_num_positional_args_ufunc( - fn_name="sinh" + fn_name="arccosh" ), ) -def test_numpy_sinh( +def test_numpy_arccosh( dtypes_values_casting, where, frontend, @@ -45,8 +45,8 @@ def test_numpy_sinh( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - atol=1e-2, rtol=1e-2, + atol=1e-2, x=x[0], out=None, where=where, @@ -57,9 +57,9 @@ def test_numpy_sinh( ) -# cosh +# arcsinh @handle_frontend_test( - fn_tree="numpy.cosh", + fn_tree="numpy.arcsinh", dtypes_values_casting=np_frontend_helpers.dtypes_values_casting_dtype( arr_func=[ lambda: helpers.dtype_and_values( @@ -71,10 +71,10 @@ def test_numpy_sinh( ), where=np_frontend_helpers.where(), number_positional_args=np_frontend_helpers.get_num_positional_args_ufunc( - fn_name="cosh" + fn_name="arcsinh" ), ) -def test_numpy_cosh( +def test_numpy_arcsinh( dtypes_values_casting, where, frontend, @@ -96,6 +96,8 @@ def test_numpy_cosh( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, + rtol=1e-2, + atol=1e-2, x=x[0], out=None, where=where, @@ -106,9 +108,9 @@ def test_numpy_cosh( ) -# tanh +# arctanh @handle_frontend_test( - fn_tree="numpy.tanh", + fn_tree="numpy.arctanh", dtypes_values_casting=np_frontend_helpers.dtypes_values_casting_dtype( arr_func=[ lambda: helpers.dtype_and_values( @@ -120,10 +122,10 @@ def test_numpy_cosh( ), where=np_frontend_helpers.where(), number_positional_args=np_frontend_helpers.get_num_positional_args_ufunc( - fn_name="tanh" + fn_name="arctanh" ), ) -def test_numpy_tanh( +def test_numpy_arctanh( dtypes_values_casting, where, frontend, @@ -145,8 +147,6 @@ def test_numpy_tanh( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - atol=1e-3, - rtol=1e-3, x=x[0], out=None, where=where, @@ -157,9 +157,9 @@ def test_numpy_tanh( ) -# arcsinh +# cosh @handle_frontend_test( - fn_tree="numpy.arcsinh", + fn_tree="numpy.cosh", dtypes_values_casting=np_frontend_helpers.dtypes_values_casting_dtype( arr_func=[ lambda: helpers.dtype_and_values( @@ -171,10 +171,10 @@ def test_numpy_tanh( ), where=np_frontend_helpers.where(), number_positional_args=np_frontend_helpers.get_num_positional_args_ufunc( - fn_name="arcsinh" + fn_name="cosh" ), ) -def test_numpy_arcsinh( +def test_numpy_cosh( dtypes_values_casting, where, frontend, @@ -196,8 +196,6 @@ def test_numpy_arcsinh( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - rtol=1e-2, - atol=1e-2, x=x[0], out=None, where=where, @@ -208,9 +206,9 @@ def test_numpy_arcsinh( ) -# arccosh +# sinh @handle_frontend_test( - fn_tree="numpy.arccosh", + fn_tree="numpy.sinh", dtypes_values_casting=np_frontend_helpers.dtypes_values_casting_dtype( arr_func=[ lambda: helpers.dtype_and_values( @@ -222,10 +220,10 @@ def test_numpy_arcsinh( ), where=np_frontend_helpers.where(), number_positional_args=np_frontend_helpers.get_num_positional_args_ufunc( - fn_name="arccosh" + fn_name="sinh" ), ) -def test_numpy_arccosh( +def test_numpy_sinh( dtypes_values_casting, where, frontend, @@ -247,8 +245,8 @@ def test_numpy_arccosh( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - rtol=1e-2, atol=1e-2, + rtol=1e-2, x=x[0], out=None, where=where, @@ -259,9 +257,9 @@ def test_numpy_arccosh( ) -# arctanh +# tanh @handle_frontend_test( - fn_tree="numpy.arctanh", + fn_tree="numpy.tanh", dtypes_values_casting=np_frontend_helpers.dtypes_values_casting_dtype( arr_func=[ lambda: helpers.dtype_and_values( @@ -273,10 +271,10 @@ def test_numpy_arccosh( ), where=np_frontend_helpers.where(), number_positional_args=np_frontend_helpers.get_num_positional_args_ufunc( - fn_name="arctanh" + fn_name="tanh" ), ) -def test_numpy_arctanh( +def test_numpy_tanh( dtypes_values_casting, where, frontend, @@ -298,6 +296,8 @@ def test_numpy_arctanh( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, + atol=1e-3, + rtol=1e-3, x=x[0], out=None, where=where, diff --git a/ivy_tests/test_ivy/test_frontends/test_numpy/test_mathematical_functions/test_miscellaneous.py b/ivy_tests/test_ivy/test_frontends/test_numpy/test_mathematical_functions/test_miscellaneous.py index 9e20eba4b779d..254a462ccfaa9 100644 --- a/ivy_tests/test_ivy/test_frontends/test_numpy/test_mathematical_functions/test_miscellaneous.py +++ b/ivy_tests/test_ivy/test_frontends/test_numpy/test_mathematical_functions/test_miscellaneous.py @@ -9,6 +9,10 @@ import ivy +# --- Helpers --- # +# --------------- # + + @st.composite def _get_clip_inputs(draw): shape = draw( @@ -36,18 +40,27 @@ def _get_clip_inputs(draw): return x_dtype, x, min, max, casting, dtype -# clip +# --- Main --- # +# ------------ # + + +# absolute @handle_frontend_test( - fn_tree="numpy.clip", - input_and_ranges=_get_clip_inputs(), + fn_tree="numpy.absolute", + dtypes_values_casting=np_frontend_helpers.dtypes_values_casting_dtype( + arr_func=[ + lambda: helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + ) + ], + ), where=np_frontend_helpers.where(), number_positional_args=np_frontend_helpers.get_num_positional_args_ufunc( - fn_name="clip" + fn_name="absolute" ), - test_with_out=st.just(False), ) -def test_numpy_clip( - input_and_ranges, +def test_numpy_absolute( + dtypes_values_casting, where, frontend, test_flags, @@ -55,7 +68,7 @@ def test_numpy_clip( backend_fw, on_device, ): - input_dtypes, x, min, max, casting, dtype = input_and_ranges + input_dtypes, x, casting, dtype = dtypes_values_casting where, input_dtypes, test_flags = np_frontend_helpers.handle_where_and_array_bools( where=where, input_dtype=input_dtypes, @@ -68,9 +81,7 @@ def test_numpy_clip( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - a=x[0], - a_min=min, - a_max=max, + x=x[0], out=None, where=where, casting=casting, @@ -129,23 +140,18 @@ def test_numpy_cbrt( ) -# sqrt +# clip @handle_frontend_test( - fn_tree="numpy.sqrt", - dtypes_values_casting=np_frontend_helpers.dtypes_values_casting_dtype( - arr_func=[ - lambda: helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - ) - ], - ), + fn_tree="numpy.clip", + input_and_ranges=_get_clip_inputs(), where=np_frontend_helpers.where(), number_positional_args=np_frontend_helpers.get_num_positional_args_ufunc( - fn_name="sqrt" + fn_name="clip" ), + test_with_out=st.just(False), ) -def test_numpy_sqrt( - dtypes_values_casting, +def test_numpy_clip( + input_and_ranges, where, frontend, test_flags, @@ -153,7 +159,7 @@ def test_numpy_sqrt( backend_fw, on_device, ): - input_dtypes, x, casting, dtype = dtypes_values_casting + input_dtypes, x, min, max, casting, dtype = input_and_ranges where, input_dtypes, test_flags = np_frontend_helpers.handle_where_and_array_bools( where=where, input_dtype=input_dtypes, @@ -166,8 +172,9 @@ def test_numpy_sqrt( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], - atol=1e-2, + a=x[0], + a_min=min, + a_max=max, out=None, where=where, casting=casting, @@ -177,22 +184,64 @@ def test_numpy_sqrt( ) -# reciprocal @handle_frontend_test( - fn_tree="numpy.reciprocal", + fn_tree="numpy.convolve", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_num_dims=1, + max_num_dims=1, + num_arrays=2, + min_value=-10, + max_value=10, + shared_dtype=True, + ), + mode=st.sampled_from(["valid", "same", "full"]), + test_with_out=st.just(False), +) +def test_numpy_convolve( + dtype_and_x, + mode, + frontend, + test_flags, + fn_tree, + backend_fw, + on_device, +): + input_dtypes, xs = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtypes, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + a=xs[0], + v=xs[1], + mode=mode, + ) + + +# copysign +@handle_frontend_test( + fn_tree="numpy.copysign", dtypes_values_casting=np_frontend_helpers.dtypes_values_casting_dtype( arr_func=[ lambda: helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, + shared_dtype=True, + min_value=-100, + max_value=100, ) ], ), where=np_frontend_helpers.where(), + test_with_out=st.just(False), number_positional_args=np_frontend_helpers.get_num_positional_args_ufunc( - fn_name="reciprocal" + fn_name="copysign" ), ) -def test_numpy_reciprocal( +def test_numpy_copysign( dtypes_values_casting, where, frontend, @@ -201,7 +250,7 @@ def test_numpy_reciprocal( backend_fw, on_device, ): - input_dtypes, x, casting, dtype = dtypes_values_casting + input_dtypes, xs, casting, dtype = dtypes_values_casting where, input_dtypes, test_flags = np_frontend_helpers.handle_where_and_array_bools( where=where, input_dtype=input_dtypes, @@ -214,8 +263,10 @@ def test_numpy_reciprocal( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], + rtol=1e-2, atol=1e-2, + x1=xs[0], + x2=xs[1], out=None, where=where, casting=casting, @@ -225,22 +276,22 @@ def test_numpy_reciprocal( ) -# square +# fabs @handle_frontend_test( - fn_tree="numpy.square", + fn_tree="numpy.fabs", dtypes_values_casting=np_frontend_helpers.dtypes_values_casting_dtype( arr_func=[ lambda: helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), + available_dtypes=helpers.get_dtypes("float"), ) ], ), where=np_frontend_helpers.where(), number_positional_args=np_frontend_helpers.get_num_positional_args_ufunc( - fn_name="square" + fn_name="fabs" ), ) -def test_numpy_square( +def test_numpy_fabs( dtypes_values_casting, where, frontend, @@ -272,23 +323,26 @@ def test_numpy_square( ) -# absolute +# gcd @handle_frontend_test( - fn_tree="numpy.absolute", - dtypes_values_casting=np_frontend_helpers.dtypes_values_casting_dtype( - arr_func=[ - lambda: helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - ) - ], + fn_tree="numpy.gcd", + dtype_and_inputs=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("integer"), + num_arrays=2, + shared_dtype=False, + min_num_dims=1, + max_num_dims=3, + min_value=-100, + max_value=100, + allow_nan=False, ), where=np_frontend_helpers.where(), number_positional_args=np_frontend_helpers.get_num_positional_args_ufunc( - fn_name="absolute" + fn_name="gcd" ), ) -def test_numpy_absolute( - dtypes_values_casting, +def test_numpy_gcd( + dtype_and_inputs, where, frontend, test_flags, @@ -296,7 +350,7 @@ def test_numpy_absolute( backend_fw, on_device, ): - input_dtypes, x, casting, dtype = dtypes_values_casting + input_dtypes, xs = dtype_and_inputs where, input_dtypes, test_flags = np_frontend_helpers.handle_where_and_array_bools( where=where, input_dtype=input_dtypes, @@ -309,32 +363,31 @@ def test_numpy_absolute( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], + x1=xs[0], + x2=xs[1], out=None, where=where, - casting=casting, - order="K", - dtype=dtype, - subok=True, ) -# fabs +# heaviside @handle_frontend_test( - fn_tree="numpy.fabs", + fn_tree="numpy.heaviside", dtypes_values_casting=np_frontend_helpers.dtypes_values_casting_dtype( arr_func=[ lambda: helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, + shared_dtype=True, ) ], ), where=np_frontend_helpers.where(), number_positional_args=np_frontend_helpers.get_num_positional_args_ufunc( - fn_name="fabs" + fn_name="heaviside" ), ) -def test_numpy_fabs( +def test_numpy_heaviside( dtypes_values_casting, where, frontend, @@ -343,7 +396,7 @@ def test_numpy_fabs( backend_fw, on_device, ): - input_dtypes, x, casting, dtype = dtypes_values_casting + input_dtypes, (x1_list, x2_list), casting, dtype = dtypes_values_casting where, input_dtypes, test_flags = np_frontend_helpers.handle_where_and_array_bools( where=where, input_dtype=input_dtypes, @@ -356,7 +409,8 @@ def test_numpy_fabs( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], + x1=x1_list, + x2=x2_list, out=None, where=where, casting=casting, @@ -366,73 +420,94 @@ def test_numpy_fabs( ) -# sign +# interp @handle_frontend_test( - fn_tree="numpy.sign", - dtypes_values_casting=np_frontend_helpers.dtypes_values_casting_dtype( - arr_func=[ - lambda: helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - ) - ], + fn_tree="numpy.interp", + xp_and_fp=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, + min_num_dims=1, + max_num_dims=1, + min_dim_size=3, + min_value=-10000, + max_value=10000, ), - where=np_frontend_helpers.where(), - number_positional_args=np_frontend_helpers.get_num_positional_args_ufunc( - fn_name="sign" + x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), + left=st.one_of(st.none(), st.floats()), + right=st.one_of(st.none(), st.floats()), + period=st.one_of( + st.none(), + st.floats( + allow_nan=False, + allow_infinity=False, + allow_subnormal=False, + min_value=0.1, + max_value=1.0e5, + exclude_min=True, + ), ), + test_with_out=st.just(False), ) -def test_numpy_sign( - dtypes_values_casting, - where, +def test_numpy_interp( frontend, test_flags, fn_tree, backend_fw, on_device, + xp_and_fp, + x, + left, + right, + period, ): - input_dtypes, x, casting, dtype = dtypes_values_casting - where, input_dtypes, test_flags = np_frontend_helpers.handle_where_and_array_bools( - where=where, - input_dtype=input_dtypes, - test_flags=test_flags, - ) - + input_dtypes, xp_fp = xp_and_fp + xp = ivy.array(xp_fp[0]) + fp = ivy.array(xp_fp[1]) + if period is None: + xp_order = ivy.argsort(xp) + xp = xp[xp_order] + fp = fp[xp_order] + previous = xp[0] + for i in xp[1:]: + assume(i > previous) + previous = i + x_dtype, x = x np_frontend_helpers.test_frontend_function( - input_dtypes=input_dtypes, + input_dtypes=input_dtypes + x_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, x=x[0], - out=None, - where=where, - casting=casting, - order="K", - dtype=dtype, - subok=True, + xp=xp, + fp=fp, + left=left, + right=right, + period=period, ) -# heaviside +# lcm @handle_frontend_test( - fn_tree="numpy.heaviside", - dtypes_values_casting=np_frontend_helpers.dtypes_values_casting_dtype( - arr_func=[ - lambda: helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - shared_dtype=True, - ) - ], + fn_tree="numpy.lcm", + dtype_and_inputs=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("integer"), + num_arrays=2, + shared_dtype=True, + min_num_dims=1, + max_num_dims=3, + min_value=-100, + max_value=100, + allow_nan=False, ), where=np_frontend_helpers.where(), number_positional_args=np_frontend_helpers.get_num_positional_args_ufunc( - fn_name="heaviside" + fn_name="lcm" ), ) -def test_numpy_heaviside( - dtypes_values_casting, +def test_numpy_lcm( + dtype_and_inputs, where, frontend, test_flags, @@ -440,7 +515,7 @@ def test_numpy_heaviside( backend_fw, on_device, ): - input_dtypes, (x1_list, x2_list), casting, dtype = dtypes_values_casting + input_dtypes, xs = dtype_and_inputs where, input_dtypes, test_flags = np_frontend_helpers.handle_where_and_array_bools( where=where, input_dtype=input_dtypes, @@ -453,14 +528,10 @@ def test_numpy_heaviside( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x1=x1_list, - x2=x2_list, + x1=xs[0], + x2=xs[1], out=None, where=where, - casting=casting, - order="K", - dtype=dtype, - subok=True, ) @@ -537,132 +608,70 @@ def test_numpy_real_if_close( ) -# interp +# reciprocal @handle_frontend_test( - fn_tree="numpy.interp", - xp_and_fp=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - min_num_dims=1, - max_num_dims=1, - min_dim_size=3, - min_value=-10000, - max_value=10000, + fn_tree="numpy.reciprocal", + dtypes_values_casting=np_frontend_helpers.dtypes_values_casting_dtype( + arr_func=[ + lambda: helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + ) + ], ), - x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), - left=st.one_of(st.none(), st.floats()), - right=st.one_of(st.none(), st.floats()), - period=st.one_of( - st.none(), - st.floats( - allow_nan=False, - allow_infinity=False, - allow_subnormal=False, - min_value=0.1, - max_value=1.0e5, - exclude_min=True, - ), + where=np_frontend_helpers.where(), + number_positional_args=np_frontend_helpers.get_num_positional_args_ufunc( + fn_name="reciprocal" ), - test_with_out=st.just(False), ) -def test_numpy_interp( +def test_numpy_reciprocal( + dtypes_values_casting, + where, frontend, test_flags, fn_tree, backend_fw, on_device, - xp_and_fp, - x, - left, - right, - period, ): - input_dtypes, xp_fp = xp_and_fp - xp = ivy.array(xp_fp[0]) - fp = ivy.array(xp_fp[1]) - if period is None: - xp_order = ivy.argsort(xp) - xp = xp[xp_order] - fp = fp[xp_order] - previous = xp[0] - for i in xp[1:]: - assume(i > previous) - previous = i - x_dtype, x = x - np_frontend_helpers.test_frontend_function( - input_dtypes=input_dtypes + x_dtype, - backend_to_test=backend_fw, - frontend=frontend, + input_dtypes, x, casting, dtype = dtypes_values_casting + where, input_dtypes, test_flags = np_frontend_helpers.handle_where_and_array_bools( + where=where, + input_dtype=input_dtypes, test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - xp=xp, - fp=fp, - left=left, - right=right, - period=period, ) - - -@handle_frontend_test( - fn_tree="numpy.convolve", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_num_dims=1, - max_num_dims=1, - num_arrays=2, - min_value=-10, - max_value=10, - shared_dtype=True, - ), - mode=st.sampled_from(["valid", "same", "full"]), - test_with_out=st.just(False), -) -def test_numpy_convolve( - dtype_and_x, - mode, - frontend, - test_flags, - fn_tree, - backend_fw, - on_device, -): - input_dtypes, xs = dtype_and_x - helpers.test_frontend_function( + np_frontend_helpers.test_frontend_function( input_dtypes=input_dtypes, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - a=xs[0], - v=xs[1], - mode=mode, + x=x[0], + atol=1e-2, + out=None, + where=where, + casting=casting, + order="K", + dtype=dtype, + subok=True, ) -# copysign +# sign @handle_frontend_test( - fn_tree="numpy.copysign", + fn_tree="numpy.sign", dtypes_values_casting=np_frontend_helpers.dtypes_values_casting_dtype( arr_func=[ lambda: helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - shared_dtype=True, - min_value=-100, - max_value=100, + available_dtypes=helpers.get_dtypes("numeric"), ) ], ), where=np_frontend_helpers.where(), - test_with_out=st.just(False), number_positional_args=np_frontend_helpers.get_num_positional_args_ufunc( - fn_name="copysign" + fn_name="sign" ), ) -def test_numpy_copysign( +def test_numpy_sign( dtypes_values_casting, where, frontend, @@ -671,12 +680,13 @@ def test_numpy_copysign( backend_fw, on_device, ): - input_dtypes, xs, casting, dtype = dtypes_values_casting + input_dtypes, x, casting, dtype = dtypes_values_casting where, input_dtypes, test_flags = np_frontend_helpers.handle_where_and_array_bools( where=where, input_dtype=input_dtypes, test_flags=test_flags, ) + np_frontend_helpers.test_frontend_function( input_dtypes=input_dtypes, backend_to_test=backend_fw, @@ -684,10 +694,7 @@ def test_numpy_copysign( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - rtol=1e-2, - atol=1e-2, - x1=xs[0], - x2=xs[1], + x=x[0], out=None, where=where, casting=casting, @@ -697,26 +704,23 @@ def test_numpy_copysign( ) -# lcm +# sqrt @handle_frontend_test( - fn_tree="numpy.lcm", - dtype_and_inputs=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("integer"), - num_arrays=2, - shared_dtype=True, - min_num_dims=1, - max_num_dims=3, - min_value=-100, - max_value=100, - allow_nan=False, + fn_tree="numpy.sqrt", + dtypes_values_casting=np_frontend_helpers.dtypes_values_casting_dtype( + arr_func=[ + lambda: helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + ) + ], ), where=np_frontend_helpers.where(), number_positional_args=np_frontend_helpers.get_num_positional_args_ufunc( - fn_name="lcm" + fn_name="sqrt" ), ) -def test_numpy_lcm( - dtype_and_inputs, +def test_numpy_sqrt( + dtypes_values_casting, where, frontend, test_flags, @@ -724,7 +728,7 @@ def test_numpy_lcm( backend_fw, on_device, ): - input_dtypes, xs = dtype_and_inputs + input_dtypes, x, casting, dtype = dtypes_values_casting where, input_dtypes, test_flags = np_frontend_helpers.handle_where_and_array_bools( where=where, input_dtype=input_dtypes, @@ -737,33 +741,34 @@ def test_numpy_lcm( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x1=xs[0], - x2=xs[1], + x=x[0], + atol=1e-2, out=None, where=where, + casting=casting, + order="K", + dtype=dtype, + subok=True, ) -# gcd +# square @handle_frontend_test( - fn_tree="numpy.gcd", - dtype_and_inputs=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("integer"), - num_arrays=2, - shared_dtype=False, - min_num_dims=1, - max_num_dims=3, - min_value=-100, - max_value=100, - allow_nan=False, + fn_tree="numpy.square", + dtypes_values_casting=np_frontend_helpers.dtypes_values_casting_dtype( + arr_func=[ + lambda: helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + ) + ], ), where=np_frontend_helpers.where(), number_positional_args=np_frontend_helpers.get_num_positional_args_ufunc( - fn_name="gcd" + fn_name="square" ), ) -def test_numpy_gcd( - dtype_and_inputs, +def test_numpy_square( + dtypes_values_casting, where, frontend, test_flags, @@ -771,7 +776,7 @@ def test_numpy_gcd( backend_fw, on_device, ): - input_dtypes, xs = dtype_and_inputs + input_dtypes, x, casting, dtype = dtypes_values_casting where, input_dtypes, test_flags = np_frontend_helpers.handle_where_and_array_bools( where=where, input_dtype=input_dtypes, @@ -784,8 +789,11 @@ def test_numpy_gcd( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x1=xs[0], - x2=xs[1], + x=x[0], out=None, where=where, + casting=casting, + order="K", + dtype=dtype, + subok=True, ) diff --git a/ivy_tests/test_ivy/test_frontends/test_numpy/test_mathematical_functions/test_rounding.py b/ivy_tests/test_ivy/test_frontends/test_numpy/test_mathematical_functions/test_rounding.py index 2fd5ded27a81a..c8432734b140b 100644 --- a/ivy_tests/test_ivy/test_frontends/test_numpy/test_mathematical_functions/test_rounding.py +++ b/ivy_tests/test_ivy/test_frontends/test_numpy/test_mathematical_functions/test_rounding.py @@ -7,56 +7,40 @@ from ivy_tests.test_ivy.helpers import handle_frontend_test -# ceil +# around @handle_frontend_test( - fn_tree="numpy.ceil", - dtypes_values_casting=np_frontend_helpers.dtypes_values_casting_dtype( - arr_func=[ - lambda: helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - ) - ], - ), - where=np_frontend_helpers.where(), - number_positional_args=np_frontend_helpers.get_num_positional_args_ufunc( - fn_name="ceil" + fn_tree="numpy.around", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), ), + decimals=st.integers(min_value=0, max_value=5), ) -def test_numpy_ceil( - dtypes_values_casting, - where, +def test_numpy_around( + *, + dtype_and_x, + decimals, + on_device, + fn_tree, frontend, test_flags, - fn_tree, backend_fw, - on_device, ): - input_dtypes, x, casting, dtype = dtypes_values_casting - where, input_dtypes, test_flags = np_frontend_helpers.handle_where_and_array_bools( - where=where, - input_dtype=input_dtypes, - test_flags=test_flags, - ) - np_frontend_helpers.test_frontend_function( - input_dtypes=input_dtypes, + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], - out=None, - where=where, - casting=casting, - order="K", - dtype=dtype, - subok=True, + a=x[0], + decimals=decimals, ) -# floor +# ceil @handle_frontend_test( - fn_tree="numpy.floor", + fn_tree="numpy.ceil", dtypes_values_casting=np_frontend_helpers.dtypes_values_casting_dtype( arr_func=[ lambda: helpers.dtype_and_values( @@ -66,10 +50,10 @@ def test_numpy_ceil( ), where=np_frontend_helpers.where(), number_positional_args=np_frontend_helpers.get_num_positional_args_ufunc( - fn_name="floor" + fn_name="ceil" ), ) -def test_numpy_floor( +def test_numpy_ceil( dtypes_values_casting, where, frontend, @@ -129,9 +113,9 @@ def test_numpy_fix( ) -# trunc +# floor @handle_frontend_test( - fn_tree="numpy.trunc", + fn_tree="numpy.floor", dtypes_values_casting=np_frontend_helpers.dtypes_values_casting_dtype( arr_func=[ lambda: helpers.dtype_and_values( @@ -141,10 +125,10 @@ def test_numpy_fix( ), where=np_frontend_helpers.where(), number_positional_args=np_frontend_helpers.get_num_positional_args_ufunc( - fn_name="trunc" + fn_name="floor" ), ) -def test_numpy_trunc( +def test_numpy_floor( dtypes_values_casting, where, frontend, @@ -223,23 +207,25 @@ def test_numpy_rint( ) -# around +# round @handle_frontend_test( - fn_tree="numpy.around", + fn_tree="numpy.round", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), + available_dtypes=helpers.get_dtypes("float"), + max_value=50, + min_value=-50, ), - decimals=st.integers(min_value=0, max_value=5), + decimals=st.integers(min_value=0, max_value=3), ) -def test_numpy_around( +def test_numpy_round( *, dtype_and_x, decimals, on_device, + backend_fw, fn_tree, frontend, test_flags, - backend_fw, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -254,34 +240,48 @@ def test_numpy_around( ) -# round +# trunc @handle_frontend_test( - fn_tree="numpy.round", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - max_value=50, - min_value=-50, + fn_tree="numpy.trunc", + dtypes_values_casting=np_frontend_helpers.dtypes_values_casting_dtype( + arr_func=[ + lambda: helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + ) + ], + ), + where=np_frontend_helpers.where(), + number_positional_args=np_frontend_helpers.get_num_positional_args_ufunc( + fn_name="trunc" ), - decimals=st.integers(min_value=0, max_value=3), ) -def test_numpy_round( - *, - dtype_and_x, - decimals, - on_device, - backend_fw, - fn_tree, +def test_numpy_trunc( + dtypes_values_casting, + where, frontend, test_flags, + fn_tree, + backend_fw, + on_device, ): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes, x, casting, dtype = dtypes_values_casting + where, input_dtypes, test_flags = np_frontend_helpers.handle_where_and_array_bools( + where=where, + input_dtype=input_dtypes, + test_flags=test_flags, + ) + np_frontend_helpers.test_frontend_function( + input_dtypes=input_dtypes, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - a=x[0], - decimals=decimals, + x=x[0], + out=None, + where=where, + casting=casting, + order="K", + dtype=dtype, + subok=True, ) diff --git a/ivy_tests/test_ivy/test_frontends/test_numpy/test_mathematical_functions/test_sums_products_differences.py b/ivy_tests/test_ivy/test_frontends/test_numpy/test_mathematical_functions/test_sums_products_differences.py index 8d85475a44891..76686c58f938d 100644 --- a/ivy_tests/test_ivy/test_frontends/test_numpy/test_mathematical_functions/test_sums_products_differences.py +++ b/ivy_tests/test_ivy/test_frontends/test_numpy/test_mathematical_functions/test_sums_products_differences.py @@ -8,6 +8,36 @@ from ivy_tests.test_ivy.helpers import handle_frontend_test +# --- Helpers --- # +# --------------- # + + +# trapz +@st.composite +def _either_x_dx(draw): + rand = (draw(st.integers(min_value=0, max_value=1)),) + if rand == 0: + either_x_dx = draw( + helpers.dtype_and_values( + avaliable_dtypes=st.shared( + helpers.get_dtypes("float"), key="trapz_dtype" + ), + min_value=-100, + max_value=100, + min_num_dims=1, + max_num_dims=3, + min_dim_size=1, + max_dim_size=3, + ) + ) + return rand, either_x_dx + else: + either_x_dx = draw( + st.floats(min_value=-10, max_value=10), + ) + return rand, either_x_dx + + # helpers @st.composite def _get_castable_dtypes_values(draw, *, allow_nan=False, use_where=False): @@ -34,32 +64,29 @@ def _get_castable_dtypes_values(draw, *, allow_nan=False, use_where=False): return [dtype1], [values], axis, dtype2 -# sum +# --- Main --- # +# ------------ # + + +# cumprod @handle_frontend_test( - fn_tree="numpy.sum", - dtype_x_axis_dtype=_get_castable_dtypes_values(use_where=True), - keep_dims=st.booleans(), - initial=st.one_of(st.floats(min_value=-100, max_value=100)), + fn_tree="numpy.cumprod", + dtype_x_axis_dtypes=_get_castable_dtypes_values(), ) -def test_numpy_sum( - dtype_x_axis_dtype, - keep_dims, - initial, +def test_numpy_cumprod( + dtype_x_axis_dtypes, frontend, test_flags, fn_tree, backend_fw, on_device, ): - input_dtypes, x, axis, dtype, where = dtype_x_axis_dtype + input_dtypes, x, axis, dtype = dtype_x_axis_dtypes + # ToDo: set as_variable_flags as the parameter generated by test_cumprod once + # this issue is marked as completed https://github.com/pytorch/pytorch/issues/75733 if backend_fw == "torch": assume(not test_flags.as_variable[0]) - where, input_dtypes, test_flags = np_frontend_helpers.handle_where_and_array_bools( - where=where, - input_dtype=input_dtypes, - test_flags=test_flags, - ) - helpers.test_frontend_function( + np_frontend_helpers.test_frontend_function( input_dtypes=input_dtypes, backend_to_test=backend_fw, frontend=frontend, @@ -69,38 +96,28 @@ def test_numpy_sum( x=x[0], axis=axis, dtype=dtype, - keepdims=keep_dims, - initial=initial, - where=where, ) -# prod +# cumsum @handle_frontend_test( - fn_tree="numpy.prod", - dtype_x_axis_dtype=_get_castable_dtypes_values(use_where=True), - keep_dims=st.booleans(), - initial=st.one_of(st.floats(min_value=-100, max_value=100)), + fn_tree="numpy.cumsum", + dtype_and_x_axis_dtype=_get_castable_dtypes_values(), ) -def test_numpy_prod( - dtype_x_axis_dtype, - keep_dims, - initial, +def test_numpy_cumsum( + dtype_and_x_axis_dtype, frontend, test_flags, fn_tree, backend_fw, on_device, ): - input_dtypes, x, axis, dtype, where = dtype_x_axis_dtype + input_dtypes, x, axis, dtype = dtype_and_x_axis_dtype + # ToDo: set as_variable_flags as the parameter generated by test_cumprod once + # this issue is marked as completed https://github.com/pytorch/pytorch/issues/75733 if backend_fw == "torch": assume(not test_flags.as_variable[0]) - where, input_dtypes, test_flags = np_frontend_helpers.handle_where_and_array_bools( - where=where, - input_dtype=input_dtypes, - test_flags=test_flags, - ) - helpers.test_frontend_function( + np_frontend_helpers.test_frontend_function( input_dtypes=input_dtypes, backend_to_test=backend_fw, frontend=frontend, @@ -110,71 +127,75 @@ def test_numpy_prod( x=x[0], axis=axis, dtype=dtype, - keepdims=keep_dims, - initial=initial, - where=where, ) -# cumsum +# diff @handle_frontend_test( - fn_tree="numpy.cumsum", - dtype_and_x_axis_dtype=_get_castable_dtypes_values(), + fn_tree="numpy.diff", + dtype_x_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=1, + valid_axis=True, + force_int_axis=True, + ), ) -def test_numpy_cumsum( - dtype_and_x_axis_dtype, +def test_numpy_diff( + dtype_x_axis, frontend, test_flags, fn_tree, backend_fw, on_device, ): - input_dtypes, x, axis, dtype = dtype_and_x_axis_dtype - # ToDo: set as_variable_flags as the parameter generated by test_cumprod once - # this issue is marked as completed https://github.com/pytorch/pytorch/issues/75733 - if backend_fw == "torch": - assume(not test_flags.as_variable[0]) + input_dtype, x, axis = dtype_x_axis np_frontend_helpers.test_frontend_function( - input_dtypes=input_dtypes, - backend_to_test=backend_fw, + input_dtypes=input_dtype, frontend=frontend, test_flags=test_flags, + backend_to_test=backend_fw, fn_tree=fn_tree, on_device=on_device, x=x[0], axis=axis, - dtype=dtype, ) -# cumprod +# ediff1d @handle_frontend_test( - fn_tree="numpy.cumprod", - dtype_x_axis_dtypes=_get_castable_dtypes_values(), + fn_tree="numpy.ediff1d", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), min_num_dims=1, max_num_dims=1 + ), + to_end=st.one_of( + st.integers(-1, 10), st.lists(st.integers(-1, 10), min_size=1, max_size=10) + ), + to_begin=st.one_of( + st.integers(-1, 10), st.lists(st.integers(-1, 10), min_size=1, max_size=10) + ), ) -def test_numpy_cumprod( - dtype_x_axis_dtypes, +def test_numpy_ediff1d( + *, + dtype_and_x, + on_device, + fn_tree, frontend, test_flags, - fn_tree, backend_fw, - on_device, + to_end, + to_begin, ): - input_dtypes, x, axis, dtype = dtype_x_axis_dtypes - # ToDo: set as_variable_flags as the parameter generated by test_cumprod once - # this issue is marked as completed https://github.com/pytorch/pytorch/issues/75733 - if backend_fw == "torch": - assume(not test_flags.as_variable[0]) - np_frontend_helpers.test_frontend_function( - input_dtypes=input_dtypes, + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, - test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], - axis=axis, - dtype=dtype, + test_flags=test_flags, + ary=x[0], + to_end=to_end, + to_begin=to_begin, ) @@ -318,101 +339,88 @@ def test_numpy_nansum( ) -# diff +# prod @handle_frontend_test( - fn_tree="numpy.diff", - dtype_x_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("valid"), - min_num_dims=1, - valid_axis=True, - force_int_axis=True, - ), + fn_tree="numpy.prod", + dtype_x_axis_dtype=_get_castable_dtypes_values(use_where=True), + keep_dims=st.booleans(), + initial=st.one_of(st.floats(min_value=-100, max_value=100)), ) -def test_numpy_diff( - dtype_x_axis, +def test_numpy_prod( + dtype_x_axis_dtype, + keep_dims, + initial, frontend, test_flags, fn_tree, backend_fw, on_device, ): - input_dtype, x, axis = dtype_x_axis - np_frontend_helpers.test_frontend_function( - input_dtypes=input_dtype, - frontend=frontend, + input_dtypes, x, axis, dtype, where = dtype_x_axis_dtype + if backend_fw == "torch": + assume(not test_flags.as_variable[0]) + where, input_dtypes, test_flags = np_frontend_helpers.handle_where_and_array_bools( + where=where, + input_dtype=input_dtypes, test_flags=test_flags, + ) + helpers.test_frontend_function( + input_dtypes=input_dtypes, backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, x=x[0], axis=axis, + dtype=dtype, + keepdims=keep_dims, + initial=initial, + where=where, ) -# ediff1d +# sum @handle_frontend_test( - fn_tree="numpy.ediff1d", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), min_num_dims=1, max_num_dims=1 - ), - to_end=st.one_of( - st.integers(-1, 10), st.lists(st.integers(-1, 10), min_size=1, max_size=10) - ), - to_begin=st.one_of( - st.integers(-1, 10), st.lists(st.integers(-1, 10), min_size=1, max_size=10) - ), + fn_tree="numpy.sum", + dtype_x_axis_dtype=_get_castable_dtypes_values(use_where=True), + keep_dims=st.booleans(), + initial=st.one_of(st.floats(min_value=-100, max_value=100)), ) -def test_numpy_ediff1d( - *, - dtype_and_x, - on_device, - fn_tree, +def test_numpy_sum( + dtype_x_axis_dtype, + keep_dims, + initial, frontend, test_flags, + fn_tree, backend_fw, - to_end, - to_begin, + on_device, ): - input_dtype, x = dtype_and_x + input_dtypes, x, axis, dtype, where = dtype_x_axis_dtype + if backend_fw == "torch": + assume(not test_flags.as_variable[0]) + where, input_dtypes, test_flags = np_frontend_helpers.handle_where_and_array_bools( + where=where, + input_dtype=input_dtypes, + test_flags=test_flags, + ) helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=input_dtypes, backend_to_test=backend_fw, frontend=frontend, + test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - test_flags=test_flags, - ary=x[0], - to_end=to_end, - to_begin=to_begin, + x=x[0], + axis=axis, + dtype=dtype, + keepdims=keep_dims, + initial=initial, + where=where, ) -# trapz -@st.composite -def _either_x_dx(draw): - rand = (draw(st.integers(min_value=0, max_value=1)),) - if rand == 0: - either_x_dx = draw( - helpers.dtype_and_values( - avaliable_dtypes=st.shared( - helpers.get_dtypes("float"), key="trapz_dtype" - ), - min_value=-100, - max_value=100, - min_num_dims=1, - max_num_dims=3, - min_dim_size=1, - max_dim_size=3, - ) - ) - return rand, either_x_dx - else: - either_x_dx = draw( - st.floats(min_value=-10, max_value=10), - ) - return rand, either_x_dx - - @handle_frontend_test( fn_tree="numpy.trapz", dtype_values_axis=helpers.dtype_values_axis( diff --git a/ivy_tests/test_ivy/test_frontends/test_numpy/test_mathematical_functions/test_trigonometric_functions.py b/ivy_tests/test_ivy/test_frontends/test_numpy/test_mathematical_functions/test_trigonometric_functions.py index fd895bc87d69a..316962589a3b1 100644 --- a/ivy_tests/test_ivy/test_frontends/test_numpy/test_mathematical_functions/test_trigonometric_functions.py +++ b/ivy_tests/test_ivy/test_frontends/test_numpy/test_mathematical_functions/test_trigonometric_functions.py @@ -4,9 +4,9 @@ from ivy_tests.test_ivy.helpers import handle_frontend_test -# cos +# arccos @handle_frontend_test( - fn_tree="numpy.cos", + fn_tree="numpy.arccos", dtypes_values_casting=np_frontend_helpers.dtypes_values_casting_dtype( arr_func=[ lambda: helpers.dtype_and_values( @@ -16,10 +16,10 @@ ), where=np_frontend_helpers.where(), number_positional_args=np_frontend_helpers.get_num_positional_args_ufunc( - fn_name="cos" + fn_name="arccos" ), ) -def test_numpy_cos( +def test_numpy_arccos( dtypes_values_casting, where, frontend, @@ -41,7 +41,8 @@ def test_numpy_cos( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - atol=1e-03, + rtol=1e-2, + atol=1e-2, x=x[0], out=None, where=where, @@ -52,9 +53,9 @@ def test_numpy_cos( ) -# tan +# arccosh @handle_frontend_test( - fn_tree="numpy.tan", + fn_tree="numpy.arccosh", dtypes_values_casting=np_frontend_helpers.dtypes_values_casting_dtype( arr_func=[ lambda: helpers.dtype_and_values( @@ -64,16 +65,16 @@ def test_numpy_cos( ), where=np_frontend_helpers.where(), number_positional_args=np_frontend_helpers.get_num_positional_args_ufunc( - fn_name="tan" + fn_name="arccosh" ), ) -def test_numpy_tan( +def test_numpy_arccosh( dtypes_values_casting, where, frontend, + backend_fw, test_flags, fn_tree, - backend_fw, on_device, ): input_dtypes, x, casting, dtype = dtypes_values_casting @@ -84,13 +85,12 @@ def test_numpy_tan( ) np_frontend_helpers.test_frontend_function( input_dtypes=input_dtypes, - backend_to_test=backend_fw, frontend=frontend, + backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - rtol=1e-02, - atol=1e-02, + atol=1e-2, x=x[0], out=None, where=where, @@ -150,9 +150,9 @@ def test_numpy_arcsin( ) -# arccos +# arctan @handle_frontend_test( - fn_tree="numpy.arccos", + fn_tree="numpy.arctan", dtypes_values_casting=np_frontend_helpers.dtypes_values_casting_dtype( arr_func=[ lambda: helpers.dtype_and_values( @@ -162,10 +162,10 @@ def test_numpy_arcsin( ), where=np_frontend_helpers.where(), number_positional_args=np_frontend_helpers.get_num_positional_args_ufunc( - fn_name="arccos" + fn_name="arctan" ), ) -def test_numpy_arccos( +def test_numpy_arctan( dtypes_values_casting, where, frontend, @@ -187,8 +187,7 @@ def test_numpy_arccos( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - rtol=1e-2, - atol=1e-2, + atol=1e-3, x=x[0], out=None, where=where, @@ -199,9 +198,9 @@ def test_numpy_arccos( ) -# arctan +# cos @handle_frontend_test( - fn_tree="numpy.arctan", + fn_tree="numpy.cos", dtypes_values_casting=np_frontend_helpers.dtypes_values_casting_dtype( arr_func=[ lambda: helpers.dtype_and_values( @@ -211,10 +210,10 @@ def test_numpy_arccos( ), where=np_frontend_helpers.where(), number_positional_args=np_frontend_helpers.get_num_positional_args_ufunc( - fn_name="arctan" + fn_name="cos" ), ) -def test_numpy_arctan( +def test_numpy_cos( dtypes_values_casting, where, frontend, @@ -236,7 +235,7 @@ def test_numpy_arctan( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - atol=1e-3, + atol=1e-03, x=x[0], out=None, where=where, @@ -247,9 +246,9 @@ def test_numpy_arctan( ) -# arccosh +# deg2rad @handle_frontend_test( - fn_tree="numpy.arccosh", + fn_tree="numpy.deg2rad", dtypes_values_casting=np_frontend_helpers.dtypes_values_casting_dtype( arr_func=[ lambda: helpers.dtype_and_values( @@ -259,16 +258,16 @@ def test_numpy_arctan( ), where=np_frontend_helpers.where(), number_positional_args=np_frontend_helpers.get_num_positional_args_ufunc( - fn_name="arccosh" + fn_name="deg2rad" ), ) -def test_numpy_arccosh( +def test_numpy_deg2rad( dtypes_values_casting, where, frontend, - backend_fw, test_flags, fn_tree, + backend_fw, on_device, ): input_dtypes, x, casting, dtype = dtypes_values_casting @@ -279,11 +278,12 @@ def test_numpy_arccosh( ) np_frontend_helpers.test_frontend_function( input_dtypes=input_dtypes, - frontend=frontend, backend_to_test=backend_fw, + frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, + rtol=1e-2, atol=1e-2, x=x[0], out=None, @@ -295,9 +295,9 @@ def test_numpy_arccosh( ) -# rad2deg +# degrees @handle_frontend_test( - fn_tree="numpy.rad2deg", + fn_tree="numpy.degrees", dtypes_values_casting=np_frontend_helpers.dtypes_values_casting_dtype( arr_func=[ lambda: helpers.dtype_and_values( @@ -307,10 +307,10 @@ def test_numpy_arccosh( ), where=np_frontend_helpers.where(), number_positional_args=np_frontend_helpers.get_num_positional_args_ufunc( - fn_name="rad2deg" + fn_name="degrees" ), ) -def test_numpy_rad2deg( +def test_numpy_degrees( dtypes_values_casting, where, frontend, @@ -333,6 +333,7 @@ def test_numpy_rad2deg( fn_tree=fn_tree, on_device=on_device, rtol=1e-2, + atol=1e-2, x=x[0], out=None, where=where, @@ -343,9 +344,9 @@ def test_numpy_rad2deg( ) -# deg2rad +# rad2deg @handle_frontend_test( - fn_tree="numpy.deg2rad", + fn_tree="numpy.rad2deg", dtypes_values_casting=np_frontend_helpers.dtypes_values_casting_dtype( arr_func=[ lambda: helpers.dtype_and_values( @@ -355,10 +356,10 @@ def test_numpy_rad2deg( ), where=np_frontend_helpers.where(), number_positional_args=np_frontend_helpers.get_num_positional_args_ufunc( - fn_name="deg2rad" + fn_name="rad2deg" ), ) -def test_numpy_deg2rad( +def test_numpy_rad2deg( dtypes_values_casting, where, frontend, @@ -381,7 +382,6 @@ def test_numpy_deg2rad( fn_tree=fn_tree, on_device=on_device, rtol=1e-2, - atol=1e-2, x=x[0], out=None, where=where, @@ -392,9 +392,9 @@ def test_numpy_deg2rad( ) -# degrees +# tan @handle_frontend_test( - fn_tree="numpy.degrees", + fn_tree="numpy.tan", dtypes_values_casting=np_frontend_helpers.dtypes_values_casting_dtype( arr_func=[ lambda: helpers.dtype_and_values( @@ -404,10 +404,10 @@ def test_numpy_deg2rad( ), where=np_frontend_helpers.where(), number_positional_args=np_frontend_helpers.get_num_positional_args_ufunc( - fn_name="degrees" + fn_name="tan" ), ) -def test_numpy_degrees( +def test_numpy_tan( dtypes_values_casting, where, frontend, @@ -429,8 +429,8 @@ def test_numpy_degrees( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - rtol=1e-2, - atol=1e-2, + rtol=1e-02, + atol=1e-02, x=x[0], out=None, where=where, diff --git a/ivy_tests/test_ivy/test_frontends/test_numpy/test_matrix/test_methods.py b/ivy_tests/test_ivy/test_frontends/test_numpy/test_matrix/test_methods.py index bb578dec5337f..ad62b2590adab 100644 --- a/ivy_tests/test_ivy/test_frontends/test_numpy/test_matrix/test_methods.py +++ b/ivy_tests/test_ivy/test_frontends/test_numpy/test_matrix/test_methods.py @@ -13,16 +13,8 @@ CLASS_TREE = "ivy.functional.frontends.numpy.matrix" -def _to_string_matrix(num_matrix): - str_matrix = "" - for i, row in enumerate(num_matrix): - for j, elem in enumerate(row): - str_matrix += str(elem) - if j < num_matrix.shape[1] - 1: - str_matrix += " " - elif i < num_matrix.shape[0] - 1: - str_matrix += "; " - return str_matrix +# --- Helpers --- # +# --------------- # def _get_x_matrix(x, to_str): @@ -51,6 +43,22 @@ def _property_helper(draw): return data, data_gt +def _to_string_matrix(num_matrix): + str_matrix = "" + for i, row in enumerate(num_matrix): + for j, elem in enumerate(row): + str_matrix += str(elem) + if j < num_matrix.shape[1] - 1: + str_matrix += " " + elif i < num_matrix.shape[0] - 1: + str_matrix += "; " + return str_matrix + + +# --- Main --- # +# ------------ # + + @handle_frontend_test( fn_tree="numpy.add", # dummy fn_tree matrices=_property_helper(), @@ -140,38 +148,11 @@ def test_numpy_dtype(matrices): ) -@handle_frontend_test( - fn_tree="numpy.add", # dummy fn_tree - matrices=_property_helper(), -) -def test_numpy_ndim(matrices): - data, data_gt = matrices - ivy.utils.assertions.check_equal(data.ndim, data_gt.ndim, as_array=False) - - -@handle_frontend_test( - fn_tree="numpy.add", # dummy fn_tree - matrices=_property_helper(), -) -def test_numpy_shape(matrices): - data, data_gt = matrices - ivy.utils.assertions.check_equal(data.shape, data_gt.shape, as_array=False) - - -@handle_frontend_test( - fn_tree="numpy.add", # dummy fn_tree - matrices=_property_helper(), -) -def test_numpy_size(matrices): - data, data_gt = matrices - ivy.utils.assertions.check_equal(data.size, data_gt.size, as_array=False) - - -# argmax +# any @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.matrix", - method_name="argmax", + method_name="any", dtype_x_axis=helpers.dtype_values_axis( available_dtypes=helpers.get_dtypes("numeric"), min_num_dims=2, @@ -183,7 +164,7 @@ def test_numpy_size(matrices): ), to_str=st.booleans(), ) -def test_numpy_matrix_argmax( +def test_numpy_matrix_any( dtype_x_axis, to_str, init_flags, @@ -206,7 +187,7 @@ def test_numpy_matrix_argmax( "data": x, "dtype": input_dtype[0], }, - method_input_dtypes=input_dtype, + method_input_dtypes=[], method_all_as_kwargs_np={ "axis": axis, }, @@ -216,11 +197,11 @@ def test_numpy_matrix_argmax( ) -# any +# argmax @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.matrix", - method_name="any", + method_name="argmax", dtype_x_axis=helpers.dtype_values_axis( available_dtypes=helpers.get_dtypes("numeric"), min_num_dims=2, @@ -232,7 +213,7 @@ def test_numpy_matrix_argmax( ), to_str=st.booleans(), ) -def test_numpy_matrix_any( +def test_numpy_matrix_argmax( dtype_x_axis, to_str, init_flags, @@ -255,7 +236,7 @@ def test_numpy_matrix_any( "data": x, "dtype": input_dtype[0], }, - method_input_dtypes=[], + method_input_dtypes=input_dtype, method_all_as_kwargs_np={ "axis": axis, }, @@ -263,3 +244,30 @@ def test_numpy_matrix_any( frontend_method_data=frontend_method_data, on_device=on_device, ) + + +@handle_frontend_test( + fn_tree="numpy.add", # dummy fn_tree + matrices=_property_helper(), +) +def test_numpy_ndim(matrices): + data, data_gt = matrices + ivy.utils.assertions.check_equal(data.ndim, data_gt.ndim, as_array=False) + + +@handle_frontend_test( + fn_tree="numpy.add", # dummy fn_tree + matrices=_property_helper(), +) +def test_numpy_shape(matrices): + data, data_gt = matrices + ivy.utils.assertions.check_equal(data.shape, data_gt.shape, as_array=False) + + +@handle_frontend_test( + fn_tree="numpy.add", # dummy fn_tree + matrices=_property_helper(), +) +def test_numpy_size(matrices): + data, data_gt = matrices + ivy.utils.assertions.check_equal(data.size, data_gt.size, as_array=False) diff --git a/ivy_tests/test_ivy/test_frontends/test_numpy/test_ndarray/test_ndarray.py b/ivy_tests/test_ivy/test_frontends/test_numpy/test_ndarray/test_ndarray.py index 482eae0ee5897..064d2135f358f 100644 --- a/ivy_tests/test_ivy/test_frontends/test_numpy/test_ndarray/test_ndarray.py +++ b/ivy_tests/test_ivy/test_frontends/test_numpy/test_ndarray/test_ndarray.py @@ -32,93 +32,101 @@ CLASS_TREE = "ivy.functional.frontends.numpy.ndarray" -@given( - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid", prune_function=False), - ret_shape=True, - ), -) -def test_numpy_ndarray_ivy_array( - dtype_x, - frontend, - backend_fw, -): - dtype, data, shape = dtype_x - with BackendHandler.update_backend(backend_fw) as ivy_backend: - x = ivy_backend.functional.frontends.numpy.ndarray(shape, dtype[0]) - x.ivy_array = data[0] - ret = helpers.flatten_and_to_np(ret=x.ivy_array.data, backend=backend_fw) - ret_gt = helpers.flatten_and_to_np(ret=data[0], backend=frontend) - helpers.value_test( - ret_np_flat=ret, - ret_np_from_gt_flat=ret_gt, - backend=backend_fw, - ground_truth_backend="numpy", - ) +# --- Helpers --- # +# --------------- # -@given( - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid", prune_function=False), - ret_shape=True, - ), -) -def test_numpy_ndarray_dtype(dtype_x, backend_fw, frontend): - dtype, data, shape = dtype_x - with BackendHandler.update_backend(backend_fw) as ivy_backend: - x = ivy_backend.functional.frontends.numpy.ndarray(shape, dtype[0]) - x.ivy_array = data[0] - ivy_backend.utils.assertions.check_equal( - x.dtype, ivy.Dtype(dtype[0]), as_array=False +# item +@st.composite +def _item_helper(draw): + dtype = draw( + helpers.array_dtypes( + num_arrays=1, + available_dtypes=helpers.get_dtypes("numeric"), ) + ) + shape = draw( + helpers.get_shape( + min_num_dims=1, + max_num_dims=3, + min_dim_size=1, + max_dim_size=10, + ) + ) + array = draw( + helpers.array_values( + dtype=dtype[0], + shape=shape, + large_abs_safety_factor=2, + small_abs_safety_factor=2, + ) + ) + index = () + for s in shape: + index += (draw(st.integers(min_value=-s + 1, max_value=s - 1)),) -@given( - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid", prune_function=False), - ret_shape=True, - ), -) -def test_numpy_ndarray_shape( - dtype_x, - backend_fw, -): - dtype, data, shape = dtype_x - with BackendHandler.update_backend(backend_fw) as ivy_backend: - x = ivy_backend.functional.frontends.numpy.ndarray(shape, dtype[0]) - x.ivy_array = data[0] - ivy_backend.utils.assertions.check_equal( - x.shape, ivy.Shape(shape), as_array=False + index_samples = [index, draw(helpers.ints(min_value=0, max_value=array.size - 1))] + + if array.size == 1: + index_samples.append(None) + + sampled_index = draw(st.sampled_from(index_samples)) + + if sampled_index is None: + method_all_as_kwargs_np = {} + num_positional_args = 0 + else: + method_all_as_kwargs_np = {"args": sampled_index} + num_positional_args = 1 + + return dtype, array, method_all_as_kwargs_np, num_positional_args + + +# swapaxes +@st.composite +def dtype_values_and_axes(draw): + dtype, x, x_shape = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + min_num_dims=1, + max_num_dims=5, + ret_shape=True, + ) + ) + axis1, axis2 = draw( + helpers.get_axis( + shape=x_shape, + sort_values=False, + unique=True, + min_size=2, + max_size=2, + force_tuple=True, ) + ) + return dtype, x, axis1, axis2 -@given( - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid", prune_function=False), - ret_shape=True, - ), -) -def test_numpy_ndarray_property_ndim(dtype_x, backend_fw): - dtype, data, shape = dtype_x - with BackendHandler.update_backend(backend_fw) as ivy_backend: - x = ivy_backend.functional.frontends.numpy.ndarray(shape, dtype[0]) - x.ivy_array = data[0] - ivy_backend.utils.assertions.check_equal(x.ndim, data[0].ndim, as_array=False) +@st.composite +def dtypes_x_reshape(draw): + dtypes, x = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + shape=helpers.get_shape( + allow_none=False, + min_num_dims=1, + max_num_dims=5, + min_dim_size=1, + max_dim_size=10, + ), + ) + ) + shape = draw(helpers.reshape_shapes(shape=np.array(x).shape)) + return dtypes, x, shape -@given( - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid", prune_function=False), - ret_shape=True, - ), -) -def test_numpy_ndarray_size( - dtype_x, -): - dtype, data, shape = dtype_x - x = ndarray(shape, dtype[0]) - x.ivy_array = data[0] - ivy.utils.assertions.check_equal(x.size, data[0].size, as_array=False) +# --- Main --- # +# ------------ # @given( @@ -151,46 +159,55 @@ def test_numpy_T( ) -@given( - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric", prune_function=False), - num_arrays=1, - ret_shape=True, - ) +@handle_frontend_method( + class_tree=CLASS_TREE, + init_tree="numpy.array", + method_name="__abs__", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + min_num_dims=1, + allow_inf=False, + large_abs_safety_factor=4, + safety_factor_scale="linear", + ), ) -def test_numpy_ndarray_flat(dtype_x, backend_fw): - dtype, data, shape = dtype_x - - with BackendHandler.update_backend(backend_fw) as ivy_backend: - x = ivy_backend.functional.frontends.numpy.ndarray(shape, dtype[0]) - x.ivy_array = data[0] +def test_numpy___abs__( + dtype_and_x, + frontend_method_data, + init_flags, + method_flags, + backend_fw, + frontend, + on_device, +): + input_dtypes, x = dtype_and_x - flat_ivy = x.flat - flat_ivy = flat_ivy.ivy_array.to_numpy() - flat_generated = ivy_backend.to_numpy(data[0]).flatten() - ivy_backend.utils.assertions.check_equal( - flat_ivy, flat_generated, as_array=True - ) + helpers.test_frontend_method( + init_input_dtypes=input_dtypes, + backend_to_test=backend_fw, + init_all_as_kwargs_np={ + "object": x[0], + }, + method_input_dtypes=input_dtypes, + method_all_as_kwargs_np={}, + frontend=frontend, + frontend_method_data=frontend_method_data, + init_flags=init_flags, + method_flags=method_flags, + on_device=on_device, + ) @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="astype", - dtypes_values_casting=np_frontend_helpers.dtypes_values_casting_dtype( - arr_func=[ - lambda: helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - ) - ], + method_name="__add__", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2 ), - order=st.sampled_from(["C", "F", "A", "K"]), - copy=st.booleans(), ) -def test_numpy_ndarray_astype( - dtypes_values_casting, - order, - copy, +def test_numpy___add__( + dtype_and_x, frontend_method_data, init_flags, method_flags, @@ -198,19 +215,17 @@ def test_numpy_ndarray_astype( frontend, on_device, ): - input_dtypes, x, casting, dtype = dtypes_values_casting + input_dtypes, xs = dtype_and_x + helpers.test_frontend_method( init_input_dtypes=input_dtypes, backend_to_test=backend_fw, init_all_as_kwargs_np={ - "object": x[0], + "object": xs[0], }, method_input_dtypes=input_dtypes, method_all_as_kwargs_np={ - "dtype": dtype if dtype else input_dtypes[0], - "order": order, - "casting": casting, - "copy": copy, + "value": xs[1], }, frontend=frontend, frontend_method_data=frontend_method_data, @@ -223,21 +238,14 @@ def test_numpy_ndarray_astype( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="argmax", - dtype_x_axis=helpers.dtype_values_axis( - available_dtypes=st.one_of( - helpers.get_dtypes("numeric"), - ), - min_axis=-1, - max_axis=0, - min_num_dims=1, - force_int_axis=True, + method_name="__and__", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=("bool",), + num_arrays=2, ), - keep_dims=st.booleans(), ) -def test_numpy_ndarray_argmax( - dtype_x_axis, - keep_dims, +def test_numpy___and__( + dtype_and_x, frontend_method_data, init_flags, method_flags, @@ -245,17 +253,17 @@ def test_numpy_ndarray_argmax( frontend, on_device, ): - input_dtypes, x, axis = dtype_x_axis + input_dtypes, xs = dtype_and_x + helpers.test_frontend_method( init_input_dtypes=input_dtypes, backend_to_test=backend_fw, init_all_as_kwargs_np={ - "object": x[0], + "object": xs[0], }, method_input_dtypes=input_dtypes, method_all_as_kwargs_np={ - "axis": axis, - "keepdims": keep_dims, + "value": xs[1], }, frontend=frontend, frontend_method_data=frontend_method_data, @@ -265,34 +273,55 @@ def test_numpy_ndarray_argmax( ) -@st.composite -def dtypes_x_reshape(draw): - dtypes, x = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - shape=helpers.get_shape( - allow_none=False, - min_num_dims=1, - max_num_dims=5, - min_dim_size=1, - max_dim_size=10, - ), - ) +# __array__ +@handle_frontend_method( + class_tree=CLASS_TREE, + init_tree="numpy.array", + method_name="__array__", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + ), +) +def test_numpy___array__( + dtype_and_x, + frontend_method_data, + init_flags, + method_flags, + backend_fw, + frontend, + on_device, +): + input_dtypes, x = dtype_and_x + helpers.test_frontend_method( + init_input_dtypes=input_dtypes, + backend_to_test=backend_fw, + init_all_as_kwargs_np={ + "object": x[0], + }, + method_input_dtypes=input_dtypes, + method_all_as_kwargs_np={ + "dtype": np.dtype(input_dtypes[0]), + }, + init_flags=init_flags, + method_flags=method_flags, + frontend=frontend, + frontend_method_data=frontend_method_data, + on_device=on_device, ) - shape = draw(helpers.reshape_shapes(shape=np.array(x).shape)) - return dtypes, x, shape +# __array_wrap__ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="reshape", - dtypes_x_shape=dtypes_x_reshape(), - order=st.sampled_from(["C", "F", "A"]), + method_name="__array_wrap__", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + num_arrays=2, + ), ) -def test_numpy_ndarray_reshape( - dtypes_x_shape, - order, +def test_numpy___array_wrap__( + dtype_and_x, frontend_method_data, init_flags, method_flags, @@ -300,22 +329,22 @@ def test_numpy_ndarray_reshape( frontend, on_device, ): - input_dtypes, x, shape = dtypes_x_shape + input_dtypes, x = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtypes, backend_to_test=backend_fw, init_all_as_kwargs_np={ "object": x[0], }, - method_input_dtypes=[], + method_input_dtypes=input_dtypes, method_all_as_kwargs_np={ - "newshape": shape, - "order": order, + "array": x[1], + "context": None, }, - frontend=frontend, - frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, + frontend=frontend, + frontend_method_data=frontend_method_data, on_device=on_device, ) @@ -323,16 +352,14 @@ def test_numpy_ndarray_reshape( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="transpose", - array_and_axes=np_frontend_helpers._array_and_axes_permute_helper( - min_num_dims=2, - max_num_dims=5, - min_dim_size=2, - max_dim_size=10, + method_name="__bool__", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + max_dim_size=1, ), ) -def test_numpy_ndarray_transpose( - array_and_axes, +def test_numpy___bool__( + dtype_and_x, frontend_method_data, init_flags, method_flags, @@ -340,17 +367,16 @@ def test_numpy_ndarray_transpose( frontend, on_device, ): - array, input_dtypes, axes = array_and_axes + input_dtypes, x = dtype_and_x + helpers.test_frontend_method( init_input_dtypes=input_dtypes, backend_to_test=backend_fw, init_all_as_kwargs_np={ - "object": np.array(array), + "object": x[0], }, method_input_dtypes=input_dtypes, - method_all_as_kwargs_np={ - "axes": axes, - }, + method_all_as_kwargs_np={}, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -359,56 +385,71 @@ def test_numpy_ndarray_transpose( ) -# swapaxes -@st.composite -def dtype_values_and_axes(draw): - dtype, x, x_shape = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - min_num_dims=1, - max_num_dims=5, - ret_shape=True, - ) - ) - axis1, axis2 = draw( - helpers.get_axis( - shape=x_shape, - sort_values=False, - unique=True, - min_size=2, - max_size=2, - force_tuple=True, - ) +@handle_frontend_method( + class_tree=CLASS_TREE, + init_tree="numpy.array", + method_name="__complex__", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + min_dim_size=1, + max_dim_size=1, + ), +) +def test_numpy___complex__( + dtype_and_x, + frontend_method_data, + init_flags, + method_flags, + backend_fw, + frontend, + on_device, +): + input_dtypes, xs = dtype_and_x + + helpers.test_frontend_method( + init_input_dtypes=input_dtypes, + backend_to_test=backend_fw, + method_input_dtypes=input_dtypes, + init_all_as_kwargs_np={ + "object": xs[0], + }, + method_all_as_kwargs_np={}, + frontend=frontend, + frontend_method_data=frontend_method_data, + init_flags=init_flags, + method_flags=method_flags, + on_device=on_device, ) - return dtype, x, axis1, axis2 @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="swapaxes", - dtype_x_and_axes=dtype_values_and_axes(), + method_name="__contains__", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + ), ) -def test_numpy_ndarray_swapaxes( - dtype_x_and_axes, - frontend, +def test_numpy___contains__( + dtype_and_x, frontend_method_data, init_flags, method_flags, backend_fw, + frontend, on_device, ): - input_dtypes, x, axis1, axis2 = dtype_x_and_axes + input_dtypes, xs = dtype_and_x + key = np.asarray(xs[0].reshape(-1)[0]) helpers.test_frontend_method( init_input_dtypes=input_dtypes, backend_to_test=backend_fw, method_input_dtypes=input_dtypes, init_all_as_kwargs_np={ - "object": x[0], + "object": xs[0], }, method_all_as_kwargs_np={ - "axis1": axis1, - "axis2": axis2, + "key": key, }, frontend=frontend, frontend_method_data=frontend_method_data, @@ -418,27 +459,17 @@ def test_numpy_ndarray_swapaxes( ) -# any @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="any", - dtype_x_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("valid"), + method_name="__copy__", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), min_num_dims=1, - max_num_dims=5, - min_dim_size=1, - valid_axis=True, - force_int_axis=True, - allow_neg_axes=True, ), - keepdims=st.booleans(), - where=np_frontend_helpers.where(), ) -def test_numpy_ndarray_any( - dtype_x_axis, - keepdims, - where, +def test_numpy___copy__( + dtype_and_x, frontend_method_data, init_flags, method_flags, @@ -446,31 +477,16 @@ def test_numpy_ndarray_any( frontend, on_device, ): - init_input_dtypes, x, axis = dtype_x_axis - ( - where, - method_input_dtypes, - method_flags, - ) = np_frontend_helpers.handle_where_and_array_bools( - where=[where[0][0]] if isinstance(where, list) else where, - input_dtype=init_input_dtypes, - test_flags=method_flags, - ) + input_dtypes, x = dtype_and_x helpers.test_frontend_method( - init_input_dtypes=init_input_dtypes, + init_input_dtypes=input_dtypes, backend_to_test=backend_fw, init_all_as_kwargs_np={ "object": x[0], }, - method_input_dtypes=method_input_dtypes[1:], - method_all_as_kwargs_np={ - "axis": axis, - "dtype": bool, - "out": None, - "keepdims": keepdims, - "where": where, - }, + method_input_dtypes=input_dtypes, + method_all_as_kwargs_np={}, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -482,23 +498,14 @@ def test_numpy_ndarray_any( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="all", - dtype_x_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("valid", full=True), + method_name="__deepcopy__", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), min_num_dims=1, - max_num_dims=5, - min_dim_size=1, - valid_axis=True, - force_int_axis=True, - allow_neg_axes=True, ), - keepdims=st.booleans(), - where=np_frontend_helpers.where(), ) -def test_numpy_ndarray_all( - dtype_x_axis, - keepdims, - where, +def test_numpy___deepcopy__( + dtype_and_x, frontend_method_data, init_flags, method_flags, @@ -506,30 +513,17 @@ def test_numpy_ndarray_all( frontend, on_device, ): - init_input_dtypes, x, axis = dtype_x_axis - ( - where, - method_input_dtypes, - method_flags, - ) = np_frontend_helpers.handle_where_and_array_bools( - where=[where[0][0]] if isinstance(where, list) else where, - input_dtype=init_input_dtypes, - test_flags=method_flags, - ) + input_dtypes, x = dtype_and_x helpers.test_frontend_method( - init_input_dtypes=init_input_dtypes, + init_input_dtypes=input_dtypes, backend_to_test=backend_fw, init_all_as_kwargs_np={ "object": x[0], }, - method_input_dtypes=method_input_dtypes[1:], + method_input_dtypes=input_dtypes, method_all_as_kwargs_np={ - "axis": axis, - "dtype": bool, - "out": None, - "keepdims": keepdims, - "where": where, + "memo": {}, }, frontend=frontend, frontend_method_data=frontend_method_data, @@ -542,17 +536,14 @@ def test_numpy_ndarray_all( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="argsort", - dtype_x_axis=helpers.dtype_values_axis( + method_name="__eq__", + dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), - min_axis=-1, - max_axis=0, - min_num_dims=1, - force_int_axis=True, + num_arrays=2, ), ) -def test_numpy_ndarray_argsort( - dtype_x_axis, +def test_numpy___eq__( + dtype_and_x, frontend_method_data, init_flags, method_flags, @@ -560,23 +551,22 @@ def test_numpy_ndarray_argsort( frontend, on_device, ): - input_dtypes, x, axis = dtype_x_axis + input_dtypes, xs = dtype_and_x + helpers.test_frontend_method( init_input_dtypes=input_dtypes, backend_to_test=backend_fw, - method_input_dtypes=input_dtypes, init_all_as_kwargs_np={ - "object": x[0], + "object": xs[0], + }, + method_input_dtypes=input_dtypes, + method_all_as_kwargs_np={ + "value": xs[1], }, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, - method_all_as_kwargs_np={ - "axis": axis, - "kind": None, - "order": None, - }, on_device=on_device, ) @@ -584,17 +574,14 @@ def test_numpy_ndarray_argsort( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="mean", - dtype_x_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("float"), - min_axis=-1, - max_axis=0, - min_num_dims=1, - force_int_axis=True, + method_name="__float__", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + max_num_dims=0, ), ) -def test_numpy_ndarray_mean( - dtype_x_axis, +def test_numpy___float__( + dtype_and_x, frontend_method_data, init_flags, method_flags, @@ -602,45 +589,41 @@ def test_numpy_ndarray_mean( frontend, on_device, ): - input_dtypes, x, axis = dtype_x_axis + input_dtypes, xs = dtype_and_x + # Numpy doesn't support complex to float conversion + assume(not np.issubdtype(input_dtypes[0], np.complexfloating)) helpers.test_frontend_method( init_input_dtypes=input_dtypes, backend_to_test=backend_fw, method_input_dtypes=input_dtypes, init_all_as_kwargs_np={ - "object": x[0], - }, - method_all_as_kwargs_np={ - "axis": axis, - "dtype": "float64", - "out": None, + "object": xs[0], }, + method_all_as_kwargs_np={}, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, - rtol_=1e-2, - atol_=1e-2, on_device=on_device, ) +# __floordiv__ test @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="min", - dtype_x_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("numeric"), - min_axis=-1, - max_axis=0, - min_num_dims=1, - force_int_axis=True, + method_name="__floordiv__", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("integer"), + num_arrays=2, + allow_inf=False, + large_abs_safety_factor=4, + safety_factor_scale="linear", + shared_dtype=True, ), - keepdims=st.booleans(), ) -def test_numpy_ndarray_min( - dtype_x_axis, - keepdims, +def test_numpy___floordiv__( + dtype_and_x, frontend_method_data, init_flags, method_flags, @@ -648,40 +631,38 @@ def test_numpy_ndarray_min( frontend, on_device, ): - input_dtypes, x, axis = dtype_x_axis - + input_dtypes, xs = dtype_and_x + assume(not np.any(np.isclose(xs[1], 0))) helpers.test_frontend_method( init_input_dtypes=input_dtypes, backend_to_test=backend_fw, init_all_as_kwargs_np={ - "object": x[0], + "object": xs[0], }, method_input_dtypes=input_dtypes, method_all_as_kwargs_np={ - "axis": axis, - "keepdims": keepdims, + "value": xs[1], }, - frontend=frontend, - frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, + frontend_method_data=frontend_method_data, + frontend=frontend, + atol_=1, on_device=on_device, ) -# prod @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="prod", - dtype_x_axis_dtype=_get_castable_dtypes_values(use_where=True), - keep_dims=st.booleans(), - initial=st.one_of(st.floats(min_value=-100, max_value=100)), + method_name="__ge__", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, + ), ) -def test_numpy_ndarray_prod( - dtype_x_axis_dtype, - keep_dims, - initial, +def test_numpy___ge__( + dtype_and_x, frontend_method_data, init_flags, method_flags, @@ -689,33 +670,17 @@ def test_numpy_ndarray_prod( frontend, on_device, ): - input_dtypes, x, axis, dtype, where = dtype_x_axis_dtype - if ivy.current_backend_str() == "torch": - assume(not method_flags.as_variable[0]) + input_dtypes, xs = dtype_and_x - ( - where, - input_dtypes, - method_flags, - ) = np_frontend_helpers.handle_where_and_array_bools( - where=where, - input_dtype=input_dtypes, - test_flags=method_flags, - ) - where = ivy.array(where, dtype="bool") helpers.test_frontend_method( init_input_dtypes=input_dtypes, backend_to_test=backend_fw, init_all_as_kwargs_np={ - "object": x[0], + "object": xs[0], }, method_input_dtypes=input_dtypes, method_all_as_kwargs_np={ - "axis": axis, - "dtype": dtype, - "keepdims": keep_dims, - "initial": initial, - "where": where, + "value": xs[1], }, frontend=frontend, frontend_method_data=frontend_method_data, @@ -725,75 +690,54 @@ def test_numpy_ndarray_prod( ) -# sum @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="sum", - dtype_x_axis_dtype=_get_castable_dtypes_values(use_where=True), - keep_dims=st.booleans(), - initial=st.one_of(st.floats(min_value=-100, max_value=100)), + method_name="__gt__", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, + ), ) -def test_numpy_ndarray_sum( - dtype_x_axis_dtype, - keep_dims, - initial, +def test_numpy___gt__( + dtype_and_x, frontend_method_data, init_flags, method_flags, + backend_fw, frontend, on_device, - backend_fw, -): - input_dtypes, x, axis, dtype, where = dtype_x_axis_dtype - if ivy.current_backend_str() == "torch": - assume(not method_flags.as_variable[0]) - - where, input_dtypes, method_flags = ( - np_frontend_helpers.handle_where_and_array_bools( - where=where, - input_dtype=input_dtypes, - test_flags=method_flags, - ) - ) - where = ivy.array(where, dtype="bool") +): + input_dtypes, xs = dtype_and_x + helpers.test_frontend_method( init_input_dtypes=input_dtypes, + backend_to_test=backend_fw, init_all_as_kwargs_np={ - "object": x[0], + "object": xs[0], }, method_input_dtypes=input_dtypes, method_all_as_kwargs_np={ - "axis": axis, - "dtype": dtype, - "keepdims": keep_dims, - "initial": initial, - "where": where, + "value": xs[1], }, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, on_device=on_device, - backend_to_test=backend_fw, ) @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="argmin", - dtype_x_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("numeric"), - min_num_dims=1, - valid_axis=True, - force_int_axis=True, + method_name="__iadd__", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2 ), - keepdims=st.booleans(), ) -def test_numpy_ndarray_argmin( - dtype_x_axis, - keepdims, +def test_numpy___iadd__( + dtype_and_x, frontend_method_data, init_flags, method_flags, @@ -801,17 +745,17 @@ def test_numpy_ndarray_argmin( frontend, on_device, ): - input_dtypes, x, axis = dtype_x_axis + input_dtypes, xs = dtype_and_x + helpers.test_frontend_method( init_input_dtypes=input_dtypes, backend_to_test=backend_fw, + method_input_dtypes=input_dtypes, init_all_as_kwargs_np={ - "object": x[0], + "object": xs[0], }, - method_input_dtypes=input_dtypes, method_all_as_kwargs_np={ - "axis": axis, - "keepdims": keepdims, + "value": xs[1], }, frontend=frontend, frontend_method_data=frontend_method_data, @@ -824,11 +768,14 @@ def test_numpy_ndarray_argmin( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="clip", - input_and_ranges=_get_clip_inputs(), + method_name="__iand__", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=("bool",), + num_arrays=2, + ), ) -def test_numpy_ndarray_clip( - input_and_ranges, +def test_numpy___iand__( + dtype_and_x, frontend_method_data, init_flags, method_flags, @@ -836,17 +783,17 @@ def test_numpy_ndarray_clip( frontend, on_device, ): - input_dtypes, x, min, max = input_and_ranges + input_dtypes, xs = dtype_and_x + helpers.test_frontend_method( init_input_dtypes=input_dtypes, backend_to_test=backend_fw, method_input_dtypes=input_dtypes, init_all_as_kwargs_np={ - "object": x[0], + "object": xs[0], }, method_all_as_kwargs_np={ - "min": min, - "max": max, + "value": xs[1], }, frontend=frontend, frontend_method_data=frontend_method_data, @@ -859,50 +806,42 @@ def test_numpy_ndarray_clip( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="compress", - dtype_arr_ax=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("valid"), - min_num_dims=1, - max_num_dims=5, - min_dim_size=10, - max_dim_size=100, - valid_axis=True, - force_int_axis=True, - ), - condition=helpers.array_values( - dtype=helpers.get_dtypes("bool"), - shape=helpers.get_shape( - min_num_dims=1, max_num_dims=1, min_dim_size=1, max_dim_size=5 - ), + method_name="__ifloordiv__", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("integer"), + num_arrays=2, + allow_inf=False, + large_abs_safety_factor=4, + safety_factor_scale="linear", + shared_dtype=True, ), ) -def test_numpy_ndarray_compress( - dtype_arr_ax, - condition, +def test_numpy___ifloordiv__( + dtype_and_x, frontend_method_data, init_flags, method_flags, - frontend, backend_fw, + frontend, on_device, ): - input_dtypes, arr, ax = dtype_arr_ax + input_dtypes, xs = dtype_and_x + assume(not np.any(np.isclose(xs[1], 0))) helpers.test_frontend_method( init_input_dtypes=input_dtypes, + backend_to_test=backend_fw, init_all_as_kwargs_np={ - "object": arr[0], + "object": xs[0], }, method_input_dtypes=input_dtypes, method_all_as_kwargs_np={ - "condition": condition, - "axis": ax, - "out": None, + "value": xs[1], }, - frontend=frontend, - backend_to_test=backend_fw, - frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, + frontend_method_data=frontend_method_data, + frontend=frontend, + atol_=1, on_device=on_device, ) @@ -910,29 +849,34 @@ def test_numpy_ndarray_compress( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="conj", - dtype_and_x=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("real_and_complex"), + method_name="__imod__", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, + min_value=0, + exclude_min=True, ), ) -def test_numpy_ndarray_conj( +def test_numpy___imod__( dtype_and_x, - on_device, - frontend, frontend_method_data, init_flags, method_flags, backend_fw, + frontend, + on_device, ): - input_dtype, x, axis = dtype_and_x + input_dtypes, xs = dtype_and_x helpers.test_frontend_method( - init_input_dtypes=input_dtype, + init_input_dtypes=input_dtypes, backend_to_test=backend_fw, + method_input_dtypes=input_dtypes, init_all_as_kwargs_np={ - "object": x[0], + "object": xs[0], + }, + method_all_as_kwargs_np={ + "value": xs[1], }, - method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -944,19 +888,13 @@ def test_numpy_ndarray_conj( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="max", - dtype_x_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("numeric"), - min_axis=-1, - max_axis=0, - min_num_dims=1, - force_int_axis=True, + method_name="__imul__", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2 ), - keepdims=st.booleans(), ) -def test_numpy_ndarray_max( - dtype_x_axis, - keepdims, +def test_numpy___imul__( + dtype_and_x, frontend_method_data, init_flags, method_flags, @@ -964,18 +902,17 @@ def test_numpy_ndarray_max( frontend, on_device, ): - input_dtypes, x, axis = dtype_x_axis + input_dtypes, xs = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtypes, backend_to_test=backend_fw, + method_input_dtypes=input_dtypes, init_all_as_kwargs_np={ - "object": x[0], + "object": xs[0], }, - method_input_dtypes=input_dtypes, method_all_as_kwargs_np={ - "axis": axis, - "keepdims": keepdims, + "value": xs[1], }, frontend=frontend, frontend_method_data=frontend_method_data, @@ -988,19 +925,16 @@ def test_numpy_ndarray_max( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="cumprod", - dtype_x_axis=helpers.dtype_values_axis( + method_name="__int__", + dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), - min_axis=-1, - max_axis=0, - min_num_dims=1, - force_int_axis=True, + max_num_dims=0, + min_value=-1e15, + max_value=1e15, ), - dtype=helpers.get_dtypes("float", full=False, none=True), ) -def test_numpy_ndarray_cumprod( - dtype_x_axis, - dtype, +def test_numpy___int__( + dtype_and_x, frontend_method_data, init_flags, method_flags, @@ -1008,20 +942,17 @@ def test_numpy_ndarray_cumprod( frontend, on_device, ): - input_dtypes, x, axis = dtype_x_axis - + input_dtypes, xs = dtype_and_x + # Numpy doesn't support complex to int conversion + assume(not np.issubdtype(input_dtypes[0], np.complexfloating)) helpers.test_frontend_method( init_input_dtypes=input_dtypes, backend_to_test=backend_fw, - init_all_as_kwargs_np={ - "object": x[0], - }, method_input_dtypes=input_dtypes, - method_all_as_kwargs_np={ - "axis": axis, - "dtype": dtype[0], - "out": None, + init_all_as_kwargs_np={ + "object": xs[0], }, + method_all_as_kwargs_np={}, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -1033,31 +964,31 @@ def test_numpy_ndarray_cumprod( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="cumsum", - dtype_x_axis_dtype=_get_castable_dtypes_values(), + method_name="__invert__", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes(kind="integer"), + num_arrays=1, + ), ) -def test_numpy_ndarray_cumsum( - dtype_x_axis_dtype, +def test_numpy___invert__( + dtype_and_x, frontend_method_data, init_flags, method_flags, - backend_fw, frontend, + backend_fw, on_device, ): - input_dtypes, x, axis, dtype = dtype_x_axis_dtype + input_dtypes, xs = dtype_and_x + helpers.test_frontend_method( init_input_dtypes=input_dtypes, - backend_to_test=backend_fw, init_all_as_kwargs_np={ - "object": x[0], + "object": xs[0], }, method_input_dtypes=input_dtypes, - method_all_as_kwargs_np={ - "axis": axis, - "dtype": dtype, - "out": None, - }, + backend_to_test=backend_fw, + method_all_as_kwargs_np={}, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -1069,10 +1000,13 @@ def test_numpy_ndarray_cumsum( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="dot", - dtype_and_x=np_frontend_helpers._get_dtype_input_and_vectors(), + method_name="__ior__", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=("bool",), + num_arrays=2, + ), ) -def test_numpy_ndarray_dot( +def test_numpy___ior__( dtype_and_x, frontend_method_data, init_flags, @@ -1081,21 +1015,22 @@ def test_numpy_ndarray_dot( frontend, on_device, ): - input_dtype, x, other = dtype_and_x + input_dtypes, xs = dtype_and_x + helpers.test_frontend_method( - init_input_dtypes=input_dtype, + init_input_dtypes=input_dtypes, backend_to_test=backend_fw, + method_input_dtypes=input_dtypes, init_all_as_kwargs_np={ - "object": x, + "object": xs[0], }, - method_input_dtypes=input_dtype, method_all_as_kwargs_np={ - "b": other, + "value": xs[1], }, + frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, - frontend=frontend, on_device=on_device, ) @@ -1103,19 +1038,15 @@ def test_numpy_ndarray_dot( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="diagonal", - dtype_x_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("numeric"), - min_num_dims=2, - min_axes_size=2, - max_axes_size=2, - valid_axis=True, + method_name="__ipow__", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), ), - offset=st.integers(min_value=-2, max_value=2), + power=helpers.ints(min_value=1, max_value=3), ) -def test_numpy_ndarray_diagonal( - dtype_x_axis, - offset, +def test_numpy___ipow__( + dtype_and_x, + power, frontend_method_data, init_flags, method_flags, @@ -1123,19 +1054,17 @@ def test_numpy_ndarray_diagonal( frontend, on_device, ): - input_dtypes, x, axis = dtype_x_axis + input_dtypes, xs = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtypes, backend_to_test=backend_fw, + method_input_dtypes=input_dtypes, init_all_as_kwargs_np={ - "object": x[0], + "object": xs[0], }, - method_input_dtypes=input_dtypes, method_all_as_kwargs_np={ - "axis1": axis[0], - "axis2": axis[1], - "offset": offset, + "value": power, }, frontend=frontend, frontend_method_data=frontend_method_data, @@ -1148,17 +1077,13 @@ def test_numpy_ndarray_diagonal( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="sort", - dtype_x_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("float"), - min_axis=-1, - max_axis=0, - min_num_dims=1, - force_int_axis=True, + method_name="__isub__", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2 ), ) -def test_numpy_ndarray_sort( - dtype_x_axis, +def test_numpy___isub__( + dtype_and_x, frontend_method_data, init_flags, method_flags, @@ -1166,46 +1091,38 @@ def test_numpy_ndarray_sort( frontend, on_device, ): - input_dtypes, x, axis = dtype_x_axis + input_dtypes, xs = dtype_and_x - ret, frontend_ret = helpers.test_frontend_method( + helpers.test_frontend_method( init_input_dtypes=input_dtypes, backend_to_test=backend_fw, + method_input_dtypes=input_dtypes, init_all_as_kwargs_np={ - "object": x[0], + "object": xs[0], }, - method_input_dtypes=input_dtypes, method_all_as_kwargs_np={ - "axis": axis, + "value": xs[1], }, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, - test_values=False, on_device=on_device, ) - frontend_ret = np.sort(x[0], axis=axis) - assert_all_close( - ret_np=ret, - ret_from_gt_np=frontend_ret, - rtol=1e-2, - atol=1e-2, - backend=backend_fw, - ground_truth_backend="numpy", - ) +# __itruediv__ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="copy", + method_name="__itruediv__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - min_num_dims=1, + available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, + shared_dtype=True, ), ) -def test_numpy_ndarray_copy( +def test_numpy___itruediv__( dtype_and_x, frontend_method_data, init_flags, @@ -1214,20 +1131,22 @@ def test_numpy_ndarray_copy( frontend, on_device, ): - input_dtypes, x = dtype_and_x + input_dtypes, xs = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtypes, backend_to_test=backend_fw, init_all_as_kwargs_np={ - "object": x[0], + "object": xs[0], }, method_input_dtypes=input_dtypes, - method_all_as_kwargs_np={}, - frontend=frontend, + method_all_as_kwargs_np={ + "value": xs[1], + }, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, + frontend=frontend, on_device=on_device, ) @@ -1235,13 +1154,14 @@ def test_numpy_ndarray_copy( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="nonzero", - dtype_and_a=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), + method_name="__ixor__", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=("bool",), + num_arrays=2, ), ) -def test_numpy_ndarray_nonzero( - dtype_and_a, +def test_numpy___ixor__( + dtype_and_x, frontend_method_data, init_flags, method_flags, @@ -1249,16 +1169,18 @@ def test_numpy_ndarray_nonzero( frontend, on_device, ): - input_dtypes, a = dtype_and_a + input_dtypes, xs = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtypes, + backend_to_test=backend_fw, + method_input_dtypes=input_dtypes, init_all_as_kwargs_np={ - "object": a[0], + "object": xs[0], + }, + method_all_as_kwargs_np={ + "value": xs[1], }, - method_input_dtypes=input_dtypes, - backend_to_test=backend_fw, - method_all_as_kwargs_np={}, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -1270,13 +1192,14 @@ def test_numpy_ndarray_nonzero( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="ravel", - dtype_and_a=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), + method_name="__le__", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, ), ) -def test_numpy_ndarray_ravel( - dtype_and_a, +def test_numpy___le__( + dtype_and_x, frontend_method_data, init_flags, method_flags, @@ -1284,16 +1207,18 @@ def test_numpy_ndarray_ravel( frontend, on_device, ): - input_dtypes, a = dtype_and_a + input_dtypes, xs = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtypes, + backend_to_test=backend_fw, init_all_as_kwargs_np={ - "object": a[0], + "object": xs[0], }, method_input_dtypes=input_dtypes, - backend_to_test=backend_fw, - method_all_as_kwargs_np={}, + method_all_as_kwargs_np={ + "value": xs[1], + }, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -1302,22 +1227,19 @@ def test_numpy_ndarray_ravel( ) +# __len__ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="repeat", + method_name="__len__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - min_num_dims=2, - min_dim_size=2, + available_dtypes=helpers.get_dtypes("numeric"), + min_num_dims=1, + max_num_dims=5, ), - repeats=helpers.ints(min_value=2, max_value=5), - axis=helpers.ints(min_value=-1, max_value=1), ) -def test_numpy_ndarray_repeat( +def test_numpy___len__( dtype_and_x, - repeats, - axis, frontend_method_data, init_flags, method_flags, @@ -1326,7 +1248,6 @@ def test_numpy_ndarray_repeat( on_device, ): input_dtypes, x = dtype_and_x - helpers.test_frontend_method( init_input_dtypes=input_dtypes, backend_to_test=backend_fw, @@ -1334,14 +1255,11 @@ def test_numpy_ndarray_repeat( "object": x[0], }, method_input_dtypes=input_dtypes, - method_all_as_kwargs_np={ - "repeats": repeats, - "axis": axis, - }, - frontend=frontend, - frontend_method_data=frontend_method_data, + method_all_as_kwargs_np={}, init_flags=init_flags, method_flags=method_flags, + frontend=frontend, + frontend_method_data=frontend_method_data, on_device=on_device, ) @@ -1349,18 +1267,14 @@ def test_numpy_ndarray_repeat( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="searchsorted", - dtype_x_v=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("signed_integer"), - min_num_dims=1, - max_num_dims=1, + method_name="__lt__", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2, ), - side=st.sampled_from(["left", "right"]), ) -def test_numpy_ndarray_searchsorted( - dtype_x_v, - side, +def test_numpy___lt__( + dtype_and_x, frontend_method_data, init_flags, method_flags, @@ -1368,19 +1282,17 @@ def test_numpy_ndarray_searchsorted( frontend, on_device, ): - input_dtypes, xs = dtype_x_v + input_dtypes, xs = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtypes, + backend_to_test=backend_fw, init_all_as_kwargs_np={ "object": xs[0], }, method_input_dtypes=input_dtypes, - backend_to_test=backend_fw, method_all_as_kwargs_np={ - "v": xs[1], - "side": side, - "sorter": np.argsort(xs[0]), + "value": xs[1], }, frontend=frontend, frontend_method_data=frontend_method_data, @@ -1393,16 +1305,13 @@ def test_numpy_ndarray_searchsorted( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="squeeze", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - shape=st.shared(helpers.get_shape(), key="value_shape"), - ), - axis=_squeeze_helper(), + method_name="__matmul__", + x=_get_first_matrix_and_dtype(), + y=_get_second_matrix_and_dtype(), ) -def test_numpy_ndarray_squeeze( - dtype_and_x, - axis, +def test_numpy___matmul__( + x, + y, frontend_method_data, init_flags, method_flags, @@ -1410,17 +1319,19 @@ def test_numpy_ndarray_squeeze( frontend, on_device, ): - input_dtype, x = dtype_and_x + dtype1, x1 = x + dtype2, x2 = y + input_dtypes = dtype1 + dtype2 helpers.test_frontend_method( - init_input_dtypes=input_dtype, + init_input_dtypes=input_dtypes, backend_to_test=backend_fw, init_all_as_kwargs_np={ - "object": x[0], + "object": x1, }, - method_input_dtypes=input_dtype, + method_input_dtypes=input_dtypes, method_all_as_kwargs_np={ - "axis": axis, + "value": x2, }, frontend=frontend, frontend_method_data=frontend_method_data, @@ -1430,23 +1341,23 @@ def test_numpy_ndarray_squeeze( ) +# mod @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="std", - dtype_x_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("valid"), - max_value=100, - valid_axis=True, - force_int_axis=True, + method_name="__mod__", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, + large_abs_safety_factor=2, + small_abs_safety_factor=2, + safety_factor_scale="log", + min_value=0, + exclude_min=True, ), - keepdims=st.booleans(), - where=np_frontend_helpers.where(), ) -def test_numpy_ndarray_std( - dtype_x_axis, - keepdims, - where, +def test_numpy___mod__( + dtype_and_x, frontend_method_data, init_flags, method_flags, @@ -1454,51 +1365,38 @@ def test_numpy_ndarray_std( frontend, on_device, ): - input_dtypes, x, axis = dtype_x_axis - ( - where, - input_dtypes, - method_flags, - ) = np_frontend_helpers.handle_where_and_array_bools( - where=[where[0][0]] if isinstance(where, list) else where, - input_dtype=input_dtypes, - test_flags=method_flags, - ) + input_dtypes, xs = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtypes, backend_to_test=backend_fw, method_input_dtypes=input_dtypes, init_all_as_kwargs_np={ - "data": x[0], + "object": xs[0], }, method_all_as_kwargs_np={ - "axis": axis, - "out": None, - "ddof": 0, - "keepdims": keepdims, - "where": where, + "value": xs[1], }, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, on_device=on_device, + rtol_=1e-5, + atol_=1e-5, ) -# fill @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="fill", + method_name="__mul__", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, ), - num=st.integers(min_value=1, max_value=10) | st.floats(min_value=1, max_value=10), ) -def test_numpy_ndarray_fill( +def test_numpy___mul__( dtype_and_x, - num, frontend_method_data, init_flags, method_flags, @@ -1506,16 +1404,17 @@ def test_numpy_ndarray_fill( frontend, on_device, ): - input_dtypes, x = dtype_and_x + input_dtypes, xs = dtype_and_x + helpers.test_frontend_method( init_input_dtypes=input_dtypes, backend_to_test=backend_fw, init_all_as_kwargs_np={ - "object": x[0], + "object": xs[0], }, - method_input_dtypes=[], + method_input_dtypes=input_dtypes, method_all_as_kwargs_np={ - "num": num, + "value": xs[1], }, frontend=frontend, frontend_method_data=frontend_method_data, @@ -1528,12 +1427,13 @@ def test_numpy_ndarray_fill( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="__add__", + method_name="__ne__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2 + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, ), ) -def test_numpy___add__( +def test_numpy___ne__( dtype_and_x, frontend_method_data, init_flags, @@ -1565,12 +1465,13 @@ def test_numpy___add__( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="__radd__", + method_name="__neg__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2 + available_dtypes=helpers.get_dtypes("numeric"), + min_num_dims=1, ), ) -def test_numpy___radd__( +def test_numpy___neg__( dtype_and_x, frontend_method_data, init_flags, @@ -1579,22 +1480,20 @@ def test_numpy___radd__( frontend, on_device, ): - input_dtypes, xs = dtype_and_x + input_dtypes, x = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtypes, backend_to_test=backend_fw, init_all_as_kwargs_np={ - "object": xs[0], + "object": x[0], }, method_input_dtypes=input_dtypes, - method_all_as_kwargs_np={ - "value": xs[1], - }, - method_flags=method_flags, - init_flags=init_flags, + method_all_as_kwargs_np={}, frontend=frontend, frontend_method_data=frontend_method_data, + init_flags=init_flags, + method_flags=method_flags, on_device=on_device, ) @@ -1602,12 +1501,13 @@ def test_numpy___radd__( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="__sub__", + method_name="__or__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2 + available_dtypes=("bool",), + num_arrays=2, ), ) -def test_numpy___sub__( +def test_numpy___or__( dtype_and_x, frontend_method_data, init_flags, @@ -1619,8 +1519,8 @@ def test_numpy___sub__( input_dtypes, xs = dtype_and_x helpers.test_frontend_method( - backend_to_test=backend_fw, init_input_dtypes=input_dtypes, + backend_to_test=backend_fw, init_all_as_kwargs_np={ "object": xs[0], }, @@ -1639,13 +1539,13 @@ def test_numpy___sub__( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="__mul__", + method_name="__pos__", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, + min_num_dims=1, ), ) -def test_numpy___mul__( +def test_numpy___pos__( dtype_and_x, frontend_method_data, init_flags, @@ -1654,18 +1554,16 @@ def test_numpy___mul__( frontend, on_device, ): - input_dtypes, xs = dtype_and_x + input_dtypes, x = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtypes, backend_to_test=backend_fw, init_all_as_kwargs_np={ - "object": xs[0], + "object": x[0], }, method_input_dtypes=input_dtypes, - method_all_as_kwargs_np={ - "value": xs[1], - }, + method_all_as_kwargs_np={}, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -1677,14 +1575,15 @@ def test_numpy___mul__( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="__rmul__", + method_name="__pow__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, + available_dtypes=helpers.get_dtypes("float"), ), + power=helpers.ints(min_value=1, max_value=3), ) -def test_numpy___rmul__( +def test_numpy___pow__( dtype_and_x, + power, frontend_method_data, init_flags, method_flags, @@ -1702,7 +1601,7 @@ def test_numpy___rmul__( }, method_input_dtypes=input_dtypes, method_all_as_kwargs_np={ - "value": xs[1], + "value": power, }, frontend=frontend, frontend_method_data=frontend_method_data, @@ -1712,21 +1611,15 @@ def test_numpy___rmul__( ) -# __floordiv__ test @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="__floordiv__", + method_name="__radd__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("integer"), - num_arrays=2, - allow_inf=False, - large_abs_safety_factor=4, - safety_factor_scale="linear", - shared_dtype=True, + available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2 ), ) -def test_numpy___floordiv__( +def test_numpy___radd__( dtype_and_x, frontend_method_data, init_flags, @@ -1736,7 +1629,7 @@ def test_numpy___floordiv__( on_device, ): input_dtypes, xs = dtype_and_x - assume(not np.any(np.isclose(xs[1], 0))) + helpers.test_frontend_method( init_input_dtypes=input_dtypes, backend_to_test=backend_fw, @@ -1747,11 +1640,10 @@ def test_numpy___floordiv__( method_all_as_kwargs_np={ "value": xs[1], }, - init_flags=init_flags, method_flags=method_flags, - frontend_method_data=frontend_method_data, + init_flags=init_flags, frontend=frontend, - atol_=1, + frontend_method_data=frontend_method_data, on_device=on_device, ) @@ -1759,13 +1651,13 @@ def test_numpy___floordiv__( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="__truediv__", + method_name="__rmul__", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2, ), ) -def test_numpy___truediv__( +def test_numpy___rmul__( dtype_and_x, frontend_method_data, init_flags, @@ -1775,7 +1667,6 @@ def test_numpy___truediv__( on_device, ): input_dtypes, xs = dtype_and_x - assume(not np.any(np.isclose(xs[0], 0))) helpers.test_frontend_method( init_input_dtypes=input_dtypes, @@ -1787,10 +1678,10 @@ def test_numpy___truediv__( method_all_as_kwargs_np={ "value": xs[1], }, + frontend=frontend, + frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, - frontend_method_data=frontend_method_data, - frontend=frontend, on_device=on_device, ) @@ -1798,38 +1689,48 @@ def test_numpy___truediv__( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="__rtruediv__", + method_name="__rshift__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), + available_dtypes=helpers.get_dtypes("integer"), num_arrays=2, ), ) -def test_numpy___rtruediv__( +def test_numpy___rshift__( dtype_and_x, frontend_method_data, init_flags, method_flags, - backend_fw, frontend, + backend_fw, on_device, ): - input_dtypes, xs = dtype_and_x - assume(not np.any(np.isclose(xs[0], 0))) - + input_dtypes, x = dtype_and_x + max_bits = np.iinfo(input_dtypes[0]).bits + max_shift = max_bits - 1 + x[1] = np.asarray(np.clip(x[1], 0, max_shift), dtype=input_dtypes[1]) + max_value_before_shift = 2 ** (max_bits - x[1]) - 1 + overflow_threshold = 2 ** (max_bits - 1) + x[0] = np.asarray( + np.clip(x[0], None, max_value_before_shift), dtype=input_dtypes[0] + ) + if np.any(x[0] > overflow_threshold): + x[0] = np.clip(x[0], None, overflow_threshold) + if np.any(x[0] < 0): + x[0] = np.abs(x[0]) helpers.test_frontend_method( init_input_dtypes=input_dtypes, - backend_to_test=backend_fw, init_all_as_kwargs_np={ - "object": xs[0], + "object": x[0], }, method_input_dtypes=input_dtypes, + backend_to_test=backend_fw, method_all_as_kwargs_np={ - "value": xs[1], + "value": x[1], }, + frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, - frontend=frontend, on_device=on_device, ) @@ -1837,15 +1738,14 @@ def test_numpy___rtruediv__( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="__pow__", + method_name="__rtruediv__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, ), - power=helpers.ints(min_value=1, max_value=3), ) -def test_numpy___pow__( +def test_numpy___rtruediv__( dtype_and_x, - power, frontend_method_data, init_flags, method_flags, @@ -1854,6 +1754,7 @@ def test_numpy___pow__( on_device, ): input_dtypes, xs = dtype_and_x + assume(not np.any(np.isclose(xs[0], 0))) helpers.test_frontend_method( init_input_dtypes=input_dtypes, @@ -1863,12 +1764,12 @@ def test_numpy___pow__( }, method_input_dtypes=input_dtypes, method_all_as_kwargs_np={ - "value": power, + "value": xs[1], }, - frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, + frontend=frontend, on_device=on_device, ) @@ -1876,13 +1777,12 @@ def test_numpy___pow__( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="__and__", + method_name="__sub__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=("bool",), - num_arrays=2, + available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2 ), ) -def test_numpy___and__( +def test_numpy___sub__( dtype_and_x, frontend_method_data, init_flags, @@ -1894,8 +1794,8 @@ def test_numpy___and__( input_dtypes, xs = dtype_and_x helpers.test_frontend_method( - init_input_dtypes=input_dtypes, backend_to_test=backend_fw, + init_input_dtypes=input_dtypes, init_all_as_kwargs_np={ "object": xs[0], }, @@ -1914,13 +1814,13 @@ def test_numpy___and__( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="__or__", + method_name="__truediv__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=("bool",), + available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2, ), ) -def test_numpy___or__( +def test_numpy___truediv__( dtype_and_x, frontend_method_data, init_flags, @@ -1930,6 +1830,7 @@ def test_numpy___or__( on_device, ): input_dtypes, xs = dtype_and_x + assume(not np.any(np.isclose(xs[0], 0))) helpers.test_frontend_method( init_input_dtypes=input_dtypes, @@ -1941,10 +1842,10 @@ def test_numpy___or__( method_all_as_kwargs_np={ "value": xs[1], }, - frontend=frontend, - frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, + frontend_method_data=frontend_method_data, + frontend=frontend, on_device=on_device, ) @@ -1987,16 +1888,17 @@ def test_numpy___xor__( ) +# __getitem__ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="__matmul__", - x=_get_first_matrix_and_dtype(), - y=_get_second_matrix_and_dtype(), + method_name="__getitem__", + dtype_x_index=helpers.dtype_array_query( + available_dtypes=helpers.get_dtypes("valid"), + ), ) -def test_numpy___matmul__( - x, - y, +def test_numpy_getitem( + dtype_x_index, frontend_method_data, init_flags, method_flags, @@ -2004,24 +1906,17 @@ def test_numpy___matmul__( frontend, on_device, ): - dtype1, x1 = x - dtype2, x2 = y - input_dtypes = dtype1 + dtype2 - + input_dtype, x, index = dtype_x_index helpers.test_frontend_method( - init_input_dtypes=input_dtypes, + init_input_dtypes=[input_dtype[0]], + init_all_as_kwargs_np={"object": x}, + method_input_dtypes=[*input_dtype[1:]], + method_all_as_kwargs_np={"key": index}, backend_to_test=backend_fw, - init_all_as_kwargs_np={ - "object": x1, - }, - method_input_dtypes=input_dtypes, - method_all_as_kwargs_np={ - "value": x2, - }, - frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, + frontend=frontend, on_device=on_device, ) @@ -2029,31 +1924,46 @@ def test_numpy___matmul__( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="__copy__", + method_name="__lshift__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - min_num_dims=1, + available_dtypes=helpers.get_dtypes("integer"), + num_arrays=2, + max_dim_size=1, + max_value=2**31 - 1, ), ) -def test_numpy___copy__( +def test_numpy_instance_lshift__( dtype_and_x, frontend_method_data, init_flags, method_flags, - backend_fw, frontend, + backend_fw, on_device, ): input_dtypes, x = dtype_and_x - + max_bits = np.iinfo(input_dtypes[0]).bits + max_shift = max_bits - 1 + x[1] = np.asarray(np.clip(x[1], 0, max_shift), dtype=input_dtypes[1]) + max_value_before_shift = 2 ** (max_bits - x[1]) - 1 + overflow_threshold = 2 ** (max_bits - 1) + x[0] = np.asarray( + np.clip(x[0], None, max_value_before_shift), dtype=input_dtypes[0] + ) + if np.any(x[0] > overflow_threshold): + x[0] = np.clip(x[0], None, overflow_threshold) + if np.any(x[0] < 0): + x[0] = np.abs(x[0]) helpers.test_frontend_method( init_input_dtypes=input_dtypes, - backend_to_test=backend_fw, init_all_as_kwargs_np={ "object": x[0], }, method_input_dtypes=input_dtypes, - method_all_as_kwargs_np={}, + backend_to_test=backend_fw, + method_all_as_kwargs_np={ + "value": x[1], + }, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -2065,14 +1975,23 @@ def test_numpy___copy__( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="__deepcopy__", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), + method_name="all", + dtype_x_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("valid", full=True), min_num_dims=1, + max_num_dims=5, + min_dim_size=1, + valid_axis=True, + force_int_axis=True, + allow_neg_axes=True, ), + keepdims=st.booleans(), + where=np_frontend_helpers.where(), ) -def test_numpy___deepcopy__( - dtype_and_x, +def test_numpy_ndarray_all( + dtype_x_axis, + keepdims, + where, frontend_method_data, init_flags, method_flags, @@ -2080,17 +1999,30 @@ def test_numpy___deepcopy__( frontend, on_device, ): - input_dtypes, x = dtype_and_x + init_input_dtypes, x, axis = dtype_x_axis + ( + where, + method_input_dtypes, + method_flags, + ) = np_frontend_helpers.handle_where_and_array_bools( + where=[where[0][0]] if isinstance(where, list) else where, + input_dtype=init_input_dtypes, + test_flags=method_flags, + ) helpers.test_frontend_method( - init_input_dtypes=input_dtypes, + init_input_dtypes=init_input_dtypes, backend_to_test=backend_fw, init_all_as_kwargs_np={ "object": x[0], }, - method_input_dtypes=input_dtypes, + method_input_dtypes=method_input_dtypes[1:], method_all_as_kwargs_np={ - "memo": {}, + "axis": axis, + "dtype": bool, + "out": None, + "keepdims": keepdims, + "where": where, }, frontend=frontend, frontend_method_data=frontend_method_data, @@ -2100,17 +2032,27 @@ def test_numpy___deepcopy__( ) +# any @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="__neg__", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), + method_name="any", + dtype_x_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("valid"), min_num_dims=1, + max_num_dims=5, + min_dim_size=1, + valid_axis=True, + force_int_axis=True, + allow_neg_axes=True, ), + keepdims=st.booleans(), + where=np_frontend_helpers.where(), ) -def test_numpy___neg__( - dtype_and_x, +def test_numpy_ndarray_any( + dtype_x_axis, + keepdims, + where, frontend_method_data, init_flags, method_flags, @@ -2118,16 +2060,31 @@ def test_numpy___neg__( frontend, on_device, ): - input_dtypes, x = dtype_and_x + init_input_dtypes, x, axis = dtype_x_axis + ( + where, + method_input_dtypes, + method_flags, + ) = np_frontend_helpers.handle_where_and_array_bools( + where=[where[0][0]] if isinstance(where, list) else where, + input_dtype=init_input_dtypes, + test_flags=method_flags, + ) helpers.test_frontend_method( - init_input_dtypes=input_dtypes, + init_input_dtypes=init_input_dtypes, backend_to_test=backend_fw, init_all_as_kwargs_np={ "object": x[0], }, - method_input_dtypes=input_dtypes, - method_all_as_kwargs_np={}, + method_input_dtypes=method_input_dtypes[1:], + method_all_as_kwargs_np={ + "axis": axis, + "dtype": bool, + "out": None, + "keepdims": keepdims, + "where": where, + }, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -2139,14 +2096,21 @@ def test_numpy___neg__( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="__pos__", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), + method_name="argmax", + dtype_x_axis=helpers.dtype_values_axis( + available_dtypes=st.one_of( + helpers.get_dtypes("numeric"), + ), + min_axis=-1, + max_axis=0, min_num_dims=1, + force_int_axis=True, ), + keep_dims=st.booleans(), ) -def test_numpy___pos__( - dtype_and_x, +def test_numpy_ndarray_argmax( + dtype_x_axis, + keep_dims, frontend_method_data, init_flags, method_flags, @@ -2154,8 +2118,7 @@ def test_numpy___pos__( frontend, on_device, ): - input_dtypes, x = dtype_and_x - + input_dtypes, x, axis = dtype_x_axis helpers.test_frontend_method( init_input_dtypes=input_dtypes, backend_to_test=backend_fw, @@ -2163,7 +2126,10 @@ def test_numpy___pos__( "object": x[0], }, method_input_dtypes=input_dtypes, - method_all_as_kwargs_np={}, + method_all_as_kwargs_np={ + "axis": axis, + "keepdims": keep_dims, + }, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -2175,18 +2141,18 @@ def test_numpy___pos__( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="__ifloordiv__", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("integer"), - num_arrays=2, - allow_inf=False, - large_abs_safety_factor=4, - safety_factor_scale="linear", - shared_dtype=True, + method_name="argmin", + dtype_x_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("numeric"), + min_num_dims=1, + valid_axis=True, + force_int_axis=True, ), + keepdims=st.booleans(), ) -def test_numpy___ifloordiv__( - dtype_and_x, +def test_numpy_ndarray_argmin( + dtype_x_axis, + keepdims, frontend_method_data, init_flags, method_flags, @@ -2194,23 +2160,22 @@ def test_numpy___ifloordiv__( frontend, on_device, ): - input_dtypes, xs = dtype_and_x - assume(not np.any(np.isclose(xs[1], 0))) + input_dtypes, x, axis = dtype_x_axis helpers.test_frontend_method( init_input_dtypes=input_dtypes, backend_to_test=backend_fw, init_all_as_kwargs_np={ - "object": xs[0], + "object": x[0], }, method_input_dtypes=input_dtypes, method_all_as_kwargs_np={ - "value": xs[1], + "axis": axis, + "keepdims": keepdims, }, + frontend=frontend, + frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, - frontend_method_data=frontend_method_data, - frontend=frontend, - atol_=1, on_device=on_device, ) @@ -2218,14 +2183,17 @@ def test_numpy___ifloordiv__( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="__bool__", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - max_dim_size=1, + method_name="argsort", + dtype_x_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("numeric"), + min_axis=-1, + max_axis=0, + min_num_dims=1, + force_int_axis=True, ), ) -def test_numpy___bool__( - dtype_and_x, +def test_numpy_ndarray_argsort( + dtype_x_axis, frontend_method_data, init_flags, method_flags, @@ -2233,20 +2201,23 @@ def test_numpy___bool__( frontend, on_device, ): - input_dtypes, x = dtype_and_x - + input_dtypes, x, axis = dtype_x_axis helpers.test_frontend_method( init_input_dtypes=input_dtypes, backend_to_test=backend_fw, + method_input_dtypes=input_dtypes, init_all_as_kwargs_np={ "object": x[0], }, - method_input_dtypes=input_dtypes, - method_all_as_kwargs_np={}, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, + method_all_as_kwargs_np={ + "axis": axis, + "kind": None, + "order": None, + }, on_device=on_device, ) @@ -2254,14 +2225,21 @@ def test_numpy___bool__( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="__ne__", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, + method_name="astype", + dtypes_values_casting=np_frontend_helpers.dtypes_values_casting_dtype( + arr_func=[ + lambda: helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + ) + ], ), + order=st.sampled_from(["C", "F", "A", "K"]), + copy=st.booleans(), ) -def test_numpy___ne__( - dtype_and_x, +def test_numpy_ndarray_astype( + dtypes_values_casting, + order, + copy, frontend_method_data, init_flags, method_flags, @@ -2269,17 +2247,19 @@ def test_numpy___ne__( frontend, on_device, ): - input_dtypes, xs = dtype_and_x - + input_dtypes, x, casting, dtype = dtypes_values_casting helpers.test_frontend_method( init_input_dtypes=input_dtypes, backend_to_test=backend_fw, init_all_as_kwargs_np={ - "object": xs[0], + "object": x[0], }, method_input_dtypes=input_dtypes, method_all_as_kwargs_np={ - "value": xs[1], + "dtype": dtype if dtype else input_dtypes[0], + "order": order, + "casting": casting, + "copy": copy, }, frontend=frontend, frontend_method_data=frontend_method_data, @@ -2292,14 +2272,11 @@ def test_numpy___ne__( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="__eq__", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, - ), + method_name="clip", + input_and_ranges=_get_clip_inputs(), ) -def test_numpy___eq__( - dtype_and_x, +def test_numpy_ndarray_clip( + input_and_ranges, frontend_method_data, init_flags, method_flags, @@ -2307,17 +2284,17 @@ def test_numpy___eq__( frontend, on_device, ): - input_dtypes, xs = dtype_and_x - + input_dtypes, x, min, max = input_and_ranges helpers.test_frontend_method( init_input_dtypes=input_dtypes, backend_to_test=backend_fw, + method_input_dtypes=input_dtypes, init_all_as_kwargs_np={ - "object": xs[0], + "object": x[0], }, - method_input_dtypes=input_dtypes, method_all_as_kwargs_np={ - "value": xs[1], + "min": min, + "max": max, }, frontend=frontend, frontend_method_data=frontend_method_data, @@ -2330,34 +2307,47 @@ def test_numpy___eq__( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="__ge__", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, + method_name="compress", + dtype_arr_ax=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=1, + max_num_dims=5, + min_dim_size=10, + max_dim_size=100, + valid_axis=True, + force_int_axis=True, + ), + condition=helpers.array_values( + dtype=helpers.get_dtypes("bool"), + shape=helpers.get_shape( + min_num_dims=1, max_num_dims=1, min_dim_size=1, max_dim_size=5 + ), ), ) -def test_numpy___ge__( - dtype_and_x, +def test_numpy_ndarray_compress( + dtype_arr_ax, + condition, frontend_method_data, init_flags, method_flags, - backend_fw, frontend, + backend_fw, on_device, ): - input_dtypes, xs = dtype_and_x - + input_dtypes, arr, ax = dtype_arr_ax helpers.test_frontend_method( init_input_dtypes=input_dtypes, - backend_to_test=backend_fw, init_all_as_kwargs_np={ - "object": xs[0], + "object": arr[0], }, method_input_dtypes=input_dtypes, method_all_as_kwargs_np={ - "value": xs[1], + "condition": condition, + "axis": ax, + "out": None, }, frontend=frontend, + backend_to_test=backend_fw, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -2368,33 +2358,29 @@ def test_numpy___ge__( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="__gt__", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, + method_name="conj", + dtype_and_x=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("real_and_complex"), ), ) -def test_numpy___gt__( +def test_numpy_ndarray_conj( dtype_and_x, + on_device, + frontend, frontend_method_data, init_flags, method_flags, backend_fw, - frontend, - on_device, ): - input_dtypes, xs = dtype_and_x - + input_dtype, x, axis = dtype_and_x helpers.test_frontend_method( - init_input_dtypes=input_dtypes, + init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ - "object": xs[0], - }, - method_input_dtypes=input_dtypes, - method_all_as_kwargs_np={ - "value": xs[1], + "object": x[0], }, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={}, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -2406,13 +2392,13 @@ def test_numpy___gt__( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="__le__", + method_name="copy", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, + min_num_dims=1, ), ) -def test_numpy___le__( +def test_numpy_ndarray_copy( dtype_and_x, frontend_method_data, init_flags, @@ -2421,18 +2407,16 @@ def test_numpy___le__( frontend, on_device, ): - input_dtypes, xs = dtype_and_x + input_dtypes, x = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtypes, backend_to_test=backend_fw, init_all_as_kwargs_np={ - "object": xs[0], + "object": x[0], }, method_input_dtypes=input_dtypes, - method_all_as_kwargs_np={ - "value": xs[1], - }, + method_all_as_kwargs_np={}, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -2444,14 +2428,19 @@ def test_numpy___le__( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="__lt__", - dtype_and_x=helpers.dtype_and_values( + method_name="cumprod", + dtype_x_axis=helpers.dtype_values_axis( available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, + min_axis=-1, + max_axis=0, + min_num_dims=1, + force_int_axis=True, ), + dtype=helpers.get_dtypes("float", full=False, none=True), ) -def test_numpy___lt__( - dtype_and_x, +def test_numpy_ndarray_cumprod( + dtype_x_axis, + dtype, frontend_method_data, init_flags, method_flags, @@ -2459,17 +2448,19 @@ def test_numpy___lt__( frontend, on_device, ): - input_dtypes, xs = dtype_and_x + input_dtypes, x, axis = dtype_x_axis helpers.test_frontend_method( init_input_dtypes=input_dtypes, backend_to_test=backend_fw, init_all_as_kwargs_np={ - "object": xs[0], + "object": x[0], }, method_input_dtypes=input_dtypes, method_all_as_kwargs_np={ - "value": xs[1], + "axis": axis, + "dtype": dtype[0], + "out": None, }, frontend=frontend, frontend_method_data=frontend_method_data, @@ -2482,16 +2473,11 @@ def test_numpy___lt__( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="__int__", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - max_num_dims=0, - min_value=-1e15, - max_value=1e15, - ), + method_name="cumsum", + dtype_x_axis_dtype=_get_castable_dtypes_values(), ) -def test_numpy___int__( - dtype_and_x, +def test_numpy_ndarray_cumsum( + dtype_x_axis_dtype, frontend_method_data, init_flags, method_flags, @@ -2499,17 +2485,19 @@ def test_numpy___int__( frontend, on_device, ): - input_dtypes, xs = dtype_and_x - # Numpy doesn't support complex to int conversion - assume(not np.issubdtype(input_dtypes[0], np.complexfloating)) + input_dtypes, x, axis, dtype = dtype_x_axis_dtype helpers.test_frontend_method( init_input_dtypes=input_dtypes, backend_to_test=backend_fw, - method_input_dtypes=input_dtypes, init_all_as_kwargs_np={ - "object": xs[0], + "object": x[0], + }, + method_input_dtypes=input_dtypes, + method_all_as_kwargs_np={ + "axis": axis, + "dtype": dtype, + "out": None, }, - method_all_as_kwargs_np={}, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -2521,14 +2509,19 @@ def test_numpy___int__( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="__float__", - dtype_and_x=helpers.dtype_and_values( + method_name="diagonal", + dtype_x_axis=helpers.dtype_values_axis( available_dtypes=helpers.get_dtypes("numeric"), - max_num_dims=0, + min_num_dims=2, + min_axes_size=2, + max_axes_size=2, + valid_axis=True, ), + offset=st.integers(min_value=-2, max_value=2), ) -def test_numpy___float__( - dtype_and_x, +def test_numpy_ndarray_diagonal( + dtype_x_axis, + offset, frontend_method_data, init_flags, method_flags, @@ -2536,17 +2529,20 @@ def test_numpy___float__( frontend, on_device, ): - input_dtypes, xs = dtype_and_x - # Numpy doesn't support complex to float conversion - assume(not np.issubdtype(input_dtypes[0], np.complexfloating)) + input_dtypes, x, axis = dtype_x_axis + helpers.test_frontend_method( init_input_dtypes=input_dtypes, backend_to_test=backend_fw, - method_input_dtypes=input_dtypes, init_all_as_kwargs_np={ - "object": xs[0], + "object": x[0], + }, + method_input_dtypes=input_dtypes, + method_all_as_kwargs_np={ + "axis1": axis[0], + "axis2": axis[1], + "offset": offset, }, - method_all_as_kwargs_np={}, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -2558,14 +2554,10 @@ def test_numpy___float__( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="__complex__", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - min_dim_size=1, - max_dim_size=1, - ), + method_name="dot", + dtype_and_x=np_frontend_helpers._get_dtype_input_and_vectors(), ) -def test_numpy___complex__( +def test_numpy_ndarray_dot( dtype_and_x, frontend_method_data, init_flags, @@ -2574,34 +2566,54 @@ def test_numpy___complex__( frontend, on_device, ): - input_dtypes, xs = dtype_and_x - + input_dtype, x, other = dtype_and_x helpers.test_frontend_method( - init_input_dtypes=input_dtypes, + init_input_dtypes=input_dtype, backend_to_test=backend_fw, - method_input_dtypes=input_dtypes, init_all_as_kwargs_np={ - "object": xs[0], + "object": x, + }, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={ + "b": other, }, - method_all_as_kwargs_np={}, - frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, + frontend=frontend, on_device=on_device, ) +@given( + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid", prune_function=False), + ret_shape=True, + ), +) +def test_numpy_ndarray_dtype(dtype_x, backend_fw, frontend): + dtype, data, shape = dtype_x + with BackendHandler.update_backend(backend_fw) as ivy_backend: + x = ivy_backend.functional.frontends.numpy.ndarray(shape, dtype[0]) + x.ivy_array = data[0] + ivy_backend.utils.assertions.check_equal( + x.dtype, ivy.Dtype(dtype[0]), as_array=False + ) + + +# fill @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="__contains__", + method_name="fill", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), + available_dtypes=helpers.get_dtypes("numeric"), ), + num=st.integers(min_value=1, max_value=10) | st.floats(min_value=1, max_value=10), ) -def test_numpy___contains__( +def test_numpy_ndarray_fill( dtype_and_x, + num, frontend_method_data, init_flags, method_flags, @@ -2609,17 +2621,16 @@ def test_numpy___contains__( frontend, on_device, ): - input_dtypes, xs = dtype_and_x - key = np.asarray(xs[0].reshape(-1)[0]) + input_dtypes, x = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtypes, backend_to_test=backend_fw, - method_input_dtypes=input_dtypes, init_all_as_kwargs_np={ - "object": xs[0], + "object": x[0], }, + method_input_dtypes=[], method_all_as_kwargs_np={ - "key": key, + "num": num, }, frontend=frontend, frontend_method_data=frontend_method_data, @@ -2629,35 +2640,51 @@ def test_numpy___contains__( ) +@given( + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric", prune_function=False), + num_arrays=1, + ret_shape=True, + ) +) +def test_numpy_ndarray_flat(dtype_x, backend_fw): + dtype, data, shape = dtype_x + + with BackendHandler.update_backend(backend_fw) as ivy_backend: + x = ivy_backend.functional.frontends.numpy.ndarray(shape, dtype[0]) + x.ivy_array = data[0] + + flat_ivy = x.flat + flat_ivy = flat_ivy.ivy_array.to_numpy() + flat_generated = ivy_backend.to_numpy(data[0]).flatten() + ivy_backend.utils.assertions.check_equal( + flat_ivy, flat_generated, as_array=True + ) + + @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="__iadd__", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2 - ), + method_name="item", + args_kwargs=_item_helper(), ) -def test_numpy___iadd__( - dtype_and_x, +def test_numpy_ndarray_item( + args_kwargs, frontend_method_data, init_flags, method_flags, - backend_fw, frontend, on_device, + backend_fw, ): - input_dtypes, xs = dtype_and_x - + input_dtype, x, method_all_as_kwargs_np, num_positional_args = args_kwargs + method_flags.num_positional_args = num_positional_args helpers.test_frontend_method( - init_input_dtypes=input_dtypes, + init_input_dtypes=input_dtype, + init_all_as_kwargs_np={"object": x}, + method_input_dtypes=input_dtype, backend_to_test=backend_fw, - method_input_dtypes=input_dtypes, - init_all_as_kwargs_np={ - "object": xs[0], - }, - method_all_as_kwargs_np={ - "value": xs[1], - }, + method_all_as_kwargs_np=method_all_as_kwargs_np, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -2666,16 +2693,47 @@ def test_numpy___iadd__( ) +@given( + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid", prune_function=False), + ret_shape=True, + ), +) +def test_numpy_ndarray_ivy_array( + dtype_x, + frontend, + backend_fw, +): + dtype, data, shape = dtype_x + with BackendHandler.update_backend(backend_fw) as ivy_backend: + x = ivy_backend.functional.frontends.numpy.ndarray(shape, dtype[0]) + x.ivy_array = data[0] + ret = helpers.flatten_and_to_np(ret=x.ivy_array.data, backend=backend_fw) + ret_gt = helpers.flatten_and_to_np(ret=data[0], backend=frontend) + helpers.value_test( + ret_np_flat=ret, + ret_np_from_gt_flat=ret_gt, + backend=backend_fw, + ground_truth_backend="numpy", + ) + + @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="__isub__", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2 + method_name="max", + dtype_x_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("numeric"), + min_axis=-1, + max_axis=0, + min_num_dims=1, + force_int_axis=True, ), + keepdims=st.booleans(), ) -def test_numpy___isub__( - dtype_and_x, +def test_numpy_ndarray_max( + dtype_x_axis, + keepdims, frontend_method_data, init_flags, method_flags, @@ -2683,17 +2741,18 @@ def test_numpy___isub__( frontend, on_device, ): - input_dtypes, xs = dtype_and_x + input_dtypes, x, axis = dtype_x_axis helpers.test_frontend_method( init_input_dtypes=input_dtypes, backend_to_test=backend_fw, - method_input_dtypes=input_dtypes, init_all_as_kwargs_np={ - "object": xs[0], + "object": x[0], }, + method_input_dtypes=input_dtypes, method_all_as_kwargs_np={ - "value": xs[1], + "axis": axis, + "keepdims": keepdims, }, frontend=frontend, frontend_method_data=frontend_method_data, @@ -2706,13 +2765,17 @@ def test_numpy___isub__( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="__imul__", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2 + method_name="mean", + dtype_x_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("float"), + min_axis=-1, + max_axis=0, + min_num_dims=1, + force_int_axis=True, ), ) -def test_numpy___imul__( - dtype_and_x, +def test_numpy_ndarray_mean( + dtype_x_axis, frontend_method_data, init_flags, method_flags, @@ -2720,39 +2783,45 @@ def test_numpy___imul__( frontend, on_device, ): - input_dtypes, xs = dtype_and_x - + input_dtypes, x, axis = dtype_x_axis helpers.test_frontend_method( init_input_dtypes=input_dtypes, backend_to_test=backend_fw, method_input_dtypes=input_dtypes, init_all_as_kwargs_np={ - "object": xs[0], + "object": x[0], }, method_all_as_kwargs_np={ - "value": xs[1], + "axis": axis, + "dtype": "float64", + "out": None, }, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, + rtol_=1e-2, + atol_=1e-2, on_device=on_device, ) -# __itruediv__ @handle_frontend_method( - class_tree=CLASS_TREE, - init_tree="numpy.array", - method_name="__itruediv__", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - shared_dtype=True, + class_tree=CLASS_TREE, + init_tree="numpy.array", + method_name="min", + dtype_x_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("numeric"), + min_axis=-1, + max_axis=0, + min_num_dims=1, + force_int_axis=True, ), + keepdims=st.booleans(), ) -def test_numpy___itruediv__( - dtype_and_x, +def test_numpy_ndarray_min( + dtype_x_axis, + keepdims, frontend_method_data, init_flags, method_flags, @@ -2760,22 +2829,23 @@ def test_numpy___itruediv__( frontend, on_device, ): - input_dtypes, xs = dtype_and_x + input_dtypes, x, axis = dtype_x_axis helpers.test_frontend_method( init_input_dtypes=input_dtypes, backend_to_test=backend_fw, init_all_as_kwargs_np={ - "object": xs[0], + "object": x[0], }, method_input_dtypes=input_dtypes, method_all_as_kwargs_np={ - "value": xs[1], + "axis": axis, + "keepdims": keepdims, }, + frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, - frontend=frontend, on_device=on_device, ) @@ -2783,15 +2853,13 @@ def test_numpy___itruediv__( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="__ipow__", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + method_name="nonzero", + dtype_and_a=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), ), - power=helpers.ints(min_value=1, max_value=3), ) -def test_numpy___ipow__( - dtype_and_x, - power, +def test_numpy_ndarray_nonzero( + dtype_and_a, frontend_method_data, init_flags, method_flags, @@ -2799,18 +2867,16 @@ def test_numpy___ipow__( frontend, on_device, ): - input_dtypes, xs = dtype_and_x + input_dtypes, a = dtype_and_a helpers.test_frontend_method( init_input_dtypes=input_dtypes, - backend_to_test=backend_fw, - method_input_dtypes=input_dtypes, init_all_as_kwargs_np={ - "object": xs[0], - }, - method_all_as_kwargs_np={ - "value": power, + "object": a[0], }, + method_input_dtypes=input_dtypes, + backend_to_test=backend_fw, + method_all_as_kwargs_np={}, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -2819,17 +2885,19 @@ def test_numpy___ipow__( ) +# prod @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="__iand__", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=("bool",), - num_arrays=2, - ), + method_name="prod", + dtype_x_axis_dtype=_get_castable_dtypes_values(use_where=True), + keep_dims=st.booleans(), + initial=st.one_of(st.floats(min_value=-100, max_value=100)), ) -def test_numpy___iand__( - dtype_and_x, +def test_numpy_ndarray_prod( + dtype_x_axis_dtype, + keep_dims, + initial, frontend_method_data, init_flags, method_flags, @@ -2837,17 +2905,33 @@ def test_numpy___iand__( frontend, on_device, ): - input_dtypes, xs = dtype_and_x + input_dtypes, x, axis, dtype, where = dtype_x_axis_dtype + if ivy.current_backend_str() == "torch": + assume(not method_flags.as_variable[0]) + ( + where, + input_dtypes, + method_flags, + ) = np_frontend_helpers.handle_where_and_array_bools( + where=where, + input_dtype=input_dtypes, + test_flags=method_flags, + ) + where = ivy.array(where, dtype="bool") helpers.test_frontend_method( init_input_dtypes=input_dtypes, backend_to_test=backend_fw, - method_input_dtypes=input_dtypes, init_all_as_kwargs_np={ - "object": xs[0], + "object": x[0], }, + method_input_dtypes=input_dtypes, method_all_as_kwargs_np={ - "value": xs[1], + "axis": axis, + "dtype": dtype, + "keepdims": keep_dims, + "initial": initial, + "where": where, }, frontend=frontend, frontend_method_data=frontend_method_data, @@ -2857,17 +2941,32 @@ def test_numpy___iand__( ) +@given( + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid", prune_function=False), + ret_shape=True, + ), +) +def test_numpy_ndarray_property_ndim(dtype_x, backend_fw): + dtype, data, shape = dtype_x + with BackendHandler.update_backend(backend_fw) as ivy_backend: + x = ivy_backend.functional.frontends.numpy.ndarray(shape, dtype[0]) + x.ivy_array = data[0] + ivy_backend.utils.assertions.check_equal(x.ndim, data[0].ndim, as_array=False) + + +# ptp @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="__ior__", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=("bool",), - num_arrays=2, + method_name="ptp", + dtype_x_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("numeric"), + valid_axis=True, ), ) -def test_numpy___ior__( - dtype_and_x, +def test_numpy_ndarray_ptp( + dtype_x_axis, frontend_method_data, init_flags, method_flags, @@ -2875,17 +2974,16 @@ def test_numpy___ior__( frontend, on_device, ): - input_dtypes, xs = dtype_and_x - + input_dtypes, x, axis = dtype_x_axis helpers.test_frontend_method( init_input_dtypes=input_dtypes, backend_to_test=backend_fw, - method_input_dtypes=input_dtypes, init_all_as_kwargs_np={ - "object": xs[0], + "object": x[0], }, + method_input_dtypes=input_dtypes, method_all_as_kwargs_np={ - "value": xs[1], + "axis": axis, }, frontend=frontend, frontend_method_data=frontend_method_data, @@ -2898,14 +2996,13 @@ def test_numpy___ior__( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="__ixor__", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=("bool",), - num_arrays=2, + method_name="ravel", + dtype_and_a=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), ), ) -def test_numpy___ixor__( - dtype_and_x, +def test_numpy_ndarray_ravel( + dtype_and_a, frontend_method_data, init_flags, method_flags, @@ -2913,18 +3010,16 @@ def test_numpy___ixor__( frontend, on_device, ): - input_dtypes, xs = dtype_and_x + input_dtypes, a = dtype_and_a helpers.test_frontend_method( init_input_dtypes=input_dtypes, - backend_to_test=backend_fw, - method_input_dtypes=input_dtypes, init_all_as_kwargs_np={ - "object": xs[0], - }, - method_all_as_kwargs_np={ - "value": xs[1], + "object": a[0], }, + method_input_dtypes=input_dtypes, + backend_to_test=backend_fw, + method_all_as_kwargs_np={}, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -2936,16 +3031,19 @@ def test_numpy___ixor__( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="__imod__", + method_name="repeat", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - min_value=0, - exclude_min=True, + available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=2, + min_dim_size=2, ), + repeats=helpers.ints(min_value=2, max_value=5), + axis=helpers.ints(min_value=-1, max_value=1), ) -def test_numpy___imod__( +def test_numpy_ndarray_repeat( dtype_and_x, + repeats, + axis, frontend_method_data, init_flags, method_flags, @@ -2953,16 +3051,18 @@ def test_numpy___imod__( frontend, on_device, ): - input_dtypes, xs = dtype_and_x + input_dtypes, x = dtype_and_x + helpers.test_frontend_method( init_input_dtypes=input_dtypes, backend_to_test=backend_fw, - method_input_dtypes=input_dtypes, init_all_as_kwargs_np={ - "object": xs[0], + "object": x[0], }, + method_input_dtypes=input_dtypes, method_all_as_kwargs_np={ - "value": xs[1], + "repeats": repeats, + "axis": axis, }, frontend=frontend, frontend_method_data=frontend_method_data, @@ -2975,17 +3075,13 @@ def test_numpy___imod__( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="__abs__", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - min_num_dims=1, - allow_inf=False, - large_abs_safety_factor=4, - safety_factor_scale="linear", - ), + method_name="reshape", + dtypes_x_shape=dtypes_x_reshape(), + order=st.sampled_from(["C", "F", "A"]), ) -def test_numpy___abs__( - dtype_and_x, +def test_numpy_ndarray_reshape( + dtypes_x_shape, + order, frontend_method_data, init_flags, method_flags, @@ -2993,16 +3089,18 @@ def test_numpy___abs__( frontend, on_device, ): - input_dtypes, x = dtype_and_x - + input_dtypes, x, shape = dtypes_x_shape helpers.test_frontend_method( init_input_dtypes=input_dtypes, backend_to_test=backend_fw, init_all_as_kwargs_np={ "object": x[0], }, - method_input_dtypes=input_dtypes, - method_all_as_kwargs_np={}, + method_input_dtypes=[], + method_all_as_kwargs_np={ + "newshape": shape, + "order": order, + }, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -3011,19 +3109,22 @@ def test_numpy___abs__( ) -# __len__ +# round @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="__len__", + method_name="round", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - min_num_dims=1, - max_num_dims=5, + available_dtypes=helpers.get_dtypes("float", full=False), + num_arrays=1, + max_value=50, + min_value=-50, ), + decimals=st.integers(min_value=0, max_value=3), ) -def test_numpy___len__( +def test_numpy_ndarray_round( dtype_and_x, + decimals, frontend_method_data, init_flags, method_flags, @@ -3031,34 +3132,40 @@ def test_numpy___len__( frontend, on_device, ): - input_dtypes, x = dtype_and_x + input_dtype, x = dtype_and_x helpers.test_frontend_method( - init_input_dtypes=input_dtypes, + init_input_dtypes=input_dtype, + method_input_dtypes=input_dtype, backend_to_test=backend_fw, + frontend=frontend, init_all_as_kwargs_np={ - "object": x[0], + "object": x, }, - method_input_dtypes=input_dtypes, - method_all_as_kwargs_np={}, + method_all_as_kwargs_np={ + "decimals": decimals, + }, + frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, - frontend=frontend, - frontend_method_data=frontend_method_data, on_device=on_device, ) -# __array__ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="__array__", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), + method_name="searchsorted", + dtype_x_v=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("signed_integer"), + min_num_dims=1, + max_num_dims=1, + num_arrays=2, ), + side=st.sampled_from(["left", "right"]), ) -def test_numpy___array__( - dtype_and_x, +def test_numpy_ndarray_searchsorted( + dtype_x_v, + side, frontend_method_data, init_flags, method_flags, @@ -3066,37 +3173,39 @@ def test_numpy___array__( frontend, on_device, ): - input_dtypes, x = dtype_and_x + input_dtypes, xs = dtype_x_v + helpers.test_frontend_method( init_input_dtypes=input_dtypes, - backend_to_test=backend_fw, init_all_as_kwargs_np={ - "object": x[0], + "object": xs[0], }, method_input_dtypes=input_dtypes, + backend_to_test=backend_fw, method_all_as_kwargs_np={ - "dtype": np.dtype(input_dtypes[0]), + "v": xs[1], + "side": side, + "sorter": np.argsort(xs[0]), }, - init_flags=init_flags, - method_flags=method_flags, frontend=frontend, frontend_method_data=frontend_method_data, + init_flags=init_flags, + method_flags=method_flags, on_device=on_device, ) -# __array_wrap__ +# __setitem__ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="__array_wrap__", - dtype_and_x=helpers.dtype_and_values( + method_name="__setitem__", + dtypes_x_index_val=helpers.dtype_array_query_val( available_dtypes=helpers.get_dtypes("valid"), - num_arrays=2, ), ) -def test_numpy___array_wrap__( - dtype_and_x, +def test_numpy_ndarray_setitem( + dtypes_x_index_val, frontend_method_data, init_flags, method_flags, @@ -3104,37 +3213,29 @@ def test_numpy___array_wrap__( frontend, on_device, ): - input_dtypes, x = dtype_and_x + input_dtype, x, index, val = dtypes_x_index_val helpers.test_frontend_method( - init_input_dtypes=input_dtypes, + init_input_dtypes=[input_dtype[0]], + init_all_as_kwargs_np={"object": x}, + method_input_dtypes=[*input_dtype[1:]], + method_all_as_kwargs_np={"key": index, "value": val}, backend_to_test=backend_fw, - init_all_as_kwargs_np={ - "object": x[0], - }, - method_input_dtypes=input_dtypes, - method_all_as_kwargs_np={ - "array": x[1], - "context": None, - }, + frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, frontend=frontend, - frontend_method_data=frontend_method_data, on_device=on_device, ) -# tobytes @given( dtype_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid", prune_function=False), ret_shape=True, ), - order=st.sampled_from(["C", "F"]), ) -def test_numpy_ndarray_tobytes( +def test_numpy_ndarray_shape( dtype_x, - order, backend_fw, ): dtype, data, shape = dtype_x @@ -3142,27 +3243,39 @@ def test_numpy_ndarray_tobytes( x = ivy_backend.functional.frontends.numpy.ndarray(shape, dtype[0]) x.ivy_array = data[0] ivy_backend.utils.assertions.check_equal( - x.tobytes(order=order), data[0].tobytes(order=order), as_array=False + x.shape, ivy.Shape(shape), as_array=False ) -# tofile +@given( + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid", prune_function=False), + ret_shape=True, + ), +) +def test_numpy_ndarray_size( + dtype_x, +): + dtype, data, shape = dtype_x + x = ndarray(shape, dtype[0]) + x.ivy_array = data[0] + ivy.utils.assertions.check_equal(x.size, data[0].size, as_array=False) + + @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="tofile", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - ), - path=st.text( - alphabet=st.characters(whitelist_categories=("Lu", "Ll", "Nd", "Pc")), - min_size=1, - max_size=50, + method_name="sort", + dtype_x_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("float"), + min_axis=-1, + max_axis=0, + min_num_dims=1, + force_int_axis=True, ), ) -def test_numpy_ndarray_tofile( - dtype_and_x, - path, +def test_numpy_ndarray_sort( + dtype_x_axis, frontend_method_data, init_flags, method_flags, @@ -3170,8 +3283,9 @@ def test_numpy_ndarray_tofile( frontend, on_device, ): - input_dtypes, x = dtype_and_x - helpers.test_frontend_method( + input_dtypes, x, axis = dtype_x_axis + + ret, frontend_ret = helpers.test_frontend_method( init_input_dtypes=input_dtypes, backend_to_test=backend_fw, init_all_as_kwargs_np={ @@ -3179,27 +3293,39 @@ def test_numpy_ndarray_tofile( }, method_input_dtypes=input_dtypes, method_all_as_kwargs_np={ - "fid": path, + "axis": axis, }, - init_flags=init_flags, - method_flags=method_flags, frontend=frontend, frontend_method_data=frontend_method_data, + init_flags=init_flags, + method_flags=method_flags, + test_values=False, on_device=on_device, ) + frontend_ret = np.sort(x[0], axis=axis) + assert_all_close( + ret_np=ret, + ret_from_gt_np=frontend_ret, + rtol=1e-2, + atol=1e-2, + backend=backend_fw, + ground_truth_backend="numpy", + ) -# tolist @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="tolist", + method_name="squeeze", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), + shape=st.shared(helpers.get_shape(), key="value_shape"), ), + axis=_squeeze_helper(), ) -def test_numpy_ndarray_tolist( +def test_numpy_ndarray_squeeze( dtype_and_x, + axis, frontend_method_data, init_flags, method_flags, @@ -3207,68 +3333,43 @@ def test_numpy_ndarray_tolist( frontend, on_device, ): - input_dtypes, x = dtype_and_x + input_dtype, x = dtype_and_x + helpers.test_frontend_method( - init_input_dtypes=input_dtypes, + init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ "object": x[0], }, - method_input_dtypes=input_dtypes, - method_all_as_kwargs_np={}, - init_flags=init_flags, - method_flags=method_flags, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={ + "axis": axis, + }, frontend=frontend, frontend_method_data=frontend_method_data, - on_device=on_device, - test_values=False, # Todo change this after we add __iter__ to ndarray - ) - - -# __getitem__ -@handle_frontend_method( - class_tree=CLASS_TREE, - init_tree="numpy.array", - method_name="__getitem__", - dtype_x_index=helpers.dtype_array_query( - available_dtypes=helpers.get_dtypes("valid"), - ), -) -def test_numpy_getitem( - dtype_x_index, - frontend_method_data, - init_flags, - method_flags, - backend_fw, - frontend, - on_device, -): - input_dtype, x, index = dtype_x_index - helpers.test_frontend_method( - init_input_dtypes=[input_dtype[0]], - init_all_as_kwargs_np={"object": x}, - method_input_dtypes=[*input_dtype[1:]], - method_all_as_kwargs_np={"key": index}, - backend_to_test=backend_fw, - frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, - frontend=frontend, on_device=on_device, ) -# __setitem__ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="__setitem__", - dtypes_x_index_val=helpers.dtype_array_query_val( + method_name="std", + dtype_x_axis=helpers.dtype_values_axis( available_dtypes=helpers.get_dtypes("valid"), + max_value=100, + valid_axis=True, + force_int_axis=True, ), + keepdims=st.booleans(), + where=np_frontend_helpers.where(), ) -def test_numpy_ndarray_setitem( - dtypes_x_index_val, +def test_numpy_ndarray_std( + dtype_x_axis, + keepdims, + where, frontend_method_data, init_flags, method_flags, @@ -3276,130 +3377,118 @@ def test_numpy_ndarray_setitem( frontend, on_device, ): - input_dtype, x, index, val = dtypes_x_index_val - helpers.test_frontend_method( - init_input_dtypes=[input_dtype[0]], - init_all_as_kwargs_np={"object": x}, - method_input_dtypes=[*input_dtype[1:]], - method_all_as_kwargs_np={"key": index, "value": val}, - backend_to_test=backend_fw, - frontend_method_data=frontend_method_data, - init_flags=init_flags, - method_flags=method_flags, - frontend=frontend, - on_device=on_device, + input_dtypes, x, axis = dtype_x_axis + ( + where, + input_dtypes, + method_flags, + ) = np_frontend_helpers.handle_where_and_array_bools( + where=[where[0][0]] if isinstance(where, list) else where, + input_dtype=input_dtypes, + test_flags=method_flags, ) - - -# view -@handle_frontend_method( - class_tree=CLASS_TREE, - init_tree="numpy.array", - method_name="view", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - ), -) -def test_numpy_ndarray_view( - dtype_and_x, - frontend_method_data, - init_flags, - method_flags, - backend_fw, - frontend, - on_device, -): - input_dtypes, x = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtypes, backend_to_test=backend_fw, + method_input_dtypes=input_dtypes, init_all_as_kwargs_np={ - "object": x[0], + "data": x[0], + }, + method_all_as_kwargs_np={ + "axis": axis, + "out": None, + "ddof": 0, + "keepdims": keepdims, + "where": where, }, - method_input_dtypes=input_dtypes, - method_all_as_kwargs_np={}, - init_flags=init_flags, - method_flags=method_flags, frontend=frontend, frontend_method_data=frontend_method_data, + init_flags=init_flags, + method_flags=method_flags, on_device=on_device, ) -# mod +# sum @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="__mod__", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - large_abs_safety_factor=2, - small_abs_safety_factor=2, - safety_factor_scale="log", - min_value=0, - exclude_min=True, - ), + method_name="sum", + dtype_x_axis_dtype=_get_castable_dtypes_values(use_where=True), + keep_dims=st.booleans(), + initial=st.one_of(st.floats(min_value=-100, max_value=100)), ) -def test_numpy___mod__( - dtype_and_x, +def test_numpy_ndarray_sum( + dtype_x_axis_dtype, + keep_dims, + initial, frontend_method_data, init_flags, method_flags, - backend_fw, frontend, on_device, + backend_fw, ): - input_dtypes, xs = dtype_and_x + input_dtypes, x, axis, dtype, where = dtype_x_axis_dtype + if ivy.current_backend_str() == "torch": + assume(not method_flags.as_variable[0]) + + where, input_dtypes, method_flags = ( + np_frontend_helpers.handle_where_and_array_bools( + where=where, + input_dtype=input_dtypes, + test_flags=method_flags, + ) + ) + where = ivy.array(where, dtype="bool") helpers.test_frontend_method( init_input_dtypes=input_dtypes, - backend_to_test=backend_fw, - method_input_dtypes=input_dtypes, init_all_as_kwargs_np={ - "object": xs[0], + "object": x[0], }, + method_input_dtypes=input_dtypes, method_all_as_kwargs_np={ - "value": xs[1], + "axis": axis, + "dtype": dtype, + "keepdims": keep_dims, + "initial": initial, + "where": where, }, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, on_device=on_device, - rtol_=1e-5, - atol_=1e-5, + backend_to_test=backend_fw, ) -# ptp @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="ptp", - dtype_x_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("numeric"), - valid_axis=True, - ), + method_name="swapaxes", + dtype_x_and_axes=dtype_values_and_axes(), ) -def test_numpy_ndarray_ptp( - dtype_x_axis, +def test_numpy_ndarray_swapaxes( + dtype_x_and_axes, + frontend, frontend_method_data, init_flags, method_flags, backend_fw, - frontend, on_device, ): - input_dtypes, x, axis = dtype_x_axis + input_dtypes, x, axis1, axis2 = dtype_x_and_axes helpers.test_frontend_method( init_input_dtypes=input_dtypes, backend_to_test=backend_fw, + method_input_dtypes=input_dtypes, init_all_as_kwargs_np={ "object": x[0], }, - method_input_dtypes=input_dtypes, method_all_as_kwargs_np={ - "axis": axis, + "axis1": axis1, + "axis2": axis2, }, frontend=frontend, frontend_method_data=frontend_method_data, @@ -3409,212 +3498,138 @@ def test_numpy_ndarray_ptp( ) -# item -@st.composite -def _item_helper(draw): - dtype = draw( - helpers.array_dtypes( - num_arrays=1, - available_dtypes=helpers.get_dtypes("numeric"), - ) - ) - shape = draw( - helpers.get_shape( - min_num_dims=1, - max_num_dims=3, - min_dim_size=1, - max_dim_size=10, - ) - ) - array = draw( - helpers.array_values( - dtype=dtype[0], - shape=shape, - large_abs_safety_factor=2, - small_abs_safety_factor=2, - ) - ) - - index = () - for s in shape: - index += (draw(st.integers(min_value=-s + 1, max_value=s - 1)),) - - index_samples = [index, draw(helpers.ints(min_value=0, max_value=array.size - 1))] - - if array.size == 1: - index_samples.append(None) - - sampled_index = draw(st.sampled_from(index_samples)) - - if sampled_index is None: - method_all_as_kwargs_np = {} - num_positional_args = 0 - else: - method_all_as_kwargs_np = {"args": sampled_index} - num_positional_args = 1 - - return dtype, array, method_all_as_kwargs_np, num_positional_args - - -@handle_frontend_method( - class_tree=CLASS_TREE, - init_tree="numpy.array", - method_name="item", - args_kwargs=_item_helper(), +# tobytes +@given( + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid", prune_function=False), + ret_shape=True, + ), + order=st.sampled_from(["C", "F"]), ) -def test_numpy_ndarray_item( - args_kwargs, - frontend_method_data, - init_flags, - method_flags, - frontend, - on_device, +def test_numpy_ndarray_tobytes( + dtype_x, + order, backend_fw, ): - input_dtype, x, method_all_as_kwargs_np, num_positional_args = args_kwargs - method_flags.num_positional_args = num_positional_args - helpers.test_frontend_method( - init_input_dtypes=input_dtype, - init_all_as_kwargs_np={"object": x}, - method_input_dtypes=input_dtype, - backend_to_test=backend_fw, - method_all_as_kwargs_np=method_all_as_kwargs_np, - frontend=frontend, - frontend_method_data=frontend_method_data, - init_flags=init_flags, - method_flags=method_flags, - on_device=on_device, - ) + dtype, data, shape = dtype_x + with BackendHandler.update_backend(backend_fw) as ivy_backend: + x = ivy_backend.functional.frontends.numpy.ndarray(shape, dtype[0]) + x.ivy_array = data[0] + ivy_backend.utils.assertions.check_equal( + x.tobytes(order=order), data[0].tobytes(order=order), as_array=False + ) +# tofile @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="__rshift__", + method_name="tofile", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("integer"), - num_arrays=2, + available_dtypes=helpers.get_dtypes("valid"), + ), + path=st.text( + alphabet=st.characters(whitelist_categories=("Lu", "Ll", "Nd", "Pc")), + min_size=1, + max_size=50, ), ) -def test_numpy___rshift__( +def test_numpy_ndarray_tofile( dtype_and_x, + path, frontend_method_data, init_flags, method_flags, - frontend, backend_fw, + frontend, on_device, ): input_dtypes, x = dtype_and_x - max_bits = np.iinfo(input_dtypes[0]).bits - max_shift = max_bits - 1 - x[1] = np.asarray(np.clip(x[1], 0, max_shift), dtype=input_dtypes[1]) - max_value_before_shift = 2 ** (max_bits - x[1]) - 1 - overflow_threshold = 2 ** (max_bits - 1) - x[0] = np.asarray( - np.clip(x[0], None, max_value_before_shift), dtype=input_dtypes[0] - ) - if np.any(x[0] > overflow_threshold): - x[0] = np.clip(x[0], None, overflow_threshold) - if np.any(x[0] < 0): - x[0] = np.abs(x[0]) helpers.test_frontend_method( init_input_dtypes=input_dtypes, + backend_to_test=backend_fw, init_all_as_kwargs_np={ "object": x[0], }, method_input_dtypes=input_dtypes, - backend_to_test=backend_fw, method_all_as_kwargs_np={ - "value": x[1], + "fid": path, }, - frontend=frontend, - frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, + frontend=frontend, + frontend_method_data=frontend_method_data, on_device=on_device, ) +# tolist @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="__lshift__", + method_name="tolist", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("integer"), - num_arrays=2, - max_dim_size=1, - max_value=2**31 - 1, + available_dtypes=helpers.get_dtypes("valid"), ), ) -def test_numpy_instance_lshift__( +def test_numpy_ndarray_tolist( dtype_and_x, frontend_method_data, init_flags, method_flags, - frontend, backend_fw, + frontend, on_device, ): input_dtypes, x = dtype_and_x - max_bits = np.iinfo(input_dtypes[0]).bits - max_shift = max_bits - 1 - x[1] = np.asarray(np.clip(x[1], 0, max_shift), dtype=input_dtypes[1]) - max_value_before_shift = 2 ** (max_bits - x[1]) - 1 - overflow_threshold = 2 ** (max_bits - 1) - x[0] = np.asarray( - np.clip(x[0], None, max_value_before_shift), dtype=input_dtypes[0] - ) - if np.any(x[0] > overflow_threshold): - x[0] = np.clip(x[0], None, overflow_threshold) - if np.any(x[0] < 0): - x[0] = np.abs(x[0]) helpers.test_frontend_method( init_input_dtypes=input_dtypes, + backend_to_test=backend_fw, init_all_as_kwargs_np={ "object": x[0], }, method_input_dtypes=input_dtypes, - backend_to_test=backend_fw, - method_all_as_kwargs_np={ - "value": x[1], - }, - frontend=frontend, - frontend_method_data=frontend_method_data, + method_all_as_kwargs_np={}, init_flags=init_flags, method_flags=method_flags, + frontend=frontend, + frontend_method_data=frontend_method_data, on_device=on_device, + test_values=False, # Todo change this after we add __iter__ to ndarray ) @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="__invert__", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes(kind="integer"), - num_arrays=1, + method_name="transpose", + array_and_axes=np_frontend_helpers._array_and_axes_permute_helper( + min_num_dims=2, + max_num_dims=5, + min_dim_size=2, + max_dim_size=10, ), ) -def test_numpy___invert__( - dtype_and_x, +def test_numpy_ndarray_transpose( + array_and_axes, frontend_method_data, init_flags, method_flags, - frontend, backend_fw, + frontend, on_device, ): - input_dtypes, xs = dtype_and_x - + array, input_dtypes, axes = array_and_axes helpers.test_frontend_method( init_input_dtypes=input_dtypes, + backend_to_test=backend_fw, init_all_as_kwargs_np={ - "object": xs[0], + "object": np.array(array), }, method_input_dtypes=input_dtypes, - backend_to_test=backend_fw, - method_all_as_kwargs_np={}, + method_all_as_kwargs_np={ + "axes": axes, + }, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -3623,22 +3638,17 @@ def test_numpy___invert__( ) -# round +# view @handle_frontend_method( class_tree=CLASS_TREE, init_tree="numpy.array", - method_name="round", + method_name="view", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float", full=False), - num_arrays=1, - max_value=50, - min_value=-50, + available_dtypes=helpers.get_dtypes("valid"), ), - decimals=st.integers(min_value=0, max_value=3), ) -def test_numpy_ndarray_round( +def test_numpy_ndarray_view( dtype_and_x, - decimals, frontend_method_data, init_flags, method_flags, @@ -3646,20 +3656,18 @@ def test_numpy_ndarray_round( frontend, on_device, ): - input_dtype, x = dtype_and_x + input_dtypes, x = dtype_and_x helpers.test_frontend_method( - init_input_dtypes=input_dtype, - method_input_dtypes=input_dtype, + init_input_dtypes=input_dtypes, backend_to_test=backend_fw, - frontend=frontend, init_all_as_kwargs_np={ - "object": x, - }, - method_all_as_kwargs_np={ - "decimals": decimals, + "object": x[0], }, - frontend_method_data=frontend_method_data, + method_input_dtypes=input_dtypes, + method_all_as_kwargs_np={}, init_flags=init_flags, method_flags=method_flags, + frontend=frontend, + frontend_method_data=frontend_method_data, on_device=on_device, ) diff --git a/ivy_tests/test_ivy/test_frontends/test_numpy/test_random/test_functions.py b/ivy_tests/test_ivy/test_frontends/test_numpy/test_random/test_functions.py index 7010d265a5244..b2cae02335ca7 100644 --- a/ivy_tests/test_ivy/test_frontends/test_numpy/test_random/test_functions.py +++ b/ivy_tests/test_ivy/test_frontends/test_numpy/test_random/test_functions.py @@ -7,14 +7,21 @@ from ivy_tests.test_ivy.helpers import handle_frontend_test -# random_sample +# beta @handle_frontend_test( - fn_tree="numpy.random.random_sample", - input_dtypes=helpers.get_dtypes("integer", full=False), - size=helpers.get_shape(allow_none=True), - test_with_out=st.just(False), + fn_tree="numpy.random.beta", + input_dtypes=helpers.get_dtypes("float", index=2), + a=st.floats( + allow_nan=False, allow_infinity=False, width=32, min_value=0, exclude_min=True + ), + b=st.floats( + allow_nan=False, allow_infinity=False, width=32, min_value=0, exclude_min=True + ), + size=st.tuples( + st.integers(min_value=2, max_value=5), st.integers(min_value=2, max_value=5) + ), ) -def test_numpy_random_sample( +def test_numpy_beta( input_dtypes, size, frontend, @@ -22,6 +29,8 @@ def test_numpy_random_sample( fn_tree, backend_fw, on_device, + a, + b, ): helpers.test_frontend_function( input_dtypes=input_dtypes, @@ -31,29 +40,84 @@ def test_numpy_random_sample( fn_tree=fn_tree, on_device=on_device, test_values=False, + a=a, + b=b, size=size, ) -# dirichlet +# binomial @handle_frontend_test( - fn_tree="numpy.random.dirichlet", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - shape=st.tuples( - st.integers(min_value=2, max_value=5), - ), - min_value=1, - max_value=100, - exclude_min=True, + fn_tree="numpy.random.binomial", + n=st.integers(min_value=0, max_value=2), + dtype=helpers.get_dtypes("float", full=False, index=2), + p=st.floats( + allow_nan=False, allow_infinity=False, width=32, min_value=0, max_value=1 ), size=st.tuples( st.integers(min_value=2, max_value=5), st.integers(min_value=2, max_value=5) ), +) +def test_numpy_binomial( + dtype, + size, + test_flags, + frontend, + backend_fw, + fn_tree, + on_device, + n, + p, +): + helpers.test_frontend_function( + input_dtypes=dtype, + test_flags=test_flags, + backend_to_test=backend_fw, + frontend=frontend, + fn_tree=fn_tree, + on_device=on_device, + test_values=False, + n=n, + p=p, + size=size, + ) + + +# chisquare +# The test values are restricted to (0, 1000] because df<=0 is invalid +# and very large df can cause problems with type conversions +@handle_frontend_test( + fn_tree="numpy.random.chisquare", + dtypes=helpers.get_dtypes("float", full=False), + df=st.one_of( + st.floats( + min_value=0, + max_value=1000, + exclude_min=True, + allow_subnormal=False, + width=32, + ), + st.integers(min_value=1, max_value=1000), + st.lists( + st.one_of( + st.floats( + min_value=0, + max_value=1000, + exclude_min=True, + allow_subnormal=False, + width=32, + ) + | st.integers(min_value=1, max_value=1000) + ), + min_size=1, + ), + ), + size=helpers.get_shape(allow_none=True), test_with_out=st.just(False), ) -def test_numpy_dirichlet( - dtype_and_x, +def test_numpy_chisquare( + dtypes, + df, size, frontend, test_flags, @@ -61,122 +125,148 @@ def test_numpy_dirichlet( backend_fw, on_device, ): - input_dtype, x = dtype_and_x + # make sure `size` is something `df` can be broadcast to + if ( + hasattr(df, "__len__") + and size is not None + and (len(size) == 0 or size[-1] != len(df)) + ): + size = (*size, len(df)) helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=dtypes, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - alpha=x[0], test_values=False, + df=df, size=size, ) -# uniform +# dirichlet @handle_frontend_test( - fn_tree="numpy.random.uniform", - input_dtypes=helpers.get_dtypes("float", index=2), - low=st.floats(allow_nan=False, allow_infinity=False, width=32), - high=st.floats(allow_nan=False, allow_infinity=False, width=32), + fn_tree="numpy.random.dirichlet", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + shape=st.tuples( + st.integers(min_value=2, max_value=5), + ), + min_value=1, + max_value=100, + exclude_min=True, + ), size=st.tuples( st.integers(min_value=2, max_value=5), st.integers(min_value=2, max_value=5) ), + test_with_out=st.just(False), ) -def test_numpy_uniform( - input_dtypes, +def test_numpy_dirichlet( + dtype_and_x, size, frontend, test_flags, fn_tree, backend_fw, on_device, - low, - high, ): + input_dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=input_dtypes, + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, + alpha=x[0], test_values=False, - low=low, - high=high, size=size, ) -# normal @handle_frontend_test( - fn_tree="numpy.random.normal", - input_dtypes=helpers.get_dtypes("float", index=2), - loc=st.floats(allow_nan=False, allow_infinity=False, width=32), - scale=st.floats(allow_nan=False, allow_infinity=False, width=32, min_value=0), - size=st.tuples( - st.integers(min_value=2, max_value=5), st.integers(min_value=2, max_value=5) + fn_tree="numpy.random.f", + input_dtypes=helpers.get_dtypes("float"), + dfn=st.floats( + allow_nan=False, + allow_infinity=False, + width=32, + min_value=1, + max_value=1000, + exclude_min=True, + ), + dfd=st.floats( + allow_nan=False, + allow_infinity=False, + width=32, + min_value=1, + max_value=1000, + exclude_min=True, ), + size=helpers.get_shape(allow_none=False), ) -def test_numpy_normal( +def test_numpy_f( input_dtypes, size, frontend, test_flags, - fn_tree, backend_fw, + fn_tree, on_device, - loc, - scale, + dfn, + dfd, ): + test_flags.num_positional_args = 2 helpers.test_frontend_function( input_dtypes=input_dtypes, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, - on_device=on_device, test_values=False, - loc=loc, - scale=scale, + on_device=on_device, + dfn=dfn, + dfd=dfd, size=size, ) -# poisson @handle_frontend_test( - fn_tree="numpy.random.poisson", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - shape=st.tuples(st.integers(min_value=1, max_value=2)), - min_value=1, - max_value=100, + fn_tree="numpy.random.gamma", + input_dtypes=helpers.get_dtypes("float", full=False), + shape=st.floats( + allow_nan=False, allow_infinity=False, width=32, min_value=0, exclude_min=True + ), + scale=st.floats( + allow_nan=False, allow_infinity=False, width=32, min_value=0, exclude_min=True ), size=st.tuples( - st.integers(min_value=1, max_value=10), st.integers(min_value=2, max_value=2) + st.integers(min_value=2, max_value=5), st.integers(min_value=2, max_value=5) ), + test_with_out=st.just(False), ) -def test_numpy_poisson( - dtype_and_x, +def test_numpy_gamma( + input_dtypes, + shape, + scale, size, - test_flags, frontend, - backend_fw, + test_flags, fn_tree, on_device, + backend_fw, ): - input_dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=input_dtypes, backend_to_test=backend_fw, test_flags=test_flags, frontend=frontend, fn_tree=fn_tree, on_device=on_device, - lam=x[0], test_values=False, + shape=shape, + scale=scale, size=size, ) @@ -218,82 +308,113 @@ def test_numpy_geometric( ) -# multinomial +# gumbel @handle_frontend_test( - fn_tree="numpy.random.multinomial", - n=helpers.ints(min_value=2, max_value=10), - dtype=helpers.get_dtypes("float", full=False), - size=st.tuples( - st.integers(min_value=1, max_value=10), st.integers(min_value=2, max_value=2) + fn_tree="numpy.random.gumbel", + input_dtypes=helpers.get_dtypes("float"), + loc=st.floats( + allow_nan=False, + allow_infinity=False, + width=32, + max_value=1000, + ), + scale=st.floats( + allow_nan=False, + allow_infinity=False, + width=32, + min_value=0, + max_value=1000, + exclude_min=True, ), + size=helpers.get_shape(allow_none=True), ) -def test_numpy_multinomial( - n, - dtype, - size, - test_flags, +def test_numpy_gumbel( + input_dtypes, frontend, + test_flags, backend_fw, fn_tree, on_device, + loc, + scale, + size, ): helpers.test_frontend_function( - input_dtypes=dtype, + input_dtypes=input_dtypes, backend_to_test=backend_fw, - test_flags=test_flags, frontend=frontend, + test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, test_values=False, - n=n, - pvals=np.array([1 / n] * n, dtype=dtype[0]), + loc=loc, + scale=scale, size=size, ) -# permutation +# logistic @handle_frontend_test( - fn_tree="numpy.random.permutation", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), min_num_dims=1 + fn_tree="numpy.random.logistic", + input_dtypes=helpers.get_dtypes("float", full=False), + loc=st.floats( + allow_nan=False, + allow_infinity=False, + width=32, + min_value=0, + exclude_min=True, + ), + scale=st.floats( + allow_nan=False, + allow_infinity=False, + width=32, + min_value=0, + exclude_min=True, ), + size=helpers.get_shape(allow_none=True), + test_with_out=st.just(False), ) -def test_numpy_permutation( - dtype_and_x, +def test_numpy_logistic( + input_dtypes, + size, frontend, test_flags, fn_tree, - backend_fw, on_device, + backend_fw, + loc, + scale, ): - input_dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=input_dtypes, backend_to_test=backend_fw, - frontend=frontend, test_flags=test_flags, + frontend=frontend, fn_tree=fn_tree, on_device=on_device, test_values=False, - x=x[0], + loc=loc, + scale=scale, + size=size, ) -# beta +# lognormal +# min value is set 0 @handle_frontend_test( - fn_tree="numpy.random.beta", + fn_tree="numpy.random.lognormal", input_dtypes=helpers.get_dtypes("float", index=2), - a=st.floats( - allow_nan=False, allow_infinity=False, width=32, min_value=0, exclude_min=True + mean=st.floats( + allow_nan=False, allow_infinity=False, width=32, min_value=-5, max_value=5 ), - b=st.floats( - allow_nan=False, allow_infinity=False, width=32, min_value=0, exclude_min=True + sigma=st.floats( + allow_nan=False, allow_infinity=False, width=32, min_value=0, max_value=5 ), size=st.tuples( st.integers(min_value=2, max_value=5), st.integers(min_value=2, max_value=5) ), ) -def test_numpy_beta( +def test_numpy_lognormal( input_dtypes, size, frontend, @@ -301,8 +422,8 @@ def test_numpy_beta( fn_tree, backend_fw, on_device, - a, - b, + mean, + sigma, ): helpers.test_frontend_function( input_dtypes=input_dtypes, @@ -312,132 +433,88 @@ def test_numpy_beta( fn_tree=fn_tree, on_device=on_device, test_values=False, - a=a, - b=b, + mean=mean, + sigma=sigma, size=size, ) -# shuffle +# multinomial @handle_frontend_test( - fn_tree="numpy.random.shuffle", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), min_num_dims=1 + fn_tree="numpy.random.multinomial", + n=helpers.ints(min_value=2, max_value=10), + dtype=helpers.get_dtypes("float", full=False), + size=st.tuples( + st.integers(min_value=1, max_value=10), st.integers(min_value=2, max_value=2) ), ) -def test_numpy_shuffle( - dtype_and_x, - frontend, - test_flags, - fn_tree, - backend_fw, - on_device, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - test_values=False, - x=x[0], - ) - - -@handle_frontend_test( - fn_tree="numpy.random.standard_normal", - input_dtypes=helpers.get_dtypes("integer", full=False), - size=helpers.get_shape(allow_none=True), - test_with_out=st.just(False), -) -def test_numpy_standard_normal( - input_dtypes, +def test_numpy_multinomial( + n, + dtype, size, - frontend, test_flags, - fn_tree, + frontend, backend_fw, + fn_tree, on_device, ): helpers.test_frontend_function( - input_dtypes=input_dtypes, + input_dtypes=dtype, backend_to_test=backend_fw, - frontend=frontend, test_flags=test_flags, + frontend=frontend, fn_tree=fn_tree, on_device=on_device, test_values=False, + n=n, + pvals=np.array([1 / n] * n, dtype=dtype[0]), size=size, ) +# negative_binomial @handle_frontend_test( - fn_tree="numpy.random.standard_gamma", - shape_dtypes=helpers.get_dtypes("float", full=False), - shape=st.floats( - allow_nan=False, allow_infinity=False, width=32, min_value=0, exclude_min=True + fn_tree="numpy.random.negative_binomial", + input_dtypes=helpers.get_dtypes("float", index=2), + # max value for n and min value for p are restricted in testing + # as they can blow up poisson lambda, which will cause an + # error (lam value too large). + n=st.floats( + allow_nan=False, + allow_infinity=False, + width=32, + min_value=0, + max_value=100000, + exclude_min=True, ), - size=st.tuples( - st.integers(min_value=2, max_value=5), st.integers(min_value=2, max_value=5) + p=st.floats( + allow_nan=False, + allow_infinity=False, + width=32, + min_value=9.999999747378752e-06, + exclude_min=True, + max_value=1, + exclude_max=True, ), - size_dtypes=helpers.get_dtypes("integer", full=False), + size=helpers.get_shape(allow_none=True), test_with_out=st.just(False), ) -def test_numpy_standard_gamma( - shape, - shape_dtypes, +def test_numpy_negative_binomial( + input_dtypes, size, - size_dtypes, frontend, test_flags, fn_tree, backend_fw, on_device, -): - assume("float16" not in shape_dtypes) - helpers.test_frontend_function( - input_dtypes=shape_dtypes + size_dtypes, - backend_to_test=backend_fw, - test_flags=test_flags, - frontend=frontend, - fn_tree=fn_tree, - on_device=on_device, - test_values=False, - shape=shape, - size=size, - ) - - -# binomial -@handle_frontend_test( - fn_tree="numpy.random.binomial", - n=st.integers(min_value=0, max_value=2), - dtype=helpers.get_dtypes("float", full=False, index=2), - p=st.floats( - allow_nan=False, allow_infinity=False, width=32, min_value=0, max_value=1 - ), - size=st.tuples( - st.integers(min_value=2, max_value=5), st.integers(min_value=2, max_value=5) - ), -) -def test_numpy_binomial( - dtype, - size, - test_flags, - frontend, - backend_fw, - fn_tree, - on_device, n, p, ): helpers.test_frontend_function( - input_dtypes=dtype, - test_flags=test_flags, + input_dtypes=input_dtypes, backend_to_test=backend_fw, frontend=frontend, + test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, test_values=False, @@ -447,93 +524,65 @@ def test_numpy_binomial( ) -# chisquare -# The test values are restricted to (0, 1000] because df<=0 is invalid -# and very large df can cause problems with type conversions +# normal @handle_frontend_test( - fn_tree="numpy.random.chisquare", - dtypes=helpers.get_dtypes("float", full=False), - df=st.one_of( - st.floats( - min_value=0, - max_value=1000, - exclude_min=True, - allow_subnormal=False, - width=32, - ), - st.integers(min_value=1, max_value=1000), - st.lists( - st.one_of( - st.floats( - min_value=0, - max_value=1000, - exclude_min=True, - allow_subnormal=False, - width=32, - ) - | st.integers(min_value=1, max_value=1000) - ), - min_size=1, - ), + fn_tree="numpy.random.normal", + input_dtypes=helpers.get_dtypes("float", index=2), + loc=st.floats(allow_nan=False, allow_infinity=False, width=32), + scale=st.floats(allow_nan=False, allow_infinity=False, width=32, min_value=0), + size=st.tuples( + st.integers(min_value=2, max_value=5), st.integers(min_value=2, max_value=5) ), - size=helpers.get_shape(allow_none=True), - test_with_out=st.just(False), ) -def test_numpy_chisquare( - dtypes, - df, +def test_numpy_normal( + input_dtypes, size, frontend, test_flags, fn_tree, backend_fw, on_device, + loc, + scale, ): - # make sure `size` is something `df` can be broadcast to - if ( - hasattr(df, "__len__") - and size is not None - and (len(size) == 0 or size[-1] != len(df)) - ): - size = (*size, len(df)) helpers.test_frontend_function( - input_dtypes=dtypes, + input_dtypes=input_dtypes, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, test_values=False, - df=df, + loc=loc, + scale=scale, size=size, ) -# lognormal -# min value is set 0 +# pareto @handle_frontend_test( - fn_tree="numpy.random.lognormal", + fn_tree="numpy.random.pareto", input_dtypes=helpers.get_dtypes("float", index=2), - mean=st.floats( - allow_nan=False, allow_infinity=False, width=32, min_value=-5, max_value=5 - ), - sigma=st.floats( - allow_nan=False, allow_infinity=False, width=32, min_value=0, max_value=5 - ), - size=st.tuples( - st.integers(min_value=2, max_value=5), st.integers(min_value=2, max_value=5) + a=st.floats( + allow_nan=False, + allow_infinity=False, + width=32, + min_value=1, + max_value=1000, + exclude_min=True, ), + size=helpers.get_shape(allow_none=True), + test_with_out=st.just(False), ) -def test_numpy_lognormal( +def test_numpy_pareto( input_dtypes, - size, frontend, test_flags, - fn_tree, backend_fw, + fn_tree, on_device, - mean, - sigma, + a, + size, ): helpers.test_frontend_function( input_dtypes=input_dtypes, @@ -543,116 +592,89 @@ def test_numpy_lognormal( fn_tree=fn_tree, on_device=on_device, test_values=False, - mean=mean, - sigma=sigma, + a=a, size=size, ) -# negative_binomial +# permutation @handle_frontend_test( - fn_tree="numpy.random.negative_binomial", - input_dtypes=helpers.get_dtypes("float", index=2), - # max value for n and min value for p are restricted in testing - # as they can blow up poisson lambda, which will cause an - # error (lam value too large). - n=st.floats( - allow_nan=False, - allow_infinity=False, - width=32, - min_value=0, - max_value=100000, - exclude_min=True, - ), - p=st.floats( - allow_nan=False, - allow_infinity=False, - width=32, - min_value=9.999999747378752e-06, - exclude_min=True, - max_value=1, - exclude_max=True, + fn_tree="numpy.random.permutation", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), min_num_dims=1 ), - size=helpers.get_shape(allow_none=True), - test_with_out=st.just(False), ) -def test_numpy_negative_binomial( - input_dtypes, - size, +def test_numpy_permutation( + dtype_and_x, frontend, test_flags, fn_tree, backend_fw, on_device, - n, - p, ): + input_dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=input_dtypes, + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, test_values=False, - n=n, - p=p, - size=size, + x=x[0], ) -# weibull +# poisson @handle_frontend_test( - fn_tree="numpy.random.weibull", - input_dtypes=helpers.get_dtypes("float", index=2), - a=st.floats( - allow_nan=False, - allow_infinity=False, - width=32, + fn_tree="numpy.random.poisson", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + shape=st.tuples(st.integers(min_value=1, max_value=2)), min_value=1, - max_value=1000, - exclude_min=True, + max_value=100, + ), + size=st.tuples( + st.integers(min_value=1, max_value=10), st.integers(min_value=2, max_value=2) ), - size=helpers.get_shape(allow_none=True), - test_with_out=st.just(False), ) -def test_numpy_weibull( - input_dtypes, - frontend, +def test_numpy_poisson( + dtype_and_x, + size, test_flags, + frontend, backend_fw, fn_tree, on_device, - a, - size, ): + input_dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=input_dtypes, + input_dtypes=input_dtype, backend_to_test=backend_fw, - frontend=frontend, test_flags=test_flags, + frontend=frontend, fn_tree=fn_tree, on_device=on_device, + lam=x[0], test_values=False, - a=a, size=size, ) -# standard_cauchy +# random_sample @handle_frontend_test( - fn_tree="numpy.random.standard_cauchy", + fn_tree="numpy.random.random_sample", input_dtypes=helpers.get_dtypes("integer", full=False), size=helpers.get_shape(allow_none=True), test_with_out=st.just(False), ) -def test_numpy_standard_cauchy( +def test_numpy_random_sample( input_dtypes, size, frontend, - backend_fw, test_flags, fn_tree, + backend_fw, on_device, ): helpers.test_frontend_function( @@ -704,125 +726,88 @@ def test_numpy_rayleigh( ) -# gumbel +# shuffle @handle_frontend_test( - fn_tree="numpy.random.gumbel", - input_dtypes=helpers.get_dtypes("float"), - loc=st.floats( - allow_nan=False, - allow_infinity=False, - width=32, - max_value=1000, - ), - scale=st.floats( - allow_nan=False, - allow_infinity=False, - width=32, - min_value=0, - max_value=1000, - exclude_min=True, + fn_tree="numpy.random.shuffle", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), min_num_dims=1 ), - size=helpers.get_shape(allow_none=True), ) -def test_numpy_gumbel( - input_dtypes, +def test_numpy_shuffle( + dtype_and_x, frontend, test_flags, - backend_fw, fn_tree, + backend_fw, on_device, - loc, - scale, - size, ): + input_dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=input_dtypes, + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, test_values=False, - loc=loc, - scale=scale, - size=size, + x=x[0], ) +# standard_cauchy @handle_frontend_test( - fn_tree="numpy.random.f", - input_dtypes=helpers.get_dtypes("float"), - dfn=st.floats( - allow_nan=False, - allow_infinity=False, - width=32, - min_value=1, - max_value=1000, - exclude_min=True, - ), - dfd=st.floats( - allow_nan=False, - allow_infinity=False, - width=32, - min_value=1, - max_value=1000, - exclude_min=True, - ), - size=helpers.get_shape(allow_none=False), + fn_tree="numpy.random.standard_cauchy", + input_dtypes=helpers.get_dtypes("integer", full=False), + size=helpers.get_shape(allow_none=True), + test_with_out=st.just(False), ) -def test_numpy_f( +def test_numpy_standard_cauchy( input_dtypes, size, frontend, - test_flags, backend_fw, + test_flags, fn_tree, on_device, - dfn, - dfd, ): - test_flags.num_positional_args = 2 helpers.test_frontend_function( input_dtypes=input_dtypes, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, - test_values=False, on_device=on_device, - dfn=dfn, - dfd=dfd, + test_values=False, size=size, ) @handle_frontend_test( - fn_tree="numpy.random.gamma", - input_dtypes=helpers.get_dtypes("float", full=False), + fn_tree="numpy.random.standard_gamma", + shape_dtypes=helpers.get_dtypes("float", full=False), shape=st.floats( allow_nan=False, allow_infinity=False, width=32, min_value=0, exclude_min=True ), - scale=st.floats( - allow_nan=False, allow_infinity=False, width=32, min_value=0, exclude_min=True - ), size=st.tuples( st.integers(min_value=2, max_value=5), st.integers(min_value=2, max_value=5) ), + size_dtypes=helpers.get_dtypes("integer", full=False), test_with_out=st.just(False), ) -def test_numpy_gamma( - input_dtypes, +def test_numpy_standard_gamma( shape, - scale, + shape_dtypes, size, + size_dtypes, frontend, test_flags, fn_tree, - on_device, backend_fw, + on_device, ): + assume("float16" not in shape_dtypes) helpers.test_frontend_function( - input_dtypes=input_dtypes, + input_dtypes=shape_dtypes + size_dtypes, backend_to_test=backend_fw, test_flags=test_flags, frontend=frontend, @@ -830,81 +815,112 @@ def test_numpy_gamma( on_device=on_device, test_values=False, shape=shape, - scale=scale, size=size, ) -# logistic @handle_frontend_test( - fn_tree="numpy.random.logistic", - input_dtypes=helpers.get_dtypes("float", full=False), - loc=st.floats( + fn_tree="numpy.random.standard_normal", + input_dtypes=helpers.get_dtypes("integer", full=False), + size=helpers.get_shape(allow_none=True), + test_with_out=st.just(False), +) +def test_numpy_standard_normal( + input_dtypes, + size, + frontend, + test_flags, + fn_tree, + backend_fw, + on_device, +): + helpers.test_frontend_function( + input_dtypes=input_dtypes, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + test_values=False, + size=size, + ) + + +@handle_frontend_test( + fn_tree="numpy.random.triangular", + input_dtypes=helpers.get_dtypes("float"), + left=st.floats( allow_nan=False, allow_infinity=False, width=32, min_value=0, + max_value=10, + ), + mode=st.floats( + allow_nan=False, + allow_infinity=False, + width=32, + min_value=10, + max_value=100, exclude_min=True, ), - scale=st.floats( + right=st.floats( allow_nan=False, allow_infinity=False, width=32, - min_value=0, + min_value=100, + max_value=1000, exclude_min=True, ), size=helpers.get_shape(allow_none=True), - test_with_out=st.just(False), ) -def test_numpy_logistic( +def test_numpy_triangular( input_dtypes, size, frontend, test_flags, + backend_fw, fn_tree, on_device, - backend_fw, - loc, - scale, + left, + mode, + right, ): helpers.test_frontend_function( input_dtypes=input_dtypes, backend_to_test=backend_fw, - test_flags=test_flags, frontend=frontend, + test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, test_values=False, - loc=loc, - scale=scale, + left=left, + mode=mode, + right=right, size=size, ) -# pareto +# uniform @handle_frontend_test( - fn_tree="numpy.random.pareto", + fn_tree="numpy.random.uniform", input_dtypes=helpers.get_dtypes("float", index=2), - a=st.floats( - allow_nan=False, - allow_infinity=False, - width=32, - min_value=1, - max_value=1000, - exclude_min=True, + low=st.floats(allow_nan=False, allow_infinity=False, width=32), + high=st.floats(allow_nan=False, allow_infinity=False, width=32), + size=st.tuples( + st.integers(min_value=2, max_value=5), st.integers(min_value=2, max_value=5) ), - size=helpers.get_shape(allow_none=True), - test_with_out=st.just(False), ) -def test_numpy_pareto( +def test_numpy_uniform( input_dtypes, + size, frontend, test_flags, - backend_fw, fn_tree, + backend_fw, on_device, - a, - size, + low, + high, ): helpers.test_frontend_function( input_dtypes=input_dtypes, @@ -914,7 +930,8 @@ def test_numpy_pareto( fn_tree=fn_tree, on_device=on_device, test_values=False, - a=a, + low=low, + high=high, size=size, ) @@ -966,45 +983,30 @@ def test_numpy_wald( ) +# weibull @handle_frontend_test( - fn_tree="numpy.random.triangular", - input_dtypes=helpers.get_dtypes("float"), - left=st.floats( - allow_nan=False, - allow_infinity=False, - width=32, - min_value=0, - max_value=10, - ), - mode=st.floats( - allow_nan=False, - allow_infinity=False, - width=32, - min_value=10, - max_value=100, - exclude_min=True, - ), - right=st.floats( + fn_tree="numpy.random.weibull", + input_dtypes=helpers.get_dtypes("float", index=2), + a=st.floats( allow_nan=False, allow_infinity=False, width=32, - min_value=100, + min_value=1, max_value=1000, exclude_min=True, ), size=helpers.get_shape(allow_none=True), + test_with_out=st.just(False), ) -def test_numpy_triangular( +def test_numpy_weibull( input_dtypes, - size, frontend, test_flags, backend_fw, fn_tree, on_device, - left, - mode, - right, + a, + size, ): helpers.test_frontend_function( input_dtypes=input_dtypes, @@ -1014,8 +1016,6 @@ def test_numpy_triangular( fn_tree=fn_tree, on_device=on_device, test_values=False, - left=left, - mode=mode, - right=right, + a=a, size=size, ) diff --git a/ivy_tests/test_ivy/test_frontends/test_numpy/test_sorting_searching_counting/test_searching.py b/ivy_tests/test_ivy/test_frontends/test_numpy/test_sorting_searching_counting/test_searching.py index 51ba4c4b55ca1..24f1fce8c31d3 100644 --- a/ivy_tests/test_ivy/test_frontends/test_numpy/test_sorting_searching_counting/test_searching.py +++ b/ivy_tests/test_ivy/test_frontends/test_numpy/test_sorting_searching_counting/test_searching.py @@ -7,6 +7,10 @@ from ivy_tests.test_ivy.helpers import handle_frontend_test +# --- Helpers --- # +# --------------- # + + @st.composite def _broadcastable_trio(draw): dtype = draw(helpers.get_dtypes("valid", full=False)) @@ -20,65 +24,87 @@ def _broadcastable_trio(draw): return cond, x1, x2, (dtype * 2) -# where -@handle_frontend_test( - fn_tree="numpy.where", - broadcastables=_broadcastable_trio(), - test_with_out=st.just(False), -) -def test_numpy_where( - broadcastables, - frontend, - test_flags, - fn_tree, - backend_fw, - on_device, -): - cond, x1, x2, dtype = broadcastables - helpers.test_frontend_function( - input_dtypes=["bool", dtype], - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - cond=cond, - x1=x1, - x2=x2, +@st.composite +def _extract_strategy(draw): + dtype_and_cond = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + ) + ) + dtype_and_arr = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + ) ) + return dtype_and_cond, dtype_and_arr -# nonzero -@handle_frontend_test( - fn_tree="numpy.nonzero", - dtype_and_a=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - ), - test_with_out=st.just(False), -) -def test_numpy_nonzero( - dtype_and_a, - frontend, - test_flags, - fn_tree, - backend_fw, - on_device, -): - dtype, a = dtype_and_a - helpers.test_frontend_function( - input_dtypes=dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - a=a[0], - ) +# searchsorted +@st.composite +def _search_sorted_values(draw): + case = st.booleans() + if case: + # when x is 1-D and v is N-D + dtype_x, x = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes( + "numeric", full=False, key="searchsorted" + ), + shape=(draw(st.integers(min_value=1, max_value=5)),), + ), + ) + dtype_v, v = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes( + "numeric", full=False, key="searchsorted" + ), + min_num_dims=1, + ) + ) + else: + # when x is N-D and v is N-D + lead_dim = draw( + helpers.get_shape(min_num_dims=1), + ) + nx = draw(st.integers(min_value=1, max_value=5)) + nv = draw(st.integers(min_value=1, max_value=5)) + dtype_x, x = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes( + "numeric", full=False, key="searchsorted" + ), + shape=lead_dim + (nx,), + ), + ) + dtype_v, v = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes( + "numeric", full=False, key="searchsorted" + ), + shape=lead_dim + (nv,), + ), + ) + input_dtypes = dtype_x + dtype_v + xs = x + v + side = draw(st.sampled_from(["left", "right"])) + use_sorter = draw(st.booleans()) + if use_sorter: + sorter_dtype = draw(st.sampled_from(["int32", "int64"])) + input_dtypes.append(sorter_dtype) + sorter = np.argsort(xs[0], axis=-1).astype(sorter_dtype) + else: + sorter = None + xs[0] = np.sort(xs[0], axis=-1) + return input_dtypes, xs, side, sorter -# argmin +# --- Main --- # +# ------------ # + + +# argmax @handle_frontend_test( - fn_tree="numpy.argmin", + fn_tree="numpy.argmax", dtype_x_axis=helpers.dtype_values_axis( available_dtypes=helpers.get_dtypes("numeric"), min_axis=-1, @@ -89,7 +115,7 @@ def test_numpy_nonzero( keep_dims=st.booleans(), test_with_out=st.just(False), ) -def test_numpy_argmin( +def test_numpy_argmax( dtype_x_axis, frontend, test_flags, @@ -112,9 +138,9 @@ def test_numpy_argmin( ) -# argmax +# argmin @handle_frontend_test( - fn_tree="numpy.argmax", + fn_tree="numpy.argmin", dtype_x_axis=helpers.dtype_values_axis( available_dtypes=helpers.get_dtypes("numeric"), min_axis=-1, @@ -125,7 +151,7 @@ def test_numpy_argmin( keep_dims=st.booleans(), test_with_out=st.just(False), ) -def test_numpy_argmax( +def test_numpy_argmin( dtype_x_axis, frontend, test_flags, @@ -148,15 +174,15 @@ def test_numpy_argmax( ) -# flatnonzero +# argwhere @handle_frontend_test( - fn_tree="numpy.flatnonzero", + fn_tree="numpy.argwhere", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("valid"), ), test_with_out=st.just(False), ) -def test_numpy_flatnonzero( +def test_numpy_argwhere( dtype_and_x, frontend, test_flags, @@ -176,102 +202,44 @@ def test_numpy_flatnonzero( ) -# searchsorted -@st.composite -def _search_sorted_values(draw): - case = st.booleans() - if case: - # when x is 1-D and v is N-D - dtype_x, x = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes( - "numeric", full=False, key="searchsorted" - ), - shape=(draw(st.integers(min_value=1, max_value=5)),), - ), - ) - dtype_v, v = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes( - "numeric", full=False, key="searchsorted" - ), - min_num_dims=1, - ) - ) - else: - # when x is N-D and v is N-D - lead_dim = draw( - helpers.get_shape(min_num_dims=1), - ) - nx = draw(st.integers(min_value=1, max_value=5)) - nv = draw(st.integers(min_value=1, max_value=5)) - dtype_x, x = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes( - "numeric", full=False, key="searchsorted" - ), - shape=lead_dim + (nx,), - ), - ) - dtype_v, v = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes( - "numeric", full=False, key="searchsorted" - ), - shape=lead_dim + (nv,), - ), - ) - input_dtypes = dtype_x + dtype_v - xs = x + v - side = draw(st.sampled_from(["left", "right"])) - use_sorter = draw(st.booleans()) - if use_sorter: - sorter_dtype = draw(st.sampled_from(["int32", "int64"])) - input_dtypes.append(sorter_dtype) - sorter = np.argsort(xs[0], axis=-1).astype(sorter_dtype) - else: - sorter = None - xs[0] = np.sort(xs[0], axis=-1) - return input_dtypes, xs, side, sorter - - +# extract @handle_frontend_test( - fn_tree="numpy.searchsorted", - dtype_x_v_side_sorter=_search_sorted_values(), + fn_tree="numpy.extract", + dtype_and_x=_extract_strategy(), test_with_out=st.just(False), ) -def test_numpy_searchsorted( - dtype_x_v_side_sorter, +def test_numpy_extract( + dtype_and_x, frontend, test_flags, fn_tree, backend_fw, on_device, ): - input_dtypes, xs, side, sorter = dtype_x_v_side_sorter + dtype_cond, cond = dtype_and_x[0] + dtype_arr, arr = dtype_and_x[1] + helpers.test_frontend_function( - input_dtypes=input_dtypes + ["int64"], - frontend=frontend, + input_dtypes=dtype_cond + dtype_arr, backend_to_test=backend_fw, + frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - a=xs[0], - v=xs[1], - side=side, - sorter=sorter, + cond=cond[0], + arr=arr[0], ) -# argwhere +# flatnonzero @handle_frontend_test( - fn_tree="numpy.argwhere", + fn_tree="numpy.flatnonzero", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), + available_dtypes=helpers.get_dtypes("float"), ), test_with_out=st.just(False), ) -def test_numpy_argwhere( +def test_numpy_flatnonzero( dtype_and_x, frontend, test_flags, @@ -363,45 +331,85 @@ def test_numpy_nanargmin( ) -@st.composite -def _extract_strategy(draw): - dtype_and_cond = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - ) - ) - dtype_and_arr = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - ) +# nonzero +@handle_frontend_test( + fn_tree="numpy.nonzero", + dtype_and_a=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + ), + test_with_out=st.just(False), +) +def test_numpy_nonzero( + dtype_and_a, + frontend, + test_flags, + fn_tree, + backend_fw, + on_device, +): + dtype, a = dtype_and_a + helpers.test_frontend_function( + input_dtypes=dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + a=a[0], ) - return dtype_and_cond, dtype_and_arr -# extract @handle_frontend_test( - fn_tree="numpy.extract", - dtype_and_x=_extract_strategy(), + fn_tree="numpy.searchsorted", + dtype_x_v_side_sorter=_search_sorted_values(), test_with_out=st.just(False), ) -def test_numpy_extract( - dtype_and_x, +def test_numpy_searchsorted( + dtype_x_v_side_sorter, frontend, test_flags, fn_tree, backend_fw, on_device, ): - dtype_cond, cond = dtype_and_x[0] - dtype_arr, arr = dtype_and_x[1] + input_dtypes, xs, side, sorter = dtype_x_v_side_sorter + helpers.test_frontend_function( + input_dtypes=input_dtypes + ["int64"], + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + a=xs[0], + v=xs[1], + side=side, + sorter=sorter, + ) + +# where +@handle_frontend_test( + fn_tree="numpy.where", + broadcastables=_broadcastable_trio(), + test_with_out=st.just(False), +) +def test_numpy_where( + broadcastables, + frontend, + test_flags, + fn_tree, + backend_fw, + on_device, +): + cond, x1, x2, dtype = broadcastables helpers.test_frontend_function( - input_dtypes=dtype_cond + dtype_arr, + input_dtypes=["bool", dtype], backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - cond=cond[0], - arr=arr[0], + cond=cond, + x1=x1, + x2=x2, ) diff --git a/ivy_tests/test_ivy/test_frontends/test_numpy/test_sorting_searching_counting/test_sorting.py b/ivy_tests/test_ivy/test_frontends/test_numpy/test_sorting_searching_counting/test_sorting.py index 5b0cad9ba10d4..14f70253f680c 100644 --- a/ivy_tests/test_ivy/test_frontends/test_numpy/test_sorting_searching_counting/test_sorting.py +++ b/ivy_tests/test_ivy/test_frontends/test_numpy/test_sorting_searching_counting/test_sorting.py @@ -40,9 +40,9 @@ def test_numpy_argsort( @handle_frontend_test( - fn_tree="numpy.sort", + fn_tree="numpy.lexsort", dtype_x_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("numeric"), min_axis=-1, max_axis=0, min_num_dims=1, @@ -50,7 +50,7 @@ def test_numpy_argsort( ), test_with_out=st.just(False), ) -def test_numpy_sort( +def test_numpy_lexsort( *, dtype_x_axis, frontend, @@ -60,7 +60,6 @@ def test_numpy_sort( on_device, ): input_dtype, x, axis = dtype_x_axis - helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -68,7 +67,7 @@ def test_numpy_sort( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - a=x[0], + keys=x[0], axis=axis, ) @@ -106,43 +105,47 @@ def test_numpy_msort( @handle_frontend_test( - fn_tree="numpy.sort_complex", - dtype_x_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("valid"), - min_num_dims=1, + fn_tree="numpy.partition", + dtype_x_axis=helpers.array_indices_axis( + array_dtypes=helpers.get_dtypes("numeric"), + indices_dtypes=["int64"], min_dim_size=1, - min_axis=-1, - max_axis=0, + max_num_dims=1, + indices_same_dims=False, + disable_random_axis=False, + axis_zero=False, + valid_bounds=True, ), test_with_out=st.just(False), ) -def test_numpy_sort_complex( +def test_numpy_partition( *, dtype_x_axis, frontend, test_flags, - fn_tree, backend_fw, + fn_tree, on_device, ): - input_dtype, x, axis = dtype_x_axis - + dtypes, x, kth, axis, _ = dtype_x_axis helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, + input_dtypes=dtypes, frontend=frontend, + backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - a=x[0], test_values=False, + a=x, + kth=kth, + axis=axis, ) @handle_frontend_test( - fn_tree="numpy.lexsort", + fn_tree="numpy.sort", dtype_x_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("numeric"), + available_dtypes=helpers.get_dtypes("float"), min_axis=-1, max_axis=0, min_num_dims=1, @@ -150,7 +153,7 @@ def test_numpy_sort_complex( ), test_with_out=st.just(False), ) -def test_numpy_lexsort( +def test_numpy_sort( *, dtype_x_axis, frontend, @@ -160,6 +163,7 @@ def test_numpy_lexsort( on_device, ): input_dtype, x, axis = dtype_x_axis + helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -167,44 +171,40 @@ def test_numpy_lexsort( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - keys=x[0], + a=x[0], axis=axis, ) @handle_frontend_test( - fn_tree="numpy.partition", - dtype_x_axis=helpers.array_indices_axis( - array_dtypes=helpers.get_dtypes("numeric"), - indices_dtypes=["int64"], + fn_tree="numpy.sort_complex", + dtype_x_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=1, min_dim_size=1, - max_num_dims=1, - indices_same_dims=False, - disable_random_axis=False, - axis_zero=False, - valid_bounds=True, + min_axis=-1, + max_axis=0, ), test_with_out=st.just(False), ) -def test_numpy_partition( +def test_numpy_sort_complex( *, dtype_x_axis, frontend, test_flags, - backend_fw, fn_tree, + backend_fw, on_device, ): - dtypes, x, kth, axis, _ = dtype_x_axis + input_dtype, x, axis = dtype_x_axis + helpers.test_frontend_function( - input_dtypes=dtypes, - frontend=frontend, + input_dtypes=input_dtype, backend_to_test=backend_fw, + frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, + a=x[0], test_values=False, - a=x, - kth=kth, - axis=axis, ) diff --git a/ivy_tests/test_ivy/test_frontends/test_numpy/test_statistics/test_averages_and_variances.py b/ivy_tests/test_ivy/test_frontends/test_numpy/test_statistics/test_averages_and_variances.py index 139b47b8c71fe..38b605e86d1f8 100644 --- a/ivy_tests/test_ivy/test_frontends/test_numpy/test_statistics/test_averages_and_variances.py +++ b/ivy_tests/test_ivy/test_frontends/test_numpy/test_statistics/test_averages_and_variances.py @@ -12,6 +12,197 @@ from ivy_tests.test_ivy.helpers import handle_frontend_test +# --- Helpers --- # +# --------------- # + + +@st.composite +def _get_dtype_value1_value2_cov( + draw, + available_dtypes, + min_num_dims, + max_num_dims, + min_dim_size, + max_dim_size, + abs_smallest_val=None, + min_value=None, + max_value=None, + allow_inf=False, + exclude_min=False, + exclude_max=False, + large_abs_safety_factor=4, + small_abs_safety_factor=4, + safety_factor_scale="log", +): + shape = draw( + helpers.get_shape( + allow_none=False, + min_num_dims=min_num_dims, + max_num_dims=max_num_dims, + min_dim_size=min_dim_size, + max_dim_size=max_dim_size, + ) + ) + + dtype = draw(st.sampled_from(draw(available_dtypes))) + + values = [] + for i in range(2): + values.append( + draw( + helpers.array_values( + dtype=dtype, + shape=shape, + abs_smallest_val=abs_smallest_val, + min_value=min_value, + max_value=max_value, + allow_inf=allow_inf, + exclude_min=exclude_min, + exclude_max=exclude_max, + large_abs_safety_factor=large_abs_safety_factor, + small_abs_safety_factor=small_abs_safety_factor, + safety_factor_scale=safety_factor_scale, + ) + ) + ) + + value1, value2 = values[0], values[1] + + # modifiers: rowVar, bias, ddof + rowVar = draw(st.booleans()) + bias = draw(st.booleans()) + ddof = draw(helpers.ints(min_value=0, max_value=1)) + + numVals = None + if rowVar is False: + numVals = -1 if numVals == 0 else 0 + else: + numVals = 0 if len(shape) == 1 else -1 + + fweights = draw( + helpers.array_values( + dtype="int64", + shape=shape[numVals], + abs_smallest_val=1, + min_value=1, + max_value=10, + allow_inf=False, + ) + ) + + aweights = draw( + helpers.array_values( + dtype="float64", + shape=shape[numVals], + abs_smallest_val=1, + min_value=1, + max_value=10, + allow_inf=False, + small_abs_safety_factor=1, + ) + ) + + return [dtype], value1, value2, rowVar, bias, ddof, fweights, aweights + + +# --- Main --- # +# ------------ # + + +# average +@handle_frontend_test( + fn_tree="numpy.average", + dtype_and_a=_statistical_dtype_values(function="average"), + dtype_and_x=_statistical_dtype_values(function="average"), + keep_dims=st.booleans(), + returned=st.booleans(), + test_with_out=st.just(False), +) +def test_numpy_average( + dtype_and_a, + dtype_and_x, + frontend, + test_flags, + fn_tree, + backend_fw, + keep_dims, + returned, + on_device, +): + try: + input_dtype, a, axis = dtype_and_a + + input_dtypes, xs, axiss = dtype_and_x + + if isinstance(axis, tuple): + axis = axis[0] + + helpers.test_frontend_function( + a=a[0], + input_dtypes=input_dtype, + backend_to_test=backend_fw, + weights=xs[0], + axis=axis, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + keepdims=keep_dims, + returned=returned, + on_device=on_device, + rtol=1e-2, + atol=1e-2, + ) + except ZeroDivisionError: + assume(False) + except AssertionError: + assume(False) + + +# cov +@handle_frontend_test( + fn_tree="numpy.cov", + dtype_x1_x2_cov=_get_dtype_value1_value2_cov( + available_dtypes=helpers.get_dtypes("float"), + min_num_dims=1, + max_num_dims=2, + min_dim_size=2, + max_dim_size=5, + min_value=1, + max_value=1e10, + abs_smallest_val=0.01, + large_abs_safety_factor=2, + safety_factor_scale="log", + ), + test_with_out=st.just(False), +) +def test_numpy_cov( + dtype_x1_x2_cov, + test_flags, + frontend, + fn_tree, + backend_fw, + on_device, +): + dtype, x1, x2, rowvar, bias, ddof, fweights, aweights = dtype_x1_x2_cov + np_frontend_helpers.test_frontend_function( + input_dtypes=[dtype[0], dtype[0], "int64", "float64"], + frontend=frontend, + test_flags=test_flags, + backend_to_test=backend_fw, + fn_tree=fn_tree, + on_device=on_device, + rtol=1e-2, + atol=1e-2, + m=x1, + y=x2, + rowvar=rowvar, + bias=bias, + ddof=ddof, + fweights=fweights, + aweights=aweights, + ) + + # mean @handle_frontend_test( fn_tree="numpy.mean", @@ -108,34 +299,23 @@ def test_numpy_nanmean( ) -# std @handle_frontend_test( - fn_tree="numpy.std", - dtype_and_x=_statistical_dtype_values(function="std"), - dtype=helpers.get_dtypes("float", full=False, none=True), - where=np_frontend_helpers.where(), + fn_tree="numpy.nanmedian", keep_dims=st.booleans(), + overwrite_input=st.booleans(), + dtype_x_axis=_statistical_dtype_values(function="nanmedian"), ) -def test_numpy_std( - dtype_and_x, - dtype, - where, +def test_numpy_nanmedian( + dtype_x_axis, frontend, - backend_fw, test_flags, fn_tree, + backend_fw, on_device, keep_dims, + overwrite_input, ): - input_dtypes, x, axis, correction = dtype_and_x - if isinstance(axis, tuple): - axis = axis[0] - where, input_dtypes, test_flags = np_frontend_helpers.handle_where_and_array_bools( - where=where, - input_dtype=input_dtypes, - test_flags=test_flags, - ) - assume(np.dtype(dtype[0]) >= np.dtype(input_dtypes[0])) + input_dtypes, x, axis = dtype_x_axis np_frontend_helpers.test_frontend_function( input_dtypes=input_dtypes, frontend=frontend, @@ -143,67 +323,14 @@ def test_numpy_std( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - rtol=1e-1, - atol=1e-1, - x=x[0], + a=x[0], axis=axis, - ddof=correction, - keepdims=keep_dims, + overwrite_input=overwrite_input, out=None, - dtype=dtype[0], - where=where, + keepdims=keep_dims, ) -# average -@handle_frontend_test( - fn_tree="numpy.average", - dtype_and_a=_statistical_dtype_values(function="average"), - dtype_and_x=_statistical_dtype_values(function="average"), - keep_dims=st.booleans(), - returned=st.booleans(), - test_with_out=st.just(False), -) -def test_numpy_average( - dtype_and_a, - dtype_and_x, - frontend, - test_flags, - fn_tree, - backend_fw, - keep_dims, - returned, - on_device, -): - try: - input_dtype, a, axis = dtype_and_a - - input_dtypes, xs, axiss = dtype_and_x - - if isinstance(axis, tuple): - axis = axis[0] - - helpers.test_frontend_function( - a=a[0], - input_dtypes=input_dtype, - backend_to_test=backend_fw, - weights=xs[0], - axis=axis, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - keepdims=keep_dims, - returned=returned, - on_device=on_device, - rtol=1e-2, - atol=1e-2, - ) - except ZeroDivisionError: - assume(False) - except AssertionError: - assume(False) - - # nanstd @handle_frontend_test( fn_tree="numpy.nanstd", @@ -251,140 +378,6 @@ def test_numpy_nanstd( ) -@st.composite -def _get_dtype_value1_value2_cov( - draw, - available_dtypes, - min_num_dims, - max_num_dims, - min_dim_size, - max_dim_size, - abs_smallest_val=None, - min_value=None, - max_value=None, - allow_inf=False, - exclude_min=False, - exclude_max=False, - large_abs_safety_factor=4, - small_abs_safety_factor=4, - safety_factor_scale="log", -): - shape = draw( - helpers.get_shape( - allow_none=False, - min_num_dims=min_num_dims, - max_num_dims=max_num_dims, - min_dim_size=min_dim_size, - max_dim_size=max_dim_size, - ) - ) - - dtype = draw(st.sampled_from(draw(available_dtypes))) - - values = [] - for i in range(2): - values.append( - draw( - helpers.array_values( - dtype=dtype, - shape=shape, - abs_smallest_val=abs_smallest_val, - min_value=min_value, - max_value=max_value, - allow_inf=allow_inf, - exclude_min=exclude_min, - exclude_max=exclude_max, - large_abs_safety_factor=large_abs_safety_factor, - small_abs_safety_factor=small_abs_safety_factor, - safety_factor_scale=safety_factor_scale, - ) - ) - ) - - value1, value2 = values[0], values[1] - - # modifiers: rowVar, bias, ddof - rowVar = draw(st.booleans()) - bias = draw(st.booleans()) - ddof = draw(helpers.ints(min_value=0, max_value=1)) - - numVals = None - if rowVar is False: - numVals = -1 if numVals == 0 else 0 - else: - numVals = 0 if len(shape) == 1 else -1 - - fweights = draw( - helpers.array_values( - dtype="int64", - shape=shape[numVals], - abs_smallest_val=1, - min_value=1, - max_value=10, - allow_inf=False, - ) - ) - - aweights = draw( - helpers.array_values( - dtype="float64", - shape=shape[numVals], - abs_smallest_val=1, - min_value=1, - max_value=10, - allow_inf=False, - small_abs_safety_factor=1, - ) - ) - - return [dtype], value1, value2, rowVar, bias, ddof, fweights, aweights - - -# cov -@handle_frontend_test( - fn_tree="numpy.cov", - dtype_x1_x2_cov=_get_dtype_value1_value2_cov( - available_dtypes=helpers.get_dtypes("float"), - min_num_dims=1, - max_num_dims=2, - min_dim_size=2, - max_dim_size=5, - min_value=1, - max_value=1e10, - abs_smallest_val=0.01, - large_abs_safety_factor=2, - safety_factor_scale="log", - ), - test_with_out=st.just(False), -) -def test_numpy_cov( - dtype_x1_x2_cov, - test_flags, - frontend, - fn_tree, - backend_fw, - on_device, -): - dtype, x1, x2, rowvar, bias, ddof, fweights, aweights = dtype_x1_x2_cov - np_frontend_helpers.test_frontend_function( - input_dtypes=[dtype[0], dtype[0], "int64", "float64"], - frontend=frontend, - test_flags=test_flags, - backend_to_test=backend_fw, - fn_tree=fn_tree, - on_device=on_device, - rtol=1e-2, - atol=1e-2, - m=x1, - y=x2, - rowvar=rowvar, - bias=bias, - ddof=ddof, - fweights=fweights, - aweights=aweights, - ) - - # nanvar @handle_frontend_test( fn_tree="numpy.nanvar", @@ -432,23 +425,34 @@ def test_numpy_nanvar( ) +# std @handle_frontend_test( - fn_tree="numpy.nanmedian", + fn_tree="numpy.std", + dtype_and_x=_statistical_dtype_values(function="std"), + dtype=helpers.get_dtypes("float", full=False, none=True), + where=np_frontend_helpers.where(), keep_dims=st.booleans(), - overwrite_input=st.booleans(), - dtype_x_axis=_statistical_dtype_values(function="nanmedian"), ) -def test_numpy_nanmedian( - dtype_x_axis, +def test_numpy_std( + dtype_and_x, + dtype, + where, frontend, + backend_fw, test_flags, fn_tree, - backend_fw, on_device, keep_dims, - overwrite_input, ): - input_dtypes, x, axis = dtype_x_axis + input_dtypes, x, axis, correction = dtype_and_x + if isinstance(axis, tuple): + axis = axis[0] + where, input_dtypes, test_flags = np_frontend_helpers.handle_where_and_array_bools( + where=where, + input_dtype=input_dtypes, + test_flags=test_flags, + ) + assume(np.dtype(dtype[0]) >= np.dtype(input_dtypes[0])) np_frontend_helpers.test_frontend_function( input_dtypes=input_dtypes, frontend=frontend, @@ -456,11 +460,15 @@ def test_numpy_nanmedian( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - a=x[0], + rtol=1e-1, + atol=1e-1, + x=x[0], axis=axis, - overwrite_input=overwrite_input, - out=None, + ddof=correction, keepdims=keep_dims, + out=None, + dtype=dtype[0], + where=where, ) diff --git a/ivy_tests/test_ivy/test_frontends/test_numpy/test_statistics/test_correlating.py b/ivy_tests/test_ivy/test_frontends/test_numpy/test_statistics/test_correlating.py index 49ba1fd548e3a..ee73f44e8a3d6 100644 --- a/ivy_tests/test_ivy/test_frontends/test_numpy/test_statistics/test_correlating.py +++ b/ivy_tests/test_ivy/test_frontends/test_numpy/test_statistics/test_correlating.py @@ -1,90 +1,90 @@ -# global -from hypothesis import strategies as st - -# local -import ivy_tests.test_ivy.helpers as helpers -from ivy_tests.test_ivy.helpers import handle_frontend_test -import ivy_tests.test_ivy.test_frontends.test_numpy.helpers as np_helpers - - -# corrcoef -@handle_frontend_test( - fn_tree="numpy.corrcoef", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - shared_dtype=True, - abs_smallest_val=1e-5, - min_num_dims=2, - max_num_dims=2, - min_dim_size=3, - max_dim_size=3, - min_value=-100, - max_value=100, - ), - rowvar=st.booleans(), - dtype=helpers.get_dtypes("float", full=False), -) -def test_numpy_corrcoef( - dtype_and_x, - rowvar, - frontend, - test_flags, - fn_tree, - on_device, - dtype, - backend_fw, -): - input_dtypes, x = dtype_and_x - np_helpers.test_frontend_function( - input_dtypes=input_dtypes, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - y=x[1], - rowvar=rowvar, - dtype=dtype[0], - backend_to_test=backend_fw, - ) - - -# correlate -@handle_frontend_test( - fn_tree="numpy.correlate", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_num_dims=1, - max_num_dims=1, - num_arrays=2, - shared_dtype=True, - large_abs_safety_factor=24, - small_abs_safety_factor=24, - safety_factor_scale="log", - ), - mode=st.sampled_from(["valid", "same", "full"]), - test_with_out=st.just(False), -) -def test_numpy_correlate( - dtype_and_x, - mode, - frontend, - test_flags, - fn_tree, - backend_fw, - on_device, -): - input_dtypes, xs = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtypes, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - rtol=1e-2, - a=xs[0], - v=xs[1], - mode=mode, - ) +# global +from hypothesis import strategies as st + +# local +import ivy_tests.test_ivy.helpers as helpers +from ivy_tests.test_ivy.helpers import handle_frontend_test +import ivy_tests.test_ivy.test_frontends.test_numpy.helpers as np_helpers + + +# corrcoef +@handle_frontend_test( + fn_tree="numpy.corrcoef", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, + shared_dtype=True, + abs_smallest_val=1e-5, + min_num_dims=2, + max_num_dims=2, + min_dim_size=3, + max_dim_size=3, + min_value=-100, + max_value=100, + ), + rowvar=st.booleans(), + dtype=helpers.get_dtypes("float", full=False), +) +def test_numpy_corrcoef( + dtype_and_x, + rowvar, + frontend, + test_flags, + fn_tree, + on_device, + dtype, + backend_fw, +): + input_dtypes, x = dtype_and_x + np_helpers.test_frontend_function( + input_dtypes=input_dtypes, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + y=x[1], + rowvar=rowvar, + dtype=dtype[0], + backend_to_test=backend_fw, + ) + + +# correlate +@handle_frontend_test( + fn_tree="numpy.correlate", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_num_dims=1, + max_num_dims=1, + num_arrays=2, + shared_dtype=True, + large_abs_safety_factor=24, + small_abs_safety_factor=24, + safety_factor_scale="log", + ), + mode=st.sampled_from(["valid", "same", "full"]), + test_with_out=st.just(False), +) +def test_numpy_correlate( + dtype_and_x, + mode, + frontend, + test_flags, + fn_tree, + backend_fw, + on_device, +): + input_dtypes, xs = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtypes, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + rtol=1e-2, + a=xs[0], + v=xs[1], + mode=mode, + ) diff --git a/ivy_tests/test_ivy/test_frontends/test_numpy/test_statistics/test_order_statistics.py b/ivy_tests/test_ivy/test_frontends/test_numpy/test_statistics/test_order_statistics.py index 0544a1f34a538..b9d641135e97d 100644 --- a/ivy_tests/test_ivy/test_frontends/test_numpy/test_statistics/test_order_statistics.py +++ b/ivy_tests/test_ivy/test_frontends/test_numpy/test_statistics/test_order_statistics.py @@ -1,87 +1,87 @@ -# global -from hypothesis import strategies as st - - -# local -from ivy_tests.test_ivy.test_functional.test_core.test_statistical import ( - _statistical_dtype_values, -) -import ivy_tests.test_ivy.helpers as helpers -import ivy_tests.test_ivy.test_frontends.test_numpy.helpers as np_frontend_helpers -from ivy_tests.test_ivy.helpers import handle_frontend_test - - -# ptp -@handle_frontend_test( - fn_tree="numpy.ptp", - dtype_values_axis=_statistical_dtype_values(function="ptp"), - keep_dims=st.booleans(), - test_with_out=st.just(False), -) -def test_numpy_ptp( - dtype_values_axis, - frontend, - backend_fw, - test_flags, - fn_tree, - keep_dims, -): - input_dtypes, values, axis = dtype_values_axis - if isinstance(axis, tuple): - axis = axis[0] - - helpers.test_frontend_function( - a=values[0], - axis=axis, - out=None, - keepdims=keep_dims, - frontend=frontend, - backend_to_test=backend_fw, - fn_tree=fn_tree, - test_flags=test_flags, - input_dtypes=input_dtypes, - ) - - -# nanpercentile -@handle_frontend_test( - fn_tree="numpy.nanpercentile", - dtype_values_axis=_statistical_dtype_values(function="nanpercentile"), - where=np_frontend_helpers.where(), - keep_dims=st.booleans(), -) -def test_numpy_nanpercentile( - dtype_values_axis, - where, - frontend, - test_flags, - fn_tree, - backend_fw, - on_device, - keep_dims, -): - input_dtypes, values, axis = dtype_values_axis - if isinstance(axis, tuple): - axis = axis[0] - - where, input_dtypes, test_flags = np_frontend_helpers.handle_where_and_array_bools( - where=where, - input_dtype=input_dtypes, - test_flags=test_flags, - ) - - np_frontend_helpers.test_frontend_function( - a=values[0][0], - q=values[0][1], - axis=axis, - out=None, - backend_to_test=backend_fw, - overwrite_input=None, - method=None, - keepdims=keep_dims, - interpolation=None, - frontend=frontend, - fn_tree=fn_tree, - test_flags=test_flags, - input_dtypes=input_dtypes, - ) +# global +from hypothesis import strategies as st + + +# local +from ivy_tests.test_ivy.test_functional.test_core.test_statistical import ( + _statistical_dtype_values, +) +import ivy_tests.test_ivy.helpers as helpers +import ivy_tests.test_ivy.test_frontends.test_numpy.helpers as np_frontend_helpers +from ivy_tests.test_ivy.helpers import handle_frontend_test + + +# nanpercentile +@handle_frontend_test( + fn_tree="numpy.nanpercentile", + dtype_values_axis=_statistical_dtype_values(function="nanpercentile"), + where=np_frontend_helpers.where(), + keep_dims=st.booleans(), +) +def test_numpy_nanpercentile( + dtype_values_axis, + where, + frontend, + test_flags, + fn_tree, + backend_fw, + on_device, + keep_dims, +): + input_dtypes, values, axis = dtype_values_axis + if isinstance(axis, tuple): + axis = axis[0] + + where, input_dtypes, test_flags = np_frontend_helpers.handle_where_and_array_bools( + where=where, + input_dtype=input_dtypes, + test_flags=test_flags, + ) + + np_frontend_helpers.test_frontend_function( + a=values[0][0], + q=values[0][1], + axis=axis, + out=None, + backend_to_test=backend_fw, + overwrite_input=None, + method=None, + keepdims=keep_dims, + interpolation=None, + frontend=frontend, + fn_tree=fn_tree, + test_flags=test_flags, + input_dtypes=input_dtypes, + ) + + +# ptp +@handle_frontend_test( + fn_tree="numpy.ptp", + dtype_values_axis=_statistical_dtype_values(function="ptp"), + keep_dims=st.booleans(), + test_with_out=st.just(False), +) +def test_numpy_ptp( + dtype_values_axis, + frontend, + backend_fw, + test_flags, + fn_tree, + keep_dims, +): + input_dtypes, values, axis = dtype_values_axis + if isinstance(axis, tuple): + axis = axis[0] + + helpers.test_frontend_function( + a=values[0], + axis=axis, + out=None, + keepdims=keep_dims, + frontend=frontend, + backend_to_test=backend_fw, + fn_tree=fn_tree, + test_flags=test_flags, + input_dtypes=input_dtypes, + ) diff --git a/ivy_tests/test_ivy/test_frontends/test_numpy/test_ufunc/test_methods.py b/ivy_tests/test_ivy/test_frontends/test_numpy/test_ufunc/test_methods.py index 8dae8a11d80d7..57415db4fce5f 100644 --- a/ivy_tests/test_ivy/test_frontends/test_numpy/test_ufunc/test_methods.py +++ b/ivy_tests/test_ivy/test_frontends/test_numpy/test_ufunc/test_methods.py @@ -9,59 +9,63 @@ ) +# --- Helpers --- # +# --------------- # + + # strategy to generate a ufunc from given list @st.composite def generate_ufunc(draw, ufuncs=ufuncs): return draw(st.sampled_from(ufuncs)) -# nargs +# identity @given( ufunc_name=generate_ufunc(), ) -def test_numpy_nargs( +def test_numpy_identity( ufunc_name, ): assume(hasattr(np_frontend, ufunc_name)) frontend_ufunc = getattr(np_frontend, ufunc_name) np_ufunc = getattr(np, ufunc_name) - assert frontend_ufunc.nargs == np_ufunc.nargs + assert frontend_ufunc.identity == np_ufunc.identity -# nin +# nargs @given( ufunc_name=generate_ufunc(), ) -def test_numpy_nin( +def test_numpy_nargs( ufunc_name, ): assume(hasattr(np_frontend, ufunc_name)) frontend_ufunc = getattr(np_frontend, ufunc_name) np_ufunc = getattr(np, ufunc_name) - assert frontend_ufunc.nin == np_ufunc.nin + assert frontend_ufunc.nargs == np_ufunc.nargs -# nout +# nin @given( ufunc_name=generate_ufunc(), ) -def test_numpy_nout( +def test_numpy_nin( ufunc_name, ): assume(hasattr(np_frontend, ufunc_name)) frontend_ufunc = getattr(np_frontend, ufunc_name) np_ufunc = getattr(np, ufunc_name) - assert frontend_ufunc.nout == np_ufunc.nout + assert frontend_ufunc.nin == np_ufunc.nin -# identity +# nout @given( ufunc_name=generate_ufunc(), ) -def test_numpy_identity( +def test_numpy_nout( ufunc_name, ): assume(hasattr(np_frontend, ufunc_name)) frontend_ufunc = getattr(np_frontend, ufunc_name) np_ufunc = getattr(np, ufunc_name) - assert frontend_ufunc.identity == np_ufunc.identity + assert frontend_ufunc.nout == np_ufunc.nout diff --git a/ivy_tests/test_ivy/test_frontends/test_onnx/test_elementwise.py b/ivy_tests/test_ivy/test_frontends/test_onnx/test_elementwise.py index ad1f388fa679c..4d4108730bd9c 100644 --- a/ivy_tests/test_ivy/test_frontends/test_onnx/test_elementwise.py +++ b/ivy_tests/test_ivy/test_frontends/test_onnx/test_elementwise.py @@ -41,6 +41,33 @@ def test_onnx_abs( ) +@pytest.mark.skip("Testing pipeline not yet implemented") +@given( + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric", prune_function=False), + large_abs_safety_factor=2.5, + small_abs_safety_factor=2.5, + safety_factor_scale="log", + ) +) +def test_onnx_abs_v2(dtype_x): + _, data = dtype_x + x_onnx = onnx.Tensor(data[0]) + x_torch = torch.Tensor(data[0]) + + onnx_abs = onnx.abs(x_onnx) + torch_abs = torch.abs(x_torch) + + ret = helpers.flatten_and_to_np(ret=onnx_abs) + ret_gt = helpers.flatten_and_to_np(ret=torch_abs) + + helpers.value_test( + ret_np_flat=ret, + ret_np_from_gt_flat=ret_gt, + ground_truth_backend="torch", + ) + + @pytest.mark.skip("Testing pipeline not yet implemented") @handle_frontend_test( fn_tree="onnx.acos", @@ -67,6 +94,30 @@ def test_onnx_acos( ) +@pytest.mark.skip("Testing pipeline not yet implemented") +@given( + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float", prune_function=False), + ).filter(lambda x: "float16" not in x[0]), +) +def test_onnx_acos_v2(dtype_x): + _, data = dtype_x + x_onnx = onnx.Tensor(data[0]) + x_torch = torch.Tensor(data[0]) + + onnx_acos = onnx.acos(x_onnx) + torch_acos = torch.acos(x_torch) + + ret = helpers.flatten_and_to_np(ret=onnx_acos) + ret_gt = helpers.flatten_and_to_np(ret=torch_acos) + + helpers.value_test( + ret_np_flat=ret, + ret_np_from_gt_flat=ret_gt, + ground_truth_backend="tensorflow", + ) + + @pytest.mark.skip("Testing pipeline not yet implemented") @handle_frontend_test( fn_tree="onnx.acosh", @@ -93,6 +144,30 @@ def test_onnx_acosh( ) +@pytest.mark.skip("Testing pipeline not yet implemented") +@given( + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float", prune_function=False), + ).filter(lambda x: "float16" not in x[0]), +) +def test_onnx_acosh_v2(dtype_x): + _, data = dtype_x + x_onnx = onnx.Tensor(data[0]) + x_torch = torch.Tensor(data[0]) + + onnx_acosh = onnx.acosh(x_onnx) + torch_acosh = torch.acosh(x_torch) + + ret = helpers.flatten_and_to_np(ret=onnx_acosh) + ret_gt = helpers.flatten_and_to_np(ret=torch_acosh) + + helpers.value_test( + ret_np_flat=ret, + ret_np_from_gt_flat=ret_gt, + ground_truth_backend="tensorflow", + ) + + @pytest.mark.skip("Testing pipeline not yet implemented") @handle_frontend_test( fn_tree="onnx.add", @@ -128,107 +203,6 @@ def test_onnx_add( ) -@pytest.mark.skip("Testing pipeline not yet implemented") -@handle_frontend_test( - fn_tree="onnx.asin", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - ), -) -def test_onnx_asin( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - test_flags, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=x[0], - ) - - -@pytest.mark.skip("Testing pipeline not yet implemented") -@given( - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric", prune_function=False), - large_abs_safety_factor=2.5, - small_abs_safety_factor=2.5, - safety_factor_scale="log", - ) -) -def test_onnx_abs_v2(dtype_x): - _, data = dtype_x - x_onnx = onnx.Tensor(data[0]) - x_torch = torch.Tensor(data[0]) - - onnx_abs = onnx.abs(x_onnx) - torch_abs = torch.abs(x_torch) - - ret = helpers.flatten_and_to_np(ret=onnx_abs) - ret_gt = helpers.flatten_and_to_np(ret=torch_abs) - - helpers.value_test( - ret_np_flat=ret, - ret_np_from_gt_flat=ret_gt, - ground_truth_backend="torch", - ) - - -@pytest.mark.skip("Testing pipeline not yet implemented") -@given( - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float", prune_function=False), - ).filter(lambda x: "float16" not in x[0]), -) -def test_onnx_acos_v2(dtype_x): - _, data = dtype_x - x_onnx = onnx.Tensor(data[0]) - x_torch = torch.Tensor(data[0]) - - onnx_acos = onnx.acos(x_onnx) - torch_acos = torch.acos(x_torch) - - ret = helpers.flatten_and_to_np(ret=onnx_acos) - ret_gt = helpers.flatten_and_to_np(ret=torch_acos) - - helpers.value_test( - ret_np_flat=ret, - ret_np_from_gt_flat=ret_gt, - ground_truth_backend="tensorflow", - ) - - -@pytest.mark.skip("Testing pipeline not yet implemented") -@given( - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float", prune_function=False), - ).filter(lambda x: "float16" not in x[0]), -) -def test_onnx_acosh_v2(dtype_x): - _, data = dtype_x - x_onnx = onnx.Tensor(data[0]) - x_torch = torch.Tensor(data[0]) - - onnx_acosh = onnx.acosh(x_onnx) - torch_acosh = torch.acosh(x_torch) - - ret = helpers.flatten_and_to_np(ret=onnx_acosh) - ret_gt = helpers.flatten_and_to_np(ret=torch_acosh) - - helpers.value_test( - ret_np_flat=ret, - ret_np_from_gt_flat=ret_gt, - ground_truth_backend="tensorflow", - ) - - @pytest.mark.skip("Testing pipeline not yet implemented") @given( dtype_x=helpers.dtype_and_values( @@ -259,6 +233,32 @@ def test_onnx_add_v2(dtype_x): ) +@pytest.mark.skip("Testing pipeline not yet implemented") +@handle_frontend_test( + fn_tree="onnx.asin", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + ), +) +def test_onnx_asin( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=x[0], + ) + + @pytest.mark.skip("Testing pipeline not yet implemented") @given( dtype_x=helpers.dtype_and_values( diff --git a/ivy_tests/test_ivy/test_frontends/test_onnx/test_tensor.py b/ivy_tests/test_ivy/test_frontends/test_onnx/test_tensor.py index 71940c9ae36cf..a0971385a7b49 100644 --- a/ivy_tests/test_ivy/test_frontends/test_onnx/test_tensor.py +++ b/ivy_tests/test_ivy/test_frontends/test_onnx/test_tensor.py @@ -14,18 +14,14 @@ available_dtypes=helpers.get_dtypes("valid", prune_function=False) ).filter(lambda x: "bfloat16" not in x[0]), ) -def test_onnx_tensor_property_ivy_array( +def test_onnx_tensor_property_device( dtype_x, ): _, data = dtype_x x = Tensor(data[0]) x.ivy_array = data[0] - ret = helpers.flatten_and_to_np(ret=x.ivy_array.data, backend="torch") - ret_gt = helpers.flatten_and_to_np(ret=data[0], backend="torch") - helpers.value_test( - ret_np_flat=ret, - ret_np_from_gt_flat=ret_gt, - backend="torch", + ivy.utils.assertions.check_equal( + x.device, ivy.dev(ivy.array(data[0])), as_array=False ) @@ -35,15 +31,13 @@ def test_onnx_tensor_property_ivy_array( available_dtypes=helpers.get_dtypes("valid", prune_function=False) ).filter(lambda x: "bfloat16" not in x[0]), ) -def test_onnx_tensor_property_device( +def test_onnx_tensor_property_dtype( dtype_x, ): - _, data = dtype_x + dtype, data = dtype_x x = Tensor(data[0]) x.ivy_array = data[0] - ivy.utils.assertions.check_equal( - x.device, ivy.dev(ivy.array(data[0])), as_array=False - ) + ivy.utils.assertions.check_equal(x.dtype, dtype[0], as_array=False) @pytest.mark.skip("Testing pipeline not yet implemented") @@ -52,13 +46,19 @@ def test_onnx_tensor_property_device( available_dtypes=helpers.get_dtypes("valid", prune_function=False) ).filter(lambda x: "bfloat16" not in x[0]), ) -def test_onnx_tensor_property_dtype( +def test_onnx_tensor_property_ivy_array( dtype_x, ): - dtype, data = dtype_x + _, data = dtype_x x = Tensor(data[0]) x.ivy_array = data[0] - ivy.utils.assertions.check_equal(x.dtype, dtype[0], as_array=False) + ret = helpers.flatten_and_to_np(ret=x.ivy_array.data, backend="torch") + ret_gt = helpers.flatten_and_to_np(ret=data[0], backend="torch") + helpers.value_test( + ret_np_flat=ret, + ret_np_from_gt_flat=ret_gt, + backend="torch", + ) @pytest.mark.skip("Testing pipeline not yet implemented") @@ -68,12 +68,12 @@ def test_onnx_tensor_property_dtype( ret_shape=True, ).filter(lambda x: "bfloat16" not in x[0]), ) -def test_onnx_tensor_property_shape(dtype_x): +def test_onnx_tensor_property_ndim( + dtype_x, +): dtype, data, shape = dtype_x x = Tensor(data[0]) - ivy.utils.assertions.check_equal( - x.ivy_array.shape, ivy.Shape(shape), as_array=False - ) + ivy.utils.assertions.check_equal(x.ndim, data[0].ndim, as_array=False) @pytest.mark.skip("Testing pipeline not yet implemented") @@ -83,9 +83,9 @@ def test_onnx_tensor_property_shape(dtype_x): ret_shape=True, ).filter(lambda x: "bfloat16" not in x[0]), ) -def test_onnx_tensor_property_ndim( - dtype_x, -): +def test_onnx_tensor_property_shape(dtype_x): dtype, data, shape = dtype_x x = Tensor(data[0]) - ivy.utils.assertions.check_equal(x.ndim, data[0].ndim, as_array=False) + ivy.utils.assertions.check_equal( + x.ivy_array.shape, ivy.Shape(shape), as_array=False + ) diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_fft.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_fft.py index 8cc081e98f9c9..d6c0779852f07 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_fft.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_fft.py @@ -115,75 +115,75 @@ def test_paddle_ifft( @handle_frontend_test( - fn_tree="paddle.fft.irfft", + fn_tree="paddle.fft.ifftshift", dtype_x_axis=helpers.dtype_values_axis( available_dtypes=helpers.get_dtypes("valid"), min_value=-10, max_value=10, min_num_dims=1, - min_dim_size=2, valid_axis=True, force_int_axis=True, ), - n=st.one_of( - st.integers(min_value=2, max_value=10), - st.just(None), - ), - norm=st.sampled_from(["backward", "ortho", "forward"]), ) -def test_paddle_irfft( +def test_paddle_ifftshift( dtype_x_axis, - n, - norm, frontend, test_flags, fn_tree, + on_device, backend_fw, ): - input_dtypes, x, axis = dtype_x_axis + input_dtype, x, axes = dtype_x_axis helpers.test_frontend_function( - input_dtypes=input_dtypes, + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, + on_device=on_device, + test_values=True, x=x[0], - n=n, - axis=axis, - norm=norm, - valid_axis=True, - force_int_axis=True, + axes=axes, ) @handle_frontend_test( - fn_tree="paddle.fft.ifftshift", + fn_tree="paddle.fft.irfft", dtype_x_axis=helpers.dtype_values_axis( available_dtypes=helpers.get_dtypes("valid"), min_value=-10, max_value=10, min_num_dims=1, + min_dim_size=2, valid_axis=True, force_int_axis=True, ), + n=st.one_of( + st.integers(min_value=2, max_value=10), + st.just(None), + ), + norm=st.sampled_from(["backward", "ortho", "forward"]), ) -def test_paddle_ifftshift( +def test_paddle_irfft( dtype_x_axis, + n, + norm, frontend, test_flags, fn_tree, - on_device, backend_fw, ): - input_dtype, x, axes = dtype_x_axis + input_dtypes, x, axis = dtype_x_axis helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=input_dtypes, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, - on_device=on_device, - test_values=True, x=x[0], - axes=axes, + n=n, + axis=axis, + norm=norm, + valid_axis=True, + force_int_axis=True, ) diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_activation.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_activation.py index 3923891f557cb..dca9cef8c70a9 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_activation.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_activation.py @@ -7,89 +7,111 @@ from ivy_tests.test_ivy.helpers import handle_frontend_test -# selu +# --- Helpers --- # +# --------------- # + + +@st.composite +def _generate_prelu_arrays(draw): + arr_size = draw(helpers.ints(min_value=2, max_value=5)) + dtype = draw(helpers.get_dtypes("float", index=1, full=False)) + input = draw( + helpers.array_values( + dtype=dtype[0], shape=(arr_size), min_value=0, max_value=10 + ) + ) + weight = draw( + helpers.array_values(dtype=dtype[0], shape=(1,), min_value=0, max_value=1.0) + ) + input_weight = input, weight + return dtype, input_weight + + +# --- Main --- # +# ------------ # + + +# celu @handle_frontend_test( - fn_tree="paddle.nn.functional.selu", + fn_tree="paddle.nn.functional.celu", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), - safety_factor_scale="log", - small_abs_safety_factor=20, ), - scale=helpers.ints(min_value=2, max_value=10), alpha=helpers.ints(min_value=1, max_value=10), ) -def test_paddle_selu( +def test_paddle_celu( *, dtype_and_x, - scale, alpha, on_device, fn_tree, frontend, - test_flags, backend_fw, + test_flags, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=input_dtype, backend_to_test=backend_fw, + input_dtypes=input_dtype, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, x=x[0], alpha=alpha, - scale=scale, ) -# hardshrink +# elu @handle_frontend_test( - fn_tree="paddle.nn.functional.hardshrink", - dtype_and_x=helpers.dtype_and_values( + fn_tree="paddle.nn.functional.elu", + dtype_and_input=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), ), - threshold=helpers.floats(min_value=0, max_value=1, exclude_min=True), + alpha=helpers.floats(min_value=0, max_value=1, exclude_min=True), ) -def test_paddle_hardshrink( +def test_paddle_elu( *, - dtype_and_x, - threshold, + dtype_and_input, + alpha, on_device, fn_tree, + backend_fw, frontend, test_flags, - backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, x = dtype_and_input helpers.test_frontend_function( input_dtypes=input_dtype, - backend_to_test=backend_fw, frontend=frontend, + backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, x=x[0], - threshold=threshold, + alpha=alpha, ) -# hardswish +# gelu @handle_frontend_test( - fn_tree="paddle.nn.functional.hardswish", + fn_tree="paddle.nn.functional.gelu", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("valid"), safety_factor_scale="log", + small_abs_safety_factor=20, ), + approximate=st.booleans(), ) -def test_paddle_hardswish( +def test_paddle_gelu( *, dtype_and_x, + approximate, on_device, fn_tree, frontend, - backend_fw, test_flags, + backend_fw, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -99,30 +121,42 @@ def test_paddle_hardswish( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, + rtol=1e-2, + atol=1e-2, x=x[0], + approximate=approximate, ) -# hardtanh +# gumbel_softmax @handle_frontend_test( - fn_tree="paddle.nn.functional.hardtanh", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), + fn_tree="paddle.nn.functional.gumbel_softmax", + dtype_x_and_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("float"), + min_num_dims=1, + max_axes_size=1, + force_int_axis=True, + valid_axis=True, + min_value=-30.0, + max_value=30.0, ), - max_val=helpers.floats(min_value=0, max_value=1, exclude_min=True), + dtypes=helpers.get_dtypes("float", none=False, full=False), + temperature=st.floats(min_value=1e-3, max_value=10), + hard=st.booleans(), ) -def test_paddle_hardtanh( +def test_paddle_gumbel_softmax( *, - dtype_and_x, - max_val, + dtype_x_and_axis, + dtypes, + temperature, + hard, on_device, + backend_fw, fn_tree, frontend, test_flags, - backend_fw, ): - input_dtype, x = dtype_and_x - max_min = max_val, -max_val + input_dtype, x, axis = dtype_x_and_axis helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -131,25 +165,25 @@ def test_paddle_hardtanh( fn_tree=fn_tree, on_device=on_device, x=x[0], - min=max_min[1], - max=max_min[0], + axis=axis, + dtype=ivy.as_ivy_dtype(dtypes[0]), + temperature=temperature, + hard=hard, ) -# gelu +# hardshrink @handle_frontend_test( - fn_tree="paddle.nn.functional.gelu", + fn_tree="paddle.nn.functional.hardshrink", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), - safety_factor_scale="log", - small_abs_safety_factor=20, ), - approximate=st.booleans(), + threshold=helpers.floats(min_value=0, max_value=1, exclude_min=True), ) -def test_paddle_gelu( +def test_paddle_hardshrink( *, dtype_and_x, - approximate, + threshold, on_device, fn_tree, frontend, @@ -164,10 +198,8 @@ def test_paddle_gelu( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - rtol=1e-2, - atol=1e-2, x=x[0], - approximate=approximate, + threshold=threshold, ) @@ -205,21 +237,22 @@ def test_paddle_hardsigmoid( ) -# relu6 +# hardswish @handle_frontend_test( - fn_tree="paddle.nn.functional.relu6", + fn_tree="paddle.nn.functional.hardswish", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), + available_dtypes=helpers.get_dtypes("float"), + safety_factor_scale="log", ), ) -def test_paddle_relu6( +def test_paddle_hardswish( *, dtype_and_x, on_device, fn_tree, frontend, - test_flags, backend_fw, + test_flags, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -233,25 +266,26 @@ def test_paddle_relu6( ) -# softshrink +# hardtanh @handle_frontend_test( - fn_tree="paddle.nn.functional.softshrink", - dtype_and_input=helpers.dtype_and_values( + fn_tree="paddle.nn.functional.hardtanh", + dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), ), - threshold=helpers.floats(min_value=0, max_value=1, exclude_min=True), + max_val=helpers.floats(min_value=0, max_value=1, exclude_min=True), ) -def test_paddle_softshrink( +def test_paddle_hardtanh( *, - dtype_and_input, - threshold, + dtype_and_x, + max_val, on_device, fn_tree, frontend, test_flags, backend_fw, ): - input_dtype, x = dtype_and_input + input_dtype, x = dtype_and_x + max_min = max_val, -max_val helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -260,27 +294,58 @@ def test_paddle_softshrink( fn_tree=fn_tree, on_device=on_device, x=x[0], - threshold=threshold, + min=max_min[1], + max=max_min[0], ) -# softsign @handle_frontend_test( - fn_tree="paddle.nn.functional.softsign", + fn_tree="paddle.nn.functional.leaky_relu", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - safety_factor_scale="log", - small_abs_safety_factor=20, + available_dtypes=helpers.get_dtypes("float"), ), ) -def test_paddle_softsign( +def test_paddle_leaky_relu( *, dtype_and_x, on_device, + backend_fw, fn_tree, frontend, test_flags, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + negative_slope=0.01, + x=x[0], + ) + + +# log_sigmoid +@handle_frontend_test( + fn_tree="paddle.nn.functional.log_sigmoid", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + large_abs_safety_factor=3, + small_abs_safety_factor=3, + safety_factor_scale="linear", + ), + test_with_out=st.just(False), +) +def test_paddle_log_sigmoid( + *, + dtype_and_x, + frontend, backend_fw, + test_flags, + fn_tree, + on_device, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -333,116 +398,71 @@ def test_paddle_log_softmax( ) -@st.composite -def _generate_prelu_arrays(draw): - arr_size = draw(helpers.ints(min_value=2, max_value=5)) - dtype = draw(helpers.get_dtypes("float", index=1, full=False)) - input = draw( - helpers.array_values( - dtype=dtype[0], shape=(arr_size), min_value=0, max_value=10 - ) - ) - weight = draw( - helpers.array_values(dtype=dtype[0], shape=(1,), min_value=0, max_value=1.0) - ) - input_weight = input, weight - return dtype, input_weight - - -# prelu -@handle_frontend_test( - fn_tree="paddle.nn.functional.prelu", - dtype_input_and_weight=_generate_prelu_arrays(), -) -def test_paddle_prelu( - *, - dtype_input_and_weight, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - dtype, x = dtype_input_and_weight - helpers.test_frontend_function( - input_dtypes=dtype, - frontend=frontend, - backend_to_test=backend_fw, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - weight=x[1], - ) - - -# celu +# mish @handle_frontend_test( - fn_tree="paddle.nn.functional.celu", - dtype_and_x=helpers.dtype_and_values( + fn_tree="paddle.nn.functional.mish", + dtype_and_input=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), + safety_factor_scale="log", + small_abs_safety_factor=20, ), - alpha=helpers.ints(min_value=1, max_value=10), ) -def test_paddle_celu( +def test_paddle_mish( *, - dtype_and_x, - alpha, + dtype_and_input, on_device, + backend_fw, fn_tree, frontend, - backend_fw, test_flags, ): - input_dtype, x = dtype_and_x + input_dtype, x = dtype_and_input helpers.test_frontend_function( - backend_to_test=backend_fw, input_dtypes=input_dtype, + backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, x=x[0], - alpha=alpha, ) +# prelu @handle_frontend_test( - fn_tree="paddle.nn.functional.rrelu", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - ), + fn_tree="paddle.nn.functional.prelu", + dtype_input_and_weight=_generate_prelu_arrays(), ) -def test_paddle_rrelu( +def test_paddle_prelu( *, - dtype_and_x, + dtype_input_and_weight, on_device, fn_tree, frontend, test_flags, backend_fw, ): - input_dtype, x = dtype_and_x + dtype, x = dtype_input_and_weight helpers.test_frontend_function( - backend_to_test=backend_fw, - input_dtypes=input_dtype, + input_dtypes=dtype, frontend=frontend, + backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - test_values=False, x=x[0], + weight=x[1], ) -# tanhshrink +# relu6 @handle_frontend_test( - fn_tree="paddle.nn.functional.tanhshrink", + fn_tree="paddle.nn.functional.relu6", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("valid"), ), ) -def test_paddle_tanhshrink( +def test_paddle_relu6( *, dtype_and_x, on_device, @@ -459,7 +479,6 @@ def test_paddle_tanhshrink( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - test_values=False, x=x[0], ) @@ -489,56 +508,57 @@ def test_paddle_relu_( ) -# elu @handle_frontend_test( - fn_tree="paddle.nn.functional.elu", - dtype_and_input=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), + fn_tree="paddle.nn.functional.rrelu", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), ), - alpha=helpers.floats(min_value=0, max_value=1, exclude_min=True), ) -def test_paddle_elu( +def test_paddle_rrelu( *, - dtype_and_input, - alpha, + dtype_and_x, on_device, fn_tree, - backend_fw, frontend, test_flags, + backend_fw, ): - input_dtype, x = dtype_and_input + input_dtype, x = dtype_and_x helpers.test_frontend_function( + backend_to_test=backend_fw, input_dtypes=input_dtype, frontend=frontend, - backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, + test_values=False, x=x[0], - alpha=alpha, ) -# mish +# selu @handle_frontend_test( - fn_tree="paddle.nn.functional.mish", - dtype_and_input=helpers.dtype_and_values( + fn_tree="paddle.nn.functional.selu", + dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), safety_factor_scale="log", small_abs_safety_factor=20, ), + scale=helpers.ints(min_value=2, max_value=10), + alpha=helpers.ints(min_value=1, max_value=10), ) -def test_paddle_mish( +def test_paddle_selu( *, - dtype_and_input, + dtype_and_x, + scale, + alpha, on_device, - backend_fw, fn_tree, frontend, test_flags, + backend_fw, ): - input_dtype, x = dtype_and_input + input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -547,25 +567,28 @@ def test_paddle_mish( fn_tree=fn_tree, on_device=on_device, x=x[0], + alpha=alpha, + scale=scale, ) +# silu @handle_frontend_test( - fn_tree="paddle.nn.functional.leaky_relu", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + fn_tree="paddle.nn.functional.silu", + dtype_and_input=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), ), ) -def test_paddle_leaky_relu( +def test_paddle_silu( *, - dtype_and_x, + dtype_and_input, on_device, backend_fw, fn_tree, frontend, test_flags, ): - input_dtype, x = dtype_and_x + input_dtype, x = dtype_and_input helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -573,32 +596,35 @@ def test_paddle_leaky_relu( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - negative_slope=0.01, x=x[0], ) -# log_sigmoid +# softmax_ @handle_frontend_test( - fn_tree="paddle.nn.functional.log_sigmoid", - dtype_and_x=helpers.dtype_and_values( + fn_tree="paddle.nn.functional.softmax_", + dtype_x_and_axis=helpers.dtype_values_axis( available_dtypes=helpers.get_dtypes("float"), - large_abs_safety_factor=3, - small_abs_safety_factor=3, - safety_factor_scale="linear", + min_num_dims=1, + max_axes_size=1, + force_int_axis=True, + valid_axis=True, + min_value=-30.0, + max_value=30.0, ), - test_with_out=st.just(False), + dtypes=helpers.get_dtypes("float", none=False, full=False), ) -def test_paddle_log_sigmoid( +def test_paddle_softmax_( *, - dtype_and_x, - frontend, + dtype_x_and_axis, + dtypes, + on_device, backend_fw, - test_flags, fn_tree, - on_device, + frontend, + test_flags, ): - input_dtype, x = dtype_and_x + input_dtype, x, axis = dtype_x_and_axis helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -607,19 +633,27 @@ def test_paddle_log_sigmoid( fn_tree=fn_tree, on_device=on_device, x=x[0], + axis=axis, + dtype=ivy.as_ivy_dtype(dtypes[0]), ) -# silu +# softplus @handle_frontend_test( - fn_tree="paddle.nn.functional.silu", + fn_tree="paddle.nn.functional.softplus", dtype_and_input=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), ), + beta=st.floats(min_value=1e-3, max_value=10), # strategy for the beta argument + threshold=st.floats( + min_value=1e-3, max_value=10 + ), # strategy for the threshold argument ) -def test_paddle_silu( +def test_paddle_softplus( *, dtype_and_input, + beta, + threshold, on_device, backend_fw, fn_tree, @@ -635,30 +669,28 @@ def test_paddle_silu( fn_tree=fn_tree, on_device=on_device, x=x[0], + beta=beta, + threshold=threshold, ) -# softplus +# softshrink @handle_frontend_test( - fn_tree="paddle.nn.functional.softplus", + fn_tree="paddle.nn.functional.softshrink", dtype_and_input=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), ), - beta=st.floats(min_value=1e-3, max_value=10), # strategy for the beta argument - threshold=st.floats( - min_value=1e-3, max_value=10 - ), # strategy for the threshold argument + threshold=helpers.floats(min_value=0, max_value=1, exclude_min=True), ) -def test_paddle_softplus( +def test_paddle_softshrink( *, dtype_and_input, - beta, threshold, on_device, - backend_fw, fn_tree, frontend, test_flags, + backend_fw, ): input_dtype, x = dtype_and_input helpers.test_frontend_function( @@ -669,36 +701,29 @@ def test_paddle_softplus( fn_tree=fn_tree, on_device=on_device, x=x[0], - beta=beta, threshold=threshold, ) -# softmax_ +# softsign @handle_frontend_test( - fn_tree="paddle.nn.functional.softmax_", - dtype_x_and_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("float"), - min_num_dims=1, - max_axes_size=1, - force_int_axis=True, - valid_axis=True, - min_value=-30.0, - max_value=30.0, + fn_tree="paddle.nn.functional.softsign", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + safety_factor_scale="log", + small_abs_safety_factor=20, ), - dtypes=helpers.get_dtypes("float", none=False, full=False), ) -def test_paddle_softmax_( +def test_paddle_softsign( *, - dtype_x_and_axis, - dtypes, + dtype_and_x, on_device, - backend_fw, fn_tree, frontend, test_flags, + backend_fw, ): - input_dtype, x, axis = dtype_x_and_axis + input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -707,8 +732,6 @@ def test_paddle_softmax_( fn_tree=fn_tree, on_device=on_device, x=x[0], - axis=axis, - dtype=ivy.as_ivy_dtype(dtypes[0]), ) @@ -740,35 +763,23 @@ def test_paddle_tanh_( ) -# gumbel_softmax +# tanhshrink @handle_frontend_test( - fn_tree="paddle.nn.functional.gumbel_softmax", - dtype_x_and_axis=helpers.dtype_values_axis( + fn_tree="paddle.nn.functional.tanhshrink", + dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - min_num_dims=1, - max_axes_size=1, - force_int_axis=True, - valid_axis=True, - min_value=-30.0, - max_value=30.0, ), - dtypes=helpers.get_dtypes("float", none=False, full=False), - temperature=st.floats(min_value=1e-3, max_value=10), - hard=st.booleans(), ) -def test_paddle_gumbel_softmax( +def test_paddle_tanhshrink( *, - dtype_x_and_axis, - dtypes, - temperature, - hard, + dtype_and_x, on_device, - backend_fw, fn_tree, frontend, test_flags, + backend_fw, ): - input_dtype, x, axis = dtype_x_and_axis + input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -776,9 +787,6 @@ def test_paddle_gumbel_softmax( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, + test_values=False, x=x[0], - axis=axis, - dtype=ivy.as_ivy_dtype(dtypes[0]), - temperature=temperature, - hard=hard, ) diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_common.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_common.py index 70dfac25f65a4..d6b095a9536ee 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_common.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_common.py @@ -10,194 +10,8 @@ ) -# Cosine Similarity -@handle_frontend_test( - fn_tree="paddle.nn.functional.common.cosine_similarity", - d_type_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - shared_dtype=True, - min_value=2, - max_value=5, - min_dim_size=2, - shape=(4, 4), - ), - axis=st.integers(min_value=-1, max_value=1), -) -def test_paddle_cosine_similarity( - *, - d_type_and_x, - axis, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - dtype, x = d_type_and_x - helpers.test_frontend_function( - input_dtypes=dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - rtol=1e-01, - x1=x[0], - x2=x[1], - axis=axis, - ) - - -# Dropout2d -@handle_frontend_test( - fn_tree="paddle.nn.functional.common.dropout2d", - d_type_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - num_arrays=1, - shared_dtype=True, - min_value=2, - max_value=5, - min_dim_size=4, - shape=( - st.integers(min_value=2, max_value=10), - 4, - st.integers(min_value=12, max_value=64), - st.integers(min_value=12, max_value=64), - ), - ), - p=st.floats(min_value=0.0, max_value=1.0), - training=st.booleans(), - data_format=st.sampled_from(["NCHW", "NHWC"]), -) -def test_paddle_dropout2d( - *, - d_type_and_x, - p, - training, - data_format, - backend_fw, - on_device, - fn_tree, - frontend, - test_flags, -): - dtype, x = d_type_and_x - helpers.test_frontend_function( - input_dtypes=dtype, - frontend=frontend, - backend_to_test=backend_fw, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - p=p, - training=training, - data_format=data_format, - ) - - -# dropout -@handle_frontend_test( - fn_tree="paddle.nn.functional.common.dropout", - d_type_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=1, - shared_dtype=True, - min_value=2, - max_value=5, - min_num_dims=1, - max_num_dims=1, - min_dim_size=2, - ), - p=st.floats(min_value=0.0, max_value=1.0), - axis=st.integers(min_value=0, max_value=1), - training=st.booleans(), - mode=st.one_of( - *[st.just(seq) for seq in ["upscale_in_train", "downscale_in_infer"]] - ), -) -def test_paddle_dropout( - *, - d_type_and_x, - p, - on_device, - fn_tree, - backend_fw, - frontend, - test_flags, - training, - axis, - mode, -): - dtype, x = d_type_and_x - helpers.test_frontend_function( - input_dtypes=dtype, - p=p, - frontend=frontend, - backend_to_test=backend_fw, - fn_tree=fn_tree, - test_flags=test_flags, - on_device=on_device, - x=x[0], - training=training, - axis=axis, - mode=mode, - ) - - -# zeropad2d -@st.composite -def _zero2pad(draw): - dtype, input, shape = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - ret_shape=True, - min_num_dims=4, - max_num_dims=4, - min_value=-100, - max_value=100, - ) - ) - ndim = len(shape) - min_dim = min(shape) - padding = draw( - st.lists( - st.integers(min_value=0, max_value=min_dim), - min_size=ndim, - max_size=ndim, - ) - ) - return dtype, input, padding - - -@handle_frontend_test( - fn_tree="paddle.nn.functional.common.zeropad2d", - d_type_and_x_paddings=_zero2pad(), - dataformat=st.sampled_from(["NCHW", "NHWC"]), -) -def test_paddle_zeropad2d( - *, - d_type_and_x_paddings, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, - dataformat, -): - dtype, x, padding = d_type_and_x_paddings - helpers.test_frontend_function( - input_dtypes=dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - padding=padding, - data_format=dataformat, - ) +# --- Helpers --- # +# --------------- # # interpolate @@ -350,61 +164,212 @@ def _interp_args(draw, mode=None, mode_list=None): return (dtype, x, mode, size, align_corners, scale_factor, recompute_scale_factor) +# zeropad2d +@st.composite +def _zero2pad(draw): + dtype, input, shape = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + ret_shape=True, + min_num_dims=4, + max_num_dims=4, + min_value=-100, + max_value=100, + ) + ) + ndim = len(shape) + min_dim = min(shape) + padding = draw( + st.lists( + st.integers(min_value=0, max_value=min_dim), + min_size=ndim, + max_size=ndim, + ) + ) + return dtype, input, padding + + +@st.composite +def paddle_unfold_handler(draw, dtype): + dtype = draw(dtype) + h_size = draw(helpers.ints(min_value=10, max_value=30)) + w_size = draw(helpers.ints(min_value=10, max_value=30)) + channels = draw(helpers.ints(min_value=1, max_value=3)) + batch = draw(helpers.ints(min_value=1, max_value=10)) + + x = draw( + helpers.array_values( + dtype=dtype[0], + shape=[batch, channels, h_size, w_size], + min_value=0, + max_value=1, + ) + ) + + kernel_sizes = draw(helpers.ints(min_value=1, max_value=3)) + strides = draw(helpers.ints(min_value=1, max_value=3)) + paddings = draw(helpers.ints(min_value=1, max_value=3)) + dilations = draw(helpers.ints(min_value=1, max_value=3)) + return dtype, x, kernel_sizes, strides, paddings, dilations + + +# --- Main --- # +# ------------ # + + +# linear +@handle_frontend_test( + fn_tree="paddle.nn.functional.common.linear", + dtype_x_weight_bias=x_and_linear( + dtypes=helpers.get_dtypes("valid", full=False), + ), +) +def test_linear( + *, + dtype_x_weight_bias, + on_device, + fn_tree, + backend_fw, + frontend, + test_flags, +): + dtype, x, weight, bias = dtype_x_weight_bias + weight = ivy.swapaxes(weight, -1, -2) + helpers.test_frontend_function( + input_dtypes=dtype, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x, + weight=weight, + bias=bias, + ) + + +# Cosine Similarity @handle_frontend_test( - fn_tree="paddle.nn.functional.common.interpolate", - dtype_x_mode=_interp_args(), + fn_tree="paddle.nn.functional.common.cosine_similarity", + d_type_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, + shared_dtype=True, + min_value=2, + max_value=5, + min_dim_size=2, + shape=(4, 4), + ), + axis=st.integers(min_value=-1, max_value=1), ) -def test_paddle_interpolate( - dtype_x_mode, +def test_paddle_cosine_similarity( + *, + d_type_and_x, + axis, on_device, fn_tree, frontend, - backend_fw, test_flags, + backend_fw, ): - ( - input_dtype, - x, - mode, - size, - align_corners, - scale_factor, - recompute_scale_factor, - ) = dtype_x_mode + dtype, x = d_type_and_x + helpers.test_frontend_function( + input_dtypes=dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + rtol=1e-01, + x1=x[0], + x2=x[1], + axis=axis, + ) + +# dropout +@handle_frontend_test( + fn_tree="paddle.nn.functional.common.dropout", + d_type_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + num_arrays=1, + shared_dtype=True, + min_value=2, + max_value=5, + min_num_dims=1, + max_num_dims=1, + min_dim_size=2, + ), + p=st.floats(min_value=0.0, max_value=1.0), + axis=st.integers(min_value=0, max_value=1), + training=st.booleans(), + mode=st.one_of( + *[st.just(seq) for seq in ["upscale_in_train", "downscale_in_infer"]] + ), +) +def test_paddle_dropout( + *, + d_type_and_x, + p, + on_device, + fn_tree, + backend_fw, + frontend, + test_flags, + training, + axis, + mode, +): + dtype, x = d_type_and_x helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=dtype, + p=p, frontend=frontend, backend_to_test=backend_fw, - test_flags=test_flags, fn_tree=fn_tree, + test_flags=test_flags, on_device=on_device, x=x[0], - size=size, - scale_factor=scale_factor, + training=training, + axis=axis, mode=mode, - align_corners=align_corners, ) -# linear +# Dropout2d @handle_frontend_test( - fn_tree="paddle.nn.functional.common.linear", - dtype_x_weight_bias=x_and_linear( - dtypes=helpers.get_dtypes("valid", full=False), + fn_tree="paddle.nn.functional.common.dropout2d", + d_type_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + num_arrays=1, + shared_dtype=True, + min_value=2, + max_value=5, + min_dim_size=4, + shape=( + st.integers(min_value=2, max_value=10), + 4, + st.integers(min_value=12, max_value=64), + st.integers(min_value=12, max_value=64), + ), ), + p=st.floats(min_value=0.0, max_value=1.0), + training=st.booleans(), + data_format=st.sampled_from(["NCHW", "NHWC"]), ) -def test_linear( +def test_paddle_dropout2d( *, - dtype_x_weight_bias, + d_type_and_x, + p, + training, + data_format, + backend_fw, on_device, fn_tree, - backend_fw, frontend, test_flags, ): - dtype, x, weight, bias = dtype_x_weight_bias - weight = ivy.swapaxes(weight, -1, -2) + dtype, x = d_type_and_x helpers.test_frontend_function( input_dtypes=dtype, frontend=frontend, @@ -412,9 +377,10 @@ def test_linear( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x, - weight=weight, - bias=bias, + x=x[0], + p=p, + training=training, + data_format=data_format, ) @@ -457,28 +423,70 @@ def test_paddle_dropout3d( ) -@st.composite -def paddle_unfold_handler(draw, dtype): - dtype = draw(dtype) - h_size = draw(helpers.ints(min_value=10, max_value=30)) - w_size = draw(helpers.ints(min_value=10, max_value=30)) - channels = draw(helpers.ints(min_value=1, max_value=3)) - batch = draw(helpers.ints(min_value=1, max_value=10)) +@handle_frontend_test( + fn_tree="paddle.nn.functional.common.interpolate", + dtype_x_mode=_interp_args(), +) +def test_paddle_interpolate( + dtype_x_mode, + on_device, + fn_tree, + frontend, + backend_fw, + test_flags, +): + ( + input_dtype, + x, + mode, + size, + align_corners, + scale_factor, + recompute_scale_factor, + ) = dtype_x_mode - x = draw( - helpers.array_values( - dtype=dtype[0], - shape=[batch, channels, h_size, w_size], - min_value=0, - max_value=1, - ) + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + size=size, + scale_factor=scale_factor, + mode=mode, + align_corners=align_corners, ) - kernel_sizes = draw(helpers.ints(min_value=1, max_value=3)) - strides = draw(helpers.ints(min_value=1, max_value=3)) - paddings = draw(helpers.ints(min_value=1, max_value=3)) - dilations = draw(helpers.ints(min_value=1, max_value=3)) - return dtype, x, kernel_sizes, strides, paddings, dilations + +@handle_frontend_test( + fn_tree="paddle.nn.functional.common.zeropad2d", + d_type_and_x_paddings=_zero2pad(), + dataformat=st.sampled_from(["NCHW", "NHWC"]), +) +def test_paddle_zeropad2d( + *, + d_type_and_x_paddings, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, + dataformat, +): + dtype, x, padding = d_type_and_x_paddings + helpers.test_frontend_function( + input_dtypes=dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + padding=padding, + data_format=dataformat, + ) @handle_frontend_test( diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_conv.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_conv.py index 34154d3e1c943..8d08fa2f37c85 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_conv.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_conv.py @@ -45,12 +45,12 @@ def test_paddle_conv1d( ) -# conv2d +# conv1d_transpose @handle_frontend_test( - fn_tree="paddle.nn.functional.conv2d", - dtype_vals=x_and_filters(dim=2), + fn_tree="paddle.nn.functional.conv1d_transpose", + dtype_vals=x_and_filters(dim=1, transpose=True), ) -def test_paddle_conv2d( +def test_paddle_conv1d_tranpose( *, dtype_vals, on_device, @@ -59,7 +59,16 @@ def test_paddle_conv2d( test_flags, backend_fw, ): - dtype, vals, weight, bias, dilations, strides, padding, fc = dtype_vals + dtype, vals, weight, bias, dilations, strides, padding, output_pad, fc = dtype_vals + dilations = 1 # ToDo: remove this when support for dilation > 1 is added + assume( + all( + x > 0 + for x in _output_shape( + 1, dilations, strides, padding, output_pad, vals.shape, weight.shape + ) + ) + ) helpers.test_frontend_function( input_dtypes=dtype, backend_to_test=backend_fw, @@ -72,17 +81,18 @@ def test_paddle_conv2d( bias=bias, stride=strides, padding=padding, - dilation=dilations, + output_padding=output_pad, groups=fc, + dilation=dilations, ) -# conv3d +# conv2d @handle_frontend_test( - fn_tree="paddle.nn.functional.conv3d", - dtype_vals=x_and_filters(dim=3), + fn_tree="paddle.nn.functional.conv2d", + dtype_vals=x_and_filters(dim=2), ) -def test_paddle_conv3d( +def test_paddle_conv2d( *, dtype_vals, on_device, @@ -92,8 +102,6 @@ def test_paddle_conv3d( backend_fw, ): dtype, vals, weight, bias, dilations, strides, padding, fc = dtype_vals - # ToDo: Enable gradient tests for dilations > 1 when tensorflow supports it. - _assume_tf_dilation_gt_1(backend_fw, on_device, dilations) helpers.test_frontend_function( input_dtypes=dtype, backend_to_test=backend_fw, @@ -111,12 +119,12 @@ def test_paddle_conv3d( ) -# conv1d_transpose +# conv2d_transpose @handle_frontend_test( - fn_tree="paddle.nn.functional.conv1d_transpose", - dtype_vals=x_and_filters(dim=1, transpose=True), + fn_tree="paddle.nn.functional.conv2d_transpose", + dtype_vals=x_and_filters(dim=2, transpose=True), ) -def test_paddle_conv1d_tranpose( +def test_paddle_conv2d_tranpose( *, dtype_vals, on_device, @@ -131,7 +139,7 @@ def test_paddle_conv1d_tranpose( all( x > 0 for x in _output_shape( - 1, dilations, strides, padding, output_pad, vals.shape, weight.shape + 2, dilations, strides, padding, output_pad, vals.shape, weight.shape ) ) ) @@ -148,17 +156,17 @@ def test_paddle_conv1d_tranpose( stride=strides, padding=padding, output_padding=output_pad, - groups=fc, dilation=dilations, + groups=fc, ) -# conv2d_transpose +# conv3d @handle_frontend_test( - fn_tree="paddle.nn.functional.conv2d_transpose", - dtype_vals=x_and_filters(dim=2, transpose=True), + fn_tree="paddle.nn.functional.conv3d", + dtype_vals=x_and_filters(dim=3), ) -def test_paddle_conv2d_tranpose( +def test_paddle_conv3d( *, dtype_vals, on_device, @@ -167,16 +175,9 @@ def test_paddle_conv2d_tranpose( test_flags, backend_fw, ): - dtype, vals, weight, bias, dilations, strides, padding, output_pad, fc = dtype_vals - dilations = 1 # ToDo: remove this when support for dilation > 1 is added - assume( - all( - x > 0 - for x in _output_shape( - 2, dilations, strides, padding, output_pad, vals.shape, weight.shape - ) - ) - ) + dtype, vals, weight, bias, dilations, strides, padding, fc = dtype_vals + # ToDo: Enable gradient tests for dilations > 1 when tensorflow supports it. + _assume_tf_dilation_gt_1(backend_fw, on_device, dilations) helpers.test_frontend_function( input_dtypes=dtype, backend_to_test=backend_fw, @@ -189,7 +190,6 @@ def test_paddle_conv2d_tranpose( bias=bias, stride=strides, padding=padding, - output_padding=output_pad, dilation=dilations, groups=fc, ) diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_distance.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_distance.py index 0a9c8754f4fed..8b137891791fe 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_distance.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_distance.py @@ -1,3 +1 @@ -# global -# local diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_extension.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_extension.py index 0a9c8754f4fed..8b137891791fe 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_extension.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_extension.py @@ -1,3 +1 @@ -# global -# local diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_input.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_input.py index 0a9c8754f4fed..8b137891791fe 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_input.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_input.py @@ -1,3 +1 @@ -# global -# local diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_loss.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_loss.py index 39ac775309039..7883b88d30559 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_loss.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_loss.py @@ -7,6 +7,39 @@ from ivy_tests.test_ivy.helpers import handle_frontend_test +# --- Helpers --- # +# --------------- # + + +# cosine embedding loss +@st.composite +def _cos_embd_loss_helper(draw): + dtype_inputs_shape = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=1, + max_num_dims=2, + min_dim_size=2, + ret_shape=True, + num_arrays=2, + ) + ) + + input_dtypes, inputs, shape = dtype_inputs_shape + + _, label = draw( + helpers.dtype_and_values( + dtype=input_dtypes, shape=(shape[0],), min_value=-1, max_value=1 + ), + ) + + return input_dtypes, inputs, label + + +# --- Main --- # +# ------------ # + + @handle_frontend_test( fn_tree="paddle.nn.functional.binary_cross_entropy_with_logits", dtype_and_x=helpers.dtype_and_values( @@ -55,68 +88,6 @@ def test_paddle_binary_cross_entropy_with_logits( ) -# mse_loss -@handle_frontend_test( - fn_tree="paddle.nn.functional.mse_loss", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - shared_dtype=True, - min_num_dims=1, - large_abs_safety_factor=2.5, - small_abs_safety_factor=2.5, - safety_factor_scale="log", - ), - reduction=st.sampled_from(["mean", "none", "sum"]), -) -def test_paddle_mse_loss( - dtype_and_x, - reduction, - on_device, - backend_fw, - fn_tree, - frontend, - test_flags, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - frontend=frontend, - backend_to_test=backend_fw, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=x[0], - label=x[1], - reduction=reduction, - ) - - -# cosine embedding loss -@st.composite -def _cos_embd_loss_helper(draw): - dtype_inputs_shape = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - min_num_dims=1, - max_num_dims=2, - min_dim_size=2, - ret_shape=True, - num_arrays=2, - ) - ) - - input_dtypes, inputs, shape = dtype_inputs_shape - - _, label = draw( - helpers.dtype_and_values( - dtype=input_dtypes, shape=(shape[0],), min_value=-1, max_value=1 - ), - ) - - return input_dtypes, inputs, label - - @handle_frontend_test( fn_tree="paddle.nn.functional.cosine_embedding_loss", dtype_xs_label=_cos_embd_loss_helper(), @@ -159,70 +130,84 @@ def test_paddle_cosine_embedding_loss( @handle_frontend_test( - fn_tree="paddle.nn.functional.hinge_embedding_loss", + fn_tree="paddle.nn.functional.dice_loss", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - num_arrays=2, - shared_dtype=True, + available_dtypes=helpers.get_dtypes("float"), + num_arrays=1, + shared_dtype=False, + min_num_dims=3, + min_dim_size=3, + max_num_dims=3, + max_dim_size=3, ), - margin=st.floats( - min_value=-1.0, - max_value=1.0, - width=16, + labels=st.lists( + ( + st.lists( + ( + st.lists( + st.integers(min_value=0, max_value=1), min_size=3, max_size=3 + ) + ), + min_size=3, + max_size=3, + ) + ), + min_size=1, + max_size=1, + ), + epsilon=st.floats( + min_value=1e-6, + max_value=1e-2, ), - reduction=st.sampled_from(["none", "mean", "sum"]), ) -def test_paddle_hinge_embedding_loss( +def test_paddle_dice_loss( dtype_and_x, - margin, - reduction, - test_flags, - backend_fw, + labels, + epsilon, + on_device, fn_tree, frontend, - on_device, + test_flags, + backend_fw, ): - input_dtype, x = dtype_and_x + x_dtype, x = dtype_and_x + x[0] = x[0].reshape([3, 3, 3]) + labels = ivy.array(labels, dtype=ivy.int64) + labels = labels.reshape([3, 3, 1]) helpers.test_frontend_function( - input_dtypes=input_dtype, - frontend=frontend, + input_dtypes=[ivy.int64] + [ivy.float64] + x_dtype, backend_to_test=backend_fw, + frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, input=x[0], - label=x[1], - margin=margin, - reduction=reduction, + label=labels, + epsilon=epsilon, ) -# log_loss @handle_frontend_test( - fn_tree="paddle.nn.functional.log_loss", + fn_tree="paddle.nn.functional.hinge_embedding_loss", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), num_arrays=2, - min_value=0, - max_value=1, - exclude_min=True, - exclude_max=True, shared_dtype=True, - min_num_dims=2, - max_num_dims=2, - max_dim_size=1, ), - epsilon=st.floats( - min_value=1e-7, + margin=st.floats( + min_value=-1.0, max_value=1.0, + width=16, ), + reduction=st.sampled_from(["none", "mean", "sum"]), ) -def test_paddle_log_loss( +def test_paddle_hinge_embedding_loss( dtype_and_x, - epsilon, - fn_tree, + margin, + reduction, test_flags, backend_fw, + fn_tree, frontend, on_device, ): @@ -236,50 +221,39 @@ def test_paddle_log_loss( on_device=on_device, input=x[0], label=x[1], - epsilon=epsilon, + margin=margin, + reduction=reduction, ) -# smooth_l1_loss @handle_frontend_test( - fn_tree="paddle.nn.functional.smooth_l1_loss", + fn_tree="paddle.nn.functional.kl_div", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), + available_dtypes=helpers.get_dtypes("float"), num_arrays=2, shared_dtype=True, min_num_dims=2, max_num_dims=5, min_dim_size=1, max_dim_size=10, + min_value=1.0013580322265625e-05, ), - delta=st.floats( - min_value=0.1, - max_value=1.0, - ), - reduction=st.sampled_from(["mean", "sum", "none"]), + reduction=st.sampled_from(["mean", "batchmean", "sum", "none"]), ) -def test_paddle_smooth_l1_loss( - dtype_and_x, - delta, - reduction, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, +def test_paddle_kl_div( + dtype_and_x, reduction, on_device, backend_fw, fn_tree, frontend, test_flags ): input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, - frontend=frontend, backend_to_test=backend_fw, + frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, input=x[0], label=x[1], reduction=reduction, - delta=delta, ) @@ -319,34 +293,46 @@ def test_paddle_l1_loss( ) +# log_loss @handle_frontend_test( - fn_tree="paddle.nn.functional.kl_div", + fn_tree="paddle.nn.functional.log_loss", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("valid"), num_arrays=2, + min_value=0, + max_value=1, + exclude_min=True, + exclude_max=True, shared_dtype=True, min_num_dims=2, - max_num_dims=5, - min_dim_size=1, - max_dim_size=10, - min_value=1.0013580322265625e-05, + max_num_dims=2, + max_dim_size=1, + ), + epsilon=st.floats( + min_value=1e-7, + max_value=1.0, ), - reduction=st.sampled_from(["mean", "batchmean", "sum", "none"]), ) -def test_paddle_kl_div( - dtype_and_x, reduction, on_device, backend_fw, fn_tree, frontend, test_flags +def test_paddle_log_loss( + dtype_and_x, + epsilon, + fn_tree, + test_flags, + backend_fw, + frontend, + on_device, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, - backend_to_test=backend_fw, frontend=frontend, + backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, input=x[0], label=x[1], - reduction=reduction, + epsilon=epsilon, ) @@ -394,51 +380,39 @@ def test_paddle_margin_ranking_loss( ) +# mse_loss @handle_frontend_test( - fn_tree="paddle.nn.functional.triplet_margin_loss", - dtype_and_inputs=helpers.dtype_and_values( + fn_tree="paddle.nn.functional.mse_loss", + dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - num_arrays=3, - allow_inf=False, + num_arrays=2, shared_dtype=True, - min_value=0.0, - max_value=1.0, min_num_dims=1, - max_num_dims=2, - min_dim_size=1, + large_abs_safety_factor=2.5, + small_abs_safety_factor=2.5, + safety_factor_scale="log", ), - margin=st.floats(min_value=1e-6, max_value=1e6), - p=st.integers(min_value=0, max_value=2), - swap=st.booleans(), - reduction=st.sampled_from(["none", "mean", "sum"]), - test_with_out=st.just(False), + reduction=st.sampled_from(["mean", "none", "sum"]), ) -def test_paddle_triplet_margin_loss( - dtype_and_inputs, - margin, - p, - swap, +def test_paddle_mse_loss( + dtype_and_x, reduction, - test_flags, - fn_tree, + on_device, backend_fw, + fn_tree, frontend, - on_device, + test_flags, ): - input_dtype, x = dtype_and_inputs + input_dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=[input_dtype[0], input_dtype[1], input_dtype[2]], - backend_to_test=backend_fw, + input_dtypes=input_dtype, frontend=frontend, + backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, input=x[0], - positive=x[1], - negative=x[2], - margin=margin, - p=p, - swap=swap, + label=x[1], reduction=reduction, ) @@ -494,59 +468,93 @@ def test_paddle_nll_loss( ) +# smooth_l1_loss @handle_frontend_test( - fn_tree="paddle.nn.functional.dice_loss", + fn_tree="paddle.nn.functional.smooth_l1_loss", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=1, - shared_dtype=False, - min_num_dims=3, - min_dim_size=3, - max_num_dims=3, - max_dim_size=3, - ), - labels=st.lists( - ( - st.lists( - ( - st.lists( - st.integers(min_value=0, max_value=1), min_size=3, max_size=3 - ) - ), - min_size=3, - max_size=3, - ) - ), - min_size=1, - max_size=1, + available_dtypes=helpers.get_dtypes("valid"), + num_arrays=2, + shared_dtype=True, + min_num_dims=2, + max_num_dims=5, + min_dim_size=1, + max_dim_size=10, ), - epsilon=st.floats( - min_value=1e-6, - max_value=1e-2, + delta=st.floats( + min_value=0.1, + max_value=1.0, ), + reduction=st.sampled_from(["mean", "sum", "none"]), ) -def test_paddle_dice_loss( +def test_paddle_smooth_l1_loss( dtype_and_x, - labels, - epsilon, + delta, + reduction, on_device, fn_tree, frontend, test_flags, backend_fw, ): - x_dtype, x = dtype_and_x - x[0] = x[0].reshape([3, 3, 3]) - labels = ivy.array(labels, dtype=ivy.int64) - labels = labels.reshape([3, 3, 1]) + input_dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=[ivy.int64] + [ivy.float64] + x_dtype, + input_dtypes=input_dtype, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=x[0], + label=x[1], + reduction=reduction, + delta=delta, + ) + + +@handle_frontend_test( + fn_tree="paddle.nn.functional.triplet_margin_loss", + dtype_and_inputs=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + num_arrays=3, + allow_inf=False, + shared_dtype=True, + min_value=0.0, + max_value=1.0, + min_num_dims=1, + max_num_dims=2, + min_dim_size=1, + ), + margin=st.floats(min_value=1e-6, max_value=1e6), + p=st.integers(min_value=0, max_value=2), + swap=st.booleans(), + reduction=st.sampled_from(["none", "mean", "sum"]), + test_with_out=st.just(False), +) +def test_paddle_triplet_margin_loss( + dtype_and_inputs, + margin, + p, + swap, + reduction, + test_flags, + fn_tree, + backend_fw, + frontend, + on_device, +): + input_dtype, x = dtype_and_inputs + helpers.test_frontend_function( + input_dtypes=[input_dtype[0], input_dtype[1], input_dtype[2]], backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, input=x[0], - label=labels, - epsilon=epsilon, + positive=x[1], + negative=x[2], + margin=margin, + p=p, + swap=swap, + reduction=reduction, ) diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_pooling.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_pooling.py index 8bba5d93b0f77..881592b4e143a 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_pooling.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_pooling.py @@ -9,111 +9,6 @@ ) -# avg_pool2d -@handle_frontend_test( - fn_tree="paddle.nn.functional.pooling.avg_pool2d", - dtype_x_k_s=helpers.arrays_for_pooling( - min_dims=4, - max_dims=4, - min_side=2, - max_side=4, - ), - ceil_mode=st.booleans(), - exclusive=st.booleans(), - data_format=st.sampled_from(["NCHW", "NHWC"]), -) -def test_paddle_avg_pool2d( - dtype_x_k_s, - exclusive, - ceil_mode, - data_format, - *, - test_flags, - backend_fw, - frontend, - fn_tree, - on_device, -): - input_dtype, x, kernel, stride, padding = dtype_x_k_s - - if data_format == "NCHW": - x[0] = x[0].reshape( - (x[0].shape[0], x[0].shape[3], x[0].shape[1], x[0].shape[2]) - ) - if len(stride) == 1: - stride = (stride[0], stride[0]) - if padding == "SAME": - padding = test_pooling_functions.calculate_same_padding( - kernel, stride, x[0].shape[2:] - ) - else: - padding = (0, 0) - helpers.test_frontend_function( - input_dtypes=input_dtype, - test_flags=test_flags, - backend_to_test=backend_fw, - frontend=frontend, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - kernel_size=kernel, - stride=stride, - padding=padding, - ceil_mode=ceil_mode, - exclusive=exclusive, - divisor_override=None, - data_format=data_format, - ) - - -# avg_pool1d -@handle_frontend_test( - fn_tree="paddle.nn.functional.avg_pool1d", - x_k_s_p_df=helpers.arrays_for_pooling( - min_dims=3, - max_dims=3, - min_side=2, - max_side=4, - data_format="channel_first", - ), - exclusive=st.booleans(), - ceil_mode=st.just(False), - test_with_out=st.just(False), -) -def test_paddle_avg_pool1d( - *, - x_k_s_p_df, - frontend, - test_flags, - backend_fw, - on_device, - fn_tree, - exclusive, - ceil_mode, -): - input_dtype, x, kernel_size, stride, padding = x_k_s_p_df - if padding == "SAME": - padding = test_pooling_functions.calculate_same_padding( - kernel_size, stride, [x[0].shape[2]] - ) - else: - padding = (0,) - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - on_device=on_device, - fn_tree=fn_tree, - x=x[0], - kernel_size=kernel_size, - stride=stride, - padding=padding, - exclusive=exclusive, - ceil_mode=ceil_mode, - ) - - # adaptive_avg_pool1d @handle_frontend_test( fn_tree="paddle.nn.functional.adaptive_avg_pool1d", @@ -192,16 +87,14 @@ def test_paddle_adaptive_avg_pool2d( ) -# adaptive_max_pool2d +# adaptive_avg_pool3d @handle_frontend_test( - fn_tree="paddle.nn.functional.adaptive_max_pool2d", + fn_tree="paddle.nn.functional.adaptive_avg_pool3d", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - min_num_dims=4, - max_num_dims=4, + min_num_dims=5, + max_num_dims=5, min_dim_size=1, - # Setting max and min value because this operation in paddle is not - # numerically stable max_value=100, min_value=-100, ), @@ -213,20 +106,18 @@ def test_paddle_adaptive_avg_pool2d( helpers.ints(min_value=1, max_value=5), ), ) -def test_paddle_adaptive_max_pool2d( +def test_paddle_adaptive_avg_pool3d( *, dtype_and_x, output_size, test_flags, frontend, on_device, - backend_fw, fn_tree, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, - backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, on_device=on_device, @@ -236,14 +127,16 @@ def test_paddle_adaptive_max_pool2d( ) -# adaptive_avg_pool3d +# adaptive_max_pool2d @handle_frontend_test( - fn_tree="paddle.nn.functional.adaptive_avg_pool3d", + fn_tree="paddle.nn.functional.adaptive_max_pool2d", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - min_num_dims=5, - max_num_dims=5, + min_num_dims=4, + max_num_dims=4, min_dim_size=1, + # Setting max and min value because this operation in paddle is not + # numerically stable max_value=100, min_value=-100, ), @@ -255,18 +148,20 @@ def test_paddle_adaptive_max_pool2d( helpers.ints(min_value=1, max_value=5), ), ) -def test_paddle_adaptive_avg_pool3d( +def test_paddle_adaptive_max_pool2d( *, dtype_and_x, output_size, test_flags, frontend, on_device, + backend_fw, fn_tree, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, + backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, on_device=on_device, @@ -276,6 +171,111 @@ def test_paddle_adaptive_avg_pool3d( ) +# avg_pool1d +@handle_frontend_test( + fn_tree="paddle.nn.functional.avg_pool1d", + x_k_s_p_df=helpers.arrays_for_pooling( + min_dims=3, + max_dims=3, + min_side=2, + max_side=4, + data_format="channel_first", + ), + exclusive=st.booleans(), + ceil_mode=st.just(False), + test_with_out=st.just(False), +) +def test_paddle_avg_pool1d( + *, + x_k_s_p_df, + frontend, + test_flags, + backend_fw, + on_device, + fn_tree, + exclusive, + ceil_mode, +): + input_dtype, x, kernel_size, stride, padding = x_k_s_p_df + if padding == "SAME": + padding = test_pooling_functions.calculate_same_padding( + kernel_size, stride, [x[0].shape[2]] + ) + else: + padding = (0,) + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + on_device=on_device, + fn_tree=fn_tree, + x=x[0], + kernel_size=kernel_size, + stride=stride, + padding=padding, + exclusive=exclusive, + ceil_mode=ceil_mode, + ) + + +# avg_pool2d +@handle_frontend_test( + fn_tree="paddle.nn.functional.pooling.avg_pool2d", + dtype_x_k_s=helpers.arrays_for_pooling( + min_dims=4, + max_dims=4, + min_side=2, + max_side=4, + ), + ceil_mode=st.booleans(), + exclusive=st.booleans(), + data_format=st.sampled_from(["NCHW", "NHWC"]), +) +def test_paddle_avg_pool2d( + dtype_x_k_s, + exclusive, + ceil_mode, + data_format, + *, + test_flags, + backend_fw, + frontend, + fn_tree, + on_device, +): + input_dtype, x, kernel, stride, padding = dtype_x_k_s + + if data_format == "NCHW": + x[0] = x[0].reshape( + (x[0].shape[0], x[0].shape[3], x[0].shape[1], x[0].shape[2]) + ) + if len(stride) == 1: + stride = (stride[0], stride[0]) + if padding == "SAME": + padding = test_pooling_functions.calculate_same_padding( + kernel, stride, x[0].shape[2:] + ) + else: + padding = (0, 0) + helpers.test_frontend_function( + input_dtypes=input_dtype, + test_flags=test_flags, + backend_to_test=backend_fw, + frontend=frontend, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + kernel_size=kernel, + stride=stride, + padding=padding, + ceil_mode=ceil_mode, + exclusive=exclusive, + divisor_override=None, + data_format=data_format, + ) + + # max_unpool1d @handle_frontend_test( fn_tree="paddle.nn.functional.max_unpool1d", diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_vision.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_vision.py index 39ea069851191..0659e33a95bf2 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_vision.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_vision.py @@ -7,46 +7,8 @@ from ivy_tests.test_ivy.helpers import handle_frontend_test -# pixel_shuffle -@handle_frontend_test( - fn_tree="paddle.nn.functional.pixel_shuffle", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=["float32", "float64"], - min_value=0, - min_num_dims=4, - max_num_dims=4, - min_dim_size=3, - ), - factor=helpers.ints(min_value=1), - data_format=st.sampled_from(["NCHW", "NHWC"]), -) -def test_paddle_pixel_shuffle( - *, - dtype_and_x, - factor, - data_format, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, x = dtype_and_x - if data_format == "NCHW": - assume(ivy.shape(x[0])[1] % (factor**2) == 0) - else: - assume(ivy.shape(x[0])[3] % (factor**2) == 0) - helpers.test_frontend_function( - input_dtypes=input_dtype, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - upscale_factor=factor, - data_format=data_format, - backend_to_test=backend_fw, - ) +# --- Helpers --- # +# --------------- # @st.composite @@ -92,6 +54,25 @@ def _affine_grid_helper(draw): return theta_dtype, theta[0], size, align_corners +@st.composite +def _image_shape_helper(draw, data_format): + n = draw(helpers.ints(min_value=1, max_value=10), label="batch") + c = draw(st.sampled_from([1, 3]), label="channel") + h = draw(helpers.ints(min_value=1, max_value=100), label="height") + w = draw(helpers.ints(min_value=1, max_value=100), label="width") + + if data_format == "NCHW": + shape = (n, c, h, w) + else: + shape = (n, h, w, c) + + return shape + + +# --- Main --- # +# ------------ # + + @handle_frontend_test( fn_tree="paddle.nn.functional.affine_grid", dtype_and_input_and_other=_affine_grid_helper(), @@ -114,21 +95,6 @@ def test_paddle_affine_grid( ) -@st.composite -def _image_shape_helper(draw, data_format): - n = draw(helpers.ints(min_value=1, max_value=10), label="batch") - c = draw(st.sampled_from([1, 3]), label="channel") - h = draw(helpers.ints(min_value=1, max_value=100), label="height") - w = draw(helpers.ints(min_value=1, max_value=100), label="width") - - if data_format == "NCHW": - shape = (n, c, h, w) - else: - shape = (n, h, w, c) - - return shape - - # channel_shuffle @handle_frontend_test( fn_tree="paddle.nn.functional.channel_shuffle", @@ -166,3 +132,45 @@ def test_paddle_channel_shuffle( groups=groups, data_format=data_format, ) + + +# pixel_shuffle +@handle_frontend_test( + fn_tree="paddle.nn.functional.pixel_shuffle", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=["float32", "float64"], + min_value=0, + min_num_dims=4, + max_num_dims=4, + min_dim_size=3, + ), + factor=helpers.ints(min_value=1), + data_format=st.sampled_from(["NCHW", "NHWC"]), +) +def test_paddle_pixel_shuffle( + *, + dtype_and_x, + factor, + data_format, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x = dtype_and_x + if data_format == "NCHW": + assume(ivy.shape(x[0])[1] % (factor**2) == 0) + else: + assume(ivy.shape(x[0])[3] % (factor**2) == 0) + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + upscale_factor=factor, + data_format=data_format, + backend_to_test=backend_fw, + ) diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_signal.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_signal.py index 0a9c8754f4fed..8b137891791fe 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_signal.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_signal.py @@ -1,3 +1 @@ -# global -# local diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_attribute.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_attribute.py index 79275a1166171..f0a471073ab0e 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_attribute.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_attribute.py @@ -6,19 +6,19 @@ @handle_frontend_test( - fn_tree="paddle.tensor.attribute.is_complex", + fn_tree="paddle.tensor.attribute.imag", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), ), ) -def test_paddle_is_complex( +def test_paddle_imag( *, dtype_and_x, on_device, fn_tree, + backend_fw, frontend, test_flags, - backend_fw, ): input_dtype, input = dtype_and_x helpers.test_frontend_function( @@ -33,19 +33,19 @@ def test_paddle_is_complex( @handle_frontend_test( - fn_tree="paddle.tensor.attribute.is_integer", + fn_tree="paddle.tensor.attribute.is_complex", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), ), ) -def test_paddle_is_integer( +def test_paddle_is_complex( *, dtype_and_x, on_device, fn_tree, - backend_fw, frontend, test_flags, + backend_fw, ): input_dtype, input = dtype_and_x helpers.test_frontend_function( @@ -87,66 +87,66 @@ def test_paddle_is_floating_point( @handle_frontend_test( - fn_tree="paddle.tensor.attribute.rank", + fn_tree="paddle.tensor.attribute.is_integer", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), ), ) -def test_paddle_rank( +def test_paddle_is_integer( *, dtype_and_x, on_device, fn_tree, - frontend, backend_fw, + frontend, test_flags, ): input_dtype, input = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, - frontend=frontend, backend_to_test=backend_fw, + frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=input[0], + x=input[0], ) @handle_frontend_test( - fn_tree="paddle.tensor.attribute.real", + fn_tree="paddle.tensor.attribute.rank", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), ), ) -def test_paddle_real( +def test_paddle_rank( *, dtype_and_x, on_device, fn_tree, - backend_fw, frontend, + backend_fw, test_flags, ): input_dtype, input = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, - backend_to_test=backend_fw, frontend=frontend, + backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=input[0], + input=input[0], ) @handle_frontend_test( - fn_tree="paddle.tensor.attribute.imag", + fn_tree="paddle.tensor.attribute.real", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), ), ) -def test_paddle_imag( +def test_paddle_real( *, dtype_and_x, on_device, diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_creation.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_creation.py index 208f3061b95c5..805aea49ee07f 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_creation.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_creation.py @@ -1,788 +1,792 @@ -# global -from hypothesis import strategies as st - -# local -import ivy_tests.test_ivy.helpers as helpers -import ivy_tests.test_ivy.helpers.globals as test_globals -from ivy_tests.test_ivy.helpers import handle_frontend_test, BackendHandler - - -# Helpers # -# ------- # - - -@st.composite -def _input_fill_and_dtype(draw): - dtype = draw(helpers.get_dtypes("float", full=False)) - with BackendHandler.update_backend(test_globals.CURRENT_BACKEND) as ivy_backend: - dtype_and_input = draw(helpers.dtype_and_values(dtype=dtype)) - if ivy_backend.is_uint_dtype(dtype[0]): - fill_values = draw(st.integers(min_value=0, max_value=5)) - elif ivy_backend.is_int_dtype(dtype[0]): - fill_values = draw(st.integers(min_value=-5, max_value=5)) - else: - fill_values = draw(st.floats(min_value=-5, max_value=5)) - dtype_to_cast = draw(helpers.get_dtypes("float", full=False)) - return dtype, dtype_and_input[1], fill_values, dtype_to_cast[0] - - -# Tests # -# ----- # - - -# to_tensor -@handle_frontend_test( - fn_tree="paddle.to_tensor", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("valid")), - dtype=helpers.get_dtypes("valid"), -) -def test_paddle_to_tensor( - *, - dtype_and_x, - dtype, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, input = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - data=input[0], - dtype=dtype[0], - place=on_device, - ) - - -# ones -@handle_frontend_test( - fn_tree="paddle.ones", - shape=helpers.get_shape( - allow_none=False, - min_num_dims=1, - max_num_dims=5, - min_dim_size=1, - max_dim_size=10, - ), - dtype=helpers.get_dtypes("valid"), - test_with_out=st.just(False), -) -def test_paddle_ones( - shape, - dtype, - frontend, - test_flags, - fn_tree, - backend_fw, - on_device, -): - helpers.test_frontend_function( - input_dtypes=dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - shape=shape, - dtype=dtype[0], - ) - - -# ones_like -@handle_frontend_test( - fn_tree="paddle.ones_like", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - ), - dtype=helpers.get_dtypes("valid"), - test_with_out=st.just(False), -) -def test_paddle_ones_like( - dtype_and_x, - dtype, - frontend, - test_flags, - fn_tree, - backend_fw, - on_device, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - dtype=dtype[0], - ) - - -# zeros -@handle_frontend_test( - fn_tree="paddle.zeros", - shape=helpers.get_shape( - allow_none=False, - min_num_dims=1, - max_num_dims=5, - min_dim_size=1, - max_dim_size=10, - ), - dtype=helpers.get_dtypes("valid"), - test_with_out=st.just(False), -) -def test_paddle_zeros( - shape, - dtype, - frontend, - test_flags, - fn_tree, - backend_fw, - on_device, -): - helpers.test_frontend_function( - input_dtypes=dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - shape=shape, - dtype=dtype[0], - ) - - -# zeros_like -@handle_frontend_test( - fn_tree="paddle.zeros_like", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - ), - dtype=helpers.get_dtypes("valid"), - test_with_out=st.just(False), -) -def test_paddle_zeros_like( - dtype_and_x, - dtype, - frontend, - test_flags, - fn_tree, - backend_fw, - on_device, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - dtype=dtype[0], - ) - - -# full -@handle_frontend_test( - fn_tree="paddle.full", - shape=helpers.get_shape( - allow_none=False, - min_num_dims=1, - max_num_dims=5, - min_dim_size=1, - max_dim_size=10, - ), - input_fill_dtype=_input_fill_and_dtype(), - test_with_out=st.just(False), -) -def test_paddle_full( - shape, - input_fill_dtype, - frontend, - test_flags, - fn_tree, - backend_fw, - on_device, -): - input_dtype, x, fill, dtype_to_cast = input_fill_dtype - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - shape=shape, - fill_value=fill, - dtype=dtype_to_cast, - ) - - -# full_like -@handle_frontend_test( - fn_tree="paddle.full_like", - input_fill_dtype=_input_fill_and_dtype(), - test_with_out=st.just(False), -) -def test_paddle_full_like( - input_fill_dtype, - frontend, - test_flags, - fn_tree, - backend_fw, - on_device, -): - input_dtype, x, fill, dtype_to_cast = input_fill_dtype - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - fill_value=fill, - dtype=dtype_to_cast, - ) - - -# arange -@handle_frontend_test( - fn_tree="paddle.arange", - start=helpers.ints(min_value=-50, max_value=0), - end=helpers.ints(min_value=1, max_value=50), - step=helpers.ints(min_value=1, max_value=5), - dtype=helpers.get_dtypes("float"), - test_with_out=st.just(False), -) -def test_paddle_arange( - start, - end, - step, - dtype, - frontend, - test_flags, - fn_tree, - backend_fw, - on_device, -): - helpers.test_frontend_function( - input_dtypes=dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - start=start, - end=end, - step=step, - dtype=dtype[0], - ) - - -# empty -@handle_frontend_test( - fn_tree="paddle.empty", - shape=helpers.get_shape( - allow_none=False, - min_num_dims=1, - max_num_dims=5, - min_dim_size=1, - max_dim_size=10, - ), - dtype=helpers.get_dtypes("valid", full=False), - test_with_out=st.just(False), -) -def test_paddle_empty( - shape, - dtype, - frontend, - test_flags, - fn_tree, - backend_fw, - on_device, -): - helpers.test_frontend_function( - input_dtypes=dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - test_values=False, - shape=shape, - dtype=dtype[0], - ) - - -# eye -@handle_frontend_test( - fn_tree="paddle.eye", - num_rows=helpers.ints(min_value=3, max_value=10), - num_columns=st.none() | helpers.ints(min_value=3, max_value=10), - dtypes=helpers.get_dtypes("valid", full=False), - test_with_out=st.just(False), -) -def test_paddle_eye( - *, - num_rows, - num_columns, - dtypes, - on_device, - fn_tree, - test_flags, - frontend, - backend_fw, -): - helpers.test_frontend_function( - input_dtypes=dtypes, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - num_rows=num_rows, - num_columns=num_columns, - dtype=dtypes[0], - ) - - -# empty_like -@handle_frontend_test( - fn_tree="paddle.empty_like", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("valid")), - dtype=helpers.get_dtypes("valid", full=False), - test_with_out=st.just(False), -) -def test_paddle_empty_like( - dtype_and_x, - dtype, - test_flags, - frontend, - backend_fw, - fn_tree, - on_device, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - frontend=frontend, - backend_to_test=backend_fw, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - test_values=False, - x=x[0], - dtype=dtype[0], - ) - - -# tril -@handle_frontend_test( - fn_tree="paddle.tril", - dtype_and_values=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - min_num_dims=2, - ), - diagonal=st.integers(min_value=-100, max_value=100), -) -def test_paddle_tril( - *, - dtype_and_values, - diagonal, - backend_fw, - on_device, - fn_tree, - frontend, - test_flags, -): - dtype, values = dtype_and_values - helpers.test_frontend_function( - input_dtypes=dtype, - frontend=frontend, - backend_to_test=backend_fw, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=values[0], - diagonal=diagonal, - ) - - -# triu -@handle_frontend_test( - fn_tree="paddle.triu", - dtype_and_values=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - min_num_dims=2, - ), - diagonal=st.integers(min_value=-100, max_value=100), -) -def test_paddle_triu( - *, - dtype_and_values, - diagonal, - on_device, - backend_fw, - fn_tree, - frontend, - test_flags, -): - dtype, values = dtype_and_values - helpers.test_frontend_function( - input_dtypes=dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=values[0], - diagonal=diagonal, - ) - - -# diagflat -@handle_frontend_test( - fn_tree="paddle.diagflat", - dtype_and_values=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - min_num_dims=1, - max_num_dims=5, - min_dim_size=1, - max_dim_size=5, - ), - offset=st.integers(min_value=-4, max_value=4), - test_with_out=st.just(False), -) -def test_paddle_diagflat( - dtype_and_values, - offset, - test_flags, - backend_fw, - frontend, - fn_tree, - on_device, -): - input_dtype, x = dtype_and_values - helpers.test_frontend_function( - input_dtypes=input_dtype, - frontend=frontend, - backend_to_test=backend_fw, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - test_values=False, - x=x[0], - offset=offset, - ) - - -@handle_frontend_test( - fn_tree="paddle.meshgrid", - dtype_and_arrays=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - num_arrays=st.integers(min_value=2, max_value=5), - min_num_dims=1, - max_num_dims=1, - shared_dtype=True, - ), - test_with_out=st.just(False), -) -def test_paddle_meshgrid( - dtype_and_arrays, - test_flags, - backend_fw, - frontend, - fn_tree, - on_device, -): - input_dtype, arrays = dtype_and_arrays - args = {} - i = 0 - for x_ in arrays: - args["x{}".format(i)] = x_ - i += 1 - test_flags.num_positional_args = len(arrays) - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - **args, - ) - - -# tril_indices -@handle_frontend_test( - fn_tree="paddle.tril_indices", - dtype=helpers.get_dtypes("valid", full=False), - row=st.integers(min_value=1, max_value=5), - col=st.integers(min_value=1, max_value=5), - offset=st.integers(min_value=-4, max_value=4), - test_with_out=st.just(False), -) -def test_paddle_tril_indices( - row, - col, - offset, - dtype, - test_flags, - backend_fw, - frontend, - fn_tree, - on_device, -): - helpers.test_frontend_function( - input_dtypes=dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - test_values=False, - row=row, - col=col, - offset=offset, - dtype=dtype[0], - ) - - -# diag -@handle_frontend_test( - fn_tree="paddle.diag", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - min_num_dims=1, - max_num_dims=2, - min_dim_size=1, - max_dim_size=5, - ), - k=helpers.ints(min_value=-1, max_value=1), - p=st.one_of( - helpers.ints(min_value=-25, max_value=25), - helpers.floats(min_value=-25, max_value=25), - ), -) -def test_paddle_diag( - dtype_and_x, - k, - p, - backend_fw, - frontend, - test_flags, - fn_tree, - on_device, -): - dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - offset=k, - padding_value=p, - ) - - -# logspace -@handle_frontend_test( - fn_tree="paddle.logspace", - start=helpers.floats(min_value=-10, max_value=10), - stop=helpers.floats(min_value=-10, max_value=10), - num=helpers.ints(min_value=1, max_value=5), - base=st.floats(min_value=0.1, max_value=10.0), - dtype=helpers.get_dtypes("float"), - test_with_out=st.just(False), -) -def test_paddle_logspace( - start, - stop, - num, - base, - dtype, - frontend, - test_flags, - fn_tree, - on_device, - backend_fw, -): - helpers.test_frontend_function( - input_dtypes=dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - start=start, - stop=stop, - num=num, - base=base, - dtype=dtype[0], - ) - - -# assign -@handle_frontend_test( - fn_tree="paddle.assign", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - num_arrays=2, - shared_dtype=True, - ), - test_with_out=st.just(True), -) -def test_paddle_assign( - dtype_and_x, - test_flags, - backend_fw, - frontend, - fn_tree, - on_device, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - output=x[1], - ) - - -# triu_indices -@handle_frontend_test( - fn_tree="paddle.triu_indices", - dtype=helpers.get_dtypes("valid", full=False), - row=st.integers(min_value=1, max_value=5), - col=st.integers(min_value=1, max_value=5), - offset=st.integers(min_value=-4, max_value=4), - test_with_out=st.just(False), -) -def test_paddle_triu_indices( - row, - col, - offset, - dtype, - test_flags, - backend_fw, - frontend, - fn_tree, - on_device, -): - helpers.test_frontend_function( - input_dtypes=dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - test_values=False, - row=row, - col=col, - offset=offset, - dtype=dtype[0], - ) - - -# complex -@handle_frontend_test( - fn_tree="paddle.complex", - dtype_and_arrays=helpers.dtype_and_values( - available_dtypes=["float32", "float64"], shared_dtype=True, num_arrays=2 - ), -) -def test_paddle_complex( - dtype_and_arrays, - test_flags, - backend_fw, - frontend, - fn_tree, - on_device, -): - input_dtype, (real, imag) = dtype_and_arrays - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - real=real, - imag=imag, - ) - - -# linspace -@handle_frontend_test( - fn_tree="paddle.linspace", - start=helpers.floats(min_value=-10, max_value=10), - stop=helpers.floats(min_value=-10, max_value=10), - num=helpers.ints(min_value=1, max_value=5), - dtype=helpers.get_dtypes("float"), - test_with_out=st.just(False), -) -def test_paddle_linspace( - start, - stop, - num, - dtype, - frontend, - test_flags, - fn_tree, - on_device, - backend_fw, -): - helpers.test_frontend_function( - input_dtypes=dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - start=start, - stop=stop, - num=num, - dtype=dtype[0], - ) - - -# clone -@handle_frontend_test( - fn_tree="paddle.clone", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("valid")), -) -def test_paddle_clone( - *, - dtype_and_x, - test_flags, - frontend, - backend_fw, - fn_tree, - on_device, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - ) +# global +from hypothesis import strategies as st + +# local +import ivy_tests.test_ivy.helpers as helpers +import ivy_tests.test_ivy.helpers.globals as test_globals +from ivy_tests.test_ivy.helpers import handle_frontend_test, BackendHandler + + +# --- Helpers --- # +# --------------- # + + +@st.composite +def _input_fill_and_dtype(draw): + dtype = draw(helpers.get_dtypes("float", full=False)) + with BackendHandler.update_backend(test_globals.CURRENT_BACKEND) as ivy_backend: + dtype_and_input = draw(helpers.dtype_and_values(dtype=dtype)) + if ivy_backend.is_uint_dtype(dtype[0]): + fill_values = draw(st.integers(min_value=0, max_value=5)) + elif ivy_backend.is_int_dtype(dtype[0]): + fill_values = draw(st.integers(min_value=-5, max_value=5)) + else: + fill_values = draw(st.floats(min_value=-5, max_value=5)) + dtype_to_cast = draw(helpers.get_dtypes("float", full=False)) + return dtype, dtype_and_input[1], fill_values, dtype_to_cast[0] + + +# --- Main --- # +# ------------ # + + +# arange +@handle_frontend_test( + fn_tree="paddle.arange", + start=helpers.ints(min_value=-50, max_value=0), + end=helpers.ints(min_value=1, max_value=50), + step=helpers.ints(min_value=1, max_value=5), + dtype=helpers.get_dtypes("float"), + test_with_out=st.just(False), +) +def test_paddle_arange( + start, + end, + step, + dtype, + frontend, + test_flags, + fn_tree, + backend_fw, + on_device, +): + helpers.test_frontend_function( + input_dtypes=dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + start=start, + end=end, + step=step, + dtype=dtype[0], + ) + + +# assign +@handle_frontend_test( + fn_tree="paddle.assign", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + num_arrays=2, + shared_dtype=True, + ), + test_with_out=st.just(True), +) +def test_paddle_assign( + dtype_and_x, + test_flags, + backend_fw, + frontend, + fn_tree, + on_device, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + output=x[1], + ) + + +# clone +@handle_frontend_test( + fn_tree="paddle.clone", + dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("valid")), +) +def test_paddle_clone( + *, + dtype_and_x, + test_flags, + frontend, + backend_fw, + fn_tree, + on_device, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + ) + + +# complex +@handle_frontend_test( + fn_tree="paddle.complex", + dtype_and_arrays=helpers.dtype_and_values( + available_dtypes=["float32", "float64"], shared_dtype=True, num_arrays=2 + ), +) +def test_paddle_complex( + dtype_and_arrays, + test_flags, + backend_fw, + frontend, + fn_tree, + on_device, +): + input_dtype, (real, imag) = dtype_and_arrays + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + real=real, + imag=imag, + ) + + +# diag +@handle_frontend_test( + fn_tree="paddle.diag", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=1, + max_num_dims=2, + min_dim_size=1, + max_dim_size=5, + ), + k=helpers.ints(min_value=-1, max_value=1), + p=st.one_of( + helpers.ints(min_value=-25, max_value=25), + helpers.floats(min_value=-25, max_value=25), + ), +) +def test_paddle_diag( + dtype_and_x, + k, + p, + backend_fw, + frontend, + test_flags, + fn_tree, + on_device, +): + dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + offset=k, + padding_value=p, + ) + + +# diagflat +@handle_frontend_test( + fn_tree="paddle.diagflat", + dtype_and_values=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=1, + max_num_dims=5, + min_dim_size=1, + max_dim_size=5, + ), + offset=st.integers(min_value=-4, max_value=4), + test_with_out=st.just(False), +) +def test_paddle_diagflat( + dtype_and_values, + offset, + test_flags, + backend_fw, + frontend, + fn_tree, + on_device, +): + input_dtype, x = dtype_and_values + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + test_values=False, + x=x[0], + offset=offset, + ) + + +# empty +@handle_frontend_test( + fn_tree="paddle.empty", + shape=helpers.get_shape( + allow_none=False, + min_num_dims=1, + max_num_dims=5, + min_dim_size=1, + max_dim_size=10, + ), + dtype=helpers.get_dtypes("valid", full=False), + test_with_out=st.just(False), +) +def test_paddle_empty( + shape, + dtype, + frontend, + test_flags, + fn_tree, + backend_fw, + on_device, +): + helpers.test_frontend_function( + input_dtypes=dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + test_values=False, + shape=shape, + dtype=dtype[0], + ) + + +# empty_like +@handle_frontend_test( + fn_tree="paddle.empty_like", + dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("valid")), + dtype=helpers.get_dtypes("valid", full=False), + test_with_out=st.just(False), +) +def test_paddle_empty_like( + dtype_and_x, + dtype, + test_flags, + frontend, + backend_fw, + fn_tree, + on_device, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + test_values=False, + x=x[0], + dtype=dtype[0], + ) + + +# eye +@handle_frontend_test( + fn_tree="paddle.eye", + num_rows=helpers.ints(min_value=3, max_value=10), + num_columns=st.none() | helpers.ints(min_value=3, max_value=10), + dtypes=helpers.get_dtypes("valid", full=False), + test_with_out=st.just(False), +) +def test_paddle_eye( + *, + num_rows, + num_columns, + dtypes, + on_device, + fn_tree, + test_flags, + frontend, + backend_fw, +): + helpers.test_frontend_function( + input_dtypes=dtypes, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + num_rows=num_rows, + num_columns=num_columns, + dtype=dtypes[0], + ) + + +# full +@handle_frontend_test( + fn_tree="paddle.full", + shape=helpers.get_shape( + allow_none=False, + min_num_dims=1, + max_num_dims=5, + min_dim_size=1, + max_dim_size=10, + ), + input_fill_dtype=_input_fill_and_dtype(), + test_with_out=st.just(False), +) +def test_paddle_full( + shape, + input_fill_dtype, + frontend, + test_flags, + fn_tree, + backend_fw, + on_device, +): + input_dtype, x, fill, dtype_to_cast = input_fill_dtype + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + shape=shape, + fill_value=fill, + dtype=dtype_to_cast, + ) + + +# full_like +@handle_frontend_test( + fn_tree="paddle.full_like", + input_fill_dtype=_input_fill_and_dtype(), + test_with_out=st.just(False), +) +def test_paddle_full_like( + input_fill_dtype, + frontend, + test_flags, + fn_tree, + backend_fw, + on_device, +): + input_dtype, x, fill, dtype_to_cast = input_fill_dtype + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + fill_value=fill, + dtype=dtype_to_cast, + ) + + +# linspace +@handle_frontend_test( + fn_tree="paddle.linspace", + start=helpers.floats(min_value=-10, max_value=10), + stop=helpers.floats(min_value=-10, max_value=10), + num=helpers.ints(min_value=1, max_value=5), + dtype=helpers.get_dtypes("float"), + test_with_out=st.just(False), +) +def test_paddle_linspace( + start, + stop, + num, + dtype, + frontend, + test_flags, + fn_tree, + on_device, + backend_fw, +): + helpers.test_frontend_function( + input_dtypes=dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + start=start, + stop=stop, + num=num, + dtype=dtype[0], + ) + + +# logspace +@handle_frontend_test( + fn_tree="paddle.logspace", + start=helpers.floats(min_value=-10, max_value=10), + stop=helpers.floats(min_value=-10, max_value=10), + num=helpers.ints(min_value=1, max_value=5), + base=st.floats(min_value=0.1, max_value=10.0), + dtype=helpers.get_dtypes("float"), + test_with_out=st.just(False), +) +def test_paddle_logspace( + start, + stop, + num, + base, + dtype, + frontend, + test_flags, + fn_tree, + on_device, + backend_fw, +): + helpers.test_frontend_function( + input_dtypes=dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + start=start, + stop=stop, + num=num, + base=base, + dtype=dtype[0], + ) + + +@handle_frontend_test( + fn_tree="paddle.meshgrid", + dtype_and_arrays=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + num_arrays=st.integers(min_value=2, max_value=5), + min_num_dims=1, + max_num_dims=1, + shared_dtype=True, + ), + test_with_out=st.just(False), +) +def test_paddle_meshgrid( + dtype_and_arrays, + test_flags, + backend_fw, + frontend, + fn_tree, + on_device, +): + input_dtype, arrays = dtype_and_arrays + args = {} + i = 0 + for x_ in arrays: + args["x{}".format(i)] = x_ + i += 1 + test_flags.num_positional_args = len(arrays) + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + **args, + ) + + +# ones +@handle_frontend_test( + fn_tree="paddle.ones", + shape=helpers.get_shape( + allow_none=False, + min_num_dims=1, + max_num_dims=5, + min_dim_size=1, + max_dim_size=10, + ), + dtype=helpers.get_dtypes("valid"), + test_with_out=st.just(False), +) +def test_paddle_ones( + shape, + dtype, + frontend, + test_flags, + fn_tree, + backend_fw, + on_device, +): + helpers.test_frontend_function( + input_dtypes=dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + shape=shape, + dtype=dtype[0], + ) + + +# ones_like +@handle_frontend_test( + fn_tree="paddle.ones_like", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + ), + dtype=helpers.get_dtypes("valid"), + test_with_out=st.just(False), +) +def test_paddle_ones_like( + dtype_and_x, + dtype, + frontend, + test_flags, + fn_tree, + backend_fw, + on_device, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + dtype=dtype[0], + ) + + +# Tests # +# ----- # + + +# to_tensor +@handle_frontend_test( + fn_tree="paddle.to_tensor", + dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("valid")), + dtype=helpers.get_dtypes("valid"), +) +def test_paddle_to_tensor( + *, + dtype_and_x, + dtype, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, input = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + data=input[0], + dtype=dtype[0], + place=on_device, + ) + + +# tril +@handle_frontend_test( + fn_tree="paddle.tril", + dtype_and_values=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + min_num_dims=2, + ), + diagonal=st.integers(min_value=-100, max_value=100), +) +def test_paddle_tril( + *, + dtype_and_values, + diagonal, + backend_fw, + on_device, + fn_tree, + frontend, + test_flags, +): + dtype, values = dtype_and_values + helpers.test_frontend_function( + input_dtypes=dtype, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=values[0], + diagonal=diagonal, + ) + + +# tril_indices +@handle_frontend_test( + fn_tree="paddle.tril_indices", + dtype=helpers.get_dtypes("valid", full=False), + row=st.integers(min_value=1, max_value=5), + col=st.integers(min_value=1, max_value=5), + offset=st.integers(min_value=-4, max_value=4), + test_with_out=st.just(False), +) +def test_paddle_tril_indices( + row, + col, + offset, + dtype, + test_flags, + backend_fw, + frontend, + fn_tree, + on_device, +): + helpers.test_frontend_function( + input_dtypes=dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + test_values=False, + row=row, + col=col, + offset=offset, + dtype=dtype[0], + ) + + +# triu +@handle_frontend_test( + fn_tree="paddle.triu", + dtype_and_values=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + min_num_dims=2, + ), + diagonal=st.integers(min_value=-100, max_value=100), +) +def test_paddle_triu( + *, + dtype_and_values, + diagonal, + on_device, + backend_fw, + fn_tree, + frontend, + test_flags, +): + dtype, values = dtype_and_values + helpers.test_frontend_function( + input_dtypes=dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=values[0], + diagonal=diagonal, + ) + + +# triu_indices +@handle_frontend_test( + fn_tree="paddle.triu_indices", + dtype=helpers.get_dtypes("valid", full=False), + row=st.integers(min_value=1, max_value=5), + col=st.integers(min_value=1, max_value=5), + offset=st.integers(min_value=-4, max_value=4), + test_with_out=st.just(False), +) +def test_paddle_triu_indices( + row, + col, + offset, + dtype, + test_flags, + backend_fw, + frontend, + fn_tree, + on_device, +): + helpers.test_frontend_function( + input_dtypes=dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + test_values=False, + row=row, + col=col, + offset=offset, + dtype=dtype[0], + ) + + +# zeros +@handle_frontend_test( + fn_tree="paddle.zeros", + shape=helpers.get_shape( + allow_none=False, + min_num_dims=1, + max_num_dims=5, + min_dim_size=1, + max_dim_size=10, + ), + dtype=helpers.get_dtypes("valid"), + test_with_out=st.just(False), +) +def test_paddle_zeros( + shape, + dtype, + frontend, + test_flags, + fn_tree, + backend_fw, + on_device, +): + helpers.test_frontend_function( + input_dtypes=dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + shape=shape, + dtype=dtype[0], + ) + + +# zeros_like +@handle_frontend_test( + fn_tree="paddle.zeros_like", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + ), + dtype=helpers.get_dtypes("valid"), + test_with_out=st.just(False), +) +def test_paddle_zeros_like( + dtype_and_x, + dtype, + frontend, + test_flags, + fn_tree, + backend_fw, + on_device, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + dtype=dtype[0], + ) diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_einsum.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_einsum.py index 0a9c8754f4fed..8b137891791fe 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_einsum.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_einsum.py @@ -1,3 +1 @@ -# global -# local diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_linalg.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_linalg.py index bd55ed23c4ccd..6f3118326c4d9 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_linalg.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_linalg.py @@ -12,6 +12,157 @@ _get_cholesky_matrix, ) + +# --- Helpers --- # +# --------------- # + + +@st.composite +def _dtype_values_axis(draw): + dtype_and_values = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_num_dims=2, + max_num_dims=5, + min_dim_size=2, + max_dim_size=5, + min_value=0.1, + max_value=1000.0, + ) + ) + + dtype, x = dtype_and_values + x = x[0] + r = len(x.shape) + + valid_axes = [None] + + for i in range(-r, r): + valid_axes.append(i) + for j in range(-r, r): + if i != j and abs(i - j) != r: + valid_axes.append([i, j]) + + axis = draw(st.sampled_from(valid_axes)) + + p_list = ["fro", 1, 2, ivy.inf, -ivy.inf] + if isinstance(axis, list) and len(axis) == 2: + p = draw( + st.one_of( + st.sampled_from(p_list), + st.floats(min_value=1.0, max_value=10.0, allow_infinity=False), + ) + ) + else: + p = draw( + st.one_of( + st.sampled_from(p_list + [0]), + st.floats(min_value=1.0, max_value=10.0, allow_infinity=False), + ) + ) + + return dtype, x, axis, p + + +# cond +@st.composite +def _get_dtype_and_matrix_non_singular(draw, dtypes): + while True: + matrix = draw( + helpers.dtype_and_values( + available_dtypes=dtypes, + min_value=-10, + max_value=10, + min_num_dims=2, + max_num_dims=2, + min_dim_size=1, + max_dim_size=5, + shape=st.tuples(st.integers(1, 5), st.integers(1, 5)).filter( + lambda x: x[0] == x[1] + ), + allow_inf=False, + allow_nan=False, + ) + ) + if np.linalg.det(matrix[1][0]) != 0: + break + + return matrix[0], matrix[1] + + +@st.composite +def _get_dtype_and_square_matrix(draw, real_and_complex_only=False): + if real_and_complex_only: + dtype = [ + draw(st.sampled_from(["float32", "float64", "complex64", "complex128"])) + ] + else: + dtype = draw(helpers.get_dtypes("valid")) + dim_size = draw(helpers.ints(min_value=2, max_value=5)) + mat = draw( + helpers.array_values( + dtype=dtype[0], shape=(dim_size, dim_size), min_value=0, max_value=10 + ) + ) + return dtype, mat + + +@st.composite +def _get_dtype_input_and_vectors(draw): + dim_size = draw(helpers.ints(min_value=1, max_value=2)) + dtype = draw(helpers.get_dtypes("float")) + if dim_size == 1: + vec1 = draw( + helpers.array_values( + dtype=dtype[0], shape=(dim_size,), min_value=2, max_value=5 + ) + ) + vec2 = draw( + helpers.array_values( + dtype=dtype[0], shape=(dim_size,), min_value=2, max_value=5 + ) + ) + else: + vec1 = draw( + helpers.array_values( + dtype=dtype[0], shape=(dim_size, dim_size), min_value=2, max_value=5 + ) + ) + vec2 = draw( + helpers.array_values( + dtype=dtype[0], shape=(dim_size, dim_size), min_value=2, max_value=5 + ) + ) + return dtype, vec1, vec2 + + +# cholesky_solve +@st.composite +def _get_paddle_cholesky_matrix(draw): + input_dtype, spd_chol = draw(_get_cholesky_matrix()) + probability = draw(st.floats(min_value=0, max_value=1)) + if probability > 0.5: + spd_chol = spd_chol.T # randomly transpose the matrix + return input_dtype, spd_chol + + +# transpose +@st.composite +def _transpose_helper(draw): + dtype, x, shape = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=1, + max_num_dims=4, + min_dim_size=2, + max_dim_size=3, + ret_shape=True, + ) + ) + perm = draw(st.permutations([i for i in range(len(shape))])) + return dtype, x, perm + + # Helpers # # ------ # @@ -77,97 +228,181 @@ def dtype_value1_value2_axis( return [dtype], value1, value2, axis -@st.composite -def _get_dtype_input_and_vectors(draw): - dim_size = draw(helpers.ints(min_value=1, max_value=2)) - dtype = draw(helpers.get_dtypes("float")) - if dim_size == 1: - vec1 = draw( - helpers.array_values( - dtype=dtype[0], shape=(dim_size,), min_value=2, max_value=5 - ) - ) - vec2 = draw( - helpers.array_values( - dtype=dtype[0], shape=(dim_size,), min_value=2, max_value=5 - ) - ) - else: - vec1 = draw( - helpers.array_values( - dtype=dtype[0], shape=(dim_size, dim_size), min_value=2, max_value=5 - ) - ) - vec2 = draw( - helpers.array_values( - dtype=dtype[0], shape=(dim_size, dim_size), min_value=2, max_value=5 - ) - ) - return dtype, vec1, vec2 +# --- Main --- # +# ------------ # -@st.composite -def _get_dtype_and_square_matrix(draw, real_and_complex_only=False): - if real_and_complex_only: - dtype = [ - draw(st.sampled_from(["float32", "float64", "complex64", "complex128"])) - ] - else: - dtype = draw(helpers.get_dtypes("valid")) - dim_size = draw(helpers.ints(min_value=2, max_value=5)) - mat = draw( - helpers.array_values( - dtype=dtype[0], shape=(dim_size, dim_size), min_value=0, max_value=10 - ) +# bincount +@handle_frontend_test( + fn_tree="paddle.bincount", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("integer"), + min_value=1, + max_value=2, + shape=st.shared( + helpers.get_shape( + min_num_dims=1, + max_num_dims=1, + ), + key="a_s_d", + ), + ), + test_with_out=st.just(False), +) +def test_paddle_bincount( + *, + dtype_and_x, + on_device, + fn_tree, + backend_fw, + frontend, + test_flags, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + test_flags=test_flags, + backend_to_test=backend_fw, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + weights=None, + minlength=0, ) - return dtype, mat -@st.composite -def _dtype_values_axis(draw): - dtype_and_values = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_num_dims=2, - max_num_dims=5, - min_dim_size=2, - max_dim_size=5, - min_value=0.1, - max_value=1000.0, - ) +# bmm +@handle_frontend_test( + fn_tree="paddle.bmm", + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + shape=(3, 3, 3), + num_arrays=2, + shared_dtype=True, + min_value=-10, + max_value=10, + ), + aliases=["paddle.tensor.linalg.bmm"], + test_with_out=st.just(False), +) +def test_paddle_bmm( + *, + dtype_x, + frontend, + test_flags, + backend_fw, + fn_tree, + on_device, +): + input_dtype, x = dtype_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + y=x[1], ) - dtype, x = dtype_and_values + +# cholesky +@handle_frontend_test( + fn_tree="paddle.tensor.linalg.cholesky", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + min_value=0, + max_value=10, + shape=helpers.ints(min_value=2, max_value=5).map(lambda x: tuple([x, x])), + ), + upper=st.booleans(), +) +def test_paddle_cholesky( + dtype_and_x, + upper, + frontend, + test_flags, + fn_tree, + backend_fw, + on_device, +): + dtype, x = dtype_and_x x = x[0] - r = len(x.shape) + x = np.matmul(x.T, x) + np.identity(x.shape[0]) # make symmetric positive-definite - valid_axes = [None] + helpers.test_frontend_function( + input_dtypes=dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x, + upper=upper, + ) - for i in range(-r, r): - valid_axes.append(i) - for j in range(-r, r): - if i != j and abs(i - j) != r: - valid_axes.append([i, j]) - axis = draw(st.sampled_from(valid_axes)) +@handle_frontend_test( + fn_tree="paddle.tensor.linalg.cholesky_solve", + x=_get_second_matrix(), + y=_get_paddle_cholesky_matrix(), + test_with_out=st.just(False), +) +def test_paddle_cholesky_solve( + *, + x, + y, + frontend, + backend_fw, + test_flags, + fn_tree, + on_device, +): + input_dtype1, x1 = x + input_dtype2, x2 = y + helpers.test_frontend_function( + input_dtypes=[input_dtype1, input_dtype2], + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + rtol=1e-3, + atol=1e-3, + x=x1, + y=x2, + upper=np.array_equal(x2, np.triu(x2)), # check whether the matrix is upper + ) - p_list = ["fro", 1, 2, ivy.inf, -ivy.inf] - if isinstance(axis, list) and len(axis) == 2: - p = draw( - st.one_of( - st.sampled_from(p_list), - st.floats(min_value=1.0, max_value=10.0, allow_infinity=False), - ) - ) - else: - p = draw( - st.one_of( - st.sampled_from(p_list + [0]), - st.floats(min_value=1.0, max_value=10.0, allow_infinity=False), - ) - ) - return dtype, x, axis, p +@handle_frontend_test( + fn_tree="paddle.tensor.linalg.cond", + dtype_and_x=_get_dtype_and_matrix_non_singular(dtypes=["float32", "float64"]), + p=st.sampled_from([None, "fro", "nuc", np.inf, -np.inf, 1, -1, 2, -2]), + test_with_out=st.just(False), +) +def test_paddle_cond( + *, dtype_and_x, p, on_device, fn_tree, frontend, test_flags, backend_fw +): + dtype, x = dtype_and_x + + assume(matrix_is_stable(x[0])) + + helpers.test_frontend_function( + input_dtypes=dtype, + frontend=frontend, + test_flags=test_flags, + backend_to_test=backend_fw, + fn_tree=fn_tree, + on_device=on_device, + test_values=True, + x=x[0], + rtol=1e-5, + atol=1e-5, + p=p, + ) # Tests # @@ -212,78 +447,73 @@ def test_paddle_cross( ) -# matmul @handle_frontend_test( - fn_tree="paddle.matmul", - dtype_x=helpers.dtype_and_values( + fn_tree="paddle.dist", + dtype_and_input=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - shape=(3, 3), num_arrays=2, shared_dtype=True, - min_value=-10, - max_value=10, - ), - aliases=["paddle.tensor.linalg.matmul"], - transpose_x=st.booleans(), - transpose_y=st.booleans(), - test_with_out=st.just(False), + min_value=-1e04, + max_value=1e04, + allow_inf=False, + ), + p=helpers.floats(min_value=1.0, max_value=10.0), ) -def test_paddle_matmul( +def test_paddle_dist( *, - dtype_x, - transpose_x, - transpose_y, - frontend, - test_flags, + dtype_and_input, + p, + on_device, fn_tree, backend_fw, - on_device, + frontend, + test_flags, ): - input_dtype, x = dtype_x + input_dtype, x = dtype_and_input helpers.test_frontend_function( input_dtypes=input_dtype, - backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, + backend_to_test=backend_fw, fn_tree=fn_tree, on_device=on_device, x=x[0], y=x[1], - transpose_x=transpose_x, - transpose_y=transpose_y, + p=p, ) -# norm +# dot @handle_frontend_test( - fn_tree="paddle.tensor.linalg.norm", - dtype_values_axis=_dtype_values_axis(), - keepdims=st.booleans(), + fn_tree="paddle.tensor.linalg.dot", + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + num_arrays=2, + min_num_dims=1, + max_num_dims=2, + shared_dtype=True, + ), test_with_out=st.just(False), ) -def test_paddle_norm( - dtype_values_axis, - keepdims, +def test_paddle_dot( + *, + dtype_x, frontend, test_flags, fn_tree, backend_fw, on_device, ): - dtype, x, axis, p = dtype_values_axis + input_dtype, x = dtype_x helpers.test_frontend_function( - input_dtypes=dtype, - backend_to_test=backend_fw, + input_dtypes=input_dtype, frontend=frontend, test_flags=test_flags, + backend_to_test=backend_fw, fn_tree=fn_tree, on_device=on_device, - x=x, - p=p, - axis=axis, - keepdim=keepdims, - atol=1e-1, - rtol=1e-1, + x=x[0], + y=x[1], ) @@ -334,73 +564,6 @@ def test_paddle_eig( ) -# eigvals -@handle_frontend_test( - fn_tree="paddle.tensor.linalg.eigvals", - dtype_x=_get_dtype_and_square_matrix(real_and_complex_only=True), - test_with_out=st.just(False), -) -def test_paddle_eigvals( - *, - dtype_x, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - dtype, x = dtype_x - x = np.array(x[0], dtype=dtype[0]) - # make symmetric positive-definite beforehand - x = np.matmul(x.T, x) + np.identity(x.shape[0]) * 1e-3 - - ret, frontend_ret = helpers.test_frontend_function( - input_dtypes=dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - test_values=False, - x=x, - ) - - -# eigvalsh -@handle_frontend_test( - fn_tree="paddle.tensor.linalg.eigvalsh", - dtype_x=_get_dtype_and_square_matrix(real_and_complex_only=True), - UPLO=st.sampled_from(("L", "U")), - test_with_out=st.just(False), -) -def test_paddle_eigvalsh( - *, - dtype_x, - UPLO, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - dtype, x = dtype_x - x = np.asarray(x[0], dtype=dtype[0]) - # make symmetric positive-definite beforehand - x = np.matmul(x.T, x) + np.identity(x.shape[0]) * 1e-3 - - helpers.test_frontend_function( - input_dtypes=dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - test_values=False, - x=x, - UPLO=UPLO, - ) - - # eigh @handle_frontend_test( fn_tree="paddle.tensor.linalg.eigh", @@ -451,152 +614,59 @@ def test_paddle_eigh( ) -# pinv -@handle_frontend_test( - fn_tree="paddle.tensor.linalg.pinv", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_num_dims=2, - max_num_dims=5, - min_dim_size=2, - max_dim_size=5, - min_value=3, - max_value=10, - large_abs_safety_factor=128, - safety_factor_scale="log", - ), - rcond=st.floats(1e-5, 1e-3), - test_with_out=st.just(False), -) -def test_paddle_pinv( - dtype_and_x, - rcond, - frontend, - test_flags, - fn_tree, - backend_fw, - on_device, -): - # TODO: paddle returns nan for all values if the input - # matrix has the same value at all indices e.g. - # [[2., 2.], [2., 2.]] would return [[nan, nan], [nan, nan]], - # causing the tests to fail for other backends. - dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - rtol=1e-3, - atol=1e-3, - x=x[0], - rcond=rcond, - ) - - -# solve +# eigvals @handle_frontend_test( - fn_tree="paddle.solve", - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - shared_dtype=True, - min_value=-10, - max_value=10, - ), - aliases=["paddle.tensor.linalg.solve"], + fn_tree="paddle.tensor.linalg.eigvals", + dtype_x=_get_dtype_and_square_matrix(real_and_complex_only=True), test_with_out=st.just(False), ) -def test_paddle_solve( +def test_paddle_eigvals( *, dtype_x, + on_device, + fn_tree, frontend, test_flags, backend_fw, - fn_tree, - on_device, ): - input_dtype, x = dtype_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - frontend=frontend, - backend_to_test=backend_fw, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - y=x[1], - ) - - -# cholesky_solve -@st.composite -def _get_paddle_cholesky_matrix(draw): - input_dtype, spd_chol = draw(_get_cholesky_matrix()) - probability = draw(st.floats(min_value=0, max_value=1)) - if probability > 0.5: - spd_chol = spd_chol.T # randomly transpose the matrix - return input_dtype, spd_chol - + dtype, x = dtype_x + x = np.array(x[0], dtype=dtype[0]) + # make symmetric positive-definite beforehand + x = np.matmul(x.T, x) + np.identity(x.shape[0]) * 1e-3 -@handle_frontend_test( - fn_tree="paddle.tensor.linalg.cholesky_solve", - x=_get_second_matrix(), - y=_get_paddle_cholesky_matrix(), - test_with_out=st.just(False), -) -def test_paddle_cholesky_solve( - *, - x, - y, - frontend, - backend_fw, - test_flags, - fn_tree, - on_device, -): - input_dtype1, x1 = x - input_dtype2, x2 = y - helpers.test_frontend_function( - input_dtypes=[input_dtype1, input_dtype2], - frontend=frontend, + ret, frontend_ret = helpers.test_frontend_function( + input_dtypes=dtype, backend_to_test=backend_fw, + frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - rtol=1e-3, - atol=1e-3, - x=x1, - y=x2, - upper=np.array_equal(x2, np.triu(x2)), # check whether the matrix is upper + test_values=False, + x=x, ) - -# cholesky -@handle_frontend_test( - fn_tree="paddle.tensor.linalg.cholesky", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - min_value=0, - max_value=10, - shape=helpers.ints(min_value=2, max_value=5).map(lambda x: tuple([x, x])), - ), - upper=st.booleans(), + +# eigvalsh +@handle_frontend_test( + fn_tree="paddle.tensor.linalg.eigvalsh", + dtype_x=_get_dtype_and_square_matrix(real_and_complex_only=True), + UPLO=st.sampled_from(("L", "U")), + test_with_out=st.just(False), ) -def test_paddle_cholesky( - dtype_and_x, - upper, +def test_paddle_eigvalsh( + *, + dtype_x, + UPLO, + on_device, + fn_tree, frontend, test_flags, - fn_tree, backend_fw, - on_device, ): - dtype, x = dtype_and_x - x = x[0] - x = np.matmul(x.T, x) + np.identity(x.shape[0]) # make symmetric positive-definite + dtype, x = dtype_x + x = np.asarray(x[0], dtype=dtype[0]) + # make symmetric positive-definite beforehand + x = np.matmul(x.T, x) + np.identity(x.shape[0]) * 1e-3 helpers.test_frontend_function( input_dtypes=dtype, @@ -605,44 +675,51 @@ def test_paddle_cholesky( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, + test_values=False, x=x, - upper=upper, + UPLO=UPLO, ) -# bmm +# matmul @handle_frontend_test( - fn_tree="paddle.bmm", + fn_tree="paddle.matmul", dtype_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - shape=(3, 3, 3), + shape=(3, 3), num_arrays=2, shared_dtype=True, min_value=-10, max_value=10, ), - aliases=["paddle.tensor.linalg.bmm"], + aliases=["paddle.tensor.linalg.matmul"], + transpose_x=st.booleans(), + transpose_y=st.booleans(), test_with_out=st.just(False), ) -def test_paddle_bmm( +def test_paddle_matmul( *, dtype_x, + transpose_x, + transpose_y, frontend, test_flags, - backend_fw, fn_tree, + backend_fw, on_device, ): input_dtype, x = dtype_x helpers.test_frontend_function( input_dtypes=input_dtype, - frontend=frontend, backend_to_test=backend_fw, + frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, x=x[0], y=x[1], + transpose_x=transpose_x, + transpose_y=transpose_y, ) @@ -680,87 +757,112 @@ def test_paddle_matrix_power( ) -# cond -@st.composite -def _get_dtype_and_matrix_non_singular(draw, dtypes): - while True: - matrix = draw( - helpers.dtype_and_values( - available_dtypes=dtypes, - min_value=-10, - max_value=10, - min_num_dims=2, - max_num_dims=2, - min_dim_size=1, - max_dim_size=5, - shape=st.tuples(st.integers(1, 5), st.integers(1, 5)).filter( - lambda x: x[0] == x[1] - ), - allow_inf=False, - allow_nan=False, - ) - ) - if np.linalg.det(matrix[1][0]) != 0: - break - - return matrix[0], matrix[1] +# norm +@handle_frontend_test( + fn_tree="paddle.tensor.linalg.norm", + dtype_values_axis=_dtype_values_axis(), + keepdims=st.booleans(), + test_with_out=st.just(False), +) +def test_paddle_norm( + dtype_values_axis, + keepdims, + frontend, + test_flags, + fn_tree, + backend_fw, + on_device, +): + dtype, x, axis, p = dtype_values_axis + helpers.test_frontend_function( + input_dtypes=dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x, + p=p, + axis=axis, + keepdim=keepdims, + atol=1e-1, + rtol=1e-1, + ) +# pinv @handle_frontend_test( - fn_tree="paddle.tensor.linalg.cond", - dtype_and_x=_get_dtype_and_matrix_non_singular(dtypes=["float32", "float64"]), - p=st.sampled_from([None, "fro", "nuc", np.inf, -np.inf, 1, -1, 2, -2]), + fn_tree="paddle.tensor.linalg.pinv", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_num_dims=2, + max_num_dims=5, + min_dim_size=2, + max_dim_size=5, + min_value=3, + max_value=10, + large_abs_safety_factor=128, + safety_factor_scale="log", + ), + rcond=st.floats(1e-5, 1e-3), test_with_out=st.just(False), ) -def test_paddle_cond( - *, dtype_and_x, p, on_device, fn_tree, frontend, test_flags, backend_fw +def test_paddle_pinv( + dtype_and_x, + rcond, + frontend, + test_flags, + fn_tree, + backend_fw, + on_device, ): + # TODO: paddle returns nan for all values if the input + # matrix has the same value at all indices e.g. + # [[2., 2.], [2., 2.]] would return [[nan, nan], [nan, nan]], + # causing the tests to fail for other backends. dtype, x = dtype_and_x - - assume(matrix_is_stable(x[0])) - helpers.test_frontend_function( input_dtypes=dtype, + backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, - backend_to_test=backend_fw, fn_tree=fn_tree, on_device=on_device, - test_values=True, + rtol=1e-3, + atol=1e-3, x=x[0], - rtol=1e-5, - atol=1e-5, - p=p, + rcond=rcond, ) -# dot +# solve @handle_frontend_test( - fn_tree="paddle.tensor.linalg.dot", + fn_tree="paddle.solve", dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), + available_dtypes=helpers.get_dtypes("float"), num_arrays=2, - min_num_dims=1, - max_num_dims=2, shared_dtype=True, + min_value=-10, + max_value=10, ), + aliases=["paddle.tensor.linalg.solve"], test_with_out=st.just(False), ) -def test_paddle_dot( +def test_paddle_solve( *, dtype_x, frontend, test_flags, - fn_tree, backend_fw, + fn_tree, on_device, ): input_dtype, x = dtype_x helpers.test_frontend_function( input_dtypes=input_dtype, frontend=frontend, - test_flags=test_flags, backend_to_test=backend_fw, + test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, x=x[0], @@ -768,23 +870,6 @@ def test_paddle_dot( ) -# transpose -@st.composite -def _transpose_helper(draw): - dtype, x, shape = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - min_num_dims=1, - max_num_dims=4, - min_dim_size=2, - max_dim_size=3, - ret_shape=True, - ) - ) - perm = draw(st.permutations([i for i in range(len(shape))])) - return dtype, x, perm - - @handle_frontend_test( fn_tree="paddle.transpose", dtype_and_x_perm=_transpose_helper(), @@ -809,79 +894,3 @@ def test_paddle_transpose( x=x[0], perm=perm, ) - - -# bincount -@handle_frontend_test( - fn_tree="paddle.bincount", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("integer"), - min_value=1, - max_value=2, - shape=st.shared( - helpers.get_shape( - min_num_dims=1, - max_num_dims=1, - ), - key="a_s_d", - ), - ), - test_with_out=st.just(False), -) -def test_paddle_bincount( - *, - dtype_and_x, - on_device, - fn_tree, - backend_fw, - frontend, - test_flags, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - frontend=frontend, - test_flags=test_flags, - backend_to_test=backend_fw, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - weights=None, - minlength=0, - ) - - -@handle_frontend_test( - fn_tree="paddle.dist", - dtype_and_input=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - shared_dtype=True, - min_value=-1e04, - max_value=1e04, - allow_inf=False, - ), - p=helpers.floats(min_value=1.0, max_value=10.0), -) -def test_paddle_dist( - *, - dtype_and_input, - p, - on_device, - fn_tree, - backend_fw, - frontend, - test_flags, -): - input_dtype, x = dtype_and_input - helpers.test_frontend_function( - input_dtypes=input_dtype, - frontend=frontend, - test_flags=test_flags, - backend_to_test=backend_fw, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - y=x[1], - p=p, - ) diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_logic.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_logic.py index 29e8745e1e938..179a1de46c787 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_logic.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_logic.py @@ -7,167 +7,169 @@ from ivy_tests.test_ivy.helpers import handle_frontend_test -# Helpers # -# ------- # - - -# Tests # -# ----- # - - -# equal +# allclose @handle_frontend_test( - fn_tree="paddle.equal", + fn_tree="paddle.allclose", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), num_arrays=2, shared_dtype=True, - safety_factor_scale="log", - small_abs_safety_factor=32, ), + equal_nan=st.booleans(), ) -def test_paddle_equal( +def test_paddle_allclose( + *, dtype_and_x, - frontend, - test_flags, + equal_nan, + on_device, fn_tree, + frontend, backend_fw, + test_flags, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, - backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, + on_device=on_device, + backend_to_test=backend_fw, x=x[0], y=x[1], + equal_nan=equal_nan, ) -# not_equal @handle_frontend_test( - fn_tree="paddle.not_equal", + fn_tree="paddle.bitwise_and", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), num_arrays=2, shared_dtype=True, - safety_factor_scale="log", - small_abs_safety_factor=32, ), + test_with_out=st.just(True), ) -def test_paddle_not_equal( +def test_paddle_bitwise_and( + *, dtype_and_x, - frontend, - test_flags, + on_device, fn_tree, backend_fw, + frontend, + test_flags, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, - backend_to_test=backend_fw, frontend=frontend, + backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, + on_device=on_device, x=x[0], y=x[1], ) -# greater_than +# bitwise_not @handle_frontend_test( - fn_tree="paddle.greater_than", + fn_tree="paddle.bitwise_not", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), - num_arrays=2, - shared_dtype=True, - safety_factor_scale="log", - small_abs_safety_factor=32, ), + test_with_out=st.just(True), ) -def test_paddle_greater_than( +def test_paddle_bitwise_not( + *, dtype_and_x, - frontend, - test_flags, + on_device, fn_tree, + frontend, backend_fw, + test_flags, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, - backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, + backend_to_test=backend_fw, fn_tree=fn_tree, + on_device=on_device, x=x[0], - y=x[1], ) -# greater_equal +# bitwise_or @handle_frontend_test( - fn_tree="paddle.greater_equal", + fn_tree="paddle.bitwise_or", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - num_arrays=2, - shared_dtype=True, - safety_factor_scale="log", - small_abs_safety_factor=32, + available_dtypes=helpers.get_dtypes("valid"), num_arrays=2, shared_dtype=True ), + test_with_out=st.just(True), ) -def test_paddle_greater_equal( +def test_paddle_bitwise_or( + *, dtype_and_x, - frontend, - test_flags, + on_device, fn_tree, backend_fw, + frontend, + test_flags, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, - backend_to_test=backend_fw, frontend=frontend, + backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, + on_device=on_device, x=x[0], y=x[1], ) -# less_than +# bitwise_xor @handle_frontend_test( - fn_tree="paddle.less_than", + fn_tree="paddle.bitwise_xor", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), num_arrays=2, shared_dtype=True, - safety_factor_scale="log", - small_abs_safety_factor=32, ), + test_with_out=st.just(True), ) -def test_paddle_less_than( +def test_paddle_bitwise_xor( + *, dtype_and_x, - frontend, - test_flags, + on_device, fn_tree, + frontend, backend_fw, + test_flags, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, - backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, + backend_to_test=backend_fw, fn_tree=fn_tree, + on_device=on_device, x=x[0], y=x[1], ) -# less_equal +# Tests # +# ----- # + + +# equal @handle_frontend_test( - fn_tree="paddle.less_equal", + fn_tree="paddle.equal", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), num_arrays=2, @@ -176,7 +178,7 @@ def test_paddle_less_than( small_abs_safety_factor=32, ), ) -def test_paddle_less_equal( +def test_paddle_equal( dtype_and_x, frontend, test_flags, @@ -229,21 +231,22 @@ def test_paddle_equal_all( ) -# logical_or +# greater_equal @handle_frontend_test( - fn_tree="paddle.logical_or", + fn_tree="paddle.greater_equal", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), num_arrays=2, shared_dtype=True + available_dtypes=helpers.get_dtypes("valid"), + num_arrays=2, + shared_dtype=True, + safety_factor_scale="log", + small_abs_safety_factor=32, ), - test_with_out=st.just(True), ) -def test_paddle_logical_or( - *, +def test_paddle_greater_equal( dtype_and_x, - on_device, - fn_tree, frontend, test_flags, + fn_tree, backend_fw, ): input_dtype, x = dtype_and_x @@ -253,27 +256,27 @@ def test_paddle_logical_or( frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, - on_device=on_device, x=x[0], y=x[1], ) -# logical_xor +# greater_than @handle_frontend_test( - fn_tree="paddle.logical_xor", + fn_tree="paddle.greater_than", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), num_arrays=2, shared_dtype=True + available_dtypes=helpers.get_dtypes("valid"), + num_arrays=2, + shared_dtype=True, + safety_factor_scale="log", + small_abs_safety_factor=32, ), - test_with_out=st.just(True), ) -def test_paddle_logical_xor( - *, +def test_paddle_greater_than( dtype_and_x, - on_device, - fn_tree, frontend, test_flags, + fn_tree, backend_fw, ): input_dtype, x = dtype_and_x @@ -283,47 +286,45 @@ def test_paddle_logical_xor( frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, - on_device=on_device, x=x[0], y=x[1], ) @handle_frontend_test( - fn_tree="paddle.logical_not", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("valid")), - test_with_out=st.just(True), + fn_tree="paddle.is_empty", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + num_arrays=2, + shared_dtype=True, + safety_factor_scale="log", + small_abs_safety_factor=32, + ), ) -def test_paddle_logical_not( - *, +def test_paddle_is_empty( dtype_and_x, - on_device, - fn_tree, frontend, test_flags, backend_fw, + fn_tree, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, - backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, + backend_to_test=backend_fw, fn_tree=fn_tree, - on_device=on_device, x=x[0], ) -# bitwise_or +# is_tensor @handle_frontend_test( - fn_tree="paddle.bitwise_or", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), num_arrays=2, shared_dtype=True - ), - test_with_out=st.just(True), + fn_tree="paddle.is_tensor", + dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("valid")), ) -def test_paddle_bitwise_or( +def test_paddle_is_tensor( *, dtype_and_x, on_device, @@ -335,31 +336,32 @@ def test_paddle_bitwise_or( input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, - frontend=frontend, backend_to_test=backend_fw, + on_device=on_device, + frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, - on_device=on_device, x=x[0], - y=x[1], ) +# isclose @handle_frontend_test( - fn_tree="paddle.bitwise_and", + fn_tree="paddle.isclose", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), num_arrays=2, shared_dtype=True, ), - test_with_out=st.just(True), + equal_nan=st.booleans(), ) -def test_paddle_bitwise_and( +def test_paddle_isclose( *, dtype_and_x, + equal_nan, on_device, - fn_tree, backend_fw, + fn_tree, frontend, test_flags, ): @@ -367,94 +369,92 @@ def test_paddle_bitwise_and( helpers.test_frontend_function( input_dtypes=input_dtype, frontend=frontend, - backend_to_test=backend_fw, test_flags=test_flags, + backend_to_test=backend_fw, fn_tree=fn_tree, on_device=on_device, x=x[0], y=x[1], + equal_nan=equal_nan, ) -# bitwise_xor +# less_equal @handle_frontend_test( - fn_tree="paddle.bitwise_xor", + fn_tree="paddle.less_equal", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), num_arrays=2, shared_dtype=True, + safety_factor_scale="log", + small_abs_safety_factor=32, ), - test_with_out=st.just(True), ) -def test_paddle_bitwise_xor( - *, +def test_paddle_less_equal( dtype_and_x, - on_device, - fn_tree, frontend, - backend_fw, test_flags, + fn_tree, + backend_fw, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, + backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, - backend_to_test=backend_fw, fn_tree=fn_tree, - on_device=on_device, x=x[0], y=x[1], ) -# bitwise_not +# less_than @handle_frontend_test( - fn_tree="paddle.bitwise_not", + fn_tree="paddle.less_than", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), + num_arrays=2, + shared_dtype=True, + safety_factor_scale="log", + small_abs_safety_factor=32, ), - test_with_out=st.just(True), ) -def test_paddle_bitwise_not( - *, +def test_paddle_less_than( dtype_and_x, - on_device, - fn_tree, frontend, - backend_fw, test_flags, + fn_tree, + backend_fw, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, + backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, - backend_to_test=backend_fw, fn_tree=fn_tree, - on_device=on_device, x=x[0], + y=x[1], ) -# allclose +# logical_and @handle_frontend_test( - fn_tree="paddle.allclose", + fn_tree="paddle.logical_and", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), num_arrays=2, shared_dtype=True, ), - equal_nan=st.booleans(), ) -def test_paddle_allclose( +def test_paddle_logical_and( *, dtype_and_x, - equal_nan, on_device, + backend_fw, fn_tree, frontend, - backend_fw, test_flags, ): input_dtype, x = dtype_and_x @@ -462,99 +462,93 @@ def test_paddle_allclose( input_dtypes=input_dtype, frontend=frontend, test_flags=test_flags, + backend_to_test=backend_fw, fn_tree=fn_tree, on_device=on_device, - backend_to_test=backend_fw, x=x[0], y=x[1], - equal_nan=equal_nan, ) -# is_tensor @handle_frontend_test( - fn_tree="paddle.is_tensor", + fn_tree="paddle.logical_not", dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("valid")), + test_with_out=st.just(True), ) -def test_paddle_is_tensor( +def test_paddle_logical_not( *, dtype_and_x, on_device, fn_tree, - backend_fw, frontend, test_flags, + backend_fw, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, - on_device=on_device, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, + on_device=on_device, x=x[0], ) -# isclose +# logical_or @handle_frontend_test( - fn_tree="paddle.isclose", + fn_tree="paddle.logical_or", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - num_arrays=2, - shared_dtype=True, + available_dtypes=helpers.get_dtypes("valid"), num_arrays=2, shared_dtype=True ), - equal_nan=st.booleans(), + test_with_out=st.just(True), ) -def test_paddle_isclose( +def test_paddle_logical_or( *, dtype_and_x, - equal_nan, on_device, - backend_fw, fn_tree, frontend, test_flags, + backend_fw, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, + backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, - backend_to_test=backend_fw, fn_tree=fn_tree, on_device=on_device, x=x[0], y=x[1], - equal_nan=equal_nan, ) -# logical_and +# logical_xor @handle_frontend_test( - fn_tree="paddle.logical_and", + fn_tree="paddle.logical_xor", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - num_arrays=2, - shared_dtype=True, + available_dtypes=helpers.get_dtypes("valid"), num_arrays=2, shared_dtype=True ), + test_with_out=st.just(True), ) -def test_paddle_logical_and( +def test_paddle_logical_xor( *, dtype_and_x, on_device, - backend_fw, fn_tree, frontend, test_flags, + backend_fw, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, + backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, - backend_to_test=backend_fw, fn_tree=fn_tree, on_device=on_device, x=x[0], @@ -562,8 +556,9 @@ def test_paddle_logical_and( ) +# not_equal @handle_frontend_test( - fn_tree="paddle.is_empty", + fn_tree="paddle.not_equal", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), num_arrays=2, @@ -572,19 +567,20 @@ def test_paddle_logical_and( small_abs_safety_factor=32, ), ) -def test_paddle_is_empty( +def test_paddle_not_equal( dtype_and_x, frontend, test_flags, - backend_fw, fn_tree, + backend_fw, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, + backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, - backend_to_test=backend_fw, fn_tree=fn_tree, x=x[0], + y=x[1], ) diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_manipulation.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_manipulation.py index a023cd7e3f24c..be36b2a116f1c 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_manipulation.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_manipulation.py @@ -10,6 +10,209 @@ ) +# --- Helpers --- # +# --------------- # + + +# stack +@st.composite +def _arrays_axis_n_dtypes(draw): + num_dims = draw(st.shared(helpers.ints(min_value=2, max_value=5), key="num_dims")) + num_arrays = draw( + st.shared(helpers.ints(min_value=2, max_value=4), key="num_arrays") + ) + common_shape = draw( + helpers.list_of_size( + x=helpers.ints(min_value=2, max_value=3), + size=num_dims - 1, + ) + ) + axis = draw(st.sampled_from(list(range(num_dims)))) + xs = [] + input_dtypes = draw( + helpers.array_dtypes(available_dtypes=draw(helpers.get_dtypes("numeric"))) + ) + dtype = draw(st.sampled_from(input_dtypes)) + for _ in range(num_arrays): + x = draw( + helpers.array_values( + shape=common_shape, + dtype=dtype, + ) + ) + xs.append(x) + input_dtypes = [dtype] * len(input_dtypes) + return xs, input_dtypes, axis + + +# concat +@st.composite +def _arrays_idx_n_dtypes(draw): + num_dims = draw(st.shared(helpers.ints(min_value=1, max_value=4), key="num_dims")) + num_arrays = draw( + st.shared(helpers.ints(min_value=2, max_value=4), key="num_arrays") + ) + common_shape = draw( + helpers.list_of_size( + x=helpers.ints(min_value=2, max_value=3), + size=num_dims - 1, + ) + ) + unique_idx = draw(helpers.ints(min_value=0, max_value=num_dims - 1)) + unique_dims = draw( + helpers.list_of_size( + x=helpers.ints(min_value=2, max_value=3), + size=num_arrays, + ) + ) + xs = [] + input_dtypes = draw( + helpers.array_dtypes(available_dtypes=draw(helpers.get_dtypes("valid"))) + ) + dtype = draw(st.sampled_from(input_dtypes)) + for ud in unique_dims: + x = draw( + helpers.array_values( + shape=common_shape[:unique_idx] + [ud] + common_shape[unique_idx:], + dtype=dtype, + ) + ) + xs.append(x) + input_dtypes = [dtype] * len(input_dtypes) + return xs, input_dtypes, unique_idx + + +@st.composite +def _broadcast_to_helper(draw): + dtype_and_x = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=1, + max_num_dims=6, + ) + ) + + dtype, x = dtype_and_x + input_shape = x[0].shape + + max_num_dims = 6 - len(input_shape) + shape = draw(helpers.get_shape(max_num_dims=max_num_dims)) + input_shape + + return dtype, x, shape + + +# flip +@st.composite +def _dtype_x_axis(draw, **kwargs): + dtype, x, shape = draw(helpers.dtype_and_values(**kwargs, ret_shape=True)) + axis = draw( + st.lists( + helpers.ints(min_value=0, max_value=len(shape) - 1), + min_size=len(shape), + max_size=len(shape), + unique=True, + ) + ) + return dtype, x, axis + + +# expand +@st.composite +def _expand_helper(draw): + dtype_and_x = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=1, + max_num_dims=6, + ) + ) + + dtype, x = dtype_and_x + input_shape = x[0].shape + + max_num_dims = 6 - len(input_shape) + shape = draw(helpers.get_shape(max_num_dims=max_num_dims)) + input_shape + + return dtype, x, shape + + +@st.composite +def _gather_helper(draw): + dtype_and_param = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=1, + max_num_dims=6, + ) + ) + + dtype_and_indices = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=1, + max_num_dims=6, + ) + ) + dtype, param = dtype_and_param + dtype, indices = dtype_and_indices + return dtype, param, indices + + +# split +@st.composite +def _split_helper(draw): + dtypes, values, shape = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=2, + max_num_dims=4, + min_dim_size=2, + max_dim_size=4, + ret_shape=True, + ) + ) + axis = draw(st.sampled_from(range(len(shape)))) + num_eles = shape[axis] + splits = [i for i in range(1, num_eles + 1) if num_eles % i == 0] + num_splits = draw(st.sampled_from(splits)) + return dtypes, values, num_splits, axis + + +# squeeze +@st.composite +def _squeeze_helper(draw): + shape = draw(st.shared(helpers.get_shape(), key="value_shape")) + valid_axes = [] + for index, axis in enumerate(shape): + if axis == 1: + valid_axes.append(index) + valid_axes.insert(0, None) + + return draw(st.sampled_from(valid_axes)) + + +# tile +@st.composite +def _tile_helper(draw): + dtype, x, shape = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=1, + max_num_dims=4, + min_dim_size=2, + max_dim_size=3, + ret_shape=True, + ) + ) + repeats = draw( + helpers.list_of_size( + x=helpers.ints(min_value=1, max_value=3), + size=len(shape), + ) + ) + return dtype, x, repeats + + # Helpers # # ------ # @@ -43,25 +246,27 @@ def dtypes_x_reshape_(draw): return dtypes, x, shape -# Tests # -# ----- # +# --- Main --- # +# ------------ # -# reshape +# abs @handle_frontend_test( - fn_tree="paddle.reshape", - dtypes_x_reshape=dtypes_x_reshape(), + fn_tree="paddle.abs", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + ), ) -def test_paddle_reshape( +def test_paddle_abs( *, - dtypes_x_reshape, + dtype_and_x, on_device, fn_tree, frontend, test_flags, backend_fw, ): - input_dtype, x, shape = dtypes_x_reshape + input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -70,25 +275,23 @@ def test_paddle_reshape( fn_tree=fn_tree, on_device=on_device, x=x[0], - shape=shape, ) -# reshape_ @handle_frontend_test( - fn_tree="paddle.reshape_", - dtypes_x_reshape=dtypes_x_reshape_(), + fn_tree="paddle.broadcast_to", + dtype_x_and_shape=_broadcast_to_helper(), ) -def test_paddle_reshape_( +def test_paddle_broadcast_to( *, - dtypes_x_reshape, + dtype_x_and_shape, on_device, fn_tree, + backend_fw, frontend, test_flags, - backend_fw, ): - input_dtype, x, shape = dtypes_x_reshape + input_dtype, x, shape = dtype_x_and_shape helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -101,21 +304,23 @@ def test_paddle_reshape_( ) -# abs +# cast @handle_frontend_test( - fn_tree="paddle.abs", + fn_tree="paddle.cast", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("valid"), ), + dtype=helpers.get_dtypes("valid", full=False), ) -def test_paddle_abs( +def test_paddle_cast( *, dtype_and_x, + dtype, on_device, + backend_fw, fn_tree, frontend, test_flags, - backend_fw, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -126,104 +331,10 @@ def test_paddle_abs( fn_tree=fn_tree, on_device=on_device, x=x[0], + dtype=dtype[0], ) -# stack -@st.composite -def _arrays_axis_n_dtypes(draw): - num_dims = draw(st.shared(helpers.ints(min_value=2, max_value=5), key="num_dims")) - num_arrays = draw( - st.shared(helpers.ints(min_value=2, max_value=4), key="num_arrays") - ) - common_shape = draw( - helpers.list_of_size( - x=helpers.ints(min_value=2, max_value=3), - size=num_dims - 1, - ) - ) - axis = draw(st.sampled_from(list(range(num_dims)))) - xs = [] - input_dtypes = draw( - helpers.array_dtypes(available_dtypes=draw(helpers.get_dtypes("numeric"))) - ) - dtype = draw(st.sampled_from(input_dtypes)) - for _ in range(num_arrays): - x = draw( - helpers.array_values( - shape=common_shape, - dtype=dtype, - ) - ) - xs.append(x) - input_dtypes = [dtype] * len(input_dtypes) - return xs, input_dtypes, axis - - -@handle_frontend_test( - fn_tree="paddle.stack", - _arrays_n_dtypes_axis=_arrays_axis_n_dtypes(), - test_with_out=st.just(False), -) -def test_paddle_stack( - *, - _arrays_n_dtypes_axis, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - xs, input_dtypes, axis = _arrays_n_dtypes_axis - helpers.test_frontend_function( - input_dtypes=input_dtypes, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=xs, - axis=axis, - ) - - -# concat -@st.composite -def _arrays_idx_n_dtypes(draw): - num_dims = draw(st.shared(helpers.ints(min_value=1, max_value=4), key="num_dims")) - num_arrays = draw( - st.shared(helpers.ints(min_value=2, max_value=4), key="num_arrays") - ) - common_shape = draw( - helpers.list_of_size( - x=helpers.ints(min_value=2, max_value=3), - size=num_dims - 1, - ) - ) - unique_idx = draw(helpers.ints(min_value=0, max_value=num_dims - 1)) - unique_dims = draw( - helpers.list_of_size( - x=helpers.ints(min_value=2, max_value=3), - size=num_arrays, - ) - ) - xs = [] - input_dtypes = draw( - helpers.array_dtypes(available_dtypes=draw(helpers.get_dtypes("valid"))) - ) - dtype = draw(st.sampled_from(input_dtypes)) - for ud in unique_dims: - x = draw( - helpers.array_values( - shape=common_shape[:unique_idx] + [ud] + common_shape[unique_idx:], - dtype=dtype, - ) - ) - xs.append(x) - input_dtypes = [dtype] * len(input_dtypes) - return xs, input_dtypes, unique_idx - - @handle_frontend_test( fn_tree="paddle.concat", xs_n_input_dtypes_n_unique_idx=_arrays_idx_n_dtypes(), @@ -251,135 +362,77 @@ def test_paddle_concat( ) -# tile -@st.composite -def _tile_helper(draw): - dtype, x, shape = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - min_num_dims=1, - max_num_dims=4, - min_dim_size=2, - max_dim_size=3, - ret_shape=True, - ) - ) - repeats = draw( - helpers.list_of_size( - x=helpers.ints(min_value=1, max_value=3), - size=len(shape), - ) - ) - return dtype, x, repeats - - @handle_frontend_test( - fn_tree="paddle.tile", - dt_x_repeats=_tile_helper(), - test_with_out=st.just(False), + fn_tree="paddle.expand", + dtype_x_and_shape=_expand_helper(), ) -def test_paddle_tile( +def test_paddle_expand( *, - dt_x_repeats, + dtype_x_and_shape, on_device, fn_tree, - frontend, backend_fw, + frontend, test_flags, ): - input_dtypes, x, repeats = dt_x_repeats + input_dtype, x, shape = dtype_x_and_shape helpers.test_frontend_function( - input_dtypes=input_dtypes, + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, x=x[0], - repeat_times=repeats, - ) - - -# split -@st.composite -def _split_helper(draw): - dtypes, values, shape = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - min_num_dims=2, - max_num_dims=4, - min_dim_size=2, - max_dim_size=4, - ret_shape=True, - ) + shape=shape, ) - axis = draw(st.sampled_from(range(len(shape)))) - num_eles = shape[axis] - splits = [i for i in range(1, num_eles + 1) if num_eles % i == 0] - num_splits = draw(st.sampled_from(splits)) - return dtypes, values, num_splits, axis @handle_frontend_test( - fn_tree="paddle.split", - dt_x_num_splits_axis=_split_helper(), + fn_tree="paddle.flip", + dtype_x_axis=_dtype_x_axis( + available_dtypes=helpers.get_dtypes("numeric"), + min_num_dims=1, + min_dim_size=1, + ), test_with_out=st.just(False), ) -def test_paddle_split( +def test_paddle_flip( *, - dt_x_num_splits_axis, + dtype_x_axis, on_device, fn_tree, frontend, test_flags, backend_fw, ): - input_dtypes, x, num_splits, axis = dt_x_num_splits_axis + input_dtype, x, axis = dtype_x_axis helpers.test_frontend_function( - input_dtypes=input_dtypes, + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, x=x[0], - num_or_sections=num_splits, axis=axis, ) -# squeeze -@st.composite -def _squeeze_helper(draw): - shape = draw(st.shared(helpers.get_shape(), key="value_shape")) - valid_axes = [] - for index, axis in enumerate(shape): - if axis == 1: - valid_axes.append(index) - valid_axes.insert(0, None) - - return draw(st.sampled_from(valid_axes)) - - @handle_frontend_test( - fn_tree="paddle.squeeze", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - shape=st.shared(helpers.get_shape(), key="value_shape"), - ), - axis=_squeeze_helper(), + fn_tree="paddle.gather", + dtype_param_and_indices=_gather_helper(), ) -def test_paddle_squeeze( +def test_paddle_gather( *, - dtype_and_x, - axis, + dtype_param_and_indices, on_device, fn_tree, frontend, - test_flags, backend_fw, + test_flags, ): - input_dtype, x = dtype_and_x + input_dtype, param, indices = dtype_param_and_indices helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -387,45 +440,30 @@ def test_paddle_squeeze( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], - axis=axis, - ) - - -# expand -@st.composite -def _expand_helper(draw): - dtype_and_x = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - min_num_dims=1, - max_num_dims=6, - ) + param=param[0], + indices=indices[0], ) - dtype, x = dtype_and_x - input_shape = x[0].shape - max_num_dims = 6 - len(input_shape) - shape = draw(helpers.get_shape(max_num_dims=max_num_dims)) + input_shape - - return dtype, x, shape +# Tests # +# ----- # +# reshape @handle_frontend_test( - fn_tree="paddle.expand", - dtype_x_and_shape=_expand_helper(), + fn_tree="paddle.reshape", + dtypes_x_reshape=dtypes_x_reshape(), ) -def test_paddle_expand( +def test_paddle_reshape( *, - dtype_x_and_shape, + dtypes_x_reshape, on_device, fn_tree, - backend_fw, frontend, test_flags, + backend_fw, ): - input_dtype, x, shape = dtype_x_and_shape + input_dtype, x, shape = dtypes_x_reshape helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -438,25 +476,21 @@ def test_paddle_expand( ) -# cast +# reshape_ @handle_frontend_test( - fn_tree="paddle.cast", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - ), - dtype=helpers.get_dtypes("valid", full=False), + fn_tree="paddle.reshape_", + dtypes_x_reshape=dtypes_x_reshape_(), ) -def test_paddle_cast( +def test_paddle_reshape_( *, - dtype_and_x, - dtype, + dtypes_x_reshape, on_device, - backend_fw, fn_tree, frontend, test_flags, + backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, x, shape = dtypes_x_reshape helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -465,43 +499,34 @@ def test_paddle_cast( fn_tree=fn_tree, on_device=on_device, x=x[0], - dtype=dtype[0], - ) - - -@st.composite -def _broadcast_to_helper(draw): - dtype_and_x = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - min_num_dims=1, - max_num_dims=6, - ) + shape=shape, ) - dtype, x = dtype_and_x - input_shape = x[0].shape - - max_num_dims = 6 - len(input_shape) - shape = draw(helpers.get_shape(max_num_dims=max_num_dims)) + input_shape - - return dtype, x, shape - +# roll @handle_frontend_test( - fn_tree="paddle.broadcast_to", - dtype_x_and_shape=_broadcast_to_helper(), + fn_tree="paddle.roll", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + min_num_dims=2, + min_dim_size=2, + ), + shift=helpers.ints(min_value=1, max_value=10), + axis=helpers.ints(min_value=-1, max_value=1), + test_with_out=st.just(False), ) -def test_paddle_broadcast_to( +def test_paddle_roll( *, - dtype_x_and_shape, + dtype_and_x, + shift, + axis, on_device, fn_tree, - backend_fw, frontend, test_flags, + backend_fw, ): - input_dtype, x, shape = dtype_x_and_shape + input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -510,46 +535,32 @@ def test_paddle_broadcast_to( fn_tree=fn_tree, on_device=on_device, x=x[0], - shape=shape, - ) - - -@st.composite -def _gather_helper(draw): - dtype_and_param = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - min_num_dims=1, - max_num_dims=6, - ) - ) - - dtype_and_indices = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - min_num_dims=1, - max_num_dims=6, - ) + shifts=shift, + axis=axis, ) - dtype, param = dtype_and_param - dtype, indices = dtype_and_indices - return dtype, param, indices +# rot90 @handle_frontend_test( - fn_tree="paddle.gather", - dtype_param_and_indices=_gather_helper(), + fn_tree="paddle.rot90", + dtype_m_k_axes=_get_dtype_values_k_axes_for_rot90( + available_dtypes=helpers.get_dtypes(kind="valid"), + min_num_dims=1, + max_num_dims=5, + min_dim_size=1, + max_dim_size=10, + ), ) -def test_paddle_gather( +def test_paddle_rot90( *, - dtype_param_and_indices, + dtype_m_k_axes, on_device, fn_tree, frontend, backend_fw, test_flags, ): - input_dtype, param, indices = dtype_param_and_indices + input_dtype, m, k, axes = dtype_m_k_axes helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -557,82 +568,59 @@ def test_paddle_gather( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - param=param[0], - indices=indices[0], + x=m, + k=k, + axes=tuple(axes), ) -# unstack @handle_frontend_test( - fn_tree="paddle.unstack", - dtypes_values=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - min_num_dims=2, - max_num_dims=2, - max_dim_size=1, - ), - number_positional_args=st.just(1), - axis=st.integers(-1, 0), + fn_tree="paddle.split", + dt_x_num_splits_axis=_split_helper(), test_with_out=st.just(False), ) -def test_paddle_unstack( +def test_paddle_split( *, - dtypes_values, - axis, + dt_x_num_splits_axis, on_device, fn_tree, - backend_fw, frontend, test_flags, + backend_fw, ): - x_dtype, x = dtypes_values - axis = axis + input_dtypes, x, num_splits, axis = dt_x_num_splits_axis helpers.test_frontend_function( - input_dtypes=x_dtype, + input_dtypes=input_dtypes, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, x=x[0], + num_or_sections=num_splits, axis=axis, ) -# flip -@st.composite -def _dtype_x_axis(draw, **kwargs): - dtype, x, shape = draw(helpers.dtype_and_values(**kwargs, ret_shape=True)) - axis = draw( - st.lists( - helpers.ints(min_value=0, max_value=len(shape) - 1), - min_size=len(shape), - max_size=len(shape), - unique=True, - ) - ) - return dtype, x, axis - - @handle_frontend_test( - fn_tree="paddle.flip", - dtype_x_axis=_dtype_x_axis( - available_dtypes=helpers.get_dtypes("numeric"), - min_num_dims=1, - min_dim_size=1, + fn_tree="paddle.squeeze", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + shape=st.shared(helpers.get_shape(), key="value_shape"), ), - test_with_out=st.just(False), + axis=_squeeze_helper(), ) -def test_paddle_flip( +def test_paddle_squeeze( *, - dtype_x_axis, + dtype_and_x, + axis, on_device, fn_tree, frontend, test_flags, backend_fw, ): - input_dtype, x, axis = dtype_x_axis + input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -645,39 +633,29 @@ def test_paddle_flip( ) -# roll @handle_frontend_test( - fn_tree="paddle.roll", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - min_num_dims=2, - min_dim_size=2, - ), - shift=helpers.ints(min_value=1, max_value=10), - axis=helpers.ints(min_value=-1, max_value=1), + fn_tree="paddle.stack", + _arrays_n_dtypes_axis=_arrays_axis_n_dtypes(), test_with_out=st.just(False), ) -def test_paddle_roll( +def test_paddle_stack( *, - dtype_and_x, - shift, - axis, + _arrays_n_dtypes_axis, on_device, fn_tree, frontend, test_flags, backend_fw, ): - input_dtype, x = dtype_and_x + xs, input_dtypes, axis = _arrays_n_dtypes_axis helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=input_dtypes, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], - shifts=shift, + x=xs, axis=axis, ) @@ -718,35 +696,65 @@ def test_paddle_take_along_axis( ) -# rot90 @handle_frontend_test( - fn_tree="paddle.rot90", - dtype_m_k_axes=_get_dtype_values_k_axes_for_rot90( - available_dtypes=helpers.get_dtypes(kind="valid"), - min_num_dims=1, - max_num_dims=5, - min_dim_size=1, - max_dim_size=10, - ), + fn_tree="paddle.tile", + dt_x_repeats=_tile_helper(), + test_with_out=st.just(False), ) -def test_paddle_rot90( +def test_paddle_tile( *, - dtype_m_k_axes, + dt_x_repeats, on_device, fn_tree, frontend, backend_fw, test_flags, ): - input_dtype, m, k, axes = dtype_m_k_axes + input_dtypes, x, repeats = dt_x_repeats helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=input_dtypes, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=m, - k=k, - axes=tuple(axes), + x=x[0], + repeat_times=repeats, + ) + + +# unstack +@handle_frontend_test( + fn_tree="paddle.unstack", + dtypes_values=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + min_num_dims=2, + max_num_dims=2, + max_dim_size=1, + ), + number_positional_args=st.just(1), + axis=st.integers(-1, 0), + test_with_out=st.just(False), +) +def test_paddle_unstack( + *, + dtypes_values, + axis, + on_device, + fn_tree, + backend_fw, + frontend, + test_flags, +): + x_dtype, x = dtypes_values + axis = axis + helpers.test_frontend_function( + input_dtypes=x_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + axis=axis, ) diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_math.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_math.py index 96344d22b0c1e..62ab9db541e88 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_math.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_math.py @@ -10,42 +10,37 @@ ) -# sin -@handle_frontend_test( - fn_tree="paddle.sin", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - ), -) -def test_paddle_sin( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], +# --- Helpers --- # +# --------------- # + + +@st.composite +def _test_paddle_take_helper(draw): + mode = draw(st.sampled_from(["raise", "clip", "wrap"])) + + safe_bounds = mode == "raise" + + dtypes, xs, indices, _, _ = draw( + helpers.array_indices_axis( + array_dtypes=helpers.get_dtypes("float_and_integer"), + indices_dtypes=["int32", "int64"], + valid_bounds=safe_bounds, + ) ) + return dtypes, xs, indices, mode + -# cos +# --- Main --- # +# ------------ # + + +# abs @handle_frontend_test( - fn_tree="paddle.cos", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - ), + fn_tree="paddle.tensor.math.abs", + dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), ) -def test_paddle_cos( +def test_paddle_abs( *, dtype_and_x, on_device, @@ -95,14 +90,14 @@ def test_paddle_acos( ) -# cosh +# acosh @handle_frontend_test( - fn_tree="paddle.tensor.math.cosh", + fn_tree="paddle.tensor.math.acosh", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), ), ) -def test_paddle_cosh( +def test_paddle_acosh( *, dtype_and_x, on_device, @@ -124,15 +119,20 @@ def test_paddle_cosh( ) -# tanh +# add @handle_frontend_test( - fn_tree="paddle.tensor.math.tanh", - aliases=["paddle.tanh", "paddle.nn.functional.tanh"], + fn_tree="paddle.add", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), + num_arrays=2, + allow_inf=False, + large_abs_safety_factor=2, + small_abs_safety_factor=2, + safety_factor_scale="log", + shared_dtype=True, ), ) -def test_paddle_tanh( +def test_paddle_add( *, dtype_and_x, on_device, @@ -146,31 +146,45 @@ def test_paddle_tanh( input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, - test_flags=test_flags, fn_tree=fn_tree, + test_flags=test_flags, on_device=on_device, - atol=1e-2, x=x[0], + y=x[1], ) -# acosh +# addmm @handle_frontend_test( - fn_tree="paddle.tensor.math.acosh", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + fn_tree="paddle.tensor.math.addmm", + dtype_input_xy=_get_dtype_and_3dbatch_matrices(with_input=True, input_3d=True), + beta=st.floats( + min_value=-5, + max_value=5, + allow_nan=False, + allow_subnormal=False, + allow_infinity=False, + ), + alpha=st.floats( + min_value=-5, + max_value=5, + allow_nan=False, + allow_subnormal=False, + allow_infinity=False, ), ) -def test_paddle_acosh( +def test_paddle_addmm( *, - dtype_and_x, + dtype_input_xy, + beta, + alpha, on_device, fn_tree, frontend, test_flags, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, input, x, y = dtype_input_xy helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -178,116 +192,117 @@ def test_paddle_acosh( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - atol=1e-2, + input=input[0], x=x[0], + y=y[0], + beta=beta, + alpha=alpha, ) -# asin +# amax @handle_frontend_test( - fn_tree="paddle.tensor.math.asin", + fn_tree="paddle.tensor.math.amax", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), + num_arrays=2, + allow_inf=False, + shared_dtype=True, ), ) -def test_paddle_asin( +def test_paddle_amax( *, dtype_and_x, - frontend, - test_flags, + on_device, fn_tree, backend_fw, - on_device, + frontend, + test_flags, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, - backend_to_test=backend_fw, frontend=frontend, - test_flags=test_flags, + backend_to_test=backend_fw, fn_tree=fn_tree, + test_flags=test_flags, on_device=on_device, x=x[0], ) -# log +# amin @handle_frontend_test( - fn_tree="paddle.log", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + fn_tree="paddle.tensor.math.amin", + dtype_and_x=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("valid"), + valid_axis=True, ), + keepdim=st.booleans(), ) -def test_paddle_log( +def test_paddle_amin( *, dtype_and_x, + keepdim, on_device, fn_tree, + backend_fw, frontend, test_flags, - backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, x, axis = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, - backend_to_test=backend_fw, frontend=frontend, - test_flags=test_flags, + backend_to_test=backend_fw, fn_tree=fn_tree, + test_flags=test_flags, on_device=on_device, x=x[0], + axis=axis, + keepdim=keepdim, ) -# divide @handle_frontend_test( - fn_tree="paddle.divide", + fn_tree="paddle.tensor.math.angle", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - allow_inf=False, - large_abs_safety_factor=2, - small_abs_safety_factor=2, - safety_factor_scale="log", - shared_dtype=True, + available_dtypes=["float64", "complex64", "complex128"], ), ) -def test_paddle_divide( +def test_paddle_angle( *, dtype_and_x, on_device, fn_tree, frontend, - test_flags, backend_fw, + test_flags, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, - backend_to_test=backend_fw, frontend=frontend, + backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, x=x[0], - y=x[1], ) -# multiply +# any @handle_frontend_test( - fn_tree="paddle.multiply", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - allow_inf=False, - large_abs_safety_factor=2, - small_abs_safety_factor=2, - safety_factor_scale="log", - shared_dtype=True, + fn_tree="paddle.tensor.math.any", + dtype_and_x=helpers.dtype_values_axis( + available_dtypes=["bool"], + valid_axis=True, + allow_neg_axes=True, + force_int_axis=True, + min_num_dims=1, ), ) -def test_paddle_multiply( +def test_paddle_any( *, dtype_and_x, on_device, @@ -296,97 +311,85 @@ def test_paddle_multiply( test_flags, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, x, axis = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, - backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, + backend_to_test=backend_fw, x=x[0], - y=x[1], + axis=axis, + keepdim=False, ) -# add +# asin @handle_frontend_test( - fn_tree="paddle.add", + fn_tree="paddle.tensor.math.asin", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), - num_arrays=2, - allow_inf=False, - large_abs_safety_factor=2, - small_abs_safety_factor=2, - safety_factor_scale="log", - shared_dtype=True, ), ) -def test_paddle_add( +def test_paddle_asin( *, dtype_and_x, - on_device, - fn_tree, frontend, test_flags, + fn_tree, backend_fw, + on_device, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, - fn_tree=fn_tree, test_flags=test_flags, + fn_tree=fn_tree, on_device=on_device, x=x[0], - y=x[1], ) -# subtract +# asinh @handle_frontend_test( - fn_tree="paddle.subtract", + fn_tree="paddle.tensor.math.asinh", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - allow_inf=False, - large_abs_safety_factor=2, - small_abs_safety_factor=2, - safety_factor_scale="log", - shared_dtype=True, ), ) -def test_paddle_subtract( +def test_paddle_asinh( *, dtype_and_x, on_device, fn_tree, frontend, - test_flags, backend_fw, + test_flags, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, - backend_to_test=backend_fw, frontend=frontend, - fn_tree=fn_tree, + backend_to_test=backend_fw, test_flags=test_flags, + fn_tree=fn_tree, on_device=on_device, + atol=1e-2, x=x[0], - y=x[1], ) -# sqrt +# atan @handle_frontend_test( - fn_tree="paddle.tensor.math.sqrt", + fn_tree="paddle.tensor.math.atan", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), + available_dtypes=helpers.get_dtypes("float"), ), ) -def test_paddle_sqrt( +def test_paddle_atan( *, dtype_and_x, frontend, @@ -407,49 +410,56 @@ def test_paddle_sqrt( ) -# atanh +# atan2 @handle_frontend_test( - fn_tree="paddle.tensor.math.atanh", + fn_tree="paddle.atan2", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, + allow_inf=False, + large_abs_safety_factor=2, + small_abs_safety_factor=2, + safety_factor_scale="log", + shared_dtype=True, ), ) -def test_paddle_atanh( +def test_paddle_atan2( *, dtype_and_x, on_device, fn_tree, frontend, - test_flags, backend_fw, + test_flags, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, - backend_to_test=backend_fw, frontend=frontend, - test_flags=test_flags, + backend_to_test=backend_fw, fn_tree=fn_tree, + test_flags=test_flags, on_device=on_device, x=x[0], + y=x[1], ) -# atan +# atanh @handle_frontend_test( - fn_tree="paddle.tensor.math.atan", + fn_tree="paddle.tensor.math.atanh", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), ), ) -def test_paddle_atan( +def test_paddle_atanh( *, dtype_and_x, + on_device, + fn_tree, frontend, test_flags, - fn_tree, backend_fw, - on_device, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -463,15 +473,14 @@ def test_paddle_atan( ) -# round +# ceil @handle_frontend_test( - fn_tree="paddle.tensor.math.round", + fn_tree="paddle.tensor.math.ceil", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - min_value=1, ), ) -def test_paddle_round( +def test_paddle_ceil( *, dtype_and_x, frontend, @@ -492,24 +501,23 @@ def test_paddle_round( ) -# round_ +# conj @handle_frontend_test( - fn_tree="paddle.tensor.math.round_", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_value=1, + fn_tree="paddle.tensor.math.conj", + dtype_and_input=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), ), ) -def test_paddle_round_( +def test_paddle_conj( *, - dtype_and_x, + dtype_and_input, frontend, + backend_fw, test_flags, fn_tree, - backend_fw, on_device, ): - input_dtype, x = dtype_and_x + input_dtype, x = dtype_and_input helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -521,21 +529,21 @@ def test_paddle_round_( ) -# ceil +# cos @handle_frontend_test( - fn_tree="paddle.tensor.math.ceil", + fn_tree="paddle.cos", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), ), ) -def test_paddle_ceil( +def test_paddle_cos( *, dtype_and_x, + on_device, + fn_tree, frontend, test_flags, - fn_tree, backend_fw, - on_device, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -549,14 +557,14 @@ def test_paddle_ceil( ) -# sinh +# cosh @handle_frontend_test( - fn_tree="paddle.sinh", + fn_tree="paddle.tensor.math.cosh", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), ), ) -def test_paddle_sinh( +def test_paddle_cosh( *, dtype_and_x, on_device, @@ -573,89 +581,66 @@ def test_paddle_sinh( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, + atol=1e-2, x=x[0], ) -# pow +# cumprod @handle_frontend_test( - fn_tree="paddle.pow", - dtype_and_x=helpers.dtype_and_values( + fn_tree="paddle.tensor.math.cumprod", + dtype_x_axis=helpers.dtype_values_axis( available_dtypes=helpers.get_dtypes("valid"), - num_arrays=2, - allow_inf=False, - shared_dtype=True, + valid_axis=True, + force_int_axis=True, + min_num_dims=1, + min_value=-5, + max_value=5, ), ) -def test_paddle_pow( +def test_paddle_cumprod( *, - dtype_and_x, + dtype_x_axis, on_device, fn_tree, frontend, - test_flags, backend_fw, + test_flags, ): - input_dtype, x = dtype_and_x + input_dtype, x, axis = dtype_x_axis helpers.test_frontend_function( input_dtypes=input_dtype, - backend_to_test=backend_fw, frontend=frontend, + backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, x=x[0], - y=x[1], + dim=axis, ) -# abs +# deg2rad @handle_frontend_test( - fn_tree="paddle.tensor.math.abs", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), + fn_tree="paddle.deg2rad", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + ), ) -def test_paddle_abs( +def test_paddle_deg2rad( *, dtype_and_x, on_device, fn_tree, frontend, - test_flags, backend_fw, + test_flags, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, - backend_to_test=backend_fw, frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - ) - - -# conj -@handle_frontend_test( - fn_tree="paddle.tensor.math.conj", - dtype_and_input=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - ), -) -def test_paddle_conj( - *, - dtype_and_input, - frontend, - backend_fw, - test_flags, - fn_tree, - on_device, -): - input_dtype, x = dtype_and_input - helpers.test_frontend_function( - input_dtypes=input_dtype, backend_to_test=backend_fw, - frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, @@ -663,37 +648,60 @@ def test_paddle_conj( ) -# floor +# diff @handle_frontend_test( - fn_tree="paddle.tensor.math.floor", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), + fn_tree="paddle.tensor.math.diff", + dtype_n_x_n_axis=helpers.dtype_values_axis( + available_dtypes=st.shared(helpers.get_dtypes("valid"), key="dtype"), + min_num_dims=1, + valid_axis=True, + force_int_axis=True, + ), + n=st.integers(min_value=1, max_value=1), + dtype_prepend=helpers.dtype_and_values( + available_dtypes=st.shared(helpers.get_dtypes("valid"), key="dtype"), + min_num_dims=1, + max_num_dims=1, + ), + dtype_append=helpers.dtype_and_values( + available_dtypes=st.shared(helpers.get_dtypes("valid"), key="dtype"), + min_num_dims=1, + max_num_dims=1, ), ) -def test_paddle_floor( +def test_paddle_diff( *, - dtype_and_x, + dtype_n_x_n_axis, + n, + dtype_prepend, + dtype_append, + test_flags, frontend, backend_fw, - test_flags, fn_tree, on_device, ): - input_dtype, x = dtype_and_x + input_dtype, x, axis = dtype_n_x_n_axis + _, prepend = dtype_prepend + _, append = dtype_append helpers.test_frontend_function( input_dtypes=input_dtype, + test_flags=test_flags, frontend=frontend, backend_to_test=backend_fw, - test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, x=x[0], + n=n, + axis=axis, + prepend=prepend[0], + append=append[0], ) -# remainder +# divide @handle_frontend_test( - fn_tree="paddle.remainder", + fn_tree="paddle.divide", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), num_arrays=2, @@ -704,20 +712,20 @@ def test_paddle_floor( shared_dtype=True, ), ) -def test_paddle_remainder( +def test_paddle_divide( *, dtype_and_x, on_device, fn_tree, frontend, - backend_fw, test_flags, + backend_fw, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, - frontend=frontend, backend_to_test=backend_fw, + frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, @@ -726,21 +734,21 @@ def test_paddle_remainder( ) -# log2 +# erf @handle_frontend_test( - fn_tree="paddle.log2", + fn_tree="paddle.tensor.math.erf", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), ), ) -def test_paddle_log2( +def test_paddle_erf( *, dtype_and_x, - on_device, - fn_tree, frontend, backend_fw, test_flags, + fn_tree, + on_device, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -754,15 +762,14 @@ def test_paddle_log2( ) -# log1p +# exp @handle_frontend_test( - fn_tree="paddle.log1p", + fn_tree="paddle.exp", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), - max_value=1e5, ), ) -def test_paddle_log1p( +def test_paddle_exp( *, dtype_and_x, on_device, @@ -783,14 +790,14 @@ def test_paddle_log1p( ) -# rad2deg +# expm1 @handle_frontend_test( - fn_tree="paddle.rad2deg", + fn_tree="paddle.expm1", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("valid"), ), ) -def test_paddle_rad2deg( +def test_paddle_expm1( *, dtype_and_x, on_device, @@ -811,21 +818,21 @@ def test_paddle_rad2deg( ) -# deg2rad +# floor @handle_frontend_test( - fn_tree="paddle.deg2rad", + fn_tree="paddle.tensor.math.floor", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("valid"), ), ) -def test_paddle_deg2rad( +def test_paddle_floor( *, dtype_and_x, - on_device, - fn_tree, frontend, backend_fw, test_flags, + fn_tree, + on_device, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -839,23 +846,22 @@ def test_paddle_deg2rad( ) -# tan @handle_frontend_test( - fn_tree="paddle.tensor.math.tan", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + fn_tree="paddle.fmax", + dtypes_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), num_arrays=2, shared_dtype=True ), ) -def test_paddle_tan( +def test_paddle_fmax( *, - dtype_and_x, + dtypes_and_x, on_device, fn_tree, frontend, backend_fw, test_flags, ): - input_dtype, x = dtype_and_x + input_dtype, x = dtypes_and_x helpers.test_frontend_function( input_dtypes=input_dtype, frontend=frontend, @@ -863,61 +869,57 @@ def test_paddle_tan( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - atol=1e-2, x=x[0], + y=x[1], ) -# atan2 @handle_frontend_test( - fn_tree="paddle.atan2", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - allow_inf=False, - large_abs_safety_factor=2, - small_abs_safety_factor=2, - safety_factor_scale="log", - shared_dtype=True, + fn_tree="paddle.fmin", + dtypes_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), num_arrays=2, shared_dtype=True ), ) -def test_paddle_atan2( +def test_paddle_fmin( *, - dtype_and_x, + dtypes_and_x, on_device, fn_tree, frontend, backend_fw, test_flags, ): - input_dtype, x = dtype_and_x + input_dtype, x = dtypes_and_x helpers.test_frontend_function( input_dtypes=input_dtype, frontend=frontend, backend_to_test=backend_fw, - fn_tree=fn_tree, test_flags=test_flags, + fn_tree=fn_tree, on_device=on_device, x=x[0], y=x[1], ) -# sign +# frac @handle_frontend_test( - fn_tree="paddle.tensor.math.sign", + fn_tree="paddle.tensor.math.frac", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("valid"), + num_arrays=1, + max_value=1e6, + min_value=-1e6, ), ) -def test_paddle_sign( +def test_paddle_frac( *, dtype_and_x, - on_device, - fn_tree, frontend, backend_fw, test_flags, + fn_tree, + on_device, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -931,14 +933,20 @@ def test_paddle_sign( ) -# neg +# gcd @handle_frontend_test( - fn_tree="paddle.neg", + fn_tree="paddle.gcd", dtype_and_x=helpers.dtype_and_values( - available_dtypes=["float32", "float64", "int8", "int16", "int32", "int64"], + available_dtypes=helpers.get_dtypes("valid"), + min_value=-100, + max_value=100, + min_num_dims=1, + min_dim_size=1, + num_arrays=2, + shared_dtype=True, ), ) -def test_paddle_neg( +def test_paddle_gcd( *, dtype_and_x, on_device, @@ -952,22 +960,28 @@ def test_paddle_neg( input_dtypes=input_dtype, frontend=frontend, backend_to_test=backend_fw, - test_flags=test_flags, fn_tree=fn_tree, + test_flags=test_flags, on_device=on_device, x=x[0], + y=x[1], ) -# lgamma +# heaviside @handle_frontend_test( - fn_tree="paddle.tensor.math.lgamma", + fn_tree="paddle.tensor.math.heaviside", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, + allow_inf=False, + large_abs_safety_factor=2, + small_abs_safety_factor=2, safety_factor_scale="log", + shared_dtype=True, ), ) -def test_paddle_lgamma( +def test_paddle_heaviside( *, dtype_and_x, on_device, @@ -981,29 +995,29 @@ def test_paddle_lgamma( input_dtypes=input_dtype, frontend=frontend, backend_to_test=backend_fw, - test_flags=test_flags, fn_tree=fn_tree, + test_flags=test_flags, on_device=on_device, - atol=1e-4, x=x[0], + y=x[1], ) -# exp +# isfinite @handle_frontend_test( - fn_tree="paddle.exp", + fn_tree="paddle.isfinite", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), ), ) -def test_paddle_exp( +def test_paddle_isfinite( *, dtype_and_x, - on_device, - fn_tree, frontend, backend_fw, test_flags, + fn_tree, + on_device, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -1017,21 +1031,21 @@ def test_paddle_exp( ) -# expm1 +# isinf @handle_frontend_test( - fn_tree="paddle.expm1", + fn_tree="paddle.isinf", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), ), ) -def test_paddle_expm1( +def test_paddle_isinf( *, dtype_and_x, - on_device, - fn_tree, frontend, backend_fw, test_flags, + fn_tree, + on_device, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -1045,21 +1059,21 @@ def test_paddle_expm1( ) -# square +# isnan @handle_frontend_test( - fn_tree="paddle.tensor.math.square", + fn_tree="paddle.isnan", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("valid"), ), ) -def test_paddle_square( +def test_paddle_isnan( *, dtype_and_x, - on_device, - fn_tree, frontend, backend_fw, test_flags, + fn_tree, + on_device, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -1073,20 +1087,23 @@ def test_paddle_square( ) -# reciprocal +# kron @handle_frontend_test( - fn_tree="paddle.reciprocal", + fn_tree="paddle.tensor.math.kron", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, + allow_inf=False, + shared_dtype=True, ), ) -def test_paddle_reciprocal( +def test_paddle_kron( *, dtype_and_x, on_device, fn_tree, - frontend, backend_fw, + frontend, test_flags, ): input_dtype, x = dtype_and_x @@ -1094,10 +1111,11 @@ def test_paddle_reciprocal( input_dtypes=input_dtype, frontend=frontend, backend_to_test=backend_fw, - test_flags=test_flags, fn_tree=fn_tree, + test_flags=test_flags, on_device=on_device, x=x[0], + y=x[1], ) @@ -1135,54 +1153,51 @@ def test_paddle_lcm( ) -# cumprod +# lerp @handle_frontend_test( - fn_tree="paddle.tensor.math.cumprod", - dtype_x_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("valid"), - valid_axis=True, - force_int_axis=True, - min_num_dims=1, - min_value=-5, - max_value=5, + fn_tree="paddle.tensor.math.lerp", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + num_arrays=3, + allow_inf=False, + large_abs_safety_factor=2, + small_abs_safety_factor=2, + safety_factor_scale="log", + shared_dtype=True, ), ) -def test_paddle_cumprod( +def test_paddle_lerp( *, - dtype_x_axis, + dtype_and_x, on_device, fn_tree, frontend, backend_fw, test_flags, ): - input_dtype, x, axis = dtype_x_axis + input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, frontend=frontend, backend_to_test=backend_fw, - test_flags=test_flags, fn_tree=fn_tree, + test_flags=test_flags, on_device=on_device, x=x[0], - dim=axis, + y=x[1], + weight=x[2], ) -# gcd +# lgamma @handle_frontend_test( - fn_tree="paddle.gcd", + fn_tree="paddle.tensor.math.lgamma", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - min_value=-100, - max_value=100, - min_num_dims=1, - min_dim_size=1, - num_arrays=2, - shared_dtype=True, + available_dtypes=helpers.get_dtypes("float"), + safety_factor_scale="log", ), ) -def test_paddle_gcd( +def test_paddle_lgamma( *, dtype_and_x, on_device, @@ -1196,50 +1211,51 @@ def test_paddle_gcd( input_dtypes=input_dtype, frontend=frontend, backend_to_test=backend_fw, - fn_tree=fn_tree, test_flags=test_flags, + fn_tree=fn_tree, on_device=on_device, + atol=1e-4, x=x[0], - y=x[1], ) +# log @handle_frontend_test( - fn_tree="paddle.fmin", - dtypes_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), num_arrays=2, shared_dtype=True + fn_tree="paddle.log", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), ), ) -def test_paddle_fmin( +def test_paddle_log( *, - dtypes_and_x, + dtype_and_x, on_device, fn_tree, frontend, - backend_fw, test_flags, + backend_fw, ): - input_dtype, x = dtypes_and_x + input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, - frontend=frontend, backend_to_test=backend_fw, + frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, x=x[0], - y=x[1], ) -# logit +# log1p @handle_frontend_test( - fn_tree="paddle.logit", + fn_tree="paddle.log1p", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("valid"), + max_value=1e5, ), ) -def test_paddle_logit( +def test_paddle_log1p( *, dtype_and_x, on_device, @@ -1257,25 +1273,24 @@ def test_paddle_logit( fn_tree=fn_tree, on_device=on_device, x=x[0], - eps=1e-2, ) -# isnan +# log2 @handle_frontend_test( - fn_tree="paddle.isnan", + fn_tree="paddle.log2", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), ), ) -def test_paddle_isnan( +def test_paddle_log2( *, dtype_and_x, + on_device, + fn_tree, frontend, backend_fw, test_flags, - fn_tree, - on_device, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -1289,21 +1304,21 @@ def test_paddle_isnan( ) -# isfinite +# logit @handle_frontend_test( - fn_tree="paddle.isfinite", + fn_tree="paddle.logit", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), + available_dtypes=helpers.get_dtypes("float"), ), ) -def test_paddle_isfinite( +def test_paddle_logit( *, dtype_and_x, + on_device, + fn_tree, frontend, backend_fw, test_flags, - fn_tree, - on_device, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -1314,26 +1329,31 @@ def test_paddle_isfinite( fn_tree=fn_tree, on_device=on_device, x=x[0], + eps=1e-2, ) -# isinf +# max @handle_frontend_test( - fn_tree="paddle.isinf", - dtype_and_x=helpers.dtype_and_values( + fn_tree="paddle.tensor.math.max", + dtype_and_x=helpers.dtype_values_axis( available_dtypes=helpers.get_dtypes("valid"), + min_axis=-1, + max_axis=0, + min_num_dims=1, + force_int_axis=False, ), ) -def test_paddle_isinf( +def test_paddle_max( *, dtype_and_x, + on_device, + fn_tree, frontend, backend_fw, test_flags, - fn_tree, - on_device, ): - input_dtype, x = dtype_and_x + input_dtype, x, axis = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, frontend=frontend, @@ -1342,16 +1362,21 @@ def test_paddle_isinf( fn_tree=fn_tree, on_device=on_device, x=x[0], + axis=axis, + keepdim=False, ) +# maximum @handle_frontend_test( - fn_tree="paddle.tensor.math.angle", + fn_tree="paddle.maximum", dtype_and_x=helpers.dtype_and_values( - available_dtypes=["float64", "complex64", "complex128"], + available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, + shared_dtype=True, ), ) -def test_paddle_angle( +def test_paddle_maximum( *, dtype_and_x, on_device, @@ -1369,25 +1394,31 @@ def test_paddle_angle( fn_tree=fn_tree, on_device=on_device, x=x[0], + y=x[1], ) +# min @handle_frontend_test( - fn_tree="paddle.fmax", - dtypes_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), num_arrays=2, shared_dtype=True + fn_tree="paddle.tensor.math.min", + dtype_and_x=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("valid"), + min_axis=-1, + max_axis=0, + min_num_dims=1, + force_int_axis=False, ), ) -def test_paddle_fmax( +def test_paddle_min( *, - dtypes_and_x, + dtype_and_x, on_device, fn_tree, frontend, backend_fw, test_flags, ): - input_dtype, x = dtypes_and_x + input_dtype, x, axis = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, frontend=frontend, @@ -1396,7 +1427,8 @@ def test_paddle_fmax( fn_tree=fn_tree, on_device=on_device, x=x[0], - y=x[1], + axis=axis, + keepdim=False, ) @@ -1430,76 +1462,76 @@ def test_paddle_minimum( ) -# erf +# mm @handle_frontend_test( - fn_tree="paddle.tensor.math.erf", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - ), + fn_tree="paddle.tensor.math.mm", + dtype_xy=_get_dtype_input_and_matrices(), ) -def test_paddle_erf( +def test_paddle_mm( *, - dtype_and_x, + dtype_xy, + on_device, + fn_tree, frontend, - backend_fw, test_flags, - fn_tree, - on_device, + backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, x, y = dtype_xy helpers.test_frontend_function( input_dtypes=input_dtype, - frontend=frontend, backend_to_test=backend_fw, + frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], + input=x, + mat2=y, ) -# trunc +# multiply @handle_frontend_test( - fn_tree="paddle.trunc", + fn_tree="paddle.multiply", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float", "int"), + available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, + allow_inf=False, + large_abs_safety_factor=2, + small_abs_safety_factor=2, + safety_factor_scale="log", + shared_dtype=True, ), ) -def test_paddle_trunc( +def test_paddle_multiply( *, dtype_and_x, on_device, fn_tree, frontend, - backend_fw, test_flags, + backend_fw, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, - frontend=frontend, backend_to_test=backend_fw, + frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, x=x[0], + y=x[1], ) +# neg @handle_frontend_test( - fn_tree="paddle.tensor.math.sgn", + fn_tree="paddle.neg", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float_and_complex"), - min_num_dims=1, - max_num_dims=1, - min_dim_size=1, - max_dim_size=1, - abs_smallest_val=1e-10, - min_value=-10, - max_value=10, + available_dtypes=["float32", "float64", "int8", "int16", "int32", "int64"], ), ) -def test_paddle_sgn( +def test_paddle_neg( *, dtype_and_x, on_device, @@ -1520,16 +1552,18 @@ def test_paddle_sgn( ) -# maximum +# outer @handle_frontend_test( - fn_tree="paddle.maximum", + fn_tree="paddle.tensor.math.outer", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("valid"), num_arrays=2, + min_num_dims=1, + max_num_dims=1, shared_dtype=True, ), ) -def test_paddle_maximum( +def test_paddle_outer( *, dtype_and_x, on_device, @@ -1543,53 +1577,91 @@ def test_paddle_maximum( input_dtypes=input_dtype, frontend=frontend, backend_to_test=backend_fw, - test_flags=test_flags, fn_tree=fn_tree, + test_flags=test_flags, on_device=on_device, x=x[0], y=x[1], ) -# frac +# pow @handle_frontend_test( - fn_tree="paddle.tensor.math.frac", + fn_tree="paddle.pow", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), - num_arrays=1, - max_value=1e6, - min_value=-1e6, + num_arrays=2, + allow_inf=False, + shared_dtype=True, ), ) -def test_paddle_frac( +def test_paddle_pow( *, dtype_and_x, + on_device, + fn_tree, frontend, - backend_fw, test_flags, - fn_tree, - on_device, + backend_fw, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, - frontend=frontend, backend_to_test=backend_fw, + frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, x=x[0], + y=x[1], ) -# asinh +# prod @handle_frontend_test( - fn_tree="paddle.tensor.math.asinh", + fn_tree="paddle.tensor.math.prod", + dtype_and_x=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("numeric"), + min_axis=-1, + max_axis=0, + min_num_dims=1, + min_value=-10, + max_value=10, + force_int_axis=False, + allow_nan=False, + ), +) +def test_paddle_prod( + *, + dtype_and_x, + on_device, + backend_fw, + fn_tree, + frontend, + test_flags, +): + input_dtype, x, axis = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + axis=axis, + keepdim=False, + backend_to_test=backend_fw, + ) + + +# rad2deg +@handle_frontend_test( + fn_tree="paddle.rad2deg", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), ), ) -def test_paddle_asinh( +def test_paddle_rad2deg( *, dtype_and_x, on_device, @@ -1606,23 +1678,18 @@ def test_paddle_asinh( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - atol=1e-2, x=x[0], ) -# max +# reciprocal @handle_frontend_test( - fn_tree="paddle.tensor.math.max", - dtype_and_x=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("valid"), - min_axis=-1, - max_axis=0, - min_num_dims=1, - force_int_axis=False, + fn_tree="paddle.reciprocal", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), ), ) -def test_paddle_max( +def test_paddle_reciprocal( *, dtype_and_x, on_device, @@ -1631,7 +1698,7 @@ def test_paddle_max( backend_fw, test_flags, ): - input_dtype, x, axis = dtype_and_x + input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, frontend=frontend, @@ -1640,17 +1707,15 @@ def test_paddle_max( fn_tree=fn_tree, on_device=on_device, x=x[0], - axis=axis, - keepdim=False, ) -# lerp +# remainder @handle_frontend_test( - fn_tree="paddle.tensor.math.lerp", + fn_tree="paddle.remainder", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - num_arrays=3, + num_arrays=2, allow_inf=False, large_abs_safety_factor=2, small_abs_safety_factor=2, @@ -1658,7 +1723,7 @@ def test_paddle_max( shared_dtype=True, ), ) -def test_paddle_lerp( +def test_paddle_remainder( *, dtype_and_x, on_device, @@ -1672,80 +1737,69 @@ def test_paddle_lerp( input_dtypes=input_dtype, frontend=frontend, backend_to_test=backend_fw, - fn_tree=fn_tree, test_flags=test_flags, + fn_tree=fn_tree, on_device=on_device, x=x[0], y=x[1], - weight=x[2], ) -# outer +# round @handle_frontend_test( - fn_tree="paddle.tensor.math.outer", + fn_tree="paddle.tensor.math.round", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - num_arrays=2, - min_num_dims=1, - max_num_dims=1, - shared_dtype=True, + available_dtypes=helpers.get_dtypes("float"), + min_value=1, ), ) -def test_paddle_outer( +def test_paddle_round( *, dtype_and_x, - on_device, - fn_tree, frontend, - backend_fw, test_flags, + fn_tree, + backend_fw, + on_device, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, - frontend=frontend, backend_to_test=backend_fw, - fn_tree=fn_tree, + frontend=frontend, test_flags=test_flags, + fn_tree=fn_tree, on_device=on_device, x=x[0], - y=x[1], ) -# heaviside +# round_ @handle_frontend_test( - fn_tree="paddle.tensor.math.heaviside", + fn_tree="paddle.tensor.math.round_", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - allow_inf=False, - large_abs_safety_factor=2, - small_abs_safety_factor=2, - safety_factor_scale="log", - shared_dtype=True, + min_value=1, ), ) -def test_paddle_heaviside( +def test_paddle_round_( *, dtype_and_x, - on_device, - fn_tree, frontend, - backend_fw, test_flags, + fn_tree, + backend_fw, + on_device, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, - frontend=frontend, backend_to_test=backend_fw, - fn_tree=fn_tree, + frontend=frontend, test_flags=test_flags, + fn_tree=fn_tree, on_device=on_device, x=x[0], - y=x[1], ) @@ -1784,7 +1838,6 @@ def test_paddle_rsqrt( available_dtypes=helpers.get_dtypes("valid"), ), ) - def test_paddle_rsqrt_( *, dtype_and_x, @@ -1806,143 +1859,113 @@ def test_paddle_rsqrt_( ) -# prod @handle_frontend_test( - fn_tree="paddle.tensor.math.prod", - dtype_and_x=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("numeric"), - min_axis=-1, - max_axis=0, + fn_tree="paddle.tensor.math.sgn", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float_and_complex"), min_num_dims=1, + max_num_dims=1, + min_dim_size=1, + max_dim_size=1, + abs_smallest_val=1e-10, min_value=-10, max_value=10, - force_int_axis=False, - allow_nan=False, ), ) -def test_paddle_prod( +def test_paddle_sgn( *, dtype_and_x, on_device, - backend_fw, fn_tree, frontend, + backend_fw, test_flags, ): - input_dtype, x, axis = dtype_and_x + input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, frontend=frontend, + backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, x=x[0], - axis=axis, - keepdim=False, - backend_to_test=backend_fw, ) -# any +# sign @handle_frontend_test( - fn_tree="paddle.tensor.math.any", - dtype_and_x=helpers.dtype_values_axis( - available_dtypes=["bool"], - valid_axis=True, - allow_neg_axes=True, - force_int_axis=True, - min_num_dims=1, + fn_tree="paddle.tensor.math.sign", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), ), ) -def test_paddle_any( +def test_paddle_sign( *, dtype_and_x, on_device, fn_tree, frontend, - test_flags, backend_fw, + test_flags, ): - input_dtype, x, axis = dtype_and_x + input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, frontend=frontend, + backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - backend_to_test=backend_fw, x=x[0], - axis=axis, - keepdim=False, ) -# diff +# sin @handle_frontend_test( - fn_tree="paddle.tensor.math.diff", - dtype_n_x_n_axis=helpers.dtype_values_axis( - available_dtypes=st.shared(helpers.get_dtypes("valid"), key="dtype"), - min_num_dims=1, - valid_axis=True, - force_int_axis=True, - ), - n=st.integers(min_value=1, max_value=1), - dtype_prepend=helpers.dtype_and_values( - available_dtypes=st.shared(helpers.get_dtypes("valid"), key="dtype"), - min_num_dims=1, - max_num_dims=1, - ), - dtype_append=helpers.dtype_and_values( - available_dtypes=st.shared(helpers.get_dtypes("valid"), key="dtype"), - min_num_dims=1, - max_num_dims=1, + fn_tree="paddle.sin", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), ), ) -def test_paddle_diff( +def test_paddle_sin( *, - dtype_n_x_n_axis, - n, - dtype_prepend, - dtype_append, - test_flags, + dtype_and_x, + on_device, + fn_tree, frontend, + test_flags, backend_fw, - fn_tree, - on_device, ): - input_dtype, x, axis = dtype_n_x_n_axis - _, prepend = dtype_prepend - _, append = dtype_append + input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, - test_flags=test_flags, - frontend=frontend, backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, x=x[0], - n=n, - axis=axis, - prepend=prepend[0], - append=append[0], ) -# mm +# sinh @handle_frontend_test( - fn_tree="paddle.tensor.math.mm", - dtype_xy=_get_dtype_input_and_matrices(), + fn_tree="paddle.sinh", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + ), ) -def test_paddle_mm( +def test_paddle_sinh( *, - dtype_xy, + dtype_and_x, on_device, fn_tree, frontend, test_flags, backend_fw, ): - input_dtype, x, y = dtype_xy + input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -1950,42 +1973,27 @@ def test_paddle_mm( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=x, - mat2=y, + x=x[0], ) -# addmm +# sqrt @handle_frontend_test( - fn_tree="paddle.tensor.math.addmm", - dtype_input_xy=_get_dtype_and_3dbatch_matrices(with_input=True, input_3d=True), - beta=st.floats( - min_value=-5, - max_value=5, - allow_nan=False, - allow_subnormal=False, - allow_infinity=False, - ), - alpha=st.floats( - min_value=-5, - max_value=5, - allow_nan=False, - allow_subnormal=False, - allow_infinity=False, + fn_tree="paddle.tensor.math.sqrt", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), ), ) -def test_paddle_addmm( +def test_paddle_sqrt( *, - dtype_input_xy, - beta, - alpha, - on_device, - fn_tree, + dtype_and_x, frontend, test_flags, + fn_tree, backend_fw, + on_device, ): - input_dtype, input, x, y = dtype_input_xy + input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -1993,31 +2001,24 @@ def test_paddle_addmm( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=input[0], x=x[0], - y=y[0], - beta=beta, - alpha=alpha, ) -# kron +# square @handle_frontend_test( - fn_tree="paddle.tensor.math.kron", + fn_tree="paddle.tensor.math.square", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - allow_inf=False, - shared_dtype=True, ), ) -def test_paddle_kron( +def test_paddle_square( *, dtype_and_x, on_device, fn_tree, - backend_fw, frontend, + backend_fw, test_flags, ): input_dtype, x = dtype_and_x @@ -2025,183 +2026,189 @@ def test_paddle_kron( input_dtypes=input_dtype, frontend=frontend, backend_to_test=backend_fw, - fn_tree=fn_tree, test_flags=test_flags, + fn_tree=fn_tree, on_device=on_device, x=x[0], - y=x[1], ) -@st.composite -def _test_paddle_take_helper(draw): - mode = draw(st.sampled_from(["raise", "clip", "wrap"])) - - safe_bounds = mode == "raise" - - dtypes, xs, indices, _, _ = draw( - helpers.array_indices_axis( - array_dtypes=helpers.get_dtypes("float_and_integer"), - indices_dtypes=["int32", "int64"], - valid_bounds=safe_bounds, - ) - ) - - return dtypes, xs, indices, mode - - -# take +# stanh @handle_frontend_test( - fn_tree="paddle.take", dtype_and_values=_test_paddle_take_helper() + fn_tree="paddle.tensor.math.stanh", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + ), + scale_a=st.floats(1e-5, 1e5), + scale_b=st.floats(1e-5, 1e5), ) -def test_paddle_take( +def test_paddle_stanh( *, - dtype_and_values, + dtype_and_x, on_device, fn_tree, - backend_fw, frontend, test_flags, + scale_a, + scale_b, ): - dtypes, xs, indices, modes = dtype_and_values + input_dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=dtypes, - backend_to_test=backend_fw, + input_dtypes=input_dtype, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=xs, - index=indices, - mode=modes, + x=x[0], + scale_a=scale_a, + scale_b=scale_b, ) -# amax +# subtract @handle_frontend_test( - fn_tree="paddle.tensor.math.amax", + fn_tree="paddle.subtract", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), + available_dtypes=helpers.get_dtypes("float"), num_arrays=2, allow_inf=False, + large_abs_safety_factor=2, + small_abs_safety_factor=2, + safety_factor_scale="log", shared_dtype=True, ), ) -def test_paddle_amax( +def test_paddle_subtract( *, dtype_and_x, on_device, fn_tree, - backend_fw, frontend, test_flags, + backend_fw, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, - frontend=frontend, backend_to_test=backend_fw, + frontend=frontend, fn_tree=fn_tree, test_flags=test_flags, on_device=on_device, x=x[0], + y=x[1], ) -# stanh +# take @handle_frontend_test( - fn_tree="paddle.tensor.math.stanh", + fn_tree="paddle.take", dtype_and_values=_test_paddle_take_helper() +) +def test_paddle_take( + *, + dtype_and_values, + on_device, + fn_tree, + backend_fw, + frontend, + test_flags, +): + dtypes, xs, indices, modes = dtype_and_values + helpers.test_frontend_function( + input_dtypes=dtypes, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=xs, + index=indices, + mode=modes, + ) + + +# tan +@handle_frontend_test( + fn_tree="paddle.tensor.math.tan", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), ), - scale_a=st.floats(1e-5, 1e5), - scale_b=st.floats(1e-5, 1e5), ) -def test_paddle_stanh( +def test_paddle_tan( *, dtype_and_x, on_device, fn_tree, frontend, + backend_fw, test_flags, - scale_a, - scale_b, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, frontend=frontend, + backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, + atol=1e-2, x=x[0], - scale_a=scale_a, - scale_b=scale_b, ) -# min +# tanh @handle_frontend_test( - fn_tree="paddle.tensor.math.min", - dtype_and_x=helpers.dtype_values_axis( + fn_tree="paddle.tensor.math.tanh", + aliases=["paddle.tanh", "paddle.nn.functional.tanh"], + dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), - min_axis=-1, - max_axis=0, - min_num_dims=1, - force_int_axis=False, ), ) -def test_paddle_min( +def test_paddle_tanh( *, dtype_and_x, on_device, fn_tree, frontend, - backend_fw, test_flags, + backend_fw, ): - input_dtype, x, axis = dtype_and_x + input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, - frontend=frontend, backend_to_test=backend_fw, + frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, + atol=1e-2, x=x[0], - axis=axis, - keepdim=False, ) -# amin +# trunc @handle_frontend_test( - fn_tree="paddle.tensor.math.amin", - dtype_and_x=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("valid"), - valid_axis=True, + fn_tree="paddle.trunc", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float", "int"), ), - keepdim=st.booleans(), ) -def test_paddle_amin( +def test_paddle_trunc( *, dtype_and_x, - keepdim, on_device, fn_tree, - backend_fw, frontend, + backend_fw, test_flags, ): - input_dtype, x, axis = dtype_and_x + input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, frontend=frontend, backend_to_test=backend_fw, - fn_tree=fn_tree, test_flags=test_flags, + fn_tree=fn_tree, on_device=on_device, x=x[0], - axis=axis, - keepdim=keepdim, ) diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_random.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_random.py index df64bc9163ce2..71c1f9f9fc227 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_random.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_random.py @@ -7,64 +7,53 @@ from ivy_tests.test_ivy.helpers import handle_frontend_test -# randint @handle_frontend_test( - fn_tree="paddle.randint", - low=helpers.ints(min_value=0, max_value=10), - high=helpers.ints(min_value=11, max_value=20), - dtype=helpers.get_dtypes("integer"), - shape=helpers.get_shape( - allow_none=False, min_num_dims=2, max_num_dims=7, min_dim_size=2 + fn_tree="paddle.poisson", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_value=0, + max_value=1000, + min_num_dims=1, + max_num_dims=1, + min_dim_size=2, + max_dim_size=2, ), ) -def test_paddle_randint( - low, - high, - dtype, - backend_fw, - frontend, - test_flags, - shape, - fn_tree, -): +def test_paddle_poisson(dtype_and_x, backend_fw, frontend, test_flags, fn_tree): + dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=dtype, - backend_to_test=backend_fw, frontend=frontend, - test_values=False, - fn_tree=fn_tree, + backend_to_test=backend_fw, test_flags=test_flags, - low=low, - high=high, - shape=shape, + fn_tree=fn_tree, + test_values=False, + x=x[0], ) @handle_frontend_test( - fn_tree="paddle.uniform", - input_dtypes=helpers.get_dtypes("float"), - shape=st.tuples( - st.integers(min_value=2, max_value=5), st.integers(min_value=2, max_value=5) + fn_tree="paddle.rand", + input_dtypes=st.sampled_from(["int32", "int64"]), + shape=helpers.get_shape( + allow_none=False, + min_num_dims=0, + min_dim_size=1, ), dtype=helpers.get_dtypes("valid", full=False), - min=st.floats(allow_nan=False, allow_infinity=False, width=32), - max=st.floats(allow_nan=False, allow_infinity=False, width=32), - seed=st.integers(min_value=2, max_value=5), ) -def test_paddle_uniform( +def test_paddle_rand( + *, input_dtypes, shape, dtype, - min, - max, - seed, frontend, backend_fw, test_flags, fn_tree, ): helpers.test_frontend_function( - input_dtypes=input_dtypes, + input_dtypes=[input_dtypes], frontend=frontend, backend_to_test=backend_fw, test_flags=test_flags, @@ -72,65 +61,77 @@ def test_paddle_uniform( test_values=False, shape=shape, dtype=dtype[0], - min=min, - max=max, - seed=seed, ) +# randint @handle_frontend_test( - fn_tree="paddle.poisson", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_value=0, - max_value=1000, - min_num_dims=1, - max_num_dims=1, - min_dim_size=2, - max_dim_size=2, + fn_tree="paddle.randint", + low=helpers.ints(min_value=0, max_value=10), + high=helpers.ints(min_value=11, max_value=20), + dtype=helpers.get_dtypes("integer"), + shape=helpers.get_shape( + allow_none=False, min_num_dims=2, max_num_dims=7, min_dim_size=2 ), ) -def test_paddle_poisson(dtype_and_x, backend_fw, frontend, test_flags, fn_tree): - dtype, x = dtype_and_x +def test_paddle_randint( + low, + high, + dtype, + backend_fw, + frontend, + test_flags, + shape, + fn_tree, +): helpers.test_frontend_function( input_dtypes=dtype, - frontend=frontend, backend_to_test=backend_fw, - test_flags=test_flags, - fn_tree=fn_tree, + frontend=frontend, test_values=False, - x=x[0], + fn_tree=fn_tree, + test_flags=test_flags, + low=low, + high=high, + shape=shape, ) @handle_frontend_test( - fn_tree="paddle.rand", - input_dtypes=st.sampled_from(["int32", "int64"]), - shape=helpers.get_shape( - allow_none=False, - min_num_dims=0, - min_dim_size=1, + fn_tree="paddle.randint_like", + input_dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + shape=helpers.get_shape( + allow_none=False, min_num_dims=2, max_num_dims=7, min_dim_size=2 + ), ), - dtype=helpers.get_dtypes("valid", full=False), + low=st.integers(min_value=0, max_value=10), + high=st.integers(min_value=11, max_value=20), + dtype=helpers.get_dtypes("integer"), ) -def test_paddle_rand( - *, - input_dtypes, - shape, +def test_paddle_randint_like( + input_dtype_and_x, + low, + high, dtype, frontend, backend_fw, test_flags, fn_tree, + on_device, ): + input_dtype, x = input_dtype_and_x helpers.test_frontend_function( - input_dtypes=[input_dtypes], + input_dtypes=input_dtype, frontend=frontend, backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, + on_device=on_device, test_values=False, - shape=shape, + x=x[0], + low=low, + high=high, dtype=dtype[0], ) @@ -166,58 +167,53 @@ def test_paddle_randn( @handle_frontend_test( - fn_tree="paddle.tensor.random.uniform_", - min=helpers.floats(min_value=-1, max_value=0), - max=helpers.floats(min_value=0.1, max_value=1), - seed=st.integers(min_value=2, max_value=5), - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_value=0, - max_value=1000, + fn_tree="paddle.standard_normal", + input_dtypes=st.sampled_from([["int32"], ["int64"]]), + shape=helpers.get_shape( min_num_dims=1, - max_num_dims=1, - min_dim_size=2, - max_dim_size=2, + min_dim_size=1, ), + dtype=helpers.get_dtypes("valid", full=False), ) -def test_paddle_uniform_( - fn_tree, - min, - max, - seed, - dtype_and_x, +def test_paddle_standard_normal( + input_dtypes, + shape, + dtype, frontend, backend_fw, test_flags, + fn_tree, ): - dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=dtype, + input_dtypes=input_dtypes, frontend=frontend, backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, test_values=False, - x=x[0], - min=min, - max=max, - seed=seed, + shape=shape, + dtype=dtype[0], ) @handle_frontend_test( - fn_tree="paddle.standard_normal", - input_dtypes=st.sampled_from([["int32"], ["int64"]]), - shape=helpers.get_shape( - min_num_dims=1, - min_dim_size=1, + fn_tree="paddle.uniform", + input_dtypes=helpers.get_dtypes("float"), + shape=st.tuples( + st.integers(min_value=2, max_value=5), st.integers(min_value=2, max_value=5) ), dtype=helpers.get_dtypes("valid", full=False), + min=st.floats(allow_nan=False, allow_infinity=False, width=32), + max=st.floats(allow_nan=False, allow_infinity=False, width=32), + seed=st.integers(min_value=2, max_value=5), ) -def test_paddle_standard_normal( +def test_paddle_uniform( input_dtypes, shape, dtype, + min, + max, + seed, frontend, backend_fw, test_flags, @@ -232,43 +228,47 @@ def test_paddle_standard_normal( test_values=False, shape=shape, dtype=dtype[0], + min=min, + max=max, + seed=seed, ) @handle_frontend_test( - fn_tree="paddle.randint_like", - input_dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - shape=helpers.get_shape( - allow_none=False, min_num_dims=2, max_num_dims=7, min_dim_size=2 - ), + fn_tree="paddle.tensor.random.uniform_", + min=helpers.floats(min_value=-1, max_value=0), + max=helpers.floats(min_value=0.1, max_value=1), + seed=st.integers(min_value=2, max_value=5), + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_value=0, + max_value=1000, + min_num_dims=1, + max_num_dims=1, + min_dim_size=2, + max_dim_size=2, ), - low=st.integers(min_value=0, max_value=10), - high=st.integers(min_value=11, max_value=20), - dtype=helpers.get_dtypes("integer"), ) -def test_paddle_randint_like( - input_dtype_and_x, - low, - high, - dtype, +def test_paddle_uniform_( + fn_tree, + min, + max, + seed, + dtype_and_x, frontend, backend_fw, test_flags, - fn_tree, - on_device, ): - input_dtype, x = input_dtype_and_x + dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=dtype, frontend=frontend, backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, - on_device=on_device, test_values=False, x=x[0], - low=low, - high=high, - dtype=dtype[0], + min=min, + max=max, + seed=seed, ) diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_search.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_search.py index eeaf1825c6982..abf9ba56fb1fe 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_search.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_search.py @@ -7,6 +7,30 @@ from ivy_tests.test_ivy.helpers import handle_frontend_test +# --- Helpers --- # +# --------------- # + + +# masked_select +@st.composite +def _dtypes_input_mask(draw): + _shape = draw(helpers.get_shape(min_num_dims=1, min_dim_size=1)) + _mask = draw(helpers.array_values(dtype="bool", shape=_shape)) + _dtype, _x = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + num_arrays=1, + shape=_shape, + ) + ) + + return _dtype, _x, _mask + + +# --- Main --- # +# ------------ # + + @handle_frontend_test( fn_tree="paddle.argmax", dtype_x_and_axis=helpers.dtype_values_axis( @@ -105,38 +129,34 @@ def test_paddle_argsort( ) -# sort @handle_frontend_test( - fn_tree="paddle.tensor.search.sort", - dtype_input_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("valid"), - min_num_dims=1, - valid_axis=True, - force_int_axis=True, - ), - descending=st.booleans(), + fn_tree="paddle.masked_select", + dtype_input_mask=_dtypes_input_mask(), ) -def test_paddle_sort( +def test_paddle_masked_select( *, - dtype_input_axis, - descending, + dtype_input_mask, on_device, fn_tree, frontend, test_flags, backend_fw, ): - input_dtype, x, axis = dtype_input_axis + ( + input_dtype, + x, + mask, + ) = dtype_input_mask + helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=input_dtype + ["bool"], + backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - backend_to_test=backend_fw, x=x[0], - axis=axis, - descending=descending, + mask=mask, ) @@ -210,50 +230,38 @@ def test_paddle_searchsorted( ) -# masked_select -@st.composite -def _dtypes_input_mask(draw): - _shape = draw(helpers.get_shape(min_num_dims=1, min_dim_size=1)) - _mask = draw(helpers.array_values(dtype="bool", shape=_shape)) - _dtype, _x = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - num_arrays=1, - shape=_shape, - ) - ) - - return _dtype, _x, _mask - - +# sort @handle_frontend_test( - fn_tree="paddle.masked_select", - dtype_input_mask=_dtypes_input_mask(), + fn_tree="paddle.tensor.search.sort", + dtype_input_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=1, + valid_axis=True, + force_int_axis=True, + ), + descending=st.booleans(), ) -def test_paddle_masked_select( +def test_paddle_sort( *, - dtype_input_mask, + dtype_input_axis, + descending, on_device, fn_tree, frontend, test_flags, backend_fw, ): - ( - input_dtype, - x, - mask, - ) = dtype_input_mask - + input_dtype, x, axis = dtype_input_axis helpers.test_frontend_function( - input_dtypes=input_dtype + ["bool"], - backend_to_test=backend_fw, + input_dtypes=input_dtype, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, + backend_to_test=backend_fw, x=x[0], - mask=mask, + axis=axis, + descending=descending, ) diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_stat.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_stat.py index 9a1e99cf6a9ec..a7b41e0baf820 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_stat.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_stat.py @@ -41,37 +41,37 @@ def test_paddle_mean( ) -# numel +# median @handle_frontend_test( - fn_tree="paddle.numel", - dtype_and_x=helpers.dtype_and_values( + fn_tree="paddle.median", + dtype_x_and_axis=helpers.dtype_values_axis( available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=1, + min_value=-1e10, + max_value=1e10, + valid_axis=True, + force_int_axis=True, ), + keepdim=st.booleans(), ) -def test_paddle_numel( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - backend_fw, - test_flags, +def test_paddle_median( + dtype_x_and_axis, keepdim, backend_fw, frontend, test_flags, fn_tree ): - input_dtype, x = dtype_and_x + input_dtypes, x, axis = dtype_x_and_axis helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=input_dtypes, frontend=frontend, backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, - on_device=on_device, x=x[0], + axis=axis, + keepdim=keepdim, ) -# median @handle_frontend_test( - fn_tree="paddle.median", + fn_tree="paddle.nanmedian", dtype_x_and_axis=helpers.dtype_values_axis( available_dtypes=helpers.get_dtypes("valid"), min_num_dims=1, @@ -82,14 +82,17 @@ def test_paddle_numel( ), keepdim=st.booleans(), ) -def test_paddle_median( - dtype_x_and_axis, keepdim, backend_fw, frontend, test_flags, fn_tree +def test_paddle_nanmedian( + dtype_x_and_axis, + keepdim, + frontend, + test_flags, + fn_tree, ): input_dtypes, x, axis = dtype_x_and_axis helpers.test_frontend_function( input_dtypes=input_dtypes, frontend=frontend, - backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, x=x[0], @@ -98,76 +101,73 @@ def test_paddle_median( ) -# var +# numel @handle_frontend_test( - fn_tree="paddle.var", - dtype_and_x=_statistical_dtype_values(function="var"), - unbiased=st.booleans(), - keepdim=st.booleans(), + fn_tree="paddle.numel", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + ), ) -def test_paddle_var( +def test_paddle_numel( *, - unbiased, dtype_and_x, - keepdim, + on_device, fn_tree, frontend, backend_fw, test_flags, ): - input_dtype, x, axis, _ = dtype_and_x + input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, - backend_to_test=backend_fw, frontend=frontend, + backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, + on_device=on_device, x=x[0], - axis=axis, - unbiased=unbiased, - keepdim=keepdim, ) +# std @handle_frontend_test( - fn_tree="paddle.nanmedian", - dtype_x_and_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("valid"), - min_num_dims=1, - min_value=-1e10, - max_value=1e10, - valid_axis=True, - force_int_axis=True, - ), + fn_tree="paddle.std", + dtype_and_x=_statistical_dtype_values(function="std"), + unbiased=st.booleans(), keepdim=st.booleans(), ) -def test_paddle_nanmedian( - dtype_x_and_axis, +def test_paddle_std( + *, + unbiased, + dtype_and_x, keepdim, + fn_tree, frontend, + backend_fw, test_flags, - fn_tree, ): - input_dtypes, x, axis = dtype_x_and_axis + input_dtype, x, axis, _ = dtype_and_x helpers.test_frontend_function( - input_dtypes=input_dtypes, + input_dtypes=input_dtype, + backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, x=x[0], axis=axis, + unbiased=unbiased, keepdim=keepdim, ) -# std +# var @handle_frontend_test( - fn_tree="paddle.std", - dtype_and_x=_statistical_dtype_values(function="std"), + fn_tree="paddle.var", + dtype_and_x=_statistical_dtype_values(function="var"), unbiased=st.booleans(), keepdim=st.booleans(), ) -def test_paddle_std( +def test_paddle_var( *, unbiased, dtype_and_x, diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_tensor.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_tensor.py index be9b7139c9745..2e9a57c783373 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_tensor.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_tensor.py @@ -19,32 +19,46 @@ CLASS_TREE = "ivy.functional.frontends.paddle.Tensor" -# Helpers # -# ------- # +# --- Helpers --- # +# --------------- # -@st.composite -def _reshape_helper(draw): - # generate a shape s.t len(shape) > 0 - shape = draw( - helpers.get_shape( - allow_none=False, - min_num_dims=1, - max_num_dims=3, - min_dim_size=1, - max_dim_size=3, +def _filter_query(query): + return ( + query.ndim > 1 + if isinstance(query, np.ndarray) + else ( + not any(isinstance(i, np.ndarray) and i.ndim <= 1 for i in query) + if isinstance(query, tuple) + else True ) ) - reshape_shape = draw(helpers.reshape_shapes(shape=shape)) - dtypes, x = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - shape=shape, +# cond +@st.composite +def _get_dtype_and_matrix_non_singular(draw, dtypes): + while True: + matrix = draw( + helpers.dtype_and_values( + available_dtypes=dtypes, + min_value=-10, + max_value=10, + min_num_dims=2, + max_num_dims=2, + min_dim_size=1, + max_dim_size=5, + shape=st.tuples(st.integers(1, 5), st.integers(1, 5)).filter( + lambda x: x[0] == x[1] + ), + allow_inf=False, + allow_nan=False, + ) ) - ) - return dtypes, x, reshape_shape + if np.linalg.det(matrix[1][0]) != 0: + break + + return matrix[0], matrix[1] @st.composite @@ -59,76 +73,45 @@ def _get_dtype_and_square_matrix(draw): return dtype, mat -# Tests # -# ----- # - - -@given( - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid", prune_function=False) - ).filter(lambda x: "bfloat16" not in x[0]), -) -def test_paddle_tensor_device( - dtype_x, -): - _, data = dtype_x - x = Tensor(data[0]) - x.ivy_array = data[0] - ivy.utils.assertions.check_equal( - x.place, ivy.dev(ivy.array(data[0])), as_array=False +@st.composite +def _reshape_helper(draw): + # generate a shape s.t len(shape) > 0 + shape = draw( + helpers.get_shape( + allow_none=False, + min_num_dims=1, + max_num_dims=3, + min_dim_size=1, + max_dim_size=3, + ) ) + reshape_shape = draw(helpers.reshape_shapes(shape=shape)) -@given( - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid", prune_function=False) - ).filter(lambda x: "bfloat16" not in x[0]), -) -def test_paddle_tensor_dtype( - dtype_x, -): - dtype, data = dtype_x - x = Tensor(data[0]) - x.ivy_array = data[0] - ivy.utils.assertions.check_equal(x.dtype, dtype[0], as_array=False) - - -@given( - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid", prune_function=False), - ret_shape=True, - ).filter(lambda x: "bfloat16" not in x[0]), -) -def test_paddle_tensor_shape(dtype_x): - _, data, shape = dtype_x - x = Tensor(data[0]) - ivy.utils.assertions.check_equal( - x.ivy_array.shape, ivy.Shape(shape), as_array=False + dtypes, x = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + shape=shape, + ) ) + return dtypes, x, reshape_shape -@given( - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid", prune_function=False), - ).filter(lambda x: "bfloat16" not in x[0]), -) -def test_paddle_tensor_ndim( - dtype_x, -): - _, data = dtype_x - x = Tensor(data[0]) - ivy.utils.assertions.check_equal(x.ndim, data[0].ndim, as_array=False) +# --- Main --- # +# ------------ # -# reshape +# __setitem__ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", - method_name="reshape", - dtype_x_shape=_reshape_helper(), + method_name="__setitem__", + dtypes_x_index_val=helpers.dtype_array_query_val( + available_dtypes=helpers.get_dtypes("valid"), + ).filter(lambda x: x[0][0] == x[0][-1] and _filter_query(x[-2])), ) -def test_paddle__reshape( - dtype_x_shape, +def test_paddle___setitem__( + dtypes_x_index_val, frontend_method_data, init_flags, method_flags, @@ -136,19 +119,13 @@ def test_paddle__reshape( on_device, backend_fw, ): - input_dtype, x, shape = dtype_x_shape - assume(len(shape) != 0) - shape = { - "shape": shape, - } + input_dtype, x, index, val = dtypes_x_index_val helpers.test_frontend_method( - init_input_dtypes=input_dtype, + init_input_dtypes=[input_dtype[0]], backend_to_test=backend_fw, - init_all_as_kwargs_np={ - "data": x[0], - }, - method_input_dtypes=input_dtype, - method_all_as_kwargs_np=shape, + init_all_as_kwargs_np={"data": x}, + method_input_dtypes=[*input_dtype[1:]], + method_all_as_kwargs_np={"item": index, "value": val}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -157,18 +134,6 @@ def test_paddle__reshape( ) -def _filter_query(query): - return ( - query.ndim > 1 - if isinstance(query, np.ndarray) - else ( - not any(isinstance(i, np.ndarray) and i.ndim <= 1 for i in query) - if isinstance(query, tuple) - else True - ) - ) - - # __getitem__ @handle_frontend_method( class_tree=CLASS_TREE, @@ -203,17 +168,15 @@ def test_paddle__getitem__( ) -# __setitem__ +# reshape @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", - method_name="__setitem__", - dtypes_x_index_val=helpers.dtype_array_query_val( - available_dtypes=helpers.get_dtypes("valid"), - ).filter(lambda x: x[0][0] == x[0][-1] and _filter_query(x[-2])), + method_name="reshape", + dtype_x_shape=_reshape_helper(), ) -def test_paddle___setitem__( - dtypes_x_index_val, +def test_paddle__reshape( + dtype_x_shape, frontend_method_data, init_flags, method_flags, @@ -221,13 +184,19 @@ def test_paddle___setitem__( on_device, backend_fw, ): - input_dtype, x, index, val = dtypes_x_index_val + input_dtype, x, shape = dtype_x_shape + assume(len(shape) != 0) + shape = { + "shape": shape, + } helpers.test_frontend_method( - init_input_dtypes=[input_dtype[0]], + init_input_dtypes=input_dtype, backend_to_test=backend_fw, - init_all_as_kwargs_np={"data": x}, - method_input_dtypes=[*input_dtype[1:]], - method_all_as_kwargs_np={"item": index, "value": val}, + init_all_as_kwargs_np={ + "data": x[0], + }, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np=shape, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -236,16 +205,16 @@ def test_paddle___setitem__( ) -# dim +# atan @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", - method_name="dim", + method_name="atan", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), + available_dtypes=helpers.get_dtypes("float"), ), ) -def test_paddle_tensor_dim( +def test_paddle_instance_atan( dtype_and_x, frontend_method_data, init_flags, @@ -257,11 +226,11 @@ def test_paddle_tensor_dim( input_dtype, x = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, - backend_to_test=backend_fw, init_all_as_kwargs_np={ "data": x[0], }, - method_input_dtypes=[], + backend_to_test=backend_fw, + method_input_dtypes=input_dtype, method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -271,51 +240,53 @@ def test_paddle_tensor_dim( ) -# abs +# var @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", - method_name="abs", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - ), + method_name="var", + dtype_and_x=_statistical_dtype_values(function="var"), + keepdim=st.booleans(), ) -def test_paddle_tensor_abs( +def test_paddle_instance_var( dtype_and_x, + keepdim, + frontend, + backend_fw, frontend_method_data, init_flags, method_flags, - frontend, on_device, - backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, x, axis, correction = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, - backend_to_test=backend_fw, - init_all_as_kwargs_np={ - "data": x[0], - }, + init_all_as_kwargs_np={"data": x[0]}, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, + method_all_as_kwargs_np={ + "axis": axis, + "unbiased": bool(correction), + "keepdim": keepdim, + }, + frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, + backend_to_test=backend_fw, method_flags=method_flags, - frontend=frontend, on_device=on_device, ) -# sin +# abs @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", - method_name="sin", + method_name="abs", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), ), ) -def test_paddle_tensor_sin( +def test_paddle_tensor_abs( dtype_and_x, frontend_method_data, init_flags, @@ -341,16 +312,16 @@ def test_paddle_tensor_sin( ) -# sinh +# acosh @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", - method_name="sinh", + method_name="acosh", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), ), ) -def test_paddle_tensor_sinh( +def test_paddle_tensor_acosh( dtype_and_x, frontend_method_data, init_flags, @@ -376,16 +347,18 @@ def test_paddle_tensor_sinh( ) -# asin +# __(add_)__ + + @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", - method_name="asin", + method_name="add_", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), ), ) -def test_paddle_tensor_asin( +def test_paddle_tensor_add_( dtype_and_x, frontend_method_data, init_flags, @@ -411,17 +384,23 @@ def test_paddle_tensor_asin( ) -# asinh +# all @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", - method_name="asinh", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + method_name="all", + dtype_x_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("bool"), + min_axis=-1, + max_axis=0, + min_num_dims=1, + force_int_axis=True, ), + keep_dims=st.booleans(), ) -def test_paddle_tensor_asinh( - dtype_and_x, +def test_paddle_tensor_all( + dtype_x_axis, + keep_dims, frontend_method_data, init_flags, method_flags, @@ -429,34 +408,45 @@ def test_paddle_tensor_asinh( on_device, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtypes, x, axis = dtype_x_axis helpers.test_frontend_method( - init_input_dtypes=input_dtype, + init_input_dtypes=input_dtypes, backend_to_test=backend_fw, init_all_as_kwargs_np={ - "data": x[0], + "object": x[0], }, - method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, + method_input_dtypes=input_dtypes, + method_all_as_kwargs_np={ + "axis": axis, + "keepdim": keep_dims, + }, + frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, - frontend=frontend, on_device=on_device, ) -# cosh +# allclose @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", - method_name="cosh", + method_name="allclose", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, + shared_dtype=True, ), + # rtol=1e-05, + # atol=1e-08, + # equal_nan=st.booleans(), ) -def test_paddle_tensor_cosh( +def test_paddle_tensor_allclose( dtype_and_x, + # rtol, + # atol, + # equal_nan, frontend_method_data, init_flags, method_flags, @@ -472,25 +462,29 @@ def test_paddle_tensor_cosh( "data": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, + method_all_as_kwargs_np={ + "other": x[1], + # "rtol": rtol, + # "atol": atol, + # "equal_nan": equal_nan, + }, + frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, - frontend=frontend, on_device=on_device, ) -# log @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", - method_name="log", + method_name="angle", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=["float64", "complex64", "complex128"], ), ) -def test_paddle_tensor_log( +def test_paddle_tensor_angle( dtype_and_x, frontend_method_data, init_flags, @@ -516,11 +510,11 @@ def test_paddle_tensor_log( ) -# argmax +# any @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", - method_name="argmax", + method_name="any", dtype_x_axis=helpers.dtype_values_axis( available_dtypes=st.one_of(helpers.get_dtypes("float")), min_axis=-1, @@ -530,7 +524,7 @@ def test_paddle_tensor_log( ), keep_dims=st.booleans(), ) -def test_paddle_tensor_argmax( +def test_paddle_tensor_any( dtype_x_axis, keep_dims, frontend_method_data, @@ -545,7 +539,7 @@ def test_paddle_tensor_argmax( init_input_dtypes=input_dtypes, backend_to_test=backend_fw, init_all_as_kwargs_np={ - "object": x[0], + "data": x[0], }, method_input_dtypes=input_dtypes, method_all_as_kwargs_np={ @@ -560,24 +554,23 @@ def test_paddle_tensor_argmax( ) -# unsqueeze +# argmax @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", - method_name="unsqueeze", - dtype_value=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - shape=st.shared(helpers.get_shape(), key="shape"), - ), - axis=helpers.get_axis( - shape=st.shared(helpers.get_shape(), key="shape"), - allow_neg=True, - force_int=True, + method_name="argmax", + dtype_x_axis=helpers.dtype_values_axis( + available_dtypes=st.one_of(helpers.get_dtypes("float")), + min_axis=-1, + max_axis=0, + min_num_dims=1, + force_int_axis=True, ), + keep_dims=st.booleans(), ) -def test_paddle_tensor_unsqueeze( - dtype_value, - axis, +def test_paddle_tensor_argmax( + dtype_x_axis, + keep_dims, frontend_method_data, init_flags, method_flags, @@ -585,36 +578,43 @@ def test_paddle_tensor_unsqueeze( on_device, backend_fw, ): - input_dtype, x = dtype_value + input_dtypes, x, axis = dtype_x_axis helpers.test_frontend_method( - init_input_dtypes=input_dtype, + init_input_dtypes=input_dtypes, backend_to_test=backend_fw, init_all_as_kwargs_np={ - "data": x[0], + "object": x[0], }, - method_input_dtypes=input_dtype, + method_input_dtypes=input_dtypes, method_all_as_kwargs_np={ "axis": axis, + "keepdim": keep_dims, }, + frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, - frontend=frontend, on_device=on_device, ) -# exp +# argsort @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", - method_name="exp", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + method_name="argsort", + dtype_x_axis=helpers.dtype_values_axis( + available_dtypes=st.one_of(helpers.get_dtypes("float")), + min_axis=-1, + max_axis=0, + min_num_dims=1, + force_int_axis=True, ), + descending=st.booleans(), ) -def test_paddle_tensor_exp( - dtype_and_x, +def test_paddle_tensor_argsort( + dtype_x_axis, + descending, frontend_method_data, init_flags, method_flags, @@ -622,33 +622,36 @@ def test_paddle_tensor_exp( on_device, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtypes, x, axis = dtype_x_axis helpers.test_frontend_method( - init_input_dtypes=input_dtype, + init_input_dtypes=input_dtypes, backend_to_test=backend_fw, init_all_as_kwargs_np={ - "data": x[0], + "object": x[0], }, - method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, + method_input_dtypes=input_dtypes, + method_all_as_kwargs_np={ + "axis": axis, + "descending": descending, + }, + frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, - frontend=frontend, on_device=on_device, ) -# exp_ +# asin @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", - method_name="exp_", + method_name="asin", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), ), ) -def test_paddle_tensor_exp_( +def test_paddle_tensor_asin( dtype_and_x, frontend_method_data, init_flags, @@ -674,16 +677,16 @@ def test_paddle_tensor_exp_( ) -# cos +# asinh @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", - method_name="cos", + method_name="asinh", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), ), ) -def test_paddle_tensor_cos( +def test_paddle_tensor_asinh( dtype_and_x, frontend_method_data, init_flags, @@ -709,17 +712,19 @@ def test_paddle_tensor_cos( ) -# log10 +# astype @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", - method_name="log10", + method_name="astype", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), ), + dtype=st.one_of(helpers.get_dtypes("valid")), ) -def test_paddle_tensor_log10( +def test_paddle_tensor_astype( dtype_and_x, + dtype, frontend_method_data, init_flags, method_flags, @@ -728,6 +733,8 @@ def test_paddle_tensor_log10( backend_fw, ): input_dtype, x = dtype_and_x + if dtype is None: + dtype = input_dtype helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -735,7 +742,9 @@ def test_paddle_tensor_log10( "data": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, + method_all_as_kwargs_np={ + "dtype": dtype, + }, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -744,23 +753,17 @@ def test_paddle_tensor_log10( ) -# argsort +# bitwise_and @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", - method_name="argsort", - dtype_x_axis=helpers.dtype_values_axis( - available_dtypes=st.one_of(helpers.get_dtypes("float")), - min_axis=-1, - max_axis=0, - min_num_dims=1, - force_int_axis=True, + method_name="bitwise_and", + dtypes_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), num_arrays=2, shared_dtype=True ), - descending=st.booleans(), ) -def test_paddle_tensor_argsort( - dtype_x_axis, - descending, +def test_paddle_tensor_bitwise_and( + dtypes_and_x, frontend_method_data, init_flags, method_flags, @@ -768,36 +771,31 @@ def test_paddle_tensor_argsort( on_device, backend_fw, ): - input_dtypes, x, axis = dtype_x_axis + input_dtype, x = dtypes_and_x helpers.test_frontend_method( - init_input_dtypes=input_dtypes, + init_input_dtypes=input_dtype, backend_to_test=backend_fw, - init_all_as_kwargs_np={ - "object": x[0], - }, - method_input_dtypes=input_dtypes, - method_all_as_kwargs_np={ - "axis": axis, - "descending": descending, - }, - frontend=frontend, + init_all_as_kwargs_np={"data": x[0]}, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={"y": x[1]}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, + frontend=frontend, on_device=on_device, ) -# floor +# bitwise_not @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", - method_name="floor", + method_name="bitwise_not", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("valid"), ), ) -def test_paddle_tensor_floor( +def test_paddle_tensor_bitwise_not( dtype_and_x, frontend_method_data, init_flags, @@ -823,17 +821,16 @@ def test_paddle_tensor_floor( ) -# sqrt @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", - method_name="sqrt", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), + method_name="bitwise_or", + dtypes_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), num_arrays=2, shared_dtype=True ), ) -def test_paddle_tensor_sqrt( - dtype_and_x, +def test_paddle_tensor_bitwise_or( + dtypes_and_x, frontend_method_data, init_flags, method_flags, @@ -841,15 +838,13 @@ def test_paddle_tensor_sqrt( on_device, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, x = dtypes_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, - init_all_as_kwargs_np={ - "data": x[0], - }, + init_all_as_kwargs_np={"data": x[0]}, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, + method_all_as_kwargs_np={"y": x[1]}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -858,31 +853,31 @@ def test_paddle_tensor_sqrt( ) -# sqrt_ +# bitwise_xor @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", - method_name="sqrt_", - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), + method_name="bitwise_xor", + dtypes_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), num_arrays=2, shared_dtype=True ), ) -def test_torch_tensor_sqrt_( - dtype_x, - frontend, +def test_paddle_tensor_bitwise_xor( + dtypes_and_x, frontend_method_data, init_flags, method_flags, + frontend, on_device, backend_fw, ): - input_dtype, x = dtype_x + input_dtype, x = dtypes_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={"data": x[0]}, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, + method_all_as_kwargs_np={"y": x[1]}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -891,16 +886,16 @@ def test_torch_tensor_sqrt_( ) -# atan +# ceil @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", - method_name="atan", + method_name="ceil", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), ), ) -def test_paddle_instance_atan( +def test_paddle_tensor_ceil( dtype_and_x, frontend_method_data, init_flags, @@ -912,10 +907,10 @@ def test_paddle_instance_atan( input_dtype, x = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, + backend_to_test=backend_fw, init_all_as_kwargs_np={ "data": x[0], }, - backend_to_test=backend_fw, method_input_dtypes=input_dtype, method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, @@ -926,17 +921,17 @@ def test_paddle_instance_atan( ) -# tanh +# cholesky @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", - method_name="tanh", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - ), -) -def test_paddle_tensor_tanh( + method_name="cholesky", + dtype_and_x=_get_dtype_and_square_matrix(), + upper=st.booleans(), +) +def test_paddle_tensor_cholesky( dtype_and_x, + upper, frontend_method_data, init_flags, method_flags, @@ -945,35 +940,34 @@ def test_paddle_tensor_tanh( backend_fw, ): input_dtype, x = dtype_and_x + x = np.matmul(x.T, x) + np.identity(x.shape[0]) + helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ - "data": x[0], + "data": x, }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, + method_all_as_kwargs_np={"upper": upper}, + frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, - frontend=frontend, on_device=on_device, ) -# __(add_)__ - - @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", - method_name="add_", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - ), + method_name="cond", + dtype_and_x=_get_dtype_and_matrix_non_singular(dtypes=["float32", "float64"]), + p=st.sampled_from([None, "fro", "nuc", np.inf, -np.inf, 1, -1, 2, -2]), ) -def test_paddle_tensor_add_( +def test_paddle_tensor_cond( dtype_and_x, + p, frontend_method_data, init_flags, method_flags, @@ -989,7 +983,7 @@ def test_paddle_tensor_add_( "data": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, + method_all_as_kwargs_np={"p": p}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -998,16 +992,16 @@ def test_paddle_tensor_add_( ) -# square +# conj @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", - method_name="square", + method_name="conj", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("numeric"), ), ) -def test_paddle_tensor_square( +def test_paddle_tensor_conj( dtype_and_x, frontend_method_data, init_flags, @@ -1033,18 +1027,16 @@ def test_paddle_tensor_square( ) -# remainder_ +# cos @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", - method_name="remainder_", + method_name="cos", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - shared_dtype=True, ), ) -def test_paddle_tensor_remainder_( +def test_paddle_tensor_cos( dtype_and_x, frontend_method_data, init_flags, @@ -1058,12 +1050,10 @@ def test_paddle_tensor_remainder_( init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ - "value": x[0], + "data": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "y": x[1], - }, + method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -1072,17 +1062,17 @@ def test_paddle_tensor_remainder_( ) -# cholesky +# cosh @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", - method_name="cholesky", - dtype_and_x=_get_dtype_and_square_matrix(), - upper=st.booleans(), + method_name="cosh", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + ), ) -def test_paddle_tensor_cholesky( +def test_paddle_tensor_cosh( dtype_and_x, - upper, frontend_method_data, init_flags, method_flags, @@ -1091,20 +1081,18 @@ def test_paddle_tensor_cholesky( backend_fw, ): input_dtype, x = dtype_and_x - x = np.matmul(x.T, x) + np.identity(x.shape[0]) - helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ - "data": x, + "data": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={"upper": upper}, - frontend=frontend, + method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, + frontend=frontend, on_device=on_device, ) @@ -1112,15 +1100,18 @@ def test_paddle_tensor_cholesky( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", - method_name="multiply", - dtype_and_x=helpers.dtype_and_values( + method_name="cumprod", + dtype_x_axis=helpers.dtype_values_axis( available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - shared_dtype=True, + valid_axis=True, + force_int_axis=True, + min_num_dims=1, + min_value=-5, + max_value=5, ), ) -def test_paddle_tensor_multiply( - dtype_and_x, +def test_paddle_tensor_cumprod( + dtype_x_axis, frontend_method_data, init_flags, method_flags, @@ -1128,17 +1119,15 @@ def test_paddle_tensor_multiply( on_device, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, x, axis = dtype_x_axis helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ - "value": x[0], + "data": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "y": x[1], - }, + method_all_as_kwargs_np={"dim": axis}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -1147,23 +1136,21 @@ def test_paddle_tensor_multiply( ) -# all @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", - method_name="all", + method_name="cumsum", dtype_x_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("bool"), - min_axis=-1, - max_axis=0, - min_num_dims=1, + available_dtypes=helpers.get_dtypes("float"), + valid_axis=True, force_int_axis=True, + min_num_dims=1, + min_value=-5, + max_value=5, ), - keep_dims=st.booleans(), ) -def test_paddle_tensor_all( +def test_paddle_tensor_cumsum( dtype_x_axis, - keep_dims, frontend_method_data, init_flags, method_flags, @@ -1171,45 +1158,34 @@ def test_paddle_tensor_all( on_device, backend_fw, ): - input_dtypes, x, axis = dtype_x_axis + input_dtype, x, axis = dtype_x_axis helpers.test_frontend_method( - init_input_dtypes=input_dtypes, + init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ - "object": x[0], - }, - method_input_dtypes=input_dtypes, - method_all_as_kwargs_np={ - "axis": axis, - "keepdim": keep_dims, + "data": x[0], }, - frontend=frontend, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={"axis": axis}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, + frontend=frontend, on_device=on_device, ) -# allclose +# deg2rad @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", - method_name="allclose", + method_name="deg2rad", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - shared_dtype=True, ), - # rtol=1e-05, - # atol=1e-08, - # equal_nan=st.booleans(), ) -def test_paddle_tensor_allclose( +def test_paddle_tensor_deg2rad( dtype_and_x, - # rtol, - # atol, - # equal_nan, frontend_method_data, init_flags, method_flags, @@ -1225,37 +1201,46 @@ def test_paddle_tensor_allclose( "data": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "other": x[1], - # "rtol": rtol, - # "atol": atol, - # "equal_nan": equal_nan, - }, - frontend=frontend, + method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, + frontend=frontend, on_device=on_device, ) -# sort +# Tests # +# ----- # + + +@given( + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid", prune_function=False) + ).filter(lambda x: "bfloat16" not in x[0]), +) +def test_paddle_tensor_device( + dtype_x, +): + _, data = dtype_x + x = Tensor(data[0]) + x.ivy_array = data[0] + ivy.utils.assertions.check_equal( + x.place, ivy.dev(ivy.array(data[0])), as_array=False + ) + + +# dim @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", - method_name="sort", - dtype_x_axis=helpers.dtype_values_axis( - available_dtypes=st.one_of(helpers.get_dtypes("float")), - min_axis=-1, - max_axis=0, - min_num_dims=1, - force_int_axis=True, + method_name="dim", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), ), - descending=st.booleans(), ) -def test_paddle_tensor_sort( - dtype_x_axis, - descending, +def test_paddle_tensor_dim( + dtype_and_x, frontend_method_data, init_flags, method_flags, @@ -1263,43 +1248,38 @@ def test_paddle_tensor_sort( on_device, backend_fw, ): - input_dtypes, x, axis = dtype_x_axis + input_dtype, x = dtype_and_x helpers.test_frontend_method( - init_input_dtypes=input_dtypes, + init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ - "object": x[0], - }, - method_input_dtypes=input_dtypes, - method_all_as_kwargs_np={ - "axis": axis, - "descending": descending, + "data": x[0], }, - frontend=frontend, + method_input_dtypes=[], + method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, + frontend=frontend, on_device=on_device, ) -# any +# divide @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", - method_name="any", - dtype_x_axis=helpers.dtype_values_axis( - available_dtypes=st.one_of(helpers.get_dtypes("float")), - min_axis=-1, - max_axis=0, - min_num_dims=1, - force_int_axis=True, + method_name="divide", + dtypes_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + num_arrays=2, + shared_dtype=True, + safety_factor_scale="log", + small_abs_safety_factor=32, ), - keep_dims=st.booleans(), ) -def test_paddle_tensor_any( - dtype_x_axis, - keep_dims, +def test_paddle_tensor_divide( + dtypes_and_x, frontend_method_data, init_flags, method_flags, @@ -1307,37 +1287,48 @@ def test_paddle_tensor_any( on_device, backend_fw, ): - input_dtypes, x, axis = dtype_x_axis + input_dtype, x = dtypes_and_x helpers.test_frontend_method( - init_input_dtypes=input_dtypes, + init_input_dtypes=input_dtype, backend_to_test=backend_fw, - init_all_as_kwargs_np={ - "data": x[0], - }, - method_input_dtypes=input_dtypes, - method_all_as_kwargs_np={ - "axis": axis, - "keepdim": keep_dims, - }, - frontend=frontend, + init_all_as_kwargs_np={"data": x[0]}, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={"y": x[1]}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, + frontend=frontend, on_device=on_device, ) -# isinf +@given( + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid", prune_function=False) + ).filter(lambda x: "bfloat16" not in x[0]), +) +def test_paddle_tensor_dtype( + dtype_x, +): + dtype, data = dtype_x + x = Tensor(data[0]) + x.ivy_array = data[0] + ivy.utils.assertions.check_equal(x.dtype, dtype[0], as_array=False) + + +# equal @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", - method_name="isinf", - dtype_and_x=helpers.dtype_and_values( + method_name="equal", + dtypes_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), + num_arrays=2, + shared_dtype=True, ), ) -def test_paddle_tensor_isinf( - dtype_and_x, +def test_paddle_tensor_equal( + dtypes_and_x, frontend_method_data, init_flags, method_flags, @@ -1345,15 +1336,13 @@ def test_paddle_tensor_isinf( on_device, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, x = dtypes_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, - init_all_as_kwargs_np={ - "data": x[0], - }, + init_all_as_kwargs_np={"data": x[0]}, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, + method_all_as_kwargs_np={"y": x[1]}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -1362,19 +1351,23 @@ def test_paddle_tensor_isinf( ) -# astype +# equal_all @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", - method_name="astype", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + method_name="equal_all", + dtypes_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + num_arrays=2, + min_value=-np.inf, + max_value=np.inf, + shared_dtype=True, + safety_factor_scale="log", + small_abs_safety_factor=32, ), - dtype=st.one_of(helpers.get_dtypes("valid")), ) -def test_paddle_tensor_astype( - dtype_and_x, - dtype, +def test_paddle_tensor_equal_all( + dtypes_and_x, frontend_method_data, init_flags, method_flags, @@ -1382,19 +1375,13 @@ def test_paddle_tensor_astype( on_device, backend_fw, ): - input_dtype, x = dtype_and_x - if dtype is None: - dtype = input_dtype + input_dtype, x = dtypes_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, - init_all_as_kwargs_np={ - "data": x[0], - }, + init_all_as_kwargs_np={"data": x[0]}, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "dtype": dtype, - }, + method_all_as_kwargs_np={"y": x[1]}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -1403,16 +1390,16 @@ def test_paddle_tensor_astype( ) -# isfinite +# erf @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", - method_name="isfinite", + method_name="erf", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), ), ) -def test_paddle_tensor_isfinite( +def test_paddle_tensor_erf( dtype_and_x, frontend_method_data, init_flags, @@ -1438,16 +1425,16 @@ def test_paddle_tensor_isfinite( ) -# erf +# exp @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", - method_name="erf", + method_name="exp", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), + available_dtypes=helpers.get_dtypes("float"), ), ) -def test_paddle_tensor_erf( +def test_paddle_tensor_exp( dtype_and_x, frontend_method_data, init_flags, @@ -1473,17 +1460,17 @@ def test_paddle_tensor_erf( ) -# subtract +# exp_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", - method_name="subtract", - dtypes_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), num_arrays=2, shared_dtype=True + method_name="exp_", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), ), ) -def test_paddle_tensor_subtract( - dtypes_and_x, +def test_paddle_tensor_exp_( + dtype_and_x, frontend_method_data, init_flags, method_flags, @@ -1491,13 +1478,15 @@ def test_paddle_tensor_subtract( on_device, backend_fw, ): - input_dtype, x = dtypes_and_x + input_dtype, x = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, - init_all_as_kwargs_np={"data": x[0]}, + init_all_as_kwargs_np={ + "data": x[0], + }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={"y": x[1]}, + method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -1506,17 +1495,17 @@ def test_paddle_tensor_subtract( ) -# subtract_ +# floor @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", - method_name="subtract_", - dtypes_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), num_arrays=2, shared_dtype=True + method_name="floor", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), ), ) -def test_paddle_tensor_subtract_( - dtypes_and_x, +def test_paddle_tensor_floor( + dtype_and_x, frontend_method_data, init_flags, method_flags, @@ -1524,13 +1513,15 @@ def test_paddle_tensor_subtract_( on_device, backend_fw, ): - input_dtype, x = dtypes_and_x + input_dtype, x = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, - init_all_as_kwargs_np={"data": x[0]}, + init_all_as_kwargs_np={ + "data": x[0], + }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={"y": x[1]}, + method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -1539,17 +1530,17 @@ def test_paddle_tensor_subtract_( ) -# bitwise_xor +# floor_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", - method_name="bitwise_xor", - dtypes_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), num_arrays=2, shared_dtype=True + method_name="floor_", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), ), ) -def test_paddle_tensor_bitwise_xor( - dtypes_and_x, +def test_paddle_tensor_floor_( + dtype_and_x, frontend_method_data, init_flags, method_flags, @@ -1557,13 +1548,15 @@ def test_paddle_tensor_bitwise_xor( on_device, backend_fw, ): - input_dtype, x = dtypes_and_x + input_dtype, x = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, - init_all_as_kwargs_np={"data": x[0]}, + init_all_as_kwargs_np={ + "data": x[0], + }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={"y": x[1]}, + method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -1572,16 +1565,22 @@ def test_paddle_tensor_bitwise_xor( ) -# logical_xor +# floor_divide @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", - method_name="logical_xor", + method_name="floor_divide", dtypes_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), num_arrays=2, shared_dtype=True + available_dtypes=helpers.get_dtypes("valid"), + num_arrays=2, + min_value=2, + shared_dtype=True, + large_abs_safety_factor=2, + small_abs_safety_factor=2, + safety_factor_scale="linear", ), ) -def test_paddle_tensor_logical_xor( +def test_paddle_tensor_floor_divide( dtypes_and_x, frontend_method_data, init_flags, @@ -1591,6 +1590,7 @@ def test_paddle_tensor_logical_xor( backend_fw, ): input_dtype, x = dtypes_and_x + # Absolute tolerance is 1, helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -1602,19 +1602,19 @@ def test_paddle_tensor_logical_xor( method_flags=method_flags, frontend=frontend, on_device=on_device, + atol_=1, ) -# logical_or @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", - method_name="logical_or", + method_name="fmax", dtypes_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), num_arrays=2, shared_dtype=True + available_dtypes=helpers.get_dtypes("float"), num_arrays=2, shared_dtype=True ), ) -def test_paddle_tensor_logical_or( +def test_paddle_tensor_fmax( dtypes_and_x, frontend_method_data, init_flags, @@ -1638,17 +1638,16 @@ def test_paddle_tensor_logical_or( ) -# rsqrt @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", - method_name="rsqrt", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), + method_name="fmin", + dtypes_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), num_arrays=2, shared_dtype=True ), ) -def test_paddle_tensor_rsqrt( - dtype_and_x, +def test_paddle_tensor_fmin( + dtypes_and_x, frontend_method_data, init_flags, method_flags, @@ -1656,15 +1655,13 @@ def test_paddle_tensor_rsqrt( on_device, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, x = dtypes_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, - init_all_as_kwargs_np={ - "data": x[0], - }, + init_all_as_kwargs_np={"data": x[0]}, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, + method_all_as_kwargs_np={"y": x[1]}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -1673,15 +1670,20 @@ def test_paddle_tensor_rsqrt( ) +# greater_than @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", - method_name="bitwise_or", + method_name="greater_than", dtypes_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), num_arrays=2, shared_dtype=True + available_dtypes=helpers.get_dtypes("valid"), + num_arrays=2, + shared_dtype=True, + safety_factor_scale="log", + small_abs_safety_factor=32, ), ) -def test_paddle_tensor_bitwise_or( +def test_paddle_tensor_greater_than( dtypes_and_x, frontend_method_data, init_flags, @@ -1705,16 +1707,16 @@ def test_paddle_tensor_bitwise_or( ) -# ceil +# imag @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", - method_name="ceil", + method_name="imag", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("valid"), ), ) -def test_paddle_tensor_ceil( +def test_paddle_tensor_imag( dtype_and_x, frontend_method_data, init_flags, @@ -1740,17 +1742,18 @@ def test_paddle_tensor_ceil( ) -# bitwise_and +# is_tensor @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", - method_name="bitwise_and", - dtypes_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), num_arrays=2, shared_dtype=True + method_name="is_tensor", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + num_arrays=1, ), ) -def test_paddle_tensor_bitwise_and( - dtypes_and_x, +def test_paddle_tensor_is_tensor( + dtype_and_x, frontend_method_data, init_flags, method_flags, @@ -1758,13 +1761,15 @@ def test_paddle_tensor_bitwise_and( on_device, backend_fw, ): - input_dtype, x = dtypes_and_x + input_dtype, x = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, - init_all_as_kwargs_np={"data": x[0]}, + init_all_as_kwargs_np={ + "data": x[0], + }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={"y": x[1]}, + method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -1773,20 +1778,16 @@ def test_paddle_tensor_bitwise_and( ) -# greater_than +# isclose @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", - method_name="greater_than", + method_name="isclose", dtypes_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - num_arrays=2, - shared_dtype=True, - safety_factor_scale="log", - small_abs_safety_factor=32, + available_dtypes=helpers.get_dtypes("valid"), num_arrays=2, shared_dtype=True ), ) -def test_paddle_tensor_greater_than( +def test_paddle_tensor_isclose( dtypes_and_x, frontend_method_data, init_flags, @@ -1810,16 +1811,16 @@ def test_paddle_tensor_greater_than( ) -# bitwise_not +# isfinite @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", - method_name="bitwise_not", + method_name="isfinite", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), ), ) -def test_paddle_tensor_bitwise_not( +def test_paddle_tensor_isfinite( dtype_and_x, frontend_method_data, init_flags, @@ -1845,16 +1846,16 @@ def test_paddle_tensor_bitwise_not( ) -# reciprocal +# isinf @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", - method_name="reciprocal", + method_name="isinf", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), ), ) -def test_paddle_tensor_reciprocal( +def test_paddle_tensor_isinf( dtype_and_x, frontend_method_data, init_flags, @@ -1880,17 +1881,17 @@ def test_paddle_tensor_reciprocal( ) -# logical_and +# isnan @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", - method_name="logical_and", - dtypes_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), num_arrays=2, shared_dtype=True + method_name="isnan", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), ), ) -def test_paddle_tensor_logical_and( - dtypes_and_x, +def test_paddle_tensor_isnan( + dtype_and_x, frontend_method_data, init_flags, method_flags, @@ -1898,13 +1899,15 @@ def test_paddle_tensor_logical_and( on_device, backend_fw, ): - input_dtype, x = dtypes_and_x + input_dtype, x = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, - init_all_as_kwargs_np={"self": x[0]}, + init_all_as_kwargs_np={ + "data": x[0], + }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={"y": x[1]}, + method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -1913,20 +1916,16 @@ def test_paddle_tensor_logical_and( ) -# divide +# less_than @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", - method_name="divide", + method_name="less_than", dtypes_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - num_arrays=2, - shared_dtype=True, - safety_factor_scale="log", - small_abs_safety_factor=32, + available_dtypes=helpers.get_dtypes("valid"), num_arrays=2, shared_dtype=True ), ) -def test_paddle_tensor_divide( +def test_paddle_tensor_less_than( dtypes_and_x, frontend_method_data, init_flags, @@ -1950,60 +1949,17 @@ def test_paddle_tensor_divide( ) +# log @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", - method_name="cumprod", - dtype_x_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("float"), - valid_axis=True, - force_int_axis=True, - min_num_dims=1, - min_value=-5, - max_value=5, - ), -) -def test_paddle_tensor_cumprod( - dtype_x_axis, - frontend_method_data, - init_flags, - method_flags, - frontend, - on_device, - backend_fw, -): - input_dtype, x, axis = dtype_x_axis - helpers.test_frontend_method( - init_input_dtypes=input_dtype, - backend_to_test=backend_fw, - init_all_as_kwargs_np={ - "data": x[0], - }, - method_input_dtypes=input_dtype, - method_all_as_kwargs_np={"dim": axis}, - frontend_method_data=frontend_method_data, - init_flags=init_flags, - method_flags=method_flags, - frontend=frontend, - on_device=on_device, - ) - - -@handle_frontend_method( - class_tree=CLASS_TREE, - init_tree="paddle.to_tensor", - method_name="cumsum", - dtype_x_axis=helpers.dtype_values_axis( + method_name="log", + dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - valid_axis=True, - force_int_axis=True, - min_num_dims=1, - min_value=-5, - max_value=5, ), ) -def test_paddle_tensor_cumsum( - dtype_x_axis, +def test_paddle_tensor_log( + dtype_and_x, frontend_method_data, init_flags, method_flags, @@ -2011,7 +1967,7 @@ def test_paddle_tensor_cumsum( on_device, backend_fw, ): - input_dtype, x, axis = dtype_x_axis + input_dtype, x = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -2019,7 +1975,7 @@ def test_paddle_tensor_cumsum( "data": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={"axis": axis}, + method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -2028,15 +1984,16 @@ def test_paddle_tensor_cumsum( ) +# log10 @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", - method_name="angle", + method_name="log10", dtype_and_x=helpers.dtype_and_values( - available_dtypes=["float64", "complex64", "complex128"], + available_dtypes=helpers.get_dtypes("float"), ), ) -def test_paddle_tensor_angle( +def test_paddle_tensor_log10( dtype_and_x, frontend_method_data, init_flags, @@ -2062,18 +2019,16 @@ def test_paddle_tensor_angle( ) -# equal +# logical_and @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", - method_name="equal", + method_name="logical_and", dtypes_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - num_arrays=2, - shared_dtype=True, + available_dtypes=helpers.get_dtypes("valid"), num_arrays=2, shared_dtype=True ), ) -def test_paddle_tensor_equal( +def test_paddle_tensor_logical_and( dtypes_and_x, frontend_method_data, init_flags, @@ -2086,7 +2041,7 @@ def test_paddle_tensor_equal( helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, - init_all_as_kwargs_np={"data": x[0]}, + init_all_as_kwargs_np={"self": x[0]}, method_input_dtypes=input_dtype, method_all_as_kwargs_np={"y": x[1]}, frontend_method_data=frontend_method_data, @@ -2097,16 +2052,16 @@ def test_paddle_tensor_equal( ) -# rad2deg +# logical_not @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", - method_name="rad2deg", + method_name="logical_not", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("valid"), ), ) -def test_paddle_tensor_rad2deg( +def test_paddle_tensor_logical_not( dtype_and_x, frontend_method_data, init_flags, @@ -2132,15 +2087,16 @@ def test_paddle_tensor_rad2deg( ) +# logical_or @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", - method_name="fmax", + method_name="logical_or", dtypes_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), num_arrays=2, shared_dtype=True + available_dtypes=helpers.get_dtypes("valid"), num_arrays=2, shared_dtype=True ), ) -def test_paddle_tensor_fmax( +def test_paddle_tensor_logical_or( dtypes_and_x, frontend_method_data, init_flags, @@ -2164,15 +2120,16 @@ def test_paddle_tensor_fmax( ) +# logical_xor @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", - method_name="fmin", + method_name="logical_xor", dtypes_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), num_arrays=2, shared_dtype=True + available_dtypes=helpers.get_dtypes("valid"), num_arrays=2, shared_dtype=True ), ) -def test_paddle_tensor_fmin( +def test_paddle_tensor_logical_xor( dtypes_and_x, frontend_method_data, init_flags, @@ -2196,18 +2153,23 @@ def test_paddle_tensor_fmin( ) +# max @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", - method_name="minimum", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - shared_dtype=True, + method_name="max", + dtype_x_axis=helpers.dtype_values_axis( + available_dtypes=st.one_of(helpers.get_dtypes("valid")), + min_axis=-1, + max_axis=0, + min_num_dims=1, + force_int_axis=False, ), + keep_dims=st.booleans(), ) -def test_paddle_tensor_minimum( - dtype_and_x, +def test_paddle_tensor_max( + dtype_x_axis, + keep_dims, frontend_method_data, init_flags, method_flags, @@ -2215,32 +2177,38 @@ def test_paddle_tensor_minimum( on_device, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtypes, x, axis = dtype_x_axis helpers.test_frontend_method( - init_input_dtypes=input_dtype, + init_input_dtypes=input_dtypes, backend_to_test=backend_fw, - init_all_as_kwargs_np={"data": x[0]}, - method_input_dtypes=input_dtype, - method_all_as_kwargs_np={"y": x[1]}, + init_all_as_kwargs_np={ + "object": x[0], + }, + method_input_dtypes=input_dtypes, + method_all_as_kwargs_np={ + "axis": axis, + "keepdim": keep_dims, + }, + frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, - frontend=frontend, on_device=on_device, ) -# less_than @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", - method_name="less_than", - dtypes_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), num_arrays=2, shared_dtype=True + method_name="minimum", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, + shared_dtype=True, ), ) -def test_paddle_tensor_less_than( - dtypes_and_x, +def test_paddle_tensor_minimum( + dtype_and_x, frontend_method_data, init_flags, method_flags, @@ -2248,7 +2216,7 @@ def test_paddle_tensor_less_than( on_device, backend_fw, ): - input_dtype, x = dtypes_and_x + input_dtype, x = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -2263,23 +2231,18 @@ def test_paddle_tensor_less_than( ) -# max @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", - method_name="max", - dtype_x_axis=helpers.dtype_values_axis( - available_dtypes=st.one_of(helpers.get_dtypes("valid")), - min_axis=-1, - max_axis=0, - min_num_dims=1, - force_int_axis=False, + method_name="multiply", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, + shared_dtype=True, ), - keep_dims=st.booleans(), ) -def test_paddle_tensor_max( - dtype_x_axis, - keep_dims, +def test_paddle_tensor_multiply( + dtype_and_x, frontend_method_data, init_flags, method_flags, @@ -2287,41 +2250,56 @@ def test_paddle_tensor_max( on_device, backend_fw, ): - input_dtypes, x, axis = dtype_x_axis + input_dtype, x = dtype_and_x helpers.test_frontend_method( - init_input_dtypes=input_dtypes, + init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ - "object": x[0], + "value": x[0], }, - method_input_dtypes=input_dtypes, + method_input_dtypes=input_dtype, method_all_as_kwargs_np={ - "axis": axis, - "keepdim": keep_dims, + "y": x[1], }, - frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, + frontend=frontend, on_device=on_device, ) -# deg2rad +@given( + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid", prune_function=False), + ).filter(lambda x: "bfloat16" not in x[0]), +) +def test_paddle_tensor_ndim( + dtype_x, +): + _, data = dtype_x + x = Tensor(data[0]) + ivy.utils.assertions.check_equal(x.ndim, data[0].ndim, as_array=False) + + +# neg @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", - method_name="deg2rad", + method_name="neg", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), + min_value=-1e04, + max_value=1e04, + allow_inf=False, ), ) -def test_paddle_tensor_deg2rad( +def test_paddle_tensor_neg( dtype_and_x, + frontend, frontend_method_data, init_flags, method_flags, - frontend, on_device, backend_fw, ): @@ -2342,41 +2320,33 @@ def test_paddle_tensor_deg2rad( ) -# rot90 @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", - method_name="rot90", - dtype_m_k_axes=_get_dtype_values_k_axes_for_rot90( + method_name="numel", + dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), - min_num_dims=3, - max_num_dims=6, - min_dim_size=1, - max_dim_size=10, + min_num_dims=1, ), ) -def test_paddle_tensor_rot90( - dtype_m_k_axes, +def test_paddle_tensor_numel( + dtype_and_x, frontend_method_data, init_flags, method_flags, frontend, - on_device, backend_fw, + on_device, ): - input_dtype, values, k, axes = dtype_m_k_axes - + input_dtype, x = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, - backend_to_test=backend_fw, init_all_as_kwargs_np={ - "data": values, + "data": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "k": k, - "axes": axes, - }, + backend_to_test=backend_fw, + method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -2385,16 +2355,16 @@ def test_paddle_tensor_rot90( ) -# imag +# rad2deg @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", - method_name="imag", + method_name="rad2deg", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), + available_dtypes=helpers.get_dtypes("float"), ), ) -def test_paddle_tensor_imag( +def test_paddle_tensor_rad2deg( dtype_and_x, frontend_method_data, init_flags, @@ -2420,23 +2390,17 @@ def test_paddle_tensor_imag( ) -# floor_divide +# reciprocal @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", - method_name="floor_divide", - dtypes_and_x=helpers.dtype_and_values( + method_name="reciprocal", + dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), - num_arrays=2, - min_value=2, - shared_dtype=True, - large_abs_safety_factor=2, - small_abs_safety_factor=2, - safety_factor_scale="linear", ), ) -def test_paddle_tensor_floor_divide( - dtypes_and_x, +def test_paddle_tensor_reciprocal( + dtype_and_x, frontend_method_data, init_flags, method_flags, @@ -2444,34 +2408,35 @@ def test_paddle_tensor_floor_divide( on_device, backend_fw, ): - input_dtype, x = dtypes_and_x - # Absolute tolerance is 1, + input_dtype, x = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, - init_all_as_kwargs_np={"data": x[0]}, + init_all_as_kwargs_np={ + "data": x[0], + }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={"y": x[1]}, + method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, frontend=frontend, on_device=on_device, - atol_=1, ) -# is_tensor +# remainder_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", - method_name="is_tensor", + method_name="remainder_", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - num_arrays=1, + num_arrays=2, + shared_dtype=True, ), ) -def test_paddle_tensor_is_tensor( +def test_paddle_tensor_remainder_( dtype_and_x, frontend_method_data, init_flags, @@ -2485,10 +2450,12 @@ def test_paddle_tensor_is_tensor( init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ - "data": x[0], + "value": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, + method_all_as_kwargs_np={ + "y": x[1], + }, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -2497,17 +2464,21 @@ def test_paddle_tensor_is_tensor( ) -# isclose +# rot90 @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", - method_name="isclose", - dtypes_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), num_arrays=2, shared_dtype=True + method_name="rot90", + dtype_m_k_axes=_get_dtype_values_k_axes_for_rot90( + available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=3, + max_num_dims=6, + min_dim_size=1, + max_dim_size=10, ), ) -def test_paddle_tensor_isclose( - dtypes_and_x, +def test_paddle_tensor_rot90( + dtype_m_k_axes, frontend_method_data, init_flags, method_flags, @@ -2515,13 +2486,19 @@ def test_paddle_tensor_isclose( on_device, backend_fw, ): - input_dtype, x = dtypes_and_x + input_dtype, values, k, axes = dtype_m_k_axes + helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, - init_all_as_kwargs_np={"data": x[0]}, + init_all_as_kwargs_np={ + "data": values, + }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={"y": x[1]}, + method_all_as_kwargs_np={ + "k": k, + "axes": axes, + }, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -2530,23 +2507,17 @@ def test_paddle_tensor_isclose( ) -# equal_all +# rsqrt @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", - method_name="equal_all", - dtypes_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - num_arrays=2, - min_value=-np.inf, - max_value=np.inf, - shared_dtype=True, - safety_factor_scale="log", - small_abs_safety_factor=32, + method_name="rsqrt", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), ), ) -def test_paddle_tensor_equal_all( - dtypes_and_x, +def test_paddle_tensor_rsqrt( + dtype_and_x, frontend_method_data, init_flags, method_flags, @@ -2554,13 +2525,15 @@ def test_paddle_tensor_equal_all( on_device, backend_fw, ): - input_dtype, x = dtypes_and_x + input_dtype, x = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, - init_all_as_kwargs_np={"data": x[0]}, + init_all_as_kwargs_np={ + "data": x[0], + }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={"y": x[1]}, + method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -2569,16 +2542,29 @@ def test_paddle_tensor_equal_all( ) -# conj +@given( + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid", prune_function=False), + ret_shape=True, + ).filter(lambda x: "bfloat16" not in x[0]), +) +def test_paddle_tensor_shape(dtype_x): + _, data, shape = dtype_x + x = Tensor(data[0]) + ivy.utils.assertions.check_equal( + x.ivy_array.shape, ivy.Shape(shape), as_array=False + ) + + @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", - method_name="conj", + method_name="sign", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), + available_dtypes=helpers.get_dtypes("float"), ), ) -def test_paddle_tensor_conj( +def test_paddle_tensor_sign( dtype_and_x, frontend_method_data, init_flags, @@ -2604,16 +2590,16 @@ def test_paddle_tensor_conj( ) -# floor_ +# sin @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", - method_name="floor_", + method_name="sin", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), ), ) -def test_paddle_tensor_floor_( +def test_paddle_tensor_sin( dtype_and_x, frontend_method_data, init_flags, @@ -2639,24 +2625,21 @@ def test_paddle_tensor_floor_( ) -# neg +# sinh @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", - method_name="neg", + method_name="sinh", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - min_value=-1e04, - max_value=1e04, - allow_inf=False, ), ) -def test_paddle_tensor_neg( +def test_paddle_tensor_sinh( dtype_and_x, - frontend, frontend_method_data, init_flags, method_flags, + frontend, on_device, backend_fw, ): @@ -2677,17 +2660,23 @@ def test_paddle_tensor_neg( ) -# isnan +# sort @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", - method_name="isnan", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + method_name="sort", + dtype_x_axis=helpers.dtype_values_axis( + available_dtypes=st.one_of(helpers.get_dtypes("float")), + min_axis=-1, + max_axis=0, + min_num_dims=1, + force_int_axis=True, ), + descending=st.booleans(), ) -def test_paddle_tensor_isnan( - dtype_and_x, +def test_paddle_tensor_sort( + dtype_x_axis, + descending, frontend_method_data, init_flags, method_flags, @@ -2695,33 +2684,36 @@ def test_paddle_tensor_isnan( on_device, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtypes, x, axis = dtype_x_axis helpers.test_frontend_method( - init_input_dtypes=input_dtype, + init_input_dtypes=input_dtypes, backend_to_test=backend_fw, init_all_as_kwargs_np={ - "data": x[0], + "object": x[0], }, - method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, + method_input_dtypes=input_dtypes, + method_all_as_kwargs_np={ + "axis": axis, + "descending": descending, + }, + frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, - frontend=frontend, on_device=on_device, ) -# logical_not +# sqrt @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", - method_name="logical_not", + method_name="sqrt", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), ), ) -def test_paddle_tensor_logical_not( +def test_paddle_tensor_sqrt( dtype_and_x, frontend_method_data, init_flags, @@ -2747,15 +2739,16 @@ def test_paddle_tensor_logical_not( ) +# square @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", - method_name="sign", + method_name="square", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), ), ) -def test_paddle_tensor_sign( +def test_paddle_tensor_square( dtype_and_x, frontend_method_data, init_flags, @@ -2781,17 +2774,17 @@ def test_paddle_tensor_sign( ) -# acosh +# subtract @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", - method_name="acosh", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + method_name="subtract", + dtypes_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), num_arrays=2, shared_dtype=True ), ) -def test_paddle_tensor_acosh( - dtype_and_x, +def test_paddle_tensor_subtract( + dtypes_and_x, frontend_method_data, init_flags, method_flags, @@ -2799,15 +2792,13 @@ def test_paddle_tensor_acosh( on_device, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, x = dtypes_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, - init_all_as_kwargs_np={ - "data": x[0], - }, + init_all_as_kwargs_np={"data": x[0]}, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, + method_all_as_kwargs_np={"y": x[1]}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -2816,42 +2807,17 @@ def test_paddle_tensor_acosh( ) -# cond -@st.composite -def _get_dtype_and_matrix_non_singular(draw, dtypes): - while True: - matrix = draw( - helpers.dtype_and_values( - available_dtypes=dtypes, - min_value=-10, - max_value=10, - min_num_dims=2, - max_num_dims=2, - min_dim_size=1, - max_dim_size=5, - shape=st.tuples(st.integers(1, 5), st.integers(1, 5)).filter( - lambda x: x[0] == x[1] - ), - allow_inf=False, - allow_nan=False, - ) - ) - if np.linalg.det(matrix[1][0]) != 0: - break - - return matrix[0], matrix[1] - - +# subtract_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", - method_name="cond", - dtype_and_x=_get_dtype_and_matrix_non_singular(dtypes=["float32", "float64"]), - p=st.sampled_from([None, "fro", "nuc", np.inf, -np.inf, 1, -1, 2, -2]), + method_name="subtract_", + dtypes_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), num_arrays=2, shared_dtype=True + ), ) -def test_paddle_tensor_cond( - dtype_and_x, - p, +def test_paddle_tensor_subtract_( + dtypes_and_x, frontend_method_data, init_flags, method_flags, @@ -2859,15 +2825,13 @@ def test_paddle_tensor_cond( on_device, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, x = dtypes_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, - init_all_as_kwargs_np={ - "data": x[0], - }, + init_all_as_kwargs_np={"data": x[0]}, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={"p": p}, + method_all_as_kwargs_np={"y": x[1]}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -2876,39 +2840,37 @@ def test_paddle_tensor_cond( ) -# var +# tanh @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", - method_name="var", - dtype_and_x=_statistical_dtype_values(function="var"), - keepdim=st.booleans(), + method_name="tanh", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + ), ) -def test_paddle_instance_var( +def test_paddle_tensor_tanh( dtype_and_x, - keepdim, - frontend, - backend_fw, frontend_method_data, init_flags, method_flags, + frontend, on_device, + backend_fw, ): - input_dtype, x, axis, correction = dtype_and_x + input_dtype, x = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, - init_all_as_kwargs_np={"data": x[0]}, - method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "axis": axis, - "unbiased": bool(correction), - "keepdim": keepdim, + backend_to_test=backend_fw, + init_all_as_kwargs_np={ + "data": x[0], }, - frontend=frontend, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, - backend_to_test=backend_fw, method_flags=method_flags, + frontend=frontend, on_device=on_device, ) @@ -2916,13 +2878,12 @@ def test_paddle_instance_var( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", - method_name="numel", + method_name="trunc", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), - min_num_dims=1, ), ) -def test_paddle_tensor_numel( +def test_paddle_tensor_trunc( dtype_and_x, frontend_method_data, init_flags, @@ -2948,31 +2909,74 @@ def test_paddle_tensor_numel( ) +# unsqueeze @handle_frontend_method( class_tree=CLASS_TREE, init_tree="paddle.to_tensor", - method_name="trunc", - dtype_and_x=helpers.dtype_and_values( + method_name="unsqueeze", + dtype_value=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), + shape=st.shared(helpers.get_shape(), key="shape"), + ), + axis=helpers.get_axis( + shape=st.shared(helpers.get_shape(), key="shape"), + allow_neg=True, + force_int=True, ), ) -def test_paddle_tensor_trunc( - dtype_and_x, +def test_paddle_tensor_unsqueeze( + dtype_value, + axis, frontend_method_data, init_flags, method_flags, frontend, - backend_fw, on_device, + backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, x = dtype_value helpers.test_frontend_method( init_input_dtypes=input_dtype, + backend_to_test=backend_fw, init_all_as_kwargs_np={ "data": x[0], }, method_input_dtypes=input_dtype, + method_all_as_kwargs_np={ + "axis": axis, + }, + frontend_method_data=frontend_method_data, + init_flags=init_flags, + method_flags=method_flags, + frontend=frontend, + on_device=on_device, + ) + + +# sqrt_ +@handle_frontend_method( + class_tree=CLASS_TREE, + init_tree="paddle.to_tensor", + method_name="sqrt_", + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + ), +) +def test_torch_tensor_sqrt_( + dtype_x, + frontend, + frontend_method_data, + init_flags, + method_flags, + on_device, + backend_fw, +): + input_dtype, x = dtype_x + helpers.test_frontend_method( + init_input_dtypes=input_dtype, backend_to_test=backend_fw, + init_all_as_kwargs_np={"data": x[0]}, + method_input_dtypes=input_dtype, method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_vision/test_transforms.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_vision/test_transforms.py index f869af563290a..998fba32ad3fb 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_vision/test_transforms.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_vision/test_transforms.py @@ -6,21 +6,44 @@ from ivy_tests.test_ivy.helpers import handle_frontend_test -# to_tensor +# --- Helpers --- # +# --------------- # + + +@st.composite +def _chw_image_shape_helper(draw): + c = draw(st.sampled_from([1, 3]), label="channel") + h = draw(helpers.ints(min_value=1, max_value=100), label="height") + w = draw(helpers.ints(min_value=1, max_value=100), label="width") + + shape = (c, h, w) + return shape + + +# --- Main --- # +# ------------ # + + +# adjust_brightness @handle_frontend_test( - fn_tree="paddle.to_tensor", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("valid")), + fn_tree="paddle.vision.transforms.adjust_brightness", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + shape=_chw_image_shape_helper(), + ), + brightness_factor=helpers.floats(min_value=0), ) -def test_paddle_to_tensor( +def test_paddle_adjust_brightness( *, dtype_and_x, + brightness_factor, on_device, fn_tree, frontend, test_flags, backend_fw, ): - input_dtype, input = dtype_and_x + input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -28,20 +51,11 @@ def test_paddle_to_tensor( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - pic=input[0], + img=x[0], + brightness_factor=brightness_factor, ) -@st.composite -def _chw_image_shape_helper(draw): - c = draw(st.sampled_from([1, 3]), label="channel") - h = draw(helpers.ints(min_value=1, max_value=100), label="height") - w = draw(helpers.ints(min_value=1, max_value=100), label="width") - - shape = (c, h, w) - return shape - - # adjust_hue @handle_frontend_test( fn_tree="paddle.vision.transforms.adjust_hue", @@ -81,19 +95,21 @@ def test_paddle_adjust_hue( ) -# adjust_brightness +# hflip @handle_frontend_test( - fn_tree="paddle.vision.transforms.adjust_brightness", + fn_tree="paddle.vision.transforms.hflip", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - shape=_chw_image_shape_helper(), + available_dtypes=helpers.get_dtypes("numeric"), + min_value=0, + min_num_dims=3, + max_num_dims=3, + min_dim_size=3, + max_dim_size=3, ), - brightness_factor=helpers.floats(min_value=0), ) -def test_paddle_adjust_brightness( +def test_paddle_hflip( *, dtype_and_x, - brightness_factor, on_device, fn_tree, frontend, @@ -103,25 +119,21 @@ def test_paddle_adjust_brightness( input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, - backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, + backend_to_test=backend_fw, img=x[0], - brightness_factor=brightness_factor, ) +# to_tensor @handle_frontend_test( - fn_tree="paddle.vision.transforms.vflip", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - min_num_dims=3, - max_num_dims=4, - ), + fn_tree="paddle.to_tensor", + dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("valid")), ) -def test_paddle_vflip( +def test_paddle_to_tensor( *, dtype_and_x, on_device, @@ -130,31 +142,27 @@ def test_paddle_vflip( test_flags, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, input = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, + backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - img=x[0], - backend_to_test=backend_fw, + pic=input[0], ) -# hflip @handle_frontend_test( - fn_tree="paddle.vision.transforms.hflip", + fn_tree="paddle.vision.transforms.vflip", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - min_value=0, + available_dtypes=helpers.get_dtypes("valid"), min_num_dims=3, - max_num_dims=3, - min_dim_size=3, - max_dim_size=3, + max_num_dims=4, ), ) -def test_paddle_hflip( +def test_paddle_vflip( *, dtype_and_x, on_device, @@ -170,6 +178,6 @@ def test_paddle_hflip( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - backend_to_test=backend_fw, img=x[0], + backend_to_test=backend_fw, ) diff --git a/ivy_tests/test_ivy/test_frontends/test_pandas/test_dataframe.py b/ivy_tests/test_ivy/test_frontends/test_pandas/test_dataframe.py index fdd19477ea001..074137130b30e 100644 --- a/ivy_tests/test_ivy/test_frontends/test_pandas/test_dataframe.py +++ b/ivy_tests/test_ivy/test_frontends/test_pandas/test_dataframe.py @@ -43,24 +43,25 @@ def test_pandas_series_abs( ) -@pytest.mark.skip("Testing pipeline not yet implemented") +@pytest.mark.xfail(reason="testing pipeline fixes") @handle_frontend_method( class_tree=CLASS_TREE, init_tree="pandas.DataFrame", - method_name="to_numpy", + method_name="mean", dtype_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("valid")), - na_values=st.sampled_from([None, np.nan, np.inf, -np.inf]), - copy=st.booleans(), + skipna=st.booleans(), + axis=st.sampled_from([None, 0, 1, "index", "columns"]), ) -def test_pandas_series_to_numpy( +def test_pandas_series_mean( dtype_x, frontend, - na_values, - copy, frontend_method_data, init_flags, method_flags, on_device, + backend_fw, + skipna, + axis, ): input_dtype, x = dtype_x helpers.test_frontend_method( @@ -69,15 +70,13 @@ def test_pandas_series_to_numpy( "data": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "na_values": na_values, - "copy": copy, - }, + method_all_as_kwargs_np={"skipna": skipna, "axis": axis}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, frontend=frontend, on_device=on_device, + backend_to_test=backend_fw, ) @@ -124,25 +123,24 @@ def test_pandas_series_sum( ) -@pytest.mark.xfail(reason="testing pipeline fixes") +@pytest.mark.skip("Testing pipeline not yet implemented") @handle_frontend_method( class_tree=CLASS_TREE, init_tree="pandas.DataFrame", - method_name="mean", + method_name="to_numpy", dtype_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("valid")), - skipna=st.booleans(), - axis=st.sampled_from([None, 0, 1, "index", "columns"]), + na_values=st.sampled_from([None, np.nan, np.inf, -np.inf]), + copy=st.booleans(), ) -def test_pandas_series_mean( +def test_pandas_series_to_numpy( dtype_x, frontend, + na_values, + copy, frontend_method_data, init_flags, method_flags, on_device, - backend_fw, - skipna, - axis, ): input_dtype, x = dtype_x helpers.test_frontend_method( @@ -151,11 +149,13 @@ def test_pandas_series_mean( "data": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={"skipna": skipna, "axis": axis}, + method_all_as_kwargs_np={ + "na_values": na_values, + "copy": copy, + }, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, frontend=frontend, on_device=on_device, - backend_to_test=backend_fw, ) diff --git a/ivy_tests/test_ivy/test_frontends/test_pandas/test_series.py b/ivy_tests/test_ivy/test_frontends/test_pandas/test_series.py index 1820de4f12ec7..796045cb93a76 100644 --- a/ivy_tests/test_ivy/test_frontends/test_pandas/test_series.py +++ b/ivy_tests/test_ivy/test_frontends/test_pandas/test_series.py @@ -46,23 +46,22 @@ def test_pandas_series_abs( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="pandas.Series", - method_name="to_numpy", - dtype_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("valid")), - na_value=st.sampled_from([None, np.nan, np.inf, -np.inf]), - copy=st.booleans(), + method_name="add", + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), num_arrays=2, shared_dtype=True + ), + axis=st.sampled_from(["index", 0]), ) -def test_pandas_series_to_numpy( +def test_pandas_series_add( dtype_x, - na_value, - copy, + frontend, frontend_method_data, init_flags, method_flags, on_device, - frontend, backend_fw, + axis, ): - # todo add castable dtypes for output input_dtype, x = dtype_x helpers.test_frontend_method( init_input_dtypes=input_dtype, @@ -71,14 +70,16 @@ def test_pandas_series_to_numpy( }, method_input_dtypes=input_dtype, method_all_as_kwargs_np={ - "na_value": na_value, - "copy": copy, + "other": x[1], + "level": None, + "axis": axis, + "fill_value": None, }, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, - on_device=on_device, frontend=frontend, + on_device=on_device, backend_to_test=backend_fw, ) @@ -87,13 +88,12 @@ def test_pandas_series_to_numpy( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="pandas.Series", - method_name="sum", + method_name="mean", dtype_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("valid")), skipna=st.booleans(), axis=st.sampled_from([None, 0]), - min_count=st.integers(min_value=0, max_value=5), ) -def test_pandas_series_sum( +def test_pandas_series_mean( dtype_x, frontend, frontend_method_data, @@ -103,7 +103,6 @@ def test_pandas_series_sum( backend_fw, skipna, axis, - min_count, ): input_dtype, x = dtype_x helpers.test_frontend_method( @@ -112,11 +111,7 @@ def test_pandas_series_sum( "data": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "skipna": skipna, - "axis": axis, - "min_count": min_count, - }, + method_all_as_kwargs_np={"skipna": skipna, "axis": axis}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -130,12 +125,13 @@ def test_pandas_series_sum( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="pandas.Series", - method_name="mean", + method_name="sum", dtype_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("valid")), skipna=st.booleans(), axis=st.sampled_from([None, 0]), + min_count=st.integers(min_value=0, max_value=5), ) -def test_pandas_series_mean( +def test_pandas_series_sum( dtype_x, frontend, frontend_method_data, @@ -145,6 +141,7 @@ def test_pandas_series_mean( backend_fw, skipna, axis, + min_count, ): input_dtype, x = dtype_x helpers.test_frontend_method( @@ -153,7 +150,11 @@ def test_pandas_series_mean( "data": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={"skipna": skipna, "axis": axis}, + method_all_as_kwargs_np={ + "skipna": skipna, + "axis": axis, + "min_count": min_count, + }, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -167,22 +168,23 @@ def test_pandas_series_mean( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="pandas.Series", - method_name="add", - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), num_arrays=2, shared_dtype=True - ), - axis=st.sampled_from(["index", 0]), + method_name="to_numpy", + dtype_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("valid")), + na_value=st.sampled_from([None, np.nan, np.inf, -np.inf]), + copy=st.booleans(), ) -def test_pandas_series_add( +def test_pandas_series_to_numpy( dtype_x, - frontend, + na_value, + copy, frontend_method_data, init_flags, method_flags, on_device, + frontend, backend_fw, - axis, ): + # todo add castable dtypes for output input_dtype, x = dtype_x helpers.test_frontend_method( init_input_dtypes=input_dtype, @@ -191,15 +193,13 @@ def test_pandas_series_add( }, method_input_dtypes=input_dtype, method_all_as_kwargs_np={ - "other": x[1], - "level": None, - "axis": axis, - "fill_value": None, + "na_value": na_value, + "copy": copy, }, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, - frontend=frontend, on_device=on_device, + frontend=frontend, backend_to_test=backend_fw, ) diff --git a/ivy_tests/test_ivy/test_frontends/test_scipy/test_fft/test_fft.py b/ivy_tests/test_ivy/test_frontends/test_scipy/test_fft/test_fft.py index c13d8e761a0fc..6e37c555f41df 100644 --- a/ivy_tests/test_ivy/test_frontends/test_scipy/test_fft/test_fft.py +++ b/ivy_tests/test_ivy/test_frontends/test_scipy/test_fft/test_fft.py @@ -6,53 +6,8 @@ from ivy_tests.test_ivy.helpers import handle_frontend_test -# Helpers - - -@st.composite -def _x_and_fft(draw, dtypes): - min_fft_points = 2 - dtype = draw(dtypes) - x_dim = draw( - helpers.get_shape( - min_dim_size=2, max_dim_size=100, min_num_dims=1, max_num_dims=4 - ) - ) - x = draw( - helpers.array_values( - dtype=dtype[0], - shape=tuple(x_dim), - ) - ) - dim = draw( - helpers.get_axis(shape=x_dim, allow_neg=True, allow_none=False, max_size=1) - ) - norm = draw(st.sampled_from(["backward", "forward", "ortho"])) - n = draw(st.integers(min_fft_points, 256)) - return dtype, x, dim, norm, n - - -@st.composite -def _x_and_ifft(draw): - min_fft_points = 2 - dtype = draw(helpers.get_dtypes("complex")) - x_dim = draw( - helpers.get_shape( - min_dim_size=2, max_dim_size=100, min_num_dims=1, max_num_dims=4 - ) - ) - x = draw( - helpers.array_values( - dtype=dtype[0], - shape=tuple(x_dim), - min_value=-1e-10, - max_value=1e10, - ) - ) - dim = draw(st.integers(1 - len(list(x_dim)), len(list(x_dim)) - 1)) - norm = draw(st.sampled_from(["backward", "forward", "ortho"])) - n = draw(st.integers(min_fft_points, 256)) - return dtype, x, dim, norm, n +# --- Helpers --- # +# --------------- # @st.composite @@ -106,6 +61,32 @@ def _valid_idct(draw): return dtype, x, type, n, axis, norm +# Helpers + + +@st.composite +def _x_and_fft(draw, dtypes): + min_fft_points = 2 + dtype = draw(dtypes) + x_dim = draw( + helpers.get_shape( + min_dim_size=2, max_dim_size=100, min_num_dims=1, max_num_dims=4 + ) + ) + x = draw( + helpers.array_values( + dtype=dtype[0], + shape=tuple(x_dim), + ) + ) + dim = draw( + helpers.get_axis(shape=x_dim, allow_neg=True, allow_none=False, max_size=1) + ) + norm = draw(st.sampled_from(["backward", "forward", "ortho"])) + n = draw(st.integers(min_fft_points, 256)) + return dtype, x, dim, norm, n + + @st.composite def _x_and_fft2(draw): min_fft2_points = 2 @@ -127,6 +108,29 @@ def _x_and_fft2(draw): return dtype, x, s, dim, norm +@st.composite +def _x_and_ifft(draw): + min_fft_points = 2 + dtype = draw(helpers.get_dtypes("complex")) + x_dim = draw( + helpers.get_shape( + min_dim_size=2, max_dim_size=100, min_num_dims=1, max_num_dims=4 + ) + ) + x = draw( + helpers.array_values( + dtype=dtype[0], + shape=tuple(x_dim), + min_value=-1e-10, + max_value=1e10, + ) + ) + dim = draw(st.integers(1 - len(list(x_dim)), len(list(x_dim)) - 1)) + norm = draw(st.sampled_from(["backward", "forward", "ortho"])) + n = draw(st.integers(min_fft_points, 256)) + return dtype, x, dim, norm, n + + @st.composite def _x_and_ifftn(draw): _x_and_ifftn = draw(_x_and_fft2()) @@ -134,43 +138,50 @@ def _x_and_ifftn(draw): return _x_and_ifftn + (workers,) -# Tests +# --- Main --- # +# ------------ # -# fft +# dct @handle_frontend_test( - fn_tree="scipy.fft.fft", - d_x_d_n_n=_x_and_fft(helpers.get_dtypes("complex")), + fn_tree="scipy.fft.dct", + dtype_x_and_args=_valid_dct(), test_with_out=st.just(False), ) -def test_scipy_fft( - d_x_d_n_n, +def test_scipy_dct( + dtype_x_and_args, frontend, test_flags, fn_tree, on_device, ): - dtype, x, dim, norm, n = d_x_d_n_n + input_dtype, x, _type, n, axis, norm = dtype_x_and_args helpers.test_frontend_function( - input_dtypes=dtype, + input_dtypes=input_dtype, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x, - dim=dim, - norm=norm, + x=x[0], + type=_type, n=n, + axis=axis, + norm=norm, + rtol_=1e-3, + atol_=1e-1, ) -# ifft +# Tests + + +# fft @handle_frontend_test( - fn_tree="scipy.fft.ifft", - d_x_d_n_n=_x_and_ifft(), + fn_tree="scipy.fft.fft", + d_x_d_n_n=_x_and_fft(helpers.get_dtypes("complex")), test_with_out=st.just(False), ) -def test_scipy_ifft( +def test_scipy_fft( d_x_d_n_n, frontend, test_flags, @@ -191,33 +202,30 @@ def test_scipy_ifft( ) -# dct +# fft2 @handle_frontend_test( - fn_tree="scipy.fft.dct", - dtype_x_and_args=_valid_dct(), + fn_tree="scipy.fft.fft2", + d_x_d_s_n=_x_and_fft2(), test_with_out=st.just(False), ) -def test_scipy_dct( - dtype_x_and_args, +def test_scipy_fft2( + d_x_d_s_n, frontend, test_flags, fn_tree, on_device, ): - input_dtype, x, _type, n, axis, norm = dtype_x_and_args + dtype, x, s, ax, norm = d_x_d_s_n helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=dtype, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, x=x[0], - type=_type, - n=n, - axis=axis, + s=s, + axes=ax, norm=norm, - rtol_=1e-3, - atol_=1e-1, ) @@ -251,30 +259,30 @@ def test_scipy_idct( ) -# fft2 +# ifft @handle_frontend_test( - fn_tree="scipy.fft.fft2", - d_x_d_s_n=_x_and_fft2(), + fn_tree="scipy.fft.ifft", + d_x_d_n_n=_x_and_ifft(), test_with_out=st.just(False), ) -def test_scipy_fft2( - d_x_d_s_n, +def test_scipy_ifft( + d_x_d_n_n, frontend, test_flags, fn_tree, on_device, ): - dtype, x, s, ax, norm = d_x_d_s_n + dtype, x, dim, norm, n = d_x_d_n_n helpers.test_frontend_function( input_dtypes=dtype, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], - s=s, - axes=ax, + x=x, + dim=dim, norm=norm, + n=n, ) diff --git a/ivy_tests/test_ivy/test_frontends/test_scipy/test_linalg/test_linalg.py b/ivy_tests/test_ivy/test_frontends/test_scipy/test_linalg/test_linalg.py index 5444d9b047f9b..3f7d69283fa4f 100644 --- a/ivy_tests/test_ivy/test_frontends/test_scipy/test_linalg/test_linalg.py +++ b/ivy_tests/test_ivy/test_frontends/test_scipy/test_linalg/test_linalg.py @@ -10,8 +10,8 @@ from ivy_tests.test_ivy.helpers import handle_frontend_test -# Helpers # -# ------- # +# --- Helpers --- # +# --------------- # @st.composite @@ -113,70 +113,25 @@ def _vector_norm_example(): return _matrix_norm_example() -# Tests # -# ----- # +# --- Main --- # +# ------------ # -# tril +# eigh_tridiagonal @handle_frontend_test( - fn_tree="scipy.linalg.tril", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=1, - min_num_dims=2, - max_num_dims=5, - min_dim_size=1, - max_dim_size=5, - ), - k=helpers.ints(min_value=-10, max_value=10), + fn_tree="scipy.linalg.eigh_tridiagonal", + all_args=_generate_eigh_tridiagonal_args(), test_with_out=st.just(False), ) -def test_scipy_tril( - dtype_and_x, - k, +def test_scipy_eigh_tridiagonal( + all_args, frontend, test_flags, fn_tree, on_device, backend_fw, ): - dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - m=x[0], - k=k, - ) - - -# triu -@handle_frontend_test( - fn_tree="scipy.linalg.triu", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=1, - min_num_dims=2, - max_num_dims=5, - min_dim_size=1, - max_dim_size=5, - ), - k=helpers.ints(min_value=-10, max_value=10), - test_with_out=st.just(False), -) -def test_scipy_triu( - dtype_and_x, - k, - test_flags, - frontend, - fn_tree, - on_device, - backend_fw, -): - dtype, x = dtype_and_x + dtype, alpha, beta, eigvals_only, select, select_range, tol = all_args helpers.test_frontend_function( input_dtypes=dtype, backend_to_test=backend_fw, @@ -184,8 +139,12 @@ def test_scipy_triu( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - m=x[0], - k=k, + alpha=alpha[0], + beta=beta[0], + eigvals_only=eigvals_only, + select=select, + select_range=select_range, + tol=tol, ) @@ -220,36 +179,6 @@ def test_scipy_inv( ) -# pinv -@handle_frontend_test( - fn_tree="scipy.linalg.pinv", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_num_dims=2, - max_num_dims=2, - ), - test_with_out=st.just(False), -) -def test_scipy_pinv( - dtype_and_x, - frontend, - test_flags, - fn_tree, - on_device, - backend_fw, -): - dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - a=x[0], - ) - - # kron @handle_frontend_test( fn_tree="scipy.linalg.kron", @@ -278,34 +207,41 @@ def test_scipy_kron(dtype_and_x, frontend, test_flags, fn_tree, on_device, backe ) -# eigh_tridiagonal +# lu_factor @handle_frontend_test( - fn_tree="scipy.linalg.eigh_tridiagonal", - all_args=_generate_eigh_tridiagonal_args(), + fn_tree="scipy.linalg.lu_factor", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_value=0, + max_value=50, + min_num_dims=2, + ), + overwrite_a=st.booleans(), + check_finite=st.booleans(), test_with_out=st.just(False), ) -def test_scipy_eigh_tridiagonal( - all_args, +def test_scipy_lu_factor( + dtype_and_x, + overwrite_a, + check_finite, frontend, test_flags, fn_tree, on_device, backend_fw, ): - dtype, alpha, beta, eigvals_only, select, select_range, tol = all_args + dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, + test_values=False, fn_tree=fn_tree, on_device=on_device, - alpha=alpha[0], - beta=beta[0], - eigvals_only=eigvals_only, - select=select, - select_range=select_range, - tol=tol, + a=x[0], + overwrite_a=overwrite_a, + check_finite=check_finite, ) @@ -340,6 +276,36 @@ def test_scipy_norm( ) +# pinv +@handle_frontend_test( + fn_tree="scipy.linalg.pinv", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_num_dims=2, + max_num_dims=2, + ), + test_with_out=st.just(False), +) +def test_scipy_pinv( + dtype_and_x, + frontend, + test_flags, + fn_tree, + on_device, + backend_fw, +): + dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + a=x[0], + ) + + # svd @handle_frontend_test( fn_tree="scipy.linalg.svd", @@ -416,23 +382,27 @@ def test_scipy_svdvals( ) -# lu_factor +# Tests # +# ----- # + + +# tril @handle_frontend_test( - fn_tree="scipy.linalg.lu_factor", + fn_tree="scipy.linalg.tril", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_value=0, - max_value=50, + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=1, min_num_dims=2, + max_num_dims=5, + min_dim_size=1, + max_dim_size=5, ), - overwrite_a=st.booleans(), - check_finite=st.booleans(), + k=helpers.ints(min_value=-10, max_value=10), test_with_out=st.just(False), ) -def test_scipy_lu_factor( +def test_scipy_tril( dtype_and_x, - overwrite_a, - check_finite, + k, frontend, test_flags, fn_tree, @@ -445,10 +415,44 @@ def test_scipy_lu_factor( backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, - test_values=False, fn_tree=fn_tree, on_device=on_device, - a=x[0], - overwrite_a=overwrite_a, - check_finite=check_finite, + m=x[0], + k=k, + ) + + +# triu +@handle_frontend_test( + fn_tree="scipy.linalg.triu", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=1, + min_num_dims=2, + max_num_dims=5, + min_dim_size=1, + max_dim_size=5, + ), + k=helpers.ints(min_value=-10, max_value=10), + test_with_out=st.just(False), +) +def test_scipy_triu( + dtype_and_x, + k, + test_flags, + frontend, + fn_tree, + on_device, + backend_fw, +): + dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + m=x[0], + k=k, ) diff --git a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_compat/test_v1/test_nn.py b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_compat/test_v1/test_nn.py index fd7dafe4ec31b..0e18c435ddd0c 100644 --- a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_compat/test_v1/test_nn.py +++ b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_compat/test_v1/test_nn.py @@ -8,6 +8,10 @@ from ivy_tests.test_ivy.test_frontends.test_tensorflow.test_nn import _x_and_filters +# --- Helpers --- # +# --------------- # + + @st.composite def _batch_norm_helper(draw): num_dims = draw(st.integers(min_value=4, max_value=5)) @@ -41,6 +45,47 @@ def _batch_norm_helper(draw): return dtype + dtypes, x, epsilon, factor, training, data_format, vectors +# --- Main --- # +# ------------ # + + +@handle_frontend_test( + fn_tree="tensorflow.compat.v1.nn.depthwise_conv2d", + x_f_d_df=_x_and_filters( + dtypes=helpers.get_dtypes("float", full=False), + data_format=st.sampled_from(["NHWC"]), + padding=st.sampled_from(["VALID", "SAME"]), + type="depthwise", + ), + test_with_out=st.just(False), +) +def test_tensorflow_depthwise_conv2d( + *, + x_f_d_df, + frontend, + test_flags, + fn_tree, + backend_fw, + on_device, +): + input_dtype, x, filters, dilation, data_format, stride, padding = x_f_d_df + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=x, + filter=filters, + strides=stride, + padding=padding, + rate=dilation, + name=None, + data_format=data_format, + ) + + @handle_frontend_test( fn_tree="tensorflow.compat.v1.nn.fused_batch_norm", dtypes_args=_batch_norm_helper(), @@ -76,26 +121,25 @@ def test_tensorflow_fused_batch_norm( ) +# max_pool @handle_frontend_test( - fn_tree="tensorflow.compat.v1.nn.depthwise_conv2d", - x_f_d_df=_x_and_filters( - dtypes=helpers.get_dtypes("float", full=False), - data_format=st.sampled_from(["NHWC"]), - padding=st.sampled_from(["VALID", "SAME"]), - type="depthwise", - ), + fn_tree="tensorflow.compat.v1.nn.max_pool", + data_format=st.just("NHWC"), + x_k_s_p=helpers.arrays_for_pooling(min_dims=4, max_dims=4, min_side=1, max_side=4), test_with_out=st.just(False), ) -def test_tensorflow_depthwise_conv2d( +def test_tensorflow_max_pool( *, - x_f_d_df, + x_k_s_p, + data_format, frontend, test_flags, fn_tree, backend_fw, on_device, ): - input_dtype, x, filters, dilation, data_format, stride, padding = x_f_d_df + input_dtype, x, ksize, strides, padding = x_k_s_p + data_format = data_format helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -103,12 +147,10 @@ def test_tensorflow_depthwise_conv2d( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=x, - filter=filters, - strides=stride, + value=x[0], + ksize=ksize, + strides=strides, padding=padding, - rate=dilation, - name=None, data_format=data_format, ) @@ -149,37 +191,3 @@ def test_tensorflow_separable_conv2d( name=None, data_format=data_format, ) - - -# max_pool -@handle_frontend_test( - fn_tree="tensorflow.compat.v1.nn.max_pool", - data_format=st.just("NHWC"), - x_k_s_p=helpers.arrays_for_pooling(min_dims=4, max_dims=4, min_side=1, max_side=4), - test_with_out=st.just(False), -) -def test_tensorflow_max_pool( - *, - x_k_s_p, - data_format, - frontend, - test_flags, - fn_tree, - backend_fw, - on_device, -): - input_dtype, x, ksize, strides, padding = x_k_s_p - data_format = data_format - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - value=x[0], - ksize=ksize, - strides=strides, - padding=padding, - data_format=data_format, - ) diff --git a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_func_wrapper.py b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_func_wrapper.py index ca4d295ef0f8e..036a6fa1e8c54 100644 --- a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_func_wrapper.py +++ b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_func_wrapper.py @@ -15,12 +15,54 @@ import ivy.functional.frontends.numpy as np_frontend +# --- Helpers --- # +# --------------- # + + +@st.composite +def _dtype_helper(draw): + return draw( + st.sampled_from( + [ + draw(helpers.get_dtypes("valid", prune_function=False, full=False))[0], + ivy.as_native_dtype( + draw(helpers.get_dtypes("valid", prune_function=False, full=False))[ + 0 + ] + ), + draw( + st.sampled_from(list(tf_frontend.tensorflow_enum_to_type.values())) + ), + draw(st.sampled_from(list(tf_frontend.tensorflow_enum_to_type.keys()))), + np_frontend.dtype( + draw(helpers.get_dtypes("valid", prune_function=False, full=False))[ + 0 + ] + ), + draw(st.sampled_from(list(np_frontend.numpy_scalar_to_dtype.keys()))), + ] + ) + ) + + def _fn(x=None, dtype=None): if ivy.exists(dtype): return dtype return x +# --- Main --- # +# ------------ # + + +@given( + dtype=_dtype_helper(), +) +def test_tensorflow_handle_tf_dtype(dtype): + ret_dtype = handle_tf_dtype(_fn)(dtype=dtype) + assert isinstance(ret_dtype, ivy.Dtype) + + @given( dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid", prune_function=False) @@ -100,37 +142,3 @@ def test_tensorflow_to_ivy_arrays_and_back(dtype_and_x): assert isinstance(output, EagerTensor) assert input_frontend.dtype == output.dtype assert ivy.all(input_frontend.ivy_array == output.ivy_array) - - -@st.composite -def _dtype_helper(draw): - return draw( - st.sampled_from( - [ - draw(helpers.get_dtypes("valid", prune_function=False, full=False))[0], - ivy.as_native_dtype( - draw(helpers.get_dtypes("valid", prune_function=False, full=False))[ - 0 - ] - ), - draw( - st.sampled_from(list(tf_frontend.tensorflow_enum_to_type.values())) - ), - draw(st.sampled_from(list(tf_frontend.tensorflow_enum_to_type.keys()))), - np_frontend.dtype( - draw(helpers.get_dtypes("valid", prune_function=False, full=False))[ - 0 - ] - ), - draw(st.sampled_from(list(np_frontend.numpy_scalar_to_dtype.keys()))), - ] - ) - ) - - -@given( - dtype=_dtype_helper(), -) -def test_tensorflow_handle_tf_dtype(dtype): - ret_dtype = handle_tf_dtype(_fn)(dtype=dtype) - assert isinstance(ret_dtype, ivy.Dtype) diff --git a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_general_functions.py b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_general_functions.py index 08465b897d18e..7bf8e2eb7f56a 100644 --- a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_general_functions.py +++ b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_general_functions.py @@ -19,6 +19,55 @@ ) +# --- Helpers --- # +# --------------- # + + +@st.composite +def _boolean_mask_helper(draw): + tensor_shape = draw( + helpers.get_shape( + allow_none=False, + min_num_dims=3, + max_num_dims=5, + min_dim_size=1, + max_dim_size=10, + ), + ) + + dtype = draw(st.sampled_from(["float32", "float64"])) + + # Param: tensor + # tensor = draw( + # helpers.array_values( + # dtype=dtype, shape=tensor_shape, min_value=-5.0, max_value=5.0 + # ), + # ) + + dtype, tensor, axis = draw( + helpers.dtype_values_axis( + available_dtypes=[dtype], + shape=tensor_shape, + min_value=-5.0, + max_value=5.0, + force_int_axis=True, + valid_axis=True, + ) + ) + mask_dim = draw(helpers.ints(min_value=1, max_value=len(tensor_shape) - axis)) + mask_shape = tensor_shape[axis : mask_dim + axis] + + # Param:stop + mask = draw( + helpers.array_values( + allow_nan=False, + dtype="bool", + shape=mask_shape, + ), + ) + return [dtype[0], "bool"], tensor, mask, axis + + @st.composite def _get_clip_inputs(draw): shape = draw( @@ -42,73 +91,6 @@ def _get_clip_inputs(draw): return x_dtype, x, min, max -# argsort -@handle_frontend_test( - fn_tree="tensorflow.argsort", - dtype_input_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("numeric"), - min_num_dims=1, - max_num_dims=5, - min_dim_size=1, - max_dim_size=5, - min_axis=-1, - max_axis=0, - ), - direction=st.sampled_from(["ASCENDING", "DESCENDING"]), -) -def test_tensorflow_argsort( - *, - dtype_input_axis, - direction, - on_device, - fn_tree, - frontend, - backend_fw, - test_flags, -): - input_dtype, input, axis = dtype_input_axis - helpers.test_frontend_function( - input_dtypes=input_dtype, - frontend=frontend, - backend_to_test=backend_fw, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - values=input[0], - axis=axis, - direction=direction, - ) - - -# clip_by_value -@handle_frontend_test( - fn_tree="tensorflow.clip_by_value", - input_and_ranges=_get_clip_inputs(), - test_with_out=st.just(False), -) -def test_tensorflow_clip_by_value( - *, - input_and_ranges, - frontend, - test_flags, - backend_fw, - fn_tree, - on_device, -): - x_dtype, x, min, max = input_and_ranges - helpers.test_frontend_function( - input_dtypes=x_dtype, - frontend=frontend, - backend_to_test=backend_fw, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - t=x[0], - clip_value_min=min, - clip_value_max=max, - ) - - @st.composite def _get_global_norm_clip_inputs(draw): t_list_dtype, t_list = draw( @@ -149,49 +131,6 @@ def _get_global_norm_clip_inputs(draw): return t_list_dtype, t_list, norm_dtype, norm, global_norm_dtype, global_norm -# clip_by_global_norm -@handle_frontend_test( - fn_tree="tensorflow.clip_by_global_norm", - input_and_norm=_get_global_norm_clip_inputs(), - test_with_out=st.just(False), -) -def test_tensorflow_clip_by_global_norm( - *, - input_and_norm, - frontend, - backend_fw, - test_flags, - fn_tree, - on_device, -): - ( - t_list_dtype, - t_list, - norm_dtype, - norm, - global_norm_dtype, - global_norm, - ) = input_and_norm - - input_dtypes = [t_list_dtype[0], norm_dtype[0]] - use_norm = None - if global_norm_dtype: - input_dtypes.append(global_norm_dtype[0]) - use_norm = global_norm[0] - - helpers.test_frontend_function( - input_dtypes=input_dtypes, - frontend=frontend, - test_flags=test_flags, - backend_to_test=backend_fw, - fn_tree=fn_tree, - on_device=on_device, - t_list=t_list, - clip_norm=norm[0], - use_norm=use_norm, - ) - - @st.composite def _get_norm_clip_inputs(draw): dtype = draw(helpers.get_dtypes("numeric", full=False)) @@ -212,322 +151,482 @@ def _get_norm_clip_inputs(draw): return x_dtype[0], x, axis, norm -# clip_by_norm -@handle_frontend_test( - fn_tree="tensorflow.clip_by_norm", - input_and_norm=_get_norm_clip_inputs(), - test_with_out=st.just(False), -) -def test_tensorflow_clip_by_norm( - *, - input_and_norm, - frontend, - backend_fw, - test_flags, - fn_tree, - on_device, -): - x_dtype, x, axis, norm = input_and_norm - helpers.test_frontend_function( - input_dtypes=[x_dtype, x_dtype], - frontend=frontend, - test_flags=test_flags, - backend_to_test=backend_fw, - fn_tree=fn_tree, - on_device=on_device, - t=x[0], - clip_norm=norm[0], - axes=axis, +# transpose +@st.composite +def _get_perm_helper(draw): + shape = draw(st.shared(helpers.get_shape(min_num_dims=1), key="shape")) + dimensions = [x for x in range(len(shape))] + perm = draw(st.permutations(dimensions)) + return perm + + +@st.composite +def _linspace_helper(draw): + shape = draw( + helpers.get_shape( + allow_none=False, + min_num_dims=0, + max_num_dims=5, + min_dim_size=1, + max_dim_size=10, + ), ) + dtype = draw(st.sampled_from(["float32", "float64"])) -# eye -@handle_frontend_test( - fn_tree="tensorflow.eye", - n_rows=helpers.ints(min_value=0, max_value=10), - n_cols=st.none() | helpers.ints(min_value=0, max_value=10), - batch_shape=st.lists( - helpers.ints(min_value=1, max_value=10), min_size=1, max_size=2 - ), - dtype=helpers.get_dtypes("valid", full=False), -) -def test_tensorflow_eye( - *, - n_rows, - n_cols, - batch_shape, - dtype, - frontend, - backend_fw, - test_flags, - fn_tree, - on_device, -): - helpers.test_frontend_function( - input_dtypes=dtype, - frontend=frontend, - test_flags=test_flags, - backend_to_test=backend_fw, - fn_tree=fn_tree, - on_device=on_device, - num_rows=n_rows, - num_columns=n_cols, - batch_shape=batch_shape, - dtype=dtype[0], + # Param: start + start = draw( + helpers.array_values( + dtype=dtype, + shape=shape, + min_value=-5.0, + max_value=5.0, + ), ) + # Param:stop + stop = draw( + helpers.array_values( + dtype=dtype, + shape=shape, + min_value=-4.0, + max_value=10.0, + ), + ) -# foldl -@handle_frontend_test( - fn_tree="tensorflow.foldl", - fn=st.sampled_from( - [ - lambda a, b: a + b, - lambda a, b: a - b, - lambda a, b: a * b, - ], - ), - initializer=st.one_of(st.none(), st.floats(min_value=-1000, max_value=1000)), - dtype_and_values=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float", full=False), - min_value=-1000, - max_value=1000, - max_dim_size=10, - max_num_dims=4, - min_dim_size=1, + return [dtype] * 2, start, stop + + +# tile +@st.composite +def _multiple_shape_helper(draw): + input_dtype, input_array, input_shape = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), ret_shape=True + ) + ) + input_dims = len(input_shape) + + dt_n_multiples = draw( + helpers.dtype_and_values( + available_dtypes=["int32", "int64"], + min_value=0, + max_value=10, + shape=draw( + helpers.get_shape( + min_num_dims=1, + max_num_dims=1, + min_dim_size=input_dims, + max_dim_size=input_dims, + ) + ), + ) + ) + return input_dtype, input_array, dt_n_multiples + + +@st.composite +def _pad_helper(draw): + mode = draw( + st.sampled_from( + [ + "CONSTANT", + "REFLECT", + "SYMMETRIC", + ] + ) + ) + dtype, input, shape = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + ret_shape=True, + min_num_dims=1, + min_value=-100, + max_value=100, + ) + ) + ndim = len(shape) + min_dim = min(shape) + paddings = draw( + st.lists( + st.tuples( + st.integers(min_value=0, max_value=min_dim - 1), + st.integers(min_value=0, max_value=min_dim - 1), + ), + min_size=ndim, + max_size=ndim, + ) + ) + constant_values = draw(st.integers(min_value=0, max_value=4)) + return dtype, input[0], paddings, mode, constant_values + + +@st.composite +def _reshape_helper(draw): + shape = draw(helpers.get_shape(min_num_dims=1)) + reshape_shape = draw(helpers.reshape_shapes(shape=shape)) + dtype = draw(helpers.array_dtypes(num_arrays=1)) + x = draw(helpers.array_values(dtype=dtype[0], shape=shape)) + return x, dtype, reshape_shape + + +@st.composite +def _slice_helper(draw): + dtype, x, shape = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=1, + ret_shape=True, + ), + ) + begin, size = [], [] + for i in shape: + begin += [draw(st.integers(min_value=0, max_value=i - 1))] + size += [draw(st.integers(min_value=0, max_value=i - begin[-1]))] + return dtype, x, np.array(begin), np.array(size) + + +# Squeeze +@st.composite +def _squeeze_helper(draw): + shape = draw(st.shared(helpers.get_shape(), key="value_shape")) + valid_axes = [] + for index, axis in enumerate(shape): + if axis == 1: + valid_axes.append(index) + valid_axes.insert(0, None) + return draw(st.sampled_from(valid_axes)) + + +@st.composite +def _strided_slice_helper(draw): + dtype, x, shape = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=1, + ret_shape=True, + ), + ) + ndims = len(shape) + masks = draw( + st.lists( + st.integers(min_value=0, max_value=2**ndims - 1), min_size=5, max_size=5 + ).filter(lambda x: bin(x[2])[2:].count("1") <= min(len(shape) - 1, 1)) + ) + begin, end, strides = [], [], [] + for i in shape: + begin += [draw(st.integers(min_value=0, max_value=i - 1))] + end += [draw(st.integers(min_value=0, max_value=i - 1))] + if begin[-1] < end[-1]: + strides += [draw(st.integers(min_value=1, max_value=i))] + else: + strides += [draw(st.integers(max_value=-1, min_value=-i))] + ellipsis_mask = _num_to_bit_list(masks[2], ndims) + for i, v in enumerate(ellipsis_mask): + if v == 1: + skip = draw(st.integers(min_value=0, max_value=ndims)) + begin, end, strides = map( + lambda x: x[:i] + x[i + skip :] if i + skip < ndims else x[:i], + [begin, end, strides], + ) + break + return dtype, x, np.array(begin), np.array(end), np.array(strides), masks + + +@st.composite +def _x_cast_dtype_shape(draw): + x_dtype = draw(helpers.get_dtypes("valid", full=False)) + x_dtype, x = draw( + helpers.dtype_and_values( + dtype=x_dtype, + shape=st.shared(helpers.get_shape(), key="value_shape"), + large_abs_safety_factor=10, + small_abs_safety_factor=10, + safety_factor_scale="log", + ), + ) + to_shape = draw( + helpers.reshape_shapes(shape=st.shared(helpers.get_shape(), key="value_shape")), + ) + cast_dtype = x_dtype[0] + # known tensorflow bug when trying to cast to a different type + # https://github.com/tensorflow/tensorflow/issues/39554 + # cast_dtype = draw( + # helpers.get_dtypes("valid", full=False) + # .map(lambda t: t[0]) + # .filter(lambda t: ivy.can_cast(x_dtype[0], t)) + # ) + return x_dtype, x, cast_dtype, to_shape + + +# reverse +@st.composite +def reverse_helper(draw): + dtype, x, shape = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + min_num_dims=1, + max_num_dims=8, + ret_shape=True, + ) + ) + axis_dtype, axis = draw( + helpers.dtype_and_values( + available_dtypes=["int32", "int64"], + min_num_dims=1, + max_num_dims=1, + min_value=-(len(shape) - 1), + max_value=len(shape) - 1, + shape=(1,), + ) + ) + return dtype, x, axis_dtype, axis + + +# --- Main --- # +# ------------ # + + +# argsort +@handle_frontend_test( + fn_tree="tensorflow.argsort", + dtype_input_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("numeric"), min_num_dims=1, + max_num_dims=5, + min_dim_size=1, + max_dim_size=5, + min_axis=-1, + max_axis=0, ), - parallel_iterations=st.just(10), - swap_memory=st.booleans(), - name=st.none(), + direction=st.sampled_from(["ASCENDING", "DESCENDING"]), ) -def test_tensorflow_foldl( +def test_tensorflow_argsort( *, - fn, - initializer, - dtype_and_values, + dtype_input_axis, + direction, + on_device, + fn_tree, frontend, backend_fw, - fn_tree, test_flags, - parallel_iterations, - swap_memory, - name, ): - dtype, elems = dtype_and_values - elems = np.atleast_1d(elems) + input_dtype, input, axis = dtype_input_axis helpers.test_frontend_function( - input_dtypes=dtype, - fn=fn, - elems=elems, - initializer=initializer, - backend_to_test=backend_fw, - parallel_iterations=parallel_iterations, - swap_memory=swap_memory, - name=name, + input_dtypes=input_dtype, frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, fn_tree=fn_tree, + on_device=on_device, + values=input[0], + axis=axis, + direction=direction, + ) + + +# boolean_mask +@handle_frontend_test( + fn_tree="tensorflow.boolean_mask", + dtype_and_values=_boolean_mask_helper(), +) +def test_tensorflow_boolean_mask( + *, + dtype_and_values, + test_flags, + frontend, + backend_fw, + fn_tree, + on_device, +): + input_dtype, tensor, mask, axis = dtype_and_values + helpers.test_frontend_function( + input_dtypes=input_dtype, test_flags=test_flags, + frontend=frontend, + backend_to_test=backend_fw, + fn_tree=fn_tree, + on_device=on_device, + tensor=tensor[0], + mask=mask, + axis=axis, ) -# ones +# clip_by_global_norm @handle_frontend_test( - fn_tree="tensorflow.ones", - shape=helpers.get_shape( - allow_none=False, - min_num_dims=1, - max_num_dims=5, - min_dim_size=1, - max_dim_size=10, - ), - dtype=helpers.get_dtypes("valid", full=False), + fn_tree="tensorflow.clip_by_global_norm", + input_and_norm=_get_global_norm_clip_inputs(), test_with_out=st.just(False), ) -def test_tensorflow_ones( - shape, - dtype, +def test_tensorflow_clip_by_global_norm( + *, + input_and_norm, frontend, backend_fw, test_flags, fn_tree, on_device, ): + ( + t_list_dtype, + t_list, + norm_dtype, + norm, + global_norm_dtype, + global_norm, + ) = input_and_norm + + input_dtypes = [t_list_dtype[0], norm_dtype[0]] + use_norm = None + if global_norm_dtype: + input_dtypes.append(global_norm_dtype[0]) + use_norm = global_norm[0] + helpers.test_frontend_function( - input_dtypes=dtype, + input_dtypes=input_dtypes, frontend=frontend, - backend_to_test=backend_fw, test_flags=test_flags, + backend_to_test=backend_fw, fn_tree=fn_tree, on_device=on_device, - shape=shape, - dtype=dtype[0], + t_list=t_list, + clip_norm=norm[0], + use_norm=use_norm, ) -# full +# clip_by_norm @handle_frontend_test( - fn_tree="tensorflow.fill", - shape=helpers.get_shape(), - input_fill_dtype=_input_fill_and_dtype(), + fn_tree="tensorflow.clip_by_norm", + input_and_norm=_get_norm_clip_inputs(), + test_with_out=st.just(False), ) -def test_tensorflow_fill( - shape, - input_fill_dtype, +def test_tensorflow_clip_by_norm( + *, + input_and_norm, frontend, backend_fw, test_flags, fn_tree, on_device, ): - input_dtype, _, fill, dtype_to_cast = input_fill_dtype + x_dtype, x, axis, norm = input_and_norm helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=[x_dtype, x_dtype], frontend=frontend, - backend_to_test=backend_fw, test_flags=test_flags, + backend_to_test=backend_fw, fn_tree=fn_tree, on_device=on_device, - rtol=1e-05, - dims=shape, - value=fill, + t=x[0], + clip_norm=norm[0], + axes=axis, ) -# einsum +# clip_by_value @handle_frontend_test( - fn_tree="tensorflow.einsum", - eq_n_op_n_shp=st.sampled_from( - [ - ("ii", (np.arange(25).reshape(5, 5),), ()), - ("ii->i", (np.arange(25).reshape(5, 5),), (5,)), - ("ij,j", (np.arange(25).reshape(5, 5), np.arange(5)), (5,)), - ] - ), - dtype=helpers.get_dtypes("float", full=False), + fn_tree="tensorflow.clip_by_value", + input_and_ranges=_get_clip_inputs(), + test_with_out=st.just(False), ) -def test_tensorflow_einsum( +def test_tensorflow_clip_by_value( *, - eq_n_op_n_shp, - dtype, - on_device, - fn_tree, + input_and_ranges, frontend, - backend_fw, test_flags, + backend_fw, + fn_tree, + on_device, ): - eq, operands, _ = eq_n_op_n_shp - kw = {} - i = 0 - for x_ in operands: - kw["x{}".format(i)] = x_ - i += 1 - # len(operands) + 1 because of the equation - test_flags.num_positional_args = len(operands) + 1 + x_dtype, x, min, max = input_and_ranges helpers.test_frontend_function( - input_dtypes=dtype, + input_dtypes=x_dtype, frontend=frontend, backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - equation=eq, - **kw, + t=x[0], + clip_value_min=min, + clip_value_max=max, ) -@st.composite -def _reshape_helper(draw): - shape = draw(helpers.get_shape(min_num_dims=1)) - reshape_shape = draw(helpers.reshape_shapes(shape=shape)) - dtype = draw(helpers.array_dtypes(num_arrays=1)) - x = draw(helpers.array_values(dtype=dtype[0], shape=shape)) - return x, dtype, reshape_shape - - -# reshape +# concat @handle_frontend_test( - fn_tree="tensorflow.reshape", - input_x_shape=_reshape_helper(), + fn_tree="tensorflow.concat", + dtype_input_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=st.integers(min_value=1, max_value=4), + min_num_dims=1, + valid_axis=True, + force_int_axis=True, + shared_dtype=True, + ), test_with_out=st.just(False), ) -def test_tensorflow_reshape( +def test_tensorflow_concat( *, - input_x_shape, + dtype_input_axis, on_device, fn_tree, - backend_fw, frontend, + backend_fw, test_flags, ): - x, x_dtype, shape = input_x_shape + input_dtype, x, axis = dtype_input_axis helpers.test_frontend_function( - input_dtypes=x_dtype, + input_dtypes=input_dtype, frontend=frontend, backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - tensor=x, - shape=shape, - ) - - -@st.composite -def _x_cast_dtype_shape(draw): - x_dtype = draw(helpers.get_dtypes("valid", full=False)) - x_dtype, x = draw( - helpers.dtype_and_values( - dtype=x_dtype, - shape=st.shared(helpers.get_shape(), key="value_shape"), - large_abs_safety_factor=10, - small_abs_safety_factor=10, - safety_factor_scale="log", - ), - ) - to_shape = draw( - helpers.reshape_shapes(shape=st.shared(helpers.get_shape(), key="value_shape")), + values=x, + axis=axis, ) - cast_dtype = x_dtype[0] - # known tensorflow bug when trying to cast to a different type - # https://github.com/tensorflow/tensorflow/issues/39554 - # cast_dtype = draw( - # helpers.get_dtypes("valid", full=False) - # .map(lambda t: t[0]) - # .filter(lambda t: ivy.can_cast(x_dtype[0], t)) - # ) - return x_dtype, x, cast_dtype, to_shape -# size -# output_dtype not generated as tf only accepts tf dtypes +# cond @handle_frontend_test( - fn_tree="tensorflow.size", + fn_tree="tensorflow.cond", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), max_num_dims=4 + available_dtypes=helpers.get_dtypes("numeric"), + min_num_dims=1, + min_dim_size=1, ), - # output_dtype=st.sampled_from(["int32", "int64"]), + pred_cond=st.booleans(), + var=st.integers(min_value=1, max_value=100), test_with_out=st.just(False), ) -def test_tensorflow_size( +def test_tensorflow_cond( *, dtype_and_x, - frontend, - backend_fw, + pred_cond, + var, test_flags, + on_device, fn_tree, - on_device, # output_dtype + frontend, + backend_fw, ): - input_dtype, x = dtype_and_x + _test_true_fn = lambda: var + var + + _test_false_fn = lambda: var * var + + input_dtype, _ = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, + test_flags=test_flags, frontend=frontend, backend_to_test=backend_fw, - test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=x[0], - # out_type=output_dtype, + pred=pred_cond, + true_fn=_test_true_fn, + false_fn=_test_false_fn, ) @@ -591,79 +690,36 @@ def test_tensorflow_convert_to_tensor( ) -# rank +# einsum @handle_frontend_test( - fn_tree="tensorflow.rank", - dtype_and_x=_matrix_rank_helper(), - test_with_out=st.just(False), + fn_tree="tensorflow.einsum", + eq_n_op_n_shp=st.sampled_from( + [ + ("ii", (np.arange(25).reshape(5, 5),), ()), + ("ii->i", (np.arange(25).reshape(5, 5),), (5,)), + ("ij,j", (np.arange(25).reshape(5, 5), np.arange(5)), (5,)), + ] + ), + dtype=helpers.get_dtypes("float", full=False), ) -def test_tensorflow_rank( +def test_tensorflow_einsum( *, - dtype_and_x, - backend_fw, - on_device, - fn_tree, - frontend, - test_flags, -): - dtype, x, _ = dtype_and_x - helpers.test_frontend_function( - input_dtypes=dtype, - frontend=frontend, - backend_to_test=backend_fw, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=x[0], - ) - - -# ones_like -@handle_frontend_test( - fn_tree="tensorflow.ones_like", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("valid")), - dtype=helpers.get_dtypes("valid", full=False), - test_with_out=st.just(False), -) -def test_tensorflow_ones_like( - dtype_and_x, + eq_n_op_n_shp, dtype, - frontend, - backend_fw, - test_flags, - fn_tree, on_device, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - frontend=frontend, - backend_to_test=backend_fw, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=x[0], - dtype=dtype[0], - ) - - -# identity -@handle_frontend_test( - fn_tree="tensorflow.identity", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - ), - test_with_out=st.just(False), -) -def test_tensorflow_identity( - dtype_and_x, + fn_tree, frontend, backend_fw, test_flags, - fn_tree, - on_device, ): - dtype, x = dtype_and_x + eq, operands, _ = eq_n_op_n_shp + kw = {} + i = 0 + for x_ in operands: + kw["x{}".format(i)] = x_ + i += 1 + # len(operands) + 1 because of the equation + test_flags.num_positional_args = len(operands) + 1 helpers.test_frontend_function( input_dtypes=dtype, frontend=frontend, @@ -671,38 +727,32 @@ def test_tensorflow_identity( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=x[0], + equation=eq, + **kw, ) -# zeros_like @handle_frontend_test( - fn_tree="tensorflow.zeros_like", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric") - ), - dtype=helpers.get_dtypes("numeric", full=False), - test_with_out=st.just(False), + fn_tree="tensorflow.ensure_shape", + dtype_and_x=_array_and_shape(min_num_dims=0, max_num_dims=5), ) -def test_tensorflow_zeros_like( +def test_tensorflow_ensure_shape( + *, dtype_and_x, - dtype, + fn_tree, frontend, backend_fw, test_flags, - fn_tree, - on_device, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, frontend=frontend, - backend_to_test=backend_fw, test_flags=test_flags, + backend_to_test=backend_fw, fn_tree=fn_tree, - on_device=on_device, - input=x[0], - dtype=dtype[0], + x=x[0], + shape=x[1], ) @@ -742,23 +792,28 @@ def test_tensorflow_expand_dims( ) -# identity_n +# eye @handle_frontend_test( - fn_tree="tensorflow.identity_n", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), max_num_dims=5 + fn_tree="tensorflow.eye", + n_rows=helpers.ints(min_value=0, max_value=10), + n_cols=st.none() | helpers.ints(min_value=0, max_value=10), + batch_shape=st.lists( + helpers.ints(min_value=1, max_value=10), min_size=1, max_size=2 ), - test_with_out=st.just(False), + dtype=helpers.get_dtypes("valid", full=False), ) -def test_tensorflow_identity_n( - dtype_and_x, +def test_tensorflow_eye( + *, + n_rows, + n_cols, + batch_shape, + dtype, frontend, - test_flags, backend_fw, + test_flags, fn_tree, on_device, ): - dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=dtype, frontend=frontend, @@ -766,76 +821,29 @@ def test_tensorflow_identity_n( backend_to_test=backend_fw, fn_tree=fn_tree, on_device=on_device, - input=x, - ) - - -# Squeeze -@st.composite -def _squeeze_helper(draw): - shape = draw(st.shared(helpers.get_shape(), key="value_shape")) - valid_axes = [] - for index, axis in enumerate(shape): - if axis == 1: - valid_axes.append(index) - valid_axes.insert(0, None) - return draw(st.sampled_from(valid_axes)) - - -@handle_frontend_test( - fn_tree="tensorflow.squeeze", - dtype_value=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - shape=st.shared(helpers.get_shape(), key="value_shape"), - ), - axis=_squeeze_helper(), -) -def test_tensorflow_squeeze_general( - *, - dtype_value, - axis, - on_device, - backend_fw, - fn_tree, - frontend, - test_flags, -): - dtype, xs = dtype_value - helpers.test_frontend_function( - input_dtypes=dtype, - frontend=frontend, - backend_to_test=backend_fw, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=xs[0], - axis=axis, + num_rows=n_rows, + num_columns=n_cols, + batch_shape=batch_shape, + dtype=dtype[0], ) -# concat +# full @handle_frontend_test( - fn_tree="tensorflow.concat", - dtype_input_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=st.integers(min_value=1, max_value=4), - min_num_dims=1, - valid_axis=True, - force_int_axis=True, - shared_dtype=True, - ), - test_with_out=st.just(False), + fn_tree="tensorflow.fill", + shape=helpers.get_shape(), + input_fill_dtype=_input_fill_and_dtype(), ) -def test_tensorflow_concat( - *, - dtype_input_axis, - on_device, - fn_tree, +def test_tensorflow_fill( + shape, + input_fill_dtype, frontend, backend_fw, test_flags, + fn_tree, + on_device, ): - input_dtype, x, axis = dtype_input_axis + input_dtype, _, fill, dtype_to_cast = input_fill_dtype helpers.test_frontend_function( input_dtypes=input_dtype, frontend=frontend, @@ -843,304 +851,304 @@ def test_tensorflow_concat( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - values=x, - axis=axis, + rtol=1e-05, + dims=shape, + value=fill, ) -# cond +# foldl @handle_frontend_test( - fn_tree="tensorflow.cond", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - min_num_dims=1, + fn_tree="tensorflow.foldl", + fn=st.sampled_from( + [ + lambda a, b: a + b, + lambda a, b: a - b, + lambda a, b: a * b, + ], + ), + initializer=st.one_of(st.none(), st.floats(min_value=-1000, max_value=1000)), + dtype_and_values=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float", full=False), + min_value=-1000, + max_value=1000, + max_dim_size=10, + max_num_dims=4, min_dim_size=1, + min_num_dims=1, ), - pred_cond=st.booleans(), - var=st.integers(min_value=1, max_value=100), - test_with_out=st.just(False), + parallel_iterations=st.just(10), + swap_memory=st.booleans(), + name=st.none(), ) -def test_tensorflow_cond( - *, - dtype_and_x, - pred_cond, - var, - test_flags, - on_device, - fn_tree, +def test_tensorflow_foldl( + *, + fn, + initializer, + dtype_and_values, frontend, backend_fw, + fn_tree, + test_flags, + parallel_iterations, + swap_memory, + name, ): - _test_true_fn = lambda: var + var - - _test_false_fn = lambda: var * var - - input_dtype, _ = dtype_and_x + dtype, elems = dtype_and_values + elems = np.atleast_1d(elems) helpers.test_frontend_function( - input_dtypes=input_dtype, - test_flags=test_flags, - frontend=frontend, + input_dtypes=dtype, + fn=fn, + elems=elems, + initializer=initializer, backend_to_test=backend_fw, + parallel_iterations=parallel_iterations, + swap_memory=swap_memory, + name=name, + frontend=frontend, fn_tree=fn_tree, - on_device=on_device, - pred=pred_cond, - true_fn=_test_true_fn, - false_fn=_test_false_fn, + test_flags=test_flags, ) -# zeros +# gather @handle_frontend_test( - fn_tree="tensorflow.zeros", - input=helpers.get_shape( - allow_none=False, - min_num_dims=0, - max_num_dims=10, - min_dim_size=0, + fn_tree="tensorflow.gather", + params_indices_axis_batch_dims=helpers.array_indices_axis( + array_dtypes=helpers.get_dtypes("valid"), + indices_dtypes=["int64"], + min_num_dims=1, + max_num_dims=5, + min_dim_size=1, max_dim_size=10, + indices_same_dims=True, ), - dtype=helpers.get_dtypes("valid", full=False), ) -def test_tensorflow_zeros( +def test_tensorflow_gather( *, - input, - dtype, + params_indices_axis_batch_dims, + on_device, + fn_tree, frontend, backend_fw, test_flags, - fn_tree, - on_device, ): + input_dtypes, params, indices, axis, batch_dims = params_indices_axis_batch_dims helpers.test_frontend_function( - shape=input, - input_dtypes=dtype, + input_dtypes=input_dtypes, frontend=frontend, backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, + params=params, + indices=indices, + axis=axis, + batch_dims=batch_dims, ) -# shape +# gather_nd @handle_frontend_test( - fn_tree="tensorflow.shape", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("valid")), - output_dtype=st.sampled_from(["int32", "int64"]), + fn_tree="tensorflow.gather_nd", + params_indices_axis_batch_dims=helpers.array_indices_axis( + array_dtypes=helpers.get_dtypes("valid"), + indices_dtypes=["int64"], + min_num_dims=5, + max_num_dims=10, + min_dim_size=1, + max_dim_size=5, + indices_same_dims=False, + ), ) -def test_tensorflow_shape( +def test_tensorflow_gather_nd( *, - dtype_and_x, - output_dtype, + params_indices_axis_batch_dims, on_device, backend_fw, fn_tree, frontend, test_flags, ): - ( - input_dtype, - x, - ) = dtype_and_x + input_dtypes, params, indices, axis, batch_dims = params_indices_axis_batch_dims helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=input_dtypes, frontend=frontend, backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=x[0], - out_type=output_dtype, + params=params, + indices=indices, + batch_dims=batch_dims, ) +# identity @handle_frontend_test( - fn_tree="tensorflow.shape_n", + fn_tree="tensorflow.identity", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), min_num_dims=1, max_num_dims=5 + available_dtypes=helpers.get_dtypes("numeric"), ), - output_dtype=st.sampled_from(["int32", "int64"]), + test_with_out=st.just(False), ) -def test_tensorflow_shape_n( - *, +def test_tensorflow_identity( dtype_and_x, - output_dtype, - on_device, - fn_tree, frontend, backend_fw, test_flags, + fn_tree, + on_device, ): - input_dtype, input = dtype_and_x + dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=dtype, frontend=frontend, backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=input, - out_type=output_dtype, + input=x[0], ) +# identity_n @handle_frontend_test( - fn_tree="tensorflow.ensure_shape", - dtype_and_x=_array_and_shape(min_num_dims=0, max_num_dims=5), + fn_tree="tensorflow.identity_n", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), max_num_dims=5 + ), + test_with_out=st.just(False), ) -def test_tensorflow_ensure_shape( - *, +def test_tensorflow_identity_n( dtype_and_x, - fn_tree, frontend, - backend_fw, test_flags, + backend_fw, + fn_tree, + on_device, ): - input_dtype, x = dtype_and_x + dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=dtype, frontend=frontend, test_flags=test_flags, backend_to_test=backend_fw, fn_tree=fn_tree, - x=x[0], - shape=x[1], + on_device=on_device, + input=x, ) -# range +# is_tensor @handle_frontend_test( - fn_tree="tensorflow.range", - start=helpers.ints(min_value=-50, max_value=0), - limit=helpers.ints(min_value=1, max_value=50), - delta=helpers.ints(min_value=1, max_value=5), - dtype=helpers.get_dtypes("float"), - test_with_out=st.just(False), + fn_tree="tensorflow.is_tensor", + dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("valid")), ) -def test_tensorflow_range( +def test_tensorflow_is_tensor( *, - start, - limit, - delta, - dtype, - on_device, - fn_tree, - frontend, + dtype_and_x, backend_fw, + frontend, test_flags, + fn_tree, ): + input_dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=[], - on_device=on_device, - fn_tree=fn_tree, + input_dtypes=input_dtype, frontend=frontend, backend_to_test=backend_fw, test_flags=test_flags, - start=start, - limit=limit, - delta=delta, - dtype=dtype[0], + fn_tree=fn_tree, + x=x[0], ) -# sort +# linspace @handle_frontend_test( - fn_tree="tensorflow.sort", - dtype_input_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("numeric"), - min_num_dims=1, - max_num_dims=5, - min_dim_size=1, - max_dim_size=5, - min_axis=-1, - max_axis=0, - ), - descending=st.sampled_from(["ASCENDING", "DESCENDING"]), + fn_tree="tensorflow.linspace", + dtype_and_params=_linspace_helper(), + num=helpers.ints(min_value=2, max_value=10), + axis=helpers.ints(min_value=-1, max_value=0), ) -def test_tensorflow_sort( +def test_tensorflow_linspace( *, - dtype_input_axis, - descending, + dtype_and_params, + num, + axis, on_device, + backend_fw, fn_tree, frontend, - backend_fw, test_flags, ): - input_dtype, input, axis = dtype_input_axis + dtype, start, stop = dtype_and_params helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=dtype, frontend=frontend, backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, - on_device=on_device, - values=input[0], + start=start, + stop=stop, + num=num, axis=axis, - direction=descending, + on_device=on_device, ) -# searchsorted +# no_op @handle_frontend_test( - fn_tree="tensorflow.searchsorted", - dtype_x_v=helpers.dtype_and_values( + fn_tree="tensorflow.no_op", + dtype=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - shared_dtype=True, - min_num_dims=1, - max_num_dims=1, - num_arrays=2, ), - side=st.sampled_from(["left", "right"]), - out_type=st.sampled_from(["int32", "int64"]), + test_with_out=st.just(False), ) -def test_tensorflow_searchsorted( - dtype_x_v, - side, - out_type, +def test_tensorflow_no_op( + *, + dtype, frontend, backend_fw, test_flags, fn_tree, - on_device, ): - input_dtypes, xs = dtype_x_v helpers.test_frontend_function( - input_dtypes=input_dtypes, + input_dtypes=dtype, frontend=frontend, backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, - on_device=on_device, - sorted_sequence=np.sort(xs[0]), - values=xs[1], - side=side, - out_type=out_type, ) -# stack @handle_frontend_test( - fn_tree="tensorflow.stack", + fn_tree="tensorflow.norm", + aliases=["tensorflow.norm"], dtype_values_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=st.shared(helpers.ints(min_value=2, max_value=4), key="num_arrays"), - shape=helpers.get_shape(min_num_dims=1), - shared_dtype=True, - valid_axis=True, - allow_neg_axes=True, - force_int_axis=True, + available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=3, + max_num_dims=5, + min_dim_size=1, + max_dim_size=4, + min_axis=-3, + max_axis=2, ), + ord=st.sampled_from([1, 2, np.inf]), + keepdims=st.booleans(), ) -def test_tensorflow_stack( +def test_tensorflow_norm( + *, dtype_values_axis, - on_device, - fn_tree, - frontend, + ord, + keepdims, backend_fw, + frontend, test_flags, + fn_tree, + on_device, ): - input_dtype, values, axis = dtype_values_axis + input_dtype, x, axis = dtype_values_axis helpers.test_frontend_function( input_dtypes=input_dtype, frontend=frontend, @@ -1148,142 +1156,107 @@ def test_tensorflow_stack( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - values=values, + tensor=x[0], + ord=ord, axis=axis, + keepdims=keepdims, ) -# is_tensor +# one_hot @handle_frontend_test( - fn_tree="tensorflow.is_tensor", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("valid")), + fn_tree="tensorflow.one_hot", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("integer"), + num_arrays=1, + min_value=0, + max_value=10, + ), ) -def test_tensorflow_is_tensor( +def test_tensorflow_one_hot( *, dtype_and_x, - backend_fw, frontend, - test_flags, + backend_fw, fn_tree, + test_flags, + on_device, ): input_dtype, x = dtype_and_x + depth = 10 helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=["uint8", "int32", "int64"], + test_flags=test_flags, frontend=frontend, backend_to_test=backend_fw, - test_flags=test_flags, fn_tree=fn_tree, - x=x[0], + on_device=on_device, + indices=x[0], + depth=depth, ) -# gather +# ones @handle_frontend_test( - fn_tree="tensorflow.gather", - params_indices_axis_batch_dims=helpers.array_indices_axis( - array_dtypes=helpers.get_dtypes("valid"), - indices_dtypes=["int64"], + fn_tree="tensorflow.ones", + shape=helpers.get_shape( + allow_none=False, min_num_dims=1, max_num_dims=5, min_dim_size=1, max_dim_size=10, - indices_same_dims=True, ), + dtype=helpers.get_dtypes("valid", full=False), + test_with_out=st.just(False), ) -def test_tensorflow_gather( - *, - params_indices_axis_batch_dims, - on_device, - fn_tree, +def test_tensorflow_ones( + shape, + dtype, frontend, backend_fw, test_flags, + fn_tree, + on_device, ): - input_dtypes, params, indices, axis, batch_dims = params_indices_axis_batch_dims helpers.test_frontend_function( - input_dtypes=input_dtypes, + input_dtypes=dtype, frontend=frontend, backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - params=params, - indices=indices, - axis=axis, - batch_dims=batch_dims, + shape=shape, + dtype=dtype[0], ) -# gather_nd +# ones_like @handle_frontend_test( - fn_tree="tensorflow.gather_nd", - params_indices_axis_batch_dims=helpers.array_indices_axis( - array_dtypes=helpers.get_dtypes("valid"), - indices_dtypes=["int64"], - min_num_dims=5, - max_num_dims=10, - min_dim_size=1, - max_dim_size=5, - indices_same_dims=False, - ), + fn_tree="tensorflow.ones_like", + dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("valid")), + dtype=helpers.get_dtypes("valid", full=False), + test_with_out=st.just(False), ) -def test_tensorflow_gather_nd( - *, - params_indices_axis_batch_dims, - on_device, - backend_fw, - fn_tree, +def test_tensorflow_ones_like( + dtype_and_x, + dtype, frontend, + backend_fw, test_flags, + fn_tree, + on_device, ): - input_dtypes, params, indices, axis, batch_dims = params_indices_axis_batch_dims + input_dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=input_dtypes, + input_dtypes=input_dtype, frontend=frontend, backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - params=params, - indices=indices, - batch_dims=batch_dims, - ) - - -@st.composite -def _pad_helper(draw): - mode = draw( - st.sampled_from( - [ - "CONSTANT", - "REFLECT", - "SYMMETRIC", - ] - ) - ) - dtype, input, shape = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - ret_shape=True, - min_num_dims=1, - min_value=-100, - max_value=100, - ) - ) - ndim = len(shape) - min_dim = min(shape) - paddings = draw( - st.lists( - st.tuples( - st.integers(min_value=0, max_value=min_dim - 1), - st.integers(min_value=0, max_value=min_dim - 1), - ), - min_size=ndim, - max_size=ndim, - ) + input=x[0], + dtype=dtype[0], ) - constant_values = draw(st.integers(min_value=0, max_value=4)) - return dtype, input[0], paddings, mode, constant_values # pad @@ -1317,37 +1290,57 @@ def test_tensorflow_pad( ) -# transpose -@st.composite -def _get_perm_helper(draw): - shape = draw(st.shared(helpers.get_shape(min_num_dims=1), key="shape")) - dimensions = [x for x in range(len(shape))] - perm = draw(st.permutations(dimensions)) - return perm +# range +@handle_frontend_test( + fn_tree="tensorflow.range", + start=helpers.ints(min_value=-50, max_value=0), + limit=helpers.ints(min_value=1, max_value=50), + delta=helpers.ints(min_value=1, max_value=5), + dtype=helpers.get_dtypes("float"), + test_with_out=st.just(False), +) +def test_tensorflow_range( + *, + start, + limit, + delta, + dtype, + on_device, + fn_tree, + frontend, + backend_fw, + test_flags, +): + helpers.test_frontend_function( + input_dtypes=[], + on_device=on_device, + fn_tree=fn_tree, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + start=start, + limit=limit, + delta=delta, + dtype=dtype[0], + ) +# rank @handle_frontend_test( - fn_tree="tensorflow.transpose", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), - ), - perm=_get_perm_helper(), - conjugate=st.booleans(), + fn_tree="tensorflow.rank", + dtype_and_x=_matrix_rank_helper(), test_with_out=st.just(False), ) -def test_tensorflow_transpose( +def test_tensorflow_rank( *, dtype_and_x, - perm, - conjugate, - frontend, backend_fw, - test_flags, - fn_tree, on_device, + fn_tree, + frontend, + test_flags, ): - dtype, x = dtype_and_x + dtype, x, _ = dtype_and_x helpers.test_frontend_function( input_dtypes=dtype, frontend=frontend, @@ -1355,445 +1348,337 @@ def test_tensorflow_transpose( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - a=x[0], - perm=perm, - conjugate=conjugate, - ) - - -@st.composite -def _strided_slice_helper(draw): - dtype, x, shape = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - min_num_dims=1, - ret_shape=True, - ), - ) - ndims = len(shape) - masks = draw( - st.lists( - st.integers(min_value=0, max_value=2**ndims - 1), min_size=5, max_size=5 - ).filter(lambda x: bin(x[2])[2:].count("1") <= min(len(shape) - 1, 1)) + input=x[0], ) - begin, end, strides = [], [], [] - for i in shape: - begin += [draw(st.integers(min_value=0, max_value=i - 1))] - end += [draw(st.integers(min_value=0, max_value=i - 1))] - if begin[-1] < end[-1]: - strides += [draw(st.integers(min_value=1, max_value=i))] - else: - strides += [draw(st.integers(max_value=-1, min_value=-i))] - ellipsis_mask = _num_to_bit_list(masks[2], ndims) - for i, v in enumerate(ellipsis_mask): - if v == 1: - skip = draw(st.integers(min_value=0, max_value=ndims)) - begin, end, strides = map( - lambda x: x[:i] + x[i + skip :] if i + skip < ndims else x[:i], - [begin, end, strides], - ) - break - return dtype, x, np.array(begin), np.array(end), np.array(strides), masks -# strided_slice +# realdiv @handle_frontend_test( - fn_tree="tensorflow.strided_slice", - dtype_x_params=_strided_slice_helper(), + fn_tree="tensorflow.realdiv", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float_and_complex"), + num_arrays=2, + min_value=-20, + max_value=20, + shared_dtype=True, + ), test_with_out=st.just(False), ) -def test_tensorflow_strided_slice( +def test_tensorflow_realdiv( *, - dtype_x_params, + dtype_and_x, + test_flags, frontend, backend_fw, - test_flags, fn_tree, on_device, ): - dtype, x, begin, end, strides, masks = dtype_x_params - try: - helpers.test_frontend_function( - input_dtypes=dtype + 3 * ["int64"] + 5 * ["int32"], - frontend=frontend, - backend_to_test=backend_fw, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input_=x[0], - begin=begin, - end=end, - strides=strides, - begin_mask=masks[0], - end_mask=masks[1], - ellipsis_mask=masks[2], - new_axis_mask=masks[3], - shrink_axis_mask=masks[4], - ) - except tf_errors.InvalidArgumentError: - assume(False) - except Exception as e: - if ( - hasattr(e, "message") - and "only stride 1 allowed on non-range indexing" in e.message - ): - assume(False) - raise e - - -@st.composite -def _slice_helper(draw): - dtype, x, shape = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - min_num_dims=1, - ret_shape=True, - ), + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + test_flags=test_flags, + frontend=frontend, + backend_to_test=backend_fw, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + y=x[1], ) - begin, size = [], [] - for i in shape: - begin += [draw(st.integers(min_value=0, max_value=i - 1))] - size += [draw(st.integers(min_value=0, max_value=i - begin[-1]))] - return dtype, x, np.array(begin), np.array(size) -# slice +# repeat @handle_frontend_test( - fn_tree="tensorflow.slice", - dtype_x_params=_slice_helper(), - test_with_out=st.just(False), + fn_tree="tensorflow.repeat", + dtypes_and_value_and_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=1, + max_num_dims=5, + max_dim_size=10, + valid_axis=True, + force_int_axis=True, + ), + repeats=helpers.ints(min_value=1, max_value=5), ) -def test_tensorflow_slice( +def test_tensorflow_repeat( *, - dtype_x_params, + dtypes_and_value_and_axis, + repeats, + on_device, + fn_tree, frontend, backend_fw, test_flags, - fn_tree, - on_device, ): - dtype, x, begin, size = dtype_x_params + input_dtypes, x, axis = dtypes_and_value_and_axis + repeats = repeats helpers.test_frontend_function( - input_dtypes=dtype + 3 * ["int64"], + input_dtypes=input_dtypes, frontend=frontend, backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input_=x[0], - begin=begin, - size=size, - ) - - -@st.composite -def _linspace_helper(draw): - shape = draw( - helpers.get_shape( - allow_none=False, - min_num_dims=0, - max_num_dims=5, - min_dim_size=1, - max_dim_size=10, - ), + input=x[0], + repeats=repeats, + axis=axis, ) - dtype = draw(st.sampled_from(["float32", "float64"])) - - # Param: start - start = draw( - helpers.array_values( - dtype=dtype, - shape=shape, - min_value=-5.0, - max_value=5.0, - ), - ) - # Param:stop - stop = draw( - helpers.array_values( - dtype=dtype, - shape=shape, - min_value=-4.0, - max_value=10.0, - ), +# reshape +@handle_frontend_test( + fn_tree="tensorflow.reshape", + input_x_shape=_reshape_helper(), + test_with_out=st.just(False), +) +def test_tensorflow_reshape( + *, + input_x_shape, + on_device, + fn_tree, + backend_fw, + frontend, + test_flags, +): + x, x_dtype, shape = input_x_shape + helpers.test_frontend_function( + input_dtypes=x_dtype, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + tensor=x, + shape=shape, ) - return [dtype] * 2, start, stop - -# linspace @handle_frontend_test( - fn_tree="tensorflow.linspace", - dtype_and_params=_linspace_helper(), - num=helpers.ints(min_value=2, max_value=10), - axis=helpers.ints(min_value=-1, max_value=0), + fn_tree="tensorflow.reverse", + dtype_x_axis=reverse_helper(), ) -def test_tensorflow_linspace( +def test_tensorflow_reverse( *, - dtype_and_params, - num, - axis, - on_device, + dtype_x_axis, + frontend, backend_fw, fn_tree, - frontend, test_flags, + on_device, ): - dtype, start, stop = dtype_and_params + dtype, x, axis_dtype, axis = dtype_x_axis helpers.test_frontend_function( - input_dtypes=dtype, + input_dtypes=dtype + axis_dtype, + test_flags=test_flags, frontend=frontend, backend_to_test=backend_fw, - test_flags=test_flags, fn_tree=fn_tree, - start=start, - stop=stop, - num=num, - axis=axis, on_device=on_device, + tensor=x[0], + axis=axis[0], ) -# no_op +# roll @handle_frontend_test( - fn_tree="tensorflow.no_op", - dtype=helpers.dtype_and_values( + fn_tree="tensorflow.roll", + dtype_and_values=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), + shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), + ), + shift=helpers.get_axis( + shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), + force_tuple=True, + ), + axis=helpers.get_axis( + shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), + force_tuple=True, ), - test_with_out=st.just(False), ) -def test_tensorflow_no_op( +def test_tensorflow_roll( *, - dtype, - frontend, + dtype_and_values, + shift, + axis, + on_device, + fn_tree, backend_fw, + frontend, test_flags, - fn_tree, ): + input_dtype, value = dtype_and_values + if isinstance(shift, int) and isinstance(axis, tuple): + axis = axis[0] + if isinstance(shift, tuple) and isinstance(axis, tuple): + if len(shift) != len(axis): + mn = min(len(shift), len(axis)) + shift = shift[:mn] + axis = axis[:mn] helpers.test_frontend_function( - input_dtypes=dtype, + input_dtypes=input_dtype, frontend=frontend, backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, + on_device=on_device, + input=value[0], + shift=shift, + axis=axis, ) -# realdiv +# scan @handle_frontend_test( - fn_tree="tensorflow.realdiv", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float_and_complex"), - num_arrays=2, - min_value=-20, - max_value=20, - shared_dtype=True, + fn_tree="tensorflow.scan", + dtypes_values=helpers.dtype_and_values( + available_dtypes=["float32"], num_arrays=1, min_num_dims=2, max_dim_size=3 ), test_with_out=st.just(False), ) -def test_tensorflow_realdiv( +def test_tensorflow_scan( *, - dtype_and_x, - test_flags, + dtypes_values, + on_device, + fn_tree, frontend, backend_fw, - fn_tree, - on_device, + test_flags, ): - input_dtype, x = dtype_and_x + def _test_fn(a, x): + return a + x + + x_dtype, elems = dtypes_values helpers.test_frontend_function( - input_dtypes=input_dtype, - test_flags=test_flags, + input_dtypes=x_dtype, frontend=frontend, backend_to_test=backend_fw, + test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], - y=x[1], - ) - - -# tile -@st.composite -def _multiple_shape_helper(draw): - input_dtype, input_array, input_shape = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), ret_shape=True - ) - ) - input_dims = len(input_shape) - - dt_n_multiples = draw( - helpers.dtype_and_values( - available_dtypes=["int32", "int64"], - min_value=0, - max_value=10, - shape=draw( - helpers.get_shape( - min_num_dims=1, - max_num_dims=1, - min_dim_size=input_dims, - max_dim_size=input_dims, - ) - ), - ) + fn=_test_fn, + elems=elems[0], ) - return input_dtype, input_array, dt_n_multiples -@handle_frontend_test(fn_tree="tensorflow.tile", all_arguments=_multiple_shape_helper()) -def test_tensorflow_tile( - *, - all_arguments, - test_flags, +# searchsorted +@handle_frontend_test( + fn_tree="tensorflow.searchsorted", + dtype_x_v=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + shared_dtype=True, + min_num_dims=1, + max_num_dims=1, + num_arrays=2, + ), + side=st.sampled_from(["left", "right"]), + out_type=st.sampled_from(["int32", "int64"]), +) +def test_tensorflow_searchsorted( + dtype_x_v, + side, + out_type, frontend, + backend_fw, + test_flags, fn_tree, on_device, - backend_fw, ): - input_dtype, input_matrix, dt_and_multiples = all_arguments - dt_mul, multiples = dt_and_multiples + input_dtypes, xs = dtype_x_v helpers.test_frontend_function( - input_dtypes=input_dtype + dt_mul, - input=input_matrix[0], - multiples=multiples[0], - test_flags=test_flags, - backend_to_test=backend_fw, + input_dtypes=input_dtypes, frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, + sorted_sequence=np.sort(xs[0]), + values=xs[1], + side=side, + out_type=out_type, ) -# one_hot +# shape @handle_frontend_test( - fn_tree="tensorflow.one_hot", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("integer"), - num_arrays=1, - min_value=0, - max_value=10, - ), + fn_tree="tensorflow.shape", + dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("valid")), + output_dtype=st.sampled_from(["int32", "int64"]), ) -def test_tensorflow_one_hot( +def test_tensorflow_shape( *, dtype_and_x, - frontend, + output_dtype, + on_device, backend_fw, fn_tree, + frontend, test_flags, - on_device, ): - input_dtype, x = dtype_and_x - depth = 10 + ( + input_dtype, + x, + ) = dtype_and_x helpers.test_frontend_function( - input_dtypes=["uint8", "int32", "int64"], - test_flags=test_flags, + input_dtypes=input_dtype, frontend=frontend, backend_to_test=backend_fw, + test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - indices=x[0], - depth=depth, - ) - - -@st.composite -def _boolean_mask_helper(draw): - tensor_shape = draw( - helpers.get_shape( - allow_none=False, - min_num_dims=3, - max_num_dims=5, - min_dim_size=1, - max_dim_size=10, - ), - ) - - dtype = draw(st.sampled_from(["float32", "float64"])) - - # Param: tensor - # tensor = draw( - # helpers.array_values( - # dtype=dtype, shape=tensor_shape, min_value=-5.0, max_value=5.0 - # ), - # ) - - dtype, tensor, axis = draw( - helpers.dtype_values_axis( - available_dtypes=[dtype], - shape=tensor_shape, - min_value=-5.0, - max_value=5.0, - force_int_axis=True, - valid_axis=True, - ) - ) - mask_dim = draw(helpers.ints(min_value=1, max_value=len(tensor_shape) - axis)) - mask_shape = tensor_shape[axis : mask_dim + axis] - - # Param:stop - mask = draw( - helpers.array_values( - allow_nan=False, - dtype="bool", - shape=mask_shape, - ), + input=x[0], + out_type=output_dtype, ) - return [dtype[0], "bool"], tensor, mask, axis -# boolean_mask @handle_frontend_test( - fn_tree="tensorflow.boolean_mask", - dtype_and_values=_boolean_mask_helper(), + fn_tree="tensorflow.shape_n", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), min_num_dims=1, max_num_dims=5 + ), + output_dtype=st.sampled_from(["int32", "int64"]), ) -def test_tensorflow_boolean_mask( +def test_tensorflow_shape_n( *, - dtype_and_values, - test_flags, + dtype_and_x, + output_dtype, + on_device, + fn_tree, frontend, backend_fw, - fn_tree, - on_device, + test_flags, ): - input_dtype, tensor, mask, axis = dtype_and_values + input_dtype, input = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, - test_flags=test_flags, frontend=frontend, backend_to_test=backend_fw, + test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - tensor=tensor[0], - mask=mask, - axis=axis, + input=input, + out_type=output_dtype, ) -# where +# size +# output_dtype not generated as tf only accepts tf dtypes @handle_frontend_test( - fn_tree="tensorflow.where", - dtype_and_input=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("integer"), - num_arrays=1, - min_value=0, - max_value=10, - min_num_dims=1, + fn_tree="tensorflow.size", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), max_num_dims=4 ), + # output_dtype=st.sampled_from(["int32", "int64"]), + test_with_out=st.just(False), ) -def test_tensorflow_where_no_xy( +def test_tensorflow_size( *, - dtype_and_input, + dtype_and_x, frontend, backend_fw, - fn_tree, test_flags, - on_device, + fn_tree, + on_device, # output_dtype ): - input_dtype, [condition] = dtype_and_input + input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, frontend=frontend, @@ -1801,89 +1686,65 @@ def test_tensorflow_where_no_xy( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - condition=condition, + input=x[0], + # out_type=output_dtype, ) -# where +# slice @handle_frontend_test( - fn_tree="tensorflow.where", - dtype_and_input=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("bool"), - num_arrays=3, - min_value=0, - max_value=10, - min_num_dims=1, - ), - dim_remove_from_x=st.integers(), - dim_remove_from_y=st.integers(), + fn_tree="tensorflow.slice", + dtype_x_params=_slice_helper(), + test_with_out=st.just(False), ) -def test_tensorflow_where_with_xy( +def test_tensorflow_slice( *, - dtype_and_input, - dim_remove_from_x, - dim_remove_from_y, + dtype_x_params, frontend, backend_fw, - fn_tree, test_flags, + fn_tree, on_device, ): - input_dtype, [condition, x, y] = dtype_and_input - if input_dtype != ["bool", "bool", "bool"]: - return - for _ in range(min(len(x.shape) - 1, dim_remove_from_x)): - x = x[0] - for _ in range(min(len(y.shape) - 1, dim_remove_from_y)): - y = y[0] + dtype, x, begin, size = dtype_x_params helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=dtype + 3 * ["int64"], frontend=frontend, backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - condition=condition, - x=x, - y=y, + input_=x[0], + begin=begin, + size=size, ) -# roll +# sort @handle_frontend_test( - fn_tree="tensorflow.roll", - dtype_and_values=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), - ), - shift=helpers.get_axis( - shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), - force_tuple=True, - ), - axis=helpers.get_axis( - shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), - force_tuple=True, + fn_tree="tensorflow.sort", + dtype_input_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("numeric"), + min_num_dims=1, + max_num_dims=5, + min_dim_size=1, + max_dim_size=5, + min_axis=-1, + max_axis=0, ), + descending=st.sampled_from(["ASCENDING", "DESCENDING"]), ) -def test_tensorflow_roll( +def test_tensorflow_sort( *, - dtype_and_values, - shift, - axis, + dtype_input_axis, + descending, on_device, fn_tree, - backend_fw, frontend, + backend_fw, test_flags, ): - input_dtype, value = dtype_and_values - if isinstance(shift, int) and isinstance(axis, tuple): - axis = axis[0] - if isinstance(shift, tuple) and isinstance(axis, tuple): - if len(shift) != len(axis): - mn = min(len(shift), len(axis)) - shift = shift[:mn] - axis = axis[:mn] + input_dtype, input, axis = dtype_input_axis helpers.test_frontend_function( input_dtypes=input_dtype, frontend=frontend, @@ -1891,9 +1752,9 @@ def test_tensorflow_roll( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=value[0], - shift=shift, + values=input[0], axis=axis, + direction=descending, ) @@ -1939,191 +1800,231 @@ def test_tensorflow_split( ) -# repeat @handle_frontend_test( - fn_tree="tensorflow.repeat", - dtypes_and_value_and_axis=helpers.dtype_values_axis( + fn_tree="tensorflow.squeeze", + dtype_value=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), - min_num_dims=1, - max_num_dims=5, - max_dim_size=10, - valid_axis=True, - force_int_axis=True, + shape=st.shared(helpers.get_shape(), key="value_shape"), ), - repeats=helpers.ints(min_value=1, max_value=5), + axis=_squeeze_helper(), ) -def test_tensorflow_repeat( +def test_tensorflow_squeeze_general( *, - dtypes_and_value_and_axis, - repeats, + dtype_value, + axis, on_device, + backend_fw, fn_tree, frontend, - backend_fw, test_flags, ): - input_dtypes, x, axis = dtypes_and_value_and_axis - repeats = repeats + dtype, xs = dtype_value helpers.test_frontend_function( - input_dtypes=input_dtypes, + input_dtypes=dtype, frontend=frontend, backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=x[0], - repeats=repeats, + input=xs[0], axis=axis, ) -# unstack +# stack @handle_frontend_test( - fn_tree="tensorflow.unstack", - dtypes_values=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - min_num_dims=2, - max_num_dims=2, - max_dim_size=1, + fn_tree="tensorflow.stack", + dtype_values_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("float"), + num_arrays=st.shared(helpers.ints(min_value=2, max_value=4), key="num_arrays"), + shape=helpers.get_shape(min_num_dims=1), + shared_dtype=True, + valid_axis=True, + allow_neg_axes=True, + force_int_axis=True, ), - number_positional_args=st.just(1), - axis=st.integers(-1, 0), - test_with_out=st.just(False), ) -def test_tensorflow_unstack( - *, - dtypes_values, - axis, +def test_tensorflow_stack( + dtype_values_axis, on_device, fn_tree, - backend_fw, frontend, + backend_fw, test_flags, ): - x_dtype, x = dtypes_values - axis = axis + input_dtype, values, axis = dtype_values_axis helpers.test_frontend_function( - input_dtypes=x_dtype, - backend_to_test=backend_fw, + input_dtypes=input_dtype, frontend=frontend, + backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - value=x[0], + values=values, axis=axis, ) -# reverse -@st.composite -def reverse_helper(draw): - dtype, x, shape = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - min_num_dims=1, - max_num_dims=8, - ret_shape=True, - ) - ) - axis_dtype, axis = draw( - helpers.dtype_and_values( - available_dtypes=["int32", "int64"], - min_num_dims=1, - max_num_dims=1, - min_value=-(len(shape) - 1), - max_value=len(shape) - 1, - shape=(1,), +# strided_slice +@handle_frontend_test( + fn_tree="tensorflow.strided_slice", + dtype_x_params=_strided_slice_helper(), + test_with_out=st.just(False), +) +def test_tensorflow_strided_slice( + *, + dtype_x_params, + frontend, + backend_fw, + test_flags, + fn_tree, + on_device, +): + dtype, x, begin, end, strides, masks = dtype_x_params + try: + helpers.test_frontend_function( + input_dtypes=dtype + 3 * ["int64"] + 5 * ["int32"], + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input_=x[0], + begin=begin, + end=end, + strides=strides, + begin_mask=masks[0], + end_mask=masks[1], + ellipsis_mask=masks[2], + new_axis_mask=masks[3], + shrink_axis_mask=masks[4], ) + except tf_errors.InvalidArgumentError: + assume(False) + except Exception as e: + if ( + hasattr(e, "message") + and "only stride 1 allowed on non-range indexing" in e.message + ): + assume(False) + raise e + + +@handle_frontend_test(fn_tree="tensorflow.tile", all_arguments=_multiple_shape_helper()) +def test_tensorflow_tile( + *, + all_arguments, + test_flags, + frontend, + fn_tree, + on_device, + backend_fw, +): + input_dtype, input_matrix, dt_and_multiples = all_arguments + dt_mul, multiples = dt_and_multiples + helpers.test_frontend_function( + input_dtypes=input_dtype + dt_mul, + input=input_matrix[0], + multiples=multiples[0], + test_flags=test_flags, + backend_to_test=backend_fw, + frontend=frontend, + fn_tree=fn_tree, + on_device=on_device, ) - return dtype, x, axis_dtype, axis @handle_frontend_test( - fn_tree="tensorflow.reverse", - dtype_x_axis=reverse_helper(), + fn_tree="tensorflow.transpose", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), + ), + perm=_get_perm_helper(), + conjugate=st.booleans(), + test_with_out=st.just(False), ) -def test_tensorflow_reverse( +def test_tensorflow_transpose( *, - dtype_x_axis, + dtype_and_x, + perm, + conjugate, frontend, backend_fw, - fn_tree, test_flags, + fn_tree, on_device, ): - dtype, x, axis_dtype, axis = dtype_x_axis + dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=dtype + axis_dtype, - test_flags=test_flags, + input_dtypes=dtype, frontend=frontend, backend_to_test=backend_fw, + test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - tensor=x[0], - axis=axis[0], + a=x[0], + perm=perm, + conjugate=conjugate, ) -# scan +# truncatediv @handle_frontend_test( - fn_tree="tensorflow.scan", - dtypes_values=helpers.dtype_and_values( - available_dtypes=["float32"], num_arrays=1, min_num_dims=2, max_dim_size=3 + fn_tree="tensorflow.truncatediv", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, + min_value=-20, + max_value=20, + shared_dtype=True, ), test_with_out=st.just(False), ) -def test_tensorflow_scan( +def test_tensorflow_truncatediv( *, - dtypes_values, - on_device, - fn_tree, + dtype_and_x, + test_flags, frontend, backend_fw, - test_flags, + fn_tree, + on_device, ): - def _test_fn(a, x): - return a + x - - x_dtype, elems = dtypes_values + input_dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=x_dtype, + input_dtypes=input_dtype, + test_flags=test_flags, frontend=frontend, backend_to_test=backend_fw, - test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - fn=_test_fn, - elems=elems[0], + x=x[0], + y=x[1], ) +# Truncatemod @handle_frontend_test( - fn_tree="tensorflow.norm", - aliases=["tensorflow.norm"], - dtype_values_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("valid"), - min_num_dims=3, - max_num_dims=5, - min_dim_size=1, - max_dim_size=4, - min_axis=-3, - max_axis=2, + fn_tree="tensorflow.truncatemod", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, + shared_dtype=True, ), - ord=st.sampled_from([1, 2, np.inf]), - keepdims=st.booleans(), + test_with_out=st.just(False), ) -def test_tensorflow_norm( +def test_tensorflow_truncatemod( *, - dtype_values_axis, - ord, - keepdims, - backend_fw, + dtype_and_x, frontend, + backend_fw, test_flags, fn_tree, on_device, ): - input_dtype, x, axis = dtype_values_axis + input_dtype, x = dtype_and_x + assume(not np.any(np.isclose(x[0], 0))) + assume(not np.any(np.isclose(x[1], 0))) helpers.test_frontend_function( input_dtypes=input_dtype, frontend=frontend, @@ -2131,10 +2032,8 @@ def test_tensorflow_norm( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - tensor=x[0], - ord=ord, - axis=axis, - keepdims=keepdims, + x=x[0], + y=x[1], ) @@ -2209,6 +2108,143 @@ def test_tensorflow_unique_with_counts( ) +@handle_frontend_test( + fn_tree="tensorflow.unravel_index", + indices=helpers.array_values( + dtype=helpers.get_dtypes("integer"), shape=(1, 2), min_value=0, max_value=49 + ), + dims=helpers.array_values( + dtype=helpers.get_dtypes("integer"), shape=(1, 2), min_value=50 + ), +) +def test_tensorflow_unravel_index( + *, indices, dims, frontend, test_flags, fn_tree, on_device, backend_fw +): + helpers.test_frontend_function( + input_dtypes=["int32"], + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + indices=indices[0], + dims=dims[0], + ) + + +# unstack +@handle_frontend_test( + fn_tree="tensorflow.unstack", + dtypes_values=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + min_num_dims=2, + max_num_dims=2, + max_dim_size=1, + ), + number_positional_args=st.just(1), + axis=st.integers(-1, 0), + test_with_out=st.just(False), +) +def test_tensorflow_unstack( + *, + dtypes_values, + axis, + on_device, + fn_tree, + backend_fw, + frontend, + test_flags, +): + x_dtype, x = dtypes_values + axis = axis + helpers.test_frontend_function( + input_dtypes=x_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + value=x[0], + axis=axis, + ) + + +# where +@handle_frontend_test( + fn_tree="tensorflow.where", + dtype_and_input=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("integer"), + num_arrays=1, + min_value=0, + max_value=10, + min_num_dims=1, + ), +) +def test_tensorflow_where_no_xy( + *, + dtype_and_input, + frontend, + backend_fw, + fn_tree, + test_flags, + on_device, +): + input_dtype, [condition] = dtype_and_input + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + condition=condition, + ) + + +# where +@handle_frontend_test( + fn_tree="tensorflow.where", + dtype_and_input=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("bool"), + num_arrays=3, + min_value=0, + max_value=10, + min_num_dims=1, + ), + dim_remove_from_x=st.integers(), + dim_remove_from_y=st.integers(), +) +def test_tensorflow_where_with_xy( + *, + dtype_and_input, + dim_remove_from_x, + dim_remove_from_y, + frontend, + backend_fw, + fn_tree, + test_flags, + on_device, +): + input_dtype, [condition, x, y] = dtype_and_input + if input_dtype != ["bool", "bool", "bool"]: + return + for _ in range(min(len(x.shape) - 1, dim_remove_from_x)): + x = x[0] + for _ in range(min(len(y.shape) - 1, dim_remove_from_y)): + y = y[0] + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + condition=condition, + x=x, + y=y, + ) + + @handle_frontend_test( fn_tree="tensorflow.while_loop", dtype_and_x=helpers.dtype_and_values( @@ -2263,53 +2299,51 @@ def _test_body_fn(x): ) -# truncatediv +# zeros @handle_frontend_test( - fn_tree="tensorflow.truncatediv", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - min_value=-20, - max_value=20, - shared_dtype=True, + fn_tree="tensorflow.zeros", + input=helpers.get_shape( + allow_none=False, + min_num_dims=0, + max_num_dims=10, + min_dim_size=0, + max_dim_size=10, ), - test_with_out=st.just(False), + dtype=helpers.get_dtypes("valid", full=False), ) -def test_tensorflow_truncatediv( +def test_tensorflow_zeros( *, - dtype_and_x, - test_flags, + input, + dtype, frontend, backend_fw, + test_flags, fn_tree, on_device, ): - input_dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=input_dtype, - test_flags=test_flags, + shape=input, + input_dtypes=dtype, frontend=frontend, backend_to_test=backend_fw, + test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], - y=x[1], ) -# Truncatemod +# zeros_like @handle_frontend_test( - fn_tree="tensorflow.truncatemod", + fn_tree="tensorflow.zeros_like", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, - shared_dtype=True, + available_dtypes=helpers.get_dtypes("numeric") ), + dtype=helpers.get_dtypes("numeric", full=False), test_with_out=st.just(False), ) -def test_tensorflow_truncatemod( - *, +def test_tensorflow_zeros_like( dtype_and_x, + dtype, frontend, backend_fw, test_flags, @@ -2317,8 +2351,6 @@ def test_tensorflow_truncatemod( on_device, ): input_dtype, x = dtype_and_x - assume(not np.any(np.isclose(x[0], 0))) - assume(not np.any(np.isclose(x[1], 0))) helpers.test_frontend_function( input_dtypes=input_dtype, frontend=frontend, @@ -2326,63 +2358,6 @@ def test_tensorflow_truncatemod( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], - y=x[1], - ) - - -@handle_frontend_test( - fn_tree="tensorflow.unravel_index", - indices=helpers.array_values( - dtype=helpers.get_dtypes("integer"), shape=(1, 2), min_value=0, max_value=49 - ), - dims=helpers.array_values( - dtype=helpers.get_dtypes("integer"), shape=(1, 2), min_value=50 - ), -) -def test_tensorflow_unravel_index( - *, indices, dims, frontend, test_flags, fn_tree, on_device, backend_fw -): - helpers.test_frontend_function( - input_dtypes=["int32"], - frontend=frontend, - backend_to_test=backend_fw, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - indices=indices[0], - dims=dims[0], + input=x[0], + dtype=dtype[0], ) - - -# @handle_frontend_test( -# fn_tree="tensorflow.zeros_initializer", -# shape=helpers.get_shape( -# allow_none=False, -# min_num_dims=1, -# max_num_dims=5, -# min_dim_size=1, -# max_dim_size=10, -# ), -# dtype=helpers.get_dtypes("valid", full=False), -# test_with_out=st.just(False), -# ) -# def test_tensorflow_zeros_initializer( -# shape, -# dtype, -# frontend, -# backend_fw, -# test_flags, -# fn_tree, -# on_device, -# ): -# helpers.test_frontend_function( -# input_dtypes=dtype, -# frontend=frontend, -# backend_to_test=backend_fw, -# test_flags=test_flags, -# fn_tree=fn_tree, -# on_device=on_device, -# shape=shape, -# dtype=dtype[0], -# ) diff --git a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_image/test_cropping.py b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_image/test_cropping.py index fea35cd71a2f0..9d7cbfb9d856c 100644 --- a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_image/test_cropping.py +++ b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_image/test_cropping.py @@ -9,6 +9,10 @@ ) +# --- Helpers --- # +# --------------- # + + @st.composite def _extract_patches_helper(draw): sizes = [ @@ -48,6 +52,10 @@ def _extract_patches_helper(draw): return dtype_x, sizes, strides, rates, padding +# --- Main --- # +# ------------ # + + # extract_patches @handle_frontend_test( fn_tree="tensorflow.image.extract_patches", diff --git a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_keras/test_activations.py b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_keras/test_activations.py index a4f3f97eec93e..da5def824a5c7 100644 --- a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_keras/test_activations.py +++ b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_keras/test_activations.py @@ -1,523 +1,523 @@ -from hypothesis import strategies as st -from types import SimpleNamespace - -try: - import tensorflow as tf -except ImportError: - tf = SimpleNamespace() -import sys - -# local -import ivy -import ivy_tests.test_ivy.helpers as helpers -from ivy_tests.test_ivy.helpers import handle_frontend_test - - -@handle_frontend_test( - fn_tree="tensorflow.keras.activations.hard_sigmoid", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), - test_with_out=st.just(False), -) -def test_tensorflow_hard_sigmoid( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - rtol=1e-2, - atol=1e-2, - x=x[0], - ) - - -@handle_frontend_test( - fn_tree="tensorflow.keras.activations.linear", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), - test_with_out=st.just(False), -) -def test_tensorflow_linear( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - ) - - -# sigmoid -@handle_frontend_test( - fn_tree="tensorflow.keras.activations.sigmoid", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), - test_with_out=st.just(False), -) -def test_tensorflow_sigmoid( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - rtol=1e-2, - atol=1e-2, - x=x[0], - ) - - -# tanh -@handle_frontend_test( - fn_tree="tensorflow.keras.activations.tanh", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), - test_with_out=st.just(False), -) -def test_tensorflow_tanh( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - rtol=1e-2, - atol=1e-2, - x=x[0], - ) - - -# softmax -@handle_frontend_test( - fn_tree="tensorflow.keras.activations.softmax", - dtype_x_and_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("float"), - min_num_dims=2, - max_axes_size=1, - force_int_axis=True, - valid_axis=True, - ), - test_with_out=st.just(False), -) -def test_tensorflow_softmax( - *, - dtype_x_and_axis, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, x, axis = dtype_x_and_axis - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - axis=axis, - ) - - -# gelu -@handle_frontend_test( - fn_tree="tensorflow.keras.activations.gelu", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - large_abs_safety_factor=2, - small_abs_safety_factor=2, - ), - approximate=st.booleans(), - test_with_out=st.just(False), -) -def test_tensorflow_gelu( - *, dtype_and_x, approximate, on_device, fn_tree, frontend, test_flags, backend_fw -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - rtol=1e-2, - atol=1e-2, - x=x[0], - approximate=approximate, - ) - - -@handle_frontend_test( - fn_tree="tensorflow.keras.activations.elu", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric") - ), - test_with_out=st.just(False), -) -def test_tensorflow_relu( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - ) - - -# softplus -@handle_frontend_test( - fn_tree="tensorflow.keras.activations.softplus", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), - test_with_out=st.just(False), -) -def test_tensorflow_softplus( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - rtol=1e-2, - atol=1e-2, - x=x[0], - ) - - -# softsign -@handle_frontend_test( - fn_tree="tensorflow.keras.activations.softsign", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - ), - test_with_out=st.just(False), -) -def test_tensorflow_softsign( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - rtol=1e-2, - atol=1e-2, - x=x[0], - ) - - -# swish -@handle_frontend_test( - fn_tree="tensorflow.keras.activations.swish", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), - test_with_out=st.just(False), -) -def test_tensorflow_swish( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - rtol=1e-2, - atol=1e-2, - x=x[0], - ) - - -# elu -@handle_frontend_test( - fn_tree="tensorflow.keras.activations.elu", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - min_value=-3, - max_value=3, - min_num_dims=1, - max_num_dims=3, - min_dim_size=1, - max_dim_size=3, - ), - alpha=st.one_of( - helpers.floats( - min_value=-3, - max_value=3, - ) - ), - test_with_out=st.just(False), -) -def test_tensorflow_elu( - *, - dtype_and_x, - alpha, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - alpha=alpha, - ) - - -# selu -@handle_frontend_test( - fn_tree="tensorflow.keras.activations.selu", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - min_value=-3, - max_value=3, - min_num_dims=1, - max_num_dims=3, - min_dim_size=1, - max_dim_size=3, - ), - test_with_out=st.just(False), -) -def test_tensorflow_selu( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - rtol=1e-03, - atol=1e-03, - x=x[0], - ) - - -# Helper function for deserialize. -def simple_test_two_function( - *, - fn_name: str, - x, - frontend: str, - fn_str: str, - dtype_data: str, - rtol_: float = None, - atol_: float = 1e-06, - ivy_submodules: list = [], - framework_submodules: list = [], -): - ivy.set_backend(frontend) - fn_ivy = ivy.functional.frontends.__dict__[frontend] - for ivy_submodule in ivy_submodules: - fn_ivy = fn_ivy.__dict__[ivy_submodule] - fn_ivy = fn_ivy.__dict__[fn_str] - - fn_framework = tf - for framework_submodule in framework_submodules: - fn_framework = fn_framework.__dict__[framework_submodule] - fn_framework = fn_framework.__dict__[fn_str] - x = ivy.array(x).to_native() - - ret_ivy = fn_ivy(fn_name)(x) - ret = fn_framework(fn_name)(x) - - ret_ivy = ivy.array(ret_ivy, dtype=dtype_data) - ret = ivy.array(ret, dtype=dtype_data) - - ret_np_flat = helpers.flatten_and_to_np(ret=ret) - frontend_ret_np_flat = helpers.flatten_and_to_np(ret=ret_ivy) - - helpers.value_test( - ret_np_flat=ret_np_flat, - ret_np_from_gt_flat=frontend_ret_np_flat, - rtol=rtol_, - atol=atol_, - ground_truth_backend=frontend, - ) - ivy.previous_backend() - - -# Helper function for deserialize. -def get_callable_functions( - module_name: str, -): - module = sys.modules[module_name] - fn_list = list() - for fn_name in dir(module): - obj = getattr(module, fn_name) - if callable(obj): - fn_list.append(fn_name) - return fn_list - - -# deserialize -@handle_frontend_test( - fn_tree="tensorflow.keras.activations.deserialize", - fn_name=st.sampled_from(get_callable_functions("keras.activations")).filter( - lambda x: not x[0].isupper() - and x - not in [ - "deserialize", - "get", - "keras_export", - "serialize", - "deserialize_keras_object", - "serialize_keras_object", - "get_globals", - ] - ), - dtype_and_data=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - min_value=0, - max_value=10, - shape=helpers.ints(min_value=2, max_value=5).map(lambda x: tuple([x, x])), - ), -) -def test_tensorflow_deserialize( - *, - dtype_and_data, - fn_name, - fn_tree, - frontend, -): - dtype_data, data = dtype_and_data - simple_test_two_function( - fn_name=fn_name, - x=data[0], - frontend=frontend, - fn_str="deserialize", - dtype_data=dtype_data[0], - rtol_=1e-01, - atol_=1e-01, - ivy_submodules=["keras", "activations"], - framework_submodules=["keras", "activations"], - ) - - -@handle_frontend_test( - fn_tree="tensorflow.keras.activations.get", - fn_name=st.sampled_from(get_callable_functions("keras.activations")).filter( - lambda x: not x[0].isupper() - and x - not in [ - "deserialize", - "get", - "keras_export", - "serialize", - "deserialize_keras_object", - "serialize_keras_object", - "get_globals", - ] - ), - dtype_and_data=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_value=0, - max_value=10, - shape=helpers.ints(min_value=2, max_value=5).map(lambda x: tuple([x, x])), - ), -) -def test_tensorflow_get(fn_name, dtype_and_data): - dtype_data, data = dtype_and_data - simple_test_two_function( - fn_name=fn_name, - x=data[0], - frontend="tensorflow", - fn_str="get", - dtype_data=dtype_data[0], - rtol_=1e-01, - atol_=1e-01, - ivy_submodules=["keras", "activations"], - framework_submodules=["keras", "activations"], - ) +from hypothesis import strategies as st +from types import SimpleNamespace +import sys + +# local +import ivy +import ivy_tests.test_ivy.helpers as helpers +from ivy_tests.test_ivy.helpers import handle_frontend_test + +try: + import tensorflow as tf +except ImportError: + tf = SimpleNamespace() + + +# Helper function for deserialize. +def get_callable_functions( + module_name: str, +): + module = sys.modules[module_name] + fn_list = list() + for fn_name in dir(module): + obj = getattr(module, fn_name) + if callable(obj): + fn_list.append(fn_name) + return fn_list + + +# Helper function for deserialize. +def simple_test_two_function( + *, + fn_name: str, + x, + frontend: str, + fn_str: str, + dtype_data: str, + rtol_: float = None, + atol_: float = 1e-06, + ivy_submodules: list = [], + framework_submodules: list = [], +): + ivy.set_backend(frontend) + fn_ivy = ivy.functional.frontends.__dict__[frontend] + for ivy_submodule in ivy_submodules: + fn_ivy = fn_ivy.__dict__[ivy_submodule] + fn_ivy = fn_ivy.__dict__[fn_str] + + fn_framework = tf + for framework_submodule in framework_submodules: + fn_framework = fn_framework.__dict__[framework_submodule] + fn_framework = fn_framework.__dict__[fn_str] + x = ivy.array(x).to_native() + + ret_ivy = fn_ivy(fn_name)(x) + ret = fn_framework(fn_name)(x) + + ret_ivy = ivy.array(ret_ivy, dtype=dtype_data) + ret = ivy.array(ret, dtype=dtype_data) + + ret_np_flat = helpers.flatten_and_to_np(ret=ret) + frontend_ret_np_flat = helpers.flatten_and_to_np(ret=ret_ivy) + + helpers.value_test( + ret_np_flat=ret_np_flat, + ret_np_from_gt_flat=frontend_ret_np_flat, + rtol=rtol_, + atol=atol_, + ground_truth_backend=frontend, + ) + ivy.previous_backend() + + +# deserialize +@handle_frontend_test( + fn_tree="tensorflow.keras.activations.deserialize", + fn_name=st.sampled_from(get_callable_functions("keras.activations")).filter( + lambda x: not x[0].isupper() + and x + not in [ + "deserialize", + "get", + "keras_export", + "serialize", + "deserialize_keras_object", + "serialize_keras_object", + "get_globals", + ] + ), + dtype_and_data=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + min_value=0, + max_value=10, + shape=helpers.ints(min_value=2, max_value=5).map(lambda x: tuple([x, x])), + ), +) +def test_tensorflow_deserialize( + *, + dtype_and_data, + fn_name, + fn_tree, + frontend, +): + dtype_data, data = dtype_and_data + simple_test_two_function( + fn_name=fn_name, + x=data[0], + frontend=frontend, + fn_str="deserialize", + dtype_data=dtype_data[0], + rtol_=1e-01, + atol_=1e-01, + ivy_submodules=["keras", "activations"], + framework_submodules=["keras", "activations"], + ) + + +# elu +@handle_frontend_test( + fn_tree="tensorflow.keras.activations.elu", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + min_value=-3, + max_value=3, + min_num_dims=1, + max_num_dims=3, + min_dim_size=1, + max_dim_size=3, + ), + alpha=st.one_of( + helpers.floats( + min_value=-3, + max_value=3, + ) + ), + test_with_out=st.just(False), +) +def test_tensorflow_elu( + *, + dtype_and_x, + alpha, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + alpha=alpha, + ) + + +# gelu +@handle_frontend_test( + fn_tree="tensorflow.keras.activations.gelu", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + large_abs_safety_factor=2, + small_abs_safety_factor=2, + ), + approximate=st.booleans(), + test_with_out=st.just(False), +) +def test_tensorflow_gelu( + *, dtype_and_x, approximate, on_device, fn_tree, frontend, test_flags, backend_fw +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + rtol=1e-2, + atol=1e-2, + x=x[0], + approximate=approximate, + ) + + +@handle_frontend_test( + fn_tree="tensorflow.keras.activations.get", + fn_name=st.sampled_from(get_callable_functions("keras.activations")).filter( + lambda x: not x[0].isupper() + and x + not in [ + "deserialize", + "get", + "keras_export", + "serialize", + "deserialize_keras_object", + "serialize_keras_object", + "get_globals", + ] + ), + dtype_and_data=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_value=0, + max_value=10, + shape=helpers.ints(min_value=2, max_value=5).map(lambda x: tuple([x, x])), + ), +) +def test_tensorflow_get(fn_name, dtype_and_data): + dtype_data, data = dtype_and_data + simple_test_two_function( + fn_name=fn_name, + x=data[0], + frontend="tensorflow", + fn_str="get", + dtype_data=dtype_data[0], + rtol_=1e-01, + atol_=1e-01, + ivy_submodules=["keras", "activations"], + framework_submodules=["keras", "activations"], + ) + + +@handle_frontend_test( + fn_tree="tensorflow.keras.activations.hard_sigmoid", + dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), + test_with_out=st.just(False), +) +def test_tensorflow_hard_sigmoid( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + rtol=1e-2, + atol=1e-2, + x=x[0], + ) + + +@handle_frontend_test( + fn_tree="tensorflow.keras.activations.linear", + dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), + test_with_out=st.just(False), +) +def test_tensorflow_linear( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + ) + + +@handle_frontend_test( + fn_tree="tensorflow.keras.activations.elu", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric") + ), + test_with_out=st.just(False), +) +def test_tensorflow_relu( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + ) + + +# selu +@handle_frontend_test( + fn_tree="tensorflow.keras.activations.selu", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + min_value=-3, + max_value=3, + min_num_dims=1, + max_num_dims=3, + min_dim_size=1, + max_dim_size=3, + ), + test_with_out=st.just(False), +) +def test_tensorflow_selu( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + rtol=1e-03, + atol=1e-03, + x=x[0], + ) + + +# sigmoid +@handle_frontend_test( + fn_tree="tensorflow.keras.activations.sigmoid", + dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), + test_with_out=st.just(False), +) +def test_tensorflow_sigmoid( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + rtol=1e-2, + atol=1e-2, + x=x[0], + ) + + +# softmax +@handle_frontend_test( + fn_tree="tensorflow.keras.activations.softmax", + dtype_x_and_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("float"), + min_num_dims=2, + max_axes_size=1, + force_int_axis=True, + valid_axis=True, + ), + test_with_out=st.just(False), +) +def test_tensorflow_softmax( + *, + dtype_x_and_axis, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x, axis = dtype_x_and_axis + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + axis=axis, + ) + + +# softplus +@handle_frontend_test( + fn_tree="tensorflow.keras.activations.softplus", + dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), + test_with_out=st.just(False), +) +def test_tensorflow_softplus( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + rtol=1e-2, + atol=1e-2, + x=x[0], + ) + + +# softsign +@handle_frontend_test( + fn_tree="tensorflow.keras.activations.softsign", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + ), + test_with_out=st.just(False), +) +def test_tensorflow_softsign( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + rtol=1e-2, + atol=1e-2, + x=x[0], + ) + + +# swish +@handle_frontend_test( + fn_tree="tensorflow.keras.activations.swish", + dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), + test_with_out=st.just(False), +) +def test_tensorflow_swish( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + rtol=1e-2, + atol=1e-2, + x=x[0], + ) + + +# tanh +@handle_frontend_test( + fn_tree="tensorflow.keras.activations.tanh", + dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), + test_with_out=st.just(False), +) +def test_tensorflow_tanh( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + rtol=1e-2, + atol=1e-2, + x=x[0], + ) diff --git a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_keras/test_metrics.py b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_keras/test_metrics.py index eefce16b565cc..c9bbebf454108 100644 --- a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_keras/test_metrics.py +++ b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_keras/test_metrics.py @@ -8,6 +8,97 @@ from ivy_tests.test_ivy.helpers import handle_frontend_test +# --- Helpers --- # +# --------------- # + + +@st.composite +def _binary_focal_args(draw): + shape = st.tuples(st.integers(1, 10), st.integers(1, 10), st.integers(1, 10)) + common_float_dtype = helpers.get_dtypes("float", full=False) + + from_logits = draw( + helpers.dtype_and_values( + available_dtypes=draw(helpers.get_dtypes("bool")), shape=(1,) + ) + ) + + if from_logits[0]: + min_value = -10.0 + max_value = 10.0 + else: + min_value = 0.0 + max_value = 1.0 + + dtype_y_true = draw( + helpers.dtype_and_values( + available_dtypes=draw(helpers.get_dtypes("integer")), + min_value=0, + max_value=2, + exclude_max=True, + shape=draw(st.shared(shape, key="shape")), + ) + ) + dtype_y_pred = draw( + helpers.dtype_and_values( + dtype=draw(st.shared(common_float_dtype, key="float_dtype")), + min_value=min_value, + max_value=max_value, + shape=draw(st.shared(shape, key="shape")), + ) + ) + dtype_label_smoothing = draw( + helpers.dtype_and_values( + dtype=draw(st.shared(common_float_dtype, key="float_dtype")), + min_value=0.0, + max_value=1.0, + exclude_min=False, + exclude_max=False, + shape=(1,), + ) + ) + dtype_gamma = draw( + helpers.dtype_and_values( + dtype=draw(st.shared(common_float_dtype, key="float_dtype")), + min_value=0.0, + max_value=10.0, + shape=(1,), + ) + ) + # attr = Tidx:type, default = DT_INT32, allowed = [DT_INT32, DT_INT64] > [Op:Mean] + dtype_axis = draw( + helpers.dtype_and_values( + available_dtypes=[ivy.int32, ivy.int64], + min_value=-len(draw(st.shared(shape, key="shape"))), + max_value=len(draw(st.shared(shape, key="shape"))), + shape=(1,), + ) + ) + dtype_true, y_true = dtype_y_true + dtype_pred, y_pred = dtype_y_pred + dtype_gamma, gamma = dtype_gamma + dtype_from_logits, from_logits = from_logits + dtype_label_smoothing, label_smoothing = dtype_label_smoothing + dtype_axis, axis = dtype_axis + dtypes = [ + dtype_true[0], + dtype_pred[0], + dtype_gamma[0], + dtype_from_logits[0], + dtype_label_smoothing[0], + dtype_axis[0], + ] + values = [ + y_true[0], + y_pred[0], + gamma[0], + from_logits[0], + label_smoothing[0], + axis[0], + ] + return dtypes, values + + @st.composite def _dtype_pred_and_labels( draw, @@ -111,6 +202,10 @@ def _dtype_pred_and_labels( return dtype, pred, labels +# --- Main --- # +# ------------ # + + # binary_accuracy @handle_frontend_test( fn_tree="tensorflow.keras.metrics.binary_accuracy", @@ -147,124 +242,113 @@ def test_tensorflow_binary_accuracy( ) -# sparse_categorical_crossentropy +# binary_crossentropy @handle_frontend_test( - fn_tree="tensorflow.keras.metrics.sparse_categorical_crossentropy", - y_true=st.lists(st.integers(min_value=0, max_value=4), min_size=1, max_size=1), - dtype_y_pred=helpers.dtype_and_values( + fn_tree="tensorflow.keras.metrics.binary_crossentropy", + dtype_pred_and_labels=_dtype_pred_and_labels( available_dtypes=helpers.get_dtypes("float"), - shape=(5,), - min_value=-10, - max_value=10, + min_pred_val=1e-6, + max_label_val=5, + min_dim_size=1, + min_num_dims=1, ), from_logits=st.booleans(), + label_smoothing=helpers.floats(min_value=0.0, max_value=1.0), test_with_out=st.just(False), ) -def test_tensorflow_sparse_categorical_crossentropy( +def test_tensorflow_binary_crossentropy( *, - y_true, - dtype_y_pred, + dtype_pred_and_labels, from_logits, + label_smoothing, frontend, test_flags, fn_tree, on_device, backend_fw, ): - y_true = ivy.array(y_true, dtype=ivy.int32) - dtype, y_pred = dtype_y_pred - y_pred = y_pred[0] - # Perform softmax on prediction if it's not a probability distribution. - if not from_logits: - y_pred = ivy.exp(y_pred) / ivy.sum(ivy.exp(y_pred)) - + input_dtype, y_pred, y_true = dtype_pred_and_labels helpers.test_frontend_function( - input_dtypes=[ivy.int32] + dtype, + input_dtypes=input_dtype[::-1], backend_to_test=backend_fw, - frontend=frontend, test_flags=test_flags, + frontend=frontend, fn_tree=fn_tree, on_device=on_device, + rtol=1e-1, + atol=1e-1, y_true=y_true, - y_pred=y_pred[0], + y_pred=y_pred, from_logits=from_logits, + label_smoothing=label_smoothing, ) -# log_cosh +# binary_focal_crossentropy @handle_frontend_test( - fn_tree="tensorflow.keras.metrics.log_cosh", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - shared_dtype=False, - min_num_dims=1, - large_abs_safety_factor=2, - small_abs_safety_factor=2, - ), + fn_tree="tensorflow.keras.metrics.binary_focal_crossentropy", + binary_focal_args=_binary_focal_args(), test_with_out=st.just(False), ) -def test_tensorflow_log_cosh( +def test_tensorflow_binary_focal_crossentropy( *, - dtype_and_x, + binary_focal_args, frontend, test_flags, fn_tree, on_device, backend_fw, ): - input_dtype, x = dtype_and_x + dtypes, values = binary_focal_args helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=dtypes, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - y_true=x[0], - y_pred=x[1], + y_true=values[0], + y_pred=values[1], + gamma=values[2], + from_logits=values[3], + label_smoothing=values[4], + axis=values[5], ) -# binary_crossentropy +# categorical_accuracy @handle_frontend_test( - fn_tree="tensorflow.keras.metrics.binary_crossentropy", - dtype_pred_and_labels=_dtype_pred_and_labels( - available_dtypes=helpers.get_dtypes("float"), - min_pred_val=1e-6, - max_label_val=5, - min_dim_size=1, - min_num_dims=1, + fn_tree="tensorflow.keras.metrics.categorical_accuracy", + dtype_and_y=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, + shared_dtype=True, + shape=helpers.get_shape( + allow_none=False, + min_num_dims=1, + ), ), - from_logits=st.booleans(), - label_smoothing=helpers.floats(min_value=0.0, max_value=1.0), test_with_out=st.just(False), ) -def test_tensorflow_binary_crossentropy( +def test_tensorflow_categorical_accuracy( *, - dtype_pred_and_labels, - from_logits, - label_smoothing, + dtype_and_y, frontend, test_flags, fn_tree, on_device, backend_fw, ): - input_dtype, y_pred, y_true = dtype_pred_and_labels + input_dtype, y = dtype_and_y helpers.test_frontend_function( - input_dtypes=input_dtype[::-1], + input_dtypes=input_dtype, backend_to_test=backend_fw, - test_flags=test_flags, frontend=frontend, + test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - rtol=1e-1, - atol=1e-1, - y_true=y_true, - y_pred=y_pred, - from_logits=from_logits, - label_smoothing=label_smoothing, + y_true=y[0], + y_pred=y[1], ) @@ -317,142 +401,57 @@ def test_tensorflow_categorical_crossentropy( ) -@st.composite -def _binary_focal_args(draw): - shape = st.tuples(st.integers(1, 10), st.integers(1, 10), st.integers(1, 10)) - common_float_dtype = helpers.get_dtypes("float", full=False) - - from_logits = draw( - helpers.dtype_and_values( - available_dtypes=draw(helpers.get_dtypes("bool")), shape=(1,) - ) - ) - - if from_logits[0]: - min_value = -10.0 - max_value = 10.0 - else: - min_value = 0.0 - max_value = 1.0 - - dtype_y_true = draw( - helpers.dtype_and_values( - available_dtypes=draw(helpers.get_dtypes("integer")), - min_value=0, - max_value=2, - exclude_max=True, - shape=draw(st.shared(shape, key="shape")), - ) - ) - dtype_y_pred = draw( - helpers.dtype_and_values( - dtype=draw(st.shared(common_float_dtype, key="float_dtype")), - min_value=min_value, - max_value=max_value, - shape=draw(st.shared(shape, key="shape")), - ) - ) - dtype_label_smoothing = draw( - helpers.dtype_and_values( - dtype=draw(st.shared(common_float_dtype, key="float_dtype")), - min_value=0.0, - max_value=1.0, - exclude_min=False, - exclude_max=False, - shape=(1,), - ) - ) - dtype_gamma = draw( - helpers.dtype_and_values( - dtype=draw(st.shared(common_float_dtype, key="float_dtype")), - min_value=0.0, - max_value=10.0, - shape=(1,), - ) - ) - # attr = Tidx:type, default = DT_INT32, allowed = [DT_INT32, DT_INT64] > [Op:Mean] - dtype_axis = draw( - helpers.dtype_and_values( - available_dtypes=[ivy.int32, ivy.int64], - min_value=-len(draw(st.shared(shape, key="shape"))), - max_value=len(draw(st.shared(shape, key="shape"))), - shape=(1,), - ) - ) - dtype_true, y_true = dtype_y_true - dtype_pred, y_pred = dtype_y_pred - dtype_gamma, gamma = dtype_gamma - dtype_from_logits, from_logits = from_logits - dtype_label_smoothing, label_smoothing = dtype_label_smoothing - dtype_axis, axis = dtype_axis - dtypes = [ - dtype_true[0], - dtype_pred[0], - dtype_gamma[0], - dtype_from_logits[0], - dtype_label_smoothing[0], - dtype_axis[0], - ] - values = [ - y_true[0], - y_pred[0], - gamma[0], - from_logits[0], - label_smoothing[0], - axis[0], - ] - return dtypes, values - - -# binary_focal_crossentropy +# Cosine Similarity @handle_frontend_test( - fn_tree="tensorflow.keras.metrics.binary_focal_crossentropy", - binary_focal_args=_binary_focal_args(), + fn_tree="tensorflow.keras.metrics.cosine_similarity", + d_type=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), shared_dtype=True, num_arrays=2 + ), + y_true=helpers.array_values( + dtype=ivy.int32, shape=(1, 5), min_value=1, max_value=5 + ), + y_pred=helpers.array_values( + dtype=ivy.int32, shape=(1, 5), min_value=5, max_value=10 + ), test_with_out=st.just(False), ) -def test_tensorflow_binary_focal_crossentropy( +def test_tensorflow_cosine_similarity( *, - binary_focal_args, + d_type, + y_true, + y_pred, frontend, test_flags, fn_tree, on_device, backend_fw, ): - dtypes, values = binary_focal_args helpers.test_frontend_function( - input_dtypes=dtypes, + input_dtypes=d_type[0], backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - y_true=values[0], - y_pred=values[1], - gamma=values[2], - from_logits=values[3], - label_smoothing=values[4], - axis=values[5], + y_true=y_true, + y_pred=y_pred, ) -# sparse_top_k_categorical_accuracy +# hinge @handle_frontend_test( - fn_tree="tensorflow.keras.metrics.sparse_top_k_categorical_accuracy", + fn_tree="tensorflow.keras.metrics.hinge", dtype_pred_and_labels=_dtype_pred_and_labels( available_dtypes=helpers.get_dtypes("float"), - min_pred_val=1e-6, - max_label_val=5, - sparse_label=True, - shape=(5, 10), + label_set=[-1, 1], + min_num_dims=2, + min_dim_size=2, ), - k=st.integers(min_value=3, max_value=10), test_with_out=st.just(False), ) -def test_tensorflow_sparse_top_k_categorical_accuracy( +def test_tensorflow_hinge( *, dtype_pred_and_labels, - k, frontend, test_flags, fn_tree, @@ -467,36 +466,32 @@ def test_tensorflow_sparse_top_k_categorical_accuracy( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - y_true=y_true, y_pred=y_pred, - k=k, + y_true=y_true, ) -# categorical_accuracy +# kl_divergence @handle_frontend_test( - fn_tree="tensorflow.keras.metrics.categorical_accuracy", - dtype_and_y=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), + fn_tree="tensorflow.keras.metrics.kl_divergence", + aliases=["tensorflow.keras.metrics.kld"], + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), num_arrays=2, shared_dtype=True, - shape=helpers.get_shape( - allow_none=False, - min_num_dims=1, - ), + min_num_dims=1, ), - test_with_out=st.just(False), ) -def test_tensorflow_categorical_accuracy( +def test_tensorflow_kl_divergence( *, - dtype_and_y, + dtype_and_x, frontend, test_flags, fn_tree, on_device, backend_fw, ): - input_dtype, y = dtype_and_y + input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -504,23 +499,25 @@ def test_tensorflow_categorical_accuracy( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - y_true=y[0], - y_pred=y[1], + y_true=x[0], + y_pred=x[1], ) -# kl_divergence +# log_cosh @handle_frontend_test( - fn_tree="tensorflow.keras.metrics.kl_divergence", - aliases=["tensorflow.keras.metrics.kld"], + fn_tree="tensorflow.keras.metrics.log_cosh", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), num_arrays=2, - shared_dtype=True, + shared_dtype=False, min_num_dims=1, + large_abs_safety_factor=2, + small_abs_safety_factor=2, ), + test_with_out=st.just(False), ) -def test_tensorflow_kl_divergence( +def test_tensorflow_log_cosh( *, dtype_and_x, frontend, @@ -578,18 +575,20 @@ def test_tensorflow_mean_absolute_error( ) -# poisson +# mean_absolute_percentage_error @handle_frontend_test( - fn_tree="tensorflow.keras.metrics.poisson", + fn_tree="tensorflow.keras.metrics.mean_absolute_percentage_error", + aliases=["tensorflow.keras.metrics.mape"], dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), num_arrays=2, shared_dtype=True, min_num_dims=1, + large_abs_safety_factor=2, + small_abs_safety_factor=2, ), - test_with_out=st.just(False), ) -def test_tensorflow_poisson( +def test_tensorflow_mean_absolute_percentage_error( *, dtype_and_x, frontend, @@ -646,20 +645,19 @@ def test_tensorflow_mean_squared_error( ) -# mean_absolute_percentage_error +# mean_squared_logarithmic_error @handle_frontend_test( - fn_tree="tensorflow.keras.metrics.mean_absolute_percentage_error", - aliases=["tensorflow.keras.metrics.mape"], + fn_tree="tensorflow.keras.metrics.mean_squared_logarithmic_error", + aliases=["tensorflow.keras.metrics.msle"], dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), num_arrays=2, - shared_dtype=True, min_num_dims=1, - large_abs_safety_factor=2, - small_abs_safety_factor=2, + shared_dtype=True, ), + test_with_out=st.just(False), ) -def test_tensorflow_mean_absolute_percentage_error( +def test_tensorflow_metrics_mean_squared_logarithmic_error( *, dtype_and_x, frontend, @@ -681,27 +679,27 @@ def test_tensorflow_mean_absolute_percentage_error( ) -# hinge +# poisson @handle_frontend_test( - fn_tree="tensorflow.keras.metrics.hinge", - dtype_pred_and_labels=_dtype_pred_and_labels( + fn_tree="tensorflow.keras.metrics.poisson", + dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - label_set=[-1, 1], - min_num_dims=2, - min_dim_size=2, + num_arrays=2, + shared_dtype=True, + min_num_dims=1, ), test_with_out=st.just(False), ) -def test_tensorflow_hinge( +def test_tensorflow_poisson( *, - dtype_pred_and_labels, + dtype_and_x, frontend, test_flags, fn_tree, on_device, backend_fw, ): - input_dtype, y_pred, y_true = dtype_pred_and_labels + input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -709,66 +707,79 @@ def test_tensorflow_hinge( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - y_pred=y_pred, - y_true=y_true, + y_true=x[0], + y_pred=x[1], ) -# squared_hinge +# sparse_categorical_crossentropy @handle_frontend_test( - fn_tree="tensorflow.keras.metrics.squared_hinge", - dtype_pred_and_labels=_dtype_pred_and_labels( + fn_tree="tensorflow.keras.metrics.sparse_categorical_crossentropy", + y_true=st.lists(st.integers(min_value=0, max_value=4), min_size=1, max_size=1), + dtype_y_pred=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - label_set=[-1, 1], - min_num_dims=2, - min_dim_size=2, + shape=(5,), + min_value=-10, + max_value=10, ), + from_logits=st.booleans(), test_with_out=st.just(False), ) -def test_tensorflow_squared_hinge( +def test_tensorflow_sparse_categorical_crossentropy( *, - dtype_pred_and_labels, + y_true, + dtype_y_pred, + from_logits, frontend, test_flags, fn_tree, on_device, backend_fw, ): - input_dtype, y_pred, y_true = dtype_pred_and_labels + y_true = ivy.array(y_true, dtype=ivy.int32) + dtype, y_pred = dtype_y_pred + y_pred = y_pred[0] + # Perform softmax on prediction if it's not a probability distribution. + if not from_logits: + y_pred = ivy.exp(y_pred) / ivy.sum(ivy.exp(y_pred)) + helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=[ivy.int32] + dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - y_pred=y_pred, y_true=y_true, + y_pred=y_pred[0], + from_logits=from_logits, ) -# mean_squared_logarithmic_error +# sparse_top_k_categorical_accuracy @handle_frontend_test( - fn_tree="tensorflow.keras.metrics.mean_squared_logarithmic_error", - aliases=["tensorflow.keras.metrics.msle"], - dtype_and_x=helpers.dtype_and_values( + fn_tree="tensorflow.keras.metrics.sparse_top_k_categorical_accuracy", + dtype_pred_and_labels=_dtype_pred_and_labels( available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - min_num_dims=1, - shared_dtype=True, + min_pred_val=1e-6, + max_label_val=5, + sparse_label=True, + shape=(5, 10), ), + k=st.integers(min_value=3, max_value=10), test_with_out=st.just(False), ) -def test_tensorflow_metrics_mean_squared_logarithmic_error( +def test_tensorflow_sparse_top_k_categorical_accuracy( *, - dtype_and_x, + dtype_pred_and_labels, + k, frontend, test_flags, fn_tree, on_device, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, y_pred, y_true = dtype_pred_and_labels helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -776,43 +787,40 @@ def test_tensorflow_metrics_mean_squared_logarithmic_error( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - y_true=x[0], - y_pred=x[1], + y_true=y_true, + y_pred=y_pred, + k=k, ) -# Cosine Similarity +# squared_hinge @handle_frontend_test( - fn_tree="tensorflow.keras.metrics.cosine_similarity", - d_type=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), shared_dtype=True, num_arrays=2 - ), - y_true=helpers.array_values( - dtype=ivy.int32, shape=(1, 5), min_value=1, max_value=5 - ), - y_pred=helpers.array_values( - dtype=ivy.int32, shape=(1, 5), min_value=5, max_value=10 + fn_tree="tensorflow.keras.metrics.squared_hinge", + dtype_pred_and_labels=_dtype_pred_and_labels( + available_dtypes=helpers.get_dtypes("float"), + label_set=[-1, 1], + min_num_dims=2, + min_dim_size=2, ), test_with_out=st.just(False), ) -def test_tensorflow_cosine_similarity( +def test_tensorflow_squared_hinge( *, - d_type, - y_true, - y_pred, + dtype_pred_and_labels, frontend, test_flags, fn_tree, on_device, backend_fw, ): + input_dtype, y_pred, y_true = dtype_pred_and_labels helpers.test_frontend_function( - input_dtypes=d_type[0], + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - y_true=y_true, y_pred=y_pred, + y_true=y_true, ) diff --git a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_linalg.py b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_linalg.py index 5df13046959be..059054d9d8982 100644 --- a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_linalg.py +++ b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_linalg.py @@ -16,6 +16,36 @@ from ivy_tests.test_ivy.test_functional.test_core.test_linalg import _matrix_rank_helper +# --- Helpers --- # +# --------------- # + + +# cholesky_solve +@st.composite +def _get_cholesky_matrix(draw): + # batch_shape, random_size, shared + input_dtype = draw( + st.shared( + st.sampled_from(draw(helpers.get_dtypes("float"))), + key="shared_dtype", + ) + ) + shared_size = draw( + st.shared(helpers.ints(min_value=2, max_value=4), key="shared_size") + ) + gen = draw( + helpers.array_values( + dtype=input_dtype, + shape=tuple([shared_size, shared_size]), + min_value=2, + max_value=5, + ).filter(lambda x: np.linalg.cond(x.tolist()) < 1 / sys.float_info.epsilon) + ) + spd = np.matmul(gen.T, gen) + np.identity(gen.shape[0]) + spd_chol = np.linalg.cholesky(spd) + return input_dtype, spd_chol + + @st.composite def _get_dtype_and_matrix(draw): arbitrary_dims = draw(helpers.get_shape(max_dim_size=5)) @@ -31,198 +61,52 @@ def _get_dtype_and_matrix(draw): ) -@handle_frontend_test( - fn_tree="tensorflow.linalg.det", - dtype_and_input=_get_dtype_and_matrix(), - test_with_out=st.just(False), -) -def test_tensorflow_det( - *, - dtype_and_input, - frontend, - backend_fw, - test_flags, - fn_tree, - on_device, -): - input_dtype, x = dtype_and_input - helpers.test_frontend_function( - input_dtypes=input_dtype, - frontend=frontend, - test_flags=test_flags, - backend_to_test=backend_fw, - fn_tree=fn_tree, - on_device=on_device, - input=x[0], - ) - - -@handle_frontend_test( - fn_tree="tensorflow.linalg.eigh", - dtype_and_input=_get_dtype_and_matrix(), - test_with_out=st.just(False), -) -def test_tensorflow_eigh( - *, - dtype_and_input, - frontend, - backend_fw, - test_flags, - fn_tree, - on_device, -): - input_dtype, x = dtype_and_input - assume(matrix_is_stable(x[0])) - helpers.test_frontend_function( - input_dtypes=input_dtype, - frontend=frontend, - backend_to_test=backend_fw, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - tensor=x[0], - ) - - -@handle_frontend_test( - fn_tree="tensorflow.linalg.eigvals", - dtype_and_input=_get_dtype_and_matrix(), - test_with_out=st.just(False), -) -def test_tensorflow_eigvals( - *, - dtype_and_input, - frontend, - test_flags, - fn_tree, - on_device, - backend_fw, -): - input_dtype, x = dtype_and_input - assume(matrix_is_stable(x[0])) - if x[0].dtype == ivy.float32: - x[0] = x[0].astype("float64") - input_dtype = [ivy.float64] - ret, frontend_ret = helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - tensor=x[0], - test_values=False, - ) - - ret = ivy.to_numpy(ret) - ret = ret.round(6) - ret = np.sort(ret) - frontend_ret = frontend_ret[0].numpy() - frontend_ret = frontend_ret.round(6) - frontend_ret = np.sort(frontend_ret) - - assert_all_close( - ret_np=ret, - ret_from_gt_np=frontend_ret, - rtol=1e-06, - atol=1e-06, - ground_truth_backend=frontend, - ) - - -@handle_frontend_test( - fn_tree="tensorflow.linalg.eigvalsh", - dtype_and_input=_get_dtype_and_matrix(), - test_with_out=st.just(False), -) -def test_tensorflow_eigvalsh( - *, - dtype_and_input, - frontend, - backend_fw, - test_flags, - fn_tree, - on_device, -): - input_dtype, x = dtype_and_input - assume(matrix_is_stable(x[0])) - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - tensor=x[0], +@st.composite +def _get_dtype_and_matrix_and_num(draw): + arbitrary_dims = draw(helpers.get_shape(max_dim_size=5)) + random_size = draw(st.integers(min_value=1, max_value=4)) + shape = (*arbitrary_dims, random_size, random_size) + dtype_and_values = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + shape=shape, + min_value=-10, + max_value=10, + ) ) + num_lower = draw(st.integers(min_value=-1, max_value=random_size - 1)) + num_upper = draw(st.integers(min_value=-1, max_value=random_size - 1)) + return (*dtype_and_values, num_lower, num_upper) -@handle_frontend_test( - fn_tree="tensorflow.linalg.matrix_rank", - dtype_x_hermitian_atol_rtol=_matrix_rank_helper(), - test_with_out=st.just(False), -) -def test_tensorflow_matrix_rank( - *, - dtype_x_hermitian_atol_rtol, - frontend, - test_flags, - backend_fw, - fn_tree, - on_device, -): - dtype, x, hermitian, atol, rtol = dtype_x_hermitian_atol_rtol - assume(matrix_is_stable(x, cond_limit=10)) - helpers.test_frontend_function( - input_dtypes=dtype, - frontend=frontend, - backend_to_test=backend_fw, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - a=x, - tol=atol, +@st.composite +def _get_dtype_and_rank_2k_tensors(draw): + arbitrary_dims = draw(helpers.get_shape(max_dim_size=5)) + shape = arbitrary_dims + arbitrary_dims + return draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + shape=shape, + min_value=-10, + max_value=10, + ) ) -@handle_frontend_test( - fn_tree="tensorflow.linalg.matmul", - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - shape=(3, 3), - num_arrays=2, - shared_dtype=True, - min_value=-1, - max_value=100, - ), - transpose_a=st.booleans(), - transpose_b=st.booleans(), - test_with_out=st.just(False), -) -def test_tensorflow_matmul( - *, - dtype_x, - transpose_a, - transpose_b, - frontend, - backend_fw, - test_flags, - fn_tree, - on_device, -): - input_dtype, x = dtype_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - a=x[0], - b=x[1], - transpose_a=transpose_a, - transpose_b=transpose_b, - ) +@st.composite +def _get_dtype_and_sequence_of_arrays(draw): + array_dtype = draw(helpers.get_dtypes("float", full=False)) + arbitrary_size = draw(st.integers(min_value=2, max_value=10)) + values = [] + for i in range(arbitrary_size): + values.append( + draw( + helpers.array_values( + dtype=array_dtype[0], shape=helpers.get_shape(), allow_nan=True + ) + ) + ) + return array_dtype, values # solve @@ -247,6 +131,31 @@ def _get_first_matrix(draw): return input_dtype, matrix +# logdet +@st.composite +def _get_hermitian_pos_def_matrix(draw): + # batch_shape, random_size, shared + input_dtype = draw( + st.shared( + st.sampled_from(draw(helpers.get_dtypes("float"))), + key="shared_dtype", + ) + ) + shared_size = draw( + st.shared(helpers.ints(min_value=2, max_value=4), key="shared_size") + ) + gen = draw( + helpers.array_values( + dtype=input_dtype, + shape=tuple([shared_size, shared_size]), + min_value=2, + max_value=5, + ).filter(lambda x: np.linalg.cond(x.tolist()) < 1 / sys.float_info.epsilon) + ) + hpd = np.matmul(np.matrix(gen).getH(), np.matrix(gen)) + np.identity(gen.shape[0]) + return [input_dtype], hpd + + # solve @st.composite def _get_second_matrix(draw): @@ -266,43 +175,9 @@ def _get_second_matrix(draw): ) -# solve -@handle_frontend_test( - fn_tree="tensorflow.linalg.solve", - x=_get_first_matrix(), - y=_get_second_matrix(), - test_with_out=st.just(False), -) -def test_tensorflow_solve( - *, - x, - y, - frontend, - backend_fw, - test_flags, - fn_tree, - on_device, -): - input_dtype1, x1 = x - input_dtype2, x2 = y - helpers.test_frontend_function( - input_dtypes=[input_dtype1, input_dtype2], - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - rtol=1e-3, - atol=1e-3, - matrix=x1, - rhs=x2, - ) - - -# logdet @st.composite -def _get_hermitian_pos_def_matrix(draw): - # batch_shape, random_size, shared +def _get_second_matrix(draw): + # batch_shape, shared, random_size input_dtype = draw( st.shared( st.sampled_from(draw(helpers.get_dtypes("float"))), @@ -312,54 +187,76 @@ def _get_hermitian_pos_def_matrix(draw): shared_size = draw( st.shared(helpers.ints(min_value=2, max_value=4), key="shared_size") ) - gen = draw( + return input_dtype, draw( helpers.array_values( - dtype=input_dtype, - shape=tuple([shared_size, shared_size]), - min_value=2, - max_value=5, - ).filter(lambda x: np.linalg.cond(x.tolist()) < 1 / sys.float_info.epsilon) + dtype=input_dtype, shape=tuple([shared_size, 1]), min_value=2, max_value=5 + ) ) - hpd = np.matmul(np.matrix(gen).getH(), np.matrix(gen)) + np.identity(gen.shape[0]) - return [input_dtype], hpd +# --- Main --- # +# ------------ # + + +# qr @handle_frontend_test( - fn_tree="tensorflow.linalg.logdet", - dtype_and_x=_get_hermitian_pos_def_matrix(), + fn_tree="tensorflow.linalg.qr", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_value=0, + max_value=10, + shape=helpers.ints(min_value=2, max_value=5).map(lambda x: tuple([x, x])), + ), ) -def test_tensorflow_logdet( +def test_qr( *, dtype_and_x, frontend, - backend_fw, test_flags, fn_tree, on_device, + backend_fw, ): dtype, x = dtype_and_x - helpers.test_frontend_function( + x = np.asarray(x[0], dtype=dtype[0]) + x = np.matmul(x.T, x) + np.identity(x.shape[0]) * 1e-3 + ret, frontend_ret = helpers.test_frontend_function( input_dtypes=dtype, - frontend=frontend, backend_to_test=backend_fw, + frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - matrix=x, + test_values=False, + atol=1e-03, + rtol=1e-05, + input=x, ) + ret = [ivy.to_numpy(x) for x in ret] + frontend_ret = [np.asarray(x) for x in frontend_ret] + assert_all_close( + ret_np=ret[0], + ret_from_gt_np=frontend_ret[0], + rtol=1e-2, + atol=1e-2, + ground_truth_backend=frontend, + ) -# slogdet + +# adjoint @handle_frontend_test( - fn_tree="tensorflow.linalg.slogdet", - dtype_and_x=_get_dtype_and_matrix(), + fn_tree="tensorflow.linalg.adjoint", + dtype_and_x=_get_dtype_and_matrix().filter( + lambda x: "float16" not in x[0] and "bfloat16" not in x[0] + ), # TODO : remove this filter when paddle.conj supports float16 test_with_out=st.just(False), ) -def test_tensorflow_slogdet( +def test_tensorflow_adjoint( *, dtype_and_x, - frontend, backend_fw, + frontend, test_flags, fn_tree, on_device, @@ -372,52 +269,36 @@ def test_tensorflow_slogdet( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=x[0], - ) - - -# cholesky_solve -@st.composite -def _get_cholesky_matrix(draw): - # batch_shape, random_size, shared - input_dtype = draw( - st.shared( - st.sampled_from(draw(helpers.get_dtypes("float"))), - key="shared_dtype", - ) - ) - shared_size = draw( - st.shared(helpers.ints(min_value=2, max_value=4), key="shared_size") - ) - gen = draw( - helpers.array_values( - dtype=input_dtype, - shape=tuple([shared_size, shared_size]), - min_value=2, - max_value=5, - ).filter(lambda x: np.linalg.cond(x.tolist()) < 1 / sys.float_info.epsilon) + matrix=x[0], ) - spd = np.matmul(gen.T, gen) + np.identity(gen.shape[0]) - spd_chol = np.linalg.cholesky(spd) - return input_dtype, spd_chol -@st.composite -def _get_second_matrix(draw): - # batch_shape, shared, random_size - input_dtype = draw( - st.shared( - st.sampled_from(draw(helpers.get_dtypes("float"))), - key="shared_dtype", - ) - ) - shared_size = draw( - st.shared(helpers.ints(min_value=2, max_value=4), key="shared_size") - ) - return input_dtype, draw( - helpers.array_values( - dtype=input_dtype, shape=tuple([shared_size, 1]), min_value=2, max_value=5 - ) +# band_part +@handle_frontend_test( + fn_tree="tensorflow.linalg.band_part", + dtype_and_input=_get_dtype_and_matrix_and_num(), + test_with_out=st.just(False), +) +def test_tensorflow_band_part( + *, + dtype_and_input, + frontend, + backend_fw, + test_flags, + fn_tree, + on_device, +): + input_dtype, x, num_lower, num_upper = dtype_and_input + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=x[0], + num_lower=num_lower, + num_upper=num_upper, ) @@ -453,141 +334,157 @@ def test_tensorflow_cholesky_solve( ) -# pinv @handle_frontend_test( - fn_tree="tensorflow.linalg.pinv", + fn_tree="tensorflow.linalg.det", dtype_and_input=_get_dtype_and_matrix(), + test_with_out=st.just(False), +) +def test_tensorflow_det( + *, + dtype_and_input, + frontend, + backend_fw, + test_flags, + fn_tree, + on_device, +): + input_dtype, x = dtype_and_input + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + test_flags=test_flags, + backend_to_test=backend_fw, + fn_tree=fn_tree, + on_device=on_device, + input=x[0], + ) + + +# diag +@handle_frontend_test( + fn_tree="tensorflow.linalg.diag", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=["int64", "int32"], + min_num_dims=1, + max_num_dims=2, + min_dim_size=5, + max_dim_size=10, + min_value=0, + max_value=10, + ), + k=st.just(0), ) -def test_tensorflow_pinv( - *, - dtype_and_input, +def test_tensorflow_diag( + dtype_and_x, + k, frontend, backend_fw, test_flags, fn_tree, on_device, ): - input_dtype, x = dtype_and_input + dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - rtol=1e-3, - atol=1e-3, - a=x[0], - rcond=1e-15, + v=x[0], + k=k, ) -# tensordot @handle_frontend_test( - fn_tree="tensorflow.linalg.tensordot", - dtype_x_y_axes=_get_dtype_value1_value2_axis_for_tensordot( - available_dtypes=helpers.get_dtypes("numeric"), - ), + fn_tree="tensorflow.linalg.eigh", + dtype_and_input=_get_dtype_and_matrix(), + test_with_out=st.just(False), ) -def test_tensorflow_tensordot( +def test_tensorflow_eigh( *, - dtype_x_y_axes, - backend_fw, + dtype_and_input, frontend, + backend_fw, test_flags, fn_tree, on_device, ): - ( - dtype, - x, - y, - axes, - ) = dtype_x_y_axes + input_dtype, x = dtype_and_input + assume(matrix_is_stable(x[0])) helpers.test_frontend_function( - input_dtypes=dtype, + input_dtypes=input_dtype, frontend=frontend, backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - a=x, - b=y, - axes=axes, + tensor=x[0], ) -# norm @handle_frontend_test( - fn_tree="tensorflow.linalg.norm", - aliases=["tensorflow.norm"], - dtype_values_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("valid"), - min_num_dims=3, - max_num_dims=5, - min_dim_size=1, - max_dim_size=4, - min_axis=-3, - max_axis=2, - ), - ord=st.sampled_from([1, 2, np.inf]), - keepdims=st.booleans(), + fn_tree="tensorflow.linalg.eigvals", + dtype_and_input=_get_dtype_and_matrix(), + test_with_out=st.just(False), ) -def test_tensorflow_norm( +def test_tensorflow_eigvals( *, - dtype_values_axis, - ord, - keepdims, - backend_fw, + dtype_and_input, frontend, test_flags, fn_tree, on_device, + backend_fw, ): - input_dtype, x, axis = dtype_values_axis - helpers.test_frontend_function( + input_dtype, x = dtype_and_input + assume(matrix_is_stable(x[0])) + if x[0].dtype == ivy.float32: + x[0] = x[0].astype("float64") + input_dtype = [ivy.float64] + ret, frontend_ret = helpers.test_frontend_function( input_dtypes=input_dtype, - frontend=frontend, backend_to_test=backend_fw, + frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, tensor=x[0], - ord=ord, - axis=axis, - keepdims=keepdims, + test_values=False, + ) + + ret = ivy.to_numpy(ret) + ret = ret.round(6) + ret = np.sort(ret) + frontend_ret = frontend_ret[0].numpy() + frontend_ret = frontend_ret.round(6) + frontend_ret = np.sort(frontend_ret) + + assert_all_close( + ret_np=ret, + ret_from_gt_np=frontend_ret, + rtol=1e-06, + atol=1e-06, + ground_truth_backend=frontend, ) -# normalize @handle_frontend_test( - fn_tree="tensorflow.linalg.normalize", - dtype_values_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("valid"), - large_abs_safety_factor=24, - small_abs_safety_factor=24, - safety_factor_scale="log", - min_num_dims=3, - max_num_dims=5, - min_dim_size=1, - max_dim_size=4, - min_axis=-3, - max_axis=2, - ), - ord=st.sampled_from([1, 2, np.inf]), + fn_tree="tensorflow.linalg.eigvalsh", + dtype_and_input=_get_dtype_and_matrix(), test_with_out=st.just(False), ) -def test_tensorflow_normalize( +def test_tensorflow_eigvalsh( *, - dtype_values_axis, - ord, - backend_fw, + dtype_and_input, frontend, + backend_fw, test_flags, fn_tree, on_device, ): - input_dtype, x, axis = dtype_values_axis + input_dtype, x = dtype_and_input + assume(matrix_is_stable(x[0])) helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -596,59 +493,56 @@ def test_tensorflow_normalize( fn_tree=fn_tree, on_device=on_device, tensor=x[0], - ord=ord, - axis=axis, - atol=1e-08, ) -# l2_normalize @handle_frontend_test( - fn_tree="tensorflow.linalg.l2_normalize", - dtype_values_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("valid"), - min_num_dims=3, - max_num_dims=5, - min_dim_size=1, - max_dim_size=4, - min_axis=-3, - max_axis=2, - ), + fn_tree="tensorflow.linalg.expm", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + num_arrays=1, + min_value=1, + max_value=10, + shape=helpers.ints(min_value=3, max_value=3).map(lambda x: tuple([x, x])), + ).filter(lambda x: "float16" not in x[0]), + test_with_out=st.just(False), ) -def test_tensorflow_l2_normalize( +def test_tensorflow_expm( *, - dtype_values_axis, - backend_fw, + dtype_and_x, + on_device, + fn_tree, frontend, + backend_fw, test_flags, - fn_tree, - on_device, ): - input_dtype, x, axis = dtype_values_axis + dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, + input_dtypes=dtype, frontend=frontend, + backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], - axis=axis, + input=x[0], + atol=1, + rtol=1e-01, ) -# trace @handle_frontend_test( - fn_tree="tensorflow.linalg.trace", - dtype_and_input=_get_dtype_and_matrix(), + fn_tree="tensorflow.linalg.global_norm", + dtype_and_input=_get_dtype_and_sequence_of_arrays(), test_with_out=st.just(False), ) -def test_tensorflow_trace( +def test_tensorflow_global_norm( + *, dtype_and_input, backend_fw, frontend, test_flags, fn_tree, + on_device, ): input_dtype, x = dtype_and_input helpers.test_frontend_function( @@ -657,71 +551,75 @@ def test_tensorflow_trace( frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, - x=x[0], + on_device=on_device, + t_list=x, ) -# matrix_transpose +# inv @handle_frontend_test( - fn_tree="tensorflow.linalg.matrix_transpose", - dtype_and_input=helpers.dtype_and_values( + fn_tree="tensorflow.linalg.inv", + dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), - min_num_dims=2, + min_value=-100, + max_value=100, + shape=helpers.ints(min_value=1, max_value=20).map(lambda x: tuple([x, x])), + ).filter( + lambda x: "bfloat16" not in x[0] + and np.linalg.cond(x[1][0]) < 1 / sys.float_info.epsilon + and np.linalg.det(np.asarray(x[1][0])) != 0 ), - conjugate=st.booleans(), + adjoint=st.booleans(), test_with_out=st.just(False), ) -def test_tensorflow_matrix_transpose( - dtype_and_input, - conjugate, - backend_fw, +def test_tensorflow_inv( + *, + dtype_and_x, + on_device, + fn_tree, frontend, + backend_fw, test_flags, - fn_tree, + adjoint, ): - input_dtype, x = dtype_and_input + dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=dtype, + rtol=1e-01, + atol=1e-01, frontend=frontend, backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, - a=x[0], - conjugate=conjugate, + on_device=on_device, + input=x[0], + adjoint=adjoint, ) -@st.composite -def _get_dtype_and_sequence_of_arrays(draw): - array_dtype = draw(helpers.get_dtypes("float", full=False)) - arbitrary_size = draw(st.integers(min_value=2, max_value=10)) - values = [] - for i in range(arbitrary_size): - values.append( - draw( - helpers.array_values( - dtype=array_dtype[0], shape=helpers.get_shape(), allow_nan=True - ) - ) - ) - return array_dtype, values - - +# l2_normalize @handle_frontend_test( - fn_tree="tensorflow.linalg.global_norm", - dtype_and_input=_get_dtype_and_sequence_of_arrays(), - test_with_out=st.just(False), + fn_tree="tensorflow.linalg.l2_normalize", + dtype_values_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=3, + max_num_dims=5, + min_dim_size=1, + max_dim_size=4, + min_axis=-3, + max_axis=2, + ), ) -def test_tensorflow_global_norm( +def test_tensorflow_l2_normalize( *, - dtype_and_input, + dtype_values_axis, backend_fw, frontend, test_flags, fn_tree, on_device, ): - input_dtype, x = dtype_and_input + input_dtype, x, axis = dtype_values_axis helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -729,7 +627,8 @@ def test_tensorflow_global_norm( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - t_list=x, + x=x[0], + axis=axis, ) @@ -808,113 +707,263 @@ def test_tensorflow_linalg_cross( ) +# einsum @handle_frontend_test( - fn_tree="tensorflow.linalg.svd", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_value=0, - max_value=10, - shape=helpers.ints(min_value=2, max_value=5).map(lambda x: tuple([x, x])), - ), - full_matrices=st.booleans(), - compute_uv=st.just(True), + fn_tree="tensorflow.linalg.einsum", + eq_n_op_n_shp=helpers.einsum_helper(), + dtype=helpers.get_dtypes("numeric", full=False), ) -def test_tensorflow_svd( +def test_tensorflow_linalg_einsum( *, - dtype_and_x, + eq_n_op_n_shp, + dtype, + on_device, + fn_tree, backend_fw, - full_matrices, - compute_uv, frontend, test_flags, +): + eq, operands, dtypes = eq_n_op_n_shp + kw = {} + for i, x_ in enumerate(operands): + dtype = dtypes[i][0] + kw["x{}".format(i)] = np.array(x_).astype(dtype) + test_flags.num_positional_args = len(operands) + 1 + helpers.test_frontend_function( + input_dtypes=dtypes, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + equation=eq, + **kw, + ) + + +@handle_frontend_test( + fn_tree="tensorflow.linalg.logdet", + dtype_and_x=_get_hermitian_pos_def_matrix(), +) +def test_tensorflow_logdet( + *, + dtype_and_x, + frontend, + backend_fw, + test_flags, fn_tree, on_device, ): dtype, x = dtype_and_x - x = np.asarray(x[0], dtype=dtype[0]) - # make symmetric positive definite beforehand - x = np.matmul(x.T, x) + np.identity(x.shape[0]) * 1e-3 - ret, frontend_ret = helpers.test_frontend_function( + helpers.test_frontend_function( input_dtypes=dtype, frontend=frontend, backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - test_values=False, - atol=1e-03, - rtol=1e-05, - a=x, - full_matrices=full_matrices, - compute_uv=compute_uv, + matrix=x, ) - ret = [ivy.to_numpy(x) for x in ret] - frontend_ret = [np.asarray(x) for x in frontend_ret] - u, s, vh = ret - frontend_s, frontend_u, frontend_vh = frontend_ret - assert_all_close( - ret_np=u @ np.diag(s) @ vh, - ret_from_gt_np=frontend_u @ np.diag(frontend_s) @ frontend_vh.T, - rtol=1e-2, - atol=1e-2, - ground_truth_backend=frontend, +@handle_frontend_test( + fn_tree="tensorflow.linalg.matmul", + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + shape=(3, 3), + num_arrays=2, + shared_dtype=True, + min_value=-1, + max_value=100, + ), + transpose_a=st.booleans(), + transpose_b=st.booleans(), + test_with_out=st.just(False), +) +def test_tensorflow_matmul( + *, + dtype_x, + transpose_a, + transpose_b, + frontend, + backend_fw, + test_flags, + fn_tree, + on_device, +): + input_dtype, x = dtype_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + a=x[0], + b=x[1], + transpose_a=transpose_a, + transpose_b=transpose_b, ) -# einsum @handle_frontend_test( - fn_tree="tensorflow.linalg.einsum", - eq_n_op_n_shp=helpers.einsum_helper(), - dtype=helpers.get_dtypes("numeric", full=False), + fn_tree="tensorflow.linalg.matrix_rank", + dtype_x_hermitian_atol_rtol=_matrix_rank_helper(), + test_with_out=st.just(False), ) -def test_tensorflow_linalg_einsum( +def test_tensorflow_matrix_rank( *, - eq_n_op_n_shp, - dtype, + dtype_x_hermitian_atol_rtol, + frontend, + test_flags, + backend_fw, + fn_tree, on_device, +): + dtype, x, hermitian, atol, rtol = dtype_x_hermitian_atol_rtol + assume(matrix_is_stable(x, cond_limit=10)) + helpers.test_frontend_function( + input_dtypes=dtype, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + a=x, + tol=atol, + ) + + +# matrix_transpose +@handle_frontend_test( + fn_tree="tensorflow.linalg.matrix_transpose", + dtype_and_input=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=2, + ), + conjugate=st.booleans(), + test_with_out=st.just(False), +) +def test_tensorflow_matrix_transpose( + dtype_and_input, + conjugate, + backend_fw, + frontend, + test_flags, + fn_tree, +): + input_dtype, x = dtype_and_input + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + a=x[0], + conjugate=conjugate, + ) + + +# norm +@handle_frontend_test( + fn_tree="tensorflow.linalg.norm", + aliases=["tensorflow.norm"], + dtype_values_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=3, + max_num_dims=5, + min_dim_size=1, + max_dim_size=4, + min_axis=-3, + max_axis=2, + ), + ord=st.sampled_from([1, 2, np.inf]), + keepdims=st.booleans(), +) +def test_tensorflow_norm( + *, + dtype_values_axis, + ord, + keepdims, + backend_fw, + frontend, + test_flags, fn_tree, + on_device, +): + input_dtype, x, axis = dtype_values_axis + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + tensor=x[0], + ord=ord, + axis=axis, + keepdims=keepdims, + ) + + +# normalize +@handle_frontend_test( + fn_tree="tensorflow.linalg.normalize", + dtype_values_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("valid"), + large_abs_safety_factor=24, + small_abs_safety_factor=24, + safety_factor_scale="log", + min_num_dims=3, + max_num_dims=5, + min_dim_size=1, + max_dim_size=4, + min_axis=-3, + max_axis=2, + ), + ord=st.sampled_from([1, 2, np.inf]), + test_with_out=st.just(False), +) +def test_tensorflow_normalize( + *, + dtype_values_axis, + ord, backend_fw, frontend, test_flags, + fn_tree, + on_device, ): - eq, operands, dtypes = eq_n_op_n_shp - kw = {} - for i, x_ in enumerate(operands): - dtype = dtypes[i][0] - kw["x{}".format(i)] = np.array(x_).astype(dtype) - test_flags.num_positional_args = len(operands) + 1 + input_dtype, x, axis = dtype_values_axis helpers.test_frontend_function( - input_dtypes=dtypes, + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - equation=eq, - **kw, + tensor=x[0], + ord=ord, + axis=axis, + atol=1e-08, ) -# adjoint +# pinv @handle_frontend_test( - fn_tree="tensorflow.linalg.adjoint", - dtype_and_x=_get_dtype_and_matrix().filter( - lambda x: "float16" not in x[0] and "bfloat16" not in x[0] - ), # TODO : remove this filter when paddle.conj supports float16 - test_with_out=st.just(False), + fn_tree="tensorflow.linalg.pinv", + dtype_and_input=_get_dtype_and_matrix(), ) -def test_tensorflow_adjoint( +def test_tensorflow_pinv( *, - dtype_and_x, - backend_fw, + dtype_and_input, frontend, + backend_fw, test_flags, fn_tree, on_device, ): - input_dtype, x = dtype_and_x + input_dtype, x = dtype_and_input helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -922,27 +971,28 @@ def test_tensorflow_adjoint( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - matrix=x[0], + rtol=1e-3, + atol=1e-3, + a=x[0], + rcond=1e-15, ) -# diag +# Tests for tensorflow.linalg.set_diag function's frontend @handle_frontend_test( - fn_tree="tensorflow.linalg.diag", + fn_tree="tensorflow.linalg.set_diag", dtype_and_x=helpers.dtype_and_values( - available_dtypes=["int64", "int32"], - min_num_dims=1, - max_num_dims=2, - min_dim_size=5, - max_dim_size=10, - min_value=0, - max_value=10, + available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=2, + max_num_dims=3, + min_dim_size=3, + max_dim_size=6, + min_value=-10.0, + max_value=10.0, ), - k=st.just(0), ) -def test_tensorflow_diag( +def test_tensorflow_set_diag( dtype_and_x, - k, frontend, backend_fw, test_flags, @@ -950,6 +1000,7 @@ def test_tensorflow_diag( on_device, ): dtype, x = dtype_and_x + x = ivy.squeeze(x) helpers.test_frontend_function( input_dtypes=dtype, backend_to_test=backend_fw, @@ -957,45 +1008,27 @@ def test_tensorflow_diag( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - v=x[0], - k=k, - ) - - -@st.composite -def _get_dtype_and_matrix_and_num(draw): - arbitrary_dims = draw(helpers.get_shape(max_dim_size=5)) - random_size = draw(st.integers(min_value=1, max_value=4)) - shape = (*arbitrary_dims, random_size, random_size) - dtype_and_values = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - shape=shape, - min_value=-10, - max_value=10, - ) + input=x, + diagonal=x[0], ) - num_lower = draw(st.integers(min_value=-1, max_value=random_size - 1)) - num_upper = draw(st.integers(min_value=-1, max_value=random_size - 1)) - return (*dtype_and_values, num_lower, num_upper) -# band_part +# slogdet @handle_frontend_test( - fn_tree="tensorflow.linalg.band_part", - dtype_and_input=_get_dtype_and_matrix_and_num(), + fn_tree="tensorflow.linalg.slogdet", + dtype_and_x=_get_dtype_and_matrix(), test_with_out=st.just(False), ) -def test_tensorflow_band_part( +def test_tensorflow_slogdet( *, - dtype_and_input, + dtype_and_x, frontend, backend_fw, test_flags, fn_tree, on_device, ): - input_dtype, x, num_lower, num_upper = dtype_and_input + input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -1004,92 +1037,91 @@ def test_tensorflow_band_part( fn_tree=fn_tree, on_device=on_device, input=x[0], - num_lower=num_lower, - num_upper=num_upper, ) -# inv +# solve @handle_frontend_test( - fn_tree="tensorflow.linalg.inv", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - min_value=-100, - max_value=100, - shape=helpers.ints(min_value=1, max_value=20).map(lambda x: tuple([x, x])), - ).filter( - lambda x: "bfloat16" not in x[0] - and np.linalg.cond(x[1][0]) < 1 / sys.float_info.epsilon - and np.linalg.det(np.asarray(x[1][0])) != 0 - ), - adjoint=st.booleans(), + fn_tree="tensorflow.linalg.solve", + x=_get_first_matrix(), + y=_get_second_matrix(), test_with_out=st.just(False), ) -def test_tensorflow_inv( +def test_tensorflow_solve( *, - dtype_and_x, - on_device, - fn_tree, + x, + y, frontend, backend_fw, test_flags, - adjoint, + fn_tree, + on_device, ): - dtype, x = dtype_and_x + input_dtype1, x1 = x + input_dtype2, x2 = y helpers.test_frontend_function( - input_dtypes=dtype, - rtol=1e-01, - atol=1e-01, - frontend=frontend, + input_dtypes=[input_dtype1, input_dtype2], backend_to_test=backend_fw, + frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=x[0], - adjoint=adjoint, + rtol=1e-3, + atol=1e-3, + matrix=x1, + rhs=x2, ) -# qr @handle_frontend_test( - fn_tree="tensorflow.linalg.qr", + fn_tree="tensorflow.linalg.svd", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), min_value=0, max_value=10, shape=helpers.ints(min_value=2, max_value=5).map(lambda x: tuple([x, x])), ), + full_matrices=st.booleans(), + compute_uv=st.just(True), ) -def test_qr( +def test_tensorflow_svd( *, dtype_and_x, + backend_fw, + full_matrices, + compute_uv, frontend, test_flags, fn_tree, on_device, - backend_fw, ): dtype, x = dtype_and_x x = np.asarray(x[0], dtype=dtype[0]) + # make symmetric positive definite beforehand x = np.matmul(x.T, x) + np.identity(x.shape[0]) * 1e-3 ret, frontend_ret = helpers.test_frontend_function( input_dtypes=dtype, - backend_to_test=backend_fw, frontend=frontend, + backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, test_values=False, atol=1e-03, rtol=1e-05, - input=x, + a=x, + full_matrices=full_matrices, + compute_uv=compute_uv, ) ret = [ivy.to_numpy(x) for x in ret] frontend_ret = [np.asarray(x) for x in frontend_ret] + u, s, vh = ret + frontend_s, frontend_u, frontend_vh = frontend_ret + assert_all_close( - ret_np=ret[0], - ret_from_gt_np=frontend_ret[0], + ret_np=u @ np.diag(s) @ vh, + ret_from_gt_np=frontend_u @ np.diag(frontend_s) @ frontend_vh.T, rtol=1e-2, atol=1e-2, ground_truth_backend=frontend, @@ -1131,62 +1163,55 @@ def test_tensorflow_tensor_diag( ) -# Tests for tensorflow.linalg.set_diag function's frontend +# Tests for tensorflow.linalg.tensor_diag_part function's frontend @handle_frontend_test( - fn_tree="tensorflow.linalg.set_diag", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - min_num_dims=2, - max_num_dims=3, - min_dim_size=3, - max_dim_size=6, - min_value=-10.0, - max_value=10.0, - ), + fn_tree="tensorflow.linalg.tensor_diag_part", + dtype_and_input=_get_dtype_and_rank_2k_tensors(), + test_with_out=st.just(False), ) -def test_tensorflow_set_diag( - dtype_and_x, +def test_tensorflow_tensor_diag_part( + *, + dtype_and_input, frontend, - backend_fw, test_flags, fn_tree, on_device, + backend_fw, ): - dtype, x = dtype_and_x - x = ivy.squeeze(x) + dtype, input = dtype_and_input helpers.test_frontend_function( input_dtypes=dtype, - backend_to_test=backend_fw, frontend=frontend, + backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=x, - diagonal=x[0], + input=input[0], ) +# tensordot @handle_frontend_test( - fn_tree="tensorflow.linalg.expm", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=1, - min_value=1, - max_value=10, - shape=helpers.ints(min_value=3, max_value=3).map(lambda x: tuple([x, x])), - ).filter(lambda x: "float16" not in x[0]), - test_with_out=st.just(False), + fn_tree="tensorflow.linalg.tensordot", + dtype_x_y_axes=_get_dtype_value1_value2_axis_for_tensordot( + available_dtypes=helpers.get_dtypes("numeric"), + ), ) -def test_tensorflow_expm( +def test_tensorflow_tensordot( *, - dtype_and_x, - on_device, - fn_tree, - frontend, + dtype_x_y_axes, backend_fw, + frontend, test_flags, + fn_tree, + on_device, ): - dtype, x = dtype_and_x + ( + dtype, + x, + y, + axes, + ) = dtype_x_y_axes helpers.test_frontend_function( input_dtypes=dtype, frontend=frontend, @@ -1194,48 +1219,31 @@ def test_tensorflow_expm( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=x[0], - atol=1, - rtol=1e-01, - ) - - -@st.composite -def _get_dtype_and_rank_2k_tensors(draw): - arbitrary_dims = draw(helpers.get_shape(max_dim_size=5)) - shape = arbitrary_dims + arbitrary_dims - return draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - shape=shape, - min_value=-10, - max_value=10, - ) + a=x, + b=y, + axes=axes, ) -# Tests for tensorflow.linalg.tensor_diag_part function's frontend +# trace @handle_frontend_test( - fn_tree="tensorflow.linalg.tensor_diag_part", - dtype_and_input=_get_dtype_and_rank_2k_tensors(), + fn_tree="tensorflow.linalg.trace", + dtype_and_input=_get_dtype_and_matrix(), test_with_out=st.just(False), ) -def test_tensorflow_tensor_diag_part( - *, +def test_tensorflow_trace( dtype_and_input, + backend_fw, frontend, test_flags, fn_tree, - on_device, - backend_fw, ): - dtype, input = dtype_and_input + input_dtype, x = dtype_and_input helpers.test_frontend_function( - input_dtypes=dtype, - frontend=frontend, + input_dtypes=input_dtype, backend_to_test=backend_fw, + frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, - on_device=on_device, - input=input[0], + x=x[0], ) diff --git a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_math.py b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_math.py index 9f7f44cf4cf3f..db0e0a9de988f 100644 --- a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_math.py +++ b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_math.py @@ -11,36 +11,36 @@ from ivy_tests.test_ivy.helpers import handle_frontend_test -# imag +# abs @handle_frontend_test( - fn_tree="tensorflow.math.imag", + fn_tree="tensorflow.math.abs", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float_and_complex"), - min_value=-20, - max_value=20, + available_dtypes=helpers.get_dtypes("numeric"), + large_abs_safety_factor=25, + small_abs_safety_factor=25, + safety_factor_scale="log", ), test_with_out=st.just(False), ) -def test_tensorflow_imag( +def test_tensorflow_abs( *, dtype_and_x, - test_flags, on_device, fn_tree, frontend, + test_flags, backend_fw, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, - test_flags=test_flags, frontend=frontend, + test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - rtol=1e-2, - atol=1e-2, - input=x[0], + rtol=1e-02, + x=x[0], ) @@ -75,24 +75,24 @@ def test_tensorflow_accumulate_n( ) -# add +# acos @handle_frontend_test( - fn_tree="tensorflow.math.add", + fn_tree="tensorflow.math.acos", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, - shared_dtype=True, + available_dtypes=helpers.get_dtypes("float"), + min_value=-1, + max_value=1, ), test_with_out=st.just(False), ) -def test_tensorflow_add( +def test_tensorflow_acos( *, dtype_and_x, + on_device, + fn_tree, frontend, test_flags, - fn_tree, backend_fw, - on_device, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -103,24 +103,27 @@ def test_tensorflow_add( fn_tree=fn_tree, on_device=on_device, x=x[0], - y=x[1], ) -# sin +# acosh @handle_frontend_test( - fn_tree="tensorflow.math.sin", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), + fn_tree="tensorflow.math.acosh", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float_and_complex"), + small_abs_safety_factor=3, + safety_factor_scale="log", + ), test_with_out=st.just(False), ) -def test_tensorflow_sin( +def test_tensorflow_acosh( *, dtype_and_x, + on_device, + fn_tree, frontend, test_flags, - fn_tree, backend_fw, - on_device, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -130,17 +133,22 @@ def test_tensorflow_sin( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, + rtol=1e-02, x=x[0], ) -# tan +# add @handle_frontend_test( - fn_tree="tensorflow.math.tan", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), + fn_tree="tensorflow.math.add", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, + shared_dtype=True, + ), test_with_out=st.just(False), ) -def test_tensorflow_tan( +def test_tensorflow_add( *, dtype_and_x, frontend, @@ -158,23 +166,27 @@ def test_tensorflow_tan( fn_tree=fn_tree, on_device=on_device, x=x[0], + y=x[1], ) -# exp +# add_n @handle_frontend_test( - fn_tree="tensorflow.math.exp", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), - test_with_out=st.just(False), + fn_tree="tensorflow.math.add_n", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + num_arrays=helpers.ints(min_value=1, max_value=5), + shared_dtype=True, + ), ) -def test_tensorflow_exp( +def test_tensorflow_add_n( *, dtype_and_x, + on_device, + fn_tree, frontend, test_flags, - fn_tree, backend_fw, - on_device, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -184,26 +196,27 @@ def test_tensorflow_exp( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], + inputs=x, ) -# expm1 +# angle @handle_frontend_test( - fn_tree="tensorflow.math.expm1", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), - test_with_out=st.just(False), + fn_tree="tensorflow.math.angle", + dtype_and_input=helpers.dtype_and_values( + available_dtypes=["float64", "complex64", "complex128"], + ), ) -def test_tensorflow_expm1( +def test_tensorflow_angle( *, - dtype_and_x, + dtype_and_input, frontend, test_flags, fn_tree, backend_fw, on_device, ): - input_dtype, x = dtype_and_x + input_dtype, x = dtype_and_input helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -211,17 +224,18 @@ def test_tensorflow_expm1( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], + input=x[0], ) -# sqrt +# argmax @handle_frontend_test( - fn_tree="tensorflow.math.sqrt", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), + fn_tree="tensorflow.math.argmax", + dtype_and_x=_statistical_dtype_values(function="argmax"), + output_type=st.sampled_from(["int16", "uint16", "int32", "int64"]), test_with_out=st.just(False), ) -def test_tensorflow_sqrt( +def test_tensorflow_argmax( *, dtype_and_x, frontend, @@ -229,8 +243,13 @@ def test_tensorflow_sqrt( fn_tree, backend_fw, on_device, + output_type, ): - input_dtype, x = dtype_and_x + if backend_fw in ("torch", "paddle"): + assume(output_type != "uint16") + input_dtype, x, axis = dtype_and_x + if isinstance(axis, tuple): + axis = axis[0] helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -238,21 +257,20 @@ def test_tensorflow_sqrt( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], + input=x[0], + axis=axis, + output_type=output_type, ) -# multiply +# argmin @handle_frontend_test( - fn_tree="tensorflow.math.multiply", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - shared_dtype=True, - ), + fn_tree="tensorflow.math.argmin", + dtype_and_x=_statistical_dtype_values(function="argmin"), + output_type=st.sampled_from(["int32", "int64"]), test_with_out=st.just(False), ) -def test_tensorflow_multiply( +def test_tensorflow_argmin( *, dtype_and_x, frontend, @@ -260,8 +278,11 @@ def test_tensorflow_multiply( fn_tree, backend_fw, on_device, + output_type, ): - input_dtype, x = dtype_and_x + input_dtype, x, axis = dtype_and_x + if isinstance(axis, tuple): + axis = axis[0] helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -269,29 +290,30 @@ def test_tensorflow_multiply( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], - y=x[1], + input=x[0], + axis=axis, + output_type=output_type, ) -# maximum +# asin @handle_frontend_test( - fn_tree="tensorflow.math.maximum", + fn_tree="tensorflow.math.asin", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - shared_dtype=True, + min_value=-1, + max_value=1, ), test_with_out=st.just(False), ) -def test_tensorflow_maximum( +def test_tensorflow_asin( *, dtype_and_x, + on_device, + fn_tree, frontend, test_flags, - fn_tree, backend_fw, - on_device, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -302,21 +324,18 @@ def test_tensorflow_maximum( fn_tree=fn_tree, on_device=on_device, x=x[0], - y=x[1], ) -# subtract +# asinh @handle_frontend_test( - fn_tree="tensorflow.math.subtract", + fn_tree="tensorflow.math.asinh", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, - shared_dtype=True, + available_dtypes=helpers.get_dtypes("float"), ), test_with_out=st.just(False), ) -def test_tensorflow_subtract( +def test_tensorflow_asinh( *, dtype_and_x, frontend, @@ -334,21 +353,16 @@ def test_tensorflow_subtract( fn_tree=fn_tree, on_device=on_device, x=x[0], - y=x[1], ) -# squared_difference +# atan @handle_frontend_test( - fn_tree="tensorflow.math.squared_difference", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, - shared_dtype=True, - ), + fn_tree="tensorflow.math.atan", + dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), test_with_out=st.just(False), ) -def test_tensorflow_squared_difference( +def test_tensorflow_atan( *, dtype_and_x, frontend, @@ -366,21 +380,18 @@ def test_tensorflow_squared_difference( fn_tree=fn_tree, on_device=on_device, x=x[0], - y=x[1], ) -# logical_not +# atan2 @handle_frontend_test( - fn_tree="tensorflow.math.logical_not", + fn_tree="tensorflow.math.atan2", dtype_and_x=helpers.dtype_and_values( - available_dtypes=tuple([ivy.bool]), - num_arrays=2, - shared_dtype=True, + available_dtypes=helpers.get_dtypes("float"), num_arrays=2, shared_dtype=True ), test_with_out=st.just(False), ) -def test_tensorflow_logical_not( +def test_tensorflow_atan2( *, dtype_and_x, frontend, @@ -390,6 +401,7 @@ def test_tensorflow_logical_not( on_device, ): input_dtype, x = dtype_and_x + assume(not np.any(np.isclose(x[1], 0))) helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -397,60 +409,65 @@ def test_tensorflow_logical_not( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], + y=x[0], + x=x[1], ) -# logical_xor +# atanh @handle_frontend_test( - fn_tree="tensorflow.math.logical_xor", + fn_tree="tensorflow.math.atanh", dtype_and_x=helpers.dtype_and_values( - available_dtypes=tuple([ivy.bool]), - num_arrays=2, - shared_dtype=True, + available_dtypes=helpers.get_dtypes("float_and_complex"), ), test_with_out=st.just(False), ) -def test_tensorflow_logical_xor( +def test_tensorflow_atanh( *, dtype_and_x, - frontend, - test_flags, + on_device, fn_tree, backend_fw, - on_device, + frontend, + test_flags, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, - backend_to_test=backend_fw, frontend=frontend, + backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, x=x[0], - y=x[1], ) -# divide +# bincount @handle_frontend_test( - fn_tree="tensorflow.math.divide", + fn_tree="tensorflow.math.bincount", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - shared_dtype=True, + available_dtypes=helpers.get_dtypes("integer"), + min_value=1, + max_value=2, + shape=st.shared( + helpers.get_shape( + min_num_dims=1, + max_num_dims=1, + ), + key="a_s_d", + ), ), test_with_out=st.just(False), ) -def test_tensorflow_divide( +def test_tensorflow_bincount( *, dtype_and_x, + on_device, + backend_fw, + fn_tree, frontend, test_flags, - fn_tree, - backend_fw, - on_device, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -460,56 +477,63 @@ def test_tensorflow_divide( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], - y=x[1], + arr=x[0], + weights=None, + minlength=0, ) -# negative +# ceil @handle_frontend_test( - fn_tree="tensorflow.math.negative", + fn_tree="tensorflow.math.ceil", dtype_and_x=helpers.dtype_and_values( - available_dtypes=st.one_of( - helpers.get_dtypes("signed_integer"), - helpers.get_dtypes("float"), - ) + available_dtypes=helpers.get_dtypes("float"), + num_arrays=1, + min_value=-20, + max_value=20, ), test_with_out=st.just(False), ) -def test_tensorflow_negative( +def test_tensorflow_ceil( *, dtype_and_x, - frontend, test_flags, - fn_tree, + frontend, backend_fw, + fn_tree, on_device, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, - frontend=frontend, test_flags=test_flags, + frontend=frontend, fn_tree=fn_tree, on_device=on_device, x=x[0], ) -# logical_and +# confusion_matrix @handle_frontend_test( - fn_tree="tensorflow.math.logical_and", + fn_tree="tensorflow.math.confusion_matrix", dtype_and_x=helpers.dtype_and_values( - available_dtypes=tuple([ivy.bool]), + available_dtypes=helpers.get_dtypes("integer"), num_arrays=2, + min_num_dims=1, + max_num_dims=1, + min_value=0, + max_value=4, shared_dtype=True, ), + num_classes=st.integers(min_value=5, max_value=10), test_with_out=st.just(False), ) -def test_tensorflow_logical_and( +def test_tensorflow_confusion_matrix( *, dtype_and_x, + num_classes, frontend, test_flags, fn_tree, @@ -524,31 +548,29 @@ def test_tensorflow_logical_and( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], - y=x[1], + labels=x[0], + predictions=x[1], + num_classes=num_classes, ) -# logical_or +# conj @handle_frontend_test( - fn_tree="tensorflow.math.logical_or", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("bool"), - num_arrays=2, - shared_dtype=True, + fn_tree="tensorflow.math.conj", + dtype_and_input=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), ), - test_with_out=st.just(False), ) -def test_tensorflow_logical_or( +def test_tensorflow_conj( *, - dtype_and_x, + dtype_and_input, frontend, test_flags, fn_tree, backend_fw, on_device, ): - input_dtype, x = dtype_and_x + input_dtype, x = dtype_and_input helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -557,29 +579,24 @@ def test_tensorflow_logical_or( fn_tree=fn_tree, on_device=on_device, x=x[0], - y=x[1], ) -# log_sigmoid @handle_frontend_test( - fn_tree="tensorflow.math.log_sigmoid", + fn_tree="tensorflow.math.cos", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - large_abs_safety_factor=3, - small_abs_safety_factor=3, - safety_factor_scale="linear", ), test_with_out=st.just(False), ) -def test_tensorflow_log_sigmoid( +def test_tensorflow_cos( *, dtype_and_x, + on_device, + fn_tree, frontend, test_flags, - fn_tree, backend_fw, - on_device, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -593,16 +610,15 @@ def test_tensorflow_log_sigmoid( ) -# log1p +# cosh @handle_frontend_test( - fn_tree="tensorflow.math.log1p", + fn_tree="tensorflow.math.cosh", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - safety_factor_scale="log", + available_dtypes=helpers.get_dtypes("float_and_complex"), ), test_with_out=st.just(False), ) -def test_tensorflow_log1p( +def test_tensorflow_cosh( *, dtype_and_x, frontend, @@ -623,25 +639,30 @@ def test_tensorflow_log1p( ) -# reciprocal +# count_nonzero @handle_frontend_test( - fn_tree="tensorflow.math.reciprocal", - dtype_and_x=helpers.dtype_and_values( + fn_tree="tensorflow.math.count_nonzero", + dtype_x_axis=helpers.dtype_values_axis( available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=1, + valid_axis=True, + allow_neg_axes=False, ), + keepdims=st.booleans(), + dtype=helpers.get_dtypes("numeric", full=False), test_with_out=st.just(False), ) -def test_tensorflow_reciprocal( +def test_tensorflow_count_nonzero( *, - dtype_and_x, + dtype_x_axis, + dtype, + keepdims, frontend, test_flags, fn_tree, backend_fw, on_device, ): - input_dtype, x = dtype_and_x + input_dtype, x, axis = dtype_x_axis helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -649,79 +670,108 @@ def test_tensorflow_reciprocal( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - rtol=1e-3, - atol=1e-3, - x=x[0], + input=x, + axis=axis, + keepdims=keepdims, + dtype=dtype[0], ) -# reciprocal_no_nan +# cumprod @handle_frontend_test( - fn_tree="tensorflow.math.reciprocal_no_nan", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + fn_tree="tensorflow.math.cumprod", + dtype_x_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("numeric"), + valid_axis=True, + force_int_axis=True, + min_num_dims=1, + min_value=-5, + max_value=5, ), + exclusive=st.booleans(), + reverse=st.booleans(), test_with_out=st.just(False), ) -def test_tensorflow_reciprocal_no_nan( +def test_tensorflow_cumprod( # NOQA *, - dtype_and_x, + dtype_x_axis, + exclusive, + reverse, frontend, test_flags, fn_tree, backend_fw, on_device, ): - input_dtype, x = dtype_and_x + dtype, x, axis = dtype_x_axis helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, x=x[0], + axis=axis, + exclusive=exclusive, + reverse=reverse, ) -# reduce_all() +# cumsum @handle_frontend_test( - fn_tree="tensorflow.math.reduce_all", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=tuple([ivy.bool]), + fn_tree="tensorflow.math.cumsum", + dtype_x_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("numeric"), + valid_axis=True, + force_int_axis=True, + min_num_dims=1, + min_value=-5, + max_value=5, ), + exclusive=st.booleans(), + reverse=st.booleans(), test_with_out=st.just(False), ) -def test_tensorflow_reduce_all( +def test_tensorflow_cumsum( # NOQA *, - dtype_and_x, + dtype_x_axis, + exclusive, + reverse, frontend, test_flags, fn_tree, backend_fw, on_device, ): - input_dtype, x = dtype_and_x + dtype, x, axis = dtype_x_axis helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=dtype, backend_to_test=backend_fw, - frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, + frontend=frontend, on_device=on_device, - input_tensor=x[0], + rtol=1e-02, + atol=1e-02, + x=x[0], + axis=axis, + exclusive=exclusive, + reverse=reverse, ) -# reduce_any +# divide @handle_frontend_test( - fn_tree="tensorflow.math.reduce_any", + fn_tree="tensorflow.math.divide", dtype_and_x=helpers.dtype_and_values( - available_dtypes=tuple([ivy.bool]), + available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, + shared_dtype=True, ), test_with_out=st.just(False), ) -def test_tensorflow_reduce_any( +def test_tensorflow_divide( *, dtype_and_x, frontend, @@ -730,10 +780,7 @@ def test_tensorflow_reduce_any( backend_fw, on_device, ): - ( - input_dtype, - x, - ) = dtype_and_x + input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -741,20 +788,22 @@ def test_tensorflow_reduce_any( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input_tensor=x[0], + x=x[0], + y=x[1], ) -# reduce_euclidean_norm +# divide_no_nan @handle_frontend_test( - fn_tree="tensorflow.math.reduce_euclidean_norm", + fn_tree="tensorflow.math.divide_no_nan", dtype_and_x=helpers.dtype_and_values( + num_arrays=2, available_dtypes=helpers.get_dtypes("float"), - max_num_dims=2, + shared_dtype=True, ), test_with_out=st.just(False), ) -def test_tensorflow_reduce_euclidean_norm( +def test_tensorflow_divide_no_nan( *, dtype_and_x, frontend, @@ -763,32 +812,30 @@ def test_tensorflow_reduce_euclidean_norm( backend_fw, on_device, ): - ( - input_dtype, - x, - ) = dtype_and_x + input_dtypes, xy = dtype_and_x helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=input_dtypes, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, - rtol=1e-01, - atol=1e-01, on_device=on_device, - input_tensor=x[0], + x=xy[0], + y=xy[1], ) -# reduce_logsumexp +# equal @handle_frontend_test( - fn_tree="tensorflow.math.reduce_logsumexp", + fn_tree="tensorflow.math.equal", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, + shared_dtype=True, ), test_with_out=st.just(False), ) -def test_tensorflow_reduce_logsumexp( +def test_tensorflow_equal( *, dtype_and_x, frontend, @@ -805,18 +852,20 @@ def test_tensorflow_reduce_logsumexp( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input_tensor=x[0], + x=x[0], + y=x[1], ) -# argmax +# erfcinv @handle_frontend_test( - fn_tree="tensorflow.math.argmax", - dtype_and_x=_statistical_dtype_values(function="argmax"), - output_type=st.sampled_from(["int16", "uint16", "int32", "int64"]), + fn_tree="tensorflow.math.erfcinv", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + ), test_with_out=st.just(False), ) -def test_tensorflow_argmax( +def test_tensorflow_erfcinv( *, dtype_and_x, frontend, @@ -824,13 +873,8 @@ def test_tensorflow_argmax( fn_tree, backend_fw, on_device, - output_type, ): - if backend_fw in ("torch", "paddle"): - assume(output_type != "uint16") - input_dtype, x, axis = dtype_and_x - if isinstance(axis, tuple): - axis = axis[0] + input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -838,21 +882,17 @@ def test_tensorflow_argmax( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=x[0], - axis=axis, - output_type=output_type, + x=x[0], ) -# reduce_max +# exp @handle_frontend_test( - fn_tree="tensorflow.math.reduce_max", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - ), + fn_tree="tensorflow.math.exp", + dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), test_with_out=st.just(False), ) -def test_tensorflow_reduce_max( +def test_tensorflow_exp( *, dtype_and_x, frontend, @@ -869,19 +909,17 @@ def test_tensorflow_reduce_max( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input_tensor=x[0], + x=x[0], ) -# reduce_min +# expm1 @handle_frontend_test( - fn_tree="tensorflow.math.reduce_min", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - ), + fn_tree="tensorflow.math.expm1", + dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), test_with_out=st.just(False), ) -def test_tensorflow_reduce_min( +def test_tensorflow_expm1( *, dtype_and_x, frontend, @@ -898,50 +936,53 @@ def test_tensorflow_reduce_min( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input_tensor=x[0], + x=x[0], ) -# reduce_prod +# floor @handle_frontend_test( - fn_tree="tensorflow.math.reduce_prod", + fn_tree="tensorflow.math.floor", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), + available_dtypes=helpers.get_dtypes("float"), + num_arrays=1, + min_value=-20, + max_value=20, ), test_with_out=st.just(False), ) -def test_tensorflow_reduce_prod( +def test_tensorflow_floor( *, dtype_and_x, - frontend, test_flags, - fn_tree, + frontend, backend_fw, + fn_tree, on_device, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, - frontend=frontend, test_flags=test_flags, + frontend=frontend, fn_tree=fn_tree, on_device=on_device, - input_tensor=x[0], + x=x[0], ) -# reduce_std +# floormod @handle_frontend_test( - fn_tree="tensorflow.math.reduce_std", + fn_tree="tensorflow.math.floormod", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - large_abs_safety_factor=24, - small_abs_safety_factor=24, - safety_factor_scale="log", + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, + shared_dtype=True, ), + test_with_out=st.just(False), ) -def test_tensorflow_reduce_std( +def test_tensorflow_floormod( *, dtype_and_x, frontend, @@ -951,6 +992,8 @@ def test_tensorflow_reduce_std( on_device, ): input_dtype, x = dtype_and_x + assume(not np.any(np.isclose(x[0], 0))) + assume(not np.any(np.isclose(x[1], 0))) helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -958,19 +1001,22 @@ def test_tensorflow_reduce_std( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input_tensor=x[0], + x=x[0], + y=x[1], ) -# asinh +# greater @handle_frontend_test( - fn_tree="tensorflow.math.asinh", + fn_tree="tensorflow.math.greater", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, + shared_dtype=True, ), test_with_out=st.just(False), ) -def test_tensorflow_asinh( +def test_tensorflow_greater( *, dtype_and_x, frontend, @@ -988,21 +1034,21 @@ def test_tensorflow_asinh( fn_tree=fn_tree, on_device=on_device, x=x[0], + y=x[1], ) -# reduce_sum +# greater_equal @handle_frontend_test( - fn_tree="tensorflow.math.reduce_sum", + fn_tree="tensorflow.math.greater_equal", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), - large_abs_safety_factor=25, - small_abs_safety_factor=25, - safety_factor_scale="log", + num_arrays=2, + shared_dtype=True, ), test_with_out=st.just(False), ) -def test_tensorflow_reduce_sum( +def test_tensorflow_greater_equal( *, dtype_and_x, frontend, @@ -1019,110 +1065,99 @@ def test_tensorflow_reduce_sum( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - rtol=1e-03, - atol=1e-03, - input_tensor=x[0], + x=x[0], + y=x[1], ) -# reduce_mean @handle_frontend_test( - fn_tree="tensorflow.math.reduce_mean", + fn_tree="tensorflow.math.igamma", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("valid"), + num_arrays=2, + shared_dtype=True, + abs_smallest_val=1e-5, + min_num_dims=2, + max_num_dims=2, + min_dim_size=3, + max_dim_size=3, + min_value=2, + max_value=100, + allow_nan=False, ), test_with_out=st.just(False), ) -def test_tensorflow_reduce_mean( +def test_tensorflow_igamma( *, dtype_and_x, - frontend, - test_flags, + on_device, fn_tree, backend_fw, - on_device, + frontend, + test_flags, ): - input_dtype, x = dtype_and_x + input_dtype, xs = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, - backend_to_test=backend_fw, frontend=frontend, + backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, - atol=1e-2, - rtol=1e-2, on_device=on_device, - input_tensor=x[0], + rtol=1e-04, + a=xs[0], + x=xs[1], ) -# reduce_variance +# imag @handle_frontend_test( - fn_tree="tensorflow.math.reduce_variance", - dtype_and_x=_statistical_dtype_values( - function="var", + fn_tree="tensorflow.math.imag", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float_and_complex"), + min_value=-20, + max_value=20, ), test_with_out=st.just(False), - keepdims=st.booleans(), ) -def test_tensorflow_reduce_variance( +def test_tensorflow_imag( *, dtype_and_x, - frontend, test_flags, + on_device, fn_tree, + frontend, backend_fw, - on_device, - keepdims, ): - input_dtype, x, axis, ddof = dtype_and_x + input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, - frontend=frontend, test_flags=test_flags, + frontend=frontend, fn_tree=fn_tree, on_device=on_device, - input_tensor=x[0], - axis=axis, - atol=1e-2, rtol=1e-2, - keepdims=keepdims, + atol=1e-2, + input=x[0], ) -# scalar_mul +# in_top_k @handle_frontend_test( - fn_tree="tensorflow.math.scalar_mul", + fn_tree="tensorflow.math.in_top_k", dtype_and_x=helpers.dtype_and_values( - available_dtypes=st.shared( - helpers.get_dtypes("float", full=False), - key="shared_dtype", - ), - min_num_dims=1, - min_dim_size=2, - ), - scalar_val=helpers.dtype_and_values( - available_dtypes=st.shared( - helpers.get_dtypes("float", full=False), - key="shared_dtype", - ), - shape=(1,), + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, + shared_dtype=True, ), + k=st.integers(min_value=0, max_value=5), test_with_out=st.just(False), ) -def test_tensorflow_scalar_mul( - *, - dtype_and_x, - scalar_val, - frontend, - test_flags, - fn_tree, - backend_fw, - on_device, +def test_tensorflow_in_top_k( + *, dtype_and_x, frontend, test_flags, backend_fw, fn_tree, on_device, k ): input_dtype, x = dtype_and_x - scalar_dtype, scalar = scalar_val helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -1130,54 +1165,50 @@ def test_tensorflow_scalar_mul( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - scalar=scalar[0][0], - x=x[0], + targets=x[0], + pred=x[1], + k=k, ) -# divide_no_nan +# is_finite + + @handle_frontend_test( - fn_tree="tensorflow.math.divide_no_nan", + fn_tree="tensorflow.math.is_finite", dtype_and_x=helpers.dtype_and_values( - num_arrays=2, - available_dtypes=helpers.get_dtypes("float"), - shared_dtype=True, + available_dtypes=helpers.get_dtypes("numeric"), ), test_with_out=st.just(False), ) -def test_tensorflow_divide_no_nan( +def test_tensorflow_is_finite( *, dtype_and_x, + on_device, + fn_tree, frontend, test_flags, - fn_tree, backend_fw, - on_device, ): - input_dtypes, xy = dtype_and_x + input_dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=input_dtypes, + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=xy[0], - y=xy[1], + x=x[0], ) -# multiply_no_nan +# is_inf @handle_frontend_test( - fn_tree="tensorflow.math.multiply_no_nan", - dtype_and_x=helpers.dtype_and_values( - num_arrays=2, - available_dtypes=helpers.get_dtypes("float"), - shared_dtype=True, - ), + fn_tree="tensorflow.math.is_inf", + dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), test_with_out=st.just(False), ) -def test_tensorflow_multiply_no_nan( +def test_tensorflow_is_inf( *, dtype_and_x, frontend, @@ -1186,28 +1217,27 @@ def test_tensorflow_multiply_no_nan( backend_fw, on_device, ): - input_dtypes, xy = dtype_and_x + input_dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=input_dtypes, + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=xy[0], - y=xy[1], + x=x[0], ) -# erfcinv +# is_nan @handle_frontend_test( - fn_tree="tensorflow.math.erfcinv", + fn_tree="tensorflow.math.is_nan", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), ), test_with_out=st.just(False), ) -def test_tensorflow_erfcinv( +def test_tensorflow_is_nan( *, dtype_and_x, frontend, @@ -1228,13 +1258,15 @@ def test_tensorflow_erfcinv( ) -# is_inf +# is_non_decreasing @handle_frontend_test( - fn_tree="tensorflow.math.is_inf", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), + fn_tree="tensorflow.math.is_non_decreasing", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + ), test_with_out=st.just(False), ) -def test_tensorflow_is_inf( +def test_tensorflow_is_non_decreasing( *, dtype_and_x, frontend, @@ -1255,15 +1287,15 @@ def test_tensorflow_is_inf( ) -# is_non_decreasing +# is_strictly_increasing @handle_frontend_test( - fn_tree="tensorflow.math.is_non_decreasing", + fn_tree="tensorflow.math.is_strictly_increasing", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), ), test_with_out=st.just(False), ) -def test_tensorflow_is_non_decreasing( +def test_tensorflow_is_strictly_increasing( *, dtype_and_x, frontend, @@ -1284,59 +1316,61 @@ def test_tensorflow_is_non_decreasing( ) -# is_strictly_increasing +# l2_normalize @handle_frontend_test( - fn_tree="tensorflow.math.is_strictly_increasing", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + fn_tree="tensorflow.math.l2_normalize", + dtype_values_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=3, + max_num_dims=5, + min_dim_size=1, + max_dim_size=4, + min_axis=-3, + max_axis=2, ), - test_with_out=st.just(False), ) -def test_tensorflow_is_strictly_increasing( +def test_tensorflow_l2_normalize( *, - dtype_and_x, + dtype_values_axis, frontend, test_flags, fn_tree, - backend_fw, on_device, + backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, x, axis = dtype_values_axis helpers.test_frontend_function( input_dtypes=input_dtype, - backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, + backend_to_test=backend_fw, x=x[0], + axis=axis, ) -# count_nonzero +# less @handle_frontend_test( - fn_tree="tensorflow.math.count_nonzero", - dtype_x_axis=helpers.dtype_values_axis( + fn_tree="tensorflow.math.less", + dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), - valid_axis=True, - allow_neg_axes=False, + num_arrays=2, + shared_dtype=True, ), - keepdims=st.booleans(), - dtype=helpers.get_dtypes("numeric", full=False), test_with_out=st.just(False), ) -def test_tensorflow_count_nonzero( +def test_tensorflow_less( *, - dtype_x_axis, - dtype, - keepdims, + dtype_and_x, frontend, test_flags, fn_tree, backend_fw, on_device, ): - input_dtype, x, axis = dtype_x_axis + input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -1344,32 +1378,24 @@ def test_tensorflow_count_nonzero( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=x, - axis=axis, - keepdims=keepdims, - dtype=dtype[0], + x=x[0], + y=x[1], ) -# confusion_matrix +# less_equal @handle_frontend_test( - fn_tree="tensorflow.math.confusion_matrix", + fn_tree="tensorflow.math.less_equal", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("integer"), + available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2, - min_num_dims=1, - max_num_dims=1, - min_value=0, - max_value=4, shared_dtype=True, ), - num_classes=st.integers(min_value=5, max_value=10), test_with_out=st.just(False), ) -def test_tensorflow_confusion_matrix( +def test_tensorflow_less_equal( *, dtype_and_x, - num_classes, frontend, test_flags, fn_tree, @@ -1384,160 +1410,142 @@ def test_tensorflow_confusion_matrix( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - labels=x[0], - predictions=x[1], - num_classes=num_classes, + x=x[0], + y=x[1], ) -# polyval +# log @handle_frontend_test( - fn_tree="tensorflow.math.polyval", - dtype_and_coeffs=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_num_dims=1, - max_num_dims=1, - ), + fn_tree="tensorflow.math.log", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - num_arrays=1, - min_num_dims=0, - max_num_dims=0, ), ) -def test_tensorflow_polyval( +def test_tensorflow_log( *, - dtype_and_coeffs, dtype_and_x, + on_device, + fn_tree, frontend, test_flags, - fn_tree, backend_fw, - on_device, ): - dtype_x, x = dtype_and_x - dtype_coeffs, coeffs = dtype_and_coeffs + input_dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=dtype_coeffs + dtype_x, + input_dtypes=input_dtype, + backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, - backend_to_test=backend_fw, fn_tree=fn_tree, on_device=on_device, - coeffs=coeffs, - x=x, + x=x[0], ) -# unsorted_segment_mean +# log1p @handle_frontend_test( - fn_tree="tensorflow.math.unsorted_segment_mean", - data=helpers.array_values(dtype=ivy.int32, shape=(5, 6), min_value=1, max_value=9), - segment_ids=helpers.array_values( - dtype="int32", shape=(5,), min_value=0, max_value=4 + fn_tree="tensorflow.math.log1p", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + safety_factor_scale="log", ), test_with_out=st.just(False), ) -def test_tensorflow_unsorted_segment_mean( +def test_tensorflow_log1p( *, - data, - segment_ids, + dtype_and_x, frontend, test_flags, fn_tree, backend_fw, on_device, ): + input_dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=["int32", "int64"], - frontend=frontend, + input_dtypes=input_dtype, backend_to_test=backend_fw, + frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - data=data, - segment_ids=segment_ids, - num_segments=np.max(segment_ids) + 1, + x=x[0], ) -# unsorted_segment_sum +# log_sigmoid @handle_frontend_test( - fn_tree="tensorflow.math.unsorted_segment_sum", - data=helpers.array_values(dtype=ivy.int32, shape=(5, 6), min_value=1, max_value=9), - segment_ids=helpers.array_values( - dtype=ivy.int32, shape=(5,), min_value=0, max_value=4 + fn_tree="tensorflow.math.log_sigmoid", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + large_abs_safety_factor=3, + small_abs_safety_factor=3, + safety_factor_scale="linear", ), test_with_out=st.just(False), ) -def test_tensorflow_unsorted_segment_sum( +def test_tensorflow_log_sigmoid( *, - data, - segment_ids, + dtype_and_x, frontend, test_flags, fn_tree, backend_fw, on_device, ): + input_dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=["int32", "int64"], - frontend=frontend, + input_dtypes=input_dtype, backend_to_test=backend_fw, + frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - data=data, - segment_ids=segment_ids, - num_segments=np.max(segment_ids) + 1, + x=x[0], ) -# unsorted_segment_sqrt_n +# log_softmax @handle_frontend_test( - fn_tree="tensorflow.math.unsorted_segment_sqrt_n", - data=helpers.array_values(dtype=ivy.int32, shape=(5, 6), min_value=1, max_value=9), - segment_ids=helpers.array_values( - dtype=ivy.int32, shape=(5,), min_value=0, max_value=4 + fn_tree="tensorflow.math.log_softmax", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_num_dims=1, ), test_with_out=st.just(False), ) -def test_tensorflow_unsorted_segment_sqrt_n( +def test_tensorflow_log_softmax( *, - data, - segment_ids, + dtype_and_x, + on_device, + fn_tree, frontend, test_flags, - fn_tree, backend_fw, - on_device, ): + input_dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=[ivy.float32, ivy.int32], + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - data=data, - segment_ids=segment_ids, - num_segments=np.max(segment_ids) + 1, + logits=x[0], ) -# zero_fraction +# logical_and @handle_frontend_test( - fn_tree="tensorflow.math.zero_fraction", + fn_tree="tensorflow.math.logical_and", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - large_abs_safety_factor=24, - small_abs_safety_factor=24, - safety_factor_scale="log", - min_num_dims=1, + available_dtypes=tuple([ivy.bool]), + num_arrays=2, + shared_dtype=True, ), test_with_out=st.just(False), ) -def test_tensorflow_zero_fraction( +def test_tensorflow_logical_and( *, dtype_and_x, frontend, @@ -1554,24 +1562,22 @@ def test_tensorflow_zero_fraction( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - value=x[0], + x=x[0], + y=x[1], ) -# truediv +# logical_not @handle_frontend_test( - fn_tree="tensorflow.math.truediv", + fn_tree="tensorflow.math.logical_not", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), + available_dtypes=tuple([ivy.bool]), num_arrays=2, shared_dtype=True, - large_abs_safety_factor=24, - small_abs_safety_factor=24, - safety_factor_scale="log", ), test_with_out=st.just(False), ) -def test_tensorflow_truediv( +def test_tensorflow_logical_not( *, dtype_and_x, frontend, @@ -1589,51 +1595,20 @@ def test_tensorflow_truediv( fn_tree=fn_tree, on_device=on_device, x=x[0], - y=x[1], - rtol=1e-2, - atol=1e-2, ) -# pow +# logical_or @handle_frontend_test( - fn_tree="tensorflow.math.pow", + fn_tree="tensorflow.math.logical_or", dtype_and_x=helpers.dtype_and_values( - available_dtypes=[ - "float16", - "float32", - "float64", - "int32", - "int64", - ], + available_dtypes=helpers.get_dtypes("bool"), num_arrays=2, - min_value=1, - max_value=7, shared_dtype=True, ), test_with_out=st.just(False), ) -def test_tensorflow_pow(dtype_and_x, frontend, test_flags, backend_fw, fn_tree): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - x=x[0], - y=x[1], - ) - - -# argmin -@handle_frontend_test( - fn_tree="tensorflow.math.argmin", - dtype_and_x=_statistical_dtype_values(function="argmin"), - output_type=st.sampled_from(["int32", "int64"]), - test_with_out=st.just(False), -) -def test_tensorflow_argmin( +def test_tensorflow_logical_or( *, dtype_and_x, frontend, @@ -1641,11 +1616,8 @@ def test_tensorflow_argmin( fn_tree, backend_fw, on_device, - output_type, ): - input_dtype, x, axis = dtype_and_x - if isinstance(axis, tuple): - axis = axis[0] + input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -1653,23 +1625,22 @@ def test_tensorflow_argmin( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=x[0], - axis=axis, - output_type=output_type, + x=x[0], + y=x[1], ) -# equal +# logical_xor @handle_frontend_test( - fn_tree="tensorflow.math.equal", + fn_tree="tensorflow.math.logical_xor", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), + available_dtypes=tuple([ivy.bool]), num_arrays=2, shared_dtype=True, ), test_with_out=st.just(False), ) -def test_tensorflow_equal( +def test_tensorflow_logical_xor( *, dtype_and_x, frontend, @@ -1691,17 +1662,17 @@ def test_tensorflow_equal( ) -# not_equal +# maximum @handle_frontend_test( - fn_tree="tensorflow.math.not_equal", + fn_tree="tensorflow.math.maximum", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), + available_dtypes=helpers.get_dtypes("float"), num_arrays=2, shared_dtype=True, ), test_with_out=st.just(False), ) -def test_tensorflow_not_equal( +def test_tensorflow_maximum( *, dtype_and_x, frontend, @@ -1723,18 +1694,19 @@ def test_tensorflow_not_equal( ) -# floor +# minimum @handle_frontend_test( - fn_tree="tensorflow.math.floor", + fn_tree="tensorflow.math.minimum", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=1, + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, min_value=-20, max_value=20, + shared_dtype=True, ), test_with_out=st.just(False), ) -def test_tensorflow_floor( +def test_tensorflow_minimum( *, dtype_and_x, test_flags, @@ -1752,48 +1724,53 @@ def test_tensorflow_floor( fn_tree=fn_tree, on_device=on_device, x=x[0], + y=x[1], ) -# ceil +# multiply @handle_frontend_test( - fn_tree="tensorflow.math.ceil", + fn_tree="tensorflow.math.multiply", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - num_arrays=1, - min_value=-20, - max_value=20, + num_arrays=2, + shared_dtype=True, ), test_with_out=st.just(False), ) -def test_tensorflow_ceil( +def test_tensorflow_multiply( *, dtype_and_x, - test_flags, frontend, - backend_fw, + test_flags, fn_tree, + backend_fw, on_device, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, - test_flags=test_flags, frontend=frontend, + test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, x=x[0], + y=x[1], ) -# round +# multiply_no_nan @handle_frontend_test( - fn_tree="tensorflow.math.round", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), + fn_tree="tensorflow.math.multiply_no_nan", + dtype_and_x=helpers.dtype_and_values( + num_arrays=2, + available_dtypes=helpers.get_dtypes("float"), + shared_dtype=True, + ), test_with_out=st.just(False), ) -def test_tensorflow_round( +def test_tensorflow_multiply_no_nan( *, dtype_and_x, frontend, @@ -1802,95 +1779,98 @@ def test_tensorflow_round( backend_fw, on_device, ): - input_dtype, x = dtype_and_x + input_dtypes, xy = dtype_and_x helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=input_dtypes, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], + x=xy[0], + y=xy[1], ) -# minimum +# negative @handle_frontend_test( - fn_tree="tensorflow.math.minimum", + fn_tree="tensorflow.math.negative", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, - min_value=-20, - max_value=20, - shared_dtype=True, + available_dtypes=st.one_of( + helpers.get_dtypes("signed_integer"), + helpers.get_dtypes("float"), + ) ), test_with_out=st.just(False), ) -def test_tensorflow_minimum( +def test_tensorflow_negative( *, dtype_and_x, - test_flags, frontend, - backend_fw, + test_flags, fn_tree, + backend_fw, on_device, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, - test_flags=test_flags, frontend=frontend, + test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, x=x[0], - y=x[1], ) -# sigmoid +# nextafter @handle_frontend_test( - fn_tree="tensorflow.math.sigmoid", + fn_tree="tensorflow.math.nextafter", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float_and_complex"), - num_arrays=1, - min_value=-20, - max_value=20, + available_dtypes=["float32", "float64"], + num_arrays=2, + shared_dtype=True, + min_value=-10, + max_value=10, + min_num_dims=1, + max_num_dims=3, ), test_with_out=st.just(False), ) -def test_tensorflow_sigmoid( +def test_tensorflow_nextafter( *, dtype_and_x, - test_flags, on_device, fn_tree, frontend, + test_flags, backend_fw, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, - test_flags=test_flags, frontend=frontend, + test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - rtol=1e-2, - atol=1e-2, - x=x[0], + x1=x[0], + x2=x[1], ) -# tanh +# not_equal @handle_frontend_test( - fn_tree="tensorflow.math.tanh", + fn_tree="tensorflow.math.not_equal", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), + num_arrays=2, + shared_dtype=True, ), test_with_out=st.just(False), ) -def test_tensorflow_tanh( +def test_tensorflow_not_equal( *, dtype_and_x, frontend, @@ -1908,62 +1888,68 @@ def test_tensorflow_tanh( fn_tree=fn_tree, on_device=on_device, x=x[0], + y=x[1], ) -# rsqrt +# polyval @handle_frontend_test( - fn_tree="tensorflow.math.rsqrt", + fn_tree="tensorflow.math.polyval", + dtype_and_coeffs=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_num_dims=1, + max_num_dims=1, + ), dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), + num_arrays=1, + min_num_dims=0, + max_num_dims=0, ), - test_with_out=st.just(False), ) -def test_tensorflow_rsqrt( +def test_tensorflow_polyval( *, + dtype_and_coeffs, dtype_and_x, - on_device, - fn_tree, frontend, test_flags, + fn_tree, backend_fw, + on_device, ): - input_dtype, x = dtype_and_x + dtype_x, x = dtype_and_x + dtype_coeffs, coeffs = dtype_and_coeffs helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, + input_dtypes=dtype_coeffs + dtype_x, frontend=frontend, test_flags=test_flags, + backend_to_test=backend_fw, fn_tree=fn_tree, on_device=on_device, - rtol=1e-02, - x=x[0], + coeffs=coeffs, + x=x, ) -# nextafter +# pow @handle_frontend_test( - fn_tree="tensorflow.math.nextafter", + fn_tree="tensorflow.math.pow", dtype_and_x=helpers.dtype_and_values( - available_dtypes=["float32", "float64"], + available_dtypes=[ + "float16", + "float32", + "float64", + "int32", + "int64", + ], num_arrays=2, + min_value=1, + max_value=7, shared_dtype=True, - min_value=-10, - max_value=10, - min_num_dims=1, - max_num_dims=3, ), test_with_out=st.just(False), ) -def test_tensorflow_nextafter( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): +def test_tensorflow_pow(dtype_and_x, frontend, test_flags, backend_fw, fn_tree): input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, @@ -1971,61 +1957,57 @@ def test_tensorflow_nextafter( frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, - on_device=on_device, - x1=x[0], - x2=x[1], + x=x[0], + y=x[1], ) -# log_softmax +# real @handle_frontend_test( - fn_tree="tensorflow.math.log_softmax", + fn_tree="tensorflow.math.real", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_num_dims=1, + available_dtypes=helpers.get_dtypes("numeric"), ), test_with_out=st.just(False), ) -def test_tensorflow_log_softmax( +def test_tensorflow_real( *, dtype_and_x, - on_device, - fn_tree, frontend, - test_flags, backend_fw, + test_flags, + fn_tree, + on_device, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, - backend_to_test=backend_fw, frontend=frontend, + backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - logits=x[0], + input=x[0], ) -# abs +# reciprocal @handle_frontend_test( - fn_tree="tensorflow.math.abs", + fn_tree="tensorflow.math.reciprocal", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), - large_abs_safety_factor=25, - small_abs_safety_factor=25, - safety_factor_scale="log", + num_arrays=1, ), test_with_out=st.just(False), ) -def test_tensorflow_abs( +def test_tensorflow_reciprocal( *, dtype_and_x, - on_device, - fn_tree, frontend, test_flags, + fn_tree, backend_fw, + on_device, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -2035,29 +2017,28 @@ def test_tensorflow_abs( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - rtol=1e-02, + rtol=1e-3, + atol=1e-3, x=x[0], ) -# asin +# reciprocal_no_nan @handle_frontend_test( - fn_tree="tensorflow.math.asin", + fn_tree="tensorflow.math.reciprocal_no_nan", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - min_value=-1, - max_value=1, ), test_with_out=st.just(False), ) -def test_tensorflow_asin( +def test_tensorflow_reciprocal_no_nan( *, dtype_and_x, - on_device, - fn_tree, frontend, test_flags, + fn_tree, backend_fw, + on_device, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -2071,24 +2052,22 @@ def test_tensorflow_asin( ) -# acos +# reduce_all() @handle_frontend_test( - fn_tree="tensorflow.math.acos", + fn_tree="tensorflow.math.reduce_all", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_value=-1, - max_value=1, + available_dtypes=tuple([ivy.bool]), ), test_with_out=st.just(False), ) -def test_tensorflow_acos( +def test_tensorflow_reduce_all( *, dtype_and_x, - on_device, - fn_tree, frontend, test_flags, + fn_tree, backend_fw, + on_device, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -2098,30 +2077,31 @@ def test_tensorflow_acos( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], + input_tensor=x[0], ) -# acosh +# reduce_any @handle_frontend_test( - fn_tree="tensorflow.math.acosh", + fn_tree="tensorflow.math.reduce_any", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float_and_complex"), - small_abs_safety_factor=3, - safety_factor_scale="log", + available_dtypes=tuple([ivy.bool]), ), test_with_out=st.just(False), ) -def test_tensorflow_acosh( +def test_tensorflow_reduce_any( *, dtype_and_x, - on_device, - fn_tree, frontend, test_flags, + fn_tree, backend_fw, + on_device, ): - input_dtype, x = dtype_and_x + ( + input_dtype, + x, + ) = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -2129,18 +2109,20 @@ def test_tensorflow_acosh( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - rtol=1e-02, - x=x[0], + input_tensor=x[0], ) -# square +# reduce_euclidean_norm @handle_frontend_test( - fn_tree="tensorflow.math.square", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), + fn_tree="tensorflow.math.reduce_euclidean_norm", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + max_num_dims=2, + ), test_with_out=st.just(False), ) -def test_tensorflow_square( +def test_tensorflow_reduce_euclidean_norm( *, dtype_and_x, frontend, @@ -2149,27 +2131,32 @@ def test_tensorflow_square( backend_fw, on_device, ): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( + ( + input_dtype, + x, + ) = dtype_and_x + helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, + rtol=1e-01, + atol=1e-01, on_device=on_device, - x=x[0], + input_tensor=x[0], ) -# is_nan +# reduce_logsumexp @handle_frontend_test( - fn_tree="tensorflow.math.is_nan", + fn_tree="tensorflow.math.reduce_logsumexp", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), ), test_with_out=st.just(False), ) -def test_tensorflow_is_nan( +def test_tensorflow_reduce_logsumexp( *, dtype_and_x, frontend, @@ -2186,28 +2173,26 @@ def test_tensorflow_is_nan( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], + input_tensor=x[0], ) -# is_finite - - +# reduce_max @handle_frontend_test( - fn_tree="tensorflow.math.is_finite", + fn_tree="tensorflow.math.reduce_max", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), + available_dtypes=helpers.get_dtypes("float"), ), test_with_out=st.just(False), ) -def test_tensorflow_is_finite( +def test_tensorflow_reduce_max( *, dtype_and_x, - on_device, - fn_tree, frontend, test_flags, + fn_tree, backend_fw, + on_device, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -2217,17 +2202,19 @@ def test_tensorflow_is_finite( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], + input_tensor=x[0], ) -# atan +# reduce_mean @handle_frontend_test( - fn_tree="tensorflow.math.atan", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), + fn_tree="tensorflow.math.reduce_mean", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + ), test_with_out=st.just(False), ) -def test_tensorflow_atan( +def test_tensorflow_reduce_mean( *, dtype_and_x, frontend, @@ -2243,26 +2230,29 @@ def test_tensorflow_atan( frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, + atol=1e-2, + rtol=1e-2, on_device=on_device, - x=x[0], + input_tensor=x[0], ) -# log +# reduce_min @handle_frontend_test( - fn_tree="tensorflow.math.log", + fn_tree="tensorflow.math.reduce_min", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), ), + test_with_out=st.just(False), ) -def test_tensorflow_log( +def test_tensorflow_reduce_min( *, dtype_and_x, - on_device, - fn_tree, frontend, test_flags, + fn_tree, backend_fw, + on_device, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -2272,27 +2262,26 @@ def test_tensorflow_log( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], + input_tensor=x[0], ) -# add_n +# reduce_prod @handle_frontend_test( - fn_tree="tensorflow.math.add_n", + fn_tree="tensorflow.math.reduce_prod", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=helpers.ints(min_value=1, max_value=5), - shared_dtype=True, + available_dtypes=helpers.get_dtypes("numeric"), ), + test_with_out=st.just(False), ) -def test_tensorflow_add_n( +def test_tensorflow_reduce_prod( *, dtype_and_x, - on_device, - fn_tree, frontend, test_flags, + fn_tree, backend_fw, + on_device, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -2302,21 +2291,21 @@ def test_tensorflow_add_n( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - inputs=x, + input_tensor=x[0], ) -# floormod +# reduce_std @handle_frontend_test( - fn_tree="tensorflow.math.floormod", + fn_tree="tensorflow.math.reduce_std", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, - shared_dtype=True, + available_dtypes=helpers.get_dtypes("float"), + large_abs_safety_factor=24, + small_abs_safety_factor=24, + safety_factor_scale="log", ), - test_with_out=st.just(False), ) -def test_tensorflow_floormod( +def test_tensorflow_reduce_std( *, dtype_and_x, frontend, @@ -2326,8 +2315,6 @@ def test_tensorflow_floormod( on_device, ): input_dtype, x = dtype_and_x - assume(not np.any(np.isclose(x[0], 0))) - assume(not np.any(np.isclose(x[1], 0))) helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -2335,22 +2322,22 @@ def test_tensorflow_floormod( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], - y=x[1], + input_tensor=x[0], ) -# greater +# reduce_sum @handle_frontend_test( - fn_tree="tensorflow.math.greater", + fn_tree="tensorflow.math.reduce_sum", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, - shared_dtype=True, + large_abs_safety_factor=25, + small_abs_safety_factor=25, + safety_factor_scale="log", ), test_with_out=st.just(False), ) -def test_tensorflow_greater( +def test_tensorflow_reduce_sum( *, dtype_and_x, frontend, @@ -2367,28 +2354,32 @@ def test_tensorflow_greater( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], - y=x[1], + rtol=1e-03, + atol=1e-03, + input_tensor=x[0], ) +# reduce_variance @handle_frontend_test( - fn_tree="tensorflow.math.cos", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + fn_tree="tensorflow.math.reduce_variance", + dtype_and_x=_statistical_dtype_values( + function="var", ), test_with_out=st.just(False), + keepdims=st.booleans(), ) -def test_tensorflow_cos( +def test_tensorflow_reduce_variance( *, dtype_and_x, - on_device, - fn_tree, frontend, test_flags, + fn_tree, backend_fw, + on_device, + keepdims, ): - input_dtype, x = dtype_and_x + input_dtype, x, axis, ddof = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -2396,61 +2387,56 @@ def test_tensorflow_cos( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], + input_tensor=x[0], + axis=axis, + atol=1e-2, + rtol=1e-2, + keepdims=keepdims, ) -# sinh @handle_frontend_test( - fn_tree="tensorflow.math.sinh", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float_and_complex"), - ), + fn_tree="tensorflow.math.rint", + dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), test_with_out=st.just(False), ) -def test_tensorflow_sinh( +def test_tensorflow_rint( *, dtype_and_x, frontend, test_flags, fn_tree, - backend_fw, on_device, + backend_fw, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, - backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, + backend_to_test=backend_fw, x=x[0], ) -# softmax +# round @handle_frontend_test( - fn_tree="tensorflow.math.softmax", - dtype_values_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("float"), - min_num_dims=1, - valid_axis=True, - force_int_axis=True, - allow_inf=False, - ), + fn_tree="tensorflow.math.round", + dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), test_with_out=st.just(False), ) -def test_tensorflow_softmax( +def test_tensorflow_round( *, - dtype_values_axis, - on_device, - fn_tree, + dtype_and_x, frontend, test_flags, + fn_tree, backend_fw, + on_device, ): - input_dtype, x, axis = dtype_values_axis + input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -2458,27 +2444,26 @@ def test_tensorflow_softmax( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - logits=x[0], - atol=1e-02, - rtol=1e-2, - axis=axis, + x=x[0], ) -# softplus +# rsqrt @handle_frontend_test( - fn_tree="tensorflow.math.softplus", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), + fn_tree="tensorflow.math.rsqrt", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + ), test_with_out=st.just(False), ) -def test_tensorflow_softplus( +def test_tensorflow_rsqrt( *, dtype_and_x, + on_device, + fn_tree, frontend, test_flags, - fn_tree, backend_fw, - on_device, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -2488,30 +2473,43 @@ def test_tensorflow_softplus( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - features=x[0], + rtol=1e-02, + x=x[0], ) -# xlogy +# scalar_mul @handle_frontend_test( - fn_tree="tensorflow.math.xlogy", + fn_tree="tensorflow.math.scalar_mul", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float_and_complex"), - num_arrays=2, - shared_dtype=True, + available_dtypes=st.shared( + helpers.get_dtypes("float", full=False), + key="shared_dtype", + ), + min_num_dims=1, + min_dim_size=2, + ), + scalar_val=helpers.dtype_and_values( + available_dtypes=st.shared( + helpers.get_dtypes("float", full=False), + key="shared_dtype", + ), + shape=(1,), ), test_with_out=st.just(False), ) -def test_tensorflow_xlogy( +def test_tensorflow_scalar_mul( *, dtype_and_x, + scalar_val, frontend, test_flags, fn_tree, backend_fw, on_device, ): - input_dtype, xs = dtype_and_x + input_dtype, x = dtype_and_x + scalar_dtype, scalar = scalar_val helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -2519,49 +2517,52 @@ def test_tensorflow_xlogy( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=xs[0], - y=xs[1], + scalar=scalar[0][0], + x=x[0], ) -# cosh +# sigmoid @handle_frontend_test( - fn_tree="tensorflow.math.cosh", + fn_tree="tensorflow.math.sigmoid", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float_and_complex"), + num_arrays=1, + min_value=-20, + max_value=20, ), test_with_out=st.just(False), ) -def test_tensorflow_cosh( +def test_tensorflow_sigmoid( *, dtype_and_x, - frontend, test_flags, + on_device, fn_tree, + frontend, backend_fw, - on_device, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, - frontend=frontend, test_flags=test_flags, + frontend=frontend, fn_tree=fn_tree, on_device=on_device, + rtol=1e-2, + atol=1e-2, x=x[0], ) -# atan2 +# sin @handle_frontend_test( - fn_tree="tensorflow.math.atan2", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), num_arrays=2, shared_dtype=True - ), + fn_tree="tensorflow.math.sin", + dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), test_with_out=st.just(False), ) -def test_tensorflow_atan2( +def test_tensorflow_sin( *, dtype_and_x, frontend, @@ -2571,7 +2572,6 @@ def test_tensorflow_atan2( on_device, ): input_dtype, x = dtype_and_x - assume(not np.any(np.isclose(x[1], 0))) helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -2579,22 +2579,19 @@ def test_tensorflow_atan2( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - y=x[0], - x=x[1], + x=x[0], ) -# less_equal +# sinh @handle_frontend_test( - fn_tree="tensorflow.math.less_equal", + fn_tree="tensorflow.math.sinh", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, - shared_dtype=True, + available_dtypes=helpers.get_dtypes("float_and_complex"), ), test_with_out=st.just(False), ) -def test_tensorflow_less_equal( +def test_tensorflow_sinh( *, dtype_and_x, frontend, @@ -2612,21 +2609,52 @@ def test_tensorflow_less_equal( fn_tree=fn_tree, on_device=on_device, x=x[0], - y=x[1], ) -# less +# softmax @handle_frontend_test( - fn_tree="tensorflow.math.less", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, - shared_dtype=True, + fn_tree="tensorflow.math.softmax", + dtype_values_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("float"), + min_num_dims=1, + valid_axis=True, + force_int_axis=True, + allow_inf=False, ), test_with_out=st.just(False), ) -def test_tensorflow_less( +def test_tensorflow_softmax( + *, + dtype_values_axis, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x, axis = dtype_values_axis + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + logits=x[0], + atol=1e-02, + rtol=1e-2, + axis=axis, + ) + + +# softplus +@handle_frontend_test( + fn_tree="tensorflow.math.softplus", + dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), + test_with_out=st.just(False), +) +def test_tensorflow_softplus( *, dtype_and_x, frontend, @@ -2643,28 +2671,55 @@ def test_tensorflow_less( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], - y=x[1], + features=x[0], ) -# angle +# softsign @handle_frontend_test( - fn_tree="tensorflow.math.angle", - dtype_and_input=helpers.dtype_and_values( - available_dtypes=["float64", "complex64", "complex128"], + fn_tree="tensorflow.math.softsign", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), ), + test_with_out=st.just(False), ) -def test_tensorflow_angle( +def test_tensorflow_softsign( *, - dtype_and_input, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + features=x[0], + ) + + +# sqrt +@handle_frontend_test( + fn_tree="tensorflow.math.sqrt", + dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), + test_with_out=st.just(False), +) +def test_tensorflow_sqrt( + *, + dtype_and_x, frontend, test_flags, fn_tree, backend_fw, on_device, ): - input_dtype, x = dtype_and_input + input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -2672,23 +2727,17 @@ def test_tensorflow_angle( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=x[0], + x=x[0], ) -# zeta +# square @handle_frontend_test( - fn_tree="tensorflow.math.zeta", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - min_num_dims=1, - max_num_dims=1, - num_arrays=2, - shared_dtype=True, - ), + fn_tree="tensorflow.math.square", + dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), test_with_out=st.just(False), ) -def test_tensorflow_zeta( +def test_tensorflow_square( *, dtype_and_x, frontend, @@ -2706,13 +2755,12 @@ def test_tensorflow_zeta( fn_tree=fn_tree, on_device=on_device, x=x[0], - q=x[1], ) -# greater_equal +# squared_difference @handle_frontend_test( - fn_tree="tensorflow.math.greater_equal", + fn_tree="tensorflow.math.squared_difference", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2, @@ -2720,7 +2768,7 @@ def test_tensorflow_zeta( ), test_with_out=st.just(False), ) -def test_tensorflow_greater_equal( +def test_tensorflow_squared_difference( *, dtype_and_x, frontend, @@ -2742,19 +2790,24 @@ def test_tensorflow_greater_equal( ) -# in_top_k +# subtract @handle_frontend_test( - fn_tree="tensorflow.math.in_top_k", + fn_tree="tensorflow.math.subtract", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2, shared_dtype=True, ), - k=st.integers(min_value=0, max_value=5), test_with_out=st.just(False), ) -def test_tensorflow_in_top_k( - *, dtype_and_x, frontend, test_flags, backend_fw, fn_tree, on_device, k +def test_tensorflow_subtract( + *, + dtype_and_x, + frontend, + test_flags, + fn_tree, + backend_fw, + on_device, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -2764,29 +2817,56 @@ def test_tensorflow_in_top_k( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - targets=x[0], - pred=x[1], - k=k, + x=x[0], + y=x[1], ) -# conj +# tan @handle_frontend_test( - fn_tree="tensorflow.math.conj", - dtype_and_input=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), + fn_tree="tensorflow.math.tan", + dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), + test_with_out=st.just(False), +) +def test_tensorflow_tan( + *, + dtype_and_x, + frontend, + test_flags, + fn_tree, + backend_fw, + on_device, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + ) + + +# tanh +@handle_frontend_test( + fn_tree="tensorflow.math.tanh", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), ), + test_with_out=st.just(False), ) -def test_tensorflow_conj( +def test_tensorflow_tanh( *, - dtype_and_input, + dtype_and_x, frontend, test_flags, fn_tree, backend_fw, on_device, ): - input_dtype, x = dtype_and_input + input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -2826,127 +2906,136 @@ def test_tensorflow_top_k( ) -# real +# truediv @handle_frontend_test( - fn_tree="tensorflow.math.real", + fn_tree="tensorflow.math.truediv", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, + shared_dtype=True, + large_abs_safety_factor=24, + small_abs_safety_factor=24, + safety_factor_scale="log", ), test_with_out=st.just(False), ) -def test_tensorflow_real( +def test_tensorflow_truediv( *, dtype_and_x, frontend, - backend_fw, test_flags, fn_tree, + backend_fw, on_device, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, - frontend=frontend, backend_to_test=backend_fw, + frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=x[0], + x=x[0], + y=x[1], + rtol=1e-2, + atol=1e-2, ) -# atanh +# unsorted_segment_mean @handle_frontend_test( - fn_tree="tensorflow.math.atanh", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float_and_complex"), + fn_tree="tensorflow.math.unsorted_segment_mean", + data=helpers.array_values(dtype=ivy.int32, shape=(5, 6), min_value=1, max_value=9), + segment_ids=helpers.array_values( + dtype="int32", shape=(5,), min_value=0, max_value=4 ), test_with_out=st.just(False), ) -def test_tensorflow_atanh( +def test_tensorflow_unsorted_segment_mean( *, - dtype_and_x, - on_device, - fn_tree, - backend_fw, + data, + segment_ids, frontend, test_flags, + fn_tree, + backend_fw, + on_device, ): - input_dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=["int32", "int64"], frontend=frontend, backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], + data=data, + segment_ids=segment_ids, + num_segments=np.max(segment_ids) + 1, ) +# unsorted_segment_sqrt_n @handle_frontend_test( - fn_tree="tensorflow.math.rint", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), + fn_tree="tensorflow.math.unsorted_segment_sqrt_n", + data=helpers.array_values(dtype=ivy.int32, shape=(5, 6), min_value=1, max_value=9), + segment_ids=helpers.array_values( + dtype=ivy.int32, shape=(5,), min_value=0, max_value=4 + ), test_with_out=st.just(False), ) -def test_tensorflow_rint( +def test_tensorflow_unsorted_segment_sqrt_n( *, - dtype_and_x, + data, + segment_ids, frontend, test_flags, fn_tree, - on_device, backend_fw, + on_device, ): - input_dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=[ivy.float32, ivy.int32], + backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - backend_to_test=backend_fw, - x=x[0], + data=data, + segment_ids=segment_ids, + num_segments=np.max(segment_ids) + 1, ) -# bincount +# unsorted_segment_sum @handle_frontend_test( - fn_tree="tensorflow.math.bincount", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("integer"), - min_value=1, - max_value=2, - shape=st.shared( - helpers.get_shape( - min_num_dims=1, - max_num_dims=1, - ), - key="a_s_d", - ), + fn_tree="tensorflow.math.unsorted_segment_sum", + data=helpers.array_values(dtype=ivy.int32, shape=(5, 6), min_value=1, max_value=9), + segment_ids=helpers.array_values( + dtype=ivy.int32, shape=(5,), min_value=0, max_value=4 ), test_with_out=st.just(False), ) -def test_tensorflow_bincount( +def test_tensorflow_unsorted_segment_sum( *, - dtype_and_x, - on_device, - backend_fw, - fn_tree, + data, + segment_ids, frontend, test_flags, + fn_tree, + backend_fw, + on_device, ): - input_dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, + input_dtypes=["int32", "int64"], frontend=frontend, + backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - arr=x[0], - weights=None, - minlength=0, + data=data, + segment_ids=segment_ids, + num_segments=np.max(segment_ids) + 1, ) @@ -3014,181 +3103,91 @@ def test_tensorflow_xlog1py( ) +# xlogy @handle_frontend_test( - fn_tree="tensorflow.math.igamma", + fn_tree="tensorflow.math.xlogy", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), + available_dtypes=helpers.get_dtypes("float_and_complex"), num_arrays=2, shared_dtype=True, - abs_smallest_val=1e-5, - min_num_dims=2, - max_num_dims=2, - min_dim_size=3, - max_dim_size=3, - min_value=2, - max_value=100, - allow_nan=False, ), test_with_out=st.just(False), ) -def test_tensorflow_igamma( +def test_tensorflow_xlogy( *, dtype_and_x, - on_device, - fn_tree, - backend_fw, - frontend, - test_flags, -): - input_dtype, xs = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - frontend=frontend, - backend_to_test=backend_fw, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - rtol=1e-04, - a=xs[0], - x=xs[1], - ) - - -# l2_normalize -@handle_frontend_test( - fn_tree="tensorflow.math.l2_normalize", - dtype_values_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("valid"), - min_num_dims=3, - max_num_dims=5, - min_dim_size=1, - max_dim_size=4, - min_axis=-3, - max_axis=2, - ), -) -def test_tensorflow_l2_normalize( - *, - dtype_values_axis, - frontend, - test_flags, - fn_tree, - on_device, - backend_fw, -): - input_dtype, x, axis = dtype_values_axis - helpers.test_frontend_function( - input_dtypes=input_dtype, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - backend_to_test=backend_fw, - x=x[0], - axis=axis, - ) - - -# cumsum -@handle_frontend_test( - fn_tree="tensorflow.math.cumsum", - dtype_x_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("numeric"), - valid_axis=True, - force_int_axis=True, - min_num_dims=1, - min_value=-5, - max_value=5, - ), - exclusive=st.booleans(), - reverse=st.booleans(), - test_with_out=st.just(False), -) -def test_tensorflow_cumsum( # NOQA - *, - dtype_x_axis, - exclusive, - reverse, frontend, test_flags, fn_tree, backend_fw, on_device, ): - dtype, x, axis = dtype_x_axis + input_dtype, xs = dtype_and_x helpers.test_frontend_function( - input_dtypes=dtype, + input_dtypes=input_dtype, backend_to_test=backend_fw, + frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, - frontend=frontend, on_device=on_device, - rtol=1e-02, - atol=1e-02, - x=x[0], - axis=axis, - exclusive=exclusive, - reverse=reverse, + x=xs[0], + y=xs[1], ) -# cumprod +# zero_fraction @handle_frontend_test( - fn_tree="tensorflow.math.cumprod", - dtype_x_axis=helpers.dtype_values_axis( + fn_tree="tensorflow.math.zero_fraction", + dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), - valid_axis=True, - force_int_axis=True, + large_abs_safety_factor=24, + small_abs_safety_factor=24, + safety_factor_scale="log", min_num_dims=1, - min_value=-5, - max_value=5, ), - exclusive=st.booleans(), - reverse=st.booleans(), test_with_out=st.just(False), ) -def test_tensorflow_cumprod( # NOQA +def test_tensorflow_zero_fraction( *, - dtype_x_axis, - exclusive, - reverse, + dtype_and_x, frontend, test_flags, fn_tree, backend_fw, on_device, ): - dtype, x, axis = dtype_x_axis + input_dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=dtype, + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], - axis=axis, - exclusive=exclusive, - reverse=reverse, + value=x[0], ) -# softsign +# zeta @handle_frontend_test( - fn_tree="tensorflow.math.softsign", + fn_tree="tensorflow.math.zeta", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=1, + max_num_dims=1, + num_arrays=2, + shared_dtype=True, ), test_with_out=st.just(False), ) -def test_tensorflow_softsign( +def test_tensorflow_zeta( *, dtype_and_x, - on_device, - fn_tree, frontend, test_flags, + fn_tree, backend_fw, + on_device, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -3198,5 +3197,6 @@ def test_tensorflow_softsign( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - features=x[0], + x=x[0], + q=x[1], ) diff --git a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_nn.py b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_nn.py index 78b50504e962a..2d3ac95eac8a0 100644 --- a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_nn.py +++ b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_nn.py @@ -13,44 +13,211 @@ ) -@handle_frontend_test( - fn_tree="tensorflow.nn.leaky_relu", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_num_dims=1, - large_abs_safety_factor=25, - small_abs_safety_factor=25, - safety_factor_scale="log", - ), - test_with_out=st.just(False), - alpha=helpers.floats( - min_value=0, - max_value=1, - large_abs_safety_factor=25, - small_abs_safety_factor=25, - safety_factor_scale="log", - ), -) -def test_tensorflow_leaky_relu( - *, - dtype_and_x, - alpha, - frontend, - test_flags, - fn_tree, - backend_fw, - on_device, -): - dtype, x = dtype_and_x - return helpers.test_frontend_function( - input_dtypes=dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - features=x[0], - alpha=alpha, +# --- Helpers --- # +# --------------- # + + +@st.composite +def _average_pool_args(draw): + dims = draw(st.integers(min_value=1, max_value=3)) + data_formats = ["NWC", "NHWC", "NDHWC"] + data_format = data_formats[dims - 1] + return ( + draw( + helpers.arrays_for_pooling( + min_dims=dims + 2, max_dims=dims + 2, min_side=1, max_side=4 + ) + ), + data_format, + ) + + +# sufficient_statistics +@st.composite +def _axes_value(draw): + s = draw( + helpers.get_shape( + min_num_dims=1, + max_num_dims=5, + min_dim_size=1, + max_dim_size=5, + ) + ) + dtype_and_x = draw( + helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("float"), + shape=s, + valid_axis=True, + force_tuple_axis=True, + ) + ) + return dtype_and_x + + +@st.composite +def _batch_normalization_helper(draw): + shape1, shape2, shape3, shape4 = draw(helpers.mutually_broadcastable_shapes(4)) + shape = helpers.broadcast_shapes(shape1, shape2, shape3, shape4) + x_dtype, x = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + large_abs_safety_factor=24, + small_abs_safety_factor=24, + safety_factor_scale="log", + shape=shape, + max_value=999, + min_value=-1001, + ) + ) + + _, mean = draw( + helpers.dtype_and_values( + dtype=x_dtype, + shape=shape1, + min_value=-1001, + max_value=999, + ) + ) + _, variance = draw( + helpers.dtype_and_values( + dtype=x_dtype, + shape=shape2, + min_value=0, + max_value=999, + ) + ) + _, offset = draw( + helpers.dtype_and_values( + dtype=x_dtype, + shape=shape3, + min_value=-1001, + max_value=999, + ) + ) + _, scale = draw( + helpers.dtype_and_values( + dtype=x_dtype, + shape=shape4, + min_value=-1001, + max_value=999, + ) + ) + + return x_dtype, x[0], mean[0], variance[0], offset[0], scale[0] + + +@st.composite +def _dropout_helper(draw): + shape = draw(helpers.get_shape(min_num_dims=1)) + dtype_and_x = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + shape=shape, + ) + ) + noise_shape = list(shape) + if draw(st.booleans()): + noise_shape = None + else: + for i, _ in enumerate(noise_shape): + if draw(st.booleans()): + noise_shape[i] = 1 + elif draw(st.booleans()): + noise_shape[i] = None + seed = draw(helpers.ints(min_value=0, max_value=100)) + rate = draw(helpers.floats(min_value=0, max_value=0.9)) + + return ( + dtype_and_x, + noise_shape, + seed, + rate, + ) + + +@st.composite +def _generate_bias_data(draw): + data_format = draw(st.sampled_from(["NC...", "N...C", None])) + channel_dim = 1 if data_format == "NC..." else -1 + dtype, value, shape = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + min_num_dims=3, + ret_shape=True, + ) + ) + channel_size = shape[channel_dim] + bias = draw(helpers.array_values(dtype=dtype[0], shape=(channel_size,))) + return data_format, dtype, value, bias + + +# Normalize Moments +@st.composite +def _normalize_moments_helper(draw): + shape1, shape2, shape3 = draw(helpers.mutually_broadcastable_shapes(3)) + counts_dtype, counts = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + max_value=999, + min_value=-1001, + max_num_dims=1, + max_dim_size=1, + min_dim_size=1, + ) + ) + _, mean = draw( + helpers.dtype_and_values( + available_dtypes=counts_dtype, + shape=shape1, + min_value=1, + max_num_dims=1, + max_dim_size=1, + min_dim_size=1, + ) + ) + _, variance = draw( + helpers.dtype_and_values( + available_dtypes=counts_dtype, + shape=shape2, + min_value=1, + max_num_dims=1, + max_dim_size=1, + min_dim_size=1, + ) + ) + _, shift = draw( + helpers.dtype_and_values( + available_dtypes=counts_dtype, + shape=shape3, + min_value=1, + max_num_dims=1, + max_dim_size=1, + min_dim_size=1, + ) + ) + + return counts_dtype, counts[0], mean[0], variance[0], shift[0] + + +@st.composite +def _pool_args(draw): + dims = draw(st.integers(min_value=3, max_value=5)) + data_formats = {3: "NWC", 4: "NHWC", 5: "NDHWC"} + data_format = data_formats[dims] + pooling_type = draw(st.one_of(st.just("AVG"), st.just("MAX"))) + return ( + draw( + helpers.arrays_for_pooling( + min_dims=dims, + max_dims=dims, + min_side=1, + max_side=4, + return_dilation=True, + ) + ), + data_format, + pooling_type, + dims, ) @@ -349,6 +516,16 @@ def _x_and_filters( return dtype, x, filters, dilations, data_format, stride, padding, output_shape +@st.composite +def df(draw, data_format): + data_format = draw(data_format) + return data_format + + +# --- Main --- # +# ------------ # + + @handle_frontend_test( fn_tree="tensorflow.nn.atrous_conv2d", x_f_d_df=_x_and_filters( @@ -436,13 +613,169 @@ def test_tensorflow_atrous_conv2d_transpose( ) +# average_pool @handle_frontend_test( - fn_tree="tensorflow.nn.conv1d", - x_f_d_df=_x_and_filters( - dtypes=helpers.get_dtypes("float", full=False), - data_format=st.sampled_from(["NWC"]), - padding=st.sampled_from(["VALID", "SAME"]), - stride_min=3, + fn_tree="tensorflow.nn.avg_pool", + x_k_s_p_df=_average_pool_args(), + test_with_out=st.just(False), +) +def test_tensorflow_avg_pool( + *, + x_k_s_p_df, + frontend, + test_flags, + fn_tree, + backend_fw, + on_device, +): + (input_dtype, x, ksize, strides, padding), data_format = x_k_s_p_df + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=x[0], + ksize=ksize, + strides=strides, + padding=padding, + data_format=data_format, + ) + + +# test_avg_pool1d +@handle_frontend_test( + fn_tree="tensorflow.nn.avg_pool1d", + x_k_s_p_df=helpers.arrays_for_pooling( + min_dims=3, max_dims=3, min_side=1, max_side=4 + ), + test_with_out=st.just(False), +) +def test_tensorflow_avg_pool1d( + *, + x_k_s_p_df, + frontend, + test_flags, + fn_tree, + backend_fw, + on_device, +): + (input_dtype, x, ksize, strides, padding) = x_k_s_p_df + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=x[0], + ksize=ksize, + strides=strides, + padding=padding, + ) + + +@handle_frontend_test( + fn_tree="tensorflow.nn.avg_pool3d", + x_k_s_p_df=helpers.arrays_for_pooling( + min_dims=5, max_dims=5, min_side=1, max_side=4 + ), + test_with_out=st.just(False), +) +def test_tensorflow_avg_pool3d( + *, + x_k_s_p_df, + frontend, + test_flags, + fn_tree, + backend_fw, + on_device, +): + input_dtype, x, ksize, strides, padding = x_k_s_p_df + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=x[0], + ksize=ksize, + strides=strides, + padding=padding, + ) + + +@handle_frontend_test( + fn_tree="tensorflow.nn.batch_normalization", + data=_batch_normalization_helper(), + eps=helpers.floats(min_value=1e-5, max_value=0.1), +) +def test_tensorflow_batch_normalization( + *, + data, + eps, + frontend, + test_flags, + fn_tree, + backend_fw, + on_device, +): + x_dtype, x, mean, variance, offset, scale = data + helpers.test_frontend_function( + input_dtypes=x_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + rtol=1e-2, + atol=1e-2, + fn_tree=fn_tree, + on_device=on_device, + x=x, + mean=mean, + variance=variance, + offset=offset, + scale=scale, + variance_epsilon=eps, + ) + + +@handle_frontend_test( + fn_tree="tensorflow.nn.bias_add", + data=_generate_bias_data(), + test_with_out=st.just(False), +) +def test_tensorflow_bias_add( + *, + data, + frontend, + test_flags, + fn_tree, + backend_fw, + on_device, +): + data_format, dtype, value, bias = data + helpers.test_frontend_function( + input_dtypes=dtype * 2, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + value=value[0], + bias=bias, + data_format=data_format, + ) + + +@handle_frontend_test( + fn_tree="tensorflow.nn.conv1d", + x_f_d_df=_x_and_filters( + dtypes=helpers.get_dtypes("float", full=False), + data_format=st.sampled_from(["NWC"]), + padding=st.sampled_from(["VALID", "SAME"]), + stride_min=3, stride_max=4, type="1d", ), @@ -525,38 +858,6 @@ def test_tensorflow_conv1d_transpose( ) -@handle_frontend_test( - fn_tree="tensorflow.nn.gelu", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - max_value=1e04, - ), - approximate=st.booleans(), - test_with_out=st.just(False), -) -def test_tensorflow_gelu( - *, - dtype_and_x, - approximate, - frontend, - test_flags, - fn_tree, - backend_fw, - on_device, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - features=x[0], - approximate=approximate, - ) - - @handle_frontend_test( fn_tree="tensorflow.nn.conv2d", x_f_d_df=_x_and_filters( @@ -725,17 +1026,19 @@ def test_tensorflow_conv3d_transpose( ) +# convolution @handle_frontend_test( - fn_tree="tensorflow.nn.depthwise_conv2d", + fn_tree="tensorflow.nn.convolution", x_f_d_df=_x_and_filters( dtypes=helpers.get_dtypes("float", full=False), - data_format=st.sampled_from(["NHWC"]), - padding=st.sampled_from(["VALID", "SAME"]), - type="depthwise", + data_format=st.sampled_from(["NWC", "NHWC", "NDHWC"]), + padding=st.sampled_from(["SAME", "VALID"]), + dilation_max=1, + type=None, ), test_with_out=st.just(False), ) -def test_tensorflow_depthwise_conv2d( +def test_tensorflow_convolution( *, x_f_d_df, frontend, @@ -753,7 +1056,7 @@ def test_tensorflow_depthwise_conv2d( fn_tree=fn_tree, on_device=on_device, input=x, - filter=filters, + filters=filters, strides=stride, padding=padding, data_format=data_format, @@ -761,155 +1064,108 @@ def test_tensorflow_depthwise_conv2d( ) +# crelu @handle_frontend_test( - fn_tree="tensorflow.nn.separable_conv2d", - x_f_d_df=_x_and_filters( - dtypes=helpers.get_dtypes("float", full=False), - data_format=st.sampled_from(["NHWC"]), - padding=st.sampled_from(["VALID", "SAME"]), - type="separable", + fn_tree="tensorflow.nn.crelu", + dtype_x_and_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("float"), + min_num_dims=4, + max_axes_size=3, + force_int_axis=True, + valid_axis=True, ), test_with_out=st.just(False), ) -def test_tensorflow_separable_conv2d( +def test_tensorflow_crelu( *, - x_f_d_df, - frontend, + dtype_x_and_axis, test_flags, - fn_tree, + frontend, backend_fw, + fn_tree, on_device, ): - input_dtype, x, filters, dilation, data_format, stride, padding = x_f_d_df + input_dtype, x, axis = dtype_x_and_axis helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, - frontend=frontend, test_flags=test_flags, + frontend=frontend, fn_tree=fn_tree, on_device=on_device, - input=x, - depthwise_filter=filters[0], - pointwise_filter=filters[1], - strides=stride, - padding=padding, - data_format=data_format, - dilations=dilation, + features=x[0], + axis=axis, ) -@st.composite -def _batch_normalization_helper(draw): - shape1, shape2, shape3, shape4 = draw(helpers.mutually_broadcastable_shapes(4)) - shape = helpers.broadcast_shapes(shape1, shape2, shape3, shape4) - x_dtype, x = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - large_abs_safety_factor=24, - small_abs_safety_factor=24, - safety_factor_scale="log", - shape=shape, - max_value=999, - min_value=-1001, - ) - ) - - _, mean = draw( - helpers.dtype_and_values( - dtype=x_dtype, - shape=shape1, - min_value=-1001, - max_value=999, - ) - ) - _, variance = draw( - helpers.dtype_and_values( - dtype=x_dtype, - shape=shape2, - min_value=0, - max_value=999, - ) - ) - _, offset = draw( - helpers.dtype_and_values( - dtype=x_dtype, - shape=shape3, - min_value=-1001, - max_value=999, - ) - ) - _, scale = draw( - helpers.dtype_and_values( - dtype=x_dtype, - shape=shape4, - min_value=-1001, - max_value=999, - ) +# ctc_unique_labels +@handle_frontend_test( + fn_tree="tensorflow.nn.ctc_unique_labels", + dtype_x=helpers.dtype_and_values( + available_dtypes=["int64", "int32"], + min_value=1, + max_value=100, + min_dim_size=1, + max_dim_size=10, + min_num_dims=2, + max_num_dims=2, + ), + test_with_out=st.just([False]), +) +def test_tensorflow_ctc_unique_labels( + *, + dtype_x, + frontend, + fn_tree, + test_flags, + on_device, + backend_fw, +): + dtype, x = dtype_x + helpers.test_frontend_function( + input_dtypes=dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + labels=x[0], ) - return x_dtype, x[0], mean[0], variance[0], offset[0], scale[0] - @handle_frontend_test( - fn_tree="tensorflow.nn.batch_normalization", - data=_batch_normalization_helper(), - eps=helpers.floats(min_value=1e-5, max_value=0.1), + fn_tree="tensorflow.nn.depthwise_conv2d", + x_f_d_df=_x_and_filters( + dtypes=helpers.get_dtypes("float", full=False), + data_format=st.sampled_from(["NHWC"]), + padding=st.sampled_from(["VALID", "SAME"]), + type="depthwise", + ), + test_with_out=st.just(False), ) -def test_tensorflow_batch_normalization( +def test_tensorflow_depthwise_conv2d( *, - data, - eps, + x_f_d_df, frontend, test_flags, fn_tree, backend_fw, on_device, ): - x_dtype, x, mean, variance, offset, scale = data + input_dtype, x, filters, dilation, data_format, stride, padding = x_f_d_df helpers.test_frontend_function( - input_dtypes=x_dtype, + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, - rtol=1e-2, - atol=1e-2, fn_tree=fn_tree, on_device=on_device, - x=x, - mean=mean, - variance=variance, - offset=offset, - scale=scale, - variance_epsilon=eps, - ) - - -@st.composite -def _dropout_helper(draw): - shape = draw(helpers.get_shape(min_num_dims=1)) - dtype_and_x = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - shape=shape, - ) - ) - noise_shape = list(shape) - if draw(st.booleans()): - noise_shape = None - else: - for i, _ in enumerate(noise_shape): - if draw(st.booleans()): - noise_shape[i] = 1 - elif draw(st.booleans()): - noise_shape[i] = None - seed = draw(helpers.ints(min_value=0, max_value=100)) - rate = draw(helpers.floats(min_value=0, max_value=0.9)) - - return ( - dtype_and_x, - noise_shape, - seed, - rate, + input=x, + filter=filters, + strides=stride, + padding=padding, + data_format=data_format, + dilations=dilation, ) @@ -947,75 +1203,58 @@ def test_tensorflow_dropout( assert u.shape == v.shape == w.shape -# silu +# embedding_lookup @handle_frontend_test( - fn_tree="tensorflow.nn.silu", - dtype_features=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_value=0, - max_value=5, - min_num_dims=1, - max_num_dims=3, - min_dim_size=1, - max_dim_size=3, - ), - beta=helpers.floats( - min_value=0, - max_value=3, - ), - test_with_out=st.just(False), + fn_tree="tensorflow.nn.embedding_lookup", + dtypes_indices_weights=helpers.embedding_helper(), + max_norm=st.floats(min_value=0.1, max_value=5, exclude_min=True), ) -def test_tensorflow_silu( +def test_tensorflow_embedding_lookup( *, - dtype_features, - beta, - frontend, + dtypes_indices_weights, + max_norm, test_flags, + on_device, fn_tree, + frontend, backend_fw, - on_device, ): - input_dtype, features = dtype_features + dtypes, indices, weight, _ = dtypes_indices_weights + dtypes.reverse() helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=dtypes, backend_to_test=backend_fw, - frontend=frontend, test_flags=test_flags, + frontend=frontend, fn_tree=fn_tree, on_device=on_device, - atol=1e-2, - features=features[0], - beta=beta, + params=weight, + ids=indices, + max_norm=max_norm, + atol=1e-4, ) -# sigmoid_cross_entropy_with_logits @handle_frontend_test( - fn_tree="tensorflow.nn.sigmoid_cross_entropy_with_logits", - dtype_labels_logits=helpers.dtype_and_values( + fn_tree="tensorflow.nn.gelu", + dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - min_value=0, - max_value=1, - min_num_dims=1, - max_num_dims=2, - min_dim_size=1, - max_dim_size=2, - shared_dtype=True, + max_value=1e04, ), + approximate=st.booleans(), test_with_out=st.just(False), ) -def test_tensorflow_sigmoid_cross_entropy_with_logits( +def test_tensorflow_gelu( *, - dtype_labels_logits, + dtype_and_x, + approximate, frontend, test_flags, fn_tree, backend_fw, on_device, ): - input_dtype, input_values = dtype_labels_logits - labels, logits = input_values + input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -1023,55 +1262,49 @@ def test_tensorflow_sigmoid_cross_entropy_with_logits( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - labels=labels, - logits=logits, + features=x[0], + approximate=approximate, ) -# weighted_cross_entropy_with_logits @handle_frontend_test( - fn_tree="tensorflow.nn.weighted_cross_entropy_with_logits", - dtype_labels_logits=helpers.dtype_and_values( + fn_tree="tensorflow.nn.leaky_relu", + dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - min_value=0, - max_value=1, min_num_dims=1, - max_num_dims=3, - min_dim_size=1, - max_dim_size=3, - shared_dtype=True, - ), - pos_weight=st.one_of( - helpers.floats( - min_value=0, - max_value=3, - ) + large_abs_safety_factor=25, + small_abs_safety_factor=25, + safety_factor_scale="log", ), test_with_out=st.just(False), + alpha=helpers.floats( + min_value=0, + max_value=1, + large_abs_safety_factor=25, + small_abs_safety_factor=25, + safety_factor_scale="log", + ), ) -def test_tensorflow_weighted_cross_entropy_with_logits( +def test_tensorflow_leaky_relu( *, - dtype_labels_logits, - pos_weight, + dtype_and_x, + alpha, frontend, test_flags, fn_tree, backend_fw, on_device, ): - input_dtype, input_values = dtype_labels_logits - labels, logits = input_values - helpers.test_frontend_function( - input_dtypes=input_dtype, + dtype, x = dtype_and_x + return helpers.test_frontend_function( + input_dtypes=dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - labels=labels, - logits=logits, - pos_weight=pos_weight, + features=x[0], + alpha=alpha, ) @@ -1125,16 +1358,50 @@ def test_tensorflow_local_response_normalization( ) -@st.composite -def df(draw, data_format): - data_format = draw(data_format) - return data_format - - -# max_pool1d @handle_frontend_test( - fn_tree="tensorflow.nn.max_pool1d", - data_format=df(data_format=st.sampled_from(["NWC"])), + fn_tree="tensorflow.nn.log_poisson_loss", + dtype_target_log_inputs=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, + min_value=0, + max_value=1, + min_num_dims=1, + max_num_dims=3, + shared_dtype=True, + ), + compute_full_loss=st.booleans(), + test_with_out=st.just(False), +) +def test_tensorflow_log_poisson_loss( + *, + dtype_target_log_inputs, + compute_full_loss, + test_flags, + frontend, + fn_tree, + on_device, + backend_fw, +): + input_dtype, input_values = dtype_target_log_inputs + targets, log_input = input_values + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + test_flags=test_flags, + frontend=frontend, + fn_tree=fn_tree, + on_device=on_device, + targets=targets, + log_input=log_input, + compute_full_loss=compute_full_loss, + atol=1e-2, + ) + + +# max_pool1d +@handle_frontend_test( + fn_tree="tensorflow.nn.max_pool1d", + data_format=df(data_format=st.sampled_from(["NWC"])), x_k_s_p=helpers.arrays_for_pooling(min_dims=3, max_dims=3, min_side=1, max_side=4), test_with_out=st.just(False), ) @@ -1232,54 +1499,6 @@ def test_tensorflow_moments( ) -# Normalize Moments -@st.composite -def _normalize_moments_helper(draw): - shape1, shape2, shape3 = draw(helpers.mutually_broadcastable_shapes(3)) - counts_dtype, counts = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - max_value=999, - min_value=-1001, - max_num_dims=1, - max_dim_size=1, - min_dim_size=1, - ) - ) - _, mean = draw( - helpers.dtype_and_values( - available_dtypes=counts_dtype, - shape=shape1, - min_value=1, - max_num_dims=1, - max_dim_size=1, - min_dim_size=1, - ) - ) - _, variance = draw( - helpers.dtype_and_values( - available_dtypes=counts_dtype, - shape=shape2, - min_value=1, - max_num_dims=1, - max_dim_size=1, - min_dim_size=1, - ) - ) - _, shift = draw( - helpers.dtype_and_values( - available_dtypes=counts_dtype, - shape=shape3, - min_value=1, - max_num_dims=1, - max_dim_size=1, - min_dim_size=1, - ) - ) - - return counts_dtype, counts[0], mean[0], variance[0], shift[0] - - @handle_frontend_test( fn_tree="tensorflow.nn.normalize_moments", data=_normalize_moments_helper(), @@ -1310,72 +1529,33 @@ def test_tensorflow_normalize_moments( ) -@st.composite -def _generate_bias_data(draw): - data_format = draw(st.sampled_from(["NC...", "N...C", None])) - channel_dim = 1 if data_format == "NC..." else -1 - dtype, value, shape = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - min_num_dims=3, - ret_shape=True, - ) - ) - channel_size = shape[channel_dim] - bias = draw(helpers.array_values(dtype=dtype[0], shape=(channel_size,))) - return data_format, dtype, value, bias - - +# pool @handle_frontend_test( - fn_tree="tensorflow.nn.bias_add", - data=_generate_bias_data(), + fn_tree="tensorflow.nn.pool", + x_k_s_p_df=_pool_args(), test_with_out=st.just(False), ) -def test_tensorflow_bias_add( +def test_tensorflow_pool( *, - data, + x_k_s_p_df, frontend, test_flags, fn_tree, - backend_fw, on_device, -): - data_format, dtype, value, bias = data - helpers.test_frontend_function( - input_dtypes=dtype * 2, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - value=value[0], - bias=bias, - data_format=data_format, - ) - - -# convolution -@handle_frontend_test( - fn_tree="tensorflow.nn.convolution", - x_f_d_df=_x_and_filters( - dtypes=helpers.get_dtypes("float", full=False), - data_format=st.sampled_from(["NWC", "NHWC", "NDHWC"]), - padding=st.sampled_from(["SAME", "VALID"]), - dilation_max=1, - type=None, - ), - test_with_out=st.just(False), -) -def test_tensorflow_convolution( - *, - x_f_d_df, - frontend, - test_flags, - fn_tree, backend_fw, - on_device, ): - input_dtype, x, filters, dilation, data_format, stride, padding = x_f_d_df + ( + (input_dtype, x, ksize, strides, padding, dilation), + data_format, + pooling_type, + num_dims, + ) = x_k_s_p_df + if num_dims == 3: + strides = (strides[0],) + elif num_dims == 4: + strides = (strides[0], strides[0]) + elif num_dims == 5: + strides = (strides[0], strides[0], strides[0]) helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -1383,9 +1563,10 @@ def test_tensorflow_convolution( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=x, - filters=filters, - strides=stride, + input=x[0], + window_shape=ksize, + pooling_type=pooling_type, + strides=strides, padding=padding, data_format=data_format, dilations=dilation, @@ -1456,168 +1637,111 @@ def test_tensorflow_relu6( ) -# softmax @handle_frontend_test( - fn_tree="tensorflow.nn.softmax", - dtype_x_and_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("float"), - min_num_dims=1, - force_int_axis=True, - valid_axis=True, + fn_tree="tensorflow.nn.separable_conv2d", + x_f_d_df=_x_and_filters( + dtypes=helpers.get_dtypes("float", full=False), + data_format=st.sampled_from(["NHWC"]), + padding=st.sampled_from(["VALID", "SAME"]), + type="separable", ), test_with_out=st.just(False), ) -def test_tensorflow_softmax( +def test_tensorflow_separable_conv2d( *, - dtype_x_and_axis, + x_f_d_df, + frontend, test_flags, - on_device, fn_tree, - frontend, backend_fw, + on_device, ): - input_dtype, x, axis = dtype_x_and_axis + input_dtype, x, filters, dilation, data_format, stride, padding = x_f_d_df helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, - test_flags=test_flags, frontend=frontend, + test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - logits=x[0], - axis=axis, + input=x, + depthwise_filter=filters[0], + pointwise_filter=filters[1], + strides=stride, + padding=padding, + data_format=data_format, + dilations=dilation, ) -# embedding_lookup +# sigmoid_cross_entropy_with_logits @handle_frontend_test( - fn_tree="tensorflow.nn.embedding_lookup", - dtypes_indices_weights=helpers.embedding_helper(), - max_norm=st.floats(min_value=0.1, max_value=5, exclude_min=True), + fn_tree="tensorflow.nn.sigmoid_cross_entropy_with_logits", + dtype_labels_logits=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, + min_value=0, + max_value=1, + min_num_dims=1, + max_num_dims=2, + min_dim_size=1, + max_dim_size=2, + shared_dtype=True, + ), + test_with_out=st.just(False), ) -def test_tensorflow_embedding_lookup( +def test_tensorflow_sigmoid_cross_entropy_with_logits( *, - dtypes_indices_weights, - max_norm, + dtype_labels_logits, + frontend, test_flags, - on_device, fn_tree, - frontend, backend_fw, + on_device, ): - dtypes, indices, weight, _ = dtypes_indices_weights - dtypes.reverse() + input_dtype, input_values = dtype_labels_logits + labels, logits = input_values helpers.test_frontend_function( - input_dtypes=dtypes, + input_dtypes=input_dtype, backend_to_test=backend_fw, - test_flags=test_flags, frontend=frontend, + test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - params=weight, - ids=indices, - max_norm=max_norm, - atol=1e-4, + labels=labels, + logits=logits, ) -# crelu +# silu @handle_frontend_test( - fn_tree="tensorflow.nn.crelu", - dtype_x_and_axis=helpers.dtype_values_axis( + fn_tree="tensorflow.nn.silu", + dtype_features=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - min_num_dims=4, - max_axes_size=3, - force_int_axis=True, - valid_axis=True, + min_value=0, + max_value=5, + min_num_dims=1, + max_num_dims=3, + min_dim_size=1, + max_dim_size=3, ), - test_with_out=st.just(False), -) -def test_tensorflow_crelu( - *, - dtype_x_and_axis, - test_flags, - frontend, - backend_fw, - fn_tree, - on_device, -): - input_dtype, x, axis = dtype_x_and_axis - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - test_flags=test_flags, - frontend=frontend, - fn_tree=fn_tree, - on_device=on_device, - features=x[0], - axis=axis, - ) - - -@st.composite -def _average_pool_args(draw): - dims = draw(st.integers(min_value=1, max_value=3)) - data_formats = ["NWC", "NHWC", "NDHWC"] - data_format = data_formats[dims - 1] - return ( - draw( - helpers.arrays_for_pooling( - min_dims=dims + 2, max_dims=dims + 2, min_side=1, max_side=4 - ) - ), - data_format, - ) - - -# average_pool -@handle_frontend_test( - fn_tree="tensorflow.nn.avg_pool", - x_k_s_p_df=_average_pool_args(), - test_with_out=st.just(False), -) -def test_tensorflow_avg_pool( - *, - x_k_s_p_df, - frontend, - test_flags, - fn_tree, - backend_fw, - on_device, -): - (input_dtype, x, ksize, strides, padding), data_format = x_k_s_p_df - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=x[0], - ksize=ksize, - strides=strides, - padding=padding, - data_format=data_format, - ) - - -@handle_frontend_test( - fn_tree="tensorflow.nn.avg_pool3d", - x_k_s_p_df=helpers.arrays_for_pooling( - min_dims=5, max_dims=5, min_side=1, max_side=4 + beta=helpers.floats( + min_value=0, + max_value=3, ), test_with_out=st.just(False), ) -def test_tensorflow_avg_pool3d( +def test_tensorflow_silu( *, - x_k_s_p_df, + dtype_features, + beta, frontend, test_flags, fn_tree, backend_fw, on_device, ): - input_dtype, x, ksize, strides, padding = x_k_s_p_df + input_dtype, features = dtype_features helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -1625,131 +1749,43 @@ def test_tensorflow_avg_pool3d( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=x[0], - ksize=ksize, - strides=strides, - padding=padding, + atol=1e-2, + features=features[0], + beta=beta, ) -# test_avg_pool1d +# softmax @handle_frontend_test( - fn_tree="tensorflow.nn.avg_pool1d", - x_k_s_p_df=helpers.arrays_for_pooling( - min_dims=3, max_dims=3, min_side=1, max_side=4 + fn_tree="tensorflow.nn.softmax", + dtype_x_and_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("float"), + min_num_dims=1, + force_int_axis=True, + valid_axis=True, ), test_with_out=st.just(False), ) -def test_tensorflow_avg_pool1d( +def test_tensorflow_softmax( *, - x_k_s_p_df, - frontend, + dtype_x_and_axis, test_flags, - fn_tree, - backend_fw, on_device, -): - (input_dtype, x, ksize, strides, padding) = x_k_s_p_df - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=x[0], - ksize=ksize, - strides=strides, - padding=padding, - ) - - -@st.composite -def _pool_args(draw): - dims = draw(st.integers(min_value=3, max_value=5)) - data_formats = {3: "NWC", 4: "NHWC", 5: "NDHWC"} - data_format = data_formats[dims] - pooling_type = draw(st.one_of(st.just("AVG"), st.just("MAX"))) - return ( - draw( - helpers.arrays_for_pooling( - min_dims=dims, - max_dims=dims, - min_side=1, - max_side=4, - return_dilation=True, - ) - ), - data_format, - pooling_type, - dims, - ) - - -# pool -@handle_frontend_test( - fn_tree="tensorflow.nn.pool", - x_k_s_p_df=_pool_args(), - test_with_out=st.just(False), -) -def test_tensorflow_pool( - *, - x_k_s_p_df, - frontend, - test_flags, fn_tree, - on_device, + frontend, backend_fw, ): - ( - (input_dtype, x, ksize, strides, padding, dilation), - data_format, - pooling_type, - num_dims, - ) = x_k_s_p_df - if num_dims == 3: - strides = (strides[0],) - elif num_dims == 4: - strides = (strides[0], strides[0]) - elif num_dims == 5: - strides = (strides[0], strides[0], strides[0]) + input_dtype, x, axis = dtype_x_and_axis helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, - frontend=frontend, test_flags=test_flags, + frontend=frontend, fn_tree=fn_tree, on_device=on_device, - input=x[0], - window_shape=ksize, - pooling_type=pooling_type, - strides=strides, - padding=padding, - data_format=data_format, - dilations=dilation, - ) - - -# sufficient_statistics -@st.composite -def _axes_value(draw): - s = draw( - helpers.get_shape( - min_num_dims=1, - max_num_dims=5, - min_dim_size=1, - max_dim_size=5, - ) - ) - dtype_and_x = draw( - helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("float"), - shape=s, - valid_axis=True, - force_tuple_axis=True, - ) + logits=x[0], + axis=axis, ) - return dtype_and_x @handle_frontend_test( @@ -1785,78 +1821,50 @@ def test_tensorflow_sufficient_statistics( ) +# weighted_cross_entropy_with_logits @handle_frontend_test( - fn_tree="tensorflow.nn.log_poisson_loss", - dtype_target_log_inputs=helpers.dtype_and_values( + fn_tree="tensorflow.nn.weighted_cross_entropy_with_logits", + dtype_labels_logits=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), num_arrays=2, min_value=0, max_value=1, min_num_dims=1, max_num_dims=3, + min_dim_size=1, + max_dim_size=3, shared_dtype=True, ), - compute_full_loss=st.booleans(), + pos_weight=st.one_of( + helpers.floats( + min_value=0, + max_value=3, + ) + ), test_with_out=st.just(False), ) -def test_tensorflow_log_poisson_loss( +def test_tensorflow_weighted_cross_entropy_with_logits( *, - dtype_target_log_inputs, - compute_full_loss, - test_flags, + dtype_labels_logits, + pos_weight, frontend, + test_flags, fn_tree, - on_device, backend_fw, -): - input_dtype, input_values = dtype_target_log_inputs - targets, log_input = input_values - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - test_flags=test_flags, - frontend=frontend, - fn_tree=fn_tree, - on_device=on_device, - targets=targets, - log_input=log_input, - compute_full_loss=compute_full_loss, - atol=1e-2, - ) - - -# ctc_unique_labels -@handle_frontend_test( - fn_tree="tensorflow.nn.ctc_unique_labels", - dtype_x=helpers.dtype_and_values( - available_dtypes=["int64", "int32"], - min_value=1, - max_value=100, - min_dim_size=1, - max_dim_size=10, - min_num_dims=2, - max_num_dims=2, - ), - test_with_out=st.just([False]), -) -def test_tensorflow_ctc_unique_labels( - *, - dtype_x, - frontend, - fn_tree, - test_flags, on_device, - backend_fw, ): - dtype, x = dtype_x + input_dtype, input_values = dtype_labels_logits + labels, logits = input_values helpers.test_frontend_function( - input_dtypes=dtype, + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - labels=x[0], + labels=labels, + logits=logits, + pos_weight=pos_weight, ) diff --git a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_random.py b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_random.py index 4bd829d30f1c1..86b49d88e9fd6 100644 --- a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_random.py +++ b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_random.py @@ -5,48 +5,82 @@ from ivy_tests.test_ivy.helpers import handle_frontend_test -# random_sample +# --- Helpers --- # +# --------------- # + + +# stateless_poisson +@st.composite +def _shape_lam_dtype(draw): + dtype = draw(helpers.array_dtypes(available_dtypes=("float32", "float64"))) + common_shape = draw( + helpers.get_shape( + allow_none=False, + min_num_dims=2, + max_num_dims=3, + min_dim_size=1, + max_dim_size=5, + ) + ) + _, lam = draw( + helpers.dtype_and_values( + available_dtypes=dtype, min_value=0, max_value=10, shape=(common_shape[-1],) + ) + ) + return common_shape, lam, dtype + + +# --- Main --- # +# ------------ # + + +# random gamma @handle_frontend_test( - fn_tree="tensorflow.random.uniform", - shape=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("integer"), - min_value=1, - max_value=5, + fn_tree="tensorflow.random.gamma", + dtype=helpers.array_dtypes( + available_dtypes=("float32", "float64"), + ), + shape=helpers.get_shape( + allow_none=False, min_num_dims=1, - max_num_dims=1, + max_num_dims=3, + min_dim_size=1, + max_dim_size=5, + ), + alpha=st.floats( + allow_infinity=False, allow_nan=False, width=32, min_value=1, max_value=3 + ), + beta=st.floats( + allow_infinity=False, allow_nan=False, width=32, min_value=1, max_value=3 ), - minval=helpers.ints(min_value=0, max_value=3), - maxval=helpers.ints(min_value=4, max_value=10), - dtype=helpers.get_dtypes("float", full=False), seed=helpers.ints(min_value=0, max_value=10), test_with_out=st.just(False), ) -def test_tensorflow_uniform( +def test_tensorflow_gamma( + frontend, + fn_tree, + on_device, shape, - minval, - maxval, + alpha, + beta, dtype, seed, - frontend, test_flags, - fn_tree, backend_fw, - on_device, ): - input_dtypes, shape = shape helpers.test_frontend_function( - input_dtypes=input_dtypes, + input_dtypes=dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - test_values=False, - shape=shape[0], - minval=minval, - maxval=maxval, + shape=shape, + alpha=alpha, + beta=beta, dtype=dtype[0], seed=seed, + test_values=False, ) @@ -94,86 +128,6 @@ def test_tensorflow_normal( ) -# random_shuffle -@handle_frontend_test( - fn_tree="tensorflow.random.shuffle", - dtype_value=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("valid")), - seed=helpers.ints(min_value=0, max_value=10), - test_with_out=st.just(False), -) -def test_tensorflow_shuffle( - frontend, - fn_tree, - on_device, - dtype_value, - seed, - test_flags, - backend_fw, -): - input_dtypes, values = dtype_value - helpers.test_frontend_function( - input_dtypes=input_dtypes, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - test_values=False, - value=values[0], - seed=seed, - ) - - -# random_stateless_uniform -@handle_frontend_test( - fn_tree="tensorflow.random.stateless_uniform", - shape=helpers.dtype_and_values( - available_dtypes=("int64", "int32"), - min_value=1, - max_value=5, - min_num_dims=1, - max_num_dims=1, - max_dim_size=9, - ), - seed=helpers.dtype_and_values( - available_dtypes=("int64", "int32"), min_value=0, max_value=10, shape=[2] - ), - minmaxval=helpers.get_bounds(dtype="int32"), - dtype=helpers.array_dtypes( - available_dtypes=("int32", "int64", "float16", "float32", "float64"), - ), - test_with_out=st.just(False), -) -def test_tensorflow_stateless_uniform( - shape, - seed, - minmaxval, - dtype, - frontend, - test_flags, - fn_tree, - backend_fw, - on_device, -): - shape_input_dtypes, shape = shape - seed_input_dtypes, seed = seed - - helpers.test_frontend_function( - input_dtypes=shape_input_dtypes + seed_input_dtypes, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - test_values=False, - shape=shape[0], - seed=seed[0], - minval=int(minmaxval[0]), - maxval=int(minmaxval[1]), - dtype=dtype[0], - ) - - # random poisson @handle_frontend_test( fn_tree="tensorflow.random.poisson", @@ -223,6 +177,36 @@ def test_tensorflow_poisson( ) +# random_shuffle +@handle_frontend_test( + fn_tree="tensorflow.random.shuffle", + dtype_value=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("valid")), + seed=helpers.ints(min_value=0, max_value=10), + test_with_out=st.just(False), +) +def test_tensorflow_shuffle( + frontend, + fn_tree, + on_device, + dtype_value, + seed, + test_flags, + backend_fw, +): + input_dtypes, values = dtype_value + helpers.test_frontend_function( + input_dtypes=input_dtypes, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + test_values=False, + value=values[0], + seed=seed, + ) + + # stateless_normal @handle_frontend_test( fn_tree="tensorflow.random.stateless_normal", @@ -270,27 +254,6 @@ def test_tensorflow_stateless_normal( ) -# stateless_poisson -@st.composite -def _shape_lam_dtype(draw): - dtype = draw(helpers.array_dtypes(available_dtypes=("float32", "float64"))) - common_shape = draw( - helpers.get_shape( - allow_none=False, - min_num_dims=2, - max_num_dims=3, - min_dim_size=1, - max_dim_size=5, - ) - ) - _, lam = draw( - helpers.dtype_and_values( - available_dtypes=dtype, min_value=0, max_value=10, shape=(common_shape[-1],) - ) - ) - return common_shape, lam, dtype - - @handle_frontend_test( fn_tree="tensorflow.random.stateless_poisson", shape_lam_dtype=_shape_lam_dtype(), @@ -325,51 +288,96 @@ def test_tensorflow_stateless_poisson( ) -# random gamma +# random_stateless_uniform @handle_frontend_test( - fn_tree="tensorflow.random.gamma", - dtype=helpers.array_dtypes( - available_dtypes=("float32", "float64"), - ), - shape=helpers.get_shape( - allow_none=False, + fn_tree="tensorflow.random.stateless_uniform", + shape=helpers.dtype_and_values( + available_dtypes=("int64", "int32"), + min_value=1, + max_value=5, min_num_dims=1, - max_num_dims=3, - min_dim_size=1, - max_dim_size=5, + max_num_dims=1, + max_dim_size=9, ), - alpha=st.floats( - allow_infinity=False, allow_nan=False, width=32, min_value=1, max_value=3 + seed=helpers.dtype_and_values( + available_dtypes=("int64", "int32"), min_value=0, max_value=10, shape=[2] ), - beta=st.floats( - allow_infinity=False, allow_nan=False, width=32, min_value=1, max_value=3 + minmaxval=helpers.get_bounds(dtype="int32"), + dtype=helpers.array_dtypes( + available_dtypes=("int32", "int64", "float16", "float32", "float64"), ), - seed=helpers.ints(min_value=0, max_value=10), test_with_out=st.just(False), ) -def test_tensorflow_gamma( +def test_tensorflow_stateless_uniform( + shape, + seed, + minmaxval, + dtype, frontend, + test_flags, fn_tree, + backend_fw, on_device, +): + shape_input_dtypes, shape = shape + seed_input_dtypes, seed = seed + + helpers.test_frontend_function( + input_dtypes=shape_input_dtypes + seed_input_dtypes, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + test_values=False, + shape=shape[0], + seed=seed[0], + minval=int(minmaxval[0]), + maxval=int(minmaxval[1]), + dtype=dtype[0], + ) + + +# random_sample +@handle_frontend_test( + fn_tree="tensorflow.random.uniform", + shape=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("integer"), + min_value=1, + max_value=5, + min_num_dims=1, + max_num_dims=1, + ), + minval=helpers.ints(min_value=0, max_value=3), + maxval=helpers.ints(min_value=4, max_value=10), + dtype=helpers.get_dtypes("float", full=False), + seed=helpers.ints(min_value=0, max_value=10), + test_with_out=st.just(False), +) +def test_tensorflow_uniform( shape, - alpha, - beta, + minval, + maxval, dtype, seed, + frontend, test_flags, + fn_tree, backend_fw, + on_device, ): + input_dtypes, shape = shape helpers.test_frontend_function( - input_dtypes=dtype, + input_dtypes=input_dtypes, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - shape=shape, - alpha=alpha, - beta=beta, + test_values=False, + shape=shape[0], + minval=minval, + maxval=maxval, dtype=dtype[0], seed=seed, - test_values=False, ) diff --git a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_raw_ops.py b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_raw_ops.py index 7482ceae775ab..bf328d73ea81a 100644 --- a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_raw_ops.py +++ b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_raw_ops.py @@ -17,126 +17,313 @@ ) -# Acos -@handle_frontend_test( - fn_tree="tensorflow.raw_ops.Acos", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - ), - test_with_out=st.just(False), -) -def test_tensorflow_Acos( # NOQA - *, - dtype_and_x, - frontend, - test_flags, - fn_tree, - backend_fw, - on_device, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], +# for data generation +dtype_shared = st.shared(st.sampled_from(helpers.get_dtypes("numeric")), key="dtype") + + +# --- Helpers --- # +# --------------- # + + +@st.composite +def _LinSpace_helper(draw): + shape = () + dtype = draw(st.sampled_from(["float32", "float64"])) + + # Param: start + start = draw( + helpers.array_values( + dtype=dtype, + shape=shape, + min_value=-5.0, + max_value=5.0, + ), + ) + + # Param: stop + stop = draw( + helpers.array_values( + dtype=dtype, + shape=shape, + min_value=-4.0, + max_value=10.0, + ), ) + return [dtype] * 2, start, stop -# Acosh + +# noinspection DuplicatedCode +@st.composite +def _arrays_idx_n_dtypes(draw): + num_dims = draw(st.shared(helpers.ints(min_value=1, max_value=4), key="num_dims")) + num_arrays = draw( + st.shared(helpers.ints(min_value=2, max_value=4), key="num_arrays") + ) + common_shape = draw( + helpers.list_of_size( + x=helpers.ints(min_value=2, max_value=3), + size=num_dims - 1, + ) + ) + unique_idx = draw(helpers.ints(min_value=0, max_value=num_dims - 1)) + unique_dims = draw( + helpers.list_of_size( + x=helpers.ints(min_value=2, max_value=3), + size=num_arrays, + ) + ) + xs = list() + input_dtypes = draw( + helpers.array_dtypes( + available_dtypes=draw(helpers.get_dtypes("float")), shared_dtype=True + ) + ) + for ud, dt in zip(unique_dims, input_dtypes): + x = draw( + helpers.array_values( + shape=common_shape[:unique_idx] + [ud] + common_shape[unique_idx:], + dtype=dt, + ) + ) + xs.append(x) + return xs, input_dtypes, unique_idx + + +@st.composite +def _dtypes(draw): + return draw( + st.shared( + helpers.list_of_size( + x=st.sampled_from(draw(helpers.get_dtypes("numeric"))), + size=1, + ), + key="dtype", + ) + ) + + +@st.composite +def _fill_value(draw): + dtype = draw(_dtypes())[0] + with BackendHandler.update_backend(test_globals.CURRENT_BACKEND) as ivy_backend: + if ivy_backend.is_uint_dtype(dtype): + return draw(helpers.ints(min_value=0, max_value=5)) + elif ivy_backend.is_int_dtype(dtype): + return draw(helpers.ints(min_value=-5, max_value=5)) + return draw(helpers.floats(min_value=-5, max_value=5)) + + +@st.composite +def _get_shared_dtype(draw): + return st.shared(st.sampled_from(draw(helpers.get_dtypes("numeric"))), key="dtype") + + +@st.composite +def _get_splits(draw, as_list=False): + """Generate valid splits, either by generating an integer that evenly divides the + axis or a list of splits that sum to the length of the axis being split.""" + shape = draw(st.shared(helpers.get_shape(min_num_dims=1), key="value_shape")) + axis = draw( + st.shared(helpers.get_axis(shape=shape, force_int=True), key="target_axis") + ) + + @st.composite + def get_int_split(draw): + if shape[axis] == 0: + return 0 + factors = [] + for i in range(1, shape[axis] + 1): + if shape[axis] % i == 0: + factors.append(i) + return draw(st.sampled_from(factors)) + + @st.composite + def get_list_split(draw): + num_or_size_splits = [] + while sum(num_or_size_splits) < shape[axis]: + split_value = draw( + helpers.ints( + min_value=1, + max_value=shape[axis] - sum(num_or_size_splits), + ) + ) + num_or_size_splits.append(split_value) + return num_or_size_splits + + if as_list: + return draw(get_list_split()) + else: + return draw(get_int_split()) + + +@st.composite +def _pad_helper(draw, return_constant_values=False): + dtype, input, shape = draw( + helpers.dtype_and_values( + min_num_dims=1, + ret_shape=True, + ) + ) + ndim = len(shape) + padding_dtype, paddings = draw( + helpers.dtype_and_values( + available_dtypes=["int32", "int64"], + shape=(ndim, 2), + min_value=0, + max_value=10, + ) + ) + + if return_constant_values: + _, constant_values = draw( + helpers.dtype_and_values( + dtype=dtype, + shape=(1,), + ) + ) + return dtype, input[0], padding_dtype, paddings[0], constant_values[0][0] + + return dtype, input[0], padding_dtype, paddings[0] + + +@st.composite +def _permute_dims_helper(draw): + shape = draw(st.shared(helpers.get_shape(min_num_dims=1), key="shape")) + dims = [x for x in range(len(shape))] + permutation = draw(st.permutations(dims)) + return permutation + + +@st.composite +def _pow_helper_shared_dtype(draw): + dtype, x = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, + shared_dtype=True, + ) + ) + dtype1, dtype2 = dtype + x1, x2 = x + if "int" in dtype2: + x2 = ivy.nested_map(x2, lambda x: abs(x), include_derived={list: True}) + + if ivy.is_int_dtype(dtype2): + max_val = ivy.iinfo(dtype2).max + else: + max_val = ivy.finfo(dtype2).max + max_x1 = np.max(np.abs(x1)) + if max_x1 in [0, 1]: + max_value = None + else: + max_value = int(math.log(max_val) / math.log(max_x1)) + if abs(max_value) > abs(max_val) / 40 or max_value < 0: + max_value = None + + return [dtype1, dtype2], [x1, x2] + + +# Reshape +@st.composite +def _reshape_helper(draw): + # generate a shape s.t len(shape) > 0 + shape = draw(helpers.get_shape(min_num_dims=1)) + reshape_shape = draw(helpers.reshape_shapes(shape=shape)) + dtype = draw(helpers.array_dtypes(num_arrays=1)) + x = draw(helpers.array_values(dtype=dtype[0], shape=shape)) + return x, dtype, reshape_shape + + +@st.composite +def _squeeze_helper(draw): + shape = draw(st.shared(helpers.get_shape(), key="value_shape")) + valid_axes = [] + for index, axis in enumerate(shape): + if axis == 1: + valid_axes.append(index) + valid_axes.insert(0, None) + axis = draw(st.sampled_from(valid_axes)) + return [axis] if axis is not None else axis + + +# Reverse +@st.composite +def reverse_helper(draw): + dtype, x, shape = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + min_num_dims=1, + max_num_dims=8, + ret_shape=True, + ) + ) + axis_dtype, axis = draw( + helpers.dtype_and_values( + available_dtypes=["bool"], + min_num_dims=1, + max_num_dims=1, + num_arrays=1, + shape=(len(shape),), + ) + ) + return dtype, x, axis_dtype, axis + + +# --- Main --- # +# ------------ # + + +# Todo: Revise strategies once reimplemented in frontend +# AccumulateNV2 @handle_frontend_test( - fn_tree="tensorflow.raw_ops.Acosh", + fn_tree="tensorflow.raw_ops.AccumulateNV2", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), ), test_with_out=st.just(False), + shape=helpers.get_shape(min_num_dims=1), ) -def test_tensorflow_Acosh( # NOQA - *, +def test_tensorflow_AccumulateNV2( dtype_and_x, frontend, test_flags, fn_tree, backend_fw, on_device, + shape, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, - frontend=frontend, test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=x[0], - ) - - -# Angle -@handle_frontend_test( - fn_tree="tensorflow.raw_ops.Angle", - dtype_and_xs=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("complex"), - ), - Tout=st.booleans(), - test_with_out=st.just(False), -) -def test_tensorflow_Angle( # NOQA - *, - dtype_and_xs, - Tout, - frontend, - test_flags, - fn_tree, - backend_fw, - on_device, -): - input_dtype, xs = dtype_and_xs - if input_dtype[0] == "complex128": - Tout = "float64" - elif input_dtype[0] == "complex64": - Tout = "float32" if Tout else None - - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, frontend=frontend, - test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=xs[0], - Tout=Tout, + inputs=x[0], + shape=shape, ) -# ApproximateEqual +# Acos @handle_frontend_test( - fn_tree="tensorflow.raw_ops.ApproximateEqual", + fn_tree="tensorflow.raw_ops.Acos", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - shared_dtype=True, - large_abs_safety_factor=20, - small_abs_safety_factor=20, - safety_factor_scale="log", ), - tol=st.floats(1e-05, 1e-03), test_with_out=st.just(False), ) -def test_tensorflow_ApproximateEqual( # NOQA +def test_tensorflow_Acos( # NOQA *, dtype_and_x, - tol, frontend, test_flags, fn_tree, backend_fw, on_device, ): - input_dtype, xs = dtype_and_x + input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -144,23 +331,19 @@ def test_tensorflow_ApproximateEqual( # NOQA test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=xs[0], - y=xs[1], - tolerance=tol, + x=x[0], ) -# AddV2 +# Acosh @handle_frontend_test( - fn_tree="tensorflow.raw_ops.AddV2", + fn_tree="tensorflow.raw_ops.Acosh", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, - shared_dtype=True, + available_dtypes=helpers.get_dtypes("float"), ), test_with_out=st.just(False), ) -def test_tensorflow_AddV2( # NOQA +def test_tensorflow_Acosh( # NOQA *, dtype_and_x, frontend, @@ -178,7 +361,6 @@ def test_tensorflow_AddV2( # NOQA fn_tree=fn_tree, on_device=on_device, x=x[0], - y=x[1], ) @@ -192,137 +374,41 @@ def test_tensorflow_AddV2( # NOQA ) def test_tensorflow_Add( # NOQA *, - dtype_and_x, - frontend, - test_flags, - fn_tree, - backend_fw, - on_device, -): - dtype, xs = dtype_and_x - helpers.test_frontend_function( - input_dtypes=dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=xs[0], - y=xs[1], - ) - - -# for data generation -dtype_shared = st.shared(st.sampled_from(helpers.get_dtypes("numeric")), key="dtype") - - -@st.composite -def _get_shared_dtype(draw): - return st.shared(st.sampled_from(draw(helpers.get_dtypes("numeric"))), key="dtype") - - -# BroadcastTo -@handle_frontend_test( - fn_tree="tensorflow.raw_ops.BroadcastTo", - array_and_shape=helpers.array_and_broadcastable_shape(_get_shared_dtype()), - test_with_out=st.just(False), -) -def test_tensorflow_BroadcastTo( # NOQA - *, - array_and_shape, - frontend, - test_flags, - fn_tree, - backend_fw, - on_device, -): - x, to_shape = array_and_shape - helpers.test_frontend_function( - input_dtypes=[x.dtype], - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=x, - shape=to_shape, - ) - - -# noinspection DuplicatedCode -@st.composite -def _arrays_idx_n_dtypes(draw): - num_dims = draw(st.shared(helpers.ints(min_value=1, max_value=4), key="num_dims")) - num_arrays = draw( - st.shared(helpers.ints(min_value=2, max_value=4), key="num_arrays") - ) - common_shape = draw( - helpers.list_of_size( - x=helpers.ints(min_value=2, max_value=3), - size=num_dims - 1, - ) - ) - unique_idx = draw(helpers.ints(min_value=0, max_value=num_dims - 1)) - unique_dims = draw( - helpers.list_of_size( - x=helpers.ints(min_value=2, max_value=3), - size=num_arrays, - ) - ) - xs = list() - input_dtypes = draw( - helpers.array_dtypes( - available_dtypes=draw(helpers.get_dtypes("float")), shared_dtype=True - ) - ) - for ud, dt in zip(unique_dims, input_dtypes): - x = draw( - helpers.array_values( - shape=common_shape[:unique_idx] + [ud] + common_shape[unique_idx:], - dtype=dt, - ) - ) - xs.append(x) - return xs, input_dtypes, unique_idx - - -# Concat -@handle_frontend_test( - fn_tree="tensorflow.raw_ops.Concat", - xs_n_input_dtypes_n_unique_idx=_arrays_idx_n_dtypes(), - test_with_out=st.just(False), -) -def test_tensorflow_Concat( # NOQA - *, - xs_n_input_dtypes_n_unique_idx, + dtype_and_x, frontend, test_flags, fn_tree, backend_fw, on_device, ): - xs, input_dtypes, unique_idx = xs_n_input_dtypes_n_unique_idx + dtype, xs = dtype_and_x helpers.test_frontend_function( - input_dtypes=input_dtypes, + input_dtypes=dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - concat_dim=unique_idx, - values=xs, + x=xs[0], + y=xs[1], ) -# Cos +# AddN @handle_frontend_test( - fn_tree="tensorflow.raw_ops.Cos", + fn_tree="tensorflow.raw_ops.AddN", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("numeric"), + min_num_dims=1, + large_abs_safety_factor=8, + small_abs_safety_factor=8, + safety_factor_scale="log", + min_value=-1e04, + max_value=1e04, ), test_with_out=st.just(False), ) -def test_tensorflow_Cos( # NOQA +def test_tensorflow_AddN( # NOQA *, dtype_and_x, frontend, @@ -339,26 +425,21 @@ def test_tensorflow_Cos( # NOQA test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], + inputs=x[0], ) -# Cross +# AddV2 @handle_frontend_test( - fn_tree="tensorflow.raw_ops.Cross", + fn_tree="tensorflow.raw_ops.AddV2", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_num_dims=1, - max_num_dims=5, - min_dim_size=3, - max_dim_size=3, - safety_factor_scale="log", + available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2, shared_dtype=True, ), test_with_out=st.just(False), ) -def test_tensorflow_Cross( # NOQA +def test_tensorflow_AddV2( # NOQA *, dtype_and_x, frontend, @@ -367,37 +448,44 @@ def test_tensorflow_Cross( # NOQA backend_fw, on_device, ): - dtype, xs = dtype_and_x + input_dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=dtype, + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - a=xs[0], - b=xs[1], + x=x[0], + y=x[1], ) -# Rsqrt +# Angle @handle_frontend_test( - fn_tree="tensorflow.raw_ops.Rsqrt", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + fn_tree="tensorflow.raw_ops.Angle", + dtype_and_xs=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("complex"), ), + Tout=st.booleans(), test_with_out=st.just(False), ) -def test_tensorflow_Rsqrt( +def test_tensorflow_Angle( # NOQA *, - dtype_and_x, + dtype_and_xs, + Tout, frontend, test_flags, fn_tree, backend_fw, on_device, ): - input_dtype, x = dtype_and_x + input_dtype, xs = dtype_and_xs + if input_dtype[0] == "complex128": + Tout = "float64" + elif input_dtype[0] == "complex64": + Tout = "float32" if Tout else None + helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -405,28 +493,36 @@ def test_tensorflow_Rsqrt( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], + input=xs[0], + Tout=Tout, ) -# Cosh +# ApproximateEqual @handle_frontend_test( - fn_tree="tensorflow.raw_ops.Cosh", + fn_tree="tensorflow.raw_ops.ApproximateEqual", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, + shared_dtype=True, + large_abs_safety_factor=20, + small_abs_safety_factor=20, + safety_factor_scale="log", ), + tol=st.floats(1e-05, 1e-03), test_with_out=st.just(False), ) -def test_tensorflow_Cosh( +def test_tensorflow_ApproximateEqual( # NOQA *, dtype_and_x, + tol, frontend, test_flags, fn_tree, backend_fw, on_device, ): - input_dtype, x = dtype_and_x + input_dtype, xs = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -434,41 +530,38 @@ def test_tensorflow_Cosh( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], - ) - - -@st.composite -def _dtypes(draw): - return draw( - st.shared( - helpers.list_of_size( - x=st.sampled_from(draw(helpers.get_dtypes("numeric"))), - size=1, - ), - key="dtype", - ) + x=xs[0], + y=xs[1], + tolerance=tol, ) -# Div +# argmax @handle_frontend_test( - fn_tree="tensorflow.raw_ops.Div", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), num_arrays=2, shared_dtype=True + fn_tree="tensorflow.raw_ops.ArgMax", + dtype_x_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("numeric"), + valid_axis=True, + force_int_axis=True, + min_num_dims=1, + min_value=-5, + max_value=5, + allow_inf=False, ), + output_type=st.sampled_from(["int16", "int32", "int64"]), test_with_out=st.just(False), ) -def test_tensorflow_Div( # NOQA +def test_tensorflow_ArgMax( # NOQA *, - dtype_and_x, + dtype_x_axis, + output_type, frontend, test_flags, fn_tree, backend_fw, on_device, ): - dtype, xs = dtype_and_x + dtype, x, axis = dtype_x_axis helpers.test_frontend_function( input_dtypes=dtype, backend_to_test=backend_fw, @@ -476,55 +569,48 @@ def test_tensorflow_Div( # NOQA test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=xs[0], - y=xs[1], + input=x[0], + dimension=axis, + output_type=output_type, ) -@st.composite -def _fill_value(draw): - dtype = draw(_dtypes())[0] - with BackendHandler.update_backend(test_globals.CURRENT_BACKEND) as ivy_backend: - if ivy_backend.is_uint_dtype(dtype): - return draw(helpers.ints(min_value=0, max_value=5)) - elif ivy_backend.is_int_dtype(dtype): - return draw(helpers.ints(min_value=-5, max_value=5)) - return draw(helpers.floats(min_value=-5, max_value=5)) - - -# fill +# ArgMin @handle_frontend_test( - fn_tree="tensorflow.raw_ops.Fill", - shape=helpers.get_shape( - allow_none=False, + fn_tree="tensorflow.raw_ops.ArgMin", + dtype_x_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("numeric"), + valid_axis=True, + force_int_axis=True, min_num_dims=1, - min_dim_size=1, + min_value=-5, + max_value=5, + allow_inf=False, ), - fill_value=_fill_value(), - dtypes=_dtypes(), + output_type=st.sampled_from(["int32", "int64"]), test_with_out=st.just(False), ) -def test_tensorflow_Fill( # NOQA +def test_tensorflow_ArgMin( # NOQA *, - shape, - fill_value, - dtypes, + dtype_x_axis, + output_type, frontend, test_flags, fn_tree, backend_fw, on_device, ): + dtype, x, axis = dtype_x_axis helpers.test_frontend_function( - input_dtypes=dtypes, + input_dtypes=dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - rtol=1e-05, - dims=shape, - value=fill_value, + input=x[0], + dimension=axis, + output_type=output_type, ) @@ -557,32 +643,87 @@ def test_tensorflow_Asin( # NOQA ) -# argmax +# Atan @handle_frontend_test( - fn_tree="tensorflow.raw_ops.ArgMax", - dtype_x_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("numeric"), - valid_axis=True, - force_int_axis=True, - min_num_dims=1, - min_value=-5, - max_value=5, - allow_inf=False, + fn_tree="tensorflow.raw_ops.Atan", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), ), - output_type=st.sampled_from(["int16", "int32", "int64"]), test_with_out=st.just(False), ) -def test_tensorflow_ArgMax( # NOQA +def test_tensorflow_Atan( # NOQA *, - dtype_x_axis, - output_type, + dtype_and_x, + frontend, + test_flags, + fn_tree, + backend_fw, + on_device, +): + dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + ) + + +# Atan2 +@handle_frontend_test( + fn_tree="tensorflow.raw_ops.Atan2", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, + shared_dtype=True, + ), + test_with_out=st.just(False), +) +def test_tensorflow_Atan2( # NOQA + *, + dtype_and_x, + frontend, + test_flags, + fn_tree, + on_device, + backend_fw, +): + input_dtype, xs = dtype_and_x + + # Assuming x and y have the same shape + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + y=xs[0], + x=xs[1], + ) + + +# Atanh +@handle_frontend_test( + fn_tree="tensorflow.raw_ops.Atanh", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + ), + test_with_out=st.just(False), +) +def test_tensorflow_Atanh( # NOQA + *, + dtype_and_x, frontend, test_flags, fn_tree, backend_fw, on_device, ): - dtype, x, axis = dtype_x_axis + dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=dtype, backend_to_test=backend_fw, @@ -590,149 +731,164 @@ def test_tensorflow_ArgMax( # NOQA test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=x[0], - dimension=axis, - output_type=output_type, + x=x[0], ) -# ArgMin +# Todo: Revise strategies once reimplemented in frontend +# BandedTriangularSolve @handle_frontend_test( - fn_tree="tensorflow.raw_ops.ArgMin", - dtype_x_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("numeric"), - valid_axis=True, - force_int_axis=True, - min_num_dims=1, - min_value=-5, - max_value=5, - allow_inf=False, + fn_tree="tensorflow.raw_ops.BandedTriangularSolve", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, ), - output_type=st.sampled_from(["int32", "int64"]), test_with_out=st.just(False), + lower=st.booleans(), + adjoint=st.booleans(), ) -def test_tensorflow_ArgMin( # NOQA - *, - dtype_x_axis, - output_type, +def test_tensorflow_BandedTriangularSolve( + dtype_and_x, frontend, test_flags, fn_tree, backend_fw, on_device, + lower, + adjoint, ): - dtype, x, axis = dtype_x_axis + input_dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=dtype, + input_dtypes=input_dtype, backend_to_test=backend_fw, - frontend=frontend, test_flags=test_flags, + frontend=frontend, fn_tree=fn_tree, on_device=on_device, - input=x[0], - dimension=axis, - output_type=output_type, + matrix=x[0], + rhs=x[1], + lower=lower, + adjoint=adjoint, ) -# Atan +# Todo: Revise strategies once reimplemented in frontend +# BatchMatMul @handle_frontend_test( - fn_tree="tensorflow.raw_ops.Atan", + fn_tree="tensorflow.raw_ops.BatchMatMul", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, ), test_with_out=st.just(False), + adj_x=st.booleans(), + adj_y=st.booleans(), ) -def test_tensorflow_Atan( # NOQA - *, +def test_tensorflow_BatchMatMul( dtype_and_x, frontend, test_flags, fn_tree, backend_fw, on_device, + adj_x, + adj_y, ): - dtype, x = dtype_and_x + input_dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=dtype, + input_dtypes=input_dtype, backend_to_test=backend_fw, - frontend=frontend, test_flags=test_flags, + frontend=frontend, fn_tree=fn_tree, on_device=on_device, x=x[0], + y=x[1], + adj_x=adj_x, + adj_y=adj_y, ) -# Atan2 +# Todo: Revise strategies once reimplemented in frontend +# BatchMatMulV2 @handle_frontend_test( - fn_tree="tensorflow.raw_ops.Atan2", + fn_tree="tensorflow.raw_ops.BatchMatMulV2", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), num_arrays=2, - shared_dtype=True, ), test_with_out=st.just(False), + adj_x=st.booleans(), + adj_y=st.booleans(), ) -def test_tensorflow_Atan2( # NOQA - *, +def test_tensorflow_BatchMatMulV2( dtype_and_x, frontend, test_flags, fn_tree, - on_device, backend_fw, + on_device, + adj_x, + adj_y, ): - input_dtype, xs = dtype_and_x - - # Assuming x and y have the same shape + input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, - frontend=frontend, test_flags=test_flags, + frontend=frontend, fn_tree=fn_tree, on_device=on_device, - y=xs[0], - x=xs[1], + x=x[0], + y=x[1], + adj_x=adj_x, + adj_y=adj_y, ) -# BitwiseAnd +# Todo: Revise strategies once reimplemented in frontend +# BatchMatMulV3 @handle_frontend_test( - fn_tree="tensorflow.raw_ops.BitwiseAnd", + fn_tree="tensorflow.raw_ops.BatchMatMulV3", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("integer"), + available_dtypes=helpers.get_dtypes("float"), num_arrays=2, - shared_dtype=True, ), test_with_out=st.just(False), + Tout=st.sampled_from(["float32", "float64"]), + adj_x=st.booleans(), + adj_y=st.booleans(), ) -def test_tensorflow_BitwiseAnd( # NOQA - *, +def test_tensorflow_BatchMatMulV3( dtype_and_x, frontend, test_flags, fn_tree, backend_fw, on_device, + Tout, + adj_x, + adj_y, ): - input_dtype, xs = dtype_and_x + input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, - frontend=frontend, test_flags=test_flags, + frontend=frontend, fn_tree=fn_tree, on_device=on_device, - x=xs[0], - y=xs[1], + x=x[0], + y=x[1], + Tout=Tout, + adj_x=adj_x, + adj_y=adj_y, ) -# BitwiseOr +# BitwiseAnd @handle_frontend_test( - fn_tree="tensorflow.raw_ops.BitwiseOr", + fn_tree="tensorflow.raw_ops.BitwiseAnd", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("integer"), num_arrays=2, @@ -740,7 +896,7 @@ def test_tensorflow_BitwiseAnd( # NOQA ), test_with_out=st.just(False), ) -def test_tensorflow_BitwiseOr( # NOQA +def test_tensorflow_BitwiseAnd( # NOQA *, dtype_and_x, frontend, @@ -762,9 +918,9 @@ def test_tensorflow_BitwiseOr( # NOQA ) -# BitwiseXor +# BitwiseOr @handle_frontend_test( - fn_tree="tensorflow.raw_ops.BitwiseXor", + fn_tree="tensorflow.raw_ops.BitwiseOr", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("integer"), num_arrays=2, @@ -772,7 +928,7 @@ def test_tensorflow_BitwiseOr( # NOQA ), test_with_out=st.just(False), ) -def test_tensorflow_BitwiseXor( # NOQA +def test_tensorflow_BitwiseOr( # NOQA *, dtype_and_x, frontend, @@ -794,15 +950,17 @@ def test_tensorflow_BitwiseXor( # NOQA ) -# Atanh +# BitwiseXor @handle_frontend_test( - fn_tree="tensorflow.raw_ops.Atanh", + fn_tree="tensorflow.raw_ops.BitwiseXor", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("integer"), + num_arrays=2, + shared_dtype=True, ), test_with_out=st.just(False), ) -def test_tensorflow_Atanh( # NOQA +def test_tensorflow_BitwiseXor( # NOQA *, dtype_and_x, frontend, @@ -811,56 +969,55 @@ def test_tensorflow_Atanh( # NOQA backend_fw, on_device, ): - dtype, x = dtype_and_x + input_dtype, xs = dtype_and_x helpers.test_frontend_function( - input_dtypes=dtype, + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], + x=xs[0], + y=xs[1], ) -# Tan +# BroadcastTo @handle_frontend_test( - fn_tree="tensorflow.raw_ops.Tan", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - ), + fn_tree="tensorflow.raw_ops.BroadcastTo", + array_and_shape=helpers.array_and_broadcastable_shape(_get_shared_dtype()), test_with_out=st.just(False), ) -def test_tensorflow_Tan( # NOQA +def test_tensorflow_BroadcastTo( # NOQA *, - dtype_and_x, + array_and_shape, frontend, test_flags, fn_tree, backend_fw, on_device, ): - input_dtype, x = dtype_and_x + x, to_shape = array_and_shape helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=[x.dtype], backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], + input=x, + shape=to_shape, ) -# Square @handle_frontend_test( - fn_tree="tensorflow.raw_ops.Square", + fn_tree="tensorflow.raw_ops.Ceil", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), + available_dtypes=helpers.get_dtypes("float"), ), test_with_out=st.just(False), ) -def test_tensorflow_Square( # NOQA +def test_tensorflow_Ceil( # NOQA *, dtype_and_x, frontend, @@ -882,24 +1039,30 @@ def test_tensorflow_Square( # NOQA @handle_frontend_test( - fn_tree="tensorflow.raw_ops.SquaredDifference", + fn_tree="tensorflow.raw_ops.Cholesky", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, - shared_dtype=True, + available_dtypes=helpers.get_dtypes("float"), + min_value=0, + max_value=10, + shape=helpers.ints(min_value=2, max_value=5).map(lambda x: tuple([x, x])), ), test_with_out=st.just(False), ) -def test_tensorflow_SquaredDifference( +def test_tensorflow_Cholesky( # NOQA *, dtype_and_x, frontend, test_flags, fn_tree, - on_device, backend_fw, + on_device, ): dtype, x = dtype_and_x + x = x[0] + x = ( + np.matmul(x.T, x) + np.identity(x.shape[0]) * 1e-3 + ) # make symmetric positive-definite + helpers.test_frontend_function( input_dtypes=dtype, backend_to_test=backend_fw, @@ -907,247 +1070,214 @@ def test_tensorflow_SquaredDifference( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], - y=x[1], + input=x, + rtol=1e-4, + atol=1e-4, ) -@st.composite -def _squeeze_helper(draw): - shape = draw(st.shared(helpers.get_shape(), key="value_shape")) - valid_axes = [] - for index, axis in enumerate(shape): - if axis == 1: - valid_axes.append(index) - valid_axes.insert(0, None) - axis = draw(st.sampled_from(valid_axes)) - return [axis] if axis is not None else axis - - -# Squeeze +# Todo: Revise strategies once reimplemented in frontend +# Complex @handle_frontend_test( - fn_tree="tensorflow.raw_ops.Squeeze", - dtype_value=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - shape=st.shared(helpers.get_shape(), key="value_shape"), + fn_tree="tensorflow.raw_ops.Complex", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, ), - axis=_squeeze_helper(), test_with_out=st.just(False), + Tout=st.sampled_from(["complex64", "complex128"]), ) -def test_tensorflow_Squeeze( # NOQA - dtype_value, - axis, +def test_tensorflow_Complex( + dtype_and_x, frontend, test_flags, fn_tree, backend_fw, on_device, + Tout, ): - dtype, xs = dtype_value + input_dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=dtype, + input_dtypes=input_dtype, backend_to_test=backend_fw, - frontend=frontend, test_flags=test_flags, + frontend=frontend, fn_tree=fn_tree, on_device=on_device, - input=xs[0], - axis=axis, - ) - - -# Sign -@handle_frontend_test( - fn_tree="tensorflow.raw_ops.Sign", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - large_abs_safety_factor=5, - small_abs_safety_factor=5, - safety_factor_scale="log", - ), + real=x[0], + imag=x[1], + Tout=Tout, + ) + + +# Concat +@handle_frontend_test( + fn_tree="tensorflow.raw_ops.Concat", + xs_n_input_dtypes_n_unique_idx=_arrays_idx_n_dtypes(), test_with_out=st.just(False), ) -def test_tensorflow_Sign( # NOQA +def test_tensorflow_Concat( # NOQA *, - dtype_and_x, + xs_n_input_dtypes_n_unique_idx, frontend, test_flags, fn_tree, backend_fw, on_device, ): - input_dtype, x = dtype_and_x + xs, input_dtypes, unique_idx = xs_n_input_dtypes_n_unique_idx helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=input_dtypes, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], - ) - - -@st.composite -def _get_splits(draw, as_list=False): - """Generate valid splits, either by generating an integer that evenly divides the - axis or a list of splits that sum to the length of the axis being split.""" - shape = draw(st.shared(helpers.get_shape(min_num_dims=1), key="value_shape")) - axis = draw( - st.shared(helpers.get_axis(shape=shape, force_int=True), key="target_axis") + concat_dim=unique_idx, + values=xs, ) - @st.composite - def get_int_split(draw): - if shape[axis] == 0: - return 0 - factors = [] - for i in range(1, shape[axis] + 1): - if shape[axis] % i == 0: - factors.append(i) - return draw(st.sampled_from(factors)) - - @st.composite - def get_list_split(draw): - num_or_size_splits = [] - while sum(num_or_size_splits) < shape[axis]: - split_value = draw( - helpers.ints( - min_value=1, - max_value=shape[axis] - sum(num_or_size_splits), - ) - ) - num_or_size_splits.append(split_value) - return num_or_size_splits - - if as_list: - return draw(get_list_split()) - else: - return draw(get_int_split()) - -# Split +# ConcatV2 @handle_frontend_test( - fn_tree="tensorflow.raw_ops.Split", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - shape=st.shared(helpers.get_shape(min_num_dims=1), key="value_shape"), - ), - axis=st.shared( - helpers.get_axis( - shape=st.shared(helpers.get_shape(min_num_dims=1), key="value_shape"), - force_int=True, - ), - key="target_axis", - ), - num_splits=_get_splits(), + fn_tree="tensorflow.raw_ops.ConcatV2", + xs_n_input_dtypes_n_unique_idx=_arrays_idx_n_dtypes(), + test_with_out=st.just(False), + number_positional_args=st.just(0), ) -def test_tensorflow_Split( # NOQA - *, - dtype_and_x, - axis, - num_splits, - frontend, +def test_tensorflow_ConcatV2( + xs_n_input_dtypes_n_unique_idx, test_flags, - fn_tree, + frontend, backend_fw, - on_device, + fn_tree, ): - dtype, value = dtype_and_x + xs, input_dtypes, unique_idx = xs_n_input_dtypes_n_unique_idx helpers.test_frontend_function( - input_dtypes=dtype, + input_dtypes=input_dtypes, backend_to_test=backend_fw, - frontend=frontend, test_flags=test_flags, + frontend=frontend, fn_tree=fn_tree, - on_device=on_device, - value=value[0], - axis=axis, - num_split=num_splits, + values=xs, + axis=unique_idx, ) -# SplitV +# Conv2D @handle_frontend_test( - fn_tree="tensorflow.raw_ops.SplitV", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - shape=st.shared(helpers.get_shape(min_num_dims=1), key="value_shape"), - ), - axis=st.shared( - helpers.get_axis( - shape=st.shared(helpers.get_shape(min_num_dims=1), key="value_shape"), - force_int=True, - ), - key="target_axis", + fn_tree="tensorflow.raw_ops.Conv2D", + x_f_d_df=_x_and_filters( + dtypes=helpers.get_dtypes("float", full=False), + data_format=st.sampled_from(["NHWC"]), + padding=st.sampled_from(["SAME", "VALID", "EXPLICIT"]), + type="2d", + dilation_min=1, + dilation_max=1, ), - size_splits=_get_splits(as_list=True), test_with_out=st.just(False), + number_positional_args=st.just(0), ) -def test_tensorflow_SplitV( # NOQA +def test_tensorflow_Conv2D( *, - dtype_and_x, - axis, - size_splits, - frontend, + x_f_d_df, test_flags, - fn_tree, + frontend, backend_fw, + fn_tree, on_device, ): - dtype, value = dtype_and_x + input_dtype, x, filters, dilation, data_format, stride, padding = x_f_d_df + channel_index = data_format.find("C") + stride = _convolution_broadcast_helper( + stride, num_spatial_dims=2, channel_index=channel_index, name="strides" + ) + dilation = _convolution_broadcast_helper( + dilation, num_spatial_dims=2, channel_index=channel_index, name="dilations" + ) + explicit_padding = None + if isinstance(padding, list): + explicit_padding = padding + padding = "EXPLICIT" + helpers.test_frontend_function( - input_dtypes=dtype, + input_dtypes=input_dtype, backend_to_test=backend_fw, - frontend=frontend, test_flags=test_flags, + frontend=frontend, fn_tree=fn_tree, on_device=on_device, - value=value[0], - axis=axis, - size_splits=size_splits, - num_split=len(size_splits), + input=x, + filter=filters, + strides=stride, + padding=padding, + explicit_paddings=explicit_padding, + data_format=data_format, + dilations=dilation, + use_cudnn_on_gpu=True, ) -# Sqrt +# Conv3D @handle_frontend_test( - fn_tree="tensorflow.raw_ops.Sqrt", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + fn_tree="tensorflow.raw_ops.Conv3D", + x_f_d_df=_x_and_filters( + dtypes=helpers.get_dtypes("float", full=False), + data_format=st.sampled_from(["NDHWC"]), + padding=st.sampled_from(["SAME", "VALID"]), + type="3d", + # Tensorflow backprop doesn't support dilations more than 1 on CPU + dilation_min=1, + dilation_max=1, ), test_with_out=st.just(False), + number_positional_args=st.just(0), ) -def test_tensorflow_Sqrt( # NOQA +def test_tensorflow_Conv3D( *, - dtype_and_x, - frontend, + x_f_d_df, test_flags, - fn_tree, + frontend, backend_fw, + fn_tree, on_device, ): - input_dtype, x = dtype_and_x + input_dtype, x, filters, dilation, data_format, stride, padding = x_f_d_df + + # Broadcast stirdes and dilations to correct dims for the ground truth + # backend func to run correctly + stride = _convolution_broadcast_helper( + stride, num_spatial_dims=3, channel_index=4, name="strides" + ) + dilation = _convolution_broadcast_helper( + dilation, num_spatial_dims=3, channel_index=4, name="dilations" + ) + helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, - frontend=frontend, test_flags=test_flags, + frontend=frontend, fn_tree=fn_tree, on_device=on_device, - x=x[0], + input=x, + filter=filters, + strides=stride, + padding=padding, + data_format=data_format, + dilations=dilation, ) -# Tanh +# Cos @handle_frontend_test( - fn_tree="tensorflow.raw_ops.Tanh", + fn_tree="tensorflow.raw_ops.Cos", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), ), test_with_out=st.just(False), ) -def test_tensorflow_Tanh( # NOQA +def test_tensorflow_Cos( # NOQA *, dtype_and_x, frontend, @@ -1168,15 +1298,15 @@ def test_tensorflow_Tanh( # NOQA ) -# TanhGrad +# Cosh @handle_frontend_test( - fn_tree="tensorflow.raw_ops.TanhGrad", + fn_tree="tensorflow.raw_ops.Cosh", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), num_arrays=2, shared_dtype=True + available_dtypes=helpers.get_dtypes("float"), ), test_with_out=st.just(False), ) -def test_tensorflow_TanhGrad( # NOQA +def test_tensorflow_Cosh( *, dtype_and_x, frontend, @@ -1185,166 +1315,187 @@ def test_tensorflow_TanhGrad( # NOQA backend_fw, on_device, ): - dtype, xs = dtype_and_x + input_dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=dtype, + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - y=xs[0], - dy=xs[1], + x=x[0], ) -@st.composite -def _permute_dims_helper(draw): - shape = draw(st.shared(helpers.get_shape(min_num_dims=1), key="shape")) - dims = [x for x in range(len(shape))] - permutation = draw(st.permutations(dims)) - return permutation - - -# Transpose +# Cross @handle_frontend_test( - fn_tree="tensorflow.raw_ops.Transpose", + fn_tree="tensorflow.raw_ops.Cross", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), + min_num_dims=1, + max_num_dims=5, + min_dim_size=3, + max_dim_size=3, + safety_factor_scale="log", + num_arrays=2, + shared_dtype=True, ), - perm=_permute_dims_helper(), test_with_out=st.just(False), ) -def test_tensorflow_transpose( # NOQA +def test_tensorflow_Cross( # NOQA *, dtype_and_x, - perm, frontend, test_flags, fn_tree, backend_fw, on_device, ): - dtype, x = dtype_and_x + dtype, xs = dtype_and_x helpers.test_frontend_function( input_dtypes=dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, - on_device=on_device, - x=x[0], - perm=perm, + on_device=on_device, + a=xs[0], + b=xs[1], ) -# Maximum +# Cumprod @handle_frontend_test( - fn_tree="tensorflow.raw_ops.Maximum", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - shared_dtype=True, + fn_tree="tensorflow.raw_ops.Cumprod", + dtype_x_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("numeric"), + valid_axis=True, + force_int_axis=True, + min_num_dims=1, + min_value=-5, + max_value=5, ), + exclusive=st.booleans(), + reverse=st.booleans(), test_with_out=st.just(False), ) -def test_tensorflow_Maximum( # NOQA +def test_tensorflow_Cumprod( # NOQA *, - dtype_and_x, + dtype_x_axis, + exclusive, + reverse, frontend, test_flags, fn_tree, backend_fw, on_device, ): - input_dtype, xs = dtype_and_x + dtype, x, axis = dtype_x_axis helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=xs[0], - y=xs[1], + x=x[0], + axis=axis, + exclusive=exclusive, + reverse=reverse, ) -# Minimum +# Cumsum @handle_frontend_test( - fn_tree="tensorflow.raw_ops.Minimum", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - shared_dtype=True, + fn_tree="tensorflow.raw_ops.Cumsum", + dtype_x_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("numeric"), + valid_axis=True, + force_int_axis=True, + min_num_dims=1, + min_value=-5, + max_value=5, ), + exclusive=st.booleans(), + reverse=st.booleans(), test_with_out=st.just(False), ) -def test_tensorflow_Minimum( # NOQA +def test_tensorflow_Cumsum( # NOQA *, - dtype_and_x, + dtype_x_axis, + exclusive, + reverse, frontend, test_flags, fn_tree, backend_fw, on_device, ): - input_dtype, xs = dtype_and_x + dtype, x, axis = dtype_x_axis helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=xs[0], - y=xs[1], + rtol=1e-02, + atol=1e-02, + x=x[0], + axis=axis, + exclusive=exclusive, + reverse=reverse, ) -# Sub +# Todo: Revise strategies once reimplemented in frontend +# CumulativeLogsumexp @handle_frontend_test( - fn_tree="tensorflow.raw_ops.Sub", + fn_tree="tensorflow.raw_ops.CumulativeLogsumexp", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2, shared_dtype=True + available_dtypes=helpers.get_dtypes("float"), ), + axis=st.just(0), test_with_out=st.just(False), + exclusive=st.booleans(), + reverse=st.booleans(), ) -def test_tensorflow_Sub( # NOQA - *, +def test_tensorflow_CumulativeLogsumexp( dtype_and_x, frontend, test_flags, fn_tree, backend_fw, on_device, + axis, + exclusive, + reverse, ): - dtype, xs = dtype_and_x + input_dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=dtype, + input_dtypes=input_dtype, backend_to_test=backend_fw, - frontend=frontend, test_flags=test_flags, + frontend=frontend, fn_tree=fn_tree, on_device=on_device, - x=xs[0], - y=xs[1], + x=x[0], + axis=axis, + exclusive=exclusive, + reverse=reverse, ) -# Less +# Todo: Revise strategies once reimplemented in frontend +# DebugGradientIdentity @handle_frontend_test( - fn_tree="tensorflow.raw_ops.Less", + fn_tree="tensorflow.raw_ops.DebugGradientIdentity", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, - shared_dtype=True, + available_dtypes=helpers.get_dtypes("float"), ), test_with_out=st.just(False), ) -def test_tensorflow_Less( # NOQA - *, +def test_tensorflow_DebugGradientIdentity( dtype_and_x, frontend, test_flags, @@ -1352,7 +1503,7 @@ def test_tensorflow_Less( # NOQA backend_fw, on_device, ): - input_dtype, xs = dtype_and_x + input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -1360,22 +1511,27 @@ def test_tensorflow_Less( # NOQA test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=xs[0], - y=xs[1], + input=x[0], ) -# LessEqual @handle_frontend_test( - fn_tree="tensorflow.raw_ops.LessEqual", + fn_tree="tensorflow.raw_ops.Diag", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, - shared_dtype=True, + available_dtypes=[ + "float32", + "float64", + "int32", + "int64", + ], + min_num_dims=1, + max_num_dims=1, + min_value=-1e30, + max_value=1e30, ), test_with_out=st.just(False), ) -def test_tensorflow_LessEqual( # NOQA +def test_tensorflow_Diag( # NOQA *, dtype_and_x, frontend, @@ -1384,28 +1540,27 @@ def test_tensorflow_LessEqual( # NOQA backend_fw, on_device, ): - input_dtype, xs = dtype_and_x + dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=xs[0], - y=xs[1], + diagonal=x[0], ) -# Floor +# Div @handle_frontend_test( - fn_tree="tensorflow.raw_ops.Floor", + fn_tree="tensorflow.raw_ops.Div", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("float"), num_arrays=2, shared_dtype=True ), test_with_out=st.just(False), ) -def test_tensorflow_Floor( # NOQA +def test_tensorflow_Div( # NOQA *, dtype_and_x, frontend, @@ -1414,61 +1569,69 @@ def test_tensorflow_Floor( # NOQA backend_fw, on_device, ): - input_dtype, x = dtype_and_x + dtype, xs = dtype_and_x helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], + x=xs[0], + y=xs[1], ) -# FloorDiv +# Elu @handle_frontend_test( - fn_tree="tensorflow.raw_ops.FloorDiv", + fn_tree="tensorflow.raw_ops.Elu", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - shared_dtype=True, + available_dtypes=helpers.get_dtypes("valid"), + min_value=-3, + max_value=3, + min_num_dims=1, + max_num_dims=3, + min_dim_size=1, + max_dim_size=3, ), + name=st.just(None), test_with_out=st.just(False), + number_positional_args=st.just(0), ) -def test_tensorflow_FloorDiv( # NOQA +def test_tensorflow_Elu( *, dtype_and_x, + name, frontend, test_flags, fn_tree, backend_fw, on_device, ): - input_dtype, xs = dtype_and_x + dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=xs[0], - y=xs[1], + features=x[0], + name=name, ) -# FloorMod +# Equal @handle_frontend_test( - fn_tree="tensorflow.raw_ops.FloorMod", + fn_tree="tensorflow.raw_ops.Equal", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2, shared_dtype=True, ), test_with_out=st.just(False), ) -def test_tensorflow_FloorMod( # NOQA +def test_tensorflow_Equal( # NOQA *, dtype_and_x, frontend, @@ -1477,7 +1640,7 @@ def test_tensorflow_FloorMod( # NOQA backend_fw, on_device, ): - input_dtype, xs = dtype_and_x + input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -1485,34 +1648,39 @@ def test_tensorflow_FloorMod( # NOQA test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=xs[0], - y=xs[1], + x=x[0], + y=x[1], ) -# FFT +# EuclideanNorm @handle_frontend_test( - fn_tree="tensorflow.raw_ops.FFT", - dtype_and_x=helpers.dtype_and_values( - min_num_dims=1, - min_dim_size=2, - large_abs_safety_factor=15, - small_abs_safety_factor=15, - safety_factor_scale="log", - available_dtypes=helpers.get_dtypes("complex"), + fn_tree="tensorflow.raw_ops.EuclideanNorm", + dtype_values_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("float"), + min_num_dims=3, + max_num_dims=5, + min_dim_size=1, + max_dim_size=4, + min_axis=-3, + max_axis=2, + valid_axis=True, + allow_neg_axes=True, ), + keep_dims=st.booleans(), test_with_out=st.just(False), + number_positional_args=st.just(0), ) -def test_tensorflow_FFT( # NOQA - *, - dtype_and_x, +def test_tensorflow_EuclideanNorm( + dtype_values_axis, + keep_dims, frontend, test_flags, fn_tree, backend_fw, on_device, ): - dtype, x = dtype_and_x + dtype, values, axis = dtype_values_axis helpers.test_frontend_function( input_dtypes=dtype, backend_to_test=backend_fw, @@ -1520,9 +1688,9 @@ def test_tensorflow_FFT( # NOQA test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=x[0], - rtol=1e-02, - atol=1e-02, + input=values[0], + axis=axis, + keep_dims=keep_dims, ) @@ -1584,15 +1752,20 @@ def test_tensorflow_Expm1( # NOQA ) -# Log +# FFT @handle_frontend_test( - fn_tree="tensorflow.raw_ops.Log", + fn_tree="tensorflow.raw_ops.FFT", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + min_num_dims=1, + min_dim_size=2, + large_abs_safety_factor=15, + small_abs_safety_factor=15, + safety_factor_scale="log", + available_dtypes=helpers.get_dtypes("complex"), ), test_with_out=st.just(False), ) -def test_tensorflow_Log( # NOQA +def test_tensorflow_FFT( # NOQA *, dtype_and_x, frontend, @@ -1601,55 +1774,65 @@ def test_tensorflow_Log( # NOQA backend_fw, on_device, ): - input_dtype, x = dtype_and_x + dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], + input=x[0], + rtol=1e-02, + atol=1e-02, ) -# Log1p +# fill @handle_frontend_test( - fn_tree="tensorflow.raw_ops.Log1p", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - safety_factor_scale="log", + fn_tree="tensorflow.raw_ops.Fill", + shape=helpers.get_shape( + allow_none=False, + min_num_dims=1, + min_dim_size=1, ), + fill_value=_fill_value(), + dtypes=_dtypes(), test_with_out=st.just(False), ) -def test_tensorflow_Log1p( # NOQA +def test_tensorflow_Fill( # NOQA *, - dtype_and_x, + shape, + fill_value, + dtypes, frontend, test_flags, fn_tree, backend_fw, on_device, ): - input_dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=dtypes, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], + rtol=1e-05, + dims=shape, + value=fill_value, ) -# Sinh +# Floor @handle_frontend_test( - fn_tree="tensorflow.raw_ops.Sinh", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), + fn_tree="tensorflow.raw_ops.Floor", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + ), test_with_out=st.just(False), ) -def test_tensorflow_Sinh( # NOQA +def test_tensorflow_Floor( # NOQA *, dtype_and_x, frontend, @@ -1670,20 +1853,17 @@ def test_tensorflow_Sinh( # NOQA ) -# RealDiv +# FloorDiv @handle_frontend_test( - fn_tree="tensorflow.raw_ops.RealDiv", + fn_tree="tensorflow.raw_ops.FloorDiv", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), + available_dtypes=helpers.get_dtypes("float"), num_arrays=2, shared_dtype=True, - large_abs_safety_factor=8, - small_abs_safety_factor=8, - safety_factor_scale="log", ), test_with_out=st.just(False), ) -def test_tensorflow_RealDiv( # NOQA +def test_tensorflow_FloorDiv( # NOQA *, dtype_and_x, frontend, @@ -1700,144 +1880,99 @@ def test_tensorflow_RealDiv( # NOQA test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - atol=1e-03, - rtol=1e-03, x=xs[0], y=xs[1], ) -# Reshape -@st.composite -def _reshape_helper(draw): - # generate a shape s.t len(shape) > 0 - shape = draw(helpers.get_shape(min_num_dims=1)) - reshape_shape = draw(helpers.reshape_shapes(shape=shape)) - dtype = draw(helpers.array_dtypes(num_arrays=1)) - x = draw(helpers.array_values(dtype=dtype[0], shape=shape)) - return x, dtype, reshape_shape - - +# FloorMod @handle_frontend_test( - fn_tree="tensorflow.raw_ops.Reshape", + fn_tree="tensorflow.raw_ops.FloorMod", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, + shared_dtype=True, + ), test_with_out=st.just(False), - x_reshape=_reshape_helper(), ) -def test_tensorflow_Reshape( # NOQA +def test_tensorflow_FloorMod( # NOQA *, - x_reshape, + dtype_and_x, frontend, test_flags, fn_tree, backend_fw, on_device, ): - x, dtype, shape = x_reshape + input_dtype, xs = dtype_and_x helpers.test_frontend_function( - input_dtypes=dtype, + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - tensor=x, - shape=shape, - ) - - -# Reverse -@st.composite -def reverse_helper(draw): - dtype, x, shape = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - min_num_dims=1, - max_num_dims=8, - ret_shape=True, - ) - ) - axis_dtype, axis = draw( - helpers.dtype_and_values( - available_dtypes=["bool"], - min_num_dims=1, - max_num_dims=1, - num_arrays=1, - shape=(len(shape),), - ) - ) - return dtype, x, axis_dtype, axis - - -@handle_frontend_test( - fn_tree="tensorflow.raw_ops.Reverse", - dtype_x_axis=reverse_helper(), -) -def test_tensorflow_Reverse( - *, - dtype_x_axis, - frontend, - fn_tree, - test_flags, - on_device, - backend_fw, -): - dtype, x, axis_dtype, axis = dtype_x_axis - helpers.test_frontend_function( - input_dtypes=dtype + axis_dtype, - backend_to_test=backend_fw, - test_flags=test_flags, - frontend=frontend, - fn_tree=fn_tree, - on_device=on_device, - tensor=x[0], - dims=axis[0], + x=xs[0], + y=xs[1], ) -# ZerosLike +# Gather @handle_frontend_test( - fn_tree="tensorflow.raw_ops.ZerosLike", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), + fn_tree="tensorflow.raw_ops.Gather", + params_indices_others=helpers.array_indices_axis( + array_dtypes=helpers.get_dtypes("numeric"), + indices_dtypes=["int32", "int64"], + disable_random_axis=True, + axis_zero=True, + min_num_dims=1, + max_num_dims=5, + min_dim_size=1, + max_dim_size=10, + ), test_with_out=st.just(False), ) -def test_tensorflow_zeros_like( # NOQA +def test_tensorflow_Gather( # NOQA *, - dtype_and_x, + params_indices_others, frontend, test_flags, fn_tree, backend_fw, on_device, ): - dtype, x = dtype_and_x + dtypes, params, indices = params_indices_others helpers.test_frontend_function( - input_dtypes=dtype, + input_dtypes=dtypes, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], + params=params, + indices=indices, + validate_indices=True, ) -# LogSoftmax +# Greater @handle_frontend_test( - fn_tree="tensorflow.raw_ops.LogSoftmax", + fn_tree="tensorflow.raw_ops.Greater", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - min_num_dims=1, + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, + shared_dtype=True, ), test_with_out=st.just(False), ) -def test_tensorflow_LogSoftmax( # NOQA +def test_tensorflow_Greater( # NOQA *, dtype_and_x, - on_device, - fn_tree, frontend, test_flags, + fn_tree, backend_fw, + on_device, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( @@ -1847,21 +1982,22 @@ def test_tensorflow_LogSoftmax( # NOQA test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - logits=x[0], + x=x[0], + y=x[1], ) -# LogicalOr +# GreaterEqual @handle_frontend_test( - fn_tree="tensorflow.raw_ops.LogicalOr", + fn_tree="tensorflow.raw_ops.GreaterEqual", dtype_and_x=helpers.dtype_and_values( - dtype=["bool", "bool"], + available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2, shared_dtype=True, ), test_with_out=st.just(False), ) -def test_tensorflow_LogicalOr( # NOQA +def test_tensorflow_GreaterEqual( # NOQA *, dtype_and_x, frontend, @@ -1883,17 +2019,15 @@ def test_tensorflow_LogicalOr( # NOQA ) -# LogicalNot +# Identity @handle_frontend_test( - fn_tree="tensorflow.raw_ops.LogicalNot", + fn_tree="tensorflow.raw_ops.Identity", dtype_and_x=helpers.dtype_and_values( - dtype=["bool"], - num_arrays=1, - shared_dtype=True, + available_dtypes=helpers.get_dtypes("numeric"), ), test_with_out=st.just(False), ) -def test_tensorflow_LogicalNot( # NOQA +def test_tensorflow_Identity( # NOQA *, dtype_and_x, frontend, @@ -1902,28 +2036,27 @@ def test_tensorflow_LogicalNot( # NOQA backend_fw, on_device, ): - input_dtype, x = dtype_and_x + dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], + input=x[0], ) -# Shape +# IdentityN @handle_frontend_test( - fn_tree="tensorflow.raw_ops.Shape", + fn_tree="tensorflow.raw_ops.IdentityN", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), - min_num_dims=1, ), test_with_out=st.just(False), ) -def test_tensorflow_Shape( # NOQA +def test_tensorflow_IdentityN( # NOQA *, dtype_and_x, frontend, @@ -1932,73 +2065,83 @@ def test_tensorflow_Shape( # NOQA backend_fw, on_device, ): - input_dtype, x = dtype_and_x + dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=x[0], + input=x, ) @handle_frontend_test( - fn_tree="tensorflow.raw_ops.ShapeN", + fn_tree="tensorflow.raw_ops.Igamma", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), max_num_dims=4 + available_dtypes=helpers.get_dtypes("valid"), + num_arrays=2, + shared_dtype=True, + abs_smallest_val=1e-5, + min_num_dims=2, + max_num_dims=2, + min_dim_size=3, + max_dim_size=3, + min_value=2, + max_value=100, + allow_nan=False, ), - output_dtype=st.sampled_from(["int32", "int64"]), test_with_out=st.just(False), ) -def test_tensorflow_ShapeN( # NOQA +def test_tensorflow_Igamma( *, dtype_and_x, - output_dtype, on_device, fn_tree, + backend_fw, frontend, test_flags, - backend_fw, ): - input_dtype, input = dtype_and_x + input_dtype, xs = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, - backend_to_test=backend_fw, frontend=frontend, + backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=input, - out_type=output_dtype, + rtol=1e-04, + a=xs[0], + x=xs[1], ) -# AddN +# Imag @handle_frontend_test( - fn_tree="tensorflow.raw_ops.AddN", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - min_num_dims=1, - large_abs_safety_factor=8, - small_abs_safety_factor=8, - safety_factor_scale="log", - min_value=-1e04, - max_value=1e04, + fn_tree="tensorflow.raw_ops.Imag", + dtype_and_xs=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), ), + send_Tout=st.booleans(), test_with_out=st.just(False), ) -def test_tensorflow_AddN( # NOQA +def test_tensorflow_Imag( *, - dtype_and_x, + dtype_and_xs, + send_Tout, frontend, test_flags, fn_tree, - backend_fw, on_device, + backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, xs = dtype_and_xs + if input_dtype[0] == "complex128": + send_Tout = "float64" + elif input_dtype[0] == "complex64": + send_Tout = "float32" if send_Tout else None + helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -2006,26 +2149,19 @@ def test_tensorflow_AddN( # NOQA test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - inputs=x[0], + input=xs[0], + Tout=send_Tout, ) -# Neg @handle_frontend_test( - fn_tree="tensorflow.raw_ops.Neg", + fn_tree="tensorflow.raw_ops.Inv", dtype_and_x=helpers.dtype_and_values( - available_dtypes=[ - "float32", - "float64", - "int8", - "int16", - "int32", - "int64", - ], + available_dtypes=helpers.get_dtypes("numeric") ), test_with_out=st.just(False), ) -def test_tensorflow_Neg( # NOQA +def test_tensorflow_Inv( # NOQA *, dtype_and_x, frontend, @@ -2034,9 +2170,9 @@ def test_tensorflow_Neg( # NOQA backend_fw, on_device, ): - input_dtype, x = dtype_and_x + dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, @@ -2046,17 +2182,17 @@ def test_tensorflow_Neg( # NOQA ) -# Equal @handle_frontend_test( - fn_tree="tensorflow.raw_ops.Equal", + fn_tree="tensorflow.raw_ops.InvGrad", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), + available_dtypes=helpers.get_dtypes("float"), + min_num_dims=1, num_arrays=2, shared_dtype=True, ), test_with_out=st.just(False), ) -def test_tensorflow_Equal( # NOQA +def test_tensorflow_InvGrad( # NOQA *, dtype_and_x, frontend, @@ -2065,30 +2201,27 @@ def test_tensorflow_Equal( # NOQA backend_fw, on_device, ): - input_dtype, x = dtype_and_x + dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], - y=x[1], + y=x[0], + dy=x[1], ) -# NotEqual @handle_frontend_test( - fn_tree="tensorflow.raw_ops.NotEqual", + fn_tree="tensorflow.raw_ops.Invert", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, - shared_dtype=True, + available_dtypes=helpers.get_dtypes("integer"), ), test_with_out=st.just(False), ) -def test_tensorflow_NotEqual( # NOQA +def test_tensorflow_Invert( # NOQA *, dtype_and_x, frontend, @@ -2097,72 +2230,60 @@ def test_tensorflow_NotEqual( # NOQA backend_fw, on_device, ): - input_dtype, x = dtype_and_x + dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, x=x[0], - y=x[1], ) -# Cumsum @handle_frontend_test( - fn_tree="tensorflow.raw_ops.Cumsum", - dtype_x_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("numeric"), - valid_axis=True, - force_int_axis=True, + fn_tree="tensorflow.raw_ops.LeakyRelu", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), min_num_dims=1, - min_value=-5, - max_value=5, ), - exclusive=st.booleans(), - reverse=st.booleans(), test_with_out=st.just(False), + alpha=helpers.floats(min_value=0, max_value=1), ) -def test_tensorflow_Cumsum( # NOQA +def test_tensorflow_LeakyReLU( *, - dtype_x_axis, - exclusive, - reverse, + dtype_and_x, + alpha, frontend, test_flags, fn_tree, backend_fw, on_device, ): - dtype, x, axis = dtype_x_axis - helpers.test_frontend_function( + dtype, x = dtype_and_x + return helpers.test_frontend_function( input_dtypes=dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - rtol=1e-02, - atol=1e-02, - x=x[0], - axis=axis, - exclusive=exclusive, - reverse=reverse, + features=x[0], + alpha=alpha, ) -# Relu @handle_frontend_test( - fn_tree="tensorflow.raw_ops.Relu", + fn_tree="tensorflow.raw_ops.LeftShift", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - min_num_dims=1, + available_dtypes=helpers.get_dtypes("integer"), + num_arrays=2, + shared_dtype=True, ), test_with_out=st.just(False), ) -def test_tensorflow_Relu( # NOQA +def test_tensorflow_LeftShift( # NOQA *, dtype_and_x, frontend, @@ -2171,7 +2292,7 @@ def test_tensorflow_Relu( # NOQA backend_fw, on_device, ): - dtype, x = dtype_and_x + dtype, xs = dtype_and_x helpers.test_frontend_function( input_dtypes=dtype, backend_to_test=backend_fw, @@ -2179,43 +2300,31 @@ def test_tensorflow_Relu( # NOQA test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - features=x[0], + x=xs[0], + y=xs[1], ) -# MatMul +# Less @handle_frontend_test( - fn_tree="tensorflow.raw_ops.MatMul", + fn_tree="tensorflow.raw_ops.Less", dtype_and_x=helpers.dtype_and_values( - available_dtypes=[ - "float32", - "float64", - "int32", - "int64", - ], - shape=(3, 3), + available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2, shared_dtype=True, - large_abs_safety_factor=10, - small_abs_safety_factor=10, - safety_factor_scale="log", ), - transpose_a=st.booleans(), - transpose_b=st.booleans(), test_with_out=st.just(False), ) -def test_tensorflow_MatMul( # NOQA +def test_tensorflow_Less( # NOQA *, dtype_and_x, - transpose_a, - transpose_b, frontend, test_flags, fn_tree, backend_fw, on_device, ): - input_dtype, x = dtype_and_x + input_dtype, xs = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -2223,104 +2332,81 @@ def test_tensorflow_MatMul( # NOQA test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - atol=1e-2, - a=x[0], - b=x[1], - transpose_a=transpose_a, - transpose_b=transpose_b, + x=xs[0], + y=xs[1], ) -# Cumprod +# LessEqual @handle_frontend_test( - fn_tree="tensorflow.raw_ops.Cumprod", - dtype_x_axis=helpers.dtype_values_axis( + fn_tree="tensorflow.raw_ops.LessEqual", + dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), - valid_axis=True, - force_int_axis=True, - min_num_dims=1, - min_value=-5, - max_value=5, + num_arrays=2, + shared_dtype=True, ), - exclusive=st.booleans(), - reverse=st.booleans(), test_with_out=st.just(False), ) -def test_tensorflow_Cumprod( # NOQA +def test_tensorflow_LessEqual( # NOQA *, - dtype_x_axis, - exclusive, - reverse, + dtype_and_x, frontend, test_flags, fn_tree, backend_fw, on_device, ): - dtype, x, axis = dtype_x_axis + input_dtype, xs = dtype_and_x helpers.test_frontend_function( - input_dtypes=dtype, + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], - axis=axis, - exclusive=exclusive, - reverse=reverse, + x=xs[0], + y=xs[1], ) -# Gather @handle_frontend_test( - fn_tree="tensorflow.raw_ops.Gather", - params_indices_others=helpers.array_indices_axis( - array_dtypes=helpers.get_dtypes("numeric"), - indices_dtypes=["int32", "int64"], - disable_random_axis=True, - axis_zero=True, - min_num_dims=1, - max_num_dims=5, - min_dim_size=1, - max_dim_size=10, - ), - test_with_out=st.just(False), + fn_tree="tensorflow.raw_ops.LinSpace", + dtype_and_params=_LinSpace_helper(), + num=helpers.ints(min_value=2, max_value=10), ) -def test_tensorflow_Gather( # NOQA +def test_tensorflow_LinSpace( *, - params_indices_others, + dtype_and_params, + num, + on_device, + fn_tree, frontend, test_flags, - fn_tree, backend_fw, - on_device, ): - dtypes, params, indices = params_indices_others + dtype, start, stop = dtype_and_params helpers.test_frontend_function( - input_dtypes=dtypes, + input_dtypes=dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, + start=start, + stop=stop, + num=num, on_device=on_device, - params=params, - indices=indices, - validate_indices=True, ) -# Greater +# Log @handle_frontend_test( - fn_tree="tensorflow.raw_ops.Greater", + fn_tree="tensorflow.raw_ops.Log", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, - shared_dtype=True, + available_dtypes=helpers.get_dtypes("float"), ), test_with_out=st.just(False), ) -def test_tensorflow_Greater( # NOQA +def test_tensorflow_Log( # NOQA *, dtype_and_x, frontend, @@ -2338,21 +2424,19 @@ def test_tensorflow_Greater( # NOQA fn_tree=fn_tree, on_device=on_device, x=x[0], - y=x[1], ) -# GreaterEqual +# Log1p @handle_frontend_test( - fn_tree="tensorflow.raw_ops.GreaterEqual", + fn_tree="tensorflow.raw_ops.Log1p", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, - shared_dtype=True, + available_dtypes=helpers.get_dtypes("float"), + safety_factor_scale="log", ), test_with_out=st.just(False), ) -def test_tensorflow_GreaterEqual( # NOQA +def test_tensorflow_Log1p( # NOQA *, dtype_and_x, frontend, @@ -2370,59 +2454,50 @@ def test_tensorflow_GreaterEqual( # NOQA fn_tree=fn_tree, on_device=on_device, x=x[0], - y=x[1], ) -# Mean +# LogSoftmax @handle_frontend_test( - fn_tree="tensorflow.raw_ops.Mean", - dtype_x_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("float"), - valid_axis=True, - force_int_axis=True, + fn_tree="tensorflow.raw_ops.LogSoftmax", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), min_num_dims=1, - min_value=-10, - max_value=3, ), - keep_dims=st.booleans(), test_with_out=st.just(False), ) -def test_tensorflow_Mean( # NOQA +def test_tensorflow_LogSoftmax( # NOQA *, - dtype_x_axis, - keep_dims, + dtype_and_x, + on_device, + fn_tree, frontend, test_flags, - fn_tree, backend_fw, - on_device, ): - dtype, x, axis = dtype_x_axis + input_dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=dtype, + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=x[0], - axis=axis, - keep_dims=keep_dims, - rtol=1e-02, - atol=1e-02, + logits=x[0], ) -# Identity +# LogicalNot @handle_frontend_test( - fn_tree="tensorflow.raw_ops.Identity", + fn_tree="tensorflow.raw_ops.LogicalNot", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), + dtype=["bool"], + num_arrays=1, + shared_dtype=True, ), test_with_out=st.just(False), ) -def test_tensorflow_Identity( # NOQA +def test_tensorflow_LogicalNot( # NOQA *, dtype_and_x, frontend, @@ -2431,27 +2506,29 @@ def test_tensorflow_Identity( # NOQA backend_fw, on_device, ): - dtype, x = dtype_and_x + input_dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=dtype, + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=x[0], + x=x[0], ) -# IdentityN +# LogicalOr @handle_frontend_test( - fn_tree="tensorflow.raw_ops.IdentityN", + fn_tree="tensorflow.raw_ops.LogicalOr", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), + dtype=["bool", "bool"], + num_arrays=2, + shared_dtype=True, ), test_with_out=st.just(False), ) -def test_tensorflow_IdentityN( # NOQA +def test_tensorflow_LogicalOr( # NOQA *, dtype_and_x, frontend, @@ -2460,61 +2537,85 @@ def test_tensorflow_IdentityN( # NOQA backend_fw, on_device, ): - dtype, x = dtype_and_x + input_dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=dtype, + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=x, + x=x[0], + y=x[1], ) +# MatMul @handle_frontend_test( - fn_tree="tensorflow.raw_ops.Inv", + fn_tree="tensorflow.raw_ops.MatMul", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric") + available_dtypes=[ + "float32", + "float64", + "int32", + "int64", + ], + shape=(3, 3), + num_arrays=2, + shared_dtype=True, + large_abs_safety_factor=10, + small_abs_safety_factor=10, + safety_factor_scale="log", ), + transpose_a=st.booleans(), + transpose_b=st.booleans(), test_with_out=st.just(False), ) -def test_tensorflow_Inv( # NOQA +def test_tensorflow_MatMul( # NOQA *, dtype_and_x, + transpose_a, + transpose_b, frontend, test_flags, fn_tree, backend_fw, on_device, ): - dtype, x = dtype_and_x + input_dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=dtype, + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], + atol=1e-2, + a=x[0], + b=x[1], + transpose_a=transpose_a, + transpose_b=transpose_b, ) -# reciprocal @handle_frontend_test( - fn_tree="tensorflow.raw_ops.Reciprocal", + fn_tree="tensorflow.raw_ops.MatrixDeterminant", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - min_value=1, + shape=helpers.ints(min_value=2, max_value=5).map(lambda x: tuple([x, x])), + min_value=-5, + max_value=5, ), test_with_out=st.just(False), ) -def test_tensorflow_Reciprocal( # NOQA +def test_tensorflow_MatrixDeterminant( # NOQA + *, dtype_and_x, frontend, test_flags, fn_tree, backend_fw, + on_device, ): dtype, x = dtype_and_x helpers.test_frontend_function( @@ -2523,63 +2624,70 @@ def test_tensorflow_Reciprocal( # NOQA frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, - x=x[0], + on_device=on_device, + input=x[0], ) @handle_frontend_test( - fn_tree="tensorflow.raw_ops.OnesLike", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric") - ), + fn_tree="tensorflow.raw_ops.MatrixInverse", + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + shape=helpers.ints(min_value=2, max_value=10).map(lambda x: tuple([x, x])), + ).filter(lambda x: np.linalg.cond(x[1][0].tolist()) < 1 / sys.float_info.epsilon), + adjoint=st.booleans(), test_with_out=st.just(False), ) -def test_tensorflow_OnesLike( # NOQA +def test_tensorflow_MatrixInverse( # NOQA *, - dtype_and_x, + dtype_x, + adjoint, frontend, test_flags, fn_tree, backend_fw, on_device, ): - dtype, x = dtype_and_x + input_dtype, x = dtype_x helpers.test_frontend_function( - input_dtypes=dtype, + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], + input=x[0], + adjoint=adjoint, + rtol=1e-05, + atol=1e-04, ) +# Max @handle_frontend_test( - fn_tree="tensorflow.raw_ops.Cholesky", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_value=0, - max_value=10, - shape=helpers.ints(min_value=2, max_value=5).map(lambda x: tuple([x, x])), + fn_tree="tensorflow.raw_ops.Max", + dtype_x_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("numeric"), + valid_axis=True, + force_int_axis=True, + min_num_dims=1, + min_value=-5, + max_value=5, ), + keep_dims=st.booleans(), test_with_out=st.just(False), ) -def test_tensorflow_Cholesky( # NOQA +def test_tensorflow_Max( # NOQA *, - dtype_and_x, + dtype_x_axis, + keep_dims, frontend, test_flags, fn_tree, backend_fw, on_device, ): - dtype, x = dtype_and_x - x = x[0] - x = ( - np.matmul(x.T, x) + np.identity(x.shape[0]) * 1e-3 - ) # make symmetric positive-definite - + dtype, x, axis = dtype_x_axis helpers.test_frontend_function( input_dtypes=dtype, backend_to_test=backend_fw, @@ -2587,22 +2695,23 @@ def test_tensorflow_Cholesky( # NOQA test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=x, - rtol=1e-4, - atol=1e-4, + input=x[0], + axis=axis, + keep_dims=keep_dims, ) +# Maximum @handle_frontend_test( - fn_tree="tensorflow.raw_ops.Mul", + fn_tree="tensorflow.raw_ops.Maximum", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), + available_dtypes=helpers.get_dtypes("float"), num_arrays=2, shared_dtype=True, ), test_with_out=st.just(False), ) -def test_tensorflow_Mul( # NOQA +def test_tensorflow_Maximum( # NOQA *, dtype_and_x, frontend, @@ -2624,21 +2733,21 @@ def test_tensorflow_Mul( # NOQA ) -# Min +# Mean @handle_frontend_test( - fn_tree="tensorflow.raw_ops.Min", + fn_tree="tensorflow.raw_ops.Mean", dtype_x_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("numeric"), + available_dtypes=helpers.get_dtypes("float"), valid_axis=True, force_int_axis=True, min_num_dims=1, - min_value=-5, - max_value=5, + min_value=-10, + max_value=3, ), keep_dims=st.booleans(), test_with_out=st.just(False), ) -def test_tensorflow_Min( # NOQA +def test_tensorflow_Mean( # NOQA *, dtype_x_axis, keep_dims, @@ -2659,12 +2768,14 @@ def test_tensorflow_Min( # NOQA input=x[0], axis=axis, keep_dims=keep_dims, + rtol=1e-02, + atol=1e-02, ) -# Max +# Min @handle_frontend_test( - fn_tree="tensorflow.raw_ops.Max", + fn_tree="tensorflow.raw_ops.Min", dtype_x_axis=helpers.dtype_values_axis( available_dtypes=helpers.get_dtypes("numeric"), valid_axis=True, @@ -2676,7 +2787,7 @@ def test_tensorflow_Min( # NOQA keep_dims=st.booleans(), test_with_out=st.just(False), ) -def test_tensorflow_Max( # NOQA +def test_tensorflow_Min( # NOQA *, dtype_x_axis, keep_dims, @@ -2700,16 +2811,17 @@ def test_tensorflow_Max( # NOQA ) +# Minimum @handle_frontend_test( - fn_tree="tensorflow.raw_ops.LeftShift", + fn_tree="tensorflow.raw_ops.Minimum", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("integer"), + available_dtypes=helpers.get_dtypes("float"), num_arrays=2, shared_dtype=True, ), test_with_out=st.just(False), ) -def test_tensorflow_LeftShift( # NOQA +def test_tensorflow_Minimum( # NOQA *, dtype_and_x, frontend, @@ -2718,9 +2830,9 @@ def test_tensorflow_LeftShift( # NOQA backend_fw, on_device, ): - dtype, xs = dtype_and_x + input_dtype, xs = dtype_and_x helpers.test_frontend_function( - input_dtypes=dtype, + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, @@ -2732,16 +2844,83 @@ def test_tensorflow_LeftShift( # NOQA @handle_frontend_test( - fn_tree="tensorflow.raw_ops.MatrixDeterminant", + fn_tree="tensorflow.raw_ops.Mul", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - shape=helpers.ints(min_value=2, max_value=5).map(lambda x: tuple([x, x])), - min_value=-5, - max_value=5, + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, + shared_dtype=True, + ), + test_with_out=st.just(False), +) +def test_tensorflow_Mul( # NOQA + *, + dtype_and_x, + frontend, + test_flags, + fn_tree, + backend_fw, + on_device, +): + input_dtype, xs = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=xs[0], + y=xs[1], + ) + + +# Neg +@handle_frontend_test( + fn_tree="tensorflow.raw_ops.Neg", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=[ + "float32", + "float64", + "int8", + "int16", + "int32", + "int64", + ], + ), + test_with_out=st.just(False), +) +def test_tensorflow_Neg( # NOQA + *, + dtype_and_x, + frontend, + test_flags, + fn_tree, + backend_fw, + on_device, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + ) + + +# NotEqual +@handle_frontend_test( + fn_tree="tensorflow.raw_ops.NotEqual", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, + shared_dtype=True, ), test_with_out=st.just(False), ) -def test_tensorflow_MatrixDeterminant( # NOQA +def test_tensorflow_NotEqual( # NOQA *, dtype_and_x, frontend, @@ -2750,15 +2929,16 @@ def test_tensorflow_MatrixDeterminant( # NOQA backend_fw, on_device, ): - dtype, x = dtype_and_x + input_dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=dtype, + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=x[0], + x=x[0], + y=x[1], ) @@ -2799,13 +2979,13 @@ def test_tensorflow_NthElement( # NOQA @handle_frontend_test( - fn_tree="tensorflow.raw_ops.Invert", + fn_tree="tensorflow.raw_ops.OnesLike", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("integer"), + available_dtypes=helpers.get_dtypes("numeric") ), test_with_out=st.just(False), ) -def test_tensorflow_Invert( # NOQA +def test_tensorflow_OnesLike( # NOQA *, dtype_and_x, frontend, @@ -2827,162 +3007,85 @@ def test_tensorflow_Invert( # NOQA @handle_frontend_test( - fn_tree="tensorflow.raw_ops.InvGrad", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + fn_tree="tensorflow.raw_ops.Pack", + dtype_x_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("valid"), + valid_axis=True, + force_int_axis=True, min_num_dims=1, - num_arrays=2, - shared_dtype=True, ), test_with_out=st.just(False), ) -def test_tensorflow_InvGrad( # NOQA - *, - dtype_and_x, +def test_tensorflow_Pack( # NOQA + dtype_x_axis, + fn_tree, frontend, test_flags, - fn_tree, backend_fw, - on_device, ): - dtype, x = dtype_and_x + dtype, x, axis = dtype_x_axis helpers.test_frontend_function( input_dtypes=dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, - on_device=on_device, - y=x[0], - dy=x[1], + values=x, + axis=axis, ) @handle_frontend_test( - fn_tree="tensorflow.raw_ops.Ceil", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - ), + fn_tree="tensorflow.raw_ops.Pad", + dtype_x_paddings=_pad_helper(), + number_positional_args=st.just(0), test_with_out=st.just(False), ) -def test_tensorflow_Ceil( # NOQA - *, - dtype_and_x, +def test_tensorflow_Pad( # NOQA + dtype_x_paddings, frontend, test_flags, fn_tree, backend_fw, - on_device, ): - input_dtype, x = dtype_and_x + dtype, x, padding_dtype, paddings = dtype_x_paddings helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=dtype + padding_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, - on_device=on_device, - x=x[0], + input=x, + paddings=paddings, ) +# TODO: Fails with torch backend +# ivy.exceptions.IvyBackendException: torch: constant_pad: constant_pad_nd(): argument +# 'value' (position 3) must be Number, not bfloat16 @handle_frontend_test( - fn_tree="tensorflow.raw_ops.Diag", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=[ - "float32", - "float64", - "int32", - "int64", - ], - min_num_dims=1, - max_num_dims=1, - min_value=-1e30, - max_value=1e30, - ), + fn_tree="tensorflow.raw_ops.PadV2", + dtype_x_paddings=_pad_helper(return_constant_values=True), test_with_out=st.just(False), ) -def test_tensorflow_Diag( # NOQA - *, - dtype_and_x, +def test_tensorflow_PadV2( + dtype_x_paddings, frontend, test_flags, fn_tree, backend_fw, - on_device, ): - dtype, x = dtype_and_x + dtype, x, padding_dtype, paddings, constant_values = dtype_x_paddings helpers.test_frontend_function( - input_dtypes=dtype, + input_dtypes=dtype + padding_dtype + dtype, backend_to_test=backend_fw, - frontend=frontend, test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - diagonal=x[0], - ) - - -@handle_frontend_test( - fn_tree="tensorflow.raw_ops.RightShift", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("integer"), - num_arrays=2, - shared_dtype=True, - min_value=0, - max_value=8, - ), - test_with_out=st.just(False), -) -def test_tensorflow_RightShift( # NOQA - *, - dtype_and_x, - frontend, - test_flags, - fn_tree, - backend_fw, - on_device, -): - dtype, xs = dtype_and_x - helpers.test_frontend_function( - input_dtypes=dtype, - backend_to_test=backend_fw, frontend=frontend, - test_flags=test_flags, fn_tree=fn_tree, - on_device=on_device, - x=xs[0], - y=xs[1], - ) - - -@st.composite -def _pow_helper_shared_dtype(draw): - dtype, x = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - shared_dtype=True, - ) + input=x, + paddings=paddings, + constant_values=constant_values, ) - dtype1, dtype2 = dtype - x1, x2 = x - if "int" in dtype2: - x2 = ivy.nested_map(x2, lambda x: abs(x), include_derived={list: True}) - - if ivy.is_int_dtype(dtype2): - max_val = ivy.iinfo(dtype2).max - else: - max_val = ivy.finfo(dtype2).max - max_x1 = np.max(np.abs(x1)) - if max_x1 in [0, 1]: - max_value = None - else: - max_value = int(math.log(max_val) / math.log(max_x1)) - if abs(max_value) > abs(max_val) / 40 or max_value < 0: - max_value = None - - return [dtype1, dtype2], [x1, x2] # Pow @@ -3014,7 +3117,7 @@ def test_tensorflow_Pow( # NOQA @handle_frontend_test( - fn_tree="tensorflow.raw_ops.Sum", + fn_tree="tensorflow.raw_ops.Prod", dtype_x_axis=helpers.dtype_values_axis( available_dtypes=helpers.get_dtypes("numeric"), valid_axis=True, @@ -3026,7 +3129,7 @@ def test_tensorflow_Pow( # NOQA keep_dims=st.booleans(), test_with_out=st.just(False), ) -def test_tensorflow_Sum( # NOQA +def test_tensorflow_Prod( # NOQA *, dtype_x_axis, keep_dims, @@ -3050,81 +3153,52 @@ def test_tensorflow_Sum( # NOQA ) +# Todo: Revise strategies once reimplemented in frontend +# Real @handle_frontend_test( - fn_tree="tensorflow.raw_ops.TruncateDiv", + fn_tree="tensorflow.raw_ops.Real", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("integer"), num_arrays=2, shared_dtype=True + available_dtypes=helpers.get_dtypes("float"), ), test_with_out=st.just(False), + Tout=st.sampled_from(["float32", "float64"]), ) -def test_tensorflow_TruncateDiv( # NOQA - *, +def test_tensorflow_Real( dtype_and_x, frontend, test_flags, fn_tree, backend_fw, on_device, + Tout, ): - dtype, xs = dtype_and_x - # prevent too close to zero - assume(not np.any(np.isclose(xs[1], 0))) - - helpers.test_frontend_function( - input_dtypes=dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=xs[0], - y=xs[1], - ) - - -@handle_frontend_test( - fn_tree="tensorflow.raw_ops.MatrixInverse", - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - shape=helpers.ints(min_value=2, max_value=10).map(lambda x: tuple([x, x])), - ).filter(lambda x: np.linalg.cond(x[1][0].tolist()) < 1 / sys.float_info.epsilon), - adjoint=st.booleans(), - test_with_out=st.just(False), -) -def test_tensorflow_MatrixInverse( # NOQA - *, - dtype_x, - adjoint, - frontend, - test_flags, - fn_tree, - backend_fw, - on_device, -): - input_dtype, x = dtype_x + input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, - frontend=frontend, test_flags=test_flags, + frontend=frontend, fn_tree=fn_tree, on_device=on_device, input=x[0], - adjoint=adjoint, - rtol=1e-05, - atol=1e-04, + Tout=Tout, ) +# RealDiv @handle_frontend_test( - fn_tree="tensorflow.raw_ops.Relu6", + fn_tree="tensorflow.raw_ops.RealDiv", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), - min_num_dims=1, + num_arrays=2, + shared_dtype=True, + large_abs_safety_factor=8, + small_abs_safety_factor=8, + safety_factor_scale="log", ), test_with_out=st.just(False), ) -def test_tensorflow_Relu6( # NOQA +def test_tensorflow_RealDiv( # NOQA *, dtype_and_x, frontend, @@ -3133,66 +3207,67 @@ def test_tensorflow_Relu6( # NOQA backend_fw, on_device, ): - dtype, x = dtype_and_x + input_dtype, xs = dtype_and_x helpers.test_frontend_function( - input_dtypes=dtype, + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - features=x[0], + atol=1e-03, + rtol=1e-03, + x=xs[0], + y=xs[1], ) +# reciprocal @handle_frontend_test( - fn_tree="tensorflow.raw_ops.Round", + fn_tree="tensorflow.raw_ops.Reciprocal", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), + min_value=1, ), test_with_out=st.just(False), ) -def test_tensorflow_Round( # NOQA - *, +def test_tensorflow_Reciprocal( # NOQA dtype_and_x, frontend, test_flags, fn_tree, backend_fw, - on_device, ): - input_dtype, x = dtype_and_x + dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, - on_device=on_device, x=x[0], ) +# Relu @handle_frontend_test( - fn_tree="tensorflow.raw_ops.Unpack", - dtype_x_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("valid"), - valid_axis=True, - force_int_axis=True, + fn_tree="tensorflow.raw_ops.Relu", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), min_num_dims=1, ), test_with_out=st.just(False), ) -def test_tensorflow_Unpack( # NOQA +def test_tensorflow_Relu( # NOQA *, - dtype_x_axis, + dtype_and_x, frontend, test_flags, fn_tree, backend_fw, on_device, ): - dtype, x, axis = dtype_x_axis + dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=dtype, backend_to_test=backend_fw, @@ -3200,22 +3275,19 @@ def test_tensorflow_Unpack( # NOQA test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - value=x[0], - num=x[0].shape[axis], - axis=axis, + features=x[0], ) -# Sigmoid @handle_frontend_test( - fn_tree="tensorflow.raw_ops.Sigmoid", + fn_tree="tensorflow.raw_ops.Relu6", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("numeric"), min_num_dims=1, ), test_with_out=st.just(False), ) -def test_tensorflow_Sigmoid( # NOQA +def test_tensorflow_Relu6( # NOQA *, dtype_and_x, frontend, @@ -3232,28 +3304,25 @@ def test_tensorflow_Sigmoid( # NOQA test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], + features=x[0], ) @handle_frontend_test( - fn_tree="tensorflow.raw_ops.Softplus", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_num_dims=1, - ), + fn_tree="tensorflow.raw_ops.Reshape", test_with_out=st.just(False), + x_reshape=_reshape_helper(), ) -def test_tensorflow_Softplus( # NOQA +def test_tensorflow_Reshape( # NOQA *, - dtype_and_x, + x_reshape, frontend, test_flags, fn_tree, backend_fw, on_device, ): - dtype, x = dtype_and_x + x, dtype, shape = x_reshape helpers.test_frontend_function( input_dtypes=dtype, backend_to_test=backend_fw, @@ -3261,20 +3330,49 @@ def test_tensorflow_Softplus( # NOQA test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - features=x[0], + tensor=x, + shape=shape, ) @handle_frontend_test( - fn_tree="tensorflow.raw_ops.Xdivy", + fn_tree="tensorflow.raw_ops.Reverse", + dtype_x_axis=reverse_helper(), +) +def test_tensorflow_Reverse( + *, + dtype_x_axis, + frontend, + fn_tree, + test_flags, + on_device, + backend_fw, +): + dtype, x, axis_dtype, axis = dtype_x_axis + helpers.test_frontend_function( + input_dtypes=dtype + axis_dtype, + backend_to_test=backend_fw, + test_flags=test_flags, + frontend=frontend, + fn_tree=fn_tree, + on_device=on_device, + tensor=x[0], + dims=axis[0], + ) + + +@handle_frontend_test( + fn_tree="tensorflow.raw_ops.RightShift", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("integer"), num_arrays=2, shared_dtype=True, + min_value=0, + max_value=8, ), test_with_out=st.just(False), ) -def test_tensorflow_Xdivy( # NOQA +def test_tensorflow_RightShift( # NOQA *, dtype_and_x, frontend, @@ -3283,9 +3381,9 @@ def test_tensorflow_Xdivy( # NOQA backend_fw, on_device, ): - input_dtype, xs = dtype_and_x + dtype, xs = dtype_and_x helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, @@ -3297,15 +3395,13 @@ def test_tensorflow_Xdivy( # NOQA @handle_frontend_test( - fn_tree="tensorflow.raw_ops.Xlog1py", + fn_tree="tensorflow.raw_ops.Round", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - shared_dtype=True, ), test_with_out=st.just(False), ) -def test_tensorflow_Xlog1py( # NOQA +def test_tensorflow_Round( # NOQA *, dtype_and_x, frontend, @@ -3314,7 +3410,7 @@ def test_tensorflow_Xlog1py( # NOQA backend_fw, on_device, ): - input_dtype, xs = dtype_and_x + input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -3322,21 +3418,19 @@ def test_tensorflow_Xlog1py( # NOQA test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=xs[0], - y=xs[1], + x=x[0], ) +# Rsqrt @handle_frontend_test( - fn_tree="tensorflow.raw_ops.Xlogy", + fn_tree="tensorflow.raw_ops.Rsqrt", dtype_and_x=helpers.dtype_and_values( - available_dtypes=["float16", "float32", "float64"], - num_arrays=2, - shared_dtype=True, + available_dtypes=helpers.get_dtypes("float"), ), test_with_out=st.just(False), ) -def test_tensorflow_Xlogy( +def test_tensorflow_Rsqrt( *, dtype_and_x, frontend, @@ -3345,7 +3439,7 @@ def test_tensorflow_Xlogy( backend_fw, on_device, ): - input_dtype, xs = dtype_and_x + input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -3353,123 +3447,90 @@ def test_tensorflow_Xlogy( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=xs[0], - y=xs[1], + x=x[0], ) +# Shape @handle_frontend_test( - fn_tree="tensorflow.raw_ops.Pack", - dtype_x_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("valid"), - valid_axis=True, - force_int_axis=True, + fn_tree="tensorflow.raw_ops.Shape", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), min_num_dims=1, ), test_with_out=st.just(False), ) -def test_tensorflow_Pack( # NOQA - dtype_x_axis, - fn_tree, +def test_tensorflow_Shape( # NOQA + *, + dtype_and_x, frontend, test_flags, + fn_tree, backend_fw, + on_device, ): - dtype, x, axis = dtype_x_axis + input_dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=dtype, + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, - values=x, - axis=axis, - ) - - -@st.composite -def _pad_helper(draw, return_constant_values=False): - dtype, input, shape = draw( - helpers.dtype_and_values( - min_num_dims=1, - ret_shape=True, - ) - ) - ndim = len(shape) - padding_dtype, paddings = draw( - helpers.dtype_and_values( - available_dtypes=["int32", "int64"], - shape=(ndim, 2), - min_value=0, - max_value=10, - ) + on_device=on_device, + input=x[0], ) - if return_constant_values: - _, constant_values = draw( - helpers.dtype_and_values( - dtype=dtype, - shape=(1,), - ) - ) - return dtype, input[0], padding_dtype, paddings[0], constant_values[0][0] - - return dtype, input[0], padding_dtype, paddings[0] - @handle_frontend_test( - fn_tree="tensorflow.raw_ops.Pad", - dtype_x_paddings=_pad_helper(), - number_positional_args=st.just(0), + fn_tree="tensorflow.raw_ops.ShapeN", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), max_num_dims=4 + ), + output_dtype=st.sampled_from(["int32", "int64"]), test_with_out=st.just(False), ) -def test_tensorflow_Pad( # NOQA - dtype_x_paddings, +def test_tensorflow_ShapeN( # NOQA + *, + dtype_and_x, + output_dtype, + on_device, + fn_tree, frontend, test_flags, - fn_tree, backend_fw, ): - dtype, x, padding_dtype, paddings = dtype_x_paddings + input_dtype, input = dtype_and_x helpers.test_frontend_function( - input_dtypes=dtype + padding_dtype, + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, - input=x, - paddings=paddings, + on_device=on_device, + input=input, + out_type=output_dtype, ) -# EuclideanNorm +# Sigmoid @handle_frontend_test( - fn_tree="tensorflow.raw_ops.EuclideanNorm", - dtype_values_axis=helpers.dtype_values_axis( + fn_tree="tensorflow.raw_ops.Sigmoid", + dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - min_num_dims=3, - max_num_dims=5, - min_dim_size=1, - max_dim_size=4, - min_axis=-3, - max_axis=2, - valid_axis=True, - allow_neg_axes=True, + min_num_dims=1, ), - keep_dims=st.booleans(), test_with_out=st.just(False), - number_positional_args=st.just(0), ) -def test_tensorflow_EuclideanNorm( - dtype_values_axis, - keep_dims, +def test_tensorflow_Sigmoid( # NOQA + *, + dtype_and_x, frontend, test_flags, fn_tree, backend_fw, on_device, ): - dtype, values, axis = dtype_values_axis + dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=dtype, backend_to_test=backend_fw, @@ -3477,140 +3538,90 @@ def test_tensorflow_EuclideanNorm( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=values[0], - axis=axis, - keep_dims=keep_dims, + x=x[0], ) -# ConcatV2 +# Sign @handle_frontend_test( - fn_tree="tensorflow.raw_ops.ConcatV2", - xs_n_input_dtypes_n_unique_idx=_arrays_idx_n_dtypes(), + fn_tree="tensorflow.raw_ops.Sign", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + large_abs_safety_factor=5, + small_abs_safety_factor=5, + safety_factor_scale="log", + ), test_with_out=st.just(False), - number_positional_args=st.just(0), ) -def test_tensorflow_ConcatV2( - xs_n_input_dtypes_n_unique_idx, - test_flags, +def test_tensorflow_Sign( # NOQA + *, + dtype_and_x, frontend, - backend_fw, + test_flags, fn_tree, + backend_fw, + on_device, ): - xs, input_dtypes, unique_idx = xs_n_input_dtypes_n_unique_idx + input_dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=input_dtypes, + input_dtypes=input_dtype, backend_to_test=backend_fw, - test_flags=test_flags, frontend=frontend, + test_flags=test_flags, fn_tree=fn_tree, - values=xs, - axis=unique_idx, + on_device=on_device, + x=x[0], ) - - -# Conv2D -@handle_frontend_test( - fn_tree="tensorflow.raw_ops.Conv2D", - x_f_d_df=_x_and_filters( - dtypes=helpers.get_dtypes("float", full=False), - data_format=st.sampled_from(["NHWC"]), - padding=st.sampled_from(["SAME", "VALID", "EXPLICIT"]), - type="2d", - dilation_min=1, - dilation_max=1, - ), + + +# Sinh +@handle_frontend_test( + fn_tree="tensorflow.raw_ops.Sinh", + dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), test_with_out=st.just(False), - number_positional_args=st.just(0), ) -def test_tensorflow_Conv2D( +def test_tensorflow_Sinh( # NOQA *, - x_f_d_df, - test_flags, + dtype_and_x, frontend, - backend_fw, + test_flags, fn_tree, + backend_fw, on_device, ): - input_dtype, x, filters, dilation, data_format, stride, padding = x_f_d_df - channel_index = data_format.find("C") - stride = _convolution_broadcast_helper( - stride, num_spatial_dims=2, channel_index=channel_index, name="strides" - ) - dilation = _convolution_broadcast_helper( - dilation, num_spatial_dims=2, channel_index=channel_index, name="dilations" - ) - explicit_padding = None - if isinstance(padding, list): - explicit_padding = padding - padding = "EXPLICIT" - + input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, - test_flags=test_flags, frontend=frontend, + test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=x, - filter=filters, - strides=stride, - padding=padding, - explicit_paddings=explicit_padding, - data_format=data_format, - dilations=dilation, - use_cudnn_on_gpu=True, + x=x[0], ) -# Conv3D @handle_frontend_test( - fn_tree="tensorflow.raw_ops.Conv3D", - x_f_d_df=_x_and_filters( - dtypes=helpers.get_dtypes("float", full=False), - data_format=st.sampled_from(["NDHWC"]), - padding=st.sampled_from(["SAME", "VALID"]), - type="3d", - # Tensorflow backprop doesn't support dilations more than 1 on CPU - dilation_min=1, - dilation_max=1, + fn_tree="tensorflow.raw_ops.Size", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), max_num_dims=4 ), + output_dtype=st.sampled_from(["int32", "int64"]), test_with_out=st.just(False), - number_positional_args=st.just(0), ) -def test_tensorflow_Conv3D( - *, - x_f_d_df, - test_flags, - frontend, - backend_fw, - fn_tree, - on_device, +def test_tensorflow_Size( # NOQA + *, dtype_and_x, frontend, test_flags, backend_fw, fn_tree, on_device, output_dtype ): - input_dtype, x, filters, dilation, data_format, stride, padding = x_f_d_df - - # Broadcast stirdes and dilations to correct dims for the ground truth - # backend func to run correctly - stride = _convolution_broadcast_helper( - stride, num_spatial_dims=3, channel_index=4, name="strides" - ) - dilation = _convolution_broadcast_helper( - dilation, num_spatial_dims=3, channel_index=4, name="dilations" - ) - + input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, - test_flags=test_flags, frontend=frontend, + test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=x, - filter=filters, - strides=stride, - padding=padding, - data_format=data_format, - dilations=dilation, + input=x[0], + out_type=output_dtype, ) @@ -3643,61 +3654,63 @@ def test_tensorflow_Softmax( ) -# TODO: Fails with torch backend -# ivy.exceptions.IvyBackendException: torch: constant_pad: constant_pad_nd(): argument -# 'value' (position 3) must be Number, not bfloat16 @handle_frontend_test( - fn_tree="tensorflow.raw_ops.PadV2", - dtype_x_paddings=_pad_helper(return_constant_values=True), + fn_tree="tensorflow.raw_ops.Softplus", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_num_dims=1, + ), test_with_out=st.just(False), ) -def test_tensorflow_PadV2( - dtype_x_paddings, +def test_tensorflow_Softplus( # NOQA + *, + dtype_and_x, frontend, test_flags, fn_tree, backend_fw, + on_device, ): - dtype, x, padding_dtype, paddings, constant_values = dtype_x_paddings + dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=dtype + padding_dtype + dtype, + input_dtypes=dtype, backend_to_test=backend_fw, - test_flags=test_flags, frontend=frontend, + test_flags=test_flags, fn_tree=fn_tree, - input=x, - paddings=paddings, - constant_values=constant_values, + on_device=on_device, + features=x[0], ) -# Elu +# Split @handle_frontend_test( - fn_tree="tensorflow.raw_ops.Elu", + fn_tree="tensorflow.raw_ops.Split", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), - min_value=-3, - max_value=3, - min_num_dims=1, - max_num_dims=3, - min_dim_size=1, - max_dim_size=3, + shape=st.shared(helpers.get_shape(min_num_dims=1), key="value_shape"), ), - name=st.just(None), - test_with_out=st.just(False), - number_positional_args=st.just(0), + axis=st.shared( + helpers.get_axis( + shape=st.shared(helpers.get_shape(min_num_dims=1), key="value_shape"), + force_int=True, + ), + key="target_axis", + ), + num_splits=_get_splits(), ) -def test_tensorflow_Elu( +def test_tensorflow_Split( # NOQA *, dtype_and_x, - name, + axis, + num_splits, frontend, test_flags, fn_tree, backend_fw, on_device, ): - dtype, x = dtype_and_x + dtype, value = dtype_and_x helpers.test_frontend_function( input_dtypes=dtype, backend_to_test=backend_fw, @@ -3705,443 +3718,468 @@ def test_tensorflow_Elu( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - features=x[0], - name=name, + value=value[0], + axis=axis, + num_split=num_splits, ) -@st.composite -def _LinSpace_helper(draw): - shape = () - dtype = draw(st.sampled_from(["float32", "float64"])) - - # Param: start - start = draw( - helpers.array_values( - dtype=dtype, - shape=shape, - min_value=-5.0, - max_value=5.0, +# SplitV +@handle_frontend_test( + fn_tree="tensorflow.raw_ops.SplitV", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + shape=st.shared(helpers.get_shape(min_num_dims=1), key="value_shape"), + ), + axis=st.shared( + helpers.get_axis( + shape=st.shared(helpers.get_shape(min_num_dims=1), key="value_shape"), + force_int=True, ), + key="target_axis", + ), + size_splits=_get_splits(as_list=True), + test_with_out=st.just(False), +) +def test_tensorflow_SplitV( # NOQA + *, + dtype_and_x, + axis, + size_splits, + frontend, + test_flags, + fn_tree, + backend_fw, + on_device, +): + dtype, value = dtype_and_x + helpers.test_frontend_function( + input_dtypes=dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + value=value[0], + axis=axis, + size_splits=size_splits, + num_split=len(size_splits), ) - # Param: stop - stop = draw( - helpers.array_values( - dtype=dtype, - shape=shape, - min_value=-4.0, - max_value=10.0, - ), - ) - return [dtype] * 2, start, stop +# Sqrt +@handle_frontend_test( + fn_tree="tensorflow.raw_ops.Sqrt", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + ), + test_with_out=st.just(False), +) +def test_tensorflow_Sqrt( # NOQA + *, + dtype_and_x, + frontend, + test_flags, + fn_tree, + backend_fw, + on_device, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + ) +# Square @handle_frontend_test( - fn_tree="tensorflow.raw_ops.LinSpace", - dtype_and_params=_LinSpace_helper(), - num=helpers.ints(min_value=2, max_value=10), + fn_tree="tensorflow.raw_ops.Square", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + ), + test_with_out=st.just(False), ) -def test_tensorflow_LinSpace( +def test_tensorflow_Square( # NOQA *, - dtype_and_params, - num, - on_device, - fn_tree, + dtype_and_x, frontend, test_flags, + fn_tree, backend_fw, + on_device, ): - dtype, start, stop = dtype_and_params + input_dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=dtype, + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, - start=start, - stop=stop, - num=num, on_device=on_device, + x=x[0], ) @handle_frontend_test( - fn_tree="tensorflow.raw_ops.Roll", - dtype_and_values=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), - ), - shift=helpers.get_axis( - shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), - force_tuple=True, - ), - axis=helpers.get_axis( - shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), - force_tuple=True, + fn_tree="tensorflow.raw_ops.SquaredDifference", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, + shared_dtype=True, ), + test_with_out=st.just(False), ) -def test_tensorflow_roll( +def test_tensorflow_SquaredDifference( *, - dtype_and_values, - shift, - axis, - on_device, - fn_tree, + dtype_and_x, frontend, test_flags, + fn_tree, + on_device, backend_fw, ): - input_dtype, value = dtype_and_values - if isinstance(shift, int) and isinstance(axis, tuple): - axis = axis[0] - if isinstance(shift, tuple) and isinstance(axis, tuple): - if len(shift) != len(axis): - mn = min(len(shift), len(axis)) - shift = shift[:mn] - axis = axis[:mn] + dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=value[0], - shift=shift, - axis=axis, + x=x[0], + y=x[1], ) -# Todo: Revise strategies once reimplemented in frontend -# CumulativeLogsumexp +# Squeeze @handle_frontend_test( - fn_tree="tensorflow.raw_ops.CumulativeLogsumexp", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + fn_tree="tensorflow.raw_ops.Squeeze", + dtype_value=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + shape=st.shared(helpers.get_shape(), key="value_shape"), ), - axis=st.just(0), + axis=_squeeze_helper(), test_with_out=st.just(False), - exclusive=st.booleans(), - reverse=st.booleans(), ) -def test_tensorflow_CumulativeLogsumexp( - dtype_and_x, +def test_tensorflow_Squeeze( # NOQA + dtype_value, + axis, frontend, test_flags, fn_tree, backend_fw, on_device, - axis, - exclusive, - reverse, ): - input_dtype, x = dtype_and_x + dtype, xs = dtype_value helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=dtype, backend_to_test=backend_fw, - test_flags=test_flags, frontend=frontend, + test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], + input=xs[0], axis=axis, - exclusive=exclusive, - reverse=reverse, ) -# Todo: Revise strategies once reimplemented in frontend -# Complex +# Sub @handle_frontend_test( - fn_tree="tensorflow.raw_ops.Complex", + fn_tree="tensorflow.raw_ops.Sub", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, + available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2, shared_dtype=True ), test_with_out=st.just(False), - Tout=st.sampled_from(["complex64", "complex128"]), ) -def test_tensorflow_Complex( +def test_tensorflow_Sub( # NOQA + *, dtype_and_x, frontend, test_flags, fn_tree, backend_fw, on_device, - Tout, ): - input_dtype, x = dtype_and_x + dtype, xs = dtype_and_x helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=dtype, backend_to_test=backend_fw, - test_flags=test_flags, frontend=frontend, + test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - real=x[0], - imag=x[1], - Tout=Tout, + x=xs[0], + y=xs[1], ) -# Todo: Revise strategies once reimplemented in frontend -# AccumulateNV2 @handle_frontend_test( - fn_tree="tensorflow.raw_ops.AccumulateNV2", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + fn_tree="tensorflow.raw_ops.Sum", + dtype_x_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("numeric"), + valid_axis=True, + force_int_axis=True, + min_num_dims=1, + min_value=-5, + max_value=5, ), + keep_dims=st.booleans(), test_with_out=st.just(False), - shape=helpers.get_shape(min_num_dims=1), ) -def test_tensorflow_AccumulateNV2( - dtype_and_x, +def test_tensorflow_Sum( # NOQA + *, + dtype_x_axis, + keep_dims, frontend, test_flags, fn_tree, backend_fw, on_device, - shape, ): - input_dtype, x = dtype_and_x + dtype, x, axis = dtype_x_axis helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=dtype, backend_to_test=backend_fw, - test_flags=test_flags, frontend=frontend, + test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - inputs=x[0], - shape=shape, + input=x[0], + axis=axis, + keep_dims=keep_dims, ) -# Todo: Revise strategies once reimplemented in frontend -# DebugGradientIdentity @handle_frontend_test( - fn_tree="tensorflow.raw_ops.DebugGradientIdentity", + fn_tree="tensorflow.raw_ops.Svd", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("valid"), + min_value=0, + max_value=10, + shape=helpers.ints(min_value=2, max_value=5).map(lambda x: tuple([x, x])), ), - test_with_out=st.just(False), + full_matrices=st.booleans(), + compute_uv=st.just(True), ) -def test_tensorflow_DebugGradientIdentity( +def test_tensorflow_Svd( + *, dtype_and_x, + full_matrices, + compute_uv, frontend, test_flags, fn_tree, - backend_fw, on_device, + backend_fw, ): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, + dtype, x = dtype_and_x + x = np.asarray(x[0], dtype=dtype[0]) + # make symmetric positive definite beforehand + x = np.matmul(x.T, x) + np.identity(x.shape[0]) * 1e-3 + ret, frontend_ret = helpers.test_frontend_function( + input_dtypes=dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=x[0], + test_values=False, + input=x, + full_matrices=full_matrices, + compute_uv=compute_uv, + ) + ret = [ivy.to_numpy(x) for x in ret] + frontend_ret = [np.asarray(x) for x in frontend_ret] + + u, s, vh = ret + frontend_s, frontend_u, frontend_vh = frontend_ret + + assert_all_close( + ret_np=u @ np.diag(s) @ vh, + ret_from_gt_np=frontend_u @ np.diag(frontend_s) @ frontend_vh.T, + rtol=1e-2, + atol=1e-2, + ground_truth_backend=frontend, ) -# Todo: Revise strategies once reimplemented in frontend -# Real +# Tan @handle_frontend_test( - fn_tree="tensorflow.raw_ops.Real", + fn_tree="tensorflow.raw_ops.Tan", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), ), test_with_out=st.just(False), - Tout=st.sampled_from(["float32", "float64"]), ) -def test_tensorflow_Real( +def test_tensorflow_Tan( # NOQA + *, dtype_and_x, frontend, test_flags, fn_tree, backend_fw, on_device, - Tout, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, - test_flags=test_flags, frontend=frontend, + test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=x[0], - Tout=Tout, + x=x[0], ) -# Todo: Revise strategies once reimplemented in frontend -# BandedTriangularSolve +# Tanh @handle_frontend_test( - fn_tree="tensorflow.raw_ops.BandedTriangularSolve", + fn_tree="tensorflow.raw_ops.Tanh", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, ), test_with_out=st.just(False), - lower=st.booleans(), - adjoint=st.booleans(), ) -def test_tensorflow_BandedTriangularSolve( +def test_tensorflow_Tanh( # NOQA + *, dtype_and_x, frontend, test_flags, fn_tree, backend_fw, on_device, - lower, - adjoint, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, - test_flags=test_flags, frontend=frontend, + test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - matrix=x[0], - rhs=x[1], - lower=lower, - adjoint=adjoint, + x=x[0], ) -# Todo: Revise strategies once reimplemented in frontend -# BatchMatMul +# TanhGrad @handle_frontend_test( - fn_tree="tensorflow.raw_ops.BatchMatMul", + fn_tree="tensorflow.raw_ops.TanhGrad", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, + available_dtypes=helpers.get_dtypes("float"), num_arrays=2, shared_dtype=True ), test_with_out=st.just(False), - adj_x=st.booleans(), - adj_y=st.booleans(), ) -def test_tensorflow_BatchMatMul( +def test_tensorflow_TanhGrad( # NOQA + *, dtype_and_x, frontend, test_flags, fn_tree, backend_fw, on_device, - adj_x, - adj_y, ): - input_dtype, x = dtype_and_x + dtype, xs = dtype_and_x helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=dtype, backend_to_test=backend_fw, - test_flags=test_flags, frontend=frontend, + test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], - y=x[1], - adj_x=adj_x, - adj_y=adj_y, + y=xs[0], + dy=xs[1], ) -# Todo: Revise strategies once reimplemented in frontend -# BatchMatMulV2 @handle_frontend_test( - fn_tree="tensorflow.raw_ops.BatchMatMulV2", + fn_tree="tensorflow.raw_ops.TruncateDiv", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, + available_dtypes=helpers.get_dtypes("integer"), num_arrays=2, shared_dtype=True ), test_with_out=st.just(False), - adj_x=st.booleans(), - adj_y=st.booleans(), ) -def test_tensorflow_BatchMatMulV2( +def test_tensorflow_TruncateDiv( # NOQA + *, dtype_and_x, frontend, test_flags, fn_tree, backend_fw, on_device, - adj_x, - adj_y, ): - input_dtype, x = dtype_and_x + dtype, xs = dtype_and_x + # prevent too close to zero + assume(not np.any(np.isclose(xs[1], 0))) + helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=dtype, backend_to_test=backend_fw, - test_flags=test_flags, frontend=frontend, + test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], - y=x[1], - adj_x=adj_x, - adj_y=adj_y, + x=xs[0], + y=xs[1], ) -# Todo: Revise strategies once reimplemented in frontend -# BatchMatMulV3 @handle_frontend_test( - fn_tree="tensorflow.raw_ops.BatchMatMulV3", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, + fn_tree="tensorflow.raw_ops.Unpack", + dtype_x_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("valid"), + valid_axis=True, + force_int_axis=True, + min_num_dims=1, ), test_with_out=st.just(False), - Tout=st.sampled_from(["float32", "float64"]), - adj_x=st.booleans(), - adj_y=st.booleans(), ) -def test_tensorflow_BatchMatMulV3( - dtype_and_x, +def test_tensorflow_Unpack( # NOQA + *, + dtype_x_axis, frontend, test_flags, fn_tree, backend_fw, on_device, - Tout, - adj_x, - adj_y, ): - input_dtype, x = dtype_and_x + dtype, x, axis = dtype_x_axis helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=dtype, backend_to_test=backend_fw, - test_flags=test_flags, frontend=frontend, + test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - x=x[0], - y=x[1], - Tout=Tout, - adj_x=adj_x, - adj_y=adj_y, + value=x[0], + num=x[0].shape[axis], + axis=axis, ) @handle_frontend_test( - fn_tree="tensorflow.raw_ops.Size", + fn_tree="tensorflow.raw_ops.Xdivy", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), max_num_dims=4 + available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, + shared_dtype=True, ), - output_dtype=st.sampled_from(["int32", "int64"]), test_with_out=st.just(False), ) -def test_tensorflow_Size( # NOQA - *, dtype_and_x, frontend, test_flags, backend_fw, fn_tree, on_device, output_dtype +def test_tensorflow_Xdivy( # NOQA + *, + dtype_and_x, + frontend, + test_flags, + fn_tree, + backend_fw, + on_device, ): - input_dtype, x = dtype_and_x + input_dtype, xs = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -4149,77 +4187,70 @@ def test_tensorflow_Size( # NOQA test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=x[0], - out_type=output_dtype, + x=xs[0], + y=xs[1], ) @handle_frontend_test( - fn_tree="tensorflow.raw_ops.Prod", - dtype_x_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("numeric"), - valid_axis=True, - force_int_axis=True, - min_num_dims=1, - min_value=-5, - max_value=5, + fn_tree="tensorflow.raw_ops.Xlog1py", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, + shared_dtype=True, ), - keep_dims=st.booleans(), test_with_out=st.just(False), ) -def test_tensorflow_Prod( # NOQA +def test_tensorflow_Xlog1py( # NOQA *, - dtype_x_axis, - keep_dims, + dtype_and_x, frontend, test_flags, fn_tree, backend_fw, on_device, ): - dtype, x, axis = dtype_x_axis + input_dtype, xs = dtype_and_x helpers.test_frontend_function( - input_dtypes=dtype, + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=x[0], - axis=axis, - keep_dims=keep_dims, + x=xs[0], + y=xs[1], ) @handle_frontend_test( - fn_tree="tensorflow.raw_ops.LeakyRelu", + fn_tree="tensorflow.raw_ops.Xlogy", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_num_dims=1, + available_dtypes=["float16", "float32", "float64"], + num_arrays=2, + shared_dtype=True, ), test_with_out=st.just(False), - alpha=helpers.floats(min_value=0, max_value=1), ) -def test_tensorflow_LeakyReLU( +def test_tensorflow_Xlogy( *, dtype_and_x, - alpha, frontend, test_flags, fn_tree, backend_fw, on_device, ): - dtype, x = dtype_and_x - return helpers.test_frontend_function( - input_dtypes=dtype, + input_dtype, xs = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - features=x[0], - alpha=alpha, + x=xs[0], + y=xs[1], ) @@ -4257,31 +4288,40 @@ def test_tensorflow_Zeta( ) -# Imag @handle_frontend_test( - fn_tree="tensorflow.raw_ops.Imag", - dtype_and_xs=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), + fn_tree="tensorflow.raw_ops.Roll", + dtype_and_values=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), + ), + shift=helpers.get_axis( + shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), + force_tuple=True, + ), + axis=helpers.get_axis( + shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), + force_tuple=True, ), - send_Tout=st.booleans(), - test_with_out=st.just(False), ) -def test_tensorflow_Imag( +def test_tensorflow_roll( *, - dtype_and_xs, - send_Tout, + dtype_and_values, + shift, + axis, + on_device, + fn_tree, frontend, test_flags, - fn_tree, - on_device, backend_fw, ): - input_dtype, xs = dtype_and_xs - if input_dtype[0] == "complex128": - send_Tout = "float64" - elif input_dtype[0] == "complex64": - send_Tout = "float32" if send_Tout else None - + input_dtype, value = dtype_and_values + if isinstance(shift, int) and isinstance(axis, tuple): + axis = axis[0] + if isinstance(shift, tuple) and isinstance(axis, tuple): + if len(shift) != len(axis): + mn = min(len(shift), len(axis)) + shift = shift[:mn] + axis = axis[:mn] helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -4289,99 +4329,67 @@ def test_tensorflow_Imag( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=xs[0], - Tout=send_Tout, + input=value[0], + shift=shift, + axis=axis, ) +# Transpose @handle_frontend_test( - fn_tree="tensorflow.raw_ops.Svd", + fn_tree="tensorflow.raw_ops.Transpose", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - min_value=0, - max_value=10, - shape=helpers.ints(min_value=2, max_value=5).map(lambda x: tuple([x, x])), + available_dtypes=helpers.get_dtypes("float"), + shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), ), - full_matrices=st.booleans(), - compute_uv=st.just(True), + perm=_permute_dims_helper(), + test_with_out=st.just(False), ) -def test_tensorflow_Svd( +def test_tensorflow_transpose( # NOQA *, dtype_and_x, - full_matrices, - compute_uv, + perm, frontend, test_flags, fn_tree, - on_device, backend_fw, + on_device, ): dtype, x = dtype_and_x - x = np.asarray(x[0], dtype=dtype[0]) - # make symmetric positive definite beforehand - x = np.matmul(x.T, x) + np.identity(x.shape[0]) * 1e-3 - ret, frontend_ret = helpers.test_frontend_function( + helpers.test_frontend_function( input_dtypes=dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - test_values=False, - input=x, - full_matrices=full_matrices, - compute_uv=compute_uv, - ) - ret = [ivy.to_numpy(x) for x in ret] - frontend_ret = [np.asarray(x) for x in frontend_ret] - - u, s, vh = ret - frontend_s, frontend_u, frontend_vh = frontend_ret - - assert_all_close( - ret_np=u @ np.diag(s) @ vh, - ret_from_gt_np=frontend_u @ np.diag(frontend_s) @ frontend_vh.T, - rtol=1e-2, - atol=1e-2, - ground_truth_backend=frontend, + x=x[0], + perm=perm, ) +# ZerosLike @handle_frontend_test( - fn_tree="tensorflow.raw_ops.Igamma", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - num_arrays=2, - shared_dtype=True, - abs_smallest_val=1e-5, - min_num_dims=2, - max_num_dims=2, - min_dim_size=3, - max_dim_size=3, - min_value=2, - max_value=100, - allow_nan=False, - ), + fn_tree="tensorflow.raw_ops.ZerosLike", + dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), test_with_out=st.just(False), ) -def test_tensorflow_Igamma( +def test_tensorflow_zeros_like( # NOQA *, dtype_and_x, - on_device, - fn_tree, - backend_fw, frontend, test_flags, + fn_tree, + backend_fw, + on_device, ): - input_dtype, xs = dtype_and_x + dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=input_dtype, - frontend=frontend, + input_dtypes=dtype, backend_to_test=backend_fw, + frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - rtol=1e-04, - a=xs[0], - x=xs[1], + x=x[0], ) diff --git a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_signal.py b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_signal.py index 894e7d26989b7..3757776e02c47 100644 --- a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_signal.py +++ b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_signal.py @@ -6,42 +6,8 @@ from ivy_tests.test_ivy.helpers import handle_frontend_test -# kaiser_window -@handle_frontend_test( - fn_tree="tensorflow.signal.kaiser_window", - dtype_and_window_length=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("integer") - ), - dtype_and_beta=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric") - ), - dtype=helpers.get_dtypes("numeric"), - test_with_out=st.just(False), -) -def test_tensorflow_kaiser_window( - *, - dtype_and_window_length, - dtype_and_beta, - dtype, - frontend, - test_flags, - fn_tree, - backend_fw, - on_device, -): - window_length_dtype, window_length = dtype_and_window_length - beta_dtype, beta = dtype_and_beta - helpers.test_frontend_function( - input_dtypes=[window_length_dtype[0], beta_dtype[0]], - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - window_length=window_length, - beta=beta, - dtype=dtype, - ) +# --- Helpers --- # +# --------------- # @st.composite @@ -65,36 +31,8 @@ def _valid_idct(draw): return dtype, x, type, n, axis, norm -# idct -@handle_frontend_test( - fn_tree="tensorflow.signal.idct", - dtype_x_and_args=_valid_idct(), - test_with_out=st.just(False), -) -def test_tensorflow_idct( - *, - dtype_x_and_args, - frontend, - test_flags, - fn_tree, - backend_fw, - on_device, -): - input_dtype, x, type, n, axis, norm = dtype_x_and_args - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=x[0], - type=type, - n=n, - axis=axis, - norm=norm, - atol=1e-01, - ) +# --- Main --- # +# ------------ # # dct @@ -149,6 +87,76 @@ def test_tensorflow_dct( ) +# idct +@handle_frontend_test( + fn_tree="tensorflow.signal.idct", + dtype_x_and_args=_valid_idct(), + test_with_out=st.just(False), +) +def test_tensorflow_idct( + *, + dtype_x_and_args, + frontend, + test_flags, + fn_tree, + backend_fw, + on_device, +): + input_dtype, x, type, n, axis, norm = dtype_x_and_args + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=x[0], + type=type, + n=n, + axis=axis, + norm=norm, + atol=1e-01, + ) + + +# kaiser_window +@handle_frontend_test( + fn_tree="tensorflow.signal.kaiser_window", + dtype_and_window_length=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("integer") + ), + dtype_and_beta=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric") + ), + dtype=helpers.get_dtypes("numeric"), + test_with_out=st.just(False), +) +def test_tensorflow_kaiser_window( + *, + dtype_and_window_length, + dtype_and_beta, + dtype, + frontend, + test_flags, + fn_tree, + backend_fw, + on_device, +): + window_length_dtype, window_length = dtype_and_window_length + beta_dtype, beta = dtype_and_beta + helpers.test_frontend_function( + input_dtypes=[window_length_dtype[0], beta_dtype[0]], + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + window_length=window_length, + beta=beta, + dtype=dtype, + ) + + # vorbis_window @handle_frontend_test( fn_tree="tensorflow.signal.vorbis_window", diff --git a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_tensor.py b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_tensor.py index 7f4da701acf2a..b76fcae600ad3 100644 --- a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_tensor.py +++ b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_tensor.py @@ -16,78 +16,56 @@ CLASS_TREE = "ivy.functional.frontends.tensorflow.EagerTensor" -@given( - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid", prune_function=False) - ), -) -def test_tensorflow_tensor_ivy_array( - dtype_x, - backend_fw, -): - ivy.set_backend(backend_fw) - _, data = dtype_x - x = EagerTensor(data[0]) - ret = helpers.flatten_and_to_np(ret=x.ivy_array.data, backend=backend_fw) - ret_gt = helpers.flatten_and_to_np(ret=data[0], backend=backend_fw) - helpers.value_test( - ret_np_flat=ret, - ret_np_from_gt_flat=ret_gt, - ground_truth_backend="tensorflow", - ) - ivy.previous_backend() +# --- Helpers --- # +# --------------- # -@given( - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid", prune_function=False) - ), -) -def test_tensorflow_tensor_device( - dtype_x, - backend_fw, +@st.composite +def _array_and_shape( + draw, + *, + min_num_dims=1, + max_num_dims=3, + min_dim_size=1, + max_dim_size=10, ): - ivy.set_backend(backend_fw) - _, data = dtype_x - data = ivy.native_array(data[0]) - x = EagerTensor(data) - ivy.utils.assertions.check_equal(x.device, ivy.dev(data), as_array=False) - ivy.previous_backend() + if isinstance(min_dim_size, st._internal.SearchStrategy): + min_dim_size = draw(min_dim_size) + if isinstance(max_dim_size, st._internal.SearchStrategy): + max_dim_size = draw(max_dim_size) + available_dtypes = draw(helpers.get_dtypes("numeric")) + dtype = draw( + helpers.array_dtypes( + num_arrays=1, + available_dtypes=available_dtypes, + ) + ) + dtype.append("int32") + shape = draw( + st.shared( + helpers.get_shape( + min_num_dims=min_num_dims, + max_num_dims=max_num_dims, + min_dim_size=min_dim_size, + max_dim_size=max_dim_size, + ), + key="shape", + ) + ) + array = draw( + helpers.array_values( + dtype=dtype[0], + shape=shape, + ) + ) + to_shape = [(None if draw(st.booleans()) else _) for _ in shape] -@given( - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid", prune_function=False), - ), -) -def test_tensorflow_tensor_dtype( - dtype_x, - backend_fw, -): - ivy.set_backend(backend_fw) - dtype, data = dtype_x - x = EagerTensor(data[0]) - ivy.utils.assertions.check_equal(x.dtype, ivy.Dtype(dtype[0]), as_array=False) - ivy.previous_backend() + return dtype, [array, to_shape] -@given( - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid", prune_function=False), - ret_shape=True, - ), -) -def test_tensorflow_tensor_shape( - dtype_x, - backend_fw, -): - ivy.set_backend(backend_fw) - dtype, data, shape = dtype_x - x = EagerTensor(data[0]) - ivy.utils.assertions.check_equal( - x.ivy_array.shape, ivy.Shape(shape), as_array=False - ) - ivy.previous_backend() +# --- Main --- # +# ------------ # # __add__ @@ -129,31 +107,27 @@ def test_tensorflow__add__( ) +# __and__ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="tensorflow.constant", - method_name="__div__", + method_name="__and__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), + available_dtypes=helpers.get_dtypes("integer"), num_arrays=2, shared_dtype=True, - large_abs_safety_factor=10, - small_abs_safety_factor=10, - safety_factor_scale="log", ), ) -def test_tensorflow__div__( +def test_tensorflow__and__( dtype_and_x, frontend, frontend_method_data, init_flags, method_flags, - on_device, backend_fw, + on_device, ): input_dtype, x = dtype_and_x - assume(not np.any(np.isclose(x[0], 0))) - assume(not np.any(np.isclose(x[1], 0))) helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -172,24 +146,55 @@ def test_tensorflow__div__( ) +# __array__ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="tensorflow.constant", - method_name="get_shape", + method_name="__array__", + dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("valid")), + dtype=helpers.get_dtypes("valid", full=False), +) +def test_tensorflow__array__( + dtype_and_x, + dtype, + frontend, + backend_fw, +): + input_dtype, x = dtype_and_x + dtype[0] = np.dtype(dtype[0]) + ret_gt = tf.constant(x[0]).__array__(dtype[0]) + with BackendHandler.update_backend(backend_fw) as ivy_backend: + local_importer = ivy_backend.utils.dynamic_import + function_module = local_importer.import_module( + "ivy.functional.frontends.tensorflow" + ) + ret = function_module.constant(x[0]).__array__(dtype[0]) + helpers.value_test( + ret_np_flat=ret.ravel(), + ret_np_from_gt_flat=ret_gt.ravel(), + ground_truth_backend="tensorflow", + backend=backend_fw, + ) + + +# __bool__ +@handle_frontend_method( + class_tree=CLASS_TREE, + init_tree="tensorflow.constant", + method_name="__bool__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - min_num_dims=1, - min_dim_size=1, + available_dtypes=helpers.get_dtypes("integer"), + max_dim_size=1, ), ) -def test_tensorflow_tensor_get_shape( +def test_tensorflow__bool__( dtype_and_x, frontend, frontend_method_data, init_flags, method_flags, - on_device, backend_fw, + on_device, ): input_dtype, x = dtype_and_x helpers.test_frontend_method( @@ -211,13 +216,17 @@ def test_tensorflow_tensor_get_shape( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="tensorflow.constant", - method_name="__eq__", + method_name="__div__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2, + shared_dtype=True, + large_abs_safety_factor=10, + small_abs_safety_factor=10, + safety_factor_scale="log", ), ) -def test_tensorflow__eq__( +def test_tensorflow__div__( dtype_and_x, frontend, frontend_method_data, @@ -227,6 +236,8 @@ def test_tensorflow__eq__( backend_fw, ): input_dtype, x = dtype_and_x + assume(not np.any(np.isclose(x[0], 0))) + assume(not np.any(np.isclose(x[1], 0))) helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -235,7 +246,7 @@ def test_tensorflow__eq__( }, method_input_dtypes=input_dtype, method_all_as_kwargs_np={ - "other": x[1], + "y": x[1], }, frontend=frontend, frontend_method_data=frontend_method_data, @@ -248,17 +259,13 @@ def test_tensorflow__eq__( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="tensorflow.constant", - method_name="__floordiv__", + method_name="__eq__", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), num_arrays=2, - shared_dtype=True, - large_abs_safety_factor=10, - small_abs_safety_factor=10, - safety_factor_scale="log", ), ) -def test_tensorflow__floordiv__( +def test_tensorflow__eq__( dtype_and_x, frontend, frontend_method_data, @@ -276,7 +283,7 @@ def test_tensorflow__floordiv__( }, method_input_dtypes=input_dtype, method_all_as_kwargs_np={ - "y": x[1], + "other": x[1], }, frontend=frontend, frontend_method_data=frontend_method_data, @@ -286,18 +293,20 @@ def test_tensorflow__floordiv__( ) -# __ge__ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="tensorflow.constant", - method_name="__ge__", + method_name="__floordiv__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), + available_dtypes=helpers.get_dtypes("float"), num_arrays=2, shared_dtype=True, + large_abs_safety_factor=10, + small_abs_safety_factor=10, + safety_factor_scale="log", ), ) -def test_tensorflow__ge__( +def test_tensorflow__floordiv__( dtype_and_x, frontend, frontend_method_data, @@ -325,18 +334,18 @@ def test_tensorflow__ge__( ) -# __gt__ +# __ge__ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="tensorflow.constant", - method_name="__gt__", + method_name="__ge__", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2, shared_dtype=True, ), ) -def test_tensorflow__gt__( +def test_tensorflow__ge__( dtype_and_x, frontend, frontend_method_data, @@ -364,37 +373,37 @@ def test_tensorflow__gt__( ) -# __le__ +# __getitem__ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="tensorflow.constant", - method_name="__le__", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, - shared_dtype=True, + method_name="__getitem__", + dtype_x_index=helpers.dtype_array_query( + available_dtypes=helpers.get_dtypes("valid"), + ).filter( + lambda x: ( + all(_check_query(i) for i in x[-1]) + if isinstance(x[-1], tuple) + else _check_query(x[-1]) + ) ), ) -def test_tensorflow__le__( - dtype_and_x, +def test_tensorflow__getitem__( + dtype_x_index, frontend, frontend_method_data, init_flags, method_flags, - on_device, backend_fw, + on_device, ): - input_dtype, x = dtype_and_x + input_dtype, x, index = dtype_x_index helpers.test_frontend_method( - init_input_dtypes=input_dtype, + init_input_dtypes=[input_dtype[0]], backend_to_test=backend_fw, - init_all_as_kwargs_np={ - "value": x[0], - }, - method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "y": x[1], - }, + init_all_as_kwargs_np={"value": x}, + method_input_dtypes=[*input_dtype[1:]], + method_all_as_kwargs_np={"slice_spec": index}, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -403,18 +412,18 @@ def test_tensorflow__le__( ) -# __lt__ +# __gt__ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="tensorflow.constant", - method_name="__lt__", + method_name="__gt__", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2, shared_dtype=True, ), ) -def test_tensorflow__lt__( +def test_tensorflow__gt__( dtype_and_x, frontend, frontend_method_data, @@ -442,24 +451,23 @@ def test_tensorflow__lt__( ) +# __invert__ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="tensorflow.constant", - method_name="__mul__", + method_name="__invert__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, - shared_dtype=True, + available_dtypes=helpers.get_dtypes("integer") ), ) -def test_tensorflow__mul__( +def test_tensorflow__invert__( dtype_and_x, frontend, frontend_method_data, init_flags, method_flags, - on_device, backend_fw, + on_device, ): input_dtype, x = dtype_and_x helpers.test_frontend_method( @@ -469,9 +477,7 @@ def test_tensorflow__mul__( "value": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "y": x[1], - }, + method_all_as_kwargs_np={}, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -480,18 +486,18 @@ def test_tensorflow__mul__( ) -# __mod__ +# __le__ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="tensorflow.constant", - method_name="__mod__", + method_name="__le__", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2, shared_dtype=True, ), ) -def test_tensorflow__mod__( +def test_tensorflow__le__( dtype_and_x, frontend, frontend_method_data, @@ -501,8 +507,6 @@ def test_tensorflow__mod__( backend_fw, ): input_dtype, x = dtype_and_x - assume(not np.any(np.isclose(x[0], 0))) - assume(not np.any(np.isclose(x[1], 0))) helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -521,18 +525,17 @@ def test_tensorflow__mod__( ) -# __sub__ +# __len__ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="tensorflow.constant", - method_name="__sub__", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, - shared_dtype=True, + method_name="__len__", + dtype_and_x=_array_and_shape( + min_num_dims=1, + max_num_dims=5, ), ) -def test_tensorflow__sub__( +def test_tensorflow__len__( dtype_and_x, frontend, frontend_method_data, @@ -549,9 +552,7 @@ def test_tensorflow__sub__( "value": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "y": x[1], - }, + method_all_as_kwargs_np={}, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -560,25 +561,25 @@ def test_tensorflow__sub__( ) -# __ne__ +# __lt__ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="tensorflow.constant", - method_name="__ne__", + method_name="__lt__", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2, shared_dtype=True, ), ) -def test_tensorflow__ne__( +def test_tensorflow__lt__( dtype_and_x, frontend, frontend_method_data, init_flags, method_flags, - backend_fw, on_device, + backend_fw, ): input_dtype, x = dtype_and_x helpers.test_frontend_method( @@ -589,7 +590,7 @@ def test_tensorflow__ne__( }, method_input_dtypes=input_dtype, method_all_as_kwargs_np={ - "other": x[1], + "y": x[1], }, frontend=frontend, frontend_method_data=frontend_method_data, @@ -599,18 +600,27 @@ def test_tensorflow__ne__( ) -# __radd__ +# __matmul__ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="tensorflow.constant", - method_name="__radd__", + method_name="__matmul__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), + available_dtypes=[ + "float32", + "float64", + "int32", + "int64", + ], + shape=(3, 3), num_arrays=2, shared_dtype=True, + large_abs_safety_factor=10, + small_abs_safety_factor=10, + safety_factor_scale="log", ), ) -def test_tensorflow__radd__( +def test_tensorflow__matmul__( dtype_and_x, frontend, frontend_method_data, @@ -628,7 +638,7 @@ def test_tensorflow__radd__( }, method_input_dtypes=input_dtype, method_all_as_kwargs_np={ - "x": x[1], + "y": x[1], }, frontend=frontend, frontend_method_data=frontend_method_data, @@ -638,30 +648,29 @@ def test_tensorflow__radd__( ) -# __rfloordiv__ +# __mod__ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="tensorflow.constant", - method_name="__rfloordiv__", + method_name="__mod__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2, shared_dtype=True, - large_abs_safety_factor=10, - small_abs_safety_factor=10, - safety_factor_scale="log", ), ) -def test_tensorflow__rfloordiv__( +def test_tensorflow__mod__( dtype_and_x, frontend, frontend_method_data, init_flags, method_flags, - backend_fw, on_device, + backend_fw, ): input_dtype, x = dtype_and_x + assume(not np.any(np.isclose(x[0], 0))) + assume(not np.any(np.isclose(x[1], 0))) helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -670,7 +679,7 @@ def test_tensorflow__rfloordiv__( }, method_input_dtypes=input_dtype, method_all_as_kwargs_np={ - "x": x[1], + "y": x[1], }, frontend=frontend, frontend_method_data=frontend_method_data, @@ -680,25 +689,24 @@ def test_tensorflow__rfloordiv__( ) -# __rsub__ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="tensorflow.constant", - method_name="__rsub__", + method_name="__mul__", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2, shared_dtype=True, ), ) -def test_tensorflow__rsub__( +def test_tensorflow__mul__( dtype_and_x, frontend, frontend_method_data, init_flags, method_flags, - backend_fw, on_device, + backend_fw, ): input_dtype, x = dtype_and_x helpers.test_frontend_method( @@ -709,7 +717,7 @@ def test_tensorflow__rsub__( }, method_input_dtypes=input_dtype, method_all_as_kwargs_np={ - "x": x[1], + "y": x[1], }, frontend=frontend, frontend_method_data=frontend_method_data, @@ -719,18 +727,18 @@ def test_tensorflow__rsub__( ) -# __and__ +# __ne__ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="tensorflow.constant", - method_name="__and__", + method_name="__ne__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("integer"), + available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2, shared_dtype=True, ), ) -def test_tensorflow__and__( +def test_tensorflow__ne__( dtype_and_x, frontend, frontend_method_data, @@ -748,7 +756,7 @@ def test_tensorflow__and__( }, method_input_dtypes=input_dtype, method_all_as_kwargs_np={ - "y": x[1], + "other": x[1], }, frontend=frontend, frontend_method_data=frontend_method_data, @@ -758,18 +766,23 @@ def test_tensorflow__and__( ) -# __rand__ +# __neg__ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="tensorflow.constant", - method_name="__rand__", + method_name="__neg__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("integer"), - num_arrays=2, - shared_dtype=True, + available_dtypes=[ + "float32", + "float64", + "int8", + "int16", + "int32", + "int64", + ], ), ) -def test_tensorflow__rand__( +def test_tensorflow__neg__( dtype_and_x, frontend, frontend_method_data, @@ -786,9 +799,7 @@ def test_tensorflow__rand__( "value": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "x": x[1], - }, + method_all_as_kwargs_np={}, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -797,18 +808,17 @@ def test_tensorflow__rand__( ) -# __or__ +# __nonzero__ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="tensorflow.constant", - method_name="__or__", + method_name="__nonzero__", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("integer"), - num_arrays=2, - shared_dtype=True, + max_dim_size=1, ), ) -def test_tensorflow__or__( +def test_tensorflow__nonzero__( dtype_and_x, frontend, frontend_method_data, @@ -825,9 +835,7 @@ def test_tensorflow__or__( "value": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "y": x[1], - }, + method_all_as_kwargs_np={}, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -836,18 +844,18 @@ def test_tensorflow__or__( ) -# __ror__ +# __or__ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="tensorflow.constant", - method_name="__ror__", + method_name="__or__", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("integer"), num_arrays=2, shared_dtype=True, ), ) -def test_tensorflow__ror__( +def test_tensorflow__or__( dtype_and_x, frontend, frontend_method_data, @@ -865,7 +873,7 @@ def test_tensorflow__ror__( }, method_input_dtypes=input_dtype, method_all_as_kwargs_np={ - "x": x[1], + "y": x[1], }, frontend=frontend, frontend_method_data=frontend_method_data, @@ -875,21 +883,24 @@ def test_tensorflow__ror__( ) -# __truediv__ +# __pow__ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="tensorflow.constant", - method_name="__truediv__", + method_name="__pow__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), + available_dtypes=[ + "float16", + "float32", + "float64", + "int32", + "int64", + ], num_arrays=2, shared_dtype=True, - large_abs_safety_factor=15, - small_abs_safety_factor=15, - safety_factor_scale="log", ), ) -def test_tensorflow__truediv__( +def test_tensorflow__pow__( dtype_and_x, frontend, frontend_method_data, @@ -899,7 +910,13 @@ def test_tensorflow__truediv__( on_device, ): input_dtype, x = dtype_and_x - assume(not np.any(np.isclose(x[1], 0))) + if x[1].dtype == "int32" or x[1].dtype == "int64": + if x[1].ndim == 0: + if x[1] < 0: + x[1] *= -1 + else: + x[1][(x[1] < 0).nonzero()] *= -1 + helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -918,21 +935,18 @@ def test_tensorflow__truediv__( ) -# __rtruediv__ +# __radd__ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="tensorflow.constant", - method_name="__rtruediv__", + method_name="__radd__", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2, shared_dtype=True, - large_abs_safety_factor=15, - small_abs_safety_factor=15, - safety_factor_scale="log", ), ) -def test_tensorflow__rtruediv__( +def test_tensorflow__radd__( dtype_and_x, frontend, frontend_method_data, @@ -960,17 +974,18 @@ def test_tensorflow__rtruediv__( ) -# __bool__ +# __rand__ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="tensorflow.constant", - method_name="__bool__", + method_name="__rand__", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("integer"), - max_dim_size=1, + num_arrays=2, + shared_dtype=True, ), ) -def test_tensorflow__bool__( +def test_tensorflow__rand__( dtype_and_x, frontend, frontend_method_data, @@ -987,7 +1002,9 @@ def test_tensorflow__bool__( "value": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, + method_all_as_kwargs_np={ + "x": x[1], + }, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -996,17 +1013,21 @@ def test_tensorflow__bool__( ) -# __nonzero__ +# __rfloordiv__ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="tensorflow.constant", - method_name="__nonzero__", + method_name="__rfloordiv__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("integer"), - max_dim_size=1, + available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, + shared_dtype=True, + large_abs_safety_factor=10, + small_abs_safety_factor=10, + safety_factor_scale="log", ), ) -def test_tensorflow__nonzero__( +def test_tensorflow__rfloordiv__( dtype_and_x, frontend, frontend_method_data, @@ -1023,7 +1044,9 @@ def test_tensorflow__nonzero__( "value": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, + method_all_as_kwargs_np={ + "x": x[1], + }, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -1032,23 +1055,27 @@ def test_tensorflow__nonzero__( ) -# __neg__ +# __rmatmul__ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="tensorflow.constant", - method_name="__neg__", + method_name="__rmatmul__", dtype_and_x=helpers.dtype_and_values( available_dtypes=[ "float32", "float64", - "int8", - "int16", "int32", "int64", ], + shape=(3, 3), + num_arrays=2, + shared_dtype=True, + large_abs_safety_factor=10, + small_abs_safety_factor=10, + safety_factor_scale="log", ), ) -def test_tensorflow__neg__( +def test_tensorflow__rmatmul__( dtype_and_x, frontend, frontend_method_data, @@ -1065,7 +1092,9 @@ def test_tensorflow__neg__( "value": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, + method_all_as_kwargs_np={ + "x": x[1], + }, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -1074,18 +1103,20 @@ def test_tensorflow__neg__( ) -# __rxor__ +# __rmul__ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="tensorflow.constant", - method_name="__rxor__", + method_name="__rmul__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("integer"), + available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2, shared_dtype=True, + min_value=-100, + max_value=100, ), ) -def test_tensorflow__rxor__( +def test_tensorflow__rmul__( dtype_and_x, frontend, frontend_method_data, @@ -1113,18 +1144,18 @@ def test_tensorflow__rxor__( ) -# __xor__ +# __ror__ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="tensorflow.constant", - method_name="__xor__", + method_name="__ror__", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("integer"), num_arrays=2, shared_dtype=True, ), ) -def test_tensorflow__xor__( +def test_tensorflow__ror__( dtype_and_x, frontend, frontend_method_data, @@ -1142,7 +1173,7 @@ def test_tensorflow__xor__( }, method_input_dtypes=input_dtype, method_all_as_kwargs_np={ - "y": x[1], + "x": x[1], }, frontend=frontend, frontend_method_data=frontend_method_data, @@ -1152,27 +1183,14 @@ def test_tensorflow__xor__( ) -# __matmul__ +# __rpow__ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="tensorflow.constant", - method_name="__matmul__", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=[ - "float32", - "float64", - "int32", - "int64", - ], - shape=(3, 3), - num_arrays=2, - shared_dtype=True, - large_abs_safety_factor=10, - small_abs_safety_factor=10, - safety_factor_scale="log", - ), + method_name="__rpow__", + dtype_and_x=_pow_helper_shared_dtype(), ) -def test_tensorflow__matmul__( +def test_tensorflow__rpow__( dtype_and_x, frontend, frontend_method_data, @@ -1190,7 +1208,7 @@ def test_tensorflow__matmul__( }, method_input_dtypes=input_dtype, method_all_as_kwargs_np={ - "y": x[1], + "x": x[1], }, frontend=frontend, frontend_method_data=frontend_method_data, @@ -1200,27 +1218,18 @@ def test_tensorflow__matmul__( ) -# __rmatmul__ +# __rsub__ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="tensorflow.constant", - method_name="__rmatmul__", + method_name="__rsub__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=[ - "float32", - "float64", - "int32", - "int64", - ], - shape=(3, 3), + available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2, shared_dtype=True, - large_abs_safety_factor=10, - small_abs_safety_factor=10, - safety_factor_scale="log", ), ) -def test_tensorflow__rmatmul__( +def test_tensorflow__rsub__( dtype_and_x, frontend, frontend_method_data, @@ -1248,47 +1257,21 @@ def test_tensorflow__rmatmul__( ) -# __array__ -@handle_frontend_method( - class_tree=CLASS_TREE, - init_tree="tensorflow.constant", - method_name="__array__", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("valid")), - dtype=helpers.get_dtypes("valid", full=False), -) -def test_tensorflow__array__( - dtype_and_x, - dtype, - frontend, - backend_fw, -): - input_dtype, x = dtype_and_x - dtype[0] = np.dtype(dtype[0]) - ret_gt = tf.constant(x[0]).__array__(dtype[0]) - with BackendHandler.update_backend(backend_fw) as ivy_backend: - local_importer = ivy_backend.utils.dynamic_import - function_module = local_importer.import_module( - "ivy.functional.frontends.tensorflow" - ) - ret = function_module.constant(x[0]).__array__(dtype[0]) - helpers.value_test( - ret_np_flat=ret.ravel(), - ret_np_from_gt_flat=ret_gt.ravel(), - ground_truth_backend="tensorflow", - backend=backend_fw, - ) - - -# __invert__ +# __rtruediv__ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="tensorflow.constant", - method_name="__invert__", + method_name="__rtruediv__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("integer") + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, + shared_dtype=True, + large_abs_safety_factor=15, + small_abs_safety_factor=15, + safety_factor_scale="log", ), ) -def test_tensorflow__invert__( +def test_tensorflow__rtruediv__( dtype_and_x, frontend, frontend_method_data, @@ -1305,7 +1288,9 @@ def test_tensorflow__invert__( "value": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, + method_all_as_kwargs_np={ + "x": x[1], + }, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -1314,20 +1299,18 @@ def test_tensorflow__invert__( ) -# __rmul__ +# __rxor__ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="tensorflow.constant", - method_name="__rmul__", + method_name="__rxor__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), + available_dtypes=helpers.get_dtypes("integer"), num_arrays=2, shared_dtype=True, - min_value=-100, - max_value=100, ), ) -def test_tensorflow__rmul__( +def test_tensorflow__rxor__( dtype_and_x, frontend, frontend_method_data, @@ -1355,14 +1338,18 @@ def test_tensorflow__rmul__( ) -# __rpow__ +# __sub__ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="tensorflow.constant", - method_name="__rpow__", - dtype_and_x=_pow_helper_shared_dtype(), + method_name="__sub__", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, + shared_dtype=True, + ), ) -def test_tensorflow__rpow__( +def test_tensorflow__sub__( dtype_and_x, frontend, frontend_method_data, @@ -1380,7 +1367,7 @@ def test_tensorflow__rpow__( }, method_input_dtypes=input_dtype, method_all_as_kwargs_np={ - "x": x[1], + "y": x[1], }, frontend=frontend, frontend_method_data=frontend_method_data, @@ -1390,24 +1377,21 @@ def test_tensorflow__rpow__( ) -# __pow__ +# __truediv__ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="tensorflow.constant", - method_name="__pow__", + method_name="__truediv__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=[ - "float16", - "float32", - "float64", - "int32", - "int64", - ], + available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2, shared_dtype=True, + large_abs_safety_factor=15, + small_abs_safety_factor=15, + safety_factor_scale="log", ), ) -def test_tensorflow__pow__( +def test_tensorflow__truediv__( dtype_and_x, frontend, frontend_method_data, @@ -1417,13 +1401,7 @@ def test_tensorflow__pow__( on_device, ): input_dtype, x = dtype_and_x - if x[1].dtype == "int32" or x[1].dtype == "int64": - if x[1].ndim == 0: - if x[1] < 0: - x[1] *= -1 - else: - x[1][(x[1] < 0).nonzero()] *= -1 - + assume(not np.any(np.isclose(x[1], 0))) helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -1442,23 +1420,19 @@ def test_tensorflow__pow__( ) -# __getitem__ +# __xor__ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="tensorflow.constant", - method_name="__getitem__", - dtype_x_index=helpers.dtype_array_query( - available_dtypes=helpers.get_dtypes("valid"), - ).filter( - lambda x: ( - all(_check_query(i) for i in x[-1]) - if isinstance(x[-1], tuple) - else _check_query(x[-1]) - ) + method_name="__xor__", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("integer"), + num_arrays=2, + shared_dtype=True, ), ) -def test_tensorflow__getitem__( - dtype_x_index, +def test_tensorflow__xor__( + dtype_and_x, frontend, frontend_method_data, init_flags, @@ -1466,13 +1440,17 @@ def test_tensorflow__getitem__( backend_fw, on_device, ): - input_dtype, x, index = dtype_x_index + input_dtype, x = dtype_and_x helpers.test_frontend_method( - init_input_dtypes=[input_dtype[0]], + init_input_dtypes=input_dtype, backend_to_test=backend_fw, - init_all_as_kwargs_np={"value": x}, - method_input_dtypes=[*input_dtype[1:]], - method_all_as_kwargs_np={"slice_spec": index}, + init_all_as_kwargs_np={ + "value": x[0], + }, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={ + "y": x[1], + }, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -1481,75 +1459,67 @@ def test_tensorflow__getitem__( ) -@st.composite -def _array_and_shape( - draw, - *, - min_num_dims=1, - max_num_dims=3, - min_dim_size=1, - max_dim_size=10, +@given( + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid", prune_function=False) + ), +) +def test_tensorflow_tensor_device( + dtype_x, + backend_fw, ): - if isinstance(min_dim_size, st._internal.SearchStrategy): - min_dim_size = draw(min_dim_size) - if isinstance(max_dim_size, st._internal.SearchStrategy): - max_dim_size = draw(max_dim_size) + ivy.set_backend(backend_fw) + _, data = dtype_x + data = ivy.native_array(data[0]) + x = EagerTensor(data) + ivy.utils.assertions.check_equal(x.device, ivy.dev(data), as_array=False) + ivy.previous_backend() - available_dtypes = draw(helpers.get_dtypes("numeric")) - dtype = draw( - helpers.array_dtypes( - num_arrays=1, - available_dtypes=available_dtypes, - ) - ) - dtype.append("int32") - shape = draw( - st.shared( - helpers.get_shape( - min_num_dims=min_num_dims, - max_num_dims=max_num_dims, - min_dim_size=min_dim_size, - max_dim_size=max_dim_size, - ), - key="shape", - ) - ) - array = draw( - helpers.array_values( - dtype=dtype[0], - shape=shape, - ) - ) - to_shape = [(None if draw(st.booleans()) else _) for _ in shape] - return dtype, [array, to_shape] +@given( + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid", prune_function=False), + ), +) +def test_tensorflow_tensor_dtype( + dtype_x, + backend_fw, +): + ivy.set_backend(backend_fw) + dtype, data = dtype_x + x = EagerTensor(data[0]) + ivy.utils.assertions.check_equal(x.dtype, ivy.Dtype(dtype[0]), as_array=False) + ivy.previous_backend() @handle_frontend_method( class_tree=CLASS_TREE, init_tree="tensorflow.constant", - method_name="set_shape", - dtype_and_x=_array_and_shape( - min_num_dims=0, - max_num_dims=5, + method_name="get_shape", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=1, + min_dim_size=1, ), ) -def test_tensorflow_tensor_set_shape( +def test_tensorflow_tensor_get_shape( dtype_and_x, frontend, frontend_method_data, init_flags, method_flags, - backend_fw, on_device, + backend_fw, ): input_dtype, x = dtype_and_x helpers.test_frontend_method( - init_input_dtypes=[input_dtype[0]], + init_input_dtypes=input_dtype, backend_to_test=backend_fw, - init_all_as_kwargs_np={"value": x[0]}, + init_all_as_kwargs_np={ + "value": x[0], + }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={"shape": x[1]}, + method_all_as_kwargs_np={}, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -1558,17 +1528,38 @@ def test_tensorflow_tensor_set_shape( ) -# __len__ +@given( + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid", prune_function=False) + ), +) +def test_tensorflow_tensor_ivy_array( + dtype_x, + backend_fw, +): + ivy.set_backend(backend_fw) + _, data = dtype_x + x = EagerTensor(data[0]) + ret = helpers.flatten_and_to_np(ret=x.ivy_array.data, backend=backend_fw) + ret_gt = helpers.flatten_and_to_np(ret=data[0], backend=backend_fw) + helpers.value_test( + ret_np_flat=ret, + ret_np_from_gt_flat=ret_gt, + ground_truth_backend="tensorflow", + ) + ivy.previous_backend() + + @handle_frontend_method( class_tree=CLASS_TREE, init_tree="tensorflow.constant", - method_name="__len__", + method_name="set_shape", dtype_and_x=_array_and_shape( - min_num_dims=1, + min_num_dims=0, max_num_dims=5, ), ) -def test_tensorflow__len__( +def test_tensorflow_tensor_set_shape( dtype_and_x, frontend, frontend_method_data, @@ -1579,16 +1570,33 @@ def test_tensorflow__len__( ): input_dtype, x = dtype_and_x helpers.test_frontend_method( - init_input_dtypes=input_dtype, + init_input_dtypes=[input_dtype[0]], backend_to_test=backend_fw, - init_all_as_kwargs_np={ - "value": x[0], - }, + init_all_as_kwargs_np={"value": x[0]}, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, + method_all_as_kwargs_np={"shape": x[1]}, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, on_device=on_device, ) + + +@given( + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid", prune_function=False), + ret_shape=True, + ), +) +def test_tensorflow_tensor_shape( + dtype_x, + backend_fw, +): + ivy.set_backend(backend_fw) + dtype, data, shape = dtype_x + x = EagerTensor(data[0]) + ivy.utils.assertions.check_equal( + x.ivy_array.shape, ivy.Shape(shape), as_array=False + ) + ivy.previous_backend() diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_blas_and_lapack_ops.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_blas_and_lapack_ops.py index 701c2b85c36e5..217a2b2edaf8a 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_blas_and_lapack_ops.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_blas_and_lapack_ops.py @@ -12,81 +12,164 @@ ) -# helpers +# --- Helpers --- # +# --------------- # + + @st.composite -def _get_dtype_and_square_matrix(draw): - dim_size = draw(helpers.ints(min_value=2, max_value=5)) +def _generate_chain_matmul_dtype_and_arrays(draw): dtype = draw(helpers.get_dtypes("float", full=True)) - dtype = [ + input_dtype = [ draw(st.sampled_from(tuple(set(dtype).difference({"bfloat16", "float16"})))) ] - mat = draw( - helpers.array_values( - dtype=dtype[0], shape=(dim_size, dim_size), min_value=0, max_value=10 + matrices_dims = draw( + st.lists(st.integers(min_value=2, max_value=10), min_size=4, max_size=4) + ) + shape_1 = (matrices_dims[0], matrices_dims[1]) + shape_2 = (matrices_dims[1], matrices_dims[2]) + shape_3 = (matrices_dims[2], matrices_dims[3]) + + matrix_1 = draw( + helpers.dtype_and_values( + shape=shape_1, + dtype=input_dtype, + min_value=-10, + max_value=10, + ) + ) + matrix_2 = draw( + helpers.dtype_and_values( + shape=shape_2, + dtype=input_dtype, + min_value=-10, + max_value=10, + ) + ) + matrix_3 = draw( + helpers.dtype_and_values( + shape=shape_3, + dtype=input_dtype, + min_value=-10, + max_value=10, ) ) - return dtype, mat + + return input_dtype, [matrix_1[1][0], matrix_2[1][0], matrix_3[1][0]] @st.composite -def _get_dtype_input_and_vectors(draw, with_input=False, same_size=False): +def _get_dtype_and_3dbatch_matrices(draw, with_input=False, input_3d=False): dim_size1 = draw(helpers.ints(min_value=2, max_value=5)) - dim_size2 = dim_size1 if same_size else draw(helpers.ints(min_value=2, max_value=5)) + dim_size2 = draw(helpers.ints(min_value=2, max_value=5)) + shared_size = draw(helpers.ints(min_value=2, max_value=5)) dtype = draw(helpers.get_dtypes("float", full=True)) dtype = [ draw(st.sampled_from(tuple(set(dtype).difference({"bfloat16", "float16"})))) ] - vec1 = draw( + batch_size = draw(helpers.ints(min_value=2, max_value=4)) + mat1 = draw( helpers.array_values( - dtype=dtype[0], shape=(dim_size1,), min_value=2, max_value=5 + dtype=dtype[0], + shape=(batch_size, dim_size1, shared_size), + min_value=2, + max_value=5, ) ) - vec2 = draw( + mat2 = draw( helpers.array_values( - dtype=dtype[0], shape=(dim_size2,), min_value=2, max_value=5 + dtype=dtype[0], + shape=(batch_size, shared_size, dim_size2), + min_value=2, + max_value=5, ) ) if with_input: + if input_3d: + input = draw( + helpers.array_values( + dtype=dtype[0], + shape=(batch_size, dim_size1, dim_size2), + min_value=2, + max_value=5, + ) + ) + return dtype, input, mat1, mat2 input = draw( helpers.array_values( dtype=dtype[0], shape=(dim_size1, dim_size2), min_value=2, max_value=5 ) ) - return dtype, input, vec1, vec2 - return dtype, vec1, vec2 + return dtype, input, mat1, mat2 + return dtype, mat1, mat2 @st.composite -def _get_dtype_input_and_matrices(draw, with_input=False): - dim_size1 = draw(helpers.ints(min_value=2, max_value=5)) - dim_size2 = draw(helpers.ints(min_value=2, max_value=5)) +def _get_dtype_and_matrices(draw): + dim1 = draw(helpers.ints(min_value=2, max_value=7)) + dim2 = draw(helpers.ints(min_value=2, max_value=7)) + dtype = draw(helpers.get_dtypes("float", full=False)) + + matr1 = draw( + helpers.array_values( + dtype=dtype[0], shape=(dim1, dim2), min_value=2, max_value=10 + ) + ) + matr2 = draw( + helpers.array_values( + dtype=dtype[0], shape=(dim1, dim2), min_value=2, max_value=10 + ) + ) + + return dtype, matr1, matr2 + + +# helpers +@st.composite +def _get_dtype_and_square_matrix(draw): + dim_size = draw(helpers.ints(min_value=2, max_value=5)) + dtype = draw(helpers.get_dtypes("float", full=True)) + dtype = [ + draw(st.sampled_from(tuple(set(dtype).difference({"bfloat16", "float16"})))) + ] + mat = draw( + helpers.array_values( + dtype=dtype[0], shape=(dim_size, dim_size), min_value=0, max_value=10 + ) + ) + return dtype, mat + + +@st.composite +def _get_dtype_input_and_mat_vec(draw, *, with_input=False): + dim_size = draw(helpers.ints(min_value=2, max_value=5)) shared_size = draw(helpers.ints(min_value=2, max_value=5)) dtype = draw(helpers.get_dtypes("float", full=True)) dtype = [ draw(st.sampled_from(tuple(set(dtype).difference({"bfloat16", "float16"})))) ] - mat1 = draw( + + mat = draw( helpers.array_values( - dtype=dtype[0], shape=(dim_size1, shared_size), min_value=2, max_value=5 + dtype=dtype[0], shape=(dim_size, shared_size), min_value=2, max_value=5 ) ) - mat2 = draw( + vec = draw( helpers.array_values( - dtype=dtype[0], shape=(shared_size, dim_size2), min_value=2, max_value=5 + dtype=dtype[0], shape=(shared_size,), min_value=2, max_value=5 ) ) if with_input: input = draw( helpers.array_values( - dtype=dtype[0], shape=(dim_size1, dim_size2), min_value=2, max_value=5 + dtype=dtype[0], shape=(dim_size,), min_value=2, max_value=5 ) ) - return dtype, input, mat1, mat2 - return dtype, mat1, mat2 + return dtype, input, mat, vec + return dtype, mat, vec @st.composite -def _get_dtype_and_3dbatch_matrices(draw, with_input=False, input_3d=False): +def _get_dtype_input_and_matrices(draw, with_input=False): dim_size1 = draw(helpers.ints(min_value=2, max_value=5)) dim_size2 = draw(helpers.ints(min_value=2, max_value=5)) shared_size = draw(helpers.ints(min_value=2, max_value=5)) @@ -94,34 +177,17 @@ def _get_dtype_and_3dbatch_matrices(draw, with_input=False, input_3d=False): dtype = [ draw(st.sampled_from(tuple(set(dtype).difference({"bfloat16", "float16"})))) ] - batch_size = draw(helpers.ints(min_value=2, max_value=4)) mat1 = draw( helpers.array_values( - dtype=dtype[0], - shape=(batch_size, dim_size1, shared_size), - min_value=2, - max_value=5, + dtype=dtype[0], shape=(dim_size1, shared_size), min_value=2, max_value=5 ) ) mat2 = draw( helpers.array_values( - dtype=dtype[0], - shape=(batch_size, shared_size, dim_size2), - min_value=2, - max_value=5, + dtype=dtype[0], shape=(shared_size, dim_size2), min_value=2, max_value=5 ) ) if with_input: - if input_3d: - input = draw( - helpers.array_values( - dtype=dtype[0], - shape=(batch_size, dim_size1, dim_size2), - min_value=2, - max_value=5, - ) - ) - return dtype, input, mat1, mat2 input = draw( helpers.array_values( dtype=dtype[0], shape=(dim_size1, dim_size2), min_value=2, max_value=5 @@ -132,32 +198,35 @@ def _get_dtype_and_3dbatch_matrices(draw, with_input=False, input_3d=False): @st.composite -def _get_dtype_input_and_mat_vec(draw, *, with_input=False): - dim_size = draw(helpers.ints(min_value=2, max_value=5)) - shared_size = draw(helpers.ints(min_value=2, max_value=5)) +def _get_dtype_input_and_vectors(draw, with_input=False, same_size=False): + dim_size1 = draw(helpers.ints(min_value=2, max_value=5)) + dim_size2 = dim_size1 if same_size else draw(helpers.ints(min_value=2, max_value=5)) dtype = draw(helpers.get_dtypes("float", full=True)) dtype = [ draw(st.sampled_from(tuple(set(dtype).difference({"bfloat16", "float16"})))) ] - - mat = draw( + vec1 = draw( helpers.array_values( - dtype=dtype[0], shape=(dim_size, shared_size), min_value=2, max_value=5 + dtype=dtype[0], shape=(dim_size1,), min_value=2, max_value=5 ) ) - vec = draw( + vec2 = draw( helpers.array_values( - dtype=dtype[0], shape=(shared_size,), min_value=2, max_value=5 + dtype=dtype[0], shape=(dim_size2,), min_value=2, max_value=5 ) ) if with_input: input = draw( helpers.array_values( - dtype=dtype[0], shape=(dim_size,), min_value=2, max_value=5 + dtype=dtype[0], shape=(dim_size1, dim_size2), min_value=2, max_value=5 ) ) - return dtype, input, mat, vec - return dtype, mat, vec + return dtype, input, vec1, vec2 + return dtype, vec1, vec2 + + +# --- Main --- # +# ------------ # # addbmm @@ -392,54 +461,12 @@ def test_torch_baddbmm( ) -@st.composite -def _generate_chain_matmul_dtype_and_arrays(draw): - dtype = draw(helpers.get_dtypes("float", full=True)) - input_dtype = [ - draw(st.sampled_from(tuple(set(dtype).difference({"bfloat16", "float16"})))) - ] - matrices_dims = draw( - st.lists(st.integers(min_value=2, max_value=10), min_size=4, max_size=4) - ) - shape_1 = (matrices_dims[0], matrices_dims[1]) - shape_2 = (matrices_dims[1], matrices_dims[2]) - shape_3 = (matrices_dims[2], matrices_dims[3]) - - matrix_1 = draw( - helpers.dtype_and_values( - shape=shape_1, - dtype=input_dtype, - min_value=-10, - max_value=10, - ) - ) - matrix_2 = draw( - helpers.dtype_and_values( - shape=shape_2, - dtype=input_dtype, - min_value=-10, - max_value=10, - ) - ) - matrix_3 = draw( - helpers.dtype_and_values( - shape=shape_3, - dtype=input_dtype, - min_value=-10, - max_value=10, - ) - ) - - return input_dtype, [matrix_1[1][0], matrix_2[1][0], matrix_3[1][0]] - - -# chain_matmul +# bmm @handle_frontend_test( - fn_tree="torch.chain_matmul", - dtype_and_matrices=_generate_chain_matmul_dtype_and_arrays(), + fn_tree="torch.bmm", + dtype_and_matrices=_get_dtype_and_3dbatch_matrices(), ) -def test_torch_chain_matmul( - *, +def test_torch_bmm( dtype_and_matrices, on_device, fn_tree, @@ -447,9 +474,7 @@ def test_torch_chain_matmul( test_flags, backend_fw, ): - dtype, matrices = dtype_and_matrices - args = {f"x{i}": matrix for i, matrix in enumerate(matrices)} - test_flags.num_positional_args = len(matrices) + dtype, mat1, mat2 = dtype_and_matrices helpers.test_frontend_function( input_dtypes=dtype, backend_to_test=backend_fw, @@ -457,17 +482,19 @@ def test_torch_chain_matmul( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - rtol=1e-03, - **args, + rtol=1e-02, + input=mat1, + mat2=mat2, ) -# bmm +# chain_matmul @handle_frontend_test( - fn_tree="torch.bmm", - dtype_and_matrices=_get_dtype_and_3dbatch_matrices(), + fn_tree="torch.chain_matmul", + dtype_and_matrices=_generate_chain_matmul_dtype_and_arrays(), ) -def test_torch_bmm( +def test_torch_chain_matmul( + *, dtype_and_matrices, on_device, fn_tree, @@ -475,7 +502,9 @@ def test_torch_bmm( test_flags, backend_fw, ): - dtype, mat1, mat2 = dtype_and_matrices + dtype, matrices = dtype_and_matrices + args = {f"x{i}": matrix for i, matrix in enumerate(matrices)} + test_flags.num_positional_args = len(matrices) helpers.test_frontend_function( input_dtypes=dtype, backend_to_test=backend_fw, @@ -483,9 +512,8 @@ def test_torch_bmm( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - rtol=1e-02, - input=mat1, - mat2=mat2, + rtol=1e-03, + **args, ) @@ -530,6 +558,33 @@ def test_torch_cholesky( ) +# dot +@handle_frontend_test( + fn_tree="torch.dot", + dtype_and_vecs=_get_dtype_input_and_vectors(same_size=True), +) +def test_torch_dot( + dtype_and_vecs, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + dtype, vec1, vec2 = dtype_and_vecs + test_flags.num_positional_args = len(dtype_and_vecs) - 1 + helpers.test_frontend_function( + input_dtypes=dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=vec1, + other=vec2, + ) + + # ger @handle_frontend_test( fn_tree="torch.ger", @@ -858,20 +913,32 @@ def test_torch_svd( ) -# vdot @handle_frontend_test( - fn_tree="torch.vdot", - dtype_and_vecs=_get_dtype_input_and_vectors(same_size=True), + fn_tree="torch.trapezoid", + test_with_out=st.just(False), + dtype_y_x=_get_dtype_and_matrices(), + use_x=st.booleans(), + dim=st.integers(min_value=0, max_value=1), + dx=st.floats(), ) -def test_torch_vdot( - dtype_and_vecs, +def test_torch_trapezoid( + dtype_y_x, + use_x, + dim, + dx, on_device, fn_tree, frontend, test_flags, backend_fw, ): - dtype, vec1, vec2 = dtype_and_vecs + dtype, y, x = dtype_y_x + if use_x: + test_flags.num_positional_args = 2 + kwargs = {"y": y, "x": x, "dim": -1} + else: + test_flags.num_positional_args = 1 + kwargs = {"y": y, "dx": dx, "dim": dim} helpers.test_frontend_function( input_dtypes=dtype, backend_to_test=backend_fw, @@ -879,17 +946,16 @@ def test_torch_vdot( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=vec1, - other=vec2, + **kwargs, ) -# dot +# vdot @handle_frontend_test( - fn_tree="torch.dot", + fn_tree="torch.vdot", dtype_and_vecs=_get_dtype_input_and_vectors(same_size=True), ) -def test_torch_dot( +def test_torch_vdot( dtype_and_vecs, on_device, fn_tree, @@ -898,7 +964,6 @@ def test_torch_dot( backend_fw, ): dtype, vec1, vec2 = dtype_and_vecs - test_flags.num_positional_args = len(dtype_and_vecs) - 1 helpers.test_frontend_function( input_dtypes=dtype, backend_to_test=backend_fw, @@ -909,60 +974,3 @@ def test_torch_dot( input=vec1, other=vec2, ) - - -@st.composite -def _get_dtype_and_matrices(draw): - dim1 = draw(helpers.ints(min_value=2, max_value=7)) - dim2 = draw(helpers.ints(min_value=2, max_value=7)) - dtype = draw(helpers.get_dtypes("float", full=False)) - - matr1 = draw( - helpers.array_values( - dtype=dtype[0], shape=(dim1, dim2), min_value=2, max_value=10 - ) - ) - matr2 = draw( - helpers.array_values( - dtype=dtype[0], shape=(dim1, dim2), min_value=2, max_value=10 - ) - ) - - return dtype, matr1, matr2 - - -@handle_frontend_test( - fn_tree="torch.trapezoid", - test_with_out=st.just(False), - dtype_y_x=_get_dtype_and_matrices(), - use_x=st.booleans(), - dim=st.integers(min_value=0, max_value=1), - dx=st.floats(), -) -def test_torch_trapezoid( - dtype_y_x, - use_x, - dim, - dx, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - dtype, y, x = dtype_y_x - if use_x: - test_flags.num_positional_args = 2 - kwargs = {"y": y, "x": x, "dim": -1} - else: - test_flags.num_positional_args = 1 - kwargs = {"y": y, "dx": dx, "dim": dim} - helpers.test_frontend_function( - input_dtypes=dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - **kwargs, - ) diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_comparison_ops.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_comparison_ops.py index 77dd0318fa03e..751c569eb42a1 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_comparison_ops.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_comparison_ops.py @@ -7,6 +7,28 @@ from ivy_tests.test_ivy.helpers.testing_helpers import handle_frontend_test +# --- Helpers --- # +# --------------- # + + +@st.composite +def _topk_helper(draw): + dtype, x, axis = draw( + helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("numeric"), + min_num_dims=1, + force_int_axis=True, + valid_axis=True, + ) + ) + k = draw(st.integers(min_value=1, max_value=x[0].shape[axis])) + return dtype, x, axis, k + + +# --- Main --- # +# ------------ # + + # allclose @handle_frontend_test( fn_tree="torch.allclose", @@ -43,17 +65,55 @@ def test_torch_allclose( ) -# equal +# argsort @handle_frontend_test( - fn_tree="torch.equal", + fn_tree="torch.argsort", + dtype_input_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("numeric"), + min_num_dims=1, + max_num_dims=5, + min_dim_size=1, + max_dim_size=5, + min_axis=-1, + max_axis=0, + ), + descending=st.booleans(), +) +def test_torch_argsort( + *, + dtype_input_axis, + descending, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, input, axis = dtype_input_axis + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=input[0], + dim=axis, + descending=descending, + ) + + +# eq +@handle_frontend_test( + fn_tree="torch.eq", dtype_and_inputs=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid", full=False), + available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2, allow_inf=False, shared_dtype=True, ), ) -def test_torch_equal( +def test_torch_eq( *, dtype_and_inputs, on_device, @@ -75,17 +135,17 @@ def test_torch_equal( ) -# eq +# equal @handle_frontend_test( - fn_tree="torch.eq", + fn_tree="torch.equal", dtype_and_inputs=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), + available_dtypes=helpers.get_dtypes("valid", full=False), num_arrays=2, allow_inf=False, shared_dtype=True, ), ) -def test_torch_eq( +def test_torch_equal( *, dtype_and_inputs, on_device, @@ -107,31 +167,27 @@ def test_torch_eq( ) -# argsort +# fmax @handle_frontend_test( - fn_tree="torch.argsort", - dtype_input_axis=helpers.dtype_values_axis( + fn_tree="torch.fmax", + dtype_and_inputs=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), - min_num_dims=1, - max_num_dims=5, - min_dim_size=1, - max_dim_size=5, - min_axis=-1, - max_axis=0, + num_arrays=2, + shared_dtype=True, + min_value=-np.inf, + max_value=np.inf, ), - descending=st.booleans(), ) -def test_torch_argsort( +def test_torch_fmax( *, - dtype_input_axis, - descending, + dtype_and_inputs, on_device, fn_tree, frontend, test_flags, backend_fw, ): - input_dtype, input, axis = dtype_input_axis + input_dtype, inputs = dtype_and_inputs helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -139,24 +195,23 @@ def test_torch_argsort( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=input[0], - dim=axis, - descending=descending, + input=inputs[0], + other=inputs[1], ) -# greater_equal +# fmin @handle_frontend_test( - fn_tree="torch.ge", - aliases=["torch.greater_equal"], + fn_tree="torch.fmin", dtype_and_inputs=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2, - allow_inf=False, shared_dtype=True, + min_value=-np.inf, + max_value=np.inf, ), ) -def test_torch_greater_equal( +def test_torch_fmin( *, dtype_and_inputs, on_device, @@ -211,6 +266,39 @@ def test_torch_greater( ) +# greater_equal +@handle_frontend_test( + fn_tree="torch.ge", + aliases=["torch.greater_equal"], + dtype_and_inputs=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, + allow_inf=False, + shared_dtype=True, + ), +) +def test_torch_greater_equal( + *, + dtype_and_inputs, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, inputs = dtype_and_inputs + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=inputs[0], + other=inputs[1], + ) + + # isclose @handle_frontend_test( fn_tree="torch.isclose", @@ -277,25 +365,28 @@ def test_torch_isfinite( ) -# isinf @handle_frontend_test( - fn_tree="torch.isinf", - dtype_and_input=helpers.dtype_and_values( + fn_tree="torch.isin", + dtype_and_inputs=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), - min_value=-np.inf, - max_value=np.inf, + num_arrays=2, + shared_dtype=True, ), + assume_unique=st.booleans(), + invert=st.booleans(), ) -def test_torch_isinf( +def test_torch_isin( *, - dtype_and_input, + dtype_and_inputs, + assume_unique, + invert, on_device, fn_tree, frontend, test_flags, backend_fw, ): - input_dtype, input = dtype_and_input + input_dtype, inputs = dtype_and_inputs helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -303,20 +394,23 @@ def test_torch_isinf( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=input[0], + elements=inputs[0], + test_elements=inputs[1], + assume_unique=assume_unique, + invert=invert, ) -# isreal +# isinf @handle_frontend_test( - fn_tree="torch.isreal", + fn_tree="torch.isinf", dtype_and_input=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), min_value=-np.inf, max_value=np.inf, ), ) -def test_torch_isreal( +def test_torch_isinf( *, dtype_and_input, on_device, @@ -337,16 +431,16 @@ def test_torch_isreal( ) -# isposinf +# isnan @handle_frontend_test( - fn_tree="torch.isposinf", + fn_tree="torch.isnan", dtype_and_input=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), + available_dtypes=helpers.get_dtypes("valid", full=False), min_value=-np.inf, max_value=np.inf, ), ) -def test_torch_isposinf( +def test_torch_isnan( *, dtype_and_input, on_device, @@ -397,31 +491,25 @@ def test_torch_isneginf( ) -# sort +# isposinf @handle_frontend_test( - fn_tree="torch.sort", - dtype_input_axis=helpers.dtype_values_axis( + fn_tree="torch.isposinf", + dtype_and_input=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), - min_num_dims=1, - min_dim_size=1, - min_axis=-1, - max_axis=0, + min_value=-np.inf, + max_value=np.inf, ), - descending=st.booleans(), - stable=st.booleans(), ) -def test_torch_sort( +def test_torch_isposinf( *, - dtype_input_axis, - descending, - stable, + dtype_and_input, on_device, fn_tree, frontend, test_flags, backend_fw, ): - input_dtype, input, axis = dtype_input_axis + input_dtype, input = dtype_and_input helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -430,22 +518,19 @@ def test_torch_sort( fn_tree=fn_tree, on_device=on_device, input=input[0], - dim=axis, - descending=descending, - stable=stable, ) -# isnan +# isreal @handle_frontend_test( - fn_tree="torch.isnan", + fn_tree="torch.isreal", dtype_and_input=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid", full=False), + available_dtypes=helpers.get_dtypes("numeric"), min_value=-np.inf, max_value=np.inf, ), ) -def test_torch_isnan( +def test_torch_isreal( *, dtype_and_input, on_device, @@ -466,26 +551,31 @@ def test_torch_isnan( ) -# less_equal +# kthvalue @handle_frontend_test( - fn_tree="torch.less_equal", - aliases=["torch.le"], - dtype_and_inputs=helpers.dtype_and_values( + fn_tree="torch.kthvalue", + dtype_input_axis=helpers.dtype_values_axis( available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, - shared_dtype=True, - ), + min_num_dims=1, + valid_axis=True, + force_int_axis=True, + ).filter(lambda v: len(np.unique(v[1][0])) == len(np.ravel(v[1][0]))), + k=st.integers(min_value=1), + keepdim=st.booleans(), ) -def test_torch_less_equal( +def test_torch_kthvalue( *, - dtype_and_inputs, + dtype_input_axis, + k, + keepdim, on_device, fn_tree, frontend, test_flags, backend_fw, ): - input_dtype, inputs = dtype_and_inputs + input_dtype, input, dim = dtype_input_axis + assume(k <= input[0].shape[dim]) helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -493,8 +583,10 @@ def test_torch_less_equal( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=inputs[0], - other=inputs[1], + input=input[0], + k=k, + dim=dim, + keepdim=keepdim, ) @@ -530,83 +622,17 @@ def test_torch_less( ) -# not_equal -@handle_frontend_test( - fn_tree="torch.not_equal", - aliases=["torch.ne"], - dtype_and_inputs=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid", full=False), - num_arrays=2, - shared_dtype=True, - ), -) -def test_torch_not_equal( - *, - dtype_and_inputs, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, inputs = dtype_and_inputs - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=inputs[0], - other=inputs[1], - ) - - -@handle_frontend_test( - fn_tree="torch.isin", - dtype_and_inputs=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, - shared_dtype=True, - ), - assume_unique=st.booleans(), - invert=st.booleans(), -) -def test_torch_isin( - *, - dtype_and_inputs, - assume_unique, - invert, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, inputs = dtype_and_inputs - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - elements=inputs[0], - test_elements=inputs[1], - assume_unique=assume_unique, - invert=invert, - ) - - +# less_equal @handle_frontend_test( - fn_tree="torch.minimum", + fn_tree="torch.less_equal", + aliases=["torch.le"], dtype_and_inputs=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2, shared_dtype=True, ), ) -def test_torch_minimum( +def test_torch_less_equal( *, dtype_and_inputs, on_device, @@ -628,18 +654,16 @@ def test_torch_minimum( ) -# fmax +# maximum @handle_frontend_test( - fn_tree="torch.fmax", + fn_tree="torch.maximum", dtype_and_inputs=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2, shared_dtype=True, - min_value=-np.inf, - max_value=np.inf, ), ) -def test_torch_fmax( +def test_torch_maximum( *, dtype_and_inputs, on_device, @@ -661,18 +685,15 @@ def test_torch_fmax( ) -# fmin @handle_frontend_test( - fn_tree="torch.fmin", + fn_tree="torch.minimum", dtype_and_inputs=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2, shared_dtype=True, - min_value=-np.inf, - max_value=np.inf, ), ) -def test_torch_fmin( +def test_torch_minimum( *, dtype_and_inputs, on_device, @@ -724,16 +745,17 @@ def test_torch_msort( ) -# maximum +# not_equal @handle_frontend_test( - fn_tree="torch.maximum", + fn_tree="torch.not_equal", + aliases=["torch.ne"], dtype_and_inputs=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), + available_dtypes=helpers.get_dtypes("valid", full=False), num_arrays=2, shared_dtype=True, ), ) -def test_torch_maximum( +def test_torch_not_equal( *, dtype_and_inputs, on_device, @@ -755,31 +777,31 @@ def test_torch_maximum( ) -# kthvalue +# sort @handle_frontend_test( - fn_tree="torch.kthvalue", + fn_tree="torch.sort", dtype_input_axis=helpers.dtype_values_axis( available_dtypes=helpers.get_dtypes("numeric"), min_num_dims=1, - valid_axis=True, - force_int_axis=True, - ).filter(lambda v: len(np.unique(v[1][0])) == len(np.ravel(v[1][0]))), - k=st.integers(min_value=1), - keepdim=st.booleans(), + min_dim_size=1, + min_axis=-1, + max_axis=0, + ), + descending=st.booleans(), + stable=st.booleans(), ) -def test_torch_kthvalue( +def test_torch_sort( *, dtype_input_axis, - k, - keepdim, + descending, + stable, on_device, fn_tree, frontend, test_flags, backend_fw, ): - input_dtype, input, dim = dtype_input_axis - assume(k <= input[0].shape[dim]) + input_dtype, input, axis = dtype_input_axis helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -788,24 +810,10 @@ def test_torch_kthvalue( fn_tree=fn_tree, on_device=on_device, input=input[0], - k=k, - dim=dim, - keepdim=keepdim, - ) - - -@st.composite -def _topk_helper(draw): - dtype, x, axis = draw( - helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("numeric"), - min_num_dims=1, - force_int_axis=True, - valid_axis=True, - ) + dim=axis, + descending=descending, + stable=stable, ) - k = draw(st.integers(min_value=1, max_value=x[0].shape[axis])) - return dtype, x, axis, k # topk diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_creation_ops.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_creation_ops.py index 25f72027836f0..fcbd20246b6b3 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_creation_ops.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_creation_ops.py @@ -9,6 +9,38 @@ from ivy_tests.test_ivy.helpers import handle_frontend_test, BackendHandler +# --- Helpers --- # +# --------------- # + + +@st.composite +def _as_strided_helper(draw): + x_dtype, x, shape = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + min_num_dims=1, + ret_shape=True, + ) + ) + ndim = len(shape) + numel = x[0].size + offset = draw(st.integers(min_value=0, max_value=numel - 1)) + numel = numel - offset + size = draw( + helpers.get_shape( + min_num_dims=ndim, + max_num_dims=ndim, + ).filter(lambda s: math.prod(s) <= numel) + ) + stride = draw( + helpers.get_shape( + min_num_dims=ndim, + max_num_dims=ndim, + ).filter(lambda s: all(numel // s_i >= size[i] for i, s_i in enumerate(s))) + ) + return x_dtype, x, size, stride, offset + + # Helper functions @@ -29,6 +61,47 @@ def _fill_value(draw): return ret +@st.composite +def _get_dtype_buffer_count_offset(draw): + dtype, value = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + ) + ) + value = np.array(value) + length = value.size + value = value.tobytes() + + offset = draw(helpers.ints(min_value=0, max_value=length - 1)) + count = draw(helpers.ints(min_value=-(2**30), max_value=length - offset)) + if count == 0: + count = -1 + offset = offset * np.dtype(dtype[0]).itemsize + + return dtype, value, count, offset + + +@st.composite +def _heaviside_helper(draw): + input_dtype, data = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + ) + ) + _, values = draw( + helpers.dtype_and_values( + available_dtypes=input_dtype, + shape=helpers.get_shape( + min_num_dims=1, + max_num_dims=1, + min_dim_size=1, + max_dim_size=1, + ), + ) + ) + return input_dtype, data, values + + @st.composite def _start_stop_step(draw): start = draw(helpers.ints(min_value=0, max_value=50)) @@ -40,23 +113,19 @@ def _start_stop_step(draw): return start, stop, step -# full +# --- Main --- # +# ------------ # + + +# arange @handle_frontend_test( - fn_tree="torch.full", - shape=helpers.get_shape( - allow_none=False, - min_num_dims=1, - max_num_dims=5, - min_dim_size=1, - max_dim_size=10, - ), - fill_value=_fill_value(), - dtype=st.shared(helpers.get_dtypes("numeric", full=False), key="dtype"), + fn_tree="torch.arange", + start_stop_step=_start_stop_step(), + dtype=helpers.get_dtypes("float", full=False), ) -def test_torch_full( +def test_torch_arange( *, - shape, - fill_value, + start_stop_step, dtype, on_device, fn_tree, @@ -64,27 +133,65 @@ def test_torch_full( test_flags, backend_fw, ): + start, stop, step = start_stop_step helpers.test_frontend_function( - input_dtypes=dtype, + input_dtypes=[], backend_to_test=backend_fw, - on_device=on_device, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, - size=shape, - fill_value=fill_value, + on_device=on_device, + start=start, + end=stop, + step=step, + out=None, dtype=dtype[0], device=on_device, ) -# ones_like +# as_strided @handle_frontend_test( - fn_tree="torch.ones_like", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), - dtype=helpers.get_dtypes("numeric", full=False), + fn_tree="torch.as_strided", + dtype_x_and_other=_as_strided_helper(), ) -def test_torch_ones_like( +def test_torch_as_strided( + *, + dtype_x_and_other, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + x_dtype, x, size, stride, offset = dtype_x_and_other + try: + helpers.test_frontend_function( + input_dtypes=x_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=x[0], + size=size, + stride=stride, + storage_offset=offset, + ) + except Exception as e: + if hasattr(e, "message") and "out of bounds for storage of size" in e.message: + assume(False) + else: + raise e + + +# as_tensor +@handle_frontend_test( + fn_tree="torch.as_tensor", + dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("valid")), + dtype=helpers.get_dtypes("valid", full=False), +) +def test_torch_as_tensor( *, dtype_and_x, dtype, @@ -102,29 +209,23 @@ def test_torch_ones_like( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=input[0], + data=input[0], dtype=dtype[0], device=on_device, ) -# ones +# asarray @handle_frontend_test( - fn_tree="torch.ones", - size=helpers.ints(min_value=1, max_value=3), - shape=helpers.get_shape( - allow_none=False, - min_num_dims=1, - max_num_dims=5, - min_dim_size=1, - max_dim_size=10, + fn_tree="torch.asarray", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric") ), dtype=helpers.get_dtypes("numeric", full=False), ) -def test_torch_ones( +def test_torch_asarray( *, - shape, - size, + dtype_and_x, dtype, on_device, fn_tree, @@ -132,30 +233,23 @@ def test_torch_ones( test_flags, backend_fw, ): - dims = {} - size = (size,) - if shape is None: - i = 0 - for x_ in size: - dims[f"x{i}"] = x_ - i += 1 + input_dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=dtype, + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - **dims, - size=shape, + obj=x[0], dtype=dtype[0], device=on_device, ) -# zeros +# empty @handle_frontend_test( - fn_tree="torch.zeros", + fn_tree="torch.empty", size=helpers.ints(min_value=1, max_value=3), shape=helpers.get_shape( allow_none=False, @@ -164,9 +258,9 @@ def test_torch_ones( min_dim_size=1, max_dim_size=10, ), - dtype=helpers.get_dtypes("numeric", full=False), + dtype=helpers.get_dtypes("valid", full=False), ) -def test_torch_zeros( +def test_torch_empty( *, size, shape, @@ -185,25 +279,27 @@ def test_torch_zeros( dims[f"x{i}"] = x_ i += 1 helpers.test_frontend_function( - input_dtypes=dtype, + input_dtypes=[], backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, + on_device=on_device, **dims, size=shape, dtype=dtype[0], + test_values=False, device=on_device, ) -# zeros_like +# empty_like @handle_frontend_test( - fn_tree="torch.zeros_like", + fn_tree="torch.empty_like", dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), - dtype=helpers.get_dtypes("numeric", full=False), + dtype=helpers.get_dtypes("valid", full=False), ) -def test_torch_zeros_like( +def test_torch_empty_like( *, dtype_and_x, dtype, @@ -213,109 +309,154 @@ def test_torch_zeros_like( test_flags, backend_fw, ): - input_dtype, input = dtype_and_x + input_dtype, inputs = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, - on_device=on_device, - input=input[0], + input=inputs[0], dtype=dtype[0], device=on_device, + test_values=False, ) -# empty +# from_dlpack @handle_frontend_test( - fn_tree="torch.empty", - size=helpers.ints(min_value=1, max_value=3), - shape=helpers.get_shape( - allow_none=False, - min_num_dims=1, - max_num_dims=5, - min_dim_size=1, - max_dim_size=10, + fn_tree="torch.from_dlpack", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric") ), - dtype=helpers.get_dtypes("valid", full=False), ) -def test_torch_empty( +def test_torch_from_dlpack( *, - size, - shape, - dtype, + dtype_and_x, on_device, fn_tree, frontend, test_flags, backend_fw, ): - dims = {} - size = (size,) - if shape is None: - i = 0 - for x_ in size: - dims[f"x{i}"] = x_ - i += 1 + input_dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=[], + ext_tensor=x[0], backend_to_test=backend_fw, + input_dtypes=input_dtype, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - **dims, - size=shape, - dtype=dtype[0], - test_values=False, - device=on_device, ) -# arange +# from_numpy @handle_frontend_test( - fn_tree="torch.arange", - start_stop_step=_start_stop_step(), - dtype=helpers.get_dtypes("float", full=False), + fn_tree="torch.from_numpy", + dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("valid")), ) -def test_torch_arange( +def test_torch_from_numpy( *, - start_stop_step, - dtype, + dtype_and_x, on_device, fn_tree, frontend, test_flags, backend_fw, ): - start, stop, step = start_stop_step + dtype, input = dtype_and_x helpers.test_frontend_function( - input_dtypes=[], + input_dtypes=dtype, backend_to_test=backend_fw, + on_device=on_device, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, + data=input[0], + ) + + +@handle_frontend_test( + fn_tree="torch.frombuffer", + dtype_buffer_count_offset=_get_dtype_buffer_count_offset(), +) +def test_torch_frombuffer( + dtype_buffer_count_offset, + test_flags, + frontend, + backend_fw, + fn_tree, + on_device, +): + input_dtype, buffer, count, offset = dtype_buffer_count_offset + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + test_flags=test_flags, on_device=on_device, - start=start, - end=stop, - step=step, - out=None, + frontend=frontend, + fn_tree=fn_tree, + buffer=buffer, + dtype=input_dtype[0], + count=count, + offset=offset, + ) + + +# full +@handle_frontend_test( + fn_tree="torch.full", + shape=helpers.get_shape( + allow_none=False, + min_num_dims=1, + max_num_dims=5, + min_dim_size=1, + max_dim_size=10, + ), + fill_value=_fill_value(), + dtype=st.shared(helpers.get_dtypes("numeric", full=False), key="dtype"), +) +def test_torch_full( + *, + shape, + fill_value, + dtype, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + helpers.test_frontend_function( + input_dtypes=dtype, + backend_to_test=backend_fw, + on_device=on_device, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + size=shape, + fill_value=fill_value, dtype=dtype[0], device=on_device, ) -# range +# full_like @handle_frontend_test( - fn_tree="torch.range", - start_stop_step=_start_stop_step(), - dtype=helpers.get_dtypes("float", full=False), - number_positional_args=st.just(3), + fn_tree="torch.full_like", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=st.shared( + helpers.get_dtypes("numeric", full=False), key="dtype" + ) + ), + fill_value=_fill_value(), + dtype=st.shared(helpers.get_dtypes("numeric", full=False), key="dtype"), ) -def test_torch_range( +def test_torch_full_like( *, - start_stop_step, + dtype_and_x, + fill_value, dtype, on_device, fn_tree, @@ -323,19 +464,46 @@ def test_torch_range( test_flags, backend_fw, ): - start, stop, step = start_stop_step + input_dtype, inputs = dtype_and_x helpers.test_frontend_function( - input_dtypes=[], + input_dtypes=input_dtype, backend_to_test=backend_fw, + on_device=on_device, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, - on_device=on_device, - start=start, - end=stop, - step=step, + input=inputs[0], + fill_value=fill_value, dtype=dtype[0], device=on_device, + test_values=False, + ) + + +# heaviside +@handle_frontend_test( + fn_tree="torch.heaviside", + dtype_and_input=_heaviside_helper(), +) +def test_torch_heaviside( + *, + dtype_and_input, + test_flags, + fn_tree, + backend_fw, + on_device, + frontend, +): + input_dtype, data, values = dtype_and_input + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + test_flags=test_flags, + frontend=frontend, + fn_tree=fn_tree, + input=data[0], + values=values[0], + on_device=on_device, ) @@ -411,15 +579,23 @@ def test_torch_logspace( ) -# empty_like +# ones @handle_frontend_test( - fn_tree="torch.empty_like", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), - dtype=helpers.get_dtypes("valid", full=False), + fn_tree="torch.ones", + size=helpers.ints(min_value=1, max_value=3), + shape=helpers.get_shape( + allow_none=False, + min_num_dims=1, + max_num_dims=5, + min_dim_size=1, + max_dim_size=10, + ), + dtype=helpers.get_dtypes("numeric", full=False), ) -def test_torch_empty_like( +def test_torch_ones( *, - dtype_and_x, + shape, + size, dtype, on_device, fn_tree, @@ -427,35 +603,36 @@ def test_torch_empty_like( test_flags, backend_fw, ): - input_dtype, inputs = dtype_and_x + dims = {} + size = (size,) + if shape is None: + i = 0 + for x_ in size: + dims[f"x{i}"] = x_ + i += 1 helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, - input=inputs[0], + on_device=on_device, + **dims, + size=shape, dtype=dtype[0], device=on_device, - test_values=False, ) -# full_like +# ones_like @handle_frontend_test( - fn_tree="torch.full_like", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=st.shared( - helpers.get_dtypes("numeric", full=False), key="dtype" - ) - ), - fill_value=_fill_value(), - dtype=st.shared(helpers.get_dtypes("numeric", full=False), key="dtype"), + fn_tree="torch.ones_like", + dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), + dtype=helpers.get_dtypes("numeric", full=False), ) -def test_torch_full_like( +def test_torch_ones_like( *, dtype_and_x, - fill_value, dtype, on_device, fn_tree, @@ -463,31 +640,30 @@ def test_torch_full_like( test_flags, backend_fw, ): - input_dtype, inputs = dtype_and_x + input_dtype, input = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, - on_device=on_device, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, - input=inputs[0], - fill_value=fill_value, + on_device=on_device, + input=input[0], dtype=dtype[0], device=on_device, - test_values=False, ) -# as_tensor +# range @handle_frontend_test( - fn_tree="torch.as_tensor", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("valid")), - dtype=helpers.get_dtypes("valid", full=False), + fn_tree="torch.range", + start_stop_step=_start_stop_step(), + dtype=helpers.get_dtypes("float", full=False), + number_positional_args=st.just(3), ) -def test_torch_as_tensor( +def test_torch_range( *, - dtype_and_x, + start_stop_step, dtype, on_device, fn_tree, @@ -495,46 +671,22 @@ def test_torch_as_tensor( test_flags, backend_fw, ): - input_dtype, input = dtype_and_x + start, stop, step = start_stop_step helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=[], backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - data=input[0], + start=start, + end=stop, + step=step, dtype=dtype[0], device=on_device, ) -# from_numpy -@handle_frontend_test( - fn_tree="torch.from_numpy", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("valid")), -) -def test_torch_from_numpy( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - dtype, input = dtype_and_x - helpers.test_frontend_function( - input_dtypes=dtype, - backend_to_test=backend_fw, - on_device=on_device, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - data=input[0], - ) - - # tensor @handle_frontend_test( fn_tree="torch.tensor", @@ -565,128 +717,23 @@ def test_torch_tensor( ) -@st.composite -def _heaviside_helper(draw): - input_dtype, data = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - ) - ) - _, values = draw( - helpers.dtype_and_values( - available_dtypes=input_dtype, - shape=helpers.get_shape( - min_num_dims=1, - max_num_dims=1, - min_dim_size=1, - max_dim_size=1, - ), - ) - ) - return input_dtype, data, values - - -@st.composite -def _as_strided_helper(draw): - x_dtype, x, shape = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - min_num_dims=1, - ret_shape=True, - ) - ) - ndim = len(shape) - numel = x[0].size - offset = draw(st.integers(min_value=0, max_value=numel - 1)) - numel = numel - offset - size = draw( - helpers.get_shape( - min_num_dims=ndim, - max_num_dims=ndim, - ).filter(lambda s: math.prod(s) <= numel) - ) - stride = draw( - helpers.get_shape( - min_num_dims=ndim, - max_num_dims=ndim, - ).filter(lambda s: all(numel // s_i >= size[i] for i, s_i in enumerate(s))) - ) - return x_dtype, x, size, stride, offset - - -# as_strided -@handle_frontend_test( - fn_tree="torch.as_strided", - dtype_x_and_other=_as_strided_helper(), -) -def test_torch_as_strided( - *, - dtype_x_and_other, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - x_dtype, x, size, stride, offset = dtype_x_and_other - try: - helpers.test_frontend_function( - input_dtypes=x_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=x[0], - size=size, - stride=stride, - storage_offset=offset, - ) - except Exception as e: - if hasattr(e, "message") and "out of bounds for storage of size" in e.message: - assume(False) - else: - raise e - - -# heaviside -@handle_frontend_test( - fn_tree="torch.heaviside", - dtype_and_input=_heaviside_helper(), -) -def test_torch_heaviside( - *, - dtype_and_input, - test_flags, - fn_tree, - backend_fw, - on_device, - frontend, -): - input_dtype, data, values = dtype_and_input - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - test_flags=test_flags, - frontend=frontend, - fn_tree=fn_tree, - input=data[0], - values=values[0], - on_device=on_device, - ) - - -# asarray +# zeros @handle_frontend_test( - fn_tree="torch.asarray", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric") + fn_tree="torch.zeros", + size=helpers.ints(min_value=1, max_value=3), + shape=helpers.get_shape( + allow_none=False, + min_num_dims=1, + max_num_dims=5, + min_dim_size=1, + max_dim_size=10, ), dtype=helpers.get_dtypes("numeric", full=False), ) -def test_torch_asarray( +def test_torch_zeros( *, - dtype_and_x, + size, + shape, dtype, on_device, fn_tree, @@ -694,90 +741,51 @@ def test_torch_asarray( test_flags, backend_fw, ): - input_dtype, x = dtype_and_x + dims = {} + size = (size,) + if shape is None: + i = 0 + for x_ in size: + dims[f"x{i}"] = x_ + i += 1 helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, - on_device=on_device, - obj=x[0], + **dims, + size=shape, dtype=dtype[0], device=on_device, ) -# from_dlpack +# zeros_like @handle_frontend_test( - fn_tree="torch.from_dlpack", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric") - ), + fn_tree="torch.zeros_like", + dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), + dtype=helpers.get_dtypes("numeric", full=False), ) -def test_torch_from_dlpack( +def test_torch_zeros_like( *, dtype_and_x, + dtype, on_device, fn_tree, frontend, test_flags, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, input = dtype_and_x helpers.test_frontend_function( - ext_tensor=x[0], - backend_to_test=backend_fw, input_dtypes=input_dtype, + backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - ) - - -@st.composite -def _get_dtype_buffer_count_offset(draw): - dtype, value = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - ) - ) - value = np.array(value) - length = value.size - value = value.tobytes() - - offset = draw(helpers.ints(min_value=0, max_value=length - 1)) - count = draw(helpers.ints(min_value=-(2**30), max_value=length - offset)) - if count == 0: - count = -1 - offset = offset * np.dtype(dtype[0]).itemsize - - return dtype, value, count, offset - - -@handle_frontend_test( - fn_tree="torch.frombuffer", - dtype_buffer_count_offset=_get_dtype_buffer_count_offset(), -) -def test_torch_frombuffer( - dtype_buffer_count_offset, - test_flags, - frontend, - backend_fw, - fn_tree, - on_device, -): - input_dtype, buffer, count, offset = dtype_buffer_count_offset - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - test_flags=test_flags, - on_device=on_device, - frontend=frontend, - fn_tree=fn_tree, - buffer=buffer, - dtype=input_dtype[0], - count=count, - offset=offset, + input=input[0], + dtype=dtype[0], + device=on_device, ) diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_func_wrapper.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_func_wrapper.py index 21cc873744df2..fad1aacd51cfd 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_func_wrapper.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_func_wrapper.py @@ -14,6 +14,10 @@ import ivy.functional.frontends.torch as torch_frontend +# --- Helpers --- # +# --------------- # + + def _fn(*args, dtype=None, check_default=False): if ( check_default @@ -31,6 +35,15 @@ def _fn(*args, dtype=None, check_default=False): return args[0] +# --- Main --- # +# ------------ # + + +@numpy_to_torch_style_args +def mocked_func(dim=None, keepdim=None, input=None, other=None): + return dim, keepdim, input, other + + @given( dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid", prune_function=False) @@ -66,6 +79,29 @@ def test_torch_inputs_to_ivy_arrays(dtype_and_x, backend_fw): ivy.previous_backend() +@given( + dim=st.integers(), + keepdim=st.booleans(), + input=st.lists(st.integers()), + other=st.integers(), +) +def test_torch_numpy_to_torch_style_args(dim, keepdim, input, other): + # PyTorch-style keyword arguments + assert (dim, keepdim, input, other) == mocked_func( + dim=dim, keepdim=keepdim, input=input, other=other + ) + + # NumPy-style keyword arguments + assert (dim, keepdim, input, other) == mocked_func( + axis=dim, keepdims=keepdim, x=input, x2=other + ) + + # Mixed-style keyword arguments + assert (dim, keepdim, input, other) == mocked_func( + axis=dim, keepdim=keepdim, input=input, x2=other + ) + + @given( dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid", prune_function=False) @@ -159,31 +195,3 @@ def test_torch_to_ivy_arrays_and_back(dtype_and_x, dtype, backend_fw): assert ivy.default_float_dtype_stack == ivy.default_int_dtype_stack == [] ivy.previous_backend() - - -@numpy_to_torch_style_args -def mocked_func(dim=None, keepdim=None, input=None, other=None): - return dim, keepdim, input, other - - -@given( - dim=st.integers(), - keepdim=st.booleans(), - input=st.lists(st.integers()), - other=st.integers(), -) -def test_torch_numpy_to_torch_style_args(dim, keepdim, input, other): - # PyTorch-style keyword arguments - assert (dim, keepdim, input, other) == mocked_func( - dim=dim, keepdim=keepdim, input=input, other=other - ) - - # NumPy-style keyword arguments - assert (dim, keepdim, input, other) == mocked_func( - axis=dim, keepdims=keepdim, x=input, x2=other - ) - - # Mixed-style keyword arguments - assert (dim, keepdim, input, other) == mocked_func( - axis=dim, keepdim=keepdim, input=input, x2=other - ) diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_indexing_slicing_joining_mutating_ops.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_indexing_slicing_joining_mutating_ops.py index fccc6cf264573..a97d4ee6ea268 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_indexing_slicing_joining_mutating_ops.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_indexing_slicing_joining_mutating_ops.py @@ -1,1596 +1,1604 @@ -# global -from hypothesis import strategies as st -import math - - -# local -import ivy_tests.test_ivy.helpers as helpers -from ivy_tests.test_ivy.helpers import handle_frontend_test -from ivy_tests.test_ivy.test_functional.test_core.test_manipulation import _get_splits -from ivy_tests.test_ivy.test_functional.test_core.test_searching import ( - _broadcastable_trio, -) -from ivy_tests.test_ivy.test_functional.test_core.test_manipulation import ( # noqa - _get_splits, -) - - -# noinspection DuplicatedCode -@st.composite -def _arrays_idx_n_dtypes(draw): - num_dims = draw(st.shared(helpers.ints(min_value=1, max_value=4), key="num_dims")) - num_arrays = draw( - st.shared(helpers.ints(min_value=2, max_value=4), key="num_arrays") - ) - common_shape = draw( - helpers.list_of_size( - x=helpers.ints(min_value=2, max_value=3), - size=num_dims - 1, - ) - ) - unique_idx = draw(helpers.ints(min_value=0, max_value=num_dims - 1)) - unique_dims = draw( - helpers.list_of_size( - x=helpers.ints(min_value=2, max_value=3), - size=num_arrays, - ) - ) - xs = list() - input_dtypes = draw( - helpers.array_dtypes(available_dtypes=draw(helpers.get_dtypes("float"))) - ) - for ud, dt in zip(unique_dims, input_dtypes): - x = draw( - helpers.array_values( - shape=common_shape[:unique_idx] + [ud] + common_shape[unique_idx:], - dtype=dt, - ) - ) - xs.append(x) - return xs, input_dtypes, unique_idx - - -# noinspection DuplicatedCode -@st.composite -def _array_idxes_n_dtype(draw, **kwargs): - num_dims = draw(helpers.ints(min_value=1, max_value=4)) - dtype, x = draw( - helpers.dtype_and_values( - **kwargs, min_num_dims=num_dims, max_num_dims=num_dims, shared_dtype=True - ) - ) - idxes = draw( - st.lists( - helpers.ints(min_value=0, max_value=num_dims - 1), - min_size=num_dims, - max_size=num_dims, - unique=True, - ) - ) - return x, idxes, dtype - - -# adjoint -@handle_frontend_test( - fn_tree="torch.adjoint", - dtype_and_values=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("real_and_complex"), - min_num_dims=2, - min_dim_size=2, - ), -) -def test_torch_adjoint( - *, - dtype_and_values, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, value = dtype_and_values - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=value[0], - ) - - -# cat -@handle_frontend_test( - fn_tree="torch.cat", - xs_n_input_dtypes_n_unique_idx=_arrays_idx_n_dtypes(), -) -def test_torch_cat( - *, - xs_n_input_dtypes_n_unique_idx, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - xs, input_dtypes, unique_idx = xs_n_input_dtypes_n_unique_idx - helpers.test_frontend_function( - input_dtypes=input_dtypes, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - tensors=xs, - dim=unique_idx, - ) - - -# concat -@handle_frontend_test( - fn_tree="torch.concat", - xs_n_input_dtypes_n_unique_idx=_arrays_idx_n_dtypes(), -) -def test_torch_concat( - *, - xs_n_input_dtypes_n_unique_idx, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - xs, input_dtypes, unique_idx = xs_n_input_dtypes_n_unique_idx - helpers.test_frontend_function( - input_dtypes=input_dtypes, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - tensors=xs, - dim=unique_idx, - ) - - -# gather -@handle_frontend_test( - fn_tree="torch.gather", - params_indices_others=helpers.array_indices_axis( - array_dtypes=helpers.get_dtypes("valid"), - indices_dtypes=["int64"], - indices_same_dims=True, - ), -) -def test_torch_gather( - *, - params_indices_others, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtypes, input, indices, axis, batch_dims = params_indices_others - helpers.test_frontend_function( - input_dtypes=input_dtypes, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=input, - dim=axis, - index=indices, - ) - - -# nonzero -@handle_frontend_test( - fn_tree="torch.nonzero", - dtype_and_values=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - ), - as_tuple=st.booleans(), -) -def test_torch_nonzero( - *, - dtype_and_values, - as_tuple, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - dtype, input = dtype_and_values - helpers.test_frontend_function( - input_dtypes=dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=input[0], - as_tuple=as_tuple, - ) - - -# permute -@handle_frontend_test( - fn_tree="torch.permute", - dtype_values_axis=_array_idxes_n_dtype( - available_dtypes=helpers.get_dtypes("float"), - ), -) -def test_torch_permute( - *, - dtype_values_axis, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - x, idxes, dtype = dtype_values_axis - helpers.test_frontend_function( - input_dtypes=dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=x[0], - dims=tuple(idxes), - ) - - -# swapdims -@handle_frontend_test( - fn_tree="torch.swapdims", - dtype_and_values=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - shape=st.shared(helpers.get_shape(min_num_dims=2), key="shape"), - ), - dim0=helpers.get_axis( - shape=st.shared(helpers.get_shape(min_num_dims=2), key="shape"), - force_int=True, - ), - dim1=helpers.get_axis( - shape=st.shared(helpers.get_shape(min_num_dims=2), key="shape"), - force_int=True, - ), -) -def test_torch_swapdims( - *, - dtype_and_values, - dim0, - dim1, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, value = dtype_and_values - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=value[0], - dim0=dim0, - dim1=dim1, - ) - - -# reshape -@st.composite -def dtypes_x_reshape(draw): - shape = draw(helpers.get_shape(min_num_dims=1)) - dtypes, x = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - shape=shape, - ) - ) - shape = draw( - helpers.get_shape(min_num_dims=1).filter( - lambda s: math.prod(s) == math.prod(shape) - ) - ) - return dtypes, x, shape - - -@handle_frontend_test( - fn_tree="torch.reshape", - dtypes_x_reshape=dtypes_x_reshape(), -) -def test_torch_reshape( - *, - dtypes_x_reshape, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, x, shape = dtypes_x_reshape - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=x[0], - shape=shape, - ) - - -# stack -@handle_frontend_test( - fn_tree="torch.stack", - dtype_value_shape=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=st.shared(helpers.ints(min_value=2, max_value=4), key="num_arrays"), - shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), - ), - dim=helpers.get_axis( - shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), - ).filter(lambda axis: isinstance(axis, int)), -) -def test_torch_stack( - *, - dtype_value_shape, - dim, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, value = dtype_value_shape - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - tensors=value, - dim=dim, - ) - - -# transpose -@handle_frontend_test( - fn_tree="torch.transpose", - dtype_and_values=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - shape=st.shared(helpers.get_shape(min_num_dims=2), key="shape"), - ), - dim0=helpers.get_axis( - shape=st.shared(helpers.get_shape(min_num_dims=2), key="shape"), - force_int=True, - ), - dim1=helpers.get_axis( - shape=st.shared(helpers.get_shape(min_num_dims=2), key="shape"), - force_int=True, - ), -) -def test_torch_transpose( - *, - dtype_and_values, - dim0, - dim1, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, value = dtype_and_values - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=value[0], - dim0=dim0, - dim1=dim1, - ) - - -# t -@handle_frontend_test( - fn_tree="torch.t", - dtype_and_values=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - shape=st.shared(helpers.get_shape(max_num_dims=2), key="shape"), - ), -) -def test_torch_t( - *, - dtype_and_values, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, value = dtype_and_values - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=value[0], - ) - - -# squeeze -@handle_frontend_test( - fn_tree="torch.squeeze", - dtype_and_values=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), - ), - dim=helpers.get_axis( - shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), - max_size=1, - ).filter(lambda axis: isinstance(axis, int)), -) -def test_torch_squeeze( - *, - dtype_and_values, - dim, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, value = dtype_and_values - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=value[0], - dim=dim, - ) - - -# swapaxes -@handle_frontend_test( - fn_tree="torch.swapaxes", - dtype_and_values=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - shape=st.shared(helpers.get_shape(min_num_dims=2), key="shape"), - ), - axis0=helpers.get_axis( - shape=st.shared(helpers.get_shape(min_num_dims=2), key="shape"), - force_int=True, - ), - axis1=helpers.get_axis( - shape=st.shared(helpers.get_shape(min_num_dims=2), key="shape"), - force_int=True, - ), -) -def test_torch_swapaxes( - *, - dtype_and_values, - axis0, - axis1, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, value = dtype_and_values - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=value[0], - axis0=axis0, - axis1=axis1, - ) - - -@st.composite -def _chunk_helper(draw): - dtype, x, shape = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_num_dims=1, - ret_shape=True, - ) - ) - axis = draw(helpers.get_axis(shape=shape, force_int=True)) - if shape[axis] == 0: - chunks = 0 - else: - factors = [] - for i in range(1, shape[axis] + 1): - if shape[axis] % i == 0: - factors.append(i) - chunks = draw(st.sampled_from(factors)) - return dtype, x, axis, chunks - - -# chunk -@handle_frontend_test( - fn_tree="torch.chunk", - x_dim_chunks=_chunk_helper(), - test_with_out=st.just(False), -) -def test_torch_chunk( - *, - x_dim_chunks, - fn_tree, - on_device, - frontend, - test_flags, - backend_fw, -): - dtype, x, axis, chunks = x_dim_chunks - helpers.test_frontend_function( - input_dtypes=dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=x[0], - chunks=chunks, - dim=axis, - ) - - -# tile -@handle_frontend_test( - fn_tree="torch.tile", - dtype_value=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - shape=st.shared(helpers.get_shape(), key="shape"), - ), - dim=helpers.get_axis( - shape=st.shared(helpers.get_shape(), key="shape"), - allow_neg=False, - force_tuple=True, - ), -) -def test_torch_tile( - *, - dtype_value, - dim, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, value = dtype_value - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=value[0], - dims=dim, - ) - - -# unsqueeze -@handle_frontend_test( - fn_tree="torch.unsqueeze", - dtype_value=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - shape=st.shared(helpers.get_shape(), key="shape"), - ), - dim=helpers.get_axis( - shape=st.shared(helpers.get_shape(), key="shape"), - allow_neg=True, - force_int=True, - ), -) -def test_torch_unsqueeze( - *, - dtype_value, - dim, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, value = dtype_value - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=value[0], - dim=dim, - ) - - -@handle_frontend_test( - fn_tree="torch.argwhere", - dtype_and_values=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - ), -) -def test_torch_argwhere( - *, - dtype_and_values, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - dtype, input = dtype_and_values - helpers.test_frontend_function( - input_dtypes=dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=input[0], - ) - - -# movedim -@handle_frontend_test( - fn_tree="torch.movedim", - dtype_and_input=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_value=-100, - max_value=100, - shape=st.shared( - helpers.get_shape( - min_num_dims=1, - max_num_dims=3, - min_dim_size=1, - max_dim_size=3, - ), - key="a_s_d", - ), - ), - source=helpers.get_axis( - allow_none=False, - unique=True, - shape=st.shared( - helpers.get_shape( - min_num_dims=1, - max_num_dims=3, - min_dim_size=1, - max_dim_size=3, - ), - key="a_s_d", - ), - min_size=1, - force_int=True, - ), - destination=helpers.get_axis( - allow_none=False, - unique=True, - shape=st.shared( - helpers.get_shape( - min_num_dims=1, - max_num_dims=3, - min_dim_size=1, - max_dim_size=3, - ), - key="a_s_d", - ), - min_size=1, - force_int=True, - ), - test_with_out=st.just(False), -) -def test_torch_movedim( - *, - dtype_and_input, - source, - destination, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, value = dtype_and_input - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=value[0], - source=source, - destination=destination, - ) - - -# moveaxis -@handle_frontend_test( - fn_tree="torch.moveaxis", - dtype_and_input=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_value=-100, - max_value=100, - shape=st.shared( - helpers.get_shape( - min_num_dims=1, - max_num_dims=3, - min_dim_size=1, - max_dim_size=3, - ), - key="a_s_d", - ), - ), - source=helpers.get_axis( - allow_none=False, - unique=True, - shape=st.shared( - helpers.get_shape( - min_num_dims=1, - max_num_dims=3, - min_dim_size=1, - max_dim_size=3, - ), - key="a_s_d", - ), - min_size=1, - force_int=True, - ), - destination=helpers.get_axis( - allow_none=False, - unique=True, - shape=st.shared( - helpers.get_shape( - min_num_dims=1, - max_num_dims=3, - min_dim_size=1, - max_dim_size=3, - ), - key="a_s_d", - ), - min_size=1, - force_int=True, - ), - test_with_out=st.just(False), -) -def test_torch_moveaxis( - *, - dtype_and_input, - source, - destination, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, value = dtype_and_input - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=value[0], - source=source, - destination=destination, - ) - - -# hstack -@handle_frontend_test( - fn_tree="torch.hstack", - dtype_value_shape=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - ), -) -def test_torch_hstack( - *, - dtype_value_shape, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, value = dtype_value_shape - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - tensors=value, - ) - - -# dstack -@handle_frontend_test( - fn_tree="torch.dstack", - dtype_value_shape=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - ), -) -def test_torch_dstack( - *, - dtype_value_shape, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, value = dtype_value_shape - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - tensors=value, - ) - - -# index_select -@handle_frontend_test( - fn_tree="torch.index_select", - params_indices_others=helpers.array_indices_axis( - array_dtypes=helpers.get_dtypes("valid"), - indices_dtypes=["int64"], - max_num_dims=1, - indices_same_dims=True, - ), -) -def test_torch_index_select( - *, - params_indices_others, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtypes, input, indices, axis, batch_dims = params_indices_others - helpers.test_frontend_function( - input_dtypes=input_dtypes, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=input, - dim=axis, - index=indices, - ) - - -# take_along_dim -@handle_frontend_test( - fn_tree="torch.take_along_dim", - dtype_indices_axis=helpers.array_indices_axis( - array_dtypes=helpers.get_dtypes("numeric"), - indices_dtypes=["int64"], - min_num_dims=1, - max_num_dims=5, - min_dim_size=1, - max_dim_size=10, - indices_same_dims=True, - ), -) -def test_torch_take_along_dim( - *, - dtype_indices_axis, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtypes, value, indices, axis, _ = dtype_indices_axis - helpers.test_frontend_function( - input_dtypes=input_dtypes, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=value, - indices=indices, - dim=axis, - ) - - -# vstack -@handle_frontend_test( - fn_tree="torch.vstack", - dtype_value_shape=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - ), -) -def test_torch_vstack( - *, - dtype_value_shape, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, value = dtype_value_shape - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - tensors=value, - ) - - -# split -@handle_frontend_test( - fn_tree="torch.split", - dtype_value=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - shape=st.shared(helpers.get_shape(min_num_dims=1), key="value_shape"), - ), - split_size_or_sections=_get_splits( - allow_none=False, min_num_dims=1, allow_array_indices=False - ), - dim=st.shared( - helpers.get_axis( - shape=st.shared(helpers.get_shape(min_num_dims=1), key="value_shape"), - force_int=True, - ), - key="target_axis", - ), -) -def test_torch_split( - *, - dtype_value, - split_size_or_sections, - dim, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, value = dtype_value - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - tensor=value[0], - split_size_or_sections=split_size_or_sections, - dim=dim, - ) - - -# tensor_split -@handle_frontend_test( - fn_tree="torch.tensor_split", - dtype_value=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("integer"), - shape=st.shared(helpers.get_shape(min_num_dims=1), key="value_shape"), - ), - indices_or_sections=_get_splits( - min_num_dims=1, allow_none=False, allow_array_indices=False - ), - axis=st.shared( - helpers.get_axis( - shape=st.shared(helpers.get_shape(min_num_dims=1), key="value_shape"), - force_int=True, - ), - key="target_axis", - ), - number_positional_args=st.just(2), - test_with_out=st.just(False), -) -def test_torch_tensor_split( - *, - dtype_value, - indices_or_sections, - axis, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, value = dtype_value - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=value[0], - indices_or_sections=indices_or_sections, - dim=axis, - ) - - -# unbind -@handle_frontend_test( - fn_tree="torch.unbind", - dtype_value_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("numeric"), - min_num_dims=1, - valid_axis=True, - force_int_axis=True, - ), -) -def test_torch_unbind( - *, - dtype_value_axis, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtypes, value, axis = dtype_value_axis - helpers.test_frontend_function( - input_dtypes=input_dtypes, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=value[0], - dim=axis, - ) - - -# dsplit -@handle_frontend_test( - fn_tree="torch.dsplit", - dtype_value=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - shape=st.shared(helpers.get_shape(min_num_dims=3), key="value_shape"), - ), - indices_or_sections=_get_splits( - min_num_dims=3, - axis=2, - allow_none=False, - allow_array_indices=False, - is_mod_split=True, - ), -) -def test_torch_dsplit( - *, - dtype_value, - indices_or_sections, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, value = dtype_value - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=value[0], - indices_or_sections=indices_or_sections, - ) - - -# hsplit -@handle_frontend_test( - fn_tree="torch.hsplit", - dtype_value=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - shape=st.shared(helpers.get_shape(min_num_dims=2), key="value_shape"), - ), - indices_or_sections=_get_splits( - min_num_dims=1, - axis=1, - allow_none=False, - allow_array_indices=False, - is_mod_split=True, - ), -) -def test_torch_hsplit( - *, - dtype_value, - indices_or_sections, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, value = dtype_value - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=value[0], - indices_or_sections=indices_or_sections, - ) - - -# vsplit -@handle_frontend_test( - fn_tree="torch.vsplit", - dtype_value=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - shape=st.shared(helpers.get_shape(min_num_dims=2), key="value_shape"), - ), - indices_or_sections=_get_splits( - min_num_dims=2, - axis=0, - allow_none=False, - allow_array_indices=False, - is_mod_split=True, - ), -) -def test_torch_vsplit( - *, - dtype_value, - indices_or_sections, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, value = dtype_value - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=value[0], - indices_or_sections=indices_or_sections, - ) - - -# row_stack -@handle_frontend_test( - fn_tree="torch.row_stack", - dtype_value_shape=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=st.integers(1, 5), - ), -) -def test_torch_row_stack( - *, - dtype_value_shape, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, value = dtype_value_shape - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - tensors=value, - ) - - -@handle_frontend_test( - fn_tree="torch.where", - broadcastables=_broadcastable_trio(), - only_cond=st.booleans(), -) -def test_torch_where( - *, - broadcastables, - only_cond, - frontend, - test_flags, - fn_tree, - backend_fw, - on_device, -): - cond, xs, dtypes = broadcastables - - if only_cond: - helpers.test_frontend_function( - input_dtypes=[dtypes[0]], - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - condition=xs[0], - ) - - else: - helpers.test_frontend_function( - input_dtypes=["bool"] + dtypes, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - condition=cond, - input=xs[0], - other=xs[1], - backend_to_test=backend_fw, - ) - - -@handle_frontend_test( - fn_tree="torch.conj", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float_and_complex"), - ), -) -def test_torch_conj( - on_device, - frontend, - *, - dtype_and_x, - fn_tree, - test_flags, - backend_fw, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=x[0], - ) - - -@st.composite -def _arrays_dim_idx_n_dtypes(draw): - num_dims = draw(st.shared(helpers.ints(min_value=1, max_value=4), key="num_dims")) - num_arrays = 2 - common_shape = draw( - helpers.lists( - x=helpers.ints(min_value=2, max_value=3), - min_size=num_dims - 1, - max_size=num_dims - 1, - ) - ) - _dim = draw(helpers.ints(min_value=0, max_value=num_dims - 1)) - unique_dims = draw( - helpers.lists( - x=helpers.ints(min_value=2, max_value=3), - min_size=num_arrays, - max_size=num_arrays, - ) - ) - - min_dim = min(unique_dims) - max_dim = max(unique_dims) - _idx = draw( - helpers.array_values( - shape=min_dim, - dtype="int64", - min_value=0, - max_value=max_dim, - exclude_min=False, - ) - ) - - xs = list() - available_input_types = draw(helpers.get_dtypes("numeric")) - available_input_types.remove("float16") # half summation unstable in backends - input_dtypes = draw( - helpers.array_dtypes( - available_dtypes=available_input_types, - num_arrays=num_arrays, - shared_dtype=True, - ) - ) - for ud, dt in zip(unique_dims, input_dtypes): - x = draw( - helpers.array_values( - shape=common_shape[:_dim] + [ud] + common_shape[_dim:], - dtype=dt, - large_abs_safety_factor=2.5, - small_abs_safety_factor=2.5, - safety_factor_scale="log", - ) - ) - xs.append(x) - return xs, input_dtypes, _dim, _idx - - -# index_add -@handle_frontend_test( - fn_tree="torch.index_add", - xs_dtypes_dim_idx=_arrays_dim_idx_n_dtypes(), - alpha=st.integers(min_value=1, max_value=2), -) -def test_torch_index_add( - *, - xs_dtypes_dim_idx, - alpha, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - xs, input_dtypes, axis, indices = xs_dtypes_dim_idx - if xs[0].shape[axis] < xs[1].shape[axis]: - source, input = xs - else: - input, source = xs - helpers.test_frontend_function( - input_dtypes=[input_dtypes[0], "int64", input_dtypes[1]], - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - rtol=1e-03, - input=input, - dim=axis, - index=indices, - source=source, - alpha=alpha, - ) - - -# index_copy -@handle_frontend_test( - fn_tree="torch.index_copy", - xs_dtypes_dim_idx=_arrays_dim_idx_n_dtypes(), -) -def test_torch_index_copy( - *, - xs_dtypes_dim_idx, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - xs, input_dtypes, axis, indices = xs_dtypes_dim_idx - if xs[0].shape[axis] < xs[1].shape[axis]: - source, input = xs - else: - input, source = xs - helpers.test_frontend_function( - input_dtypes=[input_dtypes[0], "int64", input_dtypes[1]], - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=input, - dim=axis, - index=indices, - source=source, - ) - - -@st.composite -def _dtypes_input_mask(draw): - _shape = draw(helpers.get_shape(min_num_dims=1, min_dim_size=1)) - _mask = draw(helpers.array_values(dtype=helpers.get_dtypes("bool"), shape=_shape)) - _dtype, _x = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=1, - shape=_shape, - ) - ) - - return _dtype, _x, _mask - - -@handle_frontend_test( - fn_tree="torch.masked_select", - dtype_input_mask=_dtypes_input_mask(), -) -def test_torch_masked_select( - *, - dtype_input_mask, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - ( - input_dtype, - x, - mask, - ) = dtype_input_mask - - helpers.test_frontend_function( - input_dtypes=input_dtype + ["bool"], - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=x[0], - mask=mask, - ) - - -@handle_frontend_test( - fn_tree="torch.take", - dtype_and_x=helpers.array_indices_axis( - array_dtypes=helpers.get_dtypes(), indices_dtypes=["int64"] - ), -) -def test_torch_take( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - dtypes, xs, indices, _, _ = dtype_and_x - helpers.test_frontend_function( - input_dtypes=dtypes, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=xs, - index=indices, - ) - - -@st.composite -def _dtype_input_dim_start_length(draw): - _shape = draw(helpers.get_shape(min_num_dims=1, min_dim_size=1)) - _dtype, _x = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=1, - shape=_shape, - ) - ) - _dim = draw( - helpers.get_axis( - shape=_shape, - force_int=True, - ), - ) - _start = draw(helpers.ints(min_value=1, max_value=_shape[_dim])) - - _length = draw(helpers.ints(min_value=0, max_value=_shape[_dim] - _start)) - - return _dtype, _x, _dim, _start, _length - - -@handle_frontend_test( - fn_tree="torch.narrow", - dtype_input_dim_start_length=_dtype_input_dim_start_length(), -) -def test_torch_narrow( - *, - dtype_input_dim_start_length, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - (input_dtype, x, dim, start, length) = dtype_input_dim_start_length - - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=x[0], - dim=dim, - start=start, - length=length, - ) - - -@st.composite -def _dtype_input_idx_axis(draw): - dtype_x_axis_shape = draw( - helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("valid"), - force_int_axis=True, - ret_shape=True, - valid_axis=True, - min_num_dims=2, - ) - ) - - input_dtype, x, axis, shape = dtype_x_axis_shape - max_idx = 0 - if shape: - max_idx = shape[axis] - 1 - idx = draw(helpers.ints(min_value=0, max_value=max_idx)) - x = x[0] - - return input_dtype, x, idx, axis - - -@handle_frontend_test( - fn_tree="torch.select", - dtype_x_idx_axis=_dtype_input_idx_axis(), -) -def test_torch_select( - *, - dtype_x_idx_axis, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, x, idx, axis = dtype_x_idx_axis - - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=x, - dim=axis, - index=idx, - ) +# global +from hypothesis import strategies as st +import math + + +# local +import ivy_tests.test_ivy.helpers as helpers +from ivy_tests.test_ivy.helpers import handle_frontend_test +from ivy_tests.test_ivy.test_functional.test_core.test_manipulation import _get_splits +from ivy_tests.test_ivy.test_functional.test_core.test_searching import ( + _broadcastable_trio, +) +from ivy_tests.test_ivy.test_functional.test_core.test_manipulation import ( # noqa + _get_splits, +) + + +# --- Helpers --- # +# --------------- # + + +# noinspection DuplicatedCode +@st.composite +def _array_idxes_n_dtype(draw, **kwargs): + num_dims = draw(helpers.ints(min_value=1, max_value=4)) + dtype, x = draw( + helpers.dtype_and_values( + **kwargs, min_num_dims=num_dims, max_num_dims=num_dims, shared_dtype=True + ) + ) + idxes = draw( + st.lists( + helpers.ints(min_value=0, max_value=num_dims - 1), + min_size=num_dims, + max_size=num_dims, + unique=True, + ) + ) + return x, idxes, dtype + + +@st.composite +def _arrays_dim_idx_n_dtypes(draw): + num_dims = draw(st.shared(helpers.ints(min_value=1, max_value=4), key="num_dims")) + num_arrays = 2 + common_shape = draw( + helpers.lists( + x=helpers.ints(min_value=2, max_value=3), + min_size=num_dims - 1, + max_size=num_dims - 1, + ) + ) + _dim = draw(helpers.ints(min_value=0, max_value=num_dims - 1)) + unique_dims = draw( + helpers.lists( + x=helpers.ints(min_value=2, max_value=3), + min_size=num_arrays, + max_size=num_arrays, + ) + ) + + min_dim = min(unique_dims) + max_dim = max(unique_dims) + _idx = draw( + helpers.array_values( + shape=min_dim, + dtype="int64", + min_value=0, + max_value=max_dim, + exclude_min=False, + ) + ) + + xs = list() + available_input_types = draw(helpers.get_dtypes("numeric")) + available_input_types.remove("float16") # half summation unstable in backends + input_dtypes = draw( + helpers.array_dtypes( + available_dtypes=available_input_types, + num_arrays=num_arrays, + shared_dtype=True, + ) + ) + for ud, dt in zip(unique_dims, input_dtypes): + x = draw( + helpers.array_values( + shape=common_shape[:_dim] + [ud] + common_shape[_dim:], + dtype=dt, + large_abs_safety_factor=2.5, + small_abs_safety_factor=2.5, + safety_factor_scale="log", + ) + ) + xs.append(x) + return xs, input_dtypes, _dim, _idx + + +# noinspection DuplicatedCode +@st.composite +def _arrays_idx_n_dtypes(draw): + num_dims = draw(st.shared(helpers.ints(min_value=1, max_value=4), key="num_dims")) + num_arrays = draw( + st.shared(helpers.ints(min_value=2, max_value=4), key="num_arrays") + ) + common_shape = draw( + helpers.list_of_size( + x=helpers.ints(min_value=2, max_value=3), + size=num_dims - 1, + ) + ) + unique_idx = draw(helpers.ints(min_value=0, max_value=num_dims - 1)) + unique_dims = draw( + helpers.list_of_size( + x=helpers.ints(min_value=2, max_value=3), + size=num_arrays, + ) + ) + xs = list() + input_dtypes = draw( + helpers.array_dtypes(available_dtypes=draw(helpers.get_dtypes("float"))) + ) + for ud, dt in zip(unique_dims, input_dtypes): + x = draw( + helpers.array_values( + shape=common_shape[:unique_idx] + [ud] + common_shape[unique_idx:], + dtype=dt, + ) + ) + xs.append(x) + return xs, input_dtypes, unique_idx + + +@st.composite +def _chunk_helper(draw): + dtype, x, shape = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_num_dims=1, + ret_shape=True, + ) + ) + axis = draw(helpers.get_axis(shape=shape, force_int=True)) + if shape[axis] == 0: + chunks = 0 + else: + factors = [] + for i in range(1, shape[axis] + 1): + if shape[axis] % i == 0: + factors.append(i) + chunks = draw(st.sampled_from(factors)) + return dtype, x, axis, chunks + + +@st.composite +def _dtype_input_dim_start_length(draw): + _shape = draw(helpers.get_shape(min_num_dims=1, min_dim_size=1)) + _dtype, _x = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=1, + shape=_shape, + ) + ) + _dim = draw( + helpers.get_axis( + shape=_shape, + force_int=True, + ), + ) + _start = draw(helpers.ints(min_value=1, max_value=_shape[_dim])) + + _length = draw(helpers.ints(min_value=0, max_value=_shape[_dim] - _start)) + + return _dtype, _x, _dim, _start, _length + + +@st.composite +def _dtype_input_idx_axis(draw): + dtype_x_axis_shape = draw( + helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("valid"), + force_int_axis=True, + ret_shape=True, + valid_axis=True, + min_num_dims=2, + ) + ) + + input_dtype, x, axis, shape = dtype_x_axis_shape + max_idx = 0 + if shape: + max_idx = shape[axis] - 1 + idx = draw(helpers.ints(min_value=0, max_value=max_idx)) + x = x[0] + + return input_dtype, x, idx, axis + + +@st.composite +def _dtypes_input_mask(draw): + _shape = draw(helpers.get_shape(min_num_dims=1, min_dim_size=1)) + _mask = draw(helpers.array_values(dtype=helpers.get_dtypes("bool"), shape=_shape)) + _dtype, _x = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=1, + shape=_shape, + ) + ) + + return _dtype, _x, _mask + + +# reshape +@st.composite +def dtypes_x_reshape(draw): + shape = draw(helpers.get_shape(min_num_dims=1)) + dtypes, x = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + shape=shape, + ) + ) + shape = draw( + helpers.get_shape(min_num_dims=1).filter( + lambda s: math.prod(s) == math.prod(shape) + ) + ) + return dtypes, x, shape + + +# --- Main --- # +# ------------ # + + +# adjoint +@handle_frontend_test( + fn_tree="torch.adjoint", + dtype_and_values=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("real_and_complex"), + min_num_dims=2, + min_dim_size=2, + ), +) +def test_torch_adjoint( + *, + dtype_and_values, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, value = dtype_and_values + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=value[0], + ) + + +@handle_frontend_test( + fn_tree="torch.argwhere", + dtype_and_values=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + ), +) +def test_torch_argwhere( + *, + dtype_and_values, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + dtype, input = dtype_and_values + helpers.test_frontend_function( + input_dtypes=dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=input[0], + ) + + +# cat +@handle_frontend_test( + fn_tree="torch.cat", + xs_n_input_dtypes_n_unique_idx=_arrays_idx_n_dtypes(), +) +def test_torch_cat( + *, + xs_n_input_dtypes_n_unique_idx, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + xs, input_dtypes, unique_idx = xs_n_input_dtypes_n_unique_idx + helpers.test_frontend_function( + input_dtypes=input_dtypes, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + tensors=xs, + dim=unique_idx, + ) + + +# chunk +@handle_frontend_test( + fn_tree="torch.chunk", + x_dim_chunks=_chunk_helper(), + test_with_out=st.just(False), +) +def test_torch_chunk( + *, + x_dim_chunks, + fn_tree, + on_device, + frontend, + test_flags, + backend_fw, +): + dtype, x, axis, chunks = x_dim_chunks + helpers.test_frontend_function( + input_dtypes=dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=x[0], + chunks=chunks, + dim=axis, + ) + + +# concat +@handle_frontend_test( + fn_tree="torch.concat", + xs_n_input_dtypes_n_unique_idx=_arrays_idx_n_dtypes(), +) +def test_torch_concat( + *, + xs_n_input_dtypes_n_unique_idx, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + xs, input_dtypes, unique_idx = xs_n_input_dtypes_n_unique_idx + helpers.test_frontend_function( + input_dtypes=input_dtypes, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + tensors=xs, + dim=unique_idx, + ) + + +@handle_frontend_test( + fn_tree="torch.conj", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float_and_complex"), + ), +) +def test_torch_conj( + on_device, + frontend, + *, + dtype_and_x, + fn_tree, + test_flags, + backend_fw, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=x[0], + ) + + +# dsplit +@handle_frontend_test( + fn_tree="torch.dsplit", + dtype_value=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + shape=st.shared(helpers.get_shape(min_num_dims=3), key="value_shape"), + ), + indices_or_sections=_get_splits( + min_num_dims=3, + axis=2, + allow_none=False, + allow_array_indices=False, + is_mod_split=True, + ), +) +def test_torch_dsplit( + *, + dtype_value, + indices_or_sections, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, value = dtype_value + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=value[0], + indices_or_sections=indices_or_sections, + ) + + +# dstack +@handle_frontend_test( + fn_tree="torch.dstack", + dtype_value_shape=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + ), +) +def test_torch_dstack( + *, + dtype_value_shape, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, value = dtype_value_shape + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + tensors=value, + ) + + +# gather +@handle_frontend_test( + fn_tree="torch.gather", + params_indices_others=helpers.array_indices_axis( + array_dtypes=helpers.get_dtypes("valid"), + indices_dtypes=["int64"], + indices_same_dims=True, + ), +) +def test_torch_gather( + *, + params_indices_others, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtypes, input, indices, axis, batch_dims = params_indices_others + helpers.test_frontend_function( + input_dtypes=input_dtypes, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=input, + dim=axis, + index=indices, + ) + + +# hsplit +@handle_frontend_test( + fn_tree="torch.hsplit", + dtype_value=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + shape=st.shared(helpers.get_shape(min_num_dims=2), key="value_shape"), + ), + indices_or_sections=_get_splits( + min_num_dims=1, + axis=1, + allow_none=False, + allow_array_indices=False, + is_mod_split=True, + ), +) +def test_torch_hsplit( + *, + dtype_value, + indices_or_sections, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, value = dtype_value + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=value[0], + indices_or_sections=indices_or_sections, + ) + + +# hstack +@handle_frontend_test( + fn_tree="torch.hstack", + dtype_value_shape=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + ), +) +def test_torch_hstack( + *, + dtype_value_shape, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, value = dtype_value_shape + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + tensors=value, + ) + + +# index_add +@handle_frontend_test( + fn_tree="torch.index_add", + xs_dtypes_dim_idx=_arrays_dim_idx_n_dtypes(), + alpha=st.integers(min_value=1, max_value=2), +) +def test_torch_index_add( + *, + xs_dtypes_dim_idx, + alpha, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + xs, input_dtypes, axis, indices = xs_dtypes_dim_idx + if xs[0].shape[axis] < xs[1].shape[axis]: + source, input = xs + else: + input, source = xs + helpers.test_frontend_function( + input_dtypes=[input_dtypes[0], "int64", input_dtypes[1]], + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + rtol=1e-03, + input=input, + dim=axis, + index=indices, + source=source, + alpha=alpha, + ) + + +# index_copy +@handle_frontend_test( + fn_tree="torch.index_copy", + xs_dtypes_dim_idx=_arrays_dim_idx_n_dtypes(), +) +def test_torch_index_copy( + *, + xs_dtypes_dim_idx, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + xs, input_dtypes, axis, indices = xs_dtypes_dim_idx + if xs[0].shape[axis] < xs[1].shape[axis]: + source, input = xs + else: + input, source = xs + helpers.test_frontend_function( + input_dtypes=[input_dtypes[0], "int64", input_dtypes[1]], + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=input, + dim=axis, + index=indices, + source=source, + ) + + +# index_select +@handle_frontend_test( + fn_tree="torch.index_select", + params_indices_others=helpers.array_indices_axis( + array_dtypes=helpers.get_dtypes("valid"), + indices_dtypes=["int64"], + max_num_dims=1, + indices_same_dims=True, + ), +) +def test_torch_index_select( + *, + params_indices_others, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtypes, input, indices, axis, batch_dims = params_indices_others + helpers.test_frontend_function( + input_dtypes=input_dtypes, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=input, + dim=axis, + index=indices, + ) + + +@handle_frontend_test( + fn_tree="torch.masked_select", + dtype_input_mask=_dtypes_input_mask(), +) +def test_torch_masked_select( + *, + dtype_input_mask, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + ( + input_dtype, + x, + mask, + ) = dtype_input_mask + + helpers.test_frontend_function( + input_dtypes=input_dtype + ["bool"], + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=x[0], + mask=mask, + ) + + +# moveaxis +@handle_frontend_test( + fn_tree="torch.moveaxis", + dtype_and_input=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_value=-100, + max_value=100, + shape=st.shared( + helpers.get_shape( + min_num_dims=1, + max_num_dims=3, + min_dim_size=1, + max_dim_size=3, + ), + key="a_s_d", + ), + ), + source=helpers.get_axis( + allow_none=False, + unique=True, + shape=st.shared( + helpers.get_shape( + min_num_dims=1, + max_num_dims=3, + min_dim_size=1, + max_dim_size=3, + ), + key="a_s_d", + ), + min_size=1, + force_int=True, + ), + destination=helpers.get_axis( + allow_none=False, + unique=True, + shape=st.shared( + helpers.get_shape( + min_num_dims=1, + max_num_dims=3, + min_dim_size=1, + max_dim_size=3, + ), + key="a_s_d", + ), + min_size=1, + force_int=True, + ), + test_with_out=st.just(False), +) +def test_torch_moveaxis( + *, + dtype_and_input, + source, + destination, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, value = dtype_and_input + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=value[0], + source=source, + destination=destination, + ) + + +# movedim +@handle_frontend_test( + fn_tree="torch.movedim", + dtype_and_input=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_value=-100, + max_value=100, + shape=st.shared( + helpers.get_shape( + min_num_dims=1, + max_num_dims=3, + min_dim_size=1, + max_dim_size=3, + ), + key="a_s_d", + ), + ), + source=helpers.get_axis( + allow_none=False, + unique=True, + shape=st.shared( + helpers.get_shape( + min_num_dims=1, + max_num_dims=3, + min_dim_size=1, + max_dim_size=3, + ), + key="a_s_d", + ), + min_size=1, + force_int=True, + ), + destination=helpers.get_axis( + allow_none=False, + unique=True, + shape=st.shared( + helpers.get_shape( + min_num_dims=1, + max_num_dims=3, + min_dim_size=1, + max_dim_size=3, + ), + key="a_s_d", + ), + min_size=1, + force_int=True, + ), + test_with_out=st.just(False), +) +def test_torch_movedim( + *, + dtype_and_input, + source, + destination, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, value = dtype_and_input + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=value[0], + source=source, + destination=destination, + ) + + +@handle_frontend_test( + fn_tree="torch.narrow", + dtype_input_dim_start_length=_dtype_input_dim_start_length(), +) +def test_torch_narrow( + *, + dtype_input_dim_start_length, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + (input_dtype, x, dim, start, length) = dtype_input_dim_start_length + + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=x[0], + dim=dim, + start=start, + length=length, + ) + + +# nonzero +@handle_frontend_test( + fn_tree="torch.nonzero", + dtype_and_values=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + ), + as_tuple=st.booleans(), +) +def test_torch_nonzero( + *, + dtype_and_values, + as_tuple, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + dtype, input = dtype_and_values + helpers.test_frontend_function( + input_dtypes=dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=input[0], + as_tuple=as_tuple, + ) + + +# permute +@handle_frontend_test( + fn_tree="torch.permute", + dtype_values_axis=_array_idxes_n_dtype( + available_dtypes=helpers.get_dtypes("float"), + ), +) +def test_torch_permute( + *, + dtype_values_axis, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + x, idxes, dtype = dtype_values_axis + helpers.test_frontend_function( + input_dtypes=dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=x[0], + dims=tuple(idxes), + ) + + +@handle_frontend_test( + fn_tree="torch.reshape", + dtypes_x_reshape=dtypes_x_reshape(), +) +def test_torch_reshape( + *, + dtypes_x_reshape, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x, shape = dtypes_x_reshape + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=x[0], + shape=shape, + ) + + +# row_stack +@handle_frontend_test( + fn_tree="torch.row_stack", + dtype_value_shape=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + num_arrays=st.integers(1, 5), + ), +) +def test_torch_row_stack( + *, + dtype_value_shape, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, value = dtype_value_shape + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + tensors=value, + ) + + +@handle_frontend_test( + fn_tree="torch.select", + dtype_x_idx_axis=_dtype_input_idx_axis(), +) +def test_torch_select( + *, + dtype_x_idx_axis, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x, idx, axis = dtype_x_idx_axis + + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=x, + dim=axis, + index=idx, + ) + + +# split +@handle_frontend_test( + fn_tree="torch.split", + dtype_value=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + shape=st.shared(helpers.get_shape(min_num_dims=1), key="value_shape"), + ), + split_size_or_sections=_get_splits( + allow_none=False, min_num_dims=1, allow_array_indices=False + ), + dim=st.shared( + helpers.get_axis( + shape=st.shared(helpers.get_shape(min_num_dims=1), key="value_shape"), + force_int=True, + ), + key="target_axis", + ), +) +def test_torch_split( + *, + dtype_value, + split_size_or_sections, + dim, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, value = dtype_value + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + tensor=value[0], + split_size_or_sections=split_size_or_sections, + dim=dim, + ) + + +# squeeze +@handle_frontend_test( + fn_tree="torch.squeeze", + dtype_and_values=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), + ), + dim=helpers.get_axis( + shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), + max_size=1, + ).filter(lambda axis: isinstance(axis, int)), +) +def test_torch_squeeze( + *, + dtype_and_values, + dim, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, value = dtype_and_values + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=value[0], + dim=dim, + ) + + +# stack +@handle_frontend_test( + fn_tree="torch.stack", + dtype_value_shape=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + num_arrays=st.shared(helpers.ints(min_value=2, max_value=4), key="num_arrays"), + shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), + ), + dim=helpers.get_axis( + shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), + ).filter(lambda axis: isinstance(axis, int)), +) +def test_torch_stack( + *, + dtype_value_shape, + dim, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, value = dtype_value_shape + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + tensors=value, + dim=dim, + ) + + +# swapaxes +@handle_frontend_test( + fn_tree="torch.swapaxes", + dtype_and_values=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + shape=st.shared(helpers.get_shape(min_num_dims=2), key="shape"), + ), + axis0=helpers.get_axis( + shape=st.shared(helpers.get_shape(min_num_dims=2), key="shape"), + force_int=True, + ), + axis1=helpers.get_axis( + shape=st.shared(helpers.get_shape(min_num_dims=2), key="shape"), + force_int=True, + ), +) +def test_torch_swapaxes( + *, + dtype_and_values, + axis0, + axis1, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, value = dtype_and_values + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=value[0], + axis0=axis0, + axis1=axis1, + ) + + +# swapdims +@handle_frontend_test( + fn_tree="torch.swapdims", + dtype_and_values=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + shape=st.shared(helpers.get_shape(min_num_dims=2), key="shape"), + ), + dim0=helpers.get_axis( + shape=st.shared(helpers.get_shape(min_num_dims=2), key="shape"), + force_int=True, + ), + dim1=helpers.get_axis( + shape=st.shared(helpers.get_shape(min_num_dims=2), key="shape"), + force_int=True, + ), +) +def test_torch_swapdims( + *, + dtype_and_values, + dim0, + dim1, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, value = dtype_and_values + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=value[0], + dim0=dim0, + dim1=dim1, + ) + + +# t +@handle_frontend_test( + fn_tree="torch.t", + dtype_and_values=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + shape=st.shared(helpers.get_shape(max_num_dims=2), key="shape"), + ), +) +def test_torch_t( + *, + dtype_and_values, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, value = dtype_and_values + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=value[0], + ) + + +@handle_frontend_test( + fn_tree="torch.take", + dtype_and_x=helpers.array_indices_axis( + array_dtypes=helpers.get_dtypes(), indices_dtypes=["int64"] + ), +) +def test_torch_take( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + dtypes, xs, indices, _, _ = dtype_and_x + helpers.test_frontend_function( + input_dtypes=dtypes, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=xs, + index=indices, + ) + + +# take_along_dim +@handle_frontend_test( + fn_tree="torch.take_along_dim", + dtype_indices_axis=helpers.array_indices_axis( + array_dtypes=helpers.get_dtypes("numeric"), + indices_dtypes=["int64"], + min_num_dims=1, + max_num_dims=5, + min_dim_size=1, + max_dim_size=10, + indices_same_dims=True, + ), +) +def test_torch_take_along_dim( + *, + dtype_indices_axis, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtypes, value, indices, axis, _ = dtype_indices_axis + helpers.test_frontend_function( + input_dtypes=input_dtypes, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=value, + indices=indices, + dim=axis, + ) + + +# tensor_split +@handle_frontend_test( + fn_tree="torch.tensor_split", + dtype_value=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("integer"), + shape=st.shared(helpers.get_shape(min_num_dims=1), key="value_shape"), + ), + indices_or_sections=_get_splits( + min_num_dims=1, allow_none=False, allow_array_indices=False + ), + axis=st.shared( + helpers.get_axis( + shape=st.shared(helpers.get_shape(min_num_dims=1), key="value_shape"), + force_int=True, + ), + key="target_axis", + ), + number_positional_args=st.just(2), + test_with_out=st.just(False), +) +def test_torch_tensor_split( + *, + dtype_value, + indices_or_sections, + axis, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, value = dtype_value + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=value[0], + indices_or_sections=indices_or_sections, + dim=axis, + ) + + +# tile +@handle_frontend_test( + fn_tree="torch.tile", + dtype_value=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + shape=st.shared(helpers.get_shape(), key="shape"), + ), + dim=helpers.get_axis( + shape=st.shared(helpers.get_shape(), key="shape"), + allow_neg=False, + force_tuple=True, + ), +) +def test_torch_tile( + *, + dtype_value, + dim, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, value = dtype_value + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=value[0], + dims=dim, + ) + + +# transpose +@handle_frontend_test( + fn_tree="torch.transpose", + dtype_and_values=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + shape=st.shared(helpers.get_shape(min_num_dims=2), key="shape"), + ), + dim0=helpers.get_axis( + shape=st.shared(helpers.get_shape(min_num_dims=2), key="shape"), + force_int=True, + ), + dim1=helpers.get_axis( + shape=st.shared(helpers.get_shape(min_num_dims=2), key="shape"), + force_int=True, + ), +) +def test_torch_transpose( + *, + dtype_and_values, + dim0, + dim1, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, value = dtype_and_values + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=value[0], + dim0=dim0, + dim1=dim1, + ) + + +# unbind +@handle_frontend_test( + fn_tree="torch.unbind", + dtype_value_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("numeric"), + min_num_dims=1, + valid_axis=True, + force_int_axis=True, + ), +) +def test_torch_unbind( + *, + dtype_value_axis, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtypes, value, axis = dtype_value_axis + helpers.test_frontend_function( + input_dtypes=input_dtypes, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=value[0], + dim=axis, + ) + + +# unsqueeze +@handle_frontend_test( + fn_tree="torch.unsqueeze", + dtype_value=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + shape=st.shared(helpers.get_shape(), key="shape"), + ), + dim=helpers.get_axis( + shape=st.shared(helpers.get_shape(), key="shape"), + allow_neg=True, + force_int=True, + ), +) +def test_torch_unsqueeze( + *, + dtype_value, + dim, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, value = dtype_value + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=value[0], + dim=dim, + ) + + +# vsplit +@handle_frontend_test( + fn_tree="torch.vsplit", + dtype_value=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + shape=st.shared(helpers.get_shape(min_num_dims=2), key="value_shape"), + ), + indices_or_sections=_get_splits( + min_num_dims=2, + axis=0, + allow_none=False, + allow_array_indices=False, + is_mod_split=True, + ), +) +def test_torch_vsplit( + *, + dtype_value, + indices_or_sections, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, value = dtype_value + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=value[0], + indices_or_sections=indices_or_sections, + ) + + +# vstack +@handle_frontend_test( + fn_tree="torch.vstack", + dtype_value_shape=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + ), +) +def test_torch_vstack( + *, + dtype_value_shape, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, value = dtype_value_shape + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + tensors=value, + ) + + +@handle_frontend_test( + fn_tree="torch.where", + broadcastables=_broadcastable_trio(), + only_cond=st.booleans(), +) +def test_torch_where( + *, + broadcastables, + only_cond, + frontend, + test_flags, + fn_tree, + backend_fw, + on_device, +): + cond, xs, dtypes = broadcastables + + if only_cond: + helpers.test_frontend_function( + input_dtypes=[dtypes[0]], + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + condition=xs[0], + ) + + else: + helpers.test_frontend_function( + input_dtypes=["bool"] + dtypes, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + condition=cond, + input=xs[0], + other=xs[1], + backend_to_test=backend_fw, + ) diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py index 496668ac8e255..577967c5f3e1d 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py @@ -18,6 +18,48 @@ from ivy_tests.test_ivy.test_functional.test_core.test_linalg import _matrix_rank_helper +# --- Helpers --- # +# --------------- # + + +@st.composite +def _generate_multi_dot_dtype_and_arrays(draw): + input_dtype = [draw(st.sampled_from(draw(helpers.get_dtypes("valid"))))] + matrices_dims = draw( + st.lists(st.integers(min_value=2, max_value=10), min_size=4, max_size=4) + ) + shape_1 = (matrices_dims[0], matrices_dims[1]) + shape_2 = (matrices_dims[1], matrices_dims[2]) + shape_3 = (matrices_dims[2], matrices_dims[3]) + + matrix_1 = draw( + helpers.dtype_and_values( + shape=shape_1, + dtype=input_dtype, + min_value=-10, + max_value=10, + ) + ) + matrix_2 = draw( + helpers.dtype_and_values( + shape=shape_2, + dtype=input_dtype, + min_value=-10, + max_value=10, + ) + ) + matrix_3 = draw( + helpers.dtype_and_values( + shape=shape_3, + dtype=input_dtype, + min_value=-10, + max_value=10, + ) + ) + + return input_dtype, [matrix_1[1][0], matrix_2[1][0], matrix_3[1][0]] + + # helpers @st.composite def _get_dtype_and_matrix( @@ -46,35 +88,175 @@ def _get_dtype_and_matrix( return draw(ret) -# vector_norm +@st.composite +def _get_dtype_and_symmetrix_matrix(draw): + input_dtype = draw(st.shared(st.sampled_from(draw(helpers.get_dtypes("valid"))))) + random_size = draw(helpers.ints(min_value=2, max_value=4)) + batch_shape = draw(helpers.get_shape(min_num_dims=1, max_num_dims=3)) + num_independnt_vals = int((random_size**2) / 2 + random_size / 2) + array_vals_flat = np.array( + draw( + helpers.array_values( + dtype=input_dtype, + shape=tuple(list(batch_shape) + [num_independnt_vals]), + min_value=2, + max_value=5, + ) + ) + ) + array_vals = np.zeros(batch_shape + (random_size, random_size)) + c = 0 + for i in range(random_size): + for j in range(random_size): + if j < i: + continue + array_vals[..., i, j] = array_vals_flat[..., c] + array_vals[..., j, i] = array_vals_flat[..., c] + c += 1 + return [input_dtype], array_vals + + +# tensorsolve +@st.composite +def _get_solve_matrices(draw): + # batch_shape, random_size, shared + + # float16 causes a crash when filtering out matrices + # for which `np.linalg.cond` is large. + input_dtype_strategy = st.shared( + st.sampled_from(draw(helpers.get_dtypes("valid"))), + key="shared_dtype", + ) + input_dtype = draw(input_dtype_strategy) + + dim = draw(helpers.ints(min_value=2, max_value=5)) + + first_matrix = draw( + helpers.array_values( + dtype=input_dtype, + shape=(dim, dim, dim, dim), + min_value=1.2, + max_value=5, + ).filter( + lambda x: np.linalg.cond(x.reshape((dim**2, dim**2))) + < 1 / sys.float_info.epsilon + ) + ) + + second_matrix = draw( + helpers.array_values( + dtype=input_dtype, + shape=(dim, dim), + min_value=1.2, + max_value=3, + ).filter( + lambda x: np.linalg.cond(x.reshape((dim, dim))) < 1 / sys.float_info.epsilon + ) + ) + + return input_dtype, first_matrix, second_matrix + + +# tensorinv +@st.composite +def _tensorinv_helper(draw): + def factors(x): + result = [ + 1, + ] + i = 2 + while i * i <= x: + if x % i == 0: + result.append(i) + if x // i != i: + result.append(x // i) + i += 1 + result.append(x) + return np.array(result) + + ind = draw(helpers.ints(min_value=1, max_value=6)) + product_half = draw(helpers.ints(min_value=2, max_value=25)) + factors_list = factors(product_half) + shape = () + while len(shape) < ind and ind > 2: + while np.prod(shape) < product_half: + a = factors_list[np.random.randint(len(factors_list))] + shape += (a,) + if np.prod(shape) > product_half or len(shape) > ind: + shape = () + while len(shape) < ind and shape != (): + shape += (1,) + if np.prod(shape) == product_half: + shape += shape[::-1] + break + if ind == 1 and shape == (): + shape += (product_half, product_half) + if ind == 2 and shape == (): + shape += (1, product_half, product_half, 1) + shape_cor = () + for i in shape: + shape_cor += (int(i),) + shape_draw = (product_half, product_half) + dtype, input = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + shape=shape_draw, + ).filter(lambda x: np.linalg.cond(x[1]) < 1 / sys.float_info.epsilon) + ) + input[0] = input[0].reshape(shape_cor) + return dtype, input[0], ind + + +# vander +@st.composite +def _vander_helper(draw): + # generate input matrix of shape (*, n) and where '*' is one or more + # batch dimensions + N = draw(helpers.ints(min_value=2, max_value=5)) + if draw(helpers.floats(min_value=0, max_value=1.0)) < 0.5: + N = None + + shape = draw( + helpers.get_shape( + min_num_dims=1, max_num_dims=5, min_dim_size=2, max_dim_size=10 + ) + ) + x = draw( + helpers.dtype_and_values( + available_dtypes=draw(helpers.get_dtypes("valid")), + shape=shape, + min_value=-10, + max_value=10, + ) + ) + + return *x, N + + +# --- Main --- # +# ------------ # + + @handle_frontend_test( - fn_tree="torch.linalg.vector_norm", - dtype_values_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("valid"), - valid_axis=True, - abs_smallest_val=1e04, - ), - kd=st.booleans(), - ord=st.one_of( - helpers.ints(min_value=0, max_value=5), - helpers.floats(min_value=1.0, max_value=5.0), - st.sampled_from((float("inf"), -float("inf"))), - ), - dtype=helpers.get_dtypes("valid", full=False), + fn_tree="torch.linalg.cholesky", + aliases=["torch.cholesky"], + dtype_and_x=_get_dtype_and_matrix(square=True), + upper=st.booleans(), ) -def test_torch_vector_norm( +def test_torch_cholesky( *, - dtype_values_axis, - kd, - ord, - dtype, + dtype_and_x, + upper, on_device, fn_tree, frontend, test_flags, backend_fw, ): - dtype, x, axis = dtype_values_axis + dtype, x = dtype_and_x + x = np.asarray(x[0], dtype=dtype[0]) + x = np.matmul(x.T, x) + np.identity(x.shape[0]) # make symmetric positive-definite + helpers.test_frontend_function( input_dtypes=dtype, backend_to_test=backend_fw, @@ -82,23 +264,21 @@ def test_torch_vector_norm( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=x[0], - ord=ord, - dim=axis, - keepdim=kd, - dtype=dtype[0], + rtol=1e-01, + input=x, + upper=upper, ) -# inv @handle_frontend_test( - fn_tree="torch.linalg.inv", - aliases=["torch.inverse"], - dtype_and_x=_get_dtype_and_matrix(square=True, invertible=True, batch=True), + fn_tree="torch.linalg.cholesky_ex", + dtype_and_x=_get_dtype_and_matrix(square=True, batch=True), + upper=st.booleans(), ) -def test_torch_inv( +def test_torch_cholesky_ex( *, dtype_and_x, + upper, on_device, fn_tree, frontend, @@ -106,7 +286,8 @@ def test_torch_inv( backend_fw, ): dtype, x = dtype_and_x - test_flags.num_positional_args = 1 + x = np.matmul(x.T, x) + np.identity(x.shape[0]) # make symmetric positive-definite + helpers.test_frontend_function( input_dtypes=dtype, backend_to_test=backend_fw, @@ -114,26 +295,19 @@ def test_torch_inv( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - rtol=1e-03, - atol=1e-03, - A=x[0], + rtol=1e-01, + input=x, + upper=upper, ) -# inv_ex -# TODO: Test for singular matrices @handle_frontend_test( - fn_tree="torch.linalg.inv_ex", + fn_tree="torch.linalg.cond", dtype_and_x=_get_dtype_and_matrix(square=True, invertible=True, batch=True), + p=st.sampled_from([None, "fro", "nuc", np.inf, -np.inf, 1, -1, 2, -2]), ) -def test_torch_inv_ex( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, +def test_torch_cond( + *, dtype_and_x, p, on_device, fn_tree, frontend, backend_fw, test_flags ): dtype, x = dtype_and_x helpers.test_frontend_function( @@ -143,38 +317,49 @@ def test_torch_inv_ex( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - rtol=1e-03, - atol=1e-02, - A=x[0], + test_values=False, + input=x[0], + rtol=1e-2, + atol=1e-3, + p=p, ) -# pinv -# TODO: add testing for hermitian +# cross @handle_frontend_test( - fn_tree="torch.linalg.pinv", - dtype_and_input=_get_dtype_and_matrix(batch=True), + fn_tree="torch.linalg.cross", + dtype_input_other_dim=dtype_value1_value2_axis( + available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=1, + max_num_dims=5, + min_dim_size=3, + max_dim_size=3, + min_value=-1e3, + max_value=1e3, + abs_smallest_val=0.01, + large_abs_safety_factor=2, + safety_factor_scale="log", + ), ) -def test_torch_pinv( - *, - dtype_and_input, +def test_torch_cross( + dtype_input_other_dim, frontend, test_flags, fn_tree, backend_fw, - on_device, ): - input_dtype, x = dtype_and_input + dtype, input, other, dim = dtype_input_other_dim helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, - on_device=on_device, - input=x[0], - atol=1e-02, - rtol=1e-02, + rtol=1e-2, + atol=1e-3, + input=input, + other=other, + dim=dim, ) @@ -206,12 +391,13 @@ def test_torch_det( ) -# qr +# eig +# TODO: Test for all valid dtypes once ivy.eig supports complex data types @handle_frontend_test( - fn_tree="torch.linalg.qr", - dtype_and_input=_get_dtype_and_matrix(batch=True), + fn_tree="torch.linalg.eig", + dtype_and_input=_get_dtype_and_matrix(dtype="float", square=True), ) -def test_torch_qr( +def test_torch_eig( *, dtype_and_input, frontend, @@ -221,7 +407,11 @@ def test_torch_qr( on_device, ): input_dtype, x = dtype_and_input - ivy.set_backend(backend_fw) + x = np.asarray(x[0], dtype=input_dtype[0]) + x = np.matmul(x.T, x) + np.identity(x.shape[0]) * 1e-3 + if x.dtype == ivy.float32: + x = x.astype("float64") + input_dtype = [ivy.float64] ret, frontend_ret = helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -229,52 +419,67 @@ def test_torch_qr( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - A=x[0], test_values=False, + input=x, ) - ret = [ivy.to_numpy(x) for x in ret] - frontend_ret = [np.asarray(x) for x in frontend_ret] + ret = [ivy.to_numpy(x).astype("float64") for x in ret] + frontend_ret = [np.asarray(x, dtype=np.float64) for x in frontend_ret] - q, r = ret - frontend_q, frontend_r = frontend_ret + l, v = ret + front_l, front_v = frontend_ret assert_all_close( - ret_np=q @ r, - ret_from_gt_np=frontend_q @ frontend_r, + ret_np=v @ np.diag(l) @ np.linalg.inv(v), + ret_from_gt_np=front_v @ np.diag(front_l) @ np.linalg.inv(front_v), rtol=1e-2, atol=1e-2, ground_truth_backend=frontend, ) - ivy.previous_backend() -# slogdet +# eigh +# TODO: Test for all valid dtypes @handle_frontend_test( - fn_tree="torch.linalg.slogdet", - aliases=["torch.slogdet"], - dtype_and_x=_get_dtype_and_matrix(square=True, batch=True), + fn_tree="torch.linalg.eigh", + dtype_and_x=_get_dtype_and_matrix(dtype="float", square=True, invertible=True), + UPLO=st.sampled_from(("L", "U")), ) -def test_torch_slogdet( +def test_torch_eigh( *, dtype_and_x, + UPLO, + on_device, fn_tree, frontend, - on_device, test_flags, backend_fw, ): dtype, x = dtype_and_x - test_flags.num_positional_args = len(x) - helpers.test_frontend_function( + x = np.array(x[0], dtype=dtype[0]) + # make symmetric positive-definite beforehand + x = np.matmul(x.T, x) + np.identity(x.shape[0]) * 1e-3 + + ret, frontend_ret = helpers.test_frontend_function( input_dtypes=dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - rtol=1e-4, - atol=1e-4, - A=x[0], + test_values=False, + a=x, + UPLO=UPLO, + ) + ret = [ivy.to_numpy(x) for x in ret] + frontend_ret = [np.asarray(x) for x in frontend_ret] + + L, Q = ret + frontend_L, frontend_Q = frontend_ret + + assert_all_close( + ret_np=Q @ np.diag(L) @ Q.T, + ret_from_gt_np=frontend_Q @ np.diag(frontend_L) @ frontend_Q.T, + atol=1e-02, ) @@ -340,39 +545,11 @@ def test_torch_eigvals( assert_all_close( ret_np=ret_modulus, - ret_from_gt_np=frontend_ret_modulus, - rtol=1e-2, - atol=1e-2, - ground_truth_backend=frontend, - ) - - -@st.composite -def _get_dtype_and_symmetrix_matrix(draw): - input_dtype = draw(st.shared(st.sampled_from(draw(helpers.get_dtypes("valid"))))) - random_size = draw(helpers.ints(min_value=2, max_value=4)) - batch_shape = draw(helpers.get_shape(min_num_dims=1, max_num_dims=3)) - num_independnt_vals = int((random_size**2) / 2 + random_size / 2) - array_vals_flat = np.array( - draw( - helpers.array_values( - dtype=input_dtype, - shape=tuple(list(batch_shape) + [num_independnt_vals]), - min_value=2, - max_value=5, - ) - ) - ) - array_vals = np.zeros(batch_shape + (random_size, random_size)) - c = 0 - for i in range(random_size): - for j in range(random_size): - if j < i: - continue - array_vals[..., i, j] = array_vals_flat[..., c] - array_vals[..., j, i] = array_vals_flat[..., c] - c += 1 - return [input_dtype], array_vals + ret_from_gt_np=frontend_ret_modulus, + rtol=1e-2, + atol=1e-2, + ground_truth_backend=frontend, + ) # eigvalsh @@ -406,15 +583,23 @@ def test_torch_eigvalsh( ) +# inv @handle_frontend_test( - fn_tree="torch.linalg.cond", + fn_tree="torch.linalg.inv", + aliases=["torch.inverse"], dtype_and_x=_get_dtype_and_matrix(square=True, invertible=True, batch=True), - p=st.sampled_from([None, "fro", "nuc", np.inf, -np.inf, 1, -1, 2, -2]), ) -def test_torch_cond( - *, dtype_and_x, p, on_device, fn_tree, frontend, backend_fw, test_flags +def test_torch_inv( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, ): dtype, x = dtype_and_x + test_flags.num_positional_args = 1 helpers.test_frontend_function( input_dtypes=dtype, backend_to_test=backend_fw, @@ -422,25 +607,21 @@ def test_torch_cond( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - test_values=False, - input=x[0], - rtol=1e-2, - atol=1e-3, - p=p, + rtol=1e-03, + atol=1e-03, + A=x[0], ) -# matrix_power +# inv_ex +# TODO: Test for singular matrices @handle_frontend_test( - fn_tree="torch.linalg.matrix_power", - aliases=["torch.matrix_power"], + fn_tree="torch.linalg.inv_ex", dtype_and_x=_get_dtype_and_matrix(square=True, invertible=True, batch=True), - n=helpers.ints(min_value=2, max_value=5), ) -def test_torch_matrix_power( +def test_torch_inv_ex( *, dtype_and_x, - n, on_device, fn_tree, frontend, @@ -448,7 +629,6 @@ def test_torch_matrix_power( backend_fw, ): dtype, x = dtype_and_x - test_flags.num_positional_args = len(x) + 1 helpers.test_frontend_function( input_dtypes=dtype, backend_to_test=backend_fw, @@ -456,9 +636,72 @@ def test_torch_matrix_power( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - rtol=1e-01, + rtol=1e-03, + atol=1e-02, A=x[0], - n=n, + ) + + +# lu_factor +@handle_frontend_test( + fn_tree="torch.linalg.lu_factor", + input_dtype_and_input=_get_dtype_and_matrix(batch=True), +) +def test_torch_lu_factor( + *, + input_dtype_and_input, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + dtype, input = input_dtype_and_input + helpers.test_frontend_function( + input_dtypes=dtype, + backend_to_test=backend_fw, + test_flags=test_flags, + frontend=frontend, + fn_tree=fn_tree, + on_device=on_device, + rtol=1e-03, + atol=1e-02, + A=input[0], + ) + + +@handle_frontend_test( + fn_tree="torch.linalg.matmul", + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + shape=(3, 3), + num_arrays=2, + shared_dtype=True, + min_value=-1e04, + max_value=1e04, + ), +) +def test_torch_matmul( + *, + dtype_x, + frontend, + fn_tree, + on_device, + test_flags, + backend_fw, +): + input_dtype, x = dtype_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + fn_tree=fn_tree, + on_device=on_device, + test_flags=test_flags, + input=x[0], + other=x[1], + rtol=1e-03, + atol=1e-03, ) @@ -539,80 +782,35 @@ def test_torch_matrix_norm( ) -# cross +# matrix_power @handle_frontend_test( - fn_tree="torch.linalg.cross", - dtype_input_other_dim=dtype_value1_value2_axis( - available_dtypes=helpers.get_dtypes("valid"), - min_num_dims=1, - max_num_dims=5, - min_dim_size=3, - max_dim_size=3, - min_value=-1e3, - max_value=1e3, - abs_smallest_val=0.01, - large_abs_safety_factor=2, - safety_factor_scale="log", - ), + fn_tree="torch.linalg.matrix_power", + aliases=["torch.matrix_power"], + dtype_and_x=_get_dtype_and_matrix(square=True, invertible=True, batch=True), + n=helpers.ints(min_value=2, max_value=5), ) -def test_torch_cross( - dtype_input_other_dim, - frontend, - test_flags, +def test_torch_matrix_power( + *, + dtype_and_x, + n, + on_device, fn_tree, - backend_fw, -): - dtype, input, other, dim = dtype_input_other_dim - helpers.test_frontend_function( - input_dtypes=dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - rtol=1e-2, - atol=1e-3, - input=input, - other=other, - dim=dim, - ) - - -# vecdot -@handle_frontend_test( - fn_tree="torch.linalg.vecdot", - dtype_input_other_dim=dtype_value1_value2_axis( - available_dtypes=helpers.get_dtypes("valid"), - min_num_dims=1, - max_num_dims=5, - min_dim_size=3, - max_dim_size=3, - min_value=-1e3, - max_value=1e3, - abs_smallest_val=0.01, - large_abs_safety_factor=2, - safety_factor_scale="log", - ), -) -def test_torch_vecdot( - dtype_input_other_dim, frontend, test_flags, - fn_tree, backend_fw, ): - dtype, input, other, dim = dtype_input_other_dim - test_flags.num_positional_args = len(dtype_input_other_dim) - 2 + dtype, x = dtype_and_x + test_flags.num_positional_args = len(x) + 1 helpers.test_frontend_function( input_dtypes=dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, - rtol=1e-2, - atol=1e-3, - x=input, - y=other, - dim=dim, + on_device=on_device, + rtol=1e-01, + A=x[0], + n=n, ) @@ -646,93 +844,37 @@ def test_torch_matrix_rank( @handle_frontend_test( - fn_tree="torch.linalg.cholesky", - aliases=["torch.cholesky"], - dtype_and_x=_get_dtype_and_matrix(square=True), - upper=st.booleans(), + fn_tree="torch.linalg.multi_dot", + dtype_x=_generate_multi_dot_dtype_and_arrays(), ) -def test_torch_cholesky( - *, - dtype_and_x, - upper, +def test_torch_multi_dot( + dtype_x, on_device, fn_tree, frontend, test_flags, backend_fw, ): - dtype, x = dtype_and_x - x = np.asarray(x[0], dtype=dtype[0]) - x = np.matmul(x.T, x) + np.identity(x.shape[0]) # make symmetric positive-definite - + dtype, x = dtype_x helpers.test_frontend_function( input_dtypes=dtype, backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - rtol=1e-01, - input=x, - upper=upper, - ) - - -# svd -@handle_frontend_test( - fn_tree="torch.linalg.svd", - dtype_and_x=_get_dtype_and_matrix(square=True), - full_matrices=st.booleans(), -) -def test_torch_svd( - *, - dtype_and_x, - full_matrices, - frontend, - test_flags, - fn_tree, - backend_fw, - on_device, -): - dtype, x = dtype_and_x - x = np.asarray(x[0], dtype=dtype[0]) - # make symmetric positive definite beforehand - x = np.matmul(x.T, x) + np.identity(x.shape[0]) * 1e-3 - ret, frontend_ret = helpers.test_frontend_function( - input_dtypes=dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - test_values=False, - atol=1e-03, - rtol=1e-05, - A=x, - full_matrices=full_matrices, - ) - ret = [ivy.to_numpy(x) for x in ret] - frontend_ret = [np.asarray(x) for x in frontend_ret] - - u, s, vh = ret - frontend_u, frontend_s, frontend_vh = frontend_ret - - assert_all_close( - ret_np=u @ np.diag(s) @ vh, - ret_from_gt_np=frontend_u @ np.diag(frontend_s) @ frontend_vh, - rtol=1e-2, - atol=1e-2, - ground_truth_backend=frontend, + test_flags=test_flags, + on_device=on_device, + frontend=frontend, + fn_tree=fn_tree, + test_values=True, + tensors=x, ) -# eig -# TODO: Test for all valid dtypes once ivy.eig supports complex data types +# pinv +# TODO: add testing for hermitian @handle_frontend_test( - fn_tree="torch.linalg.eig", - dtype_and_input=_get_dtype_and_matrix(dtype="float", square=True), + fn_tree="torch.linalg.pinv", + dtype_and_input=_get_dtype_and_matrix(batch=True), ) -def test_torch_eig( +def test_torch_pinv( *, dtype_and_input, frontend, @@ -742,97 +884,78 @@ def test_torch_eig( on_device, ): input_dtype, x = dtype_and_input - x = np.asarray(x[0], dtype=input_dtype[0]) - x = np.matmul(x.T, x) + np.identity(x.shape[0]) * 1e-3 - if x.dtype == ivy.float32: - x = x.astype("float64") - input_dtype = [ivy.float64] - ret, frontend_ret = helpers.test_frontend_function( + helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - test_values=False, - input=x, - ) - ret = [ivy.to_numpy(x).astype("float64") for x in ret] - frontend_ret = [np.asarray(x, dtype=np.float64) for x in frontend_ret] - - l, v = ret - front_l, front_v = frontend_ret - - assert_all_close( - ret_np=v @ np.diag(l) @ np.linalg.inv(v), - ret_from_gt_np=front_v @ np.diag(front_l) @ np.linalg.inv(front_v), - rtol=1e-2, - atol=1e-2, - ground_truth_backend=frontend, + input=x[0], + atol=1e-02, + rtol=1e-02, ) -# eigh -# TODO: Test for all valid dtypes +# qr @handle_frontend_test( - fn_tree="torch.linalg.eigh", - dtype_and_x=_get_dtype_and_matrix(dtype="float", square=True, invertible=True), - UPLO=st.sampled_from(("L", "U")), + fn_tree="torch.linalg.qr", + dtype_and_input=_get_dtype_and_matrix(batch=True), ) -def test_torch_eigh( +def test_torch_qr( *, - dtype_and_x, - UPLO, - on_device, - fn_tree, + dtype_and_input, frontend, test_flags, + fn_tree, backend_fw, + on_device, ): - dtype, x = dtype_and_x - x = np.array(x[0], dtype=dtype[0]) - # make symmetric positive-definite beforehand - x = np.matmul(x.T, x) + np.identity(x.shape[0]) * 1e-3 - + input_dtype, x = dtype_and_input + ivy.set_backend(backend_fw) ret, frontend_ret = helpers.test_frontend_function( - input_dtypes=dtype, + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, + A=x[0], test_values=False, - a=x, - UPLO=UPLO, ) ret = [ivy.to_numpy(x) for x in ret] frontend_ret = [np.asarray(x) for x in frontend_ret] - L, Q = ret - frontend_L, frontend_Q = frontend_ret + q, r = ret + frontend_q, frontend_r = frontend_ret assert_all_close( - ret_np=Q @ np.diag(L) @ Q.T, - ret_from_gt_np=frontend_Q @ np.diag(frontend_L) @ frontend_Q.T, - atol=1e-02, + ret_np=q @ r, + ret_from_gt_np=frontend_q @ frontend_r, + rtol=1e-2, + atol=1e-2, + ground_truth_backend=frontend, ) + ivy.previous_backend() -# svdvals +# slogdet @handle_frontend_test( - fn_tree="torch.linalg.svdvals", - dtype_and_x=_get_dtype_and_matrix(batch=True), + fn_tree="torch.linalg.slogdet", + aliases=["torch.slogdet"], + dtype_and_x=_get_dtype_and_matrix(square=True, batch=True), ) -def test_torch_svdvals( +def test_torch_slogdet( *, dtype_and_x, - on_device, fn_tree, frontend, + on_device, test_flags, backend_fw, ): dtype, x = dtype_and_x + test_flags.num_positional_args = len(x) helpers.test_frontend_function( input_dtypes=dtype, backend_to_test=backend_fw, @@ -840,6 +963,8 @@ def test_torch_svdvals( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, + rtol=1e-4, + atol=1e-4, A=x[0], ) @@ -888,168 +1013,139 @@ def test_torch_solve( ) -# tensorinv -@st.composite -def _tensorinv_helper(draw): - def factors(x): - result = [ - 1, - ] - i = 2 - while i * i <= x: - if x % i == 0: - result.append(i) - if x // i != i: - result.append(x // i) - i += 1 - result.append(x) - return np.array(result) - - ind = draw(helpers.ints(min_value=1, max_value=6)) - product_half = draw(helpers.ints(min_value=2, max_value=25)) - factors_list = factors(product_half) - shape = () - while len(shape) < ind and ind > 2: - while np.prod(shape) < product_half: - a = factors_list[np.random.randint(len(factors_list))] - shape += (a,) - if np.prod(shape) > product_half or len(shape) > ind: - shape = () - while len(shape) < ind and shape != (): - shape += (1,) - if np.prod(shape) == product_half: - shape += shape[::-1] - break - if ind == 1 and shape == (): - shape += (product_half, product_half) - if ind == 2 and shape == (): - shape += (1, product_half, product_half, 1) - shape_cor = () - for i in shape: - shape_cor += (int(i),) - shape_draw = (product_half, product_half) - dtype, input = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - shape=shape_draw, - ).filter(lambda x: np.linalg.cond(x[1]) < 1 / sys.float_info.epsilon) - ) - input[0] = input[0].reshape(shape_cor) - return dtype, input[0], ind - - +# solve_ex @handle_frontend_test( - fn_tree="torch.linalg.tensorinv", dtype_input_ind=_tensorinv_helper() + fn_tree="torch.linalg.solve_ex", + dtype_and_data=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + min_value=0, + max_value=10, + shape=helpers.ints(min_value=2, max_value=5).map(lambda x: tuple([x, x + 1])), + safety_factor_scale="log", + small_abs_safety_factor=6, + ).filter( + lambda x: np.linalg.cond(x[1][0][:, :-1]) < 1 / sys.float_info.epsilon + and np.linalg.det(x[1][0][:, :-1]) != 0 + and np.linalg.cond(x[1][0][:, -1].reshape(-1, 1)) < 1 / sys.float_info.epsilon + ), + left=st.booleans(), + check_errors=st.booleans(), ) -def test_torch_tensorinv( +def test_torch_solve_ex( *, - dtype_input_ind, + dtype_and_data, + left, + check_errors, on_device, fn_tree, frontend, test_flags, backend_fw, ): - dtype, x, ind = dtype_input_ind + input_dtype, data = dtype_and_data + input = data[0][:, :-1] + other = data[0][:, -1].reshape(-1, 1) helpers.test_frontend_function( - input_dtypes=dtype, + input_dtypes=[input_dtype[0], input_dtype[0]], backend_to_test=backend_fw, - test_flags=test_flags, frontend=frontend, + test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - rtol=1e-04, - atol=1e-03, - input=x, - ind=ind, + A=input, + B=other, + left=left, + check_errors=check_errors, ) -# tensorsolve -@st.composite -def _get_solve_matrices(draw): - # batch_shape, random_size, shared - - # float16 causes a crash when filtering out matrices - # for which `np.linalg.cond` is large. - input_dtype_strategy = st.shared( - st.sampled_from(draw(helpers.get_dtypes("valid"))), - key="shared_dtype", +# svd +@handle_frontend_test( + fn_tree="torch.linalg.svd", + dtype_and_x=_get_dtype_and_matrix(square=True), + full_matrices=st.booleans(), +) +def test_torch_svd( + *, + dtype_and_x, + full_matrices, + frontend, + test_flags, + fn_tree, + backend_fw, + on_device, +): + dtype, x = dtype_and_x + x = np.asarray(x[0], dtype=dtype[0]) + # make symmetric positive definite beforehand + x = np.matmul(x.T, x) + np.identity(x.shape[0]) * 1e-3 + ret, frontend_ret = helpers.test_frontend_function( + input_dtypes=dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + test_values=False, + atol=1e-03, + rtol=1e-05, + A=x, + full_matrices=full_matrices, ) - input_dtype = draw(input_dtype_strategy) - - dim = draw(helpers.ints(min_value=2, max_value=5)) + ret = [ivy.to_numpy(x) for x in ret] + frontend_ret = [np.asarray(x) for x in frontend_ret] - first_matrix = draw( - helpers.array_values( - dtype=input_dtype, - shape=(dim, dim, dim, dim), - min_value=1.2, - max_value=5, - ).filter( - lambda x: np.linalg.cond(x.reshape((dim**2, dim**2))) - < 1 / sys.float_info.epsilon - ) - ) + u, s, vh = ret + frontend_u, frontend_s, frontend_vh = frontend_ret - second_matrix = draw( - helpers.array_values( - dtype=input_dtype, - shape=(dim, dim), - min_value=1.2, - max_value=3, - ).filter( - lambda x: np.linalg.cond(x.reshape((dim, dim))) < 1 / sys.float_info.epsilon - ) + assert_all_close( + ret_np=u @ np.diag(s) @ vh, + ret_from_gt_np=frontend_u @ np.diag(frontend_s) @ frontend_vh, + rtol=1e-2, + atol=1e-2, + ground_truth_backend=frontend, ) - return input_dtype, first_matrix, second_matrix - +# svdvals @handle_frontend_test( - fn_tree="torch.linalg.tensorsolve", - a_and_b=_get_solve_matrices(), + fn_tree="torch.linalg.svdvals", + dtype_and_x=_get_dtype_and_matrix(batch=True), ) -def test_torch_tensorsolve( +def test_torch_svdvals( *, - a_and_b, + dtype_and_x, on_device, fn_tree, frontend, test_flags, backend_fw, ): - input_dtype, A, B = a_and_b - test_flags.num_positional_args = len(a_and_b) - 1 + dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=[input_dtype], + input_dtypes=dtype, backend_to_test=backend_fw, - test_flags=test_flags, frontend=frontend, + test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - atol=1e-3, - rtol=1e-3, - A=A, - B=B, + A=x[0], ) -# lu_factor @handle_frontend_test( - fn_tree="torch.linalg.lu_factor", - input_dtype_and_input=_get_dtype_and_matrix(batch=True), + fn_tree="torch.linalg.tensorinv", dtype_input_ind=_tensorinv_helper() ) -def test_torch_lu_factor( +def test_torch_tensorinv( *, - input_dtype_and_input, + dtype_input_ind, on_device, fn_tree, frontend, test_flags, backend_fw, ): - dtype, input = input_dtype_and_input + dtype, x, ind = dtype_input_ind helpers.test_frontend_function( input_dtypes=dtype, backend_to_test=backend_fw, @@ -1057,72 +1153,41 @@ def test_torch_lu_factor( frontend=frontend, fn_tree=fn_tree, on_device=on_device, - rtol=1e-03, - atol=1e-02, - A=input[0], + rtol=1e-04, + atol=1e-03, + input=x, + ind=ind, ) @handle_frontend_test( - fn_tree="torch.linalg.matmul", - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - shape=(3, 3), - num_arrays=2, - shared_dtype=True, - min_value=-1e04, - max_value=1e04, - ), + fn_tree="torch.linalg.tensorsolve", + a_and_b=_get_solve_matrices(), ) -def test_torch_matmul( +def test_torch_tensorsolve( *, - dtype_x, - frontend, - fn_tree, + a_and_b, on_device, + fn_tree, + frontend, test_flags, backend_fw, ): - input_dtype, x = dtype_x + input_dtype, A, B = a_and_b + test_flags.num_positional_args = len(a_and_b) - 1 helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=[input_dtype], backend_to_test=backend_fw, + test_flags=test_flags, frontend=frontend, fn_tree=fn_tree, on_device=on_device, - test_flags=test_flags, - input=x[0], - other=x[1], - rtol=1e-03, - atol=1e-03, - ) - - -# vander -@st.composite -def _vander_helper(draw): - # generate input matrix of shape (*, n) and where '*' is one or more - # batch dimensions - N = draw(helpers.ints(min_value=2, max_value=5)) - if draw(helpers.floats(min_value=0, max_value=1.0)) < 0.5: - N = None - - shape = draw( - helpers.get_shape( - min_num_dims=1, max_num_dims=5, min_dim_size=2, max_dim_size=10 - ) - ) - x = draw( - helpers.dtype_and_values( - available_dtypes=draw(helpers.get_dtypes("valid")), - shape=shape, - min_value=-10, - max_value=10, - ) + atol=1e-3, + rtol=1e-3, + A=A, + B=B, ) - return *x, N - @handle_frontend_test( fn_tree="torch.linalg.vander", @@ -1151,133 +1216,74 @@ def test_torch_vander( ) -@st.composite -def _generate_multi_dot_dtype_and_arrays(draw): - input_dtype = [draw(st.sampled_from(draw(helpers.get_dtypes("valid"))))] - matrices_dims = draw( - st.lists(st.integers(min_value=2, max_value=10), min_size=4, max_size=4) - ) - shape_1 = (matrices_dims[0], matrices_dims[1]) - shape_2 = (matrices_dims[1], matrices_dims[2]) - shape_3 = (matrices_dims[2], matrices_dims[3]) - - matrix_1 = draw( - helpers.dtype_and_values( - shape=shape_1, - dtype=input_dtype, - min_value=-10, - max_value=10, - ) - ) - matrix_2 = draw( - helpers.dtype_and_values( - shape=shape_2, - dtype=input_dtype, - min_value=-10, - max_value=10, - ) - ) - matrix_3 = draw( - helpers.dtype_and_values( - shape=shape_3, - dtype=input_dtype, - min_value=-10, - max_value=10, - ) - ) - - return input_dtype, [matrix_1[1][0], matrix_2[1][0], matrix_3[1][0]] - - -@handle_frontend_test( - fn_tree="torch.linalg.multi_dot", - dtype_x=_generate_multi_dot_dtype_and_arrays(), -) -def test_torch_multi_dot( - dtype_x, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - dtype, x = dtype_x - helpers.test_frontend_function( - input_dtypes=dtype, - backend_to_test=backend_fw, - test_flags=test_flags, - on_device=on_device, - frontend=frontend, - fn_tree=fn_tree, - test_values=True, - tensors=x, - ) - - -# solve_ex +# vecdot @handle_frontend_test( - fn_tree="torch.linalg.solve_ex", - dtype_and_data=helpers.dtype_and_values( + fn_tree="torch.linalg.vecdot", + dtype_input_other_dim=dtype_value1_value2_axis( available_dtypes=helpers.get_dtypes("valid"), - min_value=0, - max_value=10, - shape=helpers.ints(min_value=2, max_value=5).map(lambda x: tuple([x, x + 1])), + min_num_dims=1, + max_num_dims=5, + min_dim_size=3, + max_dim_size=3, + min_value=-1e3, + max_value=1e3, + abs_smallest_val=0.01, + large_abs_safety_factor=2, safety_factor_scale="log", - small_abs_safety_factor=6, - ).filter( - lambda x: np.linalg.cond(x[1][0][:, :-1]) < 1 / sys.float_info.epsilon - and np.linalg.det(x[1][0][:, :-1]) != 0 - and np.linalg.cond(x[1][0][:, -1].reshape(-1, 1)) < 1 / sys.float_info.epsilon ), - left=st.booleans(), - check_errors=st.booleans(), ) -def test_torch_solve_ex( - *, - dtype_and_data, - left, - check_errors, - on_device, - fn_tree, +def test_torch_vecdot( + dtype_input_other_dim, frontend, test_flags, + fn_tree, backend_fw, ): - input_dtype, data = dtype_and_data - input = data[0][:, :-1] - other = data[0][:, -1].reshape(-1, 1) + dtype, input, other, dim = dtype_input_other_dim + test_flags.num_positional_args = len(dtype_input_other_dim) - 2 helpers.test_frontend_function( - input_dtypes=[input_dtype[0], input_dtype[0]], + input_dtypes=dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, - on_device=on_device, - A=input, - B=other, - left=left, - check_errors=check_errors, + rtol=1e-2, + atol=1e-3, + x=input, + y=other, + dim=dim, ) +# vector_norm @handle_frontend_test( - fn_tree="torch.linalg.cholesky_ex", - dtype_and_x=_get_dtype_and_matrix(square=True, batch=True), - upper=st.booleans(), + fn_tree="torch.linalg.vector_norm", + dtype_values_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("valid"), + valid_axis=True, + abs_smallest_val=1e04, + ), + kd=st.booleans(), + ord=st.one_of( + helpers.ints(min_value=0, max_value=5), + helpers.floats(min_value=1.0, max_value=5.0), + st.sampled_from((float("inf"), -float("inf"))), + ), + dtype=helpers.get_dtypes("valid", full=False), ) -def test_torch_cholesky_ex( +def test_torch_vector_norm( *, - dtype_and_x, - upper, + dtype_values_axis, + kd, + ord, + dtype, on_device, fn_tree, frontend, test_flags, backend_fw, ): - dtype, x = dtype_and_x - x = np.matmul(x.T, x) + np.identity(x.shape[0]) # make symmetric positive-definite - + dtype, x, axis = dtype_values_axis helpers.test_frontend_function( input_dtypes=dtype, backend_to_test=backend_fw, @@ -1285,7 +1291,9 @@ def test_torch_cholesky_ex( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - rtol=1e-01, - input=x, - upper=upper, + input=x[0], + ord=ord, + dim=axis, + keepdim=kd, + dtype=dtype[0], ) diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_miscellaneous_ops.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_miscellaneous_ops.py index 3d9f7f7ab7b24..72c290ebecf3d 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_miscellaneous_ops.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_miscellaneous_ops.py @@ -1,1782 +1,1790 @@ -# global -import math - -import numpy as np -from hypothesis import assume, strategies as st -import hypothesis.extra.numpy as nph - -# local -import ivy -import ivy_tests.test_ivy.helpers as helpers -from ivy_tests.test_ivy.helpers import handle_frontend_test -from ivy_tests.test_ivy.test_functional.test_core.test_linalg import ( - _get_dtype_value1_value2_axis_for_tensordot, -) - - -# helpers -@st.composite -def _get_repeat_interleaves_args( - draw, *, available_dtypes, valid_axis, max_num_dims, max_dim_size -): - values_dtype, values, axis, shape = draw( - helpers.dtype_values_axis( - available_dtypes=available_dtypes, - valid_axis=valid_axis, - force_int_axis=True, - shape=draw( - helpers.get_shape( - allow_none=False, - min_num_dims=0, - max_num_dims=max_num_dims, - min_dim_size=1, - max_dim_size=max_dim_size, - ) - ), - ret_shape=True, - ) - ) - - if axis is None: - generate_repeats_as_integer = draw(st.booleans()) - num_repeats = 1 if generate_repeats_as_integer else math.prod(tuple(shape)) - else: - num_repeats = shape[axis] - - repeats_dtype, repeats = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("integer"), - min_value=0, - max_value=10, - shape=[num_repeats], - ) - ) - - # Output size is an optional parameter accepted by Torch for optimisation - use_output_size = draw(st.booleans()) - output_size = np.sum(repeats) if use_output_size else None - - return [values_dtype, repeats_dtype], values, repeats, axis, output_size - - -# atleast_1d -@handle_frontend_test( - fn_tree="torch.atleast_1d", - dtype_and_tensors=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - num_arrays=st.integers(min_value=1, max_value=5), - ), - test_with_out=st.just(False), -) -def test_torch_atleast_1d( - *, - dtype_and_tensors, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - dtypes, tensors = dtype_and_tensors - if isinstance(dtypes, list): # If more than one value was generated - args = { - f"x{i}": np.array(tensor, dtype=dtypes[i]) - for i, tensor in enumerate(tensors) - } - else: # If exactly one value was generated - args = {"x0": np.array(tensors, dtype=dtypes)} - test_flags.num_positional_args = len(tensors) - helpers.test_frontend_function( - input_dtypes=dtypes, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - **args, - ) - - -# flip -@handle_frontend_test( - fn_tree="torch.flip", - dtype_and_values=helpers.dtype_and_values( - shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), - available_dtypes=helpers.get_dtypes("float"), - ), - axis=helpers.get_axis( - shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), - force_tuple=True, - ), -) -def test_torch_flip( - *, - dtype_and_values, - axis, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, value = dtype_and_values - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=value[0], - dims=axis, - ) - - -# roll -@handle_frontend_test( - fn_tree="torch.roll", - dtype_and_values=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), - ), - shift=helpers.get_axis( - shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), - force_tuple=True, - ), - axis=helpers.get_axis( - shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), - force_tuple=True, - ), -) -def test_torch_roll( - *, - dtype_and_values, - shift, - axis, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, value = dtype_and_values - if isinstance(shift, int) and isinstance(axis, tuple): - axis = axis[0] - if isinstance(shift, tuple) and isinstance(axis, tuple): - if len(shift) != len(axis): - mn = min(len(shift), len(axis)) - shift = shift[:mn] - axis = axis[:mn] - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=value[0], - shifts=shift, - dims=axis, - ) - - -# meshgrid -@handle_frontend_test( - fn_tree="torch.meshgrid", - dtypes_and_tensors=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=st.integers(min_value=2, max_value=5), - min_num_dims=1, - max_num_dims=1, - min_dim_size=2, - max_dim_size=5, - shared_dtype=True, - ), - indexing=st.sampled_from(["ij", "xy"]), -) -def test_torch_meshgrid( - *, - dtypes_and_tensors, - indexing, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - dtypes, tensors = dtypes_and_tensors - kwargs = { - f"tensor{i}": np.array(tensor, dtype=dtypes[i]) - for i, tensor in enumerate(tensors) - } - kwargs["indexing"] = indexing - test_flags.num_positional_args = len(tensors) - helpers.test_frontend_function( - input_dtypes=dtypes, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - **kwargs, - ) - - -# fliplr -@handle_frontend_test( - fn_tree="torch.fliplr", - dtype_and_values=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - shape=helpers.get_shape(min_num_dims=2), - ), -) -def test_torch_fliplr( - *, - dtype_and_values, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, value = dtype_and_values - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=value[0], - ) - - -# flipud -@handle_frontend_test( - fn_tree="torch.flipud", - dtype_and_values=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - shape=helpers.get_shape(min_num_dims=1), - ), -) -def test_torch_flipud( - *, - dtype_and_values, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, value = dtype_and_values - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=value[0], - ) - - -# cumsum -@handle_frontend_test( - fn_tree="torch.cumsum", - dtype_x_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("numeric"), - min_value=-100, - max_value=100, - min_num_dims=1, - max_num_dims=5, - valid_axis=True, - allow_neg_axes=False, - max_axes_size=1, - force_int_axis=True, - ), - dtype=helpers.get_dtypes("numeric", none=True, full=False), -) -def test_torch_cumsum( - *, - dtype_x_axis, - dtype, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, x, axis = dtype_x_axis - # ToDo: set as_variable_flags as the parameter generated by test_torch_cumsum once - # this issue is marked as completed https://github.com/pytorch/pytorch/issues/75733 - if ivy.current_backend_str() == "torch": - test_flags.as_variable = [False] - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=x[0], - dim=axis, - dtype=dtype[0], - ) - - -@st.composite -def dims_and_offset(draw, shape): - shape_actual = draw(shape) - dim1 = draw(helpers.get_axis(shape=shape, force_int=True)) - dim2 = draw(helpers.get_axis(shape=shape, force_int=True)) - offset = draw( - st.integers(min_value=-shape_actual[dim1], max_value=shape_actual[dim1]) - ) - return dim1, dim2, offset - - -@handle_frontend_test( - fn_tree="torch.diagonal", - dtype_and_values=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - shape=st.shared(helpers.get_shape(min_num_dims=2), key="shape"), - ), - dims_and_offset=dims_and_offset( - shape=st.shared(helpers.get_shape(min_num_dims=2), key="shape") - ), -) -def test_torch_diagonal( - *, - dtype_and_values, - dims_and_offset, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, value = dtype_and_values - dim1, dim2, offset = dims_and_offset - input = value[0] - num_dims = len(np.shape(input)) - assume(dim1 != dim2) - if dim1 < 0: - assume(dim1 + num_dims != dim2) - if dim2 < 0: - assume(dim1 != dim2 + num_dims) - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=input, - offset=offset, - dim1=dim1, - dim2=dim2, - ) - - -@handle_frontend_test( - fn_tree="torch.cartesian_prod", - dtype_and_tensors=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - num_arrays=st.integers(min_value=1, max_value=5), - min_num_dims=1, - max_num_dims=1, - max_dim_size=5, - shared_dtype=True, - ), -) -def test_torch_cartesian_prod( - *, - dtype_and_tensors, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - dtypes, tensors = dtype_and_tensors - if isinstance(dtypes, list): # If more than one value was generated - args = { - f"x{i}": np.array(tensor, dtype=dtypes[i]) - for i, tensor in enumerate(tensors) - } - else: # If exactly one value was generated - args = {"x0": np.array(tensors, dtype=dtypes)} - test_flags.num_positional_args = len(tensors) - helpers.test_frontend_function( - input_dtypes=dtypes, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - **args, - ) - - -@handle_frontend_test( - fn_tree="torch.triu", - dtype_and_values=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - min_num_dims=2, # Torch requires this. - ), - diagonal=st.integers(min_value=-100, max_value=100), -) -def test_torch_triu( - *, - dtype_and_values, - diagonal, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - dtype, values = dtype_and_values - helpers.test_frontend_function( - input_dtypes=dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=values[0], - diagonal=diagonal, - ) - - -# cummax -@handle_frontend_test( - fn_tree="torch.cummax", - dtype_x_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("float"), - min_num_dims=1, - max_num_dims=2, - min_value=-100, - max_value=100, - valid_axis=True, - allow_neg_axes=False, - max_axes_size=1, - force_int_axis=True, - ), - dtype=helpers.get_dtypes("float", none=True, full=False), -) -def test_torch_cummax( - *, - dtype_x_axis, - dtype, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, x, axis = dtype_x_axis - if ivy.current_backend_str() == "torch": - test_flags.as_variable = [False] - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=x[0], - dim=axis, - ) - - -# cumprod -@handle_frontend_test( - fn_tree="torch.cumprod", - dtype_x_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("numeric"), - min_num_dims=1, - max_num_dims=5, - min_value=-100, - max_value=100, - valid_axis=True, - allow_neg_axes=False, - max_axes_size=1, - force_int_axis=True, - ), - dtype=helpers.get_dtypes("numeric", none=True, full=False), -) -def test_torch_cumprod( - *, - dtype_x_axis, - dtype, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, x, axis = dtype_x_axis - # ToDo: set as_variable_flags as the parameter generated by test_torch_cumsum once - # this issue is marked as completed https://github.com/pytorch/pytorch/issues/75733 - if ivy.current_backend_str() == "torch": - test_flags.as_variable = [False] - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=x[0], - dim=axis, - dtype=dtype[0], - ) - - -# trace -@handle_frontend_test( - fn_tree="torch.trace", - dtype_and_values=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - shape=st.shared(helpers.get_shape(min_num_dims=2, max_num_dims=2), key="shape"), - ), -) -def test_torch_trace( - *, - dtype_and_values, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - dtype, value = dtype_and_values - helpers.test_frontend_function( - input_dtypes=dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=value[0], - ) - - -# tril_indices -@handle_frontend_test( - fn_tree="torch.tril_indices", - row=st.integers(min_value=1, max_value=10), - col=st.integers(min_value=1, max_value=10), - offset=st.integers(min_value=-8, max_value=8), - dtype=helpers.get_dtypes("integer", full=False), -) -def test_torch_tril_indices( - *, - row, - col, - offset, - dtype, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - helpers.test_frontend_function( - input_dtypes=[ivy.int32], - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - row=row, - col=col, - offset=offset, - dtype=dtype[0], - ) - - -@handle_frontend_test( - fn_tree="torch.triu_indices", - row=st.integers(min_value=1, max_value=100), - col=st.integers(min_value=1, max_value=100), - offset=st.integers(min_value=-10, max_value=10), -) -def test_torch_triu_indices( - *, - row, - col, - offset, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - helpers.test_frontend_function( - input_dtypes=["int32"], - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - row=row, - col=col, - offset=offset, - ) - - -# tril -@handle_frontend_test( - fn_tree="torch.tril", - dtype_and_values=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - min_num_dims=2, # Torch requires this. - ), - diagonal=st.integers(min_value=-100, max_value=100), -) -def test_torch_tril( - *, - dtype_and_values, - diagonal, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - dtype, values = dtype_and_values - helpers.test_frontend_function( - input_dtypes=dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=values[0], - diagonal=diagonal, - ) - - -@handle_frontend_test( - fn_tree="torch.flatten", - dtype_input_axes=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("valid"), - valid_axis=True, - min_num_dims=1, - min_axes_size=2, - max_axes_size=2, - ), -) -def test_torch_flatten( - *, - dtype_input_axes, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - dtype, input, axes = dtype_input_axes - if isinstance(axes, int): - start_dim = axes - end_dim = -1 - else: - start_dim = axes[0] - end_dim = axes[1] - helpers.test_frontend_function( - input_dtypes=dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=input[0], - start_dim=start_dim, - end_dim=end_dim, - ) - - -# renorm -@handle_frontend_test( - fn_tree="torch.renorm", - dtype_and_values=helpers.dtype_and_values( - shape=st.shared(helpers.get_shape(min_num_dims=2), key="shape"), - available_dtypes=helpers.get_dtypes("numeric"), - max_value=1e4, - min_value=-1e4, - ), - dim=helpers.get_axis( - shape=st.shared(helpers.get_shape(), key="shape"), - force_int=True, - ), - p=st.floats( - min_value=0.5, - exclude_min=True, - max_value=5, - ), # Non-positive norms aren't supported in backends. - # Small positive norms cause issues due to finite-precision. - maxnorm=st.floats(min_value=0), # Norms are positive semi-definite -) -def test_torch_renorm( - *, - dtype_and_values, - p, - dim, - maxnorm, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - dtype, values = dtype_and_values - helpers.test_frontend_function( - input_dtypes=dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - atol=1e-02, - input=values[0], - p=p, - dim=dim, - maxnorm=maxnorm, - ) - - -# logcumsumexp -@handle_frontend_test( - fn_tree="torch.logcumsumexp", - dtype_and_input=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - shape=st.shared(helpers.get_shape(), key="shape"), - max_value=100, - min_value=-100, - ), - dim=helpers.get_axis( - shape=st.shared(helpers.get_shape(), key="shape"), force_int=True - ), -) -def test_torch_logcumsumexp( - *, - dtype_and_input, - dim, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - dtype, input = dtype_and_input - helpers.test_frontend_function( - input_dtypes=dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - rtol=1e-2, - atol=1e-2, - input=input[0], - dim=dim, - ) - - -# repeat_interleave -@handle_frontend_test( - fn_tree="torch.repeat_interleave", - dtype_values_repeats_axis_output_size=_get_repeat_interleaves_args( - available_dtypes=helpers.get_dtypes("valid"), - valid_axis=True, - max_num_dims=4, - max_dim_size=4, - ), -) -def test_torch_repeat_interleave( - *, - dtype_values_repeats_axis_output_size, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - dtype, values, repeats, axis, output_size = dtype_values_repeats_axis_output_size - - helpers.test_frontend_function( - input_dtypes=dtype[0], - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=values[0], - repeats=repeats[0], - dim=axis, - output_size=output_size, - ) - - -# ravel -@handle_frontend_test( - fn_tree="torch.ravel", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=1, - ), -) -def test_torch_ravel( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=np.asarray(x[0], dtype=input_dtype[0]), - ) - - -# rot90 -@handle_frontend_test( - fn_tree="torch.rot90", - dtype_and_x=helpers.dtype_and_values( - shape=st.shared(helpers.get_shape(min_num_dims=2), key="shape"), - available_dtypes=helpers.get_dtypes("numeric"), - ), - dims=helpers.get_axis( - shape=st.shared(helpers.get_shape(min_num_dims=2), key="shape"), - min_size=2, - max_size=2, - unique=True, - allow_neg=False, - force_tuple=True, - ), - k=st.integers(min_value=-10, max_value=10), -) -def test_torch_rot90( - *, - dtype_and_x, - dims, - k, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=x[0], - k=k, - dims=dims, - ) - - -# vander -@handle_frontend_test( - fn_tree="torch.vander", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - shape=st.tuples( - st.integers(min_value=1, max_value=5), - ), - min_num_dims=0, - max_num_dims=5, - ), - N=st.integers(min_value=1, max_value=10) | st.none(), - increasing=st.booleans(), -) -def test_torch_vander( - *, - dtype_and_x, - N, - increasing, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - x=np.asarray(x[0], dtype=input_dtype[0]), - N=N, - increasing=increasing, - ) - - -# lcm -@handle_frontend_test( - fn_tree="torch.lcm", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("integer"), - num_arrays=2, - shared_dtype=True, - min_num_dims=1, - max_num_dims=3, - min_value=-100, - max_value=100, - allow_nan=False, - ), -) -def test_torch_lcm( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=x[0], - other=x[1], - ) - - -# einsum -@handle_frontend_test( - fn_tree="torch.einsum", - eq_n_op_n_shp=helpers.einsum_helper(), - dtype=helpers.get_dtypes("numeric", full=False), -) -def test_torch_einsum( - *, - eq_n_op_n_shp, - dtype, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - eq, operands, dtypes = eq_n_op_n_shp - kw = {} - for i, x_ in enumerate(operands): - dtype = dtypes[i][0] - kw["x{}".format(i)] = np.array(x_).astype(dtype) - test_flags.num_positional_args = len(operands) + 1 - helpers.test_frontend_function( - input_dtypes=dtypes, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - equation=eq, - **kw, - ) - - -# cross -@st.composite -def dtype_value1_value2_axis( - draw, - available_dtypes, - abs_smallest_val=None, - min_value=None, - max_value=None, - allow_inf=False, - exclude_min=False, - exclude_max=False, - min_num_dims=1, - max_num_dims=10, - min_dim_size=1, - max_dim_size=10, - specific_dim_size=3, - large_abs_safety_factor=4, - small_abs_safety_factor=4, - safety_factor_scale="log", -): - # For cross product, a dim with size 3 is required - shape = draw( - helpers.get_shape( - allow_none=False, - min_num_dims=min_num_dims, - max_num_dims=max_num_dims, - min_dim_size=min_dim_size, - max_dim_size=max_dim_size, - ) - ) - axis = draw(helpers.ints(min_value=0, max_value=len(shape))) - # make sure there is a dim with specific dim size - shape = list(shape) - shape = shape[:axis] + [specific_dim_size] + shape[axis:] - shape = tuple(shape) - - dtype = draw(st.sampled_from(draw(available_dtypes))) - - values = [] - for i in range(2): - values.append( - draw( - helpers.array_values( - dtype=dtype, - shape=shape, - abs_smallest_val=abs_smallest_val, - min_value=min_value, - max_value=max_value, - allow_inf=allow_inf, - exclude_min=exclude_min, - exclude_max=exclude_max, - large_abs_safety_factor=large_abs_safety_factor, - small_abs_safety_factor=small_abs_safety_factor, - safety_factor_scale=safety_factor_scale, - ) - ) - ) - - value1, value2 = values[0], values[1] - return [dtype], value1, value2, axis - - -@handle_frontend_test( - fn_tree="torch.cross", - dtype_input_other_dim=dtype_value1_value2_axis( - available_dtypes=helpers.get_dtypes("numeric"), - min_num_dims=1, - max_num_dims=10, - min_dim_size=3, - max_dim_size=3, - min_value=-1e5, - max_value=1e5, - abs_smallest_val=0.01, - large_abs_safety_factor=2, - safety_factor_scale="log", - ), -) -def test_torch_cross( - dtype_input_other_dim, - frontend, - test_flags, - fn_tree, - backend_fw, -): - dtype, input, other, dim = dtype_input_other_dim - helpers.test_frontend_function( - input_dtypes=dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - rtol=1e-2, - atol=1e-2, - input=input, - other=other, - dim=dim, - ) - - -# gcd -@handle_frontend_test( - fn_tree="torch.gcd", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("integer"), - min_value=-100, - max_value=100, - min_num_dims=1, - max_num_dims=3, - min_dim_size=1, - max_dim_size=3, - num_arrays=2, - shared_dtype=True, - ), -) -def test_torch_gcd( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=x[0], - other=x[1], - ) - - -@handle_frontend_test( - fn_tree="torch.tensordot", - dtype_values_and_axes=_get_dtype_value1_value2_axis_for_tensordot( - helpers.get_dtypes(kind="float"), - min_value=-10, - max_value=10, - ), -) -def test_torch_tensordot( - dtype_values_and_axes, - test_flags, - frontend, - backend_fw, - fn_tree, -): - dtype, a, b, dims = dtype_values_and_axes - if ivy.current_backend_str() == "paddle": - # Paddle only supports ndim from 0 to 9 - assume(a.shape[0] < 10) - assume(b.shape[0] < 10) - - helpers.test_frontend_function( - input_dtypes=dtype, - backend_to_test=backend_fw, - test_flags=test_flags, - frontend=frontend, - fn_tree=fn_tree, - a=a, - b=b, - rtol=1e-2, - atol=1e-2, - dims=dims, - ) - - -# diff -@handle_frontend_test( - fn_tree="torch.diff", - dtype_n_x_n_axis=helpers.dtype_values_axis( - available_dtypes=st.shared(helpers.get_dtypes("valid"), key="dtype"), - min_num_dims=1, - valid_axis=True, - force_int_axis=True, - ), - n=st.integers(min_value=0, max_value=5), - dtype_prepend=helpers.dtype_and_values( - available_dtypes=st.shared(helpers.get_dtypes("valid"), key="dtype"), - min_num_dims=1, - max_num_dims=1, - ), - dtype_append=helpers.dtype_and_values( - available_dtypes=st.shared(helpers.get_dtypes("valid"), key="dtype"), - min_num_dims=1, - max_num_dims=1, - ), -) -def test_torch_diff( - *, - dtype_n_x_n_axis, - n, - dtype_prepend, - dtype_append, - test_flags, - frontend, - backend_fw, - fn_tree, -): - input_dtype, x, axis = dtype_n_x_n_axis - _, prepend = dtype_prepend - _, append = dtype_append - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - test_flags=test_flags, - frontend=frontend, - fn_tree=fn_tree, - input=x[0], - n=n, - dim=axis, - prepend=prepend[0], - append=append[0], - ) - - -@handle_frontend_test( - fn_tree="torch.broadcast_shapes", - shapes=nph.mutually_broadcastable_shapes( - num_shapes=4, min_dims=1, max_dims=5, min_side=1, max_side=5 - ), - test_with_out=st.just(False), -) -def test_torch_broadcast_shapes( - *, - shapes, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - shape, _ = shapes - shapes = {f"shape{i}": shape[i] for i in range(len(shape))} - test_flags.num_positional_args = len(shapes) - ret, frontend_ret = helpers.test_frontend_function( - input_dtypes=["int64"], - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - **shapes, - test_values=False, - ) - assert ret == frontend_ret - - -@st.composite -def _get_input_and_broadcast_shape(draw): - # Determine the dimensionality of the tensor, ranging from scalar (0D) to 3D. - num_dims = draw(st.integers(min_value=0, max_value=3)) - - # Generate the dimensions of the tensor. - dims = [draw(st.integers(min_value=1, max_value=5)) for _ in range(num_dims)] - - # Make Tensor. - x_dtype, x = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), shape=dims - ) - ) - - # Define the broadcast shape dimension - broadcast_num_dims = draw(st.integers(min_value=num_dims, max_value=3)) - - # Construct the broadcast shape. - if broadcast_num_dims == num_dims: - shape = tuple(dims) - else: - shape_components = [ - draw(st.integers(min_value=1, max_value=5)) - for _ in range(broadcast_num_dims - num_dims) - ] - shape = tuple(shape_components) + tuple(dims) - - return x_dtype, x, shape - - -@handle_frontend_test( - fn_tree="torch.broadcast_to", - array_and_shape=_get_input_and_broadcast_shape(), - test_with_out=st.just(False), -) -def test_torch_broadcast_to( - *, - array_and_shape, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, array, shape = array_and_shape - test_flags.num_positional_args = 2 - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - array=array[0], - shape=shape, - ) - - -# atleast_2d -@handle_frontend_test( - fn_tree="torch.atleast_2d", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - num_arrays=helpers.ints(min_value=1, max_value=10), - ), - test_with_out=st.just(False), -) -def test_torch_atleast_2d( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, arrays = dtype_and_x - arys = {} - for i, (array, idtype) in enumerate(zip(arrays, input_dtype)): - arys["arrs{}".format(i)] = array - test_flags.num_positional_args = len(arys) - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - **arys, - ) - - -@handle_frontend_test( - fn_tree="torch.searchsorted", - dtype_x_v=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - shared_dtype=True, - min_num_dims=1, - max_num_dims=1, - num_arrays=2, - ), - side=st.sampled_from(["left", "right"]), - out_int32=st.booleans(), - right=st.just(False), - test_with_out=st.just(False), -) -def test_torch_searchsorted( - dtype_x_v, - side, - out_int32, - right, - frontend, - test_flags, - fn_tree, - backend_fw, - on_device, -): - input_dtypes, xs = dtype_x_v - use_sorter = st.booleans() - if use_sorter: - sorter = np.argsort(xs[0]) - sorter = np.array(sorter, dtype=np.int64) - else: - xs[0] = np.sort(xs[0]) - sorter = None - helpers.test_frontend_function( - input_dtypes=input_dtypes + ["int64"], - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - sorted_sequence=xs[0], - values=xs[1], - side=side, - out_int32=out_int32, - right=right, - sorter=sorter, - ) - - -# atleast_3d -@handle_frontend_test( - fn_tree="torch.atleast_3d", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - num_arrays=helpers.ints(min_value=1, max_value=10), - ), - test_with_out=st.just(False), -) -def test_torch_atleast_3d( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, arrays = dtype_and_x - arys = {} - for i, array in enumerate(arrays): - arys["arrs{}".format(i)] = array - test_flags.num_positional_args = len(arys) - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - **arys, - ) - - -# diag -@handle_frontend_test( - fn_tree="torch.diag", - dtype_and_values=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - shape=st.shared(helpers.get_shape(min_num_dims=1, max_num_dims=2), key="shape"), - ), - diagonal=st.integers(min_value=-100, max_value=100), -) -def test_torch_diag( - *, - dtype_and_values, - diagonal, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - dtype, values = dtype_and_values - helpers.test_frontend_function( - input_dtypes=dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=values[0], - diagonal=diagonal, - ) - - -# clone -@handle_frontend_test( - fn_tree="torch.clone", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - ), -) -def test_torch_clone( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=x[0], - ) - - -@st.composite -def _get_dtype_value1_value2_cov( - draw, - available_dtypes, - min_num_dims, - max_num_dims, - min_dim_size, - max_dim_size, - abs_smallest_val=None, - min_value=None, - max_value=None, - allow_inf=False, - exclude_min=False, - exclude_max=False, - large_abs_safety_factor=4, - small_abs_safety_factor=4, - safety_factor_scale="log", -): - shape = draw( - helpers.get_shape( - allow_none=False, - min_num_dims=min_num_dims, - max_num_dims=max_num_dims, - min_dim_size=min_dim_size, - max_dim_size=max_dim_size, - ) - ) - - dtype = draw(st.sampled_from(draw(available_dtypes))) - - values = [] - for i in range(1): - values.append( - draw( - helpers.array_values( - dtype=dtype, - shape=shape, - abs_smallest_val=abs_smallest_val, - min_value=min_value, - max_value=max_value, - allow_inf=allow_inf, - exclude_min=exclude_min, - exclude_max=exclude_max, - large_abs_safety_factor=large_abs_safety_factor, - small_abs_safety_factor=small_abs_safety_factor, - safety_factor_scale=safety_factor_scale, - ) - ) - ) - - value1 = values[0] - - correction = draw(helpers.ints(min_value=0, max_value=1)) - - fweights = draw( - helpers.array_values( - dtype="int64", - shape=shape[1], - abs_smallest_val=1, - min_value=1, - max_value=10, - allow_inf=False, - ) - ) - - aweights = draw( - helpers.array_values( - dtype="float64", - shape=shape[1], - abs_smallest_val=1, - min_value=1, - max_value=10, - allow_inf=False, - small_abs_safety_factor=1, - ) - ) - - return [dtype], value1, correction, fweights, aweights - - -# cov -@handle_frontend_test( - fn_tree="torch.cov", - dtype_x1_corr_cov=_get_dtype_value1_value2_cov( - available_dtypes=helpers.get_dtypes("float"), - min_num_dims=2, - max_num_dims=2, - min_dim_size=2, - max_dim_size=5, - min_value=1, - max_value=1e10, - abs_smallest_val=0.01, - large_abs_safety_factor=2, - safety_factor_scale="log", - ), - test_with_out=st.just(False), -) -def test_torch_cov( - dtype_x1_corr_cov, - test_flags, - frontend, - fn_tree, - on_device, - backend_fw, -): - dtype, x1, correction, fweights, aweights = dtype_x1_corr_cov - helpers.test_frontend_function( - input_dtypes=["float64", "int64", "float64"], - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - rtol=1e-2, - atol=1e-2, - input=x1, - correction=correction, - fweights=fweights, - aweights=aweights, - ) - - -@handle_frontend_test( - fn_tree="torch.block_diag", - dtype_and_tensors=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - num_arrays=st.integers(min_value=1, max_value=10), - min_num_dims=0, - max_num_dims=2, - allow_inf=True, - ), - test_with_out=st.just(False), -) -def test_torch_block_diag( - *, - dtype_and_tensors, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - dtypes, tensors = dtype_and_tensors - if isinstance(dtypes, list): # If more than one value was generated - args = {f"x{i}": np.array(t, dtype=dtypes[i]) for i, t in enumerate(tensors)} - else: # If exactly one value was generated - args = {"x0": np.array(tensors, dtype=dtypes)} - test_flags.num_positional_args = len(tensors) - helpers.test_frontend_function( - input_dtypes=dtypes, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - backend_to_test=backend_fw, - **args, - ) - - -# view_as_real -@handle_frontend_test( - fn_tree="torch.view_as_real", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - num_arrays=1, - ), -) -def test_torch_view_as_real( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - test_flags, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=np.asarray(x[0], dtype=input_dtype[0]), - ) - - -@st.composite -def complex_strategy( - draw, min_num_dims=0, max_num_dims=5, min_dim_size=1, max_dim_size=10 -): - shape = draw( - st.lists( - helpers.ints(min_value=min_dim_size, max_value=max_dim_size), - min_size=min_num_dims, - max_size=max_num_dims, - ) - ) - shape = list(shape) - shape.append(2) - return tuple(shape) - - -# view_as_complex -@handle_frontend_test( - fn_tree="torch.view_as_complex", - dtype_and_values=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - shape=st.shared(complex_strategy()), - ), -) -def test_torch_view_as_complex( - *, - dtype_and_values, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - dtype, value = dtype_and_values - helpers.test_frontend_function( - input_dtypes=dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=value[0], - ) - - -# corrcoef -@handle_frontend_test( - fn_tree="torch.corrcoef", - dtypes_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=1, - min_num_dims=2, - max_num_dims=2, - min_dim_size=2, - max_dim_size=2, - min_value=1, - ), - test_with_out=st.just(False), -) -def test_torch_corrcoef( - dtypes_and_x, - frontend, - fn_tree, - on_device, - test_flags, - backend_fw, -): - input_dtypes, x = dtypes_and_x - helpers.test_frontend_function( - input_dtypes=["float64"], - frontend=frontend, - fn_tree=fn_tree, - test_flags=test_flags, - on_device=on_device, - backend_to_test=backend_fw, - input=x[0], - ) - - -# kron -@handle_frontend_test( - fn_tree="torch.kron", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), num_arrays=2 - ), -) -def test_torch_kron( - dtype_and_x, - frontend, - fn_tree, - test_flags, - backend_fw, - on_device, -): - input_dtypes, x = dtype_and_x - input, label = x[0], x[1] - helpers.test_frontend_function( - input_dtypes=["float32"], - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=input, - other=label, - ) +# global +import math + +import numpy as np +from hypothesis import assume, strategies as st +import hypothesis.extra.numpy as nph + +# local +import ivy +import ivy_tests.test_ivy.helpers as helpers +from ivy_tests.test_ivy.helpers import handle_frontend_test +from ivy_tests.test_ivy.test_functional.test_core.test_linalg import ( + _get_dtype_value1_value2_axis_for_tensordot, +) + + +# --- Helpers --- # +# --------------- # + + +@st.composite +def _get_dtype_value1_value2_cov( + draw, + available_dtypes, + min_num_dims, + max_num_dims, + min_dim_size, + max_dim_size, + abs_smallest_val=None, + min_value=None, + max_value=None, + allow_inf=False, + exclude_min=False, + exclude_max=False, + large_abs_safety_factor=4, + small_abs_safety_factor=4, + safety_factor_scale="log", +): + shape = draw( + helpers.get_shape( + allow_none=False, + min_num_dims=min_num_dims, + max_num_dims=max_num_dims, + min_dim_size=min_dim_size, + max_dim_size=max_dim_size, + ) + ) + + dtype = draw(st.sampled_from(draw(available_dtypes))) + + values = [] + for i in range(1): + values.append( + draw( + helpers.array_values( + dtype=dtype, + shape=shape, + abs_smallest_val=abs_smallest_val, + min_value=min_value, + max_value=max_value, + allow_inf=allow_inf, + exclude_min=exclude_min, + exclude_max=exclude_max, + large_abs_safety_factor=large_abs_safety_factor, + small_abs_safety_factor=small_abs_safety_factor, + safety_factor_scale=safety_factor_scale, + ) + ) + ) + + value1 = values[0] + + correction = draw(helpers.ints(min_value=0, max_value=1)) + + fweights = draw( + helpers.array_values( + dtype="int64", + shape=shape[1], + abs_smallest_val=1, + min_value=1, + max_value=10, + allow_inf=False, + ) + ) + + aweights = draw( + helpers.array_values( + dtype="float64", + shape=shape[1], + abs_smallest_val=1, + min_value=1, + max_value=10, + allow_inf=False, + small_abs_safety_factor=1, + ) + ) + + return [dtype], value1, correction, fweights, aweights + + +@st.composite +def _get_input_and_broadcast_shape(draw): + # Determine the dimensionality of the tensor, ranging from scalar (0D) to 3D. + num_dims = draw(st.integers(min_value=0, max_value=3)) + + # Generate the dimensions of the tensor. + dims = [draw(st.integers(min_value=1, max_value=5)) for _ in range(num_dims)] + + # Make Tensor. + x_dtype, x = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), shape=dims + ) + ) + + # Define the broadcast shape dimension + broadcast_num_dims = draw(st.integers(min_value=num_dims, max_value=3)) + + # Construct the broadcast shape. + if broadcast_num_dims == num_dims: + shape = tuple(dims) + else: + shape_components = [ + draw(st.integers(min_value=1, max_value=5)) + for _ in range(broadcast_num_dims - num_dims) + ] + shape = tuple(shape_components) + tuple(dims) + + return x_dtype, x, shape + + +# helpers +@st.composite +def _get_repeat_interleaves_args( + draw, *, available_dtypes, valid_axis, max_num_dims, max_dim_size +): + values_dtype, values, axis, shape = draw( + helpers.dtype_values_axis( + available_dtypes=available_dtypes, + valid_axis=valid_axis, + force_int_axis=True, + shape=draw( + helpers.get_shape( + allow_none=False, + min_num_dims=0, + max_num_dims=max_num_dims, + min_dim_size=1, + max_dim_size=max_dim_size, + ) + ), + ret_shape=True, + ) + ) + + if axis is None: + generate_repeats_as_integer = draw(st.booleans()) + num_repeats = 1 if generate_repeats_as_integer else math.prod(tuple(shape)) + else: + num_repeats = shape[axis] + + repeats_dtype, repeats = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("integer"), + min_value=0, + max_value=10, + shape=[num_repeats], + ) + ) + + # Output size is an optional parameter accepted by Torch for optimisation + use_output_size = draw(st.booleans()) + output_size = np.sum(repeats) if use_output_size else None + + return [values_dtype, repeats_dtype], values, repeats, axis, output_size + + +@st.composite +def complex_strategy( + draw, min_num_dims=0, max_num_dims=5, min_dim_size=1, max_dim_size=10 +): + shape = draw( + st.lists( + helpers.ints(min_value=min_dim_size, max_value=max_dim_size), + min_size=min_num_dims, + max_size=max_num_dims, + ) + ) + shape = list(shape) + shape.append(2) + return tuple(shape) + + +@st.composite +def dims_and_offset(draw, shape): + shape_actual = draw(shape) + dim1 = draw(helpers.get_axis(shape=shape, force_int=True)) + dim2 = draw(helpers.get_axis(shape=shape, force_int=True)) + offset = draw( + st.integers(min_value=-shape_actual[dim1], max_value=shape_actual[dim1]) + ) + return dim1, dim2, offset + + +# cross +@st.composite +def dtype_value1_value2_axis( + draw, + available_dtypes, + abs_smallest_val=None, + min_value=None, + max_value=None, + allow_inf=False, + exclude_min=False, + exclude_max=False, + min_num_dims=1, + max_num_dims=10, + min_dim_size=1, + max_dim_size=10, + specific_dim_size=3, + large_abs_safety_factor=4, + small_abs_safety_factor=4, + safety_factor_scale="log", +): + # For cross product, a dim with size 3 is required + shape = draw( + helpers.get_shape( + allow_none=False, + min_num_dims=min_num_dims, + max_num_dims=max_num_dims, + min_dim_size=min_dim_size, + max_dim_size=max_dim_size, + ) + ) + axis = draw(helpers.ints(min_value=0, max_value=len(shape))) + # make sure there is a dim with specific dim size + shape = list(shape) + shape = shape[:axis] + [specific_dim_size] + shape[axis:] + shape = tuple(shape) + + dtype = draw(st.sampled_from(draw(available_dtypes))) + + values = [] + for i in range(2): + values.append( + draw( + helpers.array_values( + dtype=dtype, + shape=shape, + abs_smallest_val=abs_smallest_val, + min_value=min_value, + max_value=max_value, + allow_inf=allow_inf, + exclude_min=exclude_min, + exclude_max=exclude_max, + large_abs_safety_factor=large_abs_safety_factor, + small_abs_safety_factor=small_abs_safety_factor, + safety_factor_scale=safety_factor_scale, + ) + ) + ) + + value1, value2 = values[0], values[1] + return [dtype], value1, value2, axis + + +# --- Main --- # +# ------------ # + + +# atleast_1d +@handle_frontend_test( + fn_tree="torch.atleast_1d", + dtype_and_tensors=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + num_arrays=st.integers(min_value=1, max_value=5), + ), + test_with_out=st.just(False), +) +def test_torch_atleast_1d( + *, + dtype_and_tensors, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + dtypes, tensors = dtype_and_tensors + if isinstance(dtypes, list): # If more than one value was generated + args = { + f"x{i}": np.array(tensor, dtype=dtypes[i]) + for i, tensor in enumerate(tensors) + } + else: # If exactly one value was generated + args = {"x0": np.array(tensors, dtype=dtypes)} + test_flags.num_positional_args = len(tensors) + helpers.test_frontend_function( + input_dtypes=dtypes, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + **args, + ) + + +# atleast_2d +@handle_frontend_test( + fn_tree="torch.atleast_2d", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + num_arrays=helpers.ints(min_value=1, max_value=10), + ), + test_with_out=st.just(False), +) +def test_torch_atleast_2d( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, arrays = dtype_and_x + arys = {} + for i, (array, idtype) in enumerate(zip(arrays, input_dtype)): + arys["arrs{}".format(i)] = array + test_flags.num_positional_args = len(arys) + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + **arys, + ) + + +# atleast_3d +@handle_frontend_test( + fn_tree="torch.atleast_3d", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + num_arrays=helpers.ints(min_value=1, max_value=10), + ), + test_with_out=st.just(False), +) +def test_torch_atleast_3d( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, arrays = dtype_and_x + arys = {} + for i, array in enumerate(arrays): + arys["arrs{}".format(i)] = array + test_flags.num_positional_args = len(arys) + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + **arys, + ) + + +@handle_frontend_test( + fn_tree="torch.block_diag", + dtype_and_tensors=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + num_arrays=st.integers(min_value=1, max_value=10), + min_num_dims=0, + max_num_dims=2, + allow_inf=True, + ), + test_with_out=st.just(False), +) +def test_torch_block_diag( + *, + dtype_and_tensors, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + dtypes, tensors = dtype_and_tensors + if isinstance(dtypes, list): # If more than one value was generated + args = {f"x{i}": np.array(t, dtype=dtypes[i]) for i, t in enumerate(tensors)} + else: # If exactly one value was generated + args = {"x0": np.array(tensors, dtype=dtypes)} + test_flags.num_positional_args = len(tensors) + helpers.test_frontend_function( + input_dtypes=dtypes, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + backend_to_test=backend_fw, + **args, + ) + + +@handle_frontend_test( + fn_tree="torch.broadcast_shapes", + shapes=nph.mutually_broadcastable_shapes( + num_shapes=4, min_dims=1, max_dims=5, min_side=1, max_side=5 + ), + test_with_out=st.just(False), +) +def test_torch_broadcast_shapes( + *, + shapes, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + shape, _ = shapes + shapes = {f"shape{i}": shape[i] for i in range(len(shape))} + test_flags.num_positional_args = len(shapes) + ret, frontend_ret = helpers.test_frontend_function( + input_dtypes=["int64"], + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + **shapes, + test_values=False, + ) + assert ret == frontend_ret + + +@handle_frontend_test( + fn_tree="torch.broadcast_to", + array_and_shape=_get_input_and_broadcast_shape(), + test_with_out=st.just(False), +) +def test_torch_broadcast_to( + *, + array_and_shape, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, array, shape = array_and_shape + test_flags.num_positional_args = 2 + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + array=array[0], + shape=shape, + ) + + +@handle_frontend_test( + fn_tree="torch.cartesian_prod", + dtype_and_tensors=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + num_arrays=st.integers(min_value=1, max_value=5), + min_num_dims=1, + max_num_dims=1, + max_dim_size=5, + shared_dtype=True, + ), +) +def test_torch_cartesian_prod( + *, + dtype_and_tensors, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + dtypes, tensors = dtype_and_tensors + if isinstance(dtypes, list): # If more than one value was generated + args = { + f"x{i}": np.array(tensor, dtype=dtypes[i]) + for i, tensor in enumerate(tensors) + } + else: # If exactly one value was generated + args = {"x0": np.array(tensors, dtype=dtypes)} + test_flags.num_positional_args = len(tensors) + helpers.test_frontend_function( + input_dtypes=dtypes, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + **args, + ) + + +# clone +@handle_frontend_test( + fn_tree="torch.clone", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + ), +) +def test_torch_clone( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=x[0], + ) + + +# corrcoef +@handle_frontend_test( + fn_tree="torch.corrcoef", + dtypes_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + num_arrays=1, + min_num_dims=2, + max_num_dims=2, + min_dim_size=2, + max_dim_size=2, + min_value=1, + ), + test_with_out=st.just(False), +) +def test_torch_corrcoef( + dtypes_and_x, + frontend, + fn_tree, + on_device, + test_flags, + backend_fw, +): + input_dtypes, x = dtypes_and_x + helpers.test_frontend_function( + input_dtypes=["float64"], + frontend=frontend, + fn_tree=fn_tree, + test_flags=test_flags, + on_device=on_device, + backend_to_test=backend_fw, + input=x[0], + ) + + +# cov +@handle_frontend_test( + fn_tree="torch.cov", + dtype_x1_corr_cov=_get_dtype_value1_value2_cov( + available_dtypes=helpers.get_dtypes("float"), + min_num_dims=2, + max_num_dims=2, + min_dim_size=2, + max_dim_size=5, + min_value=1, + max_value=1e10, + abs_smallest_val=0.01, + large_abs_safety_factor=2, + safety_factor_scale="log", + ), + test_with_out=st.just(False), +) +def test_torch_cov( + dtype_x1_corr_cov, + test_flags, + frontend, + fn_tree, + on_device, + backend_fw, +): + dtype, x1, correction, fweights, aweights = dtype_x1_corr_cov + helpers.test_frontend_function( + input_dtypes=["float64", "int64", "float64"], + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + rtol=1e-2, + atol=1e-2, + input=x1, + correction=correction, + fweights=fweights, + aweights=aweights, + ) + + +@handle_frontend_test( + fn_tree="torch.cross", + dtype_input_other_dim=dtype_value1_value2_axis( + available_dtypes=helpers.get_dtypes("numeric"), + min_num_dims=1, + max_num_dims=10, + min_dim_size=3, + max_dim_size=3, + min_value=-1e5, + max_value=1e5, + abs_smallest_val=0.01, + large_abs_safety_factor=2, + safety_factor_scale="log", + ), +) +def test_torch_cross( + dtype_input_other_dim, + frontend, + test_flags, + fn_tree, + backend_fw, +): + dtype, input, other, dim = dtype_input_other_dim + helpers.test_frontend_function( + input_dtypes=dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + rtol=1e-2, + atol=1e-2, + input=input, + other=other, + dim=dim, + ) + + +# cummax +@handle_frontend_test( + fn_tree="torch.cummax", + dtype_x_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("float"), + min_num_dims=1, + max_num_dims=2, + min_value=-100, + max_value=100, + valid_axis=True, + allow_neg_axes=False, + max_axes_size=1, + force_int_axis=True, + ), + dtype=helpers.get_dtypes("float", none=True, full=False), +) +def test_torch_cummax( + *, + dtype_x_axis, + dtype, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x, axis = dtype_x_axis + if ivy.current_backend_str() == "torch": + test_flags.as_variable = [False] + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=x[0], + dim=axis, + ) + + +# cumprod +@handle_frontend_test( + fn_tree="torch.cumprod", + dtype_x_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("numeric"), + min_num_dims=1, + max_num_dims=5, + min_value=-100, + max_value=100, + valid_axis=True, + allow_neg_axes=False, + max_axes_size=1, + force_int_axis=True, + ), + dtype=helpers.get_dtypes("numeric", none=True, full=False), +) +def test_torch_cumprod( + *, + dtype_x_axis, + dtype, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x, axis = dtype_x_axis + # ToDo: set as_variable_flags as the parameter generated by test_torch_cumsum once + # this issue is marked as completed https://github.com/pytorch/pytorch/issues/75733 + if ivy.current_backend_str() == "torch": + test_flags.as_variable = [False] + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=x[0], + dim=axis, + dtype=dtype[0], + ) + + +# cumsum +@handle_frontend_test( + fn_tree="torch.cumsum", + dtype_x_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("numeric"), + min_value=-100, + max_value=100, + min_num_dims=1, + max_num_dims=5, + valid_axis=True, + allow_neg_axes=False, + max_axes_size=1, + force_int_axis=True, + ), + dtype=helpers.get_dtypes("numeric", none=True, full=False), +) +def test_torch_cumsum( + *, + dtype_x_axis, + dtype, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x, axis = dtype_x_axis + # ToDo: set as_variable_flags as the parameter generated by test_torch_cumsum once + # this issue is marked as completed https://github.com/pytorch/pytorch/issues/75733 + if ivy.current_backend_str() == "torch": + test_flags.as_variable = [False] + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=x[0], + dim=axis, + dtype=dtype[0], + ) + + +# diag +@handle_frontend_test( + fn_tree="torch.diag", + dtype_and_values=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + shape=st.shared(helpers.get_shape(min_num_dims=1, max_num_dims=2), key="shape"), + ), + diagonal=st.integers(min_value=-100, max_value=100), +) +def test_torch_diag( + *, + dtype_and_values, + diagonal, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + dtype, values = dtype_and_values + helpers.test_frontend_function( + input_dtypes=dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=values[0], + diagonal=diagonal, + ) + + +@handle_frontend_test( + fn_tree="torch.diagonal", + dtype_and_values=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + shape=st.shared(helpers.get_shape(min_num_dims=2), key="shape"), + ), + dims_and_offset=dims_and_offset( + shape=st.shared(helpers.get_shape(min_num_dims=2), key="shape") + ), +) +def test_torch_diagonal( + *, + dtype_and_values, + dims_and_offset, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, value = dtype_and_values + dim1, dim2, offset = dims_and_offset + input = value[0] + num_dims = len(np.shape(input)) + assume(dim1 != dim2) + if dim1 < 0: + assume(dim1 + num_dims != dim2) + if dim2 < 0: + assume(dim1 != dim2 + num_dims) + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=input, + offset=offset, + dim1=dim1, + dim2=dim2, + ) + + +# diff +@handle_frontend_test( + fn_tree="torch.diff", + dtype_n_x_n_axis=helpers.dtype_values_axis( + available_dtypes=st.shared(helpers.get_dtypes("valid"), key="dtype"), + min_num_dims=1, + valid_axis=True, + force_int_axis=True, + ), + n=st.integers(min_value=0, max_value=5), + dtype_prepend=helpers.dtype_and_values( + available_dtypes=st.shared(helpers.get_dtypes("valid"), key="dtype"), + min_num_dims=1, + max_num_dims=1, + ), + dtype_append=helpers.dtype_and_values( + available_dtypes=st.shared(helpers.get_dtypes("valid"), key="dtype"), + min_num_dims=1, + max_num_dims=1, + ), +) +def test_torch_diff( + *, + dtype_n_x_n_axis, + n, + dtype_prepend, + dtype_append, + test_flags, + frontend, + backend_fw, + fn_tree, +): + input_dtype, x, axis = dtype_n_x_n_axis + _, prepend = dtype_prepend + _, append = dtype_append + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + test_flags=test_flags, + frontend=frontend, + fn_tree=fn_tree, + input=x[0], + n=n, + dim=axis, + prepend=prepend[0], + append=append[0], + ) + + +# einsum +@handle_frontend_test( + fn_tree="torch.einsum", + eq_n_op_n_shp=helpers.einsum_helper(), + dtype=helpers.get_dtypes("numeric", full=False), +) +def test_torch_einsum( + *, + eq_n_op_n_shp, + dtype, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + eq, operands, dtypes = eq_n_op_n_shp + kw = {} + for i, x_ in enumerate(operands): + dtype = dtypes[i][0] + kw["x{}".format(i)] = np.array(x_).astype(dtype) + test_flags.num_positional_args = len(operands) + 1 + helpers.test_frontend_function( + input_dtypes=dtypes, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + equation=eq, + **kw, + ) + + +@handle_frontend_test( + fn_tree="torch.flatten", + dtype_input_axes=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("valid"), + valid_axis=True, + min_num_dims=1, + min_axes_size=2, + max_axes_size=2, + ), +) +def test_torch_flatten( + *, + dtype_input_axes, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + dtype, input, axes = dtype_input_axes + if isinstance(axes, int): + start_dim = axes + end_dim = -1 + else: + start_dim = axes[0] + end_dim = axes[1] + helpers.test_frontend_function( + input_dtypes=dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=input[0], + start_dim=start_dim, + end_dim=end_dim, + ) + + +# flip +@handle_frontend_test( + fn_tree="torch.flip", + dtype_and_values=helpers.dtype_and_values( + shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), + available_dtypes=helpers.get_dtypes("float"), + ), + axis=helpers.get_axis( + shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), + force_tuple=True, + ), +) +def test_torch_flip( + *, + dtype_and_values, + axis, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, value = dtype_and_values + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=value[0], + dims=axis, + ) + + +# fliplr +@handle_frontend_test( + fn_tree="torch.fliplr", + dtype_and_values=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + shape=helpers.get_shape(min_num_dims=2), + ), +) +def test_torch_fliplr( + *, + dtype_and_values, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, value = dtype_and_values + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=value[0], + ) + + +# flipud +@handle_frontend_test( + fn_tree="torch.flipud", + dtype_and_values=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + shape=helpers.get_shape(min_num_dims=1), + ), +) +def test_torch_flipud( + *, + dtype_and_values, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, value = dtype_and_values + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=value[0], + ) + + +# gcd +@handle_frontend_test( + fn_tree="torch.gcd", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("integer"), + min_value=-100, + max_value=100, + min_num_dims=1, + max_num_dims=3, + min_dim_size=1, + max_dim_size=3, + num_arrays=2, + shared_dtype=True, + ), +) +def test_torch_gcd( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=x[0], + other=x[1], + ) + + +# kron +@handle_frontend_test( + fn_tree="torch.kron", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), num_arrays=2 + ), +) +def test_torch_kron( + dtype_and_x, + frontend, + fn_tree, + test_flags, + backend_fw, + on_device, +): + input_dtypes, x = dtype_and_x + input, label = x[0], x[1] + helpers.test_frontend_function( + input_dtypes=["float32"], + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=input, + other=label, + ) + + +# lcm +@handle_frontend_test( + fn_tree="torch.lcm", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("integer"), + num_arrays=2, + shared_dtype=True, + min_num_dims=1, + max_num_dims=3, + min_value=-100, + max_value=100, + allow_nan=False, + ), +) +def test_torch_lcm( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=x[0], + other=x[1], + ) + + +# logcumsumexp +@handle_frontend_test( + fn_tree="torch.logcumsumexp", + dtype_and_input=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + shape=st.shared(helpers.get_shape(), key="shape"), + max_value=100, + min_value=-100, + ), + dim=helpers.get_axis( + shape=st.shared(helpers.get_shape(), key="shape"), force_int=True + ), +) +def test_torch_logcumsumexp( + *, + dtype_and_input, + dim, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + dtype, input = dtype_and_input + helpers.test_frontend_function( + input_dtypes=dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + rtol=1e-2, + atol=1e-2, + input=input[0], + dim=dim, + ) + + +# meshgrid +@handle_frontend_test( + fn_tree="torch.meshgrid", + dtypes_and_tensors=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=st.integers(min_value=2, max_value=5), + min_num_dims=1, + max_num_dims=1, + min_dim_size=2, + max_dim_size=5, + shared_dtype=True, + ), + indexing=st.sampled_from(["ij", "xy"]), +) +def test_torch_meshgrid( + *, + dtypes_and_tensors, + indexing, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + dtypes, tensors = dtypes_and_tensors + kwargs = { + f"tensor{i}": np.array(tensor, dtype=dtypes[i]) + for i, tensor in enumerate(tensors) + } + kwargs["indexing"] = indexing + test_flags.num_positional_args = len(tensors) + helpers.test_frontend_function( + input_dtypes=dtypes, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + **kwargs, + ) + + +# ravel +@handle_frontend_test( + fn_tree="torch.ravel", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + num_arrays=1, + ), +) +def test_torch_ravel( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=np.asarray(x[0], dtype=input_dtype[0]), + ) + + +# renorm +@handle_frontend_test( + fn_tree="torch.renorm", + dtype_and_values=helpers.dtype_and_values( + shape=st.shared(helpers.get_shape(min_num_dims=2), key="shape"), + available_dtypes=helpers.get_dtypes("numeric"), + max_value=1e4, + min_value=-1e4, + ), + dim=helpers.get_axis( + shape=st.shared(helpers.get_shape(), key="shape"), + force_int=True, + ), + p=st.floats( + min_value=0.5, + exclude_min=True, + max_value=5, + ), # Non-positive norms aren't supported in backends. + # Small positive norms cause issues due to finite-precision. + maxnorm=st.floats(min_value=0), # Norms are positive semi-definite +) +def test_torch_renorm( + *, + dtype_and_values, + p, + dim, + maxnorm, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + dtype, values = dtype_and_values + helpers.test_frontend_function( + input_dtypes=dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + atol=1e-02, + input=values[0], + p=p, + dim=dim, + maxnorm=maxnorm, + ) + + +# repeat_interleave +@handle_frontend_test( + fn_tree="torch.repeat_interleave", + dtype_values_repeats_axis_output_size=_get_repeat_interleaves_args( + available_dtypes=helpers.get_dtypes("valid"), + valid_axis=True, + max_num_dims=4, + max_dim_size=4, + ), +) +def test_torch_repeat_interleave( + *, + dtype_values_repeats_axis_output_size, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + dtype, values, repeats, axis, output_size = dtype_values_repeats_axis_output_size + + helpers.test_frontend_function( + input_dtypes=dtype[0], + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=values[0], + repeats=repeats[0], + dim=axis, + output_size=output_size, + ) + + +# roll +@handle_frontend_test( + fn_tree="torch.roll", + dtype_and_values=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), + ), + shift=helpers.get_axis( + shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), + force_tuple=True, + ), + axis=helpers.get_axis( + shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), + force_tuple=True, + ), +) +def test_torch_roll( + *, + dtype_and_values, + shift, + axis, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, value = dtype_and_values + if isinstance(shift, int) and isinstance(axis, tuple): + axis = axis[0] + if isinstance(shift, tuple) and isinstance(axis, tuple): + if len(shift) != len(axis): + mn = min(len(shift), len(axis)) + shift = shift[:mn] + axis = axis[:mn] + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=value[0], + shifts=shift, + dims=axis, + ) + + +# rot90 +@handle_frontend_test( + fn_tree="torch.rot90", + dtype_and_x=helpers.dtype_and_values( + shape=st.shared(helpers.get_shape(min_num_dims=2), key="shape"), + available_dtypes=helpers.get_dtypes("numeric"), + ), + dims=helpers.get_axis( + shape=st.shared(helpers.get_shape(min_num_dims=2), key="shape"), + min_size=2, + max_size=2, + unique=True, + allow_neg=False, + force_tuple=True, + ), + k=st.integers(min_value=-10, max_value=10), +) +def test_torch_rot90( + *, + dtype_and_x, + dims, + k, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=x[0], + k=k, + dims=dims, + ) + + +@handle_frontend_test( + fn_tree="torch.searchsorted", + dtype_x_v=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + shared_dtype=True, + min_num_dims=1, + max_num_dims=1, + num_arrays=2, + ), + side=st.sampled_from(["left", "right"]), + out_int32=st.booleans(), + right=st.just(False), + test_with_out=st.just(False), +) +def test_torch_searchsorted( + dtype_x_v, + side, + out_int32, + right, + frontend, + test_flags, + fn_tree, + backend_fw, + on_device, +): + input_dtypes, xs = dtype_x_v + use_sorter = st.booleans() + if use_sorter: + sorter = np.argsort(xs[0]) + sorter = np.array(sorter, dtype=np.int64) + else: + xs[0] = np.sort(xs[0]) + sorter = None + helpers.test_frontend_function( + input_dtypes=input_dtypes + ["int64"], + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + sorted_sequence=xs[0], + values=xs[1], + side=side, + out_int32=out_int32, + right=right, + sorter=sorter, + ) + + +@handle_frontend_test( + fn_tree="torch.tensordot", + dtype_values_and_axes=_get_dtype_value1_value2_axis_for_tensordot( + helpers.get_dtypes(kind="float"), + min_value=-10, + max_value=10, + ), +) +def test_torch_tensordot( + dtype_values_and_axes, + test_flags, + frontend, + backend_fw, + fn_tree, +): + dtype, a, b, dims = dtype_values_and_axes + if ivy.current_backend_str() == "paddle": + # Paddle only supports ndim from 0 to 9 + assume(a.shape[0] < 10) + assume(b.shape[0] < 10) + + helpers.test_frontend_function( + input_dtypes=dtype, + backend_to_test=backend_fw, + test_flags=test_flags, + frontend=frontend, + fn_tree=fn_tree, + a=a, + b=b, + rtol=1e-2, + atol=1e-2, + dims=dims, + ) + + +# trace +@handle_frontend_test( + fn_tree="torch.trace", + dtype_and_values=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + shape=st.shared(helpers.get_shape(min_num_dims=2, max_num_dims=2), key="shape"), + ), +) +def test_torch_trace( + *, + dtype_and_values, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + dtype, value = dtype_and_values + helpers.test_frontend_function( + input_dtypes=dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=value[0], + ) + + +# tril +@handle_frontend_test( + fn_tree="torch.tril", + dtype_and_values=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + min_num_dims=2, # Torch requires this. + ), + diagonal=st.integers(min_value=-100, max_value=100), +) +def test_torch_tril( + *, + dtype_and_values, + diagonal, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + dtype, values = dtype_and_values + helpers.test_frontend_function( + input_dtypes=dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=values[0], + diagonal=diagonal, + ) + + +# tril_indices +@handle_frontend_test( + fn_tree="torch.tril_indices", + row=st.integers(min_value=1, max_value=10), + col=st.integers(min_value=1, max_value=10), + offset=st.integers(min_value=-8, max_value=8), + dtype=helpers.get_dtypes("integer", full=False), +) +def test_torch_tril_indices( + *, + row, + col, + offset, + dtype, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + helpers.test_frontend_function( + input_dtypes=[ivy.int32], + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + row=row, + col=col, + offset=offset, + dtype=dtype[0], + ) + + +@handle_frontend_test( + fn_tree="torch.triu", + dtype_and_values=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=2, # Torch requires this. + ), + diagonal=st.integers(min_value=-100, max_value=100), +) +def test_torch_triu( + *, + dtype_and_values, + diagonal, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + dtype, values = dtype_and_values + helpers.test_frontend_function( + input_dtypes=dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=values[0], + diagonal=diagonal, + ) + + +@handle_frontend_test( + fn_tree="torch.triu_indices", + row=st.integers(min_value=1, max_value=100), + col=st.integers(min_value=1, max_value=100), + offset=st.integers(min_value=-10, max_value=10), +) +def test_torch_triu_indices( + *, + row, + col, + offset, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + helpers.test_frontend_function( + input_dtypes=["int32"], + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + row=row, + col=col, + offset=offset, + ) + + +# vander +@handle_frontend_test( + fn_tree="torch.vander", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + shape=st.tuples( + st.integers(min_value=1, max_value=5), + ), + min_num_dims=0, + max_num_dims=5, + ), + N=st.integers(min_value=1, max_value=10) | st.none(), + increasing=st.booleans(), +) +def test_torch_vander( + *, + dtype_and_x, + N, + increasing, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=np.asarray(x[0], dtype=input_dtype[0]), + N=N, + increasing=increasing, + ) + + +# view_as_complex +@handle_frontend_test( + fn_tree="torch.view_as_complex", + dtype_and_values=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + shape=st.shared(complex_strategy()), + ), +) +def test_torch_view_as_complex( + *, + dtype_and_values, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + dtype, value = dtype_and_values + helpers.test_frontend_function( + input_dtypes=dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=value[0], + ) + + +# view_as_real +@handle_frontend_test( + fn_tree="torch.view_as_real", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + num_arrays=1, + ), +) +def test_torch_view_as_real( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=np.asarray(x[0], dtype=input_dtype[0]), + ) diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_convolution_functions.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_convolution_functions.py index be32c522fb53f..d2e95a4269d04 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_convolution_functions.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_convolution_functions.py @@ -11,6 +11,114 @@ ) +# --- Helpers --- # +# --------------- # + + +@st.composite +def _fold_helper(draw, dim=2): + stride, padding, dilation, kernel_size = draw(_fold_unfold_helper(dim)) + strides = [stride] * dim if isinstance(stride, int) else stride + paddings = [padding] * dim if isinstance(padding, int) else padding + dilations = [dilation] * dim if isinstance(dilation, int) else dilation + kernel_sizes = [kernel_size] * dim if isinstance(kernel_size, int) else kernel_size + output_shape = () + for i in range(dim): + min_dim = kernel_sizes[i] + (kernel_sizes[i] - 1) * (dilations[i] - 1) + output_shape = output_shape + (draw(st.integers(min_dim, 15)),) + batch_size = draw(st.integers(1, 5)) + n_channels = draw(st.integers(1, 3)) + x_shape = [ + (output_shape[i] + 2 * paddings[i] - dilations[i] * (kernel_sizes[i] - 1) - 1) + // strides[i] + + 1 + for i in range(2) + ] + x_shape = (batch_size, n_channels * math.prod(kernel_sizes), math.prod(x_shape)) + dtype, [vals] = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + shape=x_shape, + min_value=0.0, + max_value=1.0, + ) + ) + if vals.shape[0] == 1: # un-batched inputs are also supported + vals = draw(st.one_of(st.just(vals), st.just(ivy.squeeze(vals, axis=0)))) + return dtype, vals, kernel_size, output_shape, dilation, stride, padding + + +@st.composite +def _fold_unfold_helper(draw, dim): + stride = draw( + st.one_of( + st.lists(st.integers(min_value=1, max_value=3), min_size=dim, max_size=dim), + st.integers(min_value=1, max_value=3), + ) + ) + padding = draw( + st.one_of( + st.integers(min_value=1, max_value=3), + st.lists(st.integers(min_value=1, max_value=2), min_size=dim, max_size=dim), + ) + ) + dilation = draw( + st.one_of( + st.lists(st.integers(min_value=1, max_value=3), min_size=dim, max_size=dim), + st.integers(min_value=1, max_value=3), + ) + ) + kernel_size = draw( + st.one_of( + st.integers(min_value=1, max_value=5), + helpers.get_shape( + min_num_dims=dim, max_num_dims=dim, min_dim_size=1, max_dim_size=5 + ), + ) + ) + return stride, padding, dilation, kernel_size + + +def _output_shape( + dims, dilation, stride, padding, output_padding, input_shape, weight_shape +): + dilation, stride, padding, output_padding = map( + lambda x: [x] * dims if isinstance(x, int) else x, + [dilation, stride, padding, output_padding], + ) + return [ + (input_shape[2 + i] - 1) * stride[i] + - 2 * padding[i] + + dilation[i] * (weight_shape[2 + i] - 1) + + output_padding[i] + + 1 + for i in range(dims) + ] + + +@st.composite +def _unfold_helper(draw, dim=2): + stride, padding, dilation, kernel_size = draw(_fold_unfold_helper(dim)) + dilations = [dilation] * dim if isinstance(dilation, int) else dilation + kernel_sizes = [kernel_size] * dim if isinstance(kernel_size, int) else kernel_size + x_dim = [] + for i in range(dim): + min_x = kernel_sizes[i] + (kernel_sizes[i] - 1) * (dilations[i] - 1) + x_dim.append(draw(st.integers(min_x, 15))) + batch_size = draw(st.integers(1, 5)) + input_channels = draw(st.integers(1, 3)) + x_shape = (batch_size, input_channels) + tuple(x_dim) + dtype, [vals] = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + shape=x_shape, + min_value=0.0, + max_value=1.0, + ) + ) + return dtype, vals, kernel_size, dilation, stride, padding + + @st.composite def _x_and_filters(draw, dim: int = 2, transpose: bool = False): if not isinstance(dim, int): @@ -137,6 +245,10 @@ def _x_and_filters(draw, dim: int = 2, transpose: bool = False): return dtype, vals, filters, bias, dilations, strides, padding, fc +# --- Main --- # +# ------------ # + + @handle_frontend_test( fn_tree="torch.nn.functional.conv1d", dtype_vals=_x_and_filters(dim=1), @@ -232,23 +344,6 @@ def test_torch_conv3d( ) -def _output_shape( - dims, dilation, stride, padding, output_padding, input_shape, weight_shape -): - dilation, stride, padding, output_padding = map( - lambda x: [x] * dims if isinstance(x, int) else x, - [dilation, stride, padding, output_padding], - ) - return [ - (input_shape[2 + i] - 1) * stride[i] - - 2 * padding[i] - + dilation[i] * (weight_shape[2 + i] - 1) - + output_padding[i] - + 1 - for i in range(dims) - ] - - @handle_frontend_test( fn_tree="torch.nn.functional.conv_transpose1d", dtype_vals=_x_and_filters(dim=1, transpose=True), @@ -372,65 +467,11 @@ def test_torch_conv_tranpose3d( ) -@st.composite -def _fold_unfold_helper(draw, dim): - stride = draw( - st.one_of( - st.lists(st.integers(min_value=1, max_value=3), min_size=dim, max_size=dim), - st.integers(min_value=1, max_value=3), - ) - ) - padding = draw( - st.one_of( - st.integers(min_value=1, max_value=3), - st.lists(st.integers(min_value=1, max_value=2), min_size=dim, max_size=dim), - ) - ) - dilation = draw( - st.one_of( - st.lists(st.integers(min_value=1, max_value=3), min_size=dim, max_size=dim), - st.integers(min_value=1, max_value=3), - ) - ) - kernel_size = draw( - st.one_of( - st.integers(min_value=1, max_value=5), - helpers.get_shape( - min_num_dims=dim, max_num_dims=dim, min_dim_size=1, max_dim_size=5 - ), - ) - ) - return stride, padding, dilation, kernel_size - - -@st.composite -def _unfold_helper(draw, dim=2): - stride, padding, dilation, kernel_size = draw(_fold_unfold_helper(dim)) - dilations = [dilation] * dim if isinstance(dilation, int) else dilation - kernel_sizes = [kernel_size] * dim if isinstance(kernel_size, int) else kernel_size - x_dim = [] - for i in range(dim): - min_x = kernel_sizes[i] + (kernel_sizes[i] - 1) * (dilations[i] - 1) - x_dim.append(draw(st.integers(min_x, 15))) - batch_size = draw(st.integers(1, 5)) - input_channels = draw(st.integers(1, 3)) - x_shape = (batch_size, input_channels) + tuple(x_dim) - dtype, [vals] = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - shape=x_shape, - min_value=0.0, - max_value=1.0, - ) - ) - return dtype, vals, kernel_size, dilation, stride, padding - - @handle_frontend_test( - fn_tree="torch.nn.functional.unfold", - dtype_vals=_unfold_helper(), + fn_tree="torch.nn.functional.fold", + dtype_vals=_fold_helper(), ) -def test_torch_unfold( +def test_torch_fold( *, dtype_vals, on_device, @@ -439,7 +480,7 @@ def test_torch_unfold( test_flags, backend_fw, ): - dtype, vals, kernel_shape, dilations, strides, padding = dtype_vals + dtype, vals, kernel_shape, output_shape, dilations, strides, padding = dtype_vals helpers.test_frontend_function( input_dtypes=dtype, backend_to_test=backend_fw, @@ -448,6 +489,7 @@ def test_torch_unfold( fn_tree=fn_tree, on_device=on_device, input=vals, + output_size=output_shape, kernel_size=kernel_shape, dilation=dilations, padding=padding, @@ -455,44 +497,11 @@ def test_torch_unfold( ) -@st.composite -def _fold_helper(draw, dim=2): - stride, padding, dilation, kernel_size = draw(_fold_unfold_helper(dim)) - strides = [stride] * dim if isinstance(stride, int) else stride - paddings = [padding] * dim if isinstance(padding, int) else padding - dilations = [dilation] * dim if isinstance(dilation, int) else dilation - kernel_sizes = [kernel_size] * dim if isinstance(kernel_size, int) else kernel_size - output_shape = () - for i in range(dim): - min_dim = kernel_sizes[i] + (kernel_sizes[i] - 1) * (dilations[i] - 1) - output_shape = output_shape + (draw(st.integers(min_dim, 15)),) - batch_size = draw(st.integers(1, 5)) - n_channels = draw(st.integers(1, 3)) - x_shape = [ - (output_shape[i] + 2 * paddings[i] - dilations[i] * (kernel_sizes[i] - 1) - 1) - // strides[i] - + 1 - for i in range(2) - ] - x_shape = (batch_size, n_channels * math.prod(kernel_sizes), math.prod(x_shape)) - dtype, [vals] = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - shape=x_shape, - min_value=0.0, - max_value=1.0, - ) - ) - if vals.shape[0] == 1: # un-batched inputs are also supported - vals = draw(st.one_of(st.just(vals), st.just(ivy.squeeze(vals, axis=0)))) - return dtype, vals, kernel_size, output_shape, dilation, stride, padding - - @handle_frontend_test( - fn_tree="torch.nn.functional.fold", - dtype_vals=_fold_helper(), + fn_tree="torch.nn.functional.unfold", + dtype_vals=_unfold_helper(), ) -def test_torch_fold( +def test_torch_unfold( *, dtype_vals, on_device, @@ -501,7 +510,7 @@ def test_torch_fold( test_flags, backend_fw, ): - dtype, vals, kernel_shape, output_shape, dilations, strides, padding = dtype_vals + dtype, vals, kernel_shape, dilations, strides, padding = dtype_vals helpers.test_frontend_function( input_dtypes=dtype, backend_to_test=backend_fw, @@ -510,7 +519,6 @@ def test_torch_fold( fn_tree=fn_tree, on_device=on_device, input=vals, - output_size=output_shape, kernel_size=kernel_shape, dilation=dilations, padding=padding, diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_dropout_functions.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_dropout_functions.py index 5ed7f7887cbb7..16d6c9dc1a321 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_dropout_functions.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_dropout_functions.py @@ -8,7 +8,7 @@ @handle_frontend_test( - fn_tree="torch.nn.functional.dropout", + fn_tree="torch.nn.functional.alpha_dropout", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), min_value=0, @@ -22,7 +22,7 @@ training=st.booleans(), test_inplace=st.just(False), ) -def test_torch_dropout( +def test_torch_alpha_dropout( *, dtype_and_x, prob, @@ -54,7 +54,7 @@ def test_torch_dropout( @handle_frontend_test( - fn_tree="torch.nn.functional.alpha_dropout", + fn_tree="torch.nn.functional.dropout", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), min_value=0, @@ -68,7 +68,7 @@ def test_torch_dropout( training=st.booleans(), test_inplace=st.just(False), ) -def test_torch_alpha_dropout( +def test_torch_dropout( *, dtype_and_x, prob, diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_linear_functions.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_linear_functions.py index 2a2a3dd0f59fe..9bd49f60c5fe1 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_linear_functions.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_linear_functions.py @@ -6,6 +6,10 @@ from ivy_tests.test_ivy.helpers import handle_frontend_test +# --- Helpers --- # +# --------------- # + + @st.composite def _x_and_linear(draw, dtypes): dtype = draw(dtypes) @@ -34,6 +38,10 @@ def _x_and_linear(draw, dtypes): return dtype, x, weight, bias +# --- Main --- # +# ------------ # + + # linear @handle_frontend_test( fn_tree="torch.nn.functional.linear", diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_loss_functions.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_loss_functions.py index a52e6c705bbb9..f22c21e8c49c3 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_loss_functions.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_loss_functions.py @@ -10,72 +10,6 @@ ) -# cross_entropy -@handle_frontend_test( - fn_tree="torch.nn.functional.cross_entropy", - dtype_and_input=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - allow_inf=False, - min_num_dims=2, - max_num_dims=2, - min_dim_size=1, - ), - dtype_and_target=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_value=0.0, - max_value=1.0, - allow_inf=False, - min_num_dims=1, - max_num_dims=1, - min_dim_size=2, - ), - dtype_and_weights=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - allow_inf=False, - min_num_dims=1, - max_num_dims=1, - min_dim_size=2, - ), - size_average=st.booleans(), - reduce=st.booleans(), - reduction=st.sampled_from(["mean", "none", "sum"]), - label_smoothing=helpers.floats(min_value=0, max_value=0.49), -) -def test_torch_cross_entropy( - *, - dtype_and_input, - dtype_and_target, - dtype_and_weights, - size_average, - reduce, - reduction, - label_smoothing, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - inputs_dtype, input = dtype_and_input - target_dtype, target = dtype_and_target - weights_dtype, weights = dtype_and_weights - helpers.test_frontend_function( - input_dtypes=inputs_dtype + target_dtype + weights_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=input[0], - target=target[0], - weight=weights[0], - size_average=size_average, - reduce=reduce, - reduction=reduction, - label_smoothing=label_smoothing, - ) - - # binary_cross_entropy @handle_frontend_test( fn_tree="torch.nn.functional.binary_cross_entropy", @@ -303,116 +237,172 @@ def test_torch_cosine_embedding_loss( ivy.previous_backend() -# mse_loss +# cross_entropy @handle_frontend_test( - fn_tree="torch.nn.functional.mse_loss", - dtype_and_true=helpers.dtype_and_values( + fn_tree="torch.nn.functional.cross_entropy", + dtype_and_input=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + allow_inf=False, + min_num_dims=2, + max_num_dims=2, + min_dim_size=1, + ), + dtype_and_target=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), min_value=0.0, max_value=1.0, - large_abs_safety_factor=2, - small_abs_safety_factor=2, - safety_factor_scale="linear", allow_inf=False, - exclude_min=True, - exclude_max=True, min_num_dims=1, max_num_dims=1, min_dim_size=2, ), - dtype_and_pred=helpers.dtype_and_values( + dtype_and_weights=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - min_value=0.0, - max_value=1.0, - large_abs_safety_factor=2, - small_abs_safety_factor=2, - safety_factor_scale="linear", allow_inf=False, - exclude_min=True, - exclude_max=True, min_num_dims=1, max_num_dims=1, min_dim_size=2, ), size_average=st.booleans(), reduce=st.booleans(), - reduction=st.sampled_from(["mean"]), - test_with_out=st.just(False), + reduction=st.sampled_from(["mean", "none", "sum"]), + label_smoothing=helpers.floats(min_value=0, max_value=0.49), ) -def test_torch_mse_loss( +def test_torch_cross_entropy( *, - dtype_and_true, - dtype_and_pred, + dtype_and_input, + dtype_and_target, + dtype_and_weights, size_average, reduce, reduction, + label_smoothing, on_device, fn_tree, frontend, test_flags, backend_fw, ): - pred_dtype, pred = dtype_and_pred - true_dtype, true = dtype_and_true + inputs_dtype, input = dtype_and_input + target_dtype, target = dtype_and_target + weights_dtype, weights = dtype_and_weights helpers.test_frontend_function( - input_dtypes=[pred_dtype[0], true_dtype[0]], + input_dtypes=inputs_dtype + target_dtype + weights_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, - input=pred[0], - target=true[0], + on_device=on_device, + input=input[0], + target=target[0], + weight=weights[0], size_average=size_average, reduce=reduce, reduction=reduction, + label_smoothing=label_smoothing, + ) + + +# gaussian_nll_loss +@handle_frontend_test( + fn_tree="torch.nn.functional.gaussian_nll_loss", + dtype_and_input=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + num_arrays=3, + min_value=0.01, + max_value=1.0, + allow_inf=False, + min_num_dims=1, + max_num_dims=5, + min_dim_size=1, + max_dim_size=5, + ), + full=st.booleans(), + eps=st.floats( + min_value=0.0, + max_value=1.0, + allow_nan=False, + allow_infinity=False, + ), + reduction=st.sampled_from(["mean", "sum"]), +) +def test_torch_gaussian_nll_loss( + *, + dtype_and_input, + full, + eps, + reduction, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + inputs_dtype, input = dtype_and_input + helpers.test_frontend_function( + input_dtypes=inputs_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, on_device=on_device, + input=input[0], + target=input[1], + var=input[2], + full=full, + eps=eps, + reduction=reduction, + atol=1e-2, + rtol=1e-2, ) -# smooth_l1_loss @handle_frontend_test( - fn_tree="torch.nn.functional.smooth_l1_loss", + fn_tree="torch.nn.functional.hinge_embedding_loss", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), num_arrays=2, + min_value=-100, + max_value=100, allow_inf=False, - shared_dtype=True, ), + margin=st.floats(min_value=-10, max_value=10), size_average=st.booleans(), reduce=st.booleans(), reduction=st.sampled_from(["none", "mean", "sum"]), - beta=st.sampled_from([1.0, 0.5, 0.1, 0.0]), test_with_out=st.just(False), ) -def test_torch_smooth_l1_loss( +def test_torch_hinge_embedding_loss( *, dtype_and_x, + margin, size_average, reduce, reduction, - beta, - frontend, test_flags, fn_tree, backend_fw, + frontend, on_device, ): input_dtype, x = dtype_and_x - pred_dtype, pred = input_dtype[0], x[0] - true_dtype, true = input_dtype[1], x[1] + input, target = x + helpers.test_frontend_function( - input_dtypes=[pred_dtype, true_dtype], + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=pred, - target=true, + input=input, + target=target, + margin=margin, size_average=size_average, reduce=reduce, reduction=reduction, - beta=beta, + atol=1e-5, + rtol=1e-5, ) @@ -457,6 +447,57 @@ def test_torch_huber_loss( ) +# kl_div +@handle_frontend_test( + fn_tree="torch.nn.functional.kl_div", + dtype_and_inputs=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + allow_inf=False, + shared_dtype=True, + min_value=0, + max_value=10, + min_num_dims=0, + max_num_dims=10, + min_dim_size=0, + max_dim_size=10, + num_arrays=2, + ), + size_average=st.booleans(), + reduce=st.booleans(), + reduction=st.sampled_from(["none", "mean", "sum", "batchmean"]), + log_target=st.booleans(), + test_with_out=st.just(False), +) +def test_torch_kl_div( + *, + dtype_and_inputs, + size_average, + reduce, + reduction, + log_target, + frontend, + test_flags, + fn_tree, + backend_fw, + on_device, +): + inputs_dtype, inputs = dtype_and_inputs + helpers.test_frontend_function( + input_dtypes=inputs_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=inputs[0], + target=inputs[1], + size_average=size_average, + reduce=reduce, + reduction=reduction, + log_target=log_target, + ) + + # l1_loss @handle_frontend_test( fn_tree="torch.nn.functional.l1_loss", @@ -501,102 +542,97 @@ def test_torch_l1_loss( ) -# nll_loss +# margin ranking loss @handle_frontend_test( - fn_tree="torch.nn.functional.nll_loss", - dtype_and_input=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_value=0.01, - max_value=1.0, - allow_inf=False, - min_num_dims=1, - max_num_dims=1, - min_dim_size=1, - max_dim_size=1, - ), - dtype_and_target=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("integer"), - min_value=0.0, - max_value=1.0, - allow_inf=False, - min_num_dims=1, - max_num_dims=1, - min_dim_size=1, - max_dim_size=1, - ), - dtype_and_weights=helpers.dtype_and_values( + fn_tree="torch.nn.functional.margin_ranking_loss", + dtype_and_inputs=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), + num_arrays=3, allow_inf=False, - min_num_dims=1, - max_num_dims=1, - min_dim_size=1, - max_dim_size=1, + shared_dtype=True, ), + margin=st.floats(), size_average=st.booleans(), reduce=st.booleans(), - reduction=st.sampled_from(["mean", "none", "sum"]), + reduction=st.sampled_from(["none", "mean", "sum"]), + test_with_out=st.just(False), ) -def test_torch_nll_loss( +def test_torch_margin_ranking_loss( *, - dtype_and_input, - dtype_and_target, - dtype_and_weights, + dtype_and_inputs, + margin, size_average, reduce, reduction, - on_device, - fn_tree, - frontend, test_flags, + fn_tree, backend_fw, + frontend, + on_device, ): - inputs_dtype, input = dtype_and_input - target_dtype, target = dtype_and_target - weights_dtype, weights = dtype_and_weights + input_dtype, x = dtype_and_inputs + input1_dtype, input1 = input_dtype[0], x[0] + input2_dtype, input2 = input_dtype[1], x[1] + tar_dtype, tar = input_dtype[2], x[2] helpers.test_frontend_function( - input_dtypes=inputs_dtype + target_dtype + weights_dtype, + input_dtypes=[input1_dtype, input2_dtype, tar_dtype], backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=input[0], - target=target[0], - weight=weights[0], + input1=input1, + input2=input2, + target=tar, + margin=margin, size_average=size_average, reduce=reduce, reduction=reduction, ) -# gaussian_nll_loss +# mse_loss @handle_frontend_test( - fn_tree="torch.nn.functional.gaussian_nll_loss", - dtype_and_input=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - num_arrays=3, - min_value=0.01, + fn_tree="torch.nn.functional.mse_loss", + dtype_and_true=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_value=0.0, max_value=1.0, + large_abs_safety_factor=2, + small_abs_safety_factor=2, + safety_factor_scale="linear", allow_inf=False, + exclude_min=True, + exclude_max=True, min_num_dims=1, - max_num_dims=5, - min_dim_size=1, - max_dim_size=5, + max_num_dims=1, + min_dim_size=2, ), - full=st.booleans(), - eps=st.floats( + dtype_and_pred=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), min_value=0.0, max_value=1.0, - allow_nan=False, - allow_infinity=False, + large_abs_safety_factor=2, + small_abs_safety_factor=2, + safety_factor_scale="linear", + allow_inf=False, + exclude_min=True, + exclude_max=True, + min_num_dims=1, + max_num_dims=1, + min_dim_size=2, ), - reduction=st.sampled_from(["mean", "sum"]), + size_average=st.booleans(), + reduce=st.booleans(), + reduction=st.sampled_from(["mean"]), + test_with_out=st.just(False), ) -def test_torch_gaussian_nll_loss( +def test_torch_mse_loss( *, - dtype_and_input, - full, - eps, + dtype_and_true, + dtype_and_pred, + size_average, + reduce, reduction, on_device, fn_tree, @@ -604,163 +640,171 @@ def test_torch_gaussian_nll_loss( test_flags, backend_fw, ): - inputs_dtype, input = dtype_and_input + pred_dtype, pred = dtype_and_pred + true_dtype, true = dtype_and_true helpers.test_frontend_function( - input_dtypes=inputs_dtype, + input_dtypes=[pred_dtype[0], true_dtype[0]], backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, - on_device=on_device, - input=input[0], - target=input[1], - var=input[2], - full=full, - eps=eps, + input=pred[0], + target=true[0], + size_average=size_average, + reduce=reduce, reduction=reduction, - atol=1e-2, - rtol=1e-2, + on_device=on_device, ) -# soft margin loss +# multilabel_margin_loss + + @handle_frontend_test( - fn_tree="torch.nn.functional.soft_margin_loss", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + fn_tree="torch.nn.functional.multilabel_margin_loss", + dtype_and_inputs=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), num_arrays=2, allow_inf=False, shared_dtype=True, + min_num_dims=1, ), size_average=st.booleans(), reduce=st.booleans(), reduction=st.sampled_from(["none", "mean", "sum"]), test_with_out=st.just(False), ) -def test_torch_soft_margin_loss( +def test_torch_multilabel_margin_loss( *, - dtype_and_x, + dtype_and_inputs, + reduction, size_average, reduce, - reduction, - frontend, test_flags, fn_tree, - backend_fw, + frontend, on_device, ): - input_dtype, x = dtype_and_x - pred_dtype, pred = input_dtype[0], x[0] - tar_dtype, tar = input_dtype[1], x[1] + input_dtype, x = dtype_and_inputs helpers.test_frontend_function( - input_dtypes=[pred_dtype, tar_dtype], - backend_to_test=backend_fw, + input_dtypes=input_dtype, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=pred, - target=tar, + input=x[0], + target=x[1], + reduction=reduction, size_average=size_average, reduce=reduce, - reduction=reduction, ) -# kl_div +# multilabel soft margin loss @handle_frontend_test( - fn_tree="torch.nn.functional.kl_div", + fn_tree="torch.nn.functional.multilabel_soft_margin_loss", dtype_and_inputs=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, allow_inf=False, shared_dtype=True, - min_value=0, - max_value=10, - min_num_dims=0, - max_num_dims=10, - min_dim_size=0, - max_dim_size=10, - num_arrays=2, + min_num_dims=1, ), size_average=st.booleans(), reduce=st.booleans(), - reduction=st.sampled_from(["none", "mean", "sum", "batchmean"]), - log_target=st.booleans(), + reduction=st.sampled_from(["none", "mean", "sum"]), test_with_out=st.just(False), ) -def test_torch_kl_div( +def test_torch_multilabel_soft_margin_loss( *, dtype_and_inputs, size_average, reduce, reduction, - log_target, - frontend, test_flags, fn_tree, backend_fw, + frontend, on_device, ): - inputs_dtype, inputs = dtype_and_inputs + input_dtype, x = dtype_and_inputs helpers.test_frontend_function( - input_dtypes=inputs_dtype, + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=inputs[0], - target=inputs[1], + input=x[0], + target=x[1], size_average=size_average, reduce=reduce, reduction=reduction, - log_target=log_target, ) -# margin ranking loss +# nll_loss @handle_frontend_test( - fn_tree="torch.nn.functional.margin_ranking_loss", - dtype_and_inputs=helpers.dtype_and_values( + fn_tree="torch.nn.functional.nll_loss", + dtype_and_input=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - num_arrays=3, + min_value=0.01, + max_value=1.0, allow_inf=False, - shared_dtype=True, + min_num_dims=1, + max_num_dims=1, + min_dim_size=1, + max_dim_size=1, + ), + dtype_and_target=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("integer"), + min_value=0.0, + max_value=1.0, + allow_inf=False, + min_num_dims=1, + max_num_dims=1, + min_dim_size=1, + max_dim_size=1, + ), + dtype_and_weights=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + allow_inf=False, + min_num_dims=1, + max_num_dims=1, + min_dim_size=1, + max_dim_size=1, ), - margin=st.floats(), size_average=st.booleans(), reduce=st.booleans(), - reduction=st.sampled_from(["none", "mean", "sum"]), - test_with_out=st.just(False), + reduction=st.sampled_from(["mean", "none", "sum"]), ) -def test_torch_margin_ranking_loss( +def test_torch_nll_loss( *, - dtype_and_inputs, - margin, + dtype_and_input, + dtype_and_target, + dtype_and_weights, size_average, reduce, reduction, - test_flags, + on_device, fn_tree, - backend_fw, frontend, - on_device, + test_flags, + backend_fw, ): - input_dtype, x = dtype_and_inputs - input1_dtype, input1 = input_dtype[0], x[0] - input2_dtype, input2 = input_dtype[1], x[1] - tar_dtype, tar = input_dtype[2], x[2] + inputs_dtype, input = dtype_and_input + target_dtype, target = dtype_and_target + weights_dtype, weights = dtype_and_weights helpers.test_frontend_function( - input_dtypes=[input1_dtype, input2_dtype, tar_dtype], + input_dtypes=inputs_dtype + target_dtype + weights_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input1=input1, - input2=input2, - target=tar, - margin=margin, + input=input[0], + target=target[0], + weight=weights[0], size_average=size_average, reduce=reduce, reduction=reduction, @@ -818,133 +862,125 @@ def test_torch_poisson_nll_loss( ) +# smooth_l1_loss @handle_frontend_test( - fn_tree="torch.nn.functional.hinge_embedding_loss", + fn_tree="torch.nn.functional.smooth_l1_loss", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), num_arrays=2, - min_value=-100, - max_value=100, allow_inf=False, + shared_dtype=True, ), - margin=st.floats(min_value=-10, max_value=10), size_average=st.booleans(), reduce=st.booleans(), reduction=st.sampled_from(["none", "mean", "sum"]), + beta=st.sampled_from([1.0, 0.5, 0.1, 0.0]), test_with_out=st.just(False), ) -def test_torch_hinge_embedding_loss( +def test_torch_smooth_l1_loss( *, dtype_and_x, - margin, size_average, reduce, reduction, + beta, + frontend, test_flags, fn_tree, backend_fw, - frontend, on_device, ): input_dtype, x = dtype_and_x - input, target = x - + pred_dtype, pred = input_dtype[0], x[0] + true_dtype, true = input_dtype[1], x[1] helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=[pred_dtype, true_dtype], backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=input, - target=target, - margin=margin, + input=pred, + target=true, size_average=size_average, reduce=reduce, reduction=reduction, - atol=1e-5, - rtol=1e-5, + beta=beta, ) -# triplet margin loss +# soft margin loss @handle_frontend_test( - fn_tree="torch.nn.functional.triplet_margin_loss", - dtype_and_inputs=helpers.dtype_and_values( + fn_tree="torch.nn.functional.soft_margin_loss", + dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - num_arrays=3, + num_arrays=2, allow_inf=False, shared_dtype=True, - min_value=0.0, - max_value=1.0, - min_num_dims=1, - max_num_dims=2, - min_dim_size=1, ), - margin=st.floats(), - p=st.integers(min_value=0, max_value=2), - swap=st.booleans(), size_average=st.booleans(), reduce=st.booleans(), reduction=st.sampled_from(["none", "mean", "sum"]), test_with_out=st.just(False), ) -def test_torch_triplet_margin_loss( +def test_torch_soft_margin_loss( *, - dtype_and_inputs, - margin, - p, - swap, + dtype_and_x, size_average, reduce, reduction, + frontend, test_flags, fn_tree, backend_fw, - frontend, on_device, ): - input_dtype, x = dtype_and_inputs - anchor_dtype, anchor = input_dtype[0], x[0] - positive_dtype, positive = input_dtype[1], x[1] - negative_dtype, negative = input_dtype[2], x[2] + input_dtype, x = dtype_and_x + pred_dtype, pred = input_dtype[0], x[0] + tar_dtype, tar = input_dtype[1], x[1] helpers.test_frontend_function( - input_dtypes=[anchor_dtype, positive_dtype, negative_dtype], + input_dtypes=[pred_dtype, tar_dtype], backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - anchor=anchor, - positive=positive, - negative=negative, - margin=margin, - p=p, - swap=swap, + input=pred, + target=tar, size_average=size_average, reduce=reduce, reduction=reduction, ) -# multilabel soft margin loss +# triplet margin loss @handle_frontend_test( - fn_tree="torch.nn.functional.multilabel_soft_margin_loss", + fn_tree="torch.nn.functional.triplet_margin_loss", dtype_and_inputs=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, + num_arrays=3, allow_inf=False, shared_dtype=True, + min_value=0.0, + max_value=1.0, min_num_dims=1, + max_num_dims=2, + min_dim_size=1, ), + margin=st.floats(), + p=st.integers(min_value=0, max_value=2), + swap=st.booleans(), size_average=st.booleans(), reduce=st.booleans(), reduction=st.sampled_from(["none", "mean", "sum"]), test_with_out=st.just(False), ) -def test_torch_multilabel_soft_margin_loss( +def test_torch_triplet_margin_loss( *, dtype_and_inputs, + margin, + p, + swap, size_average, reduce, reduction, @@ -955,15 +991,22 @@ def test_torch_multilabel_soft_margin_loss( on_device, ): input_dtype, x = dtype_and_inputs + anchor_dtype, anchor = input_dtype[0], x[0] + positive_dtype, positive = input_dtype[1], x[1] + negative_dtype, negative = input_dtype[2], x[2] helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=[anchor_dtype, positive_dtype, negative_dtype], backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=x[0], - target=x[1], + anchor=anchor, + positive=positive, + negative=negative, + margin=margin, + p=p, + swap=swap, size_average=size_average, reduce=reduce, reduction=reduction, @@ -1023,46 +1066,3 @@ def test_torch_triplet_margin_with_distance_loss( swap=swap, reduction=reduction, ) - - -# multilabel_margin_loss - - -@handle_frontend_test( - fn_tree="torch.nn.functional.multilabel_margin_loss", - dtype_and_inputs=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - num_arrays=2, - allow_inf=False, - shared_dtype=True, - min_num_dims=1, - ), - size_average=st.booleans(), - reduce=st.booleans(), - reduction=st.sampled_from(["none", "mean", "sum"]), - test_with_out=st.just(False), -) -def test_torch_multilabel_margin_loss( - *, - dtype_and_inputs, - reduction, - size_average, - reduce, - test_flags, - fn_tree, - frontend, - on_device, -): - input_dtype, x = dtype_and_inputs - helpers.test_frontend_function( - input_dtypes=input_dtype, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=x[0], - target=x[1], - reduction=reduction, - size_average=size_average, - reduce=reduce, - ) diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_non_linear_activation_functions.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_non_linear_activation_functions.py index dd423523c004b..4c4614c137fc3 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_non_linear_activation_functions.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_non_linear_activation_functions.py @@ -8,6 +8,14 @@ from ivy_tests.test_ivy.helpers import handle_frontend_test +# --- Helpers --- # +# --------------- # + + +def _filter_dtypes(input_dtype): + assume(("bfloat16" not in input_dtype) and ("float16" not in input_dtype)) + + @st.composite def _generate_prelu_arrays(draw): arr_size = draw(helpers.ints(min_value=2, max_value=5)) @@ -24,125 +32,252 @@ def _generate_prelu_arrays(draw): return dtype, input_weight -def _filter_dtypes(input_dtype): - assume(("bfloat16" not in input_dtype) and ("float16" not in input_dtype)) +@st.composite +def _glu_arrays(draw): + dtype = draw(helpers.get_dtypes("float", index=1, full=False)) + shape = draw(st.shared(helpers.ints(min_value=1, max_value=5))) + shape = shape * 2 + input = draw(helpers.array_values(dtype=dtype[0], shape=(shape, shape))) + dim = draw(st.shared(helpers.get_axis(shape=(shape,), force_int=True))) + return dtype, input, dim -# sigmoid -@handle_frontend_test( - fn_tree="torch.nn.functional.sigmoid", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - ), -) -def test_torch_sigmoid( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - atol=1e-2, - input=x[0], +@st.composite +def _x_and_scaled_attention(draw, dtypes): + dtype = draw(dtypes) + num_queries = draw(helpers.ints(min_value=2, max_value=4)) + num_keys = draw(helpers.ints(min_value=2, max_value=4)) + feat_dim = draw(helpers.ints(min_value=2, max_value=4)) + batch_size = draw(helpers.ints(min_value=1, max_value=2)) + q_shape = (batch_size,) + (num_queries,) + (feat_dim,) + k_shape = (batch_size,) + (num_keys,) + (feat_dim,) + v_shape = (batch_size,) + (num_keys,) + (feat_dim,) + mask_shape = (batch_size,) + (num_queries,) + (num_keys,) + + query = draw( + helpers.array_values( + dtype=dtype[0], + shape=q_shape, + min_value=0, + max_value=1e2, + large_abs_safety_factor=7, + small_abs_safety_factor=7, + safety_factor_scale="linear", + ) ) + key = draw( + helpers.array_values( + dtype=dtype[0], + shape=k_shape, + min_value=0, + max_value=1e2, + large_abs_safety_factor=7, + small_abs_safety_factor=7, + safety_factor_scale="linear", + ) + ) + value = draw( + helpers.array_values( + dtype=dtype[0], + shape=v_shape, + min_value=0, + max_value=1e2, + large_abs_safety_factor=7, + small_abs_safety_factor=7, + safety_factor_scale="linear", + ) + ) + mask = draw( + helpers.array_values( + dtype="bool", + shape=mask_shape, + ) + | st.none() + ) + return dtype, query, key, value, mask -# softmax -@handle_frontend_test( - fn_tree="torch.nn.functional.softmax", - dtype_x_and_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("float"), - min_num_dims=1, - max_axes_size=1, - force_int_axis=True, - valid_axis=True, - ), - dtypes=helpers.get_dtypes("float", full=False), -) -def test_torch_softmax( - *, - dtype_x_and_axis, - dtypes, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, x, axis = dtype_x_and_axis - ivy.set_backend(backend_fw) - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=x[0], - dim=axis, - _stacklevel=3, - dtype=ivy.as_ivy_dtype(dtypes[0]), +@st.composite +def mha_forward_args(draw, dtypes): + dtype = draw(dtypes) + embed_dim = draw(helpers.ints(min_value=2, max_value=4)) + batch_size = draw(helpers.ints(min_value=1, max_value=2)) * 3 + seq_len = draw(helpers.ints(min_value=2, max_value=4)) + shape = ( + seq_len, + batch_size, + embed_dim, ) - ivy.previous_backend() + heads = draw(helpers.ints(min_value=1, max_value=4)) + head_dim = embed_dim // heads + if head_dim * heads != embed_dim: + heads = 1 + head_dim = embed_dim -# gelu -@handle_frontend_test( - fn_tree="torch.nn.functional.gelu", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - max_value=1e04, - ), - approximate=st.sampled_from(["none", "tanh"]), -) -def test_torch_gelu( - *, - dtype_and_x, - approximate, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=x[0], - rtol=1e-02, - atol=1e-02, - approximate=approximate, + if dtype[0] == "float32": + is_causal = False + else: + is_causal = draw(helpers.array_bools(size=1))[0] + + q = draw( + helpers.array_values(dtype=dtype[0], shape=shape, min_value=0.1, max_value=1) + ) + k = draw( + helpers.array_values(dtype=dtype[0], shape=shape, min_value=0.1, max_value=1) + ) + v = draw( + helpers.array_values(dtype=dtype[0], shape=shape, min_value=0.1, max_value=1) + ) + in_proj_weight = draw( + helpers.array_values( + dtype=dtype[0], + min_value=0.1, + max_value=1, + shape=(embed_dim * 3, embed_dim), + ) + ) + in_proj_bias = draw( + helpers.array_values( + dtype=dtype[0], + min_value=0.1, + max_value=1, + shape=(embed_dim * 3,), + ) ) + if random.randint(0, 1) == 0: + use_separate_proj_weight = True + q_proj_weight = draw( + helpers.array_values( + dtype=dtype[0], + min_value=0.1, + max_value=1, + shape=(embed_dim, embed_dim), + ) + ) + k_proj_weight = draw( + helpers.array_values( + dtype=dtype[0], + min_value=0.1, + max_value=1, + shape=(embed_dim, embed_dim), + ) + ) + v_proj_weight = draw( + helpers.array_values( + dtype=dtype[0], + min_value=0.1, + max_value=1, + shape=(embed_dim, embed_dim), + ) + ) + else: + use_separate_proj_weight = False + q_proj_weight = None + k_proj_weight = None + v_proj_weight = None -# leaky_relu + out_proj_weight = draw( + helpers.array_values( + dtype=dtype[0], + min_value=0.1, + max_value=1, + shape=(embed_dim, embed_dim), + ) + ) + out_proj_bias = draw( + helpers.array_values( + dtype=dtype[0], + min_value=0.1, + max_value=1, + shape=(embed_dim,), + ) + ) + bias_k = random.choice( + [ + draw( + helpers.array_values( + dtype=dtype[0], + min_value=0.1, + max_value=1, + shape=(embed_dim,), + ) + ), + None, + ] + ) + bias_v = bias_k + + if bias_k is None: + static_k = random.choice( + [ + draw( + helpers.array_values( + dtype=dtype[0], + min_value=0.1, + max_value=1, + shape=(batch_size * heads, seq_len, head_dim), + ) + ), + None, + ] + ) + static_v = static_k + else: + static_k = None + static_v = None + + attn_mask = ivy.ones((seq_len, seq_len), dtype=dtype[0]) + key_padding_mask = random.choice( + [ + ivy.random_normal(shape=(seq_len, seq_len), dtype=dtype[0]) > 0, + None, + ] + ) + + return ( + dtype, + q, + k, + v, + heads, + use_separate_proj_weight, + embed_dim, + in_proj_weight, + in_proj_bias, + out_proj_weight, + out_proj_bias, + q_proj_weight, + k_proj_weight, + v_proj_weight, + bias_k, + bias_v, + static_k, + static_v, + attn_mask, + key_padding_mask, + is_causal, + ) + + +# --- Main --- # +# ------------ # + + +# celu @handle_frontend_test( - fn_tree="torch.nn.functional.leaky_relu", - dtype_and_x=helpers.dtype_and_values( + fn_tree="torch.nn.functional.celu", + dtype_and_input=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), ), - alpha=st.floats(min_value=0.0, max_value=1.0, exclude_min=True), + alpha=helpers.floats(min_value=0.1, max_value=1.0, exclude_min=True), test_inplace=st.booleans(), test_with_out=st.just(False), ) -def test_torch_leaky_relu( +def test_torch_celu( *, - dtype_and_x, + dtype_and_input, alpha, on_device, fn_tree, @@ -150,7 +285,8 @@ def test_torch_leaky_relu( test_flags, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, input = dtype_and_input + _filter_dtypes(input_dtype) helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -158,30 +294,33 @@ def test_torch_leaky_relu( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=x[0], - rtol=1e-02, - atol=1e-02, - negative_slope=alpha, + input=input[0], + alpha=alpha, ) -# tanh +# elu @handle_frontend_test( - fn_tree="torch.nn.functional.tanh", - dtype_and_x=helpers.dtype_and_values( + fn_tree="torch.nn.functional.elu", + dtype_and_input=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), ), + alpha=helpers.floats(min_value=0.1, max_value=1.0, exclude_min=True), + test_inplace=st.booleans(), + test_with_out=st.just(False), ) -def test_torch_tanh( +def test_torch_elu( *, - dtype_and_x, + dtype_and_input, + alpha, on_device, fn_tree, frontend, test_flags, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, input = dtype_and_input + _filter_dtypes(input_dtype) helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -189,28 +328,31 @@ def test_torch_tanh( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - atol=1e-2, - input=x[0], + input=input[0], + alpha=alpha, ) -# logsigmoid +# elu_ @handle_frontend_test( - fn_tree="torch.nn.functional.logsigmoid", - dtype_and_x=helpers.dtype_and_values( + fn_tree="torch.nn.functional.elu_", + dtype_and_input=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), ), + alpha=helpers.floats(min_value=0.1, max_value=1.0, exclude_min=True), ) -def test_torch_logsigmoid( +def test_torch_elu_( *, - dtype_and_x, + dtype_and_input, + alpha, on_device, fn_tree, frontend, test_flags, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, input = dtype_and_input + _filter_dtypes(input_dtype) helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -218,34 +360,32 @@ def test_torch_logsigmoid( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=x[0], + test_values=False, + input=input[0], + alpha=alpha, ) -# softmin +# gelu @handle_frontend_test( - fn_tree="torch.nn.functional.softmin", - dtype_x_and_axis=helpers.dtype_values_axis( + fn_tree="torch.nn.functional.gelu", + dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - min_num_dims=1, - max_axes_size=1, - force_int_axis=True, - valid_axis=True, + max_value=1e04, ), - dtypes=helpers.get_dtypes("float", full=False), + approximate=st.sampled_from(["none", "tanh"]), ) -def test_torch_softmin( +def test_torch_gelu( *, - dtype_x_and_axis, - dtypes, + dtype_and_x, + approximate, on_device, fn_tree, frontend, test_flags, backend_fw, ): - input_dtype, x, axis = dtype_x_and_axis - ivy.set_backend(backend_fw) + input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -254,35 +394,28 @@ def test_torch_softmin( fn_tree=fn_tree, on_device=on_device, input=x[0], - dim=axis, - dtype=ivy.as_ivy_dtype(dtypes[0]), + rtol=1e-02, + atol=1e-02, + approximate=approximate, ) - ivy.previous_backend() -# threshold +# glu @handle_frontend_test( - fn_tree="torch.nn.functional.threshold", - dtype_and_input=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - ), - threshold=helpers.floats(min_value=0.0, max_value=1.0), - value=helpers.ints(min_value=5, max_value=20), - test_with_out=st.just(False), - test_inplace=st.booleans(), + fn_tree="torch.nn.functional.glu", + dtype_input_dim=_glu_arrays(), ) -def test_torch_threshold( +def test_torch_glu( *, - dtype_and_input, - threshold, - value, + dtype_input_dim, on_device, fn_tree, frontend, test_flags, backend_fw, ): - input_dtype, input = dtype_and_input + input_dtype, input, dim = dtype_input_dim + _filter_dtypes(input_dtype) helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -291,34 +424,37 @@ def test_torch_threshold( fn_tree=fn_tree, on_device=on_device, input=input[0], - threshold=threshold, - value=value, + dim=dim, ) -# threshold_ +# gumbel_softmax @handle_frontend_test( - fn_tree="torch.nn.functional.threshold_", - dtype_and_input=helpers.dtype_and_values( + fn_tree="torch.nn.functional.gumbel_softmax", + dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), ), - threshold=helpers.floats(min_value=0.0, max_value=1.0), - value=helpers.ints(min_value=5, max_value=20), + tau=st.floats(min_value=0), + hard=st.booleans(), + eps=st.floats(min_value=0, max_value=1), + dim=st.integers(), test_with_out=st.just(False), test_inplace=st.booleans(), ) -def test_torch_threshold_( +def test_torch_gumbel_softmax( *, - dtype_and_input, - threshold, - value, + dtype_and_x, + tau, + hard, + eps, + dim, on_device, fn_tree, frontend, test_flags, backend_fw, ): - input_dtype, input = dtype_and_input + input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -326,24 +462,27 @@ def test_torch_threshold_( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=input[0], - threshold=threshold, - value=value, + test_values=False, + logits=x[0], + tau=tau, + hard=hard, + eps=eps, + dim=dim, ) -# relu6 +# hardshrink @handle_frontend_test( - fn_tree="torch.nn.functional.relu6", + fn_tree="torch.nn.functional.hardshrink", dtype_and_input=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), + available_dtypes=helpers.get_dtypes("float"), ), - test_inplace=st.booleans(), - test_with_out=st.just(False), + lambd=helpers.floats(min_value=0, max_value=1, exclude_min=True), ) -def test_torch_relu6( +def test_torch_hardshrink( *, dtype_and_input, + lambd, on_device, fn_tree, frontend, @@ -360,23 +499,22 @@ def test_torch_relu6( fn_tree=fn_tree, on_device=on_device, input=input[0], + lambd=lambd, ) -# elu +# hardsigmoid @handle_frontend_test( - fn_tree="torch.nn.functional.elu", + fn_tree="torch.nn.functional.hardsigmoid", dtype_and_input=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), ), - alpha=helpers.floats(min_value=0.1, max_value=1.0, exclude_min=True), - test_inplace=st.booleans(), test_with_out=st.just(False), + test_inplace=st.booleans(), ) -def test_torch_elu( +def test_torch_hardsigmoid( *, dtype_and_input, - alpha, on_device, fn_tree, frontend, @@ -393,22 +531,22 @@ def test_torch_elu( fn_tree=fn_tree, on_device=on_device, input=input[0], - alpha=alpha, ) -# elu_ +# hardswish @handle_frontend_test( - fn_tree="torch.nn.functional.elu_", + fn_tree="torch.nn.functional.hardswish", dtype_and_input=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), + safety_factor_scale="log", ), - alpha=helpers.floats(min_value=0.1, max_value=1.0, exclude_min=True), + test_with_out=st.just(False), + test_inplace=st.booleans(), ) -def test_torch_elu_( +def test_torch_hardswish( *, dtype_and_input, - alpha, on_device, fn_tree, frontend, @@ -424,34 +562,32 @@ def test_torch_elu_( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - test_values=False, input=input[0], - alpha=alpha, ) -# celu +# hardtanh @handle_frontend_test( - fn_tree="torch.nn.functional.celu", - dtype_and_input=helpers.dtype_and_values( + fn_tree="torch.nn.functional.hardtanh", + dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), ), - alpha=helpers.floats(min_value=0.1, max_value=1.0, exclude_min=True), - test_inplace=st.booleans(), + max_val=st.floats(min_value=0, max_value=1, exclude_min=True), test_with_out=st.just(False), + test_inplace=st.booleans(), ) -def test_torch_celu( +def test_torch_hardtanh( *, - dtype_and_input, - alpha, + dtype_and_x, + max_val, on_device, fn_tree, frontend, test_flags, backend_fw, ): - input_dtype, input = dtype_and_input - _filter_dtypes(input_dtype) + input_dtype, x = dtype_and_x + max_min = max_val, -max_val helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -459,31 +595,34 @@ def test_torch_celu( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=input[0], - alpha=alpha, + input=x[0], + min_val=max_min[1], + max_val=max_min[0], ) -# mish +# hardtanh_ @handle_frontend_test( - fn_tree="torch.nn.functional.mish", - dtype_and_input=helpers.dtype_and_values( + fn_tree="torch.nn.functional.hardtanh_", + dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), ), - test_inplace=st.booleans(), + max_val=st.floats(min_value=0, max_value=1, exclude_min=True), test_with_out=st.just(False), + test_inplace=st.booleans(), ) -def test_torch_mish( +def test_torch_hardtanh_( *, - dtype_and_input, + dtype_and_x, + max_val, + on_device, fn_tree, frontend, - on_device, test_flags, backend_fw, ): - input_dtype, input = dtype_and_input - _filter_dtypes(input_dtype) + input_dtype, x = dtype_and_x + max_min = max_val, -max_val helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -491,147 +630,183 @@ def test_torch_mish( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=input[0], + test_values=False, + input=x[0], + min_val=max_min[1], + max_val=max_min[0], ) -# relu +# leaky_relu @handle_frontend_test( - fn_tree="torch.nn.functional.relu", - dtype_and_input=helpers.dtype_and_values( + fn_tree="torch.nn.functional.leaky_relu", + dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), ), + alpha=st.floats(min_value=0.0, max_value=1.0, exclude_min=True), test_inplace=st.booleans(), test_with_out=st.just(False), ) -def test_torch_relu( - dtype_and_input, +def test_torch_leaky_relu( + *, + dtype_and_x, + alpha, + on_device, + fn_tree, frontend, test_flags, - fn_tree, backend_fw, ): - input_dtype, input = dtype_and_input - _filter_dtypes(input_dtype) + input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, - input=input[0], + on_device=on_device, + input=x[0], + rtol=1e-02, + atol=1e-02, + negative_slope=alpha, ) -# relu_ +# leaky_relu_ +# ToDo test for value test once inplace testing implemented @handle_frontend_test( - fn_tree="torch.nn.functional.relu_", - dtype_and_input=helpers.dtype_and_values( + fn_tree="torch.nn.functional.leaky_relu_", + dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), ), + alpha=st.floats(min_value=0, max_value=1, exclude_min=True), test_with_out=st.just(False), + test_inplace=st.booleans(), ) -def test_torch_relu_( - dtype_and_input, +def test_torch_leaky_relu_( + *, + dtype_and_x, + alpha, + on_device, + fn_tree, frontend, test_flags, - fn_tree, backend_fw, ): - input_dtype, input = dtype_and_input - _filter_dtypes(input_dtype) + input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, - input=input[0], + on_device=on_device, + test_values=False, + input=x[0], + negative_slope=alpha, ) -# selu +# local_response_norm @handle_frontend_test( - fn_tree="torch.nn.functional.selu", - dtype_and_input=helpers.dtype_and_values( + fn_tree="torch.nn.functional.local_response_norm", + dtype_x_and_axis=helpers.dtype_values_axis( available_dtypes=helpers.get_dtypes("float"), + min_num_dims=3, + force_int_axis=True, + valid_axis=True, ), - test_inplace=st.booleans(), - test_with_out=st.just(False), + size=helpers.ints(min_value=3, max_value=10), + alpha=helpers.floats(min_value=1e-4, max_value=1e-3), + beta=helpers.floats(min_value=0.5, max_value=2.0), + k=helpers.ints(min_value=0, max_value=1), ) -def test_torch_selu( +def test_torch_local_response_norm( *, - dtype_and_input, + dtype_x_and_axis, + size, + alpha, + beta, + k, on_device, fn_tree, frontend, test_flags, backend_fw, ): - input_dtype, input = dtype_and_input - _filter_dtypes(input_dtype) + dtype, x, axis = dtype_x_and_axis + _filter_dtypes(dtype) helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=input[0], + input=x[0], + size=size, + alpha=alpha, + beta=beta, + k=k, ) -# prelu +# log_softmax @handle_frontend_test( - fn_tree="torch.nn.functional.prelu", - dtype_input_and_weight=_generate_prelu_arrays(), + fn_tree="torch.nn.functional.log_softmax", + dtype_x_and_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("float"), + min_num_dims=1, + max_axes_size=1, + force_int_axis=True, + valid_axis=True, + ), + dtypes=helpers.get_dtypes("float", none=False, full=False), ) -def test_torch_prelu( +def test_torch_log_softmax( *, - dtype_input_and_weight, + dtype_x_and_axis, + dtypes, on_device, fn_tree, frontend, test_flags, backend_fw, ): - dtype, inputs = dtype_input_and_weight - _filter_dtypes(dtype) + input_dtype, x, axis = dtype_x_and_axis + ivy.set_backend(backend_fw) helpers.test_frontend_function( - input_dtypes=dtype, + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=inputs[0], - weight=inputs[1], + input=x[0], + dim=axis, + _stacklevel=3, + dtype=ivy.as_ivy_dtype(dtypes[0]), ) + ivy.previous_backend() -# rrelu +# logsigmoid @handle_frontend_test( - fn_tree="torch.nn.functional.rrelu", - dtype_and_input=helpers.dtype_and_values( + fn_tree="torch.nn.functional.logsigmoid", + dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), ), - lower=helpers.floats(min_value=0, max_value=0.5, exclude_min=True), - upper=helpers.floats(min_value=0.5, max_value=1.0, exclude_min=True), - test_with_out=st.just(False), - test_inplace=st.booleans(), ) -def test_torch_rrelu( +def test_torch_logsigmoid( *, - dtype_and_input, - lower, - upper, + dtype_and_x, on_device, fn_tree, frontend, test_flags, backend_fw, ): - input_dtype, input = dtype_and_input + input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -639,35 +814,30 @@ def test_torch_rrelu( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=input[0], - lower=lower, - upper=upper, + input=x[0], ) -# rrelu_ +# mish @handle_frontend_test( - fn_tree="torch.nn.functional.rrelu_", + fn_tree="torch.nn.functional.mish", dtype_and_input=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), ), - lower=helpers.floats(min_value=0, max_value=0.5, exclude_min=True), - upper=helpers.floats(min_value=0.5, max_value=1.0, exclude_min=True), - test_with_out=st.just(False), test_inplace=st.booleans(), + test_with_out=st.just(False), ) -def test_torch_rrelu_( +def test_torch_mish( *, dtype_and_input, - lower, - upper, - on_device, fn_tree, frontend, + on_device, test_flags, backend_fw, ): input_dtype, input = dtype_and_input + _filter_dtypes(input_dtype) helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -675,162 +845,210 @@ def test_torch_rrelu_( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - test_values=False, input=input[0], - lower=lower, - upper=upper, ) -# hardshrink +# multi_head_attention_forward @handle_frontend_test( - fn_tree="torch.nn.functional.hardshrink", - dtype_and_input=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + fn_tree="torch.nn.functional.multi_head_attention_forward", + dtype_mha_args=mha_forward_args( + dtypes=helpers.get_dtypes("valid"), ), - lambd=helpers.floats(min_value=0, max_value=1, exclude_min=True), + add_zero_attn=st.just(False), + dropout_p=st.sampled_from([0.0, 0.1, 0.2]), + training=st.booleans(), + need_weights=st.booleans(), + average_attn_weights=st.booleans(), + test_with_out=st.just(False), ) -def test_torch_hardshrink( +def test_torch_multi_head_attention_forward( *, - dtype_and_input, - lambd, on_device, fn_tree, frontend, test_flags, + dtype_mha_args, + add_zero_attn, + dropout_p, + training, + need_weights, + average_attn_weights, backend_fw, ): - input_dtype, input = dtype_and_input - _filter_dtypes(input_dtype) + ( + dtype, + q, + k, + v, + heads, + use_separate_proj_weight, + embed_dim, + in_proj_weight, + in_proj_bias, + out_proj_weight, + out_proj_bias, + q_proj_weight, + k_proj_weight, + v_proj_weight, + bias_k, + bias_v, + static_k, + static_v, + attn_mask, + key_padding_mask, + is_causal, + ) = dtype_mha_args + helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=input[0], - lambd=lambd, + test_values=not training or dropout_p == 0.0, + query=q, + key=k, + value=v, + embed_dim_to_check=embed_dim, + num_heads=heads, + in_proj_weight=in_proj_weight, + in_proj_bias=in_proj_bias, + bias_k=bias_k, + bias_v=bias_v, + add_zero_attn=add_zero_attn, + dropout_p=dropout_p, + out_proj_weight=out_proj_weight, + out_proj_bias=out_proj_bias, + training=training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + use_separate_proj_weight=use_separate_proj_weight, + q_proj_weight=q_proj_weight, + k_proj_weight=k_proj_weight, + v_proj_weight=v_proj_weight, + static_k=static_k, + static_v=static_v, + average_attn_weights=average_attn_weights, + is_causal=is_causal, ) -# softsign +# normalize @handle_frontend_test( - fn_tree="torch.nn.functional.softsign", - dtype_and_input=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), + fn_tree="torch.nn.functional.normalize", + dtype_x_and_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("float"), + min_num_dims=1, + max_axes_size=1, + force_int_axis=True, + valid_axis=True, ), + p=helpers.ints(min_value=2, max_value=5), ) -def test_torch_softsign( +def test_torch_normalize( *, - dtype_and_input, + dtype_x_and_axis, + p, on_device, fn_tree, frontend, test_flags, backend_fw, ): - input_dtype, input = dtype_and_input + dtype, x, axis = dtype_x_and_axis + _filter_dtypes(dtype) helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=input[0], + input=x[0], + p=p, + dim=axis, + eps=1e-12, ) -# softshrink +# prelu @handle_frontend_test( - fn_tree="torch.nn.functional.softshrink", - dtype_and_input=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - ), - lambd=helpers.floats(min_value=0, max_value=1, exclude_min=True), + fn_tree="torch.nn.functional.prelu", + dtype_input_and_weight=_generate_prelu_arrays(), ) -def test_torch_softshrink( +def test_torch_prelu( *, - dtype_and_input, - lambd, + dtype_input_and_weight, on_device, fn_tree, frontend, test_flags, backend_fw, ): - input_dtype, input = dtype_and_input - _filter_dtypes(input_dtype) + dtype, inputs = dtype_input_and_weight + _filter_dtypes(dtype) helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=input[0], - lambd=lambd, + input=inputs[0], + weight=inputs[1], ) -# silu +# relu @handle_frontend_test( - fn_tree="torch.nn.functional.silu", + fn_tree="torch.nn.functional.relu", dtype_and_input=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), ), - test_with_out=st.just(False), test_inplace=st.booleans(), + test_with_out=st.just(False), ) -def test_torch_silu( - *, +def test_torch_relu( dtype_and_input, - on_device, - fn_tree, frontend, test_flags, + fn_tree, backend_fw, ): input_dtype, input = dtype_and_input + _filter_dtypes(input_dtype) helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, - on_device=on_device, - rtol=1e-2, - atol=1e-2, input=input[0], ) -@st.composite -def _glu_arrays(draw): - dtype = draw(helpers.get_dtypes("float", index=1, full=False)) - shape = draw(st.shared(helpers.ints(min_value=1, max_value=5))) - shape = shape * 2 - input = draw(helpers.array_values(dtype=dtype[0], shape=(shape, shape))) - dim = draw(st.shared(helpers.get_axis(shape=(shape,), force_int=True))) - return dtype, input, dim - - -# glu +# relu6 @handle_frontend_test( - fn_tree="torch.nn.functional.glu", - dtype_input_dim=_glu_arrays(), + fn_tree="torch.nn.functional.relu6", + dtype_and_input=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + ), + test_inplace=st.booleans(), + test_with_out=st.just(False), ) -def test_torch_glu( +def test_torch_relu6( *, - dtype_input_dim, + dtype_and_input, on_device, fn_tree, frontend, test_flags, backend_fw, ): - input_dtype, input, dim = dtype_input_dim + input_dtype, input = dtype_and_input _filter_dtypes(input_dtype) helpers.test_frontend_function( input_dtypes=input_dtype, @@ -840,59 +1058,52 @@ def test_torch_glu( fn_tree=fn_tree, on_device=on_device, input=input[0], - dim=dim, ) -# log_softmax +# relu_ @handle_frontend_test( - fn_tree="torch.nn.functional.log_softmax", - dtype_x_and_axis=helpers.dtype_values_axis( + fn_tree="torch.nn.functional.relu_", + dtype_and_input=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - min_num_dims=1, - max_axes_size=1, - force_int_axis=True, - valid_axis=True, ), - dtypes=helpers.get_dtypes("float", none=False, full=False), + test_with_out=st.just(False), ) -def test_torch_log_softmax( - *, - dtype_x_and_axis, - dtypes, - on_device, - fn_tree, +def test_torch_relu_( + dtype_and_input, frontend, test_flags, + fn_tree, backend_fw, ): - input_dtype, x, axis = dtype_x_and_axis - ivy.set_backend(backend_fw) + input_dtype, input = dtype_and_input + _filter_dtypes(input_dtype) helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, - on_device=on_device, - input=x[0], - dim=axis, - _stacklevel=3, - dtype=ivy.as_ivy_dtype(dtypes[0]), + input=input[0], ) - ivy.previous_backend() -# tanhshrink +# rrelu @handle_frontend_test( - fn_tree="torch.nn.functional.tanhshrink", + fn_tree="torch.nn.functional.rrelu", dtype_and_input=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), ), + lower=helpers.floats(min_value=0, max_value=0.5, exclude_min=True), + upper=helpers.floats(min_value=0.5, max_value=1.0, exclude_min=True), + test_with_out=st.just(False), + test_inplace=st.booleans(), ) -def test_torch_tanhshrink( +def test_torch_rrelu( *, dtype_and_input, + lower, + upper, on_device, fn_tree, frontend, @@ -908,31 +1119,34 @@ def test_torch_tanhshrink( fn_tree=fn_tree, on_device=on_device, input=input[0], + lower=lower, + upper=upper, ) -# leaky_relu_ -# ToDo test for value test once inplace testing implemented +# rrelu_ @handle_frontend_test( - fn_tree="torch.nn.functional.leaky_relu_", - dtype_and_x=helpers.dtype_and_values( + fn_tree="torch.nn.functional.rrelu_", + dtype_and_input=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), ), - alpha=st.floats(min_value=0, max_value=1, exclude_min=True), + lower=helpers.floats(min_value=0, max_value=0.5, exclude_min=True), + upper=helpers.floats(min_value=0.5, max_value=1.0, exclude_min=True), test_with_out=st.just(False), test_inplace=st.booleans(), ) -def test_torch_leaky_relu_( +def test_torch_rrelu_( *, - dtype_and_x, - alpha, + dtype_and_input, + lower, + upper, on_device, fn_tree, frontend, test_flags, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, input = dtype_and_input helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -941,53 +1155,63 @@ def test_torch_leaky_relu_( fn_tree=fn_tree, on_device=on_device, test_values=False, - input=x[0], - negative_slope=alpha, + input=input[0], + lower=lower, + upper=upper, ) -# hardswish +# scaled_dot_product_attention @handle_frontend_test( - fn_tree="torch.nn.functional.hardswish", - dtype_and_input=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - safety_factor_scale="log", + fn_tree="torch.nn.functional.scaled_dot_product_attention", + dtype_q_k_v_mask=_x_and_scaled_attention( + dtypes=helpers.get_dtypes("float"), ), - test_with_out=st.just(False), - test_inplace=st.booleans(), + dropout_p=st.floats(min_value=0, max_value=0.99), + is_causal=st.booleans(), ) -def test_torch_hardswish( +def test_torch_scaled_dot_product_attention( *, - dtype_and_input, + dtype_q_k_v_mask, + dropout_p, + is_causal, on_device, fn_tree, frontend, test_flags, backend_fw, ): - input_dtype, input = dtype_and_input - _filter_dtypes(input_dtype) + (dtype, query, key, value, mask) = dtype_q_k_v_mask + is_causal = is_causal if mask is None else False helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=input[0], + test_values=dropout_p == 0.0, + rtol=1e-05, + atol=1e-05, + query=query, + key=key, + value=value, + attn_mask=mask, + dropout_p=dropout_p, + is_causal=is_causal, ) -# hardsigmoid +# selu @handle_frontend_test( - fn_tree="torch.nn.functional.hardsigmoid", + fn_tree="torch.nn.functional.selu", dtype_and_input=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), ), - test_with_out=st.just(False), test_inplace=st.booleans(), + test_with_out=st.just(False), ) -def test_torch_hardsigmoid( +def test_torch_selu( *, dtype_and_input, on_device, @@ -1009,20 +1233,16 @@ def test_torch_hardsigmoid( ) -# hardtanh +# sigmoid @handle_frontend_test( - fn_tree="torch.nn.functional.hardtanh", + fn_tree="torch.nn.functional.sigmoid", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), ), - max_val=st.floats(min_value=0, max_value=1, exclude_min=True), - test_with_out=st.just(False), - test_inplace=st.booleans(), ) -def test_torch_hardtanh( +def test_torch_sigmoid( *, dtype_and_x, - max_val, on_device, fn_tree, frontend, @@ -1030,7 +1250,6 @@ def test_torch_hardtanh( backend_fw, ): input_dtype, x = dtype_and_x - max_min = max_val, -max_val helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -1038,34 +1257,30 @@ def test_torch_hardtanh( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, + atol=1e-2, input=x[0], - min_val=max_min[1], - max_val=max_min[0], ) -# hardtanh_ +# silu @handle_frontend_test( - fn_tree="torch.nn.functional.hardtanh_", - dtype_and_x=helpers.dtype_and_values( + fn_tree="torch.nn.functional.silu", + dtype_and_input=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), ), - max_val=st.floats(min_value=0, max_value=1, exclude_min=True), test_with_out=st.just(False), test_inplace=st.booleans(), ) -def test_torch_hardtanh_( +def test_torch_silu( *, - dtype_and_x, - max_val, + dtype_and_input, on_device, fn_tree, frontend, test_flags, backend_fw, ): - input_dtype, x = dtype_and_x - max_min = max_val, -max_val + input_dtype, input = dtype_and_input helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -1073,16 +1288,15 @@ def test_torch_hardtanh_( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - test_values=False, - input=x[0], - min_val=max_min[1], - max_val=max_min[0], + rtol=1e-2, + atol=1e-2, + input=input[0], ) -# normalize +# softmax @handle_frontend_test( - fn_tree="torch.nn.functional.normalize", + fn_tree="torch.nn.functional.softmax", dtype_x_and_axis=helpers.dtype_values_axis( available_dtypes=helpers.get_dtypes("float"), min_num_dims=1, @@ -1090,76 +1304,71 @@ def test_torch_hardtanh_( force_int_axis=True, valid_axis=True, ), - p=helpers.ints(min_value=2, max_value=5), + dtypes=helpers.get_dtypes("float", full=False), ) -def test_torch_normalize( +def test_torch_softmax( *, dtype_x_and_axis, - p, + dtypes, on_device, fn_tree, frontend, test_flags, backend_fw, ): - dtype, x, axis = dtype_x_and_axis - _filter_dtypes(dtype) + input_dtype, x, axis = dtype_x_and_axis + ivy.set_backend(backend_fw) helpers.test_frontend_function( - input_dtypes=dtype, + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, input=x[0], - p=p, dim=axis, - eps=1e-12, + _stacklevel=3, + dtype=ivy.as_ivy_dtype(dtypes[0]), ) + ivy.previous_backend() -# local_response_norm +# softmin @handle_frontend_test( - fn_tree="torch.nn.functional.local_response_norm", + fn_tree="torch.nn.functional.softmin", dtype_x_and_axis=helpers.dtype_values_axis( available_dtypes=helpers.get_dtypes("float"), - min_num_dims=3, + min_num_dims=1, + max_axes_size=1, force_int_axis=True, valid_axis=True, ), - size=helpers.ints(min_value=3, max_value=10), - alpha=helpers.floats(min_value=1e-4, max_value=1e-3), - beta=helpers.floats(min_value=0.5, max_value=2.0), - k=helpers.ints(min_value=0, max_value=1), + dtypes=helpers.get_dtypes("float", full=False), ) -def test_torch_local_response_norm( +def test_torch_softmin( *, dtype_x_and_axis, - size, - alpha, - beta, - k, + dtypes, on_device, fn_tree, frontend, test_flags, backend_fw, ): - dtype, x, axis = dtype_x_and_axis - _filter_dtypes(dtype) + input_dtype, x, axis = dtype_x_and_axis + ivy.set_backend(backend_fw) helpers.test_frontend_function( - input_dtypes=dtype, + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, input=x[0], - size=size, - alpha=alpha, - beta=beta, - k=k, + dim=axis, + dtype=ivy.as_ivy_dtype(dtypes[0]), ) + ivy.previous_backend() # softplus @@ -1197,190 +1406,76 @@ def test_torch_softplus( ) -@st.composite -def mha_forward_args(draw, dtypes): - dtype = draw(dtypes) - embed_dim = draw(helpers.ints(min_value=2, max_value=4)) - batch_size = draw(helpers.ints(min_value=1, max_value=2)) * 3 - seq_len = draw(helpers.ints(min_value=2, max_value=4)) - shape = ( - seq_len, - batch_size, - embed_dim, - ) - - heads = draw(helpers.ints(min_value=1, max_value=4)) - head_dim = embed_dim // heads - if head_dim * heads != embed_dim: - heads = 1 - head_dim = embed_dim - - if dtype[0] == "float32": - is_causal = False - else: - is_causal = draw(helpers.array_bools(size=1))[0] - - q = draw( - helpers.array_values(dtype=dtype[0], shape=shape, min_value=0.1, max_value=1) - ) - k = draw( - helpers.array_values(dtype=dtype[0], shape=shape, min_value=0.1, max_value=1) - ) - v = draw( - helpers.array_values(dtype=dtype[0], shape=shape, min_value=0.1, max_value=1) - ) - in_proj_weight = draw( - helpers.array_values( - dtype=dtype[0], - min_value=0.1, - max_value=1, - shape=(embed_dim * 3, embed_dim), - ) - ) - in_proj_bias = draw( - helpers.array_values( - dtype=dtype[0], - min_value=0.1, - max_value=1, - shape=(embed_dim * 3,), - ) - ) - - if random.randint(0, 1) == 0: - use_separate_proj_weight = True - q_proj_weight = draw( - helpers.array_values( - dtype=dtype[0], - min_value=0.1, - max_value=1, - shape=(embed_dim, embed_dim), - ) - ) - k_proj_weight = draw( - helpers.array_values( - dtype=dtype[0], - min_value=0.1, - max_value=1, - shape=(embed_dim, embed_dim), - ) - ) - v_proj_weight = draw( - helpers.array_values( - dtype=dtype[0], - min_value=0.1, - max_value=1, - shape=(embed_dim, embed_dim), - ) - ) - else: - use_separate_proj_weight = False - q_proj_weight = None - k_proj_weight = None - v_proj_weight = None - - out_proj_weight = draw( - helpers.array_values( - dtype=dtype[0], - min_value=0.1, - max_value=1, - shape=(embed_dim, embed_dim), - ) - ) - out_proj_bias = draw( - helpers.array_values( - dtype=dtype[0], - min_value=0.1, - max_value=1, - shape=(embed_dim,), - ) - ) - bias_k = random.choice( - [ - draw( - helpers.array_values( - dtype=dtype[0], - min_value=0.1, - max_value=1, - shape=(embed_dim,), - ) - ), - None, - ] - ) - bias_v = bias_k - - if bias_k is None: - static_k = random.choice( - [ - draw( - helpers.array_values( - dtype=dtype[0], - min_value=0.1, - max_value=1, - shape=(batch_size * heads, seq_len, head_dim), - ) - ), - None, - ] - ) - static_v = static_k - else: - static_k = None - static_v = None - - attn_mask = ivy.ones((seq_len, seq_len), dtype=dtype[0]) - key_padding_mask = random.choice( - [ - ivy.random_normal(shape=(seq_len, seq_len), dtype=dtype[0]) > 0, - None, - ] - ) - - return ( - dtype, - q, - k, - v, - heads, - use_separate_proj_weight, - embed_dim, - in_proj_weight, - in_proj_bias, - out_proj_weight, - out_proj_bias, - q_proj_weight, - k_proj_weight, - v_proj_weight, - bias_k, - bias_v, - static_k, - static_v, - attn_mask, - key_padding_mask, - is_causal, +# softshrink +@handle_frontend_test( + fn_tree="torch.nn.functional.softshrink", + dtype_and_input=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + ), + lambd=helpers.floats(min_value=0, max_value=1, exclude_min=True), +) +def test_torch_softshrink( + *, + dtype_and_input, + lambd, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, input = dtype_and_input + _filter_dtypes(input_dtype) + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=input[0], + lambd=lambd, ) -# gumbel_softmax +# softsign @handle_frontend_test( - fn_tree="torch.nn.functional.gumbel_softmax", + fn_tree="torch.nn.functional.softsign", + dtype_and_input=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + ), +) +def test_torch_softsign( + *, + dtype_and_input, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, input = dtype_and_input + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=input[0], + ) + + +# tanh +@handle_frontend_test( + fn_tree="torch.nn.functional.tanh", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), ), - tau=st.floats(min_value=0), - hard=st.booleans(), - eps=st.floats(min_value=0, max_value=1), - dim=st.integers(), - test_with_out=st.just(False), - test_inplace=st.booleans(), ) -def test_torch_gumbel_softmax( +def test_torch_tanh( *, dtype_and_x, - tau, - hard, - eps, - dim, on_device, fn_tree, frontend, @@ -1395,193 +1490,106 @@ def test_torch_gumbel_softmax( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - test_values=False, - logits=x[0], - tau=tau, - hard=hard, - eps=eps, - dim=dim, + atol=1e-2, + input=x[0], ) -# multi_head_attention_forward +# tanhshrink @handle_frontend_test( - fn_tree="torch.nn.functional.multi_head_attention_forward", - dtype_mha_args=mha_forward_args( - dtypes=helpers.get_dtypes("valid"), + fn_tree="torch.nn.functional.tanhshrink", + dtype_and_input=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), ), - add_zero_attn=st.just(False), - dropout_p=st.sampled_from([0.0, 0.1, 0.2]), - training=st.booleans(), - need_weights=st.booleans(), - average_attn_weights=st.booleans(), - test_with_out=st.just(False), ) -def test_torch_multi_head_attention_forward( +def test_torch_tanhshrink( *, + dtype_and_input, on_device, fn_tree, frontend, test_flags, - dtype_mha_args, - add_zero_attn, - dropout_p, - training, - need_weights, - average_attn_weights, backend_fw, ): - ( - dtype, - q, - k, - v, - heads, - use_separate_proj_weight, - embed_dim, - in_proj_weight, - in_proj_bias, - out_proj_weight, - out_proj_bias, - q_proj_weight, - k_proj_weight, - v_proj_weight, - bias_k, - bias_v, - static_k, - static_v, - attn_mask, - key_padding_mask, - is_causal, - ) = dtype_mha_args - + input_dtype, input = dtype_and_input helpers.test_frontend_function( - input_dtypes=dtype, + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - test_values=not training or dropout_p == 0.0, - query=q, - key=k, - value=v, - embed_dim_to_check=embed_dim, - num_heads=heads, - in_proj_weight=in_proj_weight, - in_proj_bias=in_proj_bias, - bias_k=bias_k, - bias_v=bias_v, - add_zero_attn=add_zero_attn, - dropout_p=dropout_p, - out_proj_weight=out_proj_weight, - out_proj_bias=out_proj_bias, - training=training, - key_padding_mask=key_padding_mask, - need_weights=need_weights, - attn_mask=attn_mask, - use_separate_proj_weight=use_separate_proj_weight, - q_proj_weight=q_proj_weight, - k_proj_weight=k_proj_weight, - v_proj_weight=v_proj_weight, - static_k=static_k, - static_v=static_v, - average_attn_weights=average_attn_weights, - is_causal=is_causal, + input=input[0], ) -@st.composite -def _x_and_scaled_attention(draw, dtypes): - dtype = draw(dtypes) - num_queries = draw(helpers.ints(min_value=2, max_value=4)) - num_keys = draw(helpers.ints(min_value=2, max_value=4)) - feat_dim = draw(helpers.ints(min_value=2, max_value=4)) - batch_size = draw(helpers.ints(min_value=1, max_value=2)) - q_shape = (batch_size,) + (num_queries,) + (feat_dim,) - k_shape = (batch_size,) + (num_keys,) + (feat_dim,) - v_shape = (batch_size,) + (num_keys,) + (feat_dim,) - mask_shape = (batch_size,) + (num_queries,) + (num_keys,) - - query = draw( - helpers.array_values( - dtype=dtype[0], - shape=q_shape, - min_value=0, - max_value=1e2, - large_abs_safety_factor=7, - small_abs_safety_factor=7, - safety_factor_scale="linear", - ) - ) - key = draw( - helpers.array_values( - dtype=dtype[0], - shape=k_shape, - min_value=0, - max_value=1e2, - large_abs_safety_factor=7, - small_abs_safety_factor=7, - safety_factor_scale="linear", - ) - ) - value = draw( - helpers.array_values( - dtype=dtype[0], - shape=v_shape, - min_value=0, - max_value=1e2, - large_abs_safety_factor=7, - small_abs_safety_factor=7, - safety_factor_scale="linear", - ) - ) - mask = draw( - helpers.array_values( - dtype="bool", - shape=mask_shape, - ) - | st.none() +# threshold +@handle_frontend_test( + fn_tree="torch.nn.functional.threshold", + dtype_and_input=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + ), + threshold=helpers.floats(min_value=0.0, max_value=1.0), + value=helpers.ints(min_value=5, max_value=20), + test_with_out=st.just(False), + test_inplace=st.booleans(), +) +def test_torch_threshold( + *, + dtype_and_input, + threshold, + value, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, input = dtype_and_input + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=input[0], + threshold=threshold, + value=value, ) - return dtype, query, key, value, mask -# scaled_dot_product_attention +# threshold_ @handle_frontend_test( - fn_tree="torch.nn.functional.scaled_dot_product_attention", - dtype_q_k_v_mask=_x_and_scaled_attention( - dtypes=helpers.get_dtypes("float"), + fn_tree="torch.nn.functional.threshold_", + dtype_and_input=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), ), - dropout_p=st.floats(min_value=0, max_value=0.99), - is_causal=st.booleans(), + threshold=helpers.floats(min_value=0.0, max_value=1.0), + value=helpers.ints(min_value=5, max_value=20), + test_with_out=st.just(False), + test_inplace=st.booleans(), ) -def test_torch_scaled_dot_product_attention( +def test_torch_threshold_( *, - dtype_q_k_v_mask, - dropout_p, - is_causal, + dtype_and_input, + threshold, + value, on_device, fn_tree, frontend, test_flags, backend_fw, ): - (dtype, query, key, value, mask) = dtype_q_k_v_mask - is_causal = is_causal if mask is None else False + input_dtype, input = dtype_and_input helpers.test_frontend_function( - input_dtypes=dtype, + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - test_values=dropout_p == 0.0, - rtol=1e-05, - atol=1e-05, - query=query, - key=key, + input=input[0], + threshold=threshold, value=value, - attn_mask=mask, - dropout_p=dropout_p, - is_causal=is_causal, ) diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_norms.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_norms.py index a6956ed2ccf5e..a0b09c01ac767 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_norms.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_norms.py @@ -6,6 +6,99 @@ from ivy_tests.test_ivy.helpers import handle_frontend_test +# --- Helpers --- # +# --------------- # + + +@st.composite +def _generate_data_layer_norm( + draw, + *, + available_dtypes, + large_abs_safety_factor=100, + small_abs_safety_factor=100, + safety_factor_scale="log", + min_num_dims=1, + max_num_dims=5, + valid_axis=True, + allow_neg_axes=False, + max_axes_size=1, + force_int_axis=True, + ret_shape=True, + abs_smallest_val=1000, + allow_inf=False, + allow_nan=False, + exclude_min=True, + exclude_max=True, + min_value=-1000, + max_value=1000, + shared_dtype=False, + min_dim_size=1, + max_dim_size=3, + group=False, +): + results = draw( + helpers.dtype_values_axis( + available_dtypes=available_dtypes, + large_abs_safety_factor=large_abs_safety_factor, + small_abs_safety_factor=small_abs_safety_factor, + safety_factor_scale=safety_factor_scale, + min_num_dims=min_num_dims, + max_num_dims=max_num_dims, + min_dim_size=min_dim_size, + max_dim_size=max_dim_size, + valid_axis=valid_axis, + allow_neg_axes=allow_neg_axes, + max_axes_size=max_axes_size, + force_int_axis=force_int_axis, + ret_shape=ret_shape, + ) + ) + + dtype, values, axis, shape = results + + if group: + channel_size = shape[1] + group_list = [*range(1, max_dim_size)] + group_list = list(filter(lambda x: (channel_size % x == 0), group_list)) + group_size = draw(st.sampled_from(group_list)) + weight_shape = [shape[1]] + bias_shape = [shape[1]] + else: + weight_shape = shape[axis:] + bias_shape = shape[axis:] + + arg_dict = { + "available_dtypes": dtype, + "abs_smallest_val": abs_smallest_val, + "min_value": min_value, + "max_value": max_value, + "large_abs_safety_factor": large_abs_safety_factor, + "small_abs_safety_factor": small_abs_safety_factor, + "allow_inf": allow_inf, + "allow_nan": allow_nan, + "exclude_min": exclude_min, + "exclude_max": exclude_max, + "min_num_dims": min_num_dims, + "max_num_dims": max_num_dims, + "shared_dtype": shared_dtype, + "ret_shape": False, + } + + results_weight = draw(helpers.dtype_and_values(shape=weight_shape, **arg_dict)) + results_bias = draw(helpers.dtype_and_values(shape=bias_shape, **arg_dict)) + results_new_std = draw(helpers.dtype_and_values(shape=shape, **arg_dict)) + + _, weight_values = results_weight + _, bias_values = results_bias + _, new_std_values = results_new_std + + axis = shape[axis:] + if group: + return dtype, values, weight_values, bias_values, group_size + return dtype, values, axis, weight_values, bias_values, new_std_values + + @st.composite def _instance_and_batch_norm_helper(draw, *, min_num_dims=1, min_dim_size=1): x_dtype, x, shape = draw( @@ -56,6 +149,10 @@ def _instance_and_batch_norm_helper(draw, *, min_num_dims=1, min_dim_size=1): return x_dtype, x[-1], others[0], others[1], mean[0], variance[0], momentum, eps +# --- Main --- # +# ------------ # + + @handle_frontend_test( fn_tree="torch.nn.functional.batch_norm", data=_instance_and_batch_norm_helper(min_num_dims=2, min_dim_size=2), @@ -90,6 +187,45 @@ def test_torch_batch_norm( ) +# group_norm +@handle_frontend_test( + fn_tree="torch.nn.functional.group_norm", + dtype_x_and_axis=_generate_data_layer_norm( + available_dtypes=helpers.get_dtypes("float"), + min_num_dims=2, + max_num_dims=3, + min_dim_size=2, + max_dim_size=4, + group=True, + ), + epsilon=st.floats(min_value=0.01, max_value=0.1), + test_with_out=st.just(False), +) +def test_torch_group_norm( + dtype_x_and_axis, + epsilon, + frontend, + test_flags, + fn_tree, + backend_fw, +): + dtype, x, weight, bias, group_size = dtype_x_and_axis + helpers.test_frontend_function( + input_dtypes=dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + atol=1e-1, + rtol=1e-1, + input=x[0], + num_groups=group_size, + weight=weight[0], + bias=bias[0], + eps=epsilon, + ) + + @handle_frontend_test( fn_tree="torch.nn.functional.instance_norm", data=_instance_and_batch_norm_helper(min_num_dims=3, min_dim_size=2), @@ -124,95 +260,6 @@ def test_torch_instance_norm( ) -@st.composite -def _generate_data_layer_norm( - draw, - *, - available_dtypes, - large_abs_safety_factor=100, - small_abs_safety_factor=100, - safety_factor_scale="log", - min_num_dims=1, - max_num_dims=5, - valid_axis=True, - allow_neg_axes=False, - max_axes_size=1, - force_int_axis=True, - ret_shape=True, - abs_smallest_val=1000, - allow_inf=False, - allow_nan=False, - exclude_min=True, - exclude_max=True, - min_value=-1000, - max_value=1000, - shared_dtype=False, - min_dim_size=1, - max_dim_size=3, - group=False, -): - results = draw( - helpers.dtype_values_axis( - available_dtypes=available_dtypes, - large_abs_safety_factor=large_abs_safety_factor, - small_abs_safety_factor=small_abs_safety_factor, - safety_factor_scale=safety_factor_scale, - min_num_dims=min_num_dims, - max_num_dims=max_num_dims, - min_dim_size=min_dim_size, - max_dim_size=max_dim_size, - valid_axis=valid_axis, - allow_neg_axes=allow_neg_axes, - max_axes_size=max_axes_size, - force_int_axis=force_int_axis, - ret_shape=ret_shape, - ) - ) - - dtype, values, axis, shape = results - - if group: - channel_size = shape[1] - group_list = [*range(1, max_dim_size)] - group_list = list(filter(lambda x: (channel_size % x == 0), group_list)) - group_size = draw(st.sampled_from(group_list)) - weight_shape = [shape[1]] - bias_shape = [shape[1]] - else: - weight_shape = shape[axis:] - bias_shape = shape[axis:] - - arg_dict = { - "available_dtypes": dtype, - "abs_smallest_val": abs_smallest_val, - "min_value": min_value, - "max_value": max_value, - "large_abs_safety_factor": large_abs_safety_factor, - "small_abs_safety_factor": small_abs_safety_factor, - "allow_inf": allow_inf, - "allow_nan": allow_nan, - "exclude_min": exclude_min, - "exclude_max": exclude_max, - "min_num_dims": min_num_dims, - "max_num_dims": max_num_dims, - "shared_dtype": shared_dtype, - "ret_shape": False, - } - - results_weight = draw(helpers.dtype_and_values(shape=weight_shape, **arg_dict)) - results_bias = draw(helpers.dtype_and_values(shape=bias_shape, **arg_dict)) - results_new_std = draw(helpers.dtype_and_values(shape=shape, **arg_dict)) - - _, weight_values = results_weight - _, bias_values = results_bias - _, new_std_values = results_new_std - - axis = shape[axis:] - if group: - return dtype, values, weight_values, bias_values, group_size - return dtype, values, axis, weight_values, bias_values, new_std_values - - @handle_frontend_test( fn_tree="torch.nn.functional.layer_norm", dtype_x_and_axis=_generate_data_layer_norm( @@ -246,42 +293,3 @@ def test_torch_layer_norm( bias=bias[0], eps=epsilon, ) - - -# group_norm -@handle_frontend_test( - fn_tree="torch.nn.functional.group_norm", - dtype_x_and_axis=_generate_data_layer_norm( - available_dtypes=helpers.get_dtypes("float"), - min_num_dims=2, - max_num_dims=3, - min_dim_size=2, - max_dim_size=4, - group=True, - ), - epsilon=st.floats(min_value=0.01, max_value=0.1), - test_with_out=st.just(False), -) -def test_torch_group_norm( - dtype_x_and_axis, - epsilon, - frontend, - test_flags, - fn_tree, - backend_fw, -): - dtype, x, weight, bias, group_size = dtype_x_and_axis - helpers.test_frontend_function( - input_dtypes=dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - atol=1e-1, - rtol=1e-1, - input=x[0], - num_groups=group_size, - weight=weight[0], - bias=bias[0], - eps=epsilon, - ) diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_pooling_functions.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_pooling_functions.py index cb7c27a1ce5a7..be4dbfd541d3d 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_pooling_functions.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_pooling_functions.py @@ -7,6 +7,22 @@ import math +def calculate_same_padding(kernel_size, stride, shape): + padding = tuple( + [ + max( + 0, + math.ceil(((shape[i] - 1) * stride[i] + kernel_size[i] - shape[i]) / 2), + ) + for i in range(len(kernel_size)) + ] + ) + if all([kernel_size[i] / 2 >= padding[i] for i in range(len(kernel_size))]): + if is_same_padding(padding, stride, kernel_size, shape): + return padding + return [0] * len(shape) + + def is_same_padding(padding, stride, kernel_size, input_shape): output_shape = tuple( [ @@ -22,20 +38,132 @@ def is_same_padding(padding, stride, kernel_size, input_shape): ) -def calculate_same_padding(kernel_size, stride, shape): - padding = tuple( - [ - max( - 0, - math.ceil(((shape[i] - 1) * stride[i] + kernel_size[i] - shape[i]) / 2), - ) - for i in range(len(kernel_size)) - ] +# adaptive_avg_pool1d +@handle_frontend_test( + fn_tree="torch.nn.functional.adaptive_avg_pool1d", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_num_dims=2, + max_num_dims=3, + min_dim_size=5, + max_value=100, + min_value=-100, + ), + output_size=helpers.ints(min_value=1, max_value=10), + test_with_out=st.just(False), +) +def test_torch_adaptive_avg_pool1d( + *, + dtype_and_x, + output_size, + on_device, + frontend, + test_flags, + fn_tree, + backend_fw, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=x[0], + output_size=output_size, + atol=1e-2, + ) + + +# adaptive_avg_pool2d +@handle_frontend_test( + fn_tree="torch.nn.functional.adaptive_avg_pool2d", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_num_dims=3, + max_num_dims=4, + min_dim_size=5, + max_value=100, + min_value=-100, + ), + output_size=st.one_of( + st.tuples( + helpers.ints(min_value=1, max_value=10), + helpers.ints(min_value=1, max_value=10), + ), + helpers.ints(min_value=1, max_value=10), + ), + test_with_out=st.just(False), +) +def test_torch_adaptive_avg_pool2d( + *, + dtype_and_x, + output_size, + on_device, + frontend, + test_flags, + fn_tree, + backend_fw, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=x[0], + output_size=output_size, + atol=1e-2, + ) + + +# adaptive_max_pool2d +@handle_frontend_test( + fn_tree="torch.nn.functional.adaptive_max_pool2d", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_num_dims=3, + max_num_dims=4, + min_dim_size=5, + # Setting max and min value because this operation in paddle is not + # numerically stable + max_value=100, + min_value=-100, + ), + output_size=st.one_of( + st.tuples( + helpers.ints(min_value=1, max_value=10), + helpers.ints(min_value=1, max_value=10), + ), + helpers.ints(min_value=1, max_value=10), + ), + test_with_out=st.just(False), +) +def test_torch_adaptive_max_pool2d( + *, + dtype_and_x, + output_size, + on_device, + frontend, + test_flags, + fn_tree, + backend_fw, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=x[0], + output_size=output_size, + atol=1e-2, ) - if all([kernel_size[i] / 2 >= padding[i] for i in range(len(kernel_size))]): - if is_same_padding(padding, stride, kernel_size, shape): - return padding - return [0] * len(shape) # avg_pool1d @@ -192,21 +320,22 @@ def test_torch_avg_pool3d( ) -# max_pool1d +# avg_pool1d @handle_frontend_test( - fn_tree="torch.nn.functional.max_pool1d", + fn_tree="torch.nn.functional.lp_pool1d", dtype_x_k_s=helpers.arrays_for_pooling( min_dims=3, max_dims=3, min_side=1, max_side=3, - only_explicit_padding=True, data_format="channel_first", ), + norm_type=helpers.ints(min_value=1, max_value=6), test_with_out=st.just(False), ) -def test_torch_max_pool1d( +def test_torch_lp_pool1d( dtype_x_k_s, + norm_type, *, test_flags, frontend, @@ -214,8 +343,8 @@ def test_torch_max_pool1d( fn_tree, on_device, ): - input_dtype, x, kernel_size, stride, padding = dtype_x_k_s - padding = (padding[0][0],) + input_dtype, x, kernel_size, stride, _ = dtype_x_k_s + helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -224,30 +353,29 @@ def test_torch_max_pool1d( fn_tree=fn_tree, on_device=on_device, input=x[0], - kernel_size=kernel_size, - stride=stride, - padding=padding, + norm_type=norm_type if norm_type > 0 else 1, + kernel_size=kernel_size[0], + stride=stride[0], + ceil_mode=False, ) -# max_pool2d +# avg_pool2d @handle_frontend_test( - fn_tree="torch.nn.functional.max_pool2d", - x_k_s_p=helpers.arrays_for_pooling( + fn_tree="torch.nn.functional.lp_pool2d", + dtype_x_k_s=helpers.arrays_for_pooling( min_dims=4, max_dims=4, min_side=1, max_side=4, - only_explicit_padding=True, - return_dilation=True, data_format="channel_first", ), + norm_type=helpers.ints(min_value=1, max_value=6), test_with_out=st.just(False), - ceil_mode=st.just(True), ) -def test_torch_max_pool2d( - x_k_s_p, - ceil_mode, +def test_torch_lp_pool2d( + dtype_x_k_s, + norm_type, *, test_flags, frontend, @@ -255,169 +383,37 @@ def test_torch_max_pool2d( fn_tree, on_device, ): - dtype, x, kernel, stride, pad, dilation = x_k_s_p - pad = (pad[0][0], pad[1][0]) - - helpers.test_frontend_function( - input_dtypes=dtype, - backend_to_test=backend_fw, - test_flags=test_flags, - frontend=frontend, - fn_tree=fn_tree, - on_device=on_device, - input=x[0], - kernel_size=kernel, - stride=stride, - padding=pad, - dilation=dilation, - ceil_mode=ceil_mode, - ) - - -# adaptive_avg_pool1d -@handle_frontend_test( - fn_tree="torch.nn.functional.adaptive_avg_pool1d", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_num_dims=2, - max_num_dims=3, - min_dim_size=5, - max_value=100, - min_value=-100, - ), - output_size=helpers.ints(min_value=1, max_value=10), - test_with_out=st.just(False), -) -def test_torch_adaptive_avg_pool1d( - *, - dtype_and_x, - output_size, - on_device, - frontend, - test_flags, - fn_tree, - backend_fw, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=x[0], - output_size=output_size, - atol=1e-2, - ) - - -# adaptive_avg_pool2d -@handle_frontend_test( - fn_tree="torch.nn.functional.adaptive_avg_pool2d", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_num_dims=3, - max_num_dims=4, - min_dim_size=5, - max_value=100, - min_value=-100, - ), - output_size=st.one_of( - st.tuples( - helpers.ints(min_value=1, max_value=10), - helpers.ints(min_value=1, max_value=10), - ), - helpers.ints(min_value=1, max_value=10), - ), - test_with_out=st.just(False), -) -def test_torch_adaptive_avg_pool2d( - *, - dtype_and_x, - output_size, - on_device, - frontend, - test_flags, - fn_tree, - backend_fw, -): - input_dtype, x = dtype_and_x + input_dtype, x, kernel_size, stride, _ = dtype_x_k_s helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, - frontend=frontend, test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=x[0], - output_size=output_size, - atol=1e-2, - ) - - -# adaptive_max_pool2d -@handle_frontend_test( - fn_tree="torch.nn.functional.adaptive_max_pool2d", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_num_dims=3, - max_num_dims=4, - min_dim_size=5, - # Setting max and min value because this operation in paddle is not - # numerically stable - max_value=100, - min_value=-100, - ), - output_size=st.one_of( - st.tuples( - helpers.ints(min_value=1, max_value=10), - helpers.ints(min_value=1, max_value=10), - ), - helpers.ints(min_value=1, max_value=10), - ), - test_with_out=st.just(False), -) -def test_torch_adaptive_max_pool2d( - *, - dtype_and_x, - output_size, - on_device, - frontend, - test_flags, - fn_tree, - backend_fw, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, frontend=frontend, - test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, input=x[0], - output_size=output_size, - atol=1e-2, + norm_type=norm_type if norm_type > 0 else 1, + kernel_size=kernel_size, + stride=stride[0], + ceil_mode=False, ) -# avg_pool1d +# max_pool1d @handle_frontend_test( - fn_tree="torch.nn.functional.lp_pool1d", + fn_tree="torch.nn.functional.max_pool1d", dtype_x_k_s=helpers.arrays_for_pooling( min_dims=3, max_dims=3, min_side=1, max_side=3, + only_explicit_padding=True, data_format="channel_first", ), - norm_type=helpers.ints(min_value=1, max_value=6), test_with_out=st.just(False), ) -def test_torch_lp_pool1d( +def test_torch_max_pool1d( dtype_x_k_s, - norm_type, *, test_flags, frontend, @@ -425,8 +421,8 @@ def test_torch_lp_pool1d( fn_tree, on_device, ): - input_dtype, x, kernel_size, stride, _ = dtype_x_k_s - + input_dtype, x, kernel_size, stride, padding = dtype_x_k_s + padding = (padding[0][0],) helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -435,29 +431,30 @@ def test_torch_lp_pool1d( fn_tree=fn_tree, on_device=on_device, input=x[0], - norm_type=norm_type if norm_type > 0 else 1, - kernel_size=kernel_size[0], - stride=stride[0], - ceil_mode=False, + kernel_size=kernel_size, + stride=stride, + padding=padding, ) -# avg_pool2d +# max_pool2d @handle_frontend_test( - fn_tree="torch.nn.functional.lp_pool2d", - dtype_x_k_s=helpers.arrays_for_pooling( + fn_tree="torch.nn.functional.max_pool2d", + x_k_s_p=helpers.arrays_for_pooling( min_dims=4, max_dims=4, min_side=1, max_side=4, + only_explicit_padding=True, + return_dilation=True, data_format="channel_first", ), - norm_type=helpers.ints(min_value=1, max_value=6), test_with_out=st.just(False), + ceil_mode=st.just(True), ) -def test_torch_lp_pool2d( - dtype_x_k_s, - norm_type, +def test_torch_max_pool2d( + x_k_s_p, + ceil_mode, *, test_flags, frontend, @@ -465,17 +462,20 @@ def test_torch_lp_pool2d( fn_tree, on_device, ): - input_dtype, x, kernel_size, stride, _ = dtype_x_k_s + dtype, x, kernel, stride, pad, dilation = x_k_s_p + pad = (pad[0][0], pad[1][0]) + helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=dtype, backend_to_test=backend_fw, test_flags=test_flags, frontend=frontend, fn_tree=fn_tree, on_device=on_device, input=x[0], - norm_type=norm_type if norm_type > 0 else 1, - kernel_size=kernel_size, - stride=stride[0], - ceil_mode=False, + kernel_size=kernel, + stride=stride, + padding=pad, + dilation=dilation, + ceil_mode=ceil_mode, ) diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_sparse_functions.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_sparse_functions.py index ae8b3edcec6a3..24dfb8dc61ac4 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_sparse_functions.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_sparse_functions.py @@ -8,6 +8,28 @@ inf = float("inf") +# --- Helpers --- # +# --------------- # + + +@st.composite +def get_dtype_num_classes(draw): + dtype_and_x = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("integer"), + num_arrays=1, + min_value=1, + max_value=10, + max_num_dims=0, + ) + ) + input_dtype, x = dtype_and_x + print(max(x)) + num_classes = draw(st.integers(min_value=max(x) + 1, max_value=10)) + + return (num_classes, dtype_and_x) + + # embedding @handle_frontend_test( fn_tree="torch.nn.functional.embedding", @@ -47,24 +69,6 @@ def test_torch_embedding( ) -@st.composite -def get_dtype_num_classes(draw): - dtype_and_x = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("integer"), - num_arrays=1, - min_value=1, - max_value=10, - max_num_dims=0, - ) - ) - input_dtype, x = dtype_and_x - print(max(x)) - num_classes = draw(st.integers(min_value=max(x) + 1, max_value=10)) - - return (num_classes, dtype_and_x) - - # one_hot @handle_frontend_test( fn_tree="torch.nn.functional.one_hot", diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_vision_functions.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_vision_functions.py index 2172ee14f7807..3e534c3852bb6 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_vision_functions.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_vision_functions.py @@ -9,75 +9,51 @@ ) -# pixel_shuffle -@handle_frontend_test( - fn_tree="torch.nn.functional.pixel_shuffle", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_value=0, - min_num_dims=4, - max_num_dims=4, - min_dim_size=1, - ), - factor=helpers.ints(min_value=1), -) -def test_torch_pixel_shuffle( - *, - dtype_and_x, - factor, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, x = dtype_and_x - assume(ivy.shape(x[0])[1] % (factor**2) == 0) - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=x[0], - upscale_factor=factor, - ) +# --- Helpers --- # +# --------------- # -@handle_frontend_test( - fn_tree="torch.nn.functional.pixel_unshuffle", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_value=0, - min_num_dims=4, - max_num_dims=4, - min_dim_size=1, - ), - factor=helpers.ints(min_value=1), -) -def test_torch_pixel_unshuffle( - *, - dtype_and_x, - factor, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, x = dtype_and_x - assume((ivy.shape(x[0])[2] % factor == 0) & (ivy.shape(x[0])[3] % factor == 0)) - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=x[0], - downscale_factor=factor, - ) +@st.composite +def _affine_grid_helper(draw): + align_corners = draw(st.booleans()) + dims = draw(st.integers(4, 5)) + if dims == 4: + size = draw( + st.tuples( + st.integers(1, 20), + st.integers(1, 20), + st.integers(2, 20), + st.integers(2, 20), + ) + ) + theta_dtype, theta = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_value=0, + max_value=1, + shape=(size[0], 2, 3), + ) + ) + return theta_dtype, theta[0], size, align_corners + else: + size = draw( + st.tuples( + st.integers(1, 20), + st.integers(1, 20), + st.integers(2, 20), + st.integers(2, 20), + st.integers(2, 20), + ) + ) + theta_dtype, theta = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_value=0, + max_value=1, + shape=(size[0], 3, 4), + ) + ) + return theta_dtype, theta[0], size, align_corners @st.composite @@ -136,20 +112,25 @@ def _pad_helper(draw): return dtype, input[0], padding, value, mode +# --- Main --- # +# ------------ # + + @handle_frontend_test( - fn_tree="torch.nn.functional.pad", - dtype_and_input_and_other=_pad_helper(), + fn_tree="torch.nn.functional.affine_grid", + dtype_and_input_and_other=_affine_grid_helper(), ) -def test_torch_pad( +def test_torch_affine_grid( *, dtype_and_input_and_other, on_device, + backend_fw, fn_tree, frontend, test_flags, - backend_fw, ): - dtype, input, padding, value, mode = dtype_and_input_and_other + dtype, theta, size, align_corners = dtype_and_input_and_other + helpers.test_frontend_function( input_dtypes=dtype, backend_to_test=backend_fw, @@ -157,10 +138,9 @@ def test_torch_pad( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=input, - pad=padding, - mode=mode, - value=value, + theta=theta, + size=size, + align_corners=align_corners, ) @@ -208,11 +188,10 @@ def test_torch_interpolate( @handle_frontend_test( - fn_tree="torch.nn.functional.upsample", - dtype_and_input_and_other=_interp_args(), - number_positional_args=st.just(2), + fn_tree="torch.nn.functional.pad", + dtype_and_input_and_other=_pad_helper(), ) -def test_torch_upsample( +def test_torch_pad( *, dtype_and_input_and_other, on_device, @@ -221,7 +200,45 @@ def test_torch_upsample( test_flags, backend_fw, ): - input_dtype, x, mode, size, align_corners = dtype_and_input_and_other + dtype, input, padding, value, mode = dtype_and_input_and_other + helpers.test_frontend_function( + input_dtypes=dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=input, + pad=padding, + mode=mode, + value=value, + ) + + +# pixel_shuffle +@handle_frontend_test( + fn_tree="torch.nn.functional.pixel_shuffle", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_value=0, + min_num_dims=4, + max_num_dims=4, + min_dim_size=1, + ), + factor=helpers.ints(min_value=1), +) +def test_torch_pixel_shuffle( + *, + dtype_and_x, + factor, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x = dtype_and_x + assume(ivy.shape(x[0])[1] % (factor**2) == 0) helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -230,18 +247,51 @@ def test_torch_upsample( fn_tree=fn_tree, on_device=on_device, input=x[0], - size=size, - mode=mode, - align_corners=align_corners, + upscale_factor=factor, ) @handle_frontend_test( - fn_tree="torch.nn.functional.upsample_nearest", - dtype_and_input_and_other=_interp_args(mode="nearest"), + fn_tree="torch.nn.functional.pixel_unshuffle", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_value=0, + min_num_dims=4, + max_num_dims=4, + min_dim_size=1, + ), + factor=helpers.ints(min_value=1), +) +def test_torch_pixel_unshuffle( + *, + dtype_and_x, + factor, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x = dtype_and_x + assume((ivy.shape(x[0])[2] % factor == 0) & (ivy.shape(x[0])[3] % factor == 0)) + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=x[0], + downscale_factor=factor, + ) + + +@handle_frontend_test( + fn_tree="torch.nn.functional.upsample", + dtype_and_input_and_other=_interp_args(), number_positional_args=st.just(2), ) -def test_torch_upsample_nearest( +def test_torch_upsample( *, dtype_and_input_and_other, on_device, @@ -250,7 +300,7 @@ def test_torch_upsample_nearest( test_flags, backend_fw, ): - input_dtype, x, _, size, _ = dtype_and_input_and_other + input_dtype, x, mode, size, align_corners = dtype_and_input_and_other helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -260,6 +310,8 @@ def test_torch_upsample_nearest( on_device=on_device, input=x[0], size=size, + mode=mode, + align_corners=align_corners, ) @@ -290,72 +342,28 @@ def test_torch_upsample_bilinear( ) -@st.composite -def _affine_grid_helper(draw): - align_corners = draw(st.booleans()) - dims = draw(st.integers(4, 5)) - if dims == 4: - size = draw( - st.tuples( - st.integers(1, 20), - st.integers(1, 20), - st.integers(2, 20), - st.integers(2, 20), - ) - ) - theta_dtype, theta = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_value=0, - max_value=1, - shape=(size[0], 2, 3), - ) - ) - return theta_dtype, theta[0], size, align_corners - else: - size = draw( - st.tuples( - st.integers(1, 20), - st.integers(1, 20), - st.integers(2, 20), - st.integers(2, 20), - st.integers(2, 20), - ) - ) - theta_dtype, theta = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_value=0, - max_value=1, - shape=(size[0], 3, 4), - ) - ) - return theta_dtype, theta[0], size, align_corners - - @handle_frontend_test( - fn_tree="torch.nn.functional.affine_grid", - dtype_and_input_and_other=_affine_grid_helper(), + fn_tree="torch.nn.functional.upsample_nearest", + dtype_and_input_and_other=_interp_args(mode="nearest"), + number_positional_args=st.just(2), ) -def test_torch_affine_grid( +def test_torch_upsample_nearest( *, dtype_and_input_and_other, on_device, - backend_fw, fn_tree, frontend, test_flags, + backend_fw, ): - dtype, theta, size, align_corners = dtype_and_input_and_other - + input_dtype, x, _, size, _ = dtype_and_input_and_other helpers.test_frontend_function( - input_dtypes=dtype, + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - theta=theta, + input=x[0], size=size, - align_corners=align_corners, ) diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_pointwise_ops.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_pointwise_ops.py index 774fea99453ce..a502749812473 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_pointwise_ops.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_pointwise_ops.py @@ -12,22 +12,108 @@ ) -# add +# --- Helpers --- # +# --------------- # + + +# float_power_helper +@st.composite +def _float_power_helper(draw, *, available_dtypes=None): + if available_dtypes is None: + available_dtypes = helpers.get_dtypes("numeric") + dtype1, x1 = draw( + helpers.dtype_and_values( + available_dtypes=available_dtypes, + small_abs_safety_factor=16, + large_abs_safety_factor=16, + safety_factor_scale="log", + ) + ) + dtype2 = draw(helpers.get_dtypes("numeric")) + if ivy.is_int_dtype(dtype2[0]): + min_value = 0 + else: + min_value = -10 + dtype2, x2 = draw( + helpers.dtype_and_values( + min_value=min_value, + max_value=10, + dtype=dtype2, + ) + ) + return (dtype1[0], dtype2[0]), (x1[0], x2[0]) + + +@st.composite +def _get_clip_inputs(draw): + shape = draw( + helpers.get_shape( + min_num_dims=1, max_num_dims=5, min_dim_size=2, max_dim_size=10 + ) + ) + x_dtype, x = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + shape=shape, + ) + ) + min = draw(st.booleans()) + if min: + max = draw(st.booleans()) + min = draw( + helpers.array_values( + dtype=x_dtype[0], shape=shape, min_value=0, max_value=25 + ) + ) + max = ( + draw( + helpers.array_values( + dtype=x_dtype[0], shape=shape, min_value=26, max_value=50 + ) + ) + if max + else None + ) + else: + min = None + max = draw( + helpers.array_values( + dtype=x_dtype[0], shape=shape, min_value=26, max_value=50 + ) + ) + return x_dtype, x, min, max + + +@st.composite +def _masked_fill_helper(draw): + cond, xs, dtypes = draw(_broadcastable_trio()) + if ivy.is_uint_dtype(dtypes[0]): + fill_value = draw(helpers.ints(min_value=0, max_value=5)) + elif ivy.is_int_dtype(dtypes[0]): + fill_value = draw(helpers.ints(min_value=-5, max_value=5)) + else: + fill_value = draw(helpers.floats(min_value=-5, max_value=5)) + return dtypes[0], xs[0], cond, fill_value + + +# --- Main --- # +# ------------ # + + +# abs @handle_frontend_test( - fn_tree="torch.add", + fn_tree="torch.abs", + aliases=["torch.absolute"], dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, + available_dtypes=helpers.get_dtypes("numeric", full=False), large_abs_safety_factor=2.5, small_abs_safety_factor=2.5, safety_factor_scale="log", ), - alpha=st.integers(min_value=1, max_value=5), ) -def test_torch_add( +def test_torch_abs( *, dtype_and_x, - alpha, on_device, fn_tree, frontend, @@ -42,21 +128,16 @@ def test_torch_add( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - rtol=1e-03, input=x[0], - other=x[1], - alpha=alpha, ) -# tan +# absolute @handle_frontend_test( - fn_tree="torch.tan", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - ), + fn_tree="torch.absolute", + dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), ) -def test_torch_tan( +def test_torch_absolute( *, dtype_and_x, on_device, @@ -65,7 +146,7 @@ def test_torch_tan( test_flags, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, input = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -73,19 +154,19 @@ def test_torch_tan( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=x[0], + input=input[0], ) -# atan +# acos @handle_frontend_test( - fn_tree="torch.atan", - aliases=["torch.arctan"], + fn_tree="torch.acos", + aliases=["torch.arccos"], dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), ), ) -def test_torch_atan( +def test_torch_acos( *, dtype_and_x, on_device, @@ -106,14 +187,14 @@ def test_torch_atan( ) -# tanh +# acosh @handle_frontend_test( - fn_tree="torch.tanh", + fn_tree="torch.acosh", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), ), ) -def test_torch_tanh( +def test_torch_acosh( *, dtype_and_x, on_device, @@ -134,20 +215,22 @@ def test_torch_tanh( ) -# abs +# add @handle_frontend_test( - fn_tree="torch.abs", - aliases=["torch.absolute"], + fn_tree="torch.add", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric", full=False), + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, large_abs_safety_factor=2.5, small_abs_safety_factor=2.5, safety_factor_scale="log", ), + alpha=st.integers(min_value=1, max_value=5), ) -def test_torch_abs( +def test_torch_add( *, dtype_and_x, + alpha, on_device, fn_tree, frontend, @@ -162,52 +245,69 @@ def test_torch_abs( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, + rtol=1e-03, input=x[0], + other=x[1], + alpha=alpha, ) -# cos +# addcdiv @handle_frontend_test( - fn_tree="torch.cos", + fn_tree="torch.addcdiv", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), + num_arrays=3, + large_abs_safety_factor=2.5, + small_abs_safety_factor=2.5, + safety_factor_scale="log", + shared_dtype=True, ), + value=st.floats(min_value=-100, max_value=100), ) -def test_torch_cos( - *, +def test_torch_addcdiv( dtype_and_x, - on_device, - fn_tree, + value, frontend, test_flags, + fn_tree, backend_fw, ): input_dtype, x = dtype_and_x + assume(not np.any(np.isclose(x[2], 0))) helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, - on_device=on_device, input=x[0], + tensor1=x[1], + tensor2=x[2], + value=value, + atol=1e-03, + out=None, ) -# sin +# addcmul @handle_frontend_test( - fn_tree="torch.sin", + fn_tree="torch.addcmul", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), + num_arrays=3, + min_value=-1e4, + max_value=1e4, + shared_dtype=True, ), + value=st.floats(min_value=-10, max_value=10), ) -def test_torch_sin( - *, +def test_torch_addcmul( dtype_and_x, - on_device, - fn_tree, + value, frontend, test_flags, + fn_tree, backend_fw, ): input_dtype, x = dtype_and_x @@ -217,29 +317,32 @@ def test_torch_sin( frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, - on_device=on_device, + atol=1e-2, input=x[0], + tensor1=x[1], + tensor2=x[2], + value=value, + out=None, ) -# acos +# angle @handle_frontend_test( - fn_tree="torch.acos", - aliases=["torch.arccos"], - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + fn_tree="torch.angle", + dtype_and_input=helpers.dtype_and_values( + available_dtypes=["float64", "complex64", "complex128"], ), ) -def test_torch_acos( +def test_torch_angle( *, - dtype_and_x, - on_device, - fn_tree, + dtype_and_input, frontend, test_flags, + fn_tree, backend_fw, + on_device, ): - input_dtype, x = dtype_and_x + input_dtype, x = dtype_and_input helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -251,14 +354,14 @@ def test_torch_acos( ) -# sinh +# arccos @handle_frontend_test( - fn_tree="torch.sinh", + fn_tree="torch.arccos", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), ), ) -def test_torch_sinh( +def test_torch_arccos( *, dtype_and_x, on_device, @@ -279,14 +382,14 @@ def test_torch_sinh( ) -# acosh +# arccosh @handle_frontend_test( - fn_tree="torch.acosh", + fn_tree="torch.arccosh", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), ), ) -def test_torch_acosh( +def test_torch_arccosh( *, dtype_and_x, on_device, @@ -307,14 +410,14 @@ def test_torch_acosh( ) -# arccos +# arcsin @handle_frontend_test( - fn_tree="torch.arccos", + fn_tree="torch.arcsin", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), ), ) -def test_torch_arccos( +def test_torch_arcsin( *, dtype_and_x, on_device, @@ -335,30 +438,23 @@ def test_torch_arccos( ) -# subtract +# arctan @handle_frontend_test( - fn_tree="torch.subtract", - aliases=["torch.sub"], - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, - large_abs_safety_factor=2.5, - small_abs_safety_factor=2.5, - safety_factor_scale="log", + fn_tree="torch.arctan", + dtype_and_input=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), ), - alpha=st.integers(min_value=1, max_value=5), ) -def test_torch_subtract( +def test_torch_arctan( *, - dtype_and_x, - alpha, - on_device, - fn_tree, + dtype_and_input, frontend, test_flags, + fn_tree, backend_fw, + on_device, ): - input_dtype, x = dtype_and_x + input_dtype, x = dtype_and_input helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -367,19 +463,18 @@ def test_torch_subtract( fn_tree=fn_tree, on_device=on_device, input=x[0], - other=x[1], - alpha=alpha, ) -# exp +# arctan2 @handle_frontend_test( - fn_tree="torch.exp", + fn_tree="torch.arctan2", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, ), ) -def test_torch_exp( +def test_torch_arctan2( *, dtype_and_x, on_device, @@ -397,18 +492,18 @@ def test_torch_exp( fn_tree=fn_tree, on_device=on_device, input=x[0], + other=x[1], ) -# asin +# arctanh @handle_frontend_test( - fn_tree="torch.asin", - aliases=["torch.arcsin"], + fn_tree="torch.arctanh", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), ), ) -def test_torch_asin( +def test_torch_arctanh( *, dtype_and_x, on_device, @@ -429,14 +524,15 @@ def test_torch_asin( ) -# arccosh +# asin @handle_frontend_test( - fn_tree="torch.arccosh", + fn_tree="torch.asin", + aliases=["torch.arcsin"], dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), ), ) -def test_torch_arccosh( +def test_torch_asin( *, dtype_and_x, on_device, @@ -457,14 +553,15 @@ def test_torch_arccosh( ) -# arcsin +# asinh @handle_frontend_test( - fn_tree="torch.arcsin", + fn_tree="torch.asinh", + aliases=["torch.arcsinh"], dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), ), ) -def test_torch_arcsin( +def test_torch_asinh( *, dtype_and_x, on_device, @@ -485,15 +582,15 @@ def test_torch_arcsin( ) -# asinh +# atan @handle_frontend_test( - fn_tree="torch.asinh", - aliases=["torch.arcsinh"], + fn_tree="torch.atan", + aliases=["torch.arctan"], dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), ), ) -def test_torch_asinh( +def test_torch_atan( *, dtype_and_x, on_device, @@ -514,14 +611,16 @@ def test_torch_asinh( ) -# cosh +# atan2 @handle_frontend_test( - fn_tree="torch.cosh", + fn_tree="torch.atan2", + aliases=["torch.arctan2"], dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, ), ) -def test_torch_cosh( +def test_torch_atan2( *, dtype_and_x, on_device, @@ -539,6 +638,7 @@ def test_torch_cosh( fn_tree=fn_tree, on_device=on_device, input=x[0], + other=x[1], ) @@ -571,14 +671,15 @@ def test_torch_atanh( ) -# arctanh +# bitwise_and @handle_frontend_test( - fn_tree="torch.arctanh", + fn_tree="torch.bitwise_and", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=st.just(("bool",)) | helpers.get_dtypes("integer"), + num_arrays=2, ), ) -def test_torch_arctanh( +def test_torch_bitwise_and( *, dtype_and_x, on_device, @@ -596,17 +697,19 @@ def test_torch_arctanh( fn_tree=fn_tree, on_device=on_device, input=x[0], + other=x[1], ) -# log2 @handle_frontend_test( - fn_tree="torch.log2", + fn_tree="torch.bitwise_left_shift", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("integer"), + num_arrays=2, + array_api_dtypes=True, ), ) -def test_torch_log2( +def test_torch_bitwise_left_shift( *, dtype_and_x, on_device, @@ -615,7 +718,12 @@ def test_torch_log2( test_flags, backend_fw, ): - input_dtype, input = dtype_and_x + input_dtype, x = dtype_and_x + # negative shifts will throw an exception + # shifts >= dtype witdth produce backend-defined behavior + x[1] = np.asarray( + np.clip(x[1], 0, np.iinfo(input_dtype[1]).bits - 1), dtype=input_dtype[1] + ) helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -623,18 +731,19 @@ def test_torch_log2( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=input[0], + input=x[0], + other=x[1], ) -# square @handle_frontend_test( - fn_tree="torch.square", + fn_tree="torch.bitwise_not", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=st.one_of(st.just(("bool",)), helpers.get_dtypes("integer")), + num_arrays=1, ), ) -def test_torch_square( +def test_torch_bitwise_not( *, dtype_and_x, on_device, @@ -643,7 +752,7 @@ def test_torch_square( test_flags, backend_fw, ): - input_dtype, input = dtype_and_x + input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -651,20 +760,18 @@ def test_torch_square( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=input[0], + input=x[0], ) -# atan2 @handle_frontend_test( - fn_tree="torch.atan2", - aliases=["torch.arctan2"], + fn_tree="torch.bitwise_xor", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=st.just(("bool",)) | helpers.get_dtypes("integer"), num_arrays=2, ), ) -def test_torch_atan2( +def test_torch_bitwise_or( *, dtype_and_x, on_device, @@ -686,15 +793,15 @@ def test_torch_atan2( ) -# arctan2 @handle_frontend_test( - fn_tree="torch.arctan2", + fn_tree="torch.bitwise_right_shift", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("integer"), num_arrays=2, + array_api_dtypes=True, ), ) -def test_torch_arctan2( +def test_torch_bitwise_right_shift( *, dtype_and_x, on_device, @@ -704,6 +811,11 @@ def test_torch_arctan2( backend_fw, ): input_dtype, x = dtype_and_x + # negative shifts will throw an exception + # shifts >= dtype witdth produce backend-defined behavior + x[1] = np.asarray( + np.clip(x[1], 0, np.iinfo(input_dtype[1]).bits - 1), dtype=input_dtype[1] + ) helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -716,15 +828,14 @@ def test_torch_arctan2( ) -# negative @handle_frontend_test( - fn_tree="torch.negative", + fn_tree="torch.bitwise_or", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), + available_dtypes=st.just(("bool",)) | helpers.get_dtypes("integer"), num_arrays=2, ), ) -def test_torch_negative( +def test_torch_bitwise_xor( *, dtype_and_x, on_device, @@ -742,18 +853,18 @@ def test_torch_negative( fn_tree=fn_tree, on_device=on_device, input=x[0], + other=x[1], ) -# bitwise_and +# ceil @handle_frontend_test( - fn_tree="torch.bitwise_and", + fn_tree="torch.ceil", dtype_and_x=helpers.dtype_and_values( - available_dtypes=st.just(("bool",)) | helpers.get_dtypes("integer"), - num_arrays=2, + available_dtypes=helpers.get_dtypes("float"), ), ) -def test_torch_bitwise_and( +def test_torch_ceil( *, dtype_and_x, on_device, @@ -771,84 +882,83 @@ def test_torch_bitwise_and( fn_tree=fn_tree, on_device=on_device, input=x[0], - other=x[1], ) +# clamp @handle_frontend_test( - fn_tree="torch.bitwise_not", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=st.one_of(st.just(("bool",)), helpers.get_dtypes("integer")), - num_arrays=1, - ), + fn_tree="torch.clamp", + aliases=["torch.clip"], + input_and_ranges=_get_clip_inputs(), ) -def test_torch_bitwise_not( +def test_torch_clamp( *, - dtype_and_x, + input_and_ranges, on_device, fn_tree, frontend, test_flags, backend_fw, ): - input_dtype, x = dtype_and_x + x_dtype, x, min, max = input_and_ranges helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=x_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, input=x[0], + min=min, + max=max, ) +# clip @handle_frontend_test( - fn_tree="torch.bitwise_or", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=st.just(("bool",)) | helpers.get_dtypes("integer"), - num_arrays=2, - ), + fn_tree="torch.clip", + input_and_ranges=_get_clip_inputs(), ) -def test_torch_bitwise_xor( +def test_torch_clip( *, - dtype_and_x, + input_and_ranges, on_device, fn_tree, frontend, test_flags, backend_fw, ): - input_dtype, x = dtype_and_x + x_dtype, x, min, max = input_and_ranges helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=x_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, input=x[0], - other=x[1], + min=min, + max=max, ) +# conj_physical @handle_frontend_test( - fn_tree="torch.bitwise_xor", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=st.just(("bool",)) | helpers.get_dtypes("integer"), - num_arrays=2, + fn_tree="torch.conj_physical", + dtype_and_input=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), ), ) -def test_torch_bitwise_or( +def test_torch_conj_physical( *, - dtype_and_x, - on_device, - fn_tree, + dtype_and_input, frontend, test_flags, + fn_tree, backend_fw, + on_device, ): - input_dtype, x = dtype_and_x + input_dtype, x = dtype_and_input helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -857,19 +967,22 @@ def test_torch_bitwise_or( fn_tree=fn_tree, on_device=on_device, input=x[0], - other=x[1], ) +# copysign @handle_frontend_test( - fn_tree="torch.bitwise_left_shift", + fn_tree="torch.copysign", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("integer"), + available_dtypes=helpers.get_dtypes("float"), num_arrays=2, - array_api_dtypes=True, + min_num_dims=1, + min_value=-100, + max_value=100, + shared_dtype=True, ), ) -def test_torch_bitwise_left_shift( +def test_torch_copysign( *, dtype_and_x, on_device, @@ -879,11 +992,6 @@ def test_torch_bitwise_left_shift( backend_fw, ): input_dtype, x = dtype_and_x - # negative shifts will throw an exception - # shifts >= dtype witdth produce backend-defined behavior - x[1] = np.asarray( - np.clip(x[1], 0, np.iinfo(input_dtype[1]).bits - 1), dtype=input_dtype[1] - ) helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -891,20 +999,20 @@ def test_torch_bitwise_left_shift( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, + atol=1e-03, input=x[0], other=x[1], ) +# cos @handle_frontend_test( - fn_tree="torch.bitwise_right_shift", + fn_tree="torch.cos", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("integer"), - num_arrays=2, - array_api_dtypes=True, + available_dtypes=helpers.get_dtypes("float"), ), ) -def test_torch_bitwise_right_shift( +def test_torch_cos( *, dtype_and_x, on_device, @@ -914,11 +1022,6 @@ def test_torch_bitwise_right_shift( backend_fw, ): input_dtype, x = dtype_and_x - # negative shifts will throw an exception - # shifts >= dtype witdth produce backend-defined behavior - x[1] = np.asarray( - np.clip(x[1], 0, np.iinfo(input_dtype[1]).bits - 1), dtype=input_dtype[1] - ) helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -927,18 +1030,17 @@ def test_torch_bitwise_right_shift( fn_tree=fn_tree, on_device=on_device, input=x[0], - other=x[1], ) -# log10 +# cosh @handle_frontend_test( - fn_tree="torch.log10", + fn_tree="torch.cosh", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), ), ) -def test_torch_log10( +def test_torch_cosh( *, dtype_and_x, on_device, @@ -959,15 +1061,16 @@ def test_torch_log10( ) -# trunc +# deg2rad @handle_frontend_test( - fn_tree="torch.trunc", - aliases=["torch.fix"], + fn_tree="torch.deg2rad", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), + min_value=-1000, + max_value=1000, ), ) -def test_torch_trunc( +def test_torch_deg2rad( *, dtype_and_x, on_device, @@ -988,21 +1091,34 @@ def test_torch_trunc( ) -# sqrt +# div @handle_frontend_test( - fn_tree="torch.sqrt", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), + fn_tree="torch.div", + aliases=["torch.divide"], + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, + large_abs_safety_factor=2.5, + small_abs_safety_factor=2.5, + safety_factor_scale="log", + ), + rounding_mode=st.sampled_from(["floor", "trunc"]) | st.none(), ) -def test_torch_sqrt( +def test_torch_div( *, dtype_and_x, + rounding_mode, on_device, fn_tree, frontend, test_flags, backend_fw, ): - input_dtype, input = dtype_and_x + input_dtype, x = dtype_and_x + assume(not np.any(np.isclose(x[1], 0))) + + # Absolute tolerance is 1, + # due to flooring can cause absolute error of 1 due to precision helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -1010,18 +1126,21 @@ def test_torch_sqrt( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=input[0], + atol=1, + input=x[0], + other=x[1], + rounding_mode=rounding_mode, ) -# real +# erf @handle_frontend_test( - fn_tree="torch.real", + fn_tree="torch.erf", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), + available_dtypes=helpers.get_dtypes("float"), ), ) -def test_torch_real( +def test_torch_erf( *, dtype_and_x, on_device, @@ -1030,7 +1149,7 @@ def test_torch_real( test_flags, backend_fw, ): - input_dtype, input = dtype_and_x + input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -1038,18 +1157,18 @@ def test_torch_real( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=input[0], + input=x[0], ) -# sign +# erfc @handle_frontend_test( - fn_tree="torch.sign", + fn_tree="torch.erfc", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), + available_dtypes=helpers.get_dtypes("float"), ), ) -def test_torch_sign( +def test_torch_erfc( *, dtype_and_x, on_device, @@ -1070,12 +1189,14 @@ def test_torch_sign( ) -# absolute +# exp @handle_frontend_test( - fn_tree="torch.absolute", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), + fn_tree="torch.exp", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + ), ) -def test_torch_absolute( +def test_torch_exp( *, dtype_and_x, on_device, @@ -1084,7 +1205,7 @@ def test_torch_absolute( test_flags, backend_fw, ): - input_dtype, input = dtype_and_x + input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -1092,18 +1213,18 @@ def test_torch_absolute( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=input[0], + input=x[0], ) -# logical not +# exp2 @handle_frontend_test( - fn_tree="torch.logical_not", + fn_tree="torch.exp2", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), num_arrays=1 + available_dtypes=helpers.get_dtypes("float"), ), ) -def test_torch_logical_not( +def test_torch_exp2( *, dtype_and_x, on_device, @@ -1124,14 +1245,14 @@ def test_torch_logical_not( ) -# logical and +# expm1 @handle_frontend_test( - fn_tree="torch.logical_and", + fn_tree="torch.expm1", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), num_arrays=2 + available_dtypes=helpers.get_dtypes("float"), ), ) -def test_torch_logical_and( +def test_torch_expm1( *, dtype_and_x, on_device, @@ -1149,27 +1270,32 @@ def test_torch_logical_and( fn_tree=fn_tree, on_device=on_device, input=x[0], - other=x[1], ) -# logical or +# flipud @handle_frontend_test( - fn_tree="torch.logical_or", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), num_arrays=2 + fn_tree="torch.flipud", + dtype_and_m=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_value=-100, + max_value=100, + min_num_dims=1, + max_num_dims=3, + min_dim_size=1, + max_dim_size=3, ), ) -def test_torch_logical_or( +def test_torch_flipud( *, - dtype_and_x, + dtype_and_m, on_device, fn_tree, frontend, test_flags, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, m = dtype_and_m helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -1177,20 +1303,15 @@ def test_torch_logical_or( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=x[0], - other=x[1], + input=m[0], ) -# logical xor @handle_frontend_test( - fn_tree="torch.logical_xor", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), num_arrays=2 - ), + fn_tree="torch.float_power", + dtype_and_x=_float_power_helper(), ) -def test_torch_logical_xor( - *, +def test_torch_float_power( dtype_and_x, on_device, fn_tree, @@ -1199,6 +1320,8 @@ def test_torch_logical_xor( backend_fw, ): input_dtype, x = dtype_and_x + # Making sure zero to the power of negative doesn't occur + assume(not np.any(np.isclose(x[0], 0))) helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -1206,19 +1329,20 @@ def test_torch_logical_xor( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, + rtol=1e-03, input=x[0], - other=x[1], + exponent=x[1], ) -# ceil +# floor @handle_frontend_test( - fn_tree="torch.ceil", + fn_tree="torch.floor", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), ), ) -def test_torch_ceil( +def test_torch_floor( *, dtype_and_x, on_device, @@ -1236,147 +1360,122 @@ def test_torch_ceil( fn_tree=fn_tree, on_device=on_device, input=x[0], + out=None, ) -# round +# floor_divide @handle_frontend_test( - fn_tree="torch.round", + fn_tree="torch.floor_divide", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, + large_abs_safety_factor=2.5, + small_abs_safety_factor=2.5, + safety_factor_scale="log", ), - decimals=st.integers(min_value=0, max_value=5), ) -def test_torch_round( +def test_torch_floor_divide( + *, dtype_and_x, - decimals, - frontend, test_flags, + on_device, fn_tree, + frontend, backend_fw, ): input_dtype, x = dtype_and_x + assume(not np.any(np.isclose(x[1], 0))) helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, - frontend=frontend, test_flags=test_flags, + frontend=frontend, fn_tree=fn_tree, + on_device=on_device, + atol=1, input=x[0], - decimals=decimals, - ) - - -@st.composite -def _get_clip_inputs(draw): - shape = draw( - helpers.get_shape( - min_num_dims=1, max_num_dims=5, min_dim_size=2, max_dim_size=10 - ) - ) - x_dtype, x = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - shape=shape, - ) + other=x[1], + out=None, ) - min = draw(st.booleans()) - if min: - max = draw(st.booleans()) - min = draw( - helpers.array_values( - dtype=x_dtype[0], shape=shape, min_value=0, max_value=25 - ) - ) - max = ( - draw( - helpers.array_values( - dtype=x_dtype[0], shape=shape, min_value=26, max_value=50 - ) - ) - if max - else None - ) - else: - min = None - max = draw( - helpers.array_values( - dtype=x_dtype[0], shape=shape, min_value=26, max_value=50 - ) - ) - return x_dtype, x, min, max -# clamp +# fmod @handle_frontend_test( - fn_tree="torch.clamp", - aliases=["torch.clip"], - input_and_ranges=_get_clip_inputs(), + fn_tree="torch.fmod", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, + min_num_dims=1, + min_value=-100, + max_value=100, + shared_dtype=True, + ), ) -def test_torch_clamp( +def test_torch_fmod( *, - input_and_ranges, + dtype_and_x, on_device, fn_tree, frontend, test_flags, backend_fw, ): - x_dtype, x, min, max = input_and_ranges + input_dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=x_dtype, + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=x[0], - min=min, - max=max, + atol=1e-03, + x1=x[0], + x2=x[1], ) -# clip +# frac @handle_frontend_test( - fn_tree="torch.clip", - input_and_ranges=_get_clip_inputs(), + fn_tree="torch.frac", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + ), ) -def test_torch_clip( +def test_torch_frac( *, - input_and_ranges, + dtype_and_x, on_device, fn_tree, frontend, test_flags, backend_fw, ): - x_dtype, x, min, max = input_and_ranges + input_dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=x_dtype, + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, input=x[0], - min=min, - max=max, ) -# mul @handle_frontend_test( - fn_tree="torch.mul", - aliases=["torch.multiply"], + fn_tree="torch.frexp", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, - min_value=-1e04, - max_value=1e04, - allow_inf=False, + available_dtypes=helpers.get_dtypes("float"), + num_arrays=1, + shared_dtype=True, + min_value=-10, + max_value=10, + min_num_dims=1, + max_num_dims=1, ), ) -def test_torch_mul( +def test_torch_frexp( dtype_and_x, on_device, fn_tree, @@ -1384,7 +1483,7 @@ def test_torch_mul( test_flags, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, input = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -1392,67 +1491,72 @@ def test_torch_mul( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - rtol=1e-03, - input=x[0], - other=x[1], + input=input[0], ) -# div +# gradient @handle_frontend_test( - fn_tree="torch.div", - aliases=["torch.divide"], - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, - large_abs_safety_factor=2.5, - small_abs_safety_factor=2.5, - safety_factor_scale="log", + fn_tree="torch.gradient", + dtype_input_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("float"), + force_int_axis=True, + min_num_dims=1, + max_num_dims=3, + min_dim_size=2, + max_dim_size=4, + valid_axis=True, ), - rounding_mode=st.sampled_from(["floor", "trunc"]) | st.none(), + spacing=helpers.ints( + min_value=-3, + max_value=3, + ), + test_with_out=st.just(False), ) -def test_torch_div( +def test_torch_gradient( *, - dtype_and_x, - rounding_mode, + dtype_input_axis, + spacing, + test_flags, on_device, fn_tree, - frontend, - test_flags, backend_fw, + frontend, ): - input_dtype, x = dtype_and_x - assume(not np.any(np.isclose(x[1], 0))) - - # Absolute tolerance is 1, - # due to flooring can cause absolute error of 1 due to precision + input_dtype, x, dim = dtype_input_axis helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, - frontend=frontend, test_flags=test_flags, + frontend=frontend, fn_tree=fn_tree, on_device=on_device, - atol=1, input=x[0], - other=x[1], - rounding_mode=rounding_mode, + spacing=spacing, + dim=dim, ) -# reciprocal +# hypot @handle_frontend_test( - fn_tree="torch.reciprocal", + fn_tree="torch.hypot", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - min_value=1, + num_arrays=2, + shared_dtype=True, + min_value=-100, + max_value=100, + min_num_dims=1, + max_num_dims=3, ), ) -def test_torch_reciprocal( +def test_torch_hypot( + *, dtype_and_x, + on_device, + fn_tree, frontend, test_flags, - fn_tree, backend_fw, ): input_dtype, x = dtype_and_x @@ -1462,22 +1566,21 @@ def test_torch_reciprocal( frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, + on_device=on_device, + atol=1e-2, input=x[0], + other=x[1], ) -# remainder +# i0 @handle_frontend_test( - fn_tree="torch.remainder", + fn_tree="torch.i0", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - large_abs_safety_factor=2.5, - small_abs_safety_factor=2.5, - safety_factor_scale="log", + available_dtypes=helpers.get_dtypes("float"), num_arrays=1 ), ) -def test_torch_remainder( +def test_torch_i0( *, dtype_and_x, on_device, @@ -1487,7 +1590,6 @@ def test_torch_remainder( backend_fw, ): input_dtype, x = dtype_and_x - assume(not np.any(np.isclose(x[1], 0))) helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -1495,35 +1597,32 @@ def test_torch_remainder( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - atol=1, + atol=1e-03, input=x[0], - other=x[1], ) -# flipud +# igamma @handle_frontend_test( - fn_tree="torch.flipud", - dtype_and_m=helpers.dtype_and_values( + fn_tree="torch.igamma", + dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - min_value=-100, + num_arrays=2, + shared_dtype=True, + min_value=2, max_value=100, - min_num_dims=1, - max_num_dims=3, - min_dim_size=1, - max_dim_size=3, ), + test_with_out=st.just(False), ) -def test_torch_flipud( - *, - dtype_and_m, +def test_torch_igamma( + dtype_and_x, on_device, fn_tree, frontend, test_flags, backend_fw, ): - input_dtype, m = dtype_and_m + input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -1531,20 +1630,20 @@ def test_torch_flipud( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=m[0], + rtol=1e-04, + input=x[0], + other=x[1], ) -# deg2rad +# imag @handle_frontend_test( - fn_tree="torch.deg2rad", + fn_tree="torch.imag", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_value=-1000, - max_value=1000, + available_dtypes=helpers.get_dtypes("complex"), ), ) -def test_torch_deg2rad( +def test_torch_imag( *, dtype_and_x, on_device, @@ -1553,7 +1652,7 @@ def test_torch_deg2rad( test_flags, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, input = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -1561,23 +1660,22 @@ def test_torch_deg2rad( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=x[0], + input=input[0], ) -# true_divide +# ldexp @handle_frontend_test( - fn_tree="torch.true_divide", + fn_tree="torch.ldexp", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), + available_dtypes=helpers.get_dtypes("valid"), num_arrays=2, - large_abs_safety_factor=2.5, - small_abs_safety_factor=2.5, - safety_factor_scale="log", + min_value=-1e04, + max_value=1e04, + allow_inf=False, ), ) -def test_torch_true_divide( - *, +def test_torch_ldexp( dtype_and_x, on_device, fn_tree, @@ -1586,7 +1684,6 @@ def test_torch_true_divide( backend_fw, ): input_dtype, x = dtype_and_x - assume(not np.any(np.isclose(x[1], 0))) helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -1594,28 +1691,37 @@ def test_torch_true_divide( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, + rtol=1e-03, input=x[0], other=x[1], ) -# floor +# lerp @handle_frontend_test( - fn_tree="torch.floor", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + fn_tree="torch.lerp", + dtype_and_input=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float", mixed_fn_compos=False), + num_arrays=3, + shared_dtype=True, + large_abs_safety_factor=2.5, + small_abs_safety_factor=2.5, + safety_factor_scale="log", + min_value=-1e3, + max_value=1e3, ), ) -def test_torch_floor( +def test_torch_lerp( *, - dtype_and_x, - on_device, - fn_tree, + dtype_and_input, frontend, test_flags, + fn_tree, backend_fw, + on_device, ): - input_dtype, x = dtype_and_x + input_dtype, inputs = dtype_and_input + start, end, weight = inputs helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -1623,61 +1729,53 @@ def test_torch_floor( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=x[0], - out=None, + input=start, + end=end, + weight=weight, ) -# floor_divide @handle_frontend_test( - fn_tree="torch.floor_divide", - dtype_and_x=helpers.dtype_and_values( + fn_tree="torch.lgamma", + dtype_and_input=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - large_abs_safety_factor=2.5, - small_abs_safety_factor=2.5, - safety_factor_scale="log", ), ) -def test_torch_floor_divide( +def test_torch_lgamma( *, - dtype_and_x, + dtype_and_input, + frontend, test_flags, - on_device, fn_tree, - frontend, backend_fw, + on_device, ): - input_dtype, x = dtype_and_x - assume(not np.any(np.isclose(x[1], 0))) + input_dtype, input = dtype_and_input helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, - test_flags=test_flags, frontend=frontend, + test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - atol=1, - input=x[0], - other=x[1], - out=None, + input=input[0], ) -# log1p +# log @handle_frontend_test( - fn_tree="torch.log1p", + fn_tree="torch.log", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - min_value=-1e4, - max_value=1e4, ), ) -def test_torch_log1p( +def test_torch_log( + *, dtype_and_x, + on_device, + fn_tree, frontend, test_flags, - fn_tree, backend_fw, ): input_dtype, x = dtype_and_x @@ -1687,64 +1785,50 @@ def test_torch_log1p( frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, + on_device=on_device, input=x[0], - out=None, ) -# addcdiv +# log10 @handle_frontend_test( - fn_tree="torch.addcdiv", + fn_tree="torch.log10", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - num_arrays=3, - large_abs_safety_factor=2.5, - small_abs_safety_factor=2.5, - safety_factor_scale="log", - shared_dtype=True, ), - value=st.floats(min_value=-100, max_value=100), ) -def test_torch_addcdiv( +def test_torch_log10( + *, dtype_and_x, - value, + on_device, + fn_tree, frontend, test_flags, - fn_tree, backend_fw, ): input_dtype, x = dtype_and_x - assume(not np.any(np.isclose(x[2], 0))) helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, + on_device=on_device, input=x[0], - tensor1=x[1], - tensor2=x[2], - value=value, - atol=1e-03, - out=None, ) -# addcmul +# log1p @handle_frontend_test( - fn_tree="torch.addcmul", + fn_tree="torch.log1p", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - num_arrays=3, min_value=-1e4, max_value=1e4, - shared_dtype=True, ), - value=st.floats(min_value=-10, max_value=10), ) -def test_torch_addcmul( +def test_torch_log1p( dtype_and_x, - value, frontend, test_flags, fn_tree, @@ -1757,26 +1841,20 @@ def test_torch_addcmul( frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, - atol=1e-2, input=x[0], - tensor1=x[1], - tensor2=x[2], - value=value, out=None, ) +# log2 @handle_frontend_test( - fn_tree="torch.pow", + fn_tree="torch.log2", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - large_abs_safety_factor=2.5, - small_abs_safety_factor=2.5, - safety_factor_scale="log", ), ) -def test_torch_pow( +def test_torch_log2( + *, dtype_and_x, on_device, fn_tree, @@ -1784,7 +1862,7 @@ def test_torch_pow( test_flags, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, input = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -1792,45 +1870,24 @@ def test_torch_pow( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - rtol=1e-03, - input=x[0], - exponent=x[1], - ) - - -# float_power_helper -@st.composite -def _float_power_helper(draw, *, available_dtypes=None): - if available_dtypes is None: - available_dtypes = helpers.get_dtypes("numeric") - dtype1, x1 = draw( - helpers.dtype_and_values( - available_dtypes=available_dtypes, - small_abs_safety_factor=16, - large_abs_safety_factor=16, - safety_factor_scale="log", - ) - ) - dtype2 = draw(helpers.get_dtypes("numeric")) - if ivy.is_int_dtype(dtype2[0]): - min_value = 0 - else: - min_value = -10 - dtype2, x2 = draw( - helpers.dtype_and_values( - min_value=min_value, - max_value=10, - dtype=dtype2, - ) + input=input[0], ) - return (dtype1[0], dtype2[0]), (x1[0], x2[0]) +# logaddexp @handle_frontend_test( - fn_tree="torch.float_power", - dtype_and_x=_float_power_helper(), + fn_tree="torch.logaddexp", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, + min_num_dims=1, + min_value=-100, + max_value=100, + shared_dtype=True, + ), ) -def test_torch_float_power( +def test_torch_logaddexp( + *, dtype_and_x, on_device, fn_tree, @@ -1839,8 +1896,6 @@ def test_torch_float_power( backend_fw, ): input_dtype, x = dtype_and_x - # Making sure zero to the power of negative doesn't occur - assume(not np.any(np.isclose(x[0], 0))) helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -1848,15 +1903,15 @@ def test_torch_float_power( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - rtol=1e-03, - input=x[0], - exponent=x[1], + atol=1e-03, + x1=x[0], + x2=x[1], ) -# logaddexp +# logaddexp2 @handle_frontend_test( - fn_tree="torch.logaddexp", + fn_tree="torch.logaddexp2", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), num_arrays=2, @@ -1866,7 +1921,7 @@ def test_torch_float_power( shared_dtype=True, ), ) -def test_torch_logaddexp( +def test_torch_logaddexp2( *, dtype_and_x, on_device, @@ -1883,20 +1938,20 @@ def test_torch_logaddexp( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - atol=1e-03, + atol=1e-02, x1=x[0], x2=x[1], ) -# exp2 +# logical and @handle_frontend_test( - fn_tree="torch.exp2", + fn_tree="torch.logical_and", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("valid"), num_arrays=2 ), ) -def test_torch_exp2( +def test_torch_logical_and( *, dtype_and_x, on_device, @@ -1914,17 +1969,18 @@ def test_torch_exp2( fn_tree=fn_tree, on_device=on_device, input=x[0], + other=x[1], ) -# log +# logical not @handle_frontend_test( - fn_tree="torch.log", + fn_tree="torch.logical_not", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("valid"), num_arrays=1 ), ) -def test_torch_log( +def test_torch_logical_not( *, dtype_and_x, on_device, @@ -1945,14 +2001,14 @@ def test_torch_log( ) -# rsqrt +# logical or @handle_frontend_test( - fn_tree="torch.rsqrt", + fn_tree="torch.logical_or", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("valid"), num_arrays=2 ), ) -def test_torch_rsqrt( +def test_torch_logical_or( *, dtype_and_x, on_device, @@ -1970,17 +2026,18 @@ def test_torch_rsqrt( fn_tree=fn_tree, on_device=on_device, input=x[0], + other=x[1], ) -# expm1 +# logical xor @handle_frontend_test( - fn_tree="torch.expm1", + fn_tree="torch.logical_xor", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("valid"), num_arrays=2 ), ) -def test_torch_expm1( +def test_torch_logical_xor( *, dtype_and_x, on_device, @@ -1998,31 +2055,34 @@ def test_torch_expm1( fn_tree=fn_tree, on_device=on_device, input=x[0], + other=x[1], ) -# logaddexp2 @handle_frontend_test( - fn_tree="torch.logaddexp2", - dtype_and_x=helpers.dtype_and_values( + fn_tree="torch.logit", + dtype_and_input=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, min_num_dims=1, - min_value=-100, - max_value=100, - shared_dtype=True, + max_num_dims=1, + min_dim_size=1, + max_dim_size=1, + min_value=-10, + max_value=10, ), + eps=st.sampled_from([1e-05, -1e-05, None]), ) -def test_torch_logaddexp2( +def test_torch_logit( *, - dtype_and_x, - on_device, - fn_tree, + dtype_and_input, + eps, frontend, test_flags, + fn_tree, backend_fw, + on_device, ): - input_dtype, x = dtype_and_x + input_dtype, input = dtype_and_input helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -2030,21 +2090,48 @@ def test_torch_logaddexp2( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - atol=1e-02, - x1=x[0], - x2=x[1], + input=input[0], + eps=eps, + out=None, ) -# i0 +# masked_fill @handle_frontend_test( - fn_tree="torch.i0", + fn_tree="torch.masked_fill", + x_mask_val=_masked_fill_helper(), +) +def test_torch_masked_fill( + *, x_mask_val, on_device, fn_tree, frontend, test_flags, backend_fw +): + dtype, x, mask, val = x_mask_val + helpers.test_frontend_function( + input_dtypes=[dtype], + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + rtol=1e-03, + input=x, + mask=mask, + value=val, + ) + + +# mul +@handle_frontend_test( + fn_tree="torch.mul", + aliases=["torch.multiply"], dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), num_arrays=1 + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, + min_value=-1e04, + max_value=1e04, + allow_inf=False, ), ) -def test_torch_i0( - *, +def test_torch_mul( dtype_and_x, on_device, fn_tree, @@ -2060,27 +2147,34 @@ def test_torch_i0( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - atol=1e-03, + rtol=1e-03, input=x[0], + other=x[1], ) -# rad2deg @handle_frontend_test( - fn_tree="torch.rad2deg", + fn_tree="torch.nan_to_num", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - shared_dtype=True, - max_dim_size=3, - max_num_dims=3, - min_dim_size=1, min_num_dims=1, + max_num_dims=3, + min_value=-100, + max_value=100, + allow_nan=True, + allow_inf=True, ), + nan=st.floats(min_value=-100.0, max_value=100.0), + posinf=st.just(None) | st.floats(min_value=5e100, max_value=5e100), + neginf=st.just(None) | st.floats(min_value=-5e100, max_value=-5e100), + test_with_out=st.just(False), ) -def test_torch_rad2deg( +def test_torch_nan_to_num( *, dtype_and_x, + nan, + posinf, + neginf, on_device, fn_tree, frontend, @@ -2096,18 +2190,21 @@ def test_torch_rad2deg( fn_tree=fn_tree, on_device=on_device, input=x[0], + nan=nan, + posinf=posinf, + neginf=neginf, ) # negative @handle_frontend_test( - fn_tree="torch.positive", + fn_tree="torch.negative", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2, ), ) -def test_torch_positive( +def test_torch_negative( *, dtype_and_x, on_device, @@ -2128,23 +2225,25 @@ def test_torch_positive( ) -# frac +# nextafter @handle_frontend_test( - fn_tree="torch.frac", - dtype_and_x=helpers.dtype_and_values( + fn_tree="torch.nextafter", + dtype_and_input=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, + shared_dtype=True, ), ) -def test_torch_frac( +def test_torch_nextafter( *, - dtype_and_x, - on_device, - fn_tree, + dtype_and_input, frontend, test_flags, + fn_tree, backend_fw, + on_device, ): - input_dtype, x = dtype_and_x + input_dtype, x = dtype_and_input helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -2153,22 +2252,19 @@ def test_torch_frac( fn_tree=fn_tree, on_device=on_device, input=x[0], + other=x[1], ) -# xlogy +# negative @handle_frontend_test( - fn_tree="torch.xlogy", + fn_tree="torch.positive", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2, - min_num_dims=1, - min_value=-100, - max_value=100, - shared_dtype=True, ), ) -def test_torch_xlogy( +def test_torch_positive( *, dtype_and_x, on_device, @@ -2185,26 +2281,21 @@ def test_torch_xlogy( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - atol=1e-03, input=x[0], - other=x[1], ) -# copysign @handle_frontend_test( - fn_tree="torch.copysign", + fn_tree="torch.pow", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), num_arrays=2, - min_num_dims=1, - min_value=-100, - max_value=100, - shared_dtype=True, + large_abs_safety_factor=2.5, + small_abs_safety_factor=2.5, + safety_factor_scale="log", ), ) -def test_torch_copysign( - *, +def test_torch_pow( dtype_and_x, on_device, fn_tree, @@ -2220,20 +2311,26 @@ def test_torch_copysign( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - atol=1e-03, + rtol=1e-03, input=x[0], - other=x[1], + exponent=x[1], ) -# sinc +# rad2deg @handle_frontend_test( - fn_tree="torch.sinc", + fn_tree="torch.rad2deg", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, + shared_dtype=True, + max_dim_size=3, + max_num_dims=3, + min_dim_size=1, + min_num_dims=1, ), ) -def test_torch_sinc( +def test_torch_rad2deg( *, dtype_and_x, on_device, @@ -2254,20 +2351,14 @@ def test_torch_sinc( ) -# hypot +# real @handle_frontend_test( - fn_tree="torch.hypot", + fn_tree="torch.real", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - shared_dtype=True, - min_value=-100, - max_value=100, - min_num_dims=1, - max_num_dims=3, + available_dtypes=helpers.get_dtypes("valid"), ), ) -def test_torch_hypot( +def test_torch_real( *, dtype_and_x, on_device, @@ -2276,7 +2367,7 @@ def test_torch_hypot( test_flags, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, input = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -2284,86 +2375,48 @@ def test_torch_hypot( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - atol=1e-2, - input=x[0], - other=x[1], + input=input[0], ) -# sigmoid +# reciprocal @handle_frontend_test( - fn_tree="torch.sigmoid", - dtype_and_input=helpers.dtype_and_values( + fn_tree="torch.reciprocal", + dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), + min_value=1, ), ) -def test_torch_sigmoid( - *, - dtype_and_input, +def test_torch_reciprocal( + dtype_and_x, frontend, test_flags, fn_tree, backend_fw, - on_device, ): - input_dtype, x = dtype_and_input + input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, - on_device=on_device, input=x[0], ) -# lerp +# remainder @handle_frontend_test( - fn_tree="torch.lerp", - dtype_and_input=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float", mixed_fn_compos=False), - num_arrays=3, - shared_dtype=True, + fn_tree="torch.remainder", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, large_abs_safety_factor=2.5, small_abs_safety_factor=2.5, safety_factor_scale="log", - min_value=-1e3, - max_value=1e3, - ), -) -def test_torch_lerp( - *, - dtype_and_input, - frontend, - test_flags, - fn_tree, - backend_fw, - on_device, -): - input_dtype, inputs = dtype_and_input - start, end, weight = inputs - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=start, - end=end, - weight=weight, - ) - - -# signbit -@handle_frontend_test( - fn_tree="torch.signbit", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), ), ) -def test_torch_signbit( +def test_torch_remainder( *, dtype_and_x, on_device, @@ -2373,6 +2426,7 @@ def test_torch_signbit( backend_fw, ): input_dtype, x = dtype_and_x + assume(not np.any(np.isclose(x[1], 0))) helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -2380,55 +2434,57 @@ def test_torch_signbit( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, + atol=1, input=x[0], + other=x[1], ) -# angle +# round @handle_frontend_test( - fn_tree="torch.angle", - dtype_and_input=helpers.dtype_and_values( - available_dtypes=["float64", "complex64", "complex128"], + fn_tree="torch.round", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), ), + decimals=st.integers(min_value=0, max_value=5), ) -def test_torch_angle( - *, - dtype_and_input, +def test_torch_round( + dtype_and_x, + decimals, frontend, test_flags, fn_tree, backend_fw, - on_device, ): - input_dtype, x = dtype_and_input + input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, - on_device=on_device, input=x[0], + decimals=decimals, ) -# arctan +# rsqrt @handle_frontend_test( - fn_tree="torch.arctan", - dtype_and_input=helpers.dtype_and_values( + fn_tree="torch.rsqrt", + dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), ), ) -def test_torch_arctan( +def test_torch_rsqrt( *, - dtype_and_input, + dtype_and_x, + on_device, + fn_tree, frontend, test_flags, - fn_tree, backend_fw, - on_device, ): - input_dtype, x = dtype_and_input + input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -2440,14 +2496,20 @@ def test_torch_arctan( ) -# conj_physical @handle_frontend_test( - fn_tree="torch.conj_physical", + fn_tree="torch.sgn", dtype_and_input=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), + available_dtypes=helpers.get_dtypes("float_and_complex"), + min_num_dims=1, + max_num_dims=1, + min_dim_size=1, + max_dim_size=1, + abs_smallest_val=1e-10, + min_value=-10, + max_value=10, ), ) -def test_torch_conj_physical( +def test_torch_sgn( *, dtype_and_input, frontend, @@ -2456,7 +2518,7 @@ def test_torch_conj_physical( backend_fw, on_device, ): - input_dtype, x = dtype_and_input + input_dtype, input = dtype_and_input helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -2464,20 +2526,19 @@ def test_torch_conj_physical( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=x[0], + input=input[0], + out=None, ) -# nextafter +# sigmoid @handle_frontend_test( - fn_tree="torch.nextafter", + fn_tree="torch.sigmoid", dtype_and_input=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - shared_dtype=True, ), ) -def test_torch_nextafter( +def test_torch_sigmoid( *, dtype_and_input, frontend, @@ -2495,23 +2556,17 @@ def test_torch_nextafter( fn_tree=fn_tree, on_device=on_device, input=x[0], - other=x[1], ) -# fmod +# sign @handle_frontend_test( - fn_tree="torch.fmod", + fn_tree="torch.sign", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - min_num_dims=1, - min_value=-100, - max_value=100, - shared_dtype=True, + available_dtypes=helpers.get_dtypes("numeric"), ), ) -def test_torch_fmod( +def test_torch_sign( *, dtype_and_x, on_device, @@ -2528,20 +2583,18 @@ def test_torch_fmod( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - atol=1e-03, - x1=x[0], - x2=x[1], + input=x[0], ) -# imag +# signbit @handle_frontend_test( - fn_tree="torch.imag", + fn_tree="torch.signbit", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("complex"), + available_dtypes=helpers.get_dtypes("numeric"), ), ) -def test_torch_imag( +def test_torch_signbit( *, dtype_and_x, on_device, @@ -2550,7 +2603,7 @@ def test_torch_imag( test_flags, backend_fw, ): - input_dtype, input = dtype_and_x + input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -2558,34 +2611,27 @@ def test_torch_imag( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=input[0], + input=x[0], ) +# sin @handle_frontend_test( - fn_tree="torch.logit", - dtype_and_input=helpers.dtype_and_values( + fn_tree="torch.sin", + dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - min_num_dims=1, - max_num_dims=1, - min_dim_size=1, - max_dim_size=1, - min_value=-10, - max_value=10, ), - eps=st.sampled_from([1e-05, -1e-05, None]), ) -def test_torch_logit( +def test_torch_sin( *, - dtype_and_input, - eps, + dtype_and_x, + on_device, + fn_tree, frontend, test_flags, - fn_tree, backend_fw, - on_device, ): - input_dtype, input = dtype_and_input + input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -2593,20 +2639,18 @@ def test_torch_logit( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=input[0], - eps=eps, - out=None, + input=x[0], ) -# erf +# sinc @handle_frontend_test( - fn_tree="torch.erf", + fn_tree="torch.sinc", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), ), ) -def test_torch_erf( +def test_torch_sinc( *, dtype_and_x, on_device, @@ -2627,14 +2671,14 @@ def test_torch_erf( ) -# erfc +# sinh @handle_frontend_test( - fn_tree="torch.erfc", + fn_tree="torch.sinh", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), ), ) -def test_torch_erfc( +def test_torch_sinh( *, dtype_and_x, on_device, @@ -2655,19 +2699,13 @@ def test_torch_erfc( ) +# sqrt @handle_frontend_test( - fn_tree="torch.frexp", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=1, - shared_dtype=True, - min_value=-10, - max_value=10, - min_num_dims=1, - max_num_dims=1, - ), + fn_tree="torch.sqrt", + dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), ) -def test_torch_frexp( +def test_torch_sqrt( + *, dtype_and_x, on_device, fn_tree, @@ -2687,29 +2725,23 @@ def test_torch_frexp( ) +# square @handle_frontend_test( - fn_tree="torch.sgn", - dtype_and_input=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float_and_complex"), - min_num_dims=1, - max_num_dims=1, - min_dim_size=1, - max_dim_size=1, - abs_smallest_val=1e-10, - min_value=-10, - max_value=10, + fn_tree="torch.square", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), ), ) -def test_torch_sgn( +def test_torch_square( *, - dtype_and_input, + dtype_and_x, + on_device, + fn_tree, frontend, test_flags, - fn_tree, backend_fw, - on_device, ): - input_dtype, input = dtype_and_input + input_dtype, input = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -2718,32 +2750,26 @@ def test_torch_sgn( fn_tree=fn_tree, on_device=on_device, input=input[0], - out=None, ) +# subtract @handle_frontend_test( - fn_tree="torch.nan_to_num", + fn_tree="torch.subtract", + aliases=["torch.sub"], dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_num_dims=1, - max_num_dims=3, - min_value=-100, - max_value=100, - allow_nan=True, - allow_inf=True, + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, + large_abs_safety_factor=2.5, + small_abs_safety_factor=2.5, + safety_factor_scale="log", ), - nan=st.floats(min_value=-100.0, max_value=100.0), - posinf=st.just(None) | st.floats(min_value=5e100, max_value=5e100), - neginf=st.just(None) | st.floats(min_value=-5e100, max_value=-5e100), - test_with_out=st.just(False), + alpha=st.integers(min_value=1, max_value=5), ) -def test_torch_nan_to_num( +def test_torch_subtract( *, dtype_and_x, - nan, - posinf, - neginf, + alpha, on_device, fn_tree, frontend, @@ -2759,60 +2785,48 @@ def test_torch_nan_to_num( fn_tree=fn_tree, on_device=on_device, input=x[0], - nan=nan, - posinf=posinf, - neginf=neginf, + other=x[1], + alpha=alpha, ) -@st.composite -def _masked_fill_helper(draw): - cond, xs, dtypes = draw(_broadcastable_trio()) - if ivy.is_uint_dtype(dtypes[0]): - fill_value = draw(helpers.ints(min_value=0, max_value=5)) - elif ivy.is_int_dtype(dtypes[0]): - fill_value = draw(helpers.ints(min_value=-5, max_value=5)) - else: - fill_value = draw(helpers.floats(min_value=-5, max_value=5)) - return dtypes[0], xs[0], cond, fill_value - - -# masked_fill +# tan @handle_frontend_test( - fn_tree="torch.masked_fill", - x_mask_val=_masked_fill_helper(), + fn_tree="torch.tan", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + ), ) -def test_torch_masked_fill( - *, x_mask_val, on_device, fn_tree, frontend, test_flags, backend_fw +def test_torch_tan( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, ): - dtype, x, mask, val = x_mask_val + input_dtype, x = dtype_and_x helpers.test_frontend_function( - input_dtypes=[dtype], + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - rtol=1e-03, - input=x, - mask=mask, - value=val, + input=x[0], ) -# igamma +# tanh @handle_frontend_test( - fn_tree="torch.igamma", + fn_tree="torch.tanh", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - shared_dtype=True, - min_value=2, - max_value=100, ), - test_with_out=st.just(False), ) -def test_torch_igamma( +def test_torch_tanh( + *, dtype_and_x, on_device, fn_tree, @@ -2828,24 +2842,23 @@ def test_torch_igamma( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - rtol=1e-04, input=x[0], - other=x[1], ) -# ldexp +# true_divide @handle_frontend_test( - fn_tree="torch.ldexp", + fn_tree="torch.true_divide", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), + available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2, - min_value=-1e04, - max_value=1e04, - allow_inf=False, + large_abs_safety_factor=2.5, + small_abs_safety_factor=2.5, + safety_factor_scale="log", ), ) -def test_torch_ldexp( +def test_torch_true_divide( + *, dtype_and_x, on_device, fn_tree, @@ -2854,6 +2867,7 @@ def test_torch_ldexp( backend_fw, ): input_dtype, x = dtype_and_x + assume(not np.any(np.isclose(x[1], 0))) helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -2861,28 +2875,29 @@ def test_torch_ldexp( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - rtol=1e-03, input=x[0], other=x[1], ) +# trunc @handle_frontend_test( - fn_tree="torch.lgamma", - dtype_and_input=helpers.dtype_and_values( + fn_tree="torch.trunc", + aliases=["torch.fix"], + dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), ), ) -def test_torch_lgamma( +def test_torch_trunc( *, - dtype_and_input, + dtype_and_x, + on_device, + fn_tree, frontend, test_flags, - fn_tree, backend_fw, - on_device, ): - input_dtype, input = dtype_and_input + input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -2890,47 +2905,40 @@ def test_torch_lgamma( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=input[0], + input=x[0], ) -# gradient +# xlogy @handle_frontend_test( - fn_tree="torch.gradient", - dtype_input_axis=helpers.dtype_values_axis( + fn_tree="torch.xlogy", + dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - force_int_axis=True, + num_arrays=2, min_num_dims=1, - max_num_dims=3, - min_dim_size=2, - max_dim_size=4, - valid_axis=True, - ), - spacing=helpers.ints( - min_value=-3, - max_value=3, + min_value=-100, + max_value=100, + shared_dtype=True, ), - test_with_out=st.just(False), ) -def test_torch_gradient( +def test_torch_xlogy( *, - dtype_input_axis, - spacing, - test_flags, + dtype_and_x, on_device, fn_tree, - backend_fw, frontend, + test_flags, + backend_fw, ): - input_dtype, x, dim = dtype_input_axis + input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, - test_flags=test_flags, frontend=frontend, + test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, + atol=1e-03, input=x[0], - spacing=spacing, - dim=dim, + other=x[1], ) diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_random_sampling.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_random_sampling.py index fb6d48e0e7bc4..7d021c4db6a05 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_random_sampling.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_random_sampling.py @@ -8,6 +8,10 @@ from ivy_tests.test_ivy.helpers import handle_frontend_test +# --- Helpers --- # +# --------------- # + + @st.composite def _pop_size_num_samples_replace_n_probs(draw): prob_dtype = draw(helpers.get_dtypes("float", full=False)) @@ -26,6 +30,73 @@ def _pop_size_num_samples_replace_n_probs(draw): return prob_dtype, batch_size, num_samples, replace, probs +# --- Main --- # +# ------------ # + + +@handle_frontend_test( + fn_tree="torch.bernoulli", + dtype_and_probs=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float", full=False), + min_value=0, + max_value=1, + min_num_dims=0, + ), +) +def test_torch_bernoulli( + dtype_and_probs, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + dtype, probs = dtype_and_probs + + def call(): + return helpers.test_frontend_function( + input_dtypes=dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + test_values=False, + input=probs[0], + ) + + ret = call() + + if not ivy.exists(ret): + return + + ret_np, ret_from_np = ret + ret_np = helpers.flatten_and_to_np(ret=ret_np, backend=backend_fw) + ret_from_np = helpers.flatten_and_to_np(ret=ret_from_np, backend=backend_fw) + for u, v in zip(ret_np, ret_from_np): + assert u.dtype == v.dtype + assert u.shape == v.shape + + +@handle_frontend_test( + fn_tree="torch.manual_seed", + seed=st.integers(min_value=0, max_value=2**32 - 1), +) +def test_torch_manual_seed( + *, + seed, + fn_tree, + frontend, + test_flags, + backend_fw, +): + # just test calling the function + frontend_fw = importlib.import_module(fn_tree[25 : fn_tree.rfind(".")]) + split_index = fn_tree.rfind(".") + _, fn_name = fn_tree[:split_index], fn_tree[split_index + 1 :] + frontend_fw.__dict__[fn_name](seed) + + # multinomial @handle_frontend_test( fn_tree="torch.multinomial", @@ -70,22 +141,61 @@ def call(): @handle_frontend_test( - fn_tree="torch.manual_seed", - seed=st.integers(min_value=0, max_value=2**32 - 1), + fn_tree="torch.normal", + dtype_and_mean=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_value=-1000, + max_value=1000, + min_num_dims=1, + max_num_dims=5, + min_dim_size=2, + ), + dtype_and_std=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_value=0, + max_value=1000, + min_num_dims=1, + max_num_dims=5, + min_dim_size=2, + ), ) -def test_torch_manual_seed( +def test_torch_normal( *, - seed, + dtype_and_mean, + dtype_and_std, + on_device, fn_tree, frontend, test_flags, backend_fw, ): - # just test calling the function - frontend_fw = importlib.import_module(fn_tree[25 : fn_tree.rfind(".")]) - split_index = fn_tree.rfind(".") - _, fn_name = fn_tree[:split_index], fn_tree[split_index + 1 :] - frontend_fw.__dict__[fn_name](seed) + mean_dtype, mean = dtype_and_mean + _, std = dtype_and_std + + def call(): + return helpers.test_frontend_function( + input_dtypes=mean_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + test_values=False, + mean=mean[0], + std=std[0], + ) + + ret = call() + + if not ivy.exists(ret): + return + + ret_np, ret_from_np = ret + ret_np = helpers.flatten_and_to_np(ret=ret_np, backend=backend_fw) + ret_from_np = helpers.flatten_and_to_np(ret=ret_from_np, backend=backend_fw) + for u, v in zip(ret_np, ret_from_np): + assert u.dtype == v.dtype + assert u.shape == v.shape @handle_frontend_test( @@ -135,36 +245,29 @@ def call(): assert u.shape == v.shape -# randint @handle_frontend_test( - fn_tree="torch.randint", - low=helpers.ints(min_value=0, max_value=10), - high=helpers.ints(min_value=11, max_value=20), - size=helpers.get_shape(), - dtype=helpers.get_dtypes("integer"), + fn_tree="torch.rand", + dtype=helpers.get_dtypes("float", full=False), + size=helpers.get_shape( + min_num_dims=1, + max_num_dims=5, + min_dim_size=1, + max_dim_size=10, + ), ) -def test_torch_randint( - *, - low, - high, - size, - dtype, - frontend, - test_flags, - fn_tree, - backend_fw, -): +def test_torch_rand(*, dtype, size, frontend, fn_tree, test_flags, backend_fw): + size = {f"size{i}": size[i] for i in range(len(size))} + test_flags.num_positional_args = len(size) + def call(): - helpers.test_frontend_function( + return helpers.test_frontend_function( input_dtypes=dtype, backend_to_test=backend_fw, frontend=frontend, test_values=False, fn_tree=fn_tree, test_flags=test_flags, - low=low, - high=high, - size=size, + **size, ) ret = call() @@ -181,28 +284,31 @@ def call(): @handle_frontend_test( - fn_tree="torch.rand", + fn_tree="torch.rand_like", dtype=helpers.get_dtypes("float", full=False), - size=helpers.get_shape( + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), min_num_dims=1, - max_num_dims=5, + max_num_dims=10, min_dim_size=1, max_dim_size=10, ), ) -def test_torch_rand(*, dtype, size, frontend, fn_tree, test_flags, backend_fw): - size = {f"size{i}": size[i] for i in range(len(size))} - test_flags.num_positional_args = len(size) +def test_torch_rand_like( + dtype_and_x, dtype, *, frontend, fn_tree, test_flags, backend_fw +): + input_dtype, input = dtype_and_x def call(): return helpers.test_frontend_function( - input_dtypes=dtype, + input_dtypes=input_dtype, backend_to_test=backend_fw, frontend=frontend, test_values=False, fn_tree=fn_tree, test_flags=test_flags, - **size, + input=input[0], + dtype=dtype[0], ) ret = call() @@ -218,49 +324,36 @@ def call(): assert u.shape == v.shape +# randint @handle_frontend_test( - fn_tree="torch.normal", - dtype_and_mean=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_value=-1000, - max_value=1000, - min_num_dims=1, - max_num_dims=5, - min_dim_size=2, - ), - dtype_and_std=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_value=0, - max_value=1000, - min_num_dims=1, - max_num_dims=5, - min_dim_size=2, - ), + fn_tree="torch.randint", + low=helpers.ints(min_value=0, max_value=10), + high=helpers.ints(min_value=11, max_value=20), + size=helpers.get_shape(), + dtype=helpers.get_dtypes("integer"), ) -def test_torch_normal( +def test_torch_randint( *, - dtype_and_mean, - dtype_and_std, - on_device, - fn_tree, + low, + high, + size, + dtype, frontend, test_flags, + fn_tree, backend_fw, ): - mean_dtype, mean = dtype_and_mean - _, std = dtype_and_std - def call(): - return helpers.test_frontend_function( - input_dtypes=mean_dtype, + helpers.test_frontend_function( + input_dtypes=dtype, backend_to_test=backend_fw, frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, test_values=False, - mean=mean[0], - std=std[0], + fn_tree=fn_tree, + test_flags=test_flags, + low=low, + high=high, + size=size, ) ret = call() @@ -277,18 +370,20 @@ def call(): @handle_frontend_test( - fn_tree="torch.rand_like", - dtype=helpers.get_dtypes("float", full=False), + fn_tree="torch.randint_like", + dtype=helpers.get_dtypes("signed_integer", full=False), dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("signed_integer"), min_num_dims=1, max_num_dims=10, min_dim_size=1, max_dim_size=10, ), + low=helpers.ints(min_value=0, max_value=10), + high=helpers.ints(min_value=11, max_value=20), ) -def test_torch_rand_like( - dtype_and_x, dtype, *, frontend, fn_tree, test_flags, backend_fw +def test_torch_randint_like( + dtype_and_x, low, high, *, dtype, frontend, fn_tree, test_flags, backend_fw ): input_dtype, input = dtype_and_x @@ -301,6 +396,8 @@ def call(): fn_tree=fn_tree, test_flags=test_flags, input=input[0], + low=low, + high=high, dtype=dtype[0], ) @@ -396,50 +493,6 @@ def call(): assert u.shape == v.shape -@handle_frontend_test( - fn_tree="torch.bernoulli", - dtype_and_probs=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float", full=False), - min_value=0, - max_value=1, - min_num_dims=0, - ), -) -def test_torch_bernoulli( - dtype_and_probs, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - dtype, probs = dtype_and_probs - - def call(): - return helpers.test_frontend_function( - input_dtypes=dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - test_values=False, - input=probs[0], - ) - - ret = call() - - if not ivy.exists(ret): - return - - ret_np, ret_from_np = ret - ret_np = helpers.flatten_and_to_np(ret=ret_np, backend=backend_fw) - ret_from_np = helpers.flatten_and_to_np(ret=ret_from_np, backend=backend_fw) - for u, v in zip(ret_np, ret_from_np): - assert u.dtype == v.dtype - assert u.shape == v.shape - - # randperm @handle_frontend_test( fn_tree="torch.randperm", @@ -481,51 +534,6 @@ def call(): assert u.shape == v.shape -@handle_frontend_test( - fn_tree="torch.randint_like", - dtype=helpers.get_dtypes("signed_integer", full=False), - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("signed_integer"), - min_num_dims=1, - max_num_dims=10, - min_dim_size=1, - max_dim_size=10, - ), - low=helpers.ints(min_value=0, max_value=10), - high=helpers.ints(min_value=11, max_value=20), -) -def test_torch_randint_like( - dtype_and_x, low, high, *, dtype, frontend, fn_tree, test_flags, backend_fw -): - input_dtype, input = dtype_and_x - - def call(): - return helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_values=False, - fn_tree=fn_tree, - test_flags=test_flags, - input=input[0], - low=low, - high=high, - dtype=dtype[0], - ) - - ret = call() - - if not ivy.exists(ret): - return - - ret_np, ret_from_np = ret - ret_np = helpers.flatten_and_to_np(ret=ret_np, backend=backend_fw) - ret_from_np = helpers.flatten_and_to_np(ret=ret_from_np, backend=backend_fw) - for u, v in zip(ret_np, ret_from_np): - assert u.dtype == v.dtype - assert u.shape == v.shape - - # set_rng_state @handle_frontend_test( fn_tree="torch.set_rng_state", diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_reduction_ops.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_reduction_ops.py index 194d727b8568c..c9505a8e5407e 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_reduction_ops.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_reduction_ops.py @@ -1,1032 +1,1040 @@ -# global -from hypothesis import strategies as st - -# local -import ivy -import ivy_tests.test_ivy.helpers as helpers -from ivy_tests.test_ivy.helpers import handle_frontend_test -from ivy_tests.test_ivy.test_functional.test_core.test_statistical import ( - _statistical_dtype_values, - _get_castable_dtype, -) -from ivy_tests.test_ivy.test_functional.test_experimental.test_core.test_statistical import ( # noqa - _quantile_helper, -) - - -@handle_frontend_test( - fn_tree="torch.dist", - dtype_and_input=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - shared_dtype=True, - min_value=-1e04, - max_value=1e04, - allow_inf=False, - ), - p=helpers.floats(min_value=1.0, max_value=10.0), -) -def test_torch_dist( - *, - dtype_and_input, - p, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, input = dtype_and_input - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=input[0], - other=input[1], - p=p, - ) - - -@handle_frontend_test( - fn_tree="torch.argmax", - dtype_input_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("numeric"), - force_int_axis=True, - min_num_dims=1, - min_axis=-1, - max_axis=0, - ), - keepdims=st.booleans(), -) -def test_torch_argmax( - *, - dtype_input_axis, - keepdims, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, x, axis = dtype_input_axis - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=x[0], - dim=axis, - keepdim=keepdims, - ) - - -@handle_frontend_test( - fn_tree="torch.argmin", - dtype_input_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("numeric"), - force_int_axis=True, - min_num_dims=1, - max_num_dims=3, - min_dim_size=1, - max_dim_size=3, - min_value=1, - max_value=5, - valid_axis=True, - allow_neg_axes=True, - ), - keepdims=st.booleans(), -) -def test_torch_argmin( - *, - dtype_input_axis, - keepdims, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, x, axis = dtype_input_axis - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=x[0], - dim=axis, - keepdim=keepdims, - ) - - -@handle_frontend_test( - fn_tree="torch.amax", - dtype_input_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("numeric"), - min_num_dims=1, - min_axis=-1, - max_axis=0, - ), - keepdims=st.booleans(), -) -def test_torch_amax( - *, - dtype_input_axis, - keepdims, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, x, axis = dtype_input_axis - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=x[0], - dim=axis, - keepdim=keepdims, - ) - - -@handle_frontend_test( - fn_tree="torch.amin", - dtype_input_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("numeric"), - min_num_dims=1, - min_axis=-1, - max_axis=0, - ), - keepdims=st.booleans(), -) -def test_torch_amin( - *, - dtype_input_axis, - keepdims, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, x, axis = dtype_input_axis - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=x[0], - dim=axis, - keepdim=keepdims, - ) - - -@handle_frontend_test( - fn_tree="torch.all", - dtype_input_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("valid"), - min_axis=-1, - max_axis=0, - min_num_dims=1, - allow_inf=False, - ), - keepdims=st.booleans(), -) -def test_torch_all( - *, - dtype_input_axis, - keepdims, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, x, axis = dtype_input_axis - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=x[0], - dim=axis, - keepdim=keepdims, - ) - - -@handle_frontend_test( - fn_tree="torch.any", - dtype_input_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("valid"), - min_axis=-1, - max_axis=0, - min_num_dims=1, - allow_inf=False, - ), - keepdims=st.booleans(), -) -def test_torch_any( - *, - dtype_input_axis, - keepdims, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, x, axis = dtype_input_axis - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=x[0], - dim=axis, - keepdim=keepdims, - ) - - -@handle_frontend_test( - fn_tree="torch.sum", - dtype_and_x=_get_castable_dtype( - min_value=-1e04, - max_value=1e04, - ), - keepdims=st.booleans(), -) -def test_torch_sum( - *, - dtype_and_x, - keepdims, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, x, axis, castable_dtype = dtype_and_x - if test_flags.as_variable: - castable_dtype = input_dtype - input_dtype = [input_dtype] - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=x[0], - dim=axis, - keepdim=keepdims, - dtype=castable_dtype, - ) - - -@handle_frontend_test( - fn_tree="torch.mean", - dtype_and_x=_statistical_dtype_values( - function="mean", - min_value=-1e04, - max_value=1e04, - ), - keepdims=st.booleans(), - dtypes=helpers.get_dtypes("float_and_complex", none=True, full=False), -) -def test_torch_mean( - *, - dtype_and_x, - keepdims, - dtypes, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, x, axis = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=x[0], - dim=axis, - keepdim=keepdims, - dtype=dtypes[0], - atol=1e-2, - ) - - -@handle_frontend_test( - fn_tree="torch.nanmean", - dtype_and_x=_statistical_dtype_values( - function="nanmean", - min_value=-1e04, - max_value=1e04, - ), - keepdims=st.booleans(), -) -def test_torch_nanmean( - *, - dtype_and_x, - keepdims, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, x, axis = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=x[0], - dim=axis, - keepdim=keepdims, - ) - - -@handle_frontend_test( - fn_tree="torch.median", - dtype_input_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("numeric"), - min_num_dims=1, - valid_axis=True, - force_int_axis=True, - ), - keepdim=st.booleans(), -) -def test_torch_median( - *, - dtype_input_axis, - keepdim, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, input, dim = dtype_input_axis - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=input[0], - dim=dim, - keepdim=keepdim, - ) - - -@handle_frontend_test( - fn_tree="torch.std", - dtype_and_x=_statistical_dtype_values(function="std"), - keepdims=st.booleans(), -) -def test_torch_std( - *, - dtype_and_x, - keepdims, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, x, axis, correction = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=x[0], - dim=axis, - unbiased=bool(correction), - keepdim=keepdims, - ) - - -# prod -@handle_frontend_test( - fn_tree="torch.prod", - dtype_x_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("numeric"), - min_num_dims=1, - max_num_dims=5, - valid_axis=True, - allow_neg_axes=False, - max_axes_size=1, - force_int_axis=True, - large_abs_safety_factor=10, - small_abs_safety_factor=10, - safety_factor_scale="log", - ), - dtype=helpers.get_dtypes("numeric", none=True, full=False), - keepdims=st.booleans(), -) -def test_torch_prod( - *, - dtype_x_axis, - dtype, - keepdims, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, x, axis = dtype_x_axis - # ToDo: set as_variable_flags as the parameter generated by test_torch_prod once - # this issue is marked as completed https://github.com/pytorch/pytorch/issues/75733 - if ivy.current_backend_str() == "torch": - test_flags.as_variable = [False] - - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=x[0], - dim=axis, - keepdim=keepdims, - dtype=dtype[0], - ) - - -@handle_frontend_test( - fn_tree="torch.var", - dtype_and_x=_statistical_dtype_values( - function="var", - min_value=-1e04, - max_value=1e04, - ), - keepdims=st.booleans(), -) -def test_torch_var( - *, - dtype_and_x, - keepdims, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, x, axis, correction = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=x[0], - dim=axis, - unbiased=bool(correction), - keepdim=keepdims, - ) - - -# min -@handle_frontend_test( - fn_tree="torch.min", - dtype_input_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("numeric"), - min_num_dims=1, - valid_axis=True, - force_int_axis=True, - num_arrays=st.integers(min_value=1, max_value=2), - ), - keepdim=st.booleans(), -) -def test_torch_min( - *, - dtype_input_axis, - keepdim, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, input, axis = dtype_input_axis - inputs = {f"input{i}": input[i] for i in range(len(input))} - kwargs = {"dim": axis, "keepdim": keepdim} if len(inputs) == 1 else {} - test_flags.num_positional_args = len(input) - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - **inputs, - **kwargs, - ) - - -# moveaxis -@handle_frontend_test( - fn_tree="torch.moveaxis", - dtype_and_a=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_value=-100, - max_value=100, - shape=st.shared( - helpers.get_shape( - min_num_dims=1, - max_num_dims=3, - min_dim_size=1, - max_dim_size=3, - ), - key="a_s_d", - ), - ), - source=helpers.get_axis( - allow_none=False, - unique=True, - shape=st.shared( - helpers.get_shape( - min_num_dims=1, - max_num_dims=3, - min_dim_size=1, - max_dim_size=3, - ), - key="a_s_d", - ), - min_size=1, - force_int=True, - ), - destination=helpers.get_axis( - allow_none=False, - unique=True, - shape=st.shared( - helpers.get_shape( - min_num_dims=1, - max_num_dims=3, - min_dim_size=1, - max_dim_size=3, - ), - key="a_s_d", - ), - min_size=1, - force_int=True, - ), -) -def test_torch_moveaxis( - *, - dtype_and_a, - source, - destination, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, a = dtype_and_a - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=a[0], - source=source, - destination=destination, - ) - - -@handle_frontend_test( - fn_tree="torch.max", - dtype_input_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("numeric"), - min_num_dims=1, - valid_axis=True, - force_int_axis=True, - num_arrays=st.integers(min_value=1, max_value=2), - ), - keepdim=st.booleans(), -) -def test_torch_max( - *, - dtype_input_axis, - keepdim, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, input, axis = dtype_input_axis - inputs = {f"input{i}": input[i] for i in range(len(input))} - kwargs = {"dim": axis, "keepdim": keepdim} if len(inputs) == 1 else {} - test_flags.num_positional_args = len(input) - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - **inputs, - **kwargs, - ) - - -@handle_frontend_test( - fn_tree="torch.std_mean", - dtype_and_x=_statistical_dtype_values( - function="std_mean", - min_value=-1e04, - max_value=1e04, - ), - keepdims=st.booleans(), -) -def test_torch_std_mean( - *, - dtype_and_x, - keepdims, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, x, axis, correction = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=x[0], - dim=axis, - unbiased=bool(correction), - keepdim=keepdims, - ) - - -@handle_frontend_test( - fn_tree="torch.var_mean", - dtype_and_x=_statistical_dtype_values( - function="var_mean", - min_value=-1e04, - max_value=1e04, - ), - keepdims=st.booleans(), -) -def test_torch_var_mean( - *, - dtype_and_x, - keepdims, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, x, axis, correction = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=x[0], - dim=axis, - unbiased=bool(correction), - keepdim=keepdims, - ) - - -@handle_frontend_test( - fn_tree="torch.aminmax", - dtype_input_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("numeric"), - min_num_dims=1, - min_axis=-1, - max_axis=0, - ), - keepdims=st.booleans(), -) -def test_torch_aminmax( - *, - dtype_input_axis, - keepdims, - test_flags, - on_device, - fn_tree, - frontend, - backend_fw, -): - input_dtype, x, axis = dtype_input_axis - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=x[0], - dim=axis, - keepdim=keepdims, - ) - - -@handle_frontend_test( - fn_tree="torch.quantile", - dtype_and_x=_quantile_helper(), - keepdims=st.booleans(), -) -def test_torch_quantile( - *, - dtype_and_x, - keepdims, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, x, axis, interpolation, q = dtype_and_x - if type(axis) is tuple: - axis = axis[0] - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=x[0], - q=q, - dim=axis, - keepdim=keepdims, - interpolation=interpolation[0], - ) - - -@handle_frontend_test( - fn_tree="torch.count_nonzero", - dtype_input_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("numeric"), - force_int_axis=True, - min_num_dims=1, - min_axis=-1, - max_axis=0, - ), -) -def test_torch_count_nonzero( - *, - dtype_input_axis, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, x, axis = dtype_input_axis - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=x[0], - dim=axis, - ) - - -@handle_frontend_test( - fn_tree="torch.logsumexp", - dtype_input_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("numeric"), - min_value=-50, - max_value=50, - min_num_dims=1, - max_num_dims=5, - valid_axis=True, - force_int_axis=True, - ), - keepdims=st.booleans(), -) -def test_torch_logsumexp( - *, - dtype_input_axis, - keepdims, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, x, axis = dtype_input_axis - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=x[0], - dim=axis, - keepdim=keepdims, - ) - - -@handle_frontend_test( - fn_tree="torch.unique", - dtype_x_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("numeric"), - force_int_axis=True, - valid_axis=True, - ), - return_inverse=st.booleans(), - return_counts=st.booleans(), - sorted=st.booleans(), -) -def test_torch_unique( - *, - dtype_x_axis, - return_inverse, - return_counts, - sorted, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtypes, x, axis = dtype_x_axis - - helpers.test_frontend_function( - input_dtypes=input_dtypes, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=x[0], - sorted=sorted, - return_inverse=return_inverse, - return_counts=return_counts, - dim=axis, - ) - - -@st.composite -def _get_axis_and_p(draw, kind="valid"): - p = draw(st.sampled_from(["fro", "nuc", 1, 2, -1, -2, float("inf"), -float("inf")])) - if p == "fro" or p == "nuc": - max_axes_size = 2 - min_axes_size = 2 - else: - min_axes_size = 1 - max_axes_size = 5 - dtype_x_axis = draw( - helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes(kind), - min_num_dims=2, - valid_axis=True, - min_value=-1e04, - max_value=1e04, - min_axes_size=min_axes_size, - max_axes_size=max_axes_size, - large_abs_safety_factor=2, - safety_factor_scale="log", - force_int_axis=True, - ) - ) - - input_dtype, x, axis = dtype_x_axis - if type(input_dtype[0]) == str: - if "complex" in input_dtype[0]: - kind = "complex" - if "float" in input_dtype[0]: - kind = "float" - else: - if input_dtype[0].is_complex_dtype: - kind = "complex" - if input_dtype[0].is_float_dtype: - kind = "float" - - dtype = draw(helpers.get_dtypes(kind, full=False)) - dtype = dtype[0] - if ivy.can_cast(input_dtype[0], dtype): - dtype = ivy.promote_types(input_dtype[0], dtype) - else: - dtype = input_dtype[0] - - return p, dtype_x_axis, dtype - - -# norm -@handle_frontend_test( - fn_tree="torch.norm", - p_dtype_x_axis=_get_axis_and_p(), - keepdim=st.booleans(), -) -def test_torch_norm( - *, - p_dtype_x_axis, - keepdim, - frontend, - test_flags, - fn_tree, - backend_fw, - on_device, -): - p, x_dtype, x, axis, dtype = p_dtype_x_axis - - helpers.test_frontend_function( - backend_to_test=backend_fw, - input_dtypes=[x_dtype], - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - rtol=1e-01, - atol=1e-08, - input=x, - p=p, - dim=axis, - keepdim=keepdim, - out=None, - dtype=dtype, - ) - - -# known bug of returning empty tensors when ret_inv or ret_counts is passed positionally -# https://github.com/pytorch/pytorch/issues/68610 -# ToDo: activate test_values when this is resolved -@handle_frontend_test( - fn_tree="torch.unique_consecutive", - dtype_x_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("numeric"), - min_num_dims=1, - min_dim_size=2, - force_int_axis=True, - valid_axis=True, - ), - ret_inv=st.booleans(), - ret_counts=st.booleans(), -) -def test_torch_unique_consecutive( - *, - dtype_x_axis, - ret_inv, - ret_counts, - frontend, - test_flags, - fn_tree, - backend_fw, - on_device, -): - dtype, x, axis = dtype_x_axis - helpers.test_frontend_function( - input_dtypes=dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - input=x[0], - return_inverse=ret_inv, - return_counts=ret_counts, - dim=axis, - test_values=False, - ) +# global +from hypothesis import strategies as st + +# local +import ivy +import ivy_tests.test_ivy.helpers as helpers +from ivy_tests.test_ivy.helpers import handle_frontend_test +from ivy_tests.test_ivy.test_functional.test_core.test_statistical import ( + _statistical_dtype_values, + _get_castable_dtype, +) +from ivy_tests.test_ivy.test_functional.test_experimental.test_core.test_statistical import ( # noqa + _quantile_helper, +) + + +# --- Helpers --- # +# --------------- # + + +@st.composite +def _get_axis_and_p(draw, kind="valid"): + p = draw(st.sampled_from(["fro", "nuc", 1, 2, -1, -2, float("inf"), -float("inf")])) + if p == "fro" or p == "nuc": + max_axes_size = 2 + min_axes_size = 2 + else: + min_axes_size = 1 + max_axes_size = 5 + dtype_x_axis = draw( + helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes(kind), + min_num_dims=2, + valid_axis=True, + min_value=-1e04, + max_value=1e04, + min_axes_size=min_axes_size, + max_axes_size=max_axes_size, + large_abs_safety_factor=2, + safety_factor_scale="log", + force_int_axis=True, + ) + ) + + input_dtype, x, axis = dtype_x_axis + if type(input_dtype[0]) == str: + if "complex" in input_dtype[0]: + kind = "complex" + if "float" in input_dtype[0]: + kind = "float" + else: + if input_dtype[0].is_complex_dtype: + kind = "complex" + if input_dtype[0].is_float_dtype: + kind = "float" + + dtype = draw(helpers.get_dtypes(kind, full=False)) + dtype = dtype[0] + if ivy.can_cast(input_dtype[0], dtype): + dtype = ivy.promote_types(input_dtype[0], dtype) + else: + dtype = input_dtype[0] + + return p, dtype_x_axis, dtype + + +# --- Main --- # +# ------------ # + + +@handle_frontend_test( + fn_tree="torch.all", + dtype_input_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("valid"), + min_axis=-1, + max_axis=0, + min_num_dims=1, + allow_inf=False, + ), + keepdims=st.booleans(), +) +def test_torch_all( + *, + dtype_input_axis, + keepdims, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x, axis = dtype_input_axis + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=x[0], + dim=axis, + keepdim=keepdims, + ) + + +@handle_frontend_test( + fn_tree="torch.amax", + dtype_input_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("numeric"), + min_num_dims=1, + min_axis=-1, + max_axis=0, + ), + keepdims=st.booleans(), +) +def test_torch_amax( + *, + dtype_input_axis, + keepdims, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x, axis = dtype_input_axis + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=x[0], + dim=axis, + keepdim=keepdims, + ) + + +@handle_frontend_test( + fn_tree="torch.amin", + dtype_input_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("numeric"), + min_num_dims=1, + min_axis=-1, + max_axis=0, + ), + keepdims=st.booleans(), +) +def test_torch_amin( + *, + dtype_input_axis, + keepdims, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x, axis = dtype_input_axis + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=x[0], + dim=axis, + keepdim=keepdims, + ) + + +@handle_frontend_test( + fn_tree="torch.aminmax", + dtype_input_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("numeric"), + min_num_dims=1, + min_axis=-1, + max_axis=0, + ), + keepdims=st.booleans(), +) +def test_torch_aminmax( + *, + dtype_input_axis, + keepdims, + test_flags, + on_device, + fn_tree, + frontend, + backend_fw, +): + input_dtype, x, axis = dtype_input_axis + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=x[0], + dim=axis, + keepdim=keepdims, + ) + + +@handle_frontend_test( + fn_tree="torch.any", + dtype_input_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("valid"), + min_axis=-1, + max_axis=0, + min_num_dims=1, + allow_inf=False, + ), + keepdims=st.booleans(), +) +def test_torch_any( + *, + dtype_input_axis, + keepdims, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x, axis = dtype_input_axis + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=x[0], + dim=axis, + keepdim=keepdims, + ) + + +@handle_frontend_test( + fn_tree="torch.argmax", + dtype_input_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("numeric"), + force_int_axis=True, + min_num_dims=1, + min_axis=-1, + max_axis=0, + ), + keepdims=st.booleans(), +) +def test_torch_argmax( + *, + dtype_input_axis, + keepdims, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x, axis = dtype_input_axis + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=x[0], + dim=axis, + keepdim=keepdims, + ) + + +@handle_frontend_test( + fn_tree="torch.argmin", + dtype_input_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("numeric"), + force_int_axis=True, + min_num_dims=1, + max_num_dims=3, + min_dim_size=1, + max_dim_size=3, + min_value=1, + max_value=5, + valid_axis=True, + allow_neg_axes=True, + ), + keepdims=st.booleans(), +) +def test_torch_argmin( + *, + dtype_input_axis, + keepdims, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x, axis = dtype_input_axis + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=x[0], + dim=axis, + keepdim=keepdims, + ) + + +@handle_frontend_test( + fn_tree="torch.count_nonzero", + dtype_input_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("numeric"), + force_int_axis=True, + min_num_dims=1, + min_axis=-1, + max_axis=0, + ), +) +def test_torch_count_nonzero( + *, + dtype_input_axis, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x, axis = dtype_input_axis + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=x[0], + dim=axis, + ) + + +@handle_frontend_test( + fn_tree="torch.dist", + dtype_and_input=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, + shared_dtype=True, + min_value=-1e04, + max_value=1e04, + allow_inf=False, + ), + p=helpers.floats(min_value=1.0, max_value=10.0), +) +def test_torch_dist( + *, + dtype_and_input, + p, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, input = dtype_and_input + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=input[0], + other=input[1], + p=p, + ) + + +@handle_frontend_test( + fn_tree="torch.logsumexp", + dtype_input_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("numeric"), + min_value=-50, + max_value=50, + min_num_dims=1, + max_num_dims=5, + valid_axis=True, + force_int_axis=True, + ), + keepdims=st.booleans(), +) +def test_torch_logsumexp( + *, + dtype_input_axis, + keepdims, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x, axis = dtype_input_axis + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=x[0], + dim=axis, + keepdim=keepdims, + ) + + +@handle_frontend_test( + fn_tree="torch.max", + dtype_input_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("numeric"), + min_num_dims=1, + valid_axis=True, + force_int_axis=True, + num_arrays=st.integers(min_value=1, max_value=2), + ), + keepdim=st.booleans(), +) +def test_torch_max( + *, + dtype_input_axis, + keepdim, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, input, axis = dtype_input_axis + inputs = {f"input{i}": input[i] for i in range(len(input))} + kwargs = {"dim": axis, "keepdim": keepdim} if len(inputs) == 1 else {} + test_flags.num_positional_args = len(input) + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + **inputs, + **kwargs, + ) + + +@handle_frontend_test( + fn_tree="torch.mean", + dtype_and_x=_statistical_dtype_values( + function="mean", + min_value=-1e04, + max_value=1e04, + ), + keepdims=st.booleans(), + dtypes=helpers.get_dtypes("float_and_complex", none=True, full=False), +) +def test_torch_mean( + *, + dtype_and_x, + keepdims, + dtypes, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x, axis = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=x[0], + dim=axis, + keepdim=keepdims, + dtype=dtypes[0], + atol=1e-2, + ) + + +@handle_frontend_test( + fn_tree="torch.median", + dtype_input_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("numeric"), + min_num_dims=1, + valid_axis=True, + force_int_axis=True, + ), + keepdim=st.booleans(), +) +def test_torch_median( + *, + dtype_input_axis, + keepdim, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, input, dim = dtype_input_axis + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=input[0], + dim=dim, + keepdim=keepdim, + ) + + +# min +@handle_frontend_test( + fn_tree="torch.min", + dtype_input_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("numeric"), + min_num_dims=1, + valid_axis=True, + force_int_axis=True, + num_arrays=st.integers(min_value=1, max_value=2), + ), + keepdim=st.booleans(), +) +def test_torch_min( + *, + dtype_input_axis, + keepdim, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, input, axis = dtype_input_axis + inputs = {f"input{i}": input[i] for i in range(len(input))} + kwargs = {"dim": axis, "keepdim": keepdim} if len(inputs) == 1 else {} + test_flags.num_positional_args = len(input) + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + **inputs, + **kwargs, + ) + + +# moveaxis +@handle_frontend_test( + fn_tree="torch.moveaxis", + dtype_and_a=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_value=-100, + max_value=100, + shape=st.shared( + helpers.get_shape( + min_num_dims=1, + max_num_dims=3, + min_dim_size=1, + max_dim_size=3, + ), + key="a_s_d", + ), + ), + source=helpers.get_axis( + allow_none=False, + unique=True, + shape=st.shared( + helpers.get_shape( + min_num_dims=1, + max_num_dims=3, + min_dim_size=1, + max_dim_size=3, + ), + key="a_s_d", + ), + min_size=1, + force_int=True, + ), + destination=helpers.get_axis( + allow_none=False, + unique=True, + shape=st.shared( + helpers.get_shape( + min_num_dims=1, + max_num_dims=3, + min_dim_size=1, + max_dim_size=3, + ), + key="a_s_d", + ), + min_size=1, + force_int=True, + ), +) +def test_torch_moveaxis( + *, + dtype_and_a, + source, + destination, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, a = dtype_and_a + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=a[0], + source=source, + destination=destination, + ) + + +@handle_frontend_test( + fn_tree="torch.nanmean", + dtype_and_x=_statistical_dtype_values( + function="nanmean", + min_value=-1e04, + max_value=1e04, + ), + keepdims=st.booleans(), +) +def test_torch_nanmean( + *, + dtype_and_x, + keepdims, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x, axis = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=x[0], + dim=axis, + keepdim=keepdims, + ) + + +# norm +@handle_frontend_test( + fn_tree="torch.norm", + p_dtype_x_axis=_get_axis_and_p(), + keepdim=st.booleans(), +) +def test_torch_norm( + *, + p_dtype_x_axis, + keepdim, + frontend, + test_flags, + fn_tree, + backend_fw, + on_device, +): + p, x_dtype, x, axis, dtype = p_dtype_x_axis + + helpers.test_frontend_function( + backend_to_test=backend_fw, + input_dtypes=[x_dtype], + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + rtol=1e-01, + atol=1e-08, + input=x, + p=p, + dim=axis, + keepdim=keepdim, + out=None, + dtype=dtype, + ) + + +# prod +@handle_frontend_test( + fn_tree="torch.prod", + dtype_x_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("numeric"), + min_num_dims=1, + max_num_dims=5, + valid_axis=True, + allow_neg_axes=False, + max_axes_size=1, + force_int_axis=True, + large_abs_safety_factor=10, + small_abs_safety_factor=10, + safety_factor_scale="log", + ), + dtype=helpers.get_dtypes("numeric", none=True, full=False), + keepdims=st.booleans(), +) +def test_torch_prod( + *, + dtype_x_axis, + dtype, + keepdims, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x, axis = dtype_x_axis + # ToDo: set as_variable_flags as the parameter generated by test_torch_prod once + # this issue is marked as completed https://github.com/pytorch/pytorch/issues/75733 + if ivy.current_backend_str() == "torch": + test_flags.as_variable = [False] + + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=x[0], + dim=axis, + keepdim=keepdims, + dtype=dtype[0], + ) + + +@handle_frontend_test( + fn_tree="torch.quantile", + dtype_and_x=_quantile_helper(), + keepdims=st.booleans(), +) +def test_torch_quantile( + *, + dtype_and_x, + keepdims, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x, axis, interpolation, q = dtype_and_x + if type(axis) is tuple: + axis = axis[0] + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=x[0], + q=q, + dim=axis, + keepdim=keepdims, + interpolation=interpolation[0], + ) + + +@handle_frontend_test( + fn_tree="torch.std", + dtype_and_x=_statistical_dtype_values(function="std"), + keepdims=st.booleans(), +) +def test_torch_std( + *, + dtype_and_x, + keepdims, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x, axis, correction = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=x[0], + dim=axis, + unbiased=bool(correction), + keepdim=keepdims, + ) + + +@handle_frontend_test( + fn_tree="torch.std_mean", + dtype_and_x=_statistical_dtype_values( + function="std_mean", + min_value=-1e04, + max_value=1e04, + ), + keepdims=st.booleans(), +) +def test_torch_std_mean( + *, + dtype_and_x, + keepdims, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x, axis, correction = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=x[0], + dim=axis, + unbiased=bool(correction), + keepdim=keepdims, + ) + + +@handle_frontend_test( + fn_tree="torch.sum", + dtype_and_x=_get_castable_dtype( + min_value=-1e04, + max_value=1e04, + ), + keepdims=st.booleans(), +) +def test_torch_sum( + *, + dtype_and_x, + keepdims, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x, axis, castable_dtype = dtype_and_x + if test_flags.as_variable: + castable_dtype = input_dtype + input_dtype = [input_dtype] + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=x[0], + dim=axis, + keepdim=keepdims, + dtype=castable_dtype, + ) + + +@handle_frontend_test( + fn_tree="torch.unique", + dtype_x_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("numeric"), + force_int_axis=True, + valid_axis=True, + ), + return_inverse=st.booleans(), + return_counts=st.booleans(), + sorted=st.booleans(), +) +def test_torch_unique( + *, + dtype_x_axis, + return_inverse, + return_counts, + sorted, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtypes, x, axis = dtype_x_axis + + helpers.test_frontend_function( + input_dtypes=input_dtypes, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=x[0], + sorted=sorted, + return_inverse=return_inverse, + return_counts=return_counts, + dim=axis, + ) + + +# known bug of returning empty tensors when ret_inv or ret_counts is passed positionally +# https://github.com/pytorch/pytorch/issues/68610 +# ToDo: activate test_values when this is resolved +@handle_frontend_test( + fn_tree="torch.unique_consecutive", + dtype_x_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("numeric"), + min_num_dims=1, + min_dim_size=2, + force_int_axis=True, + valid_axis=True, + ), + ret_inv=st.booleans(), + ret_counts=st.booleans(), +) +def test_torch_unique_consecutive( + *, + dtype_x_axis, + ret_inv, + ret_counts, + frontend, + test_flags, + fn_tree, + backend_fw, + on_device, +): + dtype, x, axis = dtype_x_axis + helpers.test_frontend_function( + input_dtypes=dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=x[0], + return_inverse=ret_inv, + return_counts=ret_counts, + dim=axis, + test_values=False, + ) + + +@handle_frontend_test( + fn_tree="torch.var", + dtype_and_x=_statistical_dtype_values( + function="var", + min_value=-1e04, + max_value=1e04, + ), + keepdims=st.booleans(), +) +def test_torch_var( + *, + dtype_and_x, + keepdims, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x, axis, correction = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=x[0], + dim=axis, + unbiased=bool(correction), + keepdim=keepdims, + ) + + +@handle_frontend_test( + fn_tree="torch.var_mean", + dtype_and_x=_statistical_dtype_values( + function="var_mean", + min_value=-1e04, + max_value=1e04, + ), + keepdims=st.booleans(), +) +def test_torch_var_mean( + *, + dtype_and_x, + keepdims, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x, axis, correction = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=x[0], + dim=axis, + unbiased=bool(correction), + keepdim=keepdims, + ) diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py index ba83b4dea4975..04e3b112fbcef 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py @@ -16,11 +16,6 @@ _get_axis_and_p, ) -try: - import torch -except ImportError: - torch = SimpleNamespace() - import ivy from hypothesis import strategies as st, given, assume @@ -54,9 +49,92 @@ _quantile_helper, ) +try: + import torch +except ImportError: + torch = SimpleNamespace() + CLASS_TREE = "ivy.functional.frontends.torch.Tensor" +# --- Helpers --- # +# --------------- # + + +@st.composite +def _array_idxes_n_dtype(draw, **kwargs): + num_dims = draw(helpers.ints(min_value=1, max_value=4)) + dtype, x = draw( + helpers.dtype_and_values( + **kwargs, min_num_dims=num_dims, max_num_dims=num_dims, shared_dtype=True + ) + ) + idxes = draw( + st.lists( + helpers.ints(min_value=0, max_value=num_dims - 1), + min_size=num_dims, + max_size=num_dims, + unique=True, + ) + ) + return x, idxes, dtype + + +@st.composite +def _arrays_dim_idx_n_dtypes(draw): + num_dims = draw(st.shared(helpers.ints(min_value=1, max_value=4), key="num_dims")) + num_arrays = 2 + common_shape = draw( + helpers.lists( + x=helpers.ints(min_value=2, max_value=3), + min_size=num_dims - 1, + max_size=num_dims - 1, + ) + ) + _dim = draw(helpers.ints(min_value=0, max_value=num_dims - 1)) + unique_dims = draw( + helpers.lists( + x=helpers.ints(min_value=2, max_value=3), + min_size=num_arrays, + max_size=num_arrays, + ) + ) + + min_dim = min(unique_dims) + max_dim = max(unique_dims) + _idx = draw( + helpers.array_values( + shape=min_dim, + dtype="int64", + min_value=0, + max_value=max_dim, + exclude_min=False, + ) + ) + + xs = list() + available_input_types = draw(helpers.get_dtypes("numeric")) + input_dtypes = draw( + helpers.array_dtypes( + available_dtypes=available_input_types, + num_arrays=num_arrays, + shared_dtype=True, + ) + ) + for ud, dt in zip(unique_dims, input_dtypes): + x = draw( + helpers.array_values( + shape=common_shape[:_dim] + [ud] + common_shape[_dim:], + dtype=dt, + large_abs_safety_factor=2.5, + small_abs_safety_factor=2.5, + safety_factor_scale="log", + ) + ) + xs.append(x) + return xs, input_dtypes, _dim, _idx + + # Helper functions @st.composite def _dtypes(draw): @@ -74,231 +152,315 @@ def _dtypes(draw): @st.composite -def _requires_grad(draw): - dtype = draw(_dtypes())[0] - if ivy.is_int_dtype(dtype) or ivy.is_uint_dtype(dtype): - return draw(st.just(False)) - return draw(st.booleans()) +def _expand_helper(draw): + num_dims = draw(st.integers(min_value=1, max_value=10)) + shape = draw( + helpers.get_shape(min_num_dims=num_dims, max_num_dims=num_dims).filter( + lambda x: any(i == 1 for i in x) + ) + ) + new_shape = draw( + helpers.get_shape(min_num_dims=num_dims, max_num_dims=num_dims).filter( + lambda x: all(x[i] == v if v != 1 else True for i, v in enumerate(shape)) + ) + ) + dtype, x = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + shape=shape, + ) + ) + return dtype, x, new_shape -@given( - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid", prune_function=False) - ).filter(lambda x: "bfloat16" not in x[0]), -) -def test_torch_tensor_ivy_array( - dtype_x, - backend_fw, +@st.composite +def _fill_value_and_size( + draw, + *, + min_num_dims=1, + max_num_dims=5, + min_dim_size=1, + max_dim_size=10, ): - _, data = dtype_x - ivy.set_backend(backend_fw) - x = Tensor(data[0]) - x.ivy_array = data[0] - ret = helpers.flatten_and_to_np(ret=x.ivy_array.data, backend=backend_fw) - ret_gt = helpers.flatten_and_to_np(ret=data[0], backend=backend_fw) - helpers.value_test( - ret_np_flat=ret, - ret_np_from_gt_flat=ret_gt, - backend="torch", + if isinstance(min_dim_size, st._internal.SearchStrategy): + min_dim_size = draw(min_dim_size) + if isinstance(max_dim_size, st._internal.SearchStrategy): + max_dim_size = draw(max_dim_size) + + available_dtypes = draw(helpers.get_dtypes("numeric")) + dtype = draw( + helpers.array_dtypes( + num_arrays=1, + available_dtypes=available_dtypes, + ) + ) + array = draw( + helpers.array_values( + dtype=dtype[0], + shape=(1,), + ) ) + dtype.append("int32") + size = draw( + st.shared( + helpers.get_shape( + min_num_dims=min_num_dims, + max_num_dims=max_num_dims, + min_dim_size=min_dim_size, + max_dim_size=max_dim_size, + ), + key="shape", + ) + ) + fill_value = draw(helpers.ints()) if "int" in dtype[0] else draw(helpers.floats()) + return dtype, [array, size, fill_value] -@given( - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid", prune_function=False) - ).filter(lambda x: "bfloat16" not in x[0]), -) -def test_torch_tensor_device( - dtype_x, - backend_fw, -): - ivy.set_backend(backend_fw) - _, data = dtype_x - x = Tensor(data[0]) - x.ivy_array = data[0] - ivy.utils.assertions.check_equal( - x.device, ivy.dev(ivy.array(data[0])), as_array=False + +@st.composite +def _get_clamp_inputs(draw): + shape = draw( + helpers.get_shape( + min_num_dims=1, max_num_dims=5, min_dim_size=2, max_dim_size=10 + ) ) - ivy.previous_backend() + x_dtype, x = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + shape=shape, + ) + ) + min = draw(st.booleans()) + if min: + max = draw(st.booleans()) + min = draw( + helpers.array_values( + dtype=x_dtype[0], shape=shape, min_value=0, max_value=25 + ) + ) + max = ( + draw( + helpers.array_values( + dtype=x_dtype[0], shape=shape, min_value=26, max_value=50 + ) + ) + if max + else None + ) + else: + min = None + max = draw( + helpers.array_values( + dtype=x_dtype[0], shape=shape, min_value=26, max_value=50 + ) + ) + return x_dtype, x, min, max -@given( - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid", prune_function=False) - ).filter(lambda x: "bfloat16" not in x[0]), -) -def test_torch_tensor_dtype(dtype_x, backend_fw): - ivy.set_backend(backend_fw) - dtype, data = dtype_x - x = Tensor(data[0]) - x.ivy_array = data[0] - ivy.utils.assertions.check_equal(x.dtype, dtype[0], as_array=False) - ivy.previous_backend() +@st.composite +def _get_clip_min_inputs(draw): + shape = draw( + helpers.get_shape( + min_num_dims=1, max_num_dims=5, min_dim_size=2, max_dim_size=10 + ) + ) + x_dtype, x = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + shape=shape, + ) + ) + + min = draw( + helpers.array_values(dtype=x_dtype[0], shape=shape, min_value=0, max_value=25) + ) + return x_dtype, x, min -@given( - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid", prune_function=False), - ret_shape=True, - ).filter(lambda x: "bfloat16" not in x[0]), -) -def test_torch_tensor_shape(dtype_x, backend_fw): - ivy.set_backend(backend_fw) - dtype, data, shape = dtype_x - x = Tensor(data[0]) - ivy.utils.assertions.check_equal( - x.ivy_array.shape, ivy.Shape(shape), as_array=False + +@st.composite +def _get_dtype_and_multiplicative_matrices(draw): + return draw( + st.one_of( + _get_dtype_input_and_matrices(), + _get_dtype_and_3dbatch_matrices(), + ) ) - ivy.previous_backend() -@given( - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("complex", prune_function=False) - ).filter(lambda x: "bfloat16" not in x[0]), -) -def test_torch_tensor_real(dtype_x, backend_fw): - ivy.set_backend(backend_fw) - _, data = dtype_x - x = Tensor(data[0]) - x.ivy_array = data[0] - ivy.utils.assertions.check_equal(x.real, ivy.real(data[0])) - ivy.previous_backend() +@st.composite +def _get_dtype_and_multiplicative_matrices(draw): + return draw( + st.one_of( + _get_dtype_input_and_matrices(), + _get_dtype_and_3dbatch_matrices(), + ) + ) -@given( - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("complex", prune_function=False) - ), -) -def test_torch_tensor_imag(dtype_x, backend_fw): - ivy.set_backend(backend_fw) - _, data = dtype_x - x = Tensor(data[0]) - x.ivy_array = data[0] - ivy.utils.assertions.check_equal(x.imag, ivy.imag(data[0])) - ivy.previous_backend() +@st.composite +def _get_dtype_input_and_vectors(draw, with_input=False, same_size=False): + dim_size1 = draw(helpers.ints(min_value=2, max_value=5)) + dim_size2 = dim_size1 if same_size else draw(helpers.ints(min_value=2, max_value=5)) + dtype = draw(helpers.get_dtypes("float", full=True)) + dtype = [ + draw(st.sampled_from(tuple(set(dtype).difference({"bfloat16", "float16"})))) + ] + vec1 = draw( + helpers.array_values( + dtype=dtype[0], shape=(dim_size1,), min_value=2, max_value=5 + ) + ) + vec2 = draw( + helpers.array_values( + dtype=dtype[0], shape=(dim_size2,), min_value=2, max_value=5 + ) + ) + if with_input: + input = draw( + helpers.array_values( + dtype=dtype[0], shape=(dim_size1, dim_size2), min_value=2, max_value=5 + ) + ) + return dtype, input, vec1, vec2 + return dtype, vec1, vec2 -@given( - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid", prune_function=False), - ret_shape=True, - ).filter(lambda x: "bfloat16" not in x[0]), -) -def test_torch_tensor_ndim(dtype_x, backend_fw): - ivy.set_backend(backend_fw) - dtype, data, shape = dtype_x - x = Tensor(data[0]) - ivy.utils.assertions.check_equal(x.ndim, data[0].ndim, as_array=False) - ivy.previous_backend() +@st.composite +def _masked_fill_helper(draw): + cond, xs, dtypes = draw(_broadcastable_trio()) + if ivy.is_uint_dtype(dtypes[0]): + fill_value = draw(helpers.ints(min_value=0, max_value=5)) + elif ivy.is_int_dtype(dtypes[0]): + fill_value = draw(helpers.ints(min_value=-5, max_value=5)) + else: + fill_value = draw(helpers.floats(min_value=-5, max_value=5)) + return dtypes[0], xs[0], cond, fill_value -def test_torch_tensor_grad(backend_fw): - ivy.set_backend(backend_fw) - x = Tensor(ivy.array([1.0, 2.0, 3.0])) - grads = ivy.array([1.0, 2.0, 3.0]) - x._grads = grads - assert ivy.array_equal(x.grad, grads) - ivy.previous_backend() +@st.composite +def _repeat_helper(draw): + shape = draw( + helpers.get_shape( + min_num_dims=1, max_num_dims=5, min_dim_size=2, max_dim_size=10 + ) + ) + input_dtype, x = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + shape=shape, + ) + ) -@given( - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid", prune_function=False), - ), - requires_grad=st.booleans(), -) -def test_torch_tensor_requires_grad(dtype_x, requires_grad, backend_fw): - ivy.set_backend(backend_fw) - _, data = dtype_x - x = Tensor(data[0], requires_grad=requires_grad) - ivy.utils.assertions.check_equal(x.requires_grad, requires_grad, as_array=False) - x.requires_grad = not requires_grad - ivy.utils.assertions.check_equal(x.requires_grad, not requires_grad, as_array=False) - ivy.previous_backend() + repeats = draw(st.lists(st.integers(min_value=1, max_value=5), min_size=len(shape))) + return input_dtype, x, repeats -@given( - requires_grad=st.booleans(), -) -def test_torch_tensor_is_leaf(requires_grad, backend_fw): - ivy.set_backend(backend_fw) - x = Tensor(ivy.array([3.0]), requires_grad=requires_grad) - ivy.utils.assertions.check_equal(x.is_leaf, True, as_array=False) - y = x.pow(2) - ivy.utils.assertions.check_equal(y.is_leaf, not requires_grad, as_array=False) - z = y.detach() - ivy.utils.assertions.check_equal(z.is_leaf, True, as_array=False) - ivy.previous_backend() +@st.composite +def _requires_grad(draw): + dtype = draw(_dtypes())[0] + if ivy.is_int_dtype(dtype) or ivy.is_uint_dtype(dtype): + return draw(st.just(False)) + return draw(st.booleans()) + + +@st.composite +def _to_helper(draw): + dtype_x = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + num_arrays=2, + large_abs_safety_factor=3, + ) + ) + input_dtype, x = dtype_x + arg = draw(st.sampled_from(["tensor", "dtype", "device"])) + if arg == "tensor": + method_num_positional_args = 1 + method_all_as_kwargs_np = {"other": x[1]} + elif arg == "dtype": + method_num_positional_args = 1 + dtype = draw(helpers.get_dtypes("valid", full=False))[0] + method_all_as_kwargs_np = {"dtype": dtype} + else: + method_num_positional_args = 0 + device = draw(st.just("cpu")) + dtype = draw(helpers.get_dtypes("valid", full=False, none=True))[0] + method_all_as_kwargs_np = {"dtype": dtype, "device": device} + return input_dtype, x, method_num_positional_args, method_all_as_kwargs_np + + +@st.composite +def _unfold_args(draw): + values_dtype, values, axis, shape = draw( + helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("float"), + force_int_axis=True, + shape=draw( + helpers.get_shape( + allow_none=False, + min_num_dims=1, + min_dim_size=1, + ) + ), + ret_shape=True, + ) + ) + size = draw( + st.integers( + min_value=1, + max_value=max(shape[axis] - 1, 1), + ) + ) + step = draw( + st.integers( + min_value=1, + max_value=size, + ) + ) + return values_dtype, values, axis, size, step -def test_torch_tensor_grad_fn(backend_fw): - ivy.set_backend(backend_fw) - x = Tensor(ivy.array([3.0]), requires_grad=True) - ivy.utils.assertions.check_equal(x.grad_fn, None, as_array=False) - y = x.pow(2) - ivy.utils.assertions.check_equal(y.grad_fn, "PowBackward", as_array=False) - ivy.utils.assertions.check_equal( - y.grad_fn.next_functions[0], "AccumulateGrad", as_array=False +# diagonal +@st.composite +def dims_and_offset(draw, shape): + shape_actual = draw(shape) + dim1 = draw(helpers.get_axis(shape=shape, force_int=True)) + dim2 = draw(helpers.get_axis(shape=shape, force_int=True)) + offset = draw( + st.integers(min_value=-shape_actual[dim1], max_value=shape_actual[dim1]) ) - z = y.detach() - ivy.utils.assertions.check_equal(z.grad_fn, None, as_array=False) - ivy.previous_backend() + return dim1, dim2, offset -@given( - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid", prune_function=False), - ), - requires_grad=st.booleans(), -) -def test_torch_tensor__requires_grad( - dtype_x, - requires_grad, - backend_fw, -): - ivy.set_backend(backend_fw) - _, data = dtype_x - x = Tensor(data[0]) - assert not x._requires_grad - x.requires_grad_() - assert x._requires_grad - x.requires_grad_(requires_grad) - assert x._requires_grad == requires_grad - ivy.previous_backend() +# --- Main --- # +# ------------ # -# chunk -@pytest.mark.skip("Testing takes a lot of time") +# __add__ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="chunk", - dtype_x_dim=helpers.dtype_values_axis( + method_name="__add__", + dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - min_num_dims=1, + num_arrays=2, min_value=-1e04, max_value=1e04, - force_int_axis=True, - valid_axis=True, - ), - chunks=st.integers( - min_value=1, - max_value=5, + allow_inf=False, ), ) -def test_torch_tensor_chunk( - dtype_x_dim, - chunks, - frontend, +def test_torch___add__( + dtype_and_x, frontend_method_data, init_flags, method_flags, + frontend, on_device, backend_fw, ): - input_dtype, x, dim = dtype_x_dim + input_dtype, x = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -307,8 +469,7 @@ def test_torch_tensor_chunk( }, method_input_dtypes=input_dtype, method_all_as_kwargs_np={ - "chunks": chunks, - "dim": dim, + "other": x[1], }, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -318,24 +479,21 @@ def test_torch_tensor_chunk( ) -# any +# __and__ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="any", - dtype_input_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("numeric"), - min_num_dims=1, + method_name="__and__", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=st.one_of(st.just(("bool",)), helpers.get_dtypes("integer")), + num_arrays=2, min_value=-1e04, max_value=1e04, - valid_axis=True, - force_int_axis=True, + allow_inf=False, ), - keepdim=st.booleans(), ) -def test_torch_tensor_any( - dtype_input_axis, - keepdim, +def test_torch___and__( + dtype_and_x, frontend_method_data, init_flags, method_flags, @@ -343,7 +501,7 @@ def test_torch_tensor_any( on_device, backend_fw, ): - input_dtype, x, axis = dtype_input_axis + input_dtype, x = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -352,8 +510,7 @@ def test_torch_tensor_any( }, method_input_dtypes=input_dtype, method_all_as_kwargs_np={ - "dim": axis, - "keepdim": keepdim, + "other": x[1], }, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -363,24 +520,53 @@ def test_torch_tensor_any( ) -# all @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="all", - dtype_input_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("numeric"), - min_num_dims=1, + method_name="__array_wrap__", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + num_arrays=2, + ), +) +def test_torch___array_wrap__( + dtype_and_x, + backend_fw, + frontend, +): + input_dtypes, x = dtype_and_x + if x[1].dtype == "bfloat16": + return + if x[0].dtype == "bfloat16": + ret_gt = torch.tensor(x[0].tolist(), dtype=torch.bfloat16).__array_wrap__(x[1]) + else: + ret_gt = torch.tensor(x[0]).__array_wrap__(x[1]) + with BackendHandler.update_backend(backend_fw) as ivy_backend: + local_importer = ivy_backend.utils.dynamic_import + function_module = local_importer.import_module("ivy.functional.frontends.torch") + ret = function_module.tensor(x[0]).__array_wrap__(x[1]) + assert isinstance(ret, function_module.Tensor) + helpers.value_test( + ret_np_flat=np.array(ret.ivy_array).ravel(), + ret_np_from_gt_flat=ret_gt.numpy().ravel(), + ground_truth_backend="torch", + backend=backend_fw, + ) + + +# __bool__ +@handle_frontend_method( + class_tree=CLASS_TREE, + init_tree="torch.tensor", + method_name="__bool__", + dtype_and_x=helpers.dtype_and_values( + max_dim_size=1, min_value=-1e04, max_value=1e04, - valid_axis=True, - force_int_axis=True, ), - keepdim=st.booleans(), ) -def test_torch_tensor_all( - dtype_input_axis, - keepdim, +def test_torch___bool__( + dtype_and_x, frontend_method_data, init_flags, method_flags, @@ -388,18 +574,15 @@ def test_torch_tensor_all( on_device, backend_fw, ): - input_dtype, x, axis = dtype_input_axis + input_dtype, x = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ "data": x[0], }, - method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "dim": axis, - "keepdim": keepdim, - }, + method_input_dtypes=[], + method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -408,11 +591,11 @@ def test_torch_tensor_all( ) -# add +# __eq__ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="add", + method_name="__eq__", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), num_arrays=2, @@ -420,19 +603,58 @@ def test_torch_tensor_all( max_value=1e04, allow_inf=False, ), - alpha=st.floats(min_value=-1e04, max_value=1e04, allow_infinity=False), ) -def test_torch_tensor_add( +def test_torch___eq__( dtype_and_x, - alpha, + frontend_method_data, + init_flags, + method_flags, frontend, + on_device, + backend_fw, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_method( + init_input_dtypes=input_dtype, + backend_to_test=backend_fw, + init_all_as_kwargs_np={ + "data": x[0], + }, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={ + "other": x[1], + }, + frontend_method_data=frontend_method_data, + init_flags=init_flags, + method_flags=method_flags, + frontend=frontend, + on_device=on_device, + ) + + +@handle_frontend_method( + class_tree=CLASS_TREE, + init_tree="torch.tensor", + method_name="__floordiv__", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, + large_abs_safety_factor=2.5, + small_abs_safety_factor=2.5, + safety_factor_scale="log", + ), +) +def test_torch___floordiv__( + dtype_and_x, frontend_method_data, init_flags, method_flags, + frontend, on_device, backend_fw, ): input_dtype, x = dtype_and_x + assume(not np.any(np.isclose(x[1], 0))) helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -442,36 +664,69 @@ def test_torch_tensor_add( method_input_dtypes=input_dtype, method_all_as_kwargs_np={ "other": x[1], - "alpha": alpha, }, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, frontend=frontend, - atol_=1e-02, + on_device=on_device, + atol_=1, + ) + + +# __getitem__ +@handle_frontend_method( + class_tree=CLASS_TREE, + init_tree="torch.tensor", + method_name="__getitem__", + dtype_x_index=helpers.dtype_array_query( + available_dtypes=helpers.get_dtypes("valid"), + allow_neg_step=False, + ), +) +def test_torch___getitem__( + dtype_x_index, + frontend_method_data, + init_flags, + method_flags, + frontend, + on_device, + backend_fw, +): + input_dtype, x, index = dtype_x_index + helpers.test_frontend_method( + init_input_dtypes=[input_dtype[0]], + backend_to_test=backend_fw, + init_all_as_kwargs_np={"data": x}, + method_input_dtypes=[*input_dtype[1:]], + method_all_as_kwargs_np={"query": index}, + frontend_method_data=frontend_method_data, + init_flags=init_flags, + method_flags=method_flags, + frontend=frontend, on_device=on_device, ) -# divide +# __gt__ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="divide", + method_name="__gt__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), + available_dtypes=helpers.get_dtypes("float"), num_arrays=2, min_value=-1e04, max_value=1e04, allow_inf=False, ), ) -def test_torch_tensor_divide( +def test_torch___gt__( dtype_and_x, - frontend, frontend_method_data, init_flags, method_flags, + frontend, on_device, backend_fw, ): @@ -490,346 +745,201 @@ def test_torch_tensor_divide( init_flags=init_flags, method_flags=method_flags, frontend=frontend, - atol_=1e-02, on_device=on_device, ) -# addmm +# __invert__ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="addmm", - dtype_and_matrices=_get_dtype_input_and_matrices(with_input=True), - beta=st.floats( - min_value=-5, - max_value=5, - allow_nan=False, - allow_subnormal=False, - allow_infinity=False, - ), - alpha=st.floats( - min_value=-5, - max_value=5, - allow_nan=False, - allow_subnormal=False, - allow_infinity=False, + method_name="__invert__", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("integer"), + num_arrays=1, ), ) -def test_torch_tensor_addmm( - dtype_and_matrices, - beta, - alpha, - frontend, +def test_torch___invert__( + dtype_and_x, frontend_method_data, init_flags, method_flags, + frontend, on_device, backend_fw, ): - input_dtype, x, mat1, mat2 = dtype_and_matrices + input_dtype, x = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ - "data": x, + "data": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "mat1": mat1, - "mat2": mat2, - "beta": beta, - "alpha": alpha, - }, + method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, frontend=frontend, - atol_=1e-02, on_device=on_device, ) -# addmm_ +# __long__ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="addmm_", - dtype_and_matrices=_get_dtype_input_and_matrices(with_input=True), - beta=st.floats( - min_value=-5, - max_value=5, - allow_nan=False, - allow_subnormal=False, - allow_infinity=False, - ), - alpha=st.floats( - min_value=-5, - max_value=5, - allow_nan=False, - allow_subnormal=False, - allow_infinity=False, + method_name="__long__", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("integer"), + min_value=-1e04, + max_value=1e04, + allow_inf=False, ), ) -def test_torch_tensor_addmm_( - dtype_and_matrices, - beta, - alpha, - frontend, +def test_torch___long__( + dtype_and_x, frontend_method_data, init_flags, method_flags, + frontend, on_device, backend_fw, ): - input_dtype, x, mat1, mat2 = dtype_and_matrices + input_dtype, x = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ - "data": x, + "data": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "mat1": mat1, - "mat2": mat2, - "beta": beta, - "alpha": alpha, - }, + method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, frontend=frontend, - atol_=1e-02, on_device=on_device, ) -# addmv +# __lt__ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="addmv", - dtype_and_matrices=_get_dtype_input_and_mat_vec(with_input=True), - beta=st.floats( - min_value=-5, - max_value=5, - allow_nan=False, - allow_subnormal=False, - allow_infinity=False, - ), - alpha=st.floats( - min_value=-5, - max_value=5, - allow_nan=False, - allow_subnormal=False, - allow_infinity=False, + method_name="__lt__", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, + min_value=-1e04, + max_value=1e04, + allow_inf=False, ), ) -def test_torch_tensor_addmv( - dtype_and_matrices, - beta, - alpha, - frontend, +def test_torch___lt__( + dtype_and_x, frontend_method_data, init_flags, method_flags, + frontend, on_device, backend_fw, ): - input_dtype, x, mat, vec = dtype_and_matrices + input_dtype, x = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ - "data": x, + "data": x[0], }, method_input_dtypes=input_dtype, method_all_as_kwargs_np={ - "mat": mat, - "vec": vec, - "beta": beta, - "alpha": alpha, + "other": x[1], }, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, frontend=frontend, - atol_=1e-02, on_device=on_device, ) -# addmv_ +# __matmul__ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="addmv_", - dtype_and_matrices=_get_dtype_input_and_mat_vec(with_input=True), - beta=st.floats( - min_value=-5, - max_value=5, - allow_nan=False, - allow_subnormal=False, - allow_infinity=False, - ), - alpha=st.floats( - min_value=-5, - max_value=5, - allow_nan=False, - allow_subnormal=False, - allow_infinity=False, - ), + method_name="__matmul__", + dtype_tensor1_tensor2=_get_dtype_and_multiplicative_matrices(), ) -def test_torch_tensor_addmv_( - dtype_and_matrices, - beta, - alpha, - frontend, +def test_torch___matmul__( + dtype_tensor1_tensor2, frontend_method_data, init_flags, method_flags, - on_device, - backend_fw, -): - input_dtype, x, mat, vec = dtype_and_matrices - helpers.test_frontend_method( - init_input_dtypes=input_dtype, - init_all_as_kwargs_np={ - "data": x, - }, - method_input_dtypes=input_dtype, - backend_to_test=backend_fw, - method_all_as_kwargs_np={ - "mat": mat, - "vec": vec, - "beta": beta, - "alpha": alpha, - }, - frontend_method_data=frontend_method_data, - init_flags=init_flags, - method_flags=method_flags, - frontend=frontend, - atol_=1e-02, - on_device=on_device, - ) - - -# addbmm -@handle_frontend_method( - class_tree=CLASS_TREE, - init_tree="torch.tensor", - method_name="addbmm", - dtype_and_matrices=_get_dtype_and_3dbatch_matrices(with_input=True), - beta=st.floats( - min_value=-5, - max_value=5, - allow_nan=False, - allow_subnormal=False, - allow_infinity=False, - ), - alpha=st.floats( - min_value=-5, - max_value=5, - allow_nan=False, - allow_subnormal=False, - allow_infinity=False, - ), -) -def test_torch_tensor_addbmm( - dtype_and_matrices, - beta, - alpha, frontend, - frontend_method_data, - init_flags, - method_flags, on_device, backend_fw, ): - input_dtype, x, batch1, batch2 = dtype_and_matrices + dtype, tensor1, tensor2 = dtype_tensor1_tensor2 helpers.test_frontend_method( - init_input_dtypes=input_dtype, + init_input_dtypes=dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ - "data": x, - }, - method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "batch1": batch1, - "batch2": batch2, - "beta": beta, - "alpha": alpha, + "data": tensor1, }, + method_input_dtypes=dtype, + method_all_as_kwargs_np={"other": tensor2}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, frontend=frontend, - atol_=1e-02, on_device=on_device, ) -# addbmm_ +# __mod__ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="addbmm_", - dtype_and_matrices=_get_dtype_and_3dbatch_matrices(with_input=True), - beta=st.floats( - min_value=-5, - max_value=5, - allow_nan=False, - allow_subnormal=False, - allow_infinity=False, - ), - alpha=st.floats( - min_value=-5, - max_value=5, - allow_nan=False, - allow_subnormal=False, - allow_infinity=False, + method_name="__mod__", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, ), ) -def test_torch_tensor_addbmm_( - dtype_and_matrices, - beta, - alpha, - frontend, +def test_torch___mod__( + dtype_and_x, frontend_method_data, init_flags, method_flags, + frontend, on_device, backend_fw, ): - input_dtype, x, batch1, batch2 = dtype_and_matrices + input_dtype, x = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ - "data": x, + "data": x[0], }, method_input_dtypes=input_dtype, method_all_as_kwargs_np={ - "batch1": batch1, - "batch2": batch2, - "beta": beta, - "alpha": alpha, + "other": x[1], }, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, frontend=frontend, - atol_=1e-02, on_device=on_device, ) -# sub +# __mul__ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="sub", + method_name="__mul__", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), num_arrays=2, @@ -837,15 +947,13 @@ def test_torch_tensor_addbmm_( max_value=1e04, allow_inf=False, ), - alpha=st.floats(min_value=-1e04, max_value=1e04, allow_infinity=False), ) -def test_torch_tensor_sub( +def test_torch___mul__( dtype_and_x, - alpha, - frontend, frontend_method_data, init_flags, method_flags, + frontend, on_device, backend_fw, ): @@ -859,43 +967,35 @@ def test_torch_tensor_sub( method_input_dtypes=input_dtype, method_all_as_kwargs_np={ "other": x[1], - "alpha": alpha, }, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, frontend=frontend, - atol_=1e-02, on_device=on_device, ) -# new_ones +# __ne__ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="new_ones", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), - size=helpers.get_shape( - allow_none=False, - min_num_dims=1, - max_num_dims=5, - min_dim_size=1, - max_dim_size=10, + method_name="__ne__", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, + min_value=-1e04, + max_value=1e04, + allow_inf=False, ), - dtypes=_dtypes(), - requires_grad=_requires_grad(), ) -def test_torch_tensor_new_ones( +def test_torch___ne__( dtype_and_x, - size, - dtypes, - requires_grad, - on_device, frontend_method_data, init_flags, method_flags, frontend, + on_device, backend_fw, ): input_dtype, x = dtype_and_x @@ -905,12 +1005,9 @@ def test_torch_tensor_new_ones( init_all_as_kwargs_np={ "data": x[0], }, - method_input_dtypes=dtypes, + method_input_dtypes=input_dtype, method_all_as_kwargs_np={ - "size": size, - "dtype": dtypes[0], - "requires_grad": requires_grad, - "device": on_device, + "other": x[1], }, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -920,32 +1017,25 @@ def test_torch_tensor_new_ones( ) -# new_zeros +# __neg__ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="new_zeros", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("valid")), - size=helpers.get_shape( - allow_none=False, - min_num_dims=1, - max_num_dims=5, - min_dim_size=1, - max_dim_size=10, + method_name="__neg__", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_value=-1e04, + max_value=1e04, + allow_inf=False, ), - dtypes=_dtypes(), - requires_grad=_requires_grad(), ) -def test_torch_tensor_new_zeros( +def test_torch___neg__( dtype_and_x, - size, - dtypes, - requires_grad, - on_device, frontend_method_data, init_flags, method_flags, frontend, + on_device, backend_fw, ): input_dtype, x = dtype_and_x @@ -955,13 +1045,8 @@ def test_torch_tensor_new_zeros( init_all_as_kwargs_np={ "data": x[0], }, - method_input_dtypes=dtypes, - method_all_as_kwargs_np={ - "size": size, - "dtype": dtypes[0], - "requires_grad": requires_grad, - "device": on_device, - }, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -970,23 +1055,21 @@ def test_torch_tensor_new_zeros( ) +# __or__ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="reshape", - dtype_x=helpers.dtype_and_values( + method_name="__or__", + dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), - shape=st.shared(helpers.get_shape(), key="value_shape"), - ), - shape=helpers.reshape_shapes( - shape=st.shared(helpers.get_shape(), key="value_shape") + num_arrays=2, + min_value=-1e04, + max_value=1e04, + allow_inf=False, ), - unpack_shape=st.booleans(), ) -def test_torch_tensor_reshape( - dtype_x, - shape, - unpack_shape, +def test_torch___or__( + dtype_and_x, frontend_method_data, init_flags, method_flags, @@ -994,16 +1077,7 @@ def test_torch_tensor_reshape( on_device, backend_fw, ): - input_dtype, x = dtype_x - shape = { - "shape": shape, - } - if unpack_shape: - method_flags.num_positional_args = len(shape["shape"]) + 1 - i = 0 - for x_ in shape["shape"]: - shape["x{}".format(i)] = x_ - i += 1 + input_dtype, x = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -1011,7 +1085,9 @@ def test_torch_tensor_reshape( "data": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np=shape, + method_all_as_kwargs_np={ + "other": x[1], + }, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -1020,17 +1096,18 @@ def test_torch_tensor_reshape( ) -# reshape_as +# __pow__ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="reshape_as", - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), num_arrays=2 + method_name="__pow__", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, ), ) -def test_torch_tensor_reshape_as( - dtype_x, +def test_torch___pow__( + dtype_and_x, frontend_method_data, init_flags, method_flags, @@ -1038,7 +1115,10 @@ def test_torch_tensor_reshape_as( on_device, backend_fw, ): - input_dtype, x = dtype_x + input_dtype, x = dtype_and_x + dtype = input_dtype[0] + if "int" in dtype: + x[1] = ivy.abs(x[1]) helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -1047,7 +1127,7 @@ def test_torch_tensor_reshape_as( }, method_input_dtypes=input_dtype, method_all_as_kwargs_np={ - "other": x[1], + "exponent": x[1], }, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -1057,17 +1137,20 @@ def test_torch_tensor_reshape_as( ) -# sin +# __radd__ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="sin", + method_name="__radd__", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, + min_value=-1e04, + max_value=1e04, allow_inf=False, ), ) -def test_torch_tensor_sin( +def test_torch___radd__( dtype_and_x, frontend_method_data, init_flags, @@ -1084,7 +1167,9 @@ def test_torch_tensor_sin( "data": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, + method_all_as_kwargs_np={ + "other": x[1], + }, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -1093,17 +1178,20 @@ def test_torch_tensor_sin( ) -# arcsin +# __rmul__ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="arcsin", + method_name="__rmul__", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, + min_value=-1e04, + max_value=1e04, allow_inf=False, ), ) -def test_torch_tensor_arcsin( +def test_torch___rmul__( dtype_and_x, frontend_method_data, init_flags, @@ -1120,7 +1208,9 @@ def test_torch_tensor_arcsin( "data": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, + method_all_as_kwargs_np={ + "other": x[1], + }, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -1129,20 +1219,19 @@ def test_torch_tensor_arcsin( ) -# sum +# __rpow__ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="sum", - dtype_x_dim=_get_castable_dtype( - min_value=-1e04, - max_value=1e04, + method_name="__rpow__", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, + min_value=1, ), - keepdim=st.booleans(), ) -def test_torch_tensor_sum( - dtype_x_dim, - keepdim, +def test_torch___rpow__( + dtype_and_x, frontend_method_data, init_flags, method_flags, @@ -1150,10 +1239,10 @@ def test_torch_tensor_sum( on_device, backend_fw, ): - input_dtype, x, dim, castable_dtype = dtype_x_dim - if method_flags.as_variable: - castable_dtype = input_dtype - input_dtype = [input_dtype] + input_dtype, x = dtype_and_x + dtype = input_dtype[0] + if "int" in dtype: + x[0] = ivy.abs(x[0]) helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -1162,9 +1251,7 @@ def test_torch_tensor_sum( }, method_input_dtypes=input_dtype, method_all_as_kwargs_np={ - "dim": dim, - "keepdim": keepdim, - "dtype": castable_dtype, + "other": x[1], }, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -1174,17 +1261,17 @@ def test_torch_tensor_sum( ) -# atan +# __rsub__ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="atan", + method_name="__rsub__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - allow_inf=False, + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, ), ) -def test_torch_tensor_atan( +def test_torch___rsub__( dtype_and_x, frontend_method_data, init_flags, @@ -1201,7 +1288,9 @@ def test_torch_tensor_atan( "data": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, + method_all_as_kwargs_np={ + "other": x[1], + }, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -1210,18 +1299,18 @@ def test_torch_tensor_atan( ) -# atan2 +# __setitem__ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="atan2", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - ), + method_name="__setitem__", + dtypes_x_index_val=helpers.dtype_array_query_val( + available_dtypes=helpers.get_dtypes("valid"), + allow_neg_step=False, + ).filter(lambda x: x[0][0] == x[0][-1]), ) -def test_torch_tensor_atan2( - dtype_and_x, +def test_torch___setitem__( + dtypes_x_index_val, frontend_method_data, init_flags, method_flags, @@ -1229,17 +1318,13 @@ def test_torch_tensor_atan2( on_device, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, x, index, val = dtypes_x_index_val helpers.test_frontend_method( - init_input_dtypes=input_dtype, + init_input_dtypes=[input_dtype[0]], backend_to_test=backend_fw, - init_all_as_kwargs_np={ - "data": x[0], - }, - method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "other": x[1], - }, + init_all_as_kwargs_np={"data": x}, + method_input_dtypes=[*input_dtype[1:]], + method_all_as_kwargs_np={"key": index, "value": val}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -1248,17 +1333,20 @@ def test_torch_tensor_atan2( ) -# sin_ +# __sub__ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="sin_", + method_name="__sub__", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, + min_value=-1e04, + max_value=1e04, allow_inf=False, ), ) -def test_torch_tensor_sin_( +def test_torch___sub__( dtype_and_x, frontend_method_data, init_flags, @@ -1275,7 +1363,9 @@ def test_torch_tensor_sin_( "data": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, + method_all_as_kwargs_np={ + "other": x[1], + }, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -1284,17 +1374,21 @@ def test_torch_tensor_sin_( ) -# cos +# __truediv__ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="cos", + method_name="__truediv__", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), + shared_dtype=True, + num_arrays=2, + min_value=-1e04, + max_value=1e04, allow_inf=False, ), ) -def test_torch_tensor_cos( +def test_torch___truediv__( dtype_and_x, frontend_method_data, init_flags, @@ -1311,7 +1405,9 @@ def test_torch_tensor_cos( "data": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, + method_all_as_kwargs_np={ + "other": x[1], + }, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -1320,101 +1416,155 @@ def test_torch_tensor_cos( ) -# cos_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="cos_", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - allow_inf=False, - ), + method_name="__array__", + dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("valid")), + dtype=helpers.get_dtypes("valid", full=False), ) -def test_torch_tensor_cos_( +def test_torch__array__( dtype_and_x, + dtype, + frontend, + backend_fw, +): + input_dtype, x = dtype_and_x + if x[0].dtype == "bfloat16": + return + dtype[0] = np.dtype(dtype[0]) + ret_gt = torch.tensor(x[0]).__array__(dtype[0]) + with BackendHandler.update_backend(backend_fw) as ivy_backend: + local_importer = ivy_backend.utils.dynamic_import + function_module = local_importer.import_module("ivy.functional.frontends.torch") + ret = function_module.tensor(x[0]).__array__(dtype[0]) + + helpers.value_test( + ret_np_flat=ret.ravel(), + ret_np_from_gt_flat=ret_gt.ravel(), + ground_truth_backend="torch", + backend=backend_fw, + ) + + +@handle_frontend_method( + class_tree=CLASS_TREE, + init_tree="torch.tensor", + method_name="baddbmm_", + dtype_and_matrices=_get_dtype_and_3dbatch_matrices(with_input=True, input_3d=True), + beta=st.floats( + min_value=-5, + max_value=5, + allow_nan=False, + allow_subnormal=False, + allow_infinity=False, + ), + alpha=st.floats( + min_value=-5, + max_value=5, + allow_nan=False, + allow_subnormal=False, + allow_infinity=False, + ), +) +def test_torch_baddbmm_( + dtype_and_matrices, + beta, + alpha, + frontend, frontend_method_data, init_flags, method_flags, - frontend, on_device, - backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, x, batch1, batch2 = dtype_and_matrices helpers.test_frontend_method( init_input_dtypes=input_dtype, - backend_to_test=backend_fw, - init_all_as_kwargs_np={ - "data": list(x[0]) if type(x[0]) == int else x[0], - }, + init_all_as_kwargs_np={"data": x[0]}, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, + method_all_as_kwargs_np={ + "batch1": batch1, + "batch2": batch2, + "beta": beta, + "alpha": alpha, + }, + frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, - frontend=frontend, on_device=on_device, ) -# sinh +# index_fill @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="sinh", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - allow_inf=False, + method_name="index_fill", + dtype_indices_axis=helpers.array_indices_axis( + array_dtypes=helpers.get_dtypes("numeric"), + indices_dtypes=["int64"], + min_num_dims=1, + max_num_dims=5, + min_dim_size=1, + max_dim_size=10, + first_dimension_only=True, + indices_same_dims=False, ), + value=st.floats(min_value=-100, max_value=100), ) -def test_torch_tensor_sinh( - dtype_and_x, +def test_torch_index_fill( + dtype_indices_axis, + value, + frontend, frontend_method_data, init_flags, method_flags, - frontend, on_device, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtypes, x, indices, axis, _ = dtype_indices_axis + if indices.ndim != 1: + indices = ivy.flatten(indices) helpers.test_frontend_method( - init_input_dtypes=input_dtype, + init_input_dtypes=[input_dtypes[0]], backend_to_test=backend_fw, - init_all_as_kwargs_np={ - "data": x[0], + init_all_as_kwargs_np={"data": x}, + method_input_dtypes=[input_dtypes[1]], + method_all_as_kwargs_np={ + "dim": axis, + "index": indices, + "value": value, }, - method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, + frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, - frontend=frontend, on_device=on_device, ) -# sinh_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="sinh_", + method_name="sinc", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - allow_inf=False, ), ) -def test_torch_tensor_sinh_( +def test_torch_instance_sinc( + *, dtype_and_x, + frontend, + backend_fw, frontend_method_data, init_flags, method_flags, - frontend, on_device, - backend_fw, ): input_dtype, x = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, - backend_to_test=backend_fw, init_all_as_kwargs_np={ "data": x[0], }, @@ -1424,36 +1574,34 @@ def test_torch_tensor_sinh_( init_flags=init_flags, method_flags=method_flags, frontend=frontend, + backend_to_test=backend_fw, on_device=on_device, ) -# cosh +# isnan @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="cosh", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - allow_inf=False, + method_name="isnan", + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), ), ) -def test_torch_tensor_cosh( - dtype_and_x, +def test_torch_isnan( + dtype_x, + frontend, frontend_method_data, init_flags, method_flags, - frontend, on_device, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, x = dtype_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, - init_all_as_kwargs_np={ - "data": x[0], - }, + init_all_as_kwargs_np={"data": x[0]}, method_input_dtypes=input_dtype, method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, @@ -1464,17 +1612,16 @@ def test_torch_tensor_cosh( ) -# cosh_ +# rsqrt_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="cosh_", + method_name="rsqrt_", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - allow_inf=False, ), ) -def test_torch_tensor_cosh_( +def test_torch_rsqrt_( dtype_and_x, frontend_method_data, init_flags, @@ -1497,63 +1644,42 @@ def test_torch_tensor_cosh_( method_flags=method_flags, frontend=frontend, on_device=on_device, - rtol_=1e-2, - atol_=1e-2, ) -# view -@handle_frontend_method( - class_tree=CLASS_TREE, - init_tree="torch.tensor", - method_name="view", +@given( dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - shape=st.shared(helpers.get_shape(), key="value_shape"), - ), - shape=helpers.reshape_shapes( - shape=st.shared(helpers.get_shape(min_num_dims=1), key="value_shape") + available_dtypes=helpers.get_dtypes("valid", prune_function=False), ), + requires_grad=st.booleans(), ) -def test_torch_tensor_view( +def test_torch_tensor__requires_grad( dtype_x, - shape, - frontend_method_data, - init_flags, - method_flags, - frontend, - on_device, + requires_grad, backend_fw, ): - input_dtype, x = dtype_x - helpers.test_frontend_method( - init_input_dtypes=input_dtype, - backend_to_test=backend_fw, - init_all_as_kwargs_np={ - "data": x[0], - }, - method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "size": shape, - }, - frontend_method_data=frontend_method_data, - init_flags=init_flags, - method_flags=method_flags, - frontend=frontend, - on_device=on_device, - ) + ivy.set_backend(backend_fw) + _, data = dtype_x + x = Tensor(data[0]) + assert not x._requires_grad + x.requires_grad_() + assert x._requires_grad + x.requires_grad_(requires_grad) + assert x._requires_grad == requires_grad + ivy.previous_backend() +# abs @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="float", - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), + method_name="abs", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), ), ) -def test_torch_tensor_float( - dtype_x, +def test_torch_tensor_abs( + dtype_and_x, frontend_method_data, init_flags, method_flags, @@ -1561,7 +1687,7 @@ def test_torch_tensor_float( on_device, backend_fw, ): - input_dtype, x = dtype_x + input_dtype, x = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -1578,26 +1704,28 @@ def test_torch_tensor_float( ) +# abs_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="double", + method_name="abs_", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), + available_dtypes=helpers.get_dtypes("float"), ), ) -def test_torch_tensor_double( +def test_torch_tensor_abs_( dtype_and_x, frontend_method_data, init_flags, method_flags, frontend, - backend_fw, on_device, + backend_fw, ): input_dtype, x = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, + backend_to_test=backend_fw, init_all_as_kwargs_np={ "data": x[0], }, @@ -1607,22 +1735,21 @@ def test_torch_tensor_double( init_flags=init_flags, method_flags=method_flags, frontend=frontend, - backend_to_test=backend_fw, on_device=on_device, ) -# asinh +# acos @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="asinh", + method_name="acos", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), allow_inf=False, ), ) -def test_torch_tensor_asinh( +def test_torch_tensor_acos( dtype_and_x, frontend_method_data, init_flags, @@ -1644,23 +1771,22 @@ def test_torch_tensor_asinh( init_flags=init_flags, method_flags=method_flags, frontend=frontend, - rtol_=1e-2, - atol_=1e-2, on_device=on_device, ) -# asinh_ +# acos_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="asinh_", + method_name="acos_", dtype_and_x=helpers.dtype_and_values( + min_value=-1.0, + max_value=1.0, available_dtypes=helpers.get_dtypes("float"), - allow_inf=False, ), ) -def test_torch_tensor_asinh_( +def test_torch_tensor_acos_( dtype_and_x, frontend_method_data, init_flags, @@ -1676,29 +1802,27 @@ def test_torch_tensor_asinh_( init_all_as_kwargs_np={ "data": x[0], }, - method_input_dtypes=input_dtype, + method_input_dtypes=[], method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, frontend=frontend, - rtol_=1e-2, - atol_=1e-2, on_device=on_device, ) -# tan +# acosh @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="tan", + method_name="acosh", dtype_and_x=helpers.dtype_and_values( + min_value=1.0, available_dtypes=helpers.get_dtypes("float"), - allow_inf=False, ), ) -def test_torch_tensor_tan( +def test_torch_tensor_acosh( dtype_and_x, frontend_method_data, init_flags, @@ -1714,7 +1838,7 @@ def test_torch_tensor_tan( init_all_as_kwargs_np={ "data": x[0], }, - method_input_dtypes=input_dtype, + method_input_dtypes=[], method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -1724,17 +1848,17 @@ def test_torch_tensor_tan( ) -# tanh +# acosh_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="tanh", + method_name="acosh_", dtype_and_x=helpers.dtype_and_values( + min_value=1.0, available_dtypes=helpers.get_dtypes("float"), - allow_inf=False, ), ) -def test_torch_tensor_tanh( +def test_torch_tensor_acosh_( dtype_and_x, frontend_method_data, init_flags, @@ -1750,7 +1874,7 @@ def test_torch_tensor_tanh( init_all_as_kwargs_np={ "data": x[0], }, - method_input_dtypes=input_dtype, + method_input_dtypes=[], method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -1760,22 +1884,27 @@ def test_torch_tensor_tanh( ) -# tanh_ +# add @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="tanh_", + method_name="add", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, + min_value=-1e04, + max_value=1e04, allow_inf=False, ), + alpha=st.floats(min_value=-1e04, max_value=1e04, allow_infinity=False), ) -def test_torch_tensor_tanh_( +def test_torch_tensor_add( dtype_and_x, + alpha, + frontend, frontend_method_data, init_flags, method_flags, - frontend, on_device, backend_fw, ): @@ -1787,26 +1916,30 @@ def test_torch_tensor_tanh_( "data": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, + method_all_as_kwargs_np={ + "other": x[1], + "alpha": alpha, + }, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, frontend=frontend, + atol_=1e-02, on_device=on_device, ) -# asin +# add_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="asin", + method_name="add_", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - allow_inf=False, + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, ), ) -def test_torch_tensor_asin( +def test_torch_tensor_add_( dtype_and_x, frontend_method_data, init_flags, @@ -1817,13 +1950,15 @@ def test_torch_tensor_asin( ): input_dtype, x = dtype_and_x helpers.test_frontend_method( - init_input_dtypes=input_dtype, + init_input_dtypes=[input_dtype[0]], backend_to_test=backend_fw, init_all_as_kwargs_np={ "data": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, + method_all_as_kwargs_np={ + "other": x[1], + }, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -1832,500 +1967,686 @@ def test_torch_tensor_asin( ) -# amax +# addbmm @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="amax", - dtype_x_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("numeric"), - valid_axis=True, - force_int_axis=True, + method_name="addbmm", + dtype_and_matrices=_get_dtype_and_3dbatch_matrices(with_input=True), + beta=st.floats( + min_value=-5, + max_value=5, + allow_nan=False, + allow_subnormal=False, + allow_infinity=False, + ), + alpha=st.floats( + min_value=-5, + max_value=5, + allow_nan=False, + allow_subnormal=False, + allow_infinity=False, ), - keepdim=st.booleans(), ) -def test_torch_tensor_amax( - dtype_x_axis, - keepdim, +def test_torch_tensor_addbmm( + dtype_and_matrices, + beta, + alpha, + frontend, frontend_method_data, init_flags, method_flags, - frontend, on_device, backend_fw, ): - input_dtype, x, axis = dtype_x_axis + input_dtype, x, batch1, batch2 = dtype_and_matrices helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ - "data": x[0], + "data": x, }, method_input_dtypes=input_dtype, method_all_as_kwargs_np={ - "dim": axis, - "keepdim": keepdim, + "batch1": batch1, + "batch2": batch2, + "beta": beta, + "alpha": alpha, }, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, frontend=frontend, + atol_=1e-02, on_device=on_device, ) -# abs +# addbmm_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="abs", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + method_name="addbmm_", + dtype_and_matrices=_get_dtype_and_3dbatch_matrices(with_input=True), + beta=st.floats( + min_value=-5, + max_value=5, + allow_nan=False, + allow_subnormal=False, + allow_infinity=False, + ), + alpha=st.floats( + min_value=-5, + max_value=5, + allow_nan=False, + allow_subnormal=False, + allow_infinity=False, ), ) -def test_torch_tensor_abs( - dtype_and_x, +def test_torch_tensor_addbmm_( + dtype_and_matrices, + beta, + alpha, + frontend, frontend_method_data, init_flags, method_flags, - frontend, on_device, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, x, batch1, batch2 = dtype_and_matrices helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ - "data": x[0], + "data": x, }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, + method_all_as_kwargs_np={ + "batch1": batch1, + "batch2": batch2, + "beta": beta, + "alpha": alpha, + }, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, frontend=frontend, + atol_=1e-02, on_device=on_device, ) -# abs_ +# addcdiv @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="abs_", + method_name="addcdiv", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), + num_arrays=3, + large_abs_safety_factor=2.5, + small_abs_safety_factor=2.5, + safety_factor_scale="log", + shared_dtype=True, ), + value=st.floats(min_value=-100, max_value=100), ) -def test_torch_tensor_abs_( +def test_torch_tensor_addcdiv( dtype_and_x, + value, + frontend, frontend_method_data, init_flags, method_flags, - frontend, on_device, backend_fw, ): input_dtype, x = dtype_and_x + assume(not np.any(np.isclose(x[2], 0))) + helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, - init_all_as_kwargs_np={ - "data": x[0], - }, + init_all_as_kwargs_np={"data": x[0]}, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, + method_all_as_kwargs_np={ + "tensor1": x[1], + "tensor2": x[2], + "value": value, + }, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, frontend=frontend, on_device=on_device, + atol_=1e-03, ) -# amin @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="amin", - dtype_x_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("numeric"), - valid_axis=True, - force_int_axis=True, + method_name="addcdiv_", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + num_arrays=3, + large_abs_safety_factor=2.5, + small_abs_safety_factor=2.5, + safety_factor_scale="log", + shared_dtype=True, ), - keepdim=st.booleans(), + value=st.floats(min_value=-100, max_value=100), ) -def test_torch_tensor_amin( - dtype_x_axis, - keepdim, +def test_torch_tensor_addcdiv_( + dtype_and_x, + value, + frontend, frontend_method_data, init_flags, method_flags, - frontend, on_device, backend_fw, ): - input_dtype, x, axis = dtype_x_axis + input_dtype, x = dtype_and_x + assume(not np.any(np.isclose(x[2], 0))) + helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, - init_all_as_kwargs_np={ - "data": x[0], - }, + init_all_as_kwargs_np={"data": x[0]}, method_input_dtypes=input_dtype, method_all_as_kwargs_np={ - "dim": axis, - "keepdim": keepdim, + "tensor1": x[1], + "tensor2": x[2], + "value": value, }, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, frontend=frontend, on_device=on_device, + atol_=1e-03, ) -# aminmax +# addcmul @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="aminmax", - dtype_input_axis=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), + method_name="addcmul", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + num_arrays=3, + large_abs_safety_factor=2.5, + small_abs_safety_factor=2.5, + safety_factor_scale="log", + shared_dtype=True, ), + value=st.floats(min_value=-100, max_value=100), ) -def test_torch_tensor_aminmax( - dtype_input_axis, +def test_torch_tensor_addcmul( + dtype_and_x, + value, + frontend, frontend_method_data, init_flags, method_flags, - frontend, on_device, backend_fw, ): - input_dtype, x = dtype_input_axis + input_dtype, x = dtype_and_x + helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, - init_all_as_kwargs_np={ - "data": x[0], + init_all_as_kwargs_np={"data": x[0]}, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={ + "tensor1": x[1], + "tensor2": x[2], + "value": value, }, - method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, frontend=frontend, on_device=on_device, + atol_=1e-02, ) -# bernoulli +# addcmul_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="bernoulli", + method_name="addcmul_", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), + available_dtypes=helpers.get_dtypes("float"), + num_arrays=3, + large_abs_safety_factor=2.5, + small_abs_safety_factor=2.5, + safety_factor_scale="log", + shared_dtype=True, ), - test_with_out=st.just(True), + value=st.floats(min_value=-100, max_value=100), ) -def test_torch_tensor_bernoulli( +def test_torch_tensor_addcmul_( dtype_and_x, + value, frontend, frontend_method_data, init_flags, method_flags, + on_device, backend_fw, ): input_dtype, x = dtype_and_x + helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, - init_all_as_kwargs_np={ - "input": x[0], - }, + init_all_as_kwargs_np={"data": x[0]}, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={"generator": x[1], "out": x[2]}, + method_all_as_kwargs_np={ + "tensor1": x[1], + "tensor2": x[2], + "value": value, + }, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, frontend=frontend, + on_device=on_device, + atol_=1e-02, ) -# contiguous +# addmm @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="contiguous", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - allow_inf=False, + method_name="addmm", + dtype_and_matrices=_get_dtype_input_and_matrices(with_input=True), + beta=st.floats( + min_value=-5, + max_value=5, + allow_nan=False, + allow_subnormal=False, + allow_infinity=False, + ), + alpha=st.floats( + min_value=-5, + max_value=5, + allow_nan=False, + allow_subnormal=False, + allow_infinity=False, ), ) -def test_torch_tensor_contiguous( - dtype_and_x, +def test_torch_tensor_addmm( + dtype_and_matrices, + beta, + alpha, + frontend, frontend_method_data, init_flags, method_flags, - frontend, on_device, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, x, mat1, mat2 = dtype_and_matrices helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ - "data": x[0], + "data": x, }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, + method_all_as_kwargs_np={ + "mat1": mat1, + "mat2": mat2, + "beta": beta, + "alpha": alpha, + }, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, frontend=frontend, + atol_=1e-02, on_device=on_device, ) -# log +# addmm_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="log", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - allow_inf=False, + method_name="addmm_", + dtype_and_matrices=_get_dtype_input_and_matrices(with_input=True), + beta=st.floats( + min_value=-5, + max_value=5, + allow_nan=False, + allow_subnormal=False, + allow_infinity=False, + ), + alpha=st.floats( + min_value=-5, + max_value=5, + allow_nan=False, + allow_subnormal=False, + allow_infinity=False, ), ) -def test_torch_tensor_log( - dtype_and_x, +def test_torch_tensor_addmm_( + dtype_and_matrices, + beta, + alpha, + frontend, frontend_method_data, init_flags, method_flags, - frontend, on_device, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, x, mat1, mat2 = dtype_and_matrices helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ - "data": x[0], + "data": x, }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, + method_all_as_kwargs_np={ + "mat1": mat1, + "mat2": mat2, + "beta": beta, + "alpha": alpha, + }, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, frontend=frontend, + atol_=1e-02, on_device=on_device, ) -# log2_ +# addmv @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="log2_", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - allow_inf=False, + method_name="addmv", + dtype_and_matrices=_get_dtype_input_and_mat_vec(with_input=True), + beta=st.floats( + min_value=-5, + max_value=5, + allow_nan=False, + allow_subnormal=False, + allow_infinity=False, + ), + alpha=st.floats( + min_value=-5, + max_value=5, + allow_nan=False, + allow_subnormal=False, + allow_infinity=False, ), ) -def test_torch_tensor_log2_( - dtype_and_x, +def test_torch_tensor_addmv( + dtype_and_matrices, + beta, + alpha, + frontend, frontend_method_data, init_flags, method_flags, - frontend, on_device, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, x, mat, vec = dtype_and_matrices helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ - "data": x[0], + "data": x, }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, + method_all_as_kwargs_np={ + "mat": mat, + "vec": vec, + "beta": beta, + "alpha": alpha, + }, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, frontend=frontend, + atol_=1e-02, on_device=on_device, ) -# log_ +# addmv_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="log_", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - allow_inf=False, + method_name="addmv_", + dtype_and_matrices=_get_dtype_input_and_mat_vec(with_input=True), + beta=st.floats( + min_value=-5, + max_value=5, + allow_nan=False, + allow_subnormal=False, + allow_infinity=False, + ), + alpha=st.floats( + min_value=-5, + max_value=5, + allow_nan=False, + allow_subnormal=False, + allow_infinity=False, ), ) -def test_torch_tensor_log_( - dtype_and_x, +def test_torch_tensor_addmv_( + dtype_and_matrices, + beta, + alpha, + frontend, frontend_method_data, init_flags, method_flags, - frontend, on_device, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, x, mat, vec = dtype_and_matrices helpers.test_frontend_method( init_input_dtypes=input_dtype, - backend_to_test=backend_fw, init_all_as_kwargs_np={ - "data": x[0], + "data": x, }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, + backend_to_test=backend_fw, + method_all_as_kwargs_np={ + "mat": mat, + "vec": vec, + "beta": beta, + "alpha": alpha, + }, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, frontend=frontend, + atol_=1e-02, on_device=on_device, ) -# log2 +# addr @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="log2", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - allow_inf=False, + method_name="addr", + dtype_and_vecs=_get_dtype_input_and_vectors(with_input=True), + beta=st.floats( + min_value=-5, + max_value=5, + allow_nan=False, + allow_subnormal=False, + allow_infinity=False, + ), + alpha=st.floats( + min_value=-5, + max_value=5, + allow_nan=False, + allow_subnormal=False, + allow_infinity=False, ), ) -def test_torch_tensor_log2( - dtype_and_x, +def test_torch_tensor_addr( + dtype_and_vecs, + beta, + alpha, + frontend, frontend_method_data, init_flags, method_flags, - frontend, on_device, backend_fw, ): - input_dtype, x = dtype_and_x + dtype, input, vec1, vec2 = dtype_and_vecs helpers.test_frontend_method( - init_input_dtypes=input_dtype, + init_input_dtypes=dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ - "data": x[0], + "data": input, + }, + method_input_dtypes=dtype, + method_all_as_kwargs_np={ + "vec1": vec1, + "vec2": vec2, + "beta": beta, + "alpha": alpha, }, - method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, frontend=frontend, + atol_=1e-02, on_device=on_device, ) -# __bool__ +# addr_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="__bool__", - dtype_and_x=helpers.dtype_and_values( - max_dim_size=1, - min_value=-1e04, - max_value=1e04, + method_name="addr_", + dtype_and_vecs=_get_dtype_input_and_vectors(with_input=True), + beta=st.floats( + min_value=-5, + max_value=5, + allow_nan=False, + allow_subnormal=False, + allow_infinity=False, + ), + alpha=st.floats( + min_value=-5, + max_value=5, + allow_nan=False, + allow_subnormal=False, + allow_infinity=False, ), ) -def test_torch___bool__( - dtype_and_x, +def test_torch_tensor_addr_( + dtype_and_vecs, + beta, + alpha, + frontend, frontend_method_data, init_flags, method_flags, - frontend, on_device, backend_fw, ): - input_dtype, x = dtype_and_x + dtype, input, vec1, vec2 = dtype_and_vecs helpers.test_frontend_method( - init_input_dtypes=input_dtype, + init_input_dtypes=dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ - "data": x[0], + "data": input, + }, + method_input_dtypes=dtype, + method_all_as_kwargs_np={ + "vec1": vec1, + "vec2": vec2, + "beta": beta, + "alpha": alpha, }, - method_input_dtypes=[], - method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, frontend=frontend, + atol_=1e-02, on_device=on_device, ) -# __add__ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="__add__", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - min_value=-1e04, - max_value=1e04, - allow_inf=False, + method_name="adjoint", + dtype_and_values=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("real_and_complex"), + min_num_dims=2, + min_dim_size=2, ), ) -def test_torch___add__( - dtype_and_x, +def test_torch_tensor_adjoint( + dtype_and_values, + frontend, frontend_method_data, init_flags, method_flags, - frontend, on_device, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, values = dtype_and_values + helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ - "data": x[0], + "data": values[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "other": x[1], - }, - frontend_method_data=frontend_method_data, + method_all_as_kwargs_np={}, init_flags=init_flags, method_flags=method_flags, + frontend_method_data=frontend_method_data, frontend=frontend, on_device=on_device, ) -# arcsinh +# all @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="arcsinh", - dtype_and_x=helpers.dtype_and_values( - min_value=-1.0, - max_value=1.0, - available_dtypes=helpers.get_dtypes("float"), + method_name="all", + dtype_input_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("numeric"), + min_num_dims=1, + min_value=-1e04, + max_value=1e04, + valid_axis=True, + force_int_axis=True, ), + keepdim=st.booleans(), ) -def test_torch_tensor_arcsinh( - dtype_and_x, +def test_torch_tensor_all( + dtype_input_axis, + keepdim, frontend_method_data, init_flags, method_flags, @@ -2333,15 +2654,18 @@ def test_torch_tensor_arcsinh( on_device, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, x, axis = dtype_input_axis helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ "data": x[0], }, - method_input_dtypes=[], - method_all_as_kwargs_np={}, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={ + "dim": axis, + "keepdim": keepdim, + }, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -2350,19 +2674,21 @@ def test_torch_tensor_arcsinh( ) -# arcsinh_ +# amax @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="arcsinh_", - dtype_and_x=helpers.dtype_and_values( - min_value=-1.0, - max_value=1.0, - available_dtypes=helpers.get_dtypes("float"), + method_name="amax", + dtype_x_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("numeric"), + valid_axis=True, + force_int_axis=True, ), + keepdim=st.booleans(), ) -def test_torch_tensor_arcsinh_( - dtype_and_x, +def test_torch_tensor_amax( + dtype_x_axis, + keepdim, frontend_method_data, init_flags, method_flags, @@ -2370,15 +2696,18 @@ def test_torch_tensor_arcsinh_( on_device, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, x, axis = dtype_x_axis helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ "data": x[0], }, - method_input_dtypes=[], - method_all_as_kwargs_np={}, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={ + "dim": axis, + "keepdim": keepdim, + }, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -2387,20 +2716,21 @@ def test_torch_tensor_arcsinh_( ) -# __long__ +# amin @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="__long__", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("integer"), - min_value=-1e04, - max_value=1e04, - allow_inf=False, + method_name="amin", + dtype_x_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("numeric"), + valid_axis=True, + force_int_axis=True, ), + keepdim=st.booleans(), ) -def test_torch___long__( - dtype_and_x, +def test_torch_tensor_amin( + dtype_x_axis, + keepdim, frontend_method_data, init_flags, method_flags, @@ -2408,7 +2738,7 @@ def test_torch___long__( on_device, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, x, axis = dtype_x_axis helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -2416,7 +2746,10 @@ def test_torch___long__( "data": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, + method_all_as_kwargs_np={ + "dim": axis, + "keepdim": keepdim, + }, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -2425,21 +2758,17 @@ def test_torch___long__( ) -# __radd__ +# aminmax @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="__radd__", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - min_value=-1e04, - max_value=1e04, - allow_inf=False, + method_name="aminmax", + dtype_input_axis=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), ), ) -def test_torch___radd__( - dtype_and_x, +def test_torch_tensor_aminmax( + dtype_input_axis, frontend_method_data, init_flags, method_flags, @@ -2447,7 +2776,7 @@ def test_torch___radd__( on_device, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, x = dtype_input_axis helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -2455,9 +2784,7 @@ def test_torch___radd__( "data": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "other": x[1], - }, + method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -2466,62 +2793,58 @@ def test_torch___radd__( ) -# __sub__ +# angle @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="__sub__", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - min_value=-1e04, - max_value=1e04, - allow_inf=False, + method_name="angle", + dtype_and_values=helpers.dtype_and_values( + available_dtypes=["float64", "complex64", "complex128"], ), ) -def test_torch___sub__( - dtype_and_x, +def test_torch_tensor_angle( + dtype_and_values, + frontend, frontend_method_data, init_flags, method_flags, - frontend, on_device, - backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, values = dtype_and_values + helpers.test_frontend_method( init_input_dtypes=input_dtype, - backend_to_test=backend_fw, init_all_as_kwargs_np={ - "data": x[0], + "data": values[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "other": x[1], - }, - frontend_method_data=frontend_method_data, + method_all_as_kwargs_np={}, init_flags=init_flags, method_flags=method_flags, + frontend_method_data=frontend_method_data, frontend=frontend, on_device=on_device, ) -# __mul__ +# any @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="__mul__", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, + method_name="any", + dtype_input_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("numeric"), + min_num_dims=1, min_value=-1e04, max_value=1e04, - allow_inf=False, + valid_axis=True, + force_int_axis=True, ), + keepdim=st.booleans(), ) -def test_torch___mul__( - dtype_and_x, +def test_torch_tensor_any( + dtype_input_axis, + keepdim, frontend_method_data, init_flags, method_flags, @@ -2529,7 +2852,7 @@ def test_torch___mul__( on_device, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, x, axis = dtype_input_axis helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -2538,7 +2861,8 @@ def test_torch___mul__( }, method_input_dtypes=input_dtype, method_all_as_kwargs_np={ - "other": x[1], + "dim": axis, + "keepdim": keepdim, }, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -2548,60 +2872,62 @@ def test_torch___mul__( ) -@st.composite -def _get_dtype_and_multiplicative_matrices(draw): - return draw( - st.one_of( - _get_dtype_input_and_matrices(), - _get_dtype_and_3dbatch_matrices(), - ) - ) +# write test for torch instance apply_ -# __matmul__ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="__matmul__", - dtype_tensor1_tensor2=_get_dtype_and_multiplicative_matrices(), + method_name="apply_", + dtype_and_values=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + num_arrays=1, + ), ) -def test_torch___matmul__( - dtype_tensor1_tensor2, +def test_torch_tensor_apply_( + dtype_and_values, + frontend, frontend_method_data, init_flags, method_flags, - frontend, on_device, backend_fw, ): - dtype, tensor1, tensor2 = dtype_tensor1_tensor2 + def func(x): + return x + 1 + + input_dtype, values = dtype_and_values + helpers.test_frontend_method( - init_input_dtypes=dtype, + init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ - "data": tensor1, + "data": values[0], + }, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={ + "callable": func, }, - method_input_dtypes=dtype, - method_all_as_kwargs_np={"other": tensor2}, - frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, + frontend_method_data=frontend_method_data, frontend=frontend, on_device=on_device, ) -# __rsub__ +# arccos @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="__rsub__", + method_name="arccos", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, + min_value=-1.0, + max_value=1.0, + available_dtypes=helpers.get_dtypes("float"), ), ) -def test_torch___rsub__( +def test_torch_tensor_arccos( dtype_and_x, frontend_method_data, init_flags, @@ -2617,10 +2943,8 @@ def test_torch___rsub__( init_all_as_kwargs_np={ "data": x[0], }, - method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "other": x[1], - }, + method_input_dtypes=[], + method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -2629,20 +2953,18 @@ def test_torch___rsub__( ) -# __rmul__ +# arccos_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="__rmul__", + method_name="arccos_", dtype_and_x=helpers.dtype_and_values( + min_value=-1.0, + max_value=1.0, available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - min_value=-1e04, - max_value=1e04, - allow_inf=False, ), ) -def test_torch___rmul__( +def test_torch_tensor_arccos_( dtype_and_x, frontend_method_data, init_flags, @@ -2658,10 +2980,8 @@ def test_torch___rmul__( init_all_as_kwargs_np={ "data": x[0], }, - method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "other": x[1], - }, + method_input_dtypes=[], + method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -2670,21 +2990,18 @@ def test_torch___rmul__( ) -# __truediv__ +# arccosh @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="__truediv__", + method_name="arccosh", dtype_and_x=helpers.dtype_and_values( + min_value=-1.0, + max_value=1.0, available_dtypes=helpers.get_dtypes("float"), - shared_dtype=True, - num_arrays=2, - min_value=-1e04, - max_value=1e04, - allow_inf=False, ), ) -def test_torch___truediv__( +def test_torch_tensor_arccosh( dtype_and_x, frontend_method_data, init_flags, @@ -2700,10 +3017,8 @@ def test_torch___truediv__( init_all_as_kwargs_np={ "data": x[0], }, - method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "other": x[1], - }, + method_input_dtypes=[], + method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -2712,19 +3027,18 @@ def test_torch___truediv__( ) +# arccosh_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="__floordiv__", + method_name="arccosh_", dtype_and_x=helpers.dtype_and_values( + min_value=-1.0, + max_value=1.0, available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - large_abs_safety_factor=2.5, - small_abs_safety_factor=2.5, - safety_factor_scale="log", ), ) -def test_torch___floordiv__( +def test_torch_tensor_arccosh_( dtype_and_x, frontend_method_data, init_flags, @@ -2734,40 +3048,33 @@ def test_torch___floordiv__( backend_fw, ): input_dtype, x = dtype_and_x - assume(not np.any(np.isclose(x[1], 0))) helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ "data": x[0], }, - method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "other": x[1], - }, + method_input_dtypes=[], + method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, frontend=frontend, on_device=on_device, - atol_=1, ) -# remainder +# arcsin @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="remainder", + method_name="arcsin", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - large_abs_safety_factor=2.5, - small_abs_safety_factor=2.5, - shared_dtype=True, - num_arrays=2, + allow_inf=False, ), ) -def test_torch_tensor_remainder( +def test_torch_tensor_arcsin( dtype_and_x, frontend_method_data, init_flags, @@ -2784,9 +3091,7 @@ def test_torch_tensor_remainder( "data": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "other": x[1], - }, + method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -2795,22 +3100,18 @@ def test_torch_tensor_remainder( ) -# remainder_ +# arcsin_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="remainder_", + method_name="arcsin_", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - min_value=-1e04, - max_value=1e04, - large_abs_safety_factor=2.5, - small_abs_safety_factor=2.5, - shared_dtype=True, - num_arrays=2, + min_value=-1.0, + max_value=1.0, + available_dtypes=helpers.get_dtypes("float"), ), ) -def test_torch_tensor_remainder_( +def test_torch_tensor_arcsin_( dtype_and_x, frontend_method_data, init_flags, @@ -2826,10 +3127,8 @@ def test_torch_tensor_remainder_( init_all_as_kwargs_np={ "data": x[0], }, - method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "other": x[1], - }, + method_input_dtypes=[], + method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -2838,41 +3137,19 @@ def test_torch_tensor_remainder_( ) -@st.composite -def _to_helper(draw): - dtype_x = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - num_arrays=2, - large_abs_safety_factor=3, - ) - ) - input_dtype, x = dtype_x - arg = draw(st.sampled_from(["tensor", "dtype", "device"])) - if arg == "tensor": - method_num_positional_args = 1 - method_all_as_kwargs_np = {"other": x[1]} - elif arg == "dtype": - method_num_positional_args = 1 - dtype = draw(helpers.get_dtypes("valid", full=False))[0] - method_all_as_kwargs_np = {"dtype": dtype} - else: - method_num_positional_args = 0 - device = draw(st.just("cpu")) - dtype = draw(helpers.get_dtypes("valid", full=False, none=True))[0] - method_all_as_kwargs_np = {"dtype": dtype, "device": device} - return input_dtype, x, method_num_positional_args, method_all_as_kwargs_np - - -# to +# arcsinh @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="to", - args_kwargs=_to_helper(), + method_name="arcsinh", + dtype_and_x=helpers.dtype_and_values( + min_value=-1.0, + max_value=1.0, + available_dtypes=helpers.get_dtypes("float"), + ), ) -def test_torch_tensor_to( - args_kwargs, +def test_torch_tensor_arcsinh( + dtype_and_x, frontend_method_data, init_flags, method_flags, @@ -2880,16 +3157,15 @@ def test_torch_tensor_to( on_device, backend_fw, ): - input_dtype, x, method_num_positional_args, method_all_as_kwargs_np = args_kwargs - method_flags.num_positional_args = method_num_positional_args + input_dtype, x = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ "data": x[0], }, - method_input_dtypes=input_dtype, - method_all_as_kwargs_np=method_all_as_kwargs_np, + method_input_dtypes=[], + method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -2898,17 +3174,18 @@ def test_torch_tensor_to( ) -# arctan +# arcsinh_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="arctan", + method_name="arcsinh_", dtype_and_x=helpers.dtype_and_values( + min_value=-1.0, + max_value=1.0, available_dtypes=helpers.get_dtypes("float"), - allow_inf=False, ), ) -def test_torch_tensor_arctan( +def test_torch_tensor_arcsinh_( dtype_and_x, frontend_method_data, init_flags, @@ -2924,7 +3201,7 @@ def test_torch_tensor_arctan( init_all_as_kwargs_np={ "data": x[0], }, - method_input_dtypes=input_dtype, + method_input_dtypes=[], method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -2934,17 +3211,17 @@ def test_torch_tensor_arctan( ) -# arctan_ +# arctan @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="arctan_", + method_name="arctan", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), allow_inf=False, ), ) -def test_torch_tensor_arctan_( +def test_torch_tensor_arctan( dtype_and_x, frontend_method_data, init_flags, @@ -3046,17 +3323,17 @@ def test_torch_tensor_arctan2_( ) -# acos +# arctan_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="acos", + method_name="arctan_", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), allow_inf=False, ), ) -def test_torch_tensor_acos( +def test_torch_tensor_arctan_( dtype_and_x, frontend_method_data, init_flags, @@ -3082,16 +3359,18 @@ def test_torch_tensor_acos( ) -# floor +# arctanh @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="floor", + method_name="arctanh", dtype_and_x=helpers.dtype_and_values( + min_value=-1.0, + max_value=1.0, available_dtypes=helpers.get_dtypes("float"), ), ) -def test_torch_tensor_floor( +def test_torch_tensor_arctanh( dtype_and_x, frontend_method_data, init_flags, @@ -3107,7 +3386,7 @@ def test_torch_tensor_floor( init_all_as_kwargs_np={ "data": x[0], }, - method_input_dtypes=input_dtype, + method_input_dtypes=[], method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -3117,17 +3396,18 @@ def test_torch_tensor_floor( ) -# new_tensor +# arctanh_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="new_tensor", + method_name="arctanh_", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, + min_value=-1.0, + max_value=1.0, + available_dtypes=helpers.get_dtypes("float"), ), ) -def test_torch_tensor_new_tensor( +def test_torch_tensor_arctanh_( dtype_and_x, frontend_method_data, init_flags, @@ -3138,50 +3418,13 @@ def test_torch_tensor_new_tensor( ): input_dtype, x = dtype_and_x helpers.test_frontend_method( - init_input_dtypes=[input_dtype[0]], - backend_to_test=backend_fw, - init_all_as_kwargs_np={ - "data": x[0], - }, - method_input_dtypes=[input_dtype[1]], - method_all_as_kwargs_np={ - "data": x[1], - "dtype": input_dtype[1], - }, - frontend_method_data=frontend_method_data, - init_flags=init_flags, - method_flags=method_flags, - frontend=frontend, - on_device=on_device, - ) - - -# __getitem__ -@handle_frontend_method( - class_tree=CLASS_TREE, - init_tree="torch.tensor", - method_name="__getitem__", - dtype_x_index=helpers.dtype_array_query( - available_dtypes=helpers.get_dtypes("valid"), - allow_neg_step=False, - ), -) -def test_torch___getitem__( - dtype_x_index, - frontend_method_data, - init_flags, - method_flags, - frontend, - on_device, - backend_fw, -): - input_dtype, x, index = dtype_x_index - helpers.test_frontend_method( - init_input_dtypes=[input_dtype[0]], + init_input_dtypes=input_dtype, backend_to_test=backend_fw, - init_all_as_kwargs_np={"data": x}, - method_input_dtypes=[*input_dtype[1:]], - method_all_as_kwargs_np={"query": index}, + init_all_as_kwargs_np={ + "data": x[0], + }, + method_input_dtypes=[], + method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -3190,18 +3433,28 @@ def test_torch___getitem__( ) -# __setitem__ +# argmax @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="__setitem__", - dtypes_x_index_val=helpers.dtype_array_query_val( - available_dtypes=helpers.get_dtypes("valid"), - allow_neg_step=False, - ).filter(lambda x: x[0][0] == x[0][-1]), + method_name="argmax", + dtype_input_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("numeric"), + force_int_axis=True, + min_num_dims=1, + max_num_dims=3, + min_dim_size=1, + max_dim_size=3, + min_value=1, + max_value=5, + valid_axis=True, + allow_neg_axes=True, + ), + keepdim=st.booleans(), ) -def test_torch___setitem__( - dtypes_x_index_val, +def test_torch_tensor_argmax( + dtype_input_axis, + keepdim, frontend_method_data, init_flags, method_flags, @@ -3209,13 +3462,18 @@ def test_torch___setitem__( on_device, backend_fw, ): - input_dtype, x, index, val = dtypes_x_index_val + input_dtype, x, axis = dtype_input_axis helpers.test_frontend_method( - init_input_dtypes=[input_dtype[0]], + init_input_dtypes=input_dtype, backend_to_test=backend_fw, - init_all_as_kwargs_np={"data": x}, - method_input_dtypes=[*input_dtype[1:]], - method_all_as_kwargs_np={"key": index, "value": val}, + init_all_as_kwargs_np={ + "data": x[0], + }, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={ + "dim": axis, + "keepdim": keepdim, + }, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -3224,19 +3482,28 @@ def test_torch___setitem__( ) -# view_as +# argmin @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="view_as", - dtype_x=helpers.dtype_and_values( + method_name="argmin", + dtype_input_axis=helpers.dtype_values_axis( available_dtypes=helpers.get_dtypes("numeric"), - shape=st.shared(helpers.get_shape(), key="value_shape"), - num_arrays=2, + force_int_axis=True, + min_num_dims=1, + max_num_dims=3, + min_dim_size=1, + max_dim_size=3, + min_value=1, + max_value=5, + valid_axis=True, + allow_neg_axes=True, ), + keepdim=st.booleans(), ) -def test_torch_tensor_view_as( - dtype_x, +def test_torch_tensor_argmin( + dtype_input_axis, + keepdim, frontend_method_data, init_flags, method_flags, @@ -3244,7 +3511,7 @@ def test_torch_tensor_view_as( on_device, backend_fw, ): - input_dtype, x = dtype_x + input_dtype, x, axis = dtype_input_axis helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -3253,7 +3520,8 @@ def test_torch_tensor_view_as( }, method_input_dtypes=input_dtype, method_all_as_kwargs_np={ - "other": x[1], + "dim": axis, + "keepdim": keepdim, }, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -3263,24 +3531,28 @@ def test_torch_tensor_view_as( ) -# unsqueeze +# argsort @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="unsqueeze", - dtype_value=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - shape=st.shared(helpers.get_shape(), key="shape"), - ), - dim=helpers.get_axis( - shape=st.shared(helpers.get_shape(), key="shape"), - allow_neg=True, - force_int=True, + method_name="argsort", + dtype_input_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("numeric"), + force_int_axis=True, + min_num_dims=1, + max_num_dims=3, + min_dim_size=1, + max_dim_size=3, + min_value=1, + max_value=5, + valid_axis=True, + allow_neg_axes=True, ), + descending=st.booleans(), ) -def test_torch_tensor_unsqueeze( - dtype_value, - dim, +def test_torch_tensor_argsort( + dtype_input_axis, + descending, frontend_method_data, init_flags, method_flags, @@ -3288,7 +3560,7 @@ def test_torch_tensor_unsqueeze( on_device, backend_fw, ): - input_dtype, x = dtype_value + input_dtype, x, axis = dtype_input_axis helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -3297,7 +3569,8 @@ def test_torch_tensor_unsqueeze( }, method_input_dtypes=input_dtype, method_all_as_kwargs_np={ - "dim": dim, + "dim": axis, + "descending": descending, }, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -3307,24 +3580,17 @@ def test_torch_tensor_unsqueeze( ) -# unsqueeze_ +# argwhere @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="unsqueeze_", - dtype_value=helpers.dtype_and_values( + method_name="argwhere", + dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), - shape=st.shared(helpers.get_shape(), key="shape"), - ), - dim=helpers.get_axis( - shape=st.shared(helpers.get_shape(), key="shape"), - allow_neg=True, - force_int=True, ), ) -def test_torch_tensor_unsqueeze_( - dtype_value, - dim, +def test_torch_tensor_argwhere( + dtype_and_x, frontend_method_data, init_flags, method_flags, @@ -3332,7 +3598,7 @@ def test_torch_tensor_unsqueeze_( on_device, backend_fw, ): - input_dtype, x = dtype_value + input_dtype, x = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -3340,9 +3606,7 @@ def test_torch_tensor_unsqueeze_( "data": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "dim": dim, - }, + method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -3351,64 +3615,52 @@ def test_torch_tensor_unsqueeze_( ) -# ravel @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="ravel", - dtype_value=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), - ), + method_name="as_strided", + dtype_x_and_other=_as_strided_helper(), ) -def test_torch_tensor_ravel( - dtype_value, +def test_torch_tensor_as_strided( + dtype_x_and_other, + frontend, frontend_method_data, init_flags, method_flags, - frontend, on_device, backend_fw, ): - input_dtype, x = dtype_value + input_dtype, x, size, stride, offset = dtype_x_and_other helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, - init_all_as_kwargs_np={ - "data": x[0], - }, + init_all_as_kwargs_np={"data": x[0]}, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, + method_all_as_kwargs_np={ + "size": size, + "stride": stride, + "storage_offset": offset, + }, + frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, - frontend=frontend, on_device=on_device, ) -# split +# asin @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="split", - dtype_value=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - shape=st.shared(helpers.get_shape(min_num_dims=1), key="value_shape"), - ), - split_size=_get_splits(allow_none=False, min_num_dims=1, allow_array_indices=False), - dim=st.shared( - helpers.get_axis( - shape=st.shared(helpers.get_shape(min_num_dims=1), key="value_shape"), - force_int=True, - ), - key="target_axis", + method_name="asin", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + allow_inf=False, ), ) -def test_torch_tensor_split( - dtype_value, - split_size, - dim, +def test_torch_tensor_asin( + dtype_and_x, frontend_method_data, init_flags, method_flags, @@ -3416,7 +3668,7 @@ def test_torch_tensor_split( on_device, backend_fw, ): - input_dtype, x = dtype_value + input_dtype, x = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -3424,10 +3676,7 @@ def test_torch_tensor_split( "data": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "split_size": split_size, - "dim": dim, - }, + method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -3436,31 +3685,19 @@ def test_torch_tensor_split( ) -# tensor_split +# asin_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="tensor_split", - dtype_value=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("integer"), - shape=st.shared(helpers.get_shape(min_num_dims=1), key="value_shape"), - ), - indices_or_sections=_get_splits( - min_num_dims=1, allow_none=False, allow_array_indices=False - ), - dim=st.shared( - helpers.get_axis( - shape=st.shared(helpers.get_shape(min_num_dims=1), key="value_shape"), - force_int=True, - ), - key="target_axis", + method_name="asin_", + dtype_and_x=helpers.dtype_and_values( + min_value=-1.0, + max_value=1.0, + available_dtypes=helpers.get_dtypes("float"), ), - method_num_positional_args=st.just(1), ) -def test_torch_tensor_tensor_split( - dtype_value, - indices_or_sections, - dim, +def test_torch_tensor_asin_( + dtype_and_x, frontend_method_data, init_flags, method_flags, @@ -3468,7 +3705,7 @@ def test_torch_tensor_tensor_split( on_device, backend_fw, ): - input_dtype, x = dtype_value + input_dtype, x = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -3476,10 +3713,7 @@ def test_torch_tensor_tensor_split( "data": x[0], }, method_input_dtypes=[], - method_all_as_kwargs_np={ - "indices_or_sections": indices_or_sections, - "dim": dim, - }, + method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -3488,26 +3722,18 @@ def test_torch_tensor_tensor_split( ) -# vsplit +# asinh @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="vsplit", - dtype_value=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - shape=st.shared(helpers.get_shape(min_num_dims=2), key="value_shape"), - ), - indices_or_sections=_get_splits( - min_num_dims=2, - axis=0, - allow_none=False, - allow_array_indices=False, - is_mod_split=True, + method_name="asinh", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + allow_inf=False, ), ) -def test_torch_tensor_vsplit( - dtype_value, - indices_or_sections, +def test_torch_tensor_asinh( + dtype_and_x, frontend_method_data, init_flags, method_flags, @@ -3515,43 +3741,37 @@ def test_torch_tensor_vsplit( on_device, backend_fw, ): - input_dtype, x = dtype_value + input_dtype, x = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ "data": x[0], }, - method_input_dtypes=[], - method_all_as_kwargs_np={"indices_or_sections": indices_or_sections}, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, frontend=frontend, + rtol_=1e-2, + atol_=1e-2, on_device=on_device, ) -# hsplit +# asinh_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="hsplit", - dtype_value=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - shape=st.shared(helpers.get_shape(min_num_dims=2), key="value_shape"), - ), - indices_or_sections=_get_splits( - min_num_dims=1, - axis=1, - allow_none=False, - allow_array_indices=False, - is_mod_split=True, + method_name="asinh_", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + allow_inf=False, ), ) -def test_torch_tensor_hsplit( - dtype_value, - indices_or_sections, +def test_torch_tensor_asinh_( + dtype_and_x, frontend_method_data, init_flags, method_flags, @@ -3559,43 +3779,37 @@ def test_torch_tensor_hsplit( on_device, backend_fw, ): - input_dtype, x = dtype_value + input_dtype, x = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ "data": x[0], }, - method_input_dtypes=[], - method_all_as_kwargs_np={"indices_or_sections": indices_or_sections}, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, frontend=frontend, + rtol_=1e-2, + atol_=1e-2, on_device=on_device, ) -# dsplit +# atan @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="dsplit", - dtype_value=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - shape=st.shared(helpers.get_shape(min_num_dims=3), key="value_shape"), - ), - indices_or_sections=_get_splits( - min_num_dims=3, - axis=2, - allow_none=False, - allow_array_indices=False, - is_mod_split=True, + method_name="atan", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + allow_inf=False, ), ) -def test_torch_tensor_dsplit( - dtype_value, - indices_or_sections, +def test_torch_tensor_atan( + dtype_and_x, frontend_method_data, init_flags, method_flags, @@ -3603,15 +3817,15 @@ def test_torch_tensor_dsplit( on_device, backend_fw, ): - input_dtype, x = dtype_value + input_dtype, x = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ "data": x[0], }, - method_input_dtypes=[], - method_all_as_kwargs_np={"indices_or_sections": indices_or_sections}, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -3620,16 +3834,17 @@ def test_torch_tensor_dsplit( ) -# detach +# atan2 @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="detach", + method_name="atan2", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), + available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, ), ) -def test_torch_tensor_detach( +def test_torch_tensor_atan2( dtype_and_x, frontend_method_data, init_flags, @@ -3646,7 +3861,9 @@ def test_torch_tensor_detach( "data": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, + method_all_as_kwargs_np={ + "other": x[1], + }, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -3655,16 +3872,17 @@ def test_torch_tensor_detach( ) -# detach_ +# atan2_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="detach_", + method_name="atan2_", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), + available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, ), ) -def test_torch_tensor_detach_( +def test_torch_tensor_atan2_( dtype_and_x, frontend_method_data, init_flags, @@ -3681,7 +3899,9 @@ def test_torch_tensor_detach_( "data": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, + method_all_as_kwargs_np={ + "other": x[1], + }, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -3690,16 +3910,17 @@ def test_torch_tensor_detach_( ) -# dim +# atan_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="dim", + method_name="atan_", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), + available_dtypes=helpers.get_dtypes("float"), + allow_inf=False, ), ) -def test_torch_tensor_dim( +def test_torch_tensor_atan_( dtype_and_x, frontend_method_data, init_flags, @@ -3725,16 +3946,18 @@ def test_torch_tensor_dim( ) -# ndimension +# atanh @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="ndimension", + method_name="atanh", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), + min_value=-1.0, + max_value=1.0, + available_dtypes=helpers.get_dtypes("float"), ), ) -def test_torch_tensor_ndimension( +def test_torch_tensor_atanh( dtype_and_x, frontend_method_data, init_flags, @@ -3760,58 +3983,18 @@ def test_torch_tensor_ndimension( ) -@st.composite -def _fill_value_and_size( - draw, - *, - min_num_dims=1, - max_num_dims=5, - min_dim_size=1, - max_dim_size=10, -): - if isinstance(min_dim_size, st._internal.SearchStrategy): - min_dim_size = draw(min_dim_size) - if isinstance(max_dim_size, st._internal.SearchStrategy): - max_dim_size = draw(max_dim_size) - - available_dtypes = draw(helpers.get_dtypes("numeric")) - dtype = draw( - helpers.array_dtypes( - num_arrays=1, - available_dtypes=available_dtypes, - ) - ) - array = draw( - helpers.array_values( - dtype=dtype[0], - shape=(1,), - ) - ) - dtype.append("int32") - size = draw( - st.shared( - helpers.get_shape( - min_num_dims=min_num_dims, - max_num_dims=max_num_dims, - min_dim_size=min_dim_size, - max_dim_size=max_dim_size, - ), - key="shape", - ) - ) - fill_value = draw(helpers.ints()) if "int" in dtype[0] else draw(helpers.floats()) - - return dtype, [array, size, fill_value] - - -# new_full +# atanh_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="new_full", - dtype_and_x=_fill_value_and_size(max_num_dims=3), + method_name="atanh_", + dtype_and_x=helpers.dtype_and_values( + min_value=-1.0, + max_value=1.0, + available_dtypes=helpers.get_dtypes("float"), + ), ) -def test_torch_tensor_new_full( +def test_torch_tensor_atanh_( dtype_and_x, frontend_method_data, init_flags, @@ -3822,16 +4005,13 @@ def test_torch_tensor_new_full( ): input_dtype, x = dtype_and_x helpers.test_frontend_method( - init_input_dtypes=[input_dtype[0]], + init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ "data": x[0], }, - method_input_dtypes=[input_dtype[1]], - method_all_as_kwargs_np={ - "size": x[1], - "fill_value": x[2], - }, + method_input_dtypes=[], + method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -3840,126 +4020,172 @@ def test_torch_tensor_new_full( ) -# new_empty (not actually intuitive for testing) +@given( + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float", prune_function=False), + num_arrays=3, + min_value=-1e3, + max_value=1e3, + ).filter(lambda x: all(dt == "float32" for dt in x[0])), +) +def test_torch_tensor_backward( + dtype_x, + backend_fw, +): + ivy.set_backend(backend_fw) + if ivy.current_backend_str() == "numpy": + ivy.warnings.warn("Gradient calculation unavailable for numpy backend") + return + if ivy.current_backend_str() == "paddle": + ivy.warnings.warn("torch.Tensor.backward() unavailable for paddle backend") + return + _, values = dtype_x + x = Tensor(values[0], requires_grad=True) + y = Tensor(values[1], requires_grad=True) + z = Tensor(values[2], requires_grad=True) + a = x + y.pow(2) + b = z * a + c = b.sum() + c.backward() + x_torch = torch.tensor(values[0], requires_grad=True, dtype=torch.float32) + y_torch = torch.tensor(values[1], requires_grad=True, dtype=torch.float32) + z_torch = torch.tensor(values[2], requires_grad=True, dtype=torch.float32) + a_torch = x_torch + y_torch.pow(2) + b_torch = z_torch * a_torch + c_torch = b_torch.sum() + c_torch.backward() + helpers.assertions.value_test( + ret_np_flat=helpers.flatten_and_to_np( + ret=x._grads.ivy_array, backend=backend_fw + ), + ret_np_from_gt_flat=helpers.flatten_and_to_np( + ret=ivy.to_ivy(x_torch.grad.numpy()), backend=backend_fw + ), + rtol=1e-3, + atol=1e-3, + backend="torch", + ) + helpers.assertions.value_test( + ret_np_flat=helpers.flatten_and_to_np( + ret=y._grads.ivy_array, backend=backend_fw + ), + ret_np_from_gt_flat=helpers.flatten_and_to_np( + ret=ivy.to_ivy(y_torch.grad.numpy()), backend=backend_fw + ), + rtol=1e-3, + atol=1e-3, + backend="torch", + ) + helpers.assertions.value_test( + ret_np_flat=helpers.flatten_and_to_np( + ret=z._grads.ivy_array, backend=backend_fw + ), + ret_np_from_gt_flat=helpers.flatten_and_to_np( + ret=ivy.to_ivy(z_torch.grad.numpy()), backend=backend_fw + ), + rtol=1e-3, + atol=1e-3, + backend="torch", + ) + + @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="new_empty", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), + method_name="baddbmm", + dtype_and_matrices=_get_dtype_and_3dbatch_matrices(with_input=True, input_3d=True), + beta=st.floats( + min_value=-5, + max_value=5, + allow_nan=False, + allow_subnormal=False, + allow_infinity=False, ), - size=helpers.get_shape( - min_num_dims=1, - max_num_dims=3, + alpha=st.floats( + min_value=-5, + max_value=5, + allow_nan=False, + allow_subnormal=False, + allow_infinity=False, ), ) -def test_torch_tensor_new_empty( - dtype_and_x, - size, +def test_torch_tensor_baddbmm( + dtype_and_matrices, + beta, + alpha, + frontend, frontend_method_data, init_flags, method_flags, - frontend, on_device, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, x, batch1, batch2 = dtype_and_matrices helpers.test_frontend_method( - init_input_dtypes=[input_dtype[0]], + init_input_dtypes=input_dtype, backend_to_test=backend_fw, - init_all_as_kwargs_np={ - "data": x, - }, - method_input_dtypes=[ivy.int32], + init_all_as_kwargs_np={"data": x[0]}, + method_input_dtypes=input_dtype, method_all_as_kwargs_np={ - "size": size, + "batch1": batch1, + "batch2": batch2, + "beta": beta, + "alpha": alpha, }, + frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, - frontend=frontend, on_device=on_device, ) -@st.composite -def _expand_helper(draw): - num_dims = draw(st.integers(min_value=1, max_value=10)) - shape = draw( - helpers.get_shape(min_num_dims=num_dims, max_num_dims=num_dims).filter( - lambda x: any(i == 1 for i in x) - ) - ) - new_shape = draw( - helpers.get_shape(min_num_dims=num_dims, max_num_dims=num_dims).filter( - lambda x: all(x[i] == v if v != 1 else True for i, v in enumerate(shape)) - ) - ) - dtype, x = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - shape=shape, - ) - ) - return dtype, x, new_shape - - +# bernoulli @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="expand", - dtype_x_shape=_expand_helper(), - unpack_shape=st.booleans(), + method_name="bernoulli", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + ), + test_with_out=st.just(True), ) -def test_torch_tensor_expand( - dtype_x_shape, - unpack_shape, +def test_torch_tensor_bernoulli( + dtype_and_x, + frontend, frontend_method_data, init_flags, method_flags, - frontend, - on_device, backend_fw, ): - input_dtype, x, shape = dtype_x_shape - if unpack_shape: - method_flags.num_positional_args = len(shape) + 1 - size = {} - i = 0 - for x_ in shape: - size["x{}".format(i)] = x_ - i += 1 - else: - size = { - "size": shape, - } + input_dtype, x = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ - "data": x[0], + "input": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np=size, + method_all_as_kwargs_np={"generator": x[1], "out": x[2]}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, frontend=frontend, - on_device=on_device, ) -# expand_as +# bitwise_and @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="expand_as", - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), num_arrays=2 + method_name="bitwise_and", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("integer"), + num_arrays=2, ), ) -def test_torch_tensor_expand_as( - dtype_x, +def test_torch_tensor_bitwise_and( + dtype_and_x, frontend_method_data, init_flags, method_flags, @@ -3967,7 +4193,7 @@ def test_torch_tensor_expand_as( on_device, backend_fw, ): - input_dtype, x = dtype_x + input_dtype, x = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -3986,46 +4212,18 @@ def test_torch_tensor_expand_as( ) -@st.composite -def _unfold_args(draw): - values_dtype, values, axis, shape = draw( - helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("float"), - force_int_axis=True, - shape=draw( - helpers.get_shape( - allow_none=False, - min_num_dims=1, - min_dim_size=1, - ) - ), - ret_shape=True, - ) - ) - size = draw( - st.integers( - min_value=1, - max_value=max(shape[axis] - 1, 1), - ) - ) - step = draw( - st.integers( - min_value=1, - max_value=size, - ) - ) - return values_dtype, values, axis, size, step - - -# unfold +# bitwise_and_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="unfold", - dtype_values_args=_unfold_args(), + method_name="bitwise_and_", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("integer"), + num_arrays=2, + ), ) -def test_torch_tensor_unfold( - dtype_values_args, +def test_torch_tensor_bitwise_and_( + dtype_and_x, frontend_method_data, init_flags, method_flags, @@ -4033,19 +4231,16 @@ def test_torch_tensor_unfold( on_device, backend_fw, ): - input_dtype, x, axis, size, step = dtype_values_args - print(axis, size, step) + input_dtype, x = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ - "data": x, + "data": x[0], }, method_input_dtypes=input_dtype, method_all_as_kwargs_np={ - "dimension": axis, - "size": size, - "step": step, + "other": x[1], }, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -4055,17 +4250,17 @@ def test_torch_tensor_unfold( ) -# __mod__ +# bitwise_left_shift @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="__mod__", + method_name="bitwise_left_shift", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("integer"), num_arrays=2, ), ) -def test_torch___mod__( +def test_torch_tensor_bitwise_left_shift( dtype_and_x, frontend_method_data, init_flags, @@ -4093,16 +4288,17 @@ def test_torch___mod__( ) -# long +# bitwise_not @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="long", + method_name="bitwise_not", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("integer"), + num_arrays=2, ), ) -def test_torch_tensor_long( +def test_torch_tensor_bitwise_not( dtype_and_x, frontend_method_data, init_flags, @@ -4119,118 +4315,62 @@ def test_torch_tensor_long( "data": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, + method_all_as_kwargs_np={}, frontend=frontend, on_device=on_device, ) -# max +# bitwise_not_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="max", - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + method_name="bitwise_not_", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("integer"), + num_arrays=2, ), ) -def test_torch_tensor_max( - dtype_x, +def test_torch_tensor_bitwise_not_( + dtype_and_x, frontend_method_data, init_flags, method_flags, frontend, - on_device, - backend_fw, -): - input_dtype, x = dtype_x - helpers.test_frontend_method( - init_input_dtypes=input_dtype, - backend_to_test=backend_fw, - init_all_as_kwargs_np={ - "data": x[0], - }, - method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, - frontend_method_data=frontend_method_data, - init_flags=init_flags, - method_flags=method_flags, - frontend=frontend, - on_device=on_device, - ) - - -@given( - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid", prune_function=False) - ).filter(lambda x: "bfloat16" not in x[0]), -) -def test_torch_tensor_is_quantized( - dtype_x, - backend_fw, -): - ivy.set_backend(backend_fw) - _, data = dtype_x - x = Tensor(data[0]) - x.ivy_array = data[0] - ivy.utils.assertions.check_equal( - x.is_quantized, "q" in ivy.dtype(ivy.array(data[0])), as_array=False - ) - ivy.previous_backend() - - -@given( - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid", prune_function=False) - ).filter(lambda x: "bfloat16" not in x[0]), -) -def test_torch_tensor_is_cuda( - dtype_x, - backend_fw, -): - ivy.set_backend(backend_fw) - _, data = dtype_x - x = Tensor(data[0]) - x.ivy_array = data[0] - ivy.utils.assertions.check_equal( - x.is_cuda, "gpu" in ivy.dev(ivy.array(data[0])), as_array=False - ) - ivy.previous_backend() - - -@given( - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid", prune_function=False) - ).filter(lambda x: "bfloat16" not in x[0]), -) -def test_torch_tensor_is_meta( - dtype_x, - backend_fw, -): - ivy.set_backend(backend_fw) - _, data = dtype_x - x = Tensor(data[0]) - x.ivy_array = data[0] - ivy.utils.assertions.check_equal( - x.is_meta, "meta" in ivy.dev(ivy.array(data[0])), as_array=False + on_device, + backend_fw, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_method( + init_input_dtypes=input_dtype, + backend_to_test=backend_fw, + init_all_as_kwargs_np={ + "data": x[0], + }, + method_input_dtypes=input_dtype, + frontend_method_data=frontend_method_data, + init_flags=init_flags, + method_flags=method_flags, + method_all_as_kwargs_np={}, + frontend=frontend, + on_device=on_device, ) - ivy.previous_backend() -# logical_and +# bitwise_or @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="logical_and", + method_name="bitwise_or", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), num_arrays=2, ), ) -def test_torch_tensor_logical_and( +def test_torch_tensor_bitwise_or( dtype_and_x, frontend_method_data, init_flags, @@ -4258,16 +4398,17 @@ def test_torch_tensor_logical_and( ) -# logical_not +# bitwise_or_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="logical_not", + method_name="bitwise_or_", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), num_arrays=1 + available_dtypes=helpers.get_dtypes("valid"), + num_arrays=2, ), ) -def test_torch_tensor_logical_not( +def test_torch_tensor_bitwise_or_( dtype_and_x, frontend_method_data, init_flags, @@ -4284,7 +4425,9 @@ def test_torch_tensor_logical_not( "data": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, + method_all_as_kwargs_np={ + "other": x[1], + }, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -4293,17 +4436,18 @@ def test_torch_tensor_logical_not( ) -# logical_or +# bitwise right shift @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="logical_or", + method_name="bitwise_right_shift", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), num_arrays=2, + shared_dtype=True, ), ) -def test_torch_tensor_logical_or( +def test_torch_tensor_bitwise_right_shift( dtype_and_x, frontend_method_data, init_flags, @@ -4313,6 +4457,11 @@ def test_torch_tensor_logical_or( backend_fw, ): input_dtype, x = dtype_and_x + # negative shifts will throw an exception + # shifts >= dtype witdth produce backend-defined behavior + x[1] = np.asarray( + np.clip(x[1], 0, np.iinfo(input_dtype[1]).bits - 1), dtype=input_dtype[1] + ) helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -4331,17 +4480,17 @@ def test_torch_tensor_logical_or( ) -# bitwise_not +# bitwise_xor @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="bitwise_not", + method_name="bitwise_xor", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("integer"), + available_dtypes=st.one_of(st.just(("bool",)), helpers.get_dtypes("integer")), num_arrays=2, ), ) -def test_torch_tensor_bitwise_not( +def test_torch_tensor_bitwise_xor( dtype_and_x, frontend_method_data, init_flags, @@ -4358,26 +4507,28 @@ def test_torch_tensor_bitwise_not( "data": x[0], }, method_input_dtypes=input_dtype, + method_all_as_kwargs_np={ + "other": x[1], + }, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, - method_all_as_kwargs_np={}, frontend=frontend, on_device=on_device, ) -# bitwise_and +# bitwise_xor_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="bitwise_and", + method_name="bitwise_xor_", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("integer"), + available_dtypes=st.one_of(st.just(("bool",)), helpers.get_dtypes("integer")), num_arrays=2, ), ) -def test_torch_tensor_bitwise_and( +def test_torch_tensor_bitwise_xor_( dtype_and_x, frontend_method_data, init_flags, @@ -4405,17 +4556,16 @@ def test_torch_tensor_bitwise_and( ) -# bitwise_or +# bool @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="bitwise_or", + method_name="bool", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - num_arrays=2, + available_dtypes=helpers.get_dtypes("integer"), ), ) -def test_torch_tensor_bitwise_or( +def test_torch_tensor_bool( dtype_and_x, frontend_method_data, init_flags, @@ -4432,9 +4582,7 @@ def test_torch_tensor_bitwise_or( "data": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "other": x[1], - }, + method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -4443,17 +4591,16 @@ def test_torch_tensor_bitwise_or( ) -# bitwise_or_ +# byte @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="bitwise_or_", + method_name="byte", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), - num_arrays=2, ), ) -def test_torch_tensor_bitwise_or_( +def test_torch_tensor_byte( dtype_and_x, frontend_method_data, init_flags, @@ -4470,9 +4617,7 @@ def test_torch_tensor_bitwise_or_( "data": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "other": x[1], - }, + method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -4481,17 +4626,16 @@ def test_torch_tensor_bitwise_or_( ) -# bitwise_left_shift +# ceil @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="bitwise_left_shift", + method_name="ceil", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("integer"), - num_arrays=2, + available_dtypes=helpers.get_dtypes("float"), ), ) -def test_torch_tensor_bitwise_left_shift( +def test_torch_tensor_ceil( dtype_and_x, frontend_method_data, init_flags, @@ -4508,9 +4652,7 @@ def test_torch_tensor_bitwise_left_shift( "data": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "other": x[1], - }, + method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -4519,17 +4661,16 @@ def test_torch_tensor_bitwise_left_shift( ) -# add_ +# ceil_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="add_", + method_name="ceil_", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, + available_dtypes=helpers.get_dtypes("float"), ), ) -def test_torch_tensor_add_( +def test_torch_tensor_ceil_( dtype_and_x, frontend_method_data, init_flags, @@ -4540,15 +4681,13 @@ def test_torch_tensor_add_( ): input_dtype, x = dtype_and_x helpers.test_frontend_method( - init_input_dtypes=[input_dtype[0]], + init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ "data": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "other": x[1], - }, + method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -4557,73 +4696,88 @@ def test_torch_tensor_add_( ) -# subtract_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="subtract_", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, - ), + method_name="cholesky", + dtype_and_x=_get_dtype_and_matrix(square=True), + upper=st.booleans(), ) -def test_torch_tensor_subtract_( +def test_torch_tensor_cholesky( dtype_and_x, + upper, + frontend, frontend_method_data, init_flags, method_flags, - frontend, on_device, backend_fw, ): input_dtype, x = dtype_and_x + x = x[0] + # make symmetric positive-definite + x = np.matmul(x.swapaxes(-1, -2), x) + np.identity(x.shape[-1]) * 1e-3 + helpers.test_frontend_method( - init_input_dtypes=[input_dtype[0]], + init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ - "data": x[0], + "data": x, }, method_input_dtypes=input_dtype, method_all_as_kwargs_np={ - "other": x[1], + "upper": upper, }, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, frontend=frontend, on_device=on_device, + rtol_=1e-2, ) -# arccos_ +# chunk +@pytest.mark.skip("Testing takes a lot of time") @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="arccos_", - dtype_and_x=helpers.dtype_and_values( - min_value=-1.0, - max_value=1.0, + method_name="chunk", + dtype_x_dim=helpers.dtype_values_axis( available_dtypes=helpers.get_dtypes("float"), + min_num_dims=1, + min_value=-1e04, + max_value=1e04, + force_int_axis=True, + valid_axis=True, + ), + chunks=st.integers( + min_value=1, + max_value=5, ), ) -def test_torch_tensor_arccos_( - dtype_and_x, +def test_torch_tensor_chunk( + dtype_x_dim, + chunks, + frontend, frontend_method_data, init_flags, method_flags, - frontend, on_device, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, x, dim = dtype_x_dim helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ "data": x[0], }, - method_input_dtypes=[], - method_all_as_kwargs_np={}, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={ + "chunks": chunks, + "dim": dim, + }, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -4632,35 +4786,31 @@ def test_torch_tensor_arccos_( ) -# arccos +# clamp @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="arccos", - dtype_and_x=helpers.dtype_and_values( - min_value=-1.0, - max_value=1.0, - available_dtypes=helpers.get_dtypes("float"), - ), + method_name="clamp", + dtype_and_x_min_max=_get_clamp_inputs(), ) -def test_torch_tensor_arccos( - dtype_and_x, +def test_torch_tensor_clamp( + dtype_and_x_min_max, + frontend, frontend_method_data, init_flags, method_flags, - frontend, on_device, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, x, min, max = dtype_and_x_min_max helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ "data": x[0], }, - method_input_dtypes=[], - method_all_as_kwargs_np={}, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={"min": min, "max": max}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -4669,35 +4819,31 @@ def test_torch_tensor_arccos( ) -# acos_ +# clamp_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="acos_", - dtype_and_x=helpers.dtype_and_values( - min_value=-1.0, - max_value=1.0, - available_dtypes=helpers.get_dtypes("float"), - ), + method_name="clamp_", + dtype_and_x_min_max=_get_clamp_inputs(), ) -def test_torch_tensor_acos_( - dtype_and_x, +def test_torch_tensor_clamp_( + dtype_and_x_min_max, + frontend, frontend_method_data, init_flags, method_flags, - frontend, on_device, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, x, min, max = dtype_and_x_min_max helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ "data": x[0], }, - method_input_dtypes=[], - method_all_as_kwargs_np={}, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={"min": min, "max": max}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -4706,35 +4852,31 @@ def test_torch_tensor_acos_( ) -# copy_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="copy_", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - num_arrays=2, - ), + method_name="clamp_min", + input_and_ranges=_get_clip_min_inputs(), ) -def test_torch_tensor_copy_( - dtype_and_x, +def test_torch_tensor_clamp_min( + input_and_ranges, frontend_method_data, init_flags, - method_flags, + backend_fw, frontend, on_device, - backend_fw, + method_flags, ): - input_dtype, x = dtype_and_x + x_dtype, x, min = input_and_ranges helpers.test_frontend_method( - init_input_dtypes=input_dtype, + init_input_dtypes=x_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ "data": x[0], }, - method_input_dtypes=input_dtype, + method_input_dtypes=x_dtype, method_all_as_kwargs_np={ - "other": x[1], + "min": min, }, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -4744,35 +4886,31 @@ def test_torch_tensor_copy_( ) -# asin_ +# clip @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="asin_", - dtype_and_x=helpers.dtype_and_values( - min_value=-1.0, - max_value=1.0, - available_dtypes=helpers.get_dtypes("float"), - ), + method_name="clip", + input_and_ranges=_get_clamp_inputs(), ) -def test_torch_tensor_asin_( - dtype_and_x, +def test_torch_tensor_clip( + input_and_ranges, + frontend, frontend_method_data, init_flags, method_flags, - frontend, on_device, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, x, min, max = input_and_ranges helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ "data": x[0], }, - method_input_dtypes=[], - method_all_as_kwargs_np={}, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={"min": min, "max": max}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -4781,35 +4919,31 @@ def test_torch_tensor_asin_( ) -# arcsin_ +# clip_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="arcsin_", - dtype_and_x=helpers.dtype_and_values( - min_value=-1.0, - max_value=1.0, - available_dtypes=helpers.get_dtypes("float"), - ), + method_name="clip_", + input_and_ranges=_get_clamp_inputs(), ) -def test_torch_tensor_arcsin_( - dtype_and_x, +def test_torch_tensor_clip_( + input_and_ranges, + frontend, frontend_method_data, init_flags, method_flags, - frontend, on_device, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, x, min, max = input_and_ranges helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ "data": x[0], }, - method_input_dtypes=[], - method_all_as_kwargs_np={}, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={"min": min, "max": max}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -4818,17 +4952,17 @@ def test_torch_tensor_arcsin_( ) -# atan_ +# clone @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="atan_", + method_name="clone", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - allow_inf=False, + available_dtypes=helpers.get_dtypes("valid"), + num_arrays=1, ), ) -def test_torch_tensor_atan_( +def test_torch_tensor_clone( dtype_and_x, frontend_method_data, init_flags, @@ -4844,7 +4978,7 @@ def test_torch_tensor_atan_( init_all_as_kwargs_np={ "data": x[0], }, - method_input_dtypes=[], + method_input_dtypes=input_dtype, method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -4854,17 +4988,15 @@ def test_torch_tensor_atan_( ) -# tan_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="tan_", + method_name="conj", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - allow_inf=False, + available_dtypes=helpers.get_dtypes("float_and_complex") ), ) -def test_torch_tensor_tan_( +def test_torch_tensor_conj( dtype_and_x, frontend_method_data, init_flags, @@ -4880,7 +5012,7 @@ def test_torch_tensor_tan_( init_all_as_kwargs_np={ "data": x[0], }, - method_input_dtypes=[], + method_input_dtypes=input_dtype, method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -4890,18 +5022,17 @@ def test_torch_tensor_tan_( ) -# atanh +# contiguous @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="atanh", + method_name="contiguous", dtype_and_x=helpers.dtype_and_values( - min_value=-1.0, - max_value=1.0, available_dtypes=helpers.get_dtypes("float"), + allow_inf=False, ), ) -def test_torch_tensor_atanh( +def test_torch_tensor_contiguous( dtype_and_x, frontend_method_data, init_flags, @@ -4917,7 +5048,7 @@ def test_torch_tensor_atanh( init_all_as_kwargs_np={ "data": x[0], }, - method_input_dtypes=[], + method_input_dtypes=input_dtype, method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -4927,18 +5058,17 @@ def test_torch_tensor_atanh( ) -# atanh_ +# copy_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="atanh_", + method_name="copy_", dtype_and_x=helpers.dtype_and_values( - min_value=-1.0, - max_value=1.0, - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("valid"), + num_arrays=2, ), ) -def test_torch_tensor_atanh_( +def test_torch_tensor_copy_( dtype_and_x, frontend_method_data, init_flags, @@ -4954,8 +5084,10 @@ def test_torch_tensor_atanh_( init_all_as_kwargs_np={ "data": x[0], }, - method_input_dtypes=[], - method_all_as_kwargs_np={}, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={ + "other": x[1], + }, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -4964,18 +5096,18 @@ def test_torch_tensor_atanh_( ) -# arctanh +# copysign @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="arctanh", + method_name="copysign", dtype_and_x=helpers.dtype_and_values( - min_value=-1.0, - max_value=1.0, available_dtypes=helpers.get_dtypes("float"), + min_num_dims=1, + num_arrays=2, ), ) -def test_torch_tensor_arctanh( +def test_torch_tensor_copysign( dtype_and_x, frontend_method_data, init_flags, @@ -4991,8 +5123,10 @@ def test_torch_tensor_arctanh( init_all_as_kwargs_np={ "data": x[0], }, - method_input_dtypes=[], - method_all_as_kwargs_np={}, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={ + "other": x[1], + }, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -5001,18 +5135,17 @@ def test_torch_tensor_arctanh( ) -# arctanh_ +# cos @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="arctanh_", + method_name="cos", dtype_and_x=helpers.dtype_and_values( - min_value=-1.0, - max_value=1.0, available_dtypes=helpers.get_dtypes("float"), + allow_inf=False, ), ) -def test_torch_tensor_arctanh_( +def test_torch_tensor_cos( dtype_and_x, frontend_method_data, init_flags, @@ -5028,7 +5161,7 @@ def test_torch_tensor_arctanh_( init_all_as_kwargs_np={ "data": x[0], }, - method_input_dtypes=[], + method_input_dtypes=input_dtype, method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -5038,20 +5171,17 @@ def test_torch_tensor_arctanh_( ) -# pow +# cos_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="pow", + method_name="cos_", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, - min_value=-1e04, - max_value=1e04, + available_dtypes=helpers.get_dtypes("float"), allow_inf=False, ), ) -def test_torch_tensor_pow( +def test_torch_tensor_cos_( dtype_and_x, frontend_method_data, init_flags, @@ -5061,19 +5191,14 @@ def test_torch_tensor_pow( backend_fw, ): input_dtype, x = dtype_and_x - dtype = input_dtype[0] - if "int" in dtype: - x[1] = ivy.abs(x[1]) helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ - "data": x[0], + "data": list(x[0]) if type(x[0]) == int else x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "exponent": x[1], - }, + method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -5082,17 +5207,17 @@ def test_torch_tensor_pow( ) -# pow_ +# cosh @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="pow_", + method_name="cosh", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, + available_dtypes=helpers.get_dtypes("float"), + allow_inf=False, ), ) -def test_torch_tensor_pow_( +def test_torch_tensor_cosh( dtype_and_x, frontend_method_data, init_flags, @@ -5102,9 +5227,6 @@ def test_torch_tensor_pow_( backend_fw, ): input_dtype, x = dtype_and_x - dtype = input_dtype[0] - if "int" in dtype: - x[1] = ivy.abs(x[1]) helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -5112,9 +5234,7 @@ def test_torch_tensor_pow_( "data": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "exponent": x[1], - }, + method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -5123,17 +5243,17 @@ def test_torch_tensor_pow_( ) -# __pow__ +# cosh_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="__pow__", + method_name="cosh_", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, + available_dtypes=helpers.get_dtypes("float"), + allow_inf=False, ), ) -def test_torch___pow__( +def test_torch_tensor_cosh_( dtype_and_x, frontend_method_data, init_flags, @@ -5143,9 +5263,6 @@ def test_torch___pow__( backend_fw, ): input_dtype, x = dtype_and_x - dtype = input_dtype[0] - if "int" in dtype: - x[1] = ivy.abs(x[1]) helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -5153,30 +5270,35 @@ def test_torch___pow__( "data": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "exponent": x[1], - }, + method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, frontend=frontend, on_device=on_device, + rtol_=1e-2, + atol_=1e-2, ) -# __rpow__ +# count_nonzero @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="__rpow__", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, - min_value=1, + method_name="count_nonzero", + dtype_value=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), + ), + dim=helpers.get_axis( + shape=st.shared(helpers.get_shape(), key="shape"), + allow_neg=True, + force_int=True, ), ) -def test_torch___rpow__( - dtype_and_x, +def test_torch_tensor_count_nonzero( + dtype_value, + dim, frontend_method_data, init_flags, method_flags, @@ -5184,10 +5306,7 @@ def test_torch___rpow__( on_device, backend_fw, ): - input_dtype, x = dtype_and_x - dtype = input_dtype[0] - if "int" in dtype: - x[0] = ivy.abs(x[0]) + input_dtype, x = dtype_value helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -5195,9 +5314,7 @@ def test_torch___rpow__( "data": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "other": x[1], - }, + method_all_as_kwargs_np={"dim": dim}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -5206,19 +5323,26 @@ def test_torch___rpow__( ) -# arccosh_ +# cross @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="arccosh_", - dtype_and_x=helpers.dtype_and_values( - min_value=-1.0, - max_value=1.0, - available_dtypes=helpers.get_dtypes("float"), + method_name="cross", + dtype_input_other_dim=dtype_value1_value2_axis( + available_dtypes=helpers.get_dtypes("numeric"), + min_num_dims=1, + max_num_dims=10, + min_dim_size=3, + max_dim_size=3, + min_value=-1e10, + max_value=1e10, + abs_smallest_val=0.01, + large_abs_safety_factor=2, + safety_factor_scale="log", ), ) -def test_torch_tensor_arccosh_( - dtype_and_x, +def test_torch_tensor_cross( + dtype_input_other_dim, frontend_method_data, init_flags, method_flags, @@ -5226,45 +5350,48 @@ def test_torch_tensor_arccosh_( on_device, backend_fw, ): - input_dtype, x = dtype_and_x + dtype, input, other, dim = dtype_input_other_dim helpers.test_frontend_method( - init_input_dtypes=input_dtype, + init_input_dtypes=dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ - "data": x[0], + "data": input, + }, + method_input_dtypes=dtype, + method_all_as_kwargs_np={ + "other": other, + "dim": dim, }, - method_input_dtypes=[], - method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, frontend=frontend, on_device=on_device, - ) - - -# argmax -@handle_frontend_method( - class_tree=CLASS_TREE, - init_tree="torch.tensor", - method_name="argmax", - dtype_input_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("numeric"), - force_int_axis=True, - min_num_dims=1, - max_num_dims=3, - min_dim_size=1, - max_dim_size=3, - min_value=1, - max_value=5, - valid_axis=True, - allow_neg_axes=True, + rtol_=1e-2, + atol_=1e-2, + ) + + +# cumprod +@handle_frontend_method( + class_tree=CLASS_TREE, + init_tree="torch.tensor", + method_name="cumprod", + dtype_value=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), ), - keepdim=st.booleans(), + dim=helpers.get_axis( + shape=st.shared(helpers.get_shape(), key="shape"), + allow_neg=True, + force_int=True, + ), + dtypes=_dtypes(), ) -def test_torch_tensor_argmax( - dtype_input_axis, - keepdim, +def test_torch_tensor_cumprod( + dtype_value, + dim, + dtypes, frontend_method_data, init_flags, method_flags, @@ -5272,17 +5399,17 @@ def test_torch_tensor_argmax( on_device, backend_fw, ): - input_dtype, x, axis = dtype_input_axis + input_dtype, x = dtype_value helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ "data": x[0], }, - method_input_dtypes=input_dtype, + method_input_dtypes=dtypes, method_all_as_kwargs_np={ - "dim": axis, - "keepdim": keepdim, + "dim": dim, + "dtype": dtypes[0], }, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -5292,28 +5419,26 @@ def test_torch_tensor_argmax( ) -# argmin +# cumsum @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="argmin", - dtype_input_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("numeric"), - force_int_axis=True, - min_num_dims=1, - max_num_dims=3, - min_dim_size=1, - max_dim_size=3, - min_value=1, - max_value=5, - valid_axis=True, - allow_neg_axes=True, + method_name="cumsum", + dtype_value=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), ), - keepdim=st.booleans(), + dim=helpers.get_axis( + shape=st.shared(helpers.get_shape(), key="shape"), + allow_neg=True, + force_int=True, + ), + dtypes=_dtypes(), ) -def test_torch_tensor_argmin( - dtype_input_axis, - keepdim, +def test_torch_tensor_cumsum( + dtype_value, + dim, + dtypes, frontend_method_data, init_flags, method_flags, @@ -5321,17 +5446,17 @@ def test_torch_tensor_argmin( on_device, backend_fw, ): - input_dtype, x, axis = dtype_input_axis + input_dtype, x = dtype_value helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ "data": x[0], }, - method_input_dtypes=input_dtype, + method_input_dtypes=dtypes, method_all_as_kwargs_np={ - "dim": axis, - "keepdim": keepdim, + "dim": dim, + "dtype": dtypes[0], }, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -5341,28 +5466,24 @@ def test_torch_tensor_argmin( ) -# argsort +# cumsum_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="argsort", - dtype_input_axis=helpers.dtype_values_axis( + method_name="cumsum_", + dtype_value=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), - force_int_axis=True, - min_num_dims=1, - max_num_dims=3, - min_dim_size=1, - max_dim_size=3, - min_value=1, - max_value=5, - valid_axis=True, - allow_neg_axes=True, + shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), + ), + dim=helpers.get_axis( + shape=st.shared(helpers.get_shape(), key="shape"), + allow_neg=True, + force_int=True, ), - descending=st.booleans(), ) -def test_torch_tensor_argsort( - dtype_input_axis, - descending, +def test_torch_tensor_cumsum_( + dtype_value, + dim, frontend_method_data, init_flags, method_flags, @@ -5370,7 +5491,7 @@ def test_torch_tensor_argsort( on_device, backend_fw, ): - input_dtype, x, axis = dtype_input_axis + input_dtype, x = dtype_value helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -5379,8 +5500,8 @@ def test_torch_tensor_argsort( }, method_input_dtypes=input_dtype, method_all_as_kwargs_np={ - "dim": axis, - "descending": descending, + "dim": dim, + "dtype": input_dtype[0], }, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -5390,18 +5511,14 @@ def test_torch_tensor_argsort( ) -# arccosh +# det @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="arccosh", - dtype_and_x=helpers.dtype_and_values( - min_value=-1.0, - max_value=1.0, - available_dtypes=helpers.get_dtypes("float"), - ), + method_name="det", + dtype_and_x=_get_dtype_and_matrix(square=True, batch=True), ) -def test_torch_tensor_arccosh( +def test_torch_tensor_det( dtype_and_x, frontend_method_data, init_flags, @@ -5415,9 +5532,9 @@ def test_torch_tensor_arccosh( init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ - "data": x[0], + "data": x, }, - method_input_dtypes=[], + method_input_dtypes=input_dtype, method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -5427,16 +5544,16 @@ def test_torch_tensor_arccosh( ) -# ceil +# detach @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="ceil", + method_name="detach", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("valid"), ), ) -def test_torch_tensor_ceil( +def test_torch_tensor_detach( dtype_and_x, frontend_method_data, init_flags, @@ -5462,16 +5579,16 @@ def test_torch_tensor_ceil( ) -# argwhere +# detach_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="argwhere", + method_name="detach_", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), ), ) -def test_torch_tensor_argwhere( +def test_torch_tensor_detach_( dtype_and_x, frontend_method_data, init_flags, @@ -5497,23 +5614,38 @@ def test_torch_tensor_argwhere( ) -# size +@given( + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid", prune_function=False) + ).filter(lambda x: "bfloat16" not in x[0]), +) +def test_torch_tensor_device( + dtype_x, + backend_fw, +): + ivy.set_backend(backend_fw) + _, data = dtype_x + x = Tensor(data[0]) + x.ivy_array = data[0] + ivy.utils.assertions.check_equal( + x.device, ivy.dev(ivy.array(data[0])), as_array=False + ) + ivy.previous_backend() + + @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="size", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), - ), - dim=helpers.get_axis( - shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), - force_int=True, + method_name="diag", + dtype_and_values=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + shape=st.shared(helpers.get_shape(min_num_dims=1, max_num_dims=2), key="shape"), ), + diagonal=st.integers(min_value=-100, max_value=100), ) -def test_torch_tensor_size( - dtype_and_x, - dim, +def test_torch_tensor_diag( + dtype_and_values, + diagonal, frontend_method_data, init_flags, method_flags, @@ -5521,16 +5653,16 @@ def test_torch_tensor_size( on_device, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, values = dtype_and_values helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ - "data": x[0], + "data": values[0], }, method_input_dtypes=input_dtype, method_all_as_kwargs_np={ - "dim": dim, + "diagonal": diagonal, }, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -5540,60 +5672,66 @@ def test_torch_tensor_size( ) -# min @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="min", - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + method_name="diagonal", + dtype_and_values=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + shape=st.shared(helpers.get_shape(min_num_dims=2), key="shape"), + ), + dims_and_offset=dims_and_offset( + shape=st.shared(helpers.get_shape(min_num_dims=2), key="shape") ), ) -def test_torch_tensor_min( - dtype_x, +def test_torch_tensor_diagonal( + dtype_and_values, + dims_and_offset, frontend, frontend_method_data, + backend_fw, init_flags, method_flags, on_device, - backend_fw, ): - input_dtype, x = dtype_x + input_dtype, value = dtype_and_values + dim1, dim2, offset = dims_and_offset + input = value[0] + num_dims = len(np.shape(input)) + assume(dim1 != dim2) + if dim1 < 0: + assume(dim1 + num_dims != dim2) + if dim2 < 0: + assume(dim1 != dim2 + num_dims) helpers.test_frontend_method( - init_input_dtypes=input_dtype, - backend_to_test=backend_fw, - init_all_as_kwargs_np={ - "data": x[0], + init_input_dtypes=[input_dtype[0]], + init_all_as_kwargs_np={"data": input}, + method_input_dtypes=[input_dtype[0]], + method_all_as_kwargs_np={ + "offset": offset, + "dim1": dim1, + "dim2": dim2, }, - method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, + frontend=frontend, frontend_method_data=frontend_method_data, + backend_to_test=backend_fw, init_flags=init_flags, method_flags=method_flags, - frontend=frontend, on_device=on_device, ) -@st.composite -def _get_dtype_and_multiplicative_matrices(draw): - return draw( - st.one_of( - _get_dtype_input_and_matrices(), - _get_dtype_and_3dbatch_matrices(), - ) - ) - - -# matmul +# dim @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="matmul", - dtype_tensor1_tensor2=_get_dtype_and_multiplicative_matrices(), + method_name="dim", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + ), ) -def test_torch_tensor_matmul( - dtype_tensor1_tensor2, +def test_torch_tensor_dim( + dtype_and_x, frontend_method_data, init_flags, method_flags, @@ -5601,15 +5739,15 @@ def test_torch_tensor_matmul( on_device, backend_fw, ): - dtype, tensor1, tensor2 = dtype_tensor1_tensor2 + input_dtype, x = dtype_and_x helpers.test_frontend_method( - init_input_dtypes=dtype, + init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ - "data": tensor1, + "data": x[0], }, - method_input_dtypes=dtype, - method_all_as_kwargs_np={"other": tensor2}, + method_input_dtypes=[], + method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -5618,64 +5756,42 @@ def test_torch_tensor_matmul( ) -@st.composite -def _array_idxes_n_dtype(draw, **kwargs): - num_dims = draw(helpers.ints(min_value=1, max_value=4)) - dtype, x = draw( - helpers.dtype_and_values( - **kwargs, min_num_dims=num_dims, max_num_dims=num_dims, shared_dtype=True - ) - ) - idxes = draw( - st.lists( - helpers.ints(min_value=0, max_value=num_dims - 1), - min_size=num_dims, - max_size=num_dims, - unique=True, - ) - ) - return x, idxes, dtype - - -# permute +# div @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="permute", - dtype_values_axis=_array_idxes_n_dtype( - available_dtypes=helpers.get_dtypes("float"), + method_name="div", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, + large_abs_safety_factor=2.5, + small_abs_safety_factor=2.5, + safety_factor_scale="log", ), + rounding_mode=st.sampled_from(["floor", "trunc"]) | st.none(), ) -def test_torch_tensor_permute( - dtype_values_axis, +def test_torch_tensor_div( + dtype_and_x, + rounding_mode, + frontend, frontend_method_data, init_flags, method_flags, - frontend, on_device, backend_fw, ): - x, idxes, dtype = dtype_values_axis - unpack_dims = True - if unpack_dims: - method_flags.num_positional_args = len(idxes) + 1 - dims = {} - i = 0 - for x_ in idxes: - dims["x{}".format(i)] = x_ - i += 1 - else: - dims = { - "dims": tuple(idxes), - } + input_dtype, x = dtype_and_x + assume(not np.any(np.isclose(x[1], 0))) + helpers.test_frontend_method( - init_input_dtypes=dtype, + init_input_dtypes=input_dtype, backend_to_test=backend_fw, - init_all_as_kwargs_np={ - "data": x[0], + init_all_as_kwargs_np={"data": x[0]}, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={ + "other": x[1], + "rounding_mode": rounding_mode, }, - method_input_dtypes=dtype, - method_all_as_kwargs_np=dims, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -5684,21 +5800,23 @@ def test_torch_tensor_permute( ) -# mean +# div_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="mean", - dtype_and_x=_statistical_dtype_values( - function="mean", - min_value=-1e04, - max_value=1e04, + method_name="div_", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, + large_abs_safety_factor=2.5, + small_abs_safety_factor=2.5, + safety_factor_scale="log", ), - keepdims=st.booleans(), + rounding_mode=st.sampled_from(["floor", "trunc"]) | st.none(), ) -def test_torch_tensor_mean( +def test_torch_tensor_div_( dtype_and_x, - keepdims, + rounding_mode, frontend, frontend_method_data, init_flags, @@ -5706,17 +5824,17 @@ def test_torch_tensor_mean( on_device, backend_fw, ): - input_dtype, x, axis = dtype_and_x + input_dtype, x = dtype_and_x + assume(not np.any(np.isclose(x[1], 0))) + helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, - init_all_as_kwargs_np={ - "data": x[0], - }, + init_all_as_kwargs_np={"data": x[0]}, method_input_dtypes=input_dtype, method_all_as_kwargs_np={ - "dim": axis, - "keepdim": keepdims, + "other": x[1], + "rounding_mode": rounding_mode, }, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -5726,19 +5844,21 @@ def test_torch_tensor_mean( ) -# nanmean +# divide @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="nanmean", - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + method_name="divide", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + num_arrays=2, min_value=-1e04, max_value=1e04, + allow_inf=False, ), ) -def test_torch_tensor_nanmean( - dtype_x, +def test_torch_tensor_divide( + dtype_and_x, frontend, frontend_method_data, init_flags, @@ -5746,7 +5866,7 @@ def test_torch_tensor_nanmean( on_device, backend_fw, ): - input_dtype, x = dtype_x + input_dtype, x = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -5754,39 +5874,38 @@ def test_torch_tensor_nanmean( "data": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, + method_all_as_kwargs_np={ + "other": x[1], + }, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, frontend=frontend, + atol_=1e-02, on_device=on_device, ) -# median @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="median", - dtype_input_axis=helpers.dtype_values_axis( + method_name="dot", + dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - min_num_dims=1, - valid_axis=True, - force_int_axis=True, + num_arrays=2, + shape=(1,), ), - keepdim=st.booleans(), ) -def test_torch_tensor_median( - dtype_input_axis, - keepdim, - frontend, +def test_torch_tensor_dot( + dtype_and_x, frontend_method_data, init_flags, method_flags, + frontend, on_device, backend_fw, ): - input_dtype, x, axis = dtype_input_axis + input_dtype, x = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -5795,8 +5914,7 @@ def test_torch_tensor_median( }, method_input_dtypes=input_dtype, method_all_as_kwargs_np={ - "dim": axis, - "keepdim": keepdim, + "tensor": x[1], }, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -5806,78 +5924,60 @@ def test_torch_tensor_median( ) -# transpose @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="transpose", - dtype_value=helpers.dtype_and_values( + method_name="double", + dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), - shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), - ), - dim0=helpers.get_axis( - shape=st.shared(helpers.get_shape(), key="shape"), - allow_neg=True, - force_int=True, - ), - dim1=helpers.get_axis( - shape=st.shared(helpers.get_shape(), key="shape"), - allow_neg=True, - force_int=True, ), ) -def test_torch_tensor_transpose( - dtype_value, - dim0, - dim1, +def test_torch_tensor_double( + dtype_and_x, frontend_method_data, init_flags, method_flags, frontend, - on_device, backend_fw, + on_device, ): - input_dtype, x = dtype_value + input_dtype, x = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, - backend_to_test=backend_fw, init_all_as_kwargs_np={ "data": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={"dim0": dim0, "dim1": dim1}, + method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, frontend=frontend, + backend_to_test=backend_fw, on_device=on_device, ) -# transpose_ +# dsplit @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="transpose_", + method_name="dsplit", dtype_value=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), - shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), - ), - dim0=helpers.get_axis( - shape=st.shared(helpers.get_shape(), key="shape"), - allow_neg=True, - force_int=True, + shape=st.shared(helpers.get_shape(min_num_dims=3), key="value_shape"), ), - dim1=helpers.get_axis( - shape=st.shared(helpers.get_shape(), key="shape"), - allow_neg=True, - force_int=True, + indices_or_sections=_get_splits( + min_num_dims=3, + axis=2, + allow_none=False, + allow_array_indices=False, + is_mod_split=True, ), ) -def test_torch_tensor_transpose_( +def test_torch_tensor_dsplit( dtype_value, - dim0, - dim1, + indices_or_sections, frontend_method_data, init_flags, method_flags, @@ -5892,11 +5992,8 @@ def test_torch_tensor_transpose_( init_all_as_kwargs_np={ "data": x[0], }, - method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "dim0": dim0, - "dim1": dim1, - }, + method_input_dtypes=[], + method_all_as_kwargs_np={"indices_or_sections": indices_or_sections}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -5905,17 +6002,33 @@ def test_torch_tensor_transpose_( ) -# t +@given( + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid", prune_function=False) + ).filter(lambda x: "bfloat16" not in x[0]), +) +def test_torch_tensor_dtype(dtype_x, backend_fw): + ivy.set_backend(backend_fw) + dtype, data = dtype_x + x = Tensor(data[0]) + x.ivy_array = data[0] + ivy.utils.assertions.check_equal(x.dtype, dtype[0], as_array=False) + ivy.previous_backend() + + @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="t", + method_name="eq_", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - shape=helpers.get_shape(min_num_dims=2, max_num_dims=2), + available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, + min_value=-1e04, + max_value=1e04, + allow_inf=False, ), ) -def test_torch_tensor_t( +def test_torch_tensor_eq_( dtype_and_x, frontend_method_data, init_flags, @@ -5932,7 +6045,9 @@ def test_torch_tensor_t( "data": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, + method_all_as_kwargs_np={ + "other": x[1], + }, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -5941,34 +6056,30 @@ def test_torch_tensor_t( ) -# flatten +# equal @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="flatten", - dtype_value=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - shape=st.shared(helpers.get_shape(), key="shape"), - ), - axes=helpers.get_axis( - shape=st.shared(helpers.get_shape(), key="shape"), - min_size=2, - max_size=2, - unique=False, - force_tuple=True, + method_name="equal", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, + shared_dtype=True, + min_num_dims=1, + min_value=-1e04, + max_value=1e04, ), ) -def test_torch_tensor_flatten( - dtype_value, - axes, +def test_torch_tensor_equal( + dtype_and_x, + frontend, frontend_method_data, init_flags, method_flags, - frontend, on_device, backend_fw, ): - input_dtype, x = dtype_value + input_dtype, x = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -5977,37 +6088,29 @@ def test_torch_tensor_flatten( }, method_input_dtypes=input_dtype, method_all_as_kwargs_np={ - "start_dim": axes[0], - "end_dim": axes[1], + "other": x[1], }, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, frontend=frontend, + atol_=1e-04, + rtol_=1e-04, on_device=on_device, ) -# cumsum +# erf @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="cumsum", - dtype_value=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), - ), - dim=helpers.get_axis( - shape=st.shared(helpers.get_shape(), key="shape"), - allow_neg=True, - force_int=True, + method_name="erf", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), ), - dtypes=_dtypes(), ) -def test_torch_tensor_cumsum( - dtype_value, - dim, - dtypes, +def test_torch_tensor_erf( + dtype_and_x, frontend_method_data, init_flags, method_flags, @@ -6015,18 +6118,15 @@ def test_torch_tensor_cumsum( on_device, backend_fw, ): - input_dtype, x = dtype_value + input_dtype, x = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ "data": x[0], }, - method_input_dtypes=dtypes, - method_all_as_kwargs_np={ - "dim": dim, - "dtype": dtypes[0], - }, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -6035,24 +6135,17 @@ def test_torch_tensor_cumsum( ) -# cumsum_ +# erf_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="cumsum_", - dtype_value=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), - ), - dim=helpers.get_axis( - shape=st.shared(helpers.get_shape(), key="shape"), - allow_neg=True, - force_int=True, + method_name="erf_", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), ), ) -def test_torch_tensor_cumsum_( - dtype_value, - dim, +def test_torch_tensor_erf_( + dtype_and_x, frontend_method_data, init_flags, method_flags, @@ -6060,7 +6153,7 @@ def test_torch_tensor_cumsum_( on_device, backend_fw, ): - input_dtype, x = dtype_value + input_dtype, x = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -6068,10 +6161,7 @@ def test_torch_tensor_cumsum_( "data": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "dim": dim, - "dtype": input_dtype[0], - }, + method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -6080,26 +6170,17 @@ def test_torch_tensor_cumsum_( ) -# sort +# exp @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="sort", - dtype_value=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), - ), - dim=helpers.get_axis( - shape=st.shared(helpers.get_shape(), key="shape"), - allow_neg=True, - force_int=True, + method_name="exp", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), ), - descending=st.booleans(), ) -def test_torch_tensor_sort( - dtype_value, - dim, - descending, +def test_torch_tensor_exp( + dtype_and_x, frontend_method_data, init_flags, method_flags, @@ -6107,7 +6188,7 @@ def test_torch_tensor_sort( on_device, backend_fw, ): - input_dtype, x = dtype_value + input_dtype, x = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -6115,10 +6196,7 @@ def test_torch_tensor_sort( "data": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "dim": dim, - "descending": descending, - }, + method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -6127,17 +6205,17 @@ def test_torch_tensor_sort( ) -# sigmoid +# exp_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="sigmoid", - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + method_name="exp_", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), ), ) -def test_torch_tensor_sigmoid( - dtype_x, +def test_torch_tensor_exp_( + dtype_and_x, frontend_method_data, init_flags, method_flags, @@ -6145,7 +6223,7 @@ def test_torch_tensor_sigmoid( on_device, backend_fw, ): - input_dtype, x = dtype_x + input_dtype, x = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -6162,17 +6240,16 @@ def test_torch_tensor_sigmoid( ) -# sigmoid @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="sigmoid_", - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - ), + method_name="expand", + dtype_x_shape=_expand_helper(), + unpack_shape=st.booleans(), ) -def test_torch_tensor_sigmoid_( - dtype_x, +def test_torch_tensor_expand( + dtype_x_shape, + unpack_shape, frontend_method_data, init_flags, method_flags, @@ -6180,7 +6257,18 @@ def test_torch_tensor_sigmoid_( on_device, backend_fw, ): - input_dtype, x = dtype_x + input_dtype, x, shape = dtype_x_shape + if unpack_shape: + method_flags.num_positional_args = len(shape) + 1 + size = {} + i = 0 + for x_ in shape: + size["x{}".format(i)] = x_ + i += 1 + else: + size = { + "size": shape, + } helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -6188,7 +6276,7 @@ def test_torch_tensor_sigmoid_( "data": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, + method_all_as_kwargs_np=size, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -6197,23 +6285,17 @@ def test_torch_tensor_sigmoid_( ) -# softmax +# expand_as @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="softmax", - dtype_x_and_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("float"), - min_num_dims=1, - max_axes_size=1, - force_int_axis=True, - valid_axis=True, + method_name="expand_as", + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), num_arrays=2 ), - dtype=helpers.get_dtypes("float", full=False), ) -def test_torch_tensor_softmax( - dtype_x_and_axis, - dtype, +def test_torch_tensor_expand_as( + dtype_x, frontend_method_data, init_flags, method_flags, @@ -6221,7 +6303,7 @@ def test_torch_tensor_softmax( on_device, backend_fw, ): - input_dtype, x, axis = dtype_x_and_axis + input_dtype, x = dtype_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -6230,8 +6312,7 @@ def test_torch_tensor_softmax( }, method_input_dtypes=input_dtype, method_all_as_kwargs_np={ - "dim": axis, - "dtype": dtype[0], + "other": x[1], }, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -6241,36 +6322,17 @@ def test_torch_tensor_softmax( ) -@st.composite -def _repeat_helper(draw): - shape = draw( - helpers.get_shape( - min_num_dims=1, max_num_dims=5, min_dim_size=2, max_dim_size=10 - ) - ) - - input_dtype, x = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - shape=shape, - ) - ) - - repeats = draw(st.lists(st.integers(min_value=1, max_value=5), min_size=len(shape))) - return input_dtype, x, repeats - - -# repeat +# expm1 @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="repeat", - dtype_x_repeats=_repeat_helper(), - unpack_repeat=st.booleans(), + method_name="expm1", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + ), ) -def test_torch_tensor_repeat( - dtype_x_repeats, - unpack_repeat, +def test_torch_tensor_expm1( + dtype_and_x, frontend_method_data, init_flags, method_flags, @@ -6278,14 +6340,7 @@ def test_torch_tensor_repeat( on_device, backend_fw, ): - input_dtype, x, repeats = dtype_x_repeats - repeat = { - "repeats": repeats, - } - if unpack_repeat: - method_flags.num_positional_args = len(repeat["repeats"]) + 1 - for i, x_ in enumerate(repeat["repeats"]): - repeat["x{}".format(i)] = x_ + input_dtype, x = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -6293,7 +6348,7 @@ def test_torch_tensor_repeat( "data": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np=repeat, + method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -6302,20 +6357,17 @@ def test_torch_tensor_repeat( ) -# unbind +# expm1_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="unbind", - dtype_value_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("numeric"), - min_num_dims=1, - valid_axis=True, - force_int_axis=True, + method_name="expm1_", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), ), ) -def test_torch_tensor_unbind( - dtype_value_axis, +def test_torch_tensor_expm1_( + dtype_and_x, frontend_method_data, init_flags, method_flags, @@ -6323,17 +6375,15 @@ def test_torch_tensor_unbind( on_device, backend_fw, ): - input_dtypes, x, axis = dtype_value_axis + input_dtype, x = dtype_and_x helpers.test_frontend_method( - init_input_dtypes=input_dtypes, + init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ "data": x[0], }, - method_input_dtypes=input_dtypes, - method_all_as_kwargs_np={ - "dim": axis, - }, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -6342,21 +6392,19 @@ def test_torch_tensor_unbind( ) -# __eq__ +# fill_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="__eq__", + method_name="fill_", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - min_value=-1e04, - max_value=1e04, - allow_inf=False, ), + value=helpers.floats(min_value=1, max_value=10), ) -def test_torch___eq__( +def test_torch_tensor_fill_( dtype_and_x, + value, frontend_method_data, init_flags, method_flags, @@ -6373,7 +6421,7 @@ def test_torch___eq__( }, method_input_dtypes=input_dtype, method_all_as_kwargs_np={ - "other": x[1], + "value": value, }, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -6383,18 +6431,18 @@ def test_torch___eq__( ) -# inverse +# fix @handle_frontend_method( class_tree=CLASS_TREE, - init_tree="torch.tensor", - method_name="inverse", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_num_dims=2, - ).filter(lambda s: s[1][0].shape[-1] == s[1][0].shape[-2]), + init_tree="torch.tensor", + method_name="fix", + dtype_value=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), + ), ) -def test_torch_tensor_inverse( - dtype_and_x, +def test_torch_tensor_fix( + dtype_value, frontend_method_data, init_flags, method_flags, @@ -6402,7 +6450,7 @@ def test_torch_tensor_inverse( on_device, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, x = dtype_value helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -6419,28 +6467,26 @@ def test_torch_tensor_inverse( ) -# neg +# fix_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="neg", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_value=-1e04, - max_value=1e04, - allow_inf=False, + method_name="fix_", + dtype_value=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), ), ) -def test_torch_tensor_neg( - dtype_and_x, - frontend, +def test_torch_tensor_fix_( + dtype_value, frontend_method_data, init_flags, method_flags, + frontend, on_device, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, x = dtype_value helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -6457,28 +6503,34 @@ def test_torch_tensor_neg( ) -# neg_ +# flatten @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="neg_", - dtype_and_x=helpers.dtype_and_values( + method_name="flatten", + dtype_value=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), - min_value=-1e04, - max_value=1e04, - allow_inf=False, + shape=st.shared(helpers.get_shape(), key="shape"), + ), + axes=helpers.get_axis( + shape=st.shared(helpers.get_shape(), key="shape"), + min_size=2, + max_size=2, + unique=False, + force_tuple=True, ), ) -def test_torch_tensor_neg_( - dtype_and_x, - frontend, +def test_torch_tensor_flatten( + dtype_value, + axes, frontend_method_data, init_flags, method_flags, + frontend, on_device, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, x = dtype_value helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -6486,7 +6538,10 @@ def test_torch_tensor_neg_( "data": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, + method_all_as_kwargs_np={ + "start_dim": axes[0], + "end_dim": axes[1], + }, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -6495,20 +6550,17 @@ def test_torch_tensor_neg_( ) -# __neg__ +# flip @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="__neg__", - dtype_and_x=helpers.dtype_and_values( + method_name="flip", + dtype_values_axis=_array_idxes_n_dtype( available_dtypes=helpers.get_dtypes("float"), - min_value=-1e04, - max_value=1e04, - allow_inf=False, ), ) -def test_torch___neg__( - dtype_and_x, +def test_torch_tensor_flip( + dtype_values_axis, frontend_method_data, init_flags, method_flags, @@ -6516,15 +6568,17 @@ def test_torch___neg__( on_device, backend_fw, ): - input_dtype, x = dtype_and_x + x, idxes, dtype = dtype_values_axis helpers.test_frontend_method( - init_input_dtypes=input_dtype, + init_input_dtypes=dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ "data": x[0], }, - method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, + method_input_dtypes=dtype, + method_all_as_kwargs_np={ + "dims": idxes, + }, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -6533,16 +6587,17 @@ def test_torch___neg__( ) -# int +# fliplr @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="int", + method_name="fliplr", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("integer"), + available_dtypes=helpers.get_dtypes("float"), + min_num_dims=2, ), ) -def test_torch_tensor_int( +def test_torch_tensor_fliplr( dtype_and_x, frontend_method_data, init_flags, @@ -6551,14 +6606,14 @@ def test_torch_tensor_int( on_device, backend_fw, ): - input_dtype, x = dtype_and_x + dtype, x = dtype_and_x helpers.test_frontend_method( - init_input_dtypes=input_dtype, + init_input_dtypes=dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ "data": x[0], }, - method_input_dtypes=input_dtype, + method_input_dtypes=dtype, method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -6568,17 +6623,16 @@ def test_torch_tensor_int( ) -# half @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="half", - dtype_and_x=helpers.dtype_and_values( + method_name="float", + dtype_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), ), ) -def test_torch_tensor_half( - dtype_and_x, +def test_torch_tensor_float( + dtype_x, frontend_method_data, init_flags, method_flags, @@ -6586,7 +6640,7 @@ def test_torch_tensor_half( on_device, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, x = dtype_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -6603,16 +6657,16 @@ def test_torch_tensor_half( ) -# bool +# floor @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="bool", + method_name="floor", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("integer"), + available_dtypes=helpers.get_dtypes("float"), ), ) -def test_torch_tensor_bool( +def test_torch_tensor_floor( dtype_and_x, frontend_method_data, init_flags, @@ -6638,19 +6692,16 @@ def test_torch_tensor_bool( ) -# type @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="type", + method_name="floor_", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), + available_dtypes=helpers.get_dtypes("float"), ), - dtype=helpers.get_dtypes("valid", full=False), ) -def test_torch_tensor_type( +def test_torch_tensor_floor_( dtype_and_x, - dtype, frontend_method_data, init_flags, method_flags, @@ -6666,9 +6717,7 @@ def test_torch_tensor_type( "data": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "dtype": dtype[0], - }, + method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -6677,17 +6726,17 @@ def test_torch_tensor_type( ) -# type_as +# fmin @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="type_as", + method_name="fmin", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), + available_dtypes=helpers.get_dtypes("float"), num_arrays=2, ), ) -def test_torch_tensor_type_as( +def test_torch_tensor_fmin( dtype_and_x, frontend_method_data, init_flags, @@ -6715,21 +6764,26 @@ def test_torch_tensor_type_as( ) -# byte +# fmod @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="byte", + method_name="fmod", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), + available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, + shared_dtype=True, + min_num_dims=1, + min_value=-100, + max_value=100, ), ) -def test_torch_tensor_byte( +def test_torch_tensor_fmod( dtype_and_x, + frontend, frontend_method_data, init_flags, method_flags, - frontend, on_device, backend_fw, ): @@ -6737,38 +6791,37 @@ def test_torch_tensor_byte( helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, - init_all_as_kwargs_np={ - "data": x[0], - }, + init_all_as_kwargs_np={"data": x[0]}, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, + method_all_as_kwargs_np={"other": x[1]}, + frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, - frontend=frontend, on_device=on_device, ) -# ne +# fmod_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="ne", + method_name="fmod_", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), num_arrays=2, - min_value=-1e04, - max_value=1e04, - allow_inf=False, + shared_dtype=True, + min_num_dims=1, + min_value=-100, + max_value=100, ), ) -def test_torch_tensor_ne( +def test_torch_tensor_fmod_( dtype_and_x, + frontend, frontend_method_data, init_flags, method_flags, - frontend, on_device, backend_fw, ): @@ -6776,85 +6829,81 @@ def test_torch_tensor_ne( helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, - init_all_as_kwargs_np={ - "data": x[0], - }, + init_all_as_kwargs_np={"data": x[0]}, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "other": x[1], - }, + method_all_as_kwargs_np={"other": x[1]}, + frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, - frontend=frontend, on_device=on_device, ) -# squeeze @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="squeeze", - dtype_value_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("valid"), - min_num_dims=1, - valid_axis=True, - force_int_axis=True, - shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), + method_name="gather", + params_indices_others=helpers.array_indices_axis( + array_dtypes=helpers.get_dtypes("valid"), + indices_dtypes=["int64"], + indices_same_dims=True, ), ) -def test_torch_tensor_squeeze( - dtype_value_axis, +def test_torch_tensor_gather( + params_indices_others, + frontend, frontend_method_data, init_flags, method_flags, - frontend, on_device, backend_fw, ): - input_dtype, x, axis = dtype_value_axis + input_dtypes, x, indices, axis, batch_dims = params_indices_others helpers.test_frontend_method( - init_input_dtypes=input_dtype, + init_input_dtypes=[input_dtypes[0]], backend_to_test=backend_fw, - init_all_as_kwargs_np={ - "data": x[0], - }, - method_input_dtypes=input_dtype, + init_all_as_kwargs_np={"data": x}, + method_input_dtypes=[input_dtypes[1]], method_all_as_kwargs_np={ "dim": axis, + "index": indices, }, + frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, - frontend=frontend, on_device=on_device, ) -# squeeze_ +# gcd @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="squeeze_", - dtype_value_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("valid"), + method_name="gcd", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("integer"), + min_value=-100, + max_value=100, min_num_dims=1, - valid_axis=True, - force_int_axis=True, - shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), + max_num_dims=3, + min_dim_size=1, + max_dim_size=3, + num_arrays=2, + shared_dtype=True, ), ) -def test_torch_tensor_squeeze_( - dtype_value_axis, +def test_torch_tensor_gcd( + dtype_and_x, + frontend, frontend_method_data, init_flags, method_flags, - frontend, on_device, backend_fw, ): - input_dtype, x, axis = dtype_value_axis + input_dtype, x = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -6863,27 +6912,54 @@ def test_torch_tensor_squeeze_( }, method_input_dtypes=input_dtype, method_all_as_kwargs_np={ - "dim": axis, + "other": x[1], }, + frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, - frontend=frontend, on_device=on_device, ) -# flip +def test_torch_tensor_grad(backend_fw): + ivy.set_backend(backend_fw) + x = Tensor(ivy.array([1.0, 2.0, 3.0])) + grads = ivy.array([1.0, 2.0, 3.0]) + x._grads = grads + assert ivy.array_equal(x.grad, grads) + ivy.previous_backend() + + +def test_torch_tensor_grad_fn(backend_fw): + ivy.set_backend(backend_fw) + x = Tensor(ivy.array([3.0]), requires_grad=True) + ivy.utils.assertions.check_equal(x.grad_fn, None, as_array=False) + y = x.pow(2) + ivy.utils.assertions.check_equal(y.grad_fn, "PowBackward", as_array=False) + ivy.utils.assertions.check_equal( + y.grad_fn.next_functions[0], "AccumulateGrad", as_array=False + ) + z = y.detach() + ivy.utils.assertions.check_equal(z.grad_fn, None, as_array=False) + ivy.previous_backend() + + +# greater @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="flip", - dtype_values_axis=_array_idxes_n_dtype( - available_dtypes=helpers.get_dtypes("float"), + method_name="greater", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + num_arrays=2, + min_value=-1e04, + max_value=1e04, + allow_inf=False, ), ) -def test_torch_tensor_flip( - dtype_values_axis, +def test_torch_tensor_greater( + dtype_and_x, frontend_method_data, init_flags, method_flags, @@ -6891,16 +6967,16 @@ def test_torch_tensor_flip( on_device, backend_fw, ): - x, idxes, dtype = dtype_values_axis + input_dtype, x = dtype_and_x helpers.test_frontend_method( - init_input_dtypes=dtype, + init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ "data": x[0], }, - method_input_dtypes=dtype, + method_input_dtypes=input_dtype, method_all_as_kwargs_np={ - "dims": idxes, + "other": x[1], }, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -6910,17 +6986,20 @@ def test_torch_tensor_flip( ) -# fliplr +# greater_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="fliplr", + method_name="greater_", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_num_dims=2, + available_dtypes=helpers.get_dtypes("valid"), + num_arrays=2, + min_value=-1e04, + max_value=1e04, + allow_inf=False, ), ) -def test_torch_tensor_fliplr( +def test_torch_tensor_greater_( dtype_and_x, frontend_method_data, init_flags, @@ -6929,15 +7008,17 @@ def test_torch_tensor_fliplr( on_device, backend_fw, ): - dtype, x = dtype_and_x + input_dtype, x = dtype_and_x helpers.test_frontend_method( - init_input_dtypes=dtype, + init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ "data": x[0], }, - method_input_dtypes=dtype, - method_all_as_kwargs_np={}, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={ + "other": x[1], + }, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -6946,20 +7027,21 @@ def test_torch_tensor_fliplr( ) -# tril +# greater_equal @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="tril", - dtype_and_values=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - min_num_dims=2, # Torch requires this. + method_name="greater_equal", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + num_arrays=2, + min_value=-1e04, + max_value=1e04, + allow_inf=False, ), - diagonal=st.integers(min_value=-100, max_value=100), ) -def test_torch_tensor_tril( - dtype_and_values, - diagonal, +def test_torch_tensor_greater_equal( + dtype_and_x, frontend_method_data, init_flags, method_flags, @@ -6967,7 +7049,7 @@ def test_torch_tensor_tril( on_device, backend_fw, ): - input_dtype, x = dtype_and_values + input_dtype, x = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -6976,7 +7058,7 @@ def test_torch_tensor_tril( }, method_input_dtypes=input_dtype, method_all_as_kwargs_np={ - "diagonal": diagonal, + "other": x[1], }, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -6986,20 +7068,21 @@ def test_torch_tensor_tril( ) -# tril_ +# greater_equal_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="tril_", - dtype_and_values=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - min_num_dims=2, # Torch requires this. + method_name="greater_equal_", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + num_arrays=2, + min_value=-1e04, + max_value=1e04, + allow_inf=False, ), - diagonal=st.integers(min_value=-100, max_value=100), ) -def test_torch_tensor_tril_( - dtype_and_values, - diagonal, +def test_torch_tensor_greater_equal_( + dtype_and_x, frontend_method_data, init_flags, method_flags, @@ -7007,7 +7090,7 @@ def test_torch_tensor_tril_( on_device, backend_fw, ): - input_dtype, x = dtype_and_values + input_dtype, x = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -7016,7 +7099,7 @@ def test_torch_tensor_tril_( }, method_input_dtypes=input_dtype, method_all_as_kwargs_np={ - "diagonal": diagonal, + "other": x[1], }, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -7026,29 +7109,31 @@ def test_torch_tensor_tril_( ) -# sqrt +# half @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="sqrt", - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), + method_name="half", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), ), ) -def test_torch_tensor_sqrt( - dtype_x, - frontend, +def test_torch_tensor_half( + dtype_and_x, frontend_method_data, init_flags, method_flags, + frontend, on_device, backend_fw, ): - input_dtype, x = dtype_x + input_dtype, x = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, - init_all_as_kwargs_np={"data": x[0]}, + init_all_as_kwargs_np={ + "data": x[0], + }, method_input_dtypes=input_dtype, method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, @@ -7059,17 +7144,17 @@ def test_torch_tensor_sqrt( ) -# sqrt_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="sqrt_", - dtype_x=helpers.dtype_and_values( + method_name="heaviside", + dtype_and_values=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, ), ) -def test_torch_tensor_sqrt_( - dtype_x, +def test_torch_tensor_heaviside( + dtype_and_values, frontend, frontend_method_data, init_flags, @@ -7077,35 +7162,45 @@ def test_torch_tensor_sqrt_( on_device, backend_fw, ): - input_dtype, x = dtype_x + input_dtype, values = dtype_and_values helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, - init_all_as_kwargs_np={"data": x[0]}, + init_all_as_kwargs_np={ + "data": values[0], + }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, - frontend_method_data=frontend_method_data, + method_all_as_kwargs_np={ + "values": values[1], + }, init_flags=init_flags, method_flags=method_flags, + frontend_method_data=frontend_method_data, frontend=frontend, on_device=on_device, ) -# index_select +# hsplit @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="index_select", - params_indices_others=helpers.array_indices_axis( - array_dtypes=helpers.get_dtypes("valid"), - indices_dtypes=["int64"], - max_num_dims=1, - indices_same_dims=True, + method_name="hsplit", + dtype_value=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + shape=st.shared(helpers.get_shape(min_num_dims=2), key="value_shape"), + ), + indices_or_sections=_get_splits( + min_num_dims=1, + axis=1, + allow_none=False, + allow_array_indices=False, + is_mod_split=True, ), ) -def test_torch_tensor_index_select( - params_indices_others, +def test_torch_tensor_hsplit( + dtype_value, + indices_or_sections, frontend_method_data, init_flags, method_flags, @@ -7113,18 +7208,15 @@ def test_torch_tensor_index_select( on_device, backend_fw, ): - input_dtypes, input, indices, axis, batch_dims = params_indices_others + input_dtype, x = dtype_value helpers.test_frontend_method( - init_input_dtypes=[input_dtypes[0]], + init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ - "data": input, - }, - method_input_dtypes=[input_dtypes[1]], - method_all_as_kwargs_np={ - "dim": axis, - "index": indices, + "data": x[0], }, + method_input_dtypes=[], + method_all_as_kwargs_np={"indices_or_sections": indices_or_sections}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -7133,70 +7225,28 @@ def test_torch_tensor_index_select( ) -@st.composite -def _arrays_dim_idx_n_dtypes(draw): - num_dims = draw(st.shared(helpers.ints(min_value=1, max_value=4), key="num_dims")) - num_arrays = 2 - common_shape = draw( - helpers.lists( - x=helpers.ints(min_value=2, max_value=3), - min_size=num_dims - 1, - max_size=num_dims - 1, - ) - ) - _dim = draw(helpers.ints(min_value=0, max_value=num_dims - 1)) - unique_dims = draw( - helpers.lists( - x=helpers.ints(min_value=2, max_value=3), - min_size=num_arrays, - max_size=num_arrays, - ) - ) - - min_dim = min(unique_dims) - max_dim = max(unique_dims) - _idx = draw( - helpers.array_values( - shape=min_dim, - dtype="int64", - min_value=0, - max_value=max_dim, - exclude_min=False, - ) - ) - - xs = list() - available_input_types = draw(helpers.get_dtypes("numeric")) - input_dtypes = draw( - helpers.array_dtypes( - available_dtypes=available_input_types, - num_arrays=num_arrays, - shared_dtype=True, - ) - ) - for ud, dt in zip(unique_dims, input_dtypes): - x = draw( - helpers.array_values( - shape=common_shape[:_dim] + [ud] + common_shape[_dim:], - dtype=dt, - large_abs_safety_factor=2.5, - small_abs_safety_factor=2.5, - safety_factor_scale="log", - ) - ) - xs.append(x) - return xs, input_dtypes, _dim, _idx +@given( + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("complex", prune_function=False) + ), +) +def test_torch_tensor_imag(dtype_x, backend_fw): + ivy.set_backend(backend_fw) + _, data = dtype_x + x = Tensor(data[0]) + x.ivy_array = data[0] + ivy.utils.assertions.check_equal(x.imag, ivy.imag(data[0])) + ivy.previous_backend() -# index_add @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="index_add_", + method_name="index_add", xs_dtypes_dim_idx=_arrays_dim_idx_n_dtypes(), alpha=st.integers(min_value=1, max_value=2), ) -def test_torch_tensor_index_add_( +def test_torch_tensor_index_add( *, xs_dtypes_dim_idx, alpha, @@ -7234,14 +7284,15 @@ def test_torch_tensor_index_add_( ) +# index_add @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="index_add", + method_name="index_add_", xs_dtypes_dim_idx=_arrays_dim_idx_n_dtypes(), alpha=st.integers(min_value=1, max_value=2), ) -def test_torch_tensor_index_add( +def test_torch_tensor_index_add_( *, xs_dtypes_dim_idx, alpha, @@ -7279,104 +7330,39 @@ def test_torch_tensor_index_add( ) -@st.composite -def _get_clamp_inputs(draw): - shape = draw( - helpers.get_shape( - min_num_dims=1, max_num_dims=5, min_dim_size=2, max_dim_size=10 - ) - ) - x_dtype, x = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - shape=shape, - ) - ) - min = draw(st.booleans()) - if min: - max = draw(st.booleans()) - min = draw( - helpers.array_values( - dtype=x_dtype[0], shape=shape, min_value=0, max_value=25 - ) - ) - max = ( - draw( - helpers.array_values( - dtype=x_dtype[0], shape=shape, min_value=26, max_value=50 - ) - ) - if max - else None - ) - else: - min = None - max = draw( - helpers.array_values( - dtype=x_dtype[0], shape=shape, min_value=26, max_value=50 - ) - ) - return x_dtype, x, min, max - - -# clamp +# index_select @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="clamp", - dtype_and_x_min_max=_get_clamp_inputs(), + method_name="index_select", + params_indices_others=helpers.array_indices_axis( + array_dtypes=helpers.get_dtypes("valid"), + indices_dtypes=["int64"], + max_num_dims=1, + indices_same_dims=True, + ), ) -def test_torch_tensor_clamp( - dtype_and_x_min_max, - frontend, +def test_torch_tensor_index_select( + params_indices_others, frontend_method_data, init_flags, method_flags, - on_device, - backend_fw, -): - input_dtype, x, min, max = dtype_and_x_min_max - helpers.test_frontend_method( - init_input_dtypes=input_dtype, - backend_to_test=backend_fw, - init_all_as_kwargs_np={ - "data": x[0], - }, - method_input_dtypes=input_dtype, - method_all_as_kwargs_np={"min": min, "max": max}, - frontend_method_data=frontend_method_data, - init_flags=init_flags, - method_flags=method_flags, - frontend=frontend, - on_device=on_device, - ) - - -# clamp_ -@handle_frontend_method( - class_tree=CLASS_TREE, - init_tree="torch.tensor", - method_name="clamp_", - dtype_and_x_min_max=_get_clamp_inputs(), -) -def test_torch_tensor_clamp_( - dtype_and_x_min_max, frontend, - frontend_method_data, - init_flags, - method_flags, on_device, backend_fw, ): - input_dtype, x, min, max = dtype_and_x_min_max + input_dtypes, input, indices, axis, batch_dims = params_indices_others helpers.test_frontend_method( - init_input_dtypes=input_dtype, + init_input_dtypes=[input_dtypes[0]], backend_to_test=backend_fw, init_all_as_kwargs_np={ - "data": x[0], + "data": input, + }, + method_input_dtypes=[input_dtypes[1]], + method_all_as_kwargs_np={ + "dim": axis, + "index": indices, }, - method_input_dtypes=input_dtype, - method_all_as_kwargs_np={"min": min, "max": max}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -7385,56 +7371,55 @@ def test_torch_tensor_clamp_( ) -# clip @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="clip", - input_and_ranges=_get_clamp_inputs(), + method_name="bmm", + dtype_and_matrices=_get_dtype_and_3dbatch_matrices(with_input=True, input_3d=True), ) -def test_torch_tensor_clip( - input_and_ranges, +def test_torch_tensor_instance_bmm( + dtype_and_matrices, + backend_fw, frontend, frontend_method_data, init_flags, method_flags, on_device, - backend_fw, ): - input_dtype, x, min, max = input_and_ranges + input_dtype, _, x, mat2 = dtype_and_matrices helpers.test_frontend_method( init_input_dtypes=input_dtype, - backend_to_test=backend_fw, - init_all_as_kwargs_np={ - "data": x[0], - }, + init_all_as_kwargs_np={"data": x}, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={"min": min, "max": max}, + method_all_as_kwargs_np={"mat2": mat2}, + frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, - frontend=frontend, on_device=on_device, + backend_to_test=backend_fw, ) -# clip_ +# int @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="clip_", - input_and_ranges=_get_clamp_inputs(), + method_name="int", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("integer"), + ), ) -def test_torch_tensor_clip_( - input_and_ranges, - frontend, +def test_torch_tensor_int( + dtype_and_x, frontend_method_data, init_flags, method_flags, + frontend, on_device, backend_fw, ): - input_dtype, x, min, max = input_and_ranges + input_dtype, x = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -7442,7 +7427,7 @@ def test_torch_tensor_clip_( "data": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={"min": min, "max": max}, + method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -7451,20 +7436,17 @@ def test_torch_tensor_clip_( ) -# __gt__ +# inverse @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="__gt__", + method_name="inverse", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - min_value=-1e04, - max_value=1e04, - allow_inf=False, - ), + min_num_dims=2, + ).filter(lambda s: s[1][0].shape[-1] == s[1][0].shape[-2]), ) -def test_torch___gt__( +def test_torch_tensor_inverse( dtype_and_x, frontend_method_data, init_flags, @@ -7481,9 +7463,7 @@ def test_torch___gt__( "data": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "other": x[1], - }, + method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -7492,20 +7472,16 @@ def test_torch___gt__( ) -# __ne__ +# is_complex @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="__ne__", + method_name="is_complex", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - min_value=-1e04, - max_value=1e04, - allow_inf=False, + available_dtypes=helpers.get_dtypes("valid"), ), ) -def test_torch___ne__( +def test_torch_tensor_is_complex( dtype_and_x, frontend_method_data, init_flags, @@ -7518,13 +7494,9 @@ def test_torch___ne__( helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, - init_all_as_kwargs_np={ - "data": x[0], - }, + init_all_as_kwargs_np={"data": x[0]}, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "other": x[1], - }, + method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -7533,20 +7505,87 @@ def test_torch___ne__( ) -# __lt__ +@given( + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid", prune_function=False) + ).filter(lambda x: "bfloat16" not in x[0]), +) +def test_torch_tensor_is_cuda( + dtype_x, + backend_fw, +): + ivy.set_backend(backend_fw) + _, data = dtype_x + x = Tensor(data[0]) + x.ivy_array = data[0] + ivy.utils.assertions.check_equal( + x.is_cuda, "gpu" in ivy.dev(ivy.array(data[0])), as_array=False + ) + ivy.previous_backend() + + +@given( + requires_grad=st.booleans(), +) +def test_torch_tensor_is_leaf(requires_grad, backend_fw): + ivy.set_backend(backend_fw) + x = Tensor(ivy.array([3.0]), requires_grad=requires_grad) + ivy.utils.assertions.check_equal(x.is_leaf, True, as_array=False) + y = x.pow(2) + ivy.utils.assertions.check_equal(y.is_leaf, not requires_grad, as_array=False) + z = y.detach() + ivy.utils.assertions.check_equal(z.is_leaf, True, as_array=False) + ivy.previous_backend() + + +@given( + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid", prune_function=False) + ).filter(lambda x: "bfloat16" not in x[0]), +) +def test_torch_tensor_is_meta( + dtype_x, + backend_fw, +): + ivy.set_backend(backend_fw) + _, data = dtype_x + x = Tensor(data[0]) + x.ivy_array = data[0] + ivy.utils.assertions.check_equal( + x.is_meta, "meta" in ivy.dev(ivy.array(data[0])), as_array=False + ) + ivy.previous_backend() + + +@given( + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid", prune_function=False) + ).filter(lambda x: "bfloat16" not in x[0]), +) +def test_torch_tensor_is_quantized( + dtype_x, + backend_fw, +): + ivy.set_backend(backend_fw) + _, data = dtype_x + x = Tensor(data[0]) + x.ivy_array = data[0] + ivy.utils.assertions.check_equal( + x.is_quantized, "q" in ivy.dtype(ivy.array(data[0])), as_array=False + ) + ivy.previous_backend() + + +# isinf @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="__lt__", + method_name="isinf", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - min_value=-1e04, - max_value=1e04, - allow_inf=False, + available_dtypes=helpers.get_dtypes("valid"), ), ) -def test_torch___lt__( +def test_torch_tensor_isinf( dtype_and_x, frontend_method_data, init_flags, @@ -7559,13 +7598,9 @@ def test_torch___lt__( helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, - init_all_as_kwargs_np={ - "data": x[0], - }, + init_all_as_kwargs_np={"data": x[0]}, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "other": x[1], - }, + method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -7574,20 +7609,16 @@ def test_torch___lt__( ) -# __or__ +# isreal @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="__or__", + method_name="isreal", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - num_arrays=2, - min_value=-1e04, - max_value=1e04, - allow_inf=False, + available_dtypes=helpers.get_dtypes("numeric"), ), ) -def test_torch___or__( +def test_torch_tensor_isreal( dtype_and_x, frontend_method_data, init_flags, @@ -7600,13 +7631,9 @@ def test_torch___or__( helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, - init_all_as_kwargs_np={ - "data": x[0], - }, + init_all_as_kwargs_np={"data": x[0]}, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "other": x[1], - }, + method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -7615,58 +7642,51 @@ def test_torch___or__( ) -# where -@handle_frontend_method( - class_tree=CLASS_TREE, - init_tree="torch.tensor", - method_name="where", - broadcastables=_broadcastable_trio(), +@given( + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid", prune_function=False) + ).filter(lambda x: "bfloat16" not in x[0]), ) -def test_torch_tensor_where( - broadcastables, - frontend_method_data, - init_flags, - method_flags, - frontend, - on_device, +def test_torch_tensor_ivy_array( + dtype_x, backend_fw, ): - cond, xs, dtypes = broadcastables - helpers.test_frontend_method( - init_input_dtypes=dtypes, - backend_to_test=backend_fw, - init_all_as_kwargs_np={ - "data": xs[0], - }, - method_input_dtypes=["bool", dtypes[1]], - method_all_as_kwargs_np={ - "condition": cond, - "other": xs[1], - }, - frontend_method_data=frontend_method_data, - init_flags=init_flags, - method_flags=method_flags, - frontend=frontend, - on_device=on_device, + _, data = dtype_x + ivy.set_backend(backend_fw) + x = Tensor(data[0]) + x.ivy_array = data[0] + ret = helpers.flatten_and_to_np(ret=x.ivy_array.data, backend=backend_fw) + ret_gt = helpers.flatten_and_to_np(ret=data[0], backend=backend_fw) + helpers.value_test( + ret_np_flat=ret, + ret_np_from_gt_flat=ret_gt, + backend="torch", ) -# clone +# lcm @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="clone", + method_name="lcm", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - num_arrays=1, + available_dtypes=helpers.get_dtypes("integer"), + num_arrays=2, + min_value=-100, + max_value=100, + min_num_dims=1, + max_num_dims=3, + min_dim_size=1, + max_dim_size=3, + shared_dtype=True, ), ) -def test_torch_tensor_clone( +def test_torch_tensor_lcm( dtype_and_x, + frontend, frontend_method_data, init_flags, method_flags, - frontend, on_device, backend_fw, ): @@ -7678,26 +7698,31 @@ def test_torch_tensor_clone( "data": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, + method_all_as_kwargs_np={ + "other": x[1], + }, + frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, - frontend=frontend, on_device=on_device, ) -# __invert__ +# less @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="__invert__", + method_name="less", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("integer"), - num_arrays=1, + available_dtypes=helpers.get_dtypes("valid"), + num_arrays=2, + min_value=-1e04, + max_value=1e04, + allow_inf=False, ), ) -def test_torch___invert__( +def test_torch_tensor_less( dtype_and_x, frontend_method_data, init_flags, @@ -7714,7 +7739,9 @@ def test_torch___invert__( "data": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, + method_all_as_kwargs_np={ + "other": x[1], + }, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -7723,17 +7750,20 @@ def test_torch___invert__( ) -# acosh +# less_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="acosh", + method_name="less_", dtype_and_x=helpers.dtype_and_values( - min_value=1.0, - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("valid"), + num_arrays=2, + min_value=-1e04, + max_value=1e04, + allow_inf=False, ), ) -def test_torch_tensor_acosh( +def test_torch_tensor_less_( dtype_and_x, frontend_method_data, init_flags, @@ -7749,55 +7779,9 @@ def test_torch_tensor_acosh( init_all_as_kwargs_np={ "data": x[0], }, - method_input_dtypes=[], - method_all_as_kwargs_np={}, - frontend_method_data=frontend_method_data, - init_flags=init_flags, - method_flags=method_flags, - frontend=frontend, - on_device=on_device, - ) - - -@st.composite -def _masked_fill_helper(draw): - cond, xs, dtypes = draw(_broadcastable_trio()) - if ivy.is_uint_dtype(dtypes[0]): - fill_value = draw(helpers.ints(min_value=0, max_value=5)) - elif ivy.is_int_dtype(dtypes[0]): - fill_value = draw(helpers.ints(min_value=-5, max_value=5)) - else: - fill_value = draw(helpers.floats(min_value=-5, max_value=5)) - return dtypes[0], xs[0], cond, fill_value - - -# masked_fill -@handle_frontend_method( - class_tree=CLASS_TREE, - init_tree="torch.tensor", - method_name="masked_fill", - x_mask_val=_masked_fill_helper(), -) -def test_torch_tensor_masked_fill( - x_mask_val, - frontend_method_data, - init_flags, - method_flags, - frontend, - on_device, - backend_fw, -): - dtype, x, mask, val = x_mask_val - helpers.test_frontend_method( - init_input_dtypes=[dtype], - backend_to_test=backend_fw, - init_all_as_kwargs_np={ - "data": x, - }, - method_input_dtypes=["bool", dtype], + method_input_dtypes=input_dtype, method_all_as_kwargs_np={ - "mask": mask, - "value": val, + "other": x[1], }, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -7807,17 +7791,20 @@ def test_torch_tensor_masked_fill( ) -# acosh_ +# less_equal @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="acosh_", + method_name="less_equal", dtype_and_x=helpers.dtype_and_values( - min_value=1.0, - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("valid"), + num_arrays=2, + min_value=-1e04, + max_value=1e04, + allow_inf=False, ), ) -def test_torch_tensor_acosh_( +def test_torch_tensor_less_equal( dtype_and_x, frontend_method_data, init_flags, @@ -7833,8 +7820,10 @@ def test_torch_tensor_acosh_( init_all_as_kwargs_np={ "data": x[0], }, - method_input_dtypes=[], - method_all_as_kwargs_np={}, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={ + "other": x[1], + }, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -7843,16 +7832,20 @@ def test_torch_tensor_acosh_( ) -# numpy +# less_equal_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="numpy", + method_name="less_equal_", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), + num_arrays=2, + min_value=-1e04, + max_value=1e04, + allow_inf=False, ), ) -def test_torch_tensor_numpy( +def test_torch_tensor_less_equal_( dtype_and_x, frontend_method_data, init_flags, @@ -7862,40 +7855,35 @@ def test_torch_tensor_numpy( backend_fw, ): input_dtype, x = dtype_and_x - ret, frontend_ret = helpers.test_frontend_method( + helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ "data": x[0], }, - method_input_dtypes=[], - method_all_as_kwargs_np={}, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={ + "other": x[1], + }, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, frontend=frontend, on_device=on_device, - test_values=False, - ) - # manual testing required as function return is numpy frontend - helpers.value_test( - ret_np_flat=helpers.flatten_and_to_np(ret=ret), - ret_np_from_gt_flat=frontend_ret[0], - ground_truth_backend="torch", ) -# atan2_ +# log @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="atan2_", + method_name="log", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, + allow_inf=False, ), ) -def test_torch_tensor_atan2_( +def test_torch_tensor_log( dtype_and_x, frontend_method_data, init_flags, @@ -7912,9 +7900,7 @@ def test_torch_tensor_atan2_( "data": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "other": x[1], - }, + method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -7923,17 +7909,17 @@ def test_torch_tensor_atan2_( ) -# bitwise_not_ +# log10 @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="bitwise_not_", + method_name="log10", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("integer"), - num_arrays=2, + available_dtypes=helpers.get_dtypes("float"), + allow_inf=False, ), ) -def test_torch_tensor_bitwise_not_( +def test_torch_tensor_log10( dtype_and_x, frontend_method_data, init_flags, @@ -7950,26 +7936,26 @@ def test_torch_tensor_bitwise_not_( "data": x[0], }, method_input_dtypes=input_dtype, + method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, - method_all_as_kwargs_np={}, frontend=frontend, on_device=on_device, ) -# bitwise_and_ +# log10_ tests @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="bitwise_and_", + method_name="log10_", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("integer"), - num_arrays=2, + available_dtypes=helpers.get_dtypes("float"), + allow_inf=False, ), ) -def test_torch_tensor_bitwise_and_( +def test_torch_tensor_log10_( dtype_and_x, frontend_method_data, init_flags, @@ -7986,9 +7972,7 @@ def test_torch_tensor_bitwise_and_( "data": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "other": x[1], - }, + method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -7997,96 +7981,84 @@ def test_torch_tensor_bitwise_and_( ) -# __and__ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="__and__", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=st.one_of(st.just(("bool",)), helpers.get_dtypes("integer")), - num_arrays=2, - min_value=-1e04, - max_value=1e04, - allow_inf=False, + method_name="log1p", + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + max_value=1e37, ), ) -def test_torch___and__( - dtype_and_x, +def test_torch_tensor_log1p( + dtype_x, + frontend, frontend_method_data, init_flags, method_flags, - frontend, on_device, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, x = dtype_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, - init_all_as_kwargs_np={ - "data": x[0], - }, + init_all_as_kwargs_np={"data": x[0]}, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "other": x[1], - }, + method_all_as_kwargs_np={}, + frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, - frontend=frontend, on_device=on_device, ) -# bitwise_xor +# log1p_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="bitwise_xor", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=st.one_of(st.just(("bool",)), helpers.get_dtypes("integer")), - num_arrays=2, + method_name="log1p_", + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + max_value=1e37, ), ) -def test_torch_tensor_bitwise_xor( - dtype_and_x, +def test_torch_tensor_log1p_( + dtype_x, + frontend, frontend_method_data, init_flags, method_flags, - frontend, on_device, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, x = dtype_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, - init_all_as_kwargs_np={ - "data": x[0], - }, + init_all_as_kwargs_np={"data": x[0]}, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "other": x[1], - }, + method_all_as_kwargs_np={}, + frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, - frontend=frontend, on_device=on_device, ) -# bitwise_xor_ +# log2 @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="bitwise_xor_", + method_name="log2", dtype_and_x=helpers.dtype_and_values( - available_dtypes=st.one_of(st.just(("bool",)), helpers.get_dtypes("integer")), - num_arrays=2, + available_dtypes=helpers.get_dtypes("float"), + allow_inf=False, ), ) -def test_torch_tensor_bitwise_xor_( +def test_torch_tensor_log2( dtype_and_x, frontend_method_data, init_flags, @@ -8103,9 +8075,7 @@ def test_torch_tensor_bitwise_xor_( "data": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "other": x[1], - }, + method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -8114,26 +8084,18 @@ def test_torch_tensor_bitwise_xor_( ) -# cumprod +# log2_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="cumprod", - dtype_value=helpers.dtype_and_values( + method_name="log2_", + dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), - shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), - ), - dim=helpers.get_axis( - shape=st.shared(helpers.get_shape(), key="shape"), - allow_neg=True, - force_int=True, + allow_inf=False, ), - dtypes=_dtypes(), ) -def test_torch_tensor_cumprod( - dtype_value, - dim, - dtypes, +def test_torch_tensor_log2_( + dtype_and_x, frontend_method_data, init_flags, method_flags, @@ -8141,18 +8103,15 @@ def test_torch_tensor_cumprod( on_device, backend_fw, ): - input_dtype, x = dtype_value + input_dtype, x = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ "data": x[0], }, - method_input_dtypes=dtypes, - method_all_as_kwargs_np={ - "dim": dim, - "dtype": dtypes[0], - }, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -8161,17 +8120,17 @@ def test_torch_tensor_cumprod( ) -# relu +# log_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="relu", + method_name="log_", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), allow_inf=False, ), ) -def test_torch_tensor_relu( +def test_torch_tensor_log_( dtype_and_x, frontend_method_data, init_flags, @@ -8197,17 +8156,21 @@ def test_torch_tensor_relu( ) -# fmin +# logaddexp @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="fmin", + method_name="logaddexp", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), num_arrays=2, + min_num_dims=1, + min_value=-100, + max_value=100, + shared_dtype=True, ), ) -def test_torch_tensor_fmin( +def test_torch_tensor_logaddexp( dtype_and_x, frontend_method_data, init_flags, @@ -8235,18 +8198,15 @@ def test_torch_tensor_fmin( ) -# msort +# logdet @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="msort", - dtype_value=helpers.dtype_and_values( - available_dtypes=["float32", "float64", "int32", "int64"], - shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), - ), + method_name="logdet", + dtype_and_x=_get_dtype_and_matrix(square=True, batch=True), ) -def test_torch_tensor_msort( - dtype_value, +def test_torch_tensor_logdet( + dtype_and_x, frontend_method_data, init_flags, method_flags, @@ -8254,12 +8214,14 @@ def test_torch_tensor_msort( on_device, backend_fw, ): - input_dtype, x = dtype_value + input_dtype, x = dtype_and_x + dtype, x = dtype_and_x + x = np.matmul(x.T, x) + np.identity(x.shape[0]) helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ - "data": x[0], + "data": x, }, method_input_dtypes=input_dtype, method_all_as_kwargs_np={}, @@ -8271,24 +8233,18 @@ def test_torch_tensor_msort( ) -# count_nonzero +# logical_and @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="count_nonzero", - dtype_value=helpers.dtype_and_values( + method_name="logical_and", + dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), - shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), - ), - dim=helpers.get_axis( - shape=st.shared(helpers.get_shape(), key="shape"), - allow_neg=True, - force_int=True, + num_arrays=2, ), ) -def test_torch_tensor_count_nonzero( - dtype_value, - dim, +def test_torch_tensor_logical_and( + dtype_and_x, frontend_method_data, init_flags, method_flags, @@ -8296,7 +8252,7 @@ def test_torch_tensor_count_nonzero( on_device, backend_fw, ): - input_dtype, x = dtype_value + input_dtype, x = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -8304,7 +8260,9 @@ def test_torch_tensor_count_nonzero( "data": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={"dim": dim}, + method_all_as_kwargs_np={ + "other": x[1], + }, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -8313,16 +8271,16 @@ def test_torch_tensor_count_nonzero( ) -# exp +# logical_not @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="exp", + method_name="logical_not", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), + available_dtypes=helpers.get_dtypes("valid"), num_arrays=1 ), ) -def test_torch_tensor_exp( +def test_torch_tensor_logical_not( dtype_and_x, frontend_method_data, init_flags, @@ -8348,16 +8306,18 @@ def test_torch_tensor_exp( ) -# exp_ +# logical_not_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="exp_", + method_name="logical_not_", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), + available_dtypes=helpers.get_dtypes("valid"), + num_arrays=1, + large_abs_safety_factor=12, ), ) -def test_torch_tensor_exp_( +def test_torch_tensor_logical_not_( dtype_and_x, frontend_method_data, init_flags, @@ -8383,16 +8343,17 @@ def test_torch_tensor_exp_( ) -# expm1 +# logical_or @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="expm1", + method_name="logical_or", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), + available_dtypes=helpers.get_dtypes("valid"), + num_arrays=2, ), ) -def test_torch_tensor_expm1( +def test_torch_tensor_logical_or( dtype_and_x, frontend_method_data, init_flags, @@ -8409,7 +8370,9 @@ def test_torch_tensor_expm1( "data": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, + method_all_as_kwargs_np={ + "other": x[1], + }, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -8418,16 +8381,16 @@ def test_torch_tensor_expm1( ) -# expm1_ +# long @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="expm1_", + method_name="long", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("integer"), ), ) -def test_torch_tensor_expm1_( +def test_torch_tensor_long( dtype_and_x, frontend_method_data, init_flags, @@ -8443,8 +8406,44 @@ def test_torch_tensor_expm1_( init_all_as_kwargs_np={ "data": x[0], }, - method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={}, + frontend_method_data=frontend_method_data, + init_flags=init_flags, + method_flags=method_flags, + frontend=frontend, + on_device=on_device, + ) + + +# masked_fill +@handle_frontend_method( + class_tree=CLASS_TREE, + init_tree="torch.tensor", + method_name="masked_fill", + x_mask_val=_masked_fill_helper(), +) +def test_torch_tensor_masked_fill( + x_mask_val, + frontend_method_data, + init_flags, + method_flags, + frontend, + on_device, + backend_fw, +): + dtype, x, mask, val = x_mask_val + helpers.test_frontend_method( + init_input_dtypes=[dtype], + backend_to_test=backend_fw, + init_all_as_kwargs_np={ + "data": x, + }, + method_input_dtypes=["bool", dtype], + method_all_as_kwargs_np={ + "mask": mask, + "value": val, + }, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -8453,18 +8452,15 @@ def test_torch_tensor_expm1_( ) -# mul +# matmul @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="mul", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, - ), + method_name="matmul", + dtype_tensor1_tensor2=_get_dtype_and_multiplicative_matrices(), ) -def test_torch_tensor_mul( - dtype_and_x, +def test_torch_tensor_matmul( + dtype_tensor1_tensor2, frontend_method_data, init_flags, method_flags, @@ -8472,17 +8468,15 @@ def test_torch_tensor_mul( on_device, backend_fw, ): - input_dtype, x = dtype_and_x + dtype, tensor1, tensor2 = dtype_tensor1_tensor2 helpers.test_frontend_method( - init_input_dtypes=input_dtype, + init_input_dtypes=dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ - "data": x[0], - }, - method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "other": x[1], + "data": tensor1, }, + method_input_dtypes=dtype, + method_all_as_kwargs_np={"other": tensor2}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -8491,17 +8485,17 @@ def test_torch_tensor_mul( ) -# ceil_ +# max @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="ceil_", - dtype_and_x=helpers.dtype_and_values( + method_name="max", + dtype_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), ), ) -def test_torch_tensor_ceil_( - dtype_and_x, +def test_torch_tensor_max( + dtype_x, frontend_method_data, init_flags, method_flags, @@ -8509,7 +8503,7 @@ def test_torch_tensor_ceil_( on_device, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, x = dtype_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -8526,27 +8520,29 @@ def test_torch_tensor_ceil_( ) -# mul_ +# mean @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="mul_", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, - shared_dtype=True, + method_name="mean", + dtype_and_x=_statistical_dtype_values( + function="mean", + min_value=-1e04, + max_value=1e04, ), + keepdims=st.booleans(), ) -def test_torch_tensor_mul_( +def test_torch_tensor_mean( dtype_and_x, + keepdims, + frontend, frontend_method_data, init_flags, method_flags, - frontend, on_device, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, x, axis = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -8555,7 +8551,8 @@ def test_torch_tensor_mul_( }, method_input_dtypes=input_dtype, method_all_as_kwargs_np={ - "other": x[1], + "dim": axis, + "keepdim": keepdims, }, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -8565,26 +8562,30 @@ def test_torch_tensor_mul_( ) -# trunc +# median @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="trunc", - dtype_value=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), + method_name="median", + dtype_input_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("float"), + min_num_dims=1, + valid_axis=True, + force_int_axis=True, ), + keepdim=st.booleans(), ) -def test_torch_tensor_trunc( - dtype_value, +def test_torch_tensor_median( + dtype_input_axis, + keepdim, + frontend, frontend_method_data, init_flags, method_flags, - frontend, on_device, backend_fw, ): - input_dtype, x = dtype_value + input_dtype, x, axis = dtype_input_axis helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -8592,7 +8593,10 @@ def test_torch_tensor_trunc( "data": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, + method_all_as_kwargs_np={ + "dim": axis, + "keepdim": keepdim, + }, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -8601,26 +8605,25 @@ def test_torch_tensor_trunc( ) -# trunc_ +# min @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="trunc_", - dtype_value=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), + method_name="min", + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), ), ) -def test_torch_tensor_trunc_( - dtype_value, +def test_torch_tensor_min( + dtype_x, + frontend, frontend_method_data, init_flags, method_flags, - frontend, on_device, backend_fw, ): - input_dtype, x = dtype_value + input_dtype, x = dtype_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -8637,18 +8640,15 @@ def test_torch_tensor_trunc_( ) -# fix +# mm @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="fix", - dtype_value=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), - ), + method_name="mm", + dtype_xy=_get_dtype_input_and_matrices(), ) -def test_torch_tensor_fix( - dtype_value, +def test_torch_tensor_mm( + dtype_xy, frontend_method_data, init_flags, method_flags, @@ -8656,15 +8656,17 @@ def test_torch_tensor_fix( on_device, backend_fw, ): - input_dtype, x = dtype_value + dtype, x, y = dtype_xy helpers.test_frontend_method( - init_input_dtypes=input_dtype, + init_input_dtypes=dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ - "data": x[0], + "data": x, + }, + method_input_dtypes=dtype, + method_all_as_kwargs_np={ + "mat2": y, }, - method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -8673,55 +8675,96 @@ def test_torch_tensor_fix( ) -# fix_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="fix_", - dtype_value=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), + method_name="movedim", + dtype_and_input=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_value=-100, + max_value=100, + shape=st.shared( + helpers.get_shape( + min_num_dims=1, + max_num_dims=3, + min_dim_size=1, + max_dim_size=3, + ), + key="a_s_d", + ), + ), + source=helpers.get_axis( + allow_none=False, + unique=True, + shape=st.shared( + helpers.get_shape( + min_num_dims=1, + max_num_dims=3, + min_dim_size=1, + max_dim_size=3, + ), + key="a_s_d", + ), + min_size=1, + force_int=True, + ), + destination=helpers.get_axis( + allow_none=False, + unique=True, + shape=st.shared( + helpers.get_shape( + min_num_dims=1, + max_num_dims=3, + min_dim_size=1, + max_dim_size=3, + ), + key="a_s_d", + ), + min_size=1, + force_int=True, ), ) -def test_torch_tensor_fix_( - dtype_value, +def test_torch_tensor_movedim( + dtype_and_input, + source, + destination, + frontend, frontend_method_data, init_flags, method_flags, - frontend, on_device, backend_fw, ): - input_dtype, x = dtype_value + input_dtype, value = dtype_and_input helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, - init_all_as_kwargs_np={ - "data": x[0], - }, + init_all_as_kwargs_np={"data": value[0]}, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, + method_all_as_kwargs_np={ + "source": source, + "destination": destination, + }, + frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, - frontend=frontend, on_device=on_device, ) -# round +# msort @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="round", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + method_name="msort", + dtype_value=helpers.dtype_and_values( + available_dtypes=["float32", "float64", "int32", "int64"], + shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), ), - decimals=st.integers(min_value=0, max_value=5), ) -def test_torch_tensor_round( - dtype_and_x, - decimals, +def test_torch_tensor_msort( + dtype_value, frontend_method_data, init_flags, method_flags, @@ -8729,7 +8772,7 @@ def test_torch_tensor_round( on_device, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, x = dtype_value helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -8737,9 +8780,7 @@ def test_torch_tensor_round( "data": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "decimals": decimals, - }, + method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -8748,19 +8789,18 @@ def test_torch_tensor_round( ) -# round_ +# mul @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="round_", + method_name="mul", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, ), - decimals=st.integers(min_value=0, max_value=5), ) -def test_torch_tensor_round_( +def test_torch_tensor_mul( dtype_and_x, - decimals, frontend_method_data, init_flags, method_flags, @@ -8777,7 +8817,7 @@ def test_torch_tensor_round_( }, method_input_dtypes=input_dtype, method_all_as_kwargs_np={ - "decimals": decimals, + "other": x[1], }, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -8787,26 +8827,19 @@ def test_torch_tensor_round_( ) -# cross +# mul_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="cross", - dtype_input_other_dim=dtype_value1_value2_axis( + method_name="mul_", + dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), - min_num_dims=1, - max_num_dims=10, - min_dim_size=3, - max_dim_size=3, - min_value=-1e10, - max_value=1e10, - abs_smallest_val=0.01, - large_abs_safety_factor=2, - safety_factor_scale="log", + num_arrays=2, + shared_dtype=True, ), -) -def test_torch_tensor_cross( - dtype_input_other_dim, +) +def test_torch_tensor_mul_( + dtype_and_x, frontend_method_data, init_flags, method_flags, @@ -8814,36 +8847,36 @@ def test_torch_tensor_cross( on_device, backend_fw, ): - dtype, input, other, dim = dtype_input_other_dim + input_dtype, x = dtype_and_x helpers.test_frontend_method( - init_input_dtypes=dtype, + init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ - "data": input, + "data": x[0], }, - method_input_dtypes=dtype, + method_input_dtypes=input_dtype, method_all_as_kwargs_np={ - "other": other, - "dim": dim, + "other": x[1], }, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, frontend=frontend, on_device=on_device, - rtol_=1e-2, - atol_=1e-2, ) -# det +# multiply @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="det", - dtype_and_x=_get_dtype_and_matrix(square=True, batch=True), + method_name="multiply", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, + ), ) -def test_torch_tensor_det( +def test_torch_tensor_multiply( dtype_and_x, frontend_method_data, init_flags, @@ -8857,10 +8890,12 @@ def test_torch_tensor_det( init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ - "data": x, + "data": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, + method_all_as_kwargs_np={ + "other": x[1], + }, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -8869,17 +8904,17 @@ def test_torch_tensor_det( ) -# reciprocal +# multiply_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="reciprocal", + method_name="multiply_", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_value=1, + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, ), ) -def test_torch_tensor_reciprocal( +def test_torch_tensor_multiply_( dtype_and_x, frontend_method_data, init_flags, @@ -8896,7 +8931,9 @@ def test_torch_tensor_reciprocal( "data": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, + method_all_as_kwargs_np={ + "other": x[1], + }, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -8905,26 +8942,27 @@ def test_torch_tensor_reciprocal( ) -# reciprocal_ +# nanmean @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="reciprocal_", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - min_value=1, + method_name="nanmean", + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_value=-1e04, + max_value=1e04, ), ) -def test_torch_tensor_reciprocal_( - dtype_and_x, +def test_torch_tensor_nanmean( + dtype_x, + frontend, frontend_method_data, init_flags, method_flags, - frontend, on_device, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, x = dtype_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -8941,56 +8979,65 @@ def test_torch_tensor_reciprocal_( ) -# fill_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="fill_", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - ), - value=helpers.floats(min_value=1, max_value=10), + method_name="narrow", + dtype_input_dim_start_length=_dtype_input_dim_start_length(), ) -def test_torch_tensor_fill_( - dtype_and_x, - value, +def test_torch_tensor_narrow( + dtype_input_dim_start_length, + frontend, frontend_method_data, init_flags, method_flags, - frontend, on_device, backend_fw, ): - input_dtype, x = dtype_and_x + (input_dtype, x, dim, start, length) = dtype_input_dim_start_length helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, - init_all_as_kwargs_np={ - "data": x[0], - }, + init_all_as_kwargs_np={"data": x[0]}, method_input_dtypes=input_dtype, method_all_as_kwargs_np={ - "value": value, + "dim": dim, + "start": start, + "length": length, }, + frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, - frontend=frontend, on_device=on_device, ) -# nonzero +@given( + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid", prune_function=False), + ret_shape=True, + ).filter(lambda x: "bfloat16" not in x[0]), +) +def test_torch_tensor_ndim(dtype_x, backend_fw): + ivy.set_backend(backend_fw) + dtype, data, shape = dtype_x + x = Tensor(data[0]) + ivy.utils.assertions.check_equal(x.ndim, data[0].ndim, as_array=False) + ivy.previous_backend() + + +# ndimension @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="nonzero", - dtype_and_values=helpers.dtype_and_values( + method_name="ndimension", + dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), ), ) -def test_torch_tensor_nonzero( - dtype_and_values, +def test_torch_tensor_ndimension( + dtype_and_x, frontend_method_data, init_flags, method_flags, @@ -8998,14 +9045,14 @@ def test_torch_tensor_nonzero( on_device, backend_fw, ): - input_dtype, x = dtype_and_values + input_dtype, x = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ "data": x[0], }, - method_input_dtypes=input_dtype, + method_input_dtypes=[], method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -9015,15 +9062,21 @@ def test_torch_tensor_nonzero( ) -# mm +# ne @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="mm", - dtype_xy=_get_dtype_input_and_matrices(), + method_name="ne", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, + min_value=-1e04, + max_value=1e04, + allow_inf=False, + ), ) -def test_torch_tensor_mm( - dtype_xy, +def test_torch_tensor_ne( + dtype_and_x, frontend_method_data, init_flags, method_flags, @@ -9031,16 +9084,16 @@ def test_torch_tensor_mm( on_device, backend_fw, ): - dtype, x, y = dtype_xy + input_dtype, x = dtype_and_x helpers.test_frontend_method( - init_input_dtypes=dtype, + init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ - "data": x, + "data": x[0], }, - method_input_dtypes=dtype, + method_input_dtypes=input_dtype, method_all_as_kwargs_np={ - "mat2": y, + "other": x[1], }, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -9050,17 +9103,20 @@ def test_torch_tensor_mm( ) -# square +# neg @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="square", - dtype_x=helpers.dtype_and_values( + method_name="neg", + dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), + min_value=-1e04, + max_value=1e04, + allow_inf=False, ), ) -def test_torch_tensor_square( - dtype_x, +def test_torch_tensor_neg( + dtype_and_x, frontend, frontend_method_data, init_flags, @@ -9068,11 +9124,13 @@ def test_torch_tensor_square( on_device, backend_fw, ): - input_dtype, x = dtype_x + input_dtype, x = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, - init_all_as_kwargs_np={"data": x[0]}, + init_all_as_kwargs_np={ + "data": x[0], + }, method_input_dtypes=input_dtype, method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, @@ -9083,22 +9141,24 @@ def test_torch_tensor_square( ) -# log10 +# neg_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="log10", + method_name="neg_", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("valid"), + min_value=-1e04, + max_value=1e04, allow_inf=False, ), ) -def test_torch_tensor_log10( +def test_torch_tensor_neg_( dtype_and_x, + frontend, frontend_method_data, init_flags, method_flags, - frontend, on_device, backend_fw, ): @@ -9119,18 +9179,22 @@ def test_torch_tensor_log10( ) -# log10_ tests +# new_empty (not actually intuitive for testing) @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="log10_", + method_name="new_empty", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - allow_inf=False, + available_dtypes=helpers.get_dtypes("numeric"), + ), + size=helpers.get_shape( + min_num_dims=1, + max_num_dims=3, ), ) -def test_torch_tensor_log10_( +def test_torch_tensor_new_empty( dtype_and_x, + size, frontend_method_data, init_flags, method_flags, @@ -9140,13 +9204,15 @@ def test_torch_tensor_log10_( ): input_dtype, x = dtype_and_x helpers.test_frontend_method( - init_input_dtypes=input_dtype, + init_input_dtypes=[input_dtype[0]], backend_to_test=backend_fw, init_all_as_kwargs_np={ - "data": x[0], + "data": x, + }, + method_input_dtypes=[ivy.int32], + method_all_as_kwargs_np={ + "size": size, }, - method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -9155,17 +9221,14 @@ def test_torch_tensor_log10_( ) -# zero_ tests +# new_full @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="zero_", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - allow_inf=False, - ), + method_name="new_full", + dtype_and_x=_fill_value_and_size(max_num_dims=3), ) -def test_torch_tensor_zero_( +def test_torch_tensor_new_full( dtype_and_x, frontend_method_data, init_flags, @@ -9176,13 +9239,16 @@ def test_torch_tensor_zero_( ): input_dtype, x = dtype_and_x helpers.test_frontend_method( - init_input_dtypes=input_dtype, + init_input_dtypes=[input_dtype[0]], backend_to_test=backend_fw, init_all_as_kwargs_np={ "data": x[0], }, - method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, + method_input_dtypes=[input_dtype[1]], + method_all_as_kwargs_np={ + "size": x[1], + "fill_value": x[2], + }, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -9191,25 +9257,32 @@ def test_torch_tensor_zero_( ) -# short +# new_ones @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="short", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - min_value=-1e04, - max_value=1e04, - allow_inf=False, + method_name="new_ones", + dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), + size=helpers.get_shape( + allow_none=False, + min_num_dims=1, + max_num_dims=5, + min_dim_size=1, + max_dim_size=10, ), + dtypes=_dtypes(), + requires_grad=_requires_grad(), ) -def test_torch_tensor_short( +def test_torch_tensor_new_ones( dtype_and_x, + size, + dtypes, + requires_grad, + on_device, frontend_method_data, init_flags, method_flags, frontend, - on_device, backend_fw, ): input_dtype, x = dtype_and_x @@ -9219,8 +9292,13 @@ def test_torch_tensor_short( init_all_as_kwargs_np={ "data": x[0], }, - method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, + method_input_dtypes=dtypes, + method_all_as_kwargs_np={ + "size": size, + "dtype": dtypes[0], + "requires_grad": requires_grad, + "device": on_device, + }, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -9229,50 +9307,36 @@ def test_torch_tensor_short( ) -# prod +# new_tensor @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="prod", - dtype_x_axis=helpers.dtype_values_axis( + method_name="new_tensor", + dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), - min_num_dims=1, - max_num_dims=5, - valid_axis=True, - allow_neg_axes=False, - max_axes_size=1, - force_int_axis=True, - large_abs_safety_factor=10, - small_abs_safety_factor=10, - safety_factor_scale="log", + num_arrays=2, ), - dtype=helpers.get_dtypes("float", none=True, full=False), - keepdims=st.booleans(), ) -def test_torch_tensor_prod( - dtype_x_axis, - dtype, - keepdims, - frontend, +def test_torch_tensor_new_tensor( + dtype_and_x, frontend_method_data, init_flags, method_flags, + frontend, on_device, backend_fw, ): - input_dtype, x, axis = dtype_x_axis - if ivy.current_backend_str() == "torch": - init_flags.as_variable = [False] - method_flags.as_variable = [False] + input_dtype, x = dtype_and_x helpers.test_frontend_method( - init_input_dtypes=input_dtype, + init_input_dtypes=[input_dtype[0]], backend_to_test=backend_fw, - init_all_as_kwargs_np={"data": x[0]}, - method_input_dtypes=input_dtype, + init_all_as_kwargs_np={ + "data": x[0], + }, + method_input_dtypes=[input_dtype[1]], method_all_as_kwargs_np={ - "dim": axis, - "keepdim": keepdims, - "dtype": dtype[0], + "data": x[1], + "dtype": input_dtype[1], }, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -9282,41 +9346,47 @@ def test_torch_tensor_prod( ) -# div +# new_zeros @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="div", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, - large_abs_safety_factor=2.5, - small_abs_safety_factor=2.5, - safety_factor_scale="log", + method_name="new_zeros", + dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("valid")), + size=helpers.get_shape( + allow_none=False, + min_num_dims=1, + max_num_dims=5, + min_dim_size=1, + max_dim_size=10, ), - rounding_mode=st.sampled_from(["floor", "trunc"]) | st.none(), + dtypes=_dtypes(), + requires_grad=_requires_grad(), ) -def test_torch_tensor_div( +def test_torch_tensor_new_zeros( dtype_and_x, - rounding_mode, - frontend, + size, + dtypes, + requires_grad, + on_device, frontend_method_data, init_flags, method_flags, - on_device, + frontend, backend_fw, ): input_dtype, x = dtype_and_x - assume(not np.any(np.isclose(x[1], 0))) - helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, - init_all_as_kwargs_np={"data": x[0]}, - method_input_dtypes=input_dtype, + init_all_as_kwargs_np={ + "data": x[0], + }, + method_input_dtypes=dtypes, method_all_as_kwargs_np={ - "other": x[1], - "rounding_mode": rounding_mode, + "size": size, + "dtype": dtypes[0], + "requires_grad": requires_grad, + "device": on_device, }, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -9326,42 +9396,33 @@ def test_torch_tensor_div( ) -# div_ +# nonzero @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="div_", - dtype_and_x=helpers.dtype_and_values( + method_name="nonzero", + dtype_and_values=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, - large_abs_safety_factor=2.5, - small_abs_safety_factor=2.5, - safety_factor_scale="log", ), - rounding_mode=st.sampled_from(["floor", "trunc"]) | st.none(), ) -def test_torch_tensor_div_( - dtype_and_x, - rounding_mode, - frontend, +def test_torch_tensor_nonzero( + dtype_and_values, frontend_method_data, init_flags, method_flags, + frontend, on_device, backend_fw, ): - input_dtype, x = dtype_and_x - assume(not np.any(np.isclose(x[1], 0))) - + input_dtype, x = dtype_and_values helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, - init_all_as_kwargs_np={"data": x[0]}, - method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "other": x[1], - "rounding_mode": rounding_mode, + init_all_as_kwargs_np={ + "data": x[0], }, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -9370,21 +9431,19 @@ def test_torch_tensor_div_( ) -# true_divide_ +# norm @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="true_divide_", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - num_arrays=2, - large_abs_safety_factor=2.5, - small_abs_safety_factor=2.5, - safety_factor_scale="log", - ), + method_name="norm", + p_dtype_x_axis=_get_axis_and_p(), + keepdim=st.booleans(), + dtype=helpers.get_dtypes("valid", full=False), ) -def test_torch_tensor_true_divide_( - dtype_and_x, +def test_torch_tensor_norm( + p_dtype_x_axis, + keepdim, + dtype, frontend, frontend_method_data, init_flags, @@ -9392,21 +9451,23 @@ def test_torch_tensor_true_divide_( on_device, backend_fw, ): - input_dtype, x = dtype_and_x - assume(not np.any(np.isclose(x[1], 0))) - + p, values = p_dtype_x_axis + input_dtype, x, axis = values helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={"data": x[0]}, method_input_dtypes=input_dtype, method_all_as_kwargs_np={ - "other": x[1], + "p": p, + "dim": axis, + "keepdim": keepdim, + "dtype": dtype[0], }, + frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, - frontend=frontend, on_device=on_device, ) @@ -9466,24 +9527,18 @@ def call(): assert u.shape == v.shape -# addcdiv +# not_equal @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="addcdiv", + method_name="not_equal", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=3, - large_abs_safety_factor=2.5, - small_abs_safety_factor=2.5, - safety_factor_scale="log", - shared_dtype=True, + available_dtypes=helpers.get_dtypes("valid"), + num_arrays=2, ), - value=st.floats(min_value=-100, max_value=100), ) -def test_torch_tensor_addcdiv( +def test_torch_tensor_not_equal( dtype_and_x, - value, frontend, frontend_method_data, init_flags, @@ -9492,144 +9547,191 @@ def test_torch_tensor_addcdiv( backend_fw, ): input_dtype, x = dtype_and_x - assume(not np.any(np.isclose(x[2], 0))) - helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, - init_all_as_kwargs_np={"data": x[0]}, + init_all_as_kwargs_np={ + "data": x[0], + }, method_input_dtypes=input_dtype, method_all_as_kwargs_np={ - "tensor1": x[1], - "tensor2": x[2], - "value": value, + "other": x[1], }, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, frontend=frontend, + atol_=1e-02, on_device=on_device, - atol_=1e-03, ) -# addcmul +# numpy @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="addcmul", + method_name="numpy", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=3, - large_abs_safety_factor=2.5, - small_abs_safety_factor=2.5, - safety_factor_scale="log", - shared_dtype=True, + available_dtypes=helpers.get_dtypes("valid"), ), - value=st.floats(min_value=-100, max_value=100), ) -def test_torch_tensor_addcmul( +def test_torch_tensor_numpy( dtype_and_x, - value, - frontend, frontend_method_data, init_flags, method_flags, + frontend, on_device, backend_fw, ): input_dtype, x = dtype_and_x - - helpers.test_frontend_method( + ret, frontend_ret = helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, - init_all_as_kwargs_np={"data": x[0]}, - method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "tensor1": x[1], - "tensor2": x[2], - "value": value, + init_all_as_kwargs_np={ + "data": x[0], + }, + method_input_dtypes=[], + method_all_as_kwargs_np={}, + frontend_method_data=frontend_method_data, + init_flags=init_flags, + method_flags=method_flags, + frontend=frontend, + on_device=on_device, + test_values=False, + ) + # manual testing required as function return is numpy frontend + helpers.value_test( + ret_np_flat=helpers.flatten_and_to_np(ret=ret), + ret_np_from_gt_flat=frontend_ret[0], + ground_truth_backend="torch", + ) + + +# permute +@handle_frontend_method( + class_tree=CLASS_TREE, + init_tree="torch.tensor", + method_name="permute", + dtype_values_axis=_array_idxes_n_dtype( + available_dtypes=helpers.get_dtypes("float"), + ), +) +def test_torch_tensor_permute( + dtype_values_axis, + frontend_method_data, + init_flags, + method_flags, + frontend, + on_device, + backend_fw, +): + x, idxes, dtype = dtype_values_axis + unpack_dims = True + if unpack_dims: + method_flags.num_positional_args = len(idxes) + 1 + dims = {} + i = 0 + for x_ in idxes: + dims["x{}".format(i)] = x_ + i += 1 + else: + dims = { + "dims": tuple(idxes), + } + helpers.test_frontend_method( + init_input_dtypes=dtype, + backend_to_test=backend_fw, + init_all_as_kwargs_np={ + "data": x[0], }, + method_input_dtypes=dtype, + method_all_as_kwargs_np=dims, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, frontend=frontend, on_device=on_device, - atol_=1e-02, ) -# addcmul_ +# pow @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="addcmul_", + method_name="pow", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=3, - large_abs_safety_factor=2.5, - small_abs_safety_factor=2.5, - safety_factor_scale="log", - shared_dtype=True, + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, + min_value=-1e04, + max_value=1e04, + allow_inf=False, ), - value=st.floats(min_value=-100, max_value=100), ) -def test_torch_tensor_addcmul_( +def test_torch_tensor_pow( dtype_and_x, - value, - frontend, frontend_method_data, init_flags, method_flags, + frontend, on_device, backend_fw, ): input_dtype, x = dtype_and_x - + dtype = input_dtype[0] + if "int" in dtype: + x[1] = ivy.abs(x[1]) helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, - init_all_as_kwargs_np={"data": x[0]}, + init_all_as_kwargs_np={ + "data": x[0], + }, method_input_dtypes=input_dtype, method_all_as_kwargs_np={ - "tensor1": x[1], - "tensor2": x[2], - "value": value, + "exponent": x[1], }, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, frontend=frontend, on_device=on_device, - atol_=1e-02, ) -# sign +# pow_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="sign", - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), + method_name="pow_", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, ), ) -def test_torch_tensor_sign( - dtype_x, - frontend, +def test_torch_tensor_pow_( + dtype_and_x, frontend_method_data, init_flags, method_flags, + frontend, on_device, backend_fw, ): - input_dtype, x = dtype_x + input_dtype, x = dtype_and_x + dtype = input_dtype[0] + if "int" in dtype: + x[1] = ivy.abs(x[1]) helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, - init_all_as_kwargs_np={"data": x[0]}, + init_all_as_kwargs_np={ + "data": x[0], + }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, + method_all_as_kwargs_np={ + "exponent": x[1], + }, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -9638,17 +9740,30 @@ def test_torch_tensor_sign( ) -# sign_ +# prod @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="sign_", - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), + method_name="prod", + dtype_x_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("numeric"), + min_num_dims=1, + max_num_dims=5, + valid_axis=True, + allow_neg_axes=False, + max_axes_size=1, + force_int_axis=True, + large_abs_safety_factor=10, + small_abs_safety_factor=10, + safety_factor_scale="log", ), + dtype=helpers.get_dtypes("float", none=True, full=False), + keepdims=st.booleans(), ) -def test_torch_tensor_sign_( - dtype_x, +def test_torch_tensor_prod( + dtype_x_axis, + dtype, + keepdims, frontend, frontend_method_data, init_flags, @@ -9656,30 +9771,38 @@ def test_torch_tensor_sign_( on_device, backend_fw, ): - input_dtype, x = dtype_x + input_dtype, x, axis = dtype_x_axis + if ivy.current_backend_str() == "torch": + init_flags.as_variable = [False] + method_flags.as_variable = [False] helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={"data": x[0]}, - method_input_dtypes=[input_dtype], - method_all_as_kwargs_np={}, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={ + "dim": axis, + "keepdim": keepdims, + "dtype": dtype[0], + }, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, - on_device=on_device, frontend=frontend, + on_device=on_device, ) -# std @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="std", - dtype_and_x=_statistical_dtype_values(function="std"), + method_name="quantile", + dtype_and_x=_quantile_helper().filter(lambda x: "bfloat16" not in x[0]), + keepdims=st.booleans(), ) -def test_torch_tensor_std( +def test_torch_tensor_quantile( dtype_and_x, + keepdims, frontend, frontend_method_data, init_flags, @@ -9687,7 +9810,9 @@ def test_torch_tensor_std( on_device, backend_fw, ): - input_dtype, x, _, _ = dtype_and_x + input_dtype, x, axis, interpolation, q = dtype_and_x + if type(axis) is tuple: + axis = axis[0] helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -9695,73 +9820,86 @@ def test_torch_tensor_std( "data": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, + method_all_as_kwargs_np={ + "q": q, + "dim": axis, + "keepdim": keepdims, + "interpolation": interpolation[0], + }, + frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, - frontend=frontend, on_device=on_device, ) -# fmod +# ravel @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="fmod", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - shared_dtype=True, - min_num_dims=1, - min_value=-100, - max_value=100, + method_name="ravel", + dtype_value=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), ), ) -def test_torch_tensor_fmod( - dtype_and_x, - frontend, +def test_torch_tensor_ravel( + dtype_value, frontend_method_data, init_flags, method_flags, + frontend, on_device, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, x = dtype_value helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, - init_all_as_kwargs_np={"data": x[0]}, + init_all_as_kwargs_np={ + "data": x[0], + }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={"other": x[1]}, - frontend=frontend, + method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, + frontend=frontend, on_device=on_device, ) -# fmod_ +@given( + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("complex", prune_function=False) + ).filter(lambda x: "bfloat16" not in x[0]), +) +def test_torch_tensor_real(dtype_x, backend_fw): + ivy.set_backend(backend_fw) + _, data = dtype_x + x = Tensor(data[0]) + x.ivy_array = data[0] + ivy.utils.assertions.check_equal(x.real, ivy.real(data[0])) + ivy.previous_backend() + + +# reciprocal @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="fmod_", + method_name="reciprocal", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - shared_dtype=True, - min_num_dims=1, - min_value=-100, - max_value=100, + min_value=1, ), ) -def test_torch_tensor_fmod_( +def test_torch_tensor_reciprocal( dtype_and_x, - frontend, frontend_method_data, init_flags, method_flags, + frontend, on_device, backend_fw, ): @@ -9769,72 +9907,30 @@ def test_torch_tensor_fmod_( helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, - init_all_as_kwargs_np={"data": x[0]}, + init_all_as_kwargs_np={ + "data": x[0], + }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={"other": x[1]}, - frontend=frontend, + method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, - on_device=on_device, - ) - - -# topk -# TODO: add value test after the stable sorting is added to torch -# https://github.com/pytorch/pytorch/issues/88184 -@handle_frontend_method( - class_tree=CLASS_TREE, - init_tree="torch.tensor", - method_name="topk", - dtype_x_axis_k=_topk_helper(), - largest=st.booleans(), - sorted=st.booleans(), -) -def test_torch_tensor_topk( - dtype_x_axis_k, - largest, - sorted, - frontend, - frontend_method_data, - init_flags, - method_flags, - on_device, - backend_fw, -): - input_dtype, input, axis, k = dtype_x_axis_k - helpers.test_frontend_method( - init_input_dtypes=input_dtype, - backend_to_test=backend_fw, - init_all_as_kwargs_np={"data": input[0]}, - method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "k": k, - "dim": axis, - "largest": largest, - "sorted": sorted, - }, frontend=frontend, - frontend_method_data=frontend_method_data, - init_flags=init_flags, - method_flags=method_flags, on_device=on_device, - test_values=False, ) -# bitwise right shift +# reciprocal_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="bitwise_right_shift", + method_name="reciprocal_", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), - num_arrays=2, - shared_dtype=True, + min_value=1, ), ) -def test_torch_tensor_bitwise_right_shift( +def test_torch_tensor_reciprocal_( dtype_and_x, frontend_method_data, init_flags, @@ -9844,11 +9940,6 @@ def test_torch_tensor_bitwise_right_shift( backend_fw, ): input_dtype, x = dtype_and_x - # negative shifts will throw an exception - # shifts >= dtype witdth produce backend-defined behavior - x[1] = np.asarray( - np.clip(x[1], 0, np.iinfo(input_dtype[1]).bits - 1), dtype=input_dtype[1] - ) helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -9856,9 +9947,7 @@ def test_torch_tensor_bitwise_right_shift( "data": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "other": x[1], - }, + method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -9867,14 +9956,17 @@ def test_torch_tensor_bitwise_right_shift( ) -# logdet +# relu @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="logdet", - dtype_and_x=_get_dtype_and_matrix(square=True, batch=True), + method_name="relu", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + allow_inf=False, + ), ) -def test_torch_tensor_logdet( +def test_torch_tensor_relu( dtype_and_x, frontend_method_data, init_flags, @@ -9884,13 +9976,11 @@ def test_torch_tensor_logdet( backend_fw, ): input_dtype, x = dtype_and_x - dtype, x = dtype_and_x - x = np.matmul(x.T, x) + np.identity(x.shape[0]) helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ - "data": x, + "data": x[0], }, method_input_dtypes=input_dtype, method_all_as_kwargs_np={}, @@ -9902,17 +9992,20 @@ def test_torch_tensor_logdet( ) -# multiply +# remainder @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="multiply", + method_name="remainder", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), + available_dtypes=helpers.get_dtypes("float"), + large_abs_safety_factor=2.5, + small_abs_safety_factor=2.5, + shared_dtype=True, num_arrays=2, ), ) -def test_torch_tensor_multiply( +def test_torch_tensor_remainder( dtype_and_x, frontend_method_data, init_flags, @@ -9940,17 +10033,22 @@ def test_torch_tensor_multiply( ) -# multiply_ +# remainder_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="multiply_", + method_name="remainder_", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), + available_dtypes=helpers.get_dtypes("valid"), + min_value=-1e04, + max_value=1e04, + large_abs_safety_factor=2.5, + small_abs_safety_factor=2.5, + shared_dtype=True, num_arrays=2, ), ) -def test_torch_tensor_multiply_( +def test_torch_tensor_remainder_( dtype_and_x, frontend_method_data, init_flags, @@ -9978,58 +10076,81 @@ def test_torch_tensor_multiply_( ) -# norm +# repeat @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="norm", - p_dtype_x_axis=_get_axis_and_p(), - keepdim=st.booleans(), - dtype=helpers.get_dtypes("valid", full=False), + method_name="repeat", + dtype_x_repeats=_repeat_helper(), + unpack_repeat=st.booleans(), ) -def test_torch_tensor_norm( - p_dtype_x_axis, - keepdim, - dtype, - frontend, +def test_torch_tensor_repeat( + dtype_x_repeats, + unpack_repeat, frontend_method_data, init_flags, method_flags, + frontend, on_device, backend_fw, ): - p, values = p_dtype_x_axis - input_dtype, x, axis = values + input_dtype, x, repeats = dtype_x_repeats + repeat = { + "repeats": repeats, + } + if unpack_repeat: + method_flags.num_positional_args = len(repeat["repeats"]) + 1 + for i, x_ in enumerate(repeat["repeats"]): + repeat["x{}".format(i)] = x_ helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, - init_all_as_kwargs_np={"data": x[0]}, - method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "p": p, - "dim": axis, - "keepdim": keepdim, - "dtype": dtype[0], + init_all_as_kwargs_np={ + "data": x[0], }, - frontend=frontend, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np=repeat, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, + frontend=frontend, on_device=on_device, ) -# isinf +@given( + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid", prune_function=False), + ), + requires_grad=st.booleans(), +) +def test_torch_tensor_requires_grad(dtype_x, requires_grad, backend_fw): + ivy.set_backend(backend_fw) + _, data = dtype_x + x = Tensor(data[0], requires_grad=requires_grad) + ivy.utils.assertions.check_equal(x.requires_grad, requires_grad, as_array=False) + x.requires_grad = not requires_grad + ivy.utils.assertions.check_equal(x.requires_grad, not requires_grad, as_array=False) + ivy.previous_backend() + + @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="isinf", - dtype_and_x=helpers.dtype_and_values( + method_name="reshape", + dtype_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), + shape=st.shared(helpers.get_shape(), key="value_shape"), + ), + shape=helpers.reshape_shapes( + shape=st.shared(helpers.get_shape(), key="value_shape") ), + unpack_shape=st.booleans(), ) -def test_torch_tensor_isinf( - dtype_and_x, +def test_torch_tensor_reshape( + dtype_x, + shape, + unpack_shape, frontend_method_data, init_flags, method_flags, @@ -10037,13 +10158,24 @@ def test_torch_tensor_isinf( on_device, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, x = dtype_x + shape = { + "shape": shape, + } + if unpack_shape: + method_flags.num_positional_args = len(shape["shape"]) + 1 + i = 0 + for x_ in shape["shape"]: + shape["x{}".format(i)] = x_ + i += 1 helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, - init_all_as_kwargs_np={"data": x[0]}, + init_all_as_kwargs_np={ + "data": x[0], + }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, + method_all_as_kwargs_np=shape, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -10052,17 +10184,17 @@ def test_torch_tensor_isinf( ) -# is_complex +# reshape_as @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="is_complex", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), + method_name="reshape_as", + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), num_arrays=2 ), ) -def test_torch_tensor_is_complex( - dtype_and_x, +def test_torch_tensor_reshape_as( + dtype_x, frontend_method_data, init_flags, method_flags, @@ -10070,13 +10202,17 @@ def test_torch_tensor_is_complex( on_device, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, x = dtype_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, - init_all_as_kwargs_np={"data": x[0]}, + init_all_as_kwargs_np={ + "data": x[0], + }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, + method_all_as_kwargs_np={ + "other": x[1], + }, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -10085,17 +10221,19 @@ def test_torch_tensor_is_complex( ) -# isreal +# round @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="isreal", + method_name="round", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), + available_dtypes=helpers.get_dtypes("float"), ), + decimals=st.integers(min_value=0, max_value=5), ) -def test_torch_tensor_isreal( +def test_torch_tensor_round( dtype_and_x, + decimals, frontend_method_data, init_flags, method_flags, @@ -10107,9 +10245,13 @@ def test_torch_tensor_isreal( helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, - init_all_as_kwargs_np={"data": x[0]}, + init_all_as_kwargs_np={ + "data": x[0], + }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, + method_all_as_kwargs_np={ + "decimals": decimals, + }, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -10118,19 +10260,19 @@ def test_torch_tensor_isreal( ) -# copysign +# round_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="copysign", + method_name="round_", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - min_num_dims=1, - num_arrays=2, ), + decimals=st.integers(min_value=0, max_value=5), ) -def test_torch_tensor_copysign( +def test_torch_tensor_round_( dtype_and_x, + decimals, frontend_method_data, init_flags, method_flags, @@ -10147,7 +10289,7 @@ def test_torch_tensor_copysign( }, method_input_dtypes=input_dtype, method_all_as_kwargs_np={ - "other": x[1], + "decimals": decimals, }, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -10157,22 +10299,21 @@ def test_torch_tensor_copysign( ) -# not_equal +# rsqrt @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="not_equal", + method_name="rsqrt", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - num_arrays=2, + available_dtypes=helpers.get_dtypes("float"), ), ) -def test_torch_tensor_not_equal( +def test_torch_tensor_rsqrt( dtype_and_x, - frontend, frontend_method_data, init_flags, method_flags, + frontend, on_device, backend_fw, ): @@ -10184,114 +10325,80 @@ def test_torch_tensor_not_equal( "data": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "other": x[1], - }, + method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, frontend=frontend, - atol_=1e-02, on_device=on_device, ) -@st.composite -def _get_dtype_input_and_vectors(draw, with_input=False, same_size=False): - dim_size1 = draw(helpers.ints(min_value=2, max_value=5)) - dim_size2 = dim_size1 if same_size else draw(helpers.ints(min_value=2, max_value=5)) - dtype = draw(helpers.get_dtypes("float", full=True)) - dtype = [ - draw(st.sampled_from(tuple(set(dtype).difference({"bfloat16", "float16"})))) - ] - vec1 = draw( - helpers.array_values( - dtype=dtype[0], shape=(dim_size1,), min_value=2, max_value=5 - ) - ) - vec2 = draw( - helpers.array_values( - dtype=dtype[0], shape=(dim_size2,), min_value=2, max_value=5 - ) +@given( + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid", prune_function=False), + ret_shape=True, + ).filter(lambda x: "bfloat16" not in x[0]), +) +def test_torch_tensor_shape(dtype_x, backend_fw): + ivy.set_backend(backend_fw) + dtype, data, shape = dtype_x + x = Tensor(data[0]) + ivy.utils.assertions.check_equal( + x.ivy_array.shape, ivy.Shape(shape), as_array=False ) - if with_input: - input = draw( - helpers.array_values( - dtype=dtype[0], shape=(dim_size1, dim_size2), min_value=2, max_value=5 - ) - ) - return dtype, input, vec1, vec2 - return dtype, vec1, vec2 - + ivy.previous_backend() -# addr -@handle_frontend_method( - class_tree=CLASS_TREE, - init_tree="torch.tensor", - method_name="addr", - dtype_and_vecs=_get_dtype_input_and_vectors(with_input=True), - beta=st.floats( - min_value=-5, - max_value=5, - allow_nan=False, - allow_subnormal=False, - allow_infinity=False, - ), - alpha=st.floats( - min_value=-5, - max_value=5, - allow_nan=False, - allow_subnormal=False, - allow_infinity=False, + +# short +@handle_frontend_method( + class_tree=CLASS_TREE, + init_tree="torch.tensor", + method_name="short", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + min_value=-1e04, + max_value=1e04, + allow_inf=False, ), ) -def test_torch_tensor_addr( - dtype_and_vecs, - beta, - alpha, - frontend, +def test_torch_tensor_short( + dtype_and_x, frontend_method_data, init_flags, method_flags, + frontend, on_device, backend_fw, ): - dtype, input, vec1, vec2 = dtype_and_vecs + input_dtype, x = dtype_and_x helpers.test_frontend_method( - init_input_dtypes=dtype, + init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ - "data": input, - }, - method_input_dtypes=dtype, - method_all_as_kwargs_np={ - "vec1": vec1, - "vec2": vec2, - "beta": beta, - "alpha": alpha, + "data": x[0], }, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, frontend=frontend, - atol_=1e-02, on_device=on_device, ) -# logical_not_ +# sigmoid @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="logical_not_", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - num_arrays=1, - large_abs_safety_factor=12, + method_name="sigmoid", + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), ), ) -def test_torch_tensor_logical_not_( - dtype_and_x, +def test_torch_tensor_sigmoid( + dtype_x, frontend_method_data, init_flags, method_flags, @@ -10299,7 +10406,7 @@ def test_torch_tensor_logical_not_( on_device, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, x = dtype_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -10316,17 +10423,17 @@ def test_torch_tensor_logical_not_( ) -# rsqrt +# sigmoid @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="rsqrt", - dtype_and_x=helpers.dtype_and_values( + method_name="sigmoid_", + dtype_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), ), ) -def test_torch_tensor_rsqrt( - dtype_and_x, +def test_torch_tensor_sigmoid_( + dtype_x, frontend_method_data, init_flags, method_flags, @@ -10334,7 +10441,7 @@ def test_torch_tensor_rsqrt( on_device, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, x = dtype_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -10351,31 +10458,29 @@ def test_torch_tensor_rsqrt( ) -# rsqrt_ +# sign @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="rsqrt_", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + method_name="sign", + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), ), ) -def test_torch_rsqrt_( - dtype_and_x, +def test_torch_tensor_sign( + dtype_x, + frontend, frontend_method_data, init_flags, method_flags, - frontend, on_device, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, x = dtype_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, - init_all_as_kwargs_np={ - "data": x[0], - }, + init_all_as_kwargs_np={"data": x[0]}, method_input_dtypes=input_dtype, method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, @@ -10386,26 +10491,55 @@ def test_torch_rsqrt_( ) -# equal +# sign_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="equal", + method_name="sign_", + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + ), +) +def test_torch_tensor_sign_( + dtype_x, + frontend, + frontend_method_data, + init_flags, + method_flags, + on_device, + backend_fw, +): + input_dtype, x = dtype_x + helpers.test_frontend_method( + init_input_dtypes=input_dtype, + backend_to_test=backend_fw, + init_all_as_kwargs_np={"data": x[0]}, + method_input_dtypes=[input_dtype], + method_all_as_kwargs_np={}, + frontend_method_data=frontend_method_data, + init_flags=init_flags, + method_flags=method_flags, + on_device=on_device, + frontend=frontend, + ) + + +# sin +@handle_frontend_method( + class_tree=CLASS_TREE, + init_tree="torch.tensor", + method_name="sin", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, - shared_dtype=True, - min_num_dims=1, - min_value=-1e04, - max_value=1e04, + available_dtypes=helpers.get_dtypes("float"), + allow_inf=False, ), ) -def test_torch_tensor_equal( +def test_torch_tensor_sin( dtype_and_x, - frontend, frontend_method_data, init_flags, method_flags, + frontend, on_device, backend_fw, ): @@ -10417,29 +10551,26 @@ def test_torch_tensor_equal( "data": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "other": x[1], - }, + method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, frontend=frontend, - atol_=1e-04, - rtol_=1e-04, on_device=on_device, ) -# erf +# sin_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="erf", + method_name="sin_", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), + allow_inf=False, ), ) -def test_torch_tensor_erf( +def test_torch_tensor_sin_( dtype_and_x, frontend_method_data, init_flags, @@ -10465,16 +10596,17 @@ def test_torch_tensor_erf( ) -# erf_ +# sinh @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="erf_", + method_name="sinh", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), + allow_inf=False, ), ) -def test_torch_tensor_erf_( +def test_torch_tensor_sinh( dtype_and_x, frontend_method_data, init_flags, @@ -10500,20 +10632,17 @@ def test_torch_tensor_erf_( ) -# greater +# sinh_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="greater", + method_name="sinh_", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - num_arrays=2, - min_value=-1e04, - max_value=1e04, + available_dtypes=helpers.get_dtypes("float"), allow_inf=False, ), ) -def test_torch_tensor_greater( +def test_torch_tensor_sinh_( dtype_and_x, frontend_method_data, init_flags, @@ -10530,9 +10659,7 @@ def test_torch_tensor_greater( "data": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "other": x[1], - }, + method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -10541,21 +10668,23 @@ def test_torch_tensor_greater( ) -# greater_ +# size @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="greater_", + method_name="size", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), - num_arrays=2, - min_value=-1e04, - max_value=1e04, - allow_inf=False, + shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), + ), + dim=helpers.get_axis( + shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), + force_int=True, ), ) -def test_torch_tensor_greater_( +def test_torch_tensor_size( dtype_and_x, + dim, frontend_method_data, init_flags, method_flags, @@ -10572,7 +10701,7 @@ def test_torch_tensor_greater_( }, method_input_dtypes=input_dtype, method_all_as_kwargs_np={ - "other": x[1], + "dim": dim, }, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -10582,21 +10711,23 @@ def test_torch_tensor_greater_( ) -# greater_equal +# softmax @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="greater_equal", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - num_arrays=2, - min_value=-1e04, - max_value=1e04, - allow_inf=False, + method_name="softmax", + dtype_x_and_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("float"), + min_num_dims=1, + max_axes_size=1, + force_int_axis=True, + valid_axis=True, ), + dtype=helpers.get_dtypes("float", full=False), ) -def test_torch_tensor_greater_equal( - dtype_and_x, +def test_torch_tensor_softmax( + dtype_x_and_axis, + dtype, frontend_method_data, init_flags, method_flags, @@ -10604,7 +10735,7 @@ def test_torch_tensor_greater_equal( on_device, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, x, axis = dtype_x_and_axis helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -10613,7 +10744,8 @@ def test_torch_tensor_greater_equal( }, method_input_dtypes=input_dtype, method_all_as_kwargs_np={ - "other": x[1], + "dim": axis, + "dtype": dtype[0], }, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -10623,21 +10755,26 @@ def test_torch_tensor_greater_equal( ) -# greater_equal_ +# sort @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="greater_equal_", - dtype_and_x=helpers.dtype_and_values( + method_name="sort", + dtype_value=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), - num_arrays=2, - min_value=-1e04, - max_value=1e04, - allow_inf=False, + shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), + ), + dim=helpers.get_axis( + shape=st.shared(helpers.get_shape(), key="shape"), + allow_neg=True, + force_int=True, ), + descending=st.booleans(), ) -def test_torch_tensor_greater_equal_( - dtype_and_x, +def test_torch_tensor_sort( + dtype_value, + dim, + descending, frontend_method_data, init_flags, method_flags, @@ -10645,7 +10782,7 @@ def test_torch_tensor_greater_equal_( on_device, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, x = dtype_value helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -10654,7 +10791,8 @@ def test_torch_tensor_greater_equal_( }, method_input_dtypes=input_dtype, method_all_as_kwargs_np={ - "other": x[1], + "dim": dim, + "descending": descending, }, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -10664,21 +10802,28 @@ def test_torch_tensor_greater_equal_( ) -# less +# split @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="less", - dtype_and_x=helpers.dtype_and_values( + method_name="split", + dtype_value=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), - num_arrays=2, - min_value=-1e04, - max_value=1e04, - allow_inf=False, + shape=st.shared(helpers.get_shape(min_num_dims=1), key="value_shape"), + ), + split_size=_get_splits(allow_none=False, min_num_dims=1, allow_array_indices=False), + dim=st.shared( + helpers.get_axis( + shape=st.shared(helpers.get_shape(min_num_dims=1), key="value_shape"), + force_int=True, + ), + key="target_axis", ), ) -def test_torch_tensor_less( - dtype_and_x, +def test_torch_tensor_split( + dtype_value, + split_size, + dim, frontend_method_data, init_flags, method_flags, @@ -10686,7 +10831,7 @@ def test_torch_tensor_less( on_device, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, x = dtype_value helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -10695,7 +10840,8 @@ def test_torch_tensor_less( }, method_input_dtypes=input_dtype, method_all_as_kwargs_np={ - "other": x[1], + "split_size": split_size, + "dim": dim, }, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -10705,39 +10851,31 @@ def test_torch_tensor_less( ) -# less_ +# sqrt @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="less_", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - num_arrays=2, - min_value=-1e04, - max_value=1e04, - allow_inf=False, + method_name="sqrt", + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), ), ) -def test_torch_tensor_less_( - dtype_and_x, +def test_torch_tensor_sqrt( + dtype_x, + frontend, frontend_method_data, init_flags, method_flags, - frontend, on_device, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, x = dtype_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, - init_all_as_kwargs_np={ - "data": x[0], - }, + init_all_as_kwargs_np={"data": x[0]}, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "other": x[1], - }, + method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -10746,39 +10884,31 @@ def test_torch_tensor_less_( ) -# less_equal +# sqrt_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="less_equal", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - num_arrays=2, - min_value=-1e04, - max_value=1e04, - allow_inf=False, + method_name="sqrt_", + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), ), ) -def test_torch_tensor_less_equal( - dtype_and_x, +def test_torch_tensor_sqrt_( + dtype_x, + frontend, frontend_method_data, init_flags, method_flags, - frontend, on_device, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, x = dtype_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, - init_all_as_kwargs_np={ - "data": x[0], - }, + init_all_as_kwargs_np={"data": x[0]}, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "other": x[1], - }, + method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -10787,39 +10917,31 @@ def test_torch_tensor_less_equal( ) -# less_equal_ +# square @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="less_equal_", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - num_arrays=2, - min_value=-1e04, - max_value=1e04, - allow_inf=False, + method_name="square", + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), ), ) -def test_torch_tensor_less_equal_( - dtype_and_x, +def test_torch_tensor_square( + dtype_x, + frontend, frontend_method_data, init_flags, method_flags, - frontend, on_device, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, x = dtype_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, - init_all_as_kwargs_np={ - "data": x[0], - }, + init_all_as_kwargs_np={"data": x[0]}, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "other": x[1], - }, + method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -10828,75 +10950,62 @@ def test_torch_tensor_less_equal_( ) -# addr_ +# squeeze @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="addr_", - dtype_and_vecs=_get_dtype_input_and_vectors(with_input=True), - beta=st.floats( - min_value=-5, - max_value=5, - allow_nan=False, - allow_subnormal=False, - allow_infinity=False, - ), - alpha=st.floats( - min_value=-5, - max_value=5, - allow_nan=False, - allow_subnormal=False, - allow_infinity=False, + method_name="squeeze", + dtype_value_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=1, + valid_axis=True, + force_int_axis=True, + shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), ), ) -def test_torch_tensor_addr_( - dtype_and_vecs, - beta, - alpha, - frontend, +def test_torch_tensor_squeeze( + dtype_value_axis, frontend_method_data, init_flags, method_flags, + frontend, on_device, backend_fw, ): - dtype, input, vec1, vec2 = dtype_and_vecs + input_dtype, x, axis = dtype_value_axis helpers.test_frontend_method( - init_input_dtypes=dtype, + init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ - "data": input, + "data": x[0], }, - method_input_dtypes=dtype, + method_input_dtypes=input_dtype, method_all_as_kwargs_np={ - "vec1": vec1, - "vec2": vec2, - "beta": beta, - "alpha": alpha, + "dim": axis, }, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, frontend=frontend, - atol_=1e-02, on_device=on_device, ) +# squeeze_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="eq_", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - min_value=-1e04, - max_value=1e04, - allow_inf=False, + method_name="squeeze_", + dtype_value_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=1, + valid_axis=True, + force_int_axis=True, + shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), ), ) -def test_torch_tensor_eq_( - dtype_and_x, +def test_torch_tensor_squeeze_( + dtype_value_axis, frontend_method_data, init_flags, method_flags, @@ -10904,7 +11013,7 @@ def test_torch_tensor_eq_( on_device, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, x, axis = dtype_value_axis helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -10913,7 +11022,7 @@ def test_torch_tensor_eq_( }, method_input_dtypes=input_dtype, method_all_as_kwargs_np={ - "other": x[1], + "dim": axis, }, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -10923,20 +11032,15 @@ def test_torch_tensor_eq_( ) +# std @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="var", - dtype_and_x=_statistical_dtype_values( - function="var", - min_value=-1e04, - max_value=1e04, - ), - keepdim=st.booleans(), + method_name="std", + dtype_and_x=_statistical_dtype_values(function="std"), ) -def test_torch_tensor_var( +def test_torch_tensor_std( dtype_and_x, - keepdim, frontend, frontend_method_data, init_flags, @@ -10944,21 +11048,19 @@ def test_torch_tensor_var( on_device, backend_fw, ): - input_dtype, x, axis, correction = dtype_and_x + input_dtype, x, _, _ = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, - init_all_as_kwargs_np={"data": x[0]}, - method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "dim": axis, - "correction": int(correction), - "keepdim": keepdim, + init_all_as_kwargs_np={ + "data": x[0], }, - frontend=frontend, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, + frontend=frontend, on_device=on_device, ) @@ -10966,11 +11068,16 @@ def test_torch_tensor_var( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="narrow", - dtype_input_dim_start_length=_dtype_input_dim_start_length(), + method_name="stride", + dtype_value_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=1, + valid_axis=True, + force_int_axis=True, + ), ) -def test_torch_tensor_narrow( - dtype_input_dim_start_length, +def test_torch_tensor_stride( + dtype_value_axis, frontend, frontend_method_data, init_flags, @@ -10978,17 +11085,13 @@ def test_torch_tensor_narrow( on_device, backend_fw, ): - (input_dtype, x, dim, start, length) = dtype_input_dim_start_length + input_dtype, x, axis = dtype_value_axis helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={"data": x[0]}, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "dim": dim, - "start": start, - "length": length, - }, + method_all_as_kwargs_np={"dim": axis}, frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -10997,14 +11100,23 @@ def test_torch_tensor_narrow( ) +# sub @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="as_strided", - dtype_x_and_other=_as_strided_helper(), + method_name="sub", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, + min_value=-1e04, + max_value=1e04, + allow_inf=False, + ), + alpha=st.floats(min_value=-1e04, max_value=1e04, allow_infinity=False), ) -def test_torch_tensor_as_strided( - dtype_x_and_other, +def test_torch_tensor_sub( + dtype_and_x, + alpha, frontend, frontend_method_data, init_flags, @@ -11012,174 +11124,214 @@ def test_torch_tensor_as_strided( on_device, backend_fw, ): - input_dtype, x, size, stride, offset = dtype_x_and_other + input_dtype, x = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, - init_all_as_kwargs_np={"data": x[0]}, + init_all_as_kwargs_np={ + "data": x[0], + }, method_input_dtypes=input_dtype, method_all_as_kwargs_np={ - "size": size, - "stride": stride, - "storage_offset": offset, + "other": x[1], + "alpha": alpha, }, - frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, + frontend=frontend, + atol_=1e-02, on_device=on_device, ) +# subtract_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="stride", - dtype_value_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("valid"), - min_num_dims=1, - valid_axis=True, - force_int_axis=True, + method_name="subtract_", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, ), ) -def test_torch_tensor_stride( - dtype_value_axis, - frontend, +def test_torch_tensor_subtract_( + dtype_and_x, frontend_method_data, init_flags, method_flags, + frontend, on_device, backend_fw, ): - input_dtype, x, axis = dtype_value_axis + input_dtype, x = dtype_and_x helpers.test_frontend_method( - init_input_dtypes=input_dtype, + init_input_dtypes=[input_dtype[0]], backend_to_test=backend_fw, - init_all_as_kwargs_np={"data": x[0]}, + init_all_as_kwargs_np={ + "data": x[0], + }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={"dim": axis}, - frontend=frontend, + method_all_as_kwargs_np={ + "other": x[1], + }, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, + frontend=frontend, on_device=on_device, ) +# sum @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="log1p", - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - max_value=1e37, + method_name="sum", + dtype_x_dim=_get_castable_dtype( + min_value=-1e04, + max_value=1e04, ), + keepdim=st.booleans(), ) -def test_torch_tensor_log1p( - dtype_x, - frontend, +def test_torch_tensor_sum( + dtype_x_dim, + keepdim, frontend_method_data, init_flags, method_flags, + frontend, on_device, backend_fw, ): - input_dtype, x = dtype_x + input_dtype, x, dim, castable_dtype = dtype_x_dim + if method_flags.as_variable: + castable_dtype = input_dtype + input_dtype = [input_dtype] helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, - init_all_as_kwargs_np={"data": x[0]}, + init_all_as_kwargs_np={ + "data": x[0], + }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, - frontend=frontend, + method_all_as_kwargs_np={ + "dim": dim, + "keepdim": keepdim, + "dtype": castable_dtype, + }, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, + frontend=frontend, on_device=on_device, ) -# log1p_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="log1p_", - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - max_value=1e37, + method_name="svd", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_value=0, + max_value=10, + shape=helpers.ints(min_value=2, max_value=5).map(lambda x: tuple([x, x])), ), + some=st.booleans(), + compute_uv=st.booleans(), ) -def test_torch_tensor_log1p_( - dtype_x, +def test_torch_tensor_svd( + dtype_and_x, + some, + compute_uv, frontend, + backend_fw, frontend_method_data, init_flags, method_flags, on_device, - backend_fw, ): - input_dtype, x = dtype_x - helpers.test_frontend_method( + input_dtype, x = dtype_and_x + x = np.asarray(x[0], dtype=input_dtype[0]) + + ret, frontend_ret = helpers.test_frontend_method( init_input_dtypes=input_dtype, - backend_to_test=backend_fw, - init_all_as_kwargs_np={"data": x[0]}, + init_all_as_kwargs_np={ + "data": x, + }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, - frontend=frontend, + method_all_as_kwargs_np={ + "some": some, + "compute_uv": compute_uv, + }, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, + frontend=frontend, + backend_to_test=backend_fw, on_device=on_device, + test_values=False, ) + with helpers.update_backend(backend_fw) as ivy_backend: + ret = [ivy_backend.to_numpy(x) for x in ret] + frontend_ret = [np.asarray(x) for x in frontend_ret] + + u, s, vh = ret + frontend_u, frontend_s, frontend_vh = frontend_ret + + if compute_uv: + helpers.assert_all_close( + ret_np=frontend_u @ np.diag(frontend_s) @ frontend_vh.T, + ret_from_gt_np=u @ np.diag(s) @ vh, + rtol=1e-2, + atol=1e-2, + backend=backend_fw, + ground_truth_backend=frontend, + ) + else: + helpers.assert_all_close( + ret_np=frontend_s, + ret_from_gt_np=s, + rtol=1e-2, + atol=1e-2, + backend=backend_fw, + ground_truth_backend=frontend, + ) +# t @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="baddbmm", - dtype_and_matrices=_get_dtype_and_3dbatch_matrices(with_input=True, input_3d=True), - beta=st.floats( - min_value=-5, - max_value=5, - allow_nan=False, - allow_subnormal=False, - allow_infinity=False, - ), - alpha=st.floats( - min_value=-5, - max_value=5, - allow_nan=False, - allow_subnormal=False, - allow_infinity=False, + method_name="t", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + shape=helpers.get_shape(min_num_dims=2, max_num_dims=2), ), ) -def test_torch_tensor_baddbmm( - dtype_and_matrices, - beta, - alpha, - frontend, +def test_torch_tensor_t( + dtype_and_x, frontend_method_data, init_flags, method_flags, + frontend, on_device, backend_fw, ): - input_dtype, x, batch1, batch2 = dtype_and_matrices + input_dtype, x = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, - init_all_as_kwargs_np={"data": x[0]}, - method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "batch1": batch1, - "batch2": batch2, - "beta": beta, - "alpha": alpha, + init_all_as_kwargs_np={ + "data": x[0], }, - frontend=frontend, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, + frontend=frontend, on_device=on_device, ) @@ -11187,91 +11339,93 @@ def test_torch_tensor_baddbmm( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="baddbmm_", - dtype_and_matrices=_get_dtype_and_3dbatch_matrices(with_input=True, input_3d=True), - beta=st.floats( - min_value=-5, - max_value=5, - allow_nan=False, - allow_subnormal=False, - allow_infinity=False, - ), - alpha=st.floats( - min_value=-5, - max_value=5, - allow_nan=False, - allow_subnormal=False, - allow_infinity=False, + method_name="take_along_dim", + dtype_indices_axis=helpers.array_indices_axis( + array_dtypes=helpers.get_dtypes("numeric"), + indices_dtypes=["int64"], + min_num_dims=1, + max_num_dims=5, + min_dim_size=1, + max_dim_size=10, + indices_same_dims=True, ), ) -def test_torch_baddbmm_( - dtype_and_matrices, - beta, - alpha, - frontend, +def test_torch_tensor_take_along_dim( + dtype_indices_axis, frontend_method_data, init_flags, method_flags, + frontend, on_device, + backend_fw, ): - input_dtype, x, batch1, batch2 = dtype_and_matrices + input_dtypes, value, indices, axis, _ = dtype_indices_axis helpers.test_frontend_method( - init_input_dtypes=input_dtype, - init_all_as_kwargs_np={"data": x[0]}, - method_input_dtypes=input_dtype, + init_input_dtypes=[input_dtypes[0]], + backend_to_test=backend_fw, + init_all_as_kwargs_np={ + "data": value, + }, + method_input_dtypes=[input_dtypes[1]], method_all_as_kwargs_np={ - "batch1": batch1, - "batch2": batch2, - "beta": beta, - "alpha": alpha, + "indices": indices, + "dim": axis, }, - frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, + frontend=frontend, on_device=on_device, ) +# tan @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="bmm", - dtype_and_matrices=_get_dtype_and_3dbatch_matrices(with_input=True, input_3d=True), + method_name="tan", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + allow_inf=False, + ), ) -def test_torch_tensor_instance_bmm( - dtype_and_matrices, - backend_fw, - frontend, +def test_torch_tensor_tan( + dtype_and_x, frontend_method_data, init_flags, method_flags, + frontend, on_device, + backend_fw, ): - input_dtype, _, x, mat2 = dtype_and_matrices + input_dtype, x = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, - init_all_as_kwargs_np={"data": x}, + backend_to_test=backend_fw, + init_all_as_kwargs_np={ + "data": x[0], + }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={"mat2": mat2}, - frontend=frontend, + method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, + frontend=frontend, on_device=on_device, - backend_to_test=backend_fw, ) +# tan_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="floor_", + method_name="tan_", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), + allow_inf=False, ), ) -def test_torch_tensor_floor_( +def test_torch_tensor_tan_( dtype_and_x, frontend_method_data, init_flags, @@ -11287,7 +11441,7 @@ def test_torch_tensor_floor_( init_all_as_kwargs_np={ "data": x[0], }, - method_input_dtypes=input_dtype, + method_input_dtypes=[], method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -11297,19 +11451,18 @@ def test_torch_tensor_floor_( ) +# tanh @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="diag", - dtype_and_values=helpers.dtype_and_values( + method_name="tanh", + dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - shape=st.shared(helpers.get_shape(min_num_dims=1, max_num_dims=2), key="shape"), + allow_inf=False, ), - diagonal=st.integers(min_value=-100, max_value=100), ) -def test_torch_tensor_diag( - dtype_and_values, - diagonal, +def test_torch_tensor_tanh( + dtype_and_x, frontend_method_data, init_flags, method_flags, @@ -11317,17 +11470,15 @@ def test_torch_tensor_diag( on_device, backend_fw, ): - input_dtype, values = dtype_and_values + input_dtype, x = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ - "data": values[0], + "data": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "diagonal": diagonal, - }, + method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -11336,100 +11487,90 @@ def test_torch_tensor_diag( ) -# diagonal -@st.composite -def dims_and_offset(draw, shape): - shape_actual = draw(shape) - dim1 = draw(helpers.get_axis(shape=shape, force_int=True)) - dim2 = draw(helpers.get_axis(shape=shape, force_int=True)) - offset = draw( - st.integers(min_value=-shape_actual[dim1], max_value=shape_actual[dim1]) - ) - return dim1, dim2, offset - - +# tanh_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="diagonal", - dtype_and_values=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - shape=st.shared(helpers.get_shape(min_num_dims=2), key="shape"), - ), - dims_and_offset=dims_and_offset( - shape=st.shared(helpers.get_shape(min_num_dims=2), key="shape") + method_name="tanh_", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + allow_inf=False, ), ) -def test_torch_tensor_diagonal( - dtype_and_values, - dims_and_offset, - frontend, +def test_torch_tensor_tanh_( + dtype_and_x, frontend_method_data, - backend_fw, init_flags, method_flags, + frontend, on_device, + backend_fw, ): - input_dtype, value = dtype_and_values - dim1, dim2, offset = dims_and_offset - input = value[0] - num_dims = len(np.shape(input)) - assume(dim1 != dim2) - if dim1 < 0: - assume(dim1 + num_dims != dim2) - if dim2 < 0: - assume(dim1 != dim2 + num_dims) + input_dtype, x = dtype_and_x helpers.test_frontend_method( - init_input_dtypes=[input_dtype[0]], - init_all_as_kwargs_np={"data": input}, - method_input_dtypes=[input_dtype[0]], - method_all_as_kwargs_np={ - "offset": offset, - "dim1": dim1, - "dim2": dim2, + init_input_dtypes=input_dtype, + backend_to_test=backend_fw, + init_all_as_kwargs_np={ + "data": x[0], }, - frontend=frontend, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, - backend_to_test=backend_fw, init_flags=init_flags, method_flags=method_flags, + frontend=frontend, on_device=on_device, ) +# tensor_split @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="gather", - params_indices_others=helpers.array_indices_axis( - array_dtypes=helpers.get_dtypes("valid"), - indices_dtypes=["int64"], - indices_same_dims=True, + method_name="tensor_split", + dtype_value=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("integer"), + shape=st.shared(helpers.get_shape(min_num_dims=1), key="value_shape"), + ), + indices_or_sections=_get_splits( + min_num_dims=1, allow_none=False, allow_array_indices=False + ), + dim=st.shared( + helpers.get_axis( + shape=st.shared(helpers.get_shape(min_num_dims=1), key="value_shape"), + force_int=True, + ), + key="target_axis", ), + method_num_positional_args=st.just(1), ) -def test_torch_tensor_gather( - params_indices_others, - frontend, +def test_torch_tensor_tensor_split( + dtype_value, + indices_or_sections, + dim, frontend_method_data, init_flags, method_flags, + frontend, on_device, backend_fw, ): - input_dtypes, x, indices, axis, batch_dims = params_indices_others + input_dtype, x = dtype_value helpers.test_frontend_method( - init_input_dtypes=[input_dtypes[0]], + init_input_dtypes=input_dtype, backend_to_test=backend_fw, - init_all_as_kwargs_np={"data": x}, - method_input_dtypes=[input_dtypes[1]], + init_all_as_kwargs_np={ + "data": x[0], + }, + method_input_dtypes=[], method_all_as_kwargs_np={ - "dim": axis, - "index": indices, + "indices_or_sections": indices_or_sections, + "dim": dim, }, - frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, + frontend=frontend, on_device=on_device, ) @@ -11437,141 +11578,98 @@ def test_torch_tensor_gather( @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="take_along_dim", - dtype_indices_axis=helpers.array_indices_axis( - array_dtypes=helpers.get_dtypes("numeric"), - indices_dtypes=["int64"], - min_num_dims=1, - max_num_dims=5, - min_dim_size=1, - max_dim_size=10, - indices_same_dims=True, + method_name="tile", + dtype_and_values=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + shape=st.shared(helpers.get_shape(), key="shape"), + ), + reps=helpers.get_axis( + shape=st.shared(helpers.get_shape(), key="shape"), + allow_neg=False, ), ) -def test_torch_tensor_take_along_dim( - dtype_indices_axis, +def test_torch_tensor_tile( + dtype_and_values, + reps, + frontend, frontend_method_data, init_flags, method_flags, - frontend, on_device, backend_fw, ): - input_dtypes, value, indices, axis, _ = dtype_indices_axis + input_dtype, values = dtype_and_values + if isinstance(reps, tuple): + method_flags.num_positional_args = len(reps) + else: + method_flags.num_positional_args = 1 helpers.test_frontend_method( - init_input_dtypes=[input_dtypes[0]], + init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ - "data": value, + "data": values[0], }, - method_input_dtypes=[input_dtypes[1]], + method_input_dtypes=input_dtype, method_all_as_kwargs_np={ - "indices": indices, - "dim": axis, + "reps": reps, }, - frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, + frontend_method_data=frontend_method_data, frontend=frontend, on_device=on_device, ) +# to @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="movedim", - dtype_and_input=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_value=-100, - max_value=100, - shape=st.shared( - helpers.get_shape( - min_num_dims=1, - max_num_dims=3, - min_dim_size=1, - max_dim_size=3, - ), - key="a_s_d", - ), - ), - source=helpers.get_axis( - allow_none=False, - unique=True, - shape=st.shared( - helpers.get_shape( - min_num_dims=1, - max_num_dims=3, - min_dim_size=1, - max_dim_size=3, - ), - key="a_s_d", - ), - min_size=1, - force_int=True, - ), - destination=helpers.get_axis( - allow_none=False, - unique=True, - shape=st.shared( - helpers.get_shape( - min_num_dims=1, - max_num_dims=3, - min_dim_size=1, - max_dim_size=3, - ), - key="a_s_d", - ), - min_size=1, - force_int=True, - ), + method_name="to", + args_kwargs=_to_helper(), ) -def test_torch_tensor_movedim( - dtype_and_input, - source, - destination, - frontend, +def test_torch_tensor_to( + args_kwargs, frontend_method_data, init_flags, method_flags, + frontend, on_device, backend_fw, ): - input_dtype, value = dtype_and_input + input_dtype, x, method_num_positional_args, method_all_as_kwargs_np = args_kwargs + method_flags.num_positional_args = method_num_positional_args helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, - init_all_as_kwargs_np={"data": value[0]}, - method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "source": source, - "destination": destination, + init_all_as_kwargs_np={ + "data": x[0], }, - frontend=frontend, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np=method_all_as_kwargs_np, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, + frontend=frontend, on_device=on_device, ) +# topk +# TODO: add value test after the stable sorting is added to torch +# https://github.com/pytorch/pytorch/issues/88184 @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="addcdiv_", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=3, - large_abs_safety_factor=2.5, - small_abs_safety_factor=2.5, - safety_factor_scale="log", - shared_dtype=True, - ), - value=st.floats(min_value=-100, max_value=100), + method_name="topk", + dtype_x_axis_k=_topk_helper(), + largest=st.booleans(), + sorted=st.booleans(), ) -def test_torch_tensor_addcdiv_( - dtype_and_x, - value, +def test_torch_tensor_topk( + dtype_x_axis_k, + largest, + sorted, frontend, frontend_method_data, init_flags, @@ -11579,118 +11677,140 @@ def test_torch_tensor_addcdiv_( on_device, backend_fw, ): - input_dtype, x = dtype_and_x - assume(not np.any(np.isclose(x[2], 0))) - + input_dtype, input, axis, k = dtype_x_axis_k helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, - init_all_as_kwargs_np={"data": x[0]}, + init_all_as_kwargs_np={"data": input[0]}, method_input_dtypes=input_dtype, method_all_as_kwargs_np={ - "tensor1": x[1], - "tensor2": x[2], - "value": value, + "k": k, + "dim": axis, + "largest": largest, + "sorted": sorted, }, + frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, - frontend=frontend, on_device=on_device, - atol_=1e-03, + test_values=False, ) +# transpose @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="cholesky", - dtype_and_x=_get_dtype_and_matrix(square=True), - upper=st.booleans(), + method_name="transpose", + dtype_value=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), + ), + dim0=helpers.get_axis( + shape=st.shared(helpers.get_shape(), key="shape"), + allow_neg=True, + force_int=True, + ), + dim1=helpers.get_axis( + shape=st.shared(helpers.get_shape(), key="shape"), + allow_neg=True, + force_int=True, + ), ) -def test_torch_tensor_cholesky( - dtype_and_x, - upper, - frontend, +def test_torch_tensor_transpose( + dtype_value, + dim0, + dim1, frontend_method_data, init_flags, method_flags, + frontend, on_device, backend_fw, ): - input_dtype, x = dtype_and_x - x = x[0] - # make symmetric positive-definite - x = np.matmul(x.swapaxes(-1, -2), x) + np.identity(x.shape[-1]) * 1e-3 - + input_dtype, x = dtype_value helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ - "data": x, + "data": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "upper": upper, - }, + method_all_as_kwargs_np={"dim0": dim0, "dim1": dim1}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, frontend=frontend, on_device=on_device, - rtol_=1e-2, ) +# transpose_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="heaviside", - dtype_and_values=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, + method_name="transpose_", + dtype_value=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), + ), + dim0=helpers.get_axis( + shape=st.shared(helpers.get_shape(), key="shape"), + allow_neg=True, + force_int=True, + ), + dim1=helpers.get_axis( + shape=st.shared(helpers.get_shape(), key="shape"), + allow_neg=True, + force_int=True, ), ) -def test_torch_tensor_heaviside( - dtype_and_values, - frontend, +def test_torch_tensor_transpose_( + dtype_value, + dim0, + dim1, frontend_method_data, init_flags, method_flags, + frontend, on_device, backend_fw, ): - input_dtype, values = dtype_and_values + input_dtype, x = dtype_value helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ - "data": values[0], + "data": x[0], }, method_input_dtypes=input_dtype, method_all_as_kwargs_np={ - "values": values[1], + "dim0": dim0, + "dim1": dim1, }, + frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, - frontend_method_data=frontend_method_data, frontend=frontend, on_device=on_device, ) +# tril @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="dot", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - shape=(1,), + method_name="tril", + dtype_and_values=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + min_num_dims=2, # Torch requires this. ), + diagonal=st.integers(min_value=-100, max_value=100), ) -def test_torch_tensor_dot( - dtype_and_x, +def test_torch_tensor_tril( + dtype_and_values, + diagonal, frontend_method_data, init_flags, method_flags, @@ -11698,7 +11818,7 @@ def test_torch_tensor_dot( on_device, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, x = dtype_and_values helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -11707,7 +11827,7 @@ def test_torch_tensor_dot( }, method_input_dtypes=input_dtype, method_all_as_kwargs_np={ - "tensor": x[1], + "diagonal": diagonal, }, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -11717,66 +11837,61 @@ def test_torch_tensor_dot( ) +# tril_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="tile", + method_name="tril_", dtype_and_values=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - shape=st.shared(helpers.get_shape(), key="shape"), - ), - reps=helpers.get_axis( - shape=st.shared(helpers.get_shape(), key="shape"), - allow_neg=False, + available_dtypes=helpers.get_dtypes("numeric"), + min_num_dims=2, # Torch requires this. ), + diagonal=st.integers(min_value=-100, max_value=100), ) -def test_torch_tensor_tile( +def test_torch_tensor_tril_( dtype_and_values, - reps, - frontend, + diagonal, frontend_method_data, init_flags, method_flags, + frontend, on_device, backend_fw, ): - input_dtype, values = dtype_and_values - if isinstance(reps, tuple): - method_flags.num_positional_args = len(reps) - else: - method_flags.num_positional_args = 1 + input_dtype, x = dtype_and_values helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ - "data": values[0], + "data": x[0], }, method_input_dtypes=input_dtype, method_all_as_kwargs_np={ - "reps": reps, + "diagonal": diagonal, }, + frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, - frontend_method_data=frontend_method_data, frontend=frontend, on_device=on_device, ) -# write test for torch instance apply_ - - +# true_divide_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="apply_", - dtype_and_values=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=1, + method_name="true_divide_", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + num_arrays=2, + large_abs_safety_factor=2.5, + small_abs_safety_factor=2.5, + safety_factor_scale="log", ), ) -def test_torch_tensor_apply_( - dtype_and_values, +def test_torch_tensor_true_divide_( + dtype_and_x, frontend, frontend_method_data, init_flags, @@ -11784,148 +11899,110 @@ def test_torch_tensor_apply_( on_device, backend_fw, ): - def func(x): - return x + 1 - - input_dtype, values = dtype_and_values + input_dtype, x = dtype_and_x + assume(not np.any(np.isclose(x[1], 0))) helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, - init_all_as_kwargs_np={ - "data": values[0], - }, + init_all_as_kwargs_np={"data": x[0]}, method_input_dtypes=input_dtype, method_all_as_kwargs_np={ - "callable": func, + "other": x[1], }, + frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, - frontend_method_data=frontend_method_data, frontend=frontend, on_device=on_device, ) -@given( - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float", prune_function=False), - num_arrays=3, - min_value=-1e3, - max_value=1e3, - ).filter(lambda x: all(dt == "float32" for dt in x[0])), +# trunc +@handle_frontend_method( + class_tree=CLASS_TREE, + init_tree="torch.tensor", + method_name="trunc", + dtype_value=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), + ), ) -def test_torch_tensor_backward( - dtype_x, +def test_torch_tensor_trunc( + dtype_value, + frontend_method_data, + init_flags, + method_flags, + frontend, + on_device, backend_fw, ): - ivy.set_backend(backend_fw) - if ivy.current_backend_str() == "numpy": - ivy.warnings.warn("Gradient calculation unavailable for numpy backend") - return - if ivy.current_backend_str() == "paddle": - ivy.warnings.warn("torch.Tensor.backward() unavailable for paddle backend") - return - _, values = dtype_x - x = Tensor(values[0], requires_grad=True) - y = Tensor(values[1], requires_grad=True) - z = Tensor(values[2], requires_grad=True) - a = x + y.pow(2) - b = z * a - c = b.sum() - c.backward() - x_torch = torch.tensor(values[0], requires_grad=True, dtype=torch.float32) - y_torch = torch.tensor(values[1], requires_grad=True, dtype=torch.float32) - z_torch = torch.tensor(values[2], requires_grad=True, dtype=torch.float32) - a_torch = x_torch + y_torch.pow(2) - b_torch = z_torch * a_torch - c_torch = b_torch.sum() - c_torch.backward() - helpers.assertions.value_test( - ret_np_flat=helpers.flatten_and_to_np( - ret=x._grads.ivy_array, backend=backend_fw - ), - ret_np_from_gt_flat=helpers.flatten_and_to_np( - ret=ivy.to_ivy(x_torch.grad.numpy()), backend=backend_fw - ), - rtol=1e-3, - atol=1e-3, - backend="torch", - ) - helpers.assertions.value_test( - ret_np_flat=helpers.flatten_and_to_np( - ret=y._grads.ivy_array, backend=backend_fw - ), - ret_np_from_gt_flat=helpers.flatten_and_to_np( - ret=ivy.to_ivy(y_torch.grad.numpy()), backend=backend_fw - ), - rtol=1e-3, - atol=1e-3, - backend="torch", - ) - helpers.assertions.value_test( - ret_np_flat=helpers.flatten_and_to_np( - ret=z._grads.ivy_array, backend=backend_fw - ), - ret_np_from_gt_flat=helpers.flatten_and_to_np( - ret=ivy.to_ivy(z_torch.grad.numpy()), backend=backend_fw - ), - rtol=1e-3, - atol=1e-3, - backend="torch", + input_dtype, x = dtype_value + helpers.test_frontend_method( + init_input_dtypes=input_dtype, + backend_to_test=backend_fw, + init_all_as_kwargs_np={ + "data": x[0], + }, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={}, + frontend_method_data=frontend_method_data, + init_flags=init_flags, + method_flags=method_flags, + frontend=frontend, + on_device=on_device, ) -# angle +# trunc_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="angle", - dtype_and_values=helpers.dtype_and_values( - available_dtypes=["float64", "complex64", "complex128"], + method_name="trunc_", + dtype_value=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), ), ) -def test_torch_tensor_angle( - dtype_and_values, - frontend, +def test_torch_tensor_trunc_( + dtype_value, frontend_method_data, init_flags, method_flags, + frontend, on_device, + backend_fw, ): - input_dtype, values = dtype_and_values - + input_dtype, x = dtype_value helpers.test_frontend_method( init_input_dtypes=input_dtype, + backend_to_test=backend_fw, init_all_as_kwargs_np={ - "data": values[0], + "data": x[0], }, method_input_dtypes=input_dtype, method_all_as_kwargs_np={}, + frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, - frontend_method_data=frontend_method_data, frontend=frontend, on_device=on_device, ) -# logaddexp +# type @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="logaddexp", + method_name="type", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - min_num_dims=1, - min_value=-100, - max_value=100, - shared_dtype=True, + available_dtypes=helpers.get_dtypes("valid"), ), + dtype=helpers.get_dtypes("valid", full=False), ) -def test_torch_tensor_logaddexp( +def test_torch_tensor_type( dtype_and_x, + dtype, frontend_method_data, init_flags, method_flags, @@ -11942,7 +12019,7 @@ def test_torch_tensor_logaddexp( }, method_input_dtypes=input_dtype, method_all_as_kwargs_np={ - "other": x[1], + "dtype": dtype[0], }, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -11952,53 +12029,58 @@ def test_torch_tensor_logaddexp( ) +# type_as @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="adjoint", - dtype_and_values=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("real_and_complex"), - min_num_dims=2, - min_dim_size=2, + method_name="type_as", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + num_arrays=2, ), ) -def test_torch_tensor_adjoint( - dtype_and_values, - frontend, +def test_torch_tensor_type_as( + dtype_and_x, frontend_method_data, init_flags, method_flags, + frontend, on_device, backend_fw, ): - input_dtype, values = dtype_and_values - + input_dtype, x = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ - "data": values[0], + "data": x[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, + method_all_as_kwargs_np={ + "other": x[1], + }, + frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, - frontend_method_data=frontend_method_data, frontend=frontend, on_device=on_device, ) +# unbind @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="conj", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float_and_complex") + method_name="unbind", + dtype_value_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("numeric"), + min_num_dims=1, + valid_axis=True, + force_int_axis=True, ), ) -def test_torch_tensor_conj( - dtype_and_x, +def test_torch_tensor_unbind( + dtype_value_axis, frontend_method_data, init_flags, method_flags, @@ -12006,15 +12088,17 @@ def test_torch_tensor_conj( on_device, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtypes, x, axis = dtype_value_axis helpers.test_frontend_method( - init_input_dtypes=input_dtype, + init_input_dtypes=input_dtypes, backend_to_test=backend_fw, init_all_as_kwargs_np={ "data": x[0], }, - method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, + method_input_dtypes=input_dtypes, + method_all_as_kwargs_np={ + "dim": axis, + }, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -12023,124 +12107,123 @@ def test_torch_tensor_conj( ) +# unfold @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="svd", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_value=0, - max_value=10, - shape=helpers.ints(min_value=2, max_value=5).map(lambda x: tuple([x, x])), - ), - some=st.booleans(), - compute_uv=st.booleans(), + method_name="unfold", + dtype_values_args=_unfold_args(), ) -def test_torch_tensor_svd( - dtype_and_x, - some, - compute_uv, - frontend, - backend_fw, +def test_torch_tensor_unfold( + dtype_values_args, frontend_method_data, init_flags, method_flags, + frontend, on_device, + backend_fw, ): - input_dtype, x = dtype_and_x - x = np.asarray(x[0], dtype=input_dtype[0]) - - ret, frontend_ret = helpers.test_frontend_method( + input_dtype, x, axis, size, step = dtype_values_args + print(axis, size, step) + helpers.test_frontend_method( init_input_dtypes=input_dtype, + backend_to_test=backend_fw, init_all_as_kwargs_np={ "data": x, }, method_input_dtypes=input_dtype, method_all_as_kwargs_np={ - "some": some, - "compute_uv": compute_uv, + "dimension": axis, + "size": size, + "step": step, }, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, frontend=frontend, - backend_to_test=backend_fw, on_device=on_device, - test_values=False, ) - with helpers.update_backend(backend_fw) as ivy_backend: - ret = [ivy_backend.to_numpy(x) for x in ret] - frontend_ret = [np.asarray(x) for x in frontend_ret] - - u, s, vh = ret - frontend_u, frontend_s, frontend_vh = frontend_ret - - if compute_uv: - helpers.assert_all_close( - ret_np=frontend_u @ np.diag(frontend_s) @ frontend_vh.T, - ret_from_gt_np=u @ np.diag(s) @ vh, - rtol=1e-2, - atol=1e-2, - backend=backend_fw, - ground_truth_backend=frontend, - ) - else: - helpers.assert_all_close( - ret_np=frontend_s, - ret_from_gt_np=s, - rtol=1e-2, - atol=1e-2, - backend=backend_fw, - ground_truth_backend=frontend, - ) -@st.composite -def _get_clip_min_inputs(draw): - shape = draw( - helpers.get_shape( - min_num_dims=1, max_num_dims=5, min_dim_size=2, max_dim_size=10 - ) - ) - x_dtype, x = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - shape=shape, - ) - ) - - min = draw( - helpers.array_values(dtype=x_dtype[0], shape=shape, min_value=0, max_value=25) +# unsqueeze +@handle_frontend_method( + class_tree=CLASS_TREE, + init_tree="torch.tensor", + method_name="unsqueeze", + dtype_value=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + shape=st.shared(helpers.get_shape(), key="shape"), + ), + dim=helpers.get_axis( + shape=st.shared(helpers.get_shape(), key="shape"), + allow_neg=True, + force_int=True, + ), +) +def test_torch_tensor_unsqueeze( + dtype_value, + dim, + frontend_method_data, + init_flags, + method_flags, + frontend, + on_device, + backend_fw, +): + input_dtype, x = dtype_value + helpers.test_frontend_method( + init_input_dtypes=input_dtype, + backend_to_test=backend_fw, + init_all_as_kwargs_np={ + "data": x[0], + }, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={ + "dim": dim, + }, + frontend_method_data=frontend_method_data, + init_flags=init_flags, + method_flags=method_flags, + frontend=frontend, + on_device=on_device, ) - return x_dtype, x, min - +# unsqueeze_ @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="clamp_min", - input_and_ranges=_get_clip_min_inputs(), + method_name="unsqueeze_", + dtype_value=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + shape=st.shared(helpers.get_shape(), key="shape"), + ), + dim=helpers.get_axis( + shape=st.shared(helpers.get_shape(), key="shape"), + allow_neg=True, + force_int=True, + ), ) -def test_torch_tensor_clamp_min( - input_and_ranges, +def test_torch_tensor_unsqueeze_( + dtype_value, + dim, frontend_method_data, init_flags, - backend_fw, + method_flags, frontend, on_device, - method_flags, + backend_fw, ): - x_dtype, x, min = input_and_ranges + input_dtype, x = dtype_value helpers.test_frontend_method( - init_input_dtypes=x_dtype, + init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ "data": x[0], }, - method_input_dtypes=x_dtype, + method_input_dtypes=input_dtype, method_all_as_kwargs_np={ - "min": min, + "dim": dim, }, frontend_method_data=frontend_method_data, init_flags=init_flags, @@ -12150,25 +12233,20 @@ def test_torch_tensor_clamp_min( ) -# gcd @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="gcd", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("integer"), - min_value=-100, - max_value=100, - min_num_dims=1, - max_num_dims=3, - min_dim_size=1, - max_dim_size=3, - num_arrays=2, - shared_dtype=True, + method_name="var", + dtype_and_x=_statistical_dtype_values( + function="var", + min_value=-1e04, + max_value=1e04, ), + keepdim=st.booleans(), ) -def test_torch_tensor_gcd( +def test_torch_tensor_var( dtype_and_x, + keepdim, frontend, frontend_method_data, init_flags, @@ -12176,16 +12254,16 @@ def test_torch_tensor_gcd( on_device, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, x, axis, correction = dtype_and_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, - init_all_as_kwargs_np={ - "data": x[0], - }, + init_all_as_kwargs_np={"data": x[0]}, method_input_dtypes=input_dtype, method_all_as_kwargs_np={ - "other": x[1], + "dim": axis, + "correction": int(correction), + "keepdim": keepdim, }, frontend=frontend, frontend_method_data=frontend_method_data, @@ -12195,21 +12273,26 @@ def test_torch_tensor_gcd( ) -# isnan +# view @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="isnan", + method_name="view", dtype_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), + shape=st.shared(helpers.get_shape(), key="value_shape"), + ), + shape=helpers.reshape_shapes( + shape=st.shared(helpers.get_shape(min_num_dims=1), key="value_shape") ), ) -def test_torch_isnan( +def test_torch_tensor_view( dtype_x, - frontend, + shape, frontend_method_data, init_flags, method_flags, + frontend, on_device, backend_fw, ): @@ -12217,9 +12300,13 @@ def test_torch_isnan( helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, - init_all_as_kwargs_np={"data": x[0]}, + init_all_as_kwargs_np={ + "data": x[0], + }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, + method_all_as_kwargs_np={ + "size": shape, + }, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, @@ -12228,33 +12315,27 @@ def test_torch_isnan( ) -# lcm +# view_as @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="lcm", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("integer"), + method_name="view_as", + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + shape=st.shared(helpers.get_shape(), key="value_shape"), num_arrays=2, - min_value=-100, - max_value=100, - min_num_dims=1, - max_num_dims=3, - min_dim_size=1, - max_dim_size=3, - shared_dtype=True, ), ) -def test_torch_tensor_lcm( - dtype_and_x, - frontend, +def test_torch_tensor_view_as( + dtype_x, frontend_method_data, init_flags, method_flags, + frontend, on_device, backend_fw, ): - input_dtype, x = dtype_and_x + input_dtype, x = dtype_x helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -12265,134 +12346,126 @@ def test_torch_tensor_lcm( method_all_as_kwargs_np={ "other": x[1], }, - frontend=frontend, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, + frontend=frontend, on_device=on_device, ) +# vsplit @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="quantile", - dtype_and_x=_quantile_helper().filter(lambda x: "bfloat16" not in x[0]), - keepdims=st.booleans(), + method_name="vsplit", + dtype_value=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + shape=st.shared(helpers.get_shape(min_num_dims=2), key="value_shape"), + ), + indices_or_sections=_get_splits( + min_num_dims=2, + axis=0, + allow_none=False, + allow_array_indices=False, + is_mod_split=True, + ), ) -def test_torch_tensor_quantile( - dtype_and_x, - keepdims, - frontend, +def test_torch_tensor_vsplit( + dtype_value, + indices_or_sections, frontend_method_data, init_flags, method_flags, + frontend, on_device, backend_fw, ): - input_dtype, x, axis, interpolation, q = dtype_and_x - if type(axis) is tuple: - axis = axis[0] + input_dtype, x = dtype_value helpers.test_frontend_method( init_input_dtypes=input_dtype, backend_to_test=backend_fw, init_all_as_kwargs_np={ "data": x[0], }, - method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "q": q, - "dim": axis, - "keepdim": keepdims, - "interpolation": interpolation[0], - }, - frontend=frontend, + method_input_dtypes=[], + method_all_as_kwargs_np={"indices_or_sections": indices_or_sections}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, + frontend=frontend, on_device=on_device, ) +# where @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="sinc", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - ), + method_name="where", + broadcastables=_broadcastable_trio(), ) -def test_torch_instance_sinc( - *, - dtype_and_x, - frontend, - backend_fw, +def test_torch_tensor_where( + broadcastables, frontend_method_data, init_flags, method_flags, + frontend, on_device, + backend_fw, ): - input_dtype, x = dtype_and_x + cond, xs, dtypes = broadcastables helpers.test_frontend_method( - init_input_dtypes=input_dtype, + init_input_dtypes=dtypes, + backend_to_test=backend_fw, init_all_as_kwargs_np={ - "data": x[0], + "data": xs[0], + }, + method_input_dtypes=["bool", dtypes[1]], + method_all_as_kwargs_np={ + "condition": cond, + "other": xs[1], }, - method_input_dtypes=input_dtype, - method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, frontend=frontend, - backend_to_test=backend_fw, on_device=on_device, ) -# index_fill +# zero_ tests @handle_frontend_method( class_tree=CLASS_TREE, init_tree="torch.tensor", - method_name="index_fill", - dtype_indices_axis=helpers.array_indices_axis( - array_dtypes=helpers.get_dtypes("numeric"), - indices_dtypes=["int64"], - min_num_dims=1, - max_num_dims=5, - min_dim_size=1, - max_dim_size=10, - first_dimension_only=True, - indices_same_dims=False, + method_name="zero_", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + allow_inf=False, ), - value=st.floats(min_value=-100, max_value=100), ) -def test_torch_index_fill( - dtype_indices_axis, - value, - frontend, +def test_torch_tensor_zero_( + dtype_and_x, frontend_method_data, init_flags, method_flags, + frontend, on_device, backend_fw, ): - input_dtypes, x, indices, axis, _ = dtype_indices_axis - if indices.ndim != 1: - indices = ivy.flatten(indices) + input_dtype, x = dtype_and_x helpers.test_frontend_method( - init_input_dtypes=[input_dtypes[0]], + init_input_dtypes=input_dtype, backend_to_test=backend_fw, - init_all_as_kwargs_np={"data": x}, - method_input_dtypes=[input_dtypes[1]], - method_all_as_kwargs_np={ - "dim": axis, - "index": indices, - "value": value, + init_all_as_kwargs_np={ + "data": x[0], }, - frontend=frontend, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={}, frontend_method_data=frontend_method_data, init_flags=init_flags, method_flags=method_flags, + frontend=frontend, on_device=on_device, ) @@ -12437,68 +12510,3 @@ def test_torch_triu_( frontend=frontend, on_device=on_device, ) - - -@handle_frontend_method( - class_tree=CLASS_TREE, - init_tree="torch.tensor", - method_name="__array__", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("valid")), - dtype=helpers.get_dtypes("valid", full=False), -) -def test_torch__array__( - dtype_and_x, - dtype, - frontend, - backend_fw, -): - input_dtype, x = dtype_and_x - if x[0].dtype == "bfloat16": - return - dtype[0] = np.dtype(dtype[0]) - ret_gt = torch.tensor(x[0]).__array__(dtype[0]) - with BackendHandler.update_backend(backend_fw) as ivy_backend: - local_importer = ivy_backend.utils.dynamic_import - function_module = local_importer.import_module("ivy.functional.frontends.torch") - ret = function_module.tensor(x[0]).__array__(dtype[0]) - - helpers.value_test( - ret_np_flat=ret.ravel(), - ret_np_from_gt_flat=ret_gt.ravel(), - ground_truth_backend="torch", - backend=backend_fw, - ) - - -@handle_frontend_method( - class_tree=CLASS_TREE, - init_tree="torch.tensor", - method_name="__array_wrap__", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - num_arrays=2, - ), -) -def test_torch___array_wrap__( - dtype_and_x, - backend_fw, - frontend, -): - input_dtypes, x = dtype_and_x - if x[1].dtype == "bfloat16": - return - if x[0].dtype == "bfloat16": - ret_gt = torch.tensor(x[0].tolist(), dtype=torch.bfloat16).__array_wrap__(x[1]) - else: - ret_gt = torch.tensor(x[0]).__array_wrap__(x[1]) - with BackendHandler.update_backend(backend_fw) as ivy_backend: - local_importer = ivy_backend.utils.dynamic_import - function_module = local_importer.import_module("ivy.functional.frontends.torch") - ret = function_module.tensor(x[0]).__array_wrap__(x[1]) - assert isinstance(ret, function_module.Tensor) - helpers.value_test( - ret_np_flat=np.array(ret.ivy_array).ravel(), - ret_np_from_gt_flat=ret_gt.numpy().ravel(), - ground_truth_backend="torch", - backend=backend_fw, - ) diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor_functions.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor_functions.py index 4f38fcf44ade6..61c1e2fe1e7a5 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor_functions.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor_functions.py @@ -5,11 +5,15 @@ @handle_frontend_test( - fn_tree="torch.is_tensor", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("valid")), + fn_tree="torch.is_complex", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=1, + min_dim_size=1, + max_dim_size=1, + ), ) -def test_torch_is_tensor( - *, +def test_torch_is_complex( dtype_and_x, on_device, fn_tree, @@ -21,22 +25,22 @@ def test_torch_is_tensor( helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, - on_device=on_device, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, - obj=x[0], + on_device=on_device, + input=x[0], ) @handle_frontend_test( - fn_tree="torch.numel", + fn_tree="torch.is_floating_point", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), min_num_dims=1, ), ) -def test_torch_numel( +def test_torch_is_floating_point( *, dtype_and_x, on_device, @@ -46,6 +50,7 @@ def test_torch_numel( backend_fw, ): input_dtype, x = dtype_and_x + ivy.set_backend(backend_fw) helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -53,19 +58,21 @@ def test_torch_numel( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=x[0], + input=ivy.asarray(x[0]), ) + ivy.previous_backend() @handle_frontend_test( - fn_tree="torch.is_floating_point", + fn_tree="torch.is_nonzero", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), min_num_dims=1, + min_dim_size=1, + max_dim_size=1, ), ) -def test_torch_is_floating_point( - *, +def test_torch_is_nonzero( dtype_and_x, on_device, fn_tree, @@ -74,7 +81,6 @@ def test_torch_is_floating_point( backend_fw, ): input_dtype, x = dtype_and_x - ivy.set_backend(backend_fw) helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -82,21 +88,16 @@ def test_torch_is_floating_point( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - input=ivy.asarray(x[0]), + input=x[0], ) - ivy.previous_backend() @handle_frontend_test( - fn_tree="torch.is_nonzero", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - min_num_dims=1, - min_dim_size=1, - max_dim_size=1, - ), + fn_tree="torch.is_tensor", + dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("valid")), ) -def test_torch_is_nonzero( +def test_torch_is_tensor( + *, dtype_and_x, on_device, fn_tree, @@ -108,24 +109,23 @@ def test_torch_is_nonzero( helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, + on_device=on_device, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, - on_device=on_device, - input=x[0], + obj=x[0], ) @handle_frontend_test( - fn_tree="torch.is_complex", + fn_tree="torch.numel", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), min_num_dims=1, - min_dim_size=1, - max_dim_size=1, ), ) -def test_torch_is_complex( +def test_torch_numel( + *, dtype_and_x, on_device, fn_tree, diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_utilities.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_utilities.py index 435bda4bfed21..f48d51d3e3487 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_utilities.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_utilities.py @@ -6,35 +6,6 @@ from ivy_tests.test_ivy.helpers import handle_frontend_test -@handle_frontend_test( - fn_tree="torch.result_type", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - num_arrays=2, - ), - test_with_out=st.just(False), -) -def test_torch_result_type( - dtype_and_x, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, x = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - tensor=x[0], - other=x[1], - ) - - # ToDo: Fix this test after torch overide of assert is implemented # @handle_frontend_test( # fn_tree="torch._assert", @@ -101,3 +72,32 @@ def test_torch_bincount( weights=None, minlength=0, ) + + +@handle_frontend_test( + fn_tree="torch.result_type", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + num_arrays=2, + ), + test_with_out=st.just(False), +) +def test_torch_result_type( + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + tensor=x[0], + other=x[1], + ) diff --git a/ivy_tests/test_ivy/test_functional/test_core/test_creation.py b/ivy_tests/test_ivy/test_functional/test_core/test_creation.py index bd4c9c5cff441..3dc85ebc92191 100644 --- a/ivy_tests/test_ivy/test_functional/test_core/test_creation.py +++ b/ivy_tests/test_ivy/test_functional/test_core/test_creation.py @@ -11,143 +11,127 @@ from ivy_tests.test_ivy.test_functional.test_core.test_dtype import astype_helper -# native_array -@handle_test( - fn_tree="functional.ivy.native_array", - dtype_and_x_and_cast_dtype=astype_helper(), - test_with_out=st.just(False), - test_gradients=st.just(False), -) -def test_native_array( - *, - dtype_and_x_and_cast_dtype, - test_flags, - backend_fw, - fn_name, - on_device, -): - input_dtype, x, dtype = dtype_and_x_and_cast_dtype - helpers.test_function( - input_dtypes=input_dtype, - test_flags=test_flags, - on_device=on_device, - backend_to_test=backend_fw, - fn_name=fn_name, - x=x[0], - dtype=dtype[0], - device=on_device, +# --- Helpers --- # +# --------------- # + + +@st.composite +def _asarray_helper(draw): + x_dtype, x = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=st.integers(min_value=1, max_value=10), + min_num_dims=0, + max_num_dims=5, + min_dim_size=1, + max_dim_size=5, + shared_dtype=True, + ) + ) + with BackendHandler.update_backend(test_globals.CURRENT_BACKEND) as ivy_backend: + x_list = ivy_backend.nested_map(x, lambda x: x.tolist(), shallow=False) + sh = draw(helpers.get_shape(min_num_dims=1)) + sh = ivy_backend.Shape(sh) + # np_array = x[0] + # dim = draw(helpers.get_shape(min_num_dims=1)) + # nested_values = draw( + # helpers.create_nested_input(dim, [sh, np_array, x_list[0]]) + # ) + dtype = draw( + helpers.get_castable_dtype( + draw(helpers.get_dtypes("numeric")), dtype=x_dtype[0] + ) + )[-1] + dtype = draw(st.sampled_from([dtype, None])) + x = draw( + st.sampled_from( + [ + x, + x_list, + sh, + # nested_values, + ] + ) ) + return x_dtype, x, dtype -# linspace -@handle_test( - fn_tree="functional.ivy.linspace", - dtype_and_start_stop_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - min_value=-1e5, - max_value=1e5, - min_num_dims=1, - max_num_dims=5, - min_dim_size=1, - max_dim_size=5, - allow_inf=False, - shared_dtype=True, - large_abs_safety_factor=2.5, - small_abs_safety_factor=2.5, - safety_factor_scale="log", - valid_axis=True, - force_int_axis=True, - ), - dtype=helpers.get_dtypes("float", full=False), - num=helpers.ints(min_value=1, max_value=5), - endpoint=st.booleans(), -) -def test_linspace( - *, - dtype_and_start_stop_axis, - num, - endpoint, - dtype, - test_flags, - backend_fw, - fn_name, - on_device, -): - input_dtypes, start_stop, axis = dtype_and_start_stop_axis - helpers.test_function( - input_dtypes=input_dtypes, - test_flags=test_flags, - backend_to_test=backend_fw, - fn_name=fn_name, - on_device=on_device, - rtol_=1e-1, - atol_=0.8, - start=start_stop[0], - stop=start_stop[1], - num=num, - axis=axis, - endpoint=endpoint, - dtype=dtype[0], - device=on_device, +@st.composite +def _dtype_and_values(draw): + return draw( + helpers.dtype_and_values( + min_num_dims=1, + max_num_dims=5, + min_dim_size=1, + max_dim_size=5, + dtype=draw(helpers.get_dtypes("numeric", full=False, key="dtype")), + ) ) -# logspace -@handle_test( - fn_tree="functional.ivy.logspace", - dtype_and_start_stop_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - min_value=-1e5, - max_value=1e5, - min_num_dims=1, - max_num_dims=5, - min_dim_size=1, - max_dim_size=5, - allow_inf=False, - shared_dtype=True, - large_abs_safety_factor=2.5, - small_abs_safety_factor=2.5, - safety_factor_scale="log", - valid_axis=True, - force_int_axis=True, - ), - dtype=helpers.get_dtypes("float", full=False), - num=helpers.ints(min_value=1, max_value=5), - base=helpers.floats(min_value=0.1, max_value=20.0), - endpoint=st.booleans(), -) -def test_logspace( - *, - dtype_and_start_stop_axis, - dtype, - num, - base, - endpoint, - test_flags, - backend_fw, - fn_name, - on_device, -): - input_dtypes, start_stop, axis = dtype_and_start_stop_axis - helpers.test_function( - input_dtypes=input_dtypes, - test_flags=test_flags, - backend_to_test=backend_fw, - fn_name=fn_name, - on_device=on_device, - rtol_=1, # if It's less than one it'll test for inf - atol_=0.8, - start=start_stop[0], - stop=start_stop[1], - num=num, - base=base, - axis=axis, - endpoint=endpoint, - dtype=dtype[0], - device=on_device, +@st.composite +def _dtype_indices_depth_axis(draw): + depth = draw(helpers.ints(min_value=2, max_value=100)) + dtype, indices, shape = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + min_value=0, + max_value=depth - 1, + small_abs_safety_factor=4, + ret_shape=True, + ) + ) + + axis = draw(st.integers(min_value=-1, max_value=len(shape) - 1)) + return dtype, indices, depth, axis + + +@st.composite +def _fill_value(draw): + dtype = draw(helpers.get_dtypes("numeric", full=False, key="dtype"))[0] + with BackendHandler.update_backend(test_globals.CURRENT_BACKEND) as ivy_backend: + if ivy_backend.is_uint_dtype(dtype): + return draw(helpers.ints(min_value=0, max_value=5)) + if ivy_backend.is_int_dtype(dtype): + return draw(helpers.ints(min_value=-5, max_value=5)) + return draw(helpers.floats(min_value=-5, max_value=5)) + + +@st.composite +def _get_dtype_buffer_count_offset(draw): + dtype, value = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + ) ) + value = np.array(value) + length = value.size + value = value.tobytes() + + offset = draw(helpers.ints(min_value=0, max_value=length - 1)) + count = draw(helpers.ints(min_value=-(2**30), max_value=length - offset)) + if count == 0: + count = -1 + offset = offset * np.dtype(dtype[0]).itemsize + + return dtype, value, count, offset + + +@st.composite +def _on_off_dtype(draw): + dtype, value = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + shape=(2,), + safety_factor_scale="log", + ) + ) + [on_value, off_value] = value[0] + return on_value, off_value, dtype[0] + + +# --- Main --- # +# ------------ # # arange @@ -187,47 +171,6 @@ def test_arange( ) -@st.composite -def _asarray_helper(draw): - x_dtype, x = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=st.integers(min_value=1, max_value=10), - min_num_dims=0, - max_num_dims=5, - min_dim_size=1, - max_dim_size=5, - shared_dtype=True, - ) - ) - with BackendHandler.update_backend(test_globals.CURRENT_BACKEND) as ivy_backend: - x_list = ivy_backend.nested_map(x, lambda x: x.tolist(), shallow=False) - sh = draw(helpers.get_shape(min_num_dims=1)) - sh = ivy_backend.Shape(sh) - # np_array = x[0] - # dim = draw(helpers.get_shape(min_num_dims=1)) - # nested_values = draw( - # helpers.create_nested_input(dim, [sh, np_array, x_list[0]]) - # ) - dtype = draw( - helpers.get_castable_dtype( - draw(helpers.get_dtypes("numeric")), dtype=x_dtype[0] - ) - )[-1] - dtype = draw(st.sampled_from([dtype, None])) - x = draw( - st.sampled_from( - [ - x, - x_list, - sh, - # nested_values, - ] - ) - ) - return x_dtype, x, dtype - - # asarray # TODO: Fix container, instance methods and as_variable @handle_test( @@ -263,6 +206,63 @@ def test_asarray( ) +# copy array +@handle_test( + fn_tree="functional.ivy.copy_array", + dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("valid")), + to_ivy_array_bool=st.booleans(), +) +def test_copy_array( + *, + test_flags, + dtype_and_x, + to_ivy_array_bool, + backend_fw, + on_device, +): + dtype, x = dtype_and_x + # avoid enabling gradients for non-float arrays + if test_flags.as_variable[0]: + assume("float" in dtype[0]) + # smoke test + with BackendHandler.update_backend(backend_fw) as ivy_backend: + x = test_flags.apply_flags( + x, dtype, 0, backend=backend_fw, on_device=on_device + )[0] + test_flags.instance_method = ( + test_flags.instance_method if not test_flags.native_arrays[0] else False + ) + if test_flags.instance_method: + ret = x.copy_array(to_ivy_array=to_ivy_array_bool) + else: + ret = ivy_backend.copy_array(x, to_ivy_array=to_ivy_array_bool) + # type test + test_ret = ret + test_x = x + if test_flags.container[0]: + assert ivy_backend.is_ivy_container(ret) + test_ret = ret["a"] + test_x = x["a"] + if to_ivy_array_bool: + assert ivy_backend.is_ivy_array(test_ret) + else: + assert ivy_backend.is_native_array(test_ret) + # cardinality test + assert test_ret.shape == test_x.shape + # value test + x, ret = ivy_backend.to_ivy(x), ivy_backend.to_ivy(ret) + x_np, ret_np = helpers.flatten_and_to_np( + backend=backend_fw, ret=x + ), helpers.flatten_and_to_np(backend=backend_fw, ret=ret) + helpers.value_test( + backend=backend_fw, + ground_truth_backend=backend_fw, + ret_np_flat=ret_np, + ret_np_from_gt_flat=x_np, + ) + assert id(x) != id(ret) + + # empty @handle_test( fn_tree="functional.ivy.empty", @@ -373,15 +373,28 @@ def test_from_dlpack(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device) ) -@st.composite -def _fill_value(draw): - dtype = draw(helpers.get_dtypes("numeric", full=False, key="dtype"))[0] - with BackendHandler.update_backend(test_globals.CURRENT_BACKEND) as ivy_backend: - if ivy_backend.is_uint_dtype(dtype): - return draw(helpers.ints(min_value=0, max_value=5)) - if ivy_backend.is_int_dtype(dtype): - return draw(helpers.ints(min_value=-5, max_value=5)) - return draw(helpers.floats(min_value=-5, max_value=5)) +@handle_test( + fn_tree="functional.ivy.frombuffer", + dtype_buffer_count_offset=_get_dtype_buffer_count_offset(), + test_instance_method=st.just(False), + test_with_out=st.just(False), + test_gradients=st.just(False), +) +def test_frombuffer( + dtype_buffer_count_offset, test_flags, backend_fw, fn_name, on_device +): + input_dtype, buffer, count, offset = dtype_buffer_count_offset + helpers.test_function( + input_dtypes=input_dtype, + test_flags=test_flags, + on_device=on_device, + backend_to_test=backend_fw, + fn_name=fn_name, + buffer=buffer, + dtype=input_dtype[0], + count=count, + offset=offset, + ) # full @@ -413,19 +426,6 @@ def test_full(*, shape, fill_value, dtypes, test_flags, backend_fw, fn_name, on_ ) -@st.composite -def _dtype_and_values(draw): - return draw( - helpers.dtype_and_values( - min_num_dims=1, - max_num_dims=5, - min_dim_size=1, - max_dim_size=5, - dtype=draw(helpers.get_dtypes("numeric", full=False, key="dtype")), - ) - ) - - # full_like @handle_test( fn_tree="functional.ivy.full_like", @@ -435,53 +435,219 @@ def _dtype_and_values(draw): def test_full_like( *, dtype_and_x, fill_value, test_flags, backend_fw, fn_name, on_device ): - dtype, x = dtype_and_x + dtype, x = dtype_and_x + helpers.test_function( + input_dtypes=dtype, + test_flags=test_flags, + on_device=on_device, + backend_to_test=backend_fw, + fn_name=fn_name, + x=x[0], + fill_value=fill_value, + dtype=dtype[0], + device=on_device, + ) + + +# linspace +@handle_test( + fn_tree="functional.ivy.linspace", + dtype_and_start_stop_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, + min_value=-1e5, + max_value=1e5, + min_num_dims=1, + max_num_dims=5, + min_dim_size=1, + max_dim_size=5, + allow_inf=False, + shared_dtype=True, + large_abs_safety_factor=2.5, + small_abs_safety_factor=2.5, + safety_factor_scale="log", + valid_axis=True, + force_int_axis=True, + ), + dtype=helpers.get_dtypes("float", full=False), + num=helpers.ints(min_value=1, max_value=5), + endpoint=st.booleans(), +) +def test_linspace( + *, + dtype_and_start_stop_axis, + num, + endpoint, + dtype, + test_flags, + backend_fw, + fn_name, + on_device, +): + input_dtypes, start_stop, axis = dtype_and_start_stop_axis + helpers.test_function( + input_dtypes=input_dtypes, + test_flags=test_flags, + backend_to_test=backend_fw, + fn_name=fn_name, + on_device=on_device, + rtol_=1e-1, + atol_=0.8, + start=start_stop[0], + stop=start_stop[1], + num=num, + axis=axis, + endpoint=endpoint, + dtype=dtype[0], + device=on_device, + ) + + +# logspace +@handle_test( + fn_tree="functional.ivy.logspace", + dtype_and_start_stop_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, + min_value=-1e5, + max_value=1e5, + min_num_dims=1, + max_num_dims=5, + min_dim_size=1, + max_dim_size=5, + allow_inf=False, + shared_dtype=True, + large_abs_safety_factor=2.5, + small_abs_safety_factor=2.5, + safety_factor_scale="log", + valid_axis=True, + force_int_axis=True, + ), + dtype=helpers.get_dtypes("float", full=False), + num=helpers.ints(min_value=1, max_value=5), + base=helpers.floats(min_value=0.1, max_value=20.0), + endpoint=st.booleans(), +) +def test_logspace( + *, + dtype_and_start_stop_axis, + dtype, + num, + base, + endpoint, + test_flags, + backend_fw, + fn_name, + on_device, +): + input_dtypes, start_stop, axis = dtype_and_start_stop_axis + helpers.test_function( + input_dtypes=input_dtypes, + test_flags=test_flags, + backend_to_test=backend_fw, + fn_name=fn_name, + on_device=on_device, + rtol_=1, # if It's less than one it'll test for inf + atol_=0.8, + start=start_stop[0], + stop=start_stop[1], + num=num, + base=base, + axis=axis, + endpoint=endpoint, + dtype=dtype[0], + device=on_device, + ) + + +# meshgrid +@handle_test( + fn_tree="functional.ivy.meshgrid", + dtype_and_arrays=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=st.integers(min_value=2, max_value=5), + min_num_dims=1, + max_num_dims=1, + shared_dtype=True, + ), + sparse=st.booleans(), + indexing=st.sampled_from(["xy", "ij"]), + test_with_out=st.just(False), +) +def test_meshgrid( + *, dtype_and_arrays, test_flags, sparse, indexing, backend_fw, fn_name, on_device +): + dtype, arrays = dtype_and_arrays + kw = {} + i = 0 + for x_ in arrays: + kw["x{}".format(i)] = x_ + i += 1 + test_flags.num_positional_args = len(arrays) + helpers.test_function( + input_dtypes=dtype, + test_flags=test_flags, + on_device=on_device, + backend_to_test=backend_fw, + fn_name=fn_name, + **kw, + sparse=sparse, + indexing=indexing, + ) + + +# native_array +@handle_test( + fn_tree="functional.ivy.native_array", + dtype_and_x_and_cast_dtype=astype_helper(), + test_with_out=st.just(False), + test_gradients=st.just(False), +) +def test_native_array( + *, + dtype_and_x_and_cast_dtype, + test_flags, + backend_fw, + fn_name, + on_device, +): + input_dtype, x, dtype = dtype_and_x_and_cast_dtype helpers.test_function( - input_dtypes=dtype, + input_dtypes=input_dtype, test_flags=test_flags, on_device=on_device, backend_to_test=backend_fw, fn_name=fn_name, x=x[0], - fill_value=fill_value, dtype=dtype[0], device=on_device, ) -# meshgrid +# one_hot @handle_test( - fn_tree="functional.ivy.meshgrid", - dtype_and_arrays=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=st.integers(min_value=2, max_value=5), - min_num_dims=1, - max_num_dims=1, - shared_dtype=True, - ), - sparse=st.booleans(), - indexing=st.sampled_from(["xy", "ij"]), - test_with_out=st.just(False), + fn_tree="functional.ivy.one_hot", + dtype_indices_depth_axis=_dtype_indices_depth_axis(), + on_off_dtype=_on_off_dtype(), + test_gradients=st.just(False), ) -def test_meshgrid( - *, dtype_and_arrays, test_flags, sparse, indexing, backend_fw, fn_name, on_device +def test_one_hot( + dtype_indices_depth_axis, on_off_dtype, test_flags, backend_fw, fn_name, on_device ): - dtype, arrays = dtype_and_arrays - kw = {} - i = 0 - for x_ in arrays: - kw["x{}".format(i)] = x_ - i += 1 - test_flags.num_positional_args = len(arrays) + input_dtype, indices, depth, axis = dtype_indices_depth_axis + on_value, off_value, dtype = on_off_dtype helpers.test_function( - input_dtypes=dtype, + input_dtypes=input_dtype, test_flags=test_flags, on_device=on_device, backend_to_test=backend_fw, fn_name=fn_name, - **kw, - sparse=sparse, - indexing=indexing, + indices=indices[0], + depth=depth, + on_value=on_value, + off_value=off_value, + axis=axis, + dtype=dtype, ) @@ -585,6 +751,32 @@ def test_triu(*, dtype_and_x, k, test_flags, backend_fw, fn_name, on_device): ) +@handle_test( + fn_tree="functional.ivy.triu_indices", + n_rows=st.integers(min_value=0, max_value=5), + n_cols=st.integers(min_value=0, max_value=5) | st.just(None), + k=st.integers(min_value=-5, max_value=5), + input_dtype=helpers.get_dtypes("integer"), + test_with_out=st.just(False), + test_gradients=st.just(False), + test_instance_method=st.just(False), +) +def test_triu_indices( + *, n_rows, n_cols, k, input_dtype, test_flags, backend_fw, fn_name, on_device +): + input_dtype = input_dtype + helpers.test_function( + input_dtypes=input_dtype, + test_flags=test_flags, + backend_to_test=backend_fw, + on_device=on_device, + fn_name=fn_name, + n_rows=n_rows, + n_cols=n_cols, + k=k, + ) + + # zeros @handle_test( fn_tree="functional.ivy.zeros", @@ -635,187 +827,3 @@ def test_zeros_like(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): dtype=dtype[0], device=on_device, ) - - -# copy array -@handle_test( - fn_tree="functional.ivy.copy_array", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("valid")), - to_ivy_array_bool=st.booleans(), -) -def test_copy_array( - *, - test_flags, - dtype_and_x, - to_ivy_array_bool, - backend_fw, - on_device, -): - dtype, x = dtype_and_x - # avoid enabling gradients for non-float arrays - if test_flags.as_variable[0]: - assume("float" in dtype[0]) - # smoke test - with BackendHandler.update_backend(backend_fw) as ivy_backend: - x = test_flags.apply_flags( - x, dtype, 0, backend=backend_fw, on_device=on_device - )[0] - test_flags.instance_method = ( - test_flags.instance_method if not test_flags.native_arrays[0] else False - ) - if test_flags.instance_method: - ret = x.copy_array(to_ivy_array=to_ivy_array_bool) - else: - ret = ivy_backend.copy_array(x, to_ivy_array=to_ivy_array_bool) - # type test - test_ret = ret - test_x = x - if test_flags.container[0]: - assert ivy_backend.is_ivy_container(ret) - test_ret = ret["a"] - test_x = x["a"] - if to_ivy_array_bool: - assert ivy_backend.is_ivy_array(test_ret) - else: - assert ivy_backend.is_native_array(test_ret) - # cardinality test - assert test_ret.shape == test_x.shape - # value test - x, ret = ivy_backend.to_ivy(x), ivy_backend.to_ivy(ret) - x_np, ret_np = helpers.flatten_and_to_np( - backend=backend_fw, ret=x - ), helpers.flatten_and_to_np(backend=backend_fw, ret=ret) - helpers.value_test( - backend=backend_fw, - ground_truth_backend=backend_fw, - ret_np_flat=ret_np, - ret_np_from_gt_flat=x_np, - ) - assert id(x) != id(ret) - - -@st.composite -def _dtype_indices_depth_axis(draw): - depth = draw(helpers.ints(min_value=2, max_value=100)) - dtype, indices, shape = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - min_value=0, - max_value=depth - 1, - small_abs_safety_factor=4, - ret_shape=True, - ) - ) - - axis = draw(st.integers(min_value=-1, max_value=len(shape) - 1)) - return dtype, indices, depth, axis - - -@st.composite -def _on_off_dtype(draw): - dtype, value = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - shape=(2,), - safety_factor_scale="log", - ) - ) - [on_value, off_value] = value[0] - return on_value, off_value, dtype[0] - - -# one_hot -@handle_test( - fn_tree="functional.ivy.one_hot", - dtype_indices_depth_axis=_dtype_indices_depth_axis(), - on_off_dtype=_on_off_dtype(), - test_gradients=st.just(False), -) -def test_one_hot( - dtype_indices_depth_axis, on_off_dtype, test_flags, backend_fw, fn_name, on_device -): - input_dtype, indices, depth, axis = dtype_indices_depth_axis - on_value, off_value, dtype = on_off_dtype - helpers.test_function( - input_dtypes=input_dtype, - test_flags=test_flags, - on_device=on_device, - backend_to_test=backend_fw, - fn_name=fn_name, - indices=indices[0], - depth=depth, - on_value=on_value, - off_value=off_value, - axis=axis, - dtype=dtype, - ) - - -@st.composite -def _get_dtype_buffer_count_offset(draw): - dtype, value = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - ) - ) - value = np.array(value) - length = value.size - value = value.tobytes() - - offset = draw(helpers.ints(min_value=0, max_value=length - 1)) - count = draw(helpers.ints(min_value=-(2**30), max_value=length - offset)) - if count == 0: - count = -1 - offset = offset * np.dtype(dtype[0]).itemsize - - return dtype, value, count, offset - - -@handle_test( - fn_tree="functional.ivy.frombuffer", - dtype_buffer_count_offset=_get_dtype_buffer_count_offset(), - test_instance_method=st.just(False), - test_with_out=st.just(False), - test_gradients=st.just(False), -) -def test_frombuffer( - dtype_buffer_count_offset, test_flags, backend_fw, fn_name, on_device -): - input_dtype, buffer, count, offset = dtype_buffer_count_offset - helpers.test_function( - input_dtypes=input_dtype, - test_flags=test_flags, - on_device=on_device, - backend_to_test=backend_fw, - fn_name=fn_name, - buffer=buffer, - dtype=input_dtype[0], - count=count, - offset=offset, - ) - - -@handle_test( - fn_tree="functional.ivy.triu_indices", - n_rows=st.integers(min_value=0, max_value=5), - n_cols=st.integers(min_value=0, max_value=5) | st.just(None), - k=st.integers(min_value=-5, max_value=5), - input_dtype=helpers.get_dtypes("integer"), - test_with_out=st.just(False), - test_gradients=st.just(False), - test_instance_method=st.just(False), -) -def test_triu_indices( - *, n_rows, n_cols, k, input_dtype, test_flags, backend_fw, fn_name, on_device -): - input_dtype = input_dtype - helpers.test_function( - input_dtypes=input_dtype, - test_flags=test_flags, - backend_to_test=backend_fw, - on_device=on_device, - fn_name=fn_name, - n_rows=n_rows, - n_cols=n_cols, - k=k, - ) diff --git a/ivy_tests/test_ivy/test_functional/test_core/test_device.py b/ivy_tests/test_ivy/test_functional/test_core/test_device.py index a2a6374781b0a..3cfc21219ab29 100644 --- a/ivy_tests/test_ivy/test_functional/test_core/test_device.py +++ b/ivy_tests/test_ivy/test_functional/test_core/test_device.py @@ -14,6 +14,14 @@ import subprocess from hypothesis import strategies as st, assume +# nvidia-ml-py (pynvml) is not installed in CPU Dockerfile. + +# local +import ivy +import ivy_tests.test_ivy.helpers as helpers +import ivy_tests.test_ivy.helpers.globals as test_globals +from ivy_tests.test_ivy.helpers import handle_test, BackendHandler + try: import pynvml except ImportError: @@ -22,36 +30,37 @@ " of the Ivy's device module will be limited. Please install pynvml if" " you wish to use GPUs with Ivy." ) - # nvidia-ml-py (pynvml) is not installed in CPU Dockerfile. -# local -import ivy -import ivy_tests.test_ivy.helpers as helpers -import ivy_tests.test_ivy.helpers.globals as test_globals -from ivy_tests.test_ivy.helpers import handle_test, BackendHandler +# --- Helpers --- # +# --------------- # -# Helpers # -# ------- # +# Function Splitting # -def _ram_array_and_clear_test(metric_fn, device, size=10000000): - # This function checks if the memory usage changes before, during and after - # Measure usage before creating array - before = metric_fn() - # Create an array of floats, by default with 10 million elements (40 MB) - arr = ivy.random_normal(shape=(size,), dtype="float32", device=device) - during = metric_fn() - # Check that the memory usage has increased - assert before < during +@st.composite +def _axis(draw): + max_val = draw(st.shared(helpers.ints(), key="num_dims")) + return draw(helpers.ints(min_value=0, max_value=max_val - 1)) - # Delete the array - del arr - # Measure the memory usage after the array is deleted - after = metric_fn() - # Check that the memory usage has decreased - assert during > after + +def _composition_1(backend_fw): + with BackendHandler.update_backend(backend_fw) as ivy_backend: + return ivy_backend.relu().argmax() + + +def _composition_2(backend_fw): + with BackendHandler.update_backend(backend_fw) as ivy_backend: + return ivy_backend.ceil() or ivy_backend.floor() + + +def _empty_dir(path, recreate=False): + # Delete the directory if it exists and create it again if recreate is True + if os.path.exists(path): + shutil.rmtree(path) + if recreate: + os.makedirs(path) def _get_possible_devices(): @@ -66,50 +75,39 @@ def _get_possible_devices(): return list(map(ivy_backend.Device, devices)) -def _empty_dir(path, recreate=False): - # Delete the directory if it exists and create it again if recreate is True - if os.path.exists(path): - shutil.rmtree(path) - if recreate: - os.makedirs(path) +def _ram_array_and_clear_test(metric_fn, device, size=10000000): + # This function checks if the memory usage changes before, during and after + + # Measure usage before creating array + before = metric_fn() + # Create an array of floats, by default with 10 million elements (40 MB) + arr = ivy.random_normal(shape=(size,), dtype="float32", device=device) + during = metric_fn() + # Check that the memory usage has increased + assert before < during + # Delete the array + del arr + # Measure the memory usage after the array is deleted + after = metric_fn() + # Check that the memory usage has decreased + assert during > after -# Tests # -# ------# -# Device Queries # +# --- Main --- # +# ------------ # -# dev -@handle_test( - fn_tree="functional.ivy.dev", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - ), -) -def test_dev(*, dtype_and_x, test_flags, backend_fw): - dtype, x = dtype_and_x - dtype = dtype[0] - x = x[0] +def get_cpu_percent(): + output = str(subprocess.check_output(["top", "-bn1"])) + cpu_percent = float(re.search(r"%Cpu\(s\):\s+([\d.]+)\s+us", output).group(1)) + return cpu_percent - with BackendHandler.update_backend(backend_fw) as ivy_backend: - for device in _get_possible_devices(): - x = ivy_backend.array(x, device=device) - if test_flags.as_variable and ivy_backend.is_float_dtype(dtype): - x = ivy_backend.functional.ivy.gradients._variable(x) - ret = ivy_backend.dev(x) - # type test - assert isinstance(ret, str) - # value test - assert ret == device - # array instance test - assert x.dev() == device - # container instance test - container_x = ivy_backend.Container({"a": x}) - assert container_x.dev() == device - # container static test - assert ivy_backend.Container.static_dev(container_x) == device +def get_gpu_mem_usage(backend, device="gpu:0"): + handle = backend.ivy.functional.ivy.device._get_nvml_gpu_handle(device) + info = pynvml.nvmlDeviceGetMemoryInfo(handle) + return (info.used / info.total) * 100 # as_ivy_dev @@ -172,6 +170,24 @@ def test_as_native_dev(*, dtype_and_x, test_flags, on_device, backend_fw): assert ret == device +@handle_test(fn_tree="clear_cached_mem_on_dev") +def test_clear_cached_mem_on_dev(backend_fw): + with BackendHandler.update_backend(backend_fw) as ivy_backend: + devices = _get_possible_devices() + for device in devices: + # Testing on only GPU since clearing cache mem is relevant + # for only CUDA devices + if "gpu" in device: + arr = ivy_backend.random_normal( # noqa: F841 + shape=(10000, 1000), dtype="float32", device=device + ) + del arr + before = get_gpu_mem_usage(device) + ivy_backend.clear_cached_mem_on_dev(device) + after = get_gpu_mem_usage(device) + assert before > after + + # Device Allocation # # default_device @handle_test(fn_tree="functional.ivy.default_device") @@ -198,99 +214,151 @@ def test_default_device(backend_fw): assert len(ivy_backend.default_device_stack) == orig_len -# to_dev +# Tests # +# ------# + +# Device Queries # + + +# dev @handle_test( - fn_tree="functional.ivy.to_device", + fn_tree="functional.ivy.dev", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), ), - stream=helpers.ints(min_value=0, max_value=50), ) -def test_to_device( - *, - dtype_and_x, - stream, - test_flags, - backend_fw, - on_device, -): +def test_dev(*, dtype_and_x, test_flags, backend_fw): dtype, x = dtype_and_x dtype = dtype[0] x = x[0] with BackendHandler.update_backend(backend_fw) as ivy_backend: - x = ivy_backend.asarray(x) - if test_flags.as_variable and ivy_backend.is_float_dtype(dtype): - x = ivy_backend.functional.ivy.gradients._variable(x) - - # create a dummy array for out that is broadcastable to x - out = ( - ivy_backend.zeros(ivy_backend.shape(x), device=on_device, dtype=dtype) - if test_flags.with_out - else None - ) + for device in _get_possible_devices(): + x = ivy_backend.array(x, device=device) + if test_flags.as_variable and ivy_backend.is_float_dtype(dtype): + x = ivy_backend.functional.ivy.gradients._variable(x) - device = ivy_backend.dev(x) - x_on_dev = ivy_backend.to_device(x, on_device, stream=stream, out=out) - dev_from_new_x = ivy_backend.dev(x_on_dev) + ret = ivy_backend.dev(x) + # type test + assert isinstance(ret, str) + # value test + assert ret == device + # array instance test + assert x.dev() == device + # container instance test + container_x = ivy_backend.Container({"a": x}) + assert container_x.dev() == device + # container static test + assert ivy_backend.Container.static_dev(container_x) == device - if test_flags.with_out: - # should be the same array test - assert x_on_dev is out - # should be the same device - if backend_fw != "paddle": - assert ivy_backend.dev(x_on_dev, as_native=True) == ivy_backend.dev( - out, as_native=True - ) - else: - assert ivy_backend.dev(x_on_dev, as_native=False) == ivy_backend.dev( - out, as_native=False +@handle_test(fn_tree="dev_util") +def test_dev_util(backend_fw): + with BackendHandler.update_backend(backend_fw) as ivy_backend: + devices = _get_possible_devices() + for device in devices: + # The internally called psutil.cpu_percent() has a unique behavior where it + # returns 0 as usage when run the second time in same line so simple + # assert psutil.cpu_percent() == ivy.dev_util(device) isn't possible + if "cpu" in device: + assert 100 >= ivy_backend.dev_util(device) >= 0 + # Comparing CPU utilization using top. Two percentiles won't be directly + # equal but absolute difference should be below a safe threshold + assert abs(get_cpu_percent() - ivy_backend.dev_util(device)) < 10 + elif "gpu" in device: + handle = ivy_backend.functional.ivy.device._get_nvml_gpu_handle(device) + assert ( + ivy_backend.dev_util(device) + == pynvml.nvmlDeviceGetUtilizationRates(handle).gpu ) - # check if native arrays are the same - # these backends do not support native inplace updates - assume(not (backend_fw in ["tensorflow", "jax"])) - - assert x_on_dev.data is out.data - - # value test - if backend_fw == "tensorflow": - assert "/" + ":".join(dev_from_new_x[1:].split(":")[-2:]) == "/" + ":".join( - device[1:].split(":")[-2:] - ) - elif backend_fw == "torch": - assert type(dev_from_new_x) == type(device) - else: - assert dev_from_new_x == device - - # array instance test - assert x.to_device(device).dev() == device - # container instance test - container_x = ivy_backend.Container({"x": x}) - assert container_x.to_device(device).dev() == device - # container static test - assert ivy_backend.Container.to_device(container_x, device).dev() == device - -# handle_soft_device_variable +# function_unsupported_devices @handle_test( - fn_tree="functional.ivy.handle_soft_device_variable", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - min_num_dims=1, - ), + fn_tree="functional.ivy.function_supported_devices", + func=st.sampled_from([_composition_1, _composition_2]), + expected=st.just(["cpu"]), ) -def test_handle_soft_device_variable(*, dtype_and_x, backend_fw): - dtype, x = dtype_and_x - dtype = dtype[0] +def test_function_supported_devices( + *, + func, + expected, + backend_fw, +): with BackendHandler.update_backend(backend_fw) as ivy_backend: - x = ivy_backend.to_device(x[0], "cpu") + res = ivy_backend.function_supported_devices(func) + exp = set(expected) - def fn(x, y): - return ivy_backend.add(x, y) + assert sorted(tuple(exp)) == sorted(res) - for device in _get_possible_devices(): + +# function_unsupported_devices +@handle_test( + fn_tree="functional.ivy.function_supported_devices", + func=st.sampled_from([_composition_1, _composition_2]), + expected=st.just(["gpu", "tpu"]), +) +def test_function_unsupported_devices( + *, + func, + expected, + backend_fw, +): + with BackendHandler.update_backend(backend_fw) as ivy_backend: + res = ivy_backend.function_unsupported_devices(func) + exp = set(expected) + + assert sorted(tuple(exp)) == sorted(res) + + +@handle_test( + fn_tree="functional.ivy.get_all_ivy_arrays_on_dev", + num=helpers.ints(min_value=0, max_value=5), +) +def test_get_all_ivy_arrays_on_dev( + *, + num, + on_device, + backend_fw, +): + with BackendHandler.update_backend(backend_fw) as ivy_backend: + arrays = [ivy_backend.array(np.random.uniform(size=2)) for _ in range(num)] + arr_ids_on_dev = [ + id(a) for a in ivy_backend.get_all_ivy_arrays_on_dev(on_device).values() + ] + for a in arrays: + assert id(a) in arr_ids_on_dev + + +@handle_test(fn_tree="gpu_is_available") +def test_gpu_is_available(backend_fw): + with BackendHandler.update_backend(backend_fw) as ivy_backend: + # If gpu is available but cannot be initialised it will fail the test + if ivy_backend.gpu_is_available(): + try: + pynvml.nvmlInit() + except pynvml.NVMLError: + assert False + + +# handle_soft_device_variable +@handle_test( + fn_tree="functional.ivy.handle_soft_device_variable", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + min_num_dims=1, + ), +) +def test_handle_soft_device_variable(*, dtype_and_x, backend_fw): + dtype, x = dtype_and_x + dtype = dtype[0] + with BackendHandler.update_backend(backend_fw) as ivy_backend: + x = ivy_backend.to_device(x[0], "cpu") + + def fn(x, y): + return ivy_backend.add(x, y) + + for device in _get_possible_devices(): ivy_backend.set_default_device(device) out = ivy_backend.handle_soft_device_variable(x, fn=fn, y=x) @@ -298,13 +366,153 @@ def fn(x, y): assert out.device == ivy_backend.default_device() -# Function Splitting # +@handle_test(fn_tree="num_cpu_cores") +def test_num_cpu_cores(backend_fw): + with BackendHandler.update_backend(backend_fw) as ivy_backend: + # using multiprocessing module too because ivy uses psutil as basis. + p_cpu_cores = psutil.cpu_count() + m_cpu_cores = multiprocessing.cpu_count() + assert type(ivy_backend.num_cpu_cores()) == int + assert ivy_backend.num_cpu_cores() == p_cpu_cores + assert ivy_backend.num_cpu_cores() == m_cpu_cores -@st.composite -def _axis(draw): - max_val = draw(st.shared(helpers.ints(), key="num_dims")) - return draw(helpers.ints(min_value=0, max_value=max_val - 1)) +@handle_test( + fn_tree="functional.ivy.num_ivy_arrays_on_dev", + num=helpers.ints(min_value=0, max_value=5), +) +def test_num_ivy_arrays_on_dev( + *, + num, + on_device, + backend_fw, +): + with BackendHandler.update_backend(backend_fw) as ivy_backend: + arrays = [ + ivy_backend.array(np.random.uniform(size=2).tolist(), device=on_device) + for _ in range(num) + ] + assert ivy_backend.num_ivy_arrays_on_dev(on_device) == num + for item in arrays: + del item + + +@handle_test(fn_tree="percent_used_mem_on_dev") +def test_percent_used_mem_on_dev(backend_fw): + with BackendHandler.update_backend(backend_fw) as ivy_backend: + devices = _get_possible_devices() + + for device in devices: + used = ivy_backend.percent_used_mem_on_dev(ivy_backend.Device(device)) + assert 0 <= used <= 100 + + # Same as test_used_mem_on_dev, but using percent of total memory as metric + # function + _ram_array_and_clear_test( + lambda: ivy_backend.percent_used_mem_on_dev( + device, process_specific=True + ), + device=device, + ) + + +@handle_test( + fn_tree="functional.ivy.print_all_ivy_arrays_on_dev", + num=helpers.ints(min_value=0, max_value=2), + attr_only=st.booleans(), +) +def test_print_all_ivy_arrays_on_dev( + *, + num, + attr_only, + on_device, + backend_fw, +): + with BackendHandler.update_backend(backend_fw) as ivy_backend: + arr = [ivy_backend.array(np.random.uniform(size=2)) for _ in range(num)] + + # Flush to avoid artifact + sys.stdout.flush() + # temporarily redirect output to a buffer + captured_output = io.StringIO() + sys.stdout = captured_output + + ivy_backend.print_all_ivy_arrays_on_dev(device=on_device, attr_only=attr_only) + # Flush again to make sure all data is printed + sys.stdout.flush() + written = captured_output.getvalue().splitlines() + # restore stdout + sys.stdout = sys.__stdout__ + + # Should have written same number of lines as the number of array in device + assert len(written) == num + + if attr_only: + # Check that the attribute are printed are in the format of + # (ivy.Shape(dim,...), type) + regex = r"^\(ivy.Shape\((\d+,(\d,\d*)*)\), \'\w*\'\)$" + else: + # Check that the arrays are printed are in the format of ivy.array(...) + regex = r"^ivy\.array\(\[.*\]\)$" + + # Clear the array from device + for item in arr: + del item + + # Apply the regex search + assert all([re.match(regex, line) for line in written]) + + +# profiler +@handle_test( + fn_tree="functional.ivy.Profiler", +) +def test_profiler(*, backend_fw): + # ToDo: find way to prevent this test from hanging when run + # alongside other tests in parallel + + # log dir, each framework uses their own folder, + # so we can run this test in parallel + with BackendHandler.update_backend(backend_fw) as ivy_backend: + this_dir = os.path.dirname(os.path.realpath(__file__)) + log_dir = os.path.join(this_dir, "../log") + fw_log_dir = os.path.join(log_dir, backend_fw) + + # Remove old content and recreate log dir + _empty_dir(fw_log_dir, True) + + # with statement + with ivy_backend.Profiler(fw_log_dir): + a = ivy_backend.ones([10]) + b = ivy_backend.zeros([10]) + _ = a + b + + # Should have content in folder + assert len(os.listdir(fw_log_dir)) != 0, "Profiler did not log anything" + + # Remove old content and recreate log dir + _empty_dir(fw_log_dir, True) + + # Profiler should stop log + assert ( + len(os.listdir(fw_log_dir)) == 0 + ), "Profiler logged something while stopped" + + # start and stop methods + profiler = ivy_backend.Profiler(fw_log_dir) + profiler.start() + a = ivy_backend.ones([10]) + b = ivy_backend.zeros([10]) + _ = a + b + profiler.stop() + + # Should have content in folder + assert len(os.listdir(fw_log_dir)) != 0, "Profiler did not log anything" + + # Remove old content including the logging folder + _empty_dir(fw_log_dir, False) + + assert not os.path.exists(fw_log_dir), "Profiler recreated logging folder" @handle_test( @@ -431,142 +639,79 @@ def func(t0, t1): ) -# profiler +# to_dev @handle_test( - fn_tree="functional.ivy.Profiler", + fn_tree="functional.ivy.to_device", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + ), + stream=helpers.ints(min_value=0, max_value=50), ) -def test_profiler(*, backend_fw): - # ToDo: find way to prevent this test from hanging when run - # alongside other tests in parallel +def test_to_device( + *, + dtype_and_x, + stream, + test_flags, + backend_fw, + on_device, +): + dtype, x = dtype_and_x + dtype = dtype[0] + x = x[0] - # log dir, each framework uses their own folder, - # so we can run this test in parallel with BackendHandler.update_backend(backend_fw) as ivy_backend: - this_dir = os.path.dirname(os.path.realpath(__file__)) - log_dir = os.path.join(this_dir, "../log") - fw_log_dir = os.path.join(log_dir, backend_fw) + x = ivy_backend.asarray(x) + if test_flags.as_variable and ivy_backend.is_float_dtype(dtype): + x = ivy_backend.functional.ivy.gradients._variable(x) - # Remove old content and recreate log dir - _empty_dir(fw_log_dir, True) - - # with statement - with ivy_backend.Profiler(fw_log_dir): - a = ivy_backend.ones([10]) - b = ivy_backend.zeros([10]) - _ = a + b - - # Should have content in folder - assert len(os.listdir(fw_log_dir)) != 0, "Profiler did not log anything" - - # Remove old content and recreate log dir - _empty_dir(fw_log_dir, True) - - # Profiler should stop log - assert ( - len(os.listdir(fw_log_dir)) == 0 - ), "Profiler logged something while stopped" - - # start and stop methods - profiler = ivy_backend.Profiler(fw_log_dir) - profiler.start() - a = ivy_backend.ones([10]) - b = ivy_backend.zeros([10]) - _ = a + b - profiler.stop() - - # Should have content in folder - assert len(os.listdir(fw_log_dir)) != 0, "Profiler did not log anything" - - # Remove old content including the logging folder - _empty_dir(fw_log_dir, False) - - assert not os.path.exists(fw_log_dir), "Profiler recreated logging folder" - - -@handle_test( - fn_tree="functional.ivy.num_ivy_arrays_on_dev", - num=helpers.ints(min_value=0, max_value=5), -) -def test_num_ivy_arrays_on_dev( - *, - num, - on_device, - backend_fw, -): - with BackendHandler.update_backend(backend_fw) as ivy_backend: - arrays = [ - ivy_backend.array(np.random.uniform(size=2).tolist(), device=on_device) - for _ in range(num) - ] - assert ivy_backend.num_ivy_arrays_on_dev(on_device) == num - for item in arrays: - del item - - -@handle_test( - fn_tree="functional.ivy.get_all_ivy_arrays_on_dev", - num=helpers.ints(min_value=0, max_value=5), -) -def test_get_all_ivy_arrays_on_dev( - *, - num, - on_device, - backend_fw, -): - with BackendHandler.update_backend(backend_fw) as ivy_backend: - arrays = [ivy_backend.array(np.random.uniform(size=2)) for _ in range(num)] - arr_ids_on_dev = [ - id(a) for a in ivy_backend.get_all_ivy_arrays_on_dev(on_device).values() - ] - for a in arrays: - assert id(a) in arr_ids_on_dev + # create a dummy array for out that is broadcastable to x + out = ( + ivy_backend.zeros(ivy_backend.shape(x), device=on_device, dtype=dtype) + if test_flags.with_out + else None + ) + device = ivy_backend.dev(x) + x_on_dev = ivy_backend.to_device(x, on_device, stream=stream, out=out) + dev_from_new_x = ivy_backend.dev(x_on_dev) -@handle_test( - fn_tree="functional.ivy.print_all_ivy_arrays_on_dev", - num=helpers.ints(min_value=0, max_value=2), - attr_only=st.booleans(), -) -def test_print_all_ivy_arrays_on_dev( - *, - num, - attr_only, - on_device, - backend_fw, -): - with BackendHandler.update_backend(backend_fw) as ivy_backend: - arr = [ivy_backend.array(np.random.uniform(size=2)) for _ in range(num)] + if test_flags.with_out: + # should be the same array test + assert x_on_dev is out - # Flush to avoid artifact - sys.stdout.flush() - # temporarily redirect output to a buffer - captured_output = io.StringIO() - sys.stdout = captured_output + # should be the same device + if backend_fw != "paddle": + assert ivy_backend.dev(x_on_dev, as_native=True) == ivy_backend.dev( + out, as_native=True + ) + else: + assert ivy_backend.dev(x_on_dev, as_native=False) == ivy_backend.dev( + out, as_native=False + ) - ivy_backend.print_all_ivy_arrays_on_dev(device=on_device, attr_only=attr_only) - # Flush again to make sure all data is printed - sys.stdout.flush() - written = captured_output.getvalue().splitlines() - # restore stdout - sys.stdout = sys.__stdout__ + # check if native arrays are the same + # these backends do not support native inplace updates + assume(not (backend_fw in ["tensorflow", "jax"])) - # Should have written same number of lines as the number of array in device - assert len(written) == num + assert x_on_dev.data is out.data - if attr_only: - # Check that the attribute are printed are in the format of - # (ivy.Shape(dim,...), type) - regex = r"^\(ivy.Shape\((\d+,(\d,\d*)*)\), \'\w*\'\)$" + # value test + if backend_fw == "tensorflow": + assert "/" + ":".join(dev_from_new_x[1:].split(":")[-2:]) == "/" + ":".join( + device[1:].split(":")[-2:] + ) + elif backend_fw == "torch": + assert type(dev_from_new_x) == type(device) else: - # Check that the arrays are printed are in the format of ivy.array(...) - regex = r"^ivy\.array\(\[.*\]\)$" - - # Clear the array from device - for item in arr: - del item + assert dev_from_new_x == device - # Apply the regex search - assert all([re.match(regex, line) for line in written]) + # array instance test + assert x.to_device(device).dev() == device + # container instance test + container_x = ivy_backend.Container({"x": x}) + assert container_x.to_device(device).dev() == device + # container static test + assert ivy_backend.Container.to_device(container_x, device).dev() == device @handle_test(fn_tree="total_mem_on_dev") @@ -585,164 +730,6 @@ def test_total_mem_on_dev(backend_fw): assert ivy_backend.total_mem_on_dev(device) == gpu_mem.total / 1e9 -@handle_test(fn_tree="used_mem_on_dev") -def test_used_mem_on_dev(backend_fw): - with BackendHandler.update_backend(backend_fw) as ivy_backend: - devices = _get_possible_devices() - - # Check that there not all memory is used - for device in devices: - assert ivy_backend.used_mem_on_dev(device) > 0 - assert ivy_backend.used_mem_on_dev(device) < ivy_backend.total_mem_on_dev( - device - ) - - _ram_array_and_clear_test( - lambda: ivy_backend.used_mem_on_dev(device, process_specific=True), - device=device, - ) - - -@handle_test(fn_tree="percent_used_mem_on_dev") -def test_percent_used_mem_on_dev(backend_fw): - with BackendHandler.update_backend(backend_fw) as ivy_backend: - devices = _get_possible_devices() - - for device in devices: - used = ivy_backend.percent_used_mem_on_dev(ivy_backend.Device(device)) - assert 0 <= used <= 100 - - # Same as test_used_mem_on_dev, but using percent of total memory as metric - # function - _ram_array_and_clear_test( - lambda: ivy_backend.percent_used_mem_on_dev( - device, process_specific=True - ), - device=device, - ) - - -@handle_test(fn_tree="gpu_is_available") -def test_gpu_is_available(backend_fw): - with BackendHandler.update_backend(backend_fw) as ivy_backend: - # If gpu is available but cannot be initialised it will fail the test - if ivy_backend.gpu_is_available(): - try: - pynvml.nvmlInit() - except pynvml.NVMLError: - assert False - - -@handle_test(fn_tree="num_cpu_cores") -def test_num_cpu_cores(backend_fw): - with BackendHandler.update_backend(backend_fw) as ivy_backend: - # using multiprocessing module too because ivy uses psutil as basis. - p_cpu_cores = psutil.cpu_count() - m_cpu_cores = multiprocessing.cpu_count() - assert type(ivy_backend.num_cpu_cores()) == int - assert ivy_backend.num_cpu_cores() == p_cpu_cores - assert ivy_backend.num_cpu_cores() == m_cpu_cores - - -def _composition_1(backend_fw): - with BackendHandler.update_backend(backend_fw) as ivy_backend: - return ivy_backend.relu().argmax() - - -def _composition_2(backend_fw): - with BackendHandler.update_backend(backend_fw) as ivy_backend: - return ivy_backend.ceil() or ivy_backend.floor() - - -# function_unsupported_devices -@handle_test( - fn_tree="functional.ivy.function_supported_devices", - func=st.sampled_from([_composition_1, _composition_2]), - expected=st.just(["cpu"]), -) -def test_function_supported_devices( - *, - func, - expected, - backend_fw, -): - with BackendHandler.update_backend(backend_fw) as ivy_backend: - res = ivy_backend.function_supported_devices(func) - exp = set(expected) - - assert sorted(tuple(exp)) == sorted(res) - - -# function_unsupported_devices -@handle_test( - fn_tree="functional.ivy.function_supported_devices", - func=st.sampled_from([_composition_1, _composition_2]), - expected=st.just(["gpu", "tpu"]), -) -def test_function_unsupported_devices( - *, - func, - expected, - backend_fw, -): - with BackendHandler.update_backend(backend_fw) as ivy_backend: - res = ivy_backend.function_unsupported_devices(func) - exp = set(expected) - - assert sorted(tuple(exp)) == sorted(res) - - -def get_gpu_mem_usage(backend, device="gpu:0"): - handle = backend.ivy.functional.ivy.device._get_nvml_gpu_handle(device) - info = pynvml.nvmlDeviceGetMemoryInfo(handle) - return (info.used / info.total) * 100 - - -@handle_test(fn_tree="clear_cached_mem_on_dev") -def test_clear_cached_mem_on_dev(backend_fw): - with BackendHandler.update_backend(backend_fw) as ivy_backend: - devices = _get_possible_devices() - for device in devices: - # Testing on only GPU since clearing cache mem is relevant - # for only CUDA devices - if "gpu" in device: - arr = ivy_backend.random_normal( # noqa: F841 - shape=(10000, 1000), dtype="float32", device=device - ) - del arr - before = get_gpu_mem_usage(device) - ivy_backend.clear_cached_mem_on_dev(device) - after = get_gpu_mem_usage(device) - assert before > after - - -def get_cpu_percent(): - output = str(subprocess.check_output(["top", "-bn1"])) - cpu_percent = float(re.search(r"%Cpu\(s\):\s+([\d.]+)\s+us", output).group(1)) - return cpu_percent - - -@handle_test(fn_tree="dev_util") -def test_dev_util(backend_fw): - with BackendHandler.update_backend(backend_fw) as ivy_backend: - devices = _get_possible_devices() - for device in devices: - # The internally called psutil.cpu_percent() has a unique behavior where it - # returns 0 as usage when run the second time in same line so simple - # assert psutil.cpu_percent() == ivy.dev_util(device) isn't possible - if "cpu" in device: - assert 100 >= ivy_backend.dev_util(device) >= 0 - # Comparing CPU utilization using top. Two percentiles won't be directly - # equal but absolute difference should be below a safe threshold - assert abs(get_cpu_percent() - ivy_backend.dev_util(device)) < 10 - elif "gpu" in device: - handle = ivy_backend.functional.ivy.device._get_nvml_gpu_handle(device) - assert ( - ivy_backend.dev_util(device) - == pynvml.nvmlDeviceGetUtilizationRates(handle).gpu - ) - - @handle_test(fn_tree="tpu_is_available") def test_tpu_is_available(backend_fw): with BackendHandler.update_backend(backend_fw) as ivy_backend: @@ -759,3 +746,21 @@ def test_tpu_is_available(backend_fw): ground_truth = False assert ivy_backend.tpu_is_available() == ground_truth + + +@handle_test(fn_tree="used_mem_on_dev") +def test_used_mem_on_dev(backend_fw): + with BackendHandler.update_backend(backend_fw) as ivy_backend: + devices = _get_possible_devices() + + # Check that there not all memory is used + for device in devices: + assert ivy_backend.used_mem_on_dev(device) > 0 + assert ivy_backend.used_mem_on_dev(device) < ivy_backend.total_mem_on_dev( + device + ) + + _ram_array_and_clear_test( + lambda: ivy_backend.used_mem_on_dev(device, process_specific=True), + device=device, + ) diff --git a/ivy_tests/test_ivy/test_functional/test_core/test_dtype.py b/ivy_tests/test_ivy/test_functional/test_core/test_dtype.py index c00217202600b..3017d22943255 100644 --- a/ivy_tests/test_ivy/test_functional/test_core/test_dtype.py +++ b/ivy_tests/test_ivy/test_functional/test_core/test_dtype.py @@ -11,46 +11,43 @@ from ivy_tests.test_ivy.helpers import handle_test, BackendHandler -# dtype objects -@handle_test(fn_tree="functional.ivy.exists") # dummy fn_tree -def test_dtype_instances(backend_fw): - with BackendHandler.update_backend(backend_fw) as ivy_backend: - assert ivy_backend.exists(ivy_backend.int8) - assert ivy_backend.exists(ivy_backend.int16) - assert ivy_backend.exists(ivy_backend.int32) - assert ivy_backend.exists(ivy_backend.int64) - assert ivy_backend.exists(ivy_backend.uint8) - if backend_fw not in ["torch", "paddle", "mxnet"]: - assert ivy_backend.exists(ivy_backend.uint16) - assert ivy_backend.exists(ivy_backend.uint32) - assert ivy_backend.exists(ivy_backend.uint64) - assert ivy_backend.exists(ivy_backend.float32) - assert ivy_backend.exists(ivy_backend.float64) - assert ivy_backend.exists(ivy_backend.complex64) - assert ivy_backend.exists(ivy_backend.complex128) - assert ivy_backend.exists(ivy_backend.bool) - - # for data generation in multiple tests dtype_shared = helpers.get_dtypes("valid", full=False, key="dtype") +# --- Helpers --- # +# --------------- # + + @st.composite -def dtypes_shared(draw, num_dtypes): - if isinstance(num_dtypes, str): - num_dtypes = draw(st.shared(helpers.ints(), key=num_dtypes)) +def _array_or_type(draw, float_or_int): + valid_dtypes = { + "float": draw(helpers.get_dtypes("float")), + "int": draw(helpers.get_dtypes("integer")), + }[float_or_int] return draw( - st.shared( - st.lists( - st.sampled_from(draw(helpers.get_dtypes("valid"))), - min_size=num_dtypes, - max_size=num_dtypes, - ), - key="dtypes", + st.sampled_from( + ( + draw( + helpers.dtype_and_values( + available_dtypes=valid_dtypes, + ) + ), + draw(st.sampled_from(valid_dtypes)), + ) ) ) +def _composition_1(): + return ivy.relu().argmax() + + +def _composition_2(): + a = ivy.floor + return ivy.ceil() or a + + # Array API Standard Function Tests # # --------------------------------- # @@ -73,6 +70,99 @@ def astype_helper(draw): return dtype, x, cast_dtype +# broadcast arrays +@st.composite +def broadcastable_arrays(draw, dtypes): + num_arrays = st.shared(helpers.ints(min_value=2, max_value=5), key="num_arrays") + shapes = draw(num_arrays.flatmap(helpers.mutually_broadcastable_shapes)) + dtypes = draw(dtypes) + arrays = [] + for c, (shape, dtype) in enumerate(zip(shapes, dtypes), 1): + x = draw(helpers.array_values(dtype=dtype, shape=shape), label=f"x{c}").tolist() + arrays.append(x) + return arrays + + +@st.composite +def dtypes_list(draw): + num = draw(st.one_of(helpers.ints(min_value=1, max_value=5))) + return draw( + st.lists( + st.sampled_from(ivy.valid_dtypes), + min_size=num, + max_size=num, + ) + ) + + +@st.composite +def dtypes_shared(draw, num_dtypes): + if isinstance(num_dtypes, str): + num_dtypes = draw(st.shared(helpers.ints(), key=num_dtypes)) + return draw( + st.shared( + st.lists( + st.sampled_from(draw(helpers.get_dtypes("valid"))), + min_size=num_dtypes, + max_size=num_dtypes, + ), + key="dtypes", + ) + ) + + +# --- Main --- # +# ------------ # + + +# as_ivy_dtype +@handle_test( + fn_tree="functional.ivy.as_ivy_dtype", + input_dtype=helpers.get_dtypes("valid", full=False), +) +def test_as_ivy_dtype( + *, + input_dtype, + backend_fw, +): + with BackendHandler.update_backend(backend_fw) as ivy_backend: + input_dtype = input_dtype[0] + res = ivy_backend.as_ivy_dtype(input_dtype) + if isinstance(input_dtype, str): + assert isinstance(res, str) + return + + assert isinstance(input_dtype, ivy_backend.Dtype) or isinstance( + input_dtype, str + ), f"input_dtype={input_dtype!r}, but should be str or ivy.Dtype" + assert isinstance(res, str), f"result={res!r}, but should be str" + + +# as_native_dtype +@handle_test( + fn_tree="functional.ivy.as_native_dtype", + input_dtype=helpers.get_dtypes("valid", full=False), +) +def test_as_native_dtype( + *, + input_dtype, + backend_fw, +): + with BackendHandler.update_backend(backend_fw) as ivy_backend: + input_dtype = input_dtype[0] + res = ivy_backend.as_native_dtype(input_dtype) + if isinstance(input_dtype, ivy_backend.NativeDtype): + assert isinstance(res, ivy_backend.NativeDtype) + return + + assert isinstance(input_dtype, ivy_backend.Dtype) or isinstance( + input_dtype, str + ), f"input_dtype={input_dtype!r}, but should be str or ivy.Dtype" + assert isinstance( + res, ivy_backend.NativeDtype + ), f"result={res!r}, but should be ivy.NativeDtype" + + # astype @handle_test( fn_tree="functional.ivy.astype", @@ -96,19 +186,6 @@ def test_astype( ) -# broadcast arrays -@st.composite -def broadcastable_arrays(draw, dtypes): - num_arrays = st.shared(helpers.ints(min_value=2, max_value=5), key="num_arrays") - shapes = draw(num_arrays.flatmap(helpers.mutually_broadcastable_shapes)) - dtypes = draw(dtypes) - arrays = [] - for c, (shape, dtype) in enumerate(zip(shapes, dtypes), 1): - x = draw(helpers.array_values(dtype=dtype, shape=shape), label=f"x{c}").tolist() - arrays.append(x) - return arrays - - @handle_test( fn_tree="functional.ivy.broadcast_arrays", arrays=broadcastable_arrays(dtypes_shared("num_arrays")), @@ -186,233 +263,163 @@ def test_can_cast(*, dtype_and_x, to_dtype, test_flags, backend_fw, fn_name, on_ ) -@st.composite -def _array_or_type(draw, float_or_int): - valid_dtypes = { - "float": draw(helpers.get_dtypes("float")), - "int": draw(helpers.get_dtypes("integer")), - }[float_or_int] - return draw( - st.sampled_from( - ( - draw( - helpers.dtype_and_values( - available_dtypes=valid_dtypes, - ) - ), - draw(st.sampled_from(valid_dtypes)), - ) +# closest_valid_dtypes +@handle_test( + fn_tree="functional.ivy.closest_valid_dtype", + input_dtype=helpers.get_dtypes("valid", full=False), +) +def test_closest_valid_dtype( + *, input_dtype, test_flags, backend_fw, fn_name, on_device +): + with BackendHandler.update_backend(backend_fw) as ivy_backend: + input_dtype = input_dtype[0] + res = ivy_backend.closest_valid_dtype(input_dtype) + assert isinstance(input_dtype, ivy_backend.Dtype) or isinstance( + input_dtype, str ) - ) + assert isinstance(res, ivy_backend.Dtype) or isinstance( + res, str + ), f"result={res!r}, but should be str or ivy.Dtype" -# finfo +# default_complex_dtype @handle_test( - fn_tree="functional.ivy.finfo", - type=_array_or_type("float"), - test_with_out=st.just(False), - as_variable_flags=st.just([False]), - native_array_flags=st.just([False]), - container_flags=st.just([False]), - test_instance_method=st.just(False), + fn_tree="functional.ivy.default_complex_dtype", + dtype_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("complex")), + as_native=st.booleans(), test_gradients=st.just(False), ) -def test_finfo(*, type, test_flags, backend_fw, fn_name, on_device): - if isinstance(type, str): - input_dtype = [type] - else: - input_dtype, x = type - type = x[0] - ret = helpers.test_function( - input_dtypes=input_dtype, - test_flags=test_flags, - backend_to_test=backend_fw, - fn_name=fn_name, - on_device=on_device, - type=type, - test_values=False, - ) - if not ivy.exists(ret): - return - mach_lims, mach_lims_np = ret - assert np.allclose(mach_lims.min, mach_lims_np.min, rtol=1e-2, atol=1e-2) - assert np.allclose(mach_lims.max, mach_lims_np.max, rtol=1e-2, atol=1e-2) - assert np.allclose(mach_lims.eps, mach_lims_np.eps, rtol=1e-2, atol=1e-2) - assert mach_lims.bits == mach_lims_np.bits +def test_default_complex_dtype( + *, + dtype_x, + as_native, + backend_fw, +): + complex_dtype, x = dtype_x + with BackendHandler.update_backend(backend_fw) as ivy_backend: + res = ivy_backend.default_complex_dtype( + input=input, + complex_dtype=complex_dtype[0], + as_native=as_native, + ) + assert ( + isinstance(res, ivy_backend.Dtype) + or isinstance(res, typing.get_args(ivy_backend.NativeDtype)) + or isinstance(res, ivy_backend.NativeDtype) + or isinstance(res, str) + ) + assert ( + ivy_backend.default_complex_dtype( + input=None, complex_dtype=None, as_native=False + ) + == ivy_backend.complex64 + ) + assert ( + ivy_backend.default_complex_dtype(complex_dtype=ivy_backend.complex64) + == ivy_backend.complex64 + ) + assert ivy_backend.default_complex_dtype() == ivy_backend.complex64 -# iinfo -@handle_test( - fn_tree="functional.ivy.iinfo", - type=_array_or_type("int"), - test_with_out=st.just(False), - as_variable_flags=st.just([False]), - native_array_flags=st.just([False]), - container_flags=st.just([False]), - test_instance_method=st.just(False), - test_gradients=st.just(False), -) -def test_iinfo(*, type, test_flags, backend_fw, fn_name, on_device): - if isinstance(type, str): - input_dtype = [type] - else: - input_dtype, x = type - type = x[0] - ret = helpers.test_function( - input_dtypes=input_dtype, - test_flags=test_flags, - backend_to_test=backend_fw, - fn_name=fn_name, - on_device=on_device, - type=type, - test_values=False, - ) - - if not ivy.exists(ret): - return - mach_lims, mach_lims_np = ret - assert mach_lims.min == mach_lims_np.min - assert mach_lims.max == mach_lims_np.max - assert mach_lims.bits == mach_lims_np.bits - - -# result_type -@handle_test( - fn_tree="functional.ivy.result_type", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - num_arrays=st.shared(helpers.ints(min_value=2, max_value=5), key="num_arrays"), - shared_dtype=False, - ), - test_with_out=st.just(False), - test_gradients=st.just(False), -) -def test_result_type(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): - dtype, x = helpers.as_lists(*dtype_and_x) - kw = {} - for i, (dtype_, x_) in enumerate(zip(dtype, x)): - kw["x{}".format(i)] = x_ - test_flags.num_positional_args = len(kw) - helpers.test_function( - input_dtypes=dtype, - test_flags=test_flags, - backend_to_test=backend_fw, - on_device=on_device, - fn_name=fn_name, - **kw, - ) - - -# Extra Ivy Function Tests # -# ------------------------ # - - -# is_hashable_dtype -@handle_test( - fn_tree="functional.ivy.is_hashable_dtype", - input_dtype=helpers.get_dtypes("valid", full=False), -) -def test_is_hashable_dtype( - *, - input_dtype, - backend_fw, -): - with BackendHandler.update_backend(backend_fw) as ivy_backend: - input_dtype = input_dtype[0] - res = ivy_backend.is_hashable_dtype(input_dtype) - assert res - - -# as_ivy_dtype +# default_dtype @handle_test( - fn_tree="functional.ivy.as_ivy_dtype", + fn_tree="functional.ivy.default_dtype", input_dtype=helpers.get_dtypes("valid", full=False), + as_native=st.booleans(), ) -def test_as_ivy_dtype( +def test_default_dtype( *, input_dtype, + as_native, backend_fw, ): with BackendHandler.update_backend(backend_fw) as ivy_backend: input_dtype = input_dtype[0] - res = ivy_backend.as_ivy_dtype(input_dtype) - if isinstance(input_dtype, str): - assert isinstance(res, str) - return - - assert isinstance(input_dtype, ivy_backend.Dtype) or isinstance( + res = ivy_backend.default_dtype(dtype=input_dtype, as_native=as_native) + assert ( + isinstance(input_dtype, ivy_backend.Dtype) + or isinstance(input_dtype, str) + or isinstance(input_dtype, ivy_backend.NativeDtype) + ) + assert isinstance(res, ivy_backend.Dtype) or isinstance( input_dtype, str ), f"input_dtype={input_dtype!r}, but should be str or ivy.Dtype" - assert isinstance(res, str), f"result={res!r}, but should be str" -# as_native_dtype +# default_float_dtype @handle_test( - fn_tree="functional.ivy.as_native_dtype", - input_dtype=helpers.get_dtypes("valid", full=False), + fn_tree="functional.ivy.default_float_dtype", + dtype_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), + as_native=st.booleans(), + test_gradients=st.just(False), ) -def test_as_native_dtype( +def test_default_float_dtype( *, - input_dtype, + dtype_x, + as_native, backend_fw, ): with BackendHandler.update_backend(backend_fw) as ivy_backend: - input_dtype = input_dtype[0] - res = ivy_backend.as_native_dtype(input_dtype) - if isinstance(input_dtype, ivy_backend.NativeDtype): - assert isinstance(res, ivy_backend.NativeDtype) - return - - assert isinstance(input_dtype, ivy_backend.Dtype) or isinstance( - input_dtype, str - ), f"input_dtype={input_dtype!r}, but should be str or ivy.Dtype" - assert isinstance( - res, ivy_backend.NativeDtype - ), f"result={res!r}, but should be ivy.NativeDtype" - - -# closest_valid_dtypes -@handle_test( - fn_tree="functional.ivy.closest_valid_dtype", - input_dtype=helpers.get_dtypes("valid", full=False), -) -def test_closest_valid_dtype( - *, input_dtype, test_flags, backend_fw, fn_name, on_device -): - with BackendHandler.update_backend(backend_fw) as ivy_backend: - input_dtype = input_dtype[0] - res = ivy_backend.closest_valid_dtype(input_dtype) - assert isinstance(input_dtype, ivy_backend.Dtype) or isinstance( - input_dtype, str + float_dtype, x = dtype_x + res = ivy_backend.default_float_dtype( + input=input, + float_dtype=float_dtype[0], + as_native=as_native, ) - assert isinstance(res, ivy_backend.Dtype) or isinstance( - res, str - ), f"result={res!r}, but should be str or ivy.Dtype" + assert ( + isinstance(res, ivy_backend.Dtype) + or isinstance(res, typing.get_args(ivy_backend.NativeDtype)) + or isinstance(res, ivy_backend.NativeDtype) + or isinstance(res, str) + ) + assert ( + ivy_backend.default_float_dtype( + input=None, float_dtype=None, as_native=False + ) + == ivy_backend.float32 + ) + assert ( + ivy_backend.default_float_dtype(float_dtype=ivy_backend.float16) + == ivy_backend.float16 + ) + assert ivy_backend.default_float_dtype() == ivy_backend.float32 -# default_dtype +# default_int_dtype @handle_test( - fn_tree="functional.ivy.default_dtype", - input_dtype=helpers.get_dtypes("valid", full=False), + fn_tree="functional.ivy.default_int_dtype", + dtype_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("integer")), as_native=st.booleans(), + test_gradients=st.just(False), ) -def test_default_dtype( +def test_default_int_dtype( *, - input_dtype, + dtype_x, as_native, backend_fw, ): + int_dtype, x = dtype_x with BackendHandler.update_backend(backend_fw) as ivy_backend: - input_dtype = input_dtype[0] - res = ivy_backend.default_dtype(dtype=input_dtype, as_native=as_native) + res = ivy_backend.default_int_dtype( + input=input, + int_dtype=int_dtype[0], + as_native=as_native, + ) assert ( - isinstance(input_dtype, ivy_backend.Dtype) - or isinstance(input_dtype, str) - or isinstance(input_dtype, ivy_backend.NativeDtype) + isinstance(res, ivy_backend.Dtype) + or isinstance(res, typing.get_args(ivy_backend.NativeDtype)) + or isinstance(res, ivy_backend.NativeDtype) + or isinstance(res, str) ) - assert isinstance(res, ivy_backend.Dtype) or isinstance( - input_dtype, str - ), f"input_dtype={input_dtype!r}, but should be str or ivy.Dtype" + assert ( + ivy_backend.default_int_dtype(input=None, int_dtype=None, as_native=False) + == ivy_backend.int32 + ) + assert ( + ivy_backend.default_int_dtype(int_dtype=ivy_backend.int16) + == ivy_backend.int16 + ) + assert ivy_backend.default_int_dtype() == ivy_backend.int32 # dtype @@ -483,58 +490,251 @@ def test_dtype_bits(*, input_dtype, test_flags, backend_fw, fn_name, on_device): assert num_bits == num_bits_np -# is_bool_dtype +# dtype objects +@handle_test(fn_tree="functional.ivy.exists") # dummy fn_tree +def test_dtype_instances(backend_fw): + with BackendHandler.update_backend(backend_fw) as ivy_backend: + assert ivy_backend.exists(ivy_backend.int8) + assert ivy_backend.exists(ivy_backend.int16) + assert ivy_backend.exists(ivy_backend.int32) + assert ivy_backend.exists(ivy_backend.int64) + assert ivy_backend.exists(ivy_backend.uint8) + if backend_fw not in ["torch", "paddle", "mxnet"]: + assert ivy_backend.exists(ivy_backend.uint16) + assert ivy_backend.exists(ivy_backend.uint32) + assert ivy_backend.exists(ivy_backend.uint64) + assert ivy_backend.exists(ivy_backend.float32) + assert ivy_backend.exists(ivy_backend.float64) + assert ivy_backend.exists(ivy_backend.complex64) + assert ivy_backend.exists(ivy_backend.complex128) + assert ivy_backend.exists(ivy_backend.bool) + + +# finfo @handle_test( - fn_tree="functional.ivy.is_bool_dtype", - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid", full=False) - ), + fn_tree="functional.ivy.finfo", + type=_array_or_type("float"), test_with_out=st.just(False), + as_variable_flags=st.just([False]), + native_array_flags=st.just([False]), + container_flags=st.just([False]), + test_instance_method=st.just(False), test_gradients=st.just(False), ) -def test_is_bool_dtype(*, dtype_x, test_flags, backend_fw, fn_name, on_device): - dtype, x = dtype_x - helpers.test_function( - input_dtypes=dtype, +def test_finfo(*, type, test_flags, backend_fw, fn_name, on_device): + if isinstance(type, str): + input_dtype = [type] + else: + input_dtype, x = type + type = x[0] + ret = helpers.test_function( + input_dtypes=input_dtype, test_flags=test_flags, backend_to_test=backend_fw, fn_name=fn_name, on_device=on_device, - dtype_in=x[0], + type=type, + test_values=False, ) - - -# is_float_dtype + if not ivy.exists(ret): + return + mach_lims, mach_lims_np = ret + assert np.allclose(mach_lims.min, mach_lims_np.min, rtol=1e-2, atol=1e-2) + assert np.allclose(mach_lims.max, mach_lims_np.max, rtol=1e-2, atol=1e-2) + assert np.allclose(mach_lims.eps, mach_lims_np.eps, rtol=1e-2, atol=1e-2) + assert mach_lims.bits == mach_lims_np.bits + + +# function_dtype_versioning @handle_test( - fn_tree="functional.ivy.is_float_dtype", - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid", full=False) + fn_tree="functional.ivy.function_unsupported_dtypes", # dummy fn_tree + func_and_version=st.just( + [ + { + "torch": { + "cumsum": { + "2.0.1": {"bfloat16", "uint8", "float16"}, + "1.12.1": set(), + } + } + }, + ], + ), +) +def test_function_dtype_versioning( + *, + func_and_version, + backend_fw, +): + with BackendHandler.update_backend(backend_fw) as ivy_backend: + for key in func_and_version: + if key != backend_fw: + continue + var = ivy_backend.backend_version + + # key --> framework + + for key1 in func_and_version[key]: + for key2 in func_and_version[key][key1]: + var["version"] = key2 + fn = getattr(ivy_backend, key1) + expected = func_and_version[key][key1][key2] + res = fn.unsupported_dtypes + if res is None: + res = set() + else: + res = set(res) + if res != expected: + raise Exception + return True + + +# function_dtype_versioning_frontend +@handle_test( + fn_tree="functional.ivy.function_unsupported_dtypes", # dummy fn_tree + func_and_version=st.just( + [ + { + "torch": { + "cumsum": { + "2.0.1": {"bfloat16", "uint8", "float16"}, + "1.12.1": set(), + } + } + }, + ], ), +) +def test_function_dtype_versioning_frontend( + *, + func_and_version, + backend_fw, +): + with BackendHandler.update_backend(backend_fw) as ivy_backend: + _import_mod = ivy_backend.utils.dynamic_import + for key in func_and_version: + if key != backend_fw: + continue + frontend = _import_mod.import_module("ivy.functional.frontends") + var = frontend.versions + + for key1 in func_and_version[key]: + for key2 in func_and_version[key][key1]: + var[backend_fw] = key2 + fn = getattr( + _import_mod.import_module( + "ivy.functional.frontends." + backend_fw + ), + key1, + ) + expected = func_and_version[key][key1][key2] + res = fn.unsupported_dtypes + if res is None: + res = set() + else: + res = set(res) + if res != expected: + raise Exception + return True + + +# function_unsupported_dtypes +@handle_test( + fn_tree="functional.ivy.function_supported_dtypes", + func=st.sampled_from([_composition_1, _composition_2]), +) +def test_function_supported_dtypes(*, func, backend_fw): + with BackendHandler.update_backend(backend_fw) as ivy_backend: + res = ivy_backend.function_supported_dtypes(func) + exp = set(ivy_backend.all_dtypes).difference( + set(func.test_unsupported_dtypes[backend_fw]) + ) + assert set(tuple(exp)) == set(res) + + +# function_unsupported_dtypes +@handle_test( + fn_tree="functional.ivy.function_unsupported_dtypes", + func=st.sampled_from([_composition_2]), +) +def test_function_unsupported_dtypes(*, func, backend_fw): + with BackendHandler.update_backend(backend_fw) as ivy_backend: + res = ivy_backend.function_unsupported_dtypes(func) + exp = func.test_unsupported_dtypes[backend_fw] + assert set(tuple(exp)) == set(res) + + +# iinfo +@handle_test( + fn_tree="functional.ivy.iinfo", + type=_array_or_type("int"), test_with_out=st.just(False), + as_variable_flags=st.just([False]), + native_array_flags=st.just([False]), + container_flags=st.just([False]), + test_instance_method=st.just(False), test_gradients=st.just(False), ) -def test_is_float_dtype(*, dtype_x, test_flags, backend_fw, fn_name, on_device): - dtype, x = dtype_x - helpers.test_function( - input_dtypes=dtype, +def test_iinfo(*, type, test_flags, backend_fw, fn_name, on_device): + if isinstance(type, str): + input_dtype = [type] + else: + input_dtype, x = type + type = x[0] + ret = helpers.test_function( + input_dtypes=input_dtype, test_flags=test_flags, backend_to_test=backend_fw, fn_name=fn_name, on_device=on_device, - dtype_in=x[0], + type=type, + test_values=False, ) + if not ivy.exists(ret): + return + mach_lims, mach_lims_np = ret + assert mach_lims.min == mach_lims_np.min + assert mach_lims.max == mach_lims_np.max + assert mach_lims.bits == mach_lims_np.bits -# is_int_dtype + +# invalid_dtype @handle_test( - fn_tree="functional.ivy.is_int_dtype", + fn_tree="functional.ivy.invalid_dtype", + dtype_in=helpers.get_dtypes("valid", full=False), +) +def test_invalid_dtype( + *, + dtype_in, + backend_fw, +): + with BackendHandler.update_backend(backend_fw) as ivy_backend: + dtype_in = dtype_in[0] + res = ivy_backend.invalid_dtype(dtype_in) + invalid_dtypes = ivy_backend.invalid_dtypes + if dtype_in in invalid_dtypes: + assert res is True, ( + f"fDtype = {dtype_in!r} is a valid dtype for {backend_fw}, butresult =" + f" {res}" + ) + else: + assert res is False, ( + f"fDtype = {dtype_in!r} is not a valid dtype for {backend_fw}," + f" butresult = {res}" + ) + + +# is_bool_dtype +@handle_test( + fn_tree="functional.ivy.is_bool_dtype", dtype_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid", full=False) ), test_with_out=st.just(False), test_gradients=st.just(False), ) -def test_is_int_dtype(*, dtype_x, test_flags, backend_fw, fn_name, on_device): +def test_is_bool_dtype(*, dtype_x, test_flags, backend_fw, fn_name, on_device): dtype, x = dtype_x helpers.test_function( input_dtypes=dtype, @@ -546,16 +746,16 @@ def test_is_int_dtype(*, dtype_x, test_flags, backend_fw, fn_name, on_device): ) -# is_uint_dtype +# is_complex_dtype @handle_test( - fn_tree="functional.ivy.is_uint_dtype", + fn_tree="functional.ivy.is_complex_dtype", dtype_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid", full=False) ), test_with_out=st.just(False), test_gradients=st.just(False), ) -def test_is_uint_dtype(*, dtype_x, test_flags, backend_fw, fn_name, on_device): +def test_is_complex_dtype(*, dtype_x, test_flags, backend_fw, fn_name, on_device): dtype, x = dtype_x helpers.test_function( input_dtypes=dtype, @@ -567,16 +767,16 @@ def test_is_uint_dtype(*, dtype_x, test_flags, backend_fw, fn_name, on_device): ) -# is_complex_dtype +# is_float_dtype @handle_test( - fn_tree="functional.ivy.is_complex_dtype", + fn_tree="functional.ivy.is_float_dtype", dtype_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid", full=False) ), test_with_out=st.just(False), test_gradients=st.just(False), ) -def test_is_complex_dtype(*, dtype_x, test_flags, backend_fw, fn_name, on_device): +def test_is_float_dtype(*, dtype_x, test_flags, backend_fw, fn_name, on_device): dtype, x = dtype_x helpers.test_function( input_dtypes=dtype, @@ -588,375 +788,186 @@ def test_is_complex_dtype(*, dtype_x, test_flags, backend_fw, fn_name, on_device ) -# promote_types -# TODO: fix instance method +# Extra Ivy Function Tests # +# ------------------------ # + + +# is_hashable_dtype @handle_test( - fn_tree="functional.ivy.promote_types", - type1=helpers.get_dtypes("valid", full=False), - type2=helpers.get_dtypes("valid", full=False), - test_with_out=st.just(False), - test_instance_method=st.just(False), - test_gradients=st.just(False), + fn_tree="functional.ivy.is_hashable_dtype", + input_dtype=helpers.get_dtypes("valid", full=False), ) -def test_promote_types(*, type1, type2, test_flags, backend_fw, fn_name, on_device): - helpers.test_function( - input_dtypes=[], - test_flags=test_flags, - backend_to_test=backend_fw, - fn_name=fn_name, - on_device=on_device, - type1=type1[0], - type2=type2[0], - test_values=False, - ) +def test_is_hashable_dtype( + *, + input_dtype, + backend_fw, +): + with BackendHandler.update_backend(backend_fw) as ivy_backend: + input_dtype = input_dtype[0] + res = ivy_backend.is_hashable_dtype(input_dtype) + assert res -# type_promote_arrays -# TODO: fix container method +# is_int_dtype @handle_test( - fn_tree="functional.ivy.type_promote_arrays", - dtype_and_values=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - num_arrays=2, - shared_dtype=False, + fn_tree="functional.ivy.is_int_dtype", + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid", full=False) ), test_with_out=st.just(False), - container_flags=st.just([False]), test_gradients=st.just(False), ) -def test_type_promote_arrays( - *, dtype_and_values, test_flags, backend_fw, fn_name, on_device -): - types, arrays = dtype_and_values +def test_is_int_dtype(*, dtype_x, test_flags, backend_fw, fn_name, on_device): + dtype, x = dtype_x helpers.test_function( - input_dtypes=types, + input_dtypes=dtype, test_flags=test_flags, backend_to_test=backend_fw, fn_name=fn_name, on_device=on_device, - x1=arrays[0], - x2=arrays[1], - test_values=True, + dtype_in=x[0], ) -# default_float_dtype +# is_native_dtype @handle_test( - fn_tree="functional.ivy.default_float_dtype", - dtype_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), - as_native=st.booleans(), - test_gradients=st.just(False), + fn_tree="functional.ivy.is_native_dtype", + input_dtype=helpers.get_dtypes("valid", full=False), ) -def test_default_float_dtype( - *, - dtype_x, - as_native, +def test_is_native_dtype( + input_dtype, backend_fw, ): with BackendHandler.update_backend(backend_fw) as ivy_backend: - float_dtype, x = dtype_x - res = ivy_backend.default_float_dtype( - input=input, - float_dtype=float_dtype[0], - as_native=as_native, - ) - assert ( - isinstance(res, ivy_backend.Dtype) - or isinstance(res, typing.get_args(ivy_backend.NativeDtype)) - or isinstance(res, ivy_backend.NativeDtype) - or isinstance(res, str) - ) - assert ( - ivy_backend.default_float_dtype( - input=None, float_dtype=None, as_native=False - ) - == ivy_backend.float32 - ) - assert ( - ivy_backend.default_float_dtype(float_dtype=ivy_backend.float16) - == ivy_backend.float16 - ) - assert ivy_backend.default_float_dtype() == ivy_backend.float32 - + input_dtype = input_dtype[0] + if isinstance(input_dtype, str): + assert ivy_backend.is_native_dtype(input_dtype) is False -# default_int_dtype -@handle_test( - fn_tree="functional.ivy.default_int_dtype", - dtype_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("integer")), - as_native=st.booleans(), - test_gradients=st.just(False), -) -def test_default_int_dtype( - *, - dtype_x, - as_native, - backend_fw, -): - int_dtype, x = dtype_x - with BackendHandler.update_backend(backend_fw) as ivy_backend: - res = ivy_backend.default_int_dtype( - input=input, - int_dtype=int_dtype[0], - as_native=as_native, - ) assert ( - isinstance(res, ivy_backend.Dtype) - or isinstance(res, typing.get_args(ivy_backend.NativeDtype)) - or isinstance(res, ivy_backend.NativeDtype) - or isinstance(res, str) - ) - assert ( - ivy_backend.default_int_dtype(input=None, int_dtype=None, as_native=False) - == ivy_backend.int32 - ) - assert ( - ivy_backend.default_int_dtype(int_dtype=ivy_backend.int16) - == ivy_backend.int16 + ivy_backend.is_native_dtype(ivy_backend.as_native_dtype(input_dtype)) + is True ) - assert ivy_backend.default_int_dtype() == ivy_backend.int32 -# default_complex_dtype +# is_uint_dtype @handle_test( - fn_tree="functional.ivy.default_complex_dtype", - dtype_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("complex")), - as_native=st.booleans(), + fn_tree="functional.ivy.is_uint_dtype", + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid", full=False) + ), + test_with_out=st.just(False), test_gradients=st.just(False), ) -def test_default_complex_dtype( - *, - dtype_x, - as_native, - backend_fw, -): - complex_dtype, x = dtype_x - with BackendHandler.update_backend(backend_fw) as ivy_backend: - res = ivy_backend.default_complex_dtype( - input=input, - complex_dtype=complex_dtype[0], - as_native=as_native, - ) - assert ( - isinstance(res, ivy_backend.Dtype) - or isinstance(res, typing.get_args(ivy_backend.NativeDtype)) - or isinstance(res, ivy_backend.NativeDtype) - or isinstance(res, str) - ) - assert ( - ivy_backend.default_complex_dtype( - input=None, complex_dtype=None, as_native=False - ) - == ivy_backend.complex64 - ) - assert ( - ivy_backend.default_complex_dtype(complex_dtype=ivy_backend.complex64) - == ivy_backend.complex64 - ) - assert ivy_backend.default_complex_dtype() == ivy_backend.complex64 - - -@st.composite -def dtypes_list(draw): - num = draw(st.one_of(helpers.ints(min_value=1, max_value=5))) - return draw( - st.lists( - st.sampled_from(ivy.valid_dtypes), - min_size=num, - max_size=num, - ) +def test_is_uint_dtype(*, dtype_x, test_flags, backend_fw, fn_name, on_device): + dtype, x = dtype_x + helpers.test_function( + input_dtypes=dtype, + test_flags=test_flags, + backend_to_test=backend_fw, + fn_name=fn_name, + on_device=on_device, + dtype_in=x[0], ) -def _composition_1(): - return ivy.relu().argmax() - - -_composition_1.test_unsupported_dtypes = { - "numpy": ("bfloat16",), - "jax": ("complex64", "complex128"), - "tensorflow": ("complex64", "complex128"), - "torch": ( - "uint16", - "uint32", - "uint64", - "float16", - "complex64", - "complex128", - ), - "paddle": ( - "uint16", - "uint32", - "uint64", - "bfloat16", - "complex64", - "complex128", - ), - "mxnet": ("uint16", "uint32", "uint64", "complex64", "complex128"), -} - - -def _composition_2(): - a = ivy.floor - return ivy.ceil() or a - - -_composition_2.test_unsupported_dtypes = { - "numpy": ("bfloat16", "complex64", "complex128"), - "jax": ("complex64", "complex128"), - "tensorflow": ("complex64", "complex128"), - "torch": ("uint16", "uint32", "uint64", "float16", "complex64", "complex128"), - "paddle": ( - "uint16", - "uint32", - "uint64", - "bfloat16", - ), -} - - -# function_unsupported_dtypes -@handle_test( - fn_tree="functional.ivy.function_supported_dtypes", - func=st.sampled_from([_composition_1, _composition_2]), -) -def test_function_supported_dtypes(*, func, backend_fw): - with BackendHandler.update_backend(backend_fw) as ivy_backend: - res = ivy_backend.function_supported_dtypes(func) - exp = set(ivy_backend.all_dtypes).difference( - set(func.test_unsupported_dtypes[backend_fw]) - ) - assert set(tuple(exp)) == set(res) - - -# function_unsupported_dtypes +# promote_types +# TODO: fix instance method @handle_test( - fn_tree="functional.ivy.function_unsupported_dtypes", - func=st.sampled_from([_composition_2]), + fn_tree="functional.ivy.promote_types", + type1=helpers.get_dtypes("valid", full=False), + type2=helpers.get_dtypes("valid", full=False), + test_with_out=st.just(False), + test_instance_method=st.just(False), + test_gradients=st.just(False), ) -def test_function_unsupported_dtypes(*, func, backend_fw): - with BackendHandler.update_backend(backend_fw) as ivy_backend: - res = ivy_backend.function_unsupported_dtypes(func) - exp = func.test_unsupported_dtypes[backend_fw] - assert set(tuple(exp)) == set(res) +def test_promote_types(*, type1, type2, test_flags, backend_fw, fn_name, on_device): + helpers.test_function( + input_dtypes=[], + test_flags=test_flags, + backend_to_test=backend_fw, + fn_name=fn_name, + on_device=on_device, + type1=type1[0], + type2=type2[0], + test_values=False, + ) -# function_dtype_versioning +# result_type @handle_test( - fn_tree="functional.ivy.function_unsupported_dtypes", # dummy fn_tree - func_and_version=st.just( - [ - { - "torch": { - "cumsum": { - "2.0.1": {"bfloat16", "uint8", "float16"}, - "1.12.1": set(), - } - } - }, - ], + fn_tree="functional.ivy.result_type", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + num_arrays=st.shared(helpers.ints(min_value=2, max_value=5), key="num_arrays"), + shared_dtype=False, ), + test_with_out=st.just(False), + test_gradients=st.just(False), ) -def test_function_dtype_versioning( - *, - func_and_version, - backend_fw, -): - with BackendHandler.update_backend(backend_fw) as ivy_backend: - for key in func_and_version: - if key != backend_fw: - continue - var = ivy_backend.backend_version - - # key --> framework - - for key1 in func_and_version[key]: - for key2 in func_and_version[key][key1]: - var["version"] = key2 - fn = getattr(ivy_backend, key1) - expected = func_and_version[key][key1][key2] - res = fn.unsupported_dtypes - if res is None: - res = set() - else: - res = set(res) - if res != expected: - raise Exception - return True +def test_result_type(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): + dtype, x = helpers.as_lists(*dtype_and_x) + kw = {} + for i, (dtype_, x_) in enumerate(zip(dtype, x)): + kw["x{}".format(i)] = x_ + test_flags.num_positional_args = len(kw) + helpers.test_function( + input_dtypes=dtype, + test_flags=test_flags, + backend_to_test=backend_fw, + on_device=on_device, + fn_name=fn_name, + **kw, + ) -# function_dtype_versioning_frontend +# type_promote_arrays +# TODO: fix container method @handle_test( - fn_tree="functional.ivy.function_unsupported_dtypes", # dummy fn_tree - func_and_version=st.just( - [ - { - "torch": { - "cumsum": { - "2.0.1": {"bfloat16", "uint8", "float16"}, - "1.12.1": set(), - } - } - }, - ], + fn_tree="functional.ivy.type_promote_arrays", + dtype_and_values=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + num_arrays=2, + shared_dtype=False, ), + test_with_out=st.just(False), + container_flags=st.just([False]), + test_gradients=st.just(False), ) -def test_function_dtype_versioning_frontend( - *, - func_and_version, - backend_fw, +def test_type_promote_arrays( + *, dtype_and_values, test_flags, backend_fw, fn_name, on_device ): - with BackendHandler.update_backend(backend_fw) as ivy_backend: - _import_mod = ivy_backend.utils.dynamic_import - for key in func_and_version: - if key != backend_fw: - continue - frontend = _import_mod.import_module("ivy.functional.frontends") - var = frontend.versions - - for key1 in func_and_version[key]: - for key2 in func_and_version[key][key1]: - var[backend_fw] = key2 - fn = getattr( - _import_mod.import_module( - "ivy.functional.frontends." + backend_fw - ), - key1, - ) - expected = func_and_version[key][key1][key2] - res = fn.unsupported_dtypes - if res is None: - res = set() - else: - res = set(res) - if res != expected: - raise Exception - return True + types, arrays = dtype_and_values + helpers.test_function( + input_dtypes=types, + test_flags=test_flags, + backend_to_test=backend_fw, + fn_name=fn_name, + on_device=on_device, + x1=arrays[0], + x2=arrays[1], + test_values=True, + ) -# invalid_dtype +# unset_default_complex_dtype @handle_test( - fn_tree="functional.ivy.invalid_dtype", - dtype_in=helpers.get_dtypes("valid", full=False), + fn_tree="functional.ivy.unset_default_complex_dtype", + dtype=helpers.get_dtypes("complex", full=False), ) -def test_invalid_dtype( +def test_unset_default_complex_dtype( *, - dtype_in, + dtype, backend_fw, ): with BackendHandler.update_backend(backend_fw) as ivy_backend: - dtype_in = dtype_in[0] - res = ivy_backend.invalid_dtype(dtype_in) - invalid_dtypes = ivy_backend.invalid_dtypes - if dtype_in in invalid_dtypes: - assert res is True, ( - f"fDtype = {dtype_in!r} is a valid dtype for {backend_fw}, butresult =" - f" {res}" - ) - else: - assert res is False, ( - f"fDtype = {dtype_in!r} is not a valid dtype for {backend_fw}," - f" butresult = {res}" - ) + dtype = dtype[0] + stack_size_before = len(ivy_backend.default_complex_dtype_stack) + ivy_backend.set_default_complex_dtype(dtype) + ivy_backend.unset_default_complex_dtype() + stack_size_after = len(ivy_backend.default_complex_dtype_stack) + assert ( + stack_size_before == stack_size_after + ), f"Default float dtype not unset. Stack size= {stack_size_after!r}" # unset_default_dtype @@ -1022,27 +1033,6 @@ def test_unset_default_int_dtype( ), f"Default int dtype not unset. Stack size= {stack_size_after!r}" -# unset_default_complex_dtype -@handle_test( - fn_tree="functional.ivy.unset_default_complex_dtype", - dtype=helpers.get_dtypes("complex", full=False), -) -def test_unset_default_complex_dtype( - *, - dtype, - backend_fw, -): - with BackendHandler.update_backend(backend_fw) as ivy_backend: - dtype = dtype[0] - stack_size_before = len(ivy_backend.default_complex_dtype_stack) - ivy_backend.set_default_complex_dtype(dtype) - ivy_backend.unset_default_complex_dtype() - stack_size_after = len(ivy_backend.default_complex_dtype_stack) - assert ( - stack_size_before == stack_size_after - ), f"Default float dtype not unset. Stack size= {stack_size_after!r}" - - # valid_dtype @handle_test( fn_tree="functional.ivy.valid_dtype", @@ -1069,21 +1059,37 @@ def test_valid_dtype( ) -# is_native_dtype -@handle_test( - fn_tree="functional.ivy.is_native_dtype", - input_dtype=helpers.get_dtypes("valid", full=False), -) -def test_is_native_dtype( - input_dtype, - backend_fw, -): - with BackendHandler.update_backend(backend_fw) as ivy_backend: - input_dtype = input_dtype[0] - if isinstance(input_dtype, str): - assert ivy_backend.is_native_dtype(input_dtype) is False - - assert ( - ivy_backend.is_native_dtype(ivy_backend.as_native_dtype(input_dtype)) - is True - ) +_composition_1.test_unsupported_dtypes = { + "numpy": ("bfloat16",), + "jax": ("complex64", "complex128"), + "tensorflow": ("complex64", "complex128"), + "torch": ( + "uint16", + "uint32", + "uint64", + "float16", + "complex64", + "complex128", + ), + "paddle": ( + "uint16", + "uint32", + "uint64", + "bfloat16", + "complex64", + "complex128", + ), + "mxnet": ("uint16", "uint32", "uint64", "complex64", "complex128"), +} +_composition_2.test_unsupported_dtypes = { + "numpy": ("bfloat16", "complex64", "complex128"), + "jax": ("complex64", "complex128"), + "tensorflow": ("complex64", "complex128"), + "torch": ("uint16", "uint32", "uint64", "float16", "complex64", "complex128"), + "paddle": ( + "uint16", + "uint32", + "uint64", + "bfloat16", + ), +} diff --git a/ivy_tests/test_ivy/test_functional/test_core/test_elementwise.py b/ivy_tests/test_ivy/test_functional/test_core/test_elementwise.py index 161ec52a7628e..a4f19e0de9668 100644 --- a/ivy_tests/test_ivy/test_functional/test_core/test_elementwise.py +++ b/ivy_tests/test_ivy/test_functional/test_core/test_elementwise.py @@ -12,13 +12,12 @@ from ivy_tests.test_ivy.helpers.pipeline_helper import BackendHandler import ivy_tests.test_ivy.helpers.globals as test_globals -_zero = np.asarray(0, dtype="uint8") _one = np.asarray(1, dtype="uint8") +_zero = np.asarray(0, dtype="uint8") -def not_too_close_to_zero(x): - f = np.vectorize(lambda item: item + (_one if np.isclose(item, 0) else _zero)) - return f(x) +# --- Helpers --- # +# --------------- # # this is not used yet and will be used when ``where`` argument is added @@ -36,6 +35,123 @@ def _array_with_mask(draw): return ([dtype[0], dtype2[0]], x, where) +# trapz +@st.composite +def _either_x_dx(draw): + rand = (draw(st.integers(min_value=0, max_value=1)),) + if rand == 0: + either_x_dx = draw( + helpers.dtype_and_values( + avaliable_dtypes=st.shared( + helpers.get_dtypes("float"), key="trapz_dtype" + ), + min_value=-100, + max_value=100, + min_num_dims=1, + max_num_dims=3, + min_dim_size=1, + max_dim_size=3, + ) + ) + return rand, either_x_dx + else: + either_x_dx = draw( + st.floats(min_value=-10, max_value=10), + ) + return rand, either_x_dx + + +@st.composite +def min_max_helper(draw): + use_where = draw(st.booleans()) + if use_where: + dtype_and_x = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, + small_abs_safety_factor=6, + large_abs_safety_factor=6, + safety_factor_scale="log", + ) + ) + else: + dtype_and_x = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, + min_value=-1e5, + max_value=1e5, + safety_factor_scale="log", + ) + ) + return dtype_and_x, use_where + + +@st.composite +def pow_helper(draw, available_dtypes=None): + if available_dtypes is None: + available_dtypes = helpers.get_dtypes("numeric") + dtype1, x1 = draw( + helpers.dtype_and_values( + available_dtypes=available_dtypes, + small_abs_safety_factor=16, + large_abs_safety_factor=16, + safety_factor_scale="log", + ) + ) + dtype1 = dtype1[0] + + def cast_filter(dtype1_x1_dtype2): + dtype1, _, dtype2 = dtype1_x1_dtype2 + with BackendHandler.update_backend(test_globals.CURRENT_BACKEND) as ivy_backend: + if ivy_backend.can_cast(dtype1, dtype2): + return True + return False + + dtype1, x1, dtype2 = draw( + helpers.get_castable_dtype(draw(available_dtypes), dtype1, x1).filter( + cast_filter + ) + ) + with BackendHandler.update_backend(test_globals.CURRENT_BACKEND) as ivy_backend: + if ivy_backend.is_int_dtype(dtype2): + max_val = ivy_backend.iinfo(dtype2).max + else: + max_val = ivy_backend.finfo(dtype2).max + + max_x1 = np.max(np.abs(x1[0])) + if max_x1 in [0, 1]: + max_value = None + else: + max_value = int(math.log(max_val) / math.log(max_x1)) + if abs(max_value) > abs(max_val) / 40 or max_value < 0: + max_value = None + dtype2, x2 = draw( + helpers.dtype_and_values( + small_abs_safety_factor=16, + large_abs_safety_factor=16, + safety_factor_scale="log", + max_value=max_value, + dtype=[dtype2], + ) + ) + dtype2 = dtype2[0] + if "int" in dtype2: + x2 = ivy.nested_map( + x2[0], lambda x: abs(x), include_derived={list: True}, shallow=False + ) + return [dtype1, dtype2], [x1, x2] + + +# --- Main --- # +# ------------ # + + +def not_too_close_to_zero(x): + f = np.vectorize(lambda item: item + (_one if np.isclose(item, 0) else _zero)) + return f(x) + + # abs @handle_test( fn_tree="functional.ivy.abs", @@ -53,17 +169,16 @@ def test_abs(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): ) -# acosh +# acos @handle_test( - fn_tree="functional.ivy.acosh", + fn_tree="functional.ivy.acos", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - min_value=1, large_abs_safety_factor=4, small_abs_safety_factor=4, ), ) -def test_acosh(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): +def test_acos(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): input_dtype, x = dtype_and_x helpers.test_function( input_dtypes=input_dtype, @@ -77,16 +192,17 @@ def test_acosh(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): ) -# acos +# acosh @handle_test( - fn_tree="functional.ivy.acos", + fn_tree="functional.ivy.acosh", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), + min_value=1, large_abs_safety_factor=4, small_abs_safety_factor=4, ), ) -def test_acos(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): +def test_acosh(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): input_dtype, x = dtype_and_x helpers.test_function( input_dtypes=input_dtype, @@ -129,6 +245,44 @@ def test_add(*, dtype_and_x, alpha, test_flags, backend_fw, fn_name, on_device): ) +# angle +@handle_test( + fn_tree="functional.ivy.angle", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=["float64"], + min_value=-5, + max_value=5, + max_dim_size=5, + max_num_dims=5, + min_dim_size=1, + min_num_dims=1, + allow_inf=False, + allow_nan=False, + ), + deg=st.booleans(), + test_gradients=st.just(False), +) +def test_angle( + *, + dtype_and_x, + deg, + test_flags, + backend_fw, + fn_name, + on_device, +): + input_dtype, z = dtype_and_x + helpers.test_function( + input_dtypes=input_dtype, + test_flags=test_flags, + backend_to_test=backend_fw, + fn_name=fn_name, + on_device=on_device, + z=z[0], + deg=deg, + ) + + # asin @handle_test( fn_tree="functional.ivy.asin", @@ -264,6 +418,27 @@ def test_bitwise_and(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device) ) +# bitwise_invert +@handle_test( + fn_tree="functional.ivy.bitwise_invert", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=st.one_of(st.just(("bool",)), helpers.get_dtypes("integer")), + array_api_dtypes=True, + ), + test_gradients=st.just(False), +) +def test_bitwise_invert(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): + input_dtype, x = dtype_and_x + helpers.test_function( + input_dtypes=input_dtype, + test_flags=test_flags, + backend_to_test=backend_fw, + fn_name=fn_name, + on_device=on_device, + x=x[0], + ) + + # bitwise_left_shift @handle_test( fn_tree="functional.ivy.bitwise_left_shift", @@ -306,27 +481,6 @@ def test_bitwise_left_shift(*, dtype_and_x, test_flags, backend_fw, fn_name, on_ ) -# bitwise_invert -@handle_test( - fn_tree="functional.ivy.bitwise_invert", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=st.one_of(st.just(("bool",)), helpers.get_dtypes("integer")), - array_api_dtypes=True, - ), - test_gradients=st.just(False), -) -def test_bitwise_invert(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): - input_dtype, x = dtype_and_x - helpers.test_function( - input_dtypes=input_dtype, - test_flags=test_flags, - backend_to_test=backend_fw, - fn_name=fn_name, - on_device=on_device, - x=x[0], - ) - - # bitwise_or @handle_test( fn_tree="functional.ivy.bitwise_or", @@ -460,6 +614,24 @@ def test_cosh(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): ) +@handle_test( + fn_tree="functional.ivy.deg2rad", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric") + ), +) +def test_deg2rad(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): + input_dtype, x = dtype_and_x + helpers.test_function( + input_dtypes=input_dtype, + test_flags=test_flags, + backend_to_test=backend_fw, + fn_name=fn_name, + on_device=on_device, + x=x[0], + ) + + # divide @handle_test( fn_tree="functional.ivy.divide", @@ -508,6 +680,29 @@ def test_equal(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): ) +# Extra # +# ------# + + +# erf +@handle_test( + fn_tree="functional.ivy.erf", + dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), +) +def test_erf(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): + input_dtype, x = dtype_and_x + helpers.test_function( + input_dtypes=input_dtype, + test_flags=test_flags, + backend_to_test=backend_fw, + fn_name=fn_name, + on_device=on_device, + rtol_=1e-2, + atol_=1e-2, + x=x[0], + ) + + # exp @handle_test( fn_tree="functional.ivy.exp", @@ -649,7 +844,68 @@ def test_fmin(dtype_and_x, test_flags, backend_fw, fn_name, on_device): ) -# greater +# fmod +@handle_test( + fn_tree="functional.ivy.fmod", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + num_arrays=2, + shared_dtype=False, + large_abs_safety_factor=6, + small_abs_safety_factor=6, + safety_factor_scale="log", + ), + test_gradients=st.just(False), +) +def test_fmod(dtype_and_x, test_flags, backend_fw, fn_name, on_device): + input_dtype, x = dtype_and_x + # Make sure values is not too close to zero + assume(not np.any(np.isclose(x[0], 0))) + assume(not np.any(np.isclose(x[1], 0))) + # jax raises inconsistent gradients for negative numbers in x1 + if (np.any(x[0] < 0) or np.any(x[1] < 0)) and ivy.current_backend_str() == "jax": + test_flags.test_gradients = False + test_flags.as_variable = [test_flags.as_variable, False] + helpers.test_function( + input_dtypes=input_dtype, + test_flags=test_flags, + on_device=on_device, + backend_to_test=backend_fw, + fn_name=fn_name, + x1=x[0], + x2=x[1], + ) + + +# gcd +@handle_test( + fn_tree="functional.ivy.gcd", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("integer"), + num_arrays=2, + shared_dtype=False, + min_num_dims=1, + max_num_dims=3, + min_value=-100, + max_value=100, + allow_nan=False, + ), + test_gradients=st.just(False), +) +def test_gcd(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): + input_dtype, x = dtype_and_x + helpers.test_function( + input_dtypes=input_dtype, + test_flags=test_flags, + backend_to_test=backend_fw, + fn_name=fn_name, + on_device=on_device, + x1=x[0], + x2=x[1], + ) + + +# greater @handle_test( fn_tree="functional.ivy.greater", dtype_and_x=helpers.dtype_and_values( @@ -697,6 +953,41 @@ def test_greater_equal(*, dtype_and_x, test_flags, backend_fw, fn_name, on_devic ) +# imag +@handle_test( + fn_tree="functional.ivy.imag", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + min_value=-5, + max_value=5, + max_dim_size=5, + max_num_dims=5, + min_dim_size=1, + min_num_dims=1, + allow_inf=False, + allow_nan=False, + ), + test_gradients=st.just(False), +) +def test_imag( + *, + dtype_and_x, + test_flags, + backend_fw, + fn_name, + on_device, +): + input_dtype, x = dtype_and_x + helpers.test_function( + input_dtypes=input_dtype, + test_flags=test_flags, + backend_to_test=backend_fw, + fn_name=fn_name, + on_device=on_device, + val=x[0], + ) + + # isfinite @handle_test( fn_tree="functional.ivy.isfinite", @@ -770,6 +1061,54 @@ def test_isnan(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): ) +# isreal +@handle_test( + fn_tree="functional.ivy.isreal", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("real_and_complex") + ), + test_gradients=st.just(False), +) +def test_isreal(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): + input_dtype, x = dtype_and_x + helpers.test_function( + input_dtypes=input_dtype, + test_flags=test_flags, + backend_to_test=backend_fw, + fn_name=fn_name, + on_device=on_device, + x=x[0], + ) + + +# lcm +@handle_test( + fn_tree="functional.ivy.lcm", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=["int16", "int32", "int64"], + num_arrays=2, + shared_dtype=False, + min_num_dims=1, + max_num_dims=3, + min_value=-100, + max_value=100, + allow_nan=False, + ), + test_gradients=st.just(False), +) +def test_lcm(dtype_and_x, test_flags, backend_fw, fn_name, on_device): + input_dtype, x = dtype_and_x + helpers.test_function( + input_dtypes=input_dtype, + test_flags=test_flags, + on_device=on_device, + backend_to_test=backend_fw, + fn_name=fn_name, + x1=x[0], + x2=x[1], + ) + + # less @handle_test( fn_tree="functional.ivy.less", @@ -841,16 +1180,12 @@ def test_log(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): ) -# log1p +# log10 @handle_test( - fn_tree="functional.ivy.log1p", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - small_abs_safety_factor=2, - safety_factor_scale="log", - ), + fn_tree="functional.ivy.log10", + dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), ) -def test_log1p(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): +def test_log10(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): input_dtype, x = dtype_and_x # avoid logging values too close to zero assume(not np.any(np.isclose(x[0], 0))) @@ -860,16 +1195,22 @@ def test_log1p(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): backend_to_test=backend_fw, fn_name=fn_name, on_device=on_device, + rtol_=1e-2, + atol_=1e-2, x=x[0], ) -# log2 +# log1p @handle_test( - fn_tree="functional.ivy.log2", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), + fn_tree="functional.ivy.log1p", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + small_abs_safety_factor=2, + safety_factor_scale="log", + ), ) -def test_log2(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): +def test_log1p(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): input_dtype, x = dtype_and_x # avoid logging values too close to zero assume(not np.any(np.isclose(x[0], 0))) @@ -879,17 +1220,16 @@ def test_log2(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): backend_to_test=backend_fw, fn_name=fn_name, on_device=on_device, - rtol_=1e-2, x=x[0], ) -# log10 +# log2 @handle_test( - fn_tree="functional.ivy.log10", + fn_tree="functional.ivy.log2", dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), ) -def test_log10(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): +def test_log2(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): input_dtype, x = dtype_and_x # avoid logging values too close to zero assume(not np.any(np.isclose(x[0], 0))) @@ -900,7 +1240,6 @@ def test_log10(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): fn_name=fn_name, on_device=on_device, rtol_=1e-2, - atol_=1e-2, x=x[0], ) @@ -1045,43 +1384,70 @@ def test_logical_xor(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device) ) -# multiply +# maximum @handle_test( - fn_tree="functional.ivy.multiply", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2 - ), + fn_tree="functional.ivy.maximum", + dtype_and_x_and_use_where=min_max_helper(), ) -def test_multiply(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): - input_dtype, x = dtype_and_x +def test_maximum( + *, dtype_and_x_and_use_where, test_flags, backend_fw, fn_name, on_device +): + (input_dtype, x), use_where = dtype_and_x_and_use_where + helpers.test_function( + input_dtypes=input_dtype, + test_flags=test_flags, + backend_to_test=backend_fw, + fn_name=fn_name, + on_device=on_device, + rtol_=1e-2, + atol_=1e-2, + x1=x[0], + x2=x[1], + use_where=use_where, + ) + +# minimum +@handle_test( + fn_tree="functional.ivy.minimum", + dtype_and_x_and_use_where=min_max_helper(), +) +def test_minimum( + *, dtype_and_x_and_use_where, test_flags, backend_fw, fn_name, on_device +): + (input_dtype, x), use_where = dtype_and_x_and_use_where helpers.test_function( input_dtypes=input_dtype, test_flags=test_flags, backend_to_test=backend_fw, fn_name=fn_name, on_device=on_device, + rtol_=1e-2, + atol_=1e-2, x1=x[0], x2=x[1], + use_where=use_where, ) -# negative +# multiply @handle_test( - fn_tree="functional.ivy.negative", + fn_tree="functional.ivy.multiply", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric") + available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2 ), ) -def test_negative(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): +def test_multiply(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): input_dtype, x = dtype_and_x + helpers.test_function( input_dtypes=input_dtype, test_flags=test_flags, backend_to_test=backend_fw, fn_name=fn_name, on_device=on_device, - x=x[0], + x1=x[0], + x2=x[1], ) @@ -1130,6 +1496,25 @@ def test_nan_to_num( ) +# negative +@handle_test( + fn_tree="functional.ivy.negative", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric") + ), +) +def test_negative(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): + input_dtype, x = dtype_and_x + helpers.test_function( + input_dtypes=input_dtype, + test_flags=test_flags, + backend_to_test=backend_fw, + fn_name=fn_name, + on_device=on_device, + x=x[0], + ) + + # not_equal @handle_test( fn_tree="functional.ivy.not_equal", @@ -1171,62 +1556,6 @@ def test_positive(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): ) -@st.composite -def pow_helper(draw, available_dtypes=None): - if available_dtypes is None: - available_dtypes = helpers.get_dtypes("numeric") - dtype1, x1 = draw( - helpers.dtype_and_values( - available_dtypes=available_dtypes, - small_abs_safety_factor=16, - large_abs_safety_factor=16, - safety_factor_scale="log", - ) - ) - dtype1 = dtype1[0] - - def cast_filter(dtype1_x1_dtype2): - dtype1, _, dtype2 = dtype1_x1_dtype2 - with BackendHandler.update_backend(test_globals.CURRENT_BACKEND) as ivy_backend: - if ivy_backend.can_cast(dtype1, dtype2): - return True - return False - - dtype1, x1, dtype2 = draw( - helpers.get_castable_dtype(draw(available_dtypes), dtype1, x1).filter( - cast_filter - ) - ) - with BackendHandler.update_backend(test_globals.CURRENT_BACKEND) as ivy_backend: - if ivy_backend.is_int_dtype(dtype2): - max_val = ivy_backend.iinfo(dtype2).max - else: - max_val = ivy_backend.finfo(dtype2).max - - max_x1 = np.max(np.abs(x1[0])) - if max_x1 in [0, 1]: - max_value = None - else: - max_value = int(math.log(max_val) / math.log(max_x1)) - if abs(max_value) > abs(max_val) / 40 or max_value < 0: - max_value = None - dtype2, x2 = draw( - helpers.dtype_and_values( - small_abs_safety_factor=16, - large_abs_safety_factor=16, - safety_factor_scale="log", - max_value=max_value, - dtype=[dtype2], - ) - ) - dtype2 = dtype2[0] - if "int" in dtype2: - x2 = ivy.nested_map( - x2[0], lambda x: abs(x), include_derived={list: True}, shallow=False - ) - return [dtype1, dtype2], [x1, x2] - - # pow @handle_test( fn_tree="functional.ivy.pow", @@ -1262,6 +1591,24 @@ def test_pow(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): ) +@handle_test( + fn_tree="functional.ivy.rad2deg", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric") + ), +) +def test_rad2deg(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): + input_dtype, x = dtype_and_x + helpers.test_function( + input_dtypes=input_dtype, + test_flags=test_flags, + backend_to_test=backend_fw, + fn_name=fn_name, + on_device=on_device, + x=x[0], + ) + + # real @handle_test( fn_tree="functional.ivy.real", @@ -1281,6 +1628,31 @@ def test_real(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): ) +# reciprocal +@handle_test( + fn_tree="functional.ivy.reciprocal", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + small_abs_safety_factor=4, + large_abs_safety_factor=4, + safety_factor_scale="log", + num_arrays=1, + ), +) +def test_reciprocal(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): + input_dtype, x = dtype_and_x + helpers.test_function( + input_dtypes=input_dtype, + test_flags=test_flags, + backend_to_test=backend_fw, + fn_name=fn_name, + on_device=on_device, + rtol_=1e-1, + atol_=1e-1, + x=x[0], + ) + + # remainder @handle_test( fn_tree="functional.ivy.remainder", @@ -1397,16 +1769,14 @@ def test_sinh(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): ) -# square +# sqrt @handle_test( - fn_tree="functional.ivy.square", + fn_tree="functional.ivy.sqrt", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - large_abs_safety_factor=2, - safety_factor_scale="log", - ), + available_dtypes=helpers.get_dtypes("float"), allow_inf=False + ).filter(lambda x: x[0][0] not in ["bfloat16"]), ) -def test_square(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): +def test_sqrt(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): input_dtype, x = dtype_and_x helpers.test_function( input_dtypes=input_dtype, @@ -1414,18 +1784,22 @@ def test_square(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): backend_to_test=backend_fw, fn_name=fn_name, on_device=on_device, + rtol_=1e-2, + atol_=1e-2, x=x[0], ) -# sqrt +# square @handle_test( - fn_tree="functional.ivy.sqrt", + fn_tree="functional.ivy.square", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), allow_inf=False - ).filter(lambda x: x[0][0] not in ["bfloat16"]), + available_dtypes=helpers.get_dtypes("numeric"), + large_abs_safety_factor=2, + safety_factor_scale="log", + ), ) -def test_sqrt(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): +def test_square(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): input_dtype, x = dtype_and_x helpers.test_function( input_dtypes=input_dtype, @@ -1433,8 +1807,6 @@ def test_sqrt(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): backend_to_test=backend_fw, fn_name=fn_name, on_device=on_device, - rtol_=1e-2, - atol_=1e-2, x=x[0], ) @@ -1505,32 +1877,6 @@ def test_tanh(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): ) -# trapz -@st.composite -def _either_x_dx(draw): - rand = (draw(st.integers(min_value=0, max_value=1)),) - if rand == 0: - either_x_dx = draw( - helpers.dtype_and_values( - avaliable_dtypes=st.shared( - helpers.get_dtypes("float"), key="trapz_dtype" - ), - min_value=-100, - max_value=100, - min_num_dims=1, - max_num_dims=3, - min_dim_size=1, - max_dim_size=3, - ) - ) - return rand, either_x_dx - else: - either_x_dx = draw( - st.floats(min_value=-10, max_value=10), - ) - return rand, either_x_dx - - @handle_test( fn_tree="functional.ivy.trapz", dtype_values_axis=helpers.dtype_values_axis( @@ -1594,162 +1940,6 @@ def test_trunc(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): ) -# Extra # -# ------# - - -# erf -@handle_test( - fn_tree="functional.ivy.erf", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), -) -def test_erf(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): - input_dtype, x = dtype_and_x - helpers.test_function( - input_dtypes=input_dtype, - test_flags=test_flags, - backend_to_test=backend_fw, - fn_name=fn_name, - on_device=on_device, - rtol_=1e-2, - atol_=1e-2, - x=x[0], - ) - - -@st.composite -def min_max_helper(draw): - use_where = draw(st.booleans()) - if use_where: - dtype_and_x = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, - small_abs_safety_factor=6, - large_abs_safety_factor=6, - safety_factor_scale="log", - ) - ) - else: - dtype_and_x = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, - min_value=-1e5, - max_value=1e5, - safety_factor_scale="log", - ) - ) - return dtype_and_x, use_where - - -# minimum -@handle_test( - fn_tree="functional.ivy.minimum", - dtype_and_x_and_use_where=min_max_helper(), -) -def test_minimum( - *, dtype_and_x_and_use_where, test_flags, backend_fw, fn_name, on_device -): - (input_dtype, x), use_where = dtype_and_x_and_use_where - helpers.test_function( - input_dtypes=input_dtype, - test_flags=test_flags, - backend_to_test=backend_fw, - fn_name=fn_name, - on_device=on_device, - rtol_=1e-2, - atol_=1e-2, - x1=x[0], - x2=x[1], - use_where=use_where, - ) - - -# maximum -@handle_test( - fn_tree="functional.ivy.maximum", - dtype_and_x_and_use_where=min_max_helper(), -) -def test_maximum( - *, dtype_and_x_and_use_where, test_flags, backend_fw, fn_name, on_device -): - (input_dtype, x), use_where = dtype_and_x_and_use_where - helpers.test_function( - input_dtypes=input_dtype, - test_flags=test_flags, - backend_to_test=backend_fw, - fn_name=fn_name, - on_device=on_device, - rtol_=1e-2, - atol_=1e-2, - x1=x[0], - x2=x[1], - use_where=use_where, - ) - - -# reciprocal -@handle_test( - fn_tree="functional.ivy.reciprocal", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - small_abs_safety_factor=4, - large_abs_safety_factor=4, - safety_factor_scale="log", - num_arrays=1, - ), -) -def test_reciprocal(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): - input_dtype, x = dtype_and_x - helpers.test_function( - input_dtypes=input_dtype, - test_flags=test_flags, - backend_to_test=backend_fw, - fn_name=fn_name, - on_device=on_device, - rtol_=1e-1, - atol_=1e-1, - x=x[0], - ) - - -@handle_test( - fn_tree="functional.ivy.deg2rad", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric") - ), -) -def test_deg2rad(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): - input_dtype, x = dtype_and_x - helpers.test_function( - input_dtypes=input_dtype, - test_flags=test_flags, - backend_to_test=backend_fw, - fn_name=fn_name, - on_device=on_device, - x=x[0], - ) - - -@handle_test( - fn_tree="functional.ivy.rad2deg", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric") - ), -) -def test_rad2deg(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): - input_dtype, x = dtype_and_x - helpers.test_function( - input_dtypes=input_dtype, - test_flags=test_flags, - backend_to_test=backend_fw, - fn_name=fn_name, - on_device=on_device, - x=x[0], - ) - - # trunc_divide @handle_test( fn_tree="functional.ivy.trunc_divide", @@ -1776,185 +1966,3 @@ def test_trunc_divide(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device x1=x[0], x2=x[1], ) - - -# isreal -@handle_test( - fn_tree="functional.ivy.isreal", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("real_and_complex") - ), - test_gradients=st.just(False), -) -def test_isreal(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): - input_dtype, x = dtype_and_x - helpers.test_function( - input_dtypes=input_dtype, - test_flags=test_flags, - backend_to_test=backend_fw, - fn_name=fn_name, - on_device=on_device, - x=x[0], - ) - - -# fmod -@handle_test( - fn_tree="functional.ivy.fmod", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - num_arrays=2, - shared_dtype=False, - large_abs_safety_factor=6, - small_abs_safety_factor=6, - safety_factor_scale="log", - ), - test_gradients=st.just(False), -) -def test_fmod(dtype_and_x, test_flags, backend_fw, fn_name, on_device): - input_dtype, x = dtype_and_x - # Make sure values is not too close to zero - assume(not np.any(np.isclose(x[0], 0))) - assume(not np.any(np.isclose(x[1], 0))) - # jax raises inconsistent gradients for negative numbers in x1 - if (np.any(x[0] < 0) or np.any(x[1] < 0)) and ivy.current_backend_str() == "jax": - test_flags.test_gradients = False - test_flags.as_variable = [test_flags.as_variable, False] - helpers.test_function( - input_dtypes=input_dtype, - test_flags=test_flags, - on_device=on_device, - backend_to_test=backend_fw, - fn_name=fn_name, - x1=x[0], - x2=x[1], - ) - - -# lcm -@handle_test( - fn_tree="functional.ivy.lcm", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=["int16", "int32", "int64"], - num_arrays=2, - shared_dtype=False, - min_num_dims=1, - max_num_dims=3, - min_value=-100, - max_value=100, - allow_nan=False, - ), - test_gradients=st.just(False), -) -def test_lcm(dtype_and_x, test_flags, backend_fw, fn_name, on_device): - input_dtype, x = dtype_and_x - helpers.test_function( - input_dtypes=input_dtype, - test_flags=test_flags, - on_device=on_device, - backend_to_test=backend_fw, - fn_name=fn_name, - x1=x[0], - x2=x[1], - ) - - -# gcd -@handle_test( - fn_tree="functional.ivy.gcd", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("integer"), - num_arrays=2, - shared_dtype=False, - min_num_dims=1, - max_num_dims=3, - min_value=-100, - max_value=100, - allow_nan=False, - ), - test_gradients=st.just(False), -) -def test_gcd(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): - input_dtype, x = dtype_and_x - helpers.test_function( - input_dtypes=input_dtype, - test_flags=test_flags, - backend_to_test=backend_fw, - fn_name=fn_name, - on_device=on_device, - x1=x[0], - x2=x[1], - ) - - -# imag -@handle_test( - fn_tree="functional.ivy.imag", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - min_value=-5, - max_value=5, - max_dim_size=5, - max_num_dims=5, - min_dim_size=1, - min_num_dims=1, - allow_inf=False, - allow_nan=False, - ), - test_gradients=st.just(False), -) -def test_imag( - *, - dtype_and_x, - test_flags, - backend_fw, - fn_name, - on_device, -): - input_dtype, x = dtype_and_x - helpers.test_function( - input_dtypes=input_dtype, - test_flags=test_flags, - backend_to_test=backend_fw, - fn_name=fn_name, - on_device=on_device, - val=x[0], - ) - - -# angle -@handle_test( - fn_tree="functional.ivy.angle", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=["float64"], - min_value=-5, - max_value=5, - max_dim_size=5, - max_num_dims=5, - min_dim_size=1, - min_num_dims=1, - allow_inf=False, - allow_nan=False, - ), - deg=st.booleans(), - test_gradients=st.just(False), -) -def test_angle( - *, - dtype_and_x, - deg, - test_flags, - backend_fw, - fn_name, - on_device, -): - input_dtype, z = dtype_and_x - helpers.test_function( - input_dtypes=input_dtype, - test_flags=test_flags, - backend_to_test=backend_fw, - fn_name=fn_name, - on_device=on_device, - z=z[0], - deg=deg, - ) diff --git a/ivy_tests/test_ivy/test_functional/test_core/test_general.py b/ivy_tests/test_ivy/test_functional/test_core/test_general.py index 30bb022a7f4cf..1c4aa8fe876a6 100644 --- a/ivy_tests/test_ivy/test_functional/test_core/test_general.py +++ b/ivy_tests/test_ivy/test_functional/test_core/test_general.py @@ -5,6 +5,20 @@ import math from types import SimpleNamespace +import pytest +from hypothesis import given, assume, strategies as st +import numpy as np +from collections.abc import Sequence + +# local +import threading +import ivy + +import ivy_tests.test_ivy.helpers as helpers +from ivy_tests.test_ivy.helpers import handle_test, BackendHandler +from ivy_tests.test_ivy.helpers.assertions import assert_all_close +from ivy_tests.test_ivy.test_functional.test_core.test_elementwise import pow_helper + try: import tensorflow as tf except ImportError: @@ -17,20 +31,11 @@ except ImportError: jnp = SimpleNamespace() -import pytest -from hypothesis import given, assume, strategies as st -import numpy as np -from collections.abc import Sequence - try: import torch.multiprocessing as multiprocessing except ImportError: multiprocessing = SimpleNamespace() -# local -import threading -import ivy - try: import ivy.functional.backends.jax except ImportError: @@ -46,13 +51,29 @@ except ImportError: ivy.functional.backends.torch = SimpleNamespace() -import ivy_tests.test_ivy.helpers as helpers -from ivy_tests.test_ivy.helpers import handle_test, BackendHandler -from ivy_tests.test_ivy.helpers.assertions import assert_all_close -from ivy_tests.test_ivy.test_functional.test_core.test_elementwise import pow_helper -# Helpers # -# --------# +# --- Helpers --- # +# --------------- # + + +def _composition_1(): + return ivy.relu().argmax() + + +def _composition_2(): + return ivy.ceil() or ivy.linspace() + + +def _fn1(x, y): + return ivy.matmul(x, y) + + +def _fn2(x, y): + return ivy.vecdot(x, y) + + +def _fn3(x, y): + ivy.add(x, y) def _get_shape_of_list(lst, shape=()): @@ -70,196 +91,365 @@ def _get_shape_of_list(lst, shape=()): return shape -# Tests # -# ------# +@st.composite # ToDo remove when helpers.get_dtypes supports it +def _get_valid_numeric_no_unsigned(draw): + return list( + set(draw(helpers.get_dtypes("numeric"))).difference( + draw(helpers.get_dtypes("unsigned")) + ) + ) -@given(fw_str=st.sampled_from(["numpy", "jax", "torch", "tensorflow"])) -def test_set_framework(fw_str): - ivy.set_backend(fw_str) - ivy.previous_backend() +@st.composite +def _isin_data_generation_helper(draw): + assume_unique = draw(st.booleans()) + if assume_unique: + dtype_and_x = helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + num_arrays=2, + shared_dtype=True, + ).filter(lambda x: np.array_equal(x[1][0], np.unique(x[1][0]))) + else: + dtype_and_x = helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + num_arrays=2, + shared_dtype=True, + ) + return assume_unique, draw(dtype_and_x) -def test_use_within_use_framework(): - with ivy.functional.backends.numpy.use: - pass - with ivy.functional.backends.jax.use: - pass - with ivy.functional.backends.tensorflow.use: - pass - with ivy.functional.backends.torch.use: - pass +# fourier_encode +# @given( +# x=helpers.dtype_and_values(ivy_np.valid_float_dtypes, min_num_dims=1), +# max_freq=helpers.dtype_and_values(ivy_np.valid_float_dtypes), +# num_bands=st.integers(min_value=1,max_value=100000), +# as_variable=st.booleans(), +# num_positional_args=st.integers(0, 3), +# native_array=st.booleans(), +# container=st.booleans(), +# instance_method=st.booleans(), +# ) +# def test_fourier_encode( +# x, +# max_freq, +# num_bands, +# as_variable, +# num_positional_args, +# native_array, +# container, +# instance_method, +# device, +# call, +# fw +# ): +# # smoke test +# dtype_x, x = x +# dtype_max_freq, max_freq = max_freq +# if fw == "torch" and dtype_x in ["uint16", "uint32", "uint64"]: +# return +# helpers.test_function( +# dtype_x, +# as_variable, +# False, +# num_positional_args, +# native_array, +# container, +# instance_method, +# fw, +# "fourier_encode", +# x=np.asarray(x, dtype=dtype_x), +# max_freq=np.asarray(max_freq,dtype=dtype_max_freq), +# num_bands=num_bands +# ) -# match_kwargs -@given(allow_duplicates=st.booleans()) -def test_match_kwargs(allow_duplicates): - def func_a(a, b, c=2): - pass +@st.composite +def _values_and_ndindices( + draw, + *, + array_dtypes, + indices_dtypes=helpers.get_dtypes("integer"), + allow_inf=False, + x_min_value=None, + x_max_value=None, + min_num_dims=2, + max_num_dims=5, + min_dim_size=1, + max_dim_size=10, +): + x_dtype, x, x_shape = draw( + helpers.dtype_and_values( + available_dtypes=array_dtypes, + allow_inf=allow_inf, + ret_shape=True, + min_value=x_min_value, + max_value=x_max_value, + min_num_dims=min_num_dims, + max_num_dims=max_num_dims, + min_dim_size=min_dim_size, + max_dim_size=max_dim_size, + ) + ) + x_dtype = x_dtype[0] if isinstance(x_dtype, (list)) else x_dtype + x = x[0] if isinstance(x, (list)) else x + # indices_dims defines how far into the array to index. + indices_dims = draw( + helpers.ints( + min_value=1, + max_value=len(x_shape) - 1, + ) + ) - def func_b(a, d, e=5): - return None + # num_ndindices defines the number of elements to generate. + num_ndindices = draw( + helpers.ints( + min_value=1, + max_value=x_shape[indices_dims], + ) + ) - class ClassA: - def __init__(self, c, f, g=3): - pass + # updates_dims defines how far into the array to index. + updates_dtype, updates = draw( + helpers.dtype_and_values( + available_dtypes=array_dtypes, + allow_inf=allow_inf, + shape=x_shape[indices_dims:], + num_arrays=num_ndindices, + shared_dtype=True, + ) + ) + updates_dtype = ( + updates_dtype[0] if isinstance(updates_dtype, list) else updates_dtype + ) + updates = updates[0] if isinstance(updates, list) else updates - kwargs = {"a": 0, "b": 1, "c": 2, "d": 3, "e": 4, "f": 5, "g": 6} - kwfa, kwfb, kwca = ivy.match_kwargs( - kwargs, func_a, func_b, ClassA, allow_duplicates=allow_duplicates + indices = [] + indices_dtype = draw(st.sampled_from(indices_dtypes)) + for _ in range(num_ndindices): + nd_index = [] + for j in range(indices_dims): + axis_index = draw( + helpers.ints( + min_value=0, + max_value=max(0, x_shape[j] - 1), + ) + ) + nd_index.append(axis_index) + indices.append(nd_index) + indices = np.array(indices) + return [x_dtype, indices_dtype, updates_dtype], x, indices, updates + + +@st.composite +def _vector_norm_helper(draw): + dtype, x = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float", key="clip_vector_norm"), + min_num_dims=1, + min_value=-100, + max_value=100, + abs_smallest_val=1e-2, + safety_factor_scale="log", + ) ) - if allow_duplicates: - assert kwfa == {"a": 0, "b": 1, "c": 2} - assert kwfb == {"a": 0, "d": 3, "e": 4} - assert kwca == {"c": 2, "f": 5, "g": 6} + if ivy.is_int_dtype(dtype[0]): + max_val = ivy.iinfo(dtype[0]).max else: - assert kwfa == {"a": 0, "b": 1, "c": 2} - assert kwfb == {"d": 3, "e": 4} - assert kwca == {"f": 5, "g": 6} + max_val = ivy.finfo(dtype[0]).max + max_x = np.abs(x[0]).max() + if max_x > 1: + max_p = math.log(max_val) / math.log(max_x) + else: + max_p = math.log(max_val) + p = draw(helpers.floats(abs_smallest_val=1e-2, min_value=-max_p, max_value=max_p)) + max_norm_val = math.log(max_val / max_x) + max_norm = draw( + helpers.floats( + large_abs_safety_factor=4, + safety_factor_scale="log", + min_value=1e-2, + max_value=max_norm_val, + ) + ) + return dtype, x, max_norm, p -# get_referrers_recursive -def test_get_referrers_recursive(): - class SomeClass: - def __init__(self): - self.x = [1, 2] - self.y = [self.x] +@st.composite +def array_and_ndindices_batch_dims( + draw, + *, + array_dtypes, + indices_dtypes=helpers.get_dtypes("integer"), + allow_inf=False, + min_num_dims=1, + max_num_dims=5, + min_dim_size=1, + max_dim_size=10, +): + x_dtype, x, x_shape = draw( + helpers.dtype_and_values( + available_dtypes=array_dtypes, + allow_inf=allow_inf, + ret_shape=True, + min_num_dims=min_num_dims, + max_num_dims=max_num_dims, + min_dim_size=min_dim_size, + max_dim_size=max_dim_size, + ) + ) - some_obj = SomeClass() - refs = ivy.get_referrers_recursive(some_obj.x) - ref_keys = refs.keys() - assert len(ref_keys) == 3 - assert "repr" in ref_keys - assert refs["repr"] == "[1,2]" - y_id = str(id(some_obj.y)) - y_refs = refs[y_id] - assert y_refs["repr"] == "[[1,2]]" - some_obj_dict_id = str(id(some_obj.__dict__)) - assert y_refs[some_obj_dict_id] == "tracked" - dict_refs = refs[some_obj_dict_id] - assert dict_refs["repr"] == "{'x':[1,2],'y':[[1,2]]}" - some_obj_id = str(id(some_obj)) - some_obj_refs = dict_refs[some_obj_id] - assert some_obj_refs["repr"] == str(some_obj).replace(" ", "") - assert len(some_obj_refs) == 1 + batch_dims = draw( + helpers.ints( + min_value=0, + max_value=len(x_shape) - 1, + ) + ) + # indices_dims defines how far into the array to index. + indices_dims = draw( + helpers.ints( + min_value=1, + max_value=max(1, len(x_shape) - batch_dims), + ) + ) + + batch_shape = x_shape[0:batch_dims] + shape_var = draw( + helpers.get_shape( + allow_none=False, + min_num_dims=min_num_dims, + max_num_dims=max_num_dims - batch_dims, + min_dim_size=min_dim_size, + max_dim_size=max_dim_size, + ) + ) + ndindices_shape = list(batch_shape) + list(shape_var) + [indices_dims] + ndindices = np.zeros(ndindices_shape, dtype="int32") + if len(ndindices_shape) <= 1: + enumerator = ndindices + else: + enumerator = np.zeros(ndindices_shape[0:-1], dtype="int32") + ndindices_dtype = draw(st.sampled_from(indices_dtypes)) + for idx, _ in np.ndenumerate(enumerator): + bounds = [] + for j in range(0, indices_dims): + bounds.append(x_shape[j + batch_dims] - 1) + ndindices[idx] = draw(ndindices_with_bounds(bounds=bounds)) + ndindices = np.asarray(ndindices, ndindices_dtype) + return [x_dtype[0], ndindices_dtype], x[0], ndindices, batch_dims -# array_equal +@st.composite +def ndindices_with_bounds( + draw, + *, + bounds, +): + arr = [] + for i in bounds: + x = draw( + helpers.ints( + min_value=0, + max_value=max(0, i), + ) + ) + arr.append(x) + return arr + + +# --- Main --- # +# ------------ # + + +# all_equal @handle_test( - fn_tree="functional.ivy.array_equal", + fn_tree="functional.ivy.all_equal", dtypes_and_xs=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), - num_arrays=2, + num_arrays=helpers.ints(min_value=2, max_value=10), + min_num_dims=1, ), + equality_matrix=st.booleans(), test_with_out=st.just(False), test_gradients=st.just(False), ) -def test_array_equal(dtypes_and_xs, test_flags, backend_fw, fn_name, on_device): +def test_all_equal( + dtypes_and_xs, equality_matrix, test_flags, backend_fw, fn_name, on_device +): dtypes, arrays = dtypes_and_xs + kw = {} + i = 0 + for x_ in arrays: + kw["x{}".format(i)] = x_ + i += 1 + test_flags.num_positional_args = len(arrays) helpers.test_function( input_dtypes=dtypes, test_flags=test_flags, on_device=on_device, backend_to_test=backend_fw, fn_name=fn_name, - x0=arrays[0], - x1=arrays[1], + **kw, + equality_matrix=equality_matrix, ) -# get_item -# TODO: add container and array instance methods -@handle_test( - fn_tree="functional.ivy.get_item", - ground_truth_backend="numpy", - dtypes_x_query=helpers.dtype_array_query( - available_dtypes=helpers.get_dtypes("valid"), - ), - copy=st.booleans(), - test_with_out=st.just(False), - test_gradients=st.just(False), - test_instance_method=st.just(False), - container_flags=st.just([False]), +def test_arg_info(): + return + + +@given( + x_n_value=st.sampled_from( + [ + [ivy.value_is_nan, ["x", "include_infs"]], + [ivy.clip_matrix_norm, ["x", "max_norm", "p", "out"]], + ] + ) ) -def test_get_item( - dtypes_x_query, - copy, - test_flags, - backend_fw, - fn_name, - on_device, -): - dtypes, x, query = dtypes_x_query - try: - helpers.test_function( - input_dtypes=dtypes, - test_flags=test_flags, - on_device=on_device, - backend_to_test=backend_fw, - fn_name=fn_name, - x=x, - query=query, - copy=copy, - ) - except ivy.utils.exceptions.IvyBackendException as e: - if backend_fw == "paddle" and "only supports access to dimension 0 to 9" in e: - assume(False) - else: - raise +def test_arg_names(x_n_value): + x, value = x_n_value + ret = ivy.arg_names(x) + assert ret == value -# set_item -# TODO: add container and array instance methods +# array_equal @handle_test( - fn_tree="functional.ivy.set_item", - ground_truth_backend="numpy", - dtypes_x_query_val=helpers.dtype_array_query_val( + fn_tree="functional.ivy.array_equal", + dtypes_and_xs=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), + num_arrays=2, ), - copy=st.booleans(), test_with_out=st.just(False), test_gradients=st.just(False), - test_instance_method=st.just(False), - container_flags=st.just([False]), ) -def test_set_item( - dtypes_x_query_val, - copy, - test_flags, - backend_fw, - fn_name, - on_device, -): - dtypes, x, query, val = dtypes_x_query_val +def test_array_equal(dtypes_and_xs, test_flags, backend_fw, fn_name, on_device): + dtypes, arrays = dtypes_and_xs helpers.test_function( input_dtypes=dtypes, test_flags=test_flags, on_device=on_device, backend_to_test=backend_fw, fn_name=fn_name, - rtol_=1e-03, # needed only for the paddle backend - x=x, - query=query, - val=val, - copy=copy, + x0=arrays[0], + x1=arrays[1], ) -# to_numpy @handle_test( - fn_tree="functional.ivy.to_numpy", - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), + fn_tree="functional.ivy.assert_supports_inplace", + x_val_and_dtypes=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid") ), - copy=st.booleans(), + ground_truth_backend="numpy", test_with_out=st.just(False), test_gradients=st.just(False), ) -def test_to_numpy(*, dtype_x, copy, test_flags, backend_fw, fn_name, on_device): - dtype, x = dtype_x - # torch throws an exception - if ivy.current_backend_str() == "torch" and not copy: +def test_assert_supports_inplace( + x_val_and_dtypes, test_flags, backend_fw, fn_name, on_device +): + dtype, x = x_val_and_dtypes + if backend_fw in ["tensorflow", "jax", "paddle"]: return + assume("bfloat16" not in dtype) helpers.test_function( input_dtypes=dtype, test_flags=test_flags, @@ -267,141 +457,95 @@ def test_to_numpy(*, dtype_x, copy, test_flags, backend_fw, fn_name, on_device): backend_to_test=backend_fw, fn_name=fn_name, x=x[0], - copy=copy, ) -# to_scalar -@handle_test( - fn_tree="functional.ivy.to_scalar", - x0_n_x1_n_res=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - min_num_dims=1, - max_num_dims=1, - min_dim_size=1, - max_dim_size=1, - large_abs_safety_factor=20, - ), - test_with_out=st.just(False), - test_gradients=st.just(False), -) -def test_to_scalar(x0_n_x1_n_res, test_flags, backend_fw, fn_name, on_device): - dtype, x = x0_n_x1_n_res - helpers.test_function( - input_dtypes=dtype, - test_flags=test_flags, - on_device=on_device, - backend_to_test=backend_fw, - fn_name=fn_name, - x=x[0], - ) - +def test_cache_fn(): + def func(): + return ivy.random_uniform() -# to_list -@handle_test( - fn_tree="functional.ivy.to_list", - x0_n_x1_n_res=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - large_abs_safety_factor=20, - ), - test_with_out=st.just(False), - test_gradients=st.just(False), -) -def test_to_list(x0_n_x1_n_res, test_flags, backend_fw, fn_name, on_device): - dtype, x = x0_n_x1_n_res - helpers.test_function( - input_dtypes=dtype, - test_flags=test_flags, - on_device=on_device, - backend_to_test=backend_fw, - fn_name=fn_name, - x=x[0], - ) + # return a single cached_fn and then query this + cached_fn = ivy.cache_fn(func) + ret0 = cached_fn() + ret0_again = cached_fn() + ret1 = func() + assert ivy.to_numpy(ret0).item() == ivy.to_numpy(ret0_again).item() + assert ivy.to_numpy(ret0).item() != ivy.to_numpy(ret1).item() + assert ret0 is ret0_again + assert ret0 is not ret1 -# shape -# TODO: add container and array methods -@handle_test( - fn_tree="functional.ivy.shape", - x0_n_x1_n_res=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid") - ), - as_array=st.booleans(), - test_with_out=st.just(False), - test_instance_method=st.just(False), - test_gradients=st.just(False), -) -def test_shape(x0_n_x1_n_res, as_array, test_flags, backend_fw, fn_name, on_device): - dtype, x = x0_n_x1_n_res - # instance_method=False because the shape property would overwrite the shape method - helpers.test_function( - input_dtypes=dtype, - test_flags=test_flags, - on_device=on_device, - backend_to_test=backend_fw, - fn_name=fn_name, - x=x[0], - as_array=as_array, - ) + # call ivy.cache_fn repeatedly, the new cached functions + # each use the same global dict + ret0 = ivy.cache_fn(func)() + ret0_again = ivy.cache_fn(func)() + ret1 = func() + + assert ivy.to_numpy(ret0).item() == ivy.to_numpy(ret0_again).item() + assert ivy.to_numpy(ret0).item() != ivy.to_numpy(ret1).item() + assert ret0 is ret0_again + assert ret0 is not ret1 -# get_num_dims +def test_cache_fn_with_args(): + def func(_): + return ivy.random_uniform() + + # return a single cached_fn and then query this + cached_fn = ivy.cache_fn(func) + ret0 = cached_fn(0) + ret0_again = cached_fn(0) + ret1 = cached_fn(1) + + assert ivy.to_numpy(ret0).item() == ivy.to_numpy(ret0_again).item() + assert ivy.to_numpy(ret0).item() != ivy.to_numpy(ret1).item() + assert ret0 is ret0_again + assert ret0 is not ret1 + + # call ivy.cache_fn repeatedly, the new cached functions + # each use the same global dict + ret0 = ivy.cache_fn(func)(0) + ret0_again = ivy.cache_fn(func)(0) + ret1 = ivy.cache_fn(func)(1) + + assert ivy.to_numpy(ret0).item() == ivy.to_numpy(ret0_again).item() + assert ivy.to_numpy(ret0).item() != ivy.to_numpy(ret1).item() + assert ret0 is ret0_again + assert ret0 is not ret1 + + +# clip_matrix_norm @handle_test( - fn_tree="functional.ivy.get_num_dims", - x0_n_x1_n_res=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid") + fn_tree="functional.ivy.clip_matrix_norm", + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_num_dims=2, + max_num_dims=5, + min_dim_size=1, + max_dim_size=5, + min_value=-10, + max_value=10, + abs_smallest_val=1e-4, ), - as_array=st.booleans(), - test_with_out=st.just(False), - test_gradients=st.just(False), + max_norm=st.floats(min_value=0.137, max_value=1e05), + p=st.sampled_from([1, 2, float("inf"), "fro", "nuc"]), ) -def test_get_num_dims( - x0_n_x1_n_res, as_array, test_flags, backend_fw, fn_name, on_device +def test_clip_matrix_norm( + dtype_x, max_norm, p, test_flags, backend_fw, fn_name, on_device ): - dtype, x = x0_n_x1_n_res + dtype, x = dtype_x helpers.test_function( input_dtypes=dtype, test_flags=test_flags, on_device=on_device, backend_to_test=backend_fw, fn_name=fn_name, + rtol_=1e-2, + atol_=1e-2, x=x[0], - as_array=as_array, - ) - - -@st.composite -def _vector_norm_helper(draw): - dtype, x = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float", key="clip_vector_norm"), - min_num_dims=1, - min_value=-100, - max_value=100, - abs_smallest_val=1e-2, - safety_factor_scale="log", - ) - ) - if ivy.is_int_dtype(dtype[0]): - max_val = ivy.iinfo(dtype[0]).max - else: - max_val = ivy.finfo(dtype[0]).max - max_x = np.abs(x[0]).max() - if max_x > 1: - max_p = math.log(max_val) / math.log(max_x) - else: - max_p = math.log(max_val) - p = draw(helpers.floats(abs_smallest_val=1e-2, min_value=-max_p, max_value=max_p)) - max_norm_val = math.log(max_val / max_x) - max_norm = draw( - helpers.floats( - large_abs_safety_factor=4, - safety_factor_scale="log", - min_value=1e-2, - max_value=max_norm_val, - ) + max_norm=max_norm, + p=p, ) - return dtype, x, max_norm, p # clip_vector_norm @@ -427,346 +571,259 @@ def test_clip_vector_norm( ) -# fourier_encode -# @given( -# x=helpers.dtype_and_values(ivy_np.valid_float_dtypes, min_num_dims=1), -# max_freq=helpers.dtype_and_values(ivy_np.valid_float_dtypes), -# num_bands=st.integers(min_value=1,max_value=100000), -# as_variable=st.booleans(), -# num_positional_args=st.integers(0, 3), -# native_array=st.booleans(), -# container=st.booleans(), -# instance_method=st.booleans(), -# ) -# def test_fourier_encode( -# x, -# max_freq, -# num_bands, -# as_variable, -# num_positional_args, -# native_array, -# container, -# instance_method, -# device, -# call, -# fw -# ): -# # smoke test -# dtype_x, x = x -# dtype_max_freq, max_freq = max_freq -# if fw == "torch" and dtype_x in ["uint16", "uint32", "uint64"]: -# return -# helpers.test_function( -# dtype_x, -# as_variable, -# False, -# num_positional_args, -# native_array, -# container, -# instance_method, -# fw, -# "fourier_encode", -# x=np.asarray(x, dtype=dtype_x), -# max_freq=np.asarray(max_freq,dtype=dtype_max_freq), -# num_bands=num_bands -# ) +# container types +def test_container_types(): + cont_types = ivy.container_types() + assert isinstance(cont_types, list) + for cont_type in cont_types: + assert hasattr(cont_type, "keys") + assert hasattr(cont_type, "values") + assert hasattr(cont_type, "items") -@st.composite -def _values_and_ndindices( - draw, - *, - array_dtypes, - indices_dtypes=helpers.get_dtypes("integer"), - allow_inf=False, - x_min_value=None, - x_max_value=None, - min_num_dims=2, - max_num_dims=5, - min_dim_size=1, - max_dim_size=10, -): - x_dtype, x, x_shape = draw( - helpers.dtype_and_values( - available_dtypes=array_dtypes, - allow_inf=allow_inf, - ret_shape=True, - min_value=x_min_value, - max_value=x_max_value, - min_num_dims=min_num_dims, - max_num_dims=max_num_dims, - min_dim_size=min_dim_size, - max_dim_size=max_dim_size, - ) - ) - x_dtype = x_dtype[0] if isinstance(x_dtype, (list)) else x_dtype - x = x[0] if isinstance(x, (list)) else x - # indices_dims defines how far into the array to index. - indices_dims = draw( - helpers.ints( - min_value=1, - max_value=len(x_shape) - 1, - ) - ) +# Still to Add # +# ---------------# - # num_ndindices defines the number of elements to generate. - num_ndindices = draw( - helpers.ints( - min_value=1, - max_value=x_shape[indices_dims], - ) - ) - # updates_dims defines how far into the array to index. - updates_dtype, updates = draw( - helpers.dtype_and_values( - available_dtypes=array_dtypes, - allow_inf=allow_inf, - shape=x_shape[indices_dims:], - num_arrays=num_ndindices, - shared_dtype=True, - ) - ) - updates_dtype = ( - updates_dtype[0] if isinstance(updates_dtype, list) else updates_dtype - ) - updates = updates[0] if isinstance(updates, list) else updates - - indices = [] - indices_dtype = draw(st.sampled_from(indices_dtypes)) - for _ in range(num_ndindices): - nd_index = [] - for j in range(indices_dims): - axis_index = draw( - helpers.ints( - min_value=0, - max_value=max(0, x_shape[j] - 1), - ) - ) - nd_index.append(axis_index) - indices.append(nd_index) - indices = np.array(indices) - return [x_dtype, indices_dtype, updates_dtype], x, indices, updates +@given(fw=st.sampled_from(["torch", "tensorflow", "numpy", "jax"])) +def test_current_backend_str(fw): + ivy.set_backend(fw) + assert ivy.current_backend_str() == fw + ivy.previous_backend() -# scatter_flat +# default @handle_test( - fn_tree="functional.ivy.scatter_flat", - x=st.integers(min_value=1, max_value=10).flatmap( - lambda n: st.tuples( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_num_dims=1, - max_num_dims=1, - min_dim_size=n, - max_dim_size=n, - ), - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("integer"), - min_value=0, - max_value=max(n - 1, 0), - min_num_dims=1, - max_num_dims=1, - min_dim_size=n, - max_dim_size=n, - ).filter(lambda d_n_v: len(set(d_n_v[1][0])) == len(d_n_v[1][0])), - st.integers(min_value=n, max_value=n), - ) + fn_tree="functional.ivy.default", + x=st.one_of( + st.none(), + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + allow_inf=False, + min_num_dims=0, + min_dim_size=2, + ), + st.sampled_from([lambda *args, **kwargs: None]), + ), + default_val=st.one_of( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + allow_inf=False, + min_num_dims=0, + min_dim_size=2, + ), + st.sampled_from([lambda *args, **kwargs: None]), ), - reduction=st.sampled_from(["sum", "min", "max", "replace"]), - ground_truth_backend="tensorflow", ) -def test_scatter_flat(x, reduction, test_flags, backend_fw, fn_name, on_device): - # scatter_flat throws an error while computing gradients for tensorflow - # this has been fixed in the newer versions of tensorflow (2.10.0 onwards) - if backend_fw == "tensorflow": - grad_support_version = [2, 10, 0] - k = 0 - for number in [int(s) for s in tf.__version__.split(".") if s.isdigit()]: - if k > len(grad_support_version): - break - if number < grad_support_version[k]: - test_flags.test_gradients = False - k += 1 - (val_dtype, vals), (ind_dtype, ind), size = x - helpers.test_function( - input_dtypes=ind_dtype + val_dtype, - test_flags=test_flags, - xs_grad_idxs=[[0, 1]], - on_device=on_device, - backend_to_test=backend_fw, - fn_name=fn_name, - indices=ind[0], - updates=vals[0], - size=size, - reduction=reduction, - ) +def test_default(x, default_val, test_flags, backend_fw): + with BackendHandler.update_backend(backend_fw) as ivy_backend: + with_callable = False + if x is not None: + if hasattr(x, "__call__"): + with_callable = True + else: + x_dtype, x = x + x = x[0].tolist() if isinstance(x, list) else x + else: + if hasattr(default_val, "__call__"): + with_callable = True + else: + dv_dtype, default_val = default_val + default_val = ( + default_val[0].tolist() + if isinstance(default_val, list) + else default_val + ) + truth_val = ivy_backend.to_native(x if x is not None else default_val) + if with_callable: + assert ivy_backend.default(x, default_val) == truth_val + else: + assert_all_close( + np.asarray(ivy_backend.default(x, default_val)), + np.asarray(truth_val), + rtol=1e-3, + atol=1e-3, + backend=backend_fw, + ground_truth_backend=test_flags.ground_truth_backend, + ) -# scatter_nd + +# ToDo: re-add this test once ivy.get_backend is working correctly, with the returned +# ivy handle having no dependence on the globally set ivy +# @handle_cmd_line_args +# +# def test_class_ivy_handles(device, call): +# +# if call is helpers.np_call: +# # Numpy is the conflicting framework being tested against +# pytest.skip() +# +# class ArrayGen: +# def __init__(self, ivyh): +# self._ivy = ivyh +# +# def get_array(self): +# return self._ivy.array([0.0, 1.0, 2.0], dtype="float32", device=device) +# +# # create instance +# ag = ArrayGen(ivy.get_backend()) +# +# # create array from array generator +# x = ag.get_array() +# +# # verify this is not a numpy array +# assert not isinstance(x, np.ndarray) +# +# # change global framework to numpy +# ivy.set_backend("numpy") +# +# # create another array from array generator +# x = ag.get_array() +# +# # verify this is not still a numpy array +# assert not isinstance(x, np.ndarray) + + +# einops_rearrange @handle_test( - fn_tree="functional.ivy.scatter_nd", - x=_values_and_ndindices( - # ToDo: needs support for boolean arrays - array_dtypes=helpers.get_dtypes("numeric"), - indices_dtypes=["int32", "int64"], - x_min_value=0, - x_max_value=0, - min_num_dims=2, + fn_tree="functional.ivy.einops_rearrange", + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), allow_inf=False, + min_num_dims=4, + max_num_dims=4, + min_dim_size=2, + max_dim_size=2, + min_value=-1e05, + max_value=1e05, + ).filter( + lambda x: (ivy.array([x[1][0]], dtype="float32").shape[2] % 2 == 0) + and (ivy.array([x[1][0]], dtype="float32").shape[3] % 2 == 0) + and (x[0][0] not in ["float16", "bfloat16"]) + ), + pattern_and_axes_lengths=st.sampled_from( + [ + ("b h w c -> b h w c", {}), + ("b h w c -> (b h) w c", {}), + ("b h w c -> b c h w", {}), + ("b h w c -> h (b w) c", {}), + ("b h w c -> b (c h w)", {}), + ("b (h1 h) (w1 w) c -> (b h1 w1) h w c", {"h1": 2, "w1": 2}), + ("b (h h1) (w w1) c -> b h w (c h1 w1)", {"h1": 2, "w1": 2}), + ] ), - reduction=st.sampled_from(["sum", "min", "max", "replace"]), - test_gradients=st.just(False), ) -def test_scatter_nd(x, reduction, test_flags, backend_fw, fn_name, on_device): - (val_dtype, ind_dtype, update_dtype), vals, ind, updates = x - shape = vals.shape +def test_einops_rearrange( + dtype_x, pattern_and_axes_lengths, test_flags, backend_fw, fn_name, on_device +): + pattern, axes_lengths = pattern_and_axes_lengths + dtype, x = dtype_x helpers.test_function( - input_dtypes=[ind_dtype, update_dtype], + input_dtypes=dtype, test_flags=test_flags, on_device=on_device, backend_to_test=backend_fw, fn_name=fn_name, - indices=np.asarray(ind, dtype=ind_dtype), - updates=updates, - shape=shape, - reduction=reduction, + x=x[0], + pattern=pattern, + **axes_lengths, ) -# gather +# einops_reduce @handle_test( - fn_tree="functional.ivy.gather", - params_indices_others=helpers.array_indices_axis( - array_dtypes=helpers.get_dtypes("numeric"), - indices_dtypes=["int32", "int64"], - min_num_dims=1, - max_num_dims=5, - min_dim_size=1, - max_dim_size=10, + fn_tree="functional.ivy.einops_reduce", + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + allow_inf=False, + min_num_dims=4, + max_num_dims=4, + min_dim_size=2, + max_dim_size=2, + min_value=-1e05, + max_value=1e05, + ).filter( + lambda x: (ivy.array([x[1][0]], dtype="float32").shape[2] % 2 == 0) + and (ivy.array([x[1][0]], dtype="float32").shape[3] % 2 == 0) + and (x[0][0] not in ["float16", "bfloat16"]) + ), + pattern_and_axes_lengths=st.sampled_from( + [ + ("b c (h1 h2) (w1 w2) -> b c h1 w1", {"h2": 2, "w2": 2}), + ] ), + floattypes=helpers.get_dtypes("float"), + reduction=st.sampled_from(["min", "max", "sum", "mean", "prod"]), ) -def test_gather(params_indices_others, test_flags, backend_fw, fn_name, on_device): - dtypes, params, indices, axis, batch_dims = params_indices_others - helpers.test_function( - input_dtypes=dtypes, - test_flags=test_flags, - on_device=on_device, - backend_to_test=backend_fw, - fn_name=fn_name, - xs_grad_idxs=[[0, 0]], - params=params, - indices=indices, - axis=axis, - batch_dims=batch_dims, - ) - - -@st.composite -def array_and_ndindices_batch_dims( - draw, - *, - array_dtypes, - indices_dtypes=helpers.get_dtypes("integer"), - allow_inf=False, - min_num_dims=1, - max_num_dims=5, - min_dim_size=1, - max_dim_size=10, -): - x_dtype, x, x_shape = draw( - helpers.dtype_and_values( - available_dtypes=array_dtypes, - allow_inf=allow_inf, - ret_shape=True, - min_num_dims=min_num_dims, - max_num_dims=max_num_dims, - min_dim_size=min_dim_size, - max_dim_size=max_dim_size, - ) - ) - - batch_dims = draw( - helpers.ints( - min_value=0, - max_value=len(x_shape) - 1, - ) - ) - # indices_dims defines how far into the array to index. - indices_dims = draw( - helpers.ints( - min_value=1, - max_value=max(1, len(x_shape) - batch_dims), - ) - ) - - batch_shape = x_shape[0:batch_dims] - shape_var = draw( - helpers.get_shape( - allow_none=False, - min_num_dims=min_num_dims, - max_num_dims=max_num_dims - batch_dims, - min_dim_size=min_dim_size, - max_dim_size=max_dim_size, - ) - ) - ndindices_shape = list(batch_shape) + list(shape_var) + [indices_dims] - ndindices = np.zeros(ndindices_shape, dtype="int32") - if len(ndindices_shape) <= 1: - enumerator = ndindices - else: - enumerator = np.zeros(ndindices_shape[0:-1], dtype="int32") - ndindices_dtype = draw(st.sampled_from(indices_dtypes)) - for idx, _ in np.ndenumerate(enumerator): - bounds = [] - for j in range(0, indices_dims): - bounds.append(x_shape[j + batch_dims] - 1) - ndindices[idx] = draw(ndindices_with_bounds(bounds=bounds)) - ndindices = np.asarray(ndindices, ndindices_dtype) - return [x_dtype[0], ndindices_dtype], x[0], ndindices, batch_dims - - -@st.composite -def ndindices_with_bounds( - draw, +def test_einops_reduce( *, - bounds, + dtype_x, + pattern_and_axes_lengths, + floattypes, + reduction, + test_flags, + backend_fw, + fn_name, + on_device, ): - arr = [] - for i in bounds: - x = draw( - helpers.ints( - min_value=0, - max_value=max(0, i), - ) - ) - arr.append(x) - return arr + pattern, axes_lengths = pattern_and_axes_lengths + dtype, x = dtype_x + if (reduction in ["mean", "prod"]) and (dtype not in floattypes): + dtype = ["float32"] + # torch computes min and max differently and leads to inconsistent gradients + if backend_fw == "torch" and reduction in ["min", "max"]: + test_flags.test_gradients = False + helpers.test_function( + input_dtypes=dtype, + test_flags=test_flags, + on_device=on_device, + backend_to_test=backend_fw, + fn_name=fn_name, + rtol_=1e-1, + atol_=1e-1, + x=x[0], + pattern=pattern, + reduction=reduction, + **axes_lengths, + ) -# gather_nd +# einops_repeat @handle_test( - fn_tree="functional.ivy.gather_nd", - params_n_ndindices_batch_dims=array_and_ndindices_batch_dims( - array_dtypes=helpers.get_dtypes("numeric"), - indices_dtypes=["int32", "int64"], + fn_tree="functional.ivy.einops_repeat", + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), allow_inf=False, + min_num_dims=2, + max_num_dims=2, + min_dim_size=2, + ), + pattern_and_axes_lengths=st.sampled_from( + [ + ("h w -> h w repeat", {"repeat": 2}), + ("h w -> (repeat h) w", {"repeat": 2}), + ("h w -> h (repeat w)", {"repeat": 2}), + ("h w -> (h h2) (w w2)", {"h2": 2, "w2": 2}), + ("h w -> w h", {}), + ] ), ) -def test_gather_nd( - params_n_ndindices_batch_dims, test_flags, backend_fw, fn_name, on_device +def test_einops_repeat( + *, dtype_x, pattern_and_axes_lengths, test_flags, backend_fw, fn_name, on_device ): - dtypes, params, ndindices, batch_dims = params_n_ndindices_batch_dims + pattern, axes_lengths = pattern_and_axes_lengths + dtype, x = dtype_x + assume("uint16" not in dtype) helpers.test_function( - input_dtypes=dtypes, + input_dtypes=dtype, test_flags=test_flags, on_device=on_device, backend_to_test=backend_fw, fn_name=fn_name, - xs_grad_idxs=[[0, 0]], - params=params, - indices=ndindices, - batch_dims=batch_dims, + x=x[0], + pattern=pattern, + **axes_lengths, ) @@ -794,144 +851,32 @@ def test_exists(x): assert ret == y_true -# default -@handle_test( - fn_tree="functional.ivy.default", - x=st.one_of( - st.none(), - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - allow_inf=False, - min_num_dims=0, - min_dim_size=2, - ), - st.sampled_from([lambda *args, **kwargs: None]), - ), - default_val=st.one_of( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - allow_inf=False, - min_num_dims=0, - min_dim_size=2, - ), - st.sampled_from([lambda *args, **kwargs: None]), - ), -) -def test_default(x, default_val, test_flags, backend_fw): - with BackendHandler.update_backend(backend_fw) as ivy_backend: - with_callable = False - if x is not None: - if hasattr(x, "__call__"): - with_callable = True - else: - x_dtype, x = x - x = x[0].tolist() if isinstance(x, list) else x - else: - if hasattr(default_val, "__call__"): - with_callable = True - else: - dv_dtype, default_val = default_val - default_val = ( - default_val[0].tolist() - if isinstance(default_val, list) - else default_val - ) - - truth_val = ivy_backend.to_native(x if x is not None else default_val) - if with_callable: - assert ivy_backend.default(x, default_val) == truth_val - else: - assert_all_close( - np.asarray(ivy_backend.default(x, default_val)), - np.asarray(truth_val), - rtol=1e-3, - atol=1e-3, - backend=backend_fw, - ground_truth_backend=test_flags.ground_truth_backend, - ) - - -def test_cache_fn(): - def func(): - return ivy.random_uniform() - - # return a single cached_fn and then query this - cached_fn = ivy.cache_fn(func) - ret0 = cached_fn() - ret0_again = cached_fn() - ret1 = func() - - assert ivy.to_numpy(ret0).item() == ivy.to_numpy(ret0_again).item() - assert ivy.to_numpy(ret0).item() != ivy.to_numpy(ret1).item() - assert ret0 is ret0_again - assert ret0 is not ret1 - - # call ivy.cache_fn repeatedly, the new cached functions - # each use the same global dict - ret0 = ivy.cache_fn(func)() - ret0_again = ivy.cache_fn(func)() - ret1 = func() - - assert ivy.to_numpy(ret0).item() == ivy.to_numpy(ret0_again).item() - assert ivy.to_numpy(ret0).item() != ivy.to_numpy(ret1).item() - assert ret0 is ret0_again - assert ret0 is not ret1 - - -def test_cache_fn_with_args(): - def func(_): - return ivy.random_uniform() - - # return a single cached_fn and then query this - cached_fn = ivy.cache_fn(func) - ret0 = cached_fn(0) - ret0_again = cached_fn(0) - ret1 = cached_fn(1) - - assert ivy.to_numpy(ret0).item() == ivy.to_numpy(ret0_again).item() - assert ivy.to_numpy(ret0).item() != ivy.to_numpy(ret1).item() - assert ret0 is ret0_again - assert ret0 is not ret1 +def test_explicit_ivy_framework_handles(backend_fw): + if backend_fw == "numpy": + # Numpy is the conflicting framework being tested against + pytest.skip() - # call ivy.cache_fn repeatedly, the new cached functions - # each use the same global dict - ret0 = ivy.cache_fn(func)(0) - ret0_again = ivy.cache_fn(func)(0) - ret1 = ivy.cache_fn(func)(1) + # set with explicit handle caught + ivy_exp = ivy.with_backend(backend_fw) + assert ivy_exp.current_backend_str() == backend_fw - assert ivy.to_numpy(ret0).item() == ivy.to_numpy(ret0_again).item() - assert ivy.to_numpy(ret0).item() != ivy.to_numpy(ret1).item() - assert ret0 is ret0_again - assert ret0 is not ret1 + # assert backend implemented function is accessible + assert "array" in ivy_exp.__dict__ + assert callable(ivy_exp.array) + # assert joint implemented function is also accessible + assert "cache_fn" in ivy_exp.__dict__ + assert callable(ivy_exp.cache_fn) -def test_framework_setting_with_threading(backend_fw): - if backend_fw == "jax": - # Numpy is the conflicting framework being tested against - pytest.skip() + # set global ivy to numpy + ivy.set_backend("numpy") - def thread_fn(): - x_ = jnp.array([0.0, 1.0, 2.0]) - ivy.set_backend("jax") - for _ in range(2000): - try: - ivy.mean(x_) - except TypeError: - return False - ivy.previous_backend() - return True + # assert the explicit handle is still unchanged + assert ivy.current_backend_str() == "numpy" + assert ivy_exp.current_backend_str() == backend_fw - # start jax loop thread - thread = threading.Thread(target=thread_fn) - thread.start() - time.sleep(0.01) - ivy.set_backend(backend_fw) - x = ivy.array([0.0, 1.0, 2.0]) - # start local original framework loop - for _ in range(2000): - ivy.mean(x) + # unset global ivy from numpy ivy.previous_backend() - assert not thread.join() def test_framework_setting_with_multiprocessing(backend_fw): @@ -969,200 +914,271 @@ def worker_fn(out_queue): assert output_queue.get_nowait() -def test_explicit_ivy_framework_handles(backend_fw): - if backend_fw == "numpy": +def test_framework_setting_with_threading(backend_fw): + if backend_fw == "jax": # Numpy is the conflicting framework being tested against pytest.skip() - # set with explicit handle caught - ivy_exp = ivy.with_backend(backend_fw) - assert ivy_exp.current_backend_str() == backend_fw + def thread_fn(): + x_ = jnp.array([0.0, 1.0, 2.0]) + ivy.set_backend("jax") + for _ in range(2000): + try: + ivy.mean(x_) + except TypeError: + return False + ivy.previous_backend() + return True + + # start jax loop thread + thread = threading.Thread(target=thread_fn) + thread.start() + time.sleep(0.01) + ivy.set_backend(backend_fw) + x = ivy.array([0.0, 1.0, 2.0]) + # start local original framework loop + for _ in range(2000): + ivy.mean(x) + ivy.previous_backend() + assert not thread.join() + - # assert backend implemented function is accessible - assert "array" in ivy_exp.__dict__ - assert callable(ivy_exp.array) +# function_supported_devices_and_dtypes +@pytest.mark.parametrize( + "func", + [_composition_1, _composition_2], +) +def test_function_supported_device_and_dtype(func, backend_fw): + with BackendHandler.update_backend(backend_fw) as ivy_backend: + res = ivy_backend.function_supported_devices_and_dtypes(func, recurse=True) + exp = {"cpu": func.test_unsupported_devices_and_dtypes.copy()["cpu"]} + for dev in exp: + exp[dev] = tuple( + set(ivy.valid_dtypes).difference(exp[dev][ivy.current_backend_str()]) + ) - # assert joint implemented function is also accessible - assert "cache_fn" in ivy_exp.__dict__ - assert callable(ivy_exp.cache_fn) + all_key = set(res.keys()).union(set(exp.keys())) + for key in all_key: + assert key in res + assert key in exp + assert set(res[key]) == set(exp[key]) - # set global ivy to numpy - ivy.set_backend("numpy") - # assert the explicit handle is still unchanged - assert ivy.current_backend_str() == "numpy" - assert ivy_exp.current_backend_str() == backend_fw +# function_unsupported_devices_and_dtypes +@pytest.mark.parametrize( + "func", + [_composition_1, _composition_2], +) +def test_function_unsupported_devices(func, backend_fw): + with BackendHandler.update_backend(backend_fw) as ivy_backend: + res = ivy_backend.function_unsupported_devices_and_dtypes(func) + exp = func.test_unsupported_devices_and_dtypes.copy() + for dev in exp: + exp[dev] = exp[dev][backend_fw] + devs = list(exp.keys()) + for dev in devs: + if len(exp[dev]) == 0: + exp.pop(dev) - # unset global ivy from numpy - ivy.previous_backend() + all_key = set(res.keys()).union(set(exp.keys())) + for key in all_key: + assert key in res + assert key in exp + assert set(res[key]) == set(exp[key]) -# ToDo: re-add this test once ivy.get_backend is working correctly, with the returned -# ivy handle having no dependence on the globally set ivy -# @handle_cmd_line_args -# -# def test_class_ivy_handles(device, call): -# -# if call is helpers.np_call: -# # Numpy is the conflicting framework being tested against -# pytest.skip() -# -# class ArrayGen: -# def __init__(self, ivyh): -# self._ivy = ivyh -# -# def get_array(self): -# return self._ivy.array([0.0, 1.0, 2.0], dtype="float32", device=device) -# -# # create instance -# ag = ArrayGen(ivy.get_backend()) -# -# # create array from array generator -# x = ag.get_array() -# -# # verify this is not a numpy array -# assert not isinstance(x, np.ndarray) -# -# # change global framework to numpy -# ivy.set_backend("numpy") -# -# # create another array from array generator -# x = ag.get_array() -# -# # verify this is not still a numpy array -# assert not isinstance(x, np.ndarray) +# gather +@handle_test( + fn_tree="functional.ivy.gather", + params_indices_others=helpers.array_indices_axis( + array_dtypes=helpers.get_dtypes("numeric"), + indices_dtypes=["int32", "int64"], + min_num_dims=1, + max_num_dims=5, + min_dim_size=1, + max_dim_size=10, + ), +) +def test_gather(params_indices_others, test_flags, backend_fw, fn_name, on_device): + dtypes, params, indices, axis, batch_dims = params_indices_others + helpers.test_function( + input_dtypes=dtypes, + test_flags=test_flags, + on_device=on_device, + backend_to_test=backend_fw, + fn_name=fn_name, + xs_grad_idxs=[[0, 0]], + params=params, + indices=indices, + axis=axis, + batch_dims=batch_dims, + ) -# einops_rearrange +# gather_nd @handle_test( - fn_tree="functional.ivy.einops_rearrange", - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), + fn_tree="functional.ivy.gather_nd", + params_n_ndindices_batch_dims=array_and_ndindices_batch_dims( + array_dtypes=helpers.get_dtypes("numeric"), + indices_dtypes=["int32", "int64"], allow_inf=False, - min_num_dims=4, - max_num_dims=4, - min_dim_size=2, - max_dim_size=2, - min_value=-1e05, - max_value=1e05, - ).filter( - lambda x: (ivy.array([x[1][0]], dtype="float32").shape[2] % 2 == 0) - and (ivy.array([x[1][0]], dtype="float32").shape[3] % 2 == 0) - and (x[0][0] not in ["float16", "bfloat16"]) - ), - pattern_and_axes_lengths=st.sampled_from( - [ - ("b h w c -> b h w c", {}), - ("b h w c -> (b h) w c", {}), - ("b h w c -> b c h w", {}), - ("b h w c -> h (b w) c", {}), - ("b h w c -> b (c h w)", {}), - ("b (h1 h) (w1 w) c -> (b h1 w1) h w c", {"h1": 2, "w1": 2}), - ("b (h h1) (w w1) c -> b h w (c h1 w1)", {"h1": 2, "w1": 2}), - ] ), ) -def test_einops_rearrange( - dtype_x, pattern_and_axes_lengths, test_flags, backend_fw, fn_name, on_device +def test_gather_nd( + params_n_ndindices_batch_dims, test_flags, backend_fw, fn_name, on_device ): - pattern, axes_lengths = pattern_and_axes_lengths - dtype, x = dtype_x + dtypes, params, ndindices, batch_dims = params_n_ndindices_batch_dims helpers.test_function( - input_dtypes=dtype, + input_dtypes=dtypes, test_flags=test_flags, on_device=on_device, backend_to_test=backend_fw, fn_name=fn_name, - x=x[0], - pattern=pattern, - **axes_lengths, + xs_grad_idxs=[[0, 0]], + params=params, + indices=ndindices, + batch_dims=batch_dims, ) -# einops_reduce +def test_get_all_arrays_in_memory(): + return + + +# get_item +# TODO: add container and array instance methods @handle_test( - fn_tree="functional.ivy.einops_reduce", - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - allow_inf=False, - min_num_dims=4, - max_num_dims=4, - min_dim_size=2, - max_dim_size=2, - min_value=-1e05, - max_value=1e05, - ).filter( - lambda x: (ivy.array([x[1][0]], dtype="float32").shape[2] % 2 == 0) - and (ivy.array([x[1][0]], dtype="float32").shape[3] % 2 == 0) - and (x[0][0] not in ["float16", "bfloat16"]) - ), - pattern_and_axes_lengths=st.sampled_from( - [ - ("b c (h1 h2) (w1 w2) -> b c h1 w1", {"h2": 2, "w2": 2}), - ] + fn_tree="functional.ivy.get_item", + ground_truth_backend="numpy", + dtypes_x_query=helpers.dtype_array_query( + available_dtypes=helpers.get_dtypes("valid"), ), - floattypes=helpers.get_dtypes("float"), - reduction=st.sampled_from(["min", "max", "sum", "mean", "prod"]), + copy=st.booleans(), + test_with_out=st.just(False), + test_gradients=st.just(False), + test_instance_method=st.just(False), + container_flags=st.just([False]), ) -def test_einops_reduce( - *, - dtype_x, - pattern_and_axes_lengths, - floattypes, - reduction, +def test_get_item( + dtypes_x_query, + copy, test_flags, backend_fw, fn_name, on_device, ): - pattern, axes_lengths = pattern_and_axes_lengths - dtype, x = dtype_x - if (reduction in ["mean", "prod"]) and (dtype not in floattypes): - dtype = ["float32"] - # torch computes min and max differently and leads to inconsistent gradients - if backend_fw == "torch" and reduction in ["min", "max"]: - test_flags.test_gradients = False + dtypes, x, query = dtypes_x_query + try: + helpers.test_function( + input_dtypes=dtypes, + test_flags=test_flags, + on_device=on_device, + backend_to_test=backend_fw, + fn_name=fn_name, + x=x, + query=query, + copy=copy, + ) + except ivy.utils.exceptions.IvyBackendException as e: + if backend_fw == "paddle" and "only supports access to dimension 0 to 9" in e: + assume(False) + else: + raise + + +# get_min_base +def test_get_min_base(): + assert ivy.min_base == 1e-5 + + +# get_min_denominator +def test_get_min_denominator(): + assert ivy.min_denominator == 1e-12 + + +# get_num_dims +@handle_test( + fn_tree="functional.ivy.get_num_dims", + x0_n_x1_n_res=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid") + ), + as_array=st.booleans(), + test_with_out=st.just(False), + test_gradients=st.just(False), +) +def test_get_num_dims( + x0_n_x1_n_res, as_array, test_flags, backend_fw, fn_name, on_device +): + dtype, x = x0_n_x1_n_res helpers.test_function( input_dtypes=dtype, test_flags=test_flags, on_device=on_device, backend_to_test=backend_fw, - fn_name=fn_name, - rtol_=1e-1, - atol_=1e-1, + fn_name=fn_name, x=x[0], - pattern=pattern, - reduction=reduction, - **axes_lengths, + as_array=as_array, ) -# einops_repeat +# get_queue_timeout +@given( + x=st.floats(allow_nan=False, allow_infinity=False), +) +def test_get_queue_timeout(x): + ivy.set_queue_timeout(x) + ret = ivy.queue_timeout + assert ret == x + + +# get_referrers_recursive +def test_get_referrers_recursive(): + class SomeClass: + def __init__(self): + self.x = [1, 2] + self.y = [self.x] + + some_obj = SomeClass() + refs = ivy.get_referrers_recursive(some_obj.x) + ref_keys = refs.keys() + assert len(ref_keys) == 3 + assert "repr" in ref_keys + assert refs["repr"] == "[1,2]" + y_id = str(id(some_obj.y)) + y_refs = refs[y_id] + assert y_refs["repr"] == "[[1,2]]" + some_obj_dict_id = str(id(some_obj.__dict__)) + assert y_refs[some_obj_dict_id] == "tracked" + dict_refs = refs[some_obj_dict_id] + assert dict_refs["repr"] == "{'x':[1,2],'y':[[1,2]]}" + some_obj_id = str(id(some_obj)) + some_obj_refs = dict_refs[some_obj_id] + assert some_obj_refs["repr"] == str(some_obj).replace(" ", "") + assert len(some_obj_refs) == 1 + + +# get_tmp_dir +def test_get_tmp_dir(): + ret = ivy.tmp_dir + assert ret == "/tmp" + + +# has_nans @handle_test( - fn_tree="functional.ivy.einops_repeat", - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - allow_inf=False, - min_num_dims=2, - max_num_dims=2, - min_dim_size=2, - ), - pattern_and_axes_lengths=st.sampled_from( - [ - ("h w -> h w repeat", {"repeat": 2}), - ("h w -> (repeat h) w", {"repeat": 2}), - ("h w -> h (repeat w)", {"repeat": 2}), - ("h w -> (h h2) (w w2)", {"h2": 2, "w2": 2}), - ("h w -> w h", {}), - ] + fn_tree="functional.ivy.has_nans", + x_val_and_dtypes=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + allow_nan=True, + allow_inf=True, ), + include_infs=st.booleans(), + test_with_out=st.just(False), + test_gradients=st.just(False), ) -def test_einops_repeat( - *, dtype_x, pattern_and_axes_lengths, test_flags, backend_fw, fn_name, on_device +def test_has_nans( + *, x_val_and_dtypes, include_infs, test_flags, backend_fw, fn_name, on_device ): - pattern, axes_lengths = pattern_and_axes_lengths - dtype, x = dtype_x - assume("uint16" not in dtype) + dtype, x = x_val_and_dtypes helpers.test_function( input_dtypes=dtype, test_flags=test_flags, @@ -1170,21 +1186,10 @@ def test_einops_repeat( backend_to_test=backend_fw, fn_name=fn_name, x=x[0], - pattern=pattern, - **axes_lengths, + include_infs=include_infs, ) -# container types -def test_container_types(): - cont_types = ivy.container_types() - assert isinstance(cont_types, list) - for cont_type in cont_types: - assert hasattr(cont_type, "keys") - assert hasattr(cont_type, "values") - assert hasattr(cont_type, "items") - - def test_inplace_arrays_supported(backend_fw): with BackendHandler.update_backend(backend_fw) as ivy_backend: if backend_fw in ["numpy", "torch"]: @@ -1195,53 +1200,6 @@ def test_inplace_arrays_supported(backend_fw): raise Exception("Unrecognized framework") -def test_inplace_variables_supported(backend_fw): - with BackendHandler.update_backend(backend_fw) as ivy_backend: - if backend_fw in ["numpy", "torch", "tensorflow"]: - assert ivy_backend.inplace_variables_supported() - elif backend_fw in ["jax", "paddle"]: - assert not ivy_backend.inplace_variables_supported() - else: - raise Exception("Unrecognized framework") - - -# inplace_update -@handle_test( - fn_tree="functional.ivy.inplace_update", - x_val_and_dtypes=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, - shared_dtype=True, - ), - keep_x_dtype=st.booleans(), -) -def test_inplace_update( - x_val_and_dtypes, keep_x_dtype, test_flags, on_device, backend_fw -): - with BackendHandler.update_backend(backend_fw) as ivy_backend: - dtype = x_val_and_dtypes[0][0] - if dtype in ivy_backend.function_unsupported_dtypes(ivy_backend.inplace_update): - return - x, val = x_val_and_dtypes[1] - x = ivy_backend.array(x.tolist(), dtype=dtype, device=on_device) - val = ivy_backend.array(val.tolist(), dtype=dtype, device=on_device) - if (not test_flags.as_variable and ivy_backend.inplace_arrays_supported()) or ( - test_flags.as_variable and ivy_backend.inplace_variables_supported() - ): - if keep_x_dtype: - x_dtype = x.dtype - x_inplace = ivy_backend.inplace_update(x, val, keep_input_dtype=True) - assert x_dtype == x_inplace.dtype - else: - x_inplace = ivy_backend.inplace_update(x, val) - assert id(x_inplace) == id(x) - x = helpers.flatten_and_to_np(backend=backend_fw, ret=x) - val = helpers.flatten_and_to_np(backend=backend_fw, ret=val) - helpers.value_test( - backend=backend_fw, ret_np_flat=x, ret_np_from_gt_flat=val - ) - - # inplace_decrement @handle_test( fn_tree="functional.ivy.inplace_decrement", @@ -1313,40 +1271,56 @@ def test_inplace_increment(x_val_and_dtypes, test_flags, on_device, backend_fw): ) -# is_ivy_array +# inplace_update @handle_test( - fn_tree="functional.ivy.is_ivy_array", + fn_tree="functional.ivy.inplace_update", x_val_and_dtypes=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid") + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, + shared_dtype=True, ), - exclusive=st.booleans(), - ground_truth_backend="numpy", - as_variable_flags=st.just([False]), - test_with_out=st.just(False), - test_gradients=st.just(False), + keep_x_dtype=st.booleans(), ) -def test_is_ivy_array( - *, x_val_and_dtypes, exclusive, test_flags, backend_fw, fn_name, on_device +def test_inplace_update( + x_val_and_dtypes, keep_x_dtype, test_flags, on_device, backend_fw ): - dtype, x = x_val_and_dtypes - # as_variable=False as the result can't be consistent across backends - if test_flags.container[0]: - # container instance methods should also not be tested - test_flags.instance_method = False - helpers.test_function( - input_dtypes=dtype, - test_flags=test_flags, - on_device=on_device, - backend_to_test=backend_fw, - fn_name=fn_name, - x=x[0], - exclusive=exclusive, - ) + with BackendHandler.update_backend(backend_fw) as ivy_backend: + dtype = x_val_and_dtypes[0][0] + if dtype in ivy_backend.function_unsupported_dtypes(ivy_backend.inplace_update): + return + x, val = x_val_and_dtypes[1] + x = ivy_backend.array(x.tolist(), dtype=dtype, device=on_device) + val = ivy_backend.array(val.tolist(), dtype=dtype, device=on_device) + if (not test_flags.as_variable and ivy_backend.inplace_arrays_supported()) or ( + test_flags.as_variable and ivy_backend.inplace_variables_supported() + ): + if keep_x_dtype: + x_dtype = x.dtype + x_inplace = ivy_backend.inplace_update(x, val, keep_input_dtype=True) + assert x_dtype == x_inplace.dtype + else: + x_inplace = ivy_backend.inplace_update(x, val) + assert id(x_inplace) == id(x) + x = helpers.flatten_and_to_np(backend=backend_fw, ret=x) + val = helpers.flatten_and_to_np(backend=backend_fw, ret=val) + helpers.value_test( + backend=backend_fw, ret_np_flat=x, ret_np_from_gt_flat=val + ) -# is_native_array +def test_inplace_variables_supported(backend_fw): + with BackendHandler.update_backend(backend_fw) as ivy_backend: + if backend_fw in ["numpy", "torch", "tensorflow"]: + assert ivy_backend.inplace_variables_supported() + elif backend_fw in ["jax", "paddle"]: + assert not ivy_backend.inplace_variables_supported() + else: + raise Exception("Unrecognized framework") + + +# is_array @handle_test( - fn_tree="functional.ivy.is_native_array", + fn_tree="functional.ivy.is_array", x_val_and_dtypes=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid") ), @@ -1356,8 +1330,8 @@ def test_is_ivy_array( test_with_out=st.just(False), test_gradients=st.just(False), ) -def test_is_native_array( - *, x_val_and_dtypes, test_flags, exclusive, backend_fw, fn_name, on_device +def test_is_array( + x_val_and_dtypes, exclusive, test_flags, backend_fw, fn_name, on_device ): dtype, x = x_val_and_dtypes # as_variable=False as the result can't be consistent across backends @@ -1375,20 +1349,20 @@ def test_is_native_array( ) -# is_array +# is_ivy_array @handle_test( - fn_tree="functional.ivy.is_array", + fn_tree="functional.ivy.is_ivy_array", x_val_and_dtypes=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid") ), exclusive=st.booleans(), + ground_truth_backend="numpy", as_variable_flags=st.just([False]), - container_flags=st.just([False]), test_with_out=st.just(False), test_gradients=st.just(False), ) -def test_is_array( - x_val_and_dtypes, exclusive, test_flags, backend_fw, fn_name, on_device +def test_is_ivy_array( + *, x_val_and_dtypes, exclusive, test_flags, backend_fw, fn_name, on_device ): dtype, x = x_val_and_dtypes # as_variable=False as the result can't be consistent across backends @@ -1428,118 +1402,78 @@ def test_is_ivy_container(x_val_and_dtypes, test_flags, backend_fw, fn_name, on_ ) -# all_equal -@handle_test( - fn_tree="functional.ivy.all_equal", - dtypes_and_xs=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - num_arrays=helpers.ints(min_value=2, max_value=10), - min_num_dims=1, - ), - equality_matrix=st.booleans(), - test_with_out=st.just(False), - test_gradients=st.just(False), -) -def test_all_equal( - dtypes_and_xs, equality_matrix, test_flags, backend_fw, fn_name, on_device -): - dtypes, arrays = dtypes_and_xs - kw = {} - i = 0 - for x_ in arrays: - kw["x{}".format(i)] = x_ - i += 1 - test_flags.num_positional_args = len(arrays) - helpers.test_function( - input_dtypes=dtypes, - test_flags=test_flags, - on_device=on_device, - backend_to_test=backend_fw, - fn_name=fn_name, - **kw, - equality_matrix=equality_matrix, - ) - - -# clip_matrix_norm +# is_native_array @handle_test( - fn_tree="functional.ivy.clip_matrix_norm", - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_num_dims=2, - max_num_dims=5, - min_dim_size=1, - max_dim_size=5, - min_value=-10, - max_value=10, - abs_smallest_val=1e-4, + fn_tree="functional.ivy.is_native_array", + x_val_and_dtypes=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid") ), - max_norm=st.floats(min_value=0.137, max_value=1e05), - p=st.sampled_from([1, 2, float("inf"), "fro", "nuc"]), + exclusive=st.booleans(), + as_variable_flags=st.just([False]), + container_flags=st.just([False]), + test_with_out=st.just(False), + test_gradients=st.just(False), ) -def test_clip_matrix_norm( - dtype_x, max_norm, p, test_flags, backend_fw, fn_name, on_device +def test_is_native_array( + *, x_val_and_dtypes, test_flags, exclusive, backend_fw, fn_name, on_device ): - dtype, x = dtype_x + dtype, x = x_val_and_dtypes + # as_variable=False as the result can't be consistent across backends + if test_flags.container[0]: + # container instance methods should also not be tested + test_flags.instance_method = False helpers.test_function( input_dtypes=dtype, test_flags=test_flags, on_device=on_device, backend_to_test=backend_fw, fn_name=fn_name, - rtol_=1e-2, - atol_=1e-2, x=x[0], - max_norm=max_norm, - p=p, + exclusive=exclusive, ) -# value_is_nan @handle_test( - fn_tree="functional.ivy.value_is_nan", - val_dtype=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - max_dim_size=1, - max_num_dims=1, - allow_nan=True, - allow_inf=True, - ), - include_infs=st.booleans(), + fn_tree="functional.ivy.isin", + assume_unique_and_dtype_and_x=_isin_data_generation_helper(), + invert=st.booleans(), + ground_truth_backend="numpy", test_with_out=st.just(False), test_gradients=st.just(False), ) -def test_value_is_nan( - *, val_dtype, include_infs, test_flags, backend_fw, fn_name, on_device +def test_isin( + assume_unique_and_dtype_and_x, + invert, + test_flags, + backend_fw, + on_device, ): - dtype, val = val_dtype + assume_unique, x_and_dtype = assume_unique_and_dtype_and_x + dtypes, values = x_and_dtype + elements, test_elements = values helpers.test_function( - input_dtypes=dtype, + input_dtypes=dtypes, test_flags=test_flags, on_device=on_device, backend_to_test=backend_fw, - fn_name=fn_name, - x=val[0], - include_infs=include_infs, + fn_name="isin", + elements=elements, + test_elements=test_elements, + invert=invert, + assume_unique=assume_unique, ) -# has_nans @handle_test( - fn_tree="functional.ivy.has_nans", - x_val_and_dtypes=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - allow_nan=True, - allow_inf=True, - ), - include_infs=st.booleans(), + fn_tree="functional.ivy.itemsize", + x_and_dtype=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("valid")), + ground_truth_backend="numpy", + test_instance_method=st.just(False), test_with_out=st.just(False), test_gradients=st.just(False), ) -def test_has_nans( - *, x_val_and_dtypes, include_infs, test_flags, backend_fw, fn_name, on_device -): - dtype, x = x_val_and_dtypes +def test_itemsize(x_and_dtype, test_flags, backend_fw, fn_name, on_device): + dtype, x = x_and_dtype helpers.test_function( input_dtypes=dtype, test_flags=test_flags, @@ -1547,165 +1481,181 @@ def test_has_nans( backend_to_test=backend_fw, fn_name=fn_name, x=x[0], - include_infs=include_infs, ) -# try_else_none -@given( - x=st.booleans(), -) -def test_try_else_none(x): - if x: - fn = ivy.try_else_none(lambda: True) - assert fn() is True - else: - fn = ivy.try_else_none(lambda x: x) - assert fn is None - +# match_kwargs +@given(allow_duplicates=st.booleans()) +def test_match_kwargs(allow_duplicates): + def func_a(a, b, c=2): + pass -@given( - x_n_value=st.sampled_from( - [ - [ivy.value_is_nan, ["x", "include_infs"]], - [ivy.clip_matrix_norm, ["x", "max_norm", "p", "out"]], - ] - ) -) -def test_arg_names(x_n_value): - x, value = x_n_value - ret = ivy.arg_names(x) - assert ret == value + def func_b(a, d, e=5): + return None + class ClassA: + def __init__(self, c, f, g=3): + pass -def _composition_1(): - return ivy.relu().argmax() + kwargs = {"a": 0, "b": 1, "c": 2, "d": 3, "e": 4, "f": 5, "g": 6} + kwfa, kwfb, kwca = ivy.match_kwargs( + kwargs, func_a, func_b, ClassA, allow_duplicates=allow_duplicates + ) + if allow_duplicates: + assert kwfa == {"a": 0, "b": 1, "c": 2} + assert kwfb == {"a": 0, "d": 3, "e": 4} + assert kwca == {"c": 2, "f": 5, "g": 6} + else: + assert kwfa == {"a": 0, "b": 1, "c": 2} + assert kwfb == {"d": 3, "e": 4} + assert kwca == {"f": 5, "g": 6} -_composition_1.test_unsupported_devices_and_dtypes = { - "cpu": { - "numpy": ("bfloat16",), - "jax": ("complex64", "complex128"), - "tensorflow": ("complex64", "complex128"), - "torch": ( - "uint16", - "uint32", - "uint64", - "float16", - "complex64", - "complex128", - ), - "paddle": ("uint16", "uint32", "uint64", "bfloat16", "complex64", "complex128"), - }, - "gpu": { - "numpy": ivy.all_dtypes, - "jax": ("complex64", "complex128"), - "tensorflow": ("complex64", "complex128"), - "torch": ("complex64", "float16", "uint16", "complex128", "uint64", "uint32"), - "paddle": ivy.all_dtypes, - }, - "tpu": { - "numpy": ivy.all_dtypes, - "jax": ivy.all_dtypes, - "tensorflow": ivy.all_dtypes, - "torch": ivy.all_dtypes, - "paddle": ivy.all_dtypes, - }, -} +def test_num_arrays_in_memory(): + return -def _composition_2(): - return ivy.ceil() or ivy.linspace() +def test_print_all_arrays_in_memory(): + return -_composition_2.test_unsupported_devices_and_dtypes = { - "cpu": { - "numpy": ("bfloat16", "complex64", "complex128"), - "jax": ("complex64", "complex128"), - "tensorflow": ("complex64", "complex128"), - "torch": ("uint16", "uint32", "uint64", "float16", "complex64", "complex128"), - "paddle": ( - "uint16", - "uint32", - "uint64", - "bfloat16", - ), - }, - "gpu": { - "numpy": ivy.all_dtypes, - "jax": ("complex64", "complex128"), - "tensorflow": ("complex64", "complex128"), - "torch": ("uint16", "uint64", "uint32", "complex128", "float16", "complex64"), - "paddle": ivy.all_dtypes, - }, - "tpu": { - "numpy": ivy.all_dtypes, - "jax": ivy.all_dtypes, - "tensorflow": ivy.all_dtypes, - "torch": ivy.all_dtypes, - "paddle": ivy.all_dtypes, - }, -} +# scatter_flat +@handle_test( + fn_tree="functional.ivy.scatter_flat", + x=st.integers(min_value=1, max_value=10).flatmap( + lambda n: st.tuples( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_num_dims=1, + max_num_dims=1, + min_dim_size=n, + max_dim_size=n, + ), + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("integer"), + min_value=0, + max_value=max(n - 1, 0), + min_num_dims=1, + max_num_dims=1, + min_dim_size=n, + max_dim_size=n, + ).filter(lambda d_n_v: len(set(d_n_v[1][0])) == len(d_n_v[1][0])), + st.integers(min_value=n, max_value=n), + ) + ), + reduction=st.sampled_from(["sum", "min", "max", "replace"]), + ground_truth_backend="tensorflow", +) +def test_scatter_flat(x, reduction, test_flags, backend_fw, fn_name, on_device): + # scatter_flat throws an error while computing gradients for tensorflow + # this has been fixed in the newer versions of tensorflow (2.10.0 onwards) + if backend_fw == "tensorflow": + grad_support_version = [2, 10, 0] + k = 0 + for number in [int(s) for s in tf.__version__.split(".") if s.isdigit()]: + if k > len(grad_support_version): + break + if number < grad_support_version[k]: + test_flags.test_gradients = False + k += 1 + (val_dtype, vals), (ind_dtype, ind), size = x + helpers.test_function( + input_dtypes=ind_dtype + val_dtype, + test_flags=test_flags, + xs_grad_idxs=[[0, 1]], + on_device=on_device, + backend_to_test=backend_fw, + fn_name=fn_name, + indices=ind[0], + updates=vals[0], + size=size, + reduction=reduction, + ) -# function_supported_devices_and_dtypes -@pytest.mark.parametrize( - "func", - [_composition_1, _composition_2], +# scatter_nd +@handle_test( + fn_tree="functional.ivy.scatter_nd", + x=_values_and_ndindices( + # ToDo: needs support for boolean arrays + array_dtypes=helpers.get_dtypes("numeric"), + indices_dtypes=["int32", "int64"], + x_min_value=0, + x_max_value=0, + min_num_dims=2, + allow_inf=False, + ), + reduction=st.sampled_from(["sum", "min", "max", "replace"]), + test_gradients=st.just(False), ) -def test_function_supported_device_and_dtype(func, backend_fw): - with BackendHandler.update_backend(backend_fw) as ivy_backend: - res = ivy_backend.function_supported_devices_and_dtypes(func, recurse=True) - exp = {"cpu": func.test_unsupported_devices_and_dtypes.copy()["cpu"]} - for dev in exp: - exp[dev] = tuple( - set(ivy.valid_dtypes).difference(exp[dev][ivy.current_backend_str()]) - ) - - all_key = set(res.keys()).union(set(exp.keys())) - for key in all_key: - assert key in res - assert key in exp - assert set(res[key]) == set(exp[key]) - +def test_scatter_nd(x, reduction, test_flags, backend_fw, fn_name, on_device): + (val_dtype, ind_dtype, update_dtype), vals, ind, updates = x + shape = vals.shape + helpers.test_function( + input_dtypes=[ind_dtype, update_dtype], + test_flags=test_flags, + on_device=on_device, + backend_to_test=backend_fw, + fn_name=fn_name, + indices=np.asarray(ind, dtype=ind_dtype), + updates=updates, + shape=shape, + reduction=reduction, + ) -# function_unsupported_devices_and_dtypes -@pytest.mark.parametrize( - "func", - [_composition_1, _composition_2], -) -def test_function_unsupported_devices(func, backend_fw): - with BackendHandler.update_backend(backend_fw) as ivy_backend: - res = ivy_backend.function_unsupported_devices_and_dtypes(func) - exp = func.test_unsupported_devices_and_dtypes.copy() - for dev in exp: - exp[dev] = exp[dev][backend_fw] - devs = list(exp.keys()) - for dev in devs: - if len(exp[dev]) == 0: - exp.pop(dev) - all_key = set(res.keys()).union(set(exp.keys())) - for key in all_key: - assert key in res - assert key in exp - assert set(res[key]) == set(exp[key]) +# Tests # +# ------# -# Still to Add # -# ---------------# +@given(fw_str=st.sampled_from(["numpy", "jax", "torch", "tensorflow"])) +def test_set_framework(fw_str): + ivy.set_backend(fw_str) + ivy.previous_backend() -@given(fw=st.sampled_from(["torch", "tensorflow", "numpy", "jax"])) -def test_current_backend_str(fw): - ivy.set_backend(fw) - assert ivy.current_backend_str() == fw - ivy.previous_backend() +# set_item +# TODO: add container and array instance methods +@handle_test( + fn_tree="functional.ivy.set_item", + ground_truth_backend="numpy", + dtypes_x_query_val=helpers.dtype_array_query_val( + available_dtypes=helpers.get_dtypes("valid"), + ), + copy=st.booleans(), + test_with_out=st.just(False), + test_gradients=st.just(False), + test_instance_method=st.just(False), + container_flags=st.just([False]), +) +def test_set_item( + dtypes_x_query_val, + copy, + test_flags, + backend_fw, + fn_name, + on_device, +): + dtypes, x, query, val = dtypes_x_query_val + helpers.test_function( + input_dtypes=dtypes, + test_flags=test_flags, + on_device=on_device, + backend_to_test=backend_fw, + fn_name=fn_name, + rtol_=1e-03, # needed only for the paddle backend + x=x, + query=query, + val=val, + copy=copy, + ) -# get_min_denominator -def test_get_min_denominator(): - assert ivy.min_denominator == 1e-12 +# set_min_base +@given(x=st.floats(allow_nan=False, allow_infinity=False)) +def test_set_min_base(x): + ivy.set_min_base(x) + assert ivy.min_base == x # set_min_denominator @@ -1715,16 +1665,47 @@ def test_set_min_denominator(x): assert ivy.min_denominator == x -# get_min_base -def test_get_min_base(): - assert ivy.min_base == 1e-5 +# set_queue_timeout +@given( + x=st.floats(allow_nan=False, allow_infinity=False), +) +def test_set_queue_timeout(x): + ivy.set_queue_timeout(x) + ret = ivy.queue_timeout + assert ret == x -# set_min_base -@given(x=st.floats(allow_nan=False, allow_infinity=False)) -def test_set_min_base(x): - ivy.set_min_base(x) - assert ivy.min_base == x +# set_tmp_dir +def test_set_tmp_dir(): + ivy.set_tmp_dir("/new_dir") + ret = ivy.tmp_dir + assert ret == "/new_dir" + + +# shape +# TODO: add container and array methods +@handle_test( + fn_tree="functional.ivy.shape", + x0_n_x1_n_res=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid") + ), + as_array=st.booleans(), + test_with_out=st.just(False), + test_instance_method=st.just(False), + test_gradients=st.just(False), +) +def test_shape(x0_n_x1_n_res, as_array, test_flags, backend_fw, fn_name, on_device): + dtype, x = x0_n_x1_n_res + # instance_method=False because the shape property would overwrite the shape method + helpers.test_function( + input_dtypes=dtype, + test_flags=test_flags, + on_device=on_device, + backend_to_test=backend_fw, + fn_name=fn_name, + x=x[0], + as_array=as_array, + ) # stable_divide @@ -1754,15 +1735,6 @@ def test_stable_divide(*, dtype_and_x, test_flags, backend_fw, fn_name, on_devic ) -@st.composite # ToDo remove when helpers.get_dtypes supports it -def _get_valid_numeric_no_unsigned(draw): - return list( - set(draw(helpers.get_dtypes("numeric"))).difference( - draw(helpers.get_dtypes("unsigned")) - ) - ) - - # stable_pow @handle_test( fn_tree="functional.ivy.stable_pow", @@ -1791,49 +1763,23 @@ def test_stable_pow( ) -def test_get_all_arrays_in_memory(): - return - - -def test_num_arrays_in_memory(): - return - - -def test_print_all_arrays_in_memory(): - return - - -# set_queue_timeout -@given( - x=st.floats(allow_nan=False, allow_infinity=False), -) -def test_set_queue_timeout(x): - ivy.set_queue_timeout(x) - ret = ivy.queue_timeout - assert ret == x - - -# get_queue_timeout -@given( - x=st.floats(allow_nan=False, allow_infinity=False), +@handle_test( + fn_tree="functional.ivy.strides", + x_and_dtype=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("valid")), + test_instance_method=st.just(False), + test_with_out=st.just(False), + test_gradients=st.just(False), ) -def test_get_queue_timeout(x): - ivy.set_queue_timeout(x) - ret = ivy.queue_timeout - assert ret == x - - -# get_tmp_dir -def test_get_tmp_dir(): - ret = ivy.tmp_dir - assert ret == "/tmp" - - -# set_tmp_dir -def test_set_tmp_dir(): - ivy.set_tmp_dir("/new_dir") - ret = ivy.tmp_dir - assert ret == "/new_dir" +def test_strides(x_and_dtype, test_flags, backend_fw, fn_name, on_device): + dtype, x = x_and_dtype + helpers.test_function( + input_dtypes=dtype, + test_flags=test_flags, + on_device=on_device, + backend_to_test=backend_fw, + fn_name=fn_name, + x=x[0], + ) @handle_test( @@ -1844,37 +1790,58 @@ def test_set_tmp_dir(): test_with_out=st.just(False), test_gradients=st.just(False), ) -def test_supports_inplace_updates( - x_val_and_dtypes, test_flags, backend_fw, fn_name, on_device -): - dtype, x = x_val_and_dtypes +def test_supports_inplace_updates( + x_val_and_dtypes, test_flags, backend_fw, fn_name, on_device +): + dtype, x = x_val_and_dtypes + helpers.test_function( + input_dtypes=dtype, + test_flags=test_flags, + on_device=on_device, + backend_to_test=backend_fw, + fn_name=fn_name, + test_values=False, + x=x[0], + ) + + +# to_list +@handle_test( + fn_tree="functional.ivy.to_list", + x0_n_x1_n_res=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + large_abs_safety_factor=20, + ), + test_with_out=st.just(False), + test_gradients=st.just(False), +) +def test_to_list(x0_n_x1_n_res, test_flags, backend_fw, fn_name, on_device): + dtype, x = x0_n_x1_n_res helpers.test_function( input_dtypes=dtype, test_flags=test_flags, on_device=on_device, backend_to_test=backend_fw, fn_name=fn_name, - test_values=False, x=x[0], ) +# to_numpy @handle_test( - fn_tree="functional.ivy.assert_supports_inplace", - x_val_and_dtypes=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid") + fn_tree="functional.ivy.to_numpy", + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), ), - ground_truth_backend="numpy", + copy=st.booleans(), test_with_out=st.just(False), test_gradients=st.just(False), ) -def test_assert_supports_inplace( - x_val_and_dtypes, test_flags, backend_fw, fn_name, on_device -): - dtype, x = x_val_and_dtypes - if backend_fw in ["tensorflow", "jax", "paddle"]: +def test_to_numpy(*, dtype_x, copy, test_flags, backend_fw, fn_name, on_device): + dtype, x = dtype_x + # torch throws an exception + if ivy.current_backend_str() == "torch" and not copy: return - assume("bfloat16" not in dtype) helpers.test_function( input_dtypes=dtype, test_flags=test_flags, @@ -1882,23 +1849,87 @@ def test_assert_supports_inplace( backend_to_test=backend_fw, fn_name=fn_name, x=x[0], + copy=copy, ) -def test_arg_info(): - return +# to_scalar +@handle_test( + fn_tree="functional.ivy.to_scalar", + x0_n_x1_n_res=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=1, + max_num_dims=1, + min_dim_size=1, + max_dim_size=1, + large_abs_safety_factor=20, + ), + test_with_out=st.just(False), + test_gradients=st.just(False), +) +def test_to_scalar(x0_n_x1_n_res, test_flags, backend_fw, fn_name, on_device): + dtype, x = x0_n_x1_n_res + helpers.test_function( + input_dtypes=dtype, + test_flags=test_flags, + on_device=on_device, + backend_to_test=backend_fw, + fn_name=fn_name, + x=x[0], + ) -def _fn1(x, y): - return ivy.matmul(x, y) +# try_else_none +@given( + x=st.booleans(), +) +def test_try_else_none(x): + if x: + fn = ivy.try_else_none(lambda: True) + assert fn() is True + else: + fn = ivy.try_else_none(lambda x: x) + assert fn is None -def _fn2(x, y): - return ivy.vecdot(x, y) +def test_use_within_use_framework(): + with ivy.functional.backends.numpy.use: + pass + with ivy.functional.backends.jax.use: + pass + with ivy.functional.backends.tensorflow.use: + pass + with ivy.functional.backends.torch.use: + pass -def _fn3(x, y): - ivy.add(x, y) +# value_is_nan +@handle_test( + fn_tree="functional.ivy.value_is_nan", + val_dtype=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + max_dim_size=1, + max_num_dims=1, + allow_nan=True, + allow_inf=True, + ), + include_infs=st.booleans(), + test_with_out=st.just(False), + test_gradients=st.just(False), +) +def test_value_is_nan( + *, val_dtype, include_infs, test_flags, backend_fw, fn_name, on_device +): + dtype, val = val_dtype + helpers.test_function( + input_dtypes=dtype, + test_flags=test_flags, + on_device=on_device, + backend_to_test=backend_fw, + fn_name=fn_name, + x=val[0], + include_infs=include_infs, + ) # vmap @@ -1973,89 +2004,61 @@ def test_vmap(func, dtype_and_arrays_and_axes, in_axes_as_cont, backend_fw): assert False, "One of the results is None while other isn't" -@st.composite -def _isin_data_generation_helper(draw): - assume_unique = draw(st.booleans()) - if assume_unique: - dtype_and_x = helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - num_arrays=2, - shared_dtype=True, - ).filter(lambda x: np.array_equal(x[1][0], np.unique(x[1][0]))) - else: - dtype_and_x = helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - num_arrays=2, - shared_dtype=True, - ) - return assume_unique, draw(dtype_and_x) - - -@handle_test( - fn_tree="functional.ivy.isin", - assume_unique_and_dtype_and_x=_isin_data_generation_helper(), - invert=st.booleans(), - ground_truth_backend="numpy", - test_with_out=st.just(False), - test_gradients=st.just(False), -) -def test_isin( - assume_unique_and_dtype_and_x, - invert, - test_flags, - backend_fw, - on_device, -): - assume_unique, x_and_dtype = assume_unique_and_dtype_and_x - dtypes, values = x_and_dtype - elements, test_elements = values - helpers.test_function( - input_dtypes=dtypes, - test_flags=test_flags, - on_device=on_device, - backend_to_test=backend_fw, - fn_name="isin", - elements=elements, - test_elements=test_elements, - invert=invert, - assume_unique=assume_unique, - ) - - -@handle_test( - fn_tree="functional.ivy.itemsize", - x_and_dtype=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("valid")), - ground_truth_backend="numpy", - test_instance_method=st.just(False), - test_with_out=st.just(False), - test_gradients=st.just(False), -) -def test_itemsize(x_and_dtype, test_flags, backend_fw, fn_name, on_device): - dtype, x = x_and_dtype - helpers.test_function( - input_dtypes=dtype, - test_flags=test_flags, - on_device=on_device, - backend_to_test=backend_fw, - fn_name=fn_name, - x=x[0], - ) - - -@handle_test( - fn_tree="functional.ivy.strides", - x_and_dtype=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("valid")), - test_instance_method=st.just(False), - test_with_out=st.just(False), - test_gradients=st.just(False), -) -def test_strides(x_and_dtype, test_flags, backend_fw, fn_name, on_device): - dtype, x = x_and_dtype - helpers.test_function( - input_dtypes=dtype, - test_flags=test_flags, - on_device=on_device, - backend_to_test=backend_fw, - fn_name=fn_name, - x=x[0], - ) +_composition_1.test_unsupported_devices_and_dtypes = { + "cpu": { + "numpy": ("bfloat16",), + "jax": ("complex64", "complex128"), + "tensorflow": ("complex64", "complex128"), + "torch": ( + "uint16", + "uint32", + "uint64", + "float16", + "complex64", + "complex128", + ), + "paddle": ("uint16", "uint32", "uint64", "bfloat16", "complex64", "complex128"), + }, + "gpu": { + "numpy": ivy.all_dtypes, + "jax": ("complex64", "complex128"), + "tensorflow": ("complex64", "complex128"), + "torch": ("complex64", "float16", "uint16", "complex128", "uint64", "uint32"), + "paddle": ivy.all_dtypes, + }, + "tpu": { + "numpy": ivy.all_dtypes, + "jax": ivy.all_dtypes, + "tensorflow": ivy.all_dtypes, + "torch": ivy.all_dtypes, + "paddle": ivy.all_dtypes, + }, +} +_composition_2.test_unsupported_devices_and_dtypes = { + "cpu": { + "numpy": ("bfloat16", "complex64", "complex128"), + "jax": ("complex64", "complex128"), + "tensorflow": ("complex64", "complex128"), + "torch": ("uint16", "uint32", "uint64", "float16", "complex64", "complex128"), + "paddle": ( + "uint16", + "uint32", + "uint64", + "bfloat16", + ), + }, + "gpu": { + "numpy": ivy.all_dtypes, + "jax": ("complex64", "complex128"), + "tensorflow": ("complex64", "complex128"), + "torch": ("uint16", "uint64", "uint32", "complex128", "float16", "complex64"), + "paddle": ivy.all_dtypes, + }, + "tpu": { + "numpy": ivy.all_dtypes, + "jax": ivy.all_dtypes, + "tensorflow": ivy.all_dtypes, + "torch": ivy.all_dtypes, + "paddle": ivy.all_dtypes, + }, +} diff --git a/ivy_tests/test_ivy/test_functional/test_core/test_gradients.py b/ivy_tests/test_ivy/test_functional/test_core/test_gradients.py index 89ec0dbce556f..7632845ff4768 100644 --- a/ivy_tests/test_ivy/test_functional/test_core/test_gradients.py +++ b/ivy_tests/test_ivy/test_functional/test_core/test_gradients.py @@ -11,6 +11,10 @@ from ivy_tests.test_ivy.helpers import handle_test, BackendHandler +# --- Helpers --- # +# --------------- # + + @st.composite def get_gradient_arguments_with_lr( draw, @@ -69,28 +73,106 @@ def get_gradient_arguments_with_lr( return dtypes, arrays, lr -# stop_gradient +# adam_step @handle_test( - fn_tree="functional.ivy.stop_gradient", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric") + fn_tree="functional.ivy.adam_step", + dtype_n_dcdw_n_mw_n_vw=get_gradient_arguments_with_lr( + num_arrays=3, + no_lr=True, + min_value=1e-05, + max_value=1e08, + large_abs_safety_factor=2, + small_abs_safety_factor=2, + ), + step=helpers.ints(min_value=1, max_value=3), + beta1_n_beta2_n_epsilon=helpers.list_of_size( + x=helpers.floats(min_value=1e-1, max_value=1), + size=3, ), - preserve_type=st.booleans(), - test_instance_method=st.just(False), - test_gradients=st.just(False), ) -def test_stop_gradient( - *, dtype_and_x, preserve_type, test_flags, backend_fw, fn_name, on_device +def test_adam_step( + *, + dtype_n_dcdw_n_mw_n_vw, + step, + beta1_n_beta2_n_epsilon, + test_flags, + backend_fw, + fn_name, + on_device, ): - dtype, x = dtype_and_x + input_dtypes, [dcdw, mw, vw] = dtype_n_dcdw_n_mw_n_vw + ( + beta1, + beta2, + epsilon, + ) = beta1_n_beta2_n_epsilon helpers.test_function( - input_dtypes=dtype, + input_dtypes=input_dtypes, test_flags=test_flags, backend_to_test=backend_fw, fn_name=fn_name, on_device=on_device, - x=x[0], - preserve_type=preserve_type, + rtol_=1e-1, + atol_=1e-1, + dcdw=dcdw, + mw=mw, + vw=vw, + step=step, + beta1=beta1, + beta2=beta2, + epsilon=epsilon, + ) + + +# adam_update +@handle_test( + fn_tree="functional.ivy.adam_update", + dtype_n_ws_n_dcdw_n_mwtm1_n_vwtm1_n_lr=get_gradient_arguments_with_lr( + num_arrays=4, + min_value=1e-05, + max_value=1e08, + large_abs_safety_factor=2.0, + small_abs_safety_factor=2.0, + ), + step=st.integers(min_value=1, max_value=10), + beta1_n_beta2_n_epsilon=helpers.list_of_size( + x=helpers.floats(min_value=1e-2, max_value=1), + size=3, + ), + stopgrad=st.booleans(), +) +def test_adam_update( + *, + dtype_n_ws_n_dcdw_n_mwtm1_n_vwtm1_n_lr, + step, + beta1_n_beta2_n_epsilon, + stopgrad, + test_flags, + backend_fw, + fn_name, + on_device, +): + input_dtypes, [w, dcdw, mw_tm1, vw_tm1], lr = dtype_n_ws_n_dcdw_n_mwtm1_n_vwtm1_n_lr + beta1, beta2, epsilon = beta1_n_beta2_n_epsilon + stop_gradients = stopgrad + helpers.test_function( + input_dtypes=input_dtypes, + test_flags=test_flags, + backend_to_test=backend_fw, + fn_name=fn_name, + on_device=on_device, + rtol_=1e-2, + atol_=1e-2, + w=w, + dcdw=dcdw, + lr=lr, + mw_tm1=mw_tm1, + vw_tm1=vw_tm1, + step=step, + beta1=beta1, + beta2=beta2, + epsilon=epsilon, + stop_gradients=stop_gradients, ) @@ -143,7 +225,7 @@ def func(xs): ) -# value_and_grad +# grad @pytest.mark.parametrize( "x", [[[4.6, 2.1, 5], [2.8, 1.3, 6.2]], [[4.6, 2.1], [5, 2.8], [1.3, 6.2]]] ) @@ -151,37 +233,71 @@ def func(xs): @pytest.mark.parametrize( "func", [lambda x: ivy.mean(ivy.square(x)), lambda x: ivy.mean(ivy.cos(x))] ) -def test_value_and_grad(x, dtype, func, backend_fw): - if backend_fw == "numpy": +@pytest.mark.parametrize("nth", [1, 2, 3]) +def test_grad(x, dtype, func, backend_fw, nth): + # ToDo: Remove skipping for paddle and jax for nth > 1 + if backend_fw == "numpy" or ( + (backend_fw == "paddle" or backend_fw == "jax") and nth > 1 + ): return + with BackendHandler.update_backend(backend_fw) as ivy_backend: - var = ivy_backend.ivy.functional.ivy.gradients._variable( - ivy_backend.array(x, dtype=dtype) - ) - fn = ivy_backend.value_and_grad(func) - value, grad = fn(var) - value_np, grad_np = helpers.flatten_and_to_np( - ret=value, backend=backend_fw - ), helpers.flatten_and_to_np(ret=grad, backend=backend_fw) + _variable_fn = ivy_backend.ivy.functional.ivy.gradients._variable + var = _variable_fn(ivy_backend.array(x, dtype=dtype)) + fn = ivy_backend.grad(func) + if nth > 1: + for _ in range(1, nth): + fn = ivy_backend.grad(fn) + grad = fn(var) + grad_np = helpers.flatten_and_to_np(ret=grad, backend=backend_fw) with BackendHandler.update_backend("tensorflow") as gt_backend: - var = gt_backend.ivy.functional.ivy.gradients._variable( - gt_backend.array(x, dtype=dtype) - ) - fn = gt_backend.value_and_grad(func) - value_gt, grad_gt = fn(var) - value_np_from_gt, grad_np_from_gt = helpers.flatten_and_to_np( - ret=value_gt, backend="tensorflow" - ), helpers.flatten_and_to_np(ret=grad_gt, backend="tensorflow") + _variable_fn = gt_backend.ivy.functional.ivy.gradients._variable + var = _variable_fn(ivy.array(x, dtype=dtype)) + fn = gt_backend.grad(func) + if nth > 1: + for _ in range(1, nth): + fn = gt_backend.grad(fn) + + grad_gt = fn(var) + grad_np_from_gt = helpers.flatten_and_to_np(ret=grad_gt, backend="tensorflow") - for value, value_from_gt in zip(value_np, value_np_from_gt): - assert value.shape == value_from_gt.shape - assert np.allclose(value, value_from_gt) for grad, grad_from_gt in zip(grad_np, grad_np_from_gt): assert grad.shape == grad_from_gt.shape assert np.allclose(grad, grad_from_gt) +# gradient_descent_update +@handle_test( + fn_tree="functional.ivy.gradient_descent_update", + dtype_n_ws_n_dcdw_n_lr=get_gradient_arguments_with_lr(num_arrays=2), + stop_gradients=st.booleans(), +) +def test_gradient_descent_update( + *, + dtype_n_ws_n_dcdw_n_lr, + stop_gradients, + test_flags, + backend_fw, + fn_name, + on_device, +): + input_dtypes, [w, dcdw], lr = dtype_n_ws_n_dcdw_n_lr + helpers.test_function( + input_dtypes=input_dtypes, + test_flags=test_flags, + backend_to_test=backend_fw, + fn_name=fn_name, + on_device=on_device, + rtol_=1e-2, + atol_=1e-2, + w=w, + dcdw=dcdw, + lr=lr, + stop_gradients=stop_gradients, + ) + + # jac @pytest.mark.parametrize( "x", [[[4.6, 2.1, 5], [2.8, 1.3, 6.2]], [[4.6, 2.1], [5, 2.8], [1.3, 6.2]]] @@ -266,81 +382,51 @@ def test_jac(x, dtype, func_str, backend_fw): assert np.allclose(jacobian, jacobian_from_gt) -# grad -@pytest.mark.parametrize( - "x", [[[4.6, 2.1, 5], [2.8, 1.3, 6.2]], [[4.6, 2.1], [5, 2.8], [1.3, 6.2]]] -) -@pytest.mark.parametrize("dtype", ["float32", "float64"]) -@pytest.mark.parametrize( - "func", [lambda x: ivy.mean(ivy.square(x)), lambda x: ivy.mean(ivy.cos(x))] -) -@pytest.mark.parametrize("nth", [1, 2, 3]) -def test_grad(x, dtype, func, backend_fw, nth): - # ToDo: Remove skipping for paddle and jax for nth > 1 - if backend_fw == "numpy" or ( - (backend_fw == "paddle" or backend_fw == "jax") and nth > 1 - ): - return - - with BackendHandler.update_backend(backend_fw) as ivy_backend: - _variable_fn = ivy_backend.ivy.functional.ivy.gradients._variable - var = _variable_fn(ivy_backend.array(x, dtype=dtype)) - fn = ivy_backend.grad(func) - if nth > 1: - for _ in range(1, nth): - fn = ivy_backend.grad(fn) - grad = fn(var) - grad_np = helpers.flatten_and_to_np(ret=grad, backend=backend_fw) - - with BackendHandler.update_backend("tensorflow") as gt_backend: - _variable_fn = gt_backend.ivy.functional.ivy.gradients._variable - var = _variable_fn(ivy.array(x, dtype=dtype)) - fn = gt_backend.grad(func) - if nth > 1: - for _ in range(1, nth): - fn = gt_backend.grad(fn) - - grad_gt = fn(var) - grad_np_from_gt = helpers.flatten_and_to_np(ret=grad_gt, backend="tensorflow") - - for grad, grad_from_gt in zip(grad_np, grad_np_from_gt): - assert grad.shape == grad_from_gt.shape - assert np.allclose(grad, grad_from_gt) - - -# adam_step +# lamb_update @handle_test( - fn_tree="functional.ivy.adam_step", - dtype_n_dcdw_n_mw_n_vw=get_gradient_arguments_with_lr( - num_arrays=3, - no_lr=True, - min_value=1e-05, - max_value=1e08, - large_abs_safety_factor=2, - small_abs_safety_factor=2, + fn_tree="functional.ivy.lamb_update", + dtype_n_ws_n_dcdw_n_mwtm1_n_vwtm1_n_lr=get_gradient_arguments_with_lr( + min_value=-1e5, + max_value=1e5, + num_arrays=4, ), - step=helpers.ints(min_value=1, max_value=3), - beta1_n_beta2_n_epsilon=helpers.list_of_size( - x=helpers.floats(min_value=1e-1, max_value=1), - size=3, + step=helpers.ints(min_value=1, max_value=100), + beta1_n_beta2_n_epsilon_n_lambda=helpers.list_of_size( + x=helpers.floats( + min_value=1e-2, + max_value=1.0, + ), + size=4, + ), + mtr=st.one_of( + helpers.ints(min_value=1, max_value=10), + st.floats(min_value=1e-2, max_value=10, exclude_min=True), ), + stopgrad=st.booleans(), ) -def test_adam_step( +def test_lamb_update( *, - dtype_n_dcdw_n_mw_n_vw, + dtype_n_ws_n_dcdw_n_mwtm1_n_vwtm1_n_lr, step, - beta1_n_beta2_n_epsilon, + beta1_n_beta2_n_epsilon_n_lambda, + mtr, + stopgrad, test_flags, backend_fw, fn_name, on_device, ): - input_dtypes, [dcdw, mw, vw] = dtype_n_dcdw_n_mw_n_vw + input_dtypes, [w, dcdw, mw_tm1, vw_tm1], lr = dtype_n_ws_n_dcdw_n_mwtm1_n_vwtm1_n_lr ( beta1, beta2, epsilon, - ) = beta1_n_beta2_n_epsilon + decay_lambda, + ) = beta1_n_beta2_n_epsilon_n_lambda + max_trust_ratio, stop_gradients = mtr, stopgrad + # ToDo: enable gradient tests for jax once the issue with jacrev is resolved + if backend_fw == "jax": + test_flags.test_gradients = False helpers.test_function( input_dtypes=input_dtypes, test_flags=test_flags, @@ -349,74 +435,17 @@ def test_adam_step( on_device=on_device, rtol_=1e-1, atol_=1e-1, + w=w, dcdw=dcdw, - mw=mw, - vw=vw, + lr=lr, + mw_tm1=mw_tm1, + vw_tm1=vw_tm1, step=step, beta1=beta1, beta2=beta2, epsilon=epsilon, - ) - - -# optimizer_update -@handle_test( - fn_tree="functional.ivy.optimizer_update", - dtype_n_ws_n_effgrad_n_lr=get_gradient_arguments_with_lr(num_arrays=2), - stop_gradients=st.booleans(), -) -def test_optimizer_update( - *, - dtype_n_ws_n_effgrad_n_lr, - stop_gradients, - test_flags, - backend_fw, - fn_name, - on_device, -): - input_dtypes, [w, effective_grad], lr = dtype_n_ws_n_effgrad_n_lr - helpers.test_function( - input_dtypes=input_dtypes, - backend_to_test=backend_fw, - test_flags=test_flags, - fn_name=fn_name, - on_device=on_device, - rtol_=1e-2, - atol_=1e-2, - w=w, - effective_grad=effective_grad, - lr=lr, - stop_gradients=stop_gradients, - ) - - -# gradient_descent_update -@handle_test( - fn_tree="functional.ivy.gradient_descent_update", - dtype_n_ws_n_dcdw_n_lr=get_gradient_arguments_with_lr(num_arrays=2), - stop_gradients=st.booleans(), -) -def test_gradient_descent_update( - *, - dtype_n_ws_n_dcdw_n_lr, - stop_gradients, - test_flags, - backend_fw, - fn_name, - on_device, -): - input_dtypes, [w, dcdw], lr = dtype_n_ws_n_dcdw_n_lr - helpers.test_function( - input_dtypes=input_dtypes, - test_flags=test_flags, - backend_to_test=backend_fw, - fn_name=fn_name, - on_device=on_device, - rtol_=1e-2, - atol_=1e-2, - w=w, - dcdw=dcdw, - lr=lr, + max_trust_ratio=max_trust_ratio, + decay_lambda=decay_lambda, stop_gradients=stop_gradients, ) @@ -460,121 +489,96 @@ def test_lars_update( ) -# adam_update +# optimizer_update @handle_test( - fn_tree="functional.ivy.adam_update", - dtype_n_ws_n_dcdw_n_mwtm1_n_vwtm1_n_lr=get_gradient_arguments_with_lr( - num_arrays=4, - min_value=1e-05, - max_value=1e08, - large_abs_safety_factor=2.0, - small_abs_safety_factor=2.0, - ), - step=st.integers(min_value=1, max_value=10), - beta1_n_beta2_n_epsilon=helpers.list_of_size( - x=helpers.floats(min_value=1e-2, max_value=1), - size=3, - ), - stopgrad=st.booleans(), + fn_tree="functional.ivy.optimizer_update", + dtype_n_ws_n_effgrad_n_lr=get_gradient_arguments_with_lr(num_arrays=2), + stop_gradients=st.booleans(), ) -def test_adam_update( +def test_optimizer_update( *, - dtype_n_ws_n_dcdw_n_mwtm1_n_vwtm1_n_lr, - step, - beta1_n_beta2_n_epsilon, - stopgrad, + dtype_n_ws_n_effgrad_n_lr, + stop_gradients, test_flags, backend_fw, fn_name, on_device, ): - input_dtypes, [w, dcdw, mw_tm1, vw_tm1], lr = dtype_n_ws_n_dcdw_n_mwtm1_n_vwtm1_n_lr - beta1, beta2, epsilon = beta1_n_beta2_n_epsilon - stop_gradients = stopgrad + input_dtypes, [w, effective_grad], lr = dtype_n_ws_n_effgrad_n_lr helpers.test_function( input_dtypes=input_dtypes, - test_flags=test_flags, backend_to_test=backend_fw, + test_flags=test_flags, fn_name=fn_name, on_device=on_device, rtol_=1e-2, atol_=1e-2, w=w, - dcdw=dcdw, + effective_grad=effective_grad, lr=lr, - mw_tm1=mw_tm1, - vw_tm1=vw_tm1, - step=step, - beta1=beta1, - beta2=beta2, - epsilon=epsilon, stop_gradients=stop_gradients, ) -# lamb_update +# stop_gradient @handle_test( - fn_tree="functional.ivy.lamb_update", - dtype_n_ws_n_dcdw_n_mwtm1_n_vwtm1_n_lr=get_gradient_arguments_with_lr( - min_value=-1e5, - max_value=1e5, - num_arrays=4, - ), - step=helpers.ints(min_value=1, max_value=100), - beta1_n_beta2_n_epsilon_n_lambda=helpers.list_of_size( - x=helpers.floats( - min_value=1e-2, - max_value=1.0, - ), - size=4, - ), - mtr=st.one_of( - helpers.ints(min_value=1, max_value=10), - st.floats(min_value=1e-2, max_value=10, exclude_min=True), + fn_tree="functional.ivy.stop_gradient", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric") ), - stopgrad=st.booleans(), + preserve_type=st.booleans(), + test_instance_method=st.just(False), + test_gradients=st.just(False), ) -def test_lamb_update( - *, - dtype_n_ws_n_dcdw_n_mwtm1_n_vwtm1_n_lr, - step, - beta1_n_beta2_n_epsilon_n_lambda, - mtr, - stopgrad, - test_flags, - backend_fw, - fn_name, - on_device, +def test_stop_gradient( + *, dtype_and_x, preserve_type, test_flags, backend_fw, fn_name, on_device ): - input_dtypes, [w, dcdw, mw_tm1, vw_tm1], lr = dtype_n_ws_n_dcdw_n_mwtm1_n_vwtm1_n_lr - ( - beta1, - beta2, - epsilon, - decay_lambda, - ) = beta1_n_beta2_n_epsilon_n_lambda - max_trust_ratio, stop_gradients = mtr, stopgrad - # ToDo: enable gradient tests for jax once the issue with jacrev is resolved - if backend_fw == "jax": - test_flags.test_gradients = False + dtype, x = dtype_and_x helpers.test_function( - input_dtypes=input_dtypes, + input_dtypes=dtype, test_flags=test_flags, backend_to_test=backend_fw, fn_name=fn_name, on_device=on_device, - rtol_=1e-1, - atol_=1e-1, - w=w, - dcdw=dcdw, - lr=lr, - mw_tm1=mw_tm1, - vw_tm1=vw_tm1, - step=step, - beta1=beta1, - beta2=beta2, - epsilon=epsilon, - max_trust_ratio=max_trust_ratio, - decay_lambda=decay_lambda, - stop_gradients=stop_gradients, + x=x[0], + preserve_type=preserve_type, ) + + +# value_and_grad +@pytest.mark.parametrize( + "x", [[[4.6, 2.1, 5], [2.8, 1.3, 6.2]], [[4.6, 2.1], [5, 2.8], [1.3, 6.2]]] +) +@pytest.mark.parametrize("dtype", ["float32", "float64"]) +@pytest.mark.parametrize( + "func", [lambda x: ivy.mean(ivy.square(x)), lambda x: ivy.mean(ivy.cos(x))] +) +def test_value_and_grad(x, dtype, func, backend_fw): + if backend_fw == "numpy": + return + with BackendHandler.update_backend(backend_fw) as ivy_backend: + var = ivy_backend.ivy.functional.ivy.gradients._variable( + ivy_backend.array(x, dtype=dtype) + ) + fn = ivy_backend.value_and_grad(func) + value, grad = fn(var) + value_np, grad_np = helpers.flatten_and_to_np( + ret=value, backend=backend_fw + ), helpers.flatten_and_to_np(ret=grad, backend=backend_fw) + + with BackendHandler.update_backend("tensorflow") as gt_backend: + var = gt_backend.ivy.functional.ivy.gradients._variable( + gt_backend.array(x, dtype=dtype) + ) + fn = gt_backend.value_and_grad(func) + value_gt, grad_gt = fn(var) + value_np_from_gt, grad_np_from_gt = helpers.flatten_and_to_np( + ret=value_gt, backend="tensorflow" + ), helpers.flatten_and_to_np(ret=grad_gt, backend="tensorflow") + + for value, value_from_gt in zip(value_np, value_np_from_gt): + assert value.shape == value_from_gt.shape + assert np.allclose(value, value_from_gt) + for grad, grad_from_gt in zip(grad_np, grad_np_from_gt): + assert grad.shape == grad_from_gt.shape + assert np.allclose(grad, grad_from_gt) diff --git a/ivy_tests/test_ivy/test_functional/test_core/test_linalg.py b/ivy_tests/test_ivy/test_functional/test_core/test_linalg.py index e3a08d3ad8a7b..bc4664e6f3eac 100644 --- a/ivy_tests/test_ivy/test_functional/test_core/test_linalg.py +++ b/ivy_tests/test_ivy/test_functional/test_core/test_linalg.py @@ -13,65 +13,104 @@ ) +# --- Helpers --- # +# --------------- # + + @st.composite -def dtype_value1_value2_axis( - draw, - available_dtypes, - abs_smallest_val=None, - min_value=None, - max_value=None, - allow_inf=False, - exclude_min=False, - exclude_max=False, - min_num_dims=1, - max_num_dims=10, - min_dim_size=1, - max_dim_size=10, - specific_dim_size=3, - large_abs_safety_factor=4, - small_abs_safety_factor=4, - safety_factor_scale="log", -): - # For cross product, a dim with size 3 is required - shape = draw( - helpers.get_shape( - allow_none=False, - min_num_dims=min_num_dims, - max_num_dims=max_num_dims, - min_dim_size=min_dim_size, - max_dim_size=max_dim_size, +def _det_helper(draw): + square = draw(helpers.ints(min_value=2, max_value=8).map(lambda x: tuple([x, x]))) + shape_prefix = draw(helpers.get_shape()) + dtype_x = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_value=2, + max_value=5, + shape=shape_prefix + square, ) ) - axis = draw(helpers.ints(min_value=0, max_value=len(shape))) - # make sure there is a dim with specific dim size - shape = list(shape) - shape = shape[:axis] + [specific_dim_size] + shape[axis:] - shape = tuple(shape) + return dtype_x - dtype = draw(st.sampled_from(draw(available_dtypes))) - values = [] - for i in range(2): - values.append( +@st.composite +def _diag_helper(draw): + dtype, x = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + small_abs_safety_factor=2, + large_abs_safety_factor=2, + safety_factor_scale="log", + min_num_dims=1, + max_num_dims=2, + min_dim_size=1, + max_dim_size=50, + ) + ) + shape = x[0].shape + if len(shape) == 2: + k = draw(helpers.ints(min_value=-shape[0] + 1, max_value=shape[1] - 1)) + else: + k = draw(helpers.ints(min_value=0, max_value=shape[0])) + return dtype, x, k + + +@st.composite +def _get_dtype_and_matrix(draw, *, symmetric=False): + # batch_shape, shared, random_size + input_dtype = draw(st.shared(st.sampled_from(draw(helpers.get_dtypes("float"))))) + random_size = draw(helpers.ints(min_value=2, max_value=4)) + batch_shape = draw(helpers.get_shape(min_num_dims=1, max_num_dims=3)) + if symmetric: + num_independnt_vals = int((random_size**2) / 2 + random_size / 2) + array_vals_flat = np.array( draw( helpers.array_values( - dtype=dtype, - shape=shape, - abs_smallest_val=abs_smallest_val, - min_value=min_value, - max_value=max_value, - allow_inf=allow_inf, - exclude_min=exclude_min, - exclude_max=exclude_max, - large_abs_safety_factor=large_abs_safety_factor, - small_abs_safety_factor=small_abs_safety_factor, - safety_factor_scale=safety_factor_scale, + dtype=input_dtype, + shape=tuple(list(batch_shape) + [num_independnt_vals]), + min_value=2, + max_value=5, ) ) ) + array_vals = np.zeros(batch_shape + (random_size, random_size)) + c = 0 + for i in range(random_size): + for j in range(random_size): + if j < i: + continue + array_vals[..., i, j] = array_vals_flat[..., c] + array_vals[..., j, i] = array_vals_flat[..., c] + c += 1 + return [input_dtype], array_vals + return [input_dtype], draw( + helpers.array_values( + dtype=input_dtype, + shape=tuple(list(batch_shape) + [random_size, random_size]), + min_value=2, + max_value=5, + ) + ) - value1, value2 = values[0], values[1] - return [dtype], value1, value2, axis + +# vector_to_skew_symmetric_matrix +@st.composite +def _get_dtype_and_vector(draw): + # batch_shape, shared, random_size + input_dtype = draw( + st.shared( + st.sampled_from(draw(helpers.get_dtypes("numeric"))), + key="shared_dtype", + ) + ) + batch_shape = draw(helpers.get_shape(min_num_dims=2, max_num_dims=4)) + return [input_dtype], draw( + helpers.array_values( + dtype=input_dtype, + shape=tuple(list(batch_shape) + [3]), + min_value=2, + max_value=5, + ) + ) @st.composite @@ -128,44 +167,6 @@ def _get_dtype_value1_value2_axis_for_tensordot( return [dtype], value1, value2, axis -@st.composite -def _get_dtype_and_matrix(draw, *, symmetric=False): - # batch_shape, shared, random_size - input_dtype = draw(st.shared(st.sampled_from(draw(helpers.get_dtypes("float"))))) - random_size = draw(helpers.ints(min_value=2, max_value=4)) - batch_shape = draw(helpers.get_shape(min_num_dims=1, max_num_dims=3)) - if symmetric: - num_independnt_vals = int((random_size**2) / 2 + random_size / 2) - array_vals_flat = np.array( - draw( - helpers.array_values( - dtype=input_dtype, - shape=tuple(list(batch_shape) + [num_independnt_vals]), - min_value=2, - max_value=5, - ) - ) - ) - array_vals = np.zeros(batch_shape + (random_size, random_size)) - c = 0 - for i in range(random_size): - for j in range(random_size): - if j < i: - continue - array_vals[..., i, j] = array_vals_flat[..., c] - array_vals[..., j, i] = array_vals_flat[..., c] - c += 1 - return [input_dtype], array_vals - return [input_dtype], draw( - helpers.array_values( - dtype=input_dtype, - shape=tuple(list(batch_shape) + [random_size, random_size]), - min_value=2, - max_value=5, - ) - ) - - @st.composite def _get_first_matrix_and_dtype(draw, *, transpose=False, conjugate=False): # batch_shape, random_size, shared @@ -241,111 +242,170 @@ def _get_second_matrix_and_dtype(draw, *, transpose=False): return [input_dtype], matrix -# vector_to_skew_symmetric_matrix @st.composite -def _get_dtype_and_vector(draw): - # batch_shape, shared, random_size - input_dtype = draw( - st.shared( - st.sampled_from(draw(helpers.get_dtypes("numeric"))), - key="shared_dtype", - ) +def _matrix_rank_helper(draw): + _batch_shape = draw( + helpers.get_shape(min_num_dims=1, max_num_dims=3, min_dim_size=1) ) - batch_shape = draw(helpers.get_shape(min_num_dims=2, max_num_dims=4)) - return [input_dtype], draw( - helpers.array_values( - dtype=input_dtype, - shape=tuple(list(batch_shape) + [3]), - min_value=2, - max_value=5, + _batch_dim = draw(st.sampled_from([(), _batch_shape])) + _matrix_dim = draw(helpers.ints(min_value=2, max_value=20)) + shape = _batch_dim + (_matrix_dim, _matrix_dim) + dtype, x = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + shape=shape, + min_value=-1e05, + max_value=1e05, + abs_smallest_val=1e-05, + safety_factor_scale="log", ) ) + if np.all(np.swapaxes(x[0], -1, -2) == x[0]): + hermitian = True + else: + hermitian = False - -@handle_test( - fn_tree="functional.ivy.vector_to_skew_symmetric_matrix", - dtype_x=_get_dtype_and_vector(), -) -def test_vector_to_skew_symmetric_matrix( - *, dtype_x, test_flags, backend_fw, fn_name, on_device -): - input_dtype, x = dtype_x - helpers.test_function( - input_dtypes=input_dtype, - test_flags=test_flags, - backend_to_test=backend_fw, - fn_name=fn_name, - on_device=on_device, - vector=x, + tol_strategy = st.one_of( + st.none(), + st.floats(allow_nan=False, allow_infinity=False), + helpers.array_values( + dtype=helpers.get_dtypes("float", prune_function=False), + shape=_batch_shape, + min_value=-1e05, + max_value=1e05, + abs_smallest_val=1e-05, + safety_factor_scale="log", + ), ) + atol = draw(tol_strategy) + rtol = draw(tol_strategy) + return dtype, x[0], hermitian, atol, rtol -# matrix_power -@handle_test( - fn_tree="functional.ivy.matrix_power", - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_value=1e-3, - max_value=20, - shape=helpers.ints(min_value=2, max_value=8).map(lambda x: tuple([x, x])), - ), - n=helpers.ints(min_value=-6, max_value=6), -) -def test_matrix_power(*, dtype_x, n, test_flags, backend_fw, fn_name, on_device): - dtype, x = dtype_x - assume(matrix_is_stable(x[0])) - helpers.test_function( - input_dtypes=dtype, - test_flags=test_flags, - backend_to_test=backend_fw, - fn_name=fn_name, - on_device=on_device, - rtol_=1e-1, - atol_=1e-1, - x=x[0], - n=n, +@st.composite +def dtype_value1_value2_axis( + draw, + available_dtypes, + abs_smallest_val=None, + min_value=None, + max_value=None, + allow_inf=False, + exclude_min=False, + exclude_max=False, + min_num_dims=1, + max_num_dims=10, + min_dim_size=1, + max_dim_size=10, + specific_dim_size=3, + large_abs_safety_factor=4, + small_abs_safety_factor=4, + safety_factor_scale="log", +): + # For cross product, a dim with size 3 is required + shape = draw( + helpers.get_shape( + allow_none=False, + min_num_dims=min_num_dims, + max_num_dims=max_num_dims, + min_dim_size=min_dim_size, + max_dim_size=max_dim_size, + ) ) + axis = draw(helpers.ints(min_value=0, max_value=len(shape))) + # make sure there is a dim with specific dim size + shape = list(shape) + shape = shape[:axis] + [specific_dim_size] + shape[axis:] + shape = tuple(shape) + dtype = draw(st.sampled_from(draw(available_dtypes))) -# matmul + values = [] + for i in range(2): + values.append( + draw( + helpers.array_values( + dtype=dtype, + shape=shape, + abs_smallest_val=abs_smallest_val, + min_value=min_value, + max_value=max_value, + allow_inf=allow_inf, + exclude_min=exclude_min, + exclude_max=exclude_max, + large_abs_safety_factor=large_abs_safety_factor, + small_abs_safety_factor=small_abs_safety_factor, + safety_factor_scale=safety_factor_scale, + ) + ) + ) + + value1, value2 = values[0], values[1] + return [dtype], value1, value2, axis + + +# --- Main --- # +# ------------ # + + +# cholesky +# execute with grads error @handle_test( - fn_tree="functional.ivy.matmul", - x=_get_first_matrix_and_dtype(transpose=True), - y=_get_second_matrix_and_dtype(transpose=True), + fn_tree="functional.ivy.cholesky", + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_value=0, + max_value=10, + shape=helpers.ints(min_value=2, max_value=5).map(lambda x: tuple([x, x])), + ), + upper=st.booleans(), ) -def test_matmul(*, x, y, test_flags, backend_fw, fn_name, on_device): - input_dtype1, x_1, transpose_a, adjoint_a = x - input_dtype2, y_1, transpose_b, adjoint_b = y +def test_cholesky(*, dtype_x, upper, test_flags, backend_fw, fn_name, on_device): + dtype, x = dtype_x + x = x[0] + x = np.matmul(x.T, x) + np.identity(x.shape[0]) # make symmetric positive-definite + helpers.test_function( - input_dtypes=input_dtype1 + input_dtype2, + input_dtypes=dtype, test_flags=test_flags, backend_to_test=backend_fw, fn_name=fn_name, on_device=on_device, - rtol_=1e-1, - atol_=1e-1, - x1=x_1, - x2=y_1, - transpose_a=transpose_a, - transpose_b=transpose_b, - adjoint_a=adjoint_a, - adjoint_b=adjoint_b, + x=x, + upper=upper, + rtol_=1e-3, + atol_=1e-3, ) -@st.composite -def _det_helper(draw): - square = draw(helpers.ints(min_value=2, max_value=8).map(lambda x: tuple([x, x]))) - shape_prefix = draw(helpers.get_shape()) - dtype_x = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_value=2, - max_value=5, - shape=shape_prefix + square, - ) +# cross +@handle_test( + fn_tree="functional.ivy.cross", + dtype_x1_x2_axis=dtype_value1_value2_axis( + available_dtypes=helpers.get_dtypes("numeric"), + min_num_dims=1, + max_num_dims=5, + min_dim_size=3, + max_dim_size=3, + min_value=-1e5, + max_value=1e5, + abs_smallest_val=0.01, + safety_factor_scale="log", + ), +) +def test_cross(*, dtype_x1_x2_axis, test_flags, backend_fw, fn_name, on_device): + dtype, x1, x2, axis = dtype_x1_x2_axis + helpers.test_function( + input_dtypes=dtype, + test_flags=test_flags, + backend_to_test=backend_fw, + fn_name=fn_name, + on_device=on_device, + rtol_=1e-1, + atol_=1e-1, + x1=x1, + x2=x2, + axis=axis, ) - return dtype_x # det @@ -368,6 +428,54 @@ def test_det(*, dtype_x, test_flags, backend_fw, fn_name, on_device): ) +# diag +@handle_test( + fn_tree="functional.ivy.diag", + dtype_x_k=_diag_helper(), +) +def test_diag(*, dtype_x_k, test_flags, backend_fw, fn_name, on_device): + dtype, x, k = dtype_x_k + helpers.test_function( + input_dtypes=dtype, + test_flags=test_flags, + backend_to_test=backend_fw, + fn_name=fn_name, + on_device=on_device, + x=x[0], + k=k, + ) + + +# diagonal +@handle_test( + fn_tree="functional.ivy.diagonal", + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + min_num_dims=2, + max_num_dims=2, + min_dim_size=1, + max_dim_size=50, + ), + offset=helpers.ints(min_value=-10, max_value=50), + axes=st.lists( + helpers.ints(min_value=-2, max_value=1), min_size=2, max_size=2, unique=True + ).filter(lambda axes: axes[0] % 2 != axes[1] % 2), +) +def test_diagonal(*, dtype_x, offset, axes, test_flags, backend_fw, fn_name, on_device): + dtype, x = dtype_x + helpers.test_function( + input_dtypes=dtype, + test_flags=test_flags, + backend_to_test=backend_fw, + fn_name=fn_name, + on_device=on_device, + x=x[0], + offset=offset, + axis1=axes[0], + axis2=axes[1], + ) + + # eigh @handle_test( fn_tree="functional.ivy.eigh", @@ -504,284 +612,159 @@ def test_inv(*, dtype_x, adjoint, test_flags, backend_fw, fn_name, on_device): ) -# matrix_transpose +# matmul @handle_test( - fn_tree="functional.ivy.matrix_transpose", - dtype_x=_get_first_matrix_and_dtype(conjugate=True), + fn_tree="functional.ivy.matmul", + x=_get_first_matrix_and_dtype(transpose=True), + y=_get_second_matrix_and_dtype(transpose=True), ) -def test_matrix_transpose(*, dtype_x, test_flags, backend_fw, fn_name, on_device): - input_dtype, x, conjugate = dtype_x +def test_matmul(*, x, y, test_flags, backend_fw, fn_name, on_device): + input_dtype1, x_1, transpose_a, adjoint_a = x + input_dtype2, y_1, transpose_b, adjoint_b = y helpers.test_function( - input_dtypes=input_dtype, + input_dtypes=input_dtype1 + input_dtype2, test_flags=test_flags, backend_to_test=backend_fw, fn_name=fn_name, on_device=on_device, - x=x, - conjugate=conjugate, + rtol_=1e-1, + atol_=1e-1, + x1=x_1, + x2=y_1, + transpose_a=transpose_a, + transpose_b=transpose_b, + adjoint_a=adjoint_a, + adjoint_b=adjoint_b, ) -# outer +# matrix_norm @handle_test( - fn_tree="functional.ivy.outer", - dtype_xy=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, - min_value=1, - max_value=50, - min_num_dims=1, - max_num_dims=1, + fn_tree="functional.ivy.matrix_norm", + # ground_truth_backend="numpy", + dtype_value_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("float"), + min_num_dims=2, + valid_axis=True, + min_axes_size=2, + max_axes_size=2, + force_tuple_axis=True, + allow_neg_axes=False, ), + kd=st.booleans(), + ord=st.sampled_from((-2, -1, 1, 2, -float("inf"), float("inf"), "fro", "nuc")), ) -def test_outer(*, dtype_xy, test_flags, backend_fw, fn_name, on_device): - types, arrays = dtype_xy +def test_matrix_norm( + *, dtype_value_axis, kd, ord, test_flags, backend_fw, fn_name, on_device +): + dtype, x, axis = dtype_value_axis + assume(matrix_is_stable(x[0], cond_limit=10)) helpers.test_function( - input_dtypes=types, + input_dtypes=dtype, test_flags=test_flags, backend_to_test=backend_fw, fn_name=fn_name, on_device=on_device, - x1=arrays[0], - x2=arrays[1], + rtol_=1e-1, + atol_=1e-2, + x=x[0], + axis=axis, + keepdims=kd, + ord=ord, ) -# slogdet -# TODO: add with_out testing when testing with tuples is supported -# execute with grads error +# matrix_power @handle_test( - fn_tree="functional.ivy.slogdet", - dtype_x=_det_helper(), - test_with_out=st.just(False), -) -def test_slogdet(*, dtype_x, test_flags, backend_fw, fn_name, on_device): - input_dtype, x = dtype_x - assume(matrix_is_stable(x[0])) - ret_grad_idxs = ( - [[1, "a"], [1, "b", "c"], [1, "b", "d"]] if test_flags.container[0] else [[1]] - ) - helpers.test_function( - input_dtypes=input_dtype, - test_flags=test_flags, - backend_to_test=backend_fw, - rtol_=1e-1, - atol_=1e-2, - fn_name=fn_name, - on_device=on_device, - ret_grad_idxs=ret_grad_idxs, - x=x[0], - ) - - -@handle_test( - fn_tree="functional.ivy.solve", - x=helpers.get_first_solve_matrix(adjoint=True), - y=helpers.get_second_solve_matrix(), -) -def test_solve(*, x, y, test_flags, backend_fw, fn_name, on_device): - input_dtype1, x1, adjoint = x - input_dtype2, x2 = y - helpers.test_function( - input_dtypes=[input_dtype1, input_dtype2], - test_flags=test_flags, - backend_to_test=backend_fw, - fn_name=fn_name, - on_device=on_device, - rtol_=1e-1, - atol_=1e-1, - x1=x1, - x2=x2, - adjoint=adjoint, - ) - - -# svdvals -@handle_test( - fn_tree="functional.ivy.svdvals", + fn_tree="functional.ivy.matrix_power", dtype_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - min_value=0, - max_value=50, - min_num_dims=2, + min_value=1e-3, + max_value=20, + shape=helpers.ints(min_value=2, max_value=8).map(lambda x: tuple([x, x])), ), - test_gradients=st.just(False), + n=helpers.ints(min_value=-6, max_value=6), ) -def test_svdvals(*, dtype_x, test_flags, backend_fw, fn_name, on_device): - input_dtype, x = dtype_x +def test_matrix_power(*, dtype_x, n, test_flags, backend_fw, fn_name, on_device): + dtype, x = dtype_x + assume(matrix_is_stable(x[0])) helpers.test_function( - input_dtypes=input_dtype, + input_dtypes=dtype, test_flags=test_flags, backend_to_test=backend_fw, fn_name=fn_name, on_device=on_device, - rtol_=1e-2, - atol_=1e-2, + rtol_=1e-1, + atol_=1e-1, x=x[0], + n=n, ) -# tensordot +# matrix_rank @handle_test( - fn_tree="functional.ivy.tensordot", - dtype_x1_x2_axis=_get_dtype_value1_value2_axis_for_tensordot( - available_dtypes=helpers.get_dtypes("numeric"), - min_num_dims=1, - max_num_dims=5, - min_dim_size=1, - max_dim_size=10, - ), + fn_tree="functional.ivy.matrix_rank", + dtype_x_hermitian_atol_rtol=_matrix_rank_helper(), + ground_truth_backend="numpy", ) -def test_tensordot(*, dtype_x1_x2_axis, test_flags, backend_fw, fn_name, on_device): - ( - dtype, - x1, - x2, - axis, - ) = dtype_x1_x2_axis - +def test_matrix_rank( + *, dtype_x_hermitian_atol_rtol, test_flags, backend_fw, fn_name, on_device +): + dtype, x, hermitian, atol, rtol = dtype_x_hermitian_atol_rtol + assume(matrix_is_stable(x, cond_limit=10)) helpers.test_function( input_dtypes=dtype, test_flags=test_flags, backend_to_test=backend_fw, fn_name=fn_name, on_device=on_device, - rtol_=0.8, - atol_=0.8, - x1=x1, - x2=x2, - axes=axis, + x=x, + atol=atol, + rtol=rtol, + hermitian=hermitian, ) -# trace +# matrix_transpose @handle_test( - fn_tree="functional.ivy.trace", - dtype_x_axes=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("float"), - valid_axis=True, - min_axes_size=2, - max_axes_size=2, - min_num_dims=2, - large_abs_safety_factor=24, - small_abs_safety_factor=24, - safety_factor_scale="log", - ), - # TODO: test for more offsets - offset=st.integers(min_value=-3, max_value=3), + fn_tree="functional.ivy.matrix_transpose", + dtype_x=_get_first_matrix_and_dtype(conjugate=True), ) -def test_trace(*, dtype_x_axes, offset, test_flags, backend_fw, fn_name, on_device): - dtype, x, axes = dtype_x_axes - axis1, axis2 = axes +def test_matrix_transpose(*, dtype_x, test_flags, backend_fw, fn_name, on_device): + input_dtype, x, conjugate = dtype_x helpers.test_function( - input_dtypes=dtype, + input_dtypes=input_dtype, test_flags=test_flags, backend_to_test=backend_fw, fn_name=fn_name, on_device=on_device, - rtol_=1e-1, - atol_=1e-1, - x=x[0], - offset=offset, - axis1=axis1, - axis2=axis2, + x=x, + conjugate=conjugate, ) -# vecdot +# outer @handle_test( - fn_tree="functional.ivy.vecdot", - dtype_x1_x2_axis=dtype_value1_value2_axis( + fn_tree="functional.ivy.outer", + dtype_xy=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), - large_abs_safety_factor=100, - small_abs_safety_factor=100, - safety_factor_scale="log", + num_arrays=2, + min_value=1, + max_value=50, min_num_dims=1, - max_num_dims=4, - min_dim_size=1, - max_dim_size=4, - ), -) -def test_vecdot(*, dtype_x1_x2_axis, test_flags, backend_fw, fn_name, on_device): - dtype, x1, x2, axis = dtype_x1_x2_axis - helpers.test_function( - input_dtypes=dtype, - test_flags=test_flags, - backend_to_test=backend_fw, - fn_name=fn_name, - on_device=on_device, - rtol_=5e-1, - atol_=5e-1, - x1=x1, - x2=x2, - axis=axis, - ) - - -# vector_norm -@handle_test( - fn_tree="functional.ivy.vector_norm", - dtype_values_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("numeric"), - valid_axis=True, - min_value=-1e04, - max_value=1e04, - abs_smallest_val=1e-04, - max_axes_size=2, - allow_neg_axes=True, - ), - kd=st.booleans(), - ord=st.one_of( - helpers.ints(min_value=-5, max_value=5), - helpers.floats(min_value=-5, max_value=5.0), - st.sampled_from((float("inf"), -float("inf"))), + max_num_dims=1, ), - dtype=helpers.get_dtypes("numeric", full=False, none=True), ) -def test_vector_norm( - *, dtype_values_axis, kd, ord, dtype, test_flags, backend_fw, fn_name, on_device -): - x_dtype, x, axis = dtype_values_axis - # to avoid tuple axis with only one axis as force_int_axis can't generate - # axis with two axes - if isinstance(axis, tuple) and len(axis) == 1: - axis = axis[0] +def test_outer(*, dtype_xy, test_flags, backend_fw, fn_name, on_device): + types, arrays = dtype_xy helpers.test_function( - input_dtypes=x_dtype, + input_dtypes=types, test_flags=test_flags, backend_to_test=backend_fw, fn_name=fn_name, on_device=on_device, - x=x[0], - axis=axis, - keepdims=kd, - ord=ord, - dtype=dtype[0], - atol_=1e-08, - ) - - # Specific value test to handle cases when ord is one of {inf, -inf} - - with BackendHandler.update_backend(backend_fw) as ivy_backend: - arr = ivy_backend.array([[1.0, 2.0, 3.0], [-1.0, 2.0, 4.0]]) - arr_normed_inf = ivy_backend.vector_norm(arr, axis=0, ord=float("inf")) - arr_normed_min_inf = ivy_backend.vector_norm(arr, axis=0, ord=float("-inf")) - - with BackendHandler.update_backend(test_flags.ground_truth_backend) as gt_backend: - gt_arr_normed_inf = gt_backend.array([1.0, 2.0, 4.0]) - gt_arr_normed_min_inf = gt_backend.array([1.0, 2.0, 3.0]) - - helpers.assert_all_close( - arr_normed_inf, - gt_arr_normed_inf, - backend=backend_fw, - ground_truth_backend=test_flags.ground_truth_backend, - ) - helpers.assert_all_close( - arr_normed_min_inf, - gt_arr_normed_min_inf, - backend=backend_fw, - ground_truth_backend=test_flags.ground_truth_backend, + x1=arrays[0], + x2=arrays[1], ) @@ -860,31 +843,80 @@ def test_qr(*, dtype_x, mode, test_flags, backend_fw, fn_name, on_device): ) -# svd +# slogdet +# TODO: add with_out testing when testing with tuples is supported +# execute with grads error @handle_test( - fn_tree="functional.ivy.svd", - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_num_dims=3, - max_num_dims=5, - min_dim_size=2, - max_dim_size=5, - min_value=0.1, - max_value=10.0, - ), - fm=st.booleans(), - uv=st.booleans(), + fn_tree="functional.ivy.slogdet", + dtype_x=_det_helper(), test_with_out=st.just(False), - test_gradients=st.just(False), ) -def test_svd(*, dtype_x, uv, fm, test_flags, backend_fw, fn_name, on_device): - dtype, x = dtype_x - - results = helpers.test_function( - input_dtypes=dtype, - test_flags=test_flags, - backend_to_test=backend_fw, - fn_name=fn_name, +def test_slogdet(*, dtype_x, test_flags, backend_fw, fn_name, on_device): + input_dtype, x = dtype_x + assume(matrix_is_stable(x[0])) + ret_grad_idxs = ( + [[1, "a"], [1, "b", "c"], [1, "b", "d"]] if test_flags.container[0] else [[1]] + ) + helpers.test_function( + input_dtypes=input_dtype, + test_flags=test_flags, + backend_to_test=backend_fw, + rtol_=1e-1, + atol_=1e-2, + fn_name=fn_name, + on_device=on_device, + ret_grad_idxs=ret_grad_idxs, + x=x[0], + ) + + +@handle_test( + fn_tree="functional.ivy.solve", + x=helpers.get_first_solve_matrix(adjoint=True), + y=helpers.get_second_solve_matrix(), +) +def test_solve(*, x, y, test_flags, backend_fw, fn_name, on_device): + input_dtype1, x1, adjoint = x + input_dtype2, x2 = y + helpers.test_function( + input_dtypes=[input_dtype1, input_dtype2], + test_flags=test_flags, + backend_to_test=backend_fw, + fn_name=fn_name, + on_device=on_device, + rtol_=1e-1, + atol_=1e-1, + x1=x1, + x2=x2, + adjoint=adjoint, + ) + + +# svd +@handle_test( + fn_tree="functional.ivy.svd", + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_num_dims=3, + max_num_dims=5, + min_dim_size=2, + max_dim_size=5, + min_value=0.1, + max_value=10.0, + ), + fm=st.booleans(), + uv=st.booleans(), + test_with_out=st.just(False), + test_gradients=st.just(False), +) +def test_svd(*, dtype_x, uv, fm, test_flags, backend_fw, fn_name, on_device): + dtype, x = dtype_x + + results = helpers.test_function( + input_dtypes=dtype, + test_flags=test_flags, + backend_to_test=backend_fw, + fn_name=fn_name, on_device=on_device, x=x[0], compute_uv=uv, @@ -956,152 +988,146 @@ def test_svd(*, dtype_x, uv, fm, test_flags, backend_fw, fn_name, on_device): ) -# matrix_norm +# svdvals @handle_test( - fn_tree="functional.ivy.matrix_norm", - # ground_truth_backend="numpy", - dtype_value_axis=helpers.dtype_values_axis( + fn_tree="functional.ivy.svdvals", + dtype_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), + min_value=0, + max_value=50, min_num_dims=2, - valid_axis=True, - min_axes_size=2, - max_axes_size=2, - force_tuple_axis=True, - allow_neg_axes=False, ), - kd=st.booleans(), - ord=st.sampled_from((-2, -1, 1, 2, -float("inf"), float("inf"), "fro", "nuc")), + test_gradients=st.just(False), ) -def test_matrix_norm( - *, dtype_value_axis, kd, ord, test_flags, backend_fw, fn_name, on_device -): - dtype, x, axis = dtype_value_axis - assume(matrix_is_stable(x[0], cond_limit=10)) +def test_svdvals(*, dtype_x, test_flags, backend_fw, fn_name, on_device): + input_dtype, x = dtype_x helpers.test_function( - input_dtypes=dtype, + input_dtypes=input_dtype, test_flags=test_flags, backend_to_test=backend_fw, fn_name=fn_name, on_device=on_device, - rtol_=1e-1, + rtol_=1e-2, atol_=1e-2, x=x[0], - axis=axis, - keepdims=kd, - ord=ord, ) -@st.composite -def _matrix_rank_helper(draw): - _batch_shape = draw( - helpers.get_shape(min_num_dims=1, max_num_dims=3, min_dim_size=1) - ) - _batch_dim = draw(st.sampled_from([(), _batch_shape])) - _matrix_dim = draw(helpers.ints(min_value=2, max_value=20)) - shape = _batch_dim + (_matrix_dim, _matrix_dim) - dtype, x = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - shape=shape, - min_value=-1e05, - max_value=1e05, - abs_smallest_val=1e-05, - safety_factor_scale="log", - ) - ) - if np.all(np.swapaxes(x[0], -1, -2) == x[0]): - hermitian = True - else: - hermitian = False +# tensordot +@handle_test( + fn_tree="functional.ivy.tensordot", + dtype_x1_x2_axis=_get_dtype_value1_value2_axis_for_tensordot( + available_dtypes=helpers.get_dtypes("numeric"), + min_num_dims=1, + max_num_dims=5, + min_dim_size=1, + max_dim_size=10, + ), +) +def test_tensordot(*, dtype_x1_x2_axis, test_flags, backend_fw, fn_name, on_device): + ( + dtype, + x1, + x2, + axis, + ) = dtype_x1_x2_axis - tol_strategy = st.one_of( - st.none(), - st.floats(allow_nan=False, allow_infinity=False), - helpers.array_values( - dtype=helpers.get_dtypes("float", prune_function=False), - shape=_batch_shape, - min_value=-1e05, - max_value=1e05, - abs_smallest_val=1e-05, - safety_factor_scale="log", - ), + helpers.test_function( + input_dtypes=dtype, + test_flags=test_flags, + backend_to_test=backend_fw, + fn_name=fn_name, + on_device=on_device, + rtol_=0.8, + atol_=0.8, + x1=x1, + x2=x2, + axes=axis, ) - atol = draw(tol_strategy) - rtol = draw(tol_strategy) - return dtype, x[0], hermitian, atol, rtol -# matrix_rank +# trace @handle_test( - fn_tree="functional.ivy.matrix_rank", - dtype_x_hermitian_atol_rtol=_matrix_rank_helper(), - ground_truth_backend="numpy", + fn_tree="functional.ivy.trace", + dtype_x_axes=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("float"), + valid_axis=True, + min_axes_size=2, + max_axes_size=2, + min_num_dims=2, + large_abs_safety_factor=24, + small_abs_safety_factor=24, + safety_factor_scale="log", + ), + # TODO: test for more offsets + offset=st.integers(min_value=-3, max_value=3), ) -def test_matrix_rank( - *, dtype_x_hermitian_atol_rtol, test_flags, backend_fw, fn_name, on_device -): - dtype, x, hermitian, atol, rtol = dtype_x_hermitian_atol_rtol - assume(matrix_is_stable(x, cond_limit=10)) +def test_trace(*, dtype_x_axes, offset, test_flags, backend_fw, fn_name, on_device): + dtype, x, axes = dtype_x_axes + axis1, axis2 = axes helpers.test_function( input_dtypes=dtype, test_flags=test_flags, backend_to_test=backend_fw, fn_name=fn_name, on_device=on_device, - x=x, - atol=atol, - rtol=rtol, - hermitian=hermitian, + rtol_=1e-1, + atol_=1e-1, + x=x[0], + offset=offset, + axis1=axis1, + axis2=axis2, ) -# cholesky -# execute with grads error +# vander @handle_test( - fn_tree="functional.ivy.cholesky", - dtype_x=helpers.dtype_and_values( + fn_tree="functional.ivy.vander", + dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - min_value=0, - max_value=10, - shape=helpers.ints(min_value=2, max_value=5).map(lambda x: tuple([x, x])), + shape=st.tuples( + helpers.ints(min_value=1, max_value=10), + ), + large_abs_safety_factor=15, + small_abs_safety_factor=15, + safety_factor_scale="log", ), - upper=st.booleans(), + N=st.integers(min_value=1, max_value=10) | st.none(), + increasing=st.booleans(), ) -def test_cholesky(*, dtype_x, upper, test_flags, backend_fw, fn_name, on_device): - dtype, x = dtype_x - x = x[0] - x = np.matmul(x.T, x) + np.identity(x.shape[0]) # make symmetric positive-definite - +def test_vander( + *, dtype_and_x, N, increasing, test_flags, backend_fw, fn_name, on_device +): + input_dtype, x = dtype_and_x helpers.test_function( - input_dtypes=dtype, + input_dtypes=input_dtype, test_flags=test_flags, backend_to_test=backend_fw, fn_name=fn_name, on_device=on_device, - x=x, - upper=upper, - rtol_=1e-3, - atol_=1e-3, + rtol_=1e-2, + atol_=1e-2, + x=x[0], + N=N, + increasing=increasing, ) -# cross +# vecdot @handle_test( - fn_tree="functional.ivy.cross", + fn_tree="functional.ivy.vecdot", dtype_x1_x2_axis=dtype_value1_value2_axis( available_dtypes=helpers.get_dtypes("numeric"), - min_num_dims=1, - max_num_dims=5, - min_dim_size=3, - max_dim_size=3, - min_value=-1e5, - max_value=1e5, - abs_smallest_val=0.01, + large_abs_safety_factor=100, + small_abs_safety_factor=100, safety_factor_scale="log", + min_num_dims=1, + max_num_dims=4, + min_dim_size=1, + max_dim_size=4, ), ) -def test_cross(*, dtype_x1_x2_axis, test_flags, backend_fw, fn_name, on_device): +def test_vecdot(*, dtype_x1_x2_axis, test_flags, backend_fw, fn_name, on_device): dtype, x1, x2, axis = dtype_x1_x2_axis helpers.test_function( input_dtypes=dtype, @@ -1109,112 +1135,94 @@ def test_cross(*, dtype_x1_x2_axis, test_flags, backend_fw, fn_name, on_device): backend_to_test=backend_fw, fn_name=fn_name, on_device=on_device, - rtol_=1e-1, - atol_=1e-1, + rtol_=5e-1, + atol_=5e-1, x1=x1, x2=x2, axis=axis, ) -# diagonal +# vector_norm @handle_test( - fn_tree="functional.ivy.diagonal", - dtype_x=helpers.dtype_and_values( + fn_tree="functional.ivy.vector_norm", + dtype_values_axis=helpers.dtype_values_axis( available_dtypes=helpers.get_dtypes("numeric"), - min_num_dims=2, - max_num_dims=2, - min_dim_size=1, - max_dim_size=50, + valid_axis=True, + min_value=-1e04, + max_value=1e04, + abs_smallest_val=1e-04, + max_axes_size=2, + allow_neg_axes=True, ), - offset=helpers.ints(min_value=-10, max_value=50), - axes=st.lists( - helpers.ints(min_value=-2, max_value=1), min_size=2, max_size=2, unique=True - ).filter(lambda axes: axes[0] % 2 != axes[1] % 2), + kd=st.booleans(), + ord=st.one_of( + helpers.ints(min_value=-5, max_value=5), + helpers.floats(min_value=-5, max_value=5.0), + st.sampled_from((float("inf"), -float("inf"))), + ), + dtype=helpers.get_dtypes("numeric", full=False, none=True), ) -def test_diagonal(*, dtype_x, offset, axes, test_flags, backend_fw, fn_name, on_device): - dtype, x = dtype_x +def test_vector_norm( + *, dtype_values_axis, kd, ord, dtype, test_flags, backend_fw, fn_name, on_device +): + x_dtype, x, axis = dtype_values_axis + # to avoid tuple axis with only one axis as force_int_axis can't generate + # axis with two axes + if isinstance(axis, tuple) and len(axis) == 1: + axis = axis[0] helpers.test_function( - input_dtypes=dtype, + input_dtypes=x_dtype, test_flags=test_flags, backend_to_test=backend_fw, fn_name=fn_name, on_device=on_device, x=x[0], - offset=offset, - axis1=axes[0], - axis2=axes[1], + axis=axis, + keepdims=kd, + ord=ord, + dtype=dtype[0], + atol_=1e-08, ) + # Specific value test to handle cases when ord is one of {inf, -inf} -@st.composite -def _diag_helper(draw): - dtype, x = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - small_abs_safety_factor=2, - large_abs_safety_factor=2, - safety_factor_scale="log", - min_num_dims=1, - max_num_dims=2, - min_dim_size=1, - max_dim_size=50, - ) - ) - shape = x[0].shape - if len(shape) == 2: - k = draw(helpers.ints(min_value=-shape[0] + 1, max_value=shape[1] - 1)) - else: - k = draw(helpers.ints(min_value=0, max_value=shape[0])) - return dtype, x, k + with BackendHandler.update_backend(backend_fw) as ivy_backend: + arr = ivy_backend.array([[1.0, 2.0, 3.0], [-1.0, 2.0, 4.0]]) + arr_normed_inf = ivy_backend.vector_norm(arr, axis=0, ord=float("inf")) + arr_normed_min_inf = ivy_backend.vector_norm(arr, axis=0, ord=float("-inf")) + with BackendHandler.update_backend(test_flags.ground_truth_backend) as gt_backend: + gt_arr_normed_inf = gt_backend.array([1.0, 2.0, 4.0]) + gt_arr_normed_min_inf = gt_backend.array([1.0, 2.0, 3.0]) -# diag -@handle_test( - fn_tree="functional.ivy.diag", - dtype_x_k=_diag_helper(), -) -def test_diag(*, dtype_x_k, test_flags, backend_fw, fn_name, on_device): - dtype, x, k = dtype_x_k - helpers.test_function( - input_dtypes=dtype, - test_flags=test_flags, - backend_to_test=backend_fw, - fn_name=fn_name, - on_device=on_device, - x=x[0], - k=k, + helpers.assert_all_close( + arr_normed_inf, + gt_arr_normed_inf, + backend=backend_fw, + ground_truth_backend=test_flags.ground_truth_backend, + ) + helpers.assert_all_close( + arr_normed_min_inf, + gt_arr_normed_min_inf, + backend=backend_fw, + ground_truth_backend=test_flags.ground_truth_backend, ) -# vander @handle_test( - fn_tree="functional.ivy.vander", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - shape=st.tuples( - helpers.ints(min_value=1, max_value=10), - ), - large_abs_safety_factor=15, - small_abs_safety_factor=15, - safety_factor_scale="log", - ), - N=st.integers(min_value=1, max_value=10) | st.none(), - increasing=st.booleans(), + fn_tree="functional.ivy.vector_to_skew_symmetric_matrix", + dtype_x=_get_dtype_and_vector(), ) -def test_vander( - *, dtype_and_x, N, increasing, test_flags, backend_fw, fn_name, on_device +def test_vector_to_skew_symmetric_matrix( + *, dtype_x, test_flags, backend_fw, fn_name, on_device ): - input_dtype, x = dtype_and_x + input_dtype, x = dtype_x helpers.test_function( input_dtypes=input_dtype, test_flags=test_flags, backend_to_test=backend_fw, fn_name=fn_name, on_device=on_device, - rtol_=1e-2, - atol_=1e-2, - x=x[0], - N=N, - increasing=increasing, + vector=x, ) diff --git a/ivy_tests/test_ivy/test_functional/test_core/test_manipulation.py b/ivy_tests/test_ivy/test_functional/test_core/test_manipulation.py index e881a673a7caf..f3564e10e7317 100644 --- a/ivy_tests/test_ivy/test_functional/test_core/test_manipulation.py +++ b/ivy_tests/test_ivy/test_functional/test_core/test_manipulation.py @@ -1,4 +1,3 @@ -# For Review """Collection of tests for manipulation functions.""" # global @@ -12,6 +11,10 @@ from ivy_tests.test_ivy.helpers import handle_test +# --- Helpers --- # +# --------------- # + + @st.composite def _arrays_idx_n_dtypes(draw): num_dims = draw(st.shared(helpers.ints(min_value=1, max_value=4), key="num_dims")) @@ -46,82 +49,109 @@ def _arrays_idx_n_dtypes(draw): return xs, input_dtypes, unique_idx -# concat -@handle_test( - fn_tree="functional.ivy.concat", - xs_n_input_dtypes_n_unique_idx=_arrays_idx_n_dtypes(), -) -def test_concat( - *, xs_n_input_dtypes_n_unique_idx, test_flags, backend_fw, fn_name, on_device -): - xs, input_dtypes, unique_idx = xs_n_input_dtypes_n_unique_idx - helpers.test_function( - input_dtypes=input_dtypes, - test_flags=test_flags, - backend_to_test=backend_fw, - fn_name=fn_name, - on_device=on_device, - xs=xs, - axis=unique_idx, - ) +# Extra # +# ------# -# expand_dims -@handle_test( - fn_tree="functional.ivy.expand_dims", - dtype_value=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - shape=st.shared(helpers.get_shape(), key="value_shape"), - ), - axis=helpers.get_axis( - shape=st.shared(helpers.get_shape(), key="value_shape"), - ), -) -def test_expand_dims(*, dtype_value, axis, test_flags, backend_fw, fn_name, on_device): - dtype, value = dtype_value - try: - helpers.test_function( - input_dtypes=dtype, - test_flags=test_flags, - backend_to_test=backend_fw, - fn_name=fn_name, - on_device=on_device, - x=value[0], - axis=axis, +@st.composite +def _basic_min_x_max(draw): + dtype, value = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), ) - # ToDo: fix `get_axis`; `unique=True` does not always work - except (ValueError, Exception) as e: - if "repeated axis" in str(e): - assume(False) - raise e + ) + min_val = draw(helpers.array_values(dtype=dtype[0], shape=())) + max_val = draw( + helpers.array_values(dtype=dtype[0], shape=()).filter(lambda x: x > min_val) + ) + return [dtype], (value[0], min_val, max_val) -# flip -@handle_test( - fn_tree="functional.ivy.flip", - dtype_value=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid", full=True), - shape=st.shared(helpers.get_shape(min_num_dims=1), key="value_shape"), - ), - axis=helpers.get_axis( - shape=st.shared(helpers.get_shape(min_num_dims=1), key="value_shape"), - min_size=1, - max_size=1, - force_int=True, - ), -) -def test_flip(*, dtype_value, axis, test_flags, backend_fw, fn_name, on_device): - dtype, value = dtype_value +@st.composite +def _constant_pad_helper(draw): + dtype, value, shape = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), ret_shape=True, min_num_dims=1 + ) + ) + pad_width = tuple( + draw( + st.lists( + st.tuples( + helpers.ints(min_value=0, max_value=5), + helpers.ints(min_value=0, max_value=5), + ), + min_size=len(shape), + max_size=len(shape), + ) + ) + ) + return dtype, value, pad_width - helpers.test_function( - input_dtypes=dtype, - test_flags=test_flags, - backend_to_test=backend_fw, - fn_name=fn_name, - on_device=on_device, - x=value[0], - axis=axis, + +@st.composite +def _get_splits( + draw, + allow_none=True, + min_num_dims=1, + axis=None, + allow_array_indices=True, + is_mod_split=False, +): + """Generate valid splits, either by generating an integer that evenly divides the + axis or a list of splits that sum to the length of the axis being split.""" + shape = draw( + st.shared(helpers.get_shape(min_num_dims=min_num_dims), key="value_shape") ) + if axis is None: + axis = draw( + st.shared(helpers.get_axis(shape=shape, force_int=True), key="target_axis") + ) + + @st.composite + def _get_int_split(draw): + if shape[axis] == 0: + return 0 + factors = [] + for i in range(1, shape[axis] + 1): + if shape[axis] % i == 0: + factors.append(i) + return draw(st.sampled_from(factors)) + + @st.composite + def _get_list_split(draw, allow_arr_indices=True, is_other_split=False): + num_or_size_splits = [] + while sum(num_or_size_splits) < shape[axis]: + split_value = draw( + helpers.ints( + min_value=1, + max_value=shape[axis] - sum(num_or_size_splits), + ) + ) + num_or_size_splits.append(split_value) + if is_other_split: + num_or_size_splits = list(set(num_or_size_splits)) + if allow_arr_indices: + gen_random_native = draw(st.booleans()) + if gen_random_native: + return np.asarray(num_or_size_splits, dtype=np.int32) + return num_or_size_splits + + if allow_none: + return draw( + _get_list_split( + allow_arr_indices=allow_array_indices, is_other_split=is_mod_split + ) + | _get_int_split() + | st.none() + ) + else: + return draw( + _get_list_split( + allow_arr_indices=allow_array_indices, is_other_split=is_mod_split + ) + | _get_int_split() + ) @st.composite @@ -132,128 +162,32 @@ def _permute_dims_helper(draw): return permutation -# permute_dims -@handle_test( - fn_tree="functional.ivy.permute_dims", - dtype_value=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid", full=True), - shape=st.shared(helpers.get_shape(min_num_dims=1), key="value_shape"), - ), - permutation=_permute_dims_helper(), -) -def test_permute_dims( - *, dtype_value, permutation, test_flags, backend_fw, fn_name, on_device -): - dtype, value = dtype_value - - helpers.test_function( - input_dtypes=dtype, - test_flags=test_flags, - backend_to_test=backend_fw, - fn_name=fn_name, - on_device=on_device, - x=value[0], - axes=permutation, +@st.composite +def _repeat_helper(draw): + shape = draw(st.shared(helpers.get_shape(min_num_dims=1), key="value_shape")) + axis = draw( + st.shared( + st.one_of(st.none(), helpers.get_axis(shape=shape, max_size=1)), key="axis" + ) ) + if not isinstance(axis, int) and axis is not None: + axis = axis[0] -@handle_test( - fn_tree="functional.ivy.reshape", - dtype_value=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid", full=True), - shape=st.shared(helpers.get_shape(), key="value_shape"), - ), - reshape=helpers.reshape_shapes( - shape=st.shared(helpers.get_shape(), key="value_shape") - ), - order=st.sampled_from(["C", "F"]), - allowzero=st.booleans(), -) -def test_reshape( - *, - dtype_value, - reshape, - order, - allowzero, - test_flags, - backend_fw, - fn_name, - on_device -): - dtype, value = dtype_value - - helpers.test_function( - input_dtypes=dtype, - test_flags=test_flags, - backend_to_test=backend_fw, - fn_name=fn_name, - on_device=on_device, - x=value[0], - shape=reshape, - order=order, - allowzero=allowzero, + repeat_shape = ( + (draw(st.one_of(st.just(1), st.just(shape[axis]))),) + if axis is not None + else (1,) ) - - -# roll -@handle_test( - fn_tree="functional.ivy.roll", - dtype_value=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - shape=st.shared(helpers.get_shape(min_num_dims=1), key="value_shape"), - large_abs_safety_factor=8, - small_abs_safety_factor=8, - safety_factor_scale="log", - ), - shift=helpers.dtype_and_values( - available_dtypes=[ivy.int32], - max_num_dims=1, - min_dim_size=st.shared( - helpers.ints(min_value=1, max_value=10), - key="shift_len", - ), - max_dim_size=st.shared( - helpers.ints(min_value=1, max_value=10), - key="shift_len", - ), - ), - axis=helpers.get_axis( - shape=st.shared(helpers.get_shape(min_num_dims=1), key="value_shape"), - force_tuple=True, - unique=False, - min_size=st.shared( - helpers.ints(min_value=1, max_value=10), - key="shift_len", - ), - max_size=st.shared( - helpers.ints(min_value=1, max_value=10), - key="shift_len", - ), - ), - # test_gradients=st.just(False), -) -def test_roll(*, dtype_value, shift, axis, test_flags, backend_fw, fn_name, on_device): - value_dtype, value = dtype_value - shift_dtype, shift_val = shift - - if shift_val[0].ndim == 0: # If shift is an int - shift_val = shift_val[0] # Drop shift's dtype (always int32) - axis = axis[0] # Extract an axis value from the tuple - else: - # Drop shift's dtype (always int32) and convert list to tuple - shift_val = tuple(shift_val[0].tolist()) - - helpers.test_function( - input_dtypes=value_dtype + shift_dtype, - test_flags=test_flags, - backend_to_test=backend_fw, - fn_name=fn_name, - on_device=on_device, - x=value[0], - shift=shift_val, - axis=axis, - xs_grad_idxs=[[0, 0]], + repeat = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("integer"), + shape=repeat_shape, + min_value=0, + max_value=10, + ) ) + return repeat @st.composite @@ -267,29 +201,6 @@ def _squeeze_helper(draw): return draw(st.sampled_from(valid_axes)) -# squeeze -@handle_test( - fn_tree="functional.ivy.squeeze", - dtype_value=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid", full=True), - shape=st.shared(helpers.get_shape(), key="value_shape"), - ), - axis=_squeeze_helper(), -) -def test_squeeze(*, dtype_value, axis, test_flags, backend_fw, fn_name, on_device): - dtype, value = dtype_value - - helpers.test_function( - input_dtypes=dtype, - test_flags=test_flags, - backend_to_test=backend_fw, - fn_name=fn_name, - on_device=on_device, - x=value[0], - axis=axis, - ) - - @st.composite def _stack_helper(draw): shape = draw(st.shared(helpers.get_shape(min_num_dims=1), key="values_shape")) @@ -306,45 +217,8 @@ def _stack_helper(draw): return dtypes, arrays -# stack -@handle_test( - fn_tree="functional.ivy.stack", - dtypes_arrays=_stack_helper(), - axis=helpers.get_axis( - shape=st.shared(helpers.get_shape(min_num_dims=1), key="values_shape"), - force_int=True, - ), -) -def test_stack(*, dtypes_arrays, axis, test_flags, backend_fw, fn_name, on_device): - dtypes, arrays = dtypes_arrays - - helpers.test_function( - input_dtypes=dtypes, - test_flags=test_flags, - backend_to_test=backend_fw, - fn_name=fn_name, - on_device=on_device, - arrays=arrays, - axis=axis, - ) - - -# Extra # -# ------# - - -@st.composite -def _basic_min_x_max(draw): - dtype, value = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - ) - ) - min_val = draw(helpers.array_values(dtype=dtype[0], shape=())) - max_val = draw( - helpers.array_values(dtype=dtype[0], shape=()).filter(lambda x: x > min_val) - ) - return [dtype], (value[0], min_val, max_val) +# --- Main --- # +# ------------ # # clip @@ -366,26 +240,24 @@ def test_clip(*, dtype_x_min_max, test_flags, backend_fw, fn_name, on_device): ) -@st.composite -def _constant_pad_helper(draw): - dtype, value, shape = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), ret_shape=True, min_num_dims=1 - ) - ) - pad_width = tuple( - draw( - st.lists( - st.tuples( - helpers.ints(min_value=0, max_value=5), - helpers.ints(min_value=0, max_value=5), - ), - min_size=len(shape), - max_size=len(shape), - ) - ) +# concat +@handle_test( + fn_tree="functional.ivy.concat", + xs_n_input_dtypes_n_unique_idx=_arrays_idx_n_dtypes(), +) +def test_concat( + *, xs_n_input_dtypes_n_unique_idx, test_flags, backend_fw, fn_name, on_device +): + xs, input_dtypes, unique_idx = xs_n_input_dtypes_n_unique_idx + helpers.test_function( + input_dtypes=input_dtypes, + test_flags=test_flags, + backend_to_test=backend_fw, + fn_name=fn_name, + on_device=on_device, + xs=xs, + axis=unique_idx, ) - return dtype, value, pad_width # constant_pad @@ -410,32 +282,87 @@ def test_constant_pad( ) -@st.composite -def _repeat_helper(draw): - shape = draw(st.shared(helpers.get_shape(min_num_dims=1), key="value_shape")) - axis = draw( - st.shared( - st.one_of(st.none(), helpers.get_axis(shape=shape, max_size=1)), key="axis" +# expand_dims +@handle_test( + fn_tree="functional.ivy.expand_dims", + dtype_value=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + shape=st.shared(helpers.get_shape(), key="value_shape"), + ), + axis=helpers.get_axis( + shape=st.shared(helpers.get_shape(), key="value_shape"), + ), +) +def test_expand_dims(*, dtype_value, axis, test_flags, backend_fw, fn_name, on_device): + dtype, value = dtype_value + try: + helpers.test_function( + input_dtypes=dtype, + test_flags=test_flags, + backend_to_test=backend_fw, + fn_name=fn_name, + on_device=on_device, + x=value[0], + axis=axis, ) + # ToDo: fix `get_axis`; `unique=True` does not always work + except (ValueError, Exception) as e: + if "repeated axis" in str(e): + assume(False) + raise e + + +# flip +@handle_test( + fn_tree="functional.ivy.flip", + dtype_value=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid", full=True), + shape=st.shared(helpers.get_shape(min_num_dims=1), key="value_shape"), + ), + axis=helpers.get_axis( + shape=st.shared(helpers.get_shape(min_num_dims=1), key="value_shape"), + min_size=1, + max_size=1, + force_int=True, + ), +) +def test_flip(*, dtype_value, axis, test_flags, backend_fw, fn_name, on_device): + dtype, value = dtype_value + + helpers.test_function( + input_dtypes=dtype, + test_flags=test_flags, + backend_to_test=backend_fw, + fn_name=fn_name, + on_device=on_device, + x=value[0], + axis=axis, ) - if not isinstance(axis, int) and axis is not None: - axis = axis[0] - repeat_shape = ( - (draw(st.one_of(st.just(1), st.just(shape[axis]))),) - if axis is not None - else (1,) - ) - repeat = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("integer"), - shape=repeat_shape, - min_value=0, - max_value=10, - ) +# permute_dims +@handle_test( + fn_tree="functional.ivy.permute_dims", + dtype_value=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid", full=True), + shape=st.shared(helpers.get_shape(min_num_dims=1), key="value_shape"), + ), + permutation=_permute_dims_helper(), +) +def test_permute_dims( + *, dtype_value, permutation, test_flags, backend_fw, fn_name, on_device +): + dtype, value = dtype_value + + helpers.test_function( + input_dtypes=dtype, + test_flags=test_flags, + backend_to_test=backend_fw, + fn_name=fn_name, + on_device=on_device, + x=value[0], + axes=permutation, ) - return repeat # repeat @@ -483,69 +410,103 @@ def test_repeat( ) -@st.composite -def _get_splits( - draw, - allow_none=True, - min_num_dims=1, - axis=None, - allow_array_indices=True, - is_mod_split=False, +@handle_test( + fn_tree="functional.ivy.reshape", + dtype_value=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid", full=True), + shape=st.shared(helpers.get_shape(), key="value_shape"), + ), + reshape=helpers.reshape_shapes( + shape=st.shared(helpers.get_shape(), key="value_shape") + ), + order=st.sampled_from(["C", "F"]), + allowzero=st.booleans(), +) +def test_reshape( + *, + dtype_value, + reshape, + order, + allowzero, + test_flags, + backend_fw, + fn_name, + on_device ): - """Generate valid splits, either by generating an integer that evenly divides the - axis or a list of splits that sum to the length of the axis being split.""" - shape = draw( - st.shared(helpers.get_shape(min_num_dims=min_num_dims), key="value_shape") + dtype, value = dtype_value + + helpers.test_function( + input_dtypes=dtype, + test_flags=test_flags, + backend_to_test=backend_fw, + fn_name=fn_name, + on_device=on_device, + x=value[0], + shape=reshape, + order=order, + allowzero=allowzero, ) - if axis is None: - axis = draw( - st.shared(helpers.get_axis(shape=shape, force_int=True), key="target_axis") - ) - @st.composite - def _get_int_split(draw): - if shape[axis] == 0: - return 0 - factors = [] - for i in range(1, shape[axis] + 1): - if shape[axis] % i == 0: - factors.append(i) - return draw(st.sampled_from(factors)) - @st.composite - def _get_list_split(draw, allow_arr_indices=True, is_other_split=False): - num_or_size_splits = [] - while sum(num_or_size_splits) < shape[axis]: - split_value = draw( - helpers.ints( - min_value=1, - max_value=shape[axis] - sum(num_or_size_splits), - ) - ) - num_or_size_splits.append(split_value) - if is_other_split: - num_or_size_splits = list(set(num_or_size_splits)) - if allow_arr_indices: - gen_random_native = draw(st.booleans()) - if gen_random_native: - return np.asarray(num_or_size_splits, dtype=np.int32) - return num_or_size_splits +# roll +@handle_test( + fn_tree="functional.ivy.roll", + dtype_value=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + shape=st.shared(helpers.get_shape(min_num_dims=1), key="value_shape"), + large_abs_safety_factor=8, + small_abs_safety_factor=8, + safety_factor_scale="log", + ), + shift=helpers.dtype_and_values( + available_dtypes=[ivy.int32], + max_num_dims=1, + min_dim_size=st.shared( + helpers.ints(min_value=1, max_value=10), + key="shift_len", + ), + max_dim_size=st.shared( + helpers.ints(min_value=1, max_value=10), + key="shift_len", + ), + ), + axis=helpers.get_axis( + shape=st.shared(helpers.get_shape(min_num_dims=1), key="value_shape"), + force_tuple=True, + unique=False, + min_size=st.shared( + helpers.ints(min_value=1, max_value=10), + key="shift_len", + ), + max_size=st.shared( + helpers.ints(min_value=1, max_value=10), + key="shift_len", + ), + ), + # test_gradients=st.just(False), +) +def test_roll(*, dtype_value, shift, axis, test_flags, backend_fw, fn_name, on_device): + value_dtype, value = dtype_value + shift_dtype, shift_val = shift - if allow_none: - return draw( - _get_list_split( - allow_arr_indices=allow_array_indices, is_other_split=is_mod_split - ) - | _get_int_split() - | st.none() - ) + if shift_val[0].ndim == 0: # If shift is an int + shift_val = shift_val[0] # Drop shift's dtype (always int32) + axis = axis[0] # Extract an axis value from the tuple else: - return draw( - _get_list_split( - allow_arr_indices=allow_array_indices, is_other_split=is_mod_split - ) - | _get_int_split() - ) + # Drop shift's dtype (always int32) and convert list to tuple + shift_val = tuple(shift_val[0].tolist()) + + helpers.test_function( + input_dtypes=value_dtype + shift_dtype, + test_flags=test_flags, + backend_to_test=backend_fw, + fn_name=fn_name, + on_device=on_device, + x=value[0], + shift=shift_val, + axis=axis, + xs_grad_idxs=[[0, 0]], + ) # TODO: there is a failure with paddle (dtype('int32')) caused by the `_get_splits` @@ -598,6 +559,52 @@ def test_split( ) +# squeeze +@handle_test( + fn_tree="functional.ivy.squeeze", + dtype_value=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid", full=True), + shape=st.shared(helpers.get_shape(), key="value_shape"), + ), + axis=_squeeze_helper(), +) +def test_squeeze(*, dtype_value, axis, test_flags, backend_fw, fn_name, on_device): + dtype, value = dtype_value + + helpers.test_function( + input_dtypes=dtype, + test_flags=test_flags, + backend_to_test=backend_fw, + fn_name=fn_name, + on_device=on_device, + x=value[0], + axis=axis, + ) + + +# stack +@handle_test( + fn_tree="functional.ivy.stack", + dtypes_arrays=_stack_helper(), + axis=helpers.get_axis( + shape=st.shared(helpers.get_shape(min_num_dims=1), key="values_shape"), + force_int=True, + ), +) +def test_stack(*, dtypes_arrays, axis, test_flags, backend_fw, fn_name, on_device): + dtypes, arrays = dtypes_arrays + + helpers.test_function( + input_dtypes=dtypes, + test_flags=test_flags, + backend_to_test=backend_fw, + fn_name=fn_name, + on_device=on_device, + arrays=arrays, + axis=axis, + ) + + # swapaxes @handle_test( fn_tree="functional.ivy.swapaxes", @@ -661,24 +668,6 @@ def test_tile(*, dtype_value, repeat, test_flags, backend_fw, fn_name, on_device ) -# zero_pad -@handle_test( - fn_tree="functional.ivy.zero_pad", - dtype_value_pad_width=_constant_pad_helper(), -) -def test_zero_pad(*, dtype_value_pad_width, test_flags, backend_fw, fn_name, on_device): - dtype, value, pad_width = dtype_value_pad_width - helpers.test_function( - input_dtypes=dtype, - test_flags=test_flags, - backend_to_test=backend_fw, - fn_name=fn_name, - on_device=on_device, - x=value[0], - pad_width=pad_width, - ) - - # unstack @handle_test( fn_tree="functional.ivy.unstack", @@ -708,3 +697,21 @@ def test_unstack( axis=axis, keepdims=keepdims, ) + + +# zero_pad +@handle_test( + fn_tree="functional.ivy.zero_pad", + dtype_value_pad_width=_constant_pad_helper(), +) +def test_zero_pad(*, dtype_value_pad_width, test_flags, backend_fw, fn_name, on_device): + dtype, value, pad_width = dtype_value_pad_width + helpers.test_function( + input_dtypes=dtype, + test_flags=test_flags, + backend_to_test=backend_fw, + fn_name=fn_name, + on_device=on_device, + x=value[0], + pad_width=pad_width, + ) diff --git a/ivy_tests/test_ivy/test_functional/test_core/test_meta.py b/ivy_tests/test_ivy/test_functional/test_core/test_meta.py index 19b63949a3494..bc6b5d524a1b4 100644 --- a/ivy_tests/test_ivy/test_functional/test_core/test_meta.py +++ b/ivy_tests/test_ivy/test_functional/test_core/test_meta.py @@ -11,14 +11,7 @@ from ivy_tests.test_ivy.helpers.pipeline_helper import BackendHandler -# ToDo: replace dict checks for verifying costs with analytic calculations - - -# First Order # -# ------------# - - -# fomaml step unique vars +# fomaml step overlapping vars @handle_test( fn_tree="functional.ivy.fomaml_step", inner_grad_steps=helpers.ints(min_value=1, max_value=3), @@ -29,7 +22,7 @@ num_tasks=helpers.ints(min_value=1, max_value=2), return_inner_v=st.sampled_from(["first", "all", False]), ) -def test_fomaml_step_unique_vars( +def test_fomaml_step_overlapping_vars( on_device, inner_grad_steps, with_outer_cost_fn, @@ -40,10 +33,10 @@ def test_fomaml_step_unique_vars( return_inner_v, backend_fw, ): - # Numpy does not support gradients, and jax does not support gradients on - # custom nested classes + # Numpy does not support gradients, jax does not support gradients on custom + # nested classes if backend_fw == "numpy": - return + pytest.skip() with BackendHandler.update_backend(backend_fw) as ivy_backend: # config @@ -91,7 +84,7 @@ def inner_cost_fn(batch_in, v): batch_in.cont_unstack_conts(0, keepdims=True), v.cont_unstack_conts(0, keepdims=True), ): - cost = cost - (sub_v["latent"] * sub_batch_in["x"] * sub_v["weight"])[0] + cost = cost - (sub_batch_in["x"] * sub_v["latent"] * sub_v["weight"])[0] return cost / batch_size # outer cost function @@ -102,12 +95,12 @@ def outer_cost_fn(batch_in, v): batch_in.cont_unstack_conts(0, keepdims=True), v.cont_unstack_conts(0, keepdims=True), ): - cost = cost + (sub_v["latent"] * sub_batch_in["x"] * sub_v["weight"])[0] + cost = cost + (sub_batch_in["x"] * sub_v["latent"] * sub_v["weight"])[0] return cost / batch_size # numpy - weight_np = ivy_backend.to_numpy(variables.weight[0:1]) latent_np = ivy_backend.to_numpy(variables.latent[0:1]) + weight_np = ivy_backend.to_numpy(variables.weight[0:1]) batch_np = batch.cont_map(lambda x, kc: ivy_backend.to_numpy(x)) # true gradient @@ -130,6 +123,11 @@ def outer_cost_fn(batch_in, v): else: true_weight_grad = sum([og[-1] for og in all_outer_grads]) / num_tasks + # true latent gradient + true_latent_grad = np.array( + [(-1 - (num_tasks - 1) / 2) * (-1 if with_outer_cost_fn else 1)] + ) + # true cost true_cost_dict = { 1: { @@ -160,7 +158,6 @@ def outer_cost_fn(batch_in, v): average_across_steps=average_across_steps, batched=batched, inner_v="latent", - outer_v="weight", return_inner_v=return_inner_v, stop_gradients=stop_gradients, ) @@ -175,6 +172,9 @@ def outer_cost_fn(batch_in, v): assert np.allclose( ivy_backend.to_numpy(outer_grads.weight[0]), np.array(true_weight_grad) ) + assert np.allclose( + ivy_backend.to_numpy(outer_grads.latent[0]), np.array(true_latent_grad) + ) if return_inner_v: inner_v_rets = rets[2] assert isinstance(inner_v_rets, ivy_backend.Container) @@ -368,7 +368,14 @@ def loss_grad_fn(sub_batch_in, w_in, outer=False): assert list(inner_v_rets.cont_shape) == [1, 1] -# fomaml step overlapping vars +# ToDo: replace dict checks for verifying costs with analytic calculations + + +# First Order # +# ------------# + + +# fomaml step unique vars @handle_test( fn_tree="functional.ivy.fomaml_step", inner_grad_steps=helpers.ints(min_value=1, max_value=3), @@ -379,7 +386,7 @@ def loss_grad_fn(sub_batch_in, w_in, outer=False): num_tasks=helpers.ints(min_value=1, max_value=2), return_inner_v=st.sampled_from(["first", "all", False]), ) -def test_fomaml_step_overlapping_vars( +def test_fomaml_step_unique_vars( on_device, inner_grad_steps, with_outer_cost_fn, @@ -390,10 +397,10 @@ def test_fomaml_step_overlapping_vars( return_inner_v, backend_fw, ): - # Numpy does not support gradients, jax does not support gradients on custom - # nested classes + # Numpy does not support gradients, and jax does not support gradients on + # custom nested classes if backend_fw == "numpy": - pytest.skip() + return with BackendHandler.update_backend(backend_fw) as ivy_backend: # config @@ -441,7 +448,7 @@ def inner_cost_fn(batch_in, v): batch_in.cont_unstack_conts(0, keepdims=True), v.cont_unstack_conts(0, keepdims=True), ): - cost = cost - (sub_batch_in["x"] * sub_v["latent"] * sub_v["weight"])[0] + cost = cost - (sub_v["latent"] * sub_batch_in["x"] * sub_v["weight"])[0] return cost / batch_size # outer cost function @@ -452,12 +459,12 @@ def outer_cost_fn(batch_in, v): batch_in.cont_unstack_conts(0, keepdims=True), v.cont_unstack_conts(0, keepdims=True), ): - cost = cost + (sub_batch_in["x"] * sub_v["latent"] * sub_v["weight"])[0] + cost = cost + (sub_v["latent"] * sub_batch_in["x"] * sub_v["weight"])[0] return cost / batch_size # numpy - latent_np = ivy_backend.to_numpy(variables.latent[0:1]) weight_np = ivy_backend.to_numpy(variables.weight[0:1]) + latent_np = ivy_backend.to_numpy(variables.latent[0:1]) batch_np = batch.cont_map(lambda x, kc: ivy_backend.to_numpy(x)) # true gradient @@ -480,11 +487,6 @@ def outer_cost_fn(batch_in, v): else: true_weight_grad = sum([og[-1] for og in all_outer_grads]) / num_tasks - # true latent gradient - true_latent_grad = np.array( - [(-1 - (num_tasks - 1) / 2) * (-1 if with_outer_cost_fn else 1)] - ) - # true cost true_cost_dict = { 1: { @@ -515,6 +517,7 @@ def outer_cost_fn(batch_in, v): average_across_steps=average_across_steps, batched=batched, inner_v="latent", + outer_v="weight", return_inner_v=return_inner_v, stop_gradients=stop_gradients, ) @@ -529,9 +532,6 @@ def outer_cost_fn(batch_in, v): assert np.allclose( ivy_backend.to_numpy(outer_grads.weight[0]), np.array(true_weight_grad) ) - assert np.allclose( - ivy_backend.to_numpy(outer_grads.latent[0]), np.array(true_latent_grad) - ) if return_inner_v: inner_v_rets = rets[2] assert isinstance(inner_v_rets, ivy_backend.Container) @@ -541,137 +541,7 @@ def outer_cost_fn(batch_in, v): assert list(inner_v_rets.cont_shape) == [1, 1] -# reptile step -@pytest.mark.parametrize("inner_grad_steps", [1, 2, 3]) -@pytest.mark.parametrize("batched", [True, False]) -@pytest.mark.parametrize("stop_gradients", [True, False]) -@pytest.mark.parametrize("num_tasks", [1, 2]) -@pytest.mark.parametrize("return_inner_v", ["first", "all", False]) -def test_reptile_step( - on_device, - inner_grad_steps, - batched, - stop_gradients, - num_tasks, - return_inner_v, - backend_fw, -): - if backend_fw == "numpy": - # Numpy does not support gradients, jax does not support gradients on custom - # nested classes, - pytest.skip() - - with BackendHandler.update_backend(backend_fw) as ivy_backend: - # config - inner_learning_rate = 1e-2 - variable_fn = ivy_backend.functional.ivy._variable - - # create variable - if batched: - variables = ivy_backend.Container( - { - "latent": variable_fn( - ivy_backend.repeat( - ivy_backend.array([[1.0]], device=on_device), - num_tasks, - axis=0, - ) - ) - } - ) - else: - variables = ivy_backend.Container( - {"latent": variable_fn(ivy_backend.array([1.0], device=on_device))} - ) - - # batch - batch = ivy_backend.Container( - {"x": ivy_backend.arange(1, num_tasks + 1, dtype="float32")} - ) - - # inner cost function - def inner_cost_fn(batch_in, v): - cost = 0 - batch_size = batch_in.cont_shape[0] - for sub_batch_in, sub_v in zip( - batch_in.cont_unstack_conts(0, keepdims=True), - v.cont_unstack_conts(0, keepdims=True), - ): - cost = cost - (sub_batch_in["x"] * sub_v["latent"] ** 2)[0] - return cost / batch_size - - # numpy - latent_np = ivy_backend.to_numpy(variables.latent[0:1]) - batch_np = batch.cont_map(lambda x, kc: ivy_backend.to_numpy(x)) - - # loss grad function - def loss_grad_fn(sub_batch_in, w_in): - return -2 * sub_batch_in["x"][0] * w_in - - # true gradient - true_outer_grads = list() - for sub_batch in batch_np.cont_unstack_conts(0, True, num_tasks): - ws = list() - grads = list() - ws.append(latent_np) - for step in range(inner_grad_steps): - update_grad = loss_grad_fn(sub_batch, ws[-1]) - w = ws[-1] - inner_learning_rate * update_grad - grads.append(update_grad) - ws.append(w) - grads.append(loss_grad_fn(sub_batch, ws[-1])) - - # true outer grad - true_outer_grad = sum(grads) / len(grads) - true_outer_grads.append(true_outer_grad) - true_outer_grad = ( - sum(true_outer_grads) / len(true_outer_grads) - ) / inner_learning_rate - - # true cost - true_cost_dict = { - 1: {1: -1.0202, 2: -1.5509}, - 2: {1: -1.0409441, 2: -1.6042916}, - 3: {1: -1.0622487, 2: -1.6603187}, - } - true_cost = true_cost_dict[inner_grad_steps][num_tasks] - - # meta update - rets = ivy_backend.reptile_step( - batch, - inner_cost_fn, - variables, - inner_grad_steps, - inner_learning_rate, - batched=batched, - return_inner_v=return_inner_v, - stop_gradients=stop_gradients, - ) - calc_cost = rets[0] - if stop_gradients: - assert ivy_backend.equal( - ivy_backend.functional.ivy._is_variable(calc_cost, exclusive=True), - False, - ) - assert np.allclose(ivy_backend.to_scalar(calc_cost), true_cost) - outer_grads = rets[1] - assert np.allclose( - ivy_backend.to_numpy(outer_grads.latent[0]), np.array(true_outer_grad) - ) - if return_inner_v: - inner_v_rets = rets[2] - assert isinstance(inner_v_rets, ivy_backend.Container) - if return_inner_v == "all": - assert list(inner_v_rets.cont_shape) == [num_tasks, 1] - elif return_inner_v == "first": - assert list(inner_v_rets.cont_shape) == [1, 1] - - -# Second Order # -# -------------# - - -# maml step unique vars +# maml step overlapping vars @pytest.mark.parametrize("inner_grad_steps", [1, 2, 3]) @pytest.mark.parametrize("with_outer_cost_fn", [True, False]) @pytest.mark.parametrize("average_across_steps", [True, False]) @@ -679,7 +549,7 @@ def loss_grad_fn(sub_batch_in, w_in): @pytest.mark.parametrize("stop_gradients", [True, False]) @pytest.mark.parametrize("num_tasks", [1, 2]) @pytest.mark.parametrize("return_inner_v", ["first", "all", False]) -def test_maml_step_unique_vars( +def test_maml_step_overlapping_vars( on_device, inner_grad_steps, with_outer_cost_fn, @@ -757,11 +627,11 @@ def outer_cost_fn(batch_in, v): return cost / batch_size # numpy - weight_np = ivy_backend.to_numpy(variables.weight[0:1]) - latent_np = ivy_backend.to_numpy(variables.latent[0:1]) + latent_np = ivy_backend.to_numpy(variables.latent) + weight_np = ivy_backend.to_numpy(variables.weight) batch_np = batch.cont_map(lambda x, kc: ivy_backend.to_numpy(x)) - # true gradient + # true weight gradient all_outer_grads = list() for sub_batch in batch_np.cont_unstack_conts(0, True, num_tasks): all_outer_grads.append( @@ -779,11 +649,16 @@ def outer_cost_fn(batch_in, v): ] ) if average_across_steps: - true_outer_grad = ( + true_weight_grad = ( sum([sum(og) / len(og) for og in all_outer_grads]) / num_tasks ) else: - true_outer_grad = sum([og[-1] for og in all_outer_grads]) / num_tasks + true_weight_grad = sum([og[-1] for og in all_outer_grads]) / num_tasks + + # true latent gradient + true_latent_grad = np.array( + [(-1 - (num_tasks - 1) / 2) * (-1 if with_outer_cost_fn else 1)] + ) # true cost true_cost_dict = { @@ -815,7 +690,6 @@ def outer_cost_fn(batch_in, v): average_across_steps=average_across_steps, batched=batched, inner_v="latent", - outer_v="weight", return_inner_v=return_inner_v, stop_gradients=stop_gradients, ) @@ -828,7 +702,10 @@ def outer_cost_fn(batch_in, v): assert np.allclose(ivy_backend.to_scalar(calc_cost), true_cost) outer_grads = rets[1] assert np.allclose( - ivy_backend.to_numpy(outer_grads.weight), np.array(true_outer_grad) + ivy_backend.to_numpy(outer_grads.weight), np.array(true_weight_grad) + ) + assert np.allclose( + ivy_backend.to_numpy(outer_grads.latent), np.array(true_latent_grad) ) if return_inner_v: inner_v_rets = rets[2] @@ -1068,7 +945,11 @@ def update_grad_fn(w_init, sub_batch_in, num_steps, average=False): assert list(inner_v_rets.cont_shape) == [1, 1] -# maml step overlapping vars +# Second Order # +# -------------# + + +# maml step unique vars @pytest.mark.parametrize("inner_grad_steps", [1, 2, 3]) @pytest.mark.parametrize("with_outer_cost_fn", [True, False]) @pytest.mark.parametrize("average_across_steps", [True, False]) @@ -1076,7 +957,7 @@ def update_grad_fn(w_init, sub_batch_in, num_steps, average=False): @pytest.mark.parametrize("stop_gradients", [True, False]) @pytest.mark.parametrize("num_tasks", [1, 2]) @pytest.mark.parametrize("return_inner_v", ["first", "all", False]) -def test_maml_step_overlapping_vars( +def test_maml_step_unique_vars( on_device, inner_grad_steps, with_outer_cost_fn, @@ -1154,11 +1035,11 @@ def outer_cost_fn(batch_in, v): return cost / batch_size # numpy - latent_np = ivy_backend.to_numpy(variables.latent) - weight_np = ivy_backend.to_numpy(variables.weight) + weight_np = ivy_backend.to_numpy(variables.weight[0:1]) + latent_np = ivy_backend.to_numpy(variables.latent[0:1]) batch_np = batch.cont_map(lambda x, kc: ivy_backend.to_numpy(x)) - # true weight gradient + # true gradient all_outer_grads = list() for sub_batch in batch_np.cont_unstack_conts(0, True, num_tasks): all_outer_grads.append( @@ -1176,16 +1057,11 @@ def outer_cost_fn(batch_in, v): ] ) if average_across_steps: - true_weight_grad = ( + true_outer_grad = ( sum([sum(og) / len(og) for og in all_outer_grads]) / num_tasks ) else: - true_weight_grad = sum([og[-1] for og in all_outer_grads]) / num_tasks - - # true latent gradient - true_latent_grad = np.array( - [(-1 - (num_tasks - 1) / 2) * (-1 if with_outer_cost_fn else 1)] - ) + true_outer_grad = sum([og[-1] for og in all_outer_grads]) / num_tasks # true cost true_cost_dict = { @@ -1217,6 +1093,7 @@ def outer_cost_fn(batch_in, v): average_across_steps=average_across_steps, batched=batched, inner_v="latent", + outer_v="weight", return_inner_v=return_inner_v, stop_gradients=stop_gradients, ) @@ -1229,10 +1106,7 @@ def outer_cost_fn(batch_in, v): assert np.allclose(ivy_backend.to_scalar(calc_cost), true_cost) outer_grads = rets[1] assert np.allclose( - ivy_backend.to_numpy(outer_grads.weight), np.array(true_weight_grad) - ) - assert np.allclose( - ivy_backend.to_numpy(outer_grads.latent), np.array(true_latent_grad) + ivy_backend.to_numpy(outer_grads.weight), np.array(true_outer_grad) ) if return_inner_v: inner_v_rets = rets[2] @@ -1243,11 +1117,127 @@ def outer_cost_fn(batch_in, v): assert list(inner_v_rets.cont_shape) == [1, 1] -# Still to Add # -# ---------------# +# reptile step +@pytest.mark.parametrize("inner_grad_steps", [1, 2, 3]) +@pytest.mark.parametrize("batched", [True, False]) +@pytest.mark.parametrize("stop_gradients", [True, False]) +@pytest.mark.parametrize("num_tasks", [1, 2]) +@pytest.mark.parametrize("return_inner_v", ["first", "all", False]) +def test_reptile_step( + on_device, + inner_grad_steps, + batched, + stop_gradients, + num_tasks, + return_inner_v, + backend_fw, +): + if backend_fw == "numpy": + # Numpy does not support gradients, jax does not support gradients on custom + # nested classes, + pytest.skip() + + with BackendHandler.update_backend(backend_fw) as ivy_backend: + # config + inner_learning_rate = 1e-2 + variable_fn = ivy_backend.functional.ivy._variable -# _compute_cost_and_update_grads -# _train_tasks -# _train_tasks_batched -# _train_tasks_with_for_loop -# _fomaml_step + # create variable + if batched: + variables = ivy_backend.Container( + { + "latent": variable_fn( + ivy_backend.repeat( + ivy_backend.array([[1.0]], device=on_device), + num_tasks, + axis=0, + ) + ) + } + ) + else: + variables = ivy_backend.Container( + {"latent": variable_fn(ivy_backend.array([1.0], device=on_device))} + ) + + # batch + batch = ivy_backend.Container( + {"x": ivy_backend.arange(1, num_tasks + 1, dtype="float32")} + ) + + # inner cost function + def inner_cost_fn(batch_in, v): + cost = 0 + batch_size = batch_in.cont_shape[0] + for sub_batch_in, sub_v in zip( + batch_in.cont_unstack_conts(0, keepdims=True), + v.cont_unstack_conts(0, keepdims=True), + ): + cost = cost - (sub_batch_in["x"] * sub_v["latent"] ** 2)[0] + return cost / batch_size + + # numpy + latent_np = ivy_backend.to_numpy(variables.latent[0:1]) + batch_np = batch.cont_map(lambda x, kc: ivy_backend.to_numpy(x)) + + # loss grad function + def loss_grad_fn(sub_batch_in, w_in): + return -2 * sub_batch_in["x"][0] * w_in + + # true gradient + true_outer_grads = list() + for sub_batch in batch_np.cont_unstack_conts(0, True, num_tasks): + ws = list() + grads = list() + ws.append(latent_np) + for step in range(inner_grad_steps): + update_grad = loss_grad_fn(sub_batch, ws[-1]) + w = ws[-1] - inner_learning_rate * update_grad + grads.append(update_grad) + ws.append(w) + grads.append(loss_grad_fn(sub_batch, ws[-1])) + + # true outer grad + true_outer_grad = sum(grads) / len(grads) + true_outer_grads.append(true_outer_grad) + true_outer_grad = ( + sum(true_outer_grads) / len(true_outer_grads) + ) / inner_learning_rate + + # true cost + true_cost_dict = { + 1: {1: -1.0202, 2: -1.5509}, + 2: {1: -1.0409441, 2: -1.6042916}, + 3: {1: -1.0622487, 2: -1.6603187}, + } + true_cost = true_cost_dict[inner_grad_steps][num_tasks] + + # meta update + rets = ivy_backend.reptile_step( + batch, + inner_cost_fn, + variables, + inner_grad_steps, + inner_learning_rate, + batched=batched, + return_inner_v=return_inner_v, + stop_gradients=stop_gradients, + ) + calc_cost = rets[0] + if stop_gradients: + assert ivy_backend.equal( + ivy_backend.functional.ivy._is_variable(calc_cost, exclusive=True), + False, + ) + assert np.allclose(ivy_backend.to_scalar(calc_cost), true_cost) + outer_grads = rets[1] + assert np.allclose( + ivy_backend.to_numpy(outer_grads.latent[0]), np.array(true_outer_grad) + ) + if return_inner_v: + inner_v_rets = rets[2] + assert isinstance(inner_v_rets, ivy_backend.Container) + if return_inner_v == "all": + assert list(inner_v_rets.cont_shape) == [num_tasks, 1] + elif return_inner_v == "first": + assert list(inner_v_rets.cont_shape) == [1, 1] diff --git a/ivy_tests/test_ivy/test_functional/test_core/test_nest.py b/ivy_tests/test_ivy/test_functional/test_core/test_nest.py index 25bbbc3ee1e7e..c4a30e9a4bf08 100644 --- a/ivy_tests/test_ivy/test_functional/test_core/test_nest.py +++ b/ivy_tests/test_ivy/test_functional/test_core/test_nest.py @@ -9,15 +9,9 @@ # local import ivy -# Helpers # -# --------# - -def _snai(n, idx, v): - if len(idx) == 1: - n[idx[0]] = v - else: - _snai(n[idx[0]], idx[1:], v) +# --- Helpers --- # +# --------------- # def _mnai(n, idx, fn): @@ -34,6 +28,17 @@ def _pnai(n, idx): _pnai(n[idx[0]], idx[1:]) +def _snai(n, idx, v): + if len(idx) == 1: + n[idx[0]] = v + else: + _snai(n[idx[0]], idx[1:], v) + + +# --- Main --- # +# ------------ # + + # only checking for dicts but can test other nested functions using # collections.abc.Sequences/Mapping/Iterable def apply_fn_to_list(item, fun): @@ -51,187 +56,24 @@ def map_nested_dicts(ob, func): ob[k] = apply_fn_to_list(v, func) -# Tests # -# ------# - - -# index_nest -@pytest.mark.parametrize( - "nest", [{"a": [[0], [1]], "b": {"c": (((2,), (4,)), ((6,), (8,)))}}] -) -@pytest.mark.parametrize( - "index", [("a", 0, 0), ("a", 1, 0), ("b", "c", 0), ("b", "c", 1, 0)] -) -def test_index_nest(nest, index): - ret = ivy.index_nest(nest, index) - true_ret = nest - for i in index: - true_ret = true_ret[i] - assert ret == true_ret - - -# set_nest_at_index -@pytest.mark.parametrize( - "nest", [{"a": [[0], [1]], "b": {"c": [[[2], [4]], [[6], [8]]]}}] -) -@pytest.mark.parametrize( - "index", [("a", 0, 0), ("a", 1, 0), ("b", "c", 0), ("b", "c", 1, 0)] -) -@pytest.mark.parametrize("value", [-1]) -@pytest.mark.parametrize("shallow", [True, False]) -def test_set_nest_at_index(nest, index, value, shallow): - nest_copy = copy.deepcopy(nest) - result = ivy.set_nest_at_index(nest, index, value, shallow=shallow) - _snai(nest_copy, index, value) - - assert result == nest_copy - if shallow: - assert nest == nest_copy - else: - assert nest != nest_copy - - -# map_nest_at_index -@pytest.mark.parametrize( - "nest", [{"a": [[0], [1]], "b": {"c": [[[2], [4]], [[6], [8]]]}}] -) -@pytest.mark.parametrize( - "index", [("a", 0, 0), ("a", 1, 0), ("b", "c", 0, 0, 0), ("b", "c", 1, 0, 0)] -) -@pytest.mark.parametrize("fn", [lambda x: x + 2]) -@pytest.mark.parametrize("shallow", [True, False]) -def test_map_nest_at_index(nest, index, fn, shallow): - nest_copy = copy.deepcopy(nest) - result = ivy.map_nest_at_index(nest, index, fn, shallow=shallow) - _mnai(nest_copy, index, fn) - - assert result == nest_copy - if shallow: - assert nest == nest_copy - else: - assert nest != nest_copy - - -# multi_index_nest -@pytest.mark.parametrize( - "nest", [{"a": [[0], [1]], "b": {"c": (((2,), (4,)), ((6,), (8,)))}}] -) -@pytest.mark.parametrize( - "multi_indices", [(("a", 0, 0), ("a", 1, 0)), (("b", "c", 0), ("b", "c", 1, 0))] -) -def test_multi_index_nest(nest, multi_indices): - rets = ivy.multi_index_nest(nest, multi_indices) - true_rets = list() - for indices in multi_indices: - true_ret = nest - for i in indices: - true_ret = true_ret[i] - true_rets.append(true_ret) - assert rets == true_rets - - -# set_nest_at_indices -@pytest.mark.parametrize( - "nest", [{"a": [[0], [1]], "b": {"c": [[[2], [4]], [[6], [8]]]}}] -) -@pytest.mark.parametrize( - "indices", [(("a", 0, 0), ("a", 1, 0)), (("b", "c", 0), ("b", "c", 1, 0))] -) -@pytest.mark.parametrize("values", [(1, 2)]) -@pytest.mark.parametrize("shallow", [False, True]) -def test_set_nest_at_indices(nest, indices, values, shallow): - nest_copy = copy.deepcopy(nest) - result = ivy.set_nest_at_indices(nest, indices, values, shallow=shallow) - - def snais(n, idxs, vs): - [_snai(n, index, value) for index, value in zip(idxs, vs)] - - snais(nest_copy, indices, values) - - assert result == nest_copy - if shallow: - assert nest == nest_copy - else: - assert nest != nest_copy - - -# map_nest_at_indices -@pytest.mark.parametrize( - "nest", [{"a": [[0], [1]], "b": {"c": [[[2], [4]], [[6], [8]]]}}] -) -@pytest.mark.parametrize( - "indices", [(("a", 0, 0), ("a", 1, 0)), (("b", "c", 0, 0, 0), ("b", "c", 1, 0, 0))] -) -@pytest.mark.parametrize("fn", [lambda x: x + 2, lambda x: x**2]) -@pytest.mark.parametrize("shallow", [True, False]) -def test_map_nest_at_indices(nest, indices, fn, shallow): - nest_copy = copy.deepcopy(nest) - result = ivy.map_nest_at_indices(nest, indices, fn, shallow) - - def mnais(n, idxs, vs): - [_mnai(n, index, vs) for index in idxs] - - mnais(nest_copy, indices, fn) - - assert result == nest_copy - if shallow: - assert nest == nest_copy - else: - assert nest != nest_copy - - -# nested_argwhere +# all_nested_indices @pytest.mark.parametrize( "nest", [{"a": [[0], [1]], "b": {"c": [[[2], [4]], [[6], [8]]]}}] ) -def test_nested_argwhere(nest): - indices = ivy.nested_argwhere(nest, lambda x: x < 5) +def test_all_nested_indices(nest): + indices = ivy.all_nested_indices(nest) assert indices[0] == ["a", 0, 0] assert indices[1] == ["a", 1, 0] assert indices[2] == ["b", "c", 0, 0, 0] assert indices[3] == ["b", "c", 0, 1, 0] + assert indices[4] == ["b", "c", 1, 0, 0] + assert indices[5] == ["b", "c", 1, 1, 0] -# nested_argwhere_w_nest_checks -@pytest.mark.parametrize( - "nest", [{"a": [[0], [1]], "b": {"c": [[[2], [4]], [[6], [8]]]}}] -) -def test_nested_argwhere_w_nest_checks(nest): - indices = ivy.nested_argwhere( - nest, lambda x: isinstance(x, list) or (isinstance(x, int) and x < 5), True - ) - assert indices[0] == ["a", 0, 0] - assert indices[1] == ["a", 0] - assert indices[2] == ["a", 1, 0] - assert indices[3] == ["a", 1] - assert indices[4] == ["a"] - assert indices[5] == ["b", "c", 0, 0, 0] - assert indices[6] == ["b", "c", 0, 0] - assert indices[7] == ["b", "c", 0, 1, 0] - assert indices[8] == ["b", "c", 0, 1] - assert indices[9] == ["b", "c", 0] - assert indices[10] == ["b", "c", 1, 0] - assert indices[11] == ["b", "c", 1, 1] - assert indices[12] == ["b", "c", 1] - assert indices[13] == ["b", "c"] - - -# nested_argwhere_w_extra_nest_types -def test_nested_argwhere_w_extra_nest_types(): +# all_nested_indices_w_extra_nest_types +def test_all_nested_indices_w_extra_nest_types(): nest = {"a": ivy.array([[0], [1]]), "b": {"c": ivy.array([[[2], [4]], [[6], [8]]])}} - indices = ivy.nested_argwhere(nest, lambda x: x < 5, extra_nest_types=ivy.Array) - assert indices[0] == ["a", 0, 0] - assert indices[1] == ["a", 1, 0] - assert indices[2] == ["b", "c", 0, 0, 0] - assert indices[3] == ["b", "c", 0, 1, 0] - - -# all_nested_indices -@pytest.mark.parametrize( - "nest", [{"a": [[0], [1]], "b": {"c": [[[2], [4]], [[6], [8]]]}}] -) -def test_all_nested_indices(nest): - indices = ivy.all_nested_indices(nest) + indices = ivy.all_nested_indices(nest, extra_nest_types=ivy.Array) assert indices[0] == ["a", 0, 0] assert indices[1] == ["a", 1, 0] assert indices[2] == ["b", "c", 0, 0, 0] @@ -265,18 +107,6 @@ def test_all_nested_indices_w_nest_checks(nest): assert indices[16] == ["b"] -# all_nested_indices_w_extra_nest_types -def test_all_nested_indices_w_extra_nest_types(): - nest = {"a": ivy.array([[0], [1]]), "b": {"c": ivy.array([[[2], [4]], [[6], [8]]])}} - indices = ivy.all_nested_indices(nest, extra_nest_types=ivy.Array) - assert indices[0] == ["a", 0, 0] - assert indices[1] == ["a", 1, 0] - assert indices[2] == ["b", "c", 0, 0, 0] - assert indices[3] == ["b", "c", 0, 1, 0] - assert indices[4] == ["b", "c", 1, 0, 0] - assert indices[5] == ["b", "c", 1, 1, 0] - - # copy_nest def test_copy_nest(): nest = { @@ -329,70 +159,56 @@ def test_copy_nest_w_extra_nest_types(): assert nest["b"]["c"][1][1] is not nest_copy["b"]["c"][1][0] -# nested_multi_map -@pytest.mark.parametrize("func", [lambda x, _: x[0] - x[1]]) +# duplicate_array_index_chains +@pytest.mark.parametrize("x", [[-1.0]]) +@pytest.mark.parametrize("y", [[1.0]]) @pytest.mark.parametrize( - "nests", - [ - [ - np.asarray([-1.82, 1.25, -2.91, 0.109, 0.76, 1.7, 0.231, 4.45]), - np.asarray([-3.98, -3.86, 7.94, 2.08, 9.3, 2.35, 9.37, 1.7]), - ] - ], + "nest", [[{"a": None, "b": {"c": None, "d": None}}, [None, None]]] ) -def test_nested_multi_map(func, nests): - nests = ivy.nested_map( - nests, - lambda x: ivy.array(x) if isinstance(x, np.ndarray) else x, - include_derived=True, - shallow=False, - ) - # without index_chains specification - nested_multi_map_res = ivy.nested_multi_map(func, nests) +def test_duplicate_array_index_chains(nest, x, y): + x = ivy.array(x) + y = ivy.array(y) + nest[0]["a"] = nest[0]["b"]["d"] = nest[1][0] = x + nest[0]["b"]["c"] = nest[1][1] = y + duplicate_index_chains = ivy.duplicate_array_index_chains(nest) + assert duplicate_index_chains[0] == [[0, "a"], [0, "b", "d"], [1, 0]] + assert duplicate_index_chains[1] == [[0, "b", "c"], [1, 1]] - # modify this to test for other functions - nests_without_multi_map_res = nests[0] - nests[1] - assert ivy.all_equal(nested_multi_map_res, nests_without_multi_map_res) +# Tests # +# ------# -# prune_nest_at_index +# index_nest @pytest.mark.parametrize( - "nest", [{"a": [[0], [1]], "b": {"c": [[[2], [4]], [[6], [8]]]}}] + "nest", [{"a": [[0], [1]], "b": {"c": (((2,), (4,)), ((6,), (8,)))}}] ) @pytest.mark.parametrize( "index", [("a", 0, 0), ("a", 1, 0), ("b", "c", 0), ("b", "c", 1, 0)] ) -def test_prune_nest_at_index(nest, index): - nest_copy = copy.deepcopy(nest) +def test_index_nest(nest, index): + ret = ivy.index_nest(nest, index) + true_ret = nest + for i in index: + true_ret = true_ret[i] + assert ret == true_ret - # handling cases where there is nothing to prune - try: - ivy.prune_nest_at_index(nest, index) - _pnai(nest_copy, index) - except Exception: - warnings.warn("Nothing to delete.") - - assert nest == nest_copy - -# prune_nest_at_indices +# insert_into_nest_at_indices @pytest.mark.parametrize( "nest", [{"a": [[0], [1]], "b": {"c": [[[2], [4]], [[6], [8]]]}}] ) -@pytest.mark.parametrize("indices", [(("a", 0), ("a", 0, 0), ("a", 1), ("b", "c", 0))]) -def test_prune_nest_at_indices(nest, indices): - nest_copy = copy.deepcopy(nest) - ivy.prune_nest_at_indices(nest_copy, indices) - print(nest_copy) - for idx in indices: - try: - ele_org = ivy.index_nest(nest, idx) - ele_new = ivy.index_nest(nest_copy, idx) - except ivy.utils.exceptions.IvyIndexError: - return - else: - assert ele_org != ele_new +@pytest.mark.parametrize("indices", [(("a", 0, 0), ("b", "c", 1, 0))]) +@pytest.mark.parametrize("values", [(1, 2)]) +def test_insert_into_nest_at_indices(nest, indices, values): + ivy.insert_into_nest_at_indices(nest, indices, values) + + def indices_nest(nest, indices): + ret = tuple(ivy.index_nest(nest, index) for index in indices) + + return ret + + assert indices_nest(nest, indices) == values # insert_into_nest_at_index @@ -407,49 +223,68 @@ def test_insert_into_nest_index(nest, index, value): assert ivy.index_nest(nest, index) == value -# insert_into_nest_at_indices +# map_nest_at_index @pytest.mark.parametrize( "nest", [{"a": [[0], [1]], "b": {"c": [[[2], [4]], [[6], [8]]]}}] ) -@pytest.mark.parametrize("indices", [(("a", 0, 0), ("b", "c", 1, 0))]) -@pytest.mark.parametrize("values", [(1, 2)]) -def test_insert_into_nest_at_indices(nest, indices, values): - ivy.insert_into_nest_at_indices(nest, indices, values) +@pytest.mark.parametrize( + "index", [("a", 0, 0), ("a", 1, 0), ("b", "c", 0, 0, 0), ("b", "c", 1, 0, 0)] +) +@pytest.mark.parametrize("fn", [lambda x: x + 2]) +@pytest.mark.parametrize("shallow", [True, False]) +def test_map_nest_at_index(nest, index, fn, shallow): + nest_copy = copy.deepcopy(nest) + result = ivy.map_nest_at_index(nest, index, fn, shallow=shallow) + _mnai(nest_copy, index, fn) - def indices_nest(nest, indices): - ret = tuple(ivy.index_nest(nest, index) for index in indices) + assert result == nest_copy + if shallow: + assert nest == nest_copy + else: + assert nest != nest_copy - return ret - assert indices_nest(nest, indices) == values +# map_nest_at_indices +@pytest.mark.parametrize( + "nest", [{"a": [[0], [1]], "b": {"c": [[[2], [4]], [[6], [8]]]}}] +) +@pytest.mark.parametrize( + "indices", [(("a", 0, 0), ("a", 1, 0)), (("b", "c", 0, 0, 0), ("b", "c", 1, 0, 0))] +) +@pytest.mark.parametrize("fn", [lambda x: x + 2, lambda x: x**2]) +@pytest.mark.parametrize("shallow", [True, False]) +def test_map_nest_at_indices(nest, indices, fn, shallow): + nest_copy = copy.deepcopy(nest) + result = ivy.map_nest_at_indices(nest, indices, fn, shallow) + def mnais(n, idxs, vs): + [_mnai(n, index, vs) for index in idxs] -# nested_map -@pytest.mark.parametrize("x", [{"a": [[0, 1], [2, 3]], "b": {"c": [[0], [1]]}}]) -@pytest.mark.parametrize("fn", [lambda x: x**2]) -@pytest.mark.parametrize("shallow", [True, False]) -def test_nested_map(x, fn, shallow): - x_copy = copy.deepcopy(x) - result = ivy.nested_map(x, fn, shallow=shallow) - map_nested_dicts(x_copy, fn) + mnais(nest_copy, indices, fn) - assert result == x_copy + assert result == nest_copy if shallow: - assert x == x_copy + assert nest == nest_copy else: - assert x != x_copy + assert nest != nest_copy -# nested_map_w_extra_nest_types -@pytest.mark.parametrize("fn", [lambda x: x**2]) -def test_nested_map_w_extra_nest_types(fn): - x = {"a": ivy.array([[0, 1], [2, 3]]), "b": {"c": ivy.array([[0], [1]])}} - x_copy = copy.deepcopy(x) - x = ivy.nested_map(x, fn, extra_nest_types=ivy.Array) - map_nested_dicts(x_copy, fn) - - assert ivy.all(x_copy["a"] == x["a"]) - assert ivy.all(x_copy["b"]["c"] == x["b"]["c"]) +# multi_index_nest +@pytest.mark.parametrize( + "nest", [{"a": [[0], [1]], "b": {"c": (((2,), (4,)), ((6,), (8,)))}}] +) +@pytest.mark.parametrize( + "multi_indices", [(("a", 0, 0), ("a", 1, 0)), (("b", "c", 0), ("b", "c", 1, 0))] +) +def test_multi_index_nest(nest, multi_indices): + rets = ivy.multi_index_nest(nest, multi_indices) + true_rets = list() + for indices in multi_indices: + true_ret = nest + for i in indices: + true_ret = true_ret[i] + true_rets.append(true_ret) + assert rets == true_rets # nested_any @@ -492,20 +327,105 @@ def is_true_any(ob): assert x_copy_bool == x_bool -# duplicate_array_index_chains -@pytest.mark.parametrize("x", [[-1.0]]) -@pytest.mark.parametrize("y", [[1.0]]) +# nested_argwhere @pytest.mark.parametrize( - "nest", [[{"a": None, "b": {"c": None, "d": None}}, [None, None]]] + "nest", [{"a": [[0], [1]], "b": {"c": [[[2], [4]], [[6], [8]]]}}] ) -def test_duplicate_array_index_chains(nest, x, y): - x = ivy.array(x) - y = ivy.array(y) - nest[0]["a"] = nest[0]["b"]["d"] = nest[1][0] = x - nest[0]["b"]["c"] = nest[1][1] = y - duplicate_index_chains = ivy.duplicate_array_index_chains(nest) - assert duplicate_index_chains[0] == [[0, "a"], [0, "b", "d"], [1, 0]] - assert duplicate_index_chains[1] == [[0, "b", "c"], [1, 1]] +def test_nested_argwhere(nest): + indices = ivy.nested_argwhere(nest, lambda x: x < 5) + assert indices[0] == ["a", 0, 0] + assert indices[1] == ["a", 1, 0] + assert indices[2] == ["b", "c", 0, 0, 0] + assert indices[3] == ["b", "c", 0, 1, 0] + + +# nested_argwhere_w_extra_nest_types +def test_nested_argwhere_w_extra_nest_types(): + nest = {"a": ivy.array([[0], [1]]), "b": {"c": ivy.array([[[2], [4]], [[6], [8]]])}} + indices = ivy.nested_argwhere(nest, lambda x: x < 5, extra_nest_types=ivy.Array) + assert indices[0] == ["a", 0, 0] + assert indices[1] == ["a", 1, 0] + assert indices[2] == ["b", "c", 0, 0, 0] + assert indices[3] == ["b", "c", 0, 1, 0] + + +# nested_argwhere_w_nest_checks +@pytest.mark.parametrize( + "nest", [{"a": [[0], [1]], "b": {"c": [[[2], [4]], [[6], [8]]]}}] +) +def test_nested_argwhere_w_nest_checks(nest): + indices = ivy.nested_argwhere( + nest, lambda x: isinstance(x, list) or (isinstance(x, int) and x < 5), True + ) + assert indices[0] == ["a", 0, 0] + assert indices[1] == ["a", 0] + assert indices[2] == ["a", 1, 0] + assert indices[3] == ["a", 1] + assert indices[4] == ["a"] + assert indices[5] == ["b", "c", 0, 0, 0] + assert indices[6] == ["b", "c", 0, 0] + assert indices[7] == ["b", "c", 0, 1, 0] + assert indices[8] == ["b", "c", 0, 1] + assert indices[9] == ["b", "c", 0] + assert indices[10] == ["b", "c", 1, 0] + assert indices[11] == ["b", "c", 1, 1] + assert indices[12] == ["b", "c", 1] + assert indices[13] == ["b", "c"] + + +# nested_map +@pytest.mark.parametrize("x", [{"a": [[0, 1], [2, 3]], "b": {"c": [[0], [1]]}}]) +@pytest.mark.parametrize("fn", [lambda x: x**2]) +@pytest.mark.parametrize("shallow", [True, False]) +def test_nested_map(x, fn, shallow): + x_copy = copy.deepcopy(x) + result = ivy.nested_map(x, fn, shallow=shallow) + map_nested_dicts(x_copy, fn) + + assert result == x_copy + if shallow: + assert x == x_copy + else: + assert x != x_copy + + +# nested_map_w_extra_nest_types +@pytest.mark.parametrize("fn", [lambda x: x**2]) +def test_nested_map_w_extra_nest_types(fn): + x = {"a": ivy.array([[0, 1], [2, 3]]), "b": {"c": ivy.array([[0], [1]])}} + x_copy = copy.deepcopy(x) + x = ivy.nested_map(x, fn, extra_nest_types=ivy.Array) + map_nested_dicts(x_copy, fn) + + assert ivy.all(x_copy["a"] == x["a"]) + assert ivy.all(x_copy["b"]["c"] == x["b"]["c"]) + + +# nested_multi_map +@pytest.mark.parametrize("func", [lambda x, _: x[0] - x[1]]) +@pytest.mark.parametrize( + "nests", + [ + [ + np.asarray([-1.82, 1.25, -2.91, 0.109, 0.76, 1.7, 0.231, 4.45]), + np.asarray([-3.98, -3.86, 7.94, 2.08, 9.3, 2.35, 9.37, 1.7]), + ] + ], +) +def test_nested_multi_map(func, nests): + nests = ivy.nested_map( + nests, + lambda x: ivy.array(x) if isinstance(x, np.ndarray) else x, + include_derived=True, + shallow=False, + ) + # without index_chains specification + nested_multi_map_res = ivy.nested_multi_map(func, nests) + + # modify this to test for other functions + nests_without_multi_map_res = nests[0] - nests[1] + + assert ivy.all_equal(nested_multi_map_res, nests_without_multi_map_res) # prune_empty @@ -513,3 +433,88 @@ def test_duplicate_array_index_chains(nest, x, y): def test_prune_empty(nest): ret = ivy.prune_empty(ivy.copy_nest(nest)) assert ret == {"b": {"c": [1]}} + + +# prune_nest_at_index +@pytest.mark.parametrize( + "nest", [{"a": [[0], [1]], "b": {"c": [[[2], [4]], [[6], [8]]]}}] +) +@pytest.mark.parametrize( + "index", [("a", 0, 0), ("a", 1, 0), ("b", "c", 0), ("b", "c", 1, 0)] +) +def test_prune_nest_at_index(nest, index): + nest_copy = copy.deepcopy(nest) + + # handling cases where there is nothing to prune + try: + ivy.prune_nest_at_index(nest, index) + _pnai(nest_copy, index) + except Exception: + warnings.warn("Nothing to delete.") + + assert nest == nest_copy + + +# prune_nest_at_indices +@pytest.mark.parametrize( + "nest", [{"a": [[0], [1]], "b": {"c": [[[2], [4]], [[6], [8]]]}}] +) +@pytest.mark.parametrize("indices", [(("a", 0), ("a", 0, 0), ("a", 1), ("b", "c", 0))]) +def test_prune_nest_at_indices(nest, indices): + nest_copy = copy.deepcopy(nest) + ivy.prune_nest_at_indices(nest_copy, indices) + print(nest_copy) + for idx in indices: + try: + ele_org = ivy.index_nest(nest, idx) + ele_new = ivy.index_nest(nest_copy, idx) + except ivy.utils.exceptions.IvyIndexError: + return + else: + assert ele_org != ele_new + + +# set_nest_at_index +@pytest.mark.parametrize( + "nest", [{"a": [[0], [1]], "b": {"c": [[[2], [4]], [[6], [8]]]}}] +) +@pytest.mark.parametrize( + "index", [("a", 0, 0), ("a", 1, 0), ("b", "c", 0), ("b", "c", 1, 0)] +) +@pytest.mark.parametrize("value", [-1]) +@pytest.mark.parametrize("shallow", [True, False]) +def test_set_nest_at_index(nest, index, value, shallow): + nest_copy = copy.deepcopy(nest) + result = ivy.set_nest_at_index(nest, index, value, shallow=shallow) + _snai(nest_copy, index, value) + + assert result == nest_copy + if shallow: + assert nest == nest_copy + else: + assert nest != nest_copy + + +# set_nest_at_indices +@pytest.mark.parametrize( + "nest", [{"a": [[0], [1]], "b": {"c": [[[2], [4]], [[6], [8]]]}}] +) +@pytest.mark.parametrize( + "indices", [(("a", 0, 0), ("a", 1, 0)), (("b", "c", 0), ("b", "c", 1, 0))] +) +@pytest.mark.parametrize("values", [(1, 2)]) +@pytest.mark.parametrize("shallow", [False, True]) +def test_set_nest_at_indices(nest, indices, values, shallow): + nest_copy = copy.deepcopy(nest) + result = ivy.set_nest_at_indices(nest, indices, values, shallow=shallow) + + def snais(n, idxs, vs): + [_snai(n, index, value) for index, value in zip(idxs, vs)] + + snais(nest_copy, indices, values) + + assert result == nest_copy + if shallow: + assert nest == nest_copy + else: + assert nest != nest_copy diff --git a/ivy_tests/test_ivy/test_functional/test_core/test_random.py b/ivy_tests/test_ivy/test_functional/test_core/test_random.py index 4bc107ad3af74..13a6e7652448f 100644 --- a/ivy_tests/test_ivy/test_functional/test_core/test_random.py +++ b/ivy_tests/test_ivy/test_functional/test_core/test_random.py @@ -9,132 +9,32 @@ from ivy_tests.test_ivy.helpers import handle_test, BackendHandler -# random_uniform -@handle_test( - fn_tree="functional.ivy.random_uniform", - dtype_and_low=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_value=-1000, - max_value=100, - min_num_dims=1, - max_num_dims=5, - min_dim_size=2, - ), - dtype_and_high=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_value=101, - max_value=1000, - min_num_dims=1, - max_num_dims=5, - min_dim_size=2, - ), - dtype=helpers.get_dtypes("float", full=False), - seed=helpers.ints(min_value=0, max_value=100), - test_gradients=st.just(False), -) -def test_random_uniform( - *, - dtype_and_low, - dtype_and_high, - dtype, - seed, - test_flags, - backend_fw, - fn_name, - on_device -): - low_dtype, low = dtype_and_low - high_dtype, high = dtype_and_high +# --- Helpers --- # +# --------------- # - def call(): - return helpers.test_function( - input_dtypes=low_dtype + high_dtype, - test_flags=test_flags, - on_device=on_device, - backend_to_test=backend_fw, - fn_name=fn_name, - test_values=False, - low=low[0], - high=high[0], - shape=None, + +@st.composite +def _gen_randint_data(draw): + dtype = draw(helpers.get_dtypes("signed_integer", full=False)) + dim1 = draw(helpers.ints(min_value=1, max_value=5)) + dim2 = draw(helpers.ints(min_value=2, max_value=8)) + low = draw( + helpers.array_values( dtype=dtype[0], - seed=seed, + shape=(dim1, dim2), + min_value=-100, + max_value=25, ) - - ret, ret_gt = call() - if seed: - ret1, ret_gt2 = call() - assert ivy.any(ret == ret1) - ret = helpers.flatten_and_to_np(ret=ret, backend=backend_fw) - ret_gt = helpers.flatten_and_to_np( - ret=ret_gt, backend=test_flags.ground_truth_backend ) - - for u, v in zip(ret, ret_gt): - assert u.dtype == v.dtype - - -# random_normal -@handle_test( - fn_tree="functional.ivy.random_normal", - dtype_and_mean=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_value=-1000, - max_value=1000, - min_num_dims=1, - max_num_dims=5, - min_dim_size=2, - ), - dtype_and_std=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_value=0, - max_value=1000, - min_num_dims=1, - max_num_dims=5, - min_dim_size=2, - ), - dtype=helpers.get_dtypes("float", full=False), - seed=helpers.ints(min_value=0, max_value=100), - test_gradients=st.just(False), -) -def test_random_normal( - dtype_and_mean, - dtype_and_std, - dtype, - seed, - test_flags, - backend_fw, - fn_name, - on_device, -): - mean_dtype, mean = dtype_and_mean - std_dtype, std = dtype_and_std - - def call(): - return helpers.test_function( - input_dtypes=mean_dtype + std_dtype, - test_flags=test_flags, - on_device=on_device, - backend_to_test=backend_fw, - fn_name=fn_name, - test_values=False, - mean=mean[0], - std=std[0], - shape=None, + high = draw( + helpers.array_values( dtype=dtype[0], - seed=seed, + shape=(dim1, dim2), + min_value=26, + max_value=100, ) - - ret, ret_gt = call() - if seed: - ret1, ret_gt1 = call() - assert ivy.any(ret == ret1) - ret = helpers.flatten_and_to_np(ret=ret, backend=backend_fw) - ret_gt = helpers.flatten_and_to_np( - ret=ret_gt, backend=test_flags.ground_truth_backend ) - for u, v in zip(ret, ret_gt): - assert u.dtype == v.dtype + return dtype, low, high @st.composite @@ -161,6 +61,10 @@ def _pop_size_num_samples_replace_n_probs(draw): return prob_dtype, batch_size, population_size, num_samples, replace, probs +# --- Main --- # +# ------------ # + + # multinomial @handle_test( fn_tree="functional.ivy.multinomial", @@ -208,30 +112,6 @@ def call(): assert u.shape == v.shape -@st.composite -def _gen_randint_data(draw): - dtype = draw(helpers.get_dtypes("signed_integer", full=False)) - dim1 = draw(helpers.ints(min_value=1, max_value=5)) - dim2 = draw(helpers.ints(min_value=2, max_value=8)) - low = draw( - helpers.array_values( - dtype=dtype[0], - shape=(dim1, dim2), - min_value=-100, - max_value=25, - ) - ) - high = draw( - helpers.array_values( - dtype=dtype[0], - shape=(dim1, dim2), - min_value=26, - max_value=100, - ) - ) - return dtype, low, high - - # randint @handle_test( fn_tree="functional.ivy.randint", @@ -270,6 +150,134 @@ def call(): assert ivy.all(v >= low) and ivy.all(v < high) +# random_normal +@handle_test( + fn_tree="functional.ivy.random_normal", + dtype_and_mean=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_value=-1000, + max_value=1000, + min_num_dims=1, + max_num_dims=5, + min_dim_size=2, + ), + dtype_and_std=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_value=0, + max_value=1000, + min_num_dims=1, + max_num_dims=5, + min_dim_size=2, + ), + dtype=helpers.get_dtypes("float", full=False), + seed=helpers.ints(min_value=0, max_value=100), + test_gradients=st.just(False), +) +def test_random_normal( + dtype_and_mean, + dtype_and_std, + dtype, + seed, + test_flags, + backend_fw, + fn_name, + on_device, +): + mean_dtype, mean = dtype_and_mean + std_dtype, std = dtype_and_std + + def call(): + return helpers.test_function( + input_dtypes=mean_dtype + std_dtype, + test_flags=test_flags, + on_device=on_device, + backend_to_test=backend_fw, + fn_name=fn_name, + test_values=False, + mean=mean[0], + std=std[0], + shape=None, + dtype=dtype[0], + seed=seed, + ) + + ret, ret_gt = call() + if seed: + ret1, ret_gt1 = call() + assert ivy.any(ret == ret1) + ret = helpers.flatten_and_to_np(ret=ret, backend=backend_fw) + ret_gt = helpers.flatten_and_to_np( + ret=ret_gt, backend=test_flags.ground_truth_backend + ) + for u, v in zip(ret, ret_gt): + assert u.dtype == v.dtype + + +# random_uniform +@handle_test( + fn_tree="functional.ivy.random_uniform", + dtype_and_low=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_value=-1000, + max_value=100, + min_num_dims=1, + max_num_dims=5, + min_dim_size=2, + ), + dtype_and_high=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_value=101, + max_value=1000, + min_num_dims=1, + max_num_dims=5, + min_dim_size=2, + ), + dtype=helpers.get_dtypes("float", full=False), + seed=helpers.ints(min_value=0, max_value=100), + test_gradients=st.just(False), +) +def test_random_uniform( + *, + dtype_and_low, + dtype_and_high, + dtype, + seed, + test_flags, + backend_fw, + fn_name, + on_device +): + low_dtype, low = dtype_and_low + high_dtype, high = dtype_and_high + + def call(): + return helpers.test_function( + input_dtypes=low_dtype + high_dtype, + test_flags=test_flags, + on_device=on_device, + backend_to_test=backend_fw, + fn_name=fn_name, + test_values=False, + low=low[0], + high=high[0], + shape=None, + dtype=dtype[0], + seed=seed, + ) + + ret, ret_gt = call() + if seed: + ret1, ret_gt2 = call() + assert ivy.any(ret == ret1) + ret = helpers.flatten_and_to_np(ret=ret, backend=backend_fw) + ret_gt = helpers.flatten_and_to_np( + ret=ret_gt, backend=test_flags.ground_truth_backend + ) + + for u, v in zip(ret, ret_gt): + assert u.dtype == v.dtype + + # seed @handle_test( fn_tree="functional.ivy.seed", diff --git a/ivy_tests/test_ivy/test_functional/test_core/test_searching.py b/ivy_tests/test_ivy/test_functional/test_core/test_searching.py index 553a97349e43e..4a1e87c7f32f0 100644 --- a/ivy_tests/test_ivy/test_functional/test_core/test_searching.py +++ b/ivy_tests/test_ivy/test_functional/test_core/test_searching.py @@ -8,6 +8,28 @@ from ivy_tests.test_ivy.helpers import handle_test +# --- Helpers --- # +# --------------- # + + +@st.composite +def _broadcastable_trio(draw): + shape = draw(helpers.get_shape(min_num_dims=1, min_dim_size=1)) + cond = draw(helpers.array_values(dtype="bool", shape=shape)) + dtypes, xs = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, + shape=shape, + shared_dtype=True, + large_abs_safety_factor=16, + small_abs_safety_factor=16, + safety_factor_scale="log", + ) + ) + return cond, xs, dtypes + + # Helpers # ############ @@ -29,22 +51,8 @@ def _dtype_x_limited_axis(draw, *, allow_none=False): return dtype, x, axis -@st.composite -def _broadcastable_trio(draw): - shape = draw(helpers.get_shape(min_num_dims=1, min_dim_size=1)) - cond = draw(helpers.array_values(dtype="bool", shape=shape)) - dtypes, xs = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, - shape=shape, - shared_dtype=True, - large_abs_safety_factor=16, - small_abs_safety_factor=16, - safety_factor_scale="log", - ) - ) - return cond, xs, dtypes +# --- Main --- # +# ------------ # # Functions # @@ -117,6 +125,24 @@ def test_argmin( ) +# argwhere +@handle_test( + fn_tree="functional.ivy.argwhere", + dtype_and_x=helpers.dtype_and_values(available_dtypes=("bool",)), + ground_truth_backend="torch", +) +def test_argwhere(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): + input_dtype, x = dtype_and_x + helpers.test_function( + input_dtypes=input_dtype, + test_flags=test_flags, + backend_to_test=backend_fw, + fn_name=fn_name, + on_device=on_device, + x=x[0], + ) + + @handle_test( fn_tree="functional.ivy.nonzero", dtype_and_x=helpers.dtype_and_values( @@ -171,21 +197,3 @@ def test_where(*, broadcastables, test_flags, backend_fw, fn_name, on_device): x1=xs[0], x2=xs[1], ) - - -# argwhere -@handle_test( - fn_tree="functional.ivy.argwhere", - dtype_and_x=helpers.dtype_and_values(available_dtypes=("bool",)), - ground_truth_backend="torch", -) -def test_argwhere(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): - input_dtype, x = dtype_and_x - helpers.test_function( - input_dtypes=input_dtype, - test_flags=test_flags, - backend_to_test=backend_fw, - fn_name=fn_name, - on_device=on_device, - x=x[0], - ) diff --git a/ivy_tests/test_ivy/test_functional/test_core/test_set.py b/ivy_tests/test_ivy/test_functional/test_core/test_set.py index b0c5a8f3d458e..1f80d4ce7b7a2 100644 --- a/ivy_tests/test_ivy/test_functional/test_core/test_set.py +++ b/ivy_tests/test_ivy/test_functional/test_core/test_set.py @@ -8,30 +8,6 @@ from ivy_tests.test_ivy.helpers import handle_test -# unique_values -@handle_test( - fn_tree="functional.ivy.unique_values", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - min_num_dims=1, - min_dim_size=1, - ), - test_gradients=st.just(False), -) -def test_unique_values(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): - dtype, x = dtype_and_x - assume(not np.any(np.isclose(x, 0.0))) - - helpers.test_function( - input_dtypes=dtype, - test_flags=test_flags, - on_device=on_device, - backend_to_test=backend_fw, - fn_name=fn_name, - x=x[0], - ) - - # unique_all @handle_test( fn_tree="functional.ivy.unique_all", @@ -114,3 +90,27 @@ def test_unique_inverse(*, dtype_and_x, test_flags, backend_fw, fn_name, on_devi fn_name=fn_name, x=x[0], ) + + +# unique_values +@handle_test( + fn_tree="functional.ivy.unique_values", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + min_num_dims=1, + min_dim_size=1, + ), + test_gradients=st.just(False), +) +def test_unique_values(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): + dtype, x = dtype_and_x + assume(not np.any(np.isclose(x, 0.0))) + + helpers.test_function( + input_dtypes=dtype, + test_flags=test_flags, + on_device=on_device, + backend_to_test=backend_fw, + fn_name=fn_name, + x=x[0], + ) diff --git a/ivy_tests/test_ivy/test_functional/test_core/test_sorting.py b/ivy_tests/test_ivy/test_functional/test_core/test_sorting.py index eca19608ed025..62112a704b0a9 100644 --- a/ivy_tests/test_ivy/test_functional/test_core/test_sorting.py +++ b/ivy_tests/test_ivy/test_functional/test_core/test_sorting.py @@ -9,40 +9,60 @@ from ivy_tests.test_ivy.helpers import handle_test -# argsort -@handle_test( - fn_tree="functional.ivy.argsort", - dtype_x_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("valid"), - min_num_dims=1, - min_dim_size=1, - min_axis=-1, - max_axis=0, - ), - descending=st.booleans(), - stable=st.booleans(), - test_gradients=st.just(False), -) -def test_argsort( - *, dtype_x_axis, descending, stable, test_flags, backend_fw, fn_name, on_device -): - dtype, x, axis = dtype_x_axis - helpers.test_function( - input_dtypes=dtype, - test_flags=test_flags, - backend_to_test=backend_fw, - fn_name=fn_name, - on_device=on_device, - x=x[0], - axis=axis, - descending=descending, - stable=stable, +# --- Helpers --- # +# --------------- # + + +@st.composite +def _searchsorted_case1(draw): + # 1-D for x, N-D for v + dtype_x, x = draw( + helpers.dtype_and_values( + dtype=draw(helpers.get_dtypes("numeric", full=False, key="searchsorted")), + shape=(draw(st.integers(min_value=1, max_value=5)),), + ) ) + dtype_v, v = draw( + helpers.dtype_and_values( + dtype=draw(helpers.get_dtypes("numeric", full=False, key="searchsorted")), + min_num_dims=1, + ) + ) + return dtype_x + dtype_v, x + v -# sort +@st.composite +def _searchsorted_case2(draw): + # N-D for x, N-D for v + arb_leading_dims = draw( + helpers.get_shape( + min_num_dims=1, + ) + ) + nx = draw(st.integers(min_value=1, max_value=5)) + nv = draw(st.integers(min_value=1, max_value=5)) + dtype_x, x = draw( + helpers.dtype_and_values( + dtype=draw(helpers.get_dtypes("numeric", full=False, key="searchsorted")), + shape=arb_leading_dims + (nx,), + ) + ) + dtype_v, v = draw( + helpers.dtype_and_values( + dtype=draw(helpers.get_dtypes("numeric", full=False, key="searchsorted")), + shape=arb_leading_dims + (nv,), + ) + ) + return dtype_x + dtype_v, x + v + + +# --- Main --- # +# ------------ # + + +# argsort @handle_test( - fn_tree="functional.ivy.sort", + fn_tree="functional.ivy.argsort", dtype_x_axis=helpers.dtype_values_axis( available_dtypes=helpers.get_dtypes("valid"), min_num_dims=1, @@ -54,7 +74,7 @@ def test_argsort( stable=st.booleans(), test_gradients=st.just(False), ) -def test_sort( +def test_argsort( *, dtype_x_axis, descending, stable, test_flags, backend_fw, fn_name, on_device ): dtype, x, axis = dtype_x_axis @@ -97,49 +117,6 @@ def test_msort(dtype_and_x, test_flags, backend_fw, fn_name, on_device): ) -@st.composite -def _searchsorted_case1(draw): - # 1-D for x, N-D for v - dtype_x, x = draw( - helpers.dtype_and_values( - dtype=draw(helpers.get_dtypes("numeric", full=False, key="searchsorted")), - shape=(draw(st.integers(min_value=1, max_value=5)),), - ) - ) - dtype_v, v = draw( - helpers.dtype_and_values( - dtype=draw(helpers.get_dtypes("numeric", full=False, key="searchsorted")), - min_num_dims=1, - ) - ) - return dtype_x + dtype_v, x + v - - -@st.composite -def _searchsorted_case2(draw): - # N-D for x, N-D for v - arb_leading_dims = draw( - helpers.get_shape( - min_num_dims=1, - ) - ) - nx = draw(st.integers(min_value=1, max_value=5)) - nv = draw(st.integers(min_value=1, max_value=5)) - dtype_x, x = draw( - helpers.dtype_and_values( - dtype=draw(helpers.get_dtypes("numeric", full=False, key="searchsorted")), - shape=arb_leading_dims + (nx,), - ) - ) - dtype_v, v = draw( - helpers.dtype_and_values( - dtype=draw(helpers.get_dtypes("numeric", full=False, key="searchsorted")), - shape=arb_leading_dims + (nv,), - ) - ) - return dtype_x + dtype_v, x + v - - @handle_test( fn_tree="functional.ivy.searchsorted", data=st.data(), @@ -181,3 +158,34 @@ def test_searchsorted( sorter=sorter, ret_dtype=ret_dtype[0], ) + + +# sort +@handle_test( + fn_tree="functional.ivy.sort", + dtype_x_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=1, + min_dim_size=1, + min_axis=-1, + max_axis=0, + ), + descending=st.booleans(), + stable=st.booleans(), + test_gradients=st.just(False), +) +def test_sort( + *, dtype_x_axis, descending, stable, test_flags, backend_fw, fn_name, on_device +): + dtype, x, axis = dtype_x_axis + helpers.test_function( + input_dtypes=dtype, + test_flags=test_flags, + backend_to_test=backend_fw, + fn_name=fn_name, + on_device=on_device, + x=x[0], + axis=axis, + descending=descending, + stable=stable, + ) diff --git a/ivy_tests/test_ivy/test_functional/test_core/test_statistical.py b/ivy_tests/test_ivy/test_functional/test_core/test_statistical.py index 9f5f7314d7a93..43e74bf3c51ca 100644 --- a/ivy_tests/test_ivy/test_functional/test_core/test_statistical.py +++ b/ivy_tests/test_ivy/test_functional/test_core/test_statistical.py @@ -8,6 +8,33 @@ from ivy_tests.test_ivy.helpers import handle_test +# --- Helpers --- # +# --------------- # + + +@st.composite +def _get_castable_dtype(draw, min_value=None, max_value=None): + available_dtypes = helpers.get_dtypes("numeric") + shape = draw(helpers.get_shape(min_num_dims=1, max_num_dims=4, max_dim_size=6)) + dtype, values = draw( + helpers.dtype_and_values( + available_dtypes=available_dtypes, + num_arrays=1, + large_abs_safety_factor=6, + small_abs_safety_factor=24, + safety_factor_scale="log", + shape=shape, + min_value=min_value, + max_value=max_value, + ) + ) + axis = draw(helpers.get_axis(shape=shape, force_int=True)) + dtype1, values, dtype2 = draw( + helpers.get_castable_dtype(draw(available_dtypes), dtype[0], values[0]) + ) + return dtype1, [values], axis, dtype2 + + @st.composite def _statistical_dtype_values(draw, *, function, min_value=None, max_value=None): large_abs_safety_factor = 2 @@ -53,46 +80,129 @@ def _statistical_dtype_values(draw, *, function, min_value=None, max_value=None) return dtype, values, axis -@st.composite -def _get_castable_dtype(draw, min_value=None, max_value=None): - available_dtypes = helpers.get_dtypes("numeric") - shape = draw(helpers.get_shape(min_num_dims=1, max_num_dims=4, max_dim_size=6)) - dtype, values = draw( - helpers.dtype_and_values( - available_dtypes=available_dtypes, - num_arrays=1, - large_abs_safety_factor=6, - small_abs_safety_factor=24, - safety_factor_scale="log", - shape=shape, - min_value=min_value, - max_value=max_value, - ) - ) - axis = draw(helpers.get_axis(shape=shape, force_int=True)) - dtype1, values, dtype2 = draw( - helpers.get_castable_dtype(draw(available_dtypes), dtype[0], values[0]) +# --- Main --- # +# ------------ # + + +# cumprod +@handle_test( + fn_tree="functional.ivy.cumprod", + dtype_x_axis_castable=_get_castable_dtype(), + exclusive=st.booleans(), + reverse=st.booleans(), +) +def test_cumprod( + *, + dtype_x_axis_castable, + exclusive, + reverse, + test_flags, + backend_fw, + fn_name, + on_device, +): + input_dtype, x, axis, castable_dtype = dtype_x_axis_castable + # ToDo: set as_variable_flags as the parameter generated by test_cumprod once + # this issue is marked as completed https://github.com/pytorch/pytorch/issues/75733 + if "torch" in backend_fw: + assume(not test_flags.as_variable[0]) + assume(not test_flags.test_gradients) + # gradient tests have been disabled for cumprod as the gradients computed by the + # backends are inconsistent with tensorflow returning a zero gradient when the + # product is zero (discrete optimization), and torch and jax returning a non-zero + # gradient based on the value used to compute the product even if it's zero + # ToDo: Revisit this later + if np.abs(np.min(np.abs(x[0])) - 0) < 1e-4: + assume(not test_flags.test_gradients) + helpers.test_function( + input_dtypes=[input_dtype], + test_flags=test_flags, + backend_to_test=backend_fw, + fn_name=fn_name, + on_device=on_device, + x=x[0], + axis=axis, + exclusive=exclusive, + reverse=reverse, + dtype=castable_dtype, + rtol_=1e-1, + atol_=1e-1, ) - return dtype1, [values], axis, dtype2 -# min @handle_test( - fn_tree="functional.ivy.min", - dtype_and_x=_statistical_dtype_values(function="min"), - keep_dims=st.booleans(), + fn_tree="functional.ivy.cumsum", + dtype_x_axis_castable=_get_castable_dtype(), + exclusive=st.booleans(), + reverse=st.booleans(), ) -def test_min(*, dtype_and_x, keep_dims, test_flags, backend_fw, fn_name, on_device): - input_dtype, x, axis = dtype_and_x +def test_cumsum( + *, + dtype_x_axis_castable, + exclusive, + reverse, + test_flags, + backend_fw, + fn_name, + on_device, +): + input_dtype, x, axis, castable_dtype = dtype_x_axis_castable + # ToDo: set as_variable_flags as the parameter generated by test_cumsum once + # this issue is marked as completed https://github.com/pytorch/pytorch/issues/75733 + if "torch" in backend_fw: + assume(not test_flags.as_variable[0]) + assume(not test_flags.test_gradients) helpers.test_function( - input_dtypes=input_dtype, + input_dtypes=[input_dtype], test_flags=test_flags, backend_to_test=backend_fw, fn_name=fn_name, on_device=on_device, x=x[0], axis=axis, - keepdims=keep_dims, + exclusive=exclusive, + reverse=reverse, + dtype=castable_dtype, + rtol_=1e-1, + atol_=1e-1, + ) + + +# TODO: add more general tests and fix get instance method testing passing +# einsum +@handle_test( + fn_tree="functional.ivy.einsum", + eq_n_op_n_shp=helpers.einsum_helper(), + test_instance_method=st.just(False), + dtype=helpers.get_dtypes("numeric", full=False), +) +def test_einsum( + *, + eq_n_op_n_shp, + dtype, + test_flags, + backend_fw, + fn_name, + on_device, +): + eq, operands, dtypes = eq_n_op_n_shp + kw = {} + # x_dtype = np.dtype(dtype[0]) + for i, x_ in enumerate(operands): + dtype = dtypes[i][0] + kw["x{}".format(i)] = np.array(x_).astype(dtype) + # len(operands) + 1 because of the equation + test_flags.num_positional_args = len(operands) + 1 + helpers.test_function( + input_dtypes=dtypes, + test_flags=test_flags, + backend_to_test=backend_fw, + fn_name=fn_name, + on_device=on_device, + equation=eq, + **kw, + rtol_=1e-2, + atol_=1e-2, ) @@ -139,25 +249,22 @@ def test_mean(*, dtype_and_x, keep_dims, test_flags, backend_fw, fn_name, on_dev ) -# var +# min @handle_test( - fn_tree="functional.ivy.var", - dtype_and_x=_statistical_dtype_values(function="var"), + fn_tree="functional.ivy.min", + dtype_and_x=_statistical_dtype_values(function="min"), keep_dims=st.booleans(), ) -def test_var(*, dtype_and_x, keep_dims, test_flags, backend_fw, fn_name, on_device): - input_dtype, x, axis, correction = dtype_and_x +def test_min(*, dtype_and_x, keep_dims, test_flags, backend_fw, fn_name, on_device): + input_dtype, x, axis = dtype_and_x helpers.test_function( input_dtypes=input_dtype, test_flags=test_flags, backend_to_test=backend_fw, fn_name=fn_name, on_device=on_device, - rtol_=1e-1, - atol_=1e-1, x=x[0], axis=axis, - correction=correction, keepdims=keep_dims, ) @@ -192,36 +299,6 @@ def test_prod( ) -# sum -@handle_test( - fn_tree="functional.ivy.sum", - dtype_x_axis_castable=_get_castable_dtype(), - keep_dims=st.booleans(), -) -def test_sum( - *, dtype_x_axis_castable, keep_dims, test_flags, backend_fw, fn_name, on_device -): - input_dtype, x, axis, castable_dtype = dtype_x_axis_castable - # ToDo: set as_variable_flags as the parameter generated by test_sum once - # this issue is marked as completed https://github.com/pytorch/pytorch/issues/75733 - if "torch" in backend_fw: - assume(not test_flags.as_variable[0]) - assume(not test_flags.test_gradients) - helpers.test_function( - input_dtypes=[input_dtype], - test_flags=test_flags, - backend_to_test=backend_fw, - fn_name=fn_name, - on_device=on_device, - rtol_=1e-1, - atol_=1e-2, - x=x[0], - axis=axis, - keepdims=keep_dims, - dtype=castable_dtype, - ) - - # std @handle_test( fn_tree="functional.ivy.std", @@ -245,24 +322,17 @@ def test_std(*, dtype_and_x, keep_dims, test_flags, backend_fw, fn_name, on_devi ) +# sum @handle_test( - fn_tree="functional.ivy.cumsum", + fn_tree="functional.ivy.sum", dtype_x_axis_castable=_get_castable_dtype(), - exclusive=st.booleans(), - reverse=st.booleans(), + keep_dims=st.booleans(), ) -def test_cumsum( - *, - dtype_x_axis_castable, - exclusive, - reverse, - test_flags, - backend_fw, - fn_name, - on_device, +def test_sum( + *, dtype_x_axis_castable, keep_dims, test_flags, backend_fw, fn_name, on_device ): input_dtype, x, axis, castable_dtype = dtype_x_axis_castable - # ToDo: set as_variable_flags as the parameter generated by test_cumsum once + # ToDo: set as_variable_flags as the parameter generated by test_sum once # this issue is marked as completed https://github.com/pytorch/pytorch/issues/75733 if "torch" in backend_fw: assume(not test_flags.as_variable[0]) @@ -273,95 +343,33 @@ def test_cumsum( backend_to_test=backend_fw, fn_name=fn_name, on_device=on_device, + rtol_=1e-1, + atol_=1e-2, x=x[0], axis=axis, - exclusive=exclusive, - reverse=reverse, + keepdims=keep_dims, dtype=castable_dtype, - rtol_=1e-1, - atol_=1e-1, ) -# cumprod +# var @handle_test( - fn_tree="functional.ivy.cumprod", - dtype_x_axis_castable=_get_castable_dtype(), - exclusive=st.booleans(), - reverse=st.booleans(), + fn_tree="functional.ivy.var", + dtype_and_x=_statistical_dtype_values(function="var"), + keep_dims=st.booleans(), ) -def test_cumprod( - *, - dtype_x_axis_castable, - exclusive, - reverse, - test_flags, - backend_fw, - fn_name, - on_device, -): - input_dtype, x, axis, castable_dtype = dtype_x_axis_castable - # ToDo: set as_variable_flags as the parameter generated by test_cumprod once - # this issue is marked as completed https://github.com/pytorch/pytorch/issues/75733 - if "torch" in backend_fw: - assume(not test_flags.as_variable[0]) - assume(not test_flags.test_gradients) - # gradient tests have been disabled for cumprod as the gradients computed by the - # backends are inconsistent with tensorflow returning a zero gradient when the - # product is zero (discrete optimization), and torch and jax returning a non-zero - # gradient based on the value used to compute the product even if it's zero - # ToDo: Revisit this later - if np.abs(np.min(np.abs(x[0])) - 0) < 1e-4: - assume(not test_flags.test_gradients) +def test_var(*, dtype_and_x, keep_dims, test_flags, backend_fw, fn_name, on_device): + input_dtype, x, axis, correction = dtype_and_x helpers.test_function( - input_dtypes=[input_dtype], + input_dtypes=input_dtype, test_flags=test_flags, backend_to_test=backend_fw, fn_name=fn_name, on_device=on_device, - x=x[0], - axis=axis, - exclusive=exclusive, - reverse=reverse, - dtype=castable_dtype, rtol_=1e-1, atol_=1e-1, - ) - - -# TODO: add more general tests and fix get instance method testing passing -# einsum -@handle_test( - fn_tree="functional.ivy.einsum", - eq_n_op_n_shp=helpers.einsum_helper(), - test_instance_method=st.just(False), - dtype=helpers.get_dtypes("numeric", full=False), -) -def test_einsum( - *, - eq_n_op_n_shp, - dtype, - test_flags, - backend_fw, - fn_name, - on_device, -): - eq, operands, dtypes = eq_n_op_n_shp - kw = {} - # x_dtype = np.dtype(dtype[0]) - for i, x_ in enumerate(operands): - dtype = dtypes[i][0] - kw["x{}".format(i)] = np.array(x_).astype(dtype) - # len(operands) + 1 because of the equation - test_flags.num_positional_args = len(operands) + 1 - helpers.test_function( - input_dtypes=dtypes, - test_flags=test_flags, - backend_to_test=backend_fw, - fn_name=fn_name, - on_device=on_device, - equation=eq, - **kw, - rtol_=1e-2, - atol_=1e-2, + x=x[0], + axis=axis, + correction=correction, + keepdims=keep_dims, ) diff --git a/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_creation.py b/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_creation.py index 66378354fc5cb..e132bb235e375 100644 --- a/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_creation.py +++ b/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_creation.py @@ -7,126 +7,111 @@ from ivy_tests.test_ivy.helpers import handle_test -# vorbis_window -@handle_test( - fn_tree="functional.ivy.experimental.vorbis_window", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("integer"), - max_num_dims=0, - min_value=1, - max_value=10, - ), - dtype=helpers.get_dtypes("float", full=False), - test_gradients=st.just(False), - test_instance_method=st.just(False), -) -def test_vorbis_window( - *, dtype_and_x, dtype, test_flags, backend_fw, fn_name, on_device -): - input_dtype, x = dtype_and_x - helpers.test_function( - input_dtypes=input_dtype, - test_flags=test_flags, - atol_=1e-02, - backend_to_test=backend_fw, - fn_name=fn_name, - on_device=on_device, - window_length=int(x[0]), - dtype=dtype[0], - ) +# --- Helpers --- # +# --------------- # -# TODO: fix return precision problem when dtype=bfloat16 -# hann_window -@handle_test( - fn_tree="functional.ivy.experimental.hann_window", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("integer"), - max_num_dims=0, - min_value=1, - max_value=10, - ), - periodic=st.booleans(), - dtype=helpers.get_dtypes("float", full=False), - test_gradients=st.just(False), - test_instance_method=st.just(False), -) -def test_hann_window( - *, dtype_and_x, periodic, dtype, test_flags, backend_fw, fn_name, on_device -): - input_dtype, x = dtype_and_x - helpers.test_function( - input_dtypes=input_dtype, - test_flags=test_flags, - backend_to_test=backend_fw, - atol_=0.015, - fn_name=fn_name, - on_device=on_device, - size=int(x[0]), - periodic=periodic, - dtype=dtype[0], +@st.composite +def _random_cp_data(draw): + shape = draw( + st.lists(helpers.ints(min_value=1, max_value=5), min_size=2, max_size=4) ) + rank = draw(helpers.ints(min_value=1, max_value=10)) + dtype = draw(helpers.get_dtypes("float", full=False)) + full = draw(st.booleans()) + orthogonal = draw(st.booleans()) + if (rank > min(shape)) and orthogonal: + rank = min(shape) + seed = draw(st.one_of((st.just(None), helpers.ints(min_value=0, max_value=2000)))) + normalise_factors = draw(st.booleans()) + return shape, rank, dtype[0], full, orthogonal, seed, normalise_factors -# kaiser_window -@handle_test( - fn_tree="functional.ivy.experimental.kaiser_window", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("integer"), - max_num_dims=0, - min_value=1, - max_value=10, - ), - periodic=st.booleans(), - beta=st.floats(min_value=0, max_value=5), - dtype=helpers.get_dtypes("float", full=False), - test_gradients=st.just(False), - test_instance_method=st.just(False), -) -def test_kaiser_window( - *, dtype_and_x, periodic, beta, dtype, test_flags, backend_fw, fn_name, on_device -): - input_dtype, x = dtype_and_x - helpers.test_function( - input_dtypes=input_dtype, - test_flags=test_flags, - backend_to_test=backend_fw, - fn_name=fn_name, - on_device=on_device, - window_length=int(x[0]), - periodic=periodic, - beta=beta, - dtype=dtype[0], +@st.composite +def _random_tucker_data(draw): + shape = draw( + st.lists(helpers.ints(min_value=1, max_value=5), min_size=2, max_size=4) ) + rank = [] + for dim in shape: + rank.append(draw(helpers.ints(min_value=1, max_value=dim))) + dtype = draw(helpers.get_dtypes("float", full=False)) + full = draw(st.booleans()) + orthogonal = draw(st.booleans()) + seed = draw(st.one_of((st.just(None), helpers.ints(min_value=0, max_value=2000)))) + non_negative = draw(st.booleans()) + return shape, rank, dtype[0], full, orthogonal, seed, non_negative -# kaiser_bessel_derived_window +@st.composite +def valid_unsorted_segment_min_inputs(draw): + while True: + dtype = draw(st.sampled_from([ivy.int32, ivy.int64, ivy.float32, ivy.float64])) + segment_ids_dim = draw(st.integers(min_value=3, max_value=10)) + num_segments = draw(st.integers(min_value=2, max_value=segment_ids_dim)) + + data_dim = draw( + helpers.get_shape( + min_dim_size=segment_ids_dim, + max_dim_size=segment_ids_dim, + min_num_dims=1, + max_num_dims=4, + ) + ) + data_dim = (segment_ids_dim,) + data_dim[1:] + + data = draw( + helpers.array_values( + dtype=dtype, + shape=data_dim, + min_value=1, + max_value=10, + ) + ) + + segment_ids = draw( + helpers.array_values( + dtype=ivy.int32, + shape=(segment_ids_dim,), + min_value=0, + max_value=num_segments + 1, + ) + ) + if data.shape[0] == segment_ids.shape[0]: + if np.max(segment_ids) < num_segments: + return (dtype, ivy.int32), data, num_segments, segment_ids + + +# --- Main --- # +# ------------ # + + +# eye_like @handle_test( - fn_tree="functional.ivy.experimental.kaiser_bessel_derived_window", + fn_tree="functional.ivy.experimental.eye_like", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("integer"), - max_num_dims=0, - min_value=1, - max_value=10, + available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=1, + max_num_dims=1, + min_dim_size=1, + max_dim_size=5, ), - beta=st.floats(min_value=1, max_value=5), - dtype=helpers.get_dtypes("float", full=False), + k=helpers.ints(min_value=-10, max_value=10), test_gradients=st.just(False), - test_instance_method=st.just(False), + number_positional_args=st.just(1), ) -def test_kaiser_bessel_derived_window( - *, dtype_and_x, beta, dtype, test_flags, backend_fw, fn_name, on_device -): - input_dtype, x = dtype_and_x +def test_eye_like(*, dtype_and_x, k, test_flags, backend_fw, fn_name, on_device): + dtype, x = dtype_and_x helpers.test_function( - input_dtypes=input_dtype, + input_dtypes=dtype, test_flags=test_flags, + on_device=on_device, backend_to_test=backend_fw, fn_name=fn_name, - on_device=on_device, - window_length=int(x[0]), - beta=beta, + x=x[0], + k=k, dtype=dtype[0], + device=on_device, ) @@ -179,94 +164,38 @@ def test_hamming_window( ) +# TODO: fix return precision problem when dtype=bfloat16 +# hann_window @handle_test( - fn_tree="functional.ivy.experimental.tril_indices", - dtype_and_n=helpers.dtype_and_values( + fn_tree="functional.ivy.experimental.hann_window", + dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("integer"), max_num_dims=0, - num_arrays=2, - min_value=0, + min_value=1, max_value=10, ), - k=helpers.ints(min_value=-11, max_value=11), - test_with_out=st.just(False), + periodic=st.booleans(), + dtype=helpers.get_dtypes("float", full=False), test_gradients=st.just(False), test_instance_method=st.just(False), ) -def test_tril_indices(*, dtype_and_n, k, test_flags, backend_fw, fn_name, on_device): - input_dtype, x = dtype_and_n +def test_hann_window( + *, dtype_and_x, periodic, dtype, test_flags, backend_fw, fn_name, on_device +): + input_dtype, x = dtype_and_x helpers.test_function( input_dtypes=input_dtype, test_flags=test_flags, backend_to_test=backend_fw, - on_device=on_device, + atol_=0.015, fn_name=fn_name, - n_rows=int(x[0]), - n_cols=int(x[1]), - k=k, - ) - - -# eye_like -@handle_test( - fn_tree="functional.ivy.experimental.eye_like", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - min_num_dims=1, - max_num_dims=1, - min_dim_size=1, - max_dim_size=5, - ), - k=helpers.ints(min_value=-10, max_value=10), - test_gradients=st.just(False), - number_positional_args=st.just(1), -) -def test_eye_like(*, dtype_and_x, k, test_flags, backend_fw, fn_name, on_device): - dtype, x = dtype_and_x - helpers.test_function( - input_dtypes=dtype, - test_flags=test_flags, on_device=on_device, - backend_to_test=backend_fw, - fn_name=fn_name, - x=x[0], - k=k, + size=int(x[0]), + periodic=periodic, dtype=dtype[0], - device=on_device, ) -# ndenumerate -@handle_test( - fn_tree="functional.ivy.experimental.ndenumerate", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - min_num_dims=1, - ), -) -def test_ndenumerate(dtype_and_x): - values = dtype_and_x[1][0] - for (index1, x1), (index2, x2) in zip( - np.ndenumerate(values), ivy.ndenumerate(values) - ): - assert index1 == index2 and x1 == x2.to_numpy() - - -# ndindex -@handle_test( - fn_tree="functional.ivy.experimental.ndindex", - dtype_x_shape=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - min_num_dims=1, - ret_shape=True, - ), -) -def test_ndindex(dtype_x_shape): - shape = dtype_x_shape[2] - for index1, index2 in zip(np.ndindex(shape), ivy.ndindex(shape)): - assert index1 == index2 - - # indices @handle_test( fn_tree="functional.ivy.experimental.indices", @@ -301,150 +230,106 @@ def test_indices(*, shape, dtypes, sparse, test_flags, backend_fw, fn_name, on_d ) -@st.composite -def valid_unsorted_segment_min_inputs(draw): - while True: - dtype = draw(st.sampled_from([ivy.int32, ivy.int64, ivy.float32, ivy.float64])) - segment_ids_dim = draw(st.integers(min_value=3, max_value=10)) - num_segments = draw(st.integers(min_value=2, max_value=segment_ids_dim)) - - data_dim = draw( - helpers.get_shape( - min_dim_size=segment_ids_dim, - max_dim_size=segment_ids_dim, - min_num_dims=1, - max_num_dims=4, - ) - ) - data_dim = (segment_ids_dim,) + data_dim[1:] - - data = draw( - helpers.array_values( - dtype=dtype, - shape=data_dim, - min_value=1, - max_value=10, - ) - ) - - segment_ids = draw( - helpers.array_values( - dtype=ivy.int32, - shape=(segment_ids_dim,), - min_value=0, - max_value=num_segments + 1, - ) - ) - if data.shape[0] == segment_ids.shape[0]: - if np.max(segment_ids) < num_segments: - return (dtype, ivy.int32), data, num_segments, segment_ids - - -# unsorted_segment_min -@handle_test( - fn_tree="functional.ivy.experimental.unsorted_segment_min", - d_x_n_s=valid_unsorted_segment_min_inputs(), - test_with_out=st.just(False), - test_gradients=st.just(False), -) -def test_unsorted_segment_min( - *, - d_x_n_s, - test_flags, - backend_fw, - fn_name, - on_device, -): - dtypes, data, num_segments, segment_ids = d_x_n_s - helpers.test_function( - input_dtypes=dtypes, - backend_to_test=backend_fw, - test_flags=test_flags, - on_device=on_device, - fn_name=fn_name, - data=data, - segment_ids=segment_ids, - num_segments=num_segments, - ) - - +# kaiser_bessel_derived_window @handle_test( - fn_tree="functional.ivy.experimental.unsorted_segment_sum", - d_x_n_s=valid_unsorted_segment_min_inputs(), - test_with_out=st.just(False), + fn_tree="functional.ivy.experimental.kaiser_bessel_derived_window", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("integer"), + max_num_dims=0, + min_value=1, + max_value=10, + ), + beta=st.floats(min_value=1, max_value=5), + dtype=helpers.get_dtypes("float", full=False), test_gradients=st.just(False), + test_instance_method=st.just(False), ) -def test_unsorted_segment_sum( - *, - d_x_n_s, - test_flags, - backend_fw, - fn_name, - on_device, +def test_kaiser_bessel_derived_window( + *, dtype_and_x, beta, dtype, test_flags, backend_fw, fn_name, on_device ): - dtypes, data, num_segments, segment_ids = d_x_n_s + input_dtype, x = dtype_and_x helpers.test_function( - input_dtypes=dtypes, + input_dtypes=input_dtype, test_flags=test_flags, - on_device=on_device, backend_to_test=backend_fw, fn_name=fn_name, - data=data, - segment_ids=segment_ids, - num_segments=num_segments, + on_device=on_device, + window_length=int(x[0]), + beta=beta, + dtype=dtype[0], ) +# kaiser_window @handle_test( - fn_tree="functional.ivy.experimental.unsorted_segment_sum", - d_x_n_s=valid_unsorted_segment_min_inputs(), - test_with_out=st.just(False), + fn_tree="functional.ivy.experimental.kaiser_window", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("integer"), + max_num_dims=0, + min_value=1, + max_value=10, + ), + periodic=st.booleans(), + beta=st.floats(min_value=0, max_value=5), + dtype=helpers.get_dtypes("float", full=False), test_gradients=st.just(False), -) -def test_unsorted_segment_sum( - *, - d_x_n_s, - test_flags, - backend_fw, - fn_name, - on_device, + test_instance_method=st.just(False), +) +def test_kaiser_window( + *, dtype_and_x, periodic, beta, dtype, test_flags, backend_fw, fn_name, on_device ): - dtypes, data, num_segments, segment_ids = d_x_n_s + input_dtype, x = dtype_and_x helpers.test_function( - input_dtypes=dtypes, + input_dtypes=input_dtype, test_flags=test_flags, - on_device=on_device, - fw=backend_fw, + backend_to_test=backend_fw, fn_name=fn_name, - data=data, - segment_ids=segment_ids, - num_segments=num_segments, + on_device=on_device, + window_length=int(x[0]), + periodic=periodic, + beta=beta, + dtype=dtype[0], ) -@st.composite -def _random_tucker_data(draw): - shape = draw( - st.lists(helpers.ints(min_value=1, max_value=5), min_size=2, max_size=4) - ) - rank = [] - for dim in shape: - rank.append(draw(helpers.ints(min_value=1, max_value=dim))) - dtype = draw(helpers.get_dtypes("float", full=False)) - full = draw(st.booleans()) - orthogonal = draw(st.booleans()) - seed = draw(st.one_of((st.just(None), helpers.ints(min_value=0, max_value=2000)))) - non_negative = draw(st.booleans()) - return shape, rank, dtype[0], full, orthogonal, seed, non_negative +# ndenumerate +@handle_test( + fn_tree="functional.ivy.experimental.ndenumerate", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=1, + ), +) +def test_ndenumerate(dtype_and_x): + values = dtype_and_x[1][0] + for (index1, x1), (index2, x2) in zip( + np.ndenumerate(values), ivy.ndenumerate(values) + ): + assert index1 == index2 and x1 == x2.to_numpy() +# ndindex @handle_test( - fn_tree="functional.ivy.experimental.random_tucker", - data=_random_tucker_data(), + fn_tree="functional.ivy.experimental.ndindex", + dtype_x_shape=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=1, + ret_shape=True, + ), +) +def test_ndindex(dtype_x_shape): + shape = dtype_x_shape[2] + for index1, index2 in zip(np.ndindex(shape), ivy.ndindex(shape)): + assert index1 == index2 + + +@handle_test( + fn_tree="functional.ivy.experimental.random_cp", + data=_random_cp_data(), test_with_out=st.just(False), test_instance_method=st.just(False), ) -def test_random_tucker( +def test_random_cp( *, data, test_flags, @@ -452,7 +337,7 @@ def test_random_tucker( fn_name, on_device, ): - shape, rank, dtype, full, orthogonal, seed, non_negative = data + shape, rank, dtype, full, orthogonal, seed, normalise_factors = data results = helpers.test_function( input_dtypes=[], backend_to_test=backend_fw, @@ -465,7 +350,7 @@ def test_random_tucker( full=full, orthogonal=orthogonal, seed=seed, - non_negative=non_negative, + normalise_factors=normalise_factors, test_values=False, ) @@ -481,46 +366,30 @@ def test_random_tucker( assert np.prod(shape) == np.prod(x_gt.shape) else: - core = helpers.flatten_and_to_np(ret=ret_np[0], backend=backend_fw) + weights = helpers.flatten_and_to_np(ret=ret_np[0], backend=backend_fw) factors = helpers.flatten_and_to_np(ret=ret_np[1], backend=backend_fw) - core_gt = helpers.flatten_and_to_np( + weights_gt = helpers.flatten_and_to_np( ret=ret_from_gt_np[0], backend=test_flags.ground_truth_backend ) factors_gt = helpers.flatten_and_to_np( ret=ret_from_gt_np[1], backend=test_flags.ground_truth_backend ) - for c, c_gt in zip(core, core_gt): - assert np.prod(c.shape) == np.prod(rank) - assert np.prod(c_gt.shape) == np.prod(rank) + for w, w_gt in zip(weights, weights_gt): + assert len(w) == rank + assert len(w_gt) == rank for f, f_gt in zip(factors, factors_gt): assert np.prod(f.shape) == np.prod(f_gt.shape) -@st.composite -def _random_cp_data(draw): - shape = draw( - st.lists(helpers.ints(min_value=1, max_value=5), min_size=2, max_size=4) - ) - rank = draw(helpers.ints(min_value=1, max_value=10)) - dtype = draw(helpers.get_dtypes("float", full=False)) - full = draw(st.booleans()) - orthogonal = draw(st.booleans()) - if (rank > min(shape)) and orthogonal: - rank = min(shape) - seed = draw(st.one_of((st.just(None), helpers.ints(min_value=0, max_value=2000)))) - normalise_factors = draw(st.booleans()) - return shape, rank, dtype[0], full, orthogonal, seed, normalise_factors - - @handle_test( - fn_tree="functional.ivy.experimental.random_cp", - data=_random_cp_data(), + fn_tree="functional.ivy.experimental.random_tucker", + data=_random_tucker_data(), test_with_out=st.just(False), test_instance_method=st.just(False), ) -def test_random_cp( +def test_random_tucker( *, data, test_flags, @@ -528,7 +397,7 @@ def test_random_cp( fn_name, on_device, ): - shape, rank, dtype, full, orthogonal, seed, normalise_factors = data + shape, rank, dtype, full, orthogonal, seed, non_negative = data results = helpers.test_function( input_dtypes=[], backend_to_test=backend_fw, @@ -541,7 +410,7 @@ def test_random_cp( full=full, orthogonal=orthogonal, seed=seed, - normalise_factors=normalise_factors, + non_negative=non_negative, test_values=False, ) @@ -557,23 +426,51 @@ def test_random_cp( assert np.prod(shape) == np.prod(x_gt.shape) else: - weights = helpers.flatten_and_to_np(ret=ret_np[0], backend=backend_fw) + core = helpers.flatten_and_to_np(ret=ret_np[0], backend=backend_fw) factors = helpers.flatten_and_to_np(ret=ret_np[1], backend=backend_fw) - weights_gt = helpers.flatten_and_to_np( + core_gt = helpers.flatten_and_to_np( ret=ret_from_gt_np[0], backend=test_flags.ground_truth_backend ) factors_gt = helpers.flatten_and_to_np( ret=ret_from_gt_np[1], backend=test_flags.ground_truth_backend ) - for w, w_gt in zip(weights, weights_gt): - assert len(w) == rank - assert len(w_gt) == rank + for c, c_gt in zip(core, core_gt): + assert np.prod(c.shape) == np.prod(rank) + assert np.prod(c_gt.shape) == np.prod(rank) for f, f_gt in zip(factors, factors_gt): assert np.prod(f.shape) == np.prod(f_gt.shape) +@handle_test( + fn_tree="functional.ivy.experimental.tril_indices", + dtype_and_n=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("integer"), + max_num_dims=0, + num_arrays=2, + min_value=0, + max_value=10, + ), + k=helpers.ints(min_value=-11, max_value=11), + test_with_out=st.just(False), + test_gradients=st.just(False), + test_instance_method=st.just(False), +) +def test_tril_indices(*, dtype_and_n, k, test_flags, backend_fw, fn_name, on_device): + input_dtype, x = dtype_and_n + helpers.test_function( + input_dtypes=input_dtype, + test_flags=test_flags, + backend_to_test=backend_fw, + on_device=on_device, + fn_name=fn_name, + n_rows=int(x[0]), + n_cols=int(x[1]), + k=k, + ) + + @handle_test( fn_tree="functional.ivy.experimental.trilu", dtype_and_x=helpers.dtype_and_values( @@ -599,3 +496,114 @@ def test_trilu(*, dtype_and_x, k, upper, test_flags, backend_fw, fn_name, on_dev upper=upper, k=k, ) + + +# unsorted_segment_min +@handle_test( + fn_tree="functional.ivy.experimental.unsorted_segment_min", + d_x_n_s=valid_unsorted_segment_min_inputs(), + test_with_out=st.just(False), + test_gradients=st.just(False), +) +def test_unsorted_segment_min( + *, + d_x_n_s, + test_flags, + backend_fw, + fn_name, + on_device, +): + dtypes, data, num_segments, segment_ids = d_x_n_s + helpers.test_function( + input_dtypes=dtypes, + backend_to_test=backend_fw, + test_flags=test_flags, + on_device=on_device, + fn_name=fn_name, + data=data, + segment_ids=segment_ids, + num_segments=num_segments, + ) + + +@handle_test( + fn_tree="functional.ivy.experimental.unsorted_segment_sum", + d_x_n_s=valid_unsorted_segment_min_inputs(), + test_with_out=st.just(False), + test_gradients=st.just(False), +) +def test_unsorted_segment_sum( + *, + d_x_n_s, + test_flags, + backend_fw, + fn_name, + on_device, +): + dtypes, data, num_segments, segment_ids = d_x_n_s + helpers.test_function( + input_dtypes=dtypes, + test_flags=test_flags, + on_device=on_device, + backend_to_test=backend_fw, + fn_name=fn_name, + data=data, + segment_ids=segment_ids, + num_segments=num_segments, + ) + + +@handle_test( + fn_tree="functional.ivy.experimental.unsorted_segment_sum", + d_x_n_s=valid_unsorted_segment_min_inputs(), + test_with_out=st.just(False), + test_gradients=st.just(False), +) +def test_unsorted_segment_sum( + *, + d_x_n_s, + test_flags, + backend_fw, + fn_name, + on_device, +): + dtypes, data, num_segments, segment_ids = d_x_n_s + helpers.test_function( + input_dtypes=dtypes, + test_flags=test_flags, + on_device=on_device, + fw=backend_fw, + fn_name=fn_name, + data=data, + segment_ids=segment_ids, + num_segments=num_segments, + ) + + +# vorbis_window +@handle_test( + fn_tree="functional.ivy.experimental.vorbis_window", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("integer"), + max_num_dims=0, + min_value=1, + max_value=10, + ), + dtype=helpers.get_dtypes("float", full=False), + test_gradients=st.just(False), + test_instance_method=st.just(False), +) +def test_vorbis_window( + *, dtype_and_x, dtype, test_flags, backend_fw, fn_name, on_device +): + input_dtype, x = dtype_and_x + helpers.test_function( + input_dtypes=input_dtype, + test_flags=test_flags, + atol_=1e-02, + backend_to_test=backend_fw, + fn_name=fn_name, + on_device=on_device, + window_length=int(x[0]), + dtype=dtype[0], + ) diff --git a/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_elementwise.py b/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_elementwise.py index 79224648053ec..dcb1132163caf 100644 --- a/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_elementwise.py +++ b/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_elementwise.py @@ -7,82 +7,8 @@ from ivy_tests.test_ivy.helpers import handle_test -# Helpers # -# ------- # -# lgamma -@handle_test( - fn_tree="functional.ivy.experimental.lgamma", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - safety_factor_scale="log", - ), - test_gradients=st.just(False), -) -def test_lgamma( - *, - dtype_and_x, - test_flags, - backend_fw, - fn_name, - on_device, -): - input_dtype, x = dtype_and_x - helpers.test_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - test_flags=test_flags, - on_device=on_device, - fn_name=fn_name, - x=x[0], - ) - - -# sinc -@handle_test( - fn_tree="functional.ivy.experimental.sinc", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - large_abs_safety_factor=4, - small_abs_safety_factor=4, - ), - test_gradients=st.just(False), -) -def test_sinc(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): - input_dtype, x = dtype_and_x - helpers.test_function( - input_dtypes=input_dtype, - test_flags=test_flags, - atol_=1e-02, - on_device=on_device, - backend_to_test=backend_fw, - fn_name=fn_name, - x=x[0], - ) - - -# fmax -@handle_test( - fn_tree="functional.ivy.experimental.fmax", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - large_abs_safety_factor=4, - small_abs_safety_factor=4, - safety_factor_scale="log", - num_arrays=2, - ), - test_gradients=st.just(False), -) -def test_fmax(dtype_and_x, test_flags, backend_fw, fn_name, on_device): - input_dtype, x = dtype_and_x - helpers.test_function( - input_dtypes=input_dtype, - test_flags=test_flags, - on_device=on_device, - backend_to_test=backend_fw, - fn_name=fn_name, - x1=x[0], - x2=x[1], - ) +# --- Helpers --- # +# --------------- # # float_power_helper @@ -113,50 +39,29 @@ def _float_power_helper(draw, *, available_dtypes=None): return (dtype1[0], dtype2[0]), (x1[0], x2[0]) -# float_power -@handle_test( - fn_tree="functional.ivy.experimental.float_power", - dtype_and_x=_float_power_helper(), - test_gradients=st.just(False), -) -def test_float_power(dtype_and_x, test_flags, backend_fw, fn_name, on_device): - input_dtypes, x = dtype_and_x - helpers.test_function( - input_dtypes=input_dtypes, - test_flags=test_flags, - on_device=on_device, - backend_to_test=backend_fw, - fn_name=fn_name, - x1=x[0], - x2=x[1], - rtol_=1e-1, - atol_=1e-1, +# nansum +@st.composite +def _get_castable_dtypes_values(draw, *, allow_nan=False): + available_dtypes = helpers.get_dtypes("numeric") + shape = draw(helpers.get_shape(min_num_dims=1, max_num_dims=4, max_dim_size=6)) + dtype, values = draw( + helpers.dtype_and_values( + available_dtypes=available_dtypes, + num_arrays=1, + large_abs_safety_factor=24, + small_abs_safety_factor=24, + safety_factor_scale="log", + shape=shape, + allow_nan=allow_nan, + ) ) - - -# copysign -@handle_test( - fn_tree="functional.ivy.experimental.copysign", - dtype_x1_x2=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - min_num_dims=0, - allow_nan=False, - shared_dtype=False, - ), - test_gradients=st.just(False), -) -def test_copysign(dtype_x1_x2, test_flags, backend_fw, fn_name, on_device): - (x1_dtype, x2_dtype), (x1, x2) = dtype_x1_x2 - helpers.test_function( - input_dtypes=[x1_dtype, x2_dtype], - test_flags=test_flags, - on_device=on_device, - backend_to_test=backend_fw, - fn_name=fn_name, - x1=x1, - x2=x2, + axis = draw(helpers.get_axis(shape=shape, force_int=True)) + dtype1, values, dtype2 = draw( + helpers.get_castable_dtype( + draw(helpers.get_dtypes("float")), dtype[0], values[0] + ) ) + return [dtype1], [values], axis, dtype2 @st.composite @@ -184,90 +89,111 @@ def _get_dtype_values_axis_for_count_nonzero( return [input_dtype, output_dtype], values, axis -# count_nonzero -@handle_test( - fn_tree="functional.ivy.experimental.count_nonzero", - dtype_values_axis=_get_dtype_values_axis_for_count_nonzero( - in_available_dtypes="integer", - out_available_dtypes="integer", - min_num_dims=1, - max_num_dims=10, - min_dim_size=1, - max_dim_size=10, - ), - keepdims=st.booleans(), - test_with_out=st.just(False), - test_gradients=st.just(False), -) -def test_count_nonzero( - *, dtype_values_axis, keepdims, test_flags, on_device, fn_name, backend_fw -): - i_o_dtype, a, axis = dtype_values_axis - helpers.test_function( - input_dtypes=i_o_dtype[0], - test_flags=test_flags, - on_device=on_device, - backend_to_test=backend_fw, - fn_name=fn_name, - a=a[0], - axis=axis, - keepdims=keepdims, - dtype=i_o_dtype[1][0], +@st.composite +def _lerp_data_helper(draw): + mixed_fn_compos = draw(st.booleans()) + is_torch_backend = ivy.current_backend_str() == "torch" + + kwargs = { + "shared_dtype": True, + "large_abs_safety_factor": 2.5, + "small_abs_safety_factor": 2.5, + "safety_factor_scale": "log", + "allow_nan": False, + "allow_inf": False, + } + + if is_torch_backend and not mixed_fn_compos: + dtype1, start_end = draw( + helpers.dtype_and_values( + available_dtypes=( + helpers.get_dtypes("numeric", mixed_fn_compos=mixed_fn_compos) + ), + num_arrays=2, + **kwargs, + ) + ) + dtype2, weight = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes( + "integer", mixed_fn_compos=mixed_fn_compos + ), + num_arrays=1, + **kwargs, + ) + ) + input_dtypes = dtype1 + dtype2 + inputs = start_end + weight + else: + input_dtypes, inputs = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes( + "valid", mixed_fn_compos=mixed_fn_compos + ), + num_arrays=3, + **kwargs, + ) + ) + + return input_dtypes, inputs[0], inputs[1], inputs[2] + + +@st.composite +def _sparsify_tensor_stg(draw): + dtype, tensor, shape = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + ret_shape=True, + min_num_dims=1, + min_dim_size=1, + min_value=10, + ) ) + size = 1 + for dim in shape: + size *= dim -# nansum + card = draw(st.integers(min_value=1, max_value=size)) + + return dtype, tensor[0], card + + +# ldexp @st.composite -def _get_castable_dtypes_values(draw, *, allow_nan=False): - available_dtypes = helpers.get_dtypes("numeric") - shape = draw(helpers.get_shape(min_num_dims=1, max_num_dims=4, max_dim_size=6)) - dtype, values = draw( +def ldexp_args(draw): + dtype1, x1 = draw( helpers.dtype_and_values( - available_dtypes=available_dtypes, + available_dtypes=["float32", "float64"], num_arrays=1, - large_abs_safety_factor=24, - small_abs_safety_factor=24, - safety_factor_scale="log", - shape=shape, - allow_nan=allow_nan, + shared_dtype=True, + min_value=-100, + max_value=100, + min_num_dims=1, + max_num_dims=3, ) ) - axis = draw(helpers.get_axis(shape=shape, force_int=True)) - dtype1, values, dtype2 = draw( - helpers.get_castable_dtype( - draw(helpers.get_dtypes("float")), dtype[0], values[0] + dtype2, x2 = draw( + helpers.dtype_and_values( + available_dtypes=["int32", "int64"], + num_arrays=1, + shared_dtype=True, + min_value=-100, + max_value=100, + min_num_dims=1, + max_num_dims=3, ) ) - return [dtype1], [values], axis, dtype2 + return (dtype1[0], dtype2[0]), (x1[0], x2[0]) -# nansum -@handle_test( - fn_tree="functional.ivy.experimental.nansum", - dtype_x_axis_dtype=_get_castable_dtypes_values(allow_nan=True), - keep_dims=st.booleans(), - test_gradients=st.just(False), -) -def test_nansum( - *, dtype_x_axis_dtype, keep_dims, test_flags, on_device, fn_name, backend_fw -): - input_dtype, x, axis, dtype = dtype_x_axis_dtype - helpers.test_function( - input_dtypes=input_dtype, - test_flags=test_flags, - backend_to_test=backend_fw, - on_device=on_device, - fn_name=fn_name, - x=x[0], - axis=axis, - keepdims=keep_dims, - dtype=dtype, - ) +# --- Main --- # +# ------------ # -# isclose +# allclose @handle_test( - fn_tree="functional.ivy.experimental.isclose", + fn_tree="functional.ivy.experimental.allclose", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), large_abs_safety_factor=4, @@ -285,50 +211,58 @@ def test_nansum( ), equal_nan=st.booleans(), test_gradients=st.just(False), + test_with_out=st.just(False), ) -def test_isclose( - *, dtype_and_x, rtol, atol, equal_nan, test_flags, backend_fw, fn_name, on_device +def test_allclose( + dtype_and_x, rtol, atol, equal_nan, test_flags, backend_fw, fn_name, on_device ): input_dtype, x = dtype_and_x helpers.test_function( input_dtypes=input_dtype, test_flags=test_flags, - on_device=on_device, backend_to_test=backend_fw, fn_name=fn_name, - a=x[0], - b=x[1], + on_device=on_device, + x1=x[0], + x2=x[1], rtol=rtol, atol=atol, equal_nan=equal_nan, ) -# allclose @handle_test( - fn_tree="functional.ivy.experimental.allclose", + fn_tree="functional.ivy.experimental.binarizer", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - large_abs_safety_factor=4, - small_abs_safety_factor=4, - safety_factor_scale="log", - num_arrays=2, - shared_dtype=True, - allow_nan=True, - ), - rtol=st.floats( - min_value=1e-05, max_value=1e-01, exclude_min=True, exclude_max=True + available_dtypes=helpers.get_dtypes("numeric") ), - atol=st.floats( - min_value=1e-08, max_value=1e-01, exclude_min=True, exclude_max=True + threshold=helpers.floats(), + container_flags=st.just([False]), +) +def test_binarizer( + *, dtype_and_x, threshold, test_flags, backend_fw, fn_name, on_device +): + input_dtype, x = dtype_and_x + helpers.test_function( + input_dtypes=input_dtype, + test_flags=test_flags, + backend_to_test=backend_fw, + fn_name=fn_name, + on_device=on_device, + x=x[0], + threshold=threshold, + ) + + +# conj +@handle_test( + fn_tree="conj", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("real_and_complex") ), - equal_nan=st.booleans(), - test_gradients=st.just(False), test_with_out=st.just(False), ) -def test_allclose( - dtype_and_x, rtol, atol, equal_nan, test_flags, backend_fw, fn_name, on_device -): +def test_conj(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): input_dtype, x = dtype_and_x helpers.test_function( input_dtypes=input_dtype, @@ -336,62 +270,64 @@ def test_allclose( backend_to_test=backend_fw, fn_name=fn_name, on_device=on_device, - x1=x[0], - x2=x[1], - rtol=rtol, - atol=atol, - equal_nan=equal_nan, + x=x[0], ) -# fix +# copysign @handle_test( - fn_tree="functional.ivy.experimental.fix", - dtype_and_x=helpers.dtype_and_values( + fn_tree="functional.ivy.experimental.copysign", + dtype_x1_x2=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - min_num_dims=1, - max_num_dims=3, - min_dim_size=1, - max_dim_size=3, + num_arrays=2, + min_num_dims=0, + allow_nan=False, + shared_dtype=False, ), test_gradients=st.just(False), ) -def test_fix(dtype_and_x, test_flags, backend_fw, fn_name, on_device): - input_dtype, x = dtype_and_x +def test_copysign(dtype_x1_x2, test_flags, backend_fw, fn_name, on_device): + (x1_dtype, x2_dtype), (x1, x2) = dtype_x1_x2 helpers.test_function( - input_dtypes=input_dtype, + input_dtypes=[x1_dtype, x2_dtype], test_flags=test_flags, + on_device=on_device, backend_to_test=backend_fw, fn_name=fn_name, - on_device=on_device, - x=x[0], + x1=x1, + x2=x2, ) -# nextafter +# count_nonzero @handle_test( - fn_tree="functional.ivy.experimental.nextafter", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=["float32", "float64"], - num_arrays=2, - shared_dtype=True, - min_value=-10, - max_value=10, + fn_tree="functional.ivy.experimental.count_nonzero", + dtype_values_axis=_get_dtype_values_axis_for_count_nonzero( + in_available_dtypes="integer", + out_available_dtypes="integer", min_num_dims=1, - max_num_dims=3, + max_num_dims=10, + min_dim_size=1, + max_dim_size=10, ), + keepdims=st.booleans(), + test_with_out=st.just(False), test_gradients=st.just(False), ) -def test_nextafter(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): - input_dtype, x = dtype_and_x +def test_count_nonzero( + *, dtype_values_axis, keepdims, test_flags, on_device, fn_name, backend_fw +): + i_o_dtype, a, axis = dtype_values_axis helpers.test_function( - input_dtypes=input_dtype, + input_dtypes=i_o_dtype[0], test_flags=test_flags, + on_device=on_device, backend_to_test=backend_fw, fn_name=fn_name, - on_device=on_device, - x1=x[0], - x2=x[1], + a=a[0], + axis=axis, + keepdims=keepdims, + dtype=i_o_dtype[1][0], ) @@ -445,108 +381,114 @@ def test_diff( ) -# zeta +# digamma @handle_test( - fn_tree="functional.ivy.experimental.zeta", + fn_tree="functional.ivy.experimental.digamma", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - shared_dtype=True, min_value=-10, max_value=10, min_num_dims=1, max_num_dims=3, - ), - test_gradients=st.just(False), + min_dim_size=1, + max_dim_size=3, + ).filter(lambda x: "bfloat16" not in x[0] and "float16" not in x[0]), + ground_truth_backend="tensorflow", ) -def test_zeta( +def test_digamma( dtype_and_x, + backend_fw, test_flags, fn_name, on_device, - backend_fw, ): input_dtype, x = dtype_and_x helpers.test_function( input_dtypes=input_dtype, - test_flags=test_flags, backend_to_test=backend_fw, - on_device=on_device, + test_flags=test_flags, fn_name=fn_name, - rtol_=1e-02, - atol_=1e-02, + on_device=on_device, x=x[0], - q=x[1], ) -# gradient +# fix @handle_test( - fn_tree="functional.ivy.experimental.gradient", - dtype_n_x_n_axis=helpers.dtype_values_axis( - available_dtypes=("float32", "float16", "float64"), + fn_tree="functional.ivy.experimental.fix", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), min_num_dims=1, max_num_dims=3, - min_dim_size=2, - max_dim_size=4, - valid_axis=True, - force_int_axis=True, - ), - spacing=helpers.ints( - min_value=-3, - max_value=3, + min_dim_size=1, + max_dim_size=3, ), - test_with_out=st.just(False), test_gradients=st.just(False), ) -def test_gradient( - *, dtype_n_x_n_axis, spacing, test_flags, backend_fw, fn_name, on_device -): - input_dtype, x, axis = dtype_n_x_n_axis +def test_fix(dtype_and_x, test_flags, backend_fw, fn_name, on_device): + input_dtype, x = dtype_and_x helpers.test_function( input_dtypes=input_dtype, test_flags=test_flags, backend_to_test=backend_fw, - on_device=on_device, fn_name=fn_name, + on_device=on_device, x=x[0], - spacing=spacing, - axis=axis, ) -# xlogy +# float_power @handle_test( - fn_tree="functional.ivy.experimental.xlogy", + fn_tree="functional.ivy.experimental.float_power", + dtype_and_x=_float_power_helper(), + test_gradients=st.just(False), +) +def test_float_power(dtype_and_x, test_flags, backend_fw, fn_name, on_device): + input_dtypes, x = dtype_and_x + helpers.test_function( + input_dtypes=input_dtypes, + test_flags=test_flags, + on_device=on_device, + backend_to_test=backend_fw, + fn_name=fn_name, + x1=x[0], + x2=x[1], + rtol_=1e-1, + atol_=1e-1, + ) + + +# fmax +@handle_test( + fn_tree="functional.ivy.experimental.fmax", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float_and_complex"), + available_dtypes=helpers.get_dtypes("valid"), + large_abs_safety_factor=4, + small_abs_safety_factor=4, + safety_factor_scale="log", num_arrays=2, - min_value=-10, - max_value=10, - min_num_dims=1, - max_num_dims=3, ), test_gradients=st.just(False), ) -def test_xlogy(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): +def test_fmax(dtype_and_x, test_flags, backend_fw, fn_name, on_device): input_dtype, x = dtype_and_x helpers.test_function( input_dtypes=input_dtype, test_flags=test_flags, + on_device=on_device, backend_to_test=backend_fw, fn_name=fn_name, - on_device=on_device, - x=x[0], - y=x[1], + x1=x[0], + x2=x[1], ) -# hypot +# frexp @handle_test( - fn_tree="functional.ivy.experimental.hypot", + fn_tree="functional.ivy.experimental.frexp", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, + available_dtypes=["float32", "float64"], + num_arrays=1, shared_dtype=True, min_value=-100, max_value=100, @@ -555,7 +497,7 @@ def test_xlogy(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): ), test_gradients=st.just(False), ) -def test_hypot(dtype_and_x, test_flags, backend_fw, fn_name, on_device): +def test_frexp(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): input_dtype, x = dtype_and_x helpers.test_function( input_dtypes=input_dtype, @@ -563,44 +505,60 @@ def test_hypot(dtype_and_x, test_flags, backend_fw, fn_name, on_device): backend_to_test=backend_fw, fn_name=fn_name, on_device=on_device, - atol_=1e-2, - x1=x[0], - x2=x[1], + x=x[0], ) +# gradient @handle_test( - fn_tree="functional.ivy.experimental.binarizer", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric") + fn_tree="functional.ivy.experimental.gradient", + dtype_n_x_n_axis=helpers.dtype_values_axis( + available_dtypes=("float32", "float16", "float64"), + min_num_dims=1, + max_num_dims=3, + min_dim_size=2, + max_dim_size=4, + valid_axis=True, + force_int_axis=True, ), - threshold=helpers.floats(), - container_flags=st.just([False]), + spacing=helpers.ints( + min_value=-3, + max_value=3, + ), + test_with_out=st.just(False), + test_gradients=st.just(False), ) -def test_binarizer( - *, dtype_and_x, threshold, test_flags, backend_fw, fn_name, on_device +def test_gradient( + *, dtype_n_x_n_axis, spacing, test_flags, backend_fw, fn_name, on_device ): - input_dtype, x = dtype_and_x + input_dtype, x, axis = dtype_n_x_n_axis helpers.test_function( input_dtypes=input_dtype, test_flags=test_flags, backend_to_test=backend_fw, - fn_name=fn_name, on_device=on_device, + fn_name=fn_name, x=x[0], - threshold=threshold, + spacing=spacing, + axis=axis, ) -# conj +# hypot @handle_test( - fn_tree="conj", + fn_tree="functional.ivy.experimental.hypot", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("real_and_complex") + available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, + shared_dtype=True, + min_value=-100, + max_value=100, + min_num_dims=1, + max_num_dims=3, ), - test_with_out=st.just(False), + test_gradients=st.just(False), ) -def test_conj(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): +def test_hypot(dtype_and_x, test_flags, backend_fw, fn_name, on_device): input_dtype, x = dtype_and_x helpers.test_function( input_dtypes=input_dtype, @@ -608,36 +566,49 @@ def test_conj(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): backend_to_test=backend_fw, fn_name=fn_name, on_device=on_device, - x=x[0], + atol_=1e-2, + x1=x[0], + x2=x[1], ) -# ldexp -@st.composite -def ldexp_args(draw): - dtype1, x1 = draw( - helpers.dtype_and_values( - available_dtypes=["float32", "float64"], - num_arrays=1, - shared_dtype=True, - min_value=-100, - max_value=100, - min_num_dims=1, - max_num_dims=3, - ) - ) - dtype2, x2 = draw( - helpers.dtype_and_values( - available_dtypes=["int32", "int64"], - num_arrays=1, - shared_dtype=True, - min_value=-100, - max_value=100, - min_num_dims=1, - max_num_dims=3, - ) +# isclose +@handle_test( + fn_tree="functional.ivy.experimental.isclose", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + large_abs_safety_factor=4, + small_abs_safety_factor=4, + safety_factor_scale="log", + num_arrays=2, + shared_dtype=True, + allow_nan=True, + ), + rtol=st.floats( + min_value=1e-05, max_value=1e-01, exclude_min=True, exclude_max=True + ), + atol=st.floats( + min_value=1e-08, max_value=1e-01, exclude_min=True, exclude_max=True + ), + equal_nan=st.booleans(), + test_gradients=st.just(False), +) +def test_isclose( + *, dtype_and_x, rtol, atol, equal_nan, test_flags, backend_fw, fn_name, on_device +): + input_dtype, x = dtype_and_x + helpers.test_function( + input_dtypes=input_dtype, + test_flags=test_flags, + on_device=on_device, + backend_to_test=backend_fw, + fn_name=fn_name, + a=x[0], + b=x[1], + rtol=rtol, + atol=atol, + equal_nan=equal_nan, ) - return (dtype1[0], dtype2[0]), (x1[0], x2[0]) @handle_test( @@ -658,55 +629,6 @@ def test_ldexp(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): ) -@st.composite -def _lerp_data_helper(draw): - mixed_fn_compos = draw(st.booleans()) - is_torch_backend = ivy.current_backend_str() == "torch" - - kwargs = { - "shared_dtype": True, - "large_abs_safety_factor": 2.5, - "small_abs_safety_factor": 2.5, - "safety_factor_scale": "log", - "allow_nan": False, - "allow_inf": False, - } - - if is_torch_backend and not mixed_fn_compos: - dtype1, start_end = draw( - helpers.dtype_and_values( - available_dtypes=( - helpers.get_dtypes("numeric", mixed_fn_compos=mixed_fn_compos) - ), - num_arrays=2, - **kwargs, - ) - ) - dtype2, weight = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes( - "integer", mixed_fn_compos=mixed_fn_compos - ), - num_arrays=1, - **kwargs, - ) - ) - input_dtypes = dtype1 + dtype2 - inputs = start_end + weight - else: - input_dtypes, inputs = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes( - "valid", mixed_fn_compos=mixed_fn_compos - ), - num_arrays=3, - **kwargs, - ) - ) - - return input_dtypes, inputs[0], inputs[1], inputs[2] - - # lerp @handle_test( fn_tree="functional.ivy.experimental.lerp", @@ -736,28 +658,30 @@ def test_lerp( ) -# frexp +# lgamma @handle_test( - fn_tree="functional.ivy.experimental.frexp", + fn_tree="functional.ivy.experimental.lgamma", dtype_and_x=helpers.dtype_and_values( - available_dtypes=["float32", "float64"], - num_arrays=1, - shared_dtype=True, - min_value=-100, - max_value=100, - min_num_dims=1, - max_num_dims=3, + available_dtypes=helpers.get_dtypes("float"), + safety_factor_scale="log", ), test_gradients=st.just(False), ) -def test_frexp(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): +def test_lgamma( + *, + dtype_and_x, + test_flags, + backend_fw, + fn_name, + on_device, +): input_dtype, x = dtype_and_x helpers.test_function( input_dtypes=input_dtype, - test_flags=test_flags, backend_to_test=backend_fw, - fn_name=fn_name, + test_flags=test_flags, on_device=on_device, + fn_name=fn_name, x=x[0], ) @@ -792,58 +716,79 @@ def test_modf( ) -# digamma +# nansum @handle_test( - fn_tree="functional.ivy.experimental.digamma", + fn_tree="functional.ivy.experimental.nansum", + dtype_x_axis_dtype=_get_castable_dtypes_values(allow_nan=True), + keep_dims=st.booleans(), + test_gradients=st.just(False), +) +def test_nansum( + *, dtype_x_axis_dtype, keep_dims, test_flags, on_device, fn_name, backend_fw +): + input_dtype, x, axis, dtype = dtype_x_axis_dtype + helpers.test_function( + input_dtypes=input_dtype, + test_flags=test_flags, + backend_to_test=backend_fw, + on_device=on_device, + fn_name=fn_name, + x=x[0], + axis=axis, + keepdims=keep_dims, + dtype=dtype, + ) + + +# nextafter +@handle_test( + fn_tree="functional.ivy.experimental.nextafter", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=["float32", "float64"], + num_arrays=2, + shared_dtype=True, min_value=-10, max_value=10, min_num_dims=1, max_num_dims=3, - min_dim_size=1, - max_dim_size=3, - ).filter(lambda x: "bfloat16" not in x[0] and "float16" not in x[0]), - ground_truth_backend="tensorflow", + ), + test_gradients=st.just(False), ) -def test_digamma( - dtype_and_x, - backend_fw, - test_flags, - fn_name, - on_device, -): +def test_nextafter(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): input_dtype, x = dtype_and_x helpers.test_function( input_dtypes=input_dtype, - backend_to_test=backend_fw, test_flags=test_flags, + backend_to_test=backend_fw, fn_name=fn_name, on_device=on_device, - x=x[0], + x1=x[0], + x2=x[1], ) -@st.composite -def _sparsify_tensor_stg(draw): - dtype, tensor, shape = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - ret_shape=True, - min_num_dims=1, - min_dim_size=1, - min_value=10, - ) +# sinc +@handle_test( + fn_tree="functional.ivy.experimental.sinc", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + large_abs_safety_factor=4, + small_abs_safety_factor=4, + ), + test_gradients=st.just(False), +) +def test_sinc(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): + input_dtype, x = dtype_and_x + helpers.test_function( + input_dtypes=input_dtype, + test_flags=test_flags, + atol_=1e-02, + on_device=on_device, + backend_to_test=backend_fw, + fn_name=fn_name, + x=x[0], ) - size = 1 - for dim in shape: - size *= dim - - card = draw(st.integers(min_value=1, max_value=size)) - - return dtype, tensor[0], card - # sparsify_tensor @handle_test( @@ -868,3 +813,64 @@ def test_sparsify_tensor( tensor=tensor, card=card, ) + + +# xlogy +@handle_test( + fn_tree="functional.ivy.experimental.xlogy", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float_and_complex"), + num_arrays=2, + min_value=-10, + max_value=10, + min_num_dims=1, + max_num_dims=3, + ), + test_gradients=st.just(False), +) +def test_xlogy(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): + input_dtype, x = dtype_and_x + helpers.test_function( + input_dtypes=input_dtype, + test_flags=test_flags, + backend_to_test=backend_fw, + fn_name=fn_name, + on_device=on_device, + x=x[0], + y=x[1], + ) + + +# zeta +@handle_test( + fn_tree="functional.ivy.experimental.zeta", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, + shared_dtype=True, + min_value=-10, + max_value=10, + min_num_dims=1, + max_num_dims=3, + ), + test_gradients=st.just(False), +) +def test_zeta( + dtype_and_x, + test_flags, + fn_name, + on_device, + backend_fw, +): + input_dtype, x = dtype_and_x + helpers.test_function( + input_dtypes=input_dtype, + test_flags=test_flags, + backend_to_test=backend_fw, + on_device=on_device, + fn_name=fn_name, + rtol_=1e-02, + atol_=1e-02, + x=x[0], + q=x[1], + ) diff --git a/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_general.py b/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_general.py index a46a6aa651c58..f044edfbf8fe6 100644 --- a/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_general.py +++ b/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_general.py @@ -7,6 +7,10 @@ from ivy_tests.test_ivy.helpers import handle_test +# --- Helpers --- # +# --------------- # + + @st.composite def _reduce_helper(draw): # ToDo: remove the filtering when supported dtypes are fixed for mixed functions @@ -35,6 +39,10 @@ def _reduce_helper(draw): return dtype, operand[0], init_value[0], func, axes +# --- Main --- # +# ------------ # + + # reduce @handle_test( fn_tree="functional.ivy.experimental.reduce", diff --git a/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_linalg.py b/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_linalg.py index 12eb222261611..ffdd059f5d1ca 100644 --- a/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_linalg.py +++ b/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_linalg.py @@ -11,6 +11,10 @@ import ivy +# --- Helpers --- # +# --------------- # + + @st.composite def _generate_diag_args(draw): x_shape = draw( @@ -184,228 +188,484 @@ def _generate_eigh_tridiagonal_args(draw): return dtype, alpha, beta, eigvals_only, select, select_range, tol -# eigh_tridiagonal -@handle_test( - fn_tree="eigh_tridiagonal", - args_packet=_generate_eigh_tridiagonal_args(), - ground_truth_backend="numpy", - test_gradients=st.just(False), -) -def test_eigh_tridiagonal( - *, - args_packet, - test_flags, - backend_fw, - fn_name, - on_device, -): - dtype, alpha, beta, eigvals_only, select, select_range, tol = args_packet - test_flags.with_out = False - results = helpers.test_function( - test_flags=test_flags, - backend_to_test=backend_fw, - fn_name=fn_name, - on_device=on_device, - rtol_=1e-2, - atol_=1e-2, - input_dtypes=dtype, - alpha=alpha[0], - beta=beta[0], - eigvals_only=eigvals_only, - select=select, - select_range=select_range, - tol=tol, - test_values=eigvals_only, - return_flat_np_arrays=True, +# multi_dot +@st.composite +def _generate_multi_dot_dtype_and_arrays(draw): + input_dtype = [draw(st.sampled_from(draw(helpers.get_dtypes("numeric"))))] + matrices_dims = draw( + st.lists(st.integers(min_value=2, max_value=10), min_size=4, max_size=4) ) - if results is None: - return - ret_np_flat, ret_np_from_gt_flat = results - reconstructed_np = None - for i in range(len(ret_np_flat) // 2): - eigenvalue = ret_np_flat[i] - eigenvector = ret_np_flat[len(ret_np_flat) // 2 + i] - if reconstructed_np is not None: - reconstructed_np += eigenvalue * np.matmul( - eigenvector.reshape(1, -1), eigenvector.reshape(-1, 1) - ) - else: - reconstructed_np = eigenvalue * np.matmul( - eigenvector.reshape(1, -1), eigenvector.reshape(-1, 1) - ) + shape_1 = (matrices_dims[0], matrices_dims[1]) + shape_2 = (matrices_dims[1], matrices_dims[2]) + shape_3 = (matrices_dims[2], matrices_dims[3]) - reconstructed_from_np = None - for i in range(len(ret_np_from_gt_flat) // 2): - eigenvalue = ret_np_from_gt_flat[i] - eigenvector = ret_np_from_gt_flat[len(ret_np_flat) // 2 + i] - if reconstructed_from_np is not None: - reconstructed_from_np += eigenvalue * np.matmul( - eigenvector.reshape(1, -1), eigenvector.reshape(-1, 1) - ) - else: - reconstructed_from_np = eigenvalue * np.matmul( - eigenvector.reshape(1, -1), eigenvector.reshape(-1, 1) - ) - # value test - helpers.assert_all_close( - reconstructed_np, - reconstructed_from_np, - rtol=1e-1, - atol=1e-2, - backend=backend_fw, - ground_truth_backend=test_flags.ground_truth_backend, + matrix_1 = draw( + helpers.dtype_and_values( + shape=shape_1, + dtype=input_dtype, + min_value=-10, + max_value=10, + ) + ) + matrix_2 = draw( + helpers.dtype_and_values( + shape=shape_2, + dtype=input_dtype, + min_value=-10, + max_value=10, + ) + ) + matrix_3 = draw( + helpers.dtype_and_values( + shape=shape_3, + dtype=input_dtype, + min_value=-10, + max_value=10, + ) ) + return input_dtype, [matrix_1[1][0], matrix_2[1][0], matrix_3[1][0]] -@handle_test( - fn_tree="functional.ivy.experimental.diagflat", - args_packet=_generate_diag_args(), - test_gradients=st.just(False), -) -def test_diagflat(*, test_flags, backend_fw, fn_name, args_packet, on_device): - dtype_x, offset, dtype_padding_value, align, num_rows, num_cols = args_packet - - x_dtype, x = dtype_x - padding_value_dtype, padding_value = dtype_padding_value - padding_value = padding_value[0][0] - helpers.test_function( - input_dtypes=x_dtype + ["int64"] + padding_value_dtype, - test_flags=test_flags, - backend_to_test=backend_fw, - fn_name=fn_name, - x=x[0], - offset=offset, - padding_value=padding_value, - align=align, - num_rows=num_rows, - num_cols=num_cols, - on_device=on_device, - atol_=1e-01, - rtol_=1 / 64, +@st.composite +def _get_dtype_value1_value2_cov( + draw, + available_dtypes, + min_num_dims, + max_num_dims, + min_dim_size, + max_dim_size, + abs_smallest_val=None, + min_value=None, + max_value=None, + allow_inf=False, + exclude_min=False, + exclude_max=False, + large_abs_safety_factor=4, + small_abs_safety_factor=4, + safety_factor_scale="log", +): + shape = draw( + helpers.get_shape( + allow_none=False, + min_num_dims=min_num_dims, + max_num_dims=max_num_dims, + min_dim_size=min_dim_size, + max_dim_size=max_dim_size, + ) ) + dtype = draw(st.sampled_from(draw(available_dtypes))) -@handle_test( - fn_tree="functional.ivy.experimental.kron", - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - min_num_dims=2, - max_num_dims=2, - min_dim_size=1, - max_dim_size=10, - num_arrays=2, - shared_dtype=True, - ), - test_gradients=st.just(False), -) -def test_kron(*, dtype_x, test_flags, backend_fw, fn_name, on_device): - dtype, x = dtype_x - helpers.test_function( - input_dtypes=dtype, - test_flags=test_flags, - on_device=on_device, - backend_to_test=backend_fw, - fn_name=fn_name, - a=x[0], - b=x[1], - ) + values = [] + for i in range(2): + values.append( + draw( + helpers.array_values( + dtype=dtype, + shape=shape, + abs_smallest_val=abs_smallest_val, + min_value=min_value, + max_value=max_value, + allow_inf=allow_inf, + exclude_min=exclude_min, + exclude_max=exclude_max, + large_abs_safety_factor=large_abs_safety_factor, + small_abs_safety_factor=small_abs_safety_factor, + safety_factor_scale=safety_factor_scale, + ) + ) + ) + value1, value2 = values[0], values[1] -# matrix_exp -@handle_test( - fn_tree="functional.ivy.experimental.matrix_exp", - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - min_num_dims=2, - max_num_dims=2, - min_dim_size=2, - max_dim_size=2, - min_value=-100, - max_value=100, - allow_nan=False, - shared_dtype=True, - ), - test_gradients=st.just(False), -) -def test_matrix_exp(dtype_x, test_flags, backend_fw, fn_name, on_device): - dtype, x = dtype_x - helpers.test_function( - input_dtypes=dtype, - test_flags=test_flags, - on_device=on_device, - backend_to_test=backend_fw, - fn_name=fn_name, - x=x[0], + # modifiers: rowVar, bias, ddof + rowVar = draw(st.booleans()) + bias = draw(st.booleans()) + ddof = draw(helpers.ints(min_value=0, max_value=1)) + + numVals = None + if rowVar is False: + numVals = -1 if numVals == 0 else 0 + else: + numVals = 0 if len(shape) == 1 else -1 + + fweights = draw( + helpers.array_values( + dtype="int64", + shape=shape[numVals], + abs_smallest_val=1, + min_value=1, + max_value=10, + allow_inf=False, + ) ) + aweights = draw( + helpers.array_values( + dtype="float64", + shape=shape[numVals], + abs_smallest_val=1, + min_value=1, + max_value=10, + allow_inf=False, + small_abs_safety_factor=1, + ) + ) -@handle_test( - fn_tree="functional.ivy.experimental.eig", - dtype_x=helpers.dtype_and_values( - available_dtypes=( - ivy.float32, - ivy.float64, - ivy.int32, - ivy.int64, - ivy.complex64, - ivy.complex128, - ), - min_num_dims=2, - max_num_dims=3, - min_dim_size=10, - max_dim_size=10, - min_value=1.0, - max_value=1.0e5, - shared_dtype=True, - ), - test_with_out=st.just(False), - test_gradients=st.just(False), -) -def test_eig(dtype_x, test_flags, backend_fw, fn_name): - dtype, x = dtype_x - helpers.test_function( - input_dtypes=dtype, - test_flags=test_flags, - backend_to_test=backend_fw, - fn_name=fn_name, - test_values=False, - x=x[0], + return [dtype], value1, value2, rowVar, bias, ddof, fweights, aweights + + +# intialize tucker +@st.composite +def _initialize_tucker_data(draw): + x_dtype, x, shape = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_num_dims=2, + max_num_dims=5, + min_dim_size=2, + max_dim_size=5, + min_value=0.1, + max_value=10.0, + ret_shape=True, + ) + ) + dims = len(shape) + rank = [] + for i in range(dims): + rank.append(draw(helpers.ints(min_value=1, max_value=shape[i]))) + n_modes = draw(helpers.ints(min_value=2, max_value=dims)) + modes = [*range(dims)][:n_modes] + mask_dtype, mask = draw( + helpers.dtype_and_values( + dtype=["int32"], + shape=shape, + min_value=0, + max_value=1, + ) + ) + svd_mask_repeats = draw(helpers.ints(min_value=0, max_value=3)) + non_negative = draw(st.booleans()) + return ( + x_dtype + mask_dtype, + x[0], + rank, + modes, + non_negative, + mask[0], + svd_mask_repeats, + ) + + +@st.composite +def _khatri_rao_data(draw): + num_matrices = draw(helpers.ints(min_value=2, max_value=4)) + m = draw(helpers.ints(min_value=1, max_value=5)) + input_dtypes, input = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + num_arrays=num_matrices, + min_dim_size=m, + max_dim_size=m, + min_num_dims=2, + max_num_dims=2, + large_abs_safety_factor=20, + small_abs_safety_factor=20, + safety_factor_scale="log", + ) + ) + skip_matrix = draw(helpers.ints(min_value=0, max_value=len(input) - 1)) + weight_dtype, weights = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("integer"), shape=(m,) + ) + ) + mask_dtype, mask = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("integer"), + min_value=0, + max_value=1, + shape=(m,), + ) + ) + return ( + input_dtypes + weight_dtype + mask_dtype, + input, + skip_matrix, + weights[0], + mask[0], + ) + + +@st.composite +def _kronecker_data(draw): + num_arrays = draw(helpers.ints(min_value=2, max_value=5)) + x_dtype, x = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + num_arrays=num_arrays, + large_abs_safety_factor=20, + small_abs_safety_factor=20, + safety_factor_scale="log", + shared_dtype=True, + min_num_dims=2, + max_num_dims=2, + ) + ) + skip_matrix = draw( + st.lists(st.integers(min_value=0, max_value=num_arrays - 1), unique=True) + ) + reverse = draw(st.booleans()) + return x_dtype, x, skip_matrix, reverse + + +# truncated svd +@st.composite +def _make_svd_nn_data(draw): + x_dtype, x, shape = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_num_dims=2, + max_num_dims=2, + min_dim_size=2, + max_dim_size=5, + min_value=1.0, + max_value=10.0, + ret_shape=True, + ) + ) + n, m = shape + _, U = draw( + helpers.dtype_and_values( + dtype=x_dtype, + available_dtypes=helpers.get_dtypes("float"), + shape=(n, m), + min_value=1.0, + max_value=10.0, + ) + ) + _, S = draw( + helpers.dtype_and_values( + dtype=x_dtype, + shape=(m,), + min_value=1.0, + max_value=10.0, + ) + ) + _, V = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + shape=(m, m), + min_value=1.0, + max_value=10.0, + ) + ) + nntype = draw(st.sampled_from(["nndsvd", "nndsvda"])) + return x_dtype, x[0], U[0], S[0], V[0], nntype + + +@st.composite +def _mode_dot_data(draw): + shape_t1 = draw(helpers.get_shape(min_num_dims=2, max_num_dims=5)) + mode = draw(helpers.ints(min_value=0, max_value=len(shape_t1) - 1)) + mode_dimsize = shape_t1[mode] + t1_dtype, t1 = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + shape=shape_t1, + large_abs_safety_factor=20, + small_abs_safety_factor=20, + safety_factor_scale="log", + ) + ) + t2_rows = draw(helpers.ints(min_value=1, max_value=4)) + shape_t2 = draw(st.sampled_from([(mode_dimsize,), (t2_rows, mode_dimsize)])) + t2_dtype, t2 = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + shape=shape_t2, + large_abs_safety_factor=20, + small_abs_safety_factor=20, + safety_factor_scale="log", + ) + ) + return t1_dtype + t2_dtype, t1[0], t2[0], mode + + +@st.composite +def _multi_mode_dot_data(draw): + t1_dtype, t1, shape_t1 = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + ret_shape=True, + min_num_dims=2, + large_abs_safety_factor=20, + small_abs_safety_factor=20, + safety_factor_scale="log", + ) + ) + modes = [*range(len(shape_t1))] + skip = draw(st.lists(helpers.ints(min_value=0, max_value=len(shape_t1) - 1))) + t2 = [] + t2_dtype = [] + for i in modes: + mode_dimsize = shape_t1[i] + rows = draw(helpers.ints(min_value=1, max_value=4)) + shape = draw(st.sampled_from([(mode_dimsize,), (rows, mode_dimsize)])) + mat_or_vec_dtype, mat_or_vec = draw( + helpers.dtype_and_values( + dtype=t1_dtype, + shape=shape, + large_abs_safety_factor=20, + small_abs_safety_factor=20, + safety_factor_scale="log", + ) + ) + t2.append(mat_or_vec[0]) + t2_dtype.append(mat_or_vec_dtype[0]) + + return t1_dtype + t2_dtype, t1[0], t2, modes, skip + + +# partial tucker +@st.composite +def _partial_tucker_data(draw): + x_dtype, x, shape = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_num_dims=2, + max_num_dims=5, + min_dim_size=2, + max_dim_size=5, + min_value=0.1, + max_value=10.0, + ret_shape=True, + ) + ) + dims = len(shape) + rank = [] + for i in range(dims): + rank.append(draw(helpers.ints(min_value=1, max_value=shape[i]))) + n_modes = draw(helpers.ints(min_value=2, max_value=dims)) + modes = [*range(dims)][:n_modes] + mask_dtype, mask = draw( + helpers.dtype_and_values( + dtype=["int32"], + shape=shape, + min_value=0, + max_value=1, + ) + ) + svd_mask_repeats = draw(helpers.ints(min_value=0, max_value=3)) + n_iter_max = draw(helpers.ints(min_value=1, max_value=7)) + tol = draw(helpers.floats(min_value=1e-5, max_value=1e-1)) + return ( + x_dtype + mask_dtype, + x[0], + rank, + modes, + n_iter_max, + mask[0], + svd_mask_repeats, + tol, + ) + + +# truncated svd +@st.composite +def _truncated_svd_data(draw): + x_dtype, x, shape = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_num_dims=2, + max_num_dims=2, + min_dim_size=2, + max_dim_size=5, + min_value=0.1, + max_value=10.0, + ret_shape=True, + ) + ) + uv = draw(st.booleans()) + n_eigen = draw(helpers.ints(min_value=1, max_value=max(shape[-2:]))) + return x_dtype, x[0], uv, n_eigen + + +# tucker +@st.composite +def _tucker_data(draw): + x_dtype, x, shape = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_num_dims=2, + max_num_dims=4, + min_dim_size=2, + max_dim_size=3, + min_value=0.1, + max_value=10.0, + ret_shape=True, + ) + ) + dims = len(shape) + rank = [] + for i in range(dims): + rank.append(draw(helpers.ints(min_value=1, max_value=shape[i]))) + mask_dtype, mask = draw( + helpers.dtype_and_values( + dtype=["int32"], + shape=shape, + min_value=0, + max_value=1, + ) + ) + svd_mask_repeats = draw(helpers.ints(min_value=0, max_value=1)) + n_iter_max = draw(helpers.ints(min_value=0, max_value=2)) + tol = draw(helpers.floats(min_value=1e-5, max_value=1e-1)) + init = draw(st.sampled_from(["svd", "random"])) + fixed_factors = draw(st.booleans()) + if fixed_factors: + _, core = draw( + helpers.dtype_and_values( + dtype=x_dtype, + min_value=0.1, + max_value=10.0, + shape=rank, + ) + ) + factors = [] + for i in range(dims): + _, factor = draw( + helpers.dtype_and_values( + dtype=x_dtype, + min_value=0.1, + max_value=10.0, + shape=(shape[i], rank[i]), + ) + ) + factors.append(factor[0]) + fixed_factors = draw( + st.lists( + helpers.ints(min_value=0, max_value=dims - 1), unique=True, min_size=1 + ) + ) + rank = [rank[i] for i in range(dims) if i not in fixed_factors] + init = ivy.TuckerTensor((core[0], factors)) + return ( + x_dtype + mask_dtype, + x[0], + rank, + fixed_factors, + init, + n_iter_max, + mask[0], + svd_mask_repeats, + tol, ) -@handle_test( - fn_tree="functional.ivy.experimental.eigvals", - dtype_x=helpers.dtype_and_values( - available_dtypes=( - ivy.float32, - ivy.float64, - ivy.int32, - ivy.int64, - ivy.complex64, - ivy.complex128, - ), - min_num_dims=2, - max_num_dims=3, - min_dim_size=10, - max_dim_size=10, - min_value=1.0, - max_value=1.0e5, - shared_dtype=True, - ), - test_with_out=st.just(False), - test_gradients=st.just(False), -) -def test_eigvals(dtype_x, test_flags, backend_fw, fn_name): - dtype, x = dtype_x - helpers.test_function( - input_dtypes=dtype, - test_flags=test_flags, - backend_to_test=backend_fw, - fn_name=fn_name, - test_values=False, - x=x[0], - ) +# --- Main --- # +# ------------ # @handle_test( @@ -427,185 +687,37 @@ def test_eigvals(dtype_x, test_flags, backend_fw, fn_name): allow_nan=False, shared_dtype=True, ), -) -def test_adjoint(dtype_x, test_flags, backend_fw, fn_name): - dtype, x = dtype_x - helpers.test_function( - input_dtypes=dtype, - test_flags=test_flags, - backend_to_test=backend_fw, - fn_name=fn_name, - x=x[0], - ) - - -# multi_dot -@st.composite -def _generate_multi_dot_dtype_and_arrays(draw): - input_dtype = [draw(st.sampled_from(draw(helpers.get_dtypes("numeric"))))] - matrices_dims = draw( - st.lists(st.integers(min_value=2, max_value=10), min_size=4, max_size=4) - ) - shape_1 = (matrices_dims[0], matrices_dims[1]) - shape_2 = (matrices_dims[1], matrices_dims[2]) - shape_3 = (matrices_dims[2], matrices_dims[3]) - - matrix_1 = draw( - helpers.dtype_and_values( - shape=shape_1, - dtype=input_dtype, - min_value=-10, - max_value=10, - ) - ) - matrix_2 = draw( - helpers.dtype_and_values( - shape=shape_2, - dtype=input_dtype, - min_value=-10, - max_value=10, - ) - ) - matrix_3 = draw( - helpers.dtype_and_values( - shape=shape_3, - dtype=input_dtype, - min_value=-10, - max_value=10, - ) - ) - - return input_dtype, [matrix_1[1][0], matrix_2[1][0], matrix_3[1][0]] - - -@handle_test( - fn_tree="functional.ivy.experimental.multi_dot", - dtype_x=_generate_multi_dot_dtype_and_arrays(), - test_gradients=st.just(False), -) -def test_multi_dot(dtype_x, test_flags, backend_fw, fn_name, on_device): - dtype, x = dtype_x - helpers.test_function( - input_dtypes=dtype, - test_flags=test_flags, - on_device=on_device, - backend_to_test=backend_fw, - fn_name=fn_name, - test_values=True, - x=x, - rtol_=1e-1, - atol_=6e-1, - ) - - -@handle_test( - fn_tree="functional.ivy.experimental.cond", - dtype_x=helpers.cond_data_gen_helper(), - test_with_out=st.just(False), - test_gradients=st.just(False), -) -def test_cond(dtype_x, test_flags, backend_fw, on_device, fn_name): - dtype, x = dtype_x - helpers.test_function( - input_dtypes=dtype, - test_flags=test_flags, - backend_to_test=backend_fw, - on_device=on_device, - fn_name=fn_name, - rtol_=1e-3, - atol_=1e-3, - x=x[0], - p=x[1], - ) - - -@st.composite -def _get_dtype_value1_value2_cov( - draw, - available_dtypes, - min_num_dims, - max_num_dims, - min_dim_size, - max_dim_size, - abs_smallest_val=None, - min_value=None, - max_value=None, - allow_inf=False, - exclude_min=False, - exclude_max=False, - large_abs_safety_factor=4, - small_abs_safety_factor=4, - safety_factor_scale="log", -): - shape = draw( - helpers.get_shape( - allow_none=False, - min_num_dims=min_num_dims, - max_num_dims=max_num_dims, - min_dim_size=min_dim_size, - max_dim_size=max_dim_size, - ) - ) - - dtype = draw(st.sampled_from(draw(available_dtypes))) - - values = [] - for i in range(2): - values.append( - draw( - helpers.array_values( - dtype=dtype, - shape=shape, - abs_smallest_val=abs_smallest_val, - min_value=min_value, - max_value=max_value, - allow_inf=allow_inf, - exclude_min=exclude_min, - exclude_max=exclude_max, - large_abs_safety_factor=large_abs_safety_factor, - small_abs_safety_factor=small_abs_safety_factor, - safety_factor_scale=safety_factor_scale, - ) - ) - ) - - value1, value2 = values[0], values[1] - - # modifiers: rowVar, bias, ddof - rowVar = draw(st.booleans()) - bias = draw(st.booleans()) - ddof = draw(helpers.ints(min_value=0, max_value=1)) - - numVals = None - if rowVar is False: - numVals = -1 if numVals == 0 else 0 - else: - numVals = 0 if len(shape) == 1 else -1 - - fweights = draw( - helpers.array_values( - dtype="int64", - shape=shape[numVals], - abs_smallest_val=1, - min_value=1, - max_value=10, - allow_inf=False, - ) - ) - - aweights = draw( - helpers.array_values( - dtype="float64", - shape=shape[numVals], - abs_smallest_val=1, - min_value=1, - max_value=10, - allow_inf=False, - small_abs_safety_factor=1, - ) +) +def test_adjoint(dtype_x, test_flags, backend_fw, fn_name): + dtype, x = dtype_x + helpers.test_function( + input_dtypes=dtype, + test_flags=test_flags, + backend_to_test=backend_fw, + fn_name=fn_name, + x=x[0], ) - return [dtype], value1, value2, rowVar, bias, ddof, fweights, aweights + +@handle_test( + fn_tree="functional.ivy.experimental.cond", + dtype_x=helpers.cond_data_gen_helper(), + test_with_out=st.just(False), + test_gradients=st.just(False), +) +def test_cond(dtype_x, test_flags, backend_fw, on_device, fn_name): + dtype, x = dtype_x + helpers.test_function( + input_dtypes=dtype, + test_flags=test_flags, + backend_to_test=backend_fw, + on_device=on_device, + fn_name=fn_name, + rtol_=1e-3, + atol_=1e-3, + x=x[0], + p=x[1], + ) # cov @@ -647,88 +759,220 @@ def test_cov(*, dtype_x1_x2_cov, test_flags, backend_fw, fn_name, on_device): ) -@st.composite -def _kronecker_data(draw): - num_arrays = draw(helpers.ints(min_value=2, max_value=5)) - x_dtype, x = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=num_arrays, - large_abs_safety_factor=20, - small_abs_safety_factor=20, - safety_factor_scale="log", - shared_dtype=True, - min_num_dims=2, - max_num_dims=2, - ) - ) - skip_matrix = draw( - st.lists(st.integers(min_value=0, max_value=num_arrays - 1), unique=True) +@handle_test( + fn_tree="functional.ivy.experimental.diagflat", + args_packet=_generate_diag_args(), + test_gradients=st.just(False), +) +def test_diagflat(*, test_flags, backend_fw, fn_name, args_packet, on_device): + dtype_x, offset, dtype_padding_value, align, num_rows, num_cols = args_packet + + x_dtype, x = dtype_x + padding_value_dtype, padding_value = dtype_padding_value + padding_value = padding_value[0][0] + + helpers.test_function( + input_dtypes=x_dtype + ["int64"] + padding_value_dtype, + test_flags=test_flags, + backend_to_test=backend_fw, + fn_name=fn_name, + x=x[0], + offset=offset, + padding_value=padding_value, + align=align, + num_rows=num_rows, + num_cols=num_cols, + on_device=on_device, + atol_=1e-01, + rtol_=1 / 64, ) - reverse = draw(st.booleans()) - return x_dtype, x, skip_matrix, reverse @handle_test( - fn_tree="functional.ivy.experimental.kronecker", - data=_kronecker_data(), - test_instance_method=st.just(False), + fn_tree="functional.ivy.experimental.eig", + dtype_x=helpers.dtype_and_values( + available_dtypes=( + ivy.float32, + ivy.float64, + ivy.int32, + ivy.int64, + ivy.complex64, + ivy.complex128, + ), + min_num_dims=2, + max_num_dims=3, + min_dim_size=10, + max_dim_size=10, + min_value=1.0, + max_value=1.0e5, + shared_dtype=True, + ), + test_with_out=st.just(False), + test_gradients=st.just(False), ) -def test_kronecker(*, data, test_flags, backend_fw, fn_name, on_device): - input_dtypes, input, skip_matrix, reverse = data +def test_eig(dtype_x, test_flags, backend_fw, fn_name): + dtype, x = dtype_x helpers.test_function( + input_dtypes=dtype, + test_flags=test_flags, backend_to_test=backend_fw, + fn_name=fn_name, + test_values=False, + x=x[0], + ) + + +# eigh_tridiagonal +@handle_test( + fn_tree="eigh_tridiagonal", + args_packet=_generate_eigh_tridiagonal_args(), + ground_truth_backend="numpy", + test_gradients=st.just(False), +) +def test_eigh_tridiagonal( + *, + args_packet, + test_flags, + backend_fw, + fn_name, + on_device, +): + dtype, alpha, beta, eigvals_only, select, select_range, tol = args_packet + test_flags.with_out = False + results = helpers.test_function( test_flags=test_flags, + backend_to_test=backend_fw, fn_name=fn_name, on_device=on_device, - rtol_=1e-1, - atol_=1e-1, - input_dtypes=input_dtypes, - x=input, - skip_matrix=skip_matrix, - reverse=reverse, + rtol_=1e-2, + atol_=1e-2, + input_dtypes=dtype, + alpha=alpha[0], + beta=beta[0], + eigvals_only=eigvals_only, + select=select, + select_range=select_range, + tol=tol, + test_values=eigvals_only, + return_flat_np_arrays=True, ) + if results is None: + return + ret_np_flat, ret_np_from_gt_flat = results + reconstructed_np = None + for i in range(len(ret_np_flat) // 2): + eigenvalue = ret_np_flat[i] + eigenvector = ret_np_flat[len(ret_np_flat) // 2 + i] + if reconstructed_np is not None: + reconstructed_np += eigenvalue * np.matmul( + eigenvector.reshape(1, -1), eigenvector.reshape(-1, 1) + ) + else: + reconstructed_np = eigenvalue * np.matmul( + eigenvector.reshape(1, -1), eigenvector.reshape(-1, 1) + ) + reconstructed_from_np = None + for i in range(len(ret_np_from_gt_flat) // 2): + eigenvalue = ret_np_from_gt_flat[i] + eigenvector = ret_np_from_gt_flat[len(ret_np_flat) // 2 + i] + if reconstructed_from_np is not None: + reconstructed_from_np += eigenvalue * np.matmul( + eigenvector.reshape(1, -1), eigenvector.reshape(-1, 1) + ) + else: + reconstructed_from_np = eigenvalue * np.matmul( + eigenvector.reshape(1, -1), eigenvector.reshape(-1, 1) + ) + # value test + helpers.assert_all_close( + reconstructed_np, + reconstructed_from_np, + rtol=1e-1, + atol=1e-2, + backend=backend_fw, + ground_truth_backend=test_flags.ground_truth_backend, + ) -@st.composite -def _khatri_rao_data(draw): - num_matrices = draw(helpers.ints(min_value=2, max_value=4)) - m = draw(helpers.ints(min_value=1, max_value=5)) - input_dtypes, input = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - num_arrays=num_matrices, - min_dim_size=m, - max_dim_size=m, - min_num_dims=2, - max_num_dims=2, - large_abs_safety_factor=20, - small_abs_safety_factor=20, - safety_factor_scale="log", - ) + +@handle_test( + fn_tree="functional.ivy.experimental.eigvals", + dtype_x=helpers.dtype_and_values( + available_dtypes=( + ivy.float32, + ivy.float64, + ivy.int32, + ivy.int64, + ivy.complex64, + ivy.complex128, + ), + min_num_dims=2, + max_num_dims=3, + min_dim_size=10, + max_dim_size=10, + min_value=1.0, + max_value=1.0e5, + shared_dtype=True, + ), + test_with_out=st.just(False), + test_gradients=st.just(False), +) +def test_eigvals(dtype_x, test_flags, backend_fw, fn_name): + dtype, x = dtype_x + helpers.test_function( + input_dtypes=dtype, + test_flags=test_flags, + backend_to_test=backend_fw, + fn_name=fn_name, + test_values=False, + x=x[0], ) - skip_matrix = draw(helpers.ints(min_value=0, max_value=len(input) - 1)) - weight_dtype, weights = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("integer"), shape=(m,) - ) + + +@handle_test( + fn_tree="functional.ivy.experimental.initialize_tucker", + data=_initialize_tucker_data(), + test_with_out=st.just(False), +) +def test_initialize_tucker(*, data, test_flags, backend_fw, fn_name, on_device): + input_dtypes, x, rank, modes, non_negative, mask, svd_mask_repeats = data + results = helpers.test_function( + backend_to_test=backend_fw, + test_flags=test_flags, + fn_name=fn_name, + on_device=on_device, + input_dtypes=input_dtypes, + x=x, + rank=rank, + modes=modes, + non_negative=non_negative, + mask=mask, + svd_mask_repeats=svd_mask_repeats, + test_values=False, ) - mask_dtype, mask = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("integer"), - min_value=0, - max_value=1, - shape=(m,), - ) + + ret_np, ret_from_gt_np = results + + core = helpers.flatten_and_to_np(ret=ret_np[0], backend=backend_fw) + factors = helpers.flatten_and_to_np(ret=ret_np[1], backend=backend_fw) + core_gt = helpers.flatten_and_to_np( + ret=ret_from_gt_np[0], backend=test_flags.ground_truth_backend ) - return ( - input_dtypes + weight_dtype + mask_dtype, - input, - skip_matrix, - weights[0], - mask[0], + factors_gt = helpers.flatten_and_to_np( + ret=ret_from_gt_np[1], backend=test_flags.ground_truth_backend ) + with BackendHandler.update_backend(backend_fw) as ivy_backend: + n_elem = int(ivy_backend.prod(rank[: len(modes)])) * int( + ivy_backend.prod(x.shape[len(modes) :]) + ) + for c, c_gt in zip(core, core_gt): + assert np.prod(c.shape) == n_elem + assert np.prod(c_gt.shape) == n_elem + + for f, f_gt in zip(factors, factors_gt): + assert np.prod(f.shape) == np.prod(f_gt.shape) + @handle_test( fn_tree="functional.ivy.experimental.khatri_rao", @@ -796,238 +1040,51 @@ def test_khatri_rao_tensorly_2(t1, t2, true_res): assert np.allclose(res, true_res) -@st.composite -def _mode_dot_data(draw): - shape_t1 = draw(helpers.get_shape(min_num_dims=2, max_num_dims=5)) - mode = draw(helpers.ints(min_value=0, max_value=len(shape_t1) - 1)) - mode_dimsize = shape_t1[mode] - t1_dtype, t1 = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - shape=shape_t1, - large_abs_safety_factor=20, - small_abs_safety_factor=20, - safety_factor_scale="log", - ) - ) - t2_rows = draw(helpers.ints(min_value=1, max_value=4)) - shape_t2 = draw(st.sampled_from([(mode_dimsize,), (t2_rows, mode_dimsize)])) - t2_dtype, t2 = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - shape=shape_t2, - large_abs_safety_factor=20, - small_abs_safety_factor=20, - safety_factor_scale="log", - ) - ) - return t1_dtype + t2_dtype, t1[0], t2[0], mode - - -@handle_test( - fn_tree="functional.ivy.experimental.mode_dot", - data=_mode_dot_data(), -) -def test_mode_dot(*, data, test_flags, backend_fw, fn_name, on_device): - input_dtypes, t1, t2, mode = data - helpers.test_function( - backend_to_test=backend_fw, - test_flags=test_flags, - fn_name=fn_name, - on_device=on_device, - rtol_=1e-1, - atol_=1e-1, - input_dtypes=input_dtypes, - x=t1, - matrix_or_vector=t2, - mode=mode, - ) - - -@pytest.mark.parametrize( - "X, U, true_res", - [ - ( - [ - [[1, 13], [4, 16], [7, 19], [10, 22]], - [[2, 14], [5, 17], [8, 20], [11, 23]], - [[3, 15], [6, 18], [9, 21], [12, 24]], - ], - [[1, 3, 5], [2, 4, 6]], - [ - [[22, 130], [49, 157], [76, 184], [103, 211]], - [[28, 172], [64, 208], [100, 244], [136, 280]], - ], - ) - ], -) -def test_mode_dot_tensorly(X, U, true_res): - X = ivy.array(X) - U = ivy.array(U) - true_res = ivy.array(true_res) - res = ivy.mode_dot(X, U, 0) - assert np.allclose(true_res, res, atol=1e-1, rtol=1e-1) - - -@st.composite -def _multi_mode_dot_data(draw): - t1_dtype, t1, shape_t1 = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - ret_shape=True, - min_num_dims=2, - large_abs_safety_factor=20, - small_abs_safety_factor=20, - safety_factor_scale="log", - ) - ) - modes = [*range(len(shape_t1))] - skip = draw(st.lists(helpers.ints(min_value=0, max_value=len(shape_t1) - 1))) - t2 = [] - t2_dtype = [] - for i in modes: - mode_dimsize = shape_t1[i] - rows = draw(helpers.ints(min_value=1, max_value=4)) - shape = draw(st.sampled_from([(mode_dimsize,), (rows, mode_dimsize)])) - mat_or_vec_dtype, mat_or_vec = draw( - helpers.dtype_and_values( - dtype=t1_dtype, - shape=shape, - large_abs_safety_factor=20, - small_abs_safety_factor=20, - safety_factor_scale="log", - ) - ) - t2.append(mat_or_vec[0]) - t2_dtype.append(mat_or_vec_dtype[0]) - - return t1_dtype + t2_dtype, t1[0], t2, modes, skip - - -@handle_test( - fn_tree="functional.ivy.experimental.multi_mode_dot", - data=_multi_mode_dot_data(), -) -def test_multi_mode_dot(*, data, test_flags, backend_fw, fn_name, on_device): - input_dtypes, t1, t2, modes, skip = data - helpers.test_function( - backend_to_test=backend_fw, - test_flags=test_flags, - fn_name=fn_name, - on_device=on_device, - rtol_=1e-1, - atol_=1e-1, - input_dtypes=input_dtypes, - x=t1, - mat_or_vec_list=t2, - modes=modes, - skip=skip, - ) - - -# The following 2 tests have been adapted from TensorLy -# https://github.com/tensorly/tensorly/blob/main/tensorly/tenalg/tests/test_n_mode_product.py#L81 -@pytest.mark.parametrize( - "X, U, true_res", - [ - ([[1, 2], [0, -1]], [[2, 1], [-1, 1]], [1]), - ], -) -def test_multi_mode_dot_tensorly_1(X, U, true_res): - X, U, true_res = ivy.array(X), ivy.array(U), ivy.array(true_res) - res = ivy.multi_mode_dot(X, U, [0, 1]) - assert np.allclose(true_res, res) - - -@pytest.mark.parametrize("shape", ((3, 5, 4, 2),)) -def test_multi_mode_dot_tensorly_2(shape): - print(shape) - X = ivy.ones(shape) - vecs = [ivy.ones(s) for s in shape] - res = ivy.multi_mode_dot(X, vecs) - # result should be a scalar - assert ivy.shape(res) == () - assert np.allclose(res, np.prod(shape)) - - # Average pooling each mode - # Order should not matter - vecs = [vecs[i] / s for i, s in enumerate(shape)] - for modes in itertools.permutations(range(len(shape))): - res = ivy.multi_mode_dot(X, [vecs[i] for i in modes], modes=modes) - assert ivy.shape(res) == () - assert np.allclose(res, 1) - - @handle_test( - fn_tree="functional.ivy.experimental.svd_flip", - uv=helpers.dtype_and_values( + fn_tree="functional.ivy.experimental.kron", + dtype_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, min_num_dims=2, max_num_dims=2, - ), - u_based_decision=st.booleans(), - test_with_out=st.just(False), -) -def test_svd_flip(*, uv, u_based_decision, test_flags, backend_fw, fn_name, on_device): - input_dtypes, input = uv - helpers.test_function( - backend_to_test=backend_fw, - test_flags=test_flags, - fn_name=fn_name, - on_device=on_device, - rtol_=1e-1, - atol_=1e-1, - input_dtypes=input_dtypes, - U=input[0], - V=input[1], - u_based_decision=u_based_decision, - ) - - -# truncated svd -@st.composite -def _make_svd_nn_data(draw): - x_dtype, x, shape = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_num_dims=2, - max_num_dims=2, - min_dim_size=2, - max_dim_size=5, - min_value=1.0, - max_value=10.0, - ret_shape=True, - ) - ) - n, m = shape - _, U = draw( - helpers.dtype_and_values( - dtype=x_dtype, - available_dtypes=helpers.get_dtypes("float"), - shape=(n, m), - min_value=1.0, - max_value=10.0, - ) - ) - _, S = draw( - helpers.dtype_and_values( - dtype=x_dtype, - shape=(m,), - min_value=1.0, - max_value=10.0, - ) + min_dim_size=1, + max_dim_size=10, + num_arrays=2, + shared_dtype=True, + ), + test_gradients=st.just(False), +) +def test_kron(*, dtype_x, test_flags, backend_fw, fn_name, on_device): + dtype, x = dtype_x + helpers.test_function( + input_dtypes=dtype, + test_flags=test_flags, + on_device=on_device, + backend_to_test=backend_fw, + fn_name=fn_name, + a=x[0], + b=x[1], ) - _, V = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - shape=(m, m), - min_value=1.0, - max_value=10.0, - ) + + +@handle_test( + fn_tree="functional.ivy.experimental.kronecker", + data=_kronecker_data(), + test_instance_method=st.just(False), +) +def test_kronecker(*, data, test_flags, backend_fw, fn_name, on_device): + input_dtypes, input, skip_matrix, reverse = data + helpers.test_function( + backend_to_test=backend_fw, + test_flags=test_flags, + fn_name=fn_name, + on_device=on_device, + rtol_=1e-1, + atol_=1e-1, + input_dtypes=input_dtypes, + x=input, + skip_matrix=skip_matrix, + reverse=reverse, ) - nntype = draw(st.sampled_from(["nndsvd", "nndsvda"])) - return x_dtype, x[0], U[0], S[0], V[0], nntype @handle_test( @@ -1078,232 +1135,151 @@ def test_make_svd_non_negative(*, data, test_flags, backend_fw, fn_name, on_devi ) -# truncated svd -@st.composite -def _truncated_svd_data(draw): - x_dtype, x, shape = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_num_dims=2, - max_num_dims=2, - min_dim_size=2, - max_dim_size=5, - min_value=0.1, - max_value=10.0, - ret_shape=True, - ) +# matrix_exp +@handle_test( + fn_tree="functional.ivy.experimental.matrix_exp", + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=2, + max_num_dims=2, + min_dim_size=2, + max_dim_size=2, + min_value=-100, + max_value=100, + allow_nan=False, + shared_dtype=True, + ), + test_gradients=st.just(False), +) +def test_matrix_exp(dtype_x, test_flags, backend_fw, fn_name, on_device): + dtype, x = dtype_x + helpers.test_function( + input_dtypes=dtype, + test_flags=test_flags, + on_device=on_device, + backend_to_test=backend_fw, + fn_name=fn_name, + x=x[0], ) - uv = draw(st.booleans()) - n_eigen = draw(helpers.ints(min_value=1, max_value=max(shape[-2:]))) - return x_dtype, x[0], uv, n_eigen @handle_test( - fn_tree="functional.ivy.experimental.truncated_svd", - data=_truncated_svd_data(), - test_with_out=st.just(False), + fn_tree="functional.ivy.experimental.mode_dot", + data=_mode_dot_data(), ) -def test_truncated_svd(*, data, test_flags, backend_fw, fn_name, on_device): - input_dtype, x, uv, n_eigenvecs = data - results = helpers.test_function( +def test_mode_dot(*, data, test_flags, backend_fw, fn_name, on_device): + input_dtypes, t1, t2, mode = data + helpers.test_function( backend_to_test=backend_fw, test_flags=test_flags, fn_name=fn_name, on_device=on_device, - input_dtypes=input_dtype, - x=x, - compute_uv=uv, - n_eigenvecs=n_eigenvecs, - test_values=False, - return_flat_np_arrays=True, + rtol_=1e-1, + atol_=1e-1, + input_dtypes=input_dtypes, + x=t1, + matrix_or_vector=t2, + mode=mode, ) - if results is None: - return - - # value test based on recreating the original matrix and testing the consistency - ret_flat_np, ret_from_gt_flat_np = results - - if uv: - for i in range(len(ret_flat_np) // 3): - U = ret_flat_np[i] - S = ret_flat_np[len(ret_flat_np) // 3 + i] - Vh = ret_flat_np[2 * len(ret_flat_np) // 3 + i] - m = U.shape[-1] - n = Vh.shape[-1] - S = np.expand_dims(S, -2) if m > n else np.expand_dims(S, -1) - - for i in range(len(ret_from_gt_flat_np) // 3): - U_gt = ret_from_gt_flat_np[i] - S_gt = ret_from_gt_flat_np[len(ret_from_gt_flat_np) // 3 + i] - Vh_gt = ret_from_gt_flat_np[2 * len(ret_from_gt_flat_np) // 3 + i] - S_gt = np.expand_dims(S_gt, -2) if m > n else np.expand_dims(S_gt, -1) - - with BackendHandler.update_backend("numpy") as ivy_backend: - S_mat = ( - S - * ivy_backend.eye( - U.shape[-1], Vh.shape[-2], batch_shape=U.shape[:-2] - ).data - ) - S_mat_gt = ( - S_gt - * ivy_backend.eye( - U_gt.shape[-1], Vh_gt.shape[-2], batch_shape=U_gt.shape[:-2] - ).data - ) - reconstructed = np.matmul(np.matmul(U, S_mat), Vh) - reconstructed_gt = np.matmul(np.matmul(U_gt, S_mat_gt), Vh_gt) - # value test - helpers.assert_all_close( - reconstructed, - reconstructed_gt, - atol=1e-04, - backend=backend_fw, - ground_truth_backend=test_flags.ground_truth_backend, - ) - else: - S = ret_flat_np - S_gt = ret_from_gt_flat_np - helpers.assert_all_close( - S[0], - S_gt[0], - atol=1e-04, - backend=backend_fw, - ground_truth_backend=test_flags.ground_truth_backend, +@pytest.mark.parametrize( + "X, U, true_res", + [ + ( + [ + [[1, 13], [4, 16], [7, 19], [10, 22]], + [[2, 14], [5, 17], [8, 20], [11, 23]], + [[3, 15], [6, 18], [9, 21], [12, 24]], + ], + [[1, 3, 5], [2, 4, 6]], + [ + [[22, 130], [49, 157], [76, 184], [103, 211]], + [[28, 172], [64, 208], [100, 244], [136, 280]], + ], ) + ], +) +def test_mode_dot_tensorly(X, U, true_res): + X = ivy.array(X) + U = ivy.array(U) + true_res = ivy.array(true_res) + res = ivy.mode_dot(X, U, 0) + assert np.allclose(true_res, res, atol=1e-1, rtol=1e-1) -# intialize tucker -@st.composite -def _initialize_tucker_data(draw): - x_dtype, x, shape = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_num_dims=2, - max_num_dims=5, - min_dim_size=2, - max_dim_size=5, - min_value=0.1, - max_value=10.0, - ret_shape=True, - ) - ) - dims = len(shape) - rank = [] - for i in range(dims): - rank.append(draw(helpers.ints(min_value=1, max_value=shape[i]))) - n_modes = draw(helpers.ints(min_value=2, max_value=dims)) - modes = [*range(dims)][:n_modes] - mask_dtype, mask = draw( - helpers.dtype_and_values( - dtype=["int32"], - shape=shape, - min_value=0, - max_value=1, - ) - ) - svd_mask_repeats = draw(helpers.ints(min_value=0, max_value=3)) - non_negative = draw(st.booleans()) - return ( - x_dtype + mask_dtype, - x[0], - rank, - modes, - non_negative, - mask[0], - svd_mask_repeats, +@handle_test( + fn_tree="functional.ivy.experimental.multi_dot", + dtype_x=_generate_multi_dot_dtype_and_arrays(), + test_gradients=st.just(False), +) +def test_multi_dot(dtype_x, test_flags, backend_fw, fn_name, on_device): + dtype, x = dtype_x + helpers.test_function( + input_dtypes=dtype, + test_flags=test_flags, + on_device=on_device, + backend_to_test=backend_fw, + fn_name=fn_name, + test_values=True, + x=x, + rtol_=1e-1, + atol_=6e-1, ) @handle_test( - fn_tree="functional.ivy.experimental.initialize_tucker", - data=_initialize_tucker_data(), - test_with_out=st.just(False), + fn_tree="functional.ivy.experimental.multi_mode_dot", + data=_multi_mode_dot_data(), ) -def test_initialize_tucker(*, data, test_flags, backend_fw, fn_name, on_device): - input_dtypes, x, rank, modes, non_negative, mask, svd_mask_repeats = data - results = helpers.test_function( +def test_multi_mode_dot(*, data, test_flags, backend_fw, fn_name, on_device): + input_dtypes, t1, t2, modes, skip = data + helpers.test_function( backend_to_test=backend_fw, test_flags=test_flags, fn_name=fn_name, on_device=on_device, + rtol_=1e-1, + atol_=1e-1, input_dtypes=input_dtypes, - x=x, - rank=rank, + x=t1, + mat_or_vec_list=t2, modes=modes, - non_negative=non_negative, - mask=mask, - svd_mask_repeats=svd_mask_repeats, - test_values=False, + skip=skip, ) - ret_np, ret_from_gt_np = results - - core = helpers.flatten_and_to_np(ret=ret_np[0], backend=backend_fw) - factors = helpers.flatten_and_to_np(ret=ret_np[1], backend=backend_fw) - core_gt = helpers.flatten_and_to_np( - ret=ret_from_gt_np[0], backend=test_flags.ground_truth_backend - ) - factors_gt = helpers.flatten_and_to_np( - ret=ret_from_gt_np[1], backend=test_flags.ground_truth_backend - ) - with BackendHandler.update_backend(backend_fw) as ivy_backend: - n_elem = int(ivy_backend.prod(rank[: len(modes)])) * int( - ivy_backend.prod(x.shape[len(modes) :]) - ) - for c, c_gt in zip(core, core_gt): - assert np.prod(c.shape) == n_elem - assert np.prod(c_gt.shape) == n_elem +# The following 2 tests have been adapted from TensorLy +# https://github.com/tensorly/tensorly/blob/main/tensorly/tenalg/tests/test_n_mode_product.py#L81 +@pytest.mark.parametrize( + "X, U, true_res", + [ + ([[1, 2], [0, -1]], [[2, 1], [-1, 1]], [1]), + ], +) +def test_multi_mode_dot_tensorly_1(X, U, true_res): + X, U, true_res = ivy.array(X), ivy.array(U), ivy.array(true_res) + res = ivy.multi_mode_dot(X, U, [0, 1]) + assert np.allclose(true_res, res) - for f, f_gt in zip(factors, factors_gt): - assert np.prod(f.shape) == np.prod(f_gt.shape) +@pytest.mark.parametrize("shape", ((3, 5, 4, 2),)) +def test_multi_mode_dot_tensorly_2(shape): + print(shape) + X = ivy.ones(shape) + vecs = [ivy.ones(s) for s in shape] + res = ivy.multi_mode_dot(X, vecs) + # result should be a scalar + assert ivy.shape(res) == () + assert np.allclose(res, np.prod(shape)) -# partial tucker -@st.composite -def _partial_tucker_data(draw): - x_dtype, x, shape = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_num_dims=2, - max_num_dims=5, - min_dim_size=2, - max_dim_size=5, - min_value=0.1, - max_value=10.0, - ret_shape=True, - ) - ) - dims = len(shape) - rank = [] - for i in range(dims): - rank.append(draw(helpers.ints(min_value=1, max_value=shape[i]))) - n_modes = draw(helpers.ints(min_value=2, max_value=dims)) - modes = [*range(dims)][:n_modes] - mask_dtype, mask = draw( - helpers.dtype_and_values( - dtype=["int32"], - shape=shape, - min_value=0, - max_value=1, - ) - ) - svd_mask_repeats = draw(helpers.ints(min_value=0, max_value=3)) - n_iter_max = draw(helpers.ints(min_value=1, max_value=7)) - tol = draw(helpers.floats(min_value=1e-5, max_value=1e-1)) - return ( - x_dtype + mask_dtype, - x[0], - rank, - modes, - n_iter_max, - mask[0], - svd_mask_repeats, - tol, - ) + # Average pooling each mode + # Order should not matter + vecs = [vecs[i] / s for i, s in enumerate(shape)] + for modes in itertools.permutations(range(len(shape))): + res = ivy.multi_mode_dot(X, [vecs[i] for i in modes], modes=modes) + assert ivy.shape(res) == () + assert np.allclose(res, 1) @handle_test( @@ -1418,76 +1394,108 @@ def test_partial_tucker_tensorly(tol_norm_2, tol_max_abs, modes, shape): np.allclose(factor1, factor2) -# tucker -@st.composite -def _tucker_data(draw): - x_dtype, x, shape = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_num_dims=2, - max_num_dims=4, - min_dim_size=2, - max_dim_size=3, - min_value=0.1, - max_value=10.0, - ret_shape=True, - ) +@handle_test( + fn_tree="functional.ivy.experimental.svd_flip", + uv=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, + min_num_dims=2, + max_num_dims=2, + ), + u_based_decision=st.booleans(), + test_with_out=st.just(False), +) +def test_svd_flip(*, uv, u_based_decision, test_flags, backend_fw, fn_name, on_device): + input_dtypes, input = uv + helpers.test_function( + backend_to_test=backend_fw, + test_flags=test_flags, + fn_name=fn_name, + on_device=on_device, + rtol_=1e-1, + atol_=1e-1, + input_dtypes=input_dtypes, + U=input[0], + V=input[1], + u_based_decision=u_based_decision, ) - dims = len(shape) - rank = [] - for i in range(dims): - rank.append(draw(helpers.ints(min_value=1, max_value=shape[i]))) - mask_dtype, mask = draw( - helpers.dtype_and_values( - dtype=["int32"], - shape=shape, - min_value=0, - max_value=1, - ) + + +@handle_test( + fn_tree="functional.ivy.experimental.truncated_svd", + data=_truncated_svd_data(), + test_with_out=st.just(False), +) +def test_truncated_svd(*, data, test_flags, backend_fw, fn_name, on_device): + input_dtype, x, uv, n_eigenvecs = data + results = helpers.test_function( + backend_to_test=backend_fw, + test_flags=test_flags, + fn_name=fn_name, + on_device=on_device, + input_dtypes=input_dtype, + x=x, + compute_uv=uv, + n_eigenvecs=n_eigenvecs, + test_values=False, + return_flat_np_arrays=True, ) - svd_mask_repeats = draw(helpers.ints(min_value=0, max_value=1)) - n_iter_max = draw(helpers.ints(min_value=0, max_value=2)) - tol = draw(helpers.floats(min_value=1e-5, max_value=1e-1)) - init = draw(st.sampled_from(["svd", "random"])) - fixed_factors = draw(st.booleans()) - if fixed_factors: - _, core = draw( - helpers.dtype_and_values( - dtype=x_dtype, - min_value=0.1, - max_value=10.0, - shape=rank, - ) - ) - factors = [] - for i in range(dims): - _, factor = draw( - helpers.dtype_and_values( - dtype=x_dtype, - min_value=0.1, - max_value=10.0, - shape=(shape[i], rank[i]), - ) + + if results is None: + return + + # value test based on recreating the original matrix and testing the consistency + ret_flat_np, ret_from_gt_flat_np = results + + if uv: + for i in range(len(ret_flat_np) // 3): + U = ret_flat_np[i] + S = ret_flat_np[len(ret_flat_np) // 3 + i] + Vh = ret_flat_np[2 * len(ret_flat_np) // 3 + i] + m = U.shape[-1] + n = Vh.shape[-1] + S = np.expand_dims(S, -2) if m > n else np.expand_dims(S, -1) + + for i in range(len(ret_from_gt_flat_np) // 3): + U_gt = ret_from_gt_flat_np[i] + S_gt = ret_from_gt_flat_np[len(ret_from_gt_flat_np) // 3 + i] + Vh_gt = ret_from_gt_flat_np[2 * len(ret_from_gt_flat_np) // 3 + i] + S_gt = np.expand_dims(S_gt, -2) if m > n else np.expand_dims(S_gt, -1) + + with BackendHandler.update_backend("numpy") as ivy_backend: + S_mat = ( + S + * ivy_backend.eye( + U.shape[-1], Vh.shape[-2], batch_shape=U.shape[:-2] + ).data ) - factors.append(factor[0]) - fixed_factors = draw( - st.lists( - helpers.ints(min_value=0, max_value=dims - 1), unique=True, min_size=1 + S_mat_gt = ( + S_gt + * ivy_backend.eye( + U_gt.shape[-1], Vh_gt.shape[-2], batch_shape=U_gt.shape[:-2] + ).data ) + reconstructed = np.matmul(np.matmul(U, S_mat), Vh) + reconstructed_gt = np.matmul(np.matmul(U_gt, S_mat_gt), Vh_gt) + + # value test + helpers.assert_all_close( + reconstructed, + reconstructed_gt, + atol=1e-04, + backend=backend_fw, + ground_truth_backend=test_flags.ground_truth_backend, + ) + else: + S = ret_flat_np + S_gt = ret_from_gt_flat_np + helpers.assert_all_close( + S[0], + S_gt[0], + atol=1e-04, + backend=backend_fw, + ground_truth_backend=test_flags.ground_truth_backend, ) - rank = [rank[i] for i in range(dims) if i not in fixed_factors] - init = ivy.TuckerTensor((core[0], factors)) - return ( - x_dtype + mask_dtype, - x[0], - rank, - fixed_factors, - init, - n_iter_max, - mask[0], - svd_mask_repeats, - tol, - ) @handle_test( diff --git a/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_manipulation.py b/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_manipulation.py index 38686fcd1ea61..6f964a658cf1e 100644 --- a/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_manipulation.py +++ b/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_manipulation.py @@ -11,178 +11,126 @@ from ivy_tests.test_ivy.test_functional.test_core.test_manipulation import _get_splits -# Helpers # -# ------- # +# --- Helpers --- # +# --------------- # -# moveaxis -@handle_test( - fn_tree="functional.ivy.experimental.moveaxis", - dtype_and_a=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_value=-100, - max_value=100, - shape=st.shared( - helpers.get_shape( - min_num_dims=1, - max_num_dims=3, - min_dim_size=1, - max_dim_size=3, - ), - key="a_s_d", - ), - ), - source=helpers.get_axis( - allow_none=False, - unique=True, - shape=st.shared( - helpers.get_shape( - min_num_dims=1, - max_num_dims=3, - min_dim_size=1, - max_dim_size=3, - ), - key="a_s_d", - ), - min_size=1, - force_int=True, - ), - destination=helpers.get_axis( - allow_none=False, - unique=True, - shape=st.shared( - helpers.get_shape( - min_num_dims=1, - max_num_dims=3, - min_dim_size=1, - max_dim_size=3, - ), - key="a_s_d", - ), - min_size=1, - force_int=True, - ), - test_gradients=st.just(False), -) -def test_moveaxis( - *, dtype_and_a, source, destination, test_flags, backend_fw, fn_name, on_device -): - input_dtype, a = dtype_and_a - helpers.test_function( - input_dtypes=input_dtype, - test_flags=test_flags, - on_device=on_device, - backend_to_test=backend_fw, - fn_name=fn_name, - a=a[0], - source=source, - destination=destination, +@st.composite +def _as_strided_helper(draw): + dtype, x = draw(helpers.dtype_and_values(min_num_dims=1, max_num_dims=5)) + x = x[0] + itemsize = x.itemsize + shape = draw(helpers.get_shape(min_num_dims=1, max_num_dims=5)) + new_ndim = len(shape) + strides = draw( + st.lists( + st.integers(min_value=1, max_value=16), + min_size=new_ndim, + max_size=new_ndim, + ).filter(lambda x: all(x[i] % itemsize == 0 for i in range(new_ndim))) ) + assume(_check_bounds(x.shape, shape, strides, itemsize)) + return dtype, x, shape, strides -# heaviside -@handle_test( - fn_tree="functional.ivy.experimental.heaviside", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_value=-100, - max_value=100, - min_num_dims=1, - max_num_dims=3, - min_dim_size=1, - max_dim_size=3, - num_arrays=2, - shared_dtype=True, - ), - test_gradients=st.just(False), -) -def test_heaviside(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): - input_dtype, x = dtype_and_x - helpers.test_function( - input_dtypes=input_dtype, - test_flags=test_flags, - on_device=on_device, - backend_to_test=backend_fw, - fn_name=fn_name, - x1=x[0], - x2=x[0], +@st.composite +def _associative_scan_helper(draw): + input_dtype = draw( + st.shared( + st.sampled_from(draw(helpers.get_dtypes("float"))), + key="shared_dtype", + ).filter(lambda _x: "float16" not in _x) + ) + random_size = draw( + st.shared(helpers.ints(min_value=1, max_value=5), key="shared_size") + ) + shared_size = draw( + st.shared(helpers.ints(min_value=1, max_value=5), key="shared_size") + ) + shape = tuple([random_size, shared_size, shared_size]) + matrix = draw( + helpers.array_values( + dtype=input_dtype, + shape=shape, + min_value=1, + max_value=10, + ) + ) + axis = draw( + helpers.get_axis( + shape=shape, + allow_neg=False, + force_int=True, + ).filter(lambda _x: _x < len(shape) - 2) ) + return [input_dtype], matrix, axis -# flipud -@handle_test( - fn_tree="functional.ivy.experimental.flipud", - dtype_and_m=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_value=-100, - max_value=100, - min_num_dims=1, - max_num_dims=3, - min_dim_size=1, - max_dim_size=3, - ), - test_gradients=st.just(False), -) -def test_flipud(*, dtype_and_m, test_flags, backend_fw, fn_name, on_device): - input_dtype, m = dtype_and_m - helpers.test_function( - input_dtypes=input_dtype, - test_flags=test_flags, - on_device=on_device, - backend_to_test=backend_fw, - fn_name=fn_name, - m=m[0], +@st.composite +def _concat_from_sequence_helper(draw): + dtypes, arrays, shape = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + num_arrays=helpers.ints(min_value=1, max_value=6), + ret_shape=True, + min_num_dims=2, + min_dim_size=2, + shared_dtype=True, + ) + ) + axis = draw( + helpers.get_axis( + shape=shape, + force_int=True, + ) ) + return dtypes, arrays, axis -# vstack -@handle_test( - fn_tree="functional.ivy.experimental.vstack", - dtype_and_m=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - shape=helpers.get_shape( - min_num_dims=1, - ), - shared_dtype=True, - num_arrays=helpers.ints(min_value=2, max_value=10), - ), - test_gradients=st.just(False), -) -def test_vstack(*, dtype_and_m, test_flags, backend_fw, fn_name, on_device): - input_dtype, m = dtype_and_m - helpers.test_function( - input_dtypes=input_dtype, - test_flags=test_flags, - on_device=on_device, - backend_to_test=backend_fw, - fn_name=fn_name, - arrays=m, +@st.composite +def _flatten_data_helper(draw): + mixed_fn_compos = draw(st.booleans()) + is_torch_backend = ivy.current_backend_str() == "torch" + + dtype_and_x = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes( + "valid", mixed_fn_compos=mixed_fn_compos + ), + shape=st.shared(helpers.get_shape(), key="flatten_shape"), + ) + ) + axes = draw( + helpers.get_axis( + shape=st.shared(helpers.get_shape(), key="flatten_shape"), + min_size=2, + max_size=2, + unique=False, + force_tuple=True, + ) ) + order = draw(st.sampled_from(["C", "F"])) + if not mixed_fn_compos and is_torch_backend: + order = "C" + return dtype_and_x, axes, order -# hstack -@handle_test( - fn_tree="functional.ivy.experimental.hstack", - dtype_and_m=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - shared_dtype=True, - num_arrays=helpers.ints(min_value=2, max_value=10), - shape=helpers.get_shape( - min_num_dims=1, - ), - ), - test_gradients=st.just(False), -) -def test_hstack(dtype_and_m, test_flags, backend_fw, fn_name, on_device): - input_dtype, m = dtype_and_m - helpers.test_function( - input_dtypes=input_dtype, - test_flags=test_flags, - on_device=on_device, - backend_to_test=backend_fw, - fn_name=fn_name, - arrays=m, +@st.composite +def _fold_data(draw): + shape = draw( + helpers.get_shape( + min_num_dims=2, max_num_dims=5, min_dim_size=2, max_dim_size=3 + ) ) + mode = draw(helpers.ints(min_value=0, max_value=len(shape) - 1)) + reduced_dims = int(ivy.prod(shape[0:mode]) * ivy.prod(shape[mode + 1 :])) + unfolded_shape = (shape[mode], reduced_dims) + dtype, input = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), shape=unfolded_shape + ) + ) + return dtype, input, shape, mode @st.composite @@ -214,208 +162,44 @@ def _get_dtype_values_k_axes_for_rot90( helpers.ints(min_value=-len(shape), max_value=len(shape) - 1), min_size=2, max_size=2, - unique=True, - ).filter(lambda axes: abs(axes[0] - axes[1]) != len(shape)) - ) - dtype = draw(st.sampled_from(draw(available_dtypes))) - values = draw( - helpers.array_values( - dtype=dtype, - shape=shape, - min_value=min_value, - max_value=max_value, - allow_inf=allow_inf, - exclude_min=exclude_min, - exclude_max=exclude_max, - large_abs_safety_factor=72, - small_abs_safety_factor=72, - safety_factor_scale="log", - ) - ) - return [dtype], values, k, axes - - -# rot90 -@handle_test( - fn_tree="functional.ivy.experimental.rot90", - dtype_m_k_axes=_get_dtype_values_k_axes_for_rot90( - available_dtypes=helpers.get_dtypes("numeric"), - min_num_dims=1, - max_num_dims=5, - min_dim_size=1, - max_dim_size=10, - ), - test_gradients=st.just(False), -) -def test_rot90(dtype_m_k_axes, test_flags, backend_fw, fn_name, on_device): - input_dtype, m, k, axes = dtype_m_k_axes - helpers.test_function( - input_dtypes=input_dtype, - test_flags=test_flags, - on_device=on_device, - backend_to_test=backend_fw, - fn_name=fn_name, - m=m, - k=k, - axes=tuple(axes), - ) - - -# top_k -@handle_test( - fn_tree="functional.ivy.experimental.top_k", - dtype_x_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("numeric"), - min_num_dims=1, - force_int_axis=True, - valid_axis=True, - ), - k=helpers.ints(min_value=1, max_value=4), - largest=st.booleans(), - sorted=st.booleans(), - test_gradients=st.just(False), -) -def test_top_k( - *, dtype_x_axis, k, largest, sorted, test_flags, backend_fw, fn_name, on_device -): - dtype, x, axis = dtype_x_axis - helpers.test_function( - input_dtypes=dtype, - test_flags=test_flags, - backend_to_test=backend_fw, - fn_name=fn_name, - on_device=on_device, - x=x[0], - k=k, - axis=axis, - largest=largest, - sorted=sorted, - ) - - -# fliplr -@handle_test( - fn_tree="functional.ivy.experimental.fliplr", - dtype_and_m=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - min_num_dims=2, - ), - test_gradients=st.just(False), -) -def test_fliplr(*, dtype_and_m, test_flags, backend_fw, fn_name, on_device): - input_dtype, m = dtype_and_m - helpers.test_function( - input_dtypes=input_dtype, - test_flags=test_flags, - backend_to_test=backend_fw, - fn_name=fn_name, - on_device=on_device, - m=m[0], - ) - - -# i0 -@handle_test( - fn_tree="functional.ivy.experimental.i0", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_value=-10, - max_value=10, - min_num_dims=1, - max_num_dims=3, - min_dim_size=1, - max_dim_size=3, - ), - test_gradients=st.just(False), -) -def test_i0(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): - input_dtype, x = dtype_and_x - helpers.test_function( - input_dtypes=input_dtype, - test_flags=test_flags, - backend_to_test=backend_fw, - fn_name=fn_name, - on_device=on_device, - x=x[0], - ) - - -@st.composite -def _flatten_data_helper(draw): - mixed_fn_compos = draw(st.booleans()) - is_torch_backend = ivy.current_backend_str() == "torch" - - dtype_and_x = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes( - "valid", mixed_fn_compos=mixed_fn_compos - ), - shape=st.shared(helpers.get_shape(), key="flatten_shape"), - ) - ) - axes = draw( - helpers.get_axis( - shape=st.shared(helpers.get_shape(), key="flatten_shape"), - min_size=2, - max_size=2, - unique=False, - force_tuple=True, + unique=True, + ).filter(lambda axes: abs(axes[0] - axes[1]) != len(shape)) + ) + dtype = draw(st.sampled_from(draw(available_dtypes))) + values = draw( + helpers.array_values( + dtype=dtype, + shape=shape, + min_value=min_value, + max_value=max_value, + allow_inf=allow_inf, + exclude_min=exclude_min, + exclude_max=exclude_max, + large_abs_safety_factor=72, + small_abs_safety_factor=72, + safety_factor_scale="log", ) ) - order = draw(st.sampled_from(["C", "F"])) - if not mixed_fn_compos and is_torch_backend: - order = "C" - return dtype_and_x, axes, order + return [dtype], values, k, axes -@handle_test( - fn_tree="functional.ivy.experimental.flatten", - data=_flatten_data_helper(), -) -def test_flatten( - *, - data, - test_flags, - backend_fw, - fn_name, - on_device, -): - (input_dtypes, x), axes, order = data - helpers.test_function( - input_dtypes=input_dtypes, - test_flags=test_flags, - backend_to_test=backend_fw, - fn_name=fn_name, - on_device=on_device, - x=x[0], - start_dim=axes[0], - end_dim=axes[1], - order=order, +@st.composite +def _matricize_data(draw): + input_dtype, input, shape = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + ret_shape=True, + min_num_dims=2, + max_num_dims=5, + ) ) - - -def st_tuples(elements, *, min_size=0, max_size=None, unique_by=None, unique=False): - return st.lists( - elements, - min_size=min_size, - max_size=max_size, - unique_by=unique_by, - unique=unique, - ).map(tuple) - - -def _st_tuples_or_int(n_pairs, min_val=0): - return st.one_of( - st_tuples( - st.tuples( - st.integers(min_value=min_val, max_value=4), - st.integers(min_value=min_val, max_value=4), - ), - min_size=n_pairs, - max_size=n_pairs, - ), - helpers.ints(min_value=min_val, max_value=4), + ndims = len(shape) + dims = set([*range(ndims)]) + row_modes = set( + draw(st.lists(helpers.ints(min_value=0, max_value=ndims - 1), min_size=1)) ) + col_modes = dims - row_modes + return input_dtype, input, row_modes, col_modes @st.composite @@ -478,97 +262,194 @@ def _pad_helper(draw): return dtype, input[0], pad_width, stat_length, constant_values, end_values, mode -@handle_test( - fn_tree="functional.ivy.experimental.pad", - ground_truth_backend="numpy", - dtype_and_input_and_other=_pad_helper(), - reflect_type=st.sampled_from(["even", "odd"]), - test_with_out=st.just(False), - test_gradients=st.just(False), -) -def test_pad( - *, - dtype_and_input_and_other, - reflect_type, - test_flags, - backend_fw, - fn_name, - on_device, -): - ( - dtype, - input, - pad_width, - stat_length, - constant_values, - end_values, - mode, - ) = dtype_and_input_and_other - helpers.test_function( - input_dtypes=dtype, - test_flags=test_flags, - backend_to_test=backend_fw, - fn_name=fn_name, - on_device=on_device, - input=input, - pad_width=pad_width, - mode=mode, - stat_length=stat_length, - constant_values=constant_values, - end_values=end_values, - reflect_type=reflect_type, +@st.composite +def _partial_fold_data(draw): + shape = draw( + helpers.get_shape( + min_num_dims=2, max_num_dims=5, min_dim_size=2, max_dim_size=3 + ) + ) + ndims = len(shape) + mode_and_skip_begin = draw( + st.lists( + helpers.ints(min_value=0, max_value=ndims - 1), min_size=2, max_size=2 + ).filter(lambda nums: np.sum(nums) <= ndims - 1) + ) + skip_begin, mode = sorted(mode_and_skip_begin) + skip_end = draw( + helpers.ints(min_value=0, max_value=ndims - (skip_begin + mode) - 1) + ) + if skip_end != 0: + reduced_dims = int( + ivy.prod(shape[skip_begin : skip_begin + mode]) + * ivy.prod(shape[skip_begin + mode + 1 : -skip_end]) + ) + unfolded_shape = ( + *shape[:skip_begin], + shape[skip_begin + mode], + reduced_dims, + *shape[-skip_end:], + ) + else: + reduced_dims = int( + ivy.prod(shape[skip_begin : skip_begin + mode]) + * ivy.prod(shape[skip_begin + mode + 1 :]) + ) + unfolded_shape = (*shape[:skip_begin], shape[skip_begin + mode], reduced_dims) + + dtype, input = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), shape=unfolded_shape + ) + ) + return dtype, input, skip_begin, shape, mode + + +@st.composite +def _partial_tensor_to_vec_data(draw): + input_dtype, input, shape = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), min_num_dims=1, ret_shape=True + ) + ) + ndims = len(shape) + skip_begin = draw(helpers.ints(min_value=0, max_value=ndims - 1)) + skip_end = draw(helpers.ints(min_value=0, max_value=ndims - 1 - skip_begin)) + return input_dtype, input, skip_begin, skip_end + + +@st.composite +def _partial_unfold_data(draw): + dtype, input = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=1, + ) + ) + ndims = len(input[0].shape) + mode_and_skip_begin = draw( + st.lists( + helpers.ints(min_value=0, max_value=ndims - 1), min_size=2, max_size=2 + ).filter(lambda nums: np.sum(nums) <= ndims - 1) ) + skip_begin, mode = sorted(mode_and_skip_begin) + skip_end = draw( + helpers.ints(min_value=0, max_value=ndims - (skip_begin + mode) - 1) + ) + ravel_tensors = draw(st.booleans()) + return dtype, input, mode, skip_begin, skip_end, ravel_tensors + + +@st.composite +def _partial_vec_to_tensor(draw): + shape = draw(helpers.get_shape(min_num_dims=1, max_num_dims=5)) + numel = int(ivy.prod(shape)) + input_dtype, input = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), shape=(numel,) + ) + ) + ndims = len(shape) + skip_begin = draw(helpers.ints(min_value=0, max_value=ndims - 1)) + return input_dtype, input, shape, skip_begin + + +@st.composite +def _soft_thresholding_data(draw): + x_min, x_max = 1e-2, 1e2 + x_dtype, x, shape = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + ret_shape=True, + min_value=x_min, + max_value=x_max, + ) + ) + threshold_choice_1 = draw(helpers.floats(min_value=x_min, max_value=x_max)) + t_dtype, threshold_choice_2 = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + shape=shape, + min_value=x_min, + max_value=x_max, + ) + ) + threshold = draw(st.sampled_from([threshold_choice_1, threshold_choice_2])) + return x_dtype + t_dtype, x, threshold + + +def _st_tuples_or_int(n_pairs, min_val=0): + return st.one_of( + st_tuples( + st.tuples( + st.integers(min_value=min_val, max_value=4), + st.integers(min_value=min_val, max_value=4), + ), + min_size=n_pairs, + max_size=n_pairs, + ), + helpers.ints(min_value=min_val, max_value=4), + ) + + +# --- Main --- # +# ------------ # + + +def st_tuples(elements, *, min_size=0, max_size=None, unique_by=None, unique=False): + return st.lists( + elements, + min_size=min_size, + max_size=max_size, + unique_by=unique_by, + unique=unique, + ).map(tuple) -# vsplit @handle_test( - fn_tree="functional.ivy.experimental.vsplit", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - shape=st.shared(helpers.get_shape(min_num_dims=2), key="value_shape"), - ), - indices_or_sections=_get_splits(allow_none=False, min_num_dims=2, axis=0), - test_gradients=st.just(False), + fn_tree="as_strided", + all_args=_as_strided_helper(), test_with_out=st.just(False), + test_gradients=st.just(False), + ground_truth_backend="numpy", ) -def test_vsplit( - dtype_and_x, indices_or_sections, test_flags, backend_fw, fn_name, on_device -): - input_dtype, x = dtype_and_x +def test_as_strided(*, all_args, test_flags, backend_fw, fn_name, on_device): + dtype, x, shape, strides = all_args helpers.test_function( - input_dtypes=input_dtype, - on_device=on_device, + input_dtypes=dtype, test_flags=test_flags, backend_to_test=backend_fw, fn_name=fn_name, - x=x[0], - indices_or_sections=indices_or_sections, + on_device=on_device, + x=x, + shape=shape, + strides=strides, ) -# dsplit +# associative_scan @handle_test( - fn_tree="functional.ivy.experimental.dsplit", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - shape=st.shared(helpers.get_shape(min_num_dims=3), key="value_shape"), - ), - indices_or_sections=_get_splits(allow_none=False, min_num_dims=3, axis=2), - test_gradients=st.just(False), + fn_tree="functional.ivy.experimental.associative_scan", + dtype_elems_axis=_associative_scan_helper(), + fn=st.sampled_from([ivy.matmul, ivy.multiply, ivy.add]), + reverse=st.booleans(), test_with_out=st.just(False), + ground_truth_backend="jax", ) -def test_dsplit( - dtype_and_x, indices_or_sections, test_flags, backend_fw, fn_name, on_device +def test_associative_scan( + *, dtype_elems_axis, fn, reverse, fn_name, test_flags, backend_fw, on_device ): - input_dtype, x = dtype_and_x + dtype, elems, axis = dtype_elems_axis helpers.test_function( - input_dtypes=input_dtype, + fn_name=fn_name, test_flags=test_flags, backend_to_test=backend_fw, - fn_name=fn_name, on_device=on_device, - x=x[0], - indices_or_sections=indices_or_sections, + input_dtypes=dtype, + elems=elems, + fn=fn, + reverse=reverse, + axis=axis, ) @@ -598,31 +479,6 @@ def test_atleast_1d(dtype_and_x, test_flags, backend_fw, fn_name, on_device): ) -# dstack -@handle_test( - fn_tree="functional.ivy.experimental.dstack", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - shared_dtype=True, - num_arrays=helpers.ints(min_value=1, max_value=10), - shape=helpers.get_shape( - min_num_dims=1, - ), - ), - test_gradients=st.just(False), -) -def test_dstack(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): - input_dtype, x = dtype_and_x - helpers.test_function( - input_dtypes=input_dtype, - test_flags=test_flags, - backend_to_test=backend_fw, - fn_name=fn_name, - on_device=on_device, - arrays=x, - ) - - # atleast_2d @handle_test( fn_tree="functional.ivy.experimental.atleast_2d", @@ -675,69 +531,70 @@ def test_atleast_3d(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): ) -# take_along_axis +# broadcast_shapes @handle_test( - fn_tree="functional.ivy.experimental.take_along_axis", - dtype_x_indices_axis=helpers.array_indices_axis( - array_dtypes=helpers.get_dtypes("numeric"), - indices_dtypes=["int32", "int64"], - min_num_dims=1, - max_num_dims=5, - min_dim_size=1, - max_dim_size=10, - indices_same_dims=True, - valid_bounds=False, + fn_tree="functional.ivy.experimental.broadcast_shapes", + shapes=nph.mutually_broadcastable_shapes( + num_shapes=4, min_dims=1, max_dims=5, min_side=1, max_side=5 ), - mode=st.sampled_from(["clip", "fill", "drop"]), - ground_truth_backend="jax", + test_instance_method=st.just(False), + test_with_out=st.just(False), test_gradients=st.just(False), ) -def test_take_along_axis( - *, - dtype_x_indices_axis, - mode, - test_flags, - backend_fw, - fn_name, - on_device, +def test_broadcast_shapes(*, shapes, test_flags, backend_fw, fn_name, on_device): + shape, _ = shapes + shapes = {f"shape{i}": shape[i] for i in range(len(shape))} + test_flags.num_positional_args = len(shapes) + helpers.test_function( + input_dtypes=["int64"], + test_flags=test_flags, + backend_to_test=backend_fw, + fn_name=fn_name, + on_device=on_device, + **shapes, + ) + + +# concat_from_sequence +@handle_test( + fn_tree="functional.ivy.experimental.concat_from_sequence", + dtypes_arrays_axis=_concat_from_sequence_helper(), + new_axis=st.integers(min_value=0, max_value=1), + container_flags=st.just([False]), + test_instance_method=st.just(False), +) +def test_concat_from_sequence( + *, dtypes_arrays_axis, new_axis, test_flags, backend_fw, fn_name, on_device ): - dtypes, x, indices, axis, _ = dtype_x_indices_axis + dtypes, arrays, axis = dtypes_arrays_axis + helpers.test_function( input_dtypes=dtypes, test_flags=test_flags, backend_to_test=backend_fw, fn_name=fn_name, on_device=on_device, - arr=x, - indices=indices, + input_sequence=arrays, + new_axis=new_axis, axis=axis, - mode=mode, ) -# hsplit -# TODO: there is a failure with paddle (dtype('int32')) caused by the `_get_splits` -# method which returns a numpy array with a numpy dtype +# dsplit @handle_test( - fn_tree="functional.ivy.experimental.hsplit", + fn_tree="functional.ivy.experimental.dsplit", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), - shape=st.shared(helpers.get_shape(min_num_dims=2), key="value_shape"), + shape=st.shared(helpers.get_shape(min_num_dims=3), key="value_shape"), ), - indices_or_sections=_get_splits(allow_none=False, min_num_dims=2, axis=1), + indices_or_sections=_get_splits(allow_none=False, min_num_dims=3, axis=2), test_gradients=st.just(False), test_with_out=st.just(False), ) -def test_hsplit( +def test_dsplit( dtype_and_x, indices_or_sections, test_flags, backend_fw, fn_name, on_device ): input_dtype, x = dtype_and_x - if ( - not isinstance(indices_or_sections, int) - and not isinstance(indices_or_sections, list) - and indices_or_sections is not None - ): - input_dtype = [*input_dtype, indices_or_sections.dtype] helpers.test_function( input_dtypes=input_dtype, test_flags=test_flags, @@ -749,27 +606,28 @@ def test_hsplit( ) -# broadcast_shapes +# dstack @handle_test( - fn_tree="functional.ivy.experimental.broadcast_shapes", - shapes=nph.mutually_broadcastable_shapes( - num_shapes=4, min_dims=1, max_dims=5, min_side=1, max_side=5 + fn_tree="functional.ivy.experimental.dstack", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + shared_dtype=True, + num_arrays=helpers.ints(min_value=1, max_value=10), + shape=helpers.get_shape( + min_num_dims=1, + ), ), - test_instance_method=st.just(False), - test_with_out=st.just(False), test_gradients=st.just(False), ) -def test_broadcast_shapes(*, shapes, test_flags, backend_fw, fn_name, on_device): - shape, _ = shapes - shapes = {f"shape{i}": shape[i] for i in range(len(shape))} - test_flags.num_positional_args = len(shapes) +def test_dstack(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): + input_dtype, x = dtype_and_x helpers.test_function( - input_dtypes=["int64"], + input_dtypes=input_dtype, test_flags=test_flags, backend_to_test=backend_fw, fn_name=fn_name, on_device=on_device, - **shapes, + arrays=x, ) @@ -815,233 +673,256 @@ def test_expand(*, dtype_and_x, shape, test_flags, backend_fw, fn_name, on_devic ) -@st.composite -def _as_strided_helper(draw): - dtype, x = draw(helpers.dtype_and_values(min_num_dims=1, max_num_dims=5)) - x = x[0] - itemsize = x.itemsize - shape = draw(helpers.get_shape(min_num_dims=1, max_num_dims=5)) - new_ndim = len(shape) - strides = draw( - st.lists( - st.integers(min_value=1, max_value=16), - min_size=new_ndim, - max_size=new_ndim, - ).filter(lambda x: all(x[i] % itemsize == 0 for i in range(new_ndim))) +# fill_diag +@handle_test( + fn_tree="fill_diagonal", + dt_a=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_num_dims=2, + max_num_dims=4, + min_dim_size=3, + max_dim_size=3, + ), + v=st.sampled_from([1, 2, 3, 10]), + wrap=st.booleans(), + test_with_out=st.just(False), +) +def test_fill_diagonal( + *, + dt_a, + v, + wrap, + test_flags, + backend_fw, + fn_name, + on_device, +): + dt, a = dt_a + helpers.test_function( + input_dtypes=dt, + test_flags=test_flags, + on_device=on_device, + backend_to_test=backend_fw, + fn_name=fn_name, + a=a[0], + v=v, + wrap=wrap, + ) + + +@handle_test( + fn_tree="functional.ivy.experimental.flatten", + data=_flatten_data_helper(), +) +def test_flatten( + *, + data, + test_flags, + backend_fw, + fn_name, + on_device, +): + (input_dtypes, x), axes, order = data + helpers.test_function( + input_dtypes=input_dtypes, + test_flags=test_flags, + backend_to_test=backend_fw, + fn_name=fn_name, + on_device=on_device, + x=x[0], + start_dim=axes[0], + end_dim=axes[1], + order=order, ) - assume(_check_bounds(x.shape, shape, strides, itemsize)) - return dtype, x, shape, strides +# fliplr @handle_test( - fn_tree="as_strided", - all_args=_as_strided_helper(), - test_with_out=st.just(False), + fn_tree="functional.ivy.experimental.fliplr", + dtype_and_m=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + min_num_dims=2, + ), test_gradients=st.just(False), - ground_truth_backend="numpy", ) -def test_as_strided(*, all_args, test_flags, backend_fw, fn_name, on_device): - dtype, x, shape, strides = all_args +def test_fliplr(*, dtype_and_m, test_flags, backend_fw, fn_name, on_device): + input_dtype, m = dtype_and_m helpers.test_function( - input_dtypes=dtype, + input_dtypes=input_dtype, test_flags=test_flags, backend_to_test=backend_fw, fn_name=fn_name, on_device=on_device, - x=x, - shape=shape, - strides=strides, + m=m[0], ) -@st.composite -def _concat_from_sequence_helper(draw): - dtypes, arrays, shape = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - num_arrays=helpers.ints(min_value=1, max_value=6), - ret_shape=True, - min_num_dims=2, - min_dim_size=2, - shared_dtype=True, - ) - ) - axis = draw( - helpers.get_axis( - shape=shape, - force_int=True, - ) +# flipud +@handle_test( + fn_tree="functional.ivy.experimental.flipud", + dtype_and_m=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_value=-100, + max_value=100, + min_num_dims=1, + max_num_dims=3, + min_dim_size=1, + max_dim_size=3, + ), + test_gradients=st.just(False), +) +def test_flipud(*, dtype_and_m, test_flags, backend_fw, fn_name, on_device): + input_dtype, m = dtype_and_m + helpers.test_function( + input_dtypes=input_dtype, + test_flags=test_flags, + on_device=on_device, + backend_to_test=backend_fw, + fn_name=fn_name, + m=m[0], ) - return dtypes, arrays, axis -# concat_from_sequence @handle_test( - fn_tree="functional.ivy.experimental.concat_from_sequence", - dtypes_arrays_axis=_concat_from_sequence_helper(), - new_axis=st.integers(min_value=0, max_value=1), - container_flags=st.just([False]), - test_instance_method=st.just(False), + fn_tree="functional.ivy.experimental.fold", + data=_fold_data(), ) -def test_concat_from_sequence( - *, dtypes_arrays_axis, new_axis, test_flags, backend_fw, fn_name, on_device -): - dtypes, arrays, axis = dtypes_arrays_axis - +def test_fold(*, data, test_flags, backend_fw, fn_name, on_device): + input_dtype, input, shape, mode = data helpers.test_function( - input_dtypes=dtypes, - test_flags=test_flags, backend_to_test=backend_fw, + test_flags=test_flags, fn_name=fn_name, on_device=on_device, - input_sequence=arrays, - new_axis=new_axis, - axis=axis, + rtol_=1e-1, + atol_=1e-1, + input_dtypes=input_dtype, + x=input[0], + mode=mode, + shape=shape, ) -@st.composite -def _associative_scan_helper(draw): - input_dtype = draw( - st.shared( - st.sampled_from(draw(helpers.get_dtypes("float"))), - key="shared_dtype", - ).filter(lambda _x: "float16" not in _x) - ) - random_size = draw( - st.shared(helpers.ints(min_value=1, max_value=5), key="shared_size") - ) - shared_size = draw( - st.shared(helpers.ints(min_value=1, max_value=5), key="shared_size") - ) - shape = tuple([random_size, shared_size, shared_size]) - matrix = draw( - helpers.array_values( - dtype=input_dtype, - shape=shape, - min_value=1, - max_value=10, - ) - ) - axis = draw( - helpers.get_axis( - shape=shape, - allow_neg=False, - force_int=True, - ).filter(lambda _x: _x < len(shape) - 2) +# heaviside +@handle_test( + fn_tree="functional.ivy.experimental.heaviside", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_value=-100, + max_value=100, + min_num_dims=1, + max_num_dims=3, + min_dim_size=1, + max_dim_size=3, + num_arrays=2, + shared_dtype=True, + ), + test_gradients=st.just(False), +) +def test_heaviside(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): + input_dtype, x = dtype_and_x + helpers.test_function( + input_dtypes=input_dtype, + test_flags=test_flags, + on_device=on_device, + backend_to_test=backend_fw, + fn_name=fn_name, + x1=x[0], + x2=x[0], ) - return [input_dtype], matrix, axis -# associative_scan +# hsplit +# TODO: there is a failure with paddle (dtype('int32')) caused by the `_get_splits` +# method which returns a numpy array with a numpy dtype @handle_test( - fn_tree="functional.ivy.experimental.associative_scan", - dtype_elems_axis=_associative_scan_helper(), - fn=st.sampled_from([ivy.matmul, ivy.multiply, ivy.add]), - reverse=st.booleans(), + fn_tree="functional.ivy.experimental.hsplit", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + shape=st.shared(helpers.get_shape(min_num_dims=2), key="value_shape"), + ), + indices_or_sections=_get_splits(allow_none=False, min_num_dims=2, axis=1), + test_gradients=st.just(False), test_with_out=st.just(False), - ground_truth_backend="jax", ) -def test_associative_scan( - *, dtype_elems_axis, fn, reverse, fn_name, test_flags, backend_fw, on_device +def test_hsplit( + dtype_and_x, indices_or_sections, test_flags, backend_fw, fn_name, on_device ): - dtype, elems, axis = dtype_elems_axis + input_dtype, x = dtype_and_x + if ( + not isinstance(indices_or_sections, int) + and not isinstance(indices_or_sections, list) + and indices_or_sections is not None + ): + input_dtype = [*input_dtype, indices_or_sections.dtype] helpers.test_function( - fn_name=fn_name, + input_dtypes=input_dtype, test_flags=test_flags, backend_to_test=backend_fw, + fn_name=fn_name, on_device=on_device, - input_dtypes=dtype, - elems=elems, - fn=fn, - reverse=reverse, - axis=axis, + x=x[0], + indices_or_sections=indices_or_sections, ) -# unique_consecutive +# hstack @handle_test( - fn_tree="unique_consecutive", - dtype_x_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("numeric"), - min_num_dims=1, - min_dim_size=2, - force_int_axis=True, - valid_axis=True, + fn_tree="functional.ivy.experimental.hstack", + dtype_and_m=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + shared_dtype=True, + num_arrays=helpers.ints(min_value=2, max_value=10), + shape=helpers.get_shape( + min_num_dims=1, + ), ), - none_axis=st.booleans(), - test_with_out=st.just(False), test_gradients=st.just(False), - ground_truth_backend="torch", ) -def test_unique_consecutive( - *, dtype_x_axis, none_axis, test_flags, backend_fw, fn_name, on_device -): - dtype, x, axis = dtype_x_axis - if none_axis: - axis = None +def test_hstack(dtype_and_m, test_flags, backend_fw, fn_name, on_device): + input_dtype, m = dtype_and_m helpers.test_function( - input_dtypes=dtype, + input_dtypes=input_dtype, test_flags=test_flags, on_device=on_device, backend_to_test=backend_fw, fn_name=fn_name, - x=x[0], - axis=axis, + arrays=m, ) -# fill_diag +# i0 @handle_test( - fn_tree="fill_diagonal", - dt_a=helpers.dtype_and_values( + fn_tree="functional.ivy.experimental.i0", + dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - min_num_dims=2, - max_num_dims=4, - min_dim_size=3, + min_value=-10, + max_value=10, + min_num_dims=1, + max_num_dims=3, + min_dim_size=1, max_dim_size=3, ), - v=st.sampled_from([1, 2, 3, 10]), - wrap=st.booleans(), - test_with_out=st.just(False), + test_gradients=st.just(False), ) -def test_fill_diagonal( - *, - dt_a, - v, - wrap, - test_flags, - backend_fw, - fn_name, - on_device, -): - dt, a = dt_a +def test_i0(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): + input_dtype, x = dtype_and_x helpers.test_function( - input_dtypes=dt, + input_dtypes=input_dtype, test_flags=test_flags, - on_device=on_device, backend_to_test=backend_fw, fn_name=fn_name, - a=a[0], - v=v, - wrap=wrap, + on_device=on_device, + x=x[0], ) @handle_test( - fn_tree="functional.ivy.experimental.unfold", - dtype_values_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("valid"), - min_num_dims=1, - valid_axis=True, - allow_neg_axes=False, - force_int_axis=True, - ), + fn_tree="functional.ivy.experimental.matricize", + data=_matricize_data(), ) -def test_unfold(*, dtype_values_axis, test_flags, backend_fw, fn_name, on_device): - input_dtype, input, axis = dtype_values_axis - if axis is None: - axis = 0 +def test_matricize(*, data, test_flags, backend_fw, fn_name, on_device): + input_dtype, input, row_modes, column_modes = data helpers.test_function( backend_to_test=backend_fw, test_flags=test_flags, @@ -1051,133 +932,116 @@ def test_unfold(*, dtype_values_axis, test_flags, backend_fw, fn_name, on_device atol_=1e-1, input_dtypes=input_dtype, x=input[0], - mode=axis, - ) - - -@st.composite -def _fold_data(draw): - shape = draw( - helpers.get_shape( - min_num_dims=2, max_num_dims=5, min_dim_size=2, max_dim_size=3 - ) - ) - mode = draw(helpers.ints(min_value=0, max_value=len(shape) - 1)) - reduced_dims = int(ivy.prod(shape[0:mode]) * ivy.prod(shape[mode + 1 :])) - unfolded_shape = (shape[mode], reduced_dims) - dtype, input = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), shape=unfolded_shape - ) + row_modes=row_modes, + column_modes=column_modes, ) - return dtype, input, shape, mode +# moveaxis @handle_test( - fn_tree="functional.ivy.experimental.fold", - data=_fold_data(), + fn_tree="functional.ivy.experimental.moveaxis", + dtype_and_a=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_value=-100, + max_value=100, + shape=st.shared( + helpers.get_shape( + min_num_dims=1, + max_num_dims=3, + min_dim_size=1, + max_dim_size=3, + ), + key="a_s_d", + ), + ), + source=helpers.get_axis( + allow_none=False, + unique=True, + shape=st.shared( + helpers.get_shape( + min_num_dims=1, + max_num_dims=3, + min_dim_size=1, + max_dim_size=3, + ), + key="a_s_d", + ), + min_size=1, + force_int=True, + ), + destination=helpers.get_axis( + allow_none=False, + unique=True, + shape=st.shared( + helpers.get_shape( + min_num_dims=1, + max_num_dims=3, + min_dim_size=1, + max_dim_size=3, + ), + key="a_s_d", + ), + min_size=1, + force_int=True, + ), + test_gradients=st.just(False), ) -def test_fold(*, data, test_flags, backend_fw, fn_name, on_device): - input_dtype, input, shape, mode = data +def test_moveaxis( + *, dtype_and_a, source, destination, test_flags, backend_fw, fn_name, on_device +): + input_dtype, a = dtype_and_a helpers.test_function( - backend_to_test=backend_fw, + input_dtypes=input_dtype, test_flags=test_flags, - fn_name=fn_name, on_device=on_device, - rtol_=1e-1, - atol_=1e-1, - input_dtypes=input_dtype, - x=input[0], - mode=mode, - shape=shape, - ) - - -@st.composite -def _partial_unfold_data(draw): - dtype, input = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - min_num_dims=1, - ) - ) - ndims = len(input[0].shape) - mode_and_skip_begin = draw( - st.lists( - helpers.ints(min_value=0, max_value=ndims - 1), min_size=2, max_size=2 - ).filter(lambda nums: np.sum(nums) <= ndims - 1) - ) - skip_begin, mode = sorted(mode_and_skip_begin) - skip_end = draw( - helpers.ints(min_value=0, max_value=ndims - (skip_begin + mode) - 1) + backend_to_test=backend_fw, + fn_name=fn_name, + a=a[0], + source=source, + destination=destination, ) - ravel_tensors = draw(st.booleans()) - return dtype, input, mode, skip_begin, skip_end, ravel_tensors @handle_test( - fn_tree="functional.ivy.experimental.partial_unfold", - data=_partial_unfold_data(), + fn_tree="functional.ivy.experimental.pad", + ground_truth_backend="numpy", + dtype_and_input_and_other=_pad_helper(), + reflect_type=st.sampled_from(["even", "odd"]), + test_with_out=st.just(False), + test_gradients=st.just(False), ) -def test_partial_unfold(*, data, test_flags, backend_fw, fn_name, on_device): - input_dtype, input, axis, skip_begin, skip_end, ravel_tensors = data +def test_pad( + *, + dtype_and_input_and_other, + reflect_type, + test_flags, + backend_fw, + fn_name, + on_device, +): + ( + dtype, + input, + pad_width, + stat_length, + constant_values, + end_values, + mode, + ) = dtype_and_input_and_other helpers.test_function( - backend_to_test=backend_fw, + input_dtypes=dtype, test_flags=test_flags, + backend_to_test=backend_fw, fn_name=fn_name, on_device=on_device, - rtol_=1e-1, - atol_=1e-1, - input_dtypes=input_dtype, - input=input[0], - mode=axis, - skip_begin=skip_begin, - skip_end=skip_end, - ravel_tensors=ravel_tensors, - ) - - -@st.composite -def _partial_fold_data(draw): - shape = draw( - helpers.get_shape( - min_num_dims=2, max_num_dims=5, min_dim_size=2, max_dim_size=3 - ) - ) - ndims = len(shape) - mode_and_skip_begin = draw( - st.lists( - helpers.ints(min_value=0, max_value=ndims - 1), min_size=2, max_size=2 - ).filter(lambda nums: np.sum(nums) <= ndims - 1) - ) - skip_begin, mode = sorted(mode_and_skip_begin) - skip_end = draw( - helpers.ints(min_value=0, max_value=ndims - (skip_begin + mode) - 1) - ) - if skip_end != 0: - reduced_dims = int( - ivy.prod(shape[skip_begin : skip_begin + mode]) - * ivy.prod(shape[skip_begin + mode + 1 : -skip_end]) - ) - unfolded_shape = ( - *shape[:skip_begin], - shape[skip_begin + mode], - reduced_dims, - *shape[-skip_end:], - ) - else: - reduced_dims = int( - ivy.prod(shape[skip_begin : skip_begin + mode]) - * ivy.prod(shape[skip_begin + mode + 1 :]) - ) - unfolded_shape = (*shape[:skip_begin], shape[skip_begin + mode], reduced_dims) - - dtype, input = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), shape=unfolded_shape - ) + input=input, + pad_width=pad_width, + mode=mode, + stat_length=stat_length, + constant_values=constant_values, + end_values=end_values, + reflect_type=reflect_type, ) - return dtype, input, skip_begin, shape, mode @handle_test( @@ -1201,19 +1065,6 @@ def test_partial_fold(*, data, test_flags, backend_fw, fn_name, on_device): ) -@st.composite -def _partial_tensor_to_vec_data(draw): - input_dtype, input, shape = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), min_num_dims=1, ret_shape=True - ) - ) - ndims = len(shape) - skip_begin = draw(helpers.ints(min_value=0, max_value=ndims - 1)) - skip_end = draw(helpers.ints(min_value=0, max_value=ndims - 1 - skip_begin)) - return input_dtype, input, skip_begin, skip_end - - @handle_test( fn_tree="functional.ivy.experimental.partial_tensor_to_vec", data=_partial_tensor_to_vec_data(), @@ -1234,18 +1085,26 @@ def test_partial_tensor_to_vec(*, data, test_flags, backend_fw, fn_name, on_devi ) -@st.composite -def _partial_vec_to_tensor(draw): - shape = draw(helpers.get_shape(min_num_dims=1, max_num_dims=5)) - numel = int(ivy.prod(shape)) - input_dtype, input = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), shape=(numel,) - ) +@handle_test( + fn_tree="functional.ivy.experimental.partial_unfold", + data=_partial_unfold_data(), +) +def test_partial_unfold(*, data, test_flags, backend_fw, fn_name, on_device): + input_dtype, input, axis, skip_begin, skip_end, ravel_tensors = data + helpers.test_function( + backend_to_test=backend_fw, + test_flags=test_flags, + fn_name=fn_name, + on_device=on_device, + rtol_=1e-1, + atol_=1e-1, + input_dtypes=input_dtype, + input=input[0], + mode=axis, + skip_begin=skip_begin, + skip_end=skip_end, + ravel_tensors=ravel_tensors, ) - ndims = len(shape) - skip_begin = draw(helpers.ints(min_value=0, max_value=ndims - 1)) - return input_dtype, input, shape, skip_begin @handle_test( @@ -1268,31 +1127,38 @@ def test_partial_vec_to_tensor(*, data, test_flags, backend_fw, fn_name, on_devi ) -@st.composite -def _matricize_data(draw): - input_dtype, input, shape = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - ret_shape=True, - min_num_dims=2, - max_num_dims=5, - ) - ) - ndims = len(shape) - dims = set([*range(ndims)]) - row_modes = set( - draw(st.lists(helpers.ints(min_value=0, max_value=ndims - 1), min_size=1)) +# rot90 +@handle_test( + fn_tree="functional.ivy.experimental.rot90", + dtype_m_k_axes=_get_dtype_values_k_axes_for_rot90( + available_dtypes=helpers.get_dtypes("numeric"), + min_num_dims=1, + max_num_dims=5, + min_dim_size=1, + max_dim_size=10, + ), + test_gradients=st.just(False), +) +def test_rot90(dtype_m_k_axes, test_flags, backend_fw, fn_name, on_device): + input_dtype, m, k, axes = dtype_m_k_axes + helpers.test_function( + input_dtypes=input_dtype, + test_flags=test_flags, + on_device=on_device, + backend_to_test=backend_fw, + fn_name=fn_name, + m=m, + k=k, + axes=tuple(axes), ) - col_modes = dims - row_modes - return input_dtype, input, row_modes, col_modes @handle_test( - fn_tree="functional.ivy.experimental.matricize", - data=_matricize_data(), + fn_tree="functional.ivy.experimental.soft_thresholding", + data=_soft_thresholding_data(), ) -def test_matricize(*, data, test_flags, backend_fw, fn_name, on_device): - input_dtype, input, row_modes, column_modes = data +def test_soft_thresholding(*, data, test_flags, backend_fw, fn_name, on_device): + x_dtype, x, threshold = data helpers.test_function( backend_to_test=backend_fw, test_flags=test_flags, @@ -1300,43 +1166,98 @@ def test_matricize(*, data, test_flags, backend_fw, fn_name, on_device): on_device=on_device, rtol_=1e-1, atol_=1e-1, - input_dtypes=input_dtype, - x=input[0], - row_modes=row_modes, - column_modes=column_modes, + input_dtypes=x_dtype, + x=x[0], + threshold=threshold, ) -@st.composite -def _soft_thresholding_data(draw): - x_min, x_max = 1e-2, 1e2 - x_dtype, x, shape = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - ret_shape=True, - min_value=x_min, - max_value=x_max, - ) +# take_along_axis +@handle_test( + fn_tree="functional.ivy.experimental.take_along_axis", + dtype_x_indices_axis=helpers.array_indices_axis( + array_dtypes=helpers.get_dtypes("numeric"), + indices_dtypes=["int32", "int64"], + min_num_dims=1, + max_num_dims=5, + min_dim_size=1, + max_dim_size=10, + indices_same_dims=True, + valid_bounds=False, + ), + mode=st.sampled_from(["clip", "fill", "drop"]), + ground_truth_backend="jax", + test_gradients=st.just(False), +) +def test_take_along_axis( + *, + dtype_x_indices_axis, + mode, + test_flags, + backend_fw, + fn_name, + on_device, +): + dtypes, x, indices, axis, _ = dtype_x_indices_axis + helpers.test_function( + input_dtypes=dtypes, + test_flags=test_flags, + backend_to_test=backend_fw, + fn_name=fn_name, + on_device=on_device, + arr=x, + indices=indices, + axis=axis, + mode=mode, ) - threshold_choice_1 = draw(helpers.floats(min_value=x_min, max_value=x_max)) - t_dtype, threshold_choice_2 = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - shape=shape, - min_value=x_min, - max_value=x_max, - ) + + +# top_k +@handle_test( + fn_tree="functional.ivy.experimental.top_k", + dtype_x_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("numeric"), + min_num_dims=1, + force_int_axis=True, + valid_axis=True, + ), + k=helpers.ints(min_value=1, max_value=4), + largest=st.booleans(), + sorted=st.booleans(), + test_gradients=st.just(False), +) +def test_top_k( + *, dtype_x_axis, k, largest, sorted, test_flags, backend_fw, fn_name, on_device +): + dtype, x, axis = dtype_x_axis + helpers.test_function( + input_dtypes=dtype, + test_flags=test_flags, + backend_to_test=backend_fw, + fn_name=fn_name, + on_device=on_device, + x=x[0], + k=k, + axis=axis, + largest=largest, + sorted=sorted, ) - threshold = draw(st.sampled_from([threshold_choice_1, threshold_choice_2])) - return x_dtype + t_dtype, x, threshold @handle_test( - fn_tree="functional.ivy.experimental.soft_thresholding", - data=_soft_thresholding_data(), + fn_tree="functional.ivy.experimental.unfold", + dtype_values_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=1, + valid_axis=True, + allow_neg_axes=False, + force_int_axis=True, + ), ) -def test_soft_thresholding(*, data, test_flags, backend_fw, fn_name, on_device): - x_dtype, x, threshold = data +def test_unfold(*, dtype_values_axis, test_flags, backend_fw, fn_name, on_device): + input_dtype, input, axis = dtype_values_axis + if axis is None: + axis = 0 helpers.test_function( backend_to_test=backend_fw, test_flags=test_flags, @@ -1344,7 +1265,90 @@ def test_soft_thresholding(*, data, test_flags, backend_fw, fn_name, on_device): on_device=on_device, rtol_=1e-1, atol_=1e-1, - input_dtypes=x_dtype, + input_dtypes=input_dtype, + x=input[0], + mode=axis, + ) + + +# unique_consecutive +@handle_test( + fn_tree="unique_consecutive", + dtype_x_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("numeric"), + min_num_dims=1, + min_dim_size=2, + force_int_axis=True, + valid_axis=True, + ), + none_axis=st.booleans(), + test_with_out=st.just(False), + test_gradients=st.just(False), + ground_truth_backend="torch", +) +def test_unique_consecutive( + *, dtype_x_axis, none_axis, test_flags, backend_fw, fn_name, on_device +): + dtype, x, axis = dtype_x_axis + if none_axis: + axis = None + helpers.test_function( + input_dtypes=dtype, + test_flags=test_flags, + on_device=on_device, + backend_to_test=backend_fw, + fn_name=fn_name, x=x[0], - threshold=threshold, + axis=axis, + ) + + +# vsplit +@handle_test( + fn_tree="functional.ivy.experimental.vsplit", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + shape=st.shared(helpers.get_shape(min_num_dims=2), key="value_shape"), + ), + indices_or_sections=_get_splits(allow_none=False, min_num_dims=2, axis=0), + test_gradients=st.just(False), + test_with_out=st.just(False), +) +def test_vsplit( + dtype_and_x, indices_or_sections, test_flags, backend_fw, fn_name, on_device +): + input_dtype, x = dtype_and_x + helpers.test_function( + input_dtypes=input_dtype, + on_device=on_device, + test_flags=test_flags, + backend_to_test=backend_fw, + fn_name=fn_name, + x=x[0], + indices_or_sections=indices_or_sections, + ) + + +# vstack +@handle_test( + fn_tree="functional.ivy.experimental.vstack", + dtype_and_m=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + shape=helpers.get_shape( + min_num_dims=1, + ), + shared_dtype=True, + num_arrays=helpers.ints(min_value=2, max_value=10), + ), + test_gradients=st.just(False), +) +def test_vstack(*, dtype_and_m, test_flags, backend_fw, fn_name, on_device): + input_dtype, m = dtype_and_m + helpers.test_function( + input_dtypes=input_dtype, + test_flags=test_flags, + on_device=on_device, + backend_to_test=backend_fw, + fn_name=fn_name, + arrays=m, ) diff --git a/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_random.py b/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_random.py index 3372550df06fe..5f53a416ebd1a 100644 --- a/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_random.py +++ b/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_random.py @@ -6,63 +6,37 @@ from ivy_tests.test_ivy.helpers import handle_test, BackendHandler -# Helpers # -# ------- # - - -# dirichlet @handle_test( - fn_tree="functional.ivy.experimental.dirichlet", - dtype_and_alpha=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - shape=st.tuples( - st.integers(min_value=2, max_value=5), - ), + fn_tree="functional.ivy.experimental.bernoulli", + dtype_and_probs=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float", full=False), min_value=0, - max_value=100, - exclude_min=True, - ), - size=st.tuples( - st.integers(min_value=2, max_value=5), st.integers(min_value=2, max_value=5) + max_value=1, + min_num_dims=0, ), seed=helpers.ints(min_value=0, max_value=100), test_gradients=st.just(False), ) -def test_dirichlet( - *, dtype_and_alpha, size, seed, test_flags, backend_fw, fn_name, on_device +def test_bernoulli( + *, dtype_and_probs, seed, test_flags, backend_fw, fn_name, on_device ): - dtype, alpha = dtype_and_alpha - assume("bfloat16" not in dtype) - - def call(): - return helpers.test_function( - input_dtypes=dtype, - test_flags=test_flags, - test_values=False, - backend_to_test=backend_fw, - fn_name=fn_name, - on_device=on_device, - alpha=alpha[0], - size=size, - seed=seed, - ) - - ret, ret_gt = call() - with BackendHandler.update_backend(backend_fw) as ivy_backend: - if seed: - ret1, ret_gt1 = call() - assert ivy_backend.any(ret == ret1) - ret = helpers.flatten_and_to_np(ret=ret, backend=backend_fw) - ret_gt = helpers.flatten_and_to_np( - ret=ret_gt, backend=test_flags.ground_truth_backend - ) - for u, v in zip(ret, ret_gt): - u, v = ivy_backend.array(u), ivy_backend.array(v) - assert ivy_backend.all( - ivy_backend.sum(u, axis=-1) == ivy_backend.sum(v, axis=-1) - ) - assert ivy_backend.all(u >= 0) and ivy_backend.all(u <= 1) - assert ivy_backend.all(v >= 0) and ivy_backend.all(v <= 1) + dtype, probs = dtype_and_probs + # torch doesn't support half precision on CPU + assume( + not ("torch" in str(backend_fw) and "float16" in dtype and on_device == "cpu") + ) + helpers.test_function( + input_dtypes=dtype, + test_flags=test_flags, + on_device=on_device, + backend_to_test=backend_fw, + fn_name=fn_name, + test_values=False, + probs=probs[0], + logits=None, + shape=None, + seed=seed, + ) # beta @@ -114,6 +88,61 @@ def test_beta( assert ivy_backend.all(v >= 0) and ivy_backend.all(v <= 1) +# dirichlet +@handle_test( + fn_tree="functional.ivy.experimental.dirichlet", + dtype_and_alpha=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + shape=st.tuples( + st.integers(min_value=2, max_value=5), + ), + min_value=0, + max_value=100, + exclude_min=True, + ), + size=st.tuples( + st.integers(min_value=2, max_value=5), st.integers(min_value=2, max_value=5) + ), + seed=helpers.ints(min_value=0, max_value=100), + test_gradients=st.just(False), +) +def test_dirichlet( + *, dtype_and_alpha, size, seed, test_flags, backend_fw, fn_name, on_device +): + dtype, alpha = dtype_and_alpha + assume("bfloat16" not in dtype) + + def call(): + return helpers.test_function( + input_dtypes=dtype, + test_flags=test_flags, + test_values=False, + backend_to_test=backend_fw, + fn_name=fn_name, + on_device=on_device, + alpha=alpha[0], + size=size, + seed=seed, + ) + + ret, ret_gt = call() + with BackendHandler.update_backend(backend_fw) as ivy_backend: + if seed: + ret1, ret_gt1 = call() + assert ivy_backend.any(ret == ret1) + ret = helpers.flatten_and_to_np(ret=ret, backend=backend_fw) + ret_gt = helpers.flatten_and_to_np( + ret=ret_gt, backend=test_flags.ground_truth_backend + ) + for u, v in zip(ret, ret_gt): + u, v = ivy_backend.array(u), ivy_backend.array(v) + assert ivy_backend.all( + ivy_backend.sum(u, axis=-1) == ivy_backend.sum(v, axis=-1) + ) + assert ivy_backend.all(u >= 0) and ivy_backend.all(u <= 1) + assert ivy_backend.all(v >= 0) and ivy_backend.all(v <= 1) + + # gamma @handle_test( fn_tree="functional.ivy.experimental.gamma", @@ -213,36 +242,3 @@ def call(): for u, v in zip(ret, ret_gt): assert u.dtype == v.dtype assert u.shape == v.shape - - -@handle_test( - fn_tree="functional.ivy.experimental.bernoulli", - dtype_and_probs=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float", full=False), - min_value=0, - max_value=1, - min_num_dims=0, - ), - seed=helpers.ints(min_value=0, max_value=100), - test_gradients=st.just(False), -) -def test_bernoulli( - *, dtype_and_probs, seed, test_flags, backend_fw, fn_name, on_device -): - dtype, probs = dtype_and_probs - # torch doesn't support half precision on CPU - assume( - not ("torch" in str(backend_fw) and "float16" in dtype and on_device == "cpu") - ) - helpers.test_function( - input_dtypes=dtype, - test_flags=test_flags, - on_device=on_device, - backend_to_test=backend_fw, - fn_name=fn_name, - test_values=False, - probs=probs[0], - logits=None, - shape=None, - seed=seed, - ) diff --git a/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_searching.py b/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_searching.py index d10e00d17d6fe..ff91b8f1ed47b 100644 --- a/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_searching.py +++ b/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_searching.py @@ -7,6 +7,10 @@ from ivy_tests.test_ivy.helpers import handle_test +# --- Helpers --- # +# --------------- # + + # unravel_index @st.composite def max_value_as_shape_prod(draw): diff --git a/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_sorting.py b/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_sorting.py index 6c8e666540b6b..5a325efb85620 100644 --- a/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_sorting.py +++ b/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_sorting.py @@ -7,6 +7,10 @@ from ivy_tests.test_ivy.helpers import handle_test +# --- Helpers --- # +# --------------- # + + @st.composite def _invert_permutation_helper(draw): perm = draw( @@ -18,6 +22,10 @@ def _invert_permutation_helper(draw): return dtype, perm +# --- Main --- # +# ------------ # + + # invert_permutation @handle_test( fn_tree="functional.ivy.experimental.invert_permutation", diff --git a/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_sparse_array.py b/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_sparse_array.py index f712fea1d261a..8082dc60733ce 100644 --- a/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_sparse_array.py +++ b/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_sparse_array.py @@ -5,58 +5,51 @@ import ivy_tests.test_ivy.helpers as helpers from ivy_tests.test_ivy.helpers import handle_method -# Helpers # -# ------- # - -@st.composite -def _sparse_coo_indices_values_shape(draw): - num_elem = draw(helpers.ints(min_value=2, max_value=8)) - dim1 = draw(helpers.ints(min_value=2, max_value=5)) - dim2 = draw(helpers.ints(min_value=5, max_value=10)) - value_dtype = draw(helpers.get_dtypes("numeric", full=False))[0] - coo_indices = draw( - helpers.array_values( - dtype="int64", - shape=(2, num_elem), - min_value=0, - max_value=dim1, - exclude_min=False, - ) - ) - values = draw(helpers.array_values(dtype=value_dtype, shape=(num_elem,))) - shape = (dim1, dim2) - return coo_indices, value_dtype, values, shape +# --- Helpers --- # +# --------------- # @st.composite -def _sparse_csr_indices_values_shape(draw): - num_elem = draw(helpers.ints(min_value=2, max_value=8)) +def _sparse_bsc_indices_values_shape(draw): + nblockrows = draw(helpers.ints(min_value=2, max_value=5)) + nblockcols = draw(helpers.ints(min_value=2, max_value=5)) + dim1 = draw(helpers.ints(min_value=2, max_value=5)) - dim2 = draw(helpers.ints(min_value=5, max_value=10)) + dim2 = draw(helpers.ints(min_value=3, max_value=5)) + value_dtype = draw(helpers.get_dtypes("numeric", full=False))[0] - values = draw(helpers.array_values(dtype=value_dtype, shape=(num_elem,))) - col_indices = draw( - helpers.array_values( - dtype="int64", - shape=(num_elem,), - min_value=0, - max_value=dim2, - exclude_min=False, - ) + + ccol_indices, row_indices, values = ( + [0], + [], + [ + [ + [], + ], + ], ) - indices = draw( + for _ in range(dim2): + index = draw( + helpers.ints( + min_value=max(ccol_indices[-1] + 1, 1), + max_value=ccol_indices[-1] + dim1, + ) + ) + cur_num_elem = index - ccol_indices[-1] + row_indices += list(range(cur_num_elem)) + ccol_indices.append(index) + + shape = (dim1 * nblockrows, dim2 * nblockcols) + values = draw( helpers.array_values( - dtype="int64", - shape=(dim1 - 1,), + dtype=value_dtype, + shape=(ccol_indices[-1], nblockrows, nblockcols), min_value=0, - max_value=num_elem, - exclude_min=False, ) ) - crow_indices = [0] + sorted(indices) + [num_elem] - shape = (dim1, dim2) - return crow_indices, col_indices, value_dtype, values, shape + + return ccol_indices, row_indices, value_dtype, values, shape @st.composite @@ -101,6 +94,26 @@ def _sparse_bsr_indices_values_shape(draw): return crow_indices, col_indices, value_dtype, values, shape +@st.composite +def _sparse_coo_indices_values_shape(draw): + num_elem = draw(helpers.ints(min_value=2, max_value=8)) + dim1 = draw(helpers.ints(min_value=2, max_value=5)) + dim2 = draw(helpers.ints(min_value=5, max_value=10)) + value_dtype = draw(helpers.get_dtypes("numeric", full=False))[0] + coo_indices = draw( + helpers.array_values( + dtype="int64", + shape=(2, num_elem), + min_value=0, + max_value=dim1, + exclude_min=False, + ) + ) + values = draw(helpers.array_values(dtype=value_dtype, shape=(num_elem,))) + shape = (dim1, dim2) + return coo_indices, value_dtype, values, shape + + @st.composite def _sparse_csc_indices_values_shape(draw): num_elem = draw(helpers.ints(min_value=2, max_value=8)) @@ -132,77 +145,70 @@ def _sparse_csc_indices_values_shape(draw): @st.composite -def _sparse_bsc_indices_values_shape(draw): - nblockrows = draw(helpers.ints(min_value=2, max_value=5)) - nblockcols = draw(helpers.ints(min_value=2, max_value=5)) - +def _sparse_csr_indices_values_shape(draw): + num_elem = draw(helpers.ints(min_value=2, max_value=8)) dim1 = draw(helpers.ints(min_value=2, max_value=5)) - dim2 = draw(helpers.ints(min_value=3, max_value=5)) - + dim2 = draw(helpers.ints(min_value=5, max_value=10)) value_dtype = draw(helpers.get_dtypes("numeric", full=False))[0] - - ccol_indices, row_indices, values = ( - [0], - [], - [ - [ - [], - ], - ], - ) - for _ in range(dim2): - index = draw( - helpers.ints( - min_value=max(ccol_indices[-1] + 1, 1), - max_value=ccol_indices[-1] + dim1, - ) + values = draw(helpers.array_values(dtype=value_dtype, shape=(num_elem,))) + col_indices = draw( + helpers.array_values( + dtype="int64", + shape=(num_elem,), + min_value=0, + max_value=dim2, + exclude_min=False, ) - cur_num_elem = index - ccol_indices[-1] - row_indices += list(range(cur_num_elem)) - ccol_indices.append(index) - - shape = (dim1 * nblockrows, dim2 * nblockcols) - values = draw( + ) + indices = draw( helpers.array_values( - dtype=value_dtype, - shape=(ccol_indices[-1], nblockrows, nblockcols), + dtype="int64", + shape=(dim1 - 1,), min_value=0, + max_value=num_elem, + exclude_min=False, ) ) + crow_indices = [0] + sorted(indices) + [num_elem] + shape = (dim1, dim2) + return crow_indices, col_indices, value_dtype, values, shape - return ccol_indices, row_indices, value_dtype, values, shape +# --- Main --- # +# ------------ # -# coo - to_dense_array + +# bsc - to_dense_array @handle_method( method_tree="SparseArray.to_dense_array", - sparse_data=_sparse_coo_indices_values_shape(), + sparse_data=_sparse_bsc_indices_values_shape(), method_num_positional_args=st.just(0), # TODO should not be hardcoded init_num_positional_args=st.just(0), # TODO should not be hardcoded ) -def test_sparse_coo( +def test_sparse_bsc( sparse_data, class_name, method_name, + on_device, backend_fw, + ground_truth_backend, init_flags, method_flags, - on_device, - ground_truth_backend, ): - coo_ind, val_dtype, val, shp = sparse_data + ccol_indices, row_indices, value_dtype, values, shape = sparse_data helpers.test_method( backend_to_test=backend_fw, ground_truth_backend=ground_truth_backend, init_flags=init_flags, method_flags=method_flags, on_device=on_device, - init_input_dtypes=["int64", val_dtype], + init_input_dtypes=["int64", "int64", value_dtype], init_all_as_kwargs_np={ - "coo_indices": coo_ind, - "values": val, - "dense_shape": shp, - "format": "coo", + "ccol_indices": ccol_indices, + "row_indices": row_indices, + "values": values, + "dense_shape": shape, + "format": "bsc", }, method_input_dtypes=[], method_all_as_kwargs_np={}, @@ -211,21 +217,21 @@ def test_sparse_coo( ) -# csr - to_dense_array +# bsr - to_dense_array @handle_method( method_tree="SparseArray.to_dense_array", - sparse_data=_sparse_csr_indices_values_shape(), + sparse_data=_sparse_bsr_indices_values_shape(), method_num_positional_args=st.just(0), # TODO should not be hardcoded init_num_positional_args=st.just(0), # TODO should not be hardcoded ) -def test_sparse_csr( +def test_sparse_bsr( sparse_data, class_name, method_name, + on_device, backend_fw, ground_truth_backend, init_flags, - on_device, method_flags, ): crow_indices, col_indices, value_dtype, values, shape = sparse_data @@ -233,15 +239,15 @@ def test_sparse_csr( backend_to_test=backend_fw, ground_truth_backend=ground_truth_backend, init_flags=init_flags, - method_flags=method_flags, on_device=on_device, + method_flags=method_flags, init_input_dtypes=["int64", "int64", value_dtype], init_all_as_kwargs_np={ "crow_indices": crow_indices, "col_indices": col_indices, "values": values, "dense_shape": shape, - "format": "csr", + "format": "bsr", }, method_input_dtypes=[], method_all_as_kwargs_np={}, @@ -250,37 +256,36 @@ def test_sparse_csr( ) -# csc - to_dense_array +# coo - to_dense_array @handle_method( method_tree="SparseArray.to_dense_array", - sparse_data=_sparse_csc_indices_values_shape(), + sparse_data=_sparse_coo_indices_values_shape(), method_num_positional_args=st.just(0), # TODO should not be hardcoded init_num_positional_args=st.just(0), # TODO should not be hardcoded ) -def test_sparse_csc( +def test_sparse_coo( sparse_data, class_name, method_name, backend_fw, - ground_truth_backend, init_flags, - on_device, method_flags, + on_device, + ground_truth_backend, ): - ccol_indices, row_indices, value_dtype, values, shape = sparse_data + coo_ind, val_dtype, val, shp = sparse_data helpers.test_method( backend_to_test=backend_fw, ground_truth_backend=ground_truth_backend, init_flags=init_flags, method_flags=method_flags, on_device=on_device, - init_input_dtypes=["int64", "int64", value_dtype], + init_input_dtypes=["int64", val_dtype], init_all_as_kwargs_np={ - "ccol_indices": ccol_indices, - "row_indices": row_indices, - "values": values, - "dense_shape": shape, - "format": "csc", + "coo_indices": coo_ind, + "values": val, + "dense_shape": shp, + "format": "coo", }, method_input_dtypes=[], method_all_as_kwargs_np={}, @@ -289,21 +294,21 @@ def test_sparse_csc( ) -# bsc - to_dense_array +# csc - to_dense_array @handle_method( method_tree="SparseArray.to_dense_array", - sparse_data=_sparse_bsc_indices_values_shape(), + sparse_data=_sparse_csc_indices_values_shape(), method_num_positional_args=st.just(0), # TODO should not be hardcoded init_num_positional_args=st.just(0), # TODO should not be hardcoded ) -def test_sparse_bsc( +def test_sparse_csc( sparse_data, class_name, method_name, - on_device, backend_fw, ground_truth_backend, init_flags, + on_device, method_flags, ): ccol_indices, row_indices, value_dtype, values, shape = sparse_data @@ -319,7 +324,7 @@ def test_sparse_bsc( "row_indices": row_indices, "values": values, "dense_shape": shape, - "format": "bsc", + "format": "csc", }, method_input_dtypes=[], method_all_as_kwargs_np={}, @@ -328,21 +333,21 @@ def test_sparse_bsc( ) -# bsr - to_dense_array +# csr - to_dense_array @handle_method( method_tree="SparseArray.to_dense_array", - sparse_data=_sparse_bsr_indices_values_shape(), + sparse_data=_sparse_csr_indices_values_shape(), method_num_positional_args=st.just(0), # TODO should not be hardcoded init_num_positional_args=st.just(0), # TODO should not be hardcoded ) -def test_sparse_bsr( +def test_sparse_csr( sparse_data, class_name, method_name, - on_device, backend_fw, ground_truth_backend, init_flags, + on_device, method_flags, ): crow_indices, col_indices, value_dtype, values, shape = sparse_data @@ -350,15 +355,15 @@ def test_sparse_bsr( backend_to_test=backend_fw, ground_truth_backend=ground_truth_backend, init_flags=init_flags, - on_device=on_device, method_flags=method_flags, + on_device=on_device, init_input_dtypes=["int64", "int64", value_dtype], init_all_as_kwargs_np={ "crow_indices": crow_indices, "col_indices": col_indices, "values": values, "dense_shape": shape, - "format": "bsr", + "format": "csr", }, method_input_dtypes=[], method_all_as_kwargs_np={}, diff --git a/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_statistical.py b/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_statistical.py index 3e0c36038cca1..f43f00c3f4526 100644 --- a/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_statistical.py +++ b/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_statistical.py @@ -11,8 +11,97 @@ ) -# Helpers # -# ------- # +# --- Helpers --- # +# --------------- # + + +@st.composite +def _get_dtype_value1_value2_cov( + draw, + available_dtypes, + min_num_dims, + max_num_dims, + min_dim_size, + max_dim_size, + abs_smallest_val=None, + min_value=None, + max_value=None, + allow_inf=False, + exclude_min=False, + exclude_max=False, + large_abs_safety_factor=4, + small_abs_safety_factor=4, + safety_factor_scale="log", +): + shape = draw( + helpers.get_shape( + allow_none=False, + min_num_dims=min_num_dims, + max_num_dims=max_num_dims, + min_dim_size=min_dim_size, + max_dim_size=max_dim_size, + ) + ) + + dtype = draw(st.sampled_from(draw(available_dtypes))) + + values = [] + for i in range(2): + values.append( + draw( + helpers.array_values( + dtype=dtype, + shape=shape, + abs_smallest_val=abs_smallest_val, + min_value=min_value, + max_value=max_value, + allow_inf=allow_inf, + exclude_min=exclude_min, + exclude_max=exclude_max, + large_abs_safety_factor=large_abs_safety_factor, + small_abs_safety_factor=small_abs_safety_factor, + safety_factor_scale=safety_factor_scale, + ) + ) + ) + + value1, value2 = values[0], values[1] + + # modifiers: rowVar, bias, ddof + rowVar = draw(st.booleans()) + bias = draw(st.booleans()) + ddof = draw(helpers.ints(min_value=0, max_value=1)) + + numVals = None + if rowVar is False: + numVals = -1 if numVals == 0 else 0 + else: + numVals = 0 if len(shape) == 1 else -1 + + fweights = draw( + helpers.array_values( + dtype="int64", + shape=shape[numVals], + abs_smallest_val=1, + min_value=1, + max_value=10, + allow_inf=False, + ) + ) + + aweights = draw( + helpers.array_values( + dtype="float64", + shape=shape[numVals], + abs_smallest_val=1, + min_value=1, + max_value=10, + allow_inf=False, + small_abs_safety_factor=1, + ) + ) + + return [dtype], value1, value2, rowVar, bias, ddof, fweights, aweights @st.composite @@ -130,102 +219,6 @@ def _histogram_helper(draw): ) -# TODO: - Error message from Tensorflow: 'Number of dimensions of `x` and `weights` -# must coincide. Found: x has , weights has ' -# - Error description: typo that throws unintended exceptions when using both -# weights and multiple axis. -# - fixed in TFP 0.20 release. -# - Test helper needs to be modified to handle this case in older verions. -@handle_test( - fn_tree="functional.ivy.experimental.histogram", - values=_histogram_helper(), - test_gradients=st.just(False), -) -def test_histogram( - *, - values, - test_flags, - backend_fw, - fn_name, - on_device, -): - ( - a, - bins, - axis, - extend_lower_interval, - extend_upper_interval, - dtype, - range, - weights, - density, - dtype_input, - ) = values - helpers.test_function( - a=a, - bins=bins, - axis=axis, - extend_lower_interval=extend_lower_interval, - extend_upper_interval=extend_upper_interval, - dtype=dtype, - range=range, - weights=weights, - density=density, - input_dtypes=[dtype_input], - test_flags=test_flags, - backend_to_test=backend_fw, - fn_name=fn_name, - on_device=on_device, - ) - - -@handle_test( - fn_tree="functional.ivy.experimental.median", - dtype_x_axis=_statistical_dtype_values(function="median"), - keep_dims=st.booleans(), - test_gradients=st.just(False), - test_with_out=st.just(False), -) -def test_median(*, dtype_x_axis, keep_dims, test_flags, backend_fw, fn_name, on_device): - input_dtype, x, axis = dtype_x_axis - helpers.test_function( - input_dtypes=input_dtype, - test_flags=test_flags, - on_device=on_device, - backend_to_test=backend_fw, - fn_name=fn_name, - input=x[0], - axis=axis, - keepdims=keep_dims, - ) - - -# nanmean -@handle_test( - fn_tree="functional.ivy.experimental.nanmean", - dtype_x_axis=_statistical_dtype_values(function="nanmean"), - keep_dims=st.booleans(), - dtype=helpers.get_dtypes("float", full=False), - test_gradients=st.just(False), -) -def test_nanmean( - *, dtype_x_axis, keep_dims, dtype, test_flags, backend_fw, fn_name, on_device -): - input_dtype, x, axis = dtype_x_axis - helpers.test_function( - input_dtypes=input_dtype, - test_flags=test_flags, - atol_=1e-02, - backend_to_test=backend_fw, - fn_name=fn_name, - on_device=on_device, - a=x[0], - axis=axis, - keepdims=keep_dims, - dtype=dtype[0], - ) - - @st.composite def _quantile_helper(draw): large_abs_safety_factor = 2 @@ -266,65 +259,6 @@ def _quantile_helper(draw): return dtype, values, axis, interpolation, q -# quantile -@handle_test( - fn_tree="functional.ivy.experimental.quantile", - dtype_and_x=_quantile_helper(), - keep_dims=st.booleans(), - test_gradients=st.just(False), - test_with_out=st.just(False), -) -def test_quantile( - *, dtype_and_x, keep_dims, test_flags, backend_fw, fn_name, on_device -): - input_dtype, x, axis, interpolation, q = dtype_and_x - helpers.test_function( - input_dtypes=input_dtype, - test_flags=test_flags, - backend_to_test=backend_fw, - fn_name=fn_name, - on_device=on_device, - a=x[0], - q=q, - axis=axis, - interpolation=interpolation[0], - keepdims=keep_dims, - ) - - -# corrcoef -@handle_test( - fn_tree="functional.ivy.experimental.corrcoef", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=["float32", "float64"], - num_arrays=2, - shared_dtype=True, - abs_smallest_val=1e-5, - min_num_dims=2, - max_num_dims=2, - min_dim_size=3, - max_dim_size=3, - min_value=-100, - max_value=100, - allow_nan=False, - ), - rowvar=st.booleans(), - test_gradients=st.just(False), -) -def test_corrcoef(*, dtype_and_x, rowvar, test_flags, backend_fw, fn_name, on_device): - input_dtype, x = dtype_and_x - helpers.test_function( - input_dtypes=input_dtype, - test_flags=test_flags, - backend_to_test=backend_fw, - fn_name=fn_name, - on_device=on_device, - x=x[0], - y=x[1], - rowvar=rowvar, - ) - - # bincount @st.composite def bincount_dtype_and_values(draw): @@ -350,6 +284,10 @@ def bincount_dtype_and_values(draw): return dtype_and_x, min_length +# --- Main --- # +# ------------ # + + @handle_test( fn_tree="functional.ivy.experimental.bincount", dtype_and_x=bincount_dtype_and_values(), @@ -370,121 +308,38 @@ def test_bincount(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): ) -# igamma +# corrcoef @handle_test( - fn_tree="functional.ivy.experimental.igamma", + fn_tree="functional.ivy.experimental.corrcoef", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), + available_dtypes=["float32", "float64"], num_arrays=2, shared_dtype=True, - min_value=2, + abs_smallest_val=1e-5, + min_num_dims=2, + max_num_dims=2, + min_dim_size=3, + max_dim_size=3, + min_value=-100, max_value=100, + allow_nan=False, ), + rowvar=st.booleans(), test_gradients=st.just(False), - test_with_out=st.just(False), ) -def test_igamma(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): +def test_corrcoef(*, dtype_and_x, rowvar, test_flags, backend_fw, fn_name, on_device): input_dtype, x = dtype_and_x helpers.test_function( input_dtypes=input_dtype, test_flags=test_flags, - on_device=on_device, backend_to_test=backend_fw, fn_name=fn_name, - rtol_=1e-04, - a=x[0], - x=x[1], - ) - - -@st.composite -def _get_dtype_value1_value2_cov( - draw, - available_dtypes, - min_num_dims, - max_num_dims, - min_dim_size, - max_dim_size, - abs_smallest_val=None, - min_value=None, - max_value=None, - allow_inf=False, - exclude_min=False, - exclude_max=False, - large_abs_safety_factor=4, - small_abs_safety_factor=4, - safety_factor_scale="log", -): - shape = draw( - helpers.get_shape( - allow_none=False, - min_num_dims=min_num_dims, - max_num_dims=max_num_dims, - min_dim_size=min_dim_size, - max_dim_size=max_dim_size, - ) - ) - - dtype = draw(st.sampled_from(draw(available_dtypes))) - - values = [] - for i in range(2): - values.append( - draw( - helpers.array_values( - dtype=dtype, - shape=shape, - abs_smallest_val=abs_smallest_val, - min_value=min_value, - max_value=max_value, - allow_inf=allow_inf, - exclude_min=exclude_min, - exclude_max=exclude_max, - large_abs_safety_factor=large_abs_safety_factor, - small_abs_safety_factor=small_abs_safety_factor, - safety_factor_scale=safety_factor_scale, - ) - ) - ) - - value1, value2 = values[0], values[1] - - # modifiers: rowVar, bias, ddof - rowVar = draw(st.booleans()) - bias = draw(st.booleans()) - ddof = draw(helpers.ints(min_value=0, max_value=1)) - - numVals = None - if rowVar is False: - numVals = -1 if numVals == 0 else 0 - else: - numVals = 0 if len(shape) == 1 else -1 - - fweights = draw( - helpers.array_values( - dtype="int64", - shape=shape[numVals], - abs_smallest_val=1, - min_value=1, - max_value=10, - allow_inf=False, - ) - ) - - aweights = draw( - helpers.array_values( - dtype="float64", - shape=shape[numVals], - abs_smallest_val=1, - min_value=1, - max_value=10, - allow_inf=False, - small_abs_safety_factor=1, - ) + on_device=on_device, + x=x[0], + y=x[1], + rowvar=rowvar, ) - return [dtype], value1, value2, rowVar, bias, ddof, fweights, aweights - # cov @handle_test( @@ -600,6 +455,129 @@ def test_cummin( ) +# TODO: - Error message from Tensorflow: 'Number of dimensions of `x` and `weights` +# must coincide. Found: x has , weights has ' +# - Error description: typo that throws unintended exceptions when using both +# weights and multiple axis. +# - fixed in TFP 0.20 release. +# - Test helper needs to be modified to handle this case in older verions. +@handle_test( + fn_tree="functional.ivy.experimental.histogram", + values=_histogram_helper(), + test_gradients=st.just(False), +) +def test_histogram( + *, + values, + test_flags, + backend_fw, + fn_name, + on_device, +): + ( + a, + bins, + axis, + extend_lower_interval, + extend_upper_interval, + dtype, + range, + weights, + density, + dtype_input, + ) = values + helpers.test_function( + a=a, + bins=bins, + axis=axis, + extend_lower_interval=extend_lower_interval, + extend_upper_interval=extend_upper_interval, + dtype=dtype, + range=range, + weights=weights, + density=density, + input_dtypes=[dtype_input], + test_flags=test_flags, + backend_to_test=backend_fw, + fn_name=fn_name, + on_device=on_device, + ) + + +# igamma +@handle_test( + fn_tree="functional.ivy.experimental.igamma", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + num_arrays=2, + shared_dtype=True, + min_value=2, + max_value=100, + ), + test_gradients=st.just(False), + test_with_out=st.just(False), +) +def test_igamma(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): + input_dtype, x = dtype_and_x + helpers.test_function( + input_dtypes=input_dtype, + test_flags=test_flags, + on_device=on_device, + backend_to_test=backend_fw, + fn_name=fn_name, + rtol_=1e-04, + a=x[0], + x=x[1], + ) + + +@handle_test( + fn_tree="functional.ivy.experimental.median", + dtype_x_axis=_statistical_dtype_values(function="median"), + keep_dims=st.booleans(), + test_gradients=st.just(False), + test_with_out=st.just(False), +) +def test_median(*, dtype_x_axis, keep_dims, test_flags, backend_fw, fn_name, on_device): + input_dtype, x, axis = dtype_x_axis + helpers.test_function( + input_dtypes=input_dtype, + test_flags=test_flags, + on_device=on_device, + backend_to_test=backend_fw, + fn_name=fn_name, + input=x[0], + axis=axis, + keepdims=keep_dims, + ) + + +# nanmean +@handle_test( + fn_tree="functional.ivy.experimental.nanmean", + dtype_x_axis=_statistical_dtype_values(function="nanmean"), + keep_dims=st.booleans(), + dtype=helpers.get_dtypes("float", full=False), + test_gradients=st.just(False), +) +def test_nanmean( + *, dtype_x_axis, keep_dims, dtype, test_flags, backend_fw, fn_name, on_device +): + input_dtype, x, axis = dtype_x_axis + helpers.test_function( + input_dtypes=input_dtype, + test_flags=test_flags, + atol_=1e-02, + backend_to_test=backend_fw, + fn_name=fn_name, + on_device=on_device, + a=x[0], + axis=axis, + keepdims=keep_dims, + dtype=dtype[0], + ) + + # nanmedian @handle_test( fn_tree="functional.ivy.experimental.nanmedian", @@ -633,3 +611,29 @@ def test_nanmedian( keepdims=keep_dims, overwrite_input=overwriteinput, ) + + +# quantile +@handle_test( + fn_tree="functional.ivy.experimental.quantile", + dtype_and_x=_quantile_helper(), + keep_dims=st.booleans(), + test_gradients=st.just(False), + test_with_out=st.just(False), +) +def test_quantile( + *, dtype_and_x, keep_dims, test_flags, backend_fw, fn_name, on_device +): + input_dtype, x, axis, interpolation, q = dtype_and_x + helpers.test_function( + input_dtypes=input_dtype, + test_flags=test_flags, + backend_to_test=backend_fw, + fn_name=fn_name, + on_device=on_device, + a=x[0], + q=q, + axis=axis, + interpolation=interpolation[0], + keepdims=keep_dims, + ) diff --git a/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_activations.py b/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_activations.py index 1d334a33689ca..f940f291a7714 100644 --- a/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_activations.py +++ b/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_activations.py @@ -6,17 +6,28 @@ from ivy_tests.test_ivy.helpers import handle_test -# logit +# elu @handle_test( - fn_tree="functional.ivy.experimental.logit", + fn_tree="functional.ivy.experimental.elu", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), large_abs_safety_factor=8, small_abs_safety_factor=8, safety_factor_scale="log", ), + alpha=st.one_of( + st.floats(min_value=0.10, max_value=1.0), + ), ) -def test_logit(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): +def test_elu( + *, + dtype_and_x, + alpha, + test_flags, + backend_fw, + fn_name, + on_device, +): dtype, x = dtype_and_x helpers.test_function( input_dtypes=dtype, @@ -25,25 +36,21 @@ def test_logit(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): fn_name=fn_name, on_device=on_device, x=x[0], + alpha=alpha, ) -# thresholded_relu +# logit @handle_test( - fn_tree="functional.ivy.experimental.thresholded_relu", + fn_tree="functional.ivy.experimental.logit", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), large_abs_safety_factor=8, small_abs_safety_factor=8, safety_factor_scale="log", ), - threshold=st.one_of( - st.floats(min_value=-0.10, max_value=10.0), - ), ) -def test_thresholded_relu( - *, dtype_and_x, threshold, test_flags, backend_fw, fn_name, on_device -): +def test_logit(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): dtype, x = dtype_and_x helpers.test_function( input_dtypes=dtype, @@ -52,7 +59,29 @@ def test_thresholded_relu( fn_name=fn_name, on_device=on_device, x=x[0], - threshold=threshold, + ) + + +# logsigmoid +@handle_test( + fn_tree="functional.ivy.experimental.logsigmoid", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + safety_factor_scale="log", + large_abs_safety_factor=120, + ), + test_with_out=st.just(False), +) +def test_logsigmoid(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): + input_dtype, x = dtype_and_x + test_flags.num_positional_args = len(x) + helpers.test_function( + input_dtypes=input_dtype, + test_flags=test_flags, + backend_to_test=backend_fw, + fn_name=fn_name, + on_device=on_device, + input=x[0], ) @@ -106,29 +135,6 @@ def test_relu6(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): ) -# logsigmoid -@handle_test( - fn_tree="functional.ivy.experimental.logsigmoid", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - safety_factor_scale="log", - large_abs_safety_factor=120, - ), - test_with_out=st.just(False), -) -def test_logsigmoid(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): - input_dtype, x = dtype_and_x - test_flags.num_positional_args = len(x) - helpers.test_function( - input_dtypes=input_dtype, - test_flags=test_flags, - backend_to_test=backend_fw, - fn_name=fn_name, - on_device=on_device, - input=x[0], - ) - - # selu @handle_test( fn_tree="functional.ivy.experimental.selu", @@ -177,27 +183,21 @@ def test_silu(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): ) -# elu +# thresholded_relu @handle_test( - fn_tree="functional.ivy.experimental.elu", + fn_tree="functional.ivy.experimental.thresholded_relu", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), large_abs_safety_factor=8, small_abs_safety_factor=8, safety_factor_scale="log", ), - alpha=st.one_of( - st.floats(min_value=0.10, max_value=1.0), + threshold=st.one_of( + st.floats(min_value=-0.10, max_value=10.0), ), ) -def test_elu( - *, - dtype_and_x, - alpha, - test_flags, - backend_fw, - fn_name, - on_device, +def test_thresholded_relu( + *, dtype_and_x, threshold, test_flags, backend_fw, fn_name, on_device ): dtype, x = dtype_and_x helpers.test_function( @@ -207,5 +207,5 @@ def test_elu( fn_name=fn_name, on_device=on_device, x=x[0], - alpha=alpha, + threshold=threshold, ) diff --git a/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_layers.py b/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_layers.py index e83868b1639c0..1536aa2be124c 100644 --- a/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_layers.py +++ b/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_layers.py @@ -8,301 +8,518 @@ from ivy_tests.test_ivy.helpers import handle_test -@handle_test( - fn_tree="functional.ivy.experimental.max_pool1d", - x_k_s_p=helpers.arrays_for_pooling( - min_dims=3, - max_dims=3, - min_side=1, - max_side=4, - explicit_or_str_padding=True, - return_dilation=True, - data_format=st.sampled_from(["channel_first", "channel_last"]), - return_data_format=True, - ), - ceil_mode=st.sampled_from([True, False]), - test_gradients=st.just(False), - ground_truth_backend="torch", -) -def test_max_pool1d( - *, - x_k_s_p, - ceil_mode, - test_flags, - backend_fw, - fn_name, - on_device, -): - dtype, x, kernel, stride, pad, dilation, data_format = x_k_s_p - data_format = "NCW" if data_format == "channel_first" else "NWC" - assume(not (isinstance(pad, str) and (pad.upper() == "VALID") and ceil_mode)) - # TODO: Remove this once the paddle backend supports dilation - assume(not (backend_fw == "paddle" and max(list(dilation)) > 1)) +# --- Helpers --- # +# --------------- # - helpers.test_function( - input_dtypes=dtype, - test_flags=test_flags, - backend_to_test=backend_fw, - on_device=on_device, - fn_name=fn_name, - rtol_=1e-2, - atol_=1e-2, - x=x[0], - kernel=kernel, - strides=stride, - padding=pad, - dilation=dilation, - data_format=data_format, - ceil_mode=ceil_mode, + +def _get_reduce_func(dtype): + if dtype == "bool": + return st.sampled_from([ivy.logical_and, ivy.logical_or]) + else: + return st.sampled_from([ivy.add, ivy.maximum, ivy.minimum, ivy.multiply]) + + +@st.composite +def _interp_args(draw, mode=None, mode_list=None): + mixed_fn_compos = draw(st.booleans()) + curr_backend = ivy.current_backend_str() + torch_modes = [ + "linear", + "bilinear", + "trilinear", + "nearest", + "nearest-exact", + "area", + ] + + tf_modes = [ + "linear", + "bilinear", + "trilinear", + "nearest-exact", + "tf_area", + "bicubic_tensorflow", + "lanczos3", + "lanczos5", + "mitchellcubic", + "gaussian", + ] + + jax_modes = [ + "linear", + "bilinear", + "trilinear", + "nearest-exact", + "bicubic_tensorflow", + "lanczos3", + "lanczos5", + ] + + if not mode and not mode_list: + if curr_backend == "torch" and not mixed_fn_compos: + mode = draw(st.sampled_from(torch_modes)) + elif curr_backend == "tensorflow" and not mixed_fn_compos: + mode = draw(st.sampled_from(tf_modes)) + elif curr_backend == "jax" and not mixed_fn_compos: + mode = draw(st.sampled_from(jax_modes)) + else: + mode = draw( + st.sampled_from( + [ + "linear", + "bilinear", + "trilinear", + "nearest", + "nearest-exact", + "area", + "tf_area", + "bicubic_tensorflow", + "lanczos3", + "lanczos5", + "mitchellcubic", + "gaussian", + ] + ) + ) + elif mode_list: + mode = draw(st.sampled_from(mode_list)) + align_corners = draw(st.one_of(st.booleans(), st.none())) + if (curr_backend == "tensorflow" or curr_backend == "jax") and not mixed_fn_compos: + align_corners = False + if mode == "linear": + num_dims = 3 + elif mode in [ + "bilinear", + "bicubic_tensorflow", + "bicubic", + "mitchellcubic", + "gaussian", + ]: + num_dims = 4 + elif mode == "trilinear": + num_dims = 5 + elif mode in [ + "nearest", + "area", + "tf_area", + "lanczos3", + "lanczos5", + "nearest-exact", + ]: + num_dims = ( + draw( + helpers.ints(min_value=1, max_value=3, mixed_fn_compos=mixed_fn_compos) + ) + + 2 + ) + align_corners = None + if curr_backend == "tensorflow" and not mixed_fn_compos: + num_dims = 3 + dtype, x = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes( + "float", mixed_fn_compos=mixed_fn_compos + ), + min_num_dims=num_dims, + max_num_dims=num_dims, + min_dim_size=2, + max_dim_size=5, + large_abs_safety_factor=50, + small_abs_safety_factor=50, + safety_factor_scale="log", + ) ) + if draw(st.booleans()): + scale_factor = draw( + st.one_of( + helpers.lists( + x=helpers.floats( + min_value=1.0, max_value=2.0, mixed_fn_compos=mixed_fn_compos + ), + min_size=num_dims - 2, + max_size=num_dims - 2, + ), + helpers.floats( + min_value=1.0, max_value=2.0, mixed_fn_compos=mixed_fn_compos + ), + ) + ) + recompute_scale_factor = draw(st.booleans()) + size = None + else: + size = draw( + st.one_of( + helpers.lists( + x=helpers.ints( + min_value=1, max_value=3, mixed_fn_compos=mixed_fn_compos + ), + min_size=num_dims - 2, + max_size=num_dims - 2, + ), + st.integers(min_value=1, max_value=3), + ) + ) + recompute_scale_factor = False + scale_factor = None + if (curr_backend == "tensorflow" or curr_backend == "jax") and not mixed_fn_compos: + if not recompute_scale_factor: + recompute_scale_factor = True + return (dtype, x, mode, size, align_corners, scale_factor, recompute_scale_factor) -@handle_test( - fn_tree="functional.ivy.experimental.layers.max_unpool1d", - x_k_s_p=helpers.arrays_for_pooling(min_dims=3, max_dims=3, min_side=1, max_side=4), - indices=st.lists(st.integers(0, 1), min_size=1, max_size=4), - ground_truth_backend="jax", - test_gradients=st.just(False), -) -def test_max_unpool1d( - *, - x_k_s_p, - indices, - test_flags, - backend_fw, - fn_name, - on_device, -): - dtype, x, kernel, stride, pad = x_k_s_p - helpers.test_function( - input_dtypes=dtype, - test_flags=test_flags, - backend_to_test=backend_fw, - on_device=on_device, - fn_name=fn_name, - rtol_=1e-2, - atol_=1e-2, - x=x[0], - kernel=kernel, - strides=stride, - padding=pad, - indices=indices, + +@st.composite +def _reduce_window_helper(draw, get_func_st): + dtype = draw(helpers.get_dtypes("valid", full=False)) + py_func = draw(get_func_st(dtype[0])) + init_value = draw( + helpers.dtype_and_values( + dtype=dtype, + shape=(), + allow_inf=True, + ) + )[1] + ndim = draw(st.integers(min_value=1, max_value=4)) + _, others = draw( + helpers.dtype_and_values( + num_arrays=4, + dtype=["int64"] * 4, + shape=(ndim,), + min_value=1, + max_value=3, + small_abs_safety_factor=1, + large_abs_safety_factor=1, + ) + ) + others = [other.tolist() for other in others] + window, dilation = others[0], others[2] + op_shape = [] + for i in range(ndim): + min_x = window[i] + (window[i] - 1) * (dilation[i] - 1) + op_shape.append(draw(st.integers(min_x, min_x + 1))) + dtype, operand = draw( + helpers.dtype_and_values( + dtype=dtype, + shape=op_shape, + ) + ) + padding = draw( + st.one_of( + st.lists( + st.tuples( + st.integers(min_value=0, max_value=3), + st.integers(min_value=0, max_value=3), + ), + min_size=ndim, + max_size=ndim, + ), + st.sampled_from(["SAME", "VALID"]), + ) + ) + for i, arg in enumerate(others): + if len(np.unique(arg)) == 1 and draw(st.booleans()): + others[i] = arg[0] + return dtype * 2, operand, init_value, py_func, others, padding + + +@st.composite +def _valid_dct(draw): + dtype, x = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + max_value=65280, + min_value=-65280, + min_num_dims=1, + max_num_dims=5, + min_dim_size=2, + max_dim_size=10, + shared_dtype=True, + ) + ) + dims_len = len(x[0].shape) + n = draw(st.sampled_from([None, "int"])) + axis = draw(helpers.ints(min_value=-dims_len, max_value=dims_len - 1)) + norm = draw(st.sampled_from([None, "ortho"])) + type = draw(helpers.ints(min_value=1, max_value=4)) + if n == "int": + n = draw(helpers.ints(min_value=1, max_value=20)) + if n <= 1 and type == 1: + n = 2 + if norm == "ortho" and type == 1: + norm = None + return dtype, x, type, n, axis, norm + + +@st.composite +def _x_and_fft(draw): + min_fft_points = 2 + dtype = draw(helpers.get_dtypes("valid", full=False)) + x_dim = draw( + helpers.get_shape( + min_dim_size=2, max_dim_size=100, min_num_dims=1, max_num_dims=4 + ) + ) + x = draw( + helpers.array_values( + dtype=dtype[0], + shape=tuple(x_dim), + min_value=-1e5, + max_value=1e5, + allow_inf=False, + large_abs_safety_factor=2.5, + small_abs_safety_factor=2.5, + safety_factor_scale="log", + ) + ) + dim = draw(helpers.get_axis(shape=x_dim, allow_neg=True, force_int=True)) + norm = draw(st.sampled_from(["backward", "forward", "ortho"])) + n = draw(st.integers(min_fft_points, 256)) + return dtype, x, dim, norm, n + + +@st.composite +def _x_and_fft2(draw): + min_fft2_points = 2 + dtype = draw(helpers.get_dtypes("float_and_complex", full=False)) + x_dim = draw( + helpers.get_shape( + min_dim_size=2, max_dim_size=100, min_num_dims=2, max_num_dims=4 + ) + ) + x = draw( + helpers.array_values( + dtype=dtype[0], + shape=tuple(x_dim), + min_value=-1e5, + max_value=1e5, + allow_inf=False, + large_abs_safety_factor=2.5, + small_abs_safety_factor=2.5, + safety_factor_scale="log", + ) + ) + s = ( + draw(st.integers(min_fft2_points, 256)), + draw(st.integers(min_fft2_points, 256)), + ) + dim = draw(st.sampled_from([(0, 1), (-1, -2), (1, 0)])) + norm = draw(st.sampled_from(["backward", "forward", "ortho"])) + return dtype, x, s, dim, norm + + +@st.composite +def _x_and_ifft(draw): + min_fft_points = 2 + dtype = draw(helpers.get_dtypes("complex")) + x_dim = draw( + helpers.get_shape( + min_dim_size=2, max_dim_size=100, min_num_dims=1, max_num_dims=4 + ) + ) + x = draw( + helpers.array_values( + dtype=dtype[0], + shape=tuple(x_dim), + min_value=-1e-10, + max_value=1e10, + ) ) + dim = draw(st.integers(1 - len(list(x_dim)), len(list(x_dim)) - 1)) + norm = draw(st.sampled_from(["backward", "forward", "ortho"])) + n = draw(st.integers(min_fft_points, 256)) + return dtype, x, dim, norm, n + + +@st.composite +def _x_and_ifftn(draw): + min_fft_points = 2 + dtype = draw(helpers.get_dtypes("complex")) + x_dim = draw( + helpers.get_shape( + min_dim_size=2, max_dim_size=100, min_num_dims=1, max_num_dims=4 + ) + ) + x = draw( + helpers.array_values( + dtype=dtype[0], + shape=tuple(x_dim), + min_value=-1e-10, + max_value=1e10, + ) + ) + axes = draw( + st.lists( + st.integers(0, len(x_dim) - 1), min_size=1, max_size=len(x_dim), unique=True + ) + ) + norm = draw(st.sampled_from(["forward", "ortho", "backward"])) + + # Shape for s can be larger, smaller or equal to the size of the input + # along the axes specified by axes. + # Here, we're generating a list of integers corresponding to each axis in axes. + s = draw( + st.lists( + st.integers(min_fft_points, 256), min_size=len(axes), max_size=len(axes) + ) + ) + + return dtype, x, s, axes, norm -@handle_test( - fn_tree="functional.ivy.experimental.max_pool2d", - x_k_s_p=helpers.arrays_for_pooling( - min_dims=4, - max_dims=4, - min_side=2, - max_side=4, - explicit_or_str_padding=True, - return_dilation=True, - data_format=st.sampled_from(["channel_first", "channel_last"]), - return_data_format=True, - ), - ceil_mode=st.sampled_from([True, False]), - test_gradients=st.just(False), - ground_truth_backend="jax", -) -def test_max_pool2d( - *, - x_k_s_p, - ceil_mode, - test_flags, - backend_fw, - fn_name, - on_device, -): - dtype, x, kernel, stride, pad, dilation, data_format = x_k_s_p - assume( - not ( - backend_fw == "tensorflow" - and ( - (stride[0] > kernel[0] or stride[0] > kernel[1]) - or ( - (stride[0] > 1 and dilation[0] > 1) - or (stride[0] > 1 and dilation[1] > 1) - ) - ) +@st.composite +def _x_and_rfftn(draw): + min_rfftn_points = 2 + dtype = draw(helpers.get_dtypes("float")) + x_dim = draw( + helpers.get_shape( + min_dim_size=2, max_dim_size=100, min_num_dims=1, max_num_dims=3 ) ) - data_format = "NCHW" if data_format == "channel_first" else "NHWC" - assume(not (isinstance(pad, str) and (pad.upper() == "VALID") and ceil_mode)) - # TODO: Remove this once the paddle backend supports dilation - assume(not (backend_fw == "paddle" and max(list(dilation)) > 1)) - - helpers.test_function( - input_dtypes=dtype, - test_flags=test_flags, - backend_to_test=backend_fw, - fn_name=fn_name, - on_device=on_device, - rtol_=1e-2, - atol_=1e-2, - x=x[0], - kernel=kernel, - strides=stride, - padding=pad, - dilation=dilation, - ceil_mode=ceil_mode, - data_format=data_format, + x = draw( + helpers.array_values( + dtype=dtype[0], + shape=tuple(x_dim), + min_value=-1e10, + max_value=1e10, + large_abs_safety_factor=2.5, + small_abs_safety_factor=2.5, + safety_factor_scale="log", + ) + ) + axes = draw( + st.lists( + st.integers(0, len(x_dim) - 1), min_size=1, max_size=len(x_dim), unique=True + ) + ) + s = draw( + st.lists( + st.integers(min_rfftn_points, 256), min_size=len(axes), max_size=len(axes) + ) ) + norm = draw(st.sampled_from(["backward", "forward", "ortho"])) + return dtype, x, s, axes, norm + + +# --- Main --- # +# ------------ # @handle_test( - fn_tree="functional.ivy.experimental.max_pool3d", - x_k_s_p=helpers.arrays_for_pooling( - min_dims=5, - max_dims=5, - min_side=1, - max_side=4, - explicit_or_str_padding=True, - return_dilation=True, - data_format=st.sampled_from(["channel_first", "channel_last"]), - return_data_format=True, + fn_tree="functional.ivy.experimental.adaptive_avg_pool1d", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_num_dims=2, + max_num_dims=3, + min_dim_size=1, + max_value=100, + min_value=-100, ), - ceil_mode=st.sampled_from([True, False]), - test_gradients=st.just(False), + output_size=helpers.ints(min_value=1, max_value=5), + test_with_out=st.just(False), ground_truth_backend="torch", ) -def test_max_pool3d( - *, - x_k_s_p, - ceil_mode, - test_flags, - backend_fw, - fn_name, - on_device, +def test_adaptive_avg_pool1d( + *, dtype_and_x, output_size, test_flags, backend_fw, fn_name, on_device ): - dtype, x, kernel, stride, pad, dilation, data_format = x_k_s_p - - data_format = "NCDHW" if data_format == "channel_first" else "NDHWC" - assume(not (isinstance(pad, str) and (pad.upper() == "VALID") and ceil_mode)) - # TODO: Remove this once the paddle backend supports dilation - assume(not (backend_fw == "paddle" and max(list(dilation)) > 1)) - + input_dtype, x = dtype_and_x helpers.test_function( - input_dtypes=dtype, + input_dtypes=input_dtype, test_flags=test_flags, backend_to_test=backend_fw, - on_device=on_device, fn_name=fn_name, - rtol_=1e-2, - atol_=1e-2, - x=x[0], - kernel=kernel, - strides=stride, - padding=pad, - data_format=data_format, - dilation=dilation, - ceil_mode=ceil_mode, + on_device=on_device, + input=x[0], + output_size=output_size, ) @handle_test( - fn_tree="functional.ivy.experimental.avg_pool1d", - x_k_s_p=helpers.arrays_for_pooling(min_dims=3, max_dims=3, min_side=1, max_side=4), - count_include_pad=st.booleans(), - ceil_mode=st.booleans(), - ground_truth_backend="jax", - test_gradients=st.just(False), + fn_tree="functional.ivy.experimental.adaptive_avg_pool2d", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_num_dims=3, + max_num_dims=4, + min_dim_size=1, + max_value=100, + min_value=-100, + ), + output_size=st.one_of( + st.tuples( + helpers.ints(min_value=1, max_value=5), + helpers.ints(min_value=1, max_value=5), + ), + helpers.ints(min_value=1, max_value=5), + ), + test_with_out=st.just(False), + ground_truth_backend="torch", ) -def test_avg_pool1d( - *, - x_k_s_p, - count_include_pad, - ceil_mode, - test_flags, - backend_fw, - on_device, +def test_adaptive_avg_pool2d( + *, dtype_and_x, output_size, test_flags, backend_fw, fn_name, on_device ): - dtype, x, kernel, stride, pad = x_k_s_p + input_dtype, x = dtype_and_x helpers.test_function( - input_dtypes=dtype, + input_dtypes=input_dtype, test_flags=test_flags, backend_to_test=backend_fw, - fn_name="avg_pool1d", - rtol_=1e-2, - atol_=1e-2, on_device=on_device, - x=x[0], - kernel=kernel, - strides=stride, - padding=pad, - count_include_pad=count_include_pad, - ceil_mode=ceil_mode, + fn_name=fn_name, + input=x[0], + output_size=output_size, ) -# avg_pool2d @handle_test( - fn_tree="functional.ivy.experimental.avg_pool2d", - x_k_s_p=helpers.arrays_for_pooling(min_dims=4, max_dims=4, min_side=1, max_side=4), - count_include_pad=st.booleans(), - ceil_mode=st.booleans(), - divisor_override=st.one_of(st.none(), st.integers(min_value=1, max_value=4)), - data_format=st.sampled_from(["NCHW", "NHWC"]), - ground_truth_backend="jax", - test_gradients=st.just(False), + fn_tree="functional.ivy.experimental.adaptive_max_pool2d", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_num_dims=3, + max_num_dims=4, + min_dim_size=1, + # Setting max and min value because this operation in paddle is not + # numerically stable + max_value=100, + min_value=-100, + ), + output_size=st.one_of( + st.tuples( + helpers.ints(min_value=1, max_value=5), + helpers.ints(min_value=1, max_value=5), + ), + helpers.ints(min_value=1, max_value=5), + ), + test_with_out=st.just(False), + ground_truth_backend="torch", ) -def test_avg_pool2d( - *, - x_k_s_p, - count_include_pad, - ceil_mode, - divisor_override, - data_format, - test_flags, - backend_fw, - on_device, - fn_name, +def test_adaptive_max_pool2d( + *, dtype_and_x, output_size, test_flags, backend_fw, fn_name, on_device ): - dtype, x, kernel, stride, pad = x_k_s_p - - if data_format == "NCHW": - x[0] = x[0].reshape( - (x[0].shape[0], x[0].shape[3], x[0].shape[1], x[0].shape[2]) - ) - + input_dtype, x = dtype_and_x helpers.test_function( - input_dtypes=dtype, + input_dtypes=input_dtype, test_flags=test_flags, backend_to_test=backend_fw, on_device=on_device, fn_name=fn_name, - rtol_=1e-2, - atol_=1e-2, - x=x[0], - kernel=kernel, - strides=stride, - padding=pad, - data_format=data_format, - count_include_pad=count_include_pad, - ceil_mode=ceil_mode, - divisor_override=divisor_override, + input=x[0], + output_size=output_size, ) @handle_test( - fn_tree="functional.ivy.experimental.avg_pool3d", - x_k_s_p=helpers.arrays_for_pooling(min_dims=5, max_dims=5, min_side=1, max_side=4), + fn_tree="functional.ivy.experimental.avg_pool1d", + x_k_s_p=helpers.arrays_for_pooling(min_dims=3, max_dims=3, min_side=1, max_side=4), count_include_pad=st.booleans(), ceil_mode=st.booleans(), - divisor_override=st.one_of(st.none(), st.integers(min_value=1, max_value=4)), ground_truth_backend="jax", test_gradients=st.just(False), ) -def test_avg_pool3d( +def test_avg_pool1d( *, x_k_s_p, count_include_pad, ceil_mode, - divisor_override, test_flags, backend_fw, - fn_name, on_device, ): dtype, x, kernel, stride, pad = x_k_s_p @@ -310,327 +527,165 @@ def test_avg_pool3d( input_dtypes=dtype, test_flags=test_flags, backend_to_test=backend_fw, - fn_name=fn_name, + fn_name="avg_pool1d", + rtol_=1e-2, + atol_=1e-2, on_device=on_device, - rtol_=1e-1, - atol_=1e-1, x=x[0], kernel=kernel, strides=stride, padding=pad, count_include_pad=count_include_pad, ceil_mode=ceil_mode, - divisor_override=divisor_override, - ) - - -@st.composite -def _valid_dct(draw): - dtype, x = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - max_value=65280, - min_value=-65280, - min_num_dims=1, - max_num_dims=5, - min_dim_size=2, - max_dim_size=10, - shared_dtype=True, - ) - ) - dims_len = len(x[0].shape) - n = draw(st.sampled_from([None, "int"])) - axis = draw(helpers.ints(min_value=-dims_len, max_value=dims_len - 1)) - norm = draw(st.sampled_from([None, "ortho"])) - type = draw(helpers.ints(min_value=1, max_value=4)) - if n == "int": - n = draw(helpers.ints(min_value=1, max_value=20)) - if n <= 1 and type == 1: - n = 2 - if norm == "ortho" and type == 1: - norm = None - return dtype, x, type, n, axis, norm - - -@handle_test( - fn_tree="dct", - dtype_x_and_args=_valid_dct(), - test_gradients=st.just(False), -) -def test_dct(*, dtype_x_and_args, test_flags, backend_fw, fn_name, on_device): - input_dtype, x, type, n, axis, norm = dtype_x_and_args - helpers.test_function( - input_dtypes=input_dtype, - test_flags=test_flags, - backend_to_test=backend_fw, - fn_name=fn_name, - on_device=on_device, - x=x[0], - type=type, - n=n, - axis=axis, - norm=norm, - rtol_=1e-3, - atol_=1e-1, ) +# avg_pool2d @handle_test( - fn_tree="idct", - dtype_x_and_args=_valid_dct(), + fn_tree="functional.ivy.experimental.avg_pool2d", + x_k_s_p=helpers.arrays_for_pooling(min_dims=4, max_dims=4, min_side=1, max_side=4), + count_include_pad=st.booleans(), + ceil_mode=st.booleans(), + divisor_override=st.one_of(st.none(), st.integers(min_value=1, max_value=4)), + data_format=st.sampled_from(["NCHW", "NHWC"]), + ground_truth_backend="jax", test_gradients=st.just(False), ) -def test_idct(dtype_x_and_args, test_flags, backend_fw, fn_name, on_device): - input_dtype, x, type, n, axis, norm = dtype_x_and_args - helpers.test_function( - input_dtypes=input_dtype, - test_flags=test_flags, - backend_to_test=backend_fw, - fn_name=fn_name, - x=x[0], - type=type, - n=n, - axis=axis, - norm=norm, - rtol_=1e-3, - atol_=1e-1, - on_device=on_device, - ) - - -@st.composite -def _interp_args(draw, mode=None, mode_list=None): - mixed_fn_compos = draw(st.booleans()) - curr_backend = ivy.current_backend_str() - torch_modes = [ - "linear", - "bilinear", - "trilinear", - "nearest", - "nearest-exact", - "area", - ] - - tf_modes = [ - "linear", - "bilinear", - "trilinear", - "nearest-exact", - "tf_area", - "bicubic_tensorflow", - "lanczos3", - "lanczos5", - "mitchellcubic", - "gaussian", - ] - - jax_modes = [ - "linear", - "bilinear", - "trilinear", - "nearest-exact", - "bicubic_tensorflow", - "lanczos3", - "lanczos5", - ] +def test_avg_pool2d( + *, + x_k_s_p, + count_include_pad, + ceil_mode, + divisor_override, + data_format, + test_flags, + backend_fw, + on_device, + fn_name, +): + dtype, x, kernel, stride, pad = x_k_s_p - if not mode and not mode_list: - if curr_backend == "torch" and not mixed_fn_compos: - mode = draw(st.sampled_from(torch_modes)) - elif curr_backend == "tensorflow" and not mixed_fn_compos: - mode = draw(st.sampled_from(tf_modes)) - elif curr_backend == "jax" and not mixed_fn_compos: - mode = draw(st.sampled_from(jax_modes)) - else: - mode = draw( - st.sampled_from( - [ - "linear", - "bilinear", - "trilinear", - "nearest", - "nearest-exact", - "area", - "tf_area", - "bicubic_tensorflow", - "lanczos3", - "lanczos5", - "mitchellcubic", - "gaussian", - ] - ) - ) - elif mode_list: - mode = draw(st.sampled_from(mode_list)) - align_corners = draw(st.one_of(st.booleans(), st.none())) - if (curr_backend == "tensorflow" or curr_backend == "jax") and not mixed_fn_compos: - align_corners = False - if mode == "linear": - num_dims = 3 - elif mode in [ - "bilinear", - "bicubic_tensorflow", - "bicubic", - "mitchellcubic", - "gaussian", - ]: - num_dims = 4 - elif mode == "trilinear": - num_dims = 5 - elif mode in [ - "nearest", - "area", - "tf_area", - "lanczos3", - "lanczos5", - "nearest-exact", - ]: - num_dims = ( - draw( - helpers.ints(min_value=1, max_value=3, mixed_fn_compos=mixed_fn_compos) - ) - + 2 - ) - align_corners = None - if curr_backend == "tensorflow" and not mixed_fn_compos: - num_dims = 3 - dtype, x = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes( - "float", mixed_fn_compos=mixed_fn_compos - ), - min_num_dims=num_dims, - max_num_dims=num_dims, - min_dim_size=2, - max_dim_size=5, - large_abs_safety_factor=50, - small_abs_safety_factor=50, - safety_factor_scale="log", - ) - ) - if draw(st.booleans()): - scale_factor = draw( - st.one_of( - helpers.lists( - x=helpers.floats( - min_value=1.0, max_value=2.0, mixed_fn_compos=mixed_fn_compos - ), - min_size=num_dims - 2, - max_size=num_dims - 2, - ), - helpers.floats( - min_value=1.0, max_value=2.0, mixed_fn_compos=mixed_fn_compos - ), - ) - ) - recompute_scale_factor = draw(st.booleans()) - size = None - else: - size = draw( - st.one_of( - helpers.lists( - x=helpers.ints( - min_value=1, max_value=3, mixed_fn_compos=mixed_fn_compos - ), - min_size=num_dims - 2, - max_size=num_dims - 2, - ), - st.integers(min_value=1, max_value=3), - ) + if data_format == "NCHW": + x[0] = x[0].reshape( + (x[0].shape[0], x[0].shape[3], x[0].shape[1], x[0].shape[2]) ) - recompute_scale_factor = False - scale_factor = None - if (curr_backend == "tensorflow" or curr_backend == "jax") and not mixed_fn_compos: - if not recompute_scale_factor: - recompute_scale_factor = True - return (dtype, x, mode, size, align_corners, scale_factor, recompute_scale_factor) + helpers.test_function( + input_dtypes=dtype, + test_flags=test_flags, + backend_to_test=backend_fw, + on_device=on_device, + fn_name=fn_name, + rtol_=1e-2, + atol_=1e-2, + x=x[0], + kernel=kernel, + strides=stride, + padding=pad, + data_format=data_format, + count_include_pad=count_include_pad, + ceil_mode=ceil_mode, + divisor_override=divisor_override, + ) @handle_test( - fn_tree="functional.ivy.experimental.interpolate", - dtype_x_mode=_interp_args(), - antialias=st.just(False), + fn_tree="functional.ivy.experimental.avg_pool3d", + x_k_s_p=helpers.arrays_for_pooling(min_dims=5, max_dims=5, min_side=1, max_side=4), + count_include_pad=st.booleans(), + ceil_mode=st.booleans(), + divisor_override=st.one_of(st.none(), st.integers(min_value=1, max_value=4)), + ground_truth_backend="jax", test_gradients=st.just(False), - number_positional_args=st.just(2), ) -def test_interpolate( - dtype_x_mode, antialias, test_flags, backend_fw, fn_name, on_device +def test_avg_pool3d( + *, + x_k_s_p, + count_include_pad, + ceil_mode, + divisor_override, + test_flags, + backend_fw, + fn_name, + on_device, ): - ( - input_dtype, - x, - mode, - size, - align_corners, - scale_factor, - recompute_scale_factor, - ) = dtype_x_mode + dtype, x, kernel, stride, pad = x_k_s_p helpers.test_function( - input_dtypes=input_dtype, + input_dtypes=dtype, test_flags=test_flags, backend_to_test=backend_fw, fn_name=fn_name, on_device=on_device, - rtol_=1e-01, - atol_=1e-01, + rtol_=1e-1, + atol_=1e-1, x=x[0], - size=size, - mode=mode, - align_corners=align_corners, - antialias=antialias, - scale_factor=scale_factor, - recompute_scale_factor=recompute_scale_factor, + kernel=kernel, + strides=stride, + padding=pad, + count_include_pad=count_include_pad, + ceil_mode=ceil_mode, + divisor_override=divisor_override, ) -@st.composite -def _x_and_fft(draw): - min_fft_points = 2 - dtype = draw(helpers.get_dtypes("valid", full=False)) - x_dim = draw( - helpers.get_shape( - min_dim_size=2, max_dim_size=100, min_num_dims=1, max_num_dims=4 - ) - ) - x = draw( - helpers.array_values( - dtype=dtype[0], - shape=tuple(x_dim), - min_value=-1e5, - max_value=1e5, - allow_inf=False, - large_abs_safety_factor=2.5, - small_abs_safety_factor=2.5, - safety_factor_scale="log", - ) +@handle_test( + fn_tree="dct", + dtype_x_and_args=_valid_dct(), + test_gradients=st.just(False), +) +def test_dct(*, dtype_x_and_args, test_flags, backend_fw, fn_name, on_device): + input_dtype, x, type, n, axis, norm = dtype_x_and_args + helpers.test_function( + input_dtypes=input_dtype, + test_flags=test_flags, + backend_to_test=backend_fw, + fn_name=fn_name, + on_device=on_device, + x=x[0], + type=type, + n=n, + axis=axis, + norm=norm, + rtol_=1e-3, + atol_=1e-1, ) - dim = draw(helpers.get_axis(shape=x_dim, allow_neg=True, force_int=True)) - norm = draw(st.sampled_from(["backward", "forward", "ortho"])) - n = draw(st.integers(min_fft_points, 256)) - return dtype, x, dim, norm, n @handle_test( - fn_tree="functional.ivy.experimental.fft", - d_x_d_n_n=_x_and_fft(), - ground_truth_backend="jax", - test_gradients=st.just(False), + fn_tree="dft", + d_xfft_axis_n_length=_x_and_fft(), + d_xifft_axis_n_length=_x_and_ifft(), + inverse=st.booleans(), + onesided=st.booleans(), ) -def test_fft(*, d_x_d_n_n, test_flags, backend_fw, on_device, fn_name): - dtype, x, dim, norm, n = d_x_d_n_n +def test_dft( + *, + d_xfft_axis_n_length, + d_xifft_axis_n_length, + inverse, + onesided, + test_flags, + backend_fw, + fn_name, + on_device, +): + if inverse: + dtype, x, axis, norm, dft_length = d_xifft_axis_n_length + else: + dtype, x, axis, norm, dft_length = d_xfft_axis_n_length + helpers.test_function( input_dtypes=dtype, test_flags=test_flags, backend_to_test=backend_fw, fn_name=fn_name, - rtol_=1e-2, - atol_=1e-2, on_device=on_device, x=x, - dim=dim, + axis=axis, + inverse=inverse, + onesided=onesided, + dft_length=dft_length, norm=norm, - n=n, ) @@ -792,27 +847,100 @@ def test_dropout3d( assert u.shape == v.shape == w.shape -@st.composite -def _x_and_ifft(draw): - min_fft_points = 2 - dtype = draw(helpers.get_dtypes("complex")) - x_dim = draw( - helpers.get_shape( - min_dim_size=2, max_dim_size=100, min_num_dims=1, max_num_dims=4 - ) +# embedding +@handle_test( + fn_tree="functional.ivy.experimental.embedding", + dtypes_indices_weights=helpers.embedding_helper(), + max_norm=st.one_of(st.none(), st.floats(min_value=1, max_value=5)), + number_positional_args=st.just(2), +) +def test_embedding( + *, dtypes_indices_weights, max_norm, test_flags, backend_fw, on_device, fn_name +): + dtypes, indices, weights, _ = dtypes_indices_weights + dtypes = [dtypes[1], dtypes[0]] + + helpers.test_function( + input_dtypes=dtypes, + test_flags=test_flags, + xs_grad_idxs=[[0, 0]], + backend_to_test=backend_fw, + on_device=on_device, + fn_name=fn_name, + weights=weights, + indices=indices, + max_norm=max_norm, ) - x = draw( - helpers.array_values( - dtype=dtype[0], - shape=tuple(x_dim), - min_value=-1e-10, - max_value=1e10, - ) + + +@handle_test( + fn_tree="functional.ivy.experimental.fft", + d_x_d_n_n=_x_and_fft(), + ground_truth_backend="jax", + test_gradients=st.just(False), +) +def test_fft(*, d_x_d_n_n, test_flags, backend_fw, on_device, fn_name): + dtype, x, dim, norm, n = d_x_d_n_n + helpers.test_function( + input_dtypes=dtype, + test_flags=test_flags, + backend_to_test=backend_fw, + fn_name=fn_name, + rtol_=1e-2, + atol_=1e-2, + on_device=on_device, + x=x, + dim=dim, + norm=norm, + n=n, + ) + + +@handle_test( + fn_tree="functional.ivy.experimental.fft2", + d_x_d_s_n=_x_and_fft2(), + ground_truth_backend="numpy", + container_flags=st.just([False]), + test_gradients=st.just(False), +) +def test_fft2(*, d_x_d_s_n, test_flags, backend_fw, fn_name, on_device): + dtype, x, s, dim, norm = d_x_d_s_n + helpers.test_function( + input_dtypes=dtype, + test_flags=test_flags, + backend_to_test=backend_fw, + on_device=on_device, + fn_name=fn_name, + # rtol_=1e-2, + # atol_=1e-2, + x=x, + s=s, + dim=dim, + norm=norm, + ) + + +@handle_test( + fn_tree="idct", + dtype_x_and_args=_valid_dct(), + test_gradients=st.just(False), +) +def test_idct(dtype_x_and_args, test_flags, backend_fw, fn_name, on_device): + input_dtype, x, type, n, axis, norm = dtype_x_and_args + helpers.test_function( + input_dtypes=input_dtype, + test_flags=test_flags, + backend_to_test=backend_fw, + fn_name=fn_name, + x=x[0], + type=type, + n=n, + axis=axis, + norm=norm, + rtol_=1e-3, + atol_=1e-1, + on_device=on_device, ) - dim = draw(st.integers(1 - len(list(x_dim)), len(list(x_dim)) - 1)) - norm = draw(st.sampled_from(["backward", "forward", "ortho"])) - n = draw(st.integers(min_fft_points, 256)) - return dtype, x, dim, norm, n @handle_test( @@ -837,407 +965,287 @@ def test_ifft(*, d_x_d_n_n, test_flags, backend_fw, fn_name): ) -# embedding -@handle_test( - fn_tree="functional.ivy.experimental.embedding", - dtypes_indices_weights=helpers.embedding_helper(), - max_norm=st.one_of(st.none(), st.floats(min_value=1, max_value=5)), - number_positional_args=st.just(2), -) -def test_embedding( - *, dtypes_indices_weights, max_norm, test_flags, backend_fw, on_device, fn_name -): - dtypes, indices, weights, _ = dtypes_indices_weights - dtypes = [dtypes[1], dtypes[0]] - - helpers.test_function( - input_dtypes=dtypes, - test_flags=test_flags, - xs_grad_idxs=[[0, 0]], - backend_to_test=backend_fw, - on_device=on_device, - fn_name=fn_name, - weights=weights, - indices=indices, - max_norm=max_norm, - ) - - @handle_test( - fn_tree="dft", - d_xfft_axis_n_length=_x_and_fft(), - d_xifft_axis_n_length=_x_and_ifft(), - inverse=st.booleans(), - onesided=st.booleans(), + fn_tree="functional.ivy.experimental.ifftn", + d_x_d_s_n=_x_and_ifftn(), + ground_truth_backend="numpy", + test_gradients=st.just(False), ) -def test_dft( +def test_ifftn( *, - d_xfft_axis_n_length, - d_xifft_axis_n_length, - inverse, - onesided, + d_x_d_s_n, test_flags, backend_fw, fn_name, on_device, ): - if inverse: - dtype, x, axis, norm, dft_length = d_xifft_axis_n_length - else: - dtype, x, axis, norm, dft_length = d_xfft_axis_n_length - + dtype, x, s, axes, norm = d_x_d_s_n helpers.test_function( input_dtypes=dtype, test_flags=test_flags, backend_to_test=backend_fw, - fn_name=fn_name, on_device=on_device, + fn_name=fn_name, x=x, - axis=axis, - inverse=inverse, - onesided=onesided, - dft_length=dft_length, + s=s, + axes=axes, norm=norm, ) @handle_test( - fn_tree="functional.ivy.experimental.adaptive_max_pool2d", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_num_dims=3, - max_num_dims=4, - min_dim_size=1, - # Setting max and min value because this operation in paddle is not - # numerically stable - max_value=100, - min_value=-100, - ), - output_size=st.one_of( - st.tuples( - helpers.ints(min_value=1, max_value=5), - helpers.ints(min_value=1, max_value=5), - ), - helpers.ints(min_value=1, max_value=5), - ), - test_with_out=st.just(False), - ground_truth_backend="torch", + fn_tree="functional.ivy.experimental.interpolate", + dtype_x_mode=_interp_args(), + antialias=st.just(False), + test_gradients=st.just(False), + number_positional_args=st.just(2), ) -def test_adaptive_max_pool2d( - *, dtype_and_x, output_size, test_flags, backend_fw, fn_name, on_device +def test_interpolate( + dtype_x_mode, antialias, test_flags, backend_fw, fn_name, on_device ): - input_dtype, x = dtype_and_x + ( + input_dtype, + x, + mode, + size, + align_corners, + scale_factor, + recompute_scale_factor, + ) = dtype_x_mode helpers.test_function( input_dtypes=input_dtype, test_flags=test_flags, backend_to_test=backend_fw, - on_device=on_device, fn_name=fn_name, - input=x[0], - output_size=output_size, + on_device=on_device, + rtol_=1e-01, + atol_=1e-01, + x=x[0], + size=size, + mode=mode, + align_corners=align_corners, + antialias=antialias, + scale_factor=scale_factor, + recompute_scale_factor=recompute_scale_factor, ) @handle_test( - fn_tree="functional.ivy.experimental.adaptive_avg_pool1d", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_num_dims=2, - max_num_dims=3, - min_dim_size=1, - max_value=100, - min_value=-100, + fn_tree="functional.ivy.experimental.max_pool1d", + x_k_s_p=helpers.arrays_for_pooling( + min_dims=3, + max_dims=3, + min_side=1, + max_side=4, + explicit_or_str_padding=True, + return_dilation=True, + data_format=st.sampled_from(["channel_first", "channel_last"]), + return_data_format=True, ), - output_size=helpers.ints(min_value=1, max_value=5), - test_with_out=st.just(False), + ceil_mode=st.sampled_from([True, False]), + test_gradients=st.just(False), ground_truth_backend="torch", ) -def test_adaptive_avg_pool1d( - *, dtype_and_x, output_size, test_flags, backend_fw, fn_name, on_device +def test_max_pool1d( + *, + x_k_s_p, + ceil_mode, + test_flags, + backend_fw, + fn_name, + on_device, ): - input_dtype, x = dtype_and_x + dtype, x, kernel, stride, pad, dilation, data_format = x_k_s_p + data_format = "NCW" if data_format == "channel_first" else "NWC" + assume(not (isinstance(pad, str) and (pad.upper() == "VALID") and ceil_mode)) + # TODO: Remove this once the paddle backend supports dilation + assume(not (backend_fw == "paddle" and max(list(dilation)) > 1)) + helpers.test_function( - input_dtypes=input_dtype, + input_dtypes=dtype, test_flags=test_flags, backend_to_test=backend_fw, - fn_name=fn_name, on_device=on_device, - input=x[0], - output_size=output_size, + fn_name=fn_name, + rtol_=1e-2, + atol_=1e-2, + x=x[0], + kernel=kernel, + strides=stride, + padding=pad, + dilation=dilation, + data_format=data_format, + ceil_mode=ceil_mode, ) @handle_test( - fn_tree="functional.ivy.experimental.adaptive_avg_pool2d", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_num_dims=3, - max_num_dims=4, - min_dim_size=1, - max_value=100, - min_value=-100, - ), - output_size=st.one_of( - st.tuples( - helpers.ints(min_value=1, max_value=5), - helpers.ints(min_value=1, max_value=5), - ), - helpers.ints(min_value=1, max_value=5), + fn_tree="functional.ivy.experimental.max_pool2d", + x_k_s_p=helpers.arrays_for_pooling( + min_dims=4, + max_dims=4, + min_side=2, + max_side=4, + explicit_or_str_padding=True, + return_dilation=True, + data_format=st.sampled_from(["channel_first", "channel_last"]), + return_data_format=True, ), - test_with_out=st.just(False), - ground_truth_backend="torch", + ceil_mode=st.sampled_from([True, False]), + test_gradients=st.just(False), + ground_truth_backend="jax", ) -def test_adaptive_avg_pool2d( - *, dtype_and_x, output_size, test_flags, backend_fw, fn_name, on_device -): - input_dtype, x = dtype_and_x - helpers.test_function( - input_dtypes=input_dtype, - test_flags=test_flags, - backend_to_test=backend_fw, - on_device=on_device, - fn_name=fn_name, - input=x[0], - output_size=output_size, - ) - - -@st.composite -def _reduce_window_helper(draw, get_func_st): - dtype = draw(helpers.get_dtypes("valid", full=False)) - py_func = draw(get_func_st(dtype[0])) - init_value = draw( - helpers.dtype_and_values( - dtype=dtype, - shape=(), - allow_inf=True, - ) - )[1] - ndim = draw(st.integers(min_value=1, max_value=4)) - _, others = draw( - helpers.dtype_and_values( - num_arrays=4, - dtype=["int64"] * 4, - shape=(ndim,), - min_value=1, - max_value=3, - small_abs_safety_factor=1, - large_abs_safety_factor=1, - ) - ) - others = [other.tolist() for other in others] - window, dilation = others[0], others[2] - op_shape = [] - for i in range(ndim): - min_x = window[i] + (window[i] - 1) * (dilation[i] - 1) - op_shape.append(draw(st.integers(min_x, min_x + 1))) - dtype, operand = draw( - helpers.dtype_and_values( - dtype=dtype, - shape=op_shape, - ) - ) - padding = draw( - st.one_of( - st.lists( - st.tuples( - st.integers(min_value=0, max_value=3), - st.integers(min_value=0, max_value=3), - ), - min_size=ndim, - max_size=ndim, - ), - st.sampled_from(["SAME", "VALID"]), +def test_max_pool2d( + *, + x_k_s_p, + ceil_mode, + test_flags, + backend_fw, + fn_name, + on_device, +): + dtype, x, kernel, stride, pad, dilation, data_format = x_k_s_p + assume( + not ( + backend_fw == "tensorflow" + and ( + (stride[0] > kernel[0] or stride[0] > kernel[1]) + or ( + (stride[0] > 1 and dilation[0] > 1) + or (stride[0] > 1 and dilation[1] > 1) + ) + ) ) ) - for i, arg in enumerate(others): - if len(np.unique(arg)) == 1 and draw(st.booleans()): - others[i] = arg[0] - return dtype * 2, operand, init_value, py_func, others, padding - - -def _get_reduce_func(dtype): - if dtype == "bool": - return st.sampled_from([ivy.logical_and, ivy.logical_or]) - else: - return st.sampled_from([ivy.add, ivy.maximum, ivy.minimum, ivy.multiply]) - + data_format = "NCHW" if data_format == "channel_first" else "NHWC" + assume(not (isinstance(pad, str) and (pad.upper() == "VALID") and ceil_mode)) + # TODO: Remove this once the paddle backend supports dilation + assume(not (backend_fw == "paddle" and max(list(dilation)) > 1)) -@handle_test( - fn_tree="functional.ivy.experimental.reduce_window", - all_args=_reduce_window_helper(_get_reduce_func), - test_with_out=st.just(False), - ground_truth_backend="jax", -) -def test_reduce_window(*, all_args, test_flags, backend_fw, fn_name, on_device): - dtypes, operand, init_value, computation, others, padding = all_args helpers.test_function( - input_dtypes=dtypes, + input_dtypes=dtype, test_flags=test_flags, backend_to_test=backend_fw, - on_device=on_device, fn_name=fn_name, - operand=operand[0], - init_value=init_value[0], - computation=computation, - window_dimensions=others[0], - window_strides=others[1], - padding=padding, - base_dilation=others[2], - window_dilation=None, - ) - - -@st.composite -def _x_and_fft2(draw): - min_fft2_points = 2 - dtype = draw(helpers.get_dtypes("float_and_complex", full=False)) - x_dim = draw( - helpers.get_shape( - min_dim_size=2, max_dim_size=100, min_num_dims=2, max_num_dims=4 - ) - ) - x = draw( - helpers.array_values( - dtype=dtype[0], - shape=tuple(x_dim), - min_value=-1e5, - max_value=1e5, - allow_inf=False, - large_abs_safety_factor=2.5, - small_abs_safety_factor=2.5, - safety_factor_scale="log", - ) - ) - s = ( - draw(st.integers(min_fft2_points, 256)), - draw(st.integers(min_fft2_points, 256)), + on_device=on_device, + rtol_=1e-2, + atol_=1e-2, + x=x[0], + kernel=kernel, + strides=stride, + padding=pad, + dilation=dilation, + ceil_mode=ceil_mode, + data_format=data_format, ) - dim = draw(st.sampled_from([(0, 1), (-1, -2), (1, 0)])) - norm = draw(st.sampled_from(["backward", "forward", "ortho"])) - return dtype, x, s, dim, norm @handle_test( - fn_tree="functional.ivy.experimental.fft2", - d_x_d_s_n=_x_and_fft2(), - ground_truth_backend="numpy", - container_flags=st.just([False]), + fn_tree="functional.ivy.experimental.max_pool3d", + x_k_s_p=helpers.arrays_for_pooling( + min_dims=5, + max_dims=5, + min_side=1, + max_side=4, + explicit_or_str_padding=True, + return_dilation=True, + data_format=st.sampled_from(["channel_first", "channel_last"]), + return_data_format=True, + ), + ceil_mode=st.sampled_from([True, False]), test_gradients=st.just(False), + ground_truth_backend="torch", ) -def test_fft2(*, d_x_d_s_n, test_flags, backend_fw, fn_name, on_device): - dtype, x, s, dim, norm = d_x_d_s_n +def test_max_pool3d( + *, + x_k_s_p, + ceil_mode, + test_flags, + backend_fw, + fn_name, + on_device, +): + dtype, x, kernel, stride, pad, dilation, data_format = x_k_s_p + + data_format = "NCDHW" if data_format == "channel_first" else "NDHWC" + assume(not (isinstance(pad, str) and (pad.upper() == "VALID") and ceil_mode)) + # TODO: Remove this once the paddle backend supports dilation + assume(not (backend_fw == "paddle" and max(list(dilation)) > 1)) + helpers.test_function( input_dtypes=dtype, test_flags=test_flags, backend_to_test=backend_fw, on_device=on_device, fn_name=fn_name, - # rtol_=1e-2, - # atol_=1e-2, - x=x, - s=s, - dim=dim, - norm=norm, - ) - - -@st.composite -def _x_and_ifftn(draw): - min_fft_points = 2 - dtype = draw(helpers.get_dtypes("complex")) - x_dim = draw( - helpers.get_shape( - min_dim_size=2, max_dim_size=100, min_num_dims=1, max_num_dims=4 - ) - ) - x = draw( - helpers.array_values( - dtype=dtype[0], - shape=tuple(x_dim), - min_value=-1e-10, - max_value=1e10, - ) - ) - axes = draw( - st.lists( - st.integers(0, len(x_dim) - 1), min_size=1, max_size=len(x_dim), unique=True - ) - ) - norm = draw(st.sampled_from(["forward", "ortho", "backward"])) - - # Shape for s can be larger, smaller or equal to the size of the input - # along the axes specified by axes. - # Here, we're generating a list of integers corresponding to each axis in axes. - s = draw( - st.lists( - st.integers(min_fft_points, 256), min_size=len(axes), max_size=len(axes) - ) + rtol_=1e-2, + atol_=1e-2, + x=x[0], + kernel=kernel, + strides=stride, + padding=pad, + data_format=data_format, + dilation=dilation, + ceil_mode=ceil_mode, ) - return dtype, x, s, axes, norm - @handle_test( - fn_tree="functional.ivy.experimental.ifftn", - d_x_d_s_n=_x_and_ifftn(), - ground_truth_backend="numpy", + fn_tree="functional.ivy.experimental.layers.max_unpool1d", + x_k_s_p=helpers.arrays_for_pooling(min_dims=3, max_dims=3, min_side=1, max_side=4), + indices=st.lists(st.integers(0, 1), min_size=1, max_size=4), + ground_truth_backend="jax", test_gradients=st.just(False), ) -def test_ifftn( +def test_max_unpool1d( *, - d_x_d_s_n, + x_k_s_p, + indices, test_flags, backend_fw, fn_name, on_device, ): - dtype, x, s, axes, norm = d_x_d_s_n + dtype, x, kernel, stride, pad = x_k_s_p helpers.test_function( input_dtypes=dtype, test_flags=test_flags, backend_to_test=backend_fw, on_device=on_device, fn_name=fn_name, - x=x, - s=s, - axes=axes, - norm=norm, + rtol_=1e-2, + atol_=1e-2, + x=x[0], + kernel=kernel, + strides=stride, + padding=pad, + indices=indices, ) -@st.composite -def _x_and_rfftn(draw): - min_rfftn_points = 2 - dtype = draw(helpers.get_dtypes("float")) - x_dim = draw( - helpers.get_shape( - min_dim_size=2, max_dim_size=100, min_num_dims=1, max_num_dims=3 - ) - ) - x = draw( - helpers.array_values( - dtype=dtype[0], - shape=tuple(x_dim), - min_value=-1e10, - max_value=1e10, - large_abs_safety_factor=2.5, - small_abs_safety_factor=2.5, - safety_factor_scale="log", - ) - ) - axes = draw( - st.lists( - st.integers(0, len(x_dim) - 1), min_size=1, max_size=len(x_dim), unique=True - ) - ) - s = draw( - st.lists( - st.integers(min_rfftn_points, 256), min_size=len(axes), max_size=len(axes) - ) +@handle_test( + fn_tree="functional.ivy.experimental.reduce_window", + all_args=_reduce_window_helper(_get_reduce_func), + test_with_out=st.just(False), + ground_truth_backend="jax", +) +def test_reduce_window(*, all_args, test_flags, backend_fw, fn_name, on_device): + dtypes, operand, init_value, computation, others, padding = all_args + helpers.test_function( + input_dtypes=dtypes, + test_flags=test_flags, + backend_to_test=backend_fw, + on_device=on_device, + fn_name=fn_name, + operand=operand[0], + init_value=init_value[0], + computation=computation, + window_dimensions=others[0], + window_strides=others[1], + padding=padding, + base_dilation=others[2], + window_dilation=None, ) - norm = draw(st.sampled_from(["backward", "forward", "ortho"])) - return dtype, x, s, axes, norm @handle_test( diff --git a/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_losses.py b/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_losses.py index 77520704697f3..2542c2495d4e4 100644 --- a/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_losses.py +++ b/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_losses.py @@ -6,56 +6,52 @@ from ivy_tests.test_ivy.helpers import handle_test -# log_poisson_loss +# huber_loss @handle_test( - fn_tree="functional.ivy.experimental.log_poisson_loss", - dtype_and_targets=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_value=0, - max_value=3, + fn_tree="functional.ivy.experimental.huber_loss", + dtype_and_true=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + min_value=-10, + max_value=10, allow_inf=False, min_num_dims=1, max_num_dims=3, min_dim_size=3, ), - dtype_and_log_input=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - small_abs_safety_factor=4, - safety_factor_scale="log", - min_value=0, - max_value=3, + dtype_and_pred=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + min_value=-10, + max_value=10, allow_inf=False, - exclude_min=True, - exclude_max=True, min_num_dims=1, max_num_dims=3, min_dim_size=3, ), - compute_full_loss=st.sampled_from([True, False]), - test_with_out=st.just(False), + reduction=st.sampled_from(["none", "sum", "mean"]), + delta=helpers.floats(min_value=0.01, max_value=2.0), ) -def test_log_poisson_loss( - *, - dtype_and_targets, - dtype_and_log_input, - compute_full_loss, +def test_huber_loss( + dtype_and_true, + dtype_and_pred, + reduction, + delta, test_flags, backend_fw, fn_name, on_device, ): - targets_dtype, targets = dtype_and_targets - log_input_dtype, log_input = dtype_and_log_input + true_dtype, true = dtype_and_true + pred_dtype, pred = dtype_and_pred helpers.test_function( - input_dtypes=targets_dtype + log_input_dtype, + input_dtypes=true_dtype + pred_dtype, test_flags=test_flags, - backend_to_fix=backend_fw, + backend_to_test=backend_fw, fn_name=fn_name, on_device=on_device, - targets=targets[0], - log_input=log_input[0], - compute_full_loss=compute_full_loss, - atol_=1e-2, + true=true[0], + pred=pred[0], + reduction=reduction, + delta=delta, ) @@ -101,102 +97,106 @@ def test_l1_loss( ) -# smooth_l1_loss -# all loss functions failing for paddle backend due to -# "There is no grad op for inputs:[0] or it's stop_gradient=True." +# log_poisson_loss @handle_test( - fn_tree="functional.ivy.experimental.smooth_l1_loss", - dtype_and_input=helpers.dtype_and_values( + fn_tree="functional.ivy.experimental.log_poisson_loss", + dtype_and_targets=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - min_value=-10.0, - max_value=10.0, + min_value=0, + max_value=3, allow_inf=False, min_num_dims=1, max_num_dims=3, min_dim_size=3, ), - dtype_and_target=helpers.dtype_and_values( + dtype_and_log_input=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - min_value=-10.0, - max_value=10.0, + small_abs_safety_factor=4, + safety_factor_scale="log", + min_value=0, + max_value=3, allow_inf=False, + exclude_min=True, + exclude_max=True, min_num_dims=1, max_num_dims=3, min_dim_size=3, ), - beta=helpers.floats(min_value=0.0, max_value=1.0), - reduction=st.sampled_from(["none", "sum", "mean"]), + compute_full_loss=st.sampled_from([True, False]), + test_with_out=st.just(False), ) -def test_smooth_l1_loss( - dtype_and_input, - dtype_and_target, - beta, - reduction, +def test_log_poisson_loss( + *, + dtype_and_targets, + dtype_and_log_input, + compute_full_loss, test_flags, backend_fw, fn_name, on_device, ): - dtype_input, input = dtype_and_input - dtype_target, target = dtype_and_target - + targets_dtype, targets = dtype_and_targets + log_input_dtype, log_input = dtype_and_log_input helpers.test_function( - input_dtypes=dtype_input + dtype_target, + input_dtypes=targets_dtype + log_input_dtype, test_flags=test_flags, - backend_to_test=backend_fw, + backend_to_fix=backend_fw, fn_name=fn_name, on_device=on_device, - input=input[0], - target=target[0], - beta=beta, - reduction=reduction, + targets=targets[0], + log_input=log_input[0], + compute_full_loss=compute_full_loss, + atol_=1e-2, ) -# huber_loss +# smooth_l1_loss +# all loss functions failing for paddle backend due to +# "There is no grad op for inputs:[0] or it's stop_gradient=True." @handle_test( - fn_tree="functional.ivy.experimental.huber_loss", - dtype_and_true=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - min_value=-10, - max_value=10, + fn_tree="functional.ivy.experimental.smooth_l1_loss", + dtype_and_input=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_value=-10.0, + max_value=10.0, allow_inf=False, min_num_dims=1, max_num_dims=3, min_dim_size=3, ), - dtype_and_pred=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - min_value=-10, - max_value=10, + dtype_and_target=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_value=-10.0, + max_value=10.0, allow_inf=False, min_num_dims=1, max_num_dims=3, min_dim_size=3, ), + beta=helpers.floats(min_value=0.0, max_value=1.0), reduction=st.sampled_from(["none", "sum", "mean"]), - delta=helpers.floats(min_value=0.01, max_value=2.0), ) -def test_huber_loss( - dtype_and_true, - dtype_and_pred, +def test_smooth_l1_loss( + dtype_and_input, + dtype_and_target, + beta, reduction, - delta, test_flags, backend_fw, fn_name, on_device, ): - true_dtype, true = dtype_and_true - pred_dtype, pred = dtype_and_pred + dtype_input, input = dtype_and_input + dtype_target, target = dtype_and_target + helpers.test_function( - input_dtypes=true_dtype + pred_dtype, + input_dtypes=dtype_input + dtype_target, test_flags=test_flags, backend_to_test=backend_fw, fn_name=fn_name, on_device=on_device, - true=true[0], - pred=pred[0], + input=input[0], + target=target[0], + beta=beta, reduction=reduction, - delta=delta, ) diff --git a/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_norms.py b/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_norms.py index c551338943bbf..1affe3770ce1b 100644 --- a/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_norms.py +++ b/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_norms.py @@ -7,25 +7,56 @@ import ivy -@handle_test( - fn_tree="functional.ivy.experimental.l1_normalize", - dtype_values_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("valid"), valid_axis=True - ), -) -def test_l1_normalize(*, dtype_values_axis, test_flags, backend_fw, fn_name, on_device): - x_dtype, x, axis = dtype_values_axis - helpers.test_function( - backend_to_test=backend_fw, - test_flags=test_flags, - fn_name=fn_name, - on_device=on_device, - rtol_=1e-1, - atol_=1e-1, - input_dtypes=x_dtype, - x=x, - axis=axis, +# --- Helpers --- # +# --------------- # + + +@st.composite +def _group_norm_helper(draw): + data_format = draw(st.sampled_from(["NSC", "NCS"])) + shape = draw( + helpers.get_shape( + min_num_dims=2, max_num_dims=4, min_dim_size=2, max_dim_size=4 + ) + ) + channel_size = shape[-1] + group_list = [*range(1, 4)] + group_list = list(filter(lambda x: (channel_size % x == 0), group_list)) + num_groups = draw(st.sampled_from(group_list)) + if data_format == "NCS": + shape = (shape[0], shape[-1], *shape[1:-1]) + x_dtype, x = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes( + "float", + ), + shape=shape, + large_abs_safety_factor=50, + small_abs_safety_factor=50, + safety_factor_scale="log", + ) + ) + _, offset = draw( + helpers.dtype_and_values( + dtype=x_dtype, + shape=(channel_size,), + large_abs_safety_factor=50, + small_abs_safety_factor=50, + safety_factor_scale="log", + ) + ) + + _, scale = draw( + helpers.dtype_and_values( + dtype=x_dtype, + shape=(channel_size,), + large_abs_safety_factor=50, + small_abs_safety_factor=50, + safety_factor_scale="log", + ) ) + eps = draw(helpers.floats(min_value=1e-5, max_value=0.1)) + return x_dtype, x[0], num_groups, data_format, scale[0], offset[0], eps @st.composite @@ -130,32 +161,8 @@ def _instance_and_batch_norm_helper(draw, *, min_dims=1, test_function="instance ) -@handle_test( - fn_tree="functional.ivy.experimental.instance_norm", - data=_instance_and_batch_norm_helper(min_dims=3), - training=st.booleans(), -) -def test_instance_norm(*, data, training, test_flags, backend_fw, fn_name, on_device): - x_dtype, x, mean, variance, offset, scale, eps, momentum, data_format = data - helpers.test_function( - backend_to_test=backend_fw, - test_flags=test_flags, - fn_name=fn_name, - on_device=on_device, - xs_grad_idxs=[[0, 0]], - rtol_=1e-1, - atol_=1e-1, - input_dtypes=x_dtype, - x=x, - mean=mean, - variance=variance, - scale=scale, - offset=offset, - eps=eps, - training=training, - momentum=momentum, - data_format=data_format, - ) +# --- Main --- # +# ------------ # # batch_norm @@ -187,54 +194,6 @@ def test_batch_norm(*, data, training, test_flags, backend_fw, fn_name, on_devic ) -@st.composite -def _group_norm_helper(draw): - data_format = draw(st.sampled_from(["NSC", "NCS"])) - shape = draw( - helpers.get_shape( - min_num_dims=2, max_num_dims=4, min_dim_size=2, max_dim_size=4 - ) - ) - channel_size = shape[-1] - group_list = [*range(1, 4)] - group_list = list(filter(lambda x: (channel_size % x == 0), group_list)) - num_groups = draw(st.sampled_from(group_list)) - if data_format == "NCS": - shape = (shape[0], shape[-1], *shape[1:-1]) - x_dtype, x = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes( - "float", - ), - shape=shape, - large_abs_safety_factor=50, - small_abs_safety_factor=50, - safety_factor_scale="log", - ) - ) - _, offset = draw( - helpers.dtype_and_values( - dtype=x_dtype, - shape=(channel_size,), - large_abs_safety_factor=50, - small_abs_safety_factor=50, - safety_factor_scale="log", - ) - ) - - _, scale = draw( - helpers.dtype_and_values( - dtype=x_dtype, - shape=(channel_size,), - large_abs_safety_factor=50, - small_abs_safety_factor=50, - safety_factor_scale="log", - ) - ) - eps = draw(helpers.floats(min_value=1e-5, max_value=0.1)) - return x_dtype, x[0], num_groups, data_format, scale[0], offset[0], eps - - # group_norm @handle_test( fn_tree="functional.ivy.experimental.group_norm", @@ -265,3 +224,52 @@ def test_group_norm( eps=eps, data_format=data_format, ) + + +@handle_test( + fn_tree="functional.ivy.experimental.instance_norm", + data=_instance_and_batch_norm_helper(min_dims=3), + training=st.booleans(), +) +def test_instance_norm(*, data, training, test_flags, backend_fw, fn_name, on_device): + x_dtype, x, mean, variance, offset, scale, eps, momentum, data_format = data + helpers.test_function( + backend_to_test=backend_fw, + test_flags=test_flags, + fn_name=fn_name, + on_device=on_device, + xs_grad_idxs=[[0, 0]], + rtol_=1e-1, + atol_=1e-1, + input_dtypes=x_dtype, + x=x, + mean=mean, + variance=variance, + scale=scale, + offset=offset, + eps=eps, + training=training, + momentum=momentum, + data_format=data_format, + ) + + +@handle_test( + fn_tree="functional.ivy.experimental.l1_normalize", + dtype_values_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("valid"), valid_axis=True + ), +) +def test_l1_normalize(*, dtype_values_axis, test_flags, backend_fw, fn_name, on_device): + x_dtype, x, axis = dtype_values_axis + helpers.test_function( + backend_to_test=backend_fw, + test_flags=test_flags, + fn_name=fn_name, + on_device=on_device, + rtol_=1e-1, + atol_=1e-1, + input_dtypes=x_dtype, + x=x, + axis=axis, + ) diff --git a/ivy_tests/test_ivy/test_functional/test_nn/test_activations.py b/ivy_tests/test_ivy/test_functional/test_nn/test_activations.py index 07e0b6dab339f..4aa174c565573 100644 --- a/ivy_tests/test_ivy/test_functional/test_nn/test_activations.py +++ b/ivy_tests/test_ivy/test_functional/test_nn/test_activations.py @@ -8,42 +8,54 @@ from ivy_tests.test_ivy.helpers import handle_test -# relu +# gelu @handle_test( - fn_tree="functional.ivy.relu", + fn_tree="functional.ivy.gelu", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - large_abs_safety_factor=8, - small_abs_safety_factor=8, - safety_factor_scale="log", + available_dtypes=helpers.get_dtypes("float_and_complex"), + large_abs_safety_factor=1, + small_abs_safety_factor=1, + safety_factor_scale="linear", + min_value=-1e4, + max_value=1e4, ), + approximate=st.booleans(), ) -def test_relu(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): +def test_gelu(*, dtype_and_x, approximate, test_flags, backend_fw, fn_name, on_device): dtype, x = dtype_and_x + if "complex" in str(x[0].dtype): + approximate = True helpers.test_function( input_dtypes=dtype, backend_to_test=backend_fw, test_flags=test_flags, fn_name=fn_name, on_device=on_device, + atol_=1e-2, + rtol_=1e-2, x=x[0], + approximate=approximate, ) -# leaky_relu +# hardswish @handle_test( - fn_tree="functional.ivy.leaky_relu", + fn_tree="functional.ivy.hardswish", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes( - "float_and_complex", full=False, key="leaky_relu" - ), - large_abs_safety_factor=16, - small_abs_safety_factor=16, + available_dtypes=helpers.get_dtypes("float"), + large_abs_safety_factor=8, + small_abs_safety_factor=8, safety_factor_scale="log", ), - alpha=st.floats(min_value=-1e-4, max_value=1e-4), ) -def test_leaky_relu(*, dtype_and_x, alpha, test_flags, backend_fw, fn_name, on_device): +def test_hardswish( + *, + dtype_and_x, + test_flags, + backend_fw, + fn_name, + on_device, +): dtype, x = dtype_and_x helpers.test_function( input_dtypes=dtype, @@ -51,54 +63,51 @@ def test_leaky_relu(*, dtype_and_x, alpha, test_flags, backend_fw, fn_name, on_d test_flags=test_flags, fn_name=fn_name, on_device=on_device, - rtol_=1e-2, - atol_=1e-2, x=x[0], - alpha=alpha, ) -# gelu +# leaky_relu @handle_test( - fn_tree="functional.ivy.gelu", + fn_tree="functional.ivy.leaky_relu", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float_and_complex"), - large_abs_safety_factor=1, - small_abs_safety_factor=1, - safety_factor_scale="linear", - min_value=-1e4, - max_value=1e4, + available_dtypes=helpers.get_dtypes( + "float_and_complex", full=False, key="leaky_relu" + ), + large_abs_safety_factor=16, + small_abs_safety_factor=16, + safety_factor_scale="log", ), - approximate=st.booleans(), + alpha=st.floats(min_value=-1e-4, max_value=1e-4), ) -def test_gelu(*, dtype_and_x, approximate, test_flags, backend_fw, fn_name, on_device): +def test_leaky_relu(*, dtype_and_x, alpha, test_flags, backend_fw, fn_name, on_device): dtype, x = dtype_and_x - if "complex" in str(x[0].dtype): - approximate = True helpers.test_function( input_dtypes=dtype, backend_to_test=backend_fw, test_flags=test_flags, fn_name=fn_name, on_device=on_device, - atol_=1e-2, rtol_=1e-2, + atol_=1e-2, x=x[0], - approximate=approximate, + alpha=alpha, ) -# sigmoid +# log_softmax @handle_test( - fn_tree="functional.ivy.sigmoid", + fn_tree="functional.ivy.log_softmax", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), + min_num_dims=1, large_abs_safety_factor=8, small_abs_safety_factor=8, safety_factor_scale="log", ), + axis=st.one_of(helpers.ints(min_value=-1, max_value=0), st.none()), ) -def test_sigmoid(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): +def test_log_softmax(*, dtype_and_x, axis, test_flags, backend_fw, fn_name, on_device): dtype, x = dtype_and_x helpers.test_function( input_dtypes=dtype, @@ -106,28 +115,24 @@ def test_sigmoid(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): test_flags=test_flags, fn_name=fn_name, on_device=on_device, - rtol_=1e-2, - atol_=1e-2, + rtol_=1e-02, + atol_=1e-02, x=x[0], + axis=axis, ) -# softmax +# mish @handle_test( - fn_tree="functional.ivy.softmax", + fn_tree="functional.ivy.mish", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - min_num_dims=1, large_abs_safety_factor=8, small_abs_safety_factor=8, safety_factor_scale="log", ), - axis=st.one_of( - helpers.ints(min_value=-1, max_value=0), - st.none(), - ), ) -def test_softmax(*, dtype_and_x, axis, test_flags, backend_fw, fn_name, on_device): +def test_mish(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): dtype, x = dtype_and_x helpers.test_function( input_dtypes=dtype, @@ -138,28 +143,20 @@ def test_softmax(*, dtype_and_x, axis, test_flags, backend_fw, fn_name, on_devic rtol_=1e-02, atol_=1e-02, x=x[0], - axis=axis, ) -# softplus +# relu @handle_test( - fn_tree="functional.ivy.softplus", + fn_tree="functional.ivy.relu", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), - min_num_dims=1, - large_abs_safety_factor=4, - small_abs_safety_factor=4, + large_abs_safety_factor=8, + small_abs_safety_factor=8, safety_factor_scale="log", ), - beta=st.one_of(helpers.number(min_value=0.1, max_value=10), st.none()), - threshold=st.one_of(helpers.number(min_value=0.1, max_value=30), st.none()), ) -def test_softplus( - *, dtype_and_x, beta, threshold, test_flags, backend_fw, fn_name, on_device -): - assume(beta != 0) - assume(threshold != 0) +def test_relu(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): dtype, x = dtype_and_x helpers.test_function( input_dtypes=dtype, @@ -167,27 +164,21 @@ def test_softplus( test_flags=test_flags, fn_name=fn_name, on_device=on_device, - rtol_=1e-02, - atol_=1e-02, x=x[0], - beta=beta, - threshold=threshold, ) -# log_softmax +# sigmoid @handle_test( - fn_tree="functional.ivy.log_softmax", + fn_tree="functional.ivy.sigmoid", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - min_num_dims=1, large_abs_safety_factor=8, small_abs_safety_factor=8, safety_factor_scale="log", ), - axis=st.one_of(helpers.ints(min_value=-1, max_value=0), st.none()), ) -def test_log_softmax(*, dtype_and_x, axis, test_flags, backend_fw, fn_name, on_device): +def test_sigmoid(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): dtype, x = dtype_and_x helpers.test_function( input_dtypes=dtype, @@ -195,24 +186,28 @@ def test_log_softmax(*, dtype_and_x, axis, test_flags, backend_fw, fn_name, on_d test_flags=test_flags, fn_name=fn_name, on_device=on_device, - rtol_=1e-02, - atol_=1e-02, + rtol_=1e-2, + atol_=1e-2, x=x[0], - axis=axis, ) -# mish +# softmax @handle_test( - fn_tree="functional.ivy.mish", + fn_tree="functional.ivy.softmax", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), + min_num_dims=1, large_abs_safety_factor=8, small_abs_safety_factor=8, safety_factor_scale="log", ), + axis=st.one_of( + helpers.ints(min_value=-1, max_value=0), + st.none(), + ), ) -def test_mish(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): +def test_softmax(*, dtype_and_x, axis, test_flags, backend_fw, fn_name, on_device): dtype, x = dtype_and_x helpers.test_function( input_dtypes=dtype, @@ -223,27 +218,28 @@ def test_mish(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): rtol_=1e-02, atol_=1e-02, x=x[0], + axis=axis, ) -# hardswish +# softplus @handle_test( - fn_tree="functional.ivy.hardswish", + fn_tree="functional.ivy.softplus", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - large_abs_safety_factor=8, - small_abs_safety_factor=8, + available_dtypes=helpers.get_dtypes("numeric"), + min_num_dims=1, + large_abs_safety_factor=4, + small_abs_safety_factor=4, safety_factor_scale="log", ), + beta=st.one_of(helpers.number(min_value=0.1, max_value=10), st.none()), + threshold=st.one_of(helpers.number(min_value=0.1, max_value=30), st.none()), ) -def test_hardswish( - *, - dtype_and_x, - test_flags, - backend_fw, - fn_name, - on_device, +def test_softplus( + *, dtype_and_x, beta, threshold, test_flags, backend_fw, fn_name, on_device ): + assume(beta != 0) + assume(threshold != 0) dtype, x = dtype_and_x helpers.test_function( input_dtypes=dtype, @@ -251,5 +247,9 @@ def test_hardswish( test_flags=test_flags, fn_name=fn_name, on_device=on_device, + rtol_=1e-02, + atol_=1e-02, x=x[0], + beta=beta, + threshold=threshold, ) diff --git a/ivy_tests/test_ivy/test_functional/test_nn/test_layers.py b/ivy_tests/test_ivy/test_functional/test_nn/test_layers.py index e0bbfa29038bb..196a6ffdf4934 100644 --- a/ivy_tests/test_ivy/test_functional/test_nn/test_layers.py +++ b/ivy_tests/test_ivy/test_functional/test_nn/test_layers.py @@ -12,74 +12,19 @@ from ivy.functional.ivy.layers import _deconv_length -# Linear # -# -------# -@st.composite -def _x_and_linear(draw): - mixed_fn_compos = draw(st.booleans()) - is_torch_backend = ivy.current_backend_str() == "torch" - dtype = draw( - helpers.get_dtypes("numeric", full=False, mixed_fn_compos=mixed_fn_compos) - ) - in_features = draw( - helpers.ints(min_value=1, max_value=2, mixed_fn_compos=mixed_fn_compos) - ) - out_features = draw( - helpers.ints(min_value=1, max_value=2, mixed_fn_compos=mixed_fn_compos) - ) - - x_shape = ( - 1, - 1, - in_features, - ) - - weight_shape = (1,) + (out_features,) + (in_features,) - # if backend is torch and we're testing the primary implementation - # weight.ndim should be equal to 2 - if is_torch_backend and not mixed_fn_compos: - weight_shape = (out_features,) + (in_features,) +# --- Helpers --- # +# --------------- # - bias_shape = ( - 1, - out_features, - ) - x = draw( - helpers.array_values(dtype=dtype[0], shape=x_shape, min_value=0, max_value=10) - ) - weight = draw( - helpers.array_values( - dtype=dtype[0], shape=weight_shape, min_value=0, max_value=10 - ) - ) - bias = draw( - helpers.array_values( - dtype=dtype[0], shape=bias_shape, min_value=0, max_value=10 +def _assume_tf_dilation_gt_1(backend_fw, on_device, dilations): + if backend_fw == "tensorflow": + assume( + not ( + on_device == "cpu" and (dilations > 1) + if isinstance(dilations, int) + else any(d > 1 for d in dilations) + ) ) - ) - return dtype, x, weight, bias - - -# linear -@handle_test( - fn_tree="functional.ivy.linear", - dtype_x_weight_bias=_x_and_linear(), -) -def test_linear(*, dtype_x_weight_bias, test_flags, backend_fw, fn_name, on_device): - dtype, x, weight, bias = dtype_x_weight_bias - helpers.test_function( - input_dtypes=dtype, - test_flags=test_flags, - backend_to_test=backend_fw, - fn_name=fn_name, - on_device=on_device, - rtol_=1e-02, - atol_=1e-02, - x=x, - weight=weight, - bias=bias, - ) # Dropout # @@ -122,150 +67,6 @@ def _dropout_helper(draw): return dtype_and_x, noise_shape, seed, dtype, prob, scale, training -# dropout -@handle_test( - fn_tree="functional.ivy.dropout", - data=_dropout_helper(), - test_gradients=st.just(False), -) -def test_dropout( - *, - data, - test_flags, - backend_fw, - fn_name, - on_device, -): - (x_dtype, x), noise_shape, seed, dtype, prob, scale, training = data - ret, gt_ret = helpers.test_function( - input_dtypes=x_dtype, - test_flags=test_flags, - backend_to_test=backend_fw, - fn_name=fn_name, - on_device=on_device, - test_values=False, - x=x[0], - prob=prob, - scale=scale, - noise_shape=noise_shape, - dtype=dtype[0], - training=training, - seed=seed, - ) - ret = helpers.flatten_and_to_np(ret=ret, backend=backend_fw) - gt_ret = helpers.flatten_and_to_np( - ret=gt_ret, backend=test_flags.ground_truth_backend - ) - for u, v, w in zip(ret, gt_ret, x): - # cardinality test - assert u.shape == v.shape == w.shape - - -# Attention # -# ----------# - - -@st.composite -def _x_and_scaled_attention(draw, dtypes): - dtype = draw(dtypes) - num_queries = draw(helpers.ints(min_value=2, max_value=4)) - num_keys = draw(helpers.ints(min_value=2, max_value=4)) - feat_dim = draw(helpers.ints(min_value=2, max_value=4)) - batch_size = draw(helpers.ints(min_value=1, max_value=2)) - q_shape = (batch_size,) + (num_queries,) + (feat_dim,) - k_shape = (batch_size,) + (num_keys,) + (feat_dim,) - v_shape = (batch_size,) + (num_keys,) + (feat_dim,) - mask_shape = (batch_size,) + (num_queries,) + (num_keys,) - - query = draw( - helpers.array_values( - dtype=dtype[0], - shape=q_shape, - min_value=1e-3, - max_value=1e2, - large_abs_safety_factor=7, - small_abs_safety_factor=7, - safety_factor_scale="linear", - ) - ) - key = draw( - helpers.array_values( - dtype=dtype[0], - shape=k_shape, - min_value=1e-3, - max_value=1e2, - large_abs_safety_factor=7, - small_abs_safety_factor=7, - safety_factor_scale="linear", - ) - ) - value = draw( - helpers.array_values( - dtype=dtype[0], - shape=v_shape, - min_value=1e-3, - max_value=1e2, - large_abs_safety_factor=7, - small_abs_safety_factor=7, - safety_factor_scale="linear", - ) - ) - mask = draw( - helpers.array_values( - dtype="bool", - shape=mask_shape, - ) - | st.none() - ) - return dtype, query, key, value, mask - - -# scaled_dot_product_attention -@handle_test( - fn_tree="functional.ivy.scaled_dot_product_attention", - dtype_q_k_v_mask=_x_and_scaled_attention( - dtypes=helpers.get_dtypes("float", full=False), - ), - scale=st.floats(min_value=0.1, max_value=1), - dropout_p=st.floats(min_value=0, max_value=0.99), - is_causal=st.booleans(), - training=st.just(False), # st.booleans(), disabled until proper testing is used - ground_truth_backend="jax", - test_with_out=st.just(True), -) -def test_scaled_dot_product_attention( - *, - dtype_q_k_v_mask, - scale, - dropout_p, - is_causal, - training, - test_flags, - backend_fw, - fn_name, - on_device, -): - (dtype, query, key, value, mask) = dtype_q_k_v_mask - is_causal = is_causal if mask is None else False - helpers.test_function( - input_dtypes=dtype, - test_flags=test_flags, - backend_to_test=backend_fw, - fn_name=fn_name, - on_device=on_device, - atol_=1e-02, - rtol_=1e-02, - query=query, - key=key, - value=value, - scale=scale, - mask=mask, - dropout_p=dropout_p, - is_causal=is_causal, - training=training, - ) - - @st.composite def _mha_helper(draw): _qkv_same_dim = draw(st.booleans()) @@ -436,89 +237,19 @@ def _mha_helper(draw): ) -# multi_head_attention -@handle_test( - fn_tree="functional.ivy.multi_head_attention", - dtype_mha=_mha_helper(), - scale=st.one_of(st.floats(), st.none()), - dropout=st.floats(min_value=0, max_value=0.99), - training=st.just(False), # st.booleans(), disabled until proper testing is used - is_causal=st.booleans(), - return_attention_weights=st.booleans(), - average_attention_weights=st.booleans(), - ground_truth_backend="jax", -) -def test_multi_head_attention( - *, - dtype_mha, - scale, - dropout, - training, - is_causal, - return_attention_weights, - average_attention_weights, - test_flags, - backend_fw, - fn_name, - on_device, -): - ( - dtype, - q, - k, - v, - num_heads, - attention_mask, - in_proj_weights, - q_proj_weights, - k_proj_weights, - v_proj_weights, - out_proj_weights, - in_proj_bias, - out_proj_bias, - ) = dtype_mha - helpers.test_function( - input_dtypes=dtype, - test_flags=test_flags, - backend_to_test=backend_fw, - fn_name=fn_name, - on_device=on_device, - atol_=1e-02, - rtol_=1e-02, - query=q, - key=k, - value=v, - num_heads=num_heads, - scale=scale, - attention_mask=attention_mask, - in_proj_weights=in_proj_weights, - q_proj_weights=q_proj_weights, - k_proj_weights=k_proj_weights, - v_proj_weights=v_proj_weights, - out_proj_weights=out_proj_weights, - in_proj_bias=in_proj_bias, - out_proj_bias=out_proj_bias, - is_causal=is_causal, - return_attention_weights=return_attention_weights, - average_attention_weights=average_attention_weights, - dropout=dropout, - training=training, - ) - - -# Convolutions # -# -------------# - - -@st.composite -def _x_and_filters( - draw, - dim: int = 2, - transpose: bool = False, - depthwise=False, - general=False, - bias=False, - filter_format=None, +# Convolutions # +# -------------# + + +@st.composite +def _x_and_filters( + draw, + dim: int = 2, + transpose: bool = False, + depthwise=False, + general=False, + bias=False, + filter_format=None, ): if not isinstance(dim, int): dim = draw(dim) @@ -672,212 +403,284 @@ def _x_and_filters( return ret -def _assume_tf_dilation_gt_1(backend_fw, on_device, dilations): - if backend_fw == "tensorflow": - assume( - not ( - on_device == "cpu" and (dilations > 1) - if isinstance(dilations, int) - else any(d > 1 for d in dilations) - ) +# filter_format not in conv_general_transpose +# output_shape not in conv_general_dilated +@st.composite +def _x_and_filters_and_transpose( + draw, + dim: int = 2, + general=False, + bias=False, + filter_format=None, +): + transpose = draw(st.booleans()) + if not transpose: + filter_format = st.sampled_from(["channel_last", "channel_first"]) + all_args = draw( + _x_and_filters( + dim=dim, + general=general, + bias=bias, + filter_format=filter_format, + transpose=transpose, ) - - -# conv1d -@handle_test( - fn_tree="functional.ivy.conv1d", - x_f_d_df=_x_and_filters( - dim=1, - bias=True, - filter_format=st.sampled_from(["channel_last", "channel_first"]), - ), - ground_truth_backend="jax", -) -def test_conv1d(*, x_f_d_df, test_flags, backend_fw, fn_name, on_device): - ( - dtype, - x, - filters, - dilations, - data_format, - stride, - pad, - fc, - ff_format, - bias, - ) = x_f_d_df - # ToDo: Enable gradient tests for dilations > 1 when tensorflow supports it. - _assume_tf_dilation_gt_1(backend_fw, on_device, dilations[0]) - helpers.test_function( - input_dtypes=dtype, - test_flags=test_flags, - backend_to_test=backend_fw, - fn_name=fn_name, - on_device=on_device, - rtol_=1e-02, - atol_=1e-02, - x=x, - filters=filters, - strides=stride, - padding=pad, - data_format=data_format, - filter_format=ff_format, - x_dilations=dilations[1], - dilations=dilations[0], - bias=bias, ) - - -# conv1d_transpose -@handle_test( - fn_tree="functional.ivy.conv1d_transpose", - x_f_d_df=_x_and_filters( - dim=1, - transpose=True, - bias=True, - ), - ground_truth_backend="jax", -) -def test_conv1d_transpose(*, x_f_d_df, test_flags, backend_fw, fn_name, on_device): - ( + output_shape = None + filter_format = "channel_last" + if transpose: + ( + dtype, + x, + filters, + dilations, + data_format, + stride, + pad, + output_shape, + fc, + bias, + ) = all_args + else: + ( + dtype, + x, + filters, + dilations, + data_format, + stride, + pad, + fc, + filter_format, + bias, + ) = all_args + return ( dtype, x, filters, - dilations, - data_format, stride, pad, + transpose, output_shape, + data_format, + filter_format, fc, + dilations, bias, - ) = x_f_d_df - _assume_tf_dilation_gt_1(backend_fw, on_device, dilations[0]) - helpers.test_function( - input_dtypes=dtype, - test_flags=test_flags, - backend_to_test=backend_fw, - fn_name=fn_name, - on_device=on_device, - rtol_=1e-2, - atol_=1e-2, - # tensorflow does not work with dilations > 1 on cpu - x=x, - filters=filters, - strides=stride, - padding=pad, - output_shape=output_shape, - data_format=data_format, - dilations=dilations[0], - bias=bias, ) -# conv2d -@handle_test( - fn_tree="functional.ivy.conv2d", - x_f_d_df=_x_and_filters( - dim=2, - bias=True, - filter_format=st.sampled_from(["channel_last", "channel_first"]), - ), - ground_truth_backend="jax", -) -def test_conv2d(*, x_f_d_df, test_flags, backend_fw, fn_name, on_device): - ( - dtype, - x, - filters, - dilations, - data_format, - stride, - pad, - fc, - ff_format, - bias, - ) = x_f_d_df - # ToDo: Enable gradient tests for dilations > 1 when tensorflow supports it. - _assume_tf_dilation_gt_1(backend_fw, on_device, dilations[0]) - helpers.test_function( - input_dtypes=dtype, - test_flags=test_flags, - backend_to_test=backend_fw, - fn_name=fn_name, - on_device=on_device, - rtol_=1e-2, - atol_=1e-2, - x=x, - filters=filters, - strides=stride, - padding=pad, - data_format=data_format, - filter_format=ff_format, - x_dilations=dilations[1], - dilations=dilations[0], - bias=bias, +# Linear # +# -------# +@st.composite +def _x_and_linear(draw): + mixed_fn_compos = draw(st.booleans()) + is_torch_backend = ivy.current_backend_str() == "torch" + dtype = draw( + helpers.get_dtypes("numeric", full=False, mixed_fn_compos=mixed_fn_compos) + ) + in_features = draw( + helpers.ints(min_value=1, max_value=2, mixed_fn_compos=mixed_fn_compos) + ) + out_features = draw( + helpers.ints(min_value=1, max_value=2, mixed_fn_compos=mixed_fn_compos) ) - -# conv2d_transpose -@handle_test( - fn_tree="functional.ivy.conv2d_transpose", - x_f_d_df=_x_and_filters( - dim=2, - transpose=True, + x_shape = ( + 1, + 1, + in_features, + ) + + weight_shape = (1,) + (out_features,) + (in_features,) + # if backend is torch and we're testing the primary implementation + # weight.ndim should be equal to 2 + if is_torch_backend and not mixed_fn_compos: + weight_shape = (out_features,) + (in_features,) + + bias_shape = ( + 1, + out_features, + ) + + x = draw( + helpers.array_values(dtype=dtype[0], shape=x_shape, min_value=0, max_value=10) + ) + weight = draw( + helpers.array_values( + dtype=dtype[0], shape=weight_shape, min_value=0, max_value=10 + ) + ) + bias = draw( + helpers.array_values( + dtype=dtype[0], shape=bias_shape, min_value=0, max_value=10 + ) + ) + return dtype, x, weight, bias + + +# LSTM # +# -----# + + +@st.composite +def _x_and_lstm(draw, dtypes): + dtype = draw(dtypes) + batch_shape = (1,) + + t = draw(helpers.ints(min_value=1, max_value=2)) + _in_ = draw(helpers.ints(min_value=1, max_value=2)) + _out_ = draw(helpers.ints(min_value=1, max_value=2)) + + x_lstm_shape = batch_shape + (t,) + (_in_,) + init_h_shape = batch_shape + (_out_,) + init_c_shape = init_h_shape + kernel_shape = (_in_,) + (4 * _out_,) + recurrent_kernel_shape = (_out_,) + (4 * _out_,) + bias_shape = (4 * _out_,) + recurrent_bias_shape = bias_shape + + x_lstm = draw( + helpers.array_values( + dtype=dtype[0], shape=x_lstm_shape, min_value=0, max_value=1 + ) + ) + init_h = draw( + helpers.array_values( + dtype=dtype[0], shape=init_h_shape, min_value=0, max_value=1 + ) + ) + init_c = draw( + helpers.array_values( + dtype=dtype[0], shape=init_c_shape, min_value=0, max_value=1 + ) + ) + kernel = draw( + helpers.array_values( + dtype=dtype[0], shape=kernel_shape, min_value=0, max_value=1 + ) + ) + recurrent_kernel = draw( + helpers.array_values( + dtype=dtype[0], shape=recurrent_kernel_shape, min_value=0, max_value=1 + ) + ) + lstm_bias = draw( + helpers.array_values(dtype=dtype[0], shape=bias_shape, min_value=0, max_value=1) + ) + recurrent_bias = draw( + helpers.array_values( + dtype=dtype[0], shape=recurrent_bias_shape, min_value=0, max_value=1 + ) + ) + return ( + dtype, + x_lstm, + init_h, + init_c, + kernel, + recurrent_kernel, + lstm_bias, + recurrent_bias, + ) + + +# Attention # +# ----------# + + +@st.composite +def _x_and_scaled_attention(draw, dtypes): + dtype = draw(dtypes) + num_queries = draw(helpers.ints(min_value=2, max_value=4)) + num_keys = draw(helpers.ints(min_value=2, max_value=4)) + feat_dim = draw(helpers.ints(min_value=2, max_value=4)) + batch_size = draw(helpers.ints(min_value=1, max_value=2)) + q_shape = (batch_size,) + (num_queries,) + (feat_dim,) + k_shape = (batch_size,) + (num_keys,) + (feat_dim,) + v_shape = (batch_size,) + (num_keys,) + (feat_dim,) + mask_shape = (batch_size,) + (num_queries,) + (num_keys,) + + query = draw( + helpers.array_values( + dtype=dtype[0], + shape=q_shape, + min_value=1e-3, + max_value=1e2, + large_abs_safety_factor=7, + small_abs_safety_factor=7, + safety_factor_scale="linear", + ) + ) + key = draw( + helpers.array_values( + dtype=dtype[0], + shape=k_shape, + min_value=1e-3, + max_value=1e2, + large_abs_safety_factor=7, + small_abs_safety_factor=7, + safety_factor_scale="linear", + ) + ) + value = draw( + helpers.array_values( + dtype=dtype[0], + shape=v_shape, + min_value=1e-3, + max_value=1e2, + large_abs_safety_factor=7, + small_abs_safety_factor=7, + safety_factor_scale="linear", + ) + ) + mask = draw( + helpers.array_values( + dtype="bool", + shape=mask_shape, + ) + | st.none() + ) + return dtype, query, key, value, mask + + +# --- Main --- # +# ------------ # + + +# conv +@handle_test( + fn_tree="functional.ivy.conv", + dims=st.shared(st.integers(1, 3), key="dims"), + x_f_d_df_tr=_x_and_filters_and_transpose( + dim=st.shared(st.integers(1, 3), key="dims"), + general=True, bias=True, ), - # tensorflow does not work with dilations > 1 on cpu ground_truth_backend="jax", ) -def test_conv2d_transpose(*, x_f_d_df, test_flags, backend_fw, fn_name, on_device): +def test_conv(*, dims, x_f_d_df_tr, test_flags, backend_fw, fn_name, on_device): + # pass ( dtype, x, filters, - dilations, - data_format, stride, pad, + transpose, output_shape, + data_format, + filter_format, fc, + dilations, bias, - ) = x_f_d_df - _assume_tf_dilation_gt_1(backend_fw, on_device, dilations[0]) - - helpers.test_function( - input_dtypes=dtype, - test_flags=test_flags, - backend_to_test=backend_fw, - fn_name=fn_name, - rtol_=1e-2, - atol_=1e-2, - on_device=on_device, - x=x, - filters=filters, - strides=stride, - padding=pad, - output_shape=output_shape, - data_format=data_format, - dilations=dilations[0], - bias=bias, - ) - - -# depthwise_conv2d -@handle_test( - fn_tree="functional.ivy.depthwise_conv2d", - x_f_d_df=_x_and_filters( - dim=2, - depthwise=True, - ), - # tensorflow does not support dilations > 1 and stride > 1 - ground_truth_backend="jax", -) -def test_depthwise_conv2d(*, x_f_d_df, test_flags, backend_fw, fn_name, on_device): - dtype, x, filters, dilations, data_format, stride, pad, fc = x_f_d_df - _assume_tf_dilation_gt_1(backend_fw, on_device, dilations[0]) - # tensorflow only supports equal length strides in row and column - if backend_fw == "tensorflow" and isinstance(stride, list) and len(stride) > 1: - assume(stride[0] == stride[1]) + ) = x_f_d_df_tr + tf_dilations = dilations + if not transpose: + tf_dilations = tf_dilations[0] + dilations, x_dilations = dilations + else: + x_dilations = None + _assume_tf_dilation_gt_1(backend_fw, on_device, tf_dilations) helpers.test_function( input_dtypes=dtype, test_flags=test_flags, @@ -890,22 +693,29 @@ def test_depthwise_conv2d(*, x_f_d_df, test_flags, backend_fw, fn_name, on_devic filters=filters, strides=stride, padding=pad, + transpose=transpose, + dims=dims, + output_shape=output_shape, data_format=data_format, + filter_format=filter_format, + feature_group_count=fc, + x_dilations=x_dilations, dilations=dilations, + bias=bias, ) -# conv3d +# conv1d @handle_test( - fn_tree="functional.ivy.conv3d", + fn_tree="functional.ivy.conv1d", x_f_d_df=_x_and_filters( - dim=3, + dim=1, bias=True, filter_format=st.sampled_from(["channel_last", "channel_first"]), ), ground_truth_backend="jax", ) -def test_conv3d(*, x_f_d_df, test_flags, backend_fw, fn_name, on_device): +def test_conv1d(*, x_f_d_df, test_flags, backend_fw, fn_name, on_device): ( dtype, x, @@ -918,6 +728,7 @@ def test_conv3d(*, x_f_d_df, test_flags, backend_fw, fn_name, on_device): ff_format, bias, ) = x_f_d_df + # ToDo: Enable gradient tests for dilations > 1 when tensorflow supports it. _assume_tf_dilation_gt_1(backend_fw, on_device, dilations[0]) helpers.test_function( input_dtypes=dtype, @@ -925,8 +736,8 @@ def test_conv3d(*, x_f_d_df, test_flags, backend_fw, fn_name, on_device): backend_to_test=backend_fw, fn_name=fn_name, on_device=on_device, - rtol_=1e-2, - atol_=1e-2, + rtol_=1e-02, + atol_=1e-02, x=x, filters=filters, strides=stride, @@ -939,17 +750,17 @@ def test_conv3d(*, x_f_d_df, test_flags, backend_fw, fn_name, on_device): ) -# conv3d_transpose +# conv1d_transpose @handle_test( - fn_tree="functional.ivy.conv3d_transpose", + fn_tree="functional.ivy.conv1d_transpose", x_f_d_df=_x_and_filters( - dim=3, + dim=1, transpose=True, bias=True, ), ground_truth_backend="jax", ) -def test_conv3d_transpose(*, x_f_d_df, test_flags, backend_fw, fn_name, on_device): +def test_conv1d_transpose(*, x_f_d_df, test_flags, backend_fw, fn_name, on_device): ( dtype, x, @@ -971,6 +782,7 @@ def test_conv3d_transpose(*, x_f_d_df, test_flags, backend_fw, fn_name, on_devic on_device=on_device, rtol_=1e-2, atol_=1e-2, + # tensorflow does not work with dilations > 1 on cpu x=x, filters=filters, strides=stride, @@ -982,21 +794,17 @@ def test_conv3d_transpose(*, x_f_d_df, test_flags, backend_fw, fn_name, on_devic ) -# conv_general_dilated +# conv2d @handle_test( - fn_tree="functional.ivy.conv_general_dilated", - dims=st.shared(st.integers(1, 3), key="dims"), + fn_tree="functional.ivy.conv2d", x_f_d_df=_x_and_filters( - dim=st.shared(st.integers(1, 3), key="dims"), - general=True, + dim=2, bias=True, filter_format=st.sampled_from(["channel_last", "channel_first"]), ), ground_truth_backend="jax", ) -def test_conv_general_dilated( - *, dims, x_f_d_df, test_flags, backend_fw, fn_name, on_device -): +def test_conv2d(*, x_f_d_df, test_flags, backend_fw, fn_name, on_device): ( dtype, x, @@ -1009,6 +817,7 @@ def test_conv_general_dilated( ff_format, bias, ) = x_f_d_df + # ToDo: Enable gradient tests for dilations > 1 when tensorflow supports it. _assume_tf_dilation_gt_1(backend_fw, on_device, dilations[0]) helpers.test_function( input_dtypes=dtype, @@ -1022,30 +831,26 @@ def test_conv_general_dilated( filters=filters, strides=stride, padding=pad, - dims=dims, data_format=data_format, filter_format=ff_format, - feature_group_count=fc, x_dilations=dilations[1], dilations=dilations[0], bias=bias, ) +# conv2d_transpose @handle_test( - fn_tree="functional.ivy.conv_general_transpose", - dims=st.shared(st.integers(1, 3), key="dims"), + fn_tree="functional.ivy.conv2d_transpose", x_f_d_df=_x_and_filters( - dim=st.shared(st.integers(1, 3), key="dims"), - general=True, + dim=2, transpose=True, bias=True, ), + # tensorflow does not work with dilations > 1 on cpu ground_truth_backend="jax", ) -def test_conv_general_transpose( - *, dims, x_f_d_df, test_flags, backend_fw, fn_name, on_device -): +def test_conv2d_transpose(*, x_f_d_df, test_flags, backend_fw, fn_name, on_device): ( dtype, x, @@ -1058,128 +863,142 @@ def test_conv_general_transpose( fc, bias, ) = x_f_d_df - _assume_tf_dilation_gt_1(backend_fw, on_device, dilations) + _assume_tf_dilation_gt_1(backend_fw, on_device, dilations[0]) + helpers.test_function( input_dtypes=dtype, test_flags=test_flags, backend_to_test=backend_fw, fn_name=fn_name, + rtol_=1e-2, + atol_=1e-2, on_device=on_device, - rtol_=1e-1, - atol_=1e-1, x=x, filters=filters, strides=stride, padding=pad, - dims=dims, output_shape=output_shape, data_format=data_format, - dilations=dilations, - feature_group_count=fc, + dilations=dilations[0], bias=bias, ) -# filter_format not in conv_general_transpose -# output_shape not in conv_general_dilated -@st.composite -def _x_and_filters_and_transpose( - draw, - dim: int = 2, - general=False, - bias=False, - filter_format=None, -): - transpose = draw(st.booleans()) - if not transpose: - filter_format = st.sampled_from(["channel_last", "channel_first"]) - all_args = draw( - _x_and_filters( - dim=dim, - general=general, - bias=bias, - filter_format=filter_format, - transpose=transpose, - ) +# conv3d +@handle_test( + fn_tree="functional.ivy.conv3d", + x_f_d_df=_x_and_filters( + dim=3, + bias=True, + filter_format=st.sampled_from(["channel_last", "channel_first"]), + ), + ground_truth_backend="jax", +) +def test_conv3d(*, x_f_d_df, test_flags, backend_fw, fn_name, on_device): + ( + dtype, + x, + filters, + dilations, + data_format, + stride, + pad, + fc, + ff_format, + bias, + ) = x_f_d_df + _assume_tf_dilation_gt_1(backend_fw, on_device, dilations[0]) + helpers.test_function( + input_dtypes=dtype, + test_flags=test_flags, + backend_to_test=backend_fw, + fn_name=fn_name, + on_device=on_device, + rtol_=1e-2, + atol_=1e-2, + x=x, + filters=filters, + strides=stride, + padding=pad, + data_format=data_format, + filter_format=ff_format, + x_dilations=dilations[1], + dilations=dilations[0], + bias=bias, ) - output_shape = None - filter_format = "channel_last" - if transpose: - ( - dtype, - x, - filters, - dilations, - data_format, - stride, - pad, - output_shape, - fc, - bias, - ) = all_args - else: - ( - dtype, - x, - filters, - dilations, - data_format, - stride, - pad, - fc, - filter_format, - bias, - ) = all_args - return ( + + +# conv3d_transpose +@handle_test( + fn_tree="functional.ivy.conv3d_transpose", + x_f_d_df=_x_and_filters( + dim=3, + transpose=True, + bias=True, + ), + ground_truth_backend="jax", +) +def test_conv3d_transpose(*, x_f_d_df, test_flags, backend_fw, fn_name, on_device): + ( dtype, x, filters, + dilations, + data_format, stride, pad, - transpose, output_shape, - data_format, - filter_format, fc, - dilations, bias, + ) = x_f_d_df + _assume_tf_dilation_gt_1(backend_fw, on_device, dilations[0]) + helpers.test_function( + input_dtypes=dtype, + test_flags=test_flags, + backend_to_test=backend_fw, + fn_name=fn_name, + on_device=on_device, + rtol_=1e-2, + atol_=1e-2, + x=x, + filters=filters, + strides=stride, + padding=pad, + output_shape=output_shape, + data_format=data_format, + dilations=dilations[0], + bias=bias, ) -# conv +# conv_general_dilated @handle_test( - fn_tree="functional.ivy.conv", + fn_tree="functional.ivy.conv_general_dilated", dims=st.shared(st.integers(1, 3), key="dims"), - x_f_d_df_tr=_x_and_filters_and_transpose( + x_f_d_df=_x_and_filters( dim=st.shared(st.integers(1, 3), key="dims"), general=True, bias=True, + filter_format=st.sampled_from(["channel_last", "channel_first"]), ), ground_truth_backend="jax", ) -def test_conv(*, dims, x_f_d_df_tr, test_flags, backend_fw, fn_name, on_device): - # pass +def test_conv_general_dilated( + *, dims, x_f_d_df, test_flags, backend_fw, fn_name, on_device +): ( dtype, x, filters, + dilations, + data_format, stride, pad, - transpose, - output_shape, - data_format, - filter_format, fc, - dilations, + ff_format, bias, - ) = x_f_d_df_tr - tf_dilations = dilations - if not transpose: - tf_dilations = tf_dilations[0] - dilations, x_dilations = dilations - else: - x_dilations = None - _assume_tf_dilation_gt_1(backend_fw, on_device, tf_dilations) + ) = x_f_d_df + _assume_tf_dilation_gt_1(backend_fw, on_device, dilations[0]) helpers.test_function( input_dtypes=dtype, test_flags=test_flags, @@ -1192,81 +1011,154 @@ def test_conv(*, dims, x_f_d_df_tr, test_flags, backend_fw, fn_name, on_device): filters=filters, strides=stride, padding=pad, - transpose=transpose, dims=dims, - output_shape=output_shape, data_format=data_format, - filter_format=filter_format, + filter_format=ff_format, feature_group_count=fc, - x_dilations=x_dilations, - dilations=dilations, + x_dilations=dilations[1], + dilations=dilations[0], bias=bias, ) -# LSTM # -# -----# - +@handle_test( + fn_tree="functional.ivy.conv_general_transpose", + dims=st.shared(st.integers(1, 3), key="dims"), + x_f_d_df=_x_and_filters( + dim=st.shared(st.integers(1, 3), key="dims"), + general=True, + transpose=True, + bias=True, + ), + ground_truth_backend="jax", +) +def test_conv_general_transpose( + *, dims, x_f_d_df, test_flags, backend_fw, fn_name, on_device +): + ( + dtype, + x, + filters, + dilations, + data_format, + stride, + pad, + output_shape, + fc, + bias, + ) = x_f_d_df + _assume_tf_dilation_gt_1(backend_fw, on_device, dilations) + helpers.test_function( + input_dtypes=dtype, + test_flags=test_flags, + backend_to_test=backend_fw, + fn_name=fn_name, + on_device=on_device, + rtol_=1e-1, + atol_=1e-1, + x=x, + filters=filters, + strides=stride, + padding=pad, + dims=dims, + output_shape=output_shape, + data_format=data_format, + dilations=dilations, + feature_group_count=fc, + bias=bias, + ) -@st.composite -def _x_and_lstm(draw, dtypes): - dtype = draw(dtypes) - batch_shape = (1,) - t = draw(helpers.ints(min_value=1, max_value=2)) - _in_ = draw(helpers.ints(min_value=1, max_value=2)) - _out_ = draw(helpers.ints(min_value=1, max_value=2)) +# depthwise_conv2d +@handle_test( + fn_tree="functional.ivy.depthwise_conv2d", + x_f_d_df=_x_and_filters( + dim=2, + depthwise=True, + ), + # tensorflow does not support dilations > 1 and stride > 1 + ground_truth_backend="jax", +) +def test_depthwise_conv2d(*, x_f_d_df, test_flags, backend_fw, fn_name, on_device): + dtype, x, filters, dilations, data_format, stride, pad, fc = x_f_d_df + _assume_tf_dilation_gt_1(backend_fw, on_device, dilations[0]) + # tensorflow only supports equal length strides in row and column + if backend_fw == "tensorflow" and isinstance(stride, list) and len(stride) > 1: + assume(stride[0] == stride[1]) + helpers.test_function( + input_dtypes=dtype, + test_flags=test_flags, + backend_to_test=backend_fw, + fn_name=fn_name, + on_device=on_device, + rtol_=1e-2, + atol_=1e-2, + x=x, + filters=filters, + strides=stride, + padding=pad, + data_format=data_format, + dilations=dilations, + ) - x_lstm_shape = batch_shape + (t,) + (_in_,) - init_h_shape = batch_shape + (_out_,) - init_c_shape = init_h_shape - kernel_shape = (_in_,) + (4 * _out_,) - recurrent_kernel_shape = (_out_,) + (4 * _out_,) - bias_shape = (4 * _out_,) - recurrent_bias_shape = bias_shape - x_lstm = draw( - helpers.array_values( - dtype=dtype[0], shape=x_lstm_shape, min_value=0, max_value=1 - ) - ) - init_h = draw( - helpers.array_values( - dtype=dtype[0], shape=init_h_shape, min_value=0, max_value=1 - ) - ) - init_c = draw( - helpers.array_values( - dtype=dtype[0], shape=init_c_shape, min_value=0, max_value=1 - ) - ) - kernel = draw( - helpers.array_values( - dtype=dtype[0], shape=kernel_shape, min_value=0, max_value=1 - ) - ) - recurrent_kernel = draw( - helpers.array_values( - dtype=dtype[0], shape=recurrent_kernel_shape, min_value=0, max_value=1 - ) - ) - lstm_bias = draw( - helpers.array_values(dtype=dtype[0], shape=bias_shape, min_value=0, max_value=1) +# dropout +@handle_test( + fn_tree="functional.ivy.dropout", + data=_dropout_helper(), + test_gradients=st.just(False), +) +def test_dropout( + *, + data, + test_flags, + backend_fw, + fn_name, + on_device, +): + (x_dtype, x), noise_shape, seed, dtype, prob, scale, training = data + ret, gt_ret = helpers.test_function( + input_dtypes=x_dtype, + test_flags=test_flags, + backend_to_test=backend_fw, + fn_name=fn_name, + on_device=on_device, + test_values=False, + x=x[0], + prob=prob, + scale=scale, + noise_shape=noise_shape, + dtype=dtype[0], + training=training, + seed=seed, ) - recurrent_bias = draw( - helpers.array_values( - dtype=dtype[0], shape=recurrent_bias_shape, min_value=0, max_value=1 - ) + ret = helpers.flatten_and_to_np(ret=ret, backend=backend_fw) + gt_ret = helpers.flatten_and_to_np( + ret=gt_ret, backend=test_flags.ground_truth_backend ) - return ( - dtype, - x_lstm, - init_h, - init_c, - kernel, - recurrent_kernel, - lstm_bias, - recurrent_bias, + for u, v, w in zip(ret, gt_ret, x): + # cardinality test + assert u.shape == v.shape == w.shape + + +# linear +@handle_test( + fn_tree="functional.ivy.linear", + dtype_x_weight_bias=_x_and_linear(), +) +def test_linear(*, dtype_x_weight_bias, test_flags, backend_fw, fn_name, on_device): + dtype, x, weight, bias = dtype_x_weight_bias + helpers.test_function( + input_dtypes=dtype, + test_flags=test_flags, + backend_to_test=backend_fw, + fn_name=fn_name, + on_device=on_device, + rtol_=1e-02, + atol_=1e-02, + x=x, + weight=weight, + bias=bias, ) @@ -1305,3 +1197,119 @@ def test_lstm_update(*, dtype_lstm, test_flags, backend_fw, fn_name, on_device): bias=bias, recurrent_bias=recurrent_bias, ) + + +# multi_head_attention +@handle_test( + fn_tree="functional.ivy.multi_head_attention", + dtype_mha=_mha_helper(), + scale=st.one_of(st.floats(), st.none()), + dropout=st.floats(min_value=0, max_value=0.99), + training=st.just(False), # st.booleans(), disabled until proper testing is used + is_causal=st.booleans(), + return_attention_weights=st.booleans(), + average_attention_weights=st.booleans(), + ground_truth_backend="jax", +) +def test_multi_head_attention( + *, + dtype_mha, + scale, + dropout, + training, + is_causal, + return_attention_weights, + average_attention_weights, + test_flags, + backend_fw, + fn_name, + on_device, +): + ( + dtype, + q, + k, + v, + num_heads, + attention_mask, + in_proj_weights, + q_proj_weights, + k_proj_weights, + v_proj_weights, + out_proj_weights, + in_proj_bias, + out_proj_bias, + ) = dtype_mha + helpers.test_function( + input_dtypes=dtype, + test_flags=test_flags, + backend_to_test=backend_fw, + fn_name=fn_name, + on_device=on_device, + atol_=1e-02, + rtol_=1e-02, + query=q, + key=k, + value=v, + num_heads=num_heads, + scale=scale, + attention_mask=attention_mask, + in_proj_weights=in_proj_weights, + q_proj_weights=q_proj_weights, + k_proj_weights=k_proj_weights, + v_proj_weights=v_proj_weights, + out_proj_weights=out_proj_weights, + in_proj_bias=in_proj_bias, + out_proj_bias=out_proj_bias, + is_causal=is_causal, + return_attention_weights=return_attention_weights, + average_attention_weights=average_attention_weights, + dropout=dropout, + training=training, + ) + + +# scaled_dot_product_attention +@handle_test( + fn_tree="functional.ivy.scaled_dot_product_attention", + dtype_q_k_v_mask=_x_and_scaled_attention( + dtypes=helpers.get_dtypes("float", full=False), + ), + scale=st.floats(min_value=0.1, max_value=1), + dropout_p=st.floats(min_value=0, max_value=0.99), + is_causal=st.booleans(), + training=st.just(False), # st.booleans(), disabled until proper testing is used + ground_truth_backend="jax", + test_with_out=st.just(True), +) +def test_scaled_dot_product_attention( + *, + dtype_q_k_v_mask, + scale, + dropout_p, + is_causal, + training, + test_flags, + backend_fw, + fn_name, + on_device, +): + (dtype, query, key, value, mask) = dtype_q_k_v_mask + is_causal = is_causal if mask is None else False + helpers.test_function( + input_dtypes=dtype, + test_flags=test_flags, + backend_to_test=backend_fw, + fn_name=fn_name, + on_device=on_device, + atol_=1e-02, + rtol_=1e-02, + query=query, + key=key, + value=value, + scale=scale, + mask=mask, + dropout_p=dropout_p, + is_causal=is_causal, + training=training, + ) diff --git a/ivy_tests/test_ivy/test_functional/test_nn/test_losses.py b/ivy_tests/test_ivy/test_functional/test_nn/test_losses.py index 069b392630ce1..bba4b09173af4 100644 --- a/ivy_tests/test_ivy/test_functional/test_nn/test_losses.py +++ b/ivy_tests/test_ivy/test_functional/test_nn/test_losses.py @@ -6,56 +6,6 @@ from ivy_tests.test_ivy.helpers import handle_test -# cross_entropy -@handle_test( - fn_tree="functional.ivy.cross_entropy", - dtype_true_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("integer"), - min_value=1e-04, - max_value=1, - allow_inf=False, - valid_axis=True, - allow_neg_axes=True, - force_int_axis=True, - ), - dtype_and_pred=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_value=1e-04, - max_value=1, - allow_inf=False, - ), - reduction=st.sampled_from(["none", "sum", "mean"]), - epsilon=helpers.floats(min_value=0.0, max_value=1.0), -) -def test_cross_entropy( - dtype_true_axis, - dtype_and_pred, - reduction, - epsilon, - test_flags, - backend_fw, - fn_name, - on_device, -): - pred_dtype, pred = dtype_and_pred - true_dtype, true, axis = dtype_true_axis - - helpers.test_function( - input_dtypes=true_dtype + pred_dtype, - test_flags=test_flags, - backend_to_test=backend_fw, - fn_name=fn_name, - on_device=on_device, - rtol_=1e-02, - atol_=1e-02, - true=true[0], - pred=pred[0], - axis=axis, - epsilon=epsilon, - reduction=reduction, - ) - - # binary_cross_entropy @handle_test( fn_tree="functional.ivy.binary_cross_entropy", @@ -150,6 +100,56 @@ def test_binary_cross_entropy( ) +# cross_entropy +@handle_test( + fn_tree="functional.ivy.cross_entropy", + dtype_true_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("integer"), + min_value=1e-04, + max_value=1, + allow_inf=False, + valid_axis=True, + allow_neg_axes=True, + force_int_axis=True, + ), + dtype_and_pred=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_value=1e-04, + max_value=1, + allow_inf=False, + ), + reduction=st.sampled_from(["none", "sum", "mean"]), + epsilon=helpers.floats(min_value=0.0, max_value=1.0), +) +def test_cross_entropy( + dtype_true_axis, + dtype_and_pred, + reduction, + epsilon, + test_flags, + backend_fw, + fn_name, + on_device, +): + pred_dtype, pred = dtype_and_pred + true_dtype, true, axis = dtype_true_axis + + helpers.test_function( + input_dtypes=true_dtype + pred_dtype, + test_flags=test_flags, + backend_to_test=backend_fw, + fn_name=fn_name, + on_device=on_device, + rtol_=1e-02, + atol_=1e-02, + true=true[0], + pred=pred[0], + axis=axis, + epsilon=epsilon, + reduction=reduction, + ) + + # sparse_cross_entropy @handle_test( fn_tree="functional.ivy.sparse_cross_entropy", diff --git a/ivy_tests/test_ivy/test_functional/test_nn/test_norms.py b/ivy_tests/test_ivy/test_functional/test_nn/test_norms.py index ba2eb62e73eb8..c37fef19ea173 100644 --- a/ivy_tests/test_ivy/test_functional/test_nn/test_norms.py +++ b/ivy_tests/test_ivy/test_functional/test_nn/test_norms.py @@ -8,6 +8,10 @@ from ivy_tests.test_ivy.helpers import handle_test +# --- Helpers --- # +# --------------- # + + @st.composite def _generate_data_layer_norm( draw, @@ -84,6 +88,10 @@ def _generate_data_layer_norm( return dtype, values, normalized_idxs, weight_values, bias_values +# --- Main --- # +# ------------ # + + @handle_test( fn_tree="functional.ivy.layer_norm", values_tuple=_generate_data_layer_norm( diff --git a/ivy_tests/test_ivy/test_misc/test_array.py b/ivy_tests/test_ivy/test_misc/test_array.py index 6b2cf18646625..93b15c77a456e 100644 --- a/ivy_tests/test_ivy/test_misc/test_array.py +++ b/ivy_tests/test_ivy/test_misc/test_array.py @@ -18,335 +18,253 @@ CLASS_TREE = "ivy.array" -def test_array_function(): - HANDLED_FUNCTIONS = {} - - class MyArray: - def __init__(self, data=None): - self.data = data - - def __ivy_array_function__(self, func, types, args, kwargs): - if func not in HANDLED_FUNCTIONS: - return NotImplemented - if not all( - issubclass(t, (MyArray, ivy.Array, ivy.NativeArray)) for t in types - ): - return NotImplemented - return HANDLED_FUNCTIONS[func](*args, **kwargs) - - def implements(ivy_function): - """Register an __ivy_array_function__ implementation for MyArray objects.""" - - def decorator(func): - HANDLED_FUNCTIONS[ivy_function] = func - return func - - return decorator - - @implements(ivy.abs) - def _(my_array, ivy_array): - my_array.data = abs(my_array.data) - ivy_array = ivy.abs(ivy_array) - return (my_array, ivy_array) - - x = MyArray(-3) - y = ivy.array([1, -1]) - xy = _(x, ivy_array=y) - x1 = xy[0] - y1 = xy[1] - assert x1.data == 3 - assert all(y1 == ivy.array([1, 1])) - - -# TODO: avoid using dummy fn_tree in property tests - - -@handle_test( - fn_tree="functional.ivy.native_array", # dummy fn_tree - dtype_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("valid")), -) -def test_array_property_data( - dtype_x, - backend_fw, - test_flags, -): - _, data = dtype_x - with BackendHandler.update_backend(backend_fw) as ivy_backend: - data = ivy_backend.native_array(data[0]) - x = ivy_backend.Array(data) - ret = helpers.flatten_and_to_np(ret=x.data, backend=backend_fw) - ret_gt = helpers.flatten_and_to_np(ret=data, backend=backend_fw) - helpers.value_test( - ret_np_flat=ret, - ret_np_from_gt_flat=ret_gt, - backend=backend_fw, - ground_truth_backend=test_flags.ground_truth_backend, - ) - - -@handle_test( - fn_tree="functional.ivy.native_array", # dummy fn_tree - dtype_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("valid")), -) -def test_array_property_dtype( - dtype_x, - backend_fw, -): - _, data = dtype_x - with BackendHandler.update_backend(backend_fw) as ivy_backend: - data = ivy_backend.native_array(data[0]) - x = ivy_backend.Array(data) - ivy_backend.utils.assertions.check_equal( - x.dtype, ivy_backend.dtype(data), as_array=False - ) - - -@handle_test( - fn_tree="functional.ivy.native_array", # dummy fn_tree - dtype_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("valid")), -) -def test_array_property_device( - dtype_x, - backend_fw, -): - _, data = dtype_x - with BackendHandler.update_backend(backend_fw) as ivy_backend: - data = ivy_backend.native_array(data[0]) - x = ivy_backend.Array(data) - ivy_backend.utils.assertions.check_equal( - x.device, ivy_backend.dev(data), as_array=False - ) - - -@handle_test( - fn_tree="functional.ivy.native_array", # dummy fn_tree - dtype_x=helpers.dtype_and_values( +@handle_method( + init_tree=CLASS_TREE, + method_tree="Array.__abs__", + dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), - ret_shape=True, + large_abs_safety_factor=1.5, + small_abs_safety_factor=1.5, + safety_factor_scale="log", ), ) -def test_array_property_ndim( - dtype_x, +def test_array__abs__( + dtype_and_x, + method_name, + class_name, + ground_truth_backend, backend_fw, + init_flags, + method_flags, + on_device, ): - _, data, input_shape = dtype_x - with BackendHandler.update_backend(backend_fw) as ivy_backend: - data = ivy_backend.native_array(data[0]) - x = ivy_backend.Array(data) - ivy_backend.utils.assertions.check_equal( - x.ndim, len(input_shape), as_array=False - ) + dtype, x = dtype_and_x + helpers.test_method( + backend_to_test=backend_fw, + on_device=on_device, + ground_truth_backend=ground_truth_backend, + init_flags=init_flags, + method_flags=method_flags, + init_all_as_kwargs_np={"data": x[0]}, + init_input_dtypes=dtype, + method_input_dtypes=[], + method_all_as_kwargs_np={}, + class_name=class_name, + method_name=method_name, + ) -@handle_test( - fn_tree="functional.ivy.native_array", # dummy fn_tree - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - ret_shape=True, +@handle_method( + init_tree=CLASS_TREE, + method_tree="Array.__add__", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, + large_abs_safety_factor=2.5, + small_abs_safety_factor=2.5, + safety_factor_scale="log", + shared_dtype=True, ), ) -def test_array_property_shape( - dtype_x, +def test_array__add__( + dtype_and_x, + method_name, + class_name, + ground_truth_backend, backend_fw, + init_flags, + method_flags, + on_device, ): - _, data, input_shape = dtype_x - with BackendHandler.update_backend(backend_fw) as ivy_backend: - data = ivy_backend.native_array(data[0]) - x = ivy_backend.Array(data) - ivy_backend.utils.assertions.check_equal( - x.shape, ivy_backend.Shape(input_shape), as_array=False - ) + dtype, x = dtype_and_x + helpers.test_method( + backend_to_test=backend_fw, + on_device=on_device, + ground_truth_backend=ground_truth_backend, + init_flags=init_flags, + method_flags=method_flags, + init_all_as_kwargs_np={"data": x[0]}, + init_input_dtypes=dtype, + method_input_dtypes=dtype, + method_all_as_kwargs_np={"other": x[1]}, + class_name=class_name, + method_name=method_name, + ) -@handle_test( - fn_tree="functional.ivy.native_array", # dummy fn_tree - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - ret_shape=True, - min_num_dims=1, +@handle_method( + init_tree=CLASS_TREE, + method_tree="Array.__and__", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=st.one_of(st.just(("bool",)), helpers.get_dtypes("integer")), + num_arrays=2, + shared_dtype=True, ), ) -def test_array_property_size( - dtype_x, +def test_array__and__( + dtype_and_x, + method_name, + class_name, + ground_truth_backend, backend_fw, + init_flags, + method_flags, + on_device, ): - _, data, input_shape = dtype_x - with BackendHandler.update_backend(backend_fw) as ivy_backend: - data = ivy_backend.native_array(data[0]) - x = ivy_backend.Array(data) - size_gt = 1 - for dim in input_shape: - size_gt *= dim - ivy_backend.utils.assertions.check_equal(x.size, size_gt, as_array=False) + dtype, x = dtype_and_x + helpers.test_method( + backend_to_test=backend_fw, + on_device=on_device, + ground_truth_backend=ground_truth_backend, + init_flags=init_flags, + method_flags=method_flags, + init_all_as_kwargs_np={"data": x[0]}, + init_input_dtypes=dtype, + method_input_dtypes=dtype, + method_all_as_kwargs_np={"other": x[1]}, + class_name=class_name, + method_name=method_name, + ) -@handle_test( - fn_tree="functional.ivy.native_array", # dummy fn_tree - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), +@handle_method( + init_tree=CLASS_TREE, + method_tree="Array.__bool__", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=st.one_of(st.just(("bool",)), helpers.get_dtypes("integer")), + max_num_dims=0, + min_value=0, + max_value=1, ), ) -def test_array_property_itemsize( - dtype_x, +def test_array__bool__( + dtype_and_x, + method_name, + class_name, + ground_truth_backend, backend_fw, + init_flags, + method_flags, + on_device, ): - dtype, data = dtype_x - with BackendHandler.update_backend(backend_fw) as ivy_backend: - data = ivy_backend.native_array(data[0]) - x = ivy_backend.Array(data) - ivy_backend.utils.assertions.check_equal( - x.itemsize, ivy_backend.to_numpy(x).itemsize, as_array=False - ) - - -@handle_test( - fn_tree="functional.ivy.native_array", # dummy fn_tree - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - ), -) -def test_array_property_strides(dtype_x, backend_fw): - dtype, data = dtype_x - with BackendHandler.update_backend(backend_fw) as ivy_backend: - data = ivy_backend.native_array(data[0]) - x = ivy_backend.Array(data) - ivy_backend.utils.assertions.check_equal( - x.strides, ivy_backend.to_numpy(x).strides, as_array=False - ) + dtype, x = dtype_and_x + helpers.test_method( + backend_to_test=backend_fw, + on_device=on_device, + ground_truth_backend=ground_truth_backend, + init_flags=init_flags, + method_flags=method_flags, + init_all_as_kwargs_np={"data": x[0]}, + init_input_dtypes=dtype, + method_input_dtypes=[], + method_all_as_kwargs_np={}, + class_name=class_name, + method_name=method_name, + ) -@handle_test( - fn_tree="functional.ivy.native_array", # dummy fn_tree - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - min_num_dims=2, +@handle_method( + init_tree=CLASS_TREE, + method_tree="Array.__complex__", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + max_num_dims=0, ), + method_container_flags=st.just([False]), ) -def test_array_property_mT( - dtype_x, +def test_array__complex__( + dtype_and_x, + method_name, + class_name, + ground_truth_backend, backend_fw, - test_flags, + init_flags, + method_flags, + on_device, ): - _, data = dtype_x - with BackendHandler.update_backend(backend_fw) as ivy_backend: - data = ivy_backend.native_array(data[0]) - x = ivy_backend.Array(data) - ret = helpers.flatten_and_to_np(ret=x.mT, backend=backend_fw) - ret_gt = helpers.flatten_and_to_np( - ret=ivy_backend.matrix_transpose(data), backend=backend_fw - ) - helpers.value_test( - ret_np_flat=ret, - ret_np_from_gt_flat=ret_gt, - backend=backend_fw, - ground_truth_backend=test_flags.ground_truth_backend, - ) + dtype, x = dtype_and_x + helpers.test_method( + backend_to_test=backend_fw, + on_device=on_device, + ground_truth_backend=ground_truth_backend, + init_flags=init_flags, + method_flags=method_flags, + init_all_as_kwargs_np={"data": x[0]}, + init_input_dtypes=dtype, + method_input_dtypes=[], + method_all_as_kwargs_np={}, + class_name=class_name, + method_name=method_name, + ) -@handle_test( - fn_tree="functional.ivy.native_array", # dummy fn_tree - dtype_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - min_num_dims=2, - max_num_dims=2, +@handle_method( + init_tree=CLASS_TREE, + method_tree="Array.__deepcopy__", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("integer"), ), ) -def test_array_property_T( - dtype_x, - backend_fw, - test_flags, -): - _, data = dtype_x - with BackendHandler.update_backend(backend_fw) as ivy_backend: - data = ivy_backend.native_array(data[0]) - x = ivy_backend.Array(data) - ret = helpers.flatten_and_to_np(ret=x.T, backend=backend_fw) - ret_gt = helpers.flatten_and_to_np( - ret=ivy_backend.matrix_transpose(data), backend=backend_fw - ) - helpers.value_test( - ret_np_flat=ret, - ret_np_from_gt_flat=ret_gt, - backend=backend_fw, - ground_truth_backend=test_flags.ground_truth_backend, - ) - - -@handle_test( - fn_tree="functional.ivy.native_array", # dummy fn_tree - dtype_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("complex")), -) -def test_array_property_real( - dtype_x, - backend_fw, - test_flags, -): - _, data = dtype_x - with BackendHandler.update_backend(backend_fw) as ivy_backend: - data = ivy_backend.native_array(data[0]) - x = ivy_backend.Array(data) - ret = helpers.flatten_and_to_np(ret=x.real, backend=backend_fw) - ret_gt = helpers.flatten_and_to_np(ret=ivy_backend.real(x), backend=backend_fw) - helpers.value_test( - ret_np_flat=ret, - ret_np_from_gt_flat=ret_gt, - backend=backend_fw, - ground_truth_backend=test_flags.ground_truth_backend, - ) - - -@handle_test( - fn_tree="functional.ivy.native_array", # dummy fn_tree - dtype_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("complex")), -) -def test_array_property_imag( - dtype_x, +def test_array__deepcopy__( + dtype_and_x, + method_name, + class_name, + ground_truth_backend, backend_fw, - test_flags, + init_flags, + method_flags, + on_device, ): - _, data = dtype_x - with BackendHandler.update_backend(backend_fw) as ivy_backend: - data = ivy_backend.native_array(data[0]) - x = ivy_backend.Array(data) - ret = helpers.flatten_and_to_np(ret=x.imag, backend=backend_fw) - ret_gt = helpers.flatten_and_to_np(ret=ivy_backend.imag(x), backend=backend_fw) - helpers.value_test( - ret_np_flat=ret, - ret_np_from_gt_flat=ret_gt, - backend=backend_fw, - ground_truth_backend=test_flags.ground_truth_backend, - ) + dtype, x = dtype_and_x + helpers.test_method( + backend_to_test=backend_fw, + on_device=on_device, + ground_truth_backend=ground_truth_backend, + init_flags=init_flags, + method_flags=method_flags, + init_all_as_kwargs_np={"data": x[0]}, + init_input_dtypes=dtype, + method_input_dtypes=[], + method_all_as_kwargs_np={"memodict": {}}, + class_name=class_name, + method_name=method_name, + ) @handle_method( init_tree=CLASS_TREE, - method_tree="Array.__getitem__", - ground_truth_backend="numpy", - dtypes_x_query=helpers.dtype_array_query( - available_dtypes=helpers.get_dtypes("valid"), + method_tree="Array.__divmod__", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, + large_abs_safety_factor=2.5, + small_abs_safety_factor=2.5, + safety_factor_scale="log", + shared_dtype=True, ), ) -def test_array__getitem__( - dtypes_x_query, - init_flags, - method_flags, +def test_array__divmod__( + dtype_and_x, method_name, class_name, - backend_fw, ground_truth_backend, + backend_fw, + init_flags, + method_flags, on_device, ): - dtypes, x, query = dtypes_x_query + dtype, x = dtype_and_x + assume(not np.any(np.isclose(x[1], 0))) helpers.test_method( backend_to_test=backend_fw, on_device=on_device, ground_truth_backend=ground_truth_backend, init_flags=init_flags, method_flags=method_flags, - init_all_as_kwargs_np={"data": x}, - init_input_dtypes=[dtypes[0]], - method_input_dtypes=[*dtypes[1:]], - method_all_as_kwargs_np={"query": query}, + init_all_as_kwargs_np={"data": x[0]}, + init_input_dtypes=dtype, + method_input_dtypes=dtype, + method_all_as_kwargs_np={"other": x[1]}, class_name=class_name, method_name=method_name, ) @@ -354,16 +272,15 @@ def test_array__getitem__( @handle_method( init_tree=CLASS_TREE, - method_tree="Array.__setitem__", - ground_truth_backend="numpy", - dtypes_x_query_val=helpers.dtype_array_query_val( - available_dtypes=helpers.get_dtypes("valid"), + method_tree="Array.__eq__", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, + shared_dtype=True, ), - # ToDo: fix container method - method_container_flags=st.just([False]), ) -def test_array__setitem__( - dtypes_x_query_val, +def test_array__eq__( + dtype_and_x, method_name, class_name, ground_truth_backend, @@ -372,17 +289,17 @@ def test_array__setitem__( method_flags, on_device, ): - dtypes, x, query, val = dtypes_x_query_val + dtype, x = dtype_and_x helpers.test_method( backend_to_test=backend_fw, on_device=on_device, ground_truth_backend=ground_truth_backend, init_flags=init_flags, method_flags=method_flags, - init_all_as_kwargs_np={"data": x}, - init_input_dtypes=[dtypes[0]], - method_input_dtypes=[*dtypes[1:]], - method_all_as_kwargs_np={"query": query, "val": val}, + init_all_as_kwargs_np={"data": x[0]}, + init_input_dtypes=dtype, + method_input_dtypes=dtype, + method_all_as_kwargs_np={"other": x[1]}, class_name=class_name, method_name=method_name, ) @@ -390,12 +307,13 @@ def test_array__setitem__( @handle_method( init_tree=CLASS_TREE, - method_tree="Array.__pos__", + method_tree="Array.__float__", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), + max_num_dims=0, ), ) -def test_array__pos__( +def test_array__float__( dtype_and_x, method_name, class_name, @@ -414,7 +332,7 @@ def test_array__pos__( method_flags=method_flags, init_all_as_kwargs_np={"data": x[0]}, init_input_dtypes=dtype, - method_input_dtypes=dtype, + method_input_dtypes=[], method_all_as_kwargs_np={}, class_name=class_name, method_name=method_name, @@ -423,12 +341,17 @@ def test_array__pos__( @handle_method( init_tree=CLASS_TREE, - method_tree="Array.__neg__", + method_tree="Array.__floordiv__", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, + large_abs_safety_factor=3.0, + small_abs_safety_factor=3.0, + safety_factor_scale="log", + shared_dtype=True, ), ) -def test_array__neg__( +def test_array__floordiv__( dtype_and_x, method_name, class_name, @@ -439,6 +362,7 @@ def test_array__neg__( on_device, ): dtype, x = dtype_and_x + assume(not np.any(np.isclose(x[1], 0))) helpers.test_method( backend_to_test=backend_fw, on_device=on_device, @@ -448,7 +372,7 @@ def test_array__neg__( init_all_as_kwargs_np={"data": x[0]}, init_input_dtypes=dtype, method_input_dtypes=dtype, - method_all_as_kwargs_np={}, + method_all_as_kwargs_np={"other": x[1]}, class_name=class_name, method_name=method_name, ) @@ -456,10 +380,14 @@ def test_array__neg__( @handle_method( init_tree=CLASS_TREE, - method_tree="Array.__pow__", - dtype_and_x=pow_helper(), + method_tree="Array.__ge__", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, + shared_dtype=True, + ), ) -def test_array__pow__( +def test_array__ge__( dtype_and_x, method_name, class_name, @@ -469,22 +397,7 @@ def test_array__pow__( method_flags, on_device, ): - input_dtype, x = dtype_and_x - - # bfloat16 is not supported by numpy - assume(not ("bfloat16" in input_dtype)) - - # Make sure x2 isn't a float when x1 is integer - assume( - not (ivy.is_int_dtype(input_dtype[0] and ivy.is_float_dtype(input_dtype[1]))) - ) - - # Make sure x2 is non-negative when both is integer - if ivy.is_int_dtype(input_dtype[1]) and ivy.is_int_dtype(input_dtype[0]): - x[1] = np.abs(x[1]) - - x[0] = not_too_close_to_zero(x[0]) - x[1] = not_too_close_to_zero(x[1]) + dtype, x = dtype_and_x helpers.test_method( backend_to_test=backend_fw, on_device=on_device, @@ -492,9 +405,9 @@ def test_array__pow__( init_flags=init_flags, method_flags=method_flags, init_all_as_kwargs_np={"data": x[0]}, - init_input_dtypes=[input_dtype[0]], - method_input_dtypes=[input_dtype[1]], - method_all_as_kwargs_np={"power": x[1]}, + init_input_dtypes=dtype, + method_input_dtypes=dtype, + method_all_as_kwargs_np={"other": x[1]}, class_name=class_name, method_name=method_name, ) @@ -502,45 +415,33 @@ def test_array__pow__( @handle_method( init_tree=CLASS_TREE, - method_tree="Array.__rpow__", - dtype_and_x=pow_helper(), + method_tree="Array.__getitem__", + ground_truth_backend="numpy", + dtypes_x_query=helpers.dtype_array_query( + available_dtypes=helpers.get_dtypes("valid"), + ), ) -def test_array__rpow__( - dtype_and_x, +def test_array__getitem__( + dtypes_x_query, + init_flags, + method_flags, method_name, class_name, - ground_truth_backend, backend_fw, - init_flags, - method_flags, + ground_truth_backend, on_device, ): - input_dtype, x = dtype_and_x - - # bfloat16 is not supported by numpy - assume(not ("bfloat16" in input_dtype)) - - # Make sure x2 isn't a float when x1 is integer - assume( - not (ivy.is_int_dtype(input_dtype[0] and ivy.is_float_dtype(input_dtype[1]))) - ) - - # Make sure x2 is non-negative when both is integer - if ivy.is_int_dtype(input_dtype[1]) and ivy.is_int_dtype(input_dtype[0]): - x[1] = np.abs(x[1]) - - x[0] = not_too_close_to_zero(x[0]) - x[1] = not_too_close_to_zero(x[1]) + dtypes, x, query = dtypes_x_query helpers.test_method( backend_to_test=backend_fw, on_device=on_device, ground_truth_backend=ground_truth_backend, init_flags=init_flags, method_flags=method_flags, - init_all_as_kwargs_np={"data": x[1]}, - init_input_dtypes=[input_dtype[1]], - method_input_dtypes=[input_dtype[0]], - method_all_as_kwargs_np={"power": x[0]}, + init_all_as_kwargs_np={"data": x}, + init_input_dtypes=[dtypes[0]], + method_input_dtypes=[*dtypes[1:]], + method_all_as_kwargs_np={"query": query}, class_name=class_name, method_name=method_name, ) @@ -548,11 +449,14 @@ def test_array__rpow__( @handle_method( init_tree=CLASS_TREE, - method_tree="Array.__ipow__", - dtype_and_x=pow_helper(), - method_container_flags=st.just([False]), + method_tree="Array.__gt__", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, + shared_dtype=True, + ), ) -def test_array__ipow__( +def test_array__gt__( dtype_and_x, method_name, class_name, @@ -562,22 +466,7 @@ def test_array__ipow__( method_flags, on_device, ): - input_dtype, x = dtype_and_x - - # bfloat16 is not supported by numpy - assume(not ("bfloat16" in input_dtype)) - - # Make sure x2 isn't a float when x1 is integer - assume( - not (ivy.is_int_dtype(input_dtype[0] and ivy.is_float_dtype(input_dtype[1]))) - ) - - # Make sure x2 is non-negative when both is integer - if ivy.is_int_dtype(input_dtype[1]) and ivy.is_int_dtype(input_dtype[0]): - x[1] = np.abs(x[1]) - - x[0] = not_too_close_to_zero(x[0]) - x[1] = not_too_close_to_zero(x[1]) + dtype, x = dtype_and_x helpers.test_method( backend_to_test=backend_fw, on_device=on_device, @@ -585,9 +474,9 @@ def test_array__ipow__( init_flags=init_flags, method_flags=method_flags, init_all_as_kwargs_np={"data": x[0]}, - init_input_dtypes=[input_dtype[0]], - method_input_dtypes=[input_dtype[1]], - method_all_as_kwargs_np={"power": x[1]}, + init_input_dtypes=dtype, + method_input_dtypes=dtype, + method_all_as_kwargs_np={"other": x[1]}, class_name=class_name, method_name=method_name, ) @@ -595,7 +484,7 @@ def test_array__ipow__( @handle_method( init_tree=CLASS_TREE, - method_tree="Array.__add__", + method_tree="Array.__iadd__", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2, @@ -604,8 +493,9 @@ def test_array__ipow__( safety_factor_scale="log", shared_dtype=True, ), + method_container_flags=st.just([False]), ) -def test_array__add__( +def test_array__iadd__( dtype_and_x, method_name, class_name, @@ -633,17 +523,15 @@ def test_array__add__( @handle_method( init_tree=CLASS_TREE, - method_tree="Array.__radd__", + method_tree="Array.__iand__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), + available_dtypes=st.one_of(st.just(("bool",)), helpers.get_dtypes("integer")), num_arrays=2, - large_abs_safety_factor=2.5, - small_abs_safety_factor=2.5, - safety_factor_scale="log", shared_dtype=True, ), + method_container_flags=st.just([False]), ) -def test_array__radd__( +def test_array__iand__( dtype_and_x, method_name, class_name, @@ -671,18 +559,18 @@ def test_array__radd__( @handle_method( init_tree=CLASS_TREE, - method_tree="Array.__iadd__", + method_tree="Array.__ifloordiv__", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2, - large_abs_safety_factor=2.5, - small_abs_safety_factor=2.5, + large_abs_safety_factor=3.0, + small_abs_safety_factor=3.0, safety_factor_scale="log", shared_dtype=True, ), method_container_flags=st.just([False]), ) -def test_array__iadd__( +def test_array__ifloordiv__( dtype_and_x, method_name, class_name, @@ -693,6 +581,7 @@ def test_array__iadd__( on_device, ): dtype, x = dtype_and_x + assume(not np.any(np.isclose(x[1], 0))) helpers.test_method( backend_to_test=backend_fw, on_device=on_device, @@ -710,17 +599,15 @@ def test_array__iadd__( @handle_method( init_tree=CLASS_TREE, - method_tree="Array.__sub__", + method_tree="Array.__ilshift__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), + available_dtypes=helpers.get_dtypes("integer"), num_arrays=2, - large_abs_safety_factor=2.5, - small_abs_safety_factor=2.5, - safety_factor_scale="log", - shared_dtype=True, + array_api_dtypes=True, ), + method_container_flags=st.just([False]), ) -def test_array__sub__( +def test_array__ilshift__( dtype_and_x, method_name, class_name, @@ -731,6 +618,7 @@ def test_array__sub__( on_device, ): dtype, x = dtype_and_x + x[1] = np.asarray(np.clip(x[1], 0, np.iinfo(dtype[1]).bits - 1), dtype=dtype[1]) helpers.test_method( backend_to_test=backend_fw, on_device=on_device, @@ -738,8 +626,8 @@ def test_array__sub__( init_flags=init_flags, method_flags=method_flags, init_all_as_kwargs_np={"data": x[0]}, - init_input_dtypes=dtype, - method_input_dtypes=dtype, + init_input_dtypes=[dtype[0]], + method_input_dtypes=[dtype[1]], method_all_as_kwargs_np={"other": x[1]}, class_name=class_name, method_name=method_name, @@ -748,18 +636,14 @@ def test_array__sub__( @handle_method( init_tree=CLASS_TREE, - method_tree="Array.__rsub__", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, - large_abs_safety_factor=2.5, - small_abs_safety_factor=2.5, - safety_factor_scale="log", - shared_dtype=True, - ), + method_tree="Array.__imatmul__", + x1=_get_first_matrix_and_dtype(), + x2=_get_second_matrix_and_dtype(), + method_container_flags=st.just([False]), ) -def test_array__rsub__( - dtype_and_x, +def test_array__imatmul__( + x1, + x2, method_name, class_name, ground_truth_backend, @@ -768,17 +652,18 @@ def test_array__rsub__( method_flags, on_device, ): - dtype, x = dtype_and_x + dtype1, x1 = x1 + dtype2, x2 = x2 helpers.test_method( backend_to_test=backend_fw, on_device=on_device, ground_truth_backend=ground_truth_backend, init_flags=init_flags, method_flags=method_flags, - init_all_as_kwargs_np={"data": x[0]}, - init_input_dtypes=dtype, - method_input_dtypes=dtype, - method_all_as_kwargs_np={"other": x[1]}, + init_all_as_kwargs_np={"data": x1}, + init_input_dtypes=dtype1, + method_input_dtypes=dtype2, + method_all_as_kwargs_np={"other": x2}, class_name=class_name, method_name=method_name, ) @@ -786,7 +671,7 @@ def test_array__rsub__( @handle_method( init_tree=CLASS_TREE, - method_tree="Array.__isub__", + method_tree="Array.__imod__", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2, @@ -797,7 +682,7 @@ def test_array__rsub__( ), method_container_flags=st.just([False]), ) -def test_array__isub__( +def test_array__imod__( dtype_and_x, method_name, class_name, @@ -808,6 +693,7 @@ def test_array__isub__( on_device, ): dtype, x = dtype_and_x + assume(not np.any(np.isclose(x[1], 0))) helpers.test_method( backend_to_test=backend_fw, on_device=on_device, @@ -825,7 +711,7 @@ def test_array__isub__( @handle_method( init_tree=CLASS_TREE, - method_tree="Array.__mul__", + method_tree="Array.__imul__", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2, @@ -834,8 +720,9 @@ def test_array__isub__( safety_factor_scale="log", shared_dtype=True, ), + method_container_flags=st.just([False]), ) -def test_array__mul__( +def test_array__imul__( dtype_and_x, method_name, class_name, @@ -863,17 +750,16 @@ def test_array__mul__( @handle_method( init_tree=CLASS_TREE, - method_tree="Array.__rmul__", + method_tree="Array.__int__", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, - large_abs_safety_factor=2.5, - small_abs_safety_factor=2.5, - safety_factor_scale="log", - shared_dtype=True, + max_num_dims=0, + min_value=-1e15, + max_value=1e15, ), + method_container_flags=st.just([False]), ) -def test_array__rmul__( +def test_array__int__( dtype_and_x, method_name, class_name, @@ -892,8 +778,8 @@ def test_array__rmul__( method_flags=method_flags, init_all_as_kwargs_np={"data": x[0]}, init_input_dtypes=dtype, - method_input_dtypes=dtype, - method_all_as_kwargs_np={"other": x[1]}, + method_input_dtypes=[], + method_all_as_kwargs_np={}, class_name=class_name, method_name=method_name, ) @@ -901,18 +787,12 @@ def test_array__rmul__( @handle_method( init_tree=CLASS_TREE, - method_tree="Array.__imul__", + method_tree="Array.__invert__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, - large_abs_safety_factor=2.5, - small_abs_safety_factor=2.5, - safety_factor_scale="log", - shared_dtype=True, + available_dtypes=st.one_of(st.just(("bool",)), helpers.get_dtypes("integer")), ), - method_container_flags=st.just([False]), ) -def test_array__imul__( +def test_array__invert__( dtype_and_x, method_name, class_name, @@ -932,7 +812,7 @@ def test_array__imul__( init_all_as_kwargs_np={"data": x[0]}, init_input_dtypes=dtype, method_input_dtypes=dtype, - method_all_as_kwargs_np={"other": x[1]}, + method_all_as_kwargs_np={}, class_name=class_name, method_name=method_name, ) @@ -940,17 +820,15 @@ def test_array__imul__( @handle_method( init_tree=CLASS_TREE, - method_tree="Array.__mod__", + method_tree="Array.__ior__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), + available_dtypes=st.one_of(st.just(("bool",)), helpers.get_dtypes("integer")), num_arrays=2, - large_abs_safety_factor=2.5, - small_abs_safety_factor=2.5, - safety_factor_scale="log", shared_dtype=True, ), + method_container_flags=st.just([False]), ) -def test_array__mod__( +def test_array__ior__( dtype_and_x, method_name, class_name, @@ -961,7 +839,6 @@ def test_array__mod__( on_device, ): dtype, x = dtype_and_x - assume(not np.any(np.isclose(x[1], 0))) helpers.test_method( backend_to_test=backend_fw, on_device=on_device, @@ -979,17 +856,11 @@ def test_array__mod__( @handle_method( init_tree=CLASS_TREE, - method_tree="Array.__rmod__", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, - large_abs_safety_factor=2.5, - small_abs_safety_factor=2.5, - safety_factor_scale="log", - shared_dtype=True, - ), + method_tree="Array.__ipow__", + dtype_and_x=pow_helper(), + method_container_flags=st.just([False]), ) -def test_array__rmod__( +def test_array__ipow__( dtype_and_x, method_name, class_name, @@ -999,8 +870,22 @@ def test_array__rmod__( method_flags, on_device, ): - dtype, x = dtype_and_x - assume(not np.any(np.isclose(x[0], 0))) + input_dtype, x = dtype_and_x + + # bfloat16 is not supported by numpy + assume(not ("bfloat16" in input_dtype)) + + # Make sure x2 isn't a float when x1 is integer + assume( + not (ivy.is_int_dtype(input_dtype[0] and ivy.is_float_dtype(input_dtype[1]))) + ) + + # Make sure x2 is non-negative when both is integer + if ivy.is_int_dtype(input_dtype[1]) and ivy.is_int_dtype(input_dtype[0]): + x[1] = np.abs(x[1]) + + x[0] = not_too_close_to_zero(x[0]) + x[1] = not_too_close_to_zero(x[1]) helpers.test_method( backend_to_test=backend_fw, on_device=on_device, @@ -1008,9 +893,9 @@ def test_array__rmod__( init_flags=init_flags, method_flags=method_flags, init_all_as_kwargs_np={"data": x[0]}, - init_input_dtypes=dtype, - method_input_dtypes=dtype, - method_all_as_kwargs_np={"other": x[1]}, + init_input_dtypes=[input_dtype[0]], + method_input_dtypes=[input_dtype[1]], + method_all_as_kwargs_np={"power": x[1]}, class_name=class_name, method_name=method_name, ) @@ -1018,18 +903,15 @@ def test_array__rmod__( @handle_method( init_tree=CLASS_TREE, - method_tree="Array.__imod__", + method_tree="Array.__irshift__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), + available_dtypes=helpers.get_dtypes("integer"), num_arrays=2, - large_abs_safety_factor=2.5, - small_abs_safety_factor=2.5, - safety_factor_scale="log", - shared_dtype=True, + array_api_dtypes=True, ), method_container_flags=st.just([False]), ) -def test_array__imod__( +def test_array__irshift__( dtype_and_x, method_name, class_name, @@ -1040,7 +922,7 @@ def test_array__imod__( on_device, ): dtype, x = dtype_and_x - assume(not np.any(np.isclose(x[1], 0))) + x[1] = np.asarray(np.clip(x[1], 0, np.iinfo(dtype[1]).bits - 1), dtype=dtype[1]) helpers.test_method( backend_to_test=backend_fw, on_device=on_device, @@ -1048,8 +930,8 @@ def test_array__imod__( init_flags=init_flags, method_flags=method_flags, init_all_as_kwargs_np={"data": x[0]}, - init_input_dtypes=dtype, - method_input_dtypes=dtype, + init_input_dtypes=[dtype[0]], + method_input_dtypes=[dtype[1]], method_all_as_kwargs_np={"other": x[1]}, class_name=class_name, method_name=method_name, @@ -1058,7 +940,7 @@ def test_array__imod__( @handle_method( init_tree=CLASS_TREE, - method_tree="Array.__divmod__", + method_tree="Array.__isub__", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2, @@ -1067,8 +949,9 @@ def test_array__imod__( safety_factor_scale="log", shared_dtype=True, ), + method_container_flags=st.just([False]), ) -def test_array__divmod__( +def test_array__isub__( dtype_and_x, method_name, class_name, @@ -1079,7 +962,6 @@ def test_array__divmod__( on_device, ): dtype, x = dtype_and_x - assume(not np.any(np.isclose(x[1], 0))) helpers.test_method( backend_to_test=backend_fw, on_device=on_device, @@ -1097,17 +979,14 @@ def test_array__divmod__( @handle_method( init_tree=CLASS_TREE, - method_tree="Array.__rdivmod__", + method_tree="Array.__iter__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, - large_abs_safety_factor=2.5, - small_abs_safety_factor=2.5, - safety_factor_scale="log", - shared_dtype=True, + available_dtypes=helpers.get_dtypes("integer"), + min_dim_size=2, + min_num_dims=1, ), ) -def test_array__rdivmod__( +def test_array__iter__( dtype_and_x, method_name, class_name, @@ -1118,7 +997,6 @@ def test_array__rdivmod__( on_device, ): dtype, x = dtype_and_x - assume(not np.any(np.isclose(x[0], 0))) helpers.test_method( backend_to_test=backend_fw, on_device=on_device, @@ -1128,7 +1006,7 @@ def test_array__rdivmod__( init_all_as_kwargs_np={"data": x[0]}, init_input_dtypes=dtype, method_input_dtypes=dtype, - method_all_as_kwargs_np={"other": x[1]}, + method_all_as_kwargs_np={}, class_name=class_name, method_name=method_name, ) @@ -1136,7 +1014,7 @@ def test_array__rdivmod__( @handle_method( init_tree=CLASS_TREE, - method_tree="Array.__truediv__", + method_tree="Array.__itruediv__", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2, @@ -1145,8 +1023,9 @@ def test_array__rdivmod__( safety_factor_scale="log", shared_dtype=True, ), + method_container_flags=st.just([False]), ) -def test_array__truediv__( +def test_array__itruediv__( dtype_and_x, method_name, class_name, @@ -1174,17 +1053,15 @@ def test_array__truediv__( @handle_method( init_tree=CLASS_TREE, - method_tree="Array.__rtruediv__", + method_tree="Array.__ixor__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), + available_dtypes=st.one_of(st.just(("bool",)), helpers.get_dtypes("integer")), num_arrays=2, - large_abs_safety_factor=2.5, - small_abs_safety_factor=2.5, - safety_factor_scale="log", shared_dtype=True, ), + method_container_flags=st.just([False]), ) -def test_array__rtruediv__( +def test_array__ixor__( dtype_and_x, method_name, class_name, @@ -1212,18 +1089,14 @@ def test_array__rtruediv__( @handle_method( init_tree=CLASS_TREE, - method_tree="Array.__itruediv__", + method_tree="Array.__le__", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2, - large_abs_safety_factor=2.5, - small_abs_safety_factor=2.5, - safety_factor_scale="log", shared_dtype=True, ), - method_container_flags=st.just([False]), ) -def test_array__itruediv__( +def test_array__le__( dtype_and_x, method_name, class_name, @@ -1251,17 +1124,12 @@ def test_array__itruediv__( @handle_method( init_tree=CLASS_TREE, - method_tree="Array.__floordiv__", + method_tree="Array.__len__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, - large_abs_safety_factor=3.0, - small_abs_safety_factor=3.0, - safety_factor_scale="log", - shared_dtype=True, + available_dtypes=helpers.get_dtypes("valid"), ), ) -def test_array__floordiv__( +def test_array__len__( dtype_and_x, method_name, class_name, @@ -1272,7 +1140,6 @@ def test_array__floordiv__( on_device, ): dtype, x = dtype_and_x - assume(not np.any(np.isclose(x[1], 0))) helpers.test_method( backend_to_test=backend_fw, on_device=on_device, @@ -1282,7 +1149,7 @@ def test_array__floordiv__( init_all_as_kwargs_np={"data": x[0]}, init_input_dtypes=dtype, method_input_dtypes=dtype, - method_all_as_kwargs_np={"other": x[1]}, + method_all_as_kwargs_np={}, class_name=class_name, method_name=method_name, ) @@ -1290,17 +1157,14 @@ def test_array__floordiv__( @handle_method( init_tree=CLASS_TREE, - method_tree="Array.__rfloordiv__", + method_tree="Array.__lshift__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), + available_dtypes=helpers.get_dtypes("integer"), num_arrays=2, - large_abs_safety_factor=3.0, - small_abs_safety_factor=3.0, - safety_factor_scale="log", - shared_dtype=True, + array_api_dtypes=True, ), ) -def test_array__rfloordiv__( +def test_array__lshift__( dtype_and_x, method_name, class_name, @@ -1311,7 +1175,16 @@ def test_array__rfloordiv__( on_device, ): dtype, x = dtype_and_x - assume(not np.any(np.isclose(x[0], 0))) + max_bits = np.iinfo(dtype[0]).bits + max_shift = max_bits - 1 + x[1] = np.asarray(np.clip(x[1], 0, max_shift), dtype=dtype[1]) + max_value_before_shift = 2 ** (max_bits - x[1]) - 1 + overflow_threshold = 2 ** (max_bits - 1) + x[0] = np.asarray(np.clip(x[0], None, max_value_before_shift), dtype=dtype[0]) + if np.any(x[0] > overflow_threshold): + x[0] = np.asarray(np.clip(x[0], None, overflow_threshold), dtype=dtype[0]) + if np.any(x[0] < 0): + x[0] = np.asarray(np.abs(x[0]), dtype=dtype[0]) helpers.test_method( backend_to_test=backend_fw, on_device=on_device, @@ -1319,28 +1192,26 @@ def test_array__rfloordiv__( init_flags=init_flags, method_flags=method_flags, init_all_as_kwargs_np={"data": x[0]}, - init_input_dtypes=dtype, - method_input_dtypes=dtype, + init_input_dtypes=[dtype[0]], + method_input_dtypes=[dtype[1]], method_all_as_kwargs_np={"other": x[1]}, class_name=class_name, method_name=method_name, + rtol_=1e-5, + atol_=1e-5, ) @handle_method( init_tree=CLASS_TREE, - method_tree="Array.__ifloordiv__", + method_tree="Array.__lt__", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2, - large_abs_safety_factor=3.0, - small_abs_safety_factor=3.0, - safety_factor_scale="log", shared_dtype=True, ), - method_container_flags=st.just([False]), ) -def test_array__ifloordiv__( +def test_array__lt__( dtype_and_x, method_name, class_name, @@ -1351,7 +1222,6 @@ def test_array__ifloordiv__( on_device, ): dtype, x = dtype_and_x - assume(not np.any(np.isclose(x[1], 0))) helpers.test_method( backend_to_test=backend_fw, on_device=on_device, @@ -1403,48 +1273,18 @@ def test_array__matmul__( @handle_method( init_tree=CLASS_TREE, - method_tree="Array.__rmatmul__", - x1=_get_first_matrix_and_dtype(), - x2=_get_second_matrix_and_dtype(), -) -def test_array__rmatmul__( - x1, - x2, - method_name, - class_name, - ground_truth_backend, - backend_fw, - init_flags, - method_flags, - on_device, -): - dtype1, x1 = x1 - dtype2, x2 = x2 - helpers.test_method( - backend_to_test=backend_fw, - on_device=on_device, - ground_truth_backend=ground_truth_backend, - init_flags=init_flags, - method_flags=method_flags, - init_all_as_kwargs_np={"data": x2}, - init_input_dtypes=dtype1, - method_input_dtypes=dtype2, - method_all_as_kwargs_np={"other": x1}, - class_name=class_name, - method_name=method_name, - ) - - -@handle_method( - init_tree=CLASS_TREE, - method_tree="Array.__imatmul__", - x1=_get_first_matrix_and_dtype(), - x2=_get_second_matrix_and_dtype(), - method_container_flags=st.just([False]), + method_tree="Array.__mod__", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, + large_abs_safety_factor=2.5, + small_abs_safety_factor=2.5, + safety_factor_scale="log", + shared_dtype=True, + ), ) -def test_array__imatmul__( - x1, - x2, +def test_array__mod__( + dtype_and_x, method_name, class_name, ground_truth_backend, @@ -1453,18 +1293,18 @@ def test_array__imatmul__( method_flags, on_device, ): - dtype1, x1 = x1 - dtype2, x2 = x2 + dtype, x = dtype_and_x + assume(not np.any(np.isclose(x[1], 0))) helpers.test_method( backend_to_test=backend_fw, on_device=on_device, ground_truth_backend=ground_truth_backend, init_flags=init_flags, method_flags=method_flags, - init_all_as_kwargs_np={"data": x1}, - init_input_dtypes=dtype1, - method_input_dtypes=dtype2, - method_all_as_kwargs_np={"other": x2}, + init_all_as_kwargs_np={"data": x[0]}, + init_input_dtypes=dtype, + method_input_dtypes=dtype, + method_all_as_kwargs_np={"other": x[1]}, class_name=class_name, method_name=method_name, ) @@ -1472,15 +1312,17 @@ def test_array__imatmul__( @handle_method( init_tree=CLASS_TREE, - method_tree="Array.__abs__", + method_tree="Array.__mul__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - large_abs_safety_factor=1.5, - small_abs_safety_factor=1.5, + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, + large_abs_safety_factor=2.5, + small_abs_safety_factor=2.5, safety_factor_scale="log", + shared_dtype=True, ), ) -def test_array__abs__( +def test_array__mul__( dtype_and_x, method_name, class_name, @@ -1499,8 +1341,8 @@ def test_array__abs__( method_flags=method_flags, init_all_as_kwargs_np={"data": x[0]}, init_input_dtypes=dtype, - method_input_dtypes=[], - method_all_as_kwargs_np={}, + method_input_dtypes=dtype, + method_all_as_kwargs_np={"other": x[1]}, class_name=class_name, method_name=method_name, ) @@ -1508,13 +1350,14 @@ def test_array__abs__( @handle_method( init_tree=CLASS_TREE, - method_tree="Array.__float__", + method_tree="Array.__ne__", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), - max_num_dims=0, + num_arrays=2, + shared_dtype=True, ), ) -def test_array__float__( +def test_array__ne__( dtype_and_x, method_name, class_name, @@ -1533,8 +1376,8 @@ def test_array__float__( method_flags=method_flags, init_all_as_kwargs_np={"data": x[0]}, init_input_dtypes=dtype, - method_input_dtypes=[], - method_all_as_kwargs_np={}, + method_input_dtypes=dtype, + method_all_as_kwargs_np={"other": x[1]}, class_name=class_name, method_name=method_name, ) @@ -1542,16 +1385,12 @@ def test_array__float__( @handle_method( init_tree=CLASS_TREE, - method_tree="Array.__int__", + method_tree="Array.__neg__", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), - max_num_dims=0, - min_value=-1e15, - max_value=1e15, ), - method_container_flags=st.just([False]), ) -def test_array__int__( +def test_array__neg__( dtype_and_x, method_name, class_name, @@ -1570,7 +1409,7 @@ def test_array__int__( method_flags=method_flags, init_all_as_kwargs_np={"data": x[0]}, init_input_dtypes=dtype, - method_input_dtypes=[], + method_input_dtypes=dtype, method_all_as_kwargs_np={}, class_name=class_name, method_name=method_name, @@ -1579,14 +1418,14 @@ def test_array__int__( @handle_method( init_tree=CLASS_TREE, - method_tree="Array.__complex__", + method_tree="Array.__or__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - max_num_dims=0, + available_dtypes=st.one_of(st.just(("bool",)), helpers.get_dtypes("integer")), + num_arrays=2, + shared_dtype=True, ), - method_container_flags=st.just([False]), ) -def test_array__complex__( +def test_array__or__( dtype_and_x, method_name, class_name, @@ -1605,8 +1444,8 @@ def test_array__complex__( method_flags=method_flags, init_all_as_kwargs_np={"data": x[0]}, init_input_dtypes=dtype, - method_input_dtypes=[], - method_all_as_kwargs_np={}, + method_input_dtypes=dtype, + method_all_as_kwargs_np={"other": x[1]}, class_name=class_name, method_name=method_name, ) @@ -1614,15 +1453,12 @@ def test_array__complex__( @handle_method( init_tree=CLASS_TREE, - method_tree="Array.__bool__", + method_tree="Array.__pos__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=st.one_of(st.just(("bool",)), helpers.get_dtypes("integer")), - max_num_dims=0, - min_value=0, - max_value=1, + available_dtypes=helpers.get_dtypes("numeric"), ), ) -def test_array__bool__( +def test_array__pos__( dtype_and_x, method_name, class_name, @@ -1641,7 +1477,7 @@ def test_array__bool__( method_flags=method_flags, init_all_as_kwargs_np={"data": x[0]}, init_input_dtypes=dtype, - method_input_dtypes=[], + method_input_dtypes=dtype, method_all_as_kwargs_np={}, class_name=class_name, method_name=method_name, @@ -1650,14 +1486,10 @@ def test_array__bool__( @handle_method( init_tree=CLASS_TREE, - method_tree="Array.__lt__", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, - shared_dtype=True, - ), + method_tree="Array.__pow__", + dtype_and_x=pow_helper(), ) -def test_array__lt__( +def test_array__pow__( dtype_and_x, method_name, class_name, @@ -1667,7 +1499,22 @@ def test_array__lt__( method_flags, on_device, ): - dtype, x = dtype_and_x + input_dtype, x = dtype_and_x + + # bfloat16 is not supported by numpy + assume(not ("bfloat16" in input_dtype)) + + # Make sure x2 isn't a float when x1 is integer + assume( + not (ivy.is_int_dtype(input_dtype[0] and ivy.is_float_dtype(input_dtype[1]))) + ) + + # Make sure x2 is non-negative when both is integer + if ivy.is_int_dtype(input_dtype[1]) and ivy.is_int_dtype(input_dtype[0]): + x[1] = np.abs(x[1]) + + x[0] = not_too_close_to_zero(x[0]) + x[1] = not_too_close_to_zero(x[1]) helpers.test_method( backend_to_test=backend_fw, on_device=on_device, @@ -1675,9 +1522,9 @@ def test_array__lt__( init_flags=init_flags, method_flags=method_flags, init_all_as_kwargs_np={"data": x[0]}, - init_input_dtypes=dtype, - method_input_dtypes=dtype, - method_all_as_kwargs_np={"other": x[1]}, + init_input_dtypes=[input_dtype[0]], + method_input_dtypes=[input_dtype[1]], + method_all_as_kwargs_np={"power": x[1]}, class_name=class_name, method_name=method_name, ) @@ -1685,14 +1532,17 @@ def test_array__lt__( @handle_method( init_tree=CLASS_TREE, - method_tree="Array.__le__", + method_tree="Array.__radd__", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2, + large_abs_safety_factor=2.5, + small_abs_safety_factor=2.5, + safety_factor_scale="log", shared_dtype=True, ), ) -def test_array__le__( +def test_array__radd__( dtype_and_x, method_name, class_name, @@ -1720,14 +1570,14 @@ def test_array__le__( @handle_method( init_tree=CLASS_TREE, - method_tree="Array.__eq__", + method_tree="Array.__rand__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), + available_dtypes=st.one_of(st.just(("bool",)), helpers.get_dtypes("integer")), num_arrays=2, shared_dtype=True, ), ) -def test_array__eq__( +def test_array__rand__( dtype_and_x, method_name, class_name, @@ -1755,14 +1605,17 @@ def test_array__eq__( @handle_method( init_tree=CLASS_TREE, - method_tree="Array.__ne__", + method_tree="Array.__rdivmod__", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2, + large_abs_safety_factor=2.5, + small_abs_safety_factor=2.5, + safety_factor_scale="log", shared_dtype=True, ), ) -def test_array__ne__( +def test_array__rdivmod__( dtype_and_x, method_name, class_name, @@ -1773,6 +1626,7 @@ def test_array__ne__( on_device, ): dtype, x = dtype_and_x + assume(not np.any(np.isclose(x[0], 0))) helpers.test_method( backend_to_test=backend_fw, on_device=on_device, @@ -1790,14 +1644,17 @@ def test_array__ne__( @handle_method( init_tree=CLASS_TREE, - method_tree="Array.__gt__", + method_tree="Array.__rfloordiv__", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2, + large_abs_safety_factor=3.0, + small_abs_safety_factor=3.0, + safety_factor_scale="log", shared_dtype=True, ), ) -def test_array__gt__( +def test_array__rfloordiv__( dtype_and_x, method_name, class_name, @@ -1808,6 +1665,7 @@ def test_array__gt__( on_device, ): dtype, x = dtype_and_x + assume(not np.any(np.isclose(x[0], 0))) helpers.test_method( backend_to_test=backend_fw, on_device=on_device, @@ -1825,14 +1683,14 @@ def test_array__gt__( @handle_method( init_tree=CLASS_TREE, - method_tree="Array.__ge__", + method_tree="Array.__rlshift__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), + available_dtypes=helpers.get_dtypes("integer"), num_arrays=2, - shared_dtype=True, + array_api_dtypes=True, ), ) -def test_array__ge__( +def test_array__rlshift__( dtype_and_x, method_name, class_name, @@ -1843,6 +1701,7 @@ def test_array__ge__( on_device, ): dtype, x = dtype_and_x + x[0] = np.asarray(np.clip(x[1], 0, np.iinfo(dtype[1]).bits - 1), dtype=dtype[1]) helpers.test_method( backend_to_test=backend_fw, on_device=on_device, @@ -1850,8 +1709,8 @@ def test_array__ge__( init_flags=init_flags, method_flags=method_flags, init_all_as_kwargs_np={"data": x[0]}, - init_input_dtypes=dtype, - method_input_dtypes=dtype, + init_input_dtypes=[dtype[0]], + method_input_dtypes=[dtype[1]], method_all_as_kwargs_np={"other": x[1]}, class_name=class_name, method_name=method_name, @@ -1860,15 +1719,13 @@ def test_array__ge__( @handle_method( init_tree=CLASS_TREE, - method_tree="Array.__and__", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=st.one_of(st.just(("bool",)), helpers.get_dtypes("integer")), - num_arrays=2, - shared_dtype=True, - ), + method_tree="Array.__rmatmul__", + x1=_get_first_matrix_and_dtype(), + x2=_get_second_matrix_and_dtype(), ) -def test_array__and__( - dtype_and_x, +def test_array__rmatmul__( + x1, + x2, method_name, class_name, ground_truth_backend, @@ -1877,17 +1734,18 @@ def test_array__and__( method_flags, on_device, ): - dtype, x = dtype_and_x + dtype1, x1 = x1 + dtype2, x2 = x2 helpers.test_method( backend_to_test=backend_fw, on_device=on_device, ground_truth_backend=ground_truth_backend, init_flags=init_flags, method_flags=method_flags, - init_all_as_kwargs_np={"data": x[0]}, - init_input_dtypes=dtype, - method_input_dtypes=dtype, - method_all_as_kwargs_np={"other": x[1]}, + init_all_as_kwargs_np={"data": x2}, + init_input_dtypes=dtype1, + method_input_dtypes=dtype2, + method_all_as_kwargs_np={"other": x1}, class_name=class_name, method_name=method_name, ) @@ -1895,14 +1753,17 @@ def test_array__and__( @handle_method( init_tree=CLASS_TREE, - method_tree="Array.__rand__", + method_tree="Array.__rmod__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=st.one_of(st.just(("bool",)), helpers.get_dtypes("integer")), + available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2, + large_abs_safety_factor=2.5, + small_abs_safety_factor=2.5, + safety_factor_scale="log", shared_dtype=True, ), ) -def test_array__rand__( +def test_array__rmod__( dtype_and_x, method_name, class_name, @@ -1913,6 +1774,7 @@ def test_array__rand__( on_device, ): dtype, x = dtype_and_x + assume(not np.any(np.isclose(x[0], 0))) helpers.test_method( backend_to_test=backend_fw, on_device=on_device, @@ -1930,15 +1792,17 @@ def test_array__rand__( @handle_method( init_tree=CLASS_TREE, - method_tree="Array.__iand__", + method_tree="Array.__rmul__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=st.one_of(st.just(("bool",)), helpers.get_dtypes("integer")), + available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2, + large_abs_safety_factor=2.5, + small_abs_safety_factor=2.5, + safety_factor_scale="log", shared_dtype=True, ), - method_container_flags=st.just([False]), ) -def test_array__iand__( +def test_array__rmul__( dtype_and_x, method_name, class_name, @@ -1966,14 +1830,14 @@ def test_array__iand__( @handle_method( init_tree=CLASS_TREE, - method_tree="Array.__or__", + method_tree="Array.__ror__", dtype_and_x=helpers.dtype_and_values( available_dtypes=st.one_of(st.just(("bool",)), helpers.get_dtypes("integer")), num_arrays=2, shared_dtype=True, ), ) -def test_array__or__( +def test_array__ror__( dtype_and_x, method_name, class_name, @@ -2001,14 +1865,10 @@ def test_array__or__( @handle_method( init_tree=CLASS_TREE, - method_tree="Array.__ror__", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=st.one_of(st.just(("bool",)), helpers.get_dtypes("integer")), - num_arrays=2, - shared_dtype=True, - ), + method_tree="Array.__rpow__", + dtype_and_x=pow_helper(), ) -def test_array__ror__( +def test_array__rpow__( dtype_and_x, method_name, class_name, @@ -2018,17 +1878,32 @@ def test_array__ror__( method_flags, on_device, ): - dtype, x = dtype_and_x + input_dtype, x = dtype_and_x + + # bfloat16 is not supported by numpy + assume(not ("bfloat16" in input_dtype)) + + # Make sure x2 isn't a float when x1 is integer + assume( + not (ivy.is_int_dtype(input_dtype[0] and ivy.is_float_dtype(input_dtype[1]))) + ) + + # Make sure x2 is non-negative when both is integer + if ivy.is_int_dtype(input_dtype[1]) and ivy.is_int_dtype(input_dtype[0]): + x[1] = np.abs(x[1]) + + x[0] = not_too_close_to_zero(x[0]) + x[1] = not_too_close_to_zero(x[1]) helpers.test_method( backend_to_test=backend_fw, on_device=on_device, ground_truth_backend=ground_truth_backend, init_flags=init_flags, method_flags=method_flags, - init_all_as_kwargs_np={"data": x[0]}, - init_input_dtypes=dtype, - method_input_dtypes=dtype, - method_all_as_kwargs_np={"other": x[1]}, + init_all_as_kwargs_np={"data": x[1]}, + init_input_dtypes=[input_dtype[1]], + method_input_dtypes=[input_dtype[0]], + method_all_as_kwargs_np={"power": x[0]}, class_name=class_name, method_name=method_name, ) @@ -2036,15 +1911,14 @@ def test_array__ror__( @handle_method( init_tree=CLASS_TREE, - method_tree="Array.__ior__", + method_tree="Array.__rrshift__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=st.one_of(st.just(("bool",)), helpers.get_dtypes("integer")), + available_dtypes=helpers.get_dtypes("integer"), num_arrays=2, - shared_dtype=True, + array_api_dtypes=True, ), - method_container_flags=st.just([False]), ) -def test_array__ior__( +def test_array__rrshift__( dtype_and_x, method_name, class_name, @@ -2055,6 +1929,7 @@ def test_array__ior__( on_device, ): dtype, x = dtype_and_x + x[0] = np.asarray(np.clip(x[0], 0, np.iinfo(dtype[0]).bits - 1), dtype=dtype[0]) helpers.test_method( backend_to_test=backend_fw, on_device=on_device, @@ -2062,8 +1937,8 @@ def test_array__ior__( init_flags=init_flags, method_flags=method_flags, init_all_as_kwargs_np={"data": x[0]}, - init_input_dtypes=dtype, - method_input_dtypes=dtype, + init_input_dtypes=[dtype[0]], + method_input_dtypes=[dtype[1]], method_all_as_kwargs_np={"other": x[1]}, class_name=class_name, method_name=method_name, @@ -2072,12 +1947,15 @@ def test_array__ior__( @handle_method( init_tree=CLASS_TREE, - method_tree="Array.__invert__", + method_tree="Array.__rshift__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=st.one_of(st.just(("bool",)), helpers.get_dtypes("integer")), + available_dtypes=helpers.get_dtypes("integer"), + num_arrays=2, + min_value=0, + shared_dtype=True, ), ) -def test_array__invert__( +def test_array__rshift__( dtype_and_x, method_name, class_name, @@ -2088,6 +1966,7 @@ def test_array__invert__( on_device, ): dtype, x = dtype_and_x + x[1] = np.asarray(np.clip(x[1], 0, np.iinfo(dtype[1]).bits - 1), dtype=dtype[1]) helpers.test_method( backend_to_test=backend_fw, on_device=on_device, @@ -2095,9 +1974,9 @@ def test_array__invert__( init_flags=init_flags, method_flags=method_flags, init_all_as_kwargs_np={"data": x[0]}, - init_input_dtypes=dtype, - method_input_dtypes=dtype, - method_all_as_kwargs_np={}, + init_input_dtypes=[dtype[0]], + method_input_dtypes=[dtype[1]], + method_all_as_kwargs_np={"other": x[1]}, class_name=class_name, method_name=method_name, ) @@ -2105,14 +1984,17 @@ def test_array__invert__( @handle_method( init_tree=CLASS_TREE, - method_tree="Array.__xor__", + method_tree="Array.__rsub__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=st.one_of(st.just(("bool",)), helpers.get_dtypes("integer")), + available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2, + large_abs_safety_factor=2.5, + small_abs_safety_factor=2.5, + safety_factor_scale="log", shared_dtype=True, ), ) -def test_array__xor__( +def test_array__rsub__( dtype_and_x, method_name, class_name, @@ -2140,14 +2022,17 @@ def test_array__xor__( @handle_method( init_tree=CLASS_TREE, - method_tree="Array.__rxor__", + method_tree="Array.__rtruediv__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=st.one_of(st.just(("bool",)), helpers.get_dtypes("integer")), + available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2, + large_abs_safety_factor=2.5, + small_abs_safety_factor=2.5, + safety_factor_scale="log", shared_dtype=True, ), ) -def test_array__rxor__( +def test_array__rtruediv__( dtype_and_x, method_name, class_name, @@ -2175,15 +2060,14 @@ def test_array__rxor__( @handle_method( init_tree=CLASS_TREE, - method_tree="Array.__ixor__", + method_tree="Array.__rxor__", dtype_and_x=helpers.dtype_and_values( available_dtypes=st.one_of(st.just(("bool",)), helpers.get_dtypes("integer")), num_arrays=2, shared_dtype=True, ), - method_container_flags=st.just([False]), ) -def test_array__ixor__( +def test_array__rxor__( dtype_and_x, method_name, class_name, @@ -2211,15 +2095,16 @@ def test_array__ixor__( @handle_method( init_tree=CLASS_TREE, - method_tree="Array.__lshift__", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("integer"), - num_arrays=2, - array_api_dtypes=True, + method_tree="Array.__setitem__", + ground_truth_backend="numpy", + dtypes_x_query_val=helpers.dtype_array_query_val( + available_dtypes=helpers.get_dtypes("valid"), ), + # ToDo: fix container method + method_container_flags=st.just([False]), ) -def test_array__lshift__( - dtype_and_x, +def test_array__setitem__( + dtypes_x_query_val, method_name, class_name, ground_truth_backend, @@ -2228,44 +2113,35 @@ def test_array__lshift__( method_flags, on_device, ): - dtype, x = dtype_and_x - max_bits = np.iinfo(dtype[0]).bits - max_shift = max_bits - 1 - x[1] = np.asarray(np.clip(x[1], 0, max_shift), dtype=dtype[1]) - max_value_before_shift = 2 ** (max_bits - x[1]) - 1 - overflow_threshold = 2 ** (max_bits - 1) - x[0] = np.asarray(np.clip(x[0], None, max_value_before_shift), dtype=dtype[0]) - if np.any(x[0] > overflow_threshold): - x[0] = np.asarray(np.clip(x[0], None, overflow_threshold), dtype=dtype[0]) - if np.any(x[0] < 0): - x[0] = np.asarray(np.abs(x[0]), dtype=dtype[0]) + dtypes, x, query, val = dtypes_x_query_val helpers.test_method( backend_to_test=backend_fw, on_device=on_device, ground_truth_backend=ground_truth_backend, init_flags=init_flags, method_flags=method_flags, - init_all_as_kwargs_np={"data": x[0]}, - init_input_dtypes=[dtype[0]], - method_input_dtypes=[dtype[1]], - method_all_as_kwargs_np={"other": x[1]}, + init_all_as_kwargs_np={"data": x}, + init_input_dtypes=[dtypes[0]], + method_input_dtypes=[*dtypes[1:]], + method_all_as_kwargs_np={"query": query, "val": val}, class_name=class_name, method_name=method_name, - rtol_=1e-5, - atol_=1e-5, ) @handle_method( init_tree=CLASS_TREE, - method_tree="Array.__rlshift__", + method_tree="Array.__sub__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("integer"), + available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2, - array_api_dtypes=True, + large_abs_safety_factor=2.5, + small_abs_safety_factor=2.5, + safety_factor_scale="log", + shared_dtype=True, ), ) -def test_array__rlshift__( +def test_array__sub__( dtype_and_x, method_name, class_name, @@ -2276,7 +2152,6 @@ def test_array__rlshift__( on_device, ): dtype, x = dtype_and_x - x[0] = np.asarray(np.clip(x[1], 0, np.iinfo(dtype[1]).bits - 1), dtype=dtype[1]) helpers.test_method( backend_to_test=backend_fw, on_device=on_device, @@ -2284,8 +2159,8 @@ def test_array__rlshift__( init_flags=init_flags, method_flags=method_flags, init_all_as_kwargs_np={"data": x[0]}, - init_input_dtypes=[dtype[0]], - method_input_dtypes=[dtype[1]], + init_input_dtypes=dtype, + method_input_dtypes=dtype, method_all_as_kwargs_np={"other": x[1]}, class_name=class_name, method_name=method_name, @@ -2294,15 +2169,17 @@ def test_array__rlshift__( @handle_method( init_tree=CLASS_TREE, - method_tree="Array.__ilshift__", + method_tree="Array.__truediv__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("integer"), + available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2, - array_api_dtypes=True, + large_abs_safety_factor=2.5, + small_abs_safety_factor=2.5, + safety_factor_scale="log", + shared_dtype=True, ), - method_container_flags=st.just([False]), ) -def test_array__ilshift__( +def test_array__truediv__( dtype_and_x, method_name, class_name, @@ -2313,7 +2190,6 @@ def test_array__ilshift__( on_device, ): dtype, x = dtype_and_x - x[1] = np.asarray(np.clip(x[1], 0, np.iinfo(dtype[1]).bits - 1), dtype=dtype[1]) helpers.test_method( backend_to_test=backend_fw, on_device=on_device, @@ -2321,8 +2197,8 @@ def test_array__ilshift__( init_flags=init_flags, method_flags=method_flags, init_all_as_kwargs_np={"data": x[0]}, - init_input_dtypes=[dtype[0]], - method_input_dtypes=[dtype[1]], + init_input_dtypes=dtype, + method_input_dtypes=dtype, method_all_as_kwargs_np={"other": x[1]}, class_name=class_name, method_name=method_name, @@ -2331,15 +2207,14 @@ def test_array__ilshift__( @handle_method( init_tree=CLASS_TREE, - method_tree="Array.__rshift__", + method_tree="Array.__xor__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("integer"), + available_dtypes=st.one_of(st.just(("bool",)), helpers.get_dtypes("integer")), num_arrays=2, - min_value=0, shared_dtype=True, ), ) -def test_array__rshift__( +def test_array__xor__( dtype_and_x, method_name, class_name, @@ -2350,7 +2225,6 @@ def test_array__rshift__( on_device, ): dtype, x = dtype_and_x - x[1] = np.asarray(np.clip(x[1], 0, np.iinfo(dtype[1]).bits - 1), dtype=dtype[1]) helpers.test_method( backend_to_test=backend_fw, on_device=on_device, @@ -2358,183 +2232,309 @@ def test_array__rshift__( init_flags=init_flags, method_flags=method_flags, init_all_as_kwargs_np={"data": x[0]}, - init_input_dtypes=[dtype[0]], - method_input_dtypes=[dtype[1]], + init_input_dtypes=dtype, + method_input_dtypes=dtype, method_all_as_kwargs_np={"other": x[1]}, class_name=class_name, method_name=method_name, ) -@handle_method( - init_tree=CLASS_TREE, - method_tree="Array.__rrshift__", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("integer"), - num_arrays=2, - array_api_dtypes=True, +def test_array_function(): + HANDLED_FUNCTIONS = {} + + class MyArray: + def __init__(self, data=None): + self.data = data + + def __ivy_array_function__(self, func, types, args, kwargs): + if func not in HANDLED_FUNCTIONS: + return NotImplemented + if not all( + issubclass(t, (MyArray, ivy.Array, ivy.NativeArray)) for t in types + ): + return NotImplemented + return HANDLED_FUNCTIONS[func](*args, **kwargs) + + def implements(ivy_function): + """Register an __ivy_array_function__ implementation for MyArray objects.""" + + def decorator(func): + HANDLED_FUNCTIONS[ivy_function] = func + return func + + return decorator + + @implements(ivy.abs) + def _(my_array, ivy_array): + my_array.data = abs(my_array.data) + ivy_array = ivy.abs(ivy_array) + return (my_array, ivy_array) + + x = MyArray(-3) + y = ivy.array([1, -1]) + xy = _(x, ivy_array=y) + x1 = xy[0] + y1 = xy[1] + assert x1.data == 3 + assert all(y1 == ivy.array([1, 1])) + + +@handle_test( + fn_tree="functional.ivy.native_array", # dummy fn_tree + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=2, + max_num_dims=2, + ), +) +def test_array_property_T( + dtype_x, + backend_fw, + test_flags, +): + _, data = dtype_x + with BackendHandler.update_backend(backend_fw) as ivy_backend: + data = ivy_backend.native_array(data[0]) + x = ivy_backend.Array(data) + ret = helpers.flatten_and_to_np(ret=x.T, backend=backend_fw) + ret_gt = helpers.flatten_and_to_np( + ret=ivy_backend.matrix_transpose(data), backend=backend_fw + ) + helpers.value_test( + ret_np_flat=ret, + ret_np_from_gt_flat=ret_gt, + backend=backend_fw, + ground_truth_backend=test_flags.ground_truth_backend, + ) + + +# TODO: avoid using dummy fn_tree in property tests + + +@handle_test( + fn_tree="functional.ivy.native_array", # dummy fn_tree + dtype_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("valid")), +) +def test_array_property_data( + dtype_x, + backend_fw, + test_flags, +): + _, data = dtype_x + with BackendHandler.update_backend(backend_fw) as ivy_backend: + data = ivy_backend.native_array(data[0]) + x = ivy_backend.Array(data) + ret = helpers.flatten_and_to_np(ret=x.data, backend=backend_fw) + ret_gt = helpers.flatten_and_to_np(ret=data, backend=backend_fw) + helpers.value_test( + ret_np_flat=ret, + ret_np_from_gt_flat=ret_gt, + backend=backend_fw, + ground_truth_backend=test_flags.ground_truth_backend, + ) + + +@handle_test( + fn_tree="functional.ivy.native_array", # dummy fn_tree + dtype_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("valid")), +) +def test_array_property_device( + dtype_x, + backend_fw, +): + _, data = dtype_x + with BackendHandler.update_backend(backend_fw) as ivy_backend: + data = ivy_backend.native_array(data[0]) + x = ivy_backend.Array(data) + ivy_backend.utils.assertions.check_equal( + x.device, ivy_backend.dev(data), as_array=False + ) + + +@handle_test( + fn_tree="functional.ivy.native_array", # dummy fn_tree + dtype_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("valid")), +) +def test_array_property_dtype( + dtype_x, + backend_fw, +): + _, data = dtype_x + with BackendHandler.update_backend(backend_fw) as ivy_backend: + data = ivy_backend.native_array(data[0]) + x = ivy_backend.Array(data) + ivy_backend.utils.assertions.check_equal( + x.dtype, ivy_backend.dtype(data), as_array=False + ) + + +@handle_test( + fn_tree="functional.ivy.native_array", # dummy fn_tree + dtype_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("complex")), +) +def test_array_property_imag( + dtype_x, + backend_fw, + test_flags, +): + _, data = dtype_x + with BackendHandler.update_backend(backend_fw) as ivy_backend: + data = ivy_backend.native_array(data[0]) + x = ivy_backend.Array(data) + ret = helpers.flatten_and_to_np(ret=x.imag, backend=backend_fw) + ret_gt = helpers.flatten_and_to_np(ret=ivy_backend.imag(x), backend=backend_fw) + helpers.value_test( + ret_np_flat=ret, + ret_np_from_gt_flat=ret_gt, + backend=backend_fw, + ground_truth_backend=test_flags.ground_truth_backend, + ) + + +@handle_test( + fn_tree="functional.ivy.native_array", # dummy fn_tree + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + ), +) +def test_array_property_itemsize( + dtype_x, + backend_fw, +): + dtype, data = dtype_x + with BackendHandler.update_backend(backend_fw) as ivy_backend: + data = ivy_backend.native_array(data[0]) + x = ivy_backend.Array(data) + ivy_backend.utils.assertions.check_equal( + x.itemsize, ivy_backend.to_numpy(x).itemsize, as_array=False + ) + + +@handle_test( + fn_tree="functional.ivy.native_array", # dummy fn_tree + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=2, ), ) -def test_array__rrshift__( - dtype_and_x, - method_name, - class_name, - ground_truth_backend, +def test_array_property_mT( + dtype_x, backend_fw, - init_flags, - method_flags, - on_device, + test_flags, ): - dtype, x = dtype_and_x - x[0] = np.asarray(np.clip(x[0], 0, np.iinfo(dtype[0]).bits - 1), dtype=dtype[0]) - helpers.test_method( - backend_to_test=backend_fw, - on_device=on_device, - ground_truth_backend=ground_truth_backend, - init_flags=init_flags, - method_flags=method_flags, - init_all_as_kwargs_np={"data": x[0]}, - init_input_dtypes=[dtype[0]], - method_input_dtypes=[dtype[1]], - method_all_as_kwargs_np={"other": x[1]}, - class_name=class_name, - method_name=method_name, - ) + _, data = dtype_x + with BackendHandler.update_backend(backend_fw) as ivy_backend: + data = ivy_backend.native_array(data[0]) + x = ivy_backend.Array(data) + ret = helpers.flatten_and_to_np(ret=x.mT, backend=backend_fw) + ret_gt = helpers.flatten_and_to_np( + ret=ivy_backend.matrix_transpose(data), backend=backend_fw + ) + helpers.value_test( + ret_np_flat=ret, + ret_np_from_gt_flat=ret_gt, + backend=backend_fw, + ground_truth_backend=test_flags.ground_truth_backend, + ) -@handle_method( - init_tree=CLASS_TREE, - method_tree="Array.__irshift__", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("integer"), - num_arrays=2, - array_api_dtypes=True, +@handle_test( + fn_tree="functional.ivy.native_array", # dummy fn_tree + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + ret_shape=True, ), - method_container_flags=st.just([False]), ) -def test_array__irshift__( - dtype_and_x, - method_name, - class_name, - ground_truth_backend, +def test_array_property_ndim( + dtype_x, backend_fw, - init_flags, - method_flags, - on_device, ): - dtype, x = dtype_and_x - x[1] = np.asarray(np.clip(x[1], 0, np.iinfo(dtype[1]).bits - 1), dtype=dtype[1]) - helpers.test_method( - backend_to_test=backend_fw, - on_device=on_device, - ground_truth_backend=ground_truth_backend, - init_flags=init_flags, - method_flags=method_flags, - init_all_as_kwargs_np={"data": x[0]}, - init_input_dtypes=[dtype[0]], - method_input_dtypes=[dtype[1]], - method_all_as_kwargs_np={"other": x[1]}, - class_name=class_name, - method_name=method_name, - ) + _, data, input_shape = dtype_x + with BackendHandler.update_backend(backend_fw) as ivy_backend: + data = ivy_backend.native_array(data[0]) + x = ivy_backend.Array(data) + ivy_backend.utils.assertions.check_equal( + x.ndim, len(input_shape), as_array=False + ) -@handle_method( - init_tree=CLASS_TREE, - method_tree="Array.__deepcopy__", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("integer"), - ), +@handle_test( + fn_tree="functional.ivy.native_array", # dummy fn_tree + dtype_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("complex")), ) -def test_array__deepcopy__( - dtype_and_x, - method_name, - class_name, - ground_truth_backend, +def test_array_property_real( + dtype_x, backend_fw, - init_flags, - method_flags, - on_device, + test_flags, ): - dtype, x = dtype_and_x - helpers.test_method( - backend_to_test=backend_fw, - on_device=on_device, - ground_truth_backend=ground_truth_backend, - init_flags=init_flags, - method_flags=method_flags, - init_all_as_kwargs_np={"data": x[0]}, - init_input_dtypes=dtype, - method_input_dtypes=[], - method_all_as_kwargs_np={"memodict": {}}, - class_name=class_name, - method_name=method_name, - ) + _, data = dtype_x + with BackendHandler.update_backend(backend_fw) as ivy_backend: + data = ivy_backend.native_array(data[0]) + x = ivy_backend.Array(data) + ret = helpers.flatten_and_to_np(ret=x.real, backend=backend_fw) + ret_gt = helpers.flatten_and_to_np(ret=ivy_backend.real(x), backend=backend_fw) + helpers.value_test( + ret_np_flat=ret, + ret_np_from_gt_flat=ret_gt, + backend=backend_fw, + ground_truth_backend=test_flags.ground_truth_backend, + ) -@handle_method( - init_tree=CLASS_TREE, - method_tree="Array.__len__", - dtype_and_x=helpers.dtype_and_values( +@handle_test( + fn_tree="functional.ivy.native_array", # dummy fn_tree + dtype_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), + ret_shape=True, ), ) -def test_array__len__( - dtype_and_x, - method_name, - class_name, - ground_truth_backend, +def test_array_property_shape( + dtype_x, backend_fw, - init_flags, - method_flags, - on_device, ): - dtype, x = dtype_and_x - helpers.test_method( - backend_to_test=backend_fw, - on_device=on_device, - ground_truth_backend=ground_truth_backend, - init_flags=init_flags, - method_flags=method_flags, - init_all_as_kwargs_np={"data": x[0]}, - init_input_dtypes=dtype, - method_input_dtypes=dtype, - method_all_as_kwargs_np={}, - class_name=class_name, - method_name=method_name, - ) + _, data, input_shape = dtype_x + with BackendHandler.update_backend(backend_fw) as ivy_backend: + data = ivy_backend.native_array(data[0]) + x = ivy_backend.Array(data) + ivy_backend.utils.assertions.check_equal( + x.shape, ivy_backend.Shape(input_shape), as_array=False + ) -@handle_method( - init_tree=CLASS_TREE, - method_tree="Array.__iter__", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("integer"), - min_dim_size=2, +@handle_test( + fn_tree="functional.ivy.native_array", # dummy fn_tree + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + ret_shape=True, min_num_dims=1, ), ) -def test_array__iter__( - dtype_and_x, - method_name, - class_name, - ground_truth_backend, +def test_array_property_size( + dtype_x, backend_fw, - init_flags, - method_flags, - on_device, ): - dtype, x = dtype_and_x - helpers.test_method( - backend_to_test=backend_fw, - on_device=on_device, - ground_truth_backend=ground_truth_backend, - init_flags=init_flags, - method_flags=method_flags, - init_all_as_kwargs_np={"data": x[0]}, - init_input_dtypes=dtype, - method_input_dtypes=dtype, - method_all_as_kwargs_np={}, - class_name=class_name, - method_name=method_name, - ) + _, data, input_shape = dtype_x + with BackendHandler.update_backend(backend_fw) as ivy_backend: + data = ivy_backend.native_array(data[0]) + x = ivy_backend.Array(data) + size_gt = 1 + for dim in input_shape: + size_gt *= dim + ivy_backend.utils.assertions.check_equal(x.size, size_gt, as_array=False) + + +@handle_test( + fn_tree="functional.ivy.native_array", # dummy fn_tree + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + ), +) +def test_array_property_strides(dtype_x, backend_fw): + dtype, data = dtype_x + with BackendHandler.update_backend(backend_fw) as ivy_backend: + data = ivy_backend.native_array(data[0]) + x = ivy_backend.Array(data) + ivy_backend.utils.assertions.check_equal( + x.strides, ivy_backend.to_numpy(x).strides, as_array=False + ) diff --git a/ivy_tests/test_ivy/test_misc/test_assertions.py b/ivy_tests/test_ivy/test_misc/test_assertions.py index a905815baf12b..d2dddc60f04d8 100644 --- a/ivy_tests/test_ivy/test_misc/test_assertions.py +++ b/ivy_tests/test_ivy/test_misc/test_assertions.py @@ -31,22 +31,21 @@ @pytest.mark.parametrize( - "x1, x2, allow_equal", + "results", [ - (5, 10, False), - (10, 5, False), - (5, 5, True), - (10, 5, True), + ([0, 1, 2]), + ([True, False]), + ([True, True]), ], ) -def test_check_less(x1, x2, allow_equal): +def test_check_all(results): filename = "except_out.txt" orig_stdout = sys.stdout f = open(filename, "w") sys.stdout = f lines = "" try: - check_less(x1, x2, allow_equal) + check_all(results) except Exception as e: print(e) sys.stdout = orig_stdout @@ -55,13 +54,10 @@ def test_check_less(x1, x2, allow_equal): with open(filename) as f: lines += f.read() - if x1 > x2 and allow_equal: - assert "lesser than or equal" in lines.strip() - - if x1 > x2 and not allow_equal: - assert "lesser than" in lines.strip() + if not all(results): + assert "one" in lines.strip() - if x1 < x2: + if all(results): assert not lines.strip() with contextlib.suppress(FileNotFoundError): @@ -69,23 +65,31 @@ def test_check_less(x1, x2, allow_equal): @pytest.mark.parametrize( - "x1, x2, allow_equal", + "args, fn, type, limit", [ - (5, 10, False), - (10, 5, False), - (5, 5, True), - (10, 5, True), + # INVALID CASES + ((1, 2, 0), ivy.array, "all", [3]), + ((0, 0), ivy.array, "all", [2]), + ((1, 1), ivy.array, "any", [3]), + ((0, 0, 1), ivy.array, "any", [3]), + ((1, 0, 1), ivy.array, "all_any", [3]), + # VALID + ((1, 1), ivy.array, "any", [2]), + ((0, 1), ivy.array, "any", [1]), + ((1, 1, 2), ivy.array, "all", [3]), ], ) -def test_check_greater(x1, x2, allow_equal): +def test_check_all_or_any_fn(args, fn, type, limit): filename = "except_out.txt" orig_stdout = sys.stdout f = open(filename, "w") sys.stdout = f lines = "" try: - check_greater(x1, x2, allow_equal) + check_all_or_any_fn(*args, fn=fn, type=type, limit=limit) + local_vars = {**locals()} except Exception as e: + local_vars = {**locals()} print(e) sys.stdout = orig_stdout f.close() @@ -93,36 +97,31 @@ def test_check_greater(x1, x2, allow_equal): with open(filename) as f: lines += f.read() - if x1 < x2 and allow_equal: - assert "greater than or equal" in lines.strip() - - if x1 < x2 and not allow_equal: - assert "greater than" in lines.strip() + if type == "all" or type == "any": + if "e" in local_vars.keys(): + assert "args must exist according to" in lines.strip() + else: + assert not lines.strip() - if x1 > x2: - assert not lines.strip() + else: + assert "type must be all or any" in lines.strip() with contextlib.suppress(FileNotFoundError): os.remove(filename) @pytest.mark.parametrize( - "x1, x2, inverse", - [ - (5, 10, False), - (10, 10, False), - (5, 5, True), - (10, 5, True), - ], + "results", + [([0, 1, 2]), ([False, False]), ([True, False]), ([0, False])], ) -def test_check_equal(x1, x2, inverse): +def test_check_any(results): filename = "except_out.txt" orig_stdout = sys.stdout f = open(filename, "w") sys.stdout = f lines = "" try: - check_equal(x1, x2, inverse) + check_any(results) except Exception as e: print(e) sys.stdout = orig_stdout @@ -131,48 +130,68 @@ def test_check_equal(x1, x2, inverse): with open(filename) as f: lines += f.read() - if inverse: - if x1 == x2: - assert "must not be equal" in lines.strip() - - if x1 != x2: - assert not lines.strip() - - if not inverse: - if x1 != x2: - assert "must be equal" in lines.strip() + if not any(results): + assert "all" in lines.strip() - if x1 == x2: - assert not lines.strip() + if all(results): + assert not lines.strip() with contextlib.suppress(FileNotFoundError): os.remove(filename) @pytest.mark.parametrize( - "x, allowed_types", - [(5.0, float), (ivy.array(5), type(ivy.array(8))), (5, float), ([5, 10], tuple)], + "device", + [ + # VALID CASES + "cpu", + "gpu:0", + "tpu:1", + # INVALID + "cuda", + "gpu;", + "tpu:abc12", + ], ) -def test_check_isinstance(x, allowed_types): +def test_check_dev_correct_formatting(device): + with pytest.raises(AssertionError): + check_dev_correct_formatting(device) + + +@pytest.mark.parametrize( + "x", + [ + # INVALID CASES + (ivy.array([1])), + (ivy.array([])), + # VALID + (ivy.array([1, 2])), + (ivy.array([[1, 2], [2, 3]])), + ], +) +def test_check_dimensions(x): filename = "except_out.txt" orig_stdout = sys.stdout f = open(filename, "w") sys.stdout = f lines = "" try: - check_isinstance(x, allowed_types) + check_dimensions(x) + local_vars = {**locals()} except Exception as e: + local_vars = {**locals()} print(e) + sys.stdout = orig_stdout f.close() with open(filename) as f: lines += f.read() - if not isinstance(x, allowed_types): - assert "must be one of the" in lines.strip() + if "e" in local_vars.keys(): + assert "greater than one dimension" in lines.strip() - if isinstance(x, allowed_types): + if "e" not in local_vars.keys(): assert not lines.strip() with contextlib.suppress(FileNotFoundError): @@ -180,17 +199,22 @@ def test_check_isinstance(x, allowed_types): @pytest.mark.parametrize( - "x, inverse", - [(None, False), ([], False), (None, True), ("abc", True)], + "elem, list, inverse", + [ + (1, [1, 2], False), + ("a", [1, 2], False), + (1, [2, 3], True), + (0, ["a", "b", "c"], True), + ], ) -def test_check_exists(x, inverse): +def test_check_elem_in_list(elem, list, inverse): filename = "except_out.txt" orig_stdout = sys.stdout f = open(filename, "w") sys.stdout = f lines = "" try: - check_exists(x, inverse) + check_elem_in_list(elem, list, inverse) except Exception as e: print(e) sys.stdout = orig_stdout @@ -200,40 +224,40 @@ def test_check_exists(x, inverse): lines += f.read() if not inverse: - if x is None: - assert "must not be" in lines.strip() + if elem not in list: + assert "must be one" in lines.strip() - if x: + if elem in list: assert not lines.strip() if inverse: - if x is None: + if elem not in list: assert not lines.strip() - if x: - assert "must be None" in lines.strip() + if elem in list: + assert "must not be one" in lines.strip() with contextlib.suppress(FileNotFoundError): os.remove(filename) @pytest.mark.parametrize( - "elem, list, inverse", + "x1, x2, inverse", [ - (1, [1, 2], False), - ("a", [1, 2], False), - (1, [2, 3], True), - (0, ["a", "b", "c"], True), + (5, 10, False), + (10, 10, False), + (5, 5, True), + (10, 5, True), ], ) -def test_check_elem_in_list(elem, list, inverse): +def test_check_equal(x1, x2, inverse): filename = "except_out.txt" orig_stdout = sys.stdout f = open(filename, "w") sys.stdout = f lines = "" try: - check_elem_in_list(elem, list, inverse) + check_equal(x1, x2, inverse) except Exception as e: print(e) sys.stdout = orig_stdout @@ -242,41 +266,36 @@ def test_check_elem_in_list(elem, list, inverse): with open(filename) as f: lines += f.read() - if not inverse: - if elem not in list: - assert "must be one" in lines.strip() + if inverse: + if x1 == x2: + assert "must not be equal" in lines.strip() - if elem in list: + if x1 != x2: assert not lines.strip() - if inverse: - if elem not in list: - assert not lines.strip() + if not inverse: + if x1 != x2: + assert "must be equal" in lines.strip() - if elem in list: - assert "must not be one" in lines.strip() + if x1 == x2: + assert not lines.strip() with contextlib.suppress(FileNotFoundError): os.remove(filename) @pytest.mark.parametrize( - "expression", - [ - (True), - "a", - (None), - (False), - ], + "x, inverse", + [(None, False), ([], False), (None, True), ("abc", True)], ) -def test_check_true(expression): +def test_check_exists(x, inverse): filename = "except_out.txt" orig_stdout = sys.stdout f = open(filename, "w") sys.stdout = f lines = "" try: - check_true(expression) + check_exists(x, inverse) except Exception as e: print(e) sys.stdout = orig_stdout @@ -285,11 +304,19 @@ def test_check_true(expression): with open(filename) as f: lines += f.read() - if not expression: - assert "True" in lines.strip() + if not inverse: + if x is None: + assert "must not be" in lines.strip() - if expression: - assert not lines.strip() + if x: + assert not lines.strip() + + if inverse: + if x is None: + assert not lines.strip() + + if x: + assert "must be None" in lines.strip() with contextlib.suppress(FileNotFoundError): os.remove(filename) @@ -331,33 +358,42 @@ def test_check_false(expression): @pytest.mark.parametrize( - "results", + "fill_value, dtype", [ - ([0, 1, 2]), - ([True, False]), - ([True, True]), + # INVALID CASES + (1.0, ivy.int16), + (1, ivy.float16), + (1, ivy.complex64), + # VALID + (1j, ivy.complex64), + (1.0, ivy.complex64), + (1.0, ivy.float16), + (1, ivy.int16), ], ) -def test_check_all(results): +def test_check_fill_value_and_dtype_are_compatible(fill_value, dtype): filename = "except_out.txt" orig_stdout = sys.stdout f = open(filename, "w") sys.stdout = f lines = "" try: - check_all(results) + check_fill_value_and_dtype_are_compatible(fill_value, dtype) + local_vars = {**locals()} except Exception as e: + local_vars = {**locals()} print(e) + sys.stdout = orig_stdout f.close() with open(filename) as f: lines += f.read() - if not all(results): - assert "one" in lines.strip() + if "e" in local_vars.keys(): + assert "not compatible" in lines.strip() - if all(results): + if "e" not in local_vars.keys(): assert not lines.strip() with contextlib.suppress(FileNotFoundError): @@ -365,29 +401,45 @@ def test_check_all(results): @pytest.mark.parametrize( - "results", - [([0, 1, 2]), ([False, False]), ([True, False]), ([0, False])], + "params, indices, axis, batch_dims", + [ + # INVALID CASES + (ivy.array([1, 2, 3]), ivy.array([1]), 2, 3), + (ivy.array([[1, 2, 3], [4, 5, 6]]), ivy.array([[0], [2]]), 1, 2), + (ivy.array([[1, 2, 3], [4, 5, 6]]), ivy.array([[0, 1], [1, 2], [2, 3]]), 1, 0), + (ivy.array([1, 2, 3]), ivy.array([[1, 2]]), 1, 0), + # VALID + (ivy.array([1, 2, 3]), ivy.array([1]), 0, 1), + (ivy.array([[1, 2, 3], [4, 5, 6]]), ivy.array([0, 2]), -1, 0), + (ivy.array([[1, 2, 3], [4, 5, 6]]), ivy.array([[0, 1], [1, 2]]), -1, 0), + ], ) -def test_check_any(results): +def test_check_gather_input_valid(params, indices, axis, batch_dims): filename = "except_out.txt" orig_stdout = sys.stdout f = open(filename, "w") sys.stdout = f lines = "" try: - check_any(results) + check_gather_input_valid(params, indices, axis, batch_dims) + local_vars = {**locals()} except Exception as e: + local_vars = {**locals()} print(e) + sys.stdout = orig_stdout f.close() with open(filename) as f: lines += f.read() - if not any(results): - assert "all" in lines.strip() + if "e" in local_vars.keys(): + assert ( + "must be less than or equal" in lines.strip() + or "batch dimensions must match in" in lines.strip() + ) - if all(results): + if "e" not in local_vars.keys(): assert not lines.strip() with contextlib.suppress(FileNotFoundError): @@ -395,84 +447,85 @@ def test_check_any(results): @pytest.mark.parametrize( - "args, fn, type, limit", + "params, indices, batch_dims", [ # INVALID CASES - ((1, 2, 0), ivy.array, "all", [3]), - ((0, 0), ivy.array, "all", [2]), - ((1, 1), ivy.array, "any", [3]), - ((0, 0, 1), ivy.array, "any", [3]), - ((1, 0, 1), ivy.array, "all_any", [3]), + (ivy.array([1, 2, 3]), ivy.array([1]), 2), + (ivy.array([[1, 2, 3], [4, 5, 6]]), ivy.array([0, 2]), 1), + (ivy.array([[1, 2, 3], [4, 5, 6]]), ivy.array([[0, 1], [1, 2], [2, 3]]), 1), + (ivy.array([1, 2, 3]), ivy.array([[1, 2]]), 0), # VALID - ((1, 1), ivy.array, "any", [2]), - ((0, 1), ivy.array, "any", [1]), - ((1, 1, 2), ivy.array, "all", [3]), + (ivy.array([1, 2, 3]), ivy.array([1]), 0), + (ivy.array([[1, 2, 3], [4, 5, 6]]), ivy.array([0, 2]), 0), + (ivy.array([[1, 2, 3], [4, 5, 6]]), ivy.array([[0, 1], [1, 2]]), 1), ], ) -def test_check_all_or_any_fn(args, fn, type, limit): +def test_check_gather_nd_input_valid(params, indices, batch_dims): filename = "except_out.txt" orig_stdout = sys.stdout f = open(filename, "w") sys.stdout = f lines = "" try: - check_all_or_any_fn(*args, fn=fn, type=type, limit=limit) + check_gather_nd_input_valid(params, indices, batch_dims) local_vars = {**locals()} except Exception as e: local_vars = {**locals()} print(e) + sys.stdout = orig_stdout f.close() with open(filename) as f: lines += f.read() - if type == "all" or type == "any": - if "e" in local_vars.keys(): - assert "args must exist according to" in lines.strip() - else: - assert not lines.strip() + if "e" in local_vars.keys(): + assert ( + "less than rank(`params`)" in lines.strip() + or "less than rank(`indices`)" in lines.strip() + or "dimensions must match in `params` and `indices`" in lines.strip() + or "index innermost dimension length must be <=" in lines.strip() + ) - else: - assert "type must be all or any" in lines.strip() + if "e" not in local_vars.keys(): + assert not lines.strip() with contextlib.suppress(FileNotFoundError): os.remove(filename) @pytest.mark.parametrize( - "x1, x2", + "x1, x2, allow_equal", [ - (ivy.array([1, 2, 3]), ivy.array([[4, 5, 6], [2, 3, 1]])), - (ivy.array([[1.0, 2.0], [3.0, 4.0]]), ivy.array([4, 5, 6])), - (ivy.array([1, 2]), ivy.array([3, 4, 5])), - (ivy.array([1]), ivy.array([2])), - (ivy.array([1, 2]), ivy.array([2, 3])), + (5, 10, False), + (10, 5, False), + (5, 5, True), + (10, 5, True), ], ) -def test_check_shape(x1, x2): +def test_check_greater(x1, x2, allow_equal): filename = "except_out.txt" orig_stdout = sys.stdout f = open(filename, "w") sys.stdout = f lines = "" try: - check_shape(x1, x2) - local_vars = {**locals()} + check_greater(x1, x2, allow_equal) except Exception as e: - local_vars = {**locals()} print(e) - sys.stdout = orig_stdout f.close() with open(filename) as f: lines += f.read() - if "e" in local_vars.keys(): - assert "same shape" in lines.strip() + if x1 < x2 and allow_equal: + assert "greater than or equal" in lines.strip() - if "e" not in local_vars.keys(): + if x1 < x2 and not allow_equal: + assert "greater than" in lines.strip() + + if x1 > x2: assert not lines.strip() with contextlib.suppress(FileNotFoundError): @@ -480,22 +533,24 @@ def test_check_shape(x1, x2): @pytest.mark.parametrize( - "x1, x2", + "var, data", [ - (ivy.array([1, 2, 3]), ivy.array([4, 5, 6])), - (ivy.array([1.0, 2.0, 3.0]), ivy.array([4, 5, 6])), - (ivy.array([1, 2, 3]), ivy.array([4j, 5 + 1j, 6])), - (ivy.array([1j]), ivy.array([2, 3 + 4j])), + # INVALID CASES + (ivy.array([1]), ivy.array([1, 2])), + (ivy.array([[1], [1], [2]]), ivy.array([1, 2])), + # VALID + (ivy.array([1, 2]), ivy.array([1])), + (ivy.array([[[1]]]), ivy.array([1, 2])), ], ) -def test_check_same_dtype(x1, x2): +def test_check_inplace_sizes_valid(var, data): filename = "except_out.txt" orig_stdout = sys.stdout f = open(filename, "w") sys.stdout = f lines = "" try: - check_same_dtype(x1, x2) + check_inplace_sizes_valid(var, data) local_vars = {**locals()} except Exception as e: local_vars = {**locals()} @@ -508,7 +563,7 @@ def test_check_same_dtype(x1, x2): lines += f.read() if "e" in local_vars.keys(): - assert "same dtype" in lines.strip() + assert "Could not output values of shape" in lines.strip() if "e" not in local_vars.keys(): assert not lines.strip() @@ -518,42 +573,29 @@ def test_check_same_dtype(x1, x2): @pytest.mark.parametrize( - "fill_value, dtype", - [ - # INVALID CASES - (1.0, ivy.int16), - (1, ivy.float16), - (1, ivy.complex64), - # VALID - (1j, ivy.complex64), - (1.0, ivy.complex64), - (1.0, ivy.float16), - (1, ivy.int16), - ], + "x, allowed_types", + [(5.0, float), (ivy.array(5), type(ivy.array(8))), (5, float), ([5, 10], tuple)], ) -def test_check_fill_value_and_dtype_are_compatible(fill_value, dtype): +def test_check_isinstance(x, allowed_types): filename = "except_out.txt" orig_stdout = sys.stdout f = open(filename, "w") sys.stdout = f lines = "" try: - check_fill_value_and_dtype_are_compatible(fill_value, dtype) - local_vars = {**locals()} + check_isinstance(x, allowed_types) except Exception as e: - local_vars = {**locals()} print(e) - sys.stdout = orig_stdout f.close() with open(filename) as f: lines += f.read() - if "e" in local_vars.keys(): - assert "not compatible" in lines.strip() + if not isinstance(x, allowed_types): + assert "must be one of the" in lines.strip() - if "e" not in local_vars.keys(): + if isinstance(x, allowed_types): assert not lines.strip() with contextlib.suppress(FileNotFoundError): @@ -561,34 +603,28 @@ def test_check_fill_value_and_dtype_are_compatible(fill_value, dtype): @pytest.mark.parametrize( - "data, segment_ids, num_segments", + "dtype", [ # INVALID CASES - (ivy.array([1, 2, 3]), ivy.array([0, 1, 0], dtype=ivy.int32), 2.0), - (ivy.array([1, 2, 3]), ivy.array([0, 1, 0], dtype=ivy.int32), 0), - (ivy.array([1, 2, 3]), ivy.array([0, 1, 0], dtype=ivy.int32), -2), - (ivy.array([1, 2, 3]), ivy.array([0, 1, 0], dtype=ivy.int32), 0), - (ivy.array([1, 2, 3]), ivy.array([0.0, 1.0, 0.0], dtype=ivy.float16), 0), - (ivy.array([1, 2]), ivy.array([0, 1, 0], dtype=ivy.int32), 0), - (ivy.array([1, 2, 3]), ivy.array([0, 1], dtype=ivy.int32), 0), - (ivy.array([1, 2, 3]), ivy.array([0, 1, 2], dtype=ivy.int32), 2), + "float64", + "int64", + "uint64", + "complex128" # VALID - ( - ivy.array([1, 2, 3]), - ivy.array([0, 1, 0], dtype=ivy.int32), - 2, - ), - (ivy.array([1, 2, 3]), ivy.array([0, 1, 0], dtype=ivy.int32), ivy.array([2])), + "float16", + "float32int32", + "int16", + "complex64", ], ) -def test_check_unsorted_segment_min_valid_params(data, segment_ids, num_segments): +def test_check_jax_x64_flag(dtype): filename = "except_out.txt" orig_stdout = sys.stdout f = open(filename, "w") sys.stdout = f lines = "" try: - check_unsorted_segment_min_valid_params(data, segment_ids, num_segments) + _check_jax_x64_flag(dtype) local_vars = {**locals()} except Exception as e: local_vars = {**locals()} @@ -601,13 +637,7 @@ def test_check_unsorted_segment_min_valid_params(data, segment_ids, num_segments lines += f.read() if "e" in local_vars.keys(): - assert ( - "num_segments must be of integer type" in lines.strip() - or "segment_ids must have an integer dtype" in lines.strip() - or "segment_ids should be equal to data.shape[0]" in lines.strip() - or "is out of range" in lines.strip() - or "num_segments must be positive" in lines.strip() - ) + assert "output not supported while jax_enable_x64" in lines.strip() if "e" not in local_vars.keys(): assert not lines.strip() @@ -617,27 +647,24 @@ def test_check_unsorted_segment_min_valid_params(data, segment_ids, num_segments @pytest.mark.parametrize( - "params, indices, axis, batch_dims", + "kernel_size, padding_size", [ # INVALID CASES - (ivy.array([1, 2, 3]), ivy.array([1]), 2, 3), - (ivy.array([[1, 2, 3], [4, 5, 6]]), ivy.array([[0], [2]]), 1, 2), - (ivy.array([[1, 2, 3], [4, 5, 6]]), ivy.array([[0, 1], [1, 2], [2, 3]]), 1, 0), - (ivy.array([1, 2, 3]), ivy.array([[1, 2]]), 1, 0), + (((2, 2), ((2, 2), (1, 1)))), + (((3, 3), ((2, 2), (1, 1)))), # VALID - (ivy.array([1, 2, 3]), ivy.array([1]), 0, 1), - (ivy.array([[1, 2, 3], [4, 5, 6]]), ivy.array([0, 2]), -1, 0), - (ivy.array([[1, 2, 3], [4, 5, 6]]), ivy.array([[0, 1], [1, 2]]), -1, 0), + (((5, 5), ((1, 1), (2, 2)))), + (((3, 3), ((1, 1), (0, 0)))), ], ) -def test_check_gather_input_valid(params, indices, axis, batch_dims): +def test_check_kernel_padding_size(kernel_size, padding_size): filename = "except_out.txt" orig_stdout = sys.stdout f = open(filename, "w") sys.stdout = f lines = "" try: - check_gather_input_valid(params, indices, axis, batch_dims) + check_kernel_padding_size(kernel_size, padding_size) local_vars = {**locals()} except Exception as e: local_vars = {**locals()} @@ -650,10 +677,7 @@ def test_check_gather_input_valid(params, indices, axis, batch_dims): lines += f.read() if "e" in local_vars.keys(): - assert ( - "must be less than or equal" in lines.strip() - or "batch dimensions must match in" in lines.strip() - ) + assert "less than or equal to half" in lines.strip() if "e" not in local_vars.keys(): assert not lines.strip() @@ -663,47 +687,37 @@ def test_check_gather_input_valid(params, indices, axis, batch_dims): @pytest.mark.parametrize( - "params, indices, batch_dims", + "x1, x2, allow_equal", [ - # INVALID CASES - (ivy.array([1, 2, 3]), ivy.array([1]), 2), - (ivy.array([[1, 2, 3], [4, 5, 6]]), ivy.array([0, 2]), 1), - (ivy.array([[1, 2, 3], [4, 5, 6]]), ivy.array([[0, 1], [1, 2], [2, 3]]), 1), - (ivy.array([1, 2, 3]), ivy.array([[1, 2]]), 0), - # VALID - (ivy.array([1, 2, 3]), ivy.array([1]), 0), - (ivy.array([[1, 2, 3], [4, 5, 6]]), ivy.array([0, 2]), 0), - (ivy.array([[1, 2, 3], [4, 5, 6]]), ivy.array([[0, 1], [1, 2]]), 1), + (5, 10, False), + (10, 5, False), + (5, 5, True), + (10, 5, True), ], ) -def test_check_gather_nd_input_valid(params, indices, batch_dims): +def test_check_less(x1, x2, allow_equal): filename = "except_out.txt" orig_stdout = sys.stdout f = open(filename, "w") sys.stdout = f lines = "" try: - check_gather_nd_input_valid(params, indices, batch_dims) - local_vars = {**locals()} + check_less(x1, x2, allow_equal) except Exception as e: - local_vars = {**locals()} print(e) - sys.stdout = orig_stdout f.close() with open(filename) as f: lines += f.read() - if "e" in local_vars.keys(): - assert ( - "less than rank(`params`)" in lines.strip() - or "less than rank(`indices`)" in lines.strip() - or "dimensions must match in `params` and `indices`" in lines.strip() - or "index innermost dimension length must be <=" in lines.strip() - ) + if x1 > x2 and allow_equal: + assert "lesser than or equal" in lines.strip() - if "e" not in local_vars.keys(): + if x1 > x2 and not allow_equal: + assert "lesser than" in lines.strip() + + if x1 < x2: assert not lines.strip() with contextlib.suppress(FileNotFoundError): @@ -711,24 +725,22 @@ def test_check_gather_nd_input_valid(params, indices, batch_dims): @pytest.mark.parametrize( - "var, data", + "x1, x2", [ - # INVALID CASES - (ivy.array([1]), ivy.array([1, 2])), - (ivy.array([[1], [1], [2]]), ivy.array([1, 2])), - # VALID - (ivy.array([1, 2]), ivy.array([1])), - (ivy.array([[[1]]]), ivy.array([1, 2])), + (ivy.array([1, 2, 3]), ivy.array([4, 5, 6])), + (ivy.array([1.0, 2.0, 3.0]), ivy.array([4, 5, 6])), + (ivy.array([1, 2, 3]), ivy.array([4j, 5 + 1j, 6])), + (ivy.array([1j]), ivy.array([2, 3 + 4j])), ], ) -def test_check_inplace_sizes_valid(var, data): +def test_check_same_dtype(x1, x2): filename = "except_out.txt" orig_stdout = sys.stdout f = open(filename, "w") sys.stdout = f lines = "" try: - check_inplace_sizes_valid(var, data) + check_same_dtype(x1, x2) local_vars = {**locals()} except Exception as e: local_vars = {**locals()} @@ -741,7 +753,7 @@ def test_check_inplace_sizes_valid(var, data): lines += f.read() if "e" in local_vars.keys(): - assert "Could not output values of shape" in lines.strip() + assert "same dtype" in lines.strip() if "e" not in local_vars.keys(): assert not lines.strip() @@ -751,24 +763,23 @@ def test_check_inplace_sizes_valid(var, data): @pytest.mark.parametrize( - "var, data", + "x1, x2", [ - # INVALID CASES - ((2, 1), (1, 2, 1)), - ((2, 1), (3, 1)), - # VALID - ((1, 2), (1, 2)), - ((1, 2), (1, 1, 1)), + (ivy.array([1, 2, 3]), ivy.array([[4, 5, 6], [2, 3, 1]])), + (ivy.array([[1.0, 2.0], [3.0, 4.0]]), ivy.array([4, 5, 6])), + (ivy.array([1, 2]), ivy.array([3, 4, 5])), + (ivy.array([1]), ivy.array([2])), + (ivy.array([1, 2]), ivy.array([2, 3])), ], ) -def test_check_shapes_broadcastable(var, data): +def test_check_shape(x1, x2): filename = "except_out.txt" orig_stdout = sys.stdout f = open(filename, "w") sys.stdout = f lines = "" try: - check_shapes_broadcastable(var, data) + check_shape(x1, x2) local_vars = {**locals()} except Exception as e: local_vars = {**locals()} @@ -781,7 +792,7 @@ def test_check_shapes_broadcastable(var, data): lines += f.read() if "e" in local_vars.keys(): - assert "Could not broadcast shape" in lines.strip() + assert "same shape" in lines.strip() if "e" not in local_vars.keys(): assert not lines.strip() @@ -791,24 +802,24 @@ def test_check_shapes_broadcastable(var, data): @pytest.mark.parametrize( - "x", + "var, data", [ # INVALID CASES - (ivy.array([1])), - (ivy.array([])), + ((2, 1), (1, 2, 1)), + ((2, 1), (3, 1)), # VALID - (ivy.array([1, 2])), - (ivy.array([[1, 2], [2, 3]])), + ((1, 2), (1, 2)), + ((1, 2), (1, 1, 1)), ], ) -def test_check_dimensions(x): +def test_check_shapes_broadcastable(var, data): filename = "except_out.txt" orig_stdout = sys.stdout f = open(filename, "w") sys.stdout = f lines = "" try: - check_dimensions(x) + check_shapes_broadcastable(var, data) local_vars = {**locals()} except Exception as e: local_vars = {**locals()} @@ -821,7 +832,7 @@ def test_check_dimensions(x): lines += f.read() if "e" in local_vars.keys(): - assert "greater than one dimension" in lines.strip() + assert "Could not broadcast shape" in lines.strip() if "e" not in local_vars.keys(): assert not lines.strip() @@ -831,39 +842,34 @@ def test_check_dimensions(x): @pytest.mark.parametrize( - "kernel_size, padding_size", + "expression", [ - # INVALID CASES - (((2, 2), ((2, 2), (1, 1)))), - (((3, 3), ((2, 2), (1, 1)))), - # VALID - (((5, 5), ((1, 1), (2, 2)))), - (((3, 3), ((1, 1), (0, 0)))), + (True), + "a", + (None), + (False), ], ) -def test_check_kernel_padding_size(kernel_size, padding_size): +def test_check_true(expression): filename = "except_out.txt" orig_stdout = sys.stdout f = open(filename, "w") sys.stdout = f lines = "" try: - check_kernel_padding_size(kernel_size, padding_size) - local_vars = {**locals()} + check_true(expression) except Exception as e: - local_vars = {**locals()} print(e) - sys.stdout = orig_stdout f.close() with open(filename) as f: lines += f.read() - if "e" in local_vars.keys(): - assert "less than or equal to half" in lines.strip() + if not expression: + assert "True" in lines.strip() - if "e" not in local_vars.keys(): + if expression: assert not lines.strip() with contextlib.suppress(FileNotFoundError): @@ -871,28 +877,34 @@ def test_check_kernel_padding_size(kernel_size, padding_size): @pytest.mark.parametrize( - "dtype", + "data, segment_ids, num_segments", [ # INVALID CASES - "float64", - "int64", - "uint64", - "complex128" + (ivy.array([1, 2, 3]), ivy.array([0, 1, 0], dtype=ivy.int32), 2.0), + (ivy.array([1, 2, 3]), ivy.array([0, 1, 0], dtype=ivy.int32), 0), + (ivy.array([1, 2, 3]), ivy.array([0, 1, 0], dtype=ivy.int32), -2), + (ivy.array([1, 2, 3]), ivy.array([0, 1, 0], dtype=ivy.int32), 0), + (ivy.array([1, 2, 3]), ivy.array([0.0, 1.0, 0.0], dtype=ivy.float16), 0), + (ivy.array([1, 2]), ivy.array([0, 1, 0], dtype=ivy.int32), 0), + (ivy.array([1, 2, 3]), ivy.array([0, 1], dtype=ivy.int32), 0), + (ivy.array([1, 2, 3]), ivy.array([0, 1, 2], dtype=ivy.int32), 2), # VALID - "float16", - "float32int32", - "int16", - "complex64", + ( + ivy.array([1, 2, 3]), + ivy.array([0, 1, 0], dtype=ivy.int32), + 2, + ), + (ivy.array([1, 2, 3]), ivy.array([0, 1, 0], dtype=ivy.int32), ivy.array([2])), ], ) -def test_check_jax_x64_flag(dtype): +def test_check_unsorted_segment_min_valid_params(data, segment_ids, num_segments): filename = "except_out.txt" orig_stdout = sys.stdout f = open(filename, "w") sys.stdout = f lines = "" try: - _check_jax_x64_flag(dtype) + check_unsorted_segment_min_valid_params(data, segment_ids, num_segments) local_vars = {**locals()} except Exception as e: local_vars = {**locals()} @@ -905,28 +917,16 @@ def test_check_jax_x64_flag(dtype): lines += f.read() if "e" in local_vars.keys(): - assert "output not supported while jax_enable_x64" in lines.strip() + assert ( + "num_segments must be of integer type" in lines.strip() + or "segment_ids must have an integer dtype" in lines.strip() + or "segment_ids should be equal to data.shape[0]" in lines.strip() + or "is out of range" in lines.strip() + or "num_segments must be positive" in lines.strip() + ) if "e" not in local_vars.keys(): assert not lines.strip() with contextlib.suppress(FileNotFoundError): os.remove(filename) - - -@pytest.mark.parametrize( - "device", - [ - # VALID CASES - "cpu", - "gpu:0", - "tpu:1", - # INVALID - "cuda", - "gpu;", - "tpu:abc12", - ], -) -def test_check_dev_correct_formatting(device): - with pytest.raises(AssertionError): - check_dev_correct_formatting(device) diff --git a/ivy_tests/test_ivy/test_misc/test_backend_handler.py b/ivy_tests/test_ivy/test_misc/test_backend_handler.py index 1bb09f050e2ce..a5229a9c195ba 100644 --- a/ivy_tests/test_ivy/test_misc/test_backend_handler.py +++ b/ivy_tests/test_ivy/test_misc/test_backend_handler.py @@ -3,6 +3,14 @@ import pytest import importlib import types +import numpy as np + +# local +import ivy +from ivy.utils.backend.handler import _backend_dict + +# TODO fix due to refactor +from ivy_tests.test_ivy.helpers.available_frameworks import _available_frameworks try: @@ -27,121 +35,28 @@ except ImportError: paddle = types.SimpleNamespace() paddle.Tensor = lambda x: x -import numpy as np - -# local -import ivy -from ivy.utils.backend.handler import _backend_dict - -# TODO fix due to refactor -from ivy_tests.test_ivy.helpers.available_frameworks import _available_frameworks - -available_frameworks_with_none = _available_frameworks()[:] -available_frameworks_with_none.append(None) available_array_types_class = [ ("numpy", ""), ] - available_array_types_input = [ ("numpy", np.array(3.0)), ] +available_frameworks_with_none = _available_frameworks()[:] +# Dynamic Backend -if "tensorflow" in _available_frameworks(): - available_array_types_input.append(("tensorflow", tf.constant([3.0]))) - available_array_types_class.append( - ("tensorflow", "") - ) - -if "jax" in _available_frameworks(): - available_array_types_input.append(("jax", jnp.array(3.0))) - if version.parse(jax.__version__) >= version.parse("0.4.1"): - available_array_types_class.append( - ("jax", "") - ) - else: - available_array_types_class.append( - ("jax", "") - ) - - -if "torch" in _available_frameworks(): - available_array_types_input.append(("torch", torch.tensor([3.0]))) - available_array_types_class.append(("torch", "")) - -if "paddle" in _available_frameworks(): - available_array_types_input.append(("paddle", paddle.to_tensor([3.0]))) - available_array_types_class.append(("paddle", "")) - - -@pytest.mark.parametrize( - ( - "backend", - "array_type", - ), - available_array_types_class, -) -def test_set_backend(backend, array_type): - # recording data before backend change - stack_before = [] - func_address_before = id(ivy.sum) - stack_before.extend(ivy.backend_stack) - - ivy.set_backend(backend) - stack_after = ivy.backend_stack - # check that the function id has changed as inverse=True. - ivy.utils.assertions.check_equal( - func_address_before, id(ivy.sum), inverse=True, as_array=False - ) - # using ivy assertions to ensure the desired backend is set - ivy.utils.assertions.check_less(len(stack_before), len(stack_after), as_array=False) - ivy.utils.assertions.check_equal(ivy.current_backend_str(), backend, as_array=False) - backend = importlib.import_module(_backend_dict[backend]) - ivy.utils.assertions.check_equal(stack_after[-1], backend, as_array=False) - x = ivy.array([1, 2, 3]) - ivy.utils.assertions.check_equal( - str(type(ivy.to_native(x))), array_type, as_array=False - ) - - -@pytest.mark.parametrize("backend", _available_frameworks()) -def test_previous_backend(backend): - if not ivy.backend_stack: - assert ivy.previous_backend() is None - - ivy.set_backend(backend) - stack_before_unset = [] - func_address_before_unset = id(ivy.sum) - stack_before_unset.extend(ivy.backend_stack) - - previous_backend = ivy.previous_backend() - stack_after_unset = ivy.backend_stack - # check that the function id has changed as inverse=True. - ivy.utils.assertions.check_equal( - func_address_before_unset, id(ivy.sum), inverse=True, as_array=False - ) - ivy.utils.assertions.check_equal( - previous_backend, - importlib.import_module(_backend_dict[backend]), - as_array=False, - ) - ivy.utils.assertions.check_greater( - len(stack_before_unset), len(stack_after_unset), as_array=False - ) - - # checking a previously set backend is still set - ivy.set_backend(backend) - ivy.set_backend("numpy") - ivy.previous_backend() - ivy.utils.assertions.check_equal(ivy.current_backend_str(), backend, as_array=False) - +backends = list(_backend_dict.keys()) -def test_unset_backend(): - for backend_str in _available_frameworks(): - ivy.set_backend(backend_str) - ivy.unset_backend() - ivy.utils.assertions.check_equal(ivy.backend_stack, [], as_array=False) +@pytest.mark.parametrize("excluded", available_frameworks_with_none) +def test_choose_random_backend(excluded): + backend = ivy.choose_random_backend(excluded=excluded) + if excluded is None: + assert backend in list(_backend_dict.keys()) + else: + backends_list = list(_backend_dict.keys()) + backends_list.remove(excluded) + assert backend in backends_list @pytest.mark.parametrize( @@ -172,22 +87,6 @@ def test_current_backend(backend, array_type): ) -@pytest.mark.parametrize("excluded", available_frameworks_with_none) -def test_choose_random_backend(excluded): - backend = ivy.choose_random_backend(excluded=excluded) - if excluded is None: - assert backend in list(_backend_dict.keys()) - else: - backends_list = list(_backend_dict.keys()) - backends_list.remove(excluded) - assert backend in backends_list - - -# Dynamic Backend - -backends = list(_backend_dict.keys()) - - @pytest.mark.parametrize( "middle_backend,end_backend", [(a, b) for a in backends for b in backends if a != b] ) @@ -240,6 +139,21 @@ def test_dynamic_backend_all_combos(middle_backend, end_backend): assert isinstance(nativ_cont["b"].data, ivy.current_backend().NativeArray) +def test_dynamic_backend_context_manager(): + with ivy.dynamic_backend_as(True): + a = ivy.array([0.0, 1.0]) + b = ivy.array([2.0, 3.0]) + + with ivy.dynamic_backend_as(False): + c = ivy.array([4.0, 5.0]) + d = ivy.array([6.0, 7.0]) + + assert a.dynamic_backend is True + assert b.dynamic_backend is True + assert c.dynamic_backend is False + assert d.dynamic_backend is False + + def test_dynamic_backend_setter(): a = ivy.array([1, 2, 3]) type_a = type(a.data) @@ -258,6 +172,76 @@ def test_dynamic_backend_setter(): assert isinstance(a.data, torch.Tensor) +@pytest.mark.parametrize("backend", _available_frameworks()) +def test_previous_backend(backend): + if not ivy.backend_stack: + assert ivy.previous_backend() is None + + ivy.set_backend(backend) + stack_before_unset = [] + func_address_before_unset = id(ivy.sum) + stack_before_unset.extend(ivy.backend_stack) + + previous_backend = ivy.previous_backend() + stack_after_unset = ivy.backend_stack + # check that the function id has changed as inverse=True. + ivy.utils.assertions.check_equal( + func_address_before_unset, id(ivy.sum), inverse=True, as_array=False + ) + ivy.utils.assertions.check_equal( + previous_backend, + importlib.import_module(_backend_dict[backend]), + as_array=False, + ) + ivy.utils.assertions.check_greater( + len(stack_before_unset), len(stack_after_unset), as_array=False + ) + + # checking a previously set backend is still set + ivy.set_backend(backend) + ivy.set_backend("numpy") + ivy.previous_backend() + ivy.utils.assertions.check_equal(ivy.current_backend_str(), backend, as_array=False) + + +@pytest.mark.parametrize( + ( + "backend", + "array_type", + ), + available_array_types_class, +) +def test_set_backend(backend, array_type): + # recording data before backend change + stack_before = [] + func_address_before = id(ivy.sum) + stack_before.extend(ivy.backend_stack) + + ivy.set_backend(backend) + stack_after = ivy.backend_stack + # check that the function id has changed as inverse=True. + ivy.utils.assertions.check_equal( + func_address_before, id(ivy.sum), inverse=True, as_array=False + ) + # using ivy assertions to ensure the desired backend is set + ivy.utils.assertions.check_less(len(stack_before), len(stack_after), as_array=False) + ivy.utils.assertions.check_equal(ivy.current_backend_str(), backend, as_array=False) + backend = importlib.import_module(_backend_dict[backend]) + ivy.utils.assertions.check_equal(stack_after[-1], backend, as_array=False) + x = ivy.array([1, 2, 3]) + ivy.utils.assertions.check_equal( + str(type(ivy.to_native(x))), array_type, as_array=False + ) + + +def test_unset_backend(): + for backend_str in _available_frameworks(): + ivy.set_backend(backend_str) + + ivy.unset_backend() + ivy.utils.assertions.check_equal(ivy.backend_stack, [], as_array=False) + + def test_variables(): # clear the backend stack ivy.unset_backend() @@ -280,16 +264,30 @@ def test_variables(): assert isinstance(stat_cont["w"], tf.Variable) -def test_dynamic_backend_context_manager(): - with ivy.dynamic_backend_as(True): - a = ivy.array([0.0, 1.0]) - b = ivy.array([2.0, 3.0]) +available_frameworks_with_none.append(None) - with ivy.dynamic_backend_as(False): - c = ivy.array([4.0, 5.0]) - d = ivy.array([6.0, 7.0]) +if "tensorflow" in _available_frameworks(): + available_array_types_input.append(("tensorflow", tf.constant([3.0]))) + available_array_types_class.append( + ("tensorflow", "") + ) - assert a.dynamic_backend is True - assert b.dynamic_backend is True - assert c.dynamic_backend is False - assert d.dynamic_backend is False +if "jax" in _available_frameworks(): + available_array_types_input.append(("jax", jnp.array(3.0))) + if version.parse(jax.__version__) >= version.parse("0.4.1"): + available_array_types_class.append( + ("jax", "") + ) + else: + available_array_types_class.append( + ("jax", "") + ) + + +if "torch" in _available_frameworks(): + available_array_types_input.append(("torch", torch.tensor([3.0]))) + available_array_types_class.append(("torch", "")) + +if "paddle" in _available_frameworks(): + available_array_types_input.append(("paddle", paddle.to_tensor([3.0]))) + available_array_types_class.append(("paddle", "")) diff --git a/ivy_tests/test_ivy/test_misc/test_container.py b/ivy_tests/test_ivy/test_misc/test_container.py index d8986b03440e7..151cc08dee600 100644 --- a/ivy_tests/test_ivy/test_misc/test_container.py +++ b/ivy_tests/test_ivy/test_misc/test_container.py @@ -14,42 +14,131 @@ from ivy.utils.exceptions import IvyException -def test_container_list_join(on_device): - container_0 = Container( - { - "a": [ivy.array([1], device=on_device)], - "b": { - "c": [ivy.array([2], device=on_device)], - "d": [ivy.array([3], device=on_device)], - }, - } - ) - container_1 = Container( +def test_container_all_false(on_device): + assert Container({"a": False, "b": {"c": [], "d": 0}}).cont_all_false() + assert not Container({"a": False, "b": {"c": [1], "d": 0}}).cont_all_false() + # noinspection PyBroadException + try: + assert Container( + {"a": ivy.array([1], device=on_device), "b": {"c": [1], "d": True}} + ).cont_all_false(assert_is_bool=True) + error_raised = False + except IvyException: + error_raised = True + assert error_raised + + +@pytest.mark.parametrize("include_empty", [True, False]) +def test_container_all_key_chains(include_empty, on_device): + a_val = Container() if include_empty else ivy.array([1], device=on_device) + bc_val = Container() if include_empty else ivy.array([2], device=on_device) + bd_val = Container() if include_empty else ivy.array([3], device=on_device) + dict_in = {"a": a_val, "b": {"c": bc_val, "d": bd_val}} + container = Container(dict_in) + kcs = container.cont_all_key_chains(include_empty) + assert kcs[0] == "a" + assert kcs[1] == "b/c" + assert kcs[2] == "b/d" + + +def test_container_all_true(on_device): + assert not Container( + {"a": ivy.array([1], device=on_device), "b": {"c": [], "d": True}} + ).cont_all_true() + assert Container( + {"a": ivy.array([1], device=on_device), "b": {"c": [1], "d": True}} + ).cont_all_true() + # noinspection PyBroadException + try: + assert Container( + {"a": ivy.array([1], device=on_device), "b": {"c": [1], "d": True}} + ).cont_all_true(assert_is_bool=True) + error_raised = False + except IvyException: + error_raised = True + assert error_raised + + +def test_container_as_bools(on_device): + dict_in = {"a": ivy.array([1], device=on_device), "b": {"c": [], "d": True}} + container = Container(dict_in) + + container_bools = container.cont_as_bools() + assert container_bools["a"] is True + assert container_bools.a is True + assert container_bools["b"]["c"] is False + assert container_bools.b.c is False + assert container_bools["b"]["d"] is True + assert container_bools.b.d is True + + +def test_container_assert_contains(on_device): + arr0 = ivy.array([0.0], device=on_device) + arr1 = ivy.array([1.0], device=on_device) + arr2 = ivy.array([2.0], device=on_device) + sub_cont = Container({"c": arr1, "d": arr2}) + container = Container({"a": arr0, "b": sub_cont}) + + # keys + assert "a" in container + assert "b" in container + assert "c" not in container + assert "b/c" in container + assert "d" not in container + assert "b/d" in container + + # sub-container + container.cont_assert_contains_sub_container(container) + container.cont_assert_contains_sub_container(sub_cont) + assert sub_cont in container + + # partial sub-container + partial_sub_cont = Container({"b": {"d": arr2}}) + container.cont_assert_contains_sub_container(container, partial=True) + container.cont_assert_contains_sub_container(partial_sub_cont, partial=True) + try: + partial_sub_cont.cont_assert_contains_sub_container(container, partial=True) + error_caught = False + except IvyException: + error_caught = True + assert error_caught + # sub-structure + sub_struc = Container( { - "a": [ivy.array([4], device=on_device)], - "b": { - "c": [ivy.array([5], device=on_device)], - "d": [ivy.array([6], device=on_device)], - }, + "c": ivy.array([3.0], device=on_device), + "d": ivy.array([4.0], device=on_device), } ) - container_list_joined = ivy.Container.cont_list_join([container_0, container_1]) - assert np.allclose(ivy.to_numpy(container_list_joined["a"][0]), np.array([1])) - assert np.allclose(ivy.to_numpy(container_list_joined.a[0]), np.array([1])) - assert np.allclose(ivy.to_numpy(container_list_joined["b"]["c"][0]), np.array([2])) - assert np.allclose(ivy.to_numpy(container_list_joined.b.c[0]), np.array([2])) - assert np.allclose(ivy.to_numpy(container_list_joined["b"]["d"][0]), np.array([3])) - assert np.allclose(ivy.to_numpy(container_list_joined.b.d[0]), np.array([3])) - assert np.allclose(ivy.to_numpy(container_list_joined["a"][1]), np.array([4])) - assert np.allclose(ivy.to_numpy(container_list_joined.a[1]), np.array([4])) - assert np.allclose(ivy.to_numpy(container_list_joined["b"]["c"][1]), np.array([5])) - assert np.allclose(ivy.to_numpy(container_list_joined.b.c[1]), np.array([5])) - assert np.allclose(ivy.to_numpy(container_list_joined["b"]["d"][1]), np.array([6])) - assert np.allclose(ivy.to_numpy(container_list_joined.b.d[1]), np.array([6])) + try: + not container.cont_assert_contains_sub_container(sub_struc) + error_caught = False + except IvyException: + error_caught = True + assert error_caught + assert sub_struc not in container + container.cont_assert_contains_sub_structure(sub_struc) + container.cont_assert_contains_sub_structure(container) + + # partial sub-structure + partial_sub_struc = Container({"b": {"d": ivy.array([4.0], device=on_device)}}) + container.cont_assert_contains_sub_structure(container, partial=True) + container.cont_assert_contains_sub_structure(partial_sub_struc, partial=True) + try: + partial_sub_struc.cont_assert_contains_sub_structure(container, partial=True) + error_caught = False + except IvyException: + error_caught = True + assert error_caught -def test_container_list_stack(on_device): - container_0 = Container( +def test_container_assert_identical(on_device): + # without key_chains specification + arr1 = ivy.array([1], device=on_device) + arr2 = ivy.array([2], device=on_device) + arr3 = ivy.array([3], device=on_device) + container0 = Container({"a": arr1, "b": {"c": arr2, "d": arr3}}) + container1 = Container({"a": arr1, "b": {"c": arr2, "d": arr3}}) + container2 = Container( { "a": ivy.array([1], device=on_device), "b": { @@ -58,71 +147,41 @@ def test_container_list_stack(on_device): }, } ) - container_1 = Container( - { - "a": ivy.array([4], device=on_device), - "b": { - "c": ivy.array([5], device=on_device), - "d": ivy.array([6], device=on_device), - }, - } - ) - container_list_stacked = ivy.Container.cont_list_stack( - [container_0, container_1], 0 - ) - assert np.allclose(ivy.to_numpy(container_list_stacked["a"][0]), np.array([1])) - assert np.allclose(ivy.to_numpy(container_list_stacked.a[0]), np.array([1])) - assert np.allclose(ivy.to_numpy(container_list_stacked["b"]["c"][0]), np.array([2])) - assert np.allclose(ivy.to_numpy(container_list_stacked.b.c[0]), np.array([2])) - assert np.allclose(ivy.to_numpy(container_list_stacked["b"]["d"][0]), np.array([3])) - assert np.allclose(ivy.to_numpy(container_list_stacked.b.d[0]), np.array([3])) - assert np.allclose(ivy.to_numpy(container_list_stacked["a"][1]), np.array([4])) - assert np.allclose(ivy.to_numpy(container_list_stacked.a[1]), np.array([4])) - assert np.allclose(ivy.to_numpy(container_list_stacked["b"]["c"][1]), np.array([5])) - assert np.allclose(ivy.to_numpy(container_list_stacked.b.c[1]), np.array([5])) - assert np.allclose(ivy.to_numpy(container_list_stacked["b"]["d"][1]), np.array([6])) - assert np.allclose(ivy.to_numpy(container_list_stacked.b.d[1]), np.array([6])) + container3 = Container({"b": {"d": arr3}}) + container4 = Container({"d": arr3}) + # the same + ivy.Container.cont_assert_identical([container0, container1]) + ivy.Container.cont_assert_identical([container1, container0]) -def test_container_unify(on_device): - # on_devices and containers - on_devices = list() - dev0 = on_device - on_devices.append(dev0) - conts = dict() - conts[dev0] = Container( - { - "a": ivy.array([1], device=dev0), - "b": {"c": ivy.array([2], device=dev0), "d": ivy.array([3], device=dev0)}, - } - ) - if "gpu" in on_device and ivy.num_gpus() > 1: - idx = ivy.num_gpus() - 1 - dev1 = on_device[:-1] + str(idx) - on_devices.append(dev1) - conts[dev1] = Container( - { - "a": ivy.array([4], device=dev1), - "b": { - "c": ivy.array([5], device=dev1), - "d": ivy.array([6], device=dev1), - }, - } - ) + # not the same + try: + ivy.Container.cont_assert_identical([container0, container2]) + error_caught = False + except IvyException: + error_caught = True + assert error_caught + try: + ivy.Container.cont_assert_identical([container1, container2]) + error_caught = False + except IvyException: + error_caught = True + assert error_caught - # test - container_unified = ivy.Container.cont_unify(conts, dev0, "concat", 0) - assert np.allclose(ivy.to_numpy(container_unified.a[0]), np.array([1])) - assert np.allclose(ivy.to_numpy(container_unified.b.c[0]), np.array([2])) - assert np.allclose(ivy.to_numpy(container_unified.b.d[0]), np.array([3])) - if len(on_devices) > 1: - assert np.allclose(ivy.to_numpy(container_unified.a[1]), np.array([4])) - assert np.allclose(ivy.to_numpy(container_unified.b.c[1]), np.array([5])) - assert np.allclose(ivy.to_numpy(container_unified.b.d[1]), np.array([6])) + # partial + ivy.Container.cont_assert_identical([container0, container3], partial=True) + ivy.Container.cont_assert_identical([container3, container0], partial=True) + try: + ivy.Container.cont_assert_identical([container4, container0], partial=True) + error_caught = False + except IvyException: + error_caught = True + assert error_caught -def test_container_combine(on_device): - container_0 = Container( +def test_container_assert_identical_structure(on_device): + # without key_chains specification + container0 = Container( { "a": ivy.array([1], device=on_device), "b": { @@ -131,59 +190,142 @@ def test_container_combine(on_device): }, } ) - container_1 = Container( + container1 = Container( { - "a": ivy.array([4], device=on_device), + "a": ivy.array([3], device=on_device), "b": { - "c": ivy.array([5], device=on_device), - "e": ivy.array([6], device=on_device), + "c": ivy.array([4], device=on_device), + "d": ivy.array([5], device=on_device), }, } ) - container_comb = ivy.Container.cont_combine(container_0, container_1) - assert np.equal(ivy.to_numpy(container_comb.a), np.array([4])) - assert np.equal(ivy.to_numpy(container_comb.b.c), np.array([5])) - assert np.equal(ivy.to_numpy(container_comb.b.d), np.array([3])) - assert np.equal(ivy.to_numpy(container_comb.b.e), np.array([6])) - - -def test_container_diff(on_device): - # all different arrays - container_0 = Container( + container2 = Container( { - "a": ivy.array([1], device=on_device), + "a": ivy.array([3], device=on_device), "b": { - "c": ivy.array([2], device=on_device), - "d": ivy.array([3], device=on_device), + "c": ivy.array([4], device=on_device), + "d": ivy.array([5], device=on_device), + "e": ivy.array([6], device=on_device), }, } ) - container_1 = Container( + container3 = Container( { - "a": ivy.array([4], device=on_device), + "a": ivy.array([3], device=on_device), "b": { - "c": ivy.array([5], device=on_device), - "d": ivy.array([6], device=on_device), + "c": ivy.array([4], device=on_device), + "d": ivy.array([5], device=on_device), }, + "e": ivy.array([6], device=on_device), } ) - container_diff = ivy.Container.cont_diff(container_0, container_1) - assert np.equal(ivy.to_numpy(container_diff.a.diff_0), np.array([1])) - assert np.equal(ivy.to_numpy(container_diff.a.diff_1), np.array([4])) - assert np.equal(ivy.to_numpy(container_diff.b.c.diff_0), np.array([2])) - assert np.equal(ivy.to_numpy(container_diff.b.c.diff_1), np.array([5])) - assert np.equal(ivy.to_numpy(container_diff.b.d.diff_0), np.array([3])) - assert np.equal(ivy.to_numpy(container_diff.b.d.diff_1), np.array([6])) - container_diff_diff_only = ivy.Container.cont_diff( - container_0, container_1, mode="diff_only" - ) - assert container_diff_diff_only.cont_to_dict() == container_diff.cont_to_dict() - container_diff_same_only = ivy.Container.cont_diff( - container_0, container_1, mode="same_only" - ) - assert container_diff_same_only.cont_to_dict() == {} + container4 = Container({"b": {"d": ivy.array([4], device=on_device)}}) + container5 = Container({"d": ivy.array([4], device=on_device)}) - # some different arrays + # with identical + ivy.Container.cont_assert_identical_structure([container0, container1]) + ivy.Container.cont_assert_identical_structure([container1, container0]) + ivy.Container.cont_assert_identical_structure([container1, container0, container1]) + + # without identical + try: + ivy.Container.cont_assert_identical_structure( + [container0, container1, container2, container3] + ) + error_caught = False + except IvyException: + error_caught = True + # partial + try: + ivy.Container.cont_assert_identical_structure( + [container0, container1, container2, container3, container4, container5], + partial=True, + ) + error_caught = False + except IvyException: + error_caught = True + assert error_caught + try: + ivy.Container.cont_assert_identical_structure( + [container0, container5], partial=True + ) + error_caught = False + except IvyException: + error_caught = True + assert error_caught + + +def test_container_at_key_chain(on_device): + dict_in = { + "a": ivy.array([1], device=on_device), + "b": { + "c": ivy.array([2], device=on_device), + "d": ivy.array([3], device=on_device), + }, + } + container = Container(dict_in) + + # explicit function call + sub_container = container.cont_at_key_chain("b") + assert np.allclose(ivy.to_numpy(sub_container["c"]), np.array([2])) + sub_container = container.cont_at_key_chain("b/c") + assert np.allclose(ivy.to_numpy(sub_container), np.array([2])) + + # overridden built-in function call + sub_container = container["b"] + assert np.allclose(ivy.to_numpy(sub_container["c"]), np.array([2])) + sub_container = container["b/c"] + assert np.allclose(ivy.to_numpy(sub_container), np.array([2])) + + +def test_container_at_key_chains(on_device): + dict_in = { + "a": ivy.array([1], device=on_device), + "b": { + "c": ivy.array([2], device=on_device), + "d": ivy.array([3], device=on_device), + }, + } + container = Container(dict_in) + target_cont = Container({"a": True, "b": {"c": True}}) + new_container = container.cont_at_key_chains(target_cont) + assert np.allclose(ivy.to_numpy(new_container["a"]), np.array([1])) + assert np.allclose(ivy.to_numpy(new_container["b"]["c"]), np.array([2])) + assert "d" not in new_container["b"] + new_container = container.cont_at_key_chains(["b/c", "b/d"]) + assert "a" not in new_container + assert np.allclose(ivy.to_numpy(new_container["b"]["c"]), np.array([2])) + assert np.allclose(ivy.to_numpy(new_container["b"]["d"]), np.array([3])) + new_container = container.cont_at_key_chains("b/c") + assert "a" not in new_container + assert np.allclose(ivy.to_numpy(new_container["b"]["c"]), np.array([2])) + assert "d" not in new_container["b"] + + +def test_container_at_keys(on_device): + dict_in = { + "a": ivy.array([1], device=on_device), + "b": { + "c": ivy.array([2], device=on_device), + "d": ivy.array([3], device=on_device), + }, + } + container = Container(dict_in) + new_container = container.cont_at_keys(["a", "c"]) + assert np.allclose(ivy.to_numpy(new_container["a"]), np.array([1])) + assert np.allclose(ivy.to_numpy(new_container["b"]["c"]), np.array([2])) + assert "d" not in new_container["b"] + new_container = container.cont_at_keys("c") + assert "a" not in new_container + assert np.allclose(ivy.to_numpy(new_container["b"]["c"]), np.array([2])) + assert "d" not in new_container["b"] + new_container = container.cont_at_keys(["b"]) + assert "a" not in new_container + assert np.allclose(ivy.to_numpy(new_container["b"]["c"]), np.array([2])) + assert np.allclose(ivy.to_numpy(new_container["b"]["d"]), np.array([3])) + + +def test_container_combine(on_device): container_0 = Container( { "a": ivy.array([1], device=on_device), @@ -195,111 +337,281 @@ def test_container_diff(on_device): ) container_1 = Container( { - "a": ivy.array([1], device=on_device), + "a": ivy.array([4], device=on_device), "b": { "c": ivy.array([5], device=on_device), - "d": ivy.array([3], device=on_device), + "e": ivy.array([6], device=on_device), }, } ) - container_diff = ivy.Container.cont_diff(container_0, container_1) - assert np.equal(ivy.to_numpy(container_diff.a), np.array([1])) - assert np.equal(ivy.to_numpy(container_diff.b.c.diff_0), np.array([2])) - assert np.equal(ivy.to_numpy(container_diff.b.c.diff_1), np.array([5])) - assert np.equal(ivy.to_numpy(container_diff.b.d), np.array([3])) - container_diff_diff_only = ivy.Container.cont_diff( - container_0, container_1, mode="diff_only" - ) - assert "a" not in container_diff_diff_only - assert "b" in container_diff_diff_only - assert "c" in container_diff_diff_only["b"] - assert "d" not in container_diff_diff_only["b"] - container_diff_same_only = ivy.Container.cont_diff( - container_0, container_1, mode="same_only" - ) - assert "a" in container_diff_same_only - assert "b" in container_diff_same_only - assert "c" not in container_diff_same_only["b"] - assert "d" in container_diff_same_only["b"] + container_comb = ivy.Container.cont_combine(container_0, container_1) + assert np.equal(ivy.to_numpy(container_comb.a), np.array([4])) + assert np.equal(ivy.to_numpy(container_comb.b.c), np.array([5])) + assert np.equal(ivy.to_numpy(container_comb.b.d), np.array([3])) + assert np.equal(ivy.to_numpy(container_comb.b.e), np.array([6])) - # all different keys - container_0 = Container( + +def test_container_common_key_chains(on_device): + arr1 = ivy.array([1], device=on_device) + arr2 = ivy.array([2], device=on_device) + arr3 = ivy.array([3], device=on_device) + cont0 = Container({"a": arr1, "b": {"c": arr2, "d": arr3}}) + cont1 = Container({"b": {"c": arr2, "d": arr3, "e": arr1}}) + cont2 = Container({"a": arr1, "b": {"d": arr3, "e": arr1}}) + + # 0 + common_kcs = Container.cont_common_key_chains([cont0]) + assert len(common_kcs) == 3 + assert "a" in common_kcs + assert "b/c" in common_kcs + assert "b/d" in common_kcs + + # 0-1 + common_kcs = Container.cont_common_key_chains([cont0, cont1]) + assert len(common_kcs) == 2 + assert "b/c" in common_kcs + assert "b/d" in common_kcs + + # 0-2 + common_kcs = Container.cont_common_key_chains([cont0, cont2]) + assert len(common_kcs) == 2 + assert "a" in common_kcs + assert "b/d" in common_kcs + + # 1-2 + common_kcs = Container.cont_common_key_chains([cont1, cont2]) + assert len(common_kcs) == 2 + assert "b/d" in common_kcs + assert "b/e" in common_kcs + + # all + common_kcs = Container.cont_common_key_chains([cont0, cont1, cont2]) + assert len(common_kcs) == 1 + assert "b/d" in common_kcs + + +def test_container_cont_inplace_update(on_device): + container0 = Container( { "a": ivy.array([1], device=on_device), "b": { - "c": ivy.array([2], device=on_device), - "d": ivy.array([3], device=on_device), - }, - } - ) - container_1 = Container( - { - "e": ivy.array([1], device=on_device), - "f": { - "g": ivy.array([2], device=on_device), - "h": ivy.array([3], device=on_device), + "c": ivy.array([1], device=on_device), + "d": ivy.array([2], device=on_device), }, } ) - container_diff = ivy.Container.cont_diff(container_0, container_1) - assert np.equal(ivy.to_numpy(container_diff.a.diff_0), np.array([1])) - assert np.equal(ivy.to_numpy(container_diff.b.diff_0.c), np.array([2])) - assert np.equal(ivy.to_numpy(container_diff.b.diff_0.d), np.array([3])) - assert np.equal(ivy.to_numpy(container_diff.e.diff_1), np.array([1])) - assert np.equal(ivy.to_numpy(container_diff.f.diff_1.g), np.array([2])) - assert np.equal(ivy.to_numpy(container_diff.f.diff_1.h), np.array([3])) - container_diff_diff_only = ivy.Container.cont_diff( - container_0, container_1, mode="diff_only" - ) - assert container_diff_diff_only.cont_to_dict() == container_diff.cont_to_dict() - container_diff_same_only = ivy.Container.cont_diff( - container_0, container_1, mode="same_only" - ) - assert container_diff_same_only.cont_to_dict() == {} - - # some different keys - container_0 = Container( + id0 = id(container0) + container1 = Container( { - "a": ivy.array([1], device=on_device), + "a": ivy.array([0], device=on_device), "b": { - "c": ivy.array([2], device=on_device), - "d": ivy.array([3], device=on_device), + "c": ivy.array([0], device=on_device), + "d": ivy.array([0], device=on_device), }, } ) - container_1 = Container( - { - "a": ivy.array([1], device=on_device), - "b": { - "c": ivy.array([2], device=on_device), - "e": ivy.array([3], device=on_device), - }, + id1 = id(container1) + assert ivy.Container.cont_all_false(container0.all_equal(container1)) + container0.inplace_update(container1) + assert id0 == id(container0) + assert id1 == id(container1) + assert ivy.Container.cont_all_true(container0.all_equal(container1)) + + +def test_container_contains(on_device): + arr0 = ivy.array([0.0], device=on_device) + arr1 = ivy.array([1.0], device=on_device) + arr2 = ivy.array([2.0], device=on_device) + sub_cont = Container({"c": arr1, "d": arr2}) + container = Container({"a": arr0, "b": sub_cont}) + + # keys + assert "a" in container + assert "b" in container + assert "c" not in container + assert "b/c" in container + assert "d" not in container + assert "b/d" in container + + # sub-container + assert container.cont_contains_sub_container(container) + assert container.cont_contains_sub_container(sub_cont) + assert sub_cont in container + + # partial sub-container + partial_sub_cont = Container({"b": {"d": arr2}}) + assert container.cont_contains_sub_container(container, partial=True) + assert container.cont_contains_sub_container(partial_sub_cont, partial=True) + assert not partial_sub_cont.cont_contains_sub_container(container, partial=True) + + # sub-structure + sub_struc = Container( + { + "c": ivy.array([3.0], device=on_device), + "d": ivy.array([4.0], device=on_device), } ) - container_diff = ivy.Container.cont_diff(container_0, container_1) - assert np.equal(ivy.to_numpy(container_diff.a), np.array([1])) - assert np.equal(ivy.to_numpy(container_diff.b.c), np.array([2])) - assert np.equal(ivy.to_numpy(container_diff.b.d.diff_0), np.array([3])) - assert np.equal(ivy.to_numpy(container_diff.b.e.diff_1), np.array([3])) - container_diff_diff_only = ivy.Container.cont_diff( - container_0, container_1, mode="diff_only" - ) - assert "a" not in container_diff_diff_only - assert "b" in container_diff_diff_only - assert "c" not in container_diff_diff_only["b"] - assert "d" in container_diff_diff_only["b"] - assert "e" in container_diff_diff_only["b"] - container_diff_same_only = ivy.Container.cont_diff( - container_0, container_1, mode="same_only" - ) - assert "a" in container_diff_same_only - assert "b" in container_diff_same_only - assert "c" in container_diff_same_only["b"] - assert "d" not in container_diff_same_only["b"] - assert "e" not in container_diff_same_only["b"] + assert not container.cont_contains_sub_container(sub_struc) + assert sub_struc not in container + assert container.cont_contains_sub_structure(sub_struc) + assert container.cont_contains_sub_structure(container) - # same containers - container_0 = Container( + # partial sub-structure + partial_sub_struc = Container({"b": {"d": ivy.array([4.0], device=on_device)}}) + assert container.cont_contains_sub_structure(container, partial=True) + assert container.cont_contains_sub_structure(partial_sub_struc, partial=True) + assert not partial_sub_struc.cont_contains_sub_structure(container, partial=True) + + +def test_container_copy(on_device): + dict_in = { + "a": ivy.array([0.0], device=on_device), + "b": { + "c": ivy.array([1.0], device=on_device), + "d": ivy.array([2.0], device=on_device), + }, + } + cont = Container(dict_in) + cont_deepcopy = cont.cont_copy() + assert np.allclose(ivy.to_numpy(cont.a), ivy.to_numpy(cont_deepcopy.a)) + assert np.allclose(ivy.to_numpy(cont.b.c), ivy.to_numpy(cont_deepcopy.b.c)) + assert np.allclose(ivy.to_numpy(cont.b.d), ivy.to_numpy(cont_deepcopy.b.d)) + assert id(cont) != id(cont_deepcopy) + assert id(cont.a) == id(cont_deepcopy.a) + assert id(cont.b.c) == id(cont_deepcopy.b.c) + assert id(cont.b.d) == id(cont_deepcopy.b.d) + + +def test_container_create_if_absent(on_device): + dict_in = { + "a": ivy.array([[[1.0], [2.0], [3.0]]], device=on_device), + "b": { + "c": ivy.array([[[2.0], [4.0], [6.0]]], device=on_device), + "d": ivy.array([[[3.0], [6.0], [9.0]]], device=on_device), + }, + } + + # depth 1 + container = Container(dict_in) + container.cont_create_if_absent("a", None, True) + assert np.allclose(ivy.to_numpy(container.a), np.array([[[1.0], [2.0], [3.0]]])) + container.cont_create_if_absent("e", ivy.array([[[4.0], [8.0], [12.0]]]), True) + assert np.allclose(ivy.to_numpy(container.e), np.array([[[4.0], [8.0], [12.0]]])) + + # depth 2 + container.cont_create_if_absent("f/g", np.array([[[5.0], [10.0], [15.0]]]), True) + assert np.allclose(ivy.to_numpy(container.f.g), np.array([[[5.0], [10.0], [15.0]]])) + + +@pytest.mark.parametrize("inplace", [True, False]) +def test_container_cutoff_at_depth(inplace, on_device): + # values + a_val = ivy.array([1], device=on_device) + bcde_val = ivy.array([2], device=on_device) + + # depth 1 + cont = Container({"a": a_val, "b": {"c": {"d": {"e": bcde_val}}}}) + cont_cutoff = cont.cont_cutoff_at_depth(1, inplace=inplace) + if inplace: + cont_cutoff = cont + assert np.allclose(ivy.to_numpy(cont_cutoff.a), ivy.to_numpy(a_val)) + assert not cont_cutoff.b + + # depth 2 + cont = Container({"a": a_val, "b": {"c": {"d": {"e": bcde_val}}}}) + cont_cutoff = cont.cont_cutoff_at_depth(2, inplace=inplace) + if inplace: + cont_cutoff = cont + assert np.allclose(ivy.to_numpy(cont_cutoff.a), ivy.to_numpy(a_val)) + assert not cont_cutoff.b.c + + # depth 3 + cont = Container({"a": a_val, "b": {"c": {"d": {"e": bcde_val}}}}) + cont_cutoff = cont.cont_cutoff_at_depth(3, inplace=inplace) + if inplace: + cont_cutoff = cont + assert np.allclose(ivy.to_numpy(cont_cutoff.a), ivy.to_numpy(a_val)) + assert not cont_cutoff.b.c.d + + # depth 4 + cont = Container({"a": a_val, "b": {"c": {"d": {"e": bcde_val}}}}) + cont_cutoff = cont.cont_cutoff_at_depth(4, inplace=inplace) + if inplace: + cont_cutoff = cont + assert np.allclose(ivy.to_numpy(cont_cutoff.a), ivy.to_numpy(a_val)) + assert np.allclose(ivy.to_numpy(cont_cutoff.b.c.d.e), ivy.to_numpy(bcde_val)) + + +@pytest.mark.parametrize("inplace", [True, False]) +def test_container_cutoff_at_height(inplace, on_device): + # values + d_val = ivy.array([2], device=on_device) + e_val = ivy.array([3], device=on_device) + + # height 0 + cont = Container({"a": {"c": {"d": d_val}}, "b": {"c": {"d": {"e": e_val}}}}) + cont_cutoff = cont.cont_cutoff_at_height(0, inplace=inplace) + if inplace: + cont_cutoff = cont + assert np.allclose(ivy.to_numpy(cont_cutoff.a.c.d), ivy.to_numpy(d_val)) + assert np.allclose(ivy.to_numpy(cont_cutoff.b.c.d.e), ivy.to_numpy(e_val)) + + # height 1 + cont = Container({"a": {"c": {"d": d_val}}, "b": {"c": {"d": {"e": e_val}}}}) + cont_cutoff = cont.cont_cutoff_at_height(1, inplace=inplace) + if inplace: + cont_cutoff = cont + assert not cont_cutoff.a.c + assert not cont_cutoff.b.c.d + + # height 2 + cont = Container({"a": {"c": {"d": d_val}}, "b": {"c": {"d": {"e": e_val}}}}) + cont_cutoff = cont.cont_cutoff_at_height(2, inplace=inplace) + if inplace: + cont_cutoff = cont + assert not cont_cutoff.a + assert not cont_cutoff.b.c + + # height 3 + cont = Container({"a": {"c": {"d": d_val}}, "b": {"c": {"d": {"e": e_val}}}}) + cont_cutoff = cont.cont_cutoff_at_height(3, inplace=inplace) + if inplace: + cont_cutoff = cont + assert not cont_cutoff.a + assert not cont_cutoff.b + + # height 4 + cont = Container({"a": {"c": {"d": d_val}}, "b": {"c": {"d": {"e": e_val}}}}) + cont_cutoff = cont.cont_cutoff_at_height(4, inplace=inplace) + if inplace: + cont_cutoff = cont + assert not cont_cutoff + + +def test_container_deep_copy(on_device): + dict_in = { + "a": ivy.array([0.0], device=on_device), + "b": { + "c": ivy.array([1.0], device=on_device), + "d": ivy.array([2.0], device=on_device), + }, + } + cont = Container(dict_in) + cont_deepcopy = cont.cont_deep_copy() + assert np.allclose(ivy.to_numpy(cont.a), ivy.to_numpy(cont_deepcopy.a)) + assert np.allclose(ivy.to_numpy(cont.b.c), ivy.to_numpy(cont_deepcopy.b.c)) + assert np.allclose(ivy.to_numpy(cont.b.d), ivy.to_numpy(cont_deepcopy.b.d)) + assert id(cont.a) != id(cont_deepcopy.a) + assert id(cont.b.c) != id(cont_deepcopy.b.c) + assert id(cont.b.d) != id(cont_deepcopy.b.d) + + +def test_container_depth(on_device): + cont_depth1 = Container( + {"a": ivy.array([1], device=on_device), "b": ivy.array([2], device=on_device)} + ) + assert cont_depth1.cont_max_depth == 1 + cont_depth2 = Container( { "a": ivy.array([1], device=on_device), "b": { @@ -308,50 +620,40 @@ def test_container_diff(on_device): }, } ) - container_1 = Container( + assert cont_depth2.cont_max_depth == 2 + cont_depth3 = Container( { "a": ivy.array([1], device=on_device), "b": { - "c": ivy.array([2], device=on_device), - "d": ivy.array([3], device=on_device), + "c": {"d": ivy.array([2], device=on_device)}, + "e": ivy.array([3], device=on_device), }, } ) - container_diff = ivy.Container.cont_diff(container_0, container_1) - assert np.equal(ivy.to_numpy(container_diff.a), np.array([1])) - assert np.equal(ivy.to_numpy(container_diff.b.c), np.array([2])) - assert np.equal(ivy.to_numpy(container_diff.b.d), np.array([3])) - container_diff_diff_only = ivy.Container.cont_diff( - container_0, container_1, mode="diff_only" - ) - assert container_diff_diff_only.cont_to_dict() == {} - container_diff_same_only = ivy.Container.cont_diff( - container_0, container_1, mode="same_only" + assert cont_depth3.cont_max_depth == 3 + cont_depth4 = Container( + { + "a": ivy.array([1], device=on_device), + "b": {"c": {"d": {"e": ivy.array([2], device=on_device)}}}, + } ) - assert container_diff_same_only.cont_to_dict() == container_diff.cont_to_dict() + assert cont_depth4.cont_max_depth == 4 - # all different strings - container_0 = Container({"a": "1", "b": {"c": "2", "d": "3"}}) - container_1 = Container({"a": "4", "b": {"c": "5", "d": "6"}}) - container_diff = ivy.Container.cont_diff(container_0, container_1) - assert container_diff.a.diff_0 == "1" - assert container_diff.a.diff_1 == "4" - assert container_diff.b.c.diff_0 == "2" - assert container_diff.b.c.diff_1 == "5" - assert container_diff.b.d.diff_0 == "3" - assert container_diff.b.d.diff_1 == "6" - container_diff_diff_only = ivy.Container.cont_diff( - container_0, container_1, mode="diff_only" - ) - assert container_diff_diff_only.cont_to_dict() == container_diff.cont_to_dict() - container_diff_same_only = ivy.Container.cont_diff( - container_0, container_1, mode="same_only" - ) - assert container_diff_same_only.cont_to_dict() == {} + +def test_container_dev_str(on_device): + dict_in = { + "a": ivy.array([[[1.0], [2.0], [3.0]]], device=on_device), + "b": { + "c": ivy.array([[[2.0], [4.0], [6.0]]], device=on_device), + "d": ivy.array([[[3.0], [6.0], [9.0]]], device=on_device), + }, + } + container = Container(dict_in) + assert container.cont_dev_str == on_device -def test_container_structural_diff(on_device): - # all different keys or shapes +def test_container_diff(on_device): + # all different arrays container_0 = Container( { "a": ivy.array([1], device=on_device), @@ -363,30 +665,30 @@ def test_container_structural_diff(on_device): ) container_1 = Container( { - "a": ivy.array([[4]], device=on_device), + "a": ivy.array([4], device=on_device), "b": { - "c": ivy.array([[[5]]], device=on_device), - "e": ivy.array([3], device=on_device), + "c": ivy.array([5], device=on_device), + "d": ivy.array([6], device=on_device), }, } ) - container_diff = ivy.Container.cont_structural_diff(container_0, container_1) + container_diff = ivy.Container.cont_diff(container_0, container_1) assert np.equal(ivy.to_numpy(container_diff.a.diff_0), np.array([1])) - assert np.equal(ivy.to_numpy(container_diff.a.diff_1), np.array([[4]])) + assert np.equal(ivy.to_numpy(container_diff.a.diff_1), np.array([4])) assert np.equal(ivy.to_numpy(container_diff.b.c.diff_0), np.array([2])) - assert np.equal(ivy.to_numpy(container_diff.b.c.diff_1), np.array([[[5]]])) + assert np.equal(ivy.to_numpy(container_diff.b.c.diff_1), np.array([5])) assert np.equal(ivy.to_numpy(container_diff.b.d.diff_0), np.array([3])) - assert np.equal(ivy.to_numpy(container_diff.b.e.diff_1), np.array([3])) - container_diff_diff_only = ivy.Container.cont_structural_diff( + assert np.equal(ivy.to_numpy(container_diff.b.d.diff_1), np.array([6])) + container_diff_diff_only = ivy.Container.cont_diff( container_0, container_1, mode="diff_only" ) assert container_diff_diff_only.cont_to_dict() == container_diff.cont_to_dict() - container_diff_same_only = ivy.Container.cont_structural_diff( + container_diff_same_only = ivy.Container.cont_diff( container_0, container_1, mode="same_only" ) assert container_diff_same_only.cont_to_dict() == {} - # some different shapes + # some different arrays container_0 = Container( { "a": ivy.array([1], device=on_device), @@ -398,26 +700,26 @@ def test_container_structural_diff(on_device): ) container_1 = Container( { - "a": ivy.array([4], device=on_device), + "a": ivy.array([1], device=on_device), "b": { - "c": ivy.array([[5]], device=on_device), - "d": ivy.array([6], device=on_device), + "c": ivy.array([5], device=on_device), + "d": ivy.array([3], device=on_device), }, } ) - container_diff = ivy.Container.cont_structural_diff(container_0, container_1) + container_diff = ivy.Container.cont_diff(container_0, container_1) assert np.equal(ivy.to_numpy(container_diff.a), np.array([1])) assert np.equal(ivy.to_numpy(container_diff.b.c.diff_0), np.array([2])) assert np.equal(ivy.to_numpy(container_diff.b.c.diff_1), np.array([5])) assert np.equal(ivy.to_numpy(container_diff.b.d), np.array([3])) - container_diff_diff_only = ivy.Container.cont_structural_diff( + container_diff_diff_only = ivy.Container.cont_diff( container_0, container_1, mode="diff_only" ) assert "a" not in container_diff_diff_only assert "b" in container_diff_diff_only assert "c" in container_diff_diff_only["b"] assert "d" not in container_diff_diff_only["b"] - container_diff_same_only = ivy.Container.cont_structural_diff( + container_diff_same_only = ivy.Container.cont_diff( container_0, container_1, mode="same_only" ) assert "a" in container_diff_same_only @@ -437,25 +739,25 @@ def test_container_structural_diff(on_device): ) container_1 = Container( { - "e": ivy.array([4], device=on_device), + "e": ivy.array([1], device=on_device), "f": { - "g": ivy.array([5], device=on_device), - "h": ivy.array([6], device=on_device), + "g": ivy.array([2], device=on_device), + "h": ivy.array([3], device=on_device), }, } ) - container_diff = ivy.Container.cont_structural_diff(container_0, container_1) + container_diff = ivy.Container.cont_diff(container_0, container_1) assert np.equal(ivy.to_numpy(container_diff.a.diff_0), np.array([1])) assert np.equal(ivy.to_numpy(container_diff.b.diff_0.c), np.array([2])) assert np.equal(ivy.to_numpy(container_diff.b.diff_0.d), np.array([3])) - assert np.equal(ivy.to_numpy(container_diff.e.diff_1), np.array([4])) - assert np.equal(ivy.to_numpy(container_diff.f.diff_1.g), np.array([5])) - assert np.equal(ivy.to_numpy(container_diff.f.diff_1.h), np.array([6])) - container_diff_diff_only = ivy.Container.cont_structural_diff( + assert np.equal(ivy.to_numpy(container_diff.e.diff_1), np.array([1])) + assert np.equal(ivy.to_numpy(container_diff.f.diff_1.g), np.array([2])) + assert np.equal(ivy.to_numpy(container_diff.f.diff_1.h), np.array([3])) + container_diff_diff_only = ivy.Container.cont_diff( container_0, container_1, mode="diff_only" ) assert container_diff_diff_only.cont_to_dict() == container_diff.cont_to_dict() - container_diff_same_only = ivy.Container.cont_structural_diff( + container_diff_same_only = ivy.Container.cont_diff( container_0, container_1, mode="same_only" ) assert container_diff_same_only.cont_to_dict() == {} @@ -472,19 +774,19 @@ def test_container_structural_diff(on_device): ) container_1 = Container( { - "a": ivy.array([4], device=on_device), + "a": ivy.array([1], device=on_device), "b": { - "c": ivy.array([5], device=on_device), - "e": ivy.array([6], device=on_device), + "c": ivy.array([2], device=on_device), + "e": ivy.array([3], device=on_device), }, } ) - container_diff = ivy.Container.cont_structural_diff(container_0, container_1) + container_diff = ivy.Container.cont_diff(container_0, container_1) assert np.equal(ivy.to_numpy(container_diff.a), np.array([1])) assert np.equal(ivy.to_numpy(container_diff.b.c), np.array([2])) assert np.equal(ivy.to_numpy(container_diff.b.d.diff_0), np.array([3])) - assert np.equal(ivy.to_numpy(container_diff.b.e.diff_1), np.array([6])) - container_diff_diff_only = ivy.Container.cont_structural_diff( + assert np.equal(ivy.to_numpy(container_diff.b.e.diff_1), np.array([3])) + container_diff_diff_only = ivy.Container.cont_diff( container_0, container_1, mode="diff_only" ) assert "a" not in container_diff_diff_only @@ -492,7 +794,7 @@ def test_container_structural_diff(on_device): assert "c" not in container_diff_diff_only["b"] assert "d" in container_diff_diff_only["b"] assert "e" in container_diff_diff_only["b"] - container_diff_same_only = ivy.Container.cont_structural_diff( + container_diff_same_only = ivy.Container.cont_diff( container_0, container_1, mode="same_only" ) assert "a" in container_diff_same_only @@ -501,7 +803,7 @@ def test_container_structural_diff(on_device): assert "d" not in container_diff_same_only["b"] assert "e" not in container_diff_same_only["b"] - # all same + # same containers container_0 = Container( { "a": ivy.array([1], device=on_device), @@ -513,28 +815,92 @@ def test_container_structural_diff(on_device): ) container_1 = Container( { - "a": ivy.array([4], device=on_device), + "a": ivy.array([1], device=on_device), "b": { - "c": ivy.array([5], device=on_device), - "d": ivy.array([6], device=on_device), + "c": ivy.array([2], device=on_device), + "d": ivy.array([3], device=on_device), }, } ) - container_diff = ivy.Container.cont_structural_diff(container_0, container_1) + container_diff = ivy.Container.cont_diff(container_0, container_1) assert np.equal(ivy.to_numpy(container_diff.a), np.array([1])) assert np.equal(ivy.to_numpy(container_diff.b.c), np.array([2])) assert np.equal(ivy.to_numpy(container_diff.b.d), np.array([3])) - container_diff_diff_only = ivy.Container.cont_structural_diff( + container_diff_diff_only = ivy.Container.cont_diff( container_0, container_1, mode="diff_only" ) assert container_diff_diff_only.cont_to_dict() == {} - container_diff_same_only = ivy.Container.cont_structural_diff( + container_diff_same_only = ivy.Container.cont_diff( container_0, container_1, mode="same_only" ) assert container_diff_same_only.cont_to_dict() == container_diff.cont_to_dict() + # all different strings + container_0 = Container({"a": "1", "b": {"c": "2", "d": "3"}}) + container_1 = Container({"a": "4", "b": {"c": "5", "d": "6"}}) + container_diff = ivy.Container.cont_diff(container_0, container_1) + assert container_diff.a.diff_0 == "1" + assert container_diff.a.diff_1 == "4" + assert container_diff.b.c.diff_0 == "2" + assert container_diff.b.c.diff_1 == "5" + assert container_diff.b.d.diff_0 == "3" + assert container_diff.b.d.diff_1 == "6" + container_diff_diff_only = ivy.Container.cont_diff( + container_0, container_1, mode="diff_only" + ) + assert container_diff_diff_only.cont_to_dict() == container_diff.cont_to_dict() + container_diff_same_only = ivy.Container.cont_diff( + container_0, container_1, mode="same_only" + ) + assert container_diff_same_only.cont_to_dict() == {} + -def test_container_from_dict(on_device): +def test_container_duplicate_array_keychains(on_device): + arr1 = ivy.array([1], device=on_device) + arr2 = ivy.array([2], device=on_device) + container0 = Container({"a": arr1, "b": {"c": arr1, "d": arr2}}) + container1 = Container( + { + "a": ivy.array([1], device=on_device), + "b": { + "c": ivy.array([1], device=on_device), + "d": ivy.array([2], device=on_device), + }, + } + ) + res = ivy.Container.cont_duplicate_array_keychains(container0) + assert res == (("a", "b/c"),) + res = ivy.Container.cont_duplicate_array_keychains(container1) + assert res == () + + +def test_container_find_sub_container(on_device): + arr1 = ivy.array([1], device=on_device) + arr2 = ivy.array([2], device=on_device) + arr3 = ivy.array([3], device=on_device) + dict_in = {"a": arr1, "b": {"c": arr2, "d": arr3}} + top_cont = Container(dict_in) + + # full + sub_cont = Container(dict_in["b"]) + assert sub_cont in top_cont + found_kc = top_cont.cont_find_sub_container(sub_cont) + assert found_kc == "b" + found_kc = top_cont.cont_find_sub_container(top_cont) + assert found_kc == "" + + # partial + partial_sub_cont = Container({"d": arr3}) + found_kc = top_cont.cont_find_sub_container(partial_sub_cont, partial=True) + assert found_kc == "b" + assert partial_sub_cont.cont_find_sub_container(top_cont, partial=True) is False + partial_sub_cont = Container({"b": {"d": arr3}}) + found_kc = top_cont.cont_find_sub_container(partial_sub_cont, partial=True) + assert found_kc == "" + assert partial_sub_cont.cont_find_sub_container(top_cont, partial=True) is False + + +def test_container_find_sub_structure(on_device): dict_in = { "a": ivy.array([1], device=on_device), "b": { @@ -542,279 +908,96 @@ def test_container_from_dict(on_device): "d": ivy.array([3], device=on_device), }, } - container = Container(dict_in) - assert np.allclose(ivy.to_numpy(container["a"]), np.array([1])) - assert np.allclose(ivy.to_numpy(container.a), np.array([1])) - assert np.allclose(ivy.to_numpy(container["b"]["c"]), np.array([2])) - assert np.allclose(ivy.to_numpy(container.b.c), np.array([2])) - assert np.allclose(ivy.to_numpy(container["b"]["d"]), np.array([3])) - assert np.allclose(ivy.to_numpy(container.b.d), np.array([3])) - + top_cont = Container(dict_in) -def test_container_depth(on_device): - cont_depth1 = Container( - {"a": ivy.array([1], device=on_device), "b": ivy.array([2], device=on_device)} - ) - assert cont_depth1.cont_max_depth == 1 - cont_depth2 = Container( - { - "a": ivy.array([1], device=on_device), - "b": { - "c": ivy.array([2], device=on_device), - "d": ivy.array([3], device=on_device), - }, - } + # full + sub_cont = Container( + {"c": ivy.array([4], device=on_device), "d": ivy.array([5], device=on_device)} ) - assert cont_depth2.cont_max_depth == 2 - cont_depth3 = Container( + assert not top_cont.cont_find_sub_container(sub_cont) + found_kc = top_cont.cont_find_sub_structure(sub_cont) + assert found_kc == "b" + found_kc = top_cont.cont_find_sub_structure(top_cont) + assert found_kc == "" + + # partial + partial_sub_cont = Container({"d": ivy.array([5], device=on_device)}) + found_kc = top_cont.cont_find_sub_structure(partial_sub_cont, partial=True) + assert found_kc == "b" + partial_sub_cont = Container({"b": {"d": ivy.array([5], device=on_device)}}) + found_kc = top_cont.cont_find_sub_structure(partial_sub_cont, partial=True) + assert found_kc == "" + + +def test_container_flatten_key_chains(on_device): + container = Container( { "a": ivy.array([1], device=on_device), "b": { "c": {"d": ivy.array([2], device=on_device)}, - "e": ivy.array([3], device=on_device), + "e": {"f": {"g": ivy.array([3], device=on_device)}}, }, } ) - assert cont_depth3.cont_max_depth == 3 - cont_depth4 = Container( - { - "a": ivy.array([1], device=on_device), - "b": {"c": {"d": {"e": ivy.array([2], device=on_device)}}}, - } - ) - assert cont_depth4.cont_max_depth == 4 + # full + container_flat = container.cont_flatten_key_chains() + assert np.allclose(ivy.to_numpy(container_flat["a"]), np.array([[1]])) + assert np.allclose(ivy.to_numpy(container_flat.a), np.array([[1]])) + assert np.allclose(ivy.to_numpy(container_flat["b__c__d"]), np.array([[2]])) + assert np.allclose(ivy.to_numpy(container_flat.b__c__d), np.array([[2]])) + assert np.allclose(ivy.to_numpy(container_flat["b__e__f__g"]), np.array([[3]])) + assert np.allclose(ivy.to_numpy(container_flat.b__e__f__g), np.array([[3]])) -@pytest.mark.parametrize("inplace", [True, False]) -def test_container_cutoff_at_depth(inplace, on_device): - # values - a_val = ivy.array([1], device=on_device) - bcde_val = ivy.array([2], device=on_device) - - # depth 1 - cont = Container({"a": a_val, "b": {"c": {"d": {"e": bcde_val}}}}) - cont_cutoff = cont.cont_cutoff_at_depth(1, inplace=inplace) - if inplace: - cont_cutoff = cont - assert np.allclose(ivy.to_numpy(cont_cutoff.a), ivy.to_numpy(a_val)) - assert not cont_cutoff.b + # above height 1 + container_flat = container.cont_flatten_key_chains(above_height=1) + assert np.allclose(ivy.to_numpy(container_flat["a"]), np.array([[1]])) + assert np.allclose(ivy.to_numpy(container_flat.a), np.array([[1]])) + assert np.allclose(ivy.to_numpy(container_flat["b__c"]["d"]), np.array([[2]])) + assert np.allclose(ivy.to_numpy(container_flat.b__c.d), np.array([[2]])) + assert np.allclose(ivy.to_numpy(container_flat["b__e__f"]["g"]), np.array([[3]])) + assert np.allclose(ivy.to_numpy(container_flat.b__e__f.g), np.array([[3]])) - # depth 2 - cont = Container({"a": a_val, "b": {"c": {"d": {"e": bcde_val}}}}) - cont_cutoff = cont.cont_cutoff_at_depth(2, inplace=inplace) - if inplace: - cont_cutoff = cont - assert np.allclose(ivy.to_numpy(cont_cutoff.a), ivy.to_numpy(a_val)) - assert not cont_cutoff.b.c - - # depth 3 - cont = Container({"a": a_val, "b": {"c": {"d": {"e": bcde_val}}}}) - cont_cutoff = cont.cont_cutoff_at_depth(3, inplace=inplace) - if inplace: - cont_cutoff = cont - assert np.allclose(ivy.to_numpy(cont_cutoff.a), ivy.to_numpy(a_val)) - assert not cont_cutoff.b.c.d - - # depth 4 - cont = Container({"a": a_val, "b": {"c": {"d": {"e": bcde_val}}}}) - cont_cutoff = cont.cont_cutoff_at_depth(4, inplace=inplace) - if inplace: - cont_cutoff = cont - assert np.allclose(ivy.to_numpy(cont_cutoff.a), ivy.to_numpy(a_val)) - assert np.allclose(ivy.to_numpy(cont_cutoff.b.c.d.e), ivy.to_numpy(bcde_val)) - - -@pytest.mark.parametrize("inplace", [True, False]) -def test_container_cutoff_at_height(inplace, on_device): - # values - d_val = ivy.array([2], device=on_device) - e_val = ivy.array([3], device=on_device) - - # height 0 - cont = Container({"a": {"c": {"d": d_val}}, "b": {"c": {"d": {"e": e_val}}}}) - cont_cutoff = cont.cont_cutoff_at_height(0, inplace=inplace) - if inplace: - cont_cutoff = cont - assert np.allclose(ivy.to_numpy(cont_cutoff.a.c.d), ivy.to_numpy(d_val)) - assert np.allclose(ivy.to_numpy(cont_cutoff.b.c.d.e), ivy.to_numpy(e_val)) - - # height 1 - cont = Container({"a": {"c": {"d": d_val}}, "b": {"c": {"d": {"e": e_val}}}}) - cont_cutoff = cont.cont_cutoff_at_height(1, inplace=inplace) - if inplace: - cont_cutoff = cont - assert not cont_cutoff.a.c - assert not cont_cutoff.b.c.d - - # height 2 - cont = Container({"a": {"c": {"d": d_val}}, "b": {"c": {"d": {"e": e_val}}}}) - cont_cutoff = cont.cont_cutoff_at_height(2, inplace=inplace) - if inplace: - cont_cutoff = cont - assert not cont_cutoff.a - assert not cont_cutoff.b.c - - # height 3 - cont = Container({"a": {"c": {"d": d_val}}, "b": {"c": {"d": {"e": e_val}}}}) - cont_cutoff = cont.cont_cutoff_at_height(3, inplace=inplace) - if inplace: - cont_cutoff = cont - assert not cont_cutoff.a - assert not cont_cutoff.b - - # height 4 - cont = Container({"a": {"c": {"d": d_val}}, "b": {"c": {"d": {"e": e_val}}}}) - cont_cutoff = cont.cont_cutoff_at_height(4, inplace=inplace) - if inplace: - cont_cutoff = cont - assert not cont_cutoff - - -@pytest.mark.parametrize("str_slice", [True, False]) -def test_container_slice_keys(str_slice, on_device): - # values - a_val = ivy.array([1], device=on_device) - b_val = ivy.array([2], device=on_device) - c_val = ivy.array([3], device=on_device) - d_val = ivy.array([4], device=on_device) - e_val = ivy.array([5], device=on_device) - - # slice - if str_slice: - slc = "b:d" - else: - slc = slice(1, 4, 1) - - # without dict - cont = Container({"a": a_val, "b": b_val, "c": c_val, "d": d_val, "e": e_val}) - cont_sliced = cont.cont_slice_keys(slc) - assert "a" not in cont_sliced - assert np.allclose(ivy.to_numpy(cont_sliced.b), ivy.to_numpy(b_val)) - assert np.allclose(ivy.to_numpy(cont_sliced.c), ivy.to_numpy(c_val)) - assert np.allclose(ivy.to_numpy(cont_sliced.d), ivy.to_numpy(d_val)) - assert "e" not in cont_sliced - - # with dict, depth 0 - sub_cont = Container({"a": a_val, "b": b_val, "c": c_val, "d": d_val, "e": e_val}) - cont = Container( - {"a": sub_cont, "b": sub_cont, "c": sub_cont, "d": sub_cont, "e": sub_cont} - ) - cont_sliced = cont.cont_slice_keys({0: slc}) - assert "a" not in cont_sliced - assert Container.cont_identical([cont_sliced.b, sub_cont]) - assert Container.cont_identical([cont_sliced.c, sub_cont]) - assert Container.cont_identical([cont_sliced.d, sub_cont]) - assert "e" not in cont_sliced - - # with dict, depth 1 - sub_cont = Container({"a": a_val, "b": b_val, "c": c_val, "d": d_val, "e": e_val}) - sub_sub_cont = Container({"b": b_val, "c": c_val, "d": d_val}) - cont = Container( - {"a": sub_cont, "b": sub_cont, "c": sub_cont, "d": sub_cont, "e": sub_cont} - ) - cont_sliced = cont.cont_slice_keys({1: slc}) - assert Container.cont_identical([cont_sliced.a, sub_sub_cont]) - assert Container.cont_identical([cont_sliced.b, sub_sub_cont]) - assert Container.cont_identical([cont_sliced.c, sub_sub_cont]) - assert Container.cont_identical([cont_sliced.d, sub_sub_cont]) - assert Container.cont_identical([cont_sliced.e, sub_sub_cont]) - - # with dict, depth 0, 1 - sub_cont = Container({"a": a_val, "b": b_val, "c": c_val, "d": d_val, "e": e_val}) - sub_sub_cont = Container({"b": b_val, "c": c_val, "d": d_val}) - cont = Container( - {"a": sub_cont, "b": sub_cont, "c": sub_cont, "d": sub_cont, "e": sub_cont} - ) - cont_sliced = cont.cont_slice_keys({0: slc, 1: slc}) - assert "a" not in cont_sliced - assert Container.cont_identical([cont_sliced.b, sub_sub_cont]) - assert Container.cont_identical([cont_sliced.c, sub_sub_cont]) - assert Container.cont_identical([cont_sliced.d, sub_sub_cont]) - assert "e" not in cont_sliced + # below depth 1 + container_flat = container.cont_flatten_key_chains(below_depth=1) + assert np.allclose(ivy.to_numpy(container_flat["a"]), np.array([[1]])) + assert np.allclose(ivy.to_numpy(container_flat.a), np.array([[1]])) + assert np.allclose(ivy.to_numpy(container_flat["b"]["c__d"]), np.array([[2]])) + assert np.allclose(ivy.to_numpy(container_flat.b.c__d), np.array([[2]])) + assert np.allclose(ivy.to_numpy(container_flat["b"]["e__f__g"]), np.array([[3]])) + assert np.allclose(ivy.to_numpy(container_flat.b.e__f__g), np.array([[3]])) - # all depths - sub_cont = Container({"a": a_val, "b": b_val, "c": c_val, "d": d_val, "e": e_val}) - sub_sub_cont = Container({"b": b_val, "c": c_val, "d": d_val}) - cont = Container( - {"a": sub_cont, "b": sub_cont, "c": sub_cont, "d": sub_cont, "e": sub_cont} - ) - cont_sliced = cont.cont_slice_keys(slc, all_depths=True) - assert "a" not in cont_sliced - assert Container.cont_identical([cont_sliced.b, sub_sub_cont]) - assert Container.cont_identical([cont_sliced.c, sub_sub_cont]) - assert Container.cont_identical([cont_sliced.d, sub_sub_cont]) - assert "e" not in cont_sliced + # above height 1, below depth 1 + container_flat = container.cont_flatten_key_chains(above_height=1, below_depth=1) + assert np.allclose(ivy.to_numpy(container_flat["a"]), np.array([[1]])) + assert np.allclose(ivy.to_numpy(container_flat.a), np.array([[1]])) + assert np.allclose(ivy.to_numpy(container_flat["b"]["c"]["d"]), np.array([[2]])) + assert np.allclose(ivy.to_numpy(container_flat.b.c.d), np.array([[2]])) + assert np.allclose(ivy.to_numpy(container_flat["b"]["e__f"]["g"]), np.array([[3]])) + assert np.allclose(ivy.to_numpy(container_flat.b.e__f.g), np.array([[3]])) -def test_container_show(on_device): +def test_container_format_key_chains(on_device): dict_in = { - "a": ivy.array([1], device=on_device), - "b": { + "_a": ivy.array([1], device=on_device), + "b ": { "c": ivy.array([2], device=on_device), - "d": ivy.array([3], device=on_device), + "d-": ivy.array([3], device=on_device), }, } cont = Container(dict_in) - print(cont) - cont.cont_show() - - -def test_container_find_sub_container(on_device): - arr1 = ivy.array([1], device=on_device) - arr2 = ivy.array([2], device=on_device) - arr3 = ivy.array([3], device=on_device) - dict_in = {"a": arr1, "b": {"c": arr2, "d": arr3}} - top_cont = Container(dict_in) - - # full - sub_cont = Container(dict_in["b"]) - assert sub_cont in top_cont - found_kc = top_cont.cont_find_sub_container(sub_cont) - assert found_kc == "b" - found_kc = top_cont.cont_find_sub_container(top_cont) - assert found_kc == "" - - # partial - partial_sub_cont = Container({"d": arr3}) - found_kc = top_cont.cont_find_sub_container(partial_sub_cont, partial=True) - assert found_kc == "b" - assert partial_sub_cont.cont_find_sub_container(top_cont, partial=True) is False - partial_sub_cont = Container({"b": {"d": arr3}}) - found_kc = top_cont.cont_find_sub_container(partial_sub_cont, partial=True) - assert found_kc == "" - assert partial_sub_cont.cont_find_sub_container(top_cont, partial=True) is False - - -def test_container_find_sub_structure(on_device): - dict_in = { - "a": ivy.array([1], device=on_device), - "b": { - "c": ivy.array([2], device=on_device), - "d": ivy.array([3], device=on_device), - }, - } - top_cont = Container(dict_in) - - # full - sub_cont = Container( - {"c": ivy.array([4], device=on_device), "d": ivy.array([5], device=on_device)} + cont_formatted = cont.cont_format_key_chains( + lambda s: s.replace("_", "").replace(" ", "").replace("-", "") ) - assert not top_cont.cont_find_sub_container(sub_cont) - found_kc = top_cont.cont_find_sub_structure(sub_cont) - assert found_kc == "b" - found_kc = top_cont.cont_find_sub_structure(top_cont) - assert found_kc == "" - - # partial - partial_sub_cont = Container({"d": ivy.array([5], device=on_device)}) - found_kc = top_cont.cont_find_sub_structure(partial_sub_cont, partial=True) - assert found_kc == "b" - partial_sub_cont = Container({"b": {"d": ivy.array([5], device=on_device)}}) - found_kc = top_cont.cont_find_sub_structure(partial_sub_cont, partial=True) - assert found_kc == "" + assert np.allclose(ivy.to_numpy(cont_formatted["a"]), np.array([1])) + assert np.allclose(ivy.to_numpy(cont_formatted.a), np.array([1])) + assert np.allclose(ivy.to_numpy(cont_formatted["b"]["c"]), np.array([2])) + assert np.allclose(ivy.to_numpy(cont_formatted.b.c), np.array([2])) + assert np.allclose(ivy.to_numpy(cont_formatted["b"]["d"]), np.array([3])) + assert np.allclose(ivy.to_numpy(cont_formatted.b.d), np.array([3])) -def test_container_show_sub_container(on_device): +def test_container_from_dict(on_device): dict_in = { "a": ivy.array([1], device=on_device), "b": { @@ -822,10 +1005,13 @@ def test_container_show_sub_container(on_device): "d": ivy.array([3], device=on_device), }, } - top_cont = Container(dict_in) - sub_cont = Container(dict_in["b"]) - top_cont.cont_show_sub_container("b") - top_cont.cont_show_sub_container(sub_cont) + container = Container(dict_in) + assert np.allclose(ivy.to_numpy(container["a"]), np.array([1])) + assert np.allclose(ivy.to_numpy(container.a), np.array([1])) + assert np.allclose(ivy.to_numpy(container["b"]["c"]), np.array([2])) + assert np.allclose(ivy.to_numpy(container.b.c), np.array([2])) + assert np.allclose(ivy.to_numpy(container["b"]["d"]), np.array([3])) + assert np.allclose(ivy.to_numpy(container.b.d), np.array([3])) def test_container_from_dict_w_cont_types(on_device): @@ -852,17 +1038,36 @@ def test_container_from_dict_w_cont_types(on_device): assert np.allclose(ivy.to_numpy(container.b.d), np.array([3])) -def test_container_from_kwargs(on_device): - container = Container( - a=ivy.array([1], device=on_device), - b={ +def test_container_from_flat_list(on_device): + dict_in = { + "a": ivy.array([1], device=on_device), + "b": { "c": ivy.array([2], device=on_device), "d": ivy.array([3], device=on_device), }, - ) - assert np.allclose(ivy.to_numpy(container["a"]), np.array([1])) - assert np.allclose(ivy.to_numpy(container.a), np.array([1])) - assert np.allclose(ivy.to_numpy(container["b"]["c"]), np.array([2])) + } + container = Container(dict_in) + flat_list = [4, 5, 6] + container = container.cont_from_flat_list(flat_list) + assert np.allclose(ivy.to_numpy(container["a"]), np.array([4])) + assert np.allclose(ivy.to_numpy(container.a), np.array([4])) + assert np.allclose(ivy.to_numpy(container["b"]["c"]), np.array([5])) + assert np.allclose(ivy.to_numpy(container.b.c), np.array([5])) + assert np.allclose(ivy.to_numpy(container["b"]["d"]), np.array([6])) + assert np.allclose(ivy.to_numpy(container.b.d), np.array([6])) + + +def test_container_from_kwargs(on_device): + container = Container( + a=ivy.array([1], device=on_device), + b={ + "c": ivy.array([2], device=on_device), + "d": ivy.array([3], device=on_device), + }, + ) + assert np.allclose(ivy.to_numpy(container["a"]), np.array([1])) + assert np.allclose(ivy.to_numpy(container.a), np.array([1])) + assert np.allclose(ivy.to_numpy(container["b"]["c"]), np.array([2])) assert np.allclose(ivy.to_numpy(container.b.c), np.array([2])) assert np.allclose(ivy.to_numpy(container["b"]["d"]), np.array([3])) assert np.allclose(ivy.to_numpy(container.b.d), np.array([3])) @@ -882,6 +1087,110 @@ def test_container_from_list(on_device): assert np.allclose(ivy.to_numpy(container.it_1.it_1), np.array([3])) +@pytest.mark.skip("Prevents PyTest from Terminating.") +def test_container_from_queues(on_device): + if "gpu" in on_device: + # Cannot re-initialize CUDA in forked subprocess. 'spawn' + # start method must be used. + pytest.skip() + + if ivy.gpu_is_available() and ivy.current_backend_str() == "jax": + # Not found a way to set default on_device for JAX, and this causes + # issues with multiprocessing and CUDA, even when device=cpu + # ToDo: find a fix for this problem ^^ + pytest.skip() + + def worker_fn(in_queue, out_queue, load_size, worker_id): + keep_going = True + while keep_going: + try: + keep_going = in_queue.get(timeout=0.1) + except queue.Empty: + continue + out_queue.put( + { + "a": [ + ivy.to_native(ivy.array([1.0, 2.0, 3.0], device=on_device)) + * worker_id + ] * load_size + } + ) + + workers = list() + in_queues = list() + out_queues = list() + queue_load_sizes = [1, 2, 1] + for i, queue_load_size in enumerate(queue_load_sizes): + input_queue = multiprocessing.Queue() + output_queue = multiprocessing.Queue() + worker = multiprocessing.Process( + target=worker_fn, args=(input_queue, output_queue, queue_load_size, i + 1) + ) + worker.start() + in_queues.append(input_queue) + out_queues.append(output_queue) + workers.append(worker) + + container = Container( + queues=out_queues, queue_load_sizes=queue_load_sizes, queue_timeout=0.25 + ) + + # queue 0 + queue_was_empty = False + try: + container[0] + except queue.Empty: + queue_was_empty = True + assert queue_was_empty + in_queues[0].put(True) + assert np.allclose(ivy.to_numpy(container[0].a), np.array([1.0, 2.0, 3.0])) + assert np.allclose(ivy.to_numpy(container[0].a), np.array([1.0, 2.0, 3.0])) + + # queue 1 + queue_was_empty = False + try: + container[1] + except queue.Empty: + queue_was_empty = True + assert queue_was_empty + queue_was_empty = False + try: + container[2] + except queue.Empty: + queue_was_empty = True + assert queue_was_empty + in_queues[1].put(True) + assert np.allclose(ivy.to_numpy(container[1].a), np.array([2.0, 4.0, 6.0])) + assert np.allclose(ivy.to_numpy(container[1].a), np.array([2.0, 4.0, 6.0])) + assert np.allclose(ivy.to_numpy(container[2].a), np.array([2.0, 4.0, 6.0])) + assert np.allclose(ivy.to_numpy(container[2].a), np.array([2.0, 4.0, 6.0])) + + # queue 2 + queue_was_empty = False + try: + container[3] + except queue.Empty: + queue_was_empty = True + assert queue_was_empty + in_queues[2].put(True) + assert np.allclose(ivy.to_numpy(container[3].a), np.array([3.0, 6.0, 9.0])) + assert np.allclose(ivy.to_numpy(container[3].a), np.array([3.0, 6.0, 9.0])) + + # stop workers + in_queues[0].put(False) + in_queues[1].put(False) + in_queues[2].put(False) + in_queues[0].close() + in_queues[1].close() + in_queues[2].close() + + # join workers + for worker in workers: + worker.join() + + del container + + def test_container_from_tuple(on_device): tuple_in = ( ivy.array([1], device=on_device), @@ -896,381 +1205,137 @@ def test_container_from_tuple(on_device): assert np.allclose(ivy.to_numpy(container.it_1.it_1), np.array([3])) -def test_container_to_raw(on_device): - tuple_in = ( - ivy.array([1], device=on_device), - (ivy.array([2], device=on_device), ivy.array([3], device=on_device)), - ) - container = Container(tuple_in, types_to_iteratively_nest=[tuple]) - raw = container.cont_to_raw() - assert np.allclose(ivy.to_numpy(raw[0]), np.array([1])) - assert np.allclose(ivy.to_numpy(raw[1][0]), np.array([2])) - assert np.allclose(ivy.to_numpy(raw[1][1]), np.array([3])) - - -def test_container_as_bools(on_device): - dict_in = {"a": ivy.array([1], device=on_device), "b": {"c": [], "d": True}} - container = Container(dict_in) - - container_bools = container.cont_as_bools() - assert container_bools["a"] is True - assert container_bools.a is True - assert container_bools["b"]["c"] is False - assert container_bools.b.c is False - assert container_bools["b"]["d"] is True - assert container_bools.b.d is True - - -def test_container_all_true(on_device): - assert not Container( - {"a": ivy.array([1], device=on_device), "b": {"c": [], "d": True}} - ).cont_all_true() - assert Container( - {"a": ivy.array([1], device=on_device), "b": {"c": [1], "d": True}} - ).cont_all_true() - # noinspection PyBroadException - try: - assert Container( - {"a": ivy.array([1], device=on_device), "b": {"c": [1], "d": True}} - ).cont_all_true(assert_is_bool=True) - error_raised = False - except IvyException: - error_raised = True - assert error_raised - - -def test_container_all_false(on_device): - assert Container({"a": False, "b": {"c": [], "d": 0}}).cont_all_false() - assert not Container({"a": False, "b": {"c": [1], "d": 0}}).cont_all_false() - # noinspection PyBroadException - try: - assert Container( - {"a": ivy.array([1], device=on_device), "b": {"c": [1], "d": True}} - ).cont_all_false(assert_is_bool=True) - error_raised = False - except IvyException: - error_raised = True - assert error_raised - - -def test_container_unstack_conts(on_device): +def test_container_has_key(on_device): dict_in = { - "a": ivy.array([[1], [2], [3]], device=on_device), + "a": ivy.array([1], device=on_device), "b": { - "c": ivy.array([[2], [3], [4]], device=on_device), - "d": ivy.array([[3], [4], [5]], device=on_device), + "c": ivy.array([2], device=on_device), + "d": ivy.array([3], device=on_device), }, } container = Container(dict_in) - - # without key_chains specification - container_unstacked = container.cont_unstack_conts(0) - for cont, a, bc, bd in zip(container_unstacked, [1, 2, 3], [2, 3, 4], [3, 4, 5]): - assert np.array_equal(ivy.to_numpy(cont["a"]), np.array([a])) - assert np.array_equal(ivy.to_numpy(cont.a), np.array([a])) - assert np.array_equal(ivy.to_numpy(cont["b"]["c"]), np.array([bc])) - assert np.array_equal(ivy.to_numpy(cont.b.c), np.array([bc])) - assert np.array_equal(ivy.to_numpy(cont["b"]["d"]), np.array([bd])) - assert np.array_equal(ivy.to_numpy(cont.b.d), np.array([bd])) + assert container.cont_has_key("a") # noqa + assert container.cont_has_key("b") # noqa + assert container.cont_has_key("c") # noqa + assert container.cont_has_key("d") # noqa + assert not container.cont_has_key("e") # noqa + assert not container.cont_has_key("f") # noqa -def test_container_split_conts(on_device): +def test_container_has_key_chain(on_device): dict_in = { - "a": ivy.array([[1], [2], [3]], device=on_device), + "a": ivy.array([1], device=on_device), "b": { - "c": ivy.array([[2], [3], [4]], device=on_device), - "d": ivy.array([[3], [4], [5]], device=on_device), + "c": ivy.array([2], device=on_device), + "d": ivy.array([3], device=on_device), }, } container = Container(dict_in) + assert container.cont_has_key_chain("a") + assert container.cont_has_key_chain("b") + assert container.cont_has_key_chain("b/c") + assert container.cont_has_key_chain("b/d") + assert not container.cont_has_key_chain("b/e") + assert not container.cont_has_key_chain("c") + +def test_container_identical(on_device): # without key_chains specification - container_split = container.split_conts(1, -1) - for cont, a, bc, bd in zip(container_split, [1, 2, 3], [2, 3, 4], [3, 4, 5]): - assert np.array_equal(ivy.to_numpy(cont["a"])[0], np.array([a])) - assert np.array_equal(ivy.to_numpy(cont.a)[0], np.array([a])) - assert np.array_equal(ivy.to_numpy(cont["b"]["c"])[0], np.array([bc])) - assert np.array_equal(ivy.to_numpy(cont.b.c)[0], np.array([bc])) - assert np.array_equal(ivy.to_numpy(cont["b"]["d"])[0], np.array([bd])) - assert np.array_equal(ivy.to_numpy(cont.b.d)[0], np.array([bd])) + arr1 = ivy.array([1], device=on_device) + arr2 = ivy.array([2], device=on_device) + arr3 = ivy.array([3], device=on_device) + container0 = Container({"a": arr1, "b": {"c": arr2, "d": arr3}}) + container1 = Container({"a": arr1, "b": {"c": arr2, "d": arr3}}) + container2 = Container( + { + "a": ivy.array([1], device=on_device), + "b": { + "c": ivy.array([2], device=on_device), + "d": ivy.array([3], device=on_device), + }, + } + ) + container3 = Container({"b": {"d": arr3}}) + container4 = Container({"d": arr3}) + # the same + assert ivy.Container.cont_identical([container0, container1]) + assert ivy.Container.cont_identical([container1, container0]) -def test_container_num_arrays(on_device): - dict_in = { - "a": ivy.array([[0.0, 1.0, 2.0, 3.0]], device=on_device), - "b": { - "c": ivy.array([[5.0, 10.0, 15.0, 20.0]], device=on_device), - "d": ivy.array([[10.0, 9.0, 8.0, 7.0]], device=on_device), - }, - } - container = Container(dict_in) - assert container.cont_num_arrays() == 3 - dict_in = { - "a": ivy.array([[0.0, 1.0, 2.0, 3.0]], device=on_device), - "b": { - "c": _variable(ivy.array([[5.0, 10.0, 15.0, 20.0]], device=on_device)), - "d": ivy.array([[10.0, 9.0, 8.0, 7.0]], device=on_device), - }, - } - container = Container(dict_in) - assert ( - container.cont_num_arrays() == 3 - if ivy.current_backend_str() in ("numpy", "jax") - else 2 - ) - - -def test_container_size_ordered_arrays(on_device): - dict_in = { - "a": ivy.array([[0.0, 1.0, 2.0, 3.0]], device=on_device), - "b": { - "c": ivy.array([[5.0, 10.0]], device=on_device), - "d": ivy.array([[10.0, 9.0, 8.0]], device=on_device), - }, - } - container = Container(dict_in) - size_ordered = container.cont_size_ordered_arrays() - assert np.allclose(ivy.to_numpy(size_ordered.a), np.array([[0.0, 1.0, 2.0, 3.0]])) - assert np.allclose(ivy.to_numpy(size_ordered.b__c), np.array([[5.0, 10.0]])) - assert np.allclose(ivy.to_numpy(size_ordered.b__d), np.array([[10.0, 9.0, 8.0]])) - for v, arr in zip( - size_ordered.values(), - [ - np.array([[5.0, 10.0]]), - np.array([[10.0, 9.0, 8.0]]), - np.array([[0.0, 1.0, 2.0, 3.0]]), - ], - ): - assert np.allclose(ivy.to_numpy(v), arr) - - -def test_container_has_key(on_device): - dict_in = { - "a": ivy.array([1], device=on_device), - "b": { - "c": ivy.array([2], device=on_device), - "d": ivy.array([3], device=on_device), - }, - } - container = Container(dict_in) - assert container.cont_has_key("a") # noqa - assert container.cont_has_key("b") # noqa - assert container.cont_has_key("c") # noqa - assert container.cont_has_key("d") # noqa - assert not container.cont_has_key("e") # noqa - assert not container.cont_has_key("f") # noqa - - -def test_container_has_key_chain(on_device): - dict_in = { - "a": ivy.array([1], device=on_device), - "b": { - "c": ivy.array([2], device=on_device), - "d": ivy.array([3], device=on_device), - }, - } - container = Container(dict_in) - assert container.cont_has_key_chain("a") - assert container.cont_has_key_chain("b") - assert container.cont_has_key_chain("b/c") - assert container.cont_has_key_chain("b/d") - assert not container.cont_has_key_chain("b/e") - assert not container.cont_has_key_chain("c") - - -def test_container_at_keys(on_device): - dict_in = { - "a": ivy.array([1], device=on_device), - "b": { - "c": ivy.array([2], device=on_device), - "d": ivy.array([3], device=on_device), - }, - } - container = Container(dict_in) - new_container = container.cont_at_keys(["a", "c"]) - assert np.allclose(ivy.to_numpy(new_container["a"]), np.array([1])) - assert np.allclose(ivy.to_numpy(new_container["b"]["c"]), np.array([2])) - assert "d" not in new_container["b"] - new_container = container.cont_at_keys("c") - assert "a" not in new_container - assert np.allclose(ivy.to_numpy(new_container["b"]["c"]), np.array([2])) - assert "d" not in new_container["b"] - new_container = container.cont_at_keys(["b"]) - assert "a" not in new_container - assert np.allclose(ivy.to_numpy(new_container["b"]["c"]), np.array([2])) - assert np.allclose(ivy.to_numpy(new_container["b"]["d"]), np.array([3])) - - -def test_container_at_key_chain(on_device): - dict_in = { - "a": ivy.array([1], device=on_device), - "b": { - "c": ivy.array([2], device=on_device), - "d": ivy.array([3], device=on_device), - }, - } - container = Container(dict_in) - - # explicit function call - sub_container = container.cont_at_key_chain("b") - assert np.allclose(ivy.to_numpy(sub_container["c"]), np.array([2])) - sub_container = container.cont_at_key_chain("b/c") - assert np.allclose(ivy.to_numpy(sub_container), np.array([2])) - - # overridden built-in function call - sub_container = container["b"] - assert np.allclose(ivy.to_numpy(sub_container["c"]), np.array([2])) - sub_container = container["b/c"] - assert np.allclose(ivy.to_numpy(sub_container), np.array([2])) - - -def test_container_at_key_chains(on_device): - dict_in = { - "a": ivy.array([1], device=on_device), - "b": { - "c": ivy.array([2], device=on_device), - "d": ivy.array([3], device=on_device), - }, - } - container = Container(dict_in) - target_cont = Container({"a": True, "b": {"c": True}}) - new_container = container.cont_at_key_chains(target_cont) - assert np.allclose(ivy.to_numpy(new_container["a"]), np.array([1])) - assert np.allclose(ivy.to_numpy(new_container["b"]["c"]), np.array([2])) - assert "d" not in new_container["b"] - new_container = container.cont_at_key_chains(["b/c", "b/d"]) - assert "a" not in new_container - assert np.allclose(ivy.to_numpy(new_container["b"]["c"]), np.array([2])) - assert np.allclose(ivy.to_numpy(new_container["b"]["d"]), np.array([3])) - new_container = container.cont_at_key_chains("b/c") - assert "a" not in new_container - assert np.allclose(ivy.to_numpy(new_container["b"]["c"]), np.array([2])) - assert "d" not in new_container["b"] - - -@pytest.mark.parametrize("include_empty", [True, False]) -def test_container_all_key_chains(include_empty, on_device): - a_val = Container() if include_empty else ivy.array([1], device=on_device) - bc_val = Container() if include_empty else ivy.array([2], device=on_device) - bd_val = Container() if include_empty else ivy.array([3], device=on_device) - dict_in = {"a": a_val, "b": {"c": bc_val, "d": bd_val}} - container = Container(dict_in) - kcs = container.cont_all_key_chains(include_empty) - assert kcs[0] == "a" - assert kcs[1] == "b/c" - assert kcs[2] == "b/d" - - -@pytest.mark.parametrize("include_empty", [True, False]) -def test_container_key_chains_containing(include_empty, on_device): - a_val = Container() if include_empty else ivy.array([1], device=on_device) - bc_val = Container() if include_empty else ivy.array([2], device=on_device) - bd_val = Container() if include_empty else ivy.array([3], device=on_device) - dict_in = {"a_sub": a_val, "b": {"c": bc_val, "d_sub": bd_val}} - container = Container(dict_in) - kcs = container.cont_key_chains_containing("sub", include_empty) - assert kcs[0] == "a_sub" - assert kcs[1] == "b/d_sub" + # not the same + assert not ivy.Container.cont_identical([container0, container2]) + assert not ivy.Container.cont_identical([container2, container0]) + assert not ivy.Container.cont_identical([container1, container2]) + assert not ivy.Container.cont_identical([container2, container1]) + # partial + assert ivy.Container.cont_identical([container0, container3], partial=True) + assert ivy.Container.cont_identical([container3, container0], partial=True) + assert not ivy.Container.cont_identical([container0, container4], partial=True) + assert not ivy.Container.cont_identical([container4, container0], partial=True) -# noinspection PyUnresolvedReferences -def test_container_set_at_keys(on_device): - dict_in = { - "a": ivy.array([1], device=on_device), - "b": { - "c": ivy.array([2], device=on_device), - "d": ivy.array([3], device=on_device), - }, - } - container_orig = Container(dict_in) - # explicit function call - orig_container = container_orig.cont_copy() - container = orig_container.cont_set_at_keys({"b": ivy.array([4], device=on_device)}) - assert np.allclose(ivy.to_numpy(container["a"]), np.array([1])) - assert np.allclose(ivy.to_numpy(container["b"]), np.array([4])) - assert not container.cont_has_key("c") # noqa - assert not container.cont_has_key("d") # noqa - container = orig_container.cont_set_at_keys( - {"a": ivy.array([5], device=on_device), "c": ivy.array([6], device=on_device)} +def test_container_identical_array_shapes(on_device): + # without key_chains specification + container0 = Container( + { + "a": ivy.array([1, 2], device=on_device), + "b": { + "c": ivy.array([2, 3, 4], device=on_device), + "d": ivy.array([3, 4, 5, 6], device=on_device), + }, + } + ) + container1 = Container( + { + "a": ivy.array([1, 2, 3, 4], device=on_device), + "b": { + "c": ivy.array([3, 4], device=on_device), + "d": ivy.array([3, 4, 5], device=on_device), + }, + } + ) + container2 = Container( + { + "a": ivy.array([1, 2, 3, 4], device=on_device), + "b": { + "c": ivy.array([3, 4], device=on_device), + "d": ivy.array([3, 4, 5, 6], device=on_device), + }, + } ) - assert np.allclose(ivy.to_numpy(container["a"]), np.array([5])) - assert np.allclose(ivy.to_numpy(container["b"]["c"]), np.array([6])) - assert np.allclose(ivy.to_numpy(container["b"]["d"]), np.array([3])) - - -# noinspection PyUnresolvedReferences -def test_container_set_at_key_chain(on_device): - dict_in = { - "a": ivy.array([1], device=on_device), - "b": { - "c": ivy.array([2], device=on_device), - "d": ivy.array([3], device=on_device), - }, - } - container_orig = Container(dict_in) - # explicit function call - container = container_orig.cont_copy() - container = container.cont_set_at_key_chain("b/e", ivy.array([4], device=on_device)) - assert np.allclose(ivy.to_numpy(container["a"]), np.array([1])) - assert np.allclose(ivy.to_numpy(container["b"]["c"]), np.array([2])) - assert np.allclose(ivy.to_numpy(container["b"]["d"]), np.array([3])) - assert np.allclose(ivy.to_numpy(container["b"]["e"]), np.array([4])) - container = container.cont_set_at_key_chain("f", ivy.array([5], device=on_device)) - assert np.allclose(ivy.to_numpy(container["a"]), np.array([1])) - assert np.allclose(ivy.to_numpy(container["b"]["c"]), np.array([2])) - assert np.allclose(ivy.to_numpy(container["b"]["d"]), np.array([3])) - assert np.allclose(ivy.to_numpy(container["b"]["e"]), np.array([4])) - assert np.allclose(ivy.to_numpy(container["f"]), np.array([5])) + # with identical + assert ivy.Container.cont_identical_array_shapes([container0, container1]) + assert ivy.Container.cont_identical_array_shapes([container1, container0]) + assert ivy.Container.cont_identical_array_shapes( + [container1, container0, container1] + ) + assert not ivy.Container.cont_identical([container0, container2]) + assert not ivy.Container.cont_identical([container1, container2]) + assert not ivy.Container.cont_identical([container0, container1, container2]) - # overridden built-in function call - container = container_orig.cont_copy() - assert "b/e" not in container - container["b/e"] = ivy.array([4], device=on_device) - assert np.allclose(ivy.to_numpy(container["a"]), np.array([1])) - assert np.allclose(ivy.to_numpy(container["b"]["c"]), np.array([2])) - assert np.allclose(ivy.to_numpy(container["b"]["d"]), np.array([3])) - assert np.allclose(ivy.to_numpy(container["b"]["e"]), np.array([4])) - assert "f" not in container - container["f"] = ivy.array([5], device=on_device) - assert np.allclose(ivy.to_numpy(container["a"]), np.array([1])) - assert np.allclose(ivy.to_numpy(container["b"]["c"]), np.array([2])) - assert np.allclose(ivy.to_numpy(container["b"]["d"]), np.array([3])) - assert np.allclose(ivy.to_numpy(container["b"]["e"]), np.array([4])) - assert np.allclose(ivy.to_numpy(container["f"]), np.array([5])) +def test_container_identical_configs(on_device): + container0 = Container({"a": ivy.array([1], device=on_device)}, print_limit=5) + container1 = Container({"a": ivy.array([1], device=on_device)}, print_limit=5) + container2 = Container({"a": ivy.array([1], device=on_device)}, print_limit=10) -# noinspection PyUnresolvedReferences -def test_container_overwrite_at_key_chain(on_device): - dict_in = { - "a": ivy.array([1], device=on_device), - "b": { - "c": ivy.array([2], device=on_device), - "d": ivy.array([3], device=on_device), - }, - } - container_orig = Container(dict_in) + # with identical + assert ivy.Container.cont_identical_configs([container0, container1]) + assert ivy.Container.cont_identical_configs([container1, container0]) + assert ivy.Container.cont_identical_configs([container1, container0, container1]) - # explicit function call - container = container_orig.cont_copy() - # noinspection PyBroadException - try: - container.cont_overwrite_at_key_chain("b/e", ivy.array([4], device=on_device)) - exception_raised = False - except Exception: - exception_raised = True - assert exception_raised - container = container.cont_overwrite_at_key_chain( - "b/d", ivy.array([4], device=on_device) + # without identical + assert not ivy.Container.cont_identical_configs([container1, container2]) + assert not ivy.Container.cont_identical_configs( + [container1, container0, container2] ) - assert np.allclose(ivy.to_numpy(container["a"]), np.array([1])) - assert np.allclose(ivy.to_numpy(container["b"]["c"]), np.array([2])) - assert np.allclose(ivy.to_numpy(container["b"]["d"]), np.array([4])) -def test_container_set_at_key_chains(on_device): - container = Container( +def test_container_identical_structure(on_device): + # without key_chains specification + container0 = Container( { "a": ivy.array([1], device=on_device), "b": { @@ -1279,333 +1344,362 @@ def test_container_set_at_key_chains(on_device): }, } ) - target_container = Container( + container1 = Container( { - "a": ivy.array([4], device=on_device), - "b": {"d": ivy.array([5], device=on_device)}, + "a": ivy.array([3], device=on_device), + "b": { + "c": ivy.array([4], device=on_device), + "d": ivy.array([5], device=on_device), + }, } ) - new_container = container.cont_set_at_key_chains(target_container, inplace=False) - assert np.allclose(ivy.to_numpy(new_container["a"]), np.array([4])) - assert np.allclose(ivy.to_numpy(new_container["b"]["c"]), np.array([2])) - assert np.allclose(ivy.to_numpy(new_container["b"]["d"]), np.array([5])) - target_container = Container({"b": {"c": ivy.array([7], device=on_device)}}) - new_container = container.cont_set_at_key_chains(target_container, inplace=False) - assert np.allclose(ivy.to_numpy(new_container["a"]), np.array([1])) - assert np.allclose(ivy.to_numpy(new_container["b"]["c"]), np.array([7])) - assert np.allclose(ivy.to_numpy(new_container["b"]["d"]), np.array([3])) - - -def test_container_overwrite_at_key_chains(on_device): - container = Container( + container2 = Container( { - "a": ivy.array([1], device=on_device), + "a": ivy.array([3], device=on_device), "b": { - "c": ivy.array([2], device=on_device), - "d": ivy.array([3], device=on_device), + "c": ivy.array([4], device=on_device), + "d": ivy.array([5], device=on_device), + "e": ivy.array([6], device=on_device), }, } ) - target_container = Container( + container3 = Container( { - "a": ivy.array([4], device=on_device), - "b": {"d": ivy.array([5], device=on_device)}, + "a": ivy.array([3], device=on_device), + "b": { + "c": ivy.array([4], device=on_device), + "d": ivy.array([5], device=on_device), + }, + "e": ivy.array([6], device=on_device), } ) - new_container = container.cont_overwrite_at_key_chains( - target_container, inplace=False - ) - assert np.allclose(ivy.to_numpy(new_container["a"]), np.array([4])) - assert np.allclose(ivy.to_numpy(new_container["b"]["c"]), np.array([2])) - assert np.allclose(ivy.to_numpy(new_container["b"]["d"]), np.array([5])) - target_container = Container({"b": {"c": ivy.array([7], device=on_device)}}) - new_container = container.cont_overwrite_at_key_chains( - target_container, inplace=False - ) - assert np.allclose(ivy.to_numpy(new_container["a"]), np.array([1])) - assert np.allclose(ivy.to_numpy(new_container["b"]["c"]), np.array([7])) - assert np.allclose(ivy.to_numpy(new_container["b"]["d"]), np.array([3])) - # noinspection PyBroadException - try: - container.cont_overwrite_at_key_chains( - Container({"b": {"e": ivy.array([5], device=on_device)}}) - ) - exception_raised = False - except Exception: - exception_raised = True - assert exception_raised - - -def test_container_prune_keys(on_device): - dict_in = { - "a": ivy.array([1], device=on_device), - "b": { - "c": ivy.array([2], device=on_device), - "d": ivy.array([3], device=on_device), - }, - } - container = Container(dict_in) - container_pruned = container.cont_prune_keys(["a", "c"]) - assert "a" not in container_pruned - assert np.allclose(ivy.to_numpy(container_pruned["b"]["d"]), np.array([[3]])) - assert np.allclose(ivy.to_numpy(container_pruned.b.d), np.array([[3]])) - assert "c" not in container_pruned["b"] - - def _test_a_exception(container_in): - try: - _ = container_in.a - return False - except AttributeError: - return True - - def _test_bc_exception(container_in): - try: - _ = container_in.b.c - return False - except AttributeError: - return True - - def _test_bd_exception(container_in): - try: - _ = container_in.b.d - return False - except AttributeError: - return True - - assert _test_a_exception(container_pruned) - assert _test_bc_exception(container_pruned) - - container_pruned = container.cont_prune_keys(["a", "d"]) - assert "a" not in container_pruned - assert np.allclose(ivy.to_numpy(container_pruned["b"]["c"]), np.array([[2]])) - assert np.allclose(ivy.to_numpy(container_pruned.b.c), np.array([[2]])) - assert "d" not in container_pruned["b"] - assert _test_a_exception(container_pruned) - assert _test_bd_exception(container_pruned) - - -def test_container_prune_key_chain(on_device): - dict_in = { - "a": ivy.array([1], device=on_device), - "b": {"c": ivy.array([2], device=on_device), "d": None}, - } - container = Container(dict_in) - container_pruned = container.cont_prune_key_chain("b/c") - assert np.allclose(ivy.to_numpy(container_pruned["a"]), np.array([[1]])) - assert np.allclose(ivy.to_numpy(container_pruned.a), np.array([[1]])) - assert container_pruned["b"]["d"] is None - assert container_pruned.b.d is None - assert "c" not in container_pruned["b"].keys() - - def _test_exception(container_in): - try: - _ = container_in.b.c - return False - except AttributeError: - return True - - assert _test_exception(container_pruned) + container4 = Container({"b": {"d": ivy.array([4], device=on_device)}}) + container5 = Container({"d": ivy.array([4], device=on_device)}) - container_pruned = container.cont_prune_key_chain("b") - assert np.allclose(ivy.to_numpy(container_pruned["a"]), np.array([[1]])) - assert np.allclose(ivy.to_numpy(container_pruned.a), np.array([[1]])) - assert "b" not in container_pruned.keys() + # with identical + assert ivy.Container.cont_identical_structure([container0, container1]) + assert ivy.Container.cont_identical_structure([container1, container0]) + assert ivy.Container.cont_identical_structure([container1, container0, container1]) - def _test_exception(container_in): - try: - _ = container_in.b - return False - except AttributeError: - return True + # without identical + assert not ivy.Container.cont_identical_structure([container2, container3]) + assert not ivy.Container.cont_identical_structure([container0, container3]) + assert not ivy.Container.cont_identical_structure([container1, container2]) + assert not ivy.Container.cont_identical_structure( + [container1, container0, container2] + ) - assert _test_exception(container_pruned) + # partial + assert ivy.Container.cont_identical_structure( + [container0, container4], partial=True + ) + assert ivy.Container.cont_identical_structure( + [container1, container4], partial=True + ) + assert ivy.Container.cont_identical_structure( + [container2, container4], partial=True + ) + assert ivy.Container.cont_identical_structure( + [container3, container4], partial=True + ) + assert ivy.Container.cont_identical_structure( + [container4, container4], partial=True + ) + assert not ivy.Container.cont_identical_structure( + [container0, container5], partial=True + ) + assert not ivy.Container.cont_identical_structure( + [container1, container5], partial=True + ) + assert not ivy.Container.cont_identical_structure( + [container2, container5], partial=True + ) + assert not ivy.Container.cont_identical_structure( + [container3, container5], partial=True + ) + assert not ivy.Container.cont_identical_structure( + [container4, container5], partial=True + ) -def test_container_prune_key_chains(on_device): +def test_container_if_exists(on_device): dict_in = { - "a": ivy.array([1], device=on_device), + "a": ivy.array([[[1.0], [2.0], [3.0]]], device=on_device), "b": { - "c": ivy.array([2], device=on_device), - "d": ivy.array([3], device=on_device), + "c": ivy.array([[[2.0], [4.0], [6.0]]], device=on_device), + "d": ivy.array([[[3.0], [6.0], [9.0]]], device=on_device), }, } container = Container(dict_in) - container_pruned = container.cont_prune_key_chains(["a", "b/c"]) - assert "a" not in container_pruned - assert np.allclose(ivy.to_numpy(container_pruned["b"]["d"]), np.array([[3]])) - assert np.allclose(ivy.to_numpy(container_pruned.b.d), np.array([[3]])) - assert "c" not in container_pruned["b"] - - def _test_a_exception(container_in): - try: - _ = container_in.a - return False - except AttributeError: - return True - - def _test_bc_exception(container_in): - try: - _ = container_in.b.c - return False - except AttributeError: - return True + assert np.allclose( + ivy.to_numpy(container.cont_if_exists("a")), np.array([[[1.0], [2.0], [3.0]]]) + ) + assert "c" not in container + assert container.cont_if_exists("c") is None + container["c"] = ivy.array([[[1.0], [2.0], [3.0]]], device=on_device) + assert np.allclose( + ivy.to_numpy(container.cont_if_exists("c")), np.array([[[1.0], [2.0], [3.0]]]) + ) + assert container.cont_if_exists("d") is None + container.d = ivy.array([[[1.0], [2.0], [3.0]]], device=on_device) + assert np.allclose( + ivy.to_numpy(container.cont_if_exists("d")), np.array([[[1.0], [2.0], [3.0]]]) + ) - assert _test_a_exception(container_pruned) - assert _test_bc_exception(container_pruned) - container_pruned = container.cont_prune_key_chains( - Container({"a": True, "b": {"c": True}}) +def test_container_inplace(on_device): + container0 = Container( + { + "a": ivy.array([1], device=on_device), + "b": { + "c": ivy.array([1], device=on_device), + "d": ivy.array([2], device=on_device), + }, + } + ) + const = 3 + arr = ivy.array([1], device=on_device) + container1 = Container( + { + "a": ivy.array([3], device=on_device), + "b": { + "c": ivy.array([4], device=on_device), + "d": ivy.array([5], device=on_device), + }, + } ) - assert "a" not in container_pruned - assert np.allclose(ivy.to_numpy(container_pruned["b"]["d"]), np.array([[3]])) - assert np.allclose(ivy.to_numpy(container_pruned.b.d), np.array([[3]])) - assert "c" not in container_pruned["b"] - assert _test_a_exception(container_pruned) - assert _test_bc_exception(container_pruned) + special_funcs = [ + "__add__", + "__and__", + "__floordiv__", + "__lshift__", + "__matmul__", + "__mod__", + "__mul__", + "__pow__", + "__rshift__", + "__sub__", + "__truediv__", + "__xor__", + ] -def test_container_format_key_chains(on_device): - dict_in = { - "_a": ivy.array([1], device=on_device), - "b ": { - "c": ivy.array([2], device=on_device), - "d-": ivy.array([3], device=on_device), - }, - } - cont = Container(dict_in) - cont_formatted = cont.cont_format_key_chains( - lambda s: s.replace("_", "").replace(" ", "").replace("-", "") - ) - assert np.allclose(ivy.to_numpy(cont_formatted["a"]), np.array([1])) - assert np.allclose(ivy.to_numpy(cont_formatted.a), np.array([1])) - assert np.allclose(ivy.to_numpy(cont_formatted["b"]["c"]), np.array([2])) - assert np.allclose(ivy.to_numpy(cont_formatted.b.c), np.array([2])) - assert np.allclose(ivy.to_numpy(cont_formatted["b"]["d"]), np.array([3])) - assert np.allclose(ivy.to_numpy(cont_formatted.b.d), np.array([3])) + for func_str in special_funcs: + func = getattr(Container, func_str) + ifunc = getattr(Container, func_str[:2] + "i" + func_str[2:]) + for value in [ + const, + arr, + container1, + ]: + if value == const and func_str == "__matmul__": + continue + container0_copy = container0.cont_deep_copy() + id_before_op = id(container0_copy) + og_ids = container0_copy.cont_map(lambda x, _: id(x)) + ifunc(container0_copy, value) + op_ids = container0_copy.cont_map(lambda x, _: id(x)) -def test_container_sort_by_key(on_device): - dict_in = { - "b": ivy.array([1], device=on_device), - "a": { - "d": ivy.array([2], device=on_device), - "c": ivy.array([3], device=on_device), - }, - } - container = Container(dict_in) - container_sorted = container.cont_sort_by_key() - for k, k_true in zip(container_sorted.keys(), ["a", "b"]): - assert k == k_true - for k, k_true in zip(container_sorted.a.keys(), ["c", "d"]): - assert k == k_true + assert func(container0, value) == container0_copy # values + assert id(container0_copy) == id_before_op # container ids + assert og_ids == op_ids # value ids -def test_container_prune_empty(on_device): - dict_in = { - "a": ivy.array([1], device=on_device), - "b": {"c": {}, "d": ivy.array([3], device=on_device)}, - } +@pytest.mark.parametrize("include_empty", [True, False]) +def test_container_key_chains_containing(include_empty, on_device): + a_val = Container() if include_empty else ivy.array([1], device=on_device) + bc_val = Container() if include_empty else ivy.array([2], device=on_device) + bd_val = Container() if include_empty else ivy.array([3], device=on_device) + dict_in = {"a_sub": a_val, "b": {"c": bc_val, "d_sub": bd_val}} container = Container(dict_in) - container_pruned = container.cont_prune_empty() - assert np.allclose(ivy.to_numpy(container_pruned["a"]), np.array([[1]])) - assert np.allclose(ivy.to_numpy(container_pruned.a), np.array([[1]])) - assert np.allclose(ivy.to_numpy(container_pruned["b"]["d"]), np.array([[3]])) - assert np.allclose(ivy.to_numpy(container_pruned.b.d), np.array([[3]])) - assert "c" not in container_pruned["b"] + kcs = container.cont_key_chains_containing("sub", include_empty) + assert kcs[0] == "a_sub" + assert kcs[1] == "b/d_sub" - def _test_exception(container_in): - try: - _ = container_in.b.c - return False - except AttributeError: - return True - assert _test_exception(container_pruned) +def test_container_list_join(on_device): + container_0 = Container( + { + "a": [ivy.array([1], device=on_device)], + "b": { + "c": [ivy.array([2], device=on_device)], + "d": [ivy.array([3], device=on_device)], + }, + } + ) + container_1 = Container( + { + "a": [ivy.array([4], device=on_device)], + "b": { + "c": [ivy.array([5], device=on_device)], + "d": [ivy.array([6], device=on_device)], + }, + } + ) + container_list_joined = ivy.Container.cont_list_join([container_0, container_1]) + assert np.allclose(ivy.to_numpy(container_list_joined["a"][0]), np.array([1])) + assert np.allclose(ivy.to_numpy(container_list_joined.a[0]), np.array([1])) + assert np.allclose(ivy.to_numpy(container_list_joined["b"]["c"][0]), np.array([2])) + assert np.allclose(ivy.to_numpy(container_list_joined.b.c[0]), np.array([2])) + assert np.allclose(ivy.to_numpy(container_list_joined["b"]["d"][0]), np.array([3])) + assert np.allclose(ivy.to_numpy(container_list_joined.b.d[0]), np.array([3])) + assert np.allclose(ivy.to_numpy(container_list_joined["a"][1]), np.array([4])) + assert np.allclose(ivy.to_numpy(container_list_joined.a[1]), np.array([4])) + assert np.allclose(ivy.to_numpy(container_list_joined["b"]["c"][1]), np.array([5])) + assert np.allclose(ivy.to_numpy(container_list_joined.b.c[1]), np.array([5])) + assert np.allclose(ivy.to_numpy(container_list_joined["b"]["d"][1]), np.array([6])) + assert np.allclose(ivy.to_numpy(container_list_joined.b.d[1]), np.array([6])) -def test_container_prune_key_from_key_chains(on_device): - container = Container( +def test_container_list_stack(on_device): + container_0 = Container( { - "Ayy": ivy.array([1], device=on_device), - "Bee": { - "Cee": ivy.array([2], device=on_device), - "Dee": ivy.array([3], device=on_device), + "a": ivy.array([1], device=on_device), + "b": { + "c": ivy.array([2], device=on_device), + "d": ivy.array([3], device=on_device), }, - "Beh": { - "Ceh": ivy.array([4], device=on_device), - "Deh": ivy.array([5], device=on_device), + } + ) + container_1 = Container( + { + "a": ivy.array([4], device=on_device), + "b": { + "c": ivy.array([5], device=on_device), + "d": ivy.array([6], device=on_device), }, } ) + container_list_stacked = ivy.Container.cont_list_stack( + [container_0, container_1], 0 + ) + assert np.allclose(ivy.to_numpy(container_list_stacked["a"][0]), np.array([1])) + assert np.allclose(ivy.to_numpy(container_list_stacked.a[0]), np.array([1])) + assert np.allclose(ivy.to_numpy(container_list_stacked["b"]["c"][0]), np.array([2])) + assert np.allclose(ivy.to_numpy(container_list_stacked.b.c[0]), np.array([2])) + assert np.allclose(ivy.to_numpy(container_list_stacked["b"]["d"][0]), np.array([3])) + assert np.allclose(ivy.to_numpy(container_list_stacked.b.d[0]), np.array([3])) + assert np.allclose(ivy.to_numpy(container_list_stacked["a"][1]), np.array([4])) + assert np.allclose(ivy.to_numpy(container_list_stacked.a[1]), np.array([4])) + assert np.allclose(ivy.to_numpy(container_list_stacked["b"]["c"][1]), np.array([5])) + assert np.allclose(ivy.to_numpy(container_list_stacked.b.c[1]), np.array([5])) + assert np.allclose(ivy.to_numpy(container_list_stacked["b"]["d"][1]), np.array([6])) + assert np.allclose(ivy.to_numpy(container_list_stacked.b.d[1]), np.array([6])) - # absolute - container_pruned = container.cont_prune_key_from_key_chains("Bee") - assert np.allclose(ivy.to_numpy(container_pruned["Ayy"]), np.array([[1]])) - assert np.allclose(ivy.to_numpy(container_pruned.Ayy), np.array([[1]])) - assert np.allclose(ivy.to_numpy(container_pruned["Cee"]), np.array([[2]])) - assert np.allclose(ivy.to_numpy(container_pruned.Cee), np.array([[2]])) - assert np.allclose(ivy.to_numpy(container_pruned["Dee"]), np.array([[3]])) - assert np.allclose(ivy.to_numpy(container_pruned.Dee), np.array([[3]])) - assert "Bee" not in container_pruned - # containing - container_pruned = container.cont_prune_key_from_key_chains(containing="B") - assert np.allclose(ivy.to_numpy(container_pruned["Ayy"]), np.array([[1]])) - assert np.allclose(ivy.to_numpy(container_pruned.Ayy), np.array([[1]])) - assert np.allclose(ivy.to_numpy(container_pruned["Cee"]), np.array([[2]])) - assert np.allclose(ivy.to_numpy(container_pruned.Cee), np.array([[2]])) - assert np.allclose(ivy.to_numpy(container_pruned["Dee"]), np.array([[3]])) - assert np.allclose(ivy.to_numpy(container_pruned.Dee), np.array([[3]])) - assert np.allclose(ivy.to_numpy(container_pruned["Ceh"]), np.array([[4]])) - assert np.allclose(ivy.to_numpy(container_pruned.Ceh), np.array([[4]])) - assert np.allclose(ivy.to_numpy(container_pruned["Deh"]), np.array([[5]])) - assert np.allclose(ivy.to_numpy(container_pruned.Deh), np.array([[5]])) - assert "Bee" not in container_pruned - assert "Beh" not in container_pruned +@pytest.mark.parametrize("inplace", [True, False]) +def test_container_map(inplace, on_device): + # without key_chains specification + dict_in = { + "a": ivy.array([1], device=on_device), + "b": { + "c": ivy.array([2], device=on_device), + "d": ivy.array([3], device=on_device), + }, + } + container_orig = Container(dict_in) + container = container_orig.cont_deep_copy() + container_mapped = container.cont_map(lambda x, _: x + 1, inplace=inplace) + if inplace: + container_iterator = container.cont_to_iterator() + else: + container_iterator = container_mapped.cont_to_iterator() + for (key, value), expected_value in zip( + container_iterator, + [ + ivy.array([2], device=on_device), + ivy.array([3], device=on_device), + ivy.array([4], device=on_device), + ], + ): + assert ivy.to_numpy(value) == ivy.to_numpy(expected_value) + + # with key_chains to apply + container = container_orig.cont_deep_copy() + container_mapped = container.cont_map( + lambda x, _: x + 1, ["a", "b/c"], inplace=inplace + ) + if inplace: + container_mapped = container + assert np.allclose(ivy.to_numpy(container_mapped["a"]), np.array([[2]])) + assert np.allclose(ivy.to_numpy(container_mapped.a), np.array([[2]])) + assert np.allclose(ivy.to_numpy(container_mapped["b"]["c"]), np.array([[3]])) + assert np.allclose(ivy.to_numpy(container_mapped.b.c), np.array([[3]])) + assert np.allclose(ivy.to_numpy(container_mapped["b"]["d"]), np.array([[3]])) + assert np.allclose(ivy.to_numpy(container_mapped.b.d), np.array([[3]])) + + # with key_chains to apply pruned + container = container_orig.cont_deep_copy() + container_mapped = container.cont_map( + lambda x, _: x + 1, ["a", "b/c"], prune_unapplied=True, inplace=inplace + ) + if inplace: + container_mapped = container + assert np.allclose(ivy.to_numpy(container_mapped["a"]), np.array([[2]])) + assert np.allclose(ivy.to_numpy(container_mapped.a), np.array([[2]])) + assert np.allclose(ivy.to_numpy(container_mapped["b"]["c"]), np.array([[3]])) + assert np.allclose(ivy.to_numpy(container_mapped.b.c), np.array([[3]])) + if not inplace: + assert "b/d" not in container_mapped + + # with key_chains to not apply + container = container_orig.cont_deep_copy() + container_mapped = container.cont_map( + lambda x, _: x + 1, + Container({"a": None, "b": {"d": None}}), + to_apply=False, + inplace=inplace, + ) + if inplace: + container_mapped = container + assert np.allclose(ivy.to_numpy(container_mapped["a"]), np.array([[1]])) + assert np.allclose(ivy.to_numpy(container_mapped.a), np.array([[1]])) + assert np.allclose(ivy.to_numpy(container_mapped["b"]["c"]), np.array([[3]])) + assert np.allclose(ivy.to_numpy(container_mapped.b.c), np.array([[3]])) + assert np.allclose(ivy.to_numpy(container_mapped["b"]["d"]), np.array([[3]])) + assert np.allclose(ivy.to_numpy(container_mapped.b.d), np.array([[3]])) + # with key_chains to not apply pruned + container = container_orig.cont_deep_copy() + container_mapped = container.cont_map( + lambda x, _: x + 1, + Container({"a": None, "b": {"d": None}}), + to_apply=False, + prune_unapplied=True, + inplace=inplace, + ) + if inplace: + container_mapped = container + if not inplace: + assert "a" not in container_mapped + assert np.allclose(ivy.to_numpy(container_mapped["b"]["c"]), np.array([[3]])) + assert np.allclose(ivy.to_numpy(container_mapped.b.c), np.array([[3]])) + if not inplace: + assert "b/d" not in container_mapped -def test_container_prune_keys_from_key_chains(on_device): - container = Container( + # with sequences + container_orig = Container( { - "Ayy": ivy.array([1], device=on_device), - "Bee": { - "Cee": ivy.array([2], device=on_device), - "Dee": ivy.array([3], device=on_device), - }, - "Eee": {"Fff": ivy.array([4], device=on_device)}, + "a": ivy.array([1], device=on_device), + "b": [ivy.array([2], device=on_device), ivy.array([3], device=on_device)], } ) - - # absolute - container_pruned = container.cont_prune_keys_from_key_chains(["Bee", "Eee"]) - assert np.allclose(ivy.to_numpy(container_pruned["Ayy"]), np.array([[1]])) - assert np.allclose(ivy.to_numpy(container_pruned.Ayy), np.array([[1]])) - assert np.allclose(ivy.to_numpy(container_pruned["Cee"]), np.array([[2]])) - assert np.allclose(ivy.to_numpy(container_pruned.Cee), np.array([[2]])) - assert np.allclose(ivy.to_numpy(container_pruned["Dee"]), np.array([[3]])) - assert np.allclose(ivy.to_numpy(container_pruned.Dee), np.array([[3]])) - assert np.allclose(ivy.to_numpy(container_pruned["Fff"]), np.array([[4]])) - assert np.allclose(ivy.to_numpy(container_pruned.Fff), np.array([[4]])) - assert "Bee" not in container_pruned - assert "Eee" not in container_pruned - - # containing - container_pruned = container.cont_prune_keys_from_key_chains(containing=["B", "E"]) - assert np.allclose(ivy.to_numpy(container_pruned["Ayy"]), np.array([[1]])) - assert np.allclose(ivy.to_numpy(container_pruned.Ayy), np.array([[1]])) - assert np.allclose(ivy.to_numpy(container_pruned["Cee"]), np.array([[2]])) - assert np.allclose(ivy.to_numpy(container_pruned.Cee), np.array([[2]])) - assert np.allclose(ivy.to_numpy(container_pruned["Dee"]), np.array([[3]])) - assert np.allclose(ivy.to_numpy(container_pruned.Dee), np.array([[3]])) - assert np.allclose(ivy.to_numpy(container_pruned["Fff"]), np.array([[4]])) - assert np.allclose(ivy.to_numpy(container_pruned.Fff), np.array([[4]])) - assert "Bee" not in container_pruned - assert "Eee" not in container_pruned + container = container_orig.cont_deep_copy() + container_mapped = container.cont_map( + lambda x, _: x + 1, inplace=inplace, map_sequences=True + ) + if inplace: + container_mapped = container + assert np.allclose(ivy.to_numpy(container_mapped["a"]), np.array([2])) + assert np.allclose(ivy.to_numpy(container_mapped["b"][0]), np.array([3])) + assert np.allclose(ivy.to_numpy(container_mapped["b"][1]), np.array([4])) -def test_container_restructure_key_chains(on_device): - # single - container = Container( +@pytest.mark.parametrize("inplace", [True, False]) +def test_container_map_sub_conts(inplace, on_device): + # without key_chains specification + container_orig = Container( { "a": ivy.array([1], device=on_device), "b": { @@ -1614,16 +1708,38 @@ def test_container_restructure_key_chains(on_device): }, } ) - container_restructured = container.cont_restructure_key_chains({"a": "A"}) - assert np.allclose(ivy.to_numpy(container_restructured["A"]), np.array([[1]])) - assert np.allclose(ivy.to_numpy(container_restructured.A), np.array([[1]])) - assert np.allclose(ivy.to_numpy(container_restructured["b/c"]), np.array([[2]])) - assert np.allclose(ivy.to_numpy(container_restructured.b.c), np.array([[2]])) - assert np.allclose(ivy.to_numpy(container_restructured["b/d"]), np.array([[3]])) - assert np.allclose(ivy.to_numpy(container_restructured.b.d), np.array([[3]])) - # full - container = Container( + def _add_e_attr(cont_in): + cont_in.e = ivy.array([4], device=on_device) + return cont_in + + # with self + container = container_orig.cont_deep_copy() + container_mapped = container.cont_map_sub_conts( + lambda c, _: _add_e_attr(c), inplace=inplace + ) + if inplace: + container_mapped = container + assert "e" in container_mapped + assert np.array_equal(ivy.to_numpy(container_mapped.e), np.array([4])) + assert "e" in container_mapped.b + assert np.array_equal(ivy.to_numpy(container_mapped.b.e), np.array([4])) + + # without self + container = container_orig.cont_deep_copy() + container_mapped = container.cont_map_sub_conts( + lambda c, _: _add_e_attr(c), include_self=False, inplace=inplace + ) + if inplace: + container_mapped = container + assert "e" not in container_mapped + assert "e" in container_mapped.b + assert np.array_equal(ivy.to_numpy(container_mapped.b.e), np.array([4])) + + +def test_container_multi_map(on_device): + # without key_chains specification + container0 = Container( { "a": ivy.array([1], device=on_device), "b": { @@ -1632,235 +1748,264 @@ def test_container_restructure_key_chains(on_device): }, } ) - container_restructured = container.cont_restructure_key_chains( - {"a": "A", "b/c": "B/C", "b/d": "B/D"} - ) - assert np.allclose(ivy.to_numpy(container_restructured["A"]), np.array([[1]])) - assert np.allclose(ivy.to_numpy(container_restructured.A), np.array([[1]])) - assert np.allclose(ivy.to_numpy(container_restructured["B/C"]), np.array([[2]])) - assert np.allclose(ivy.to_numpy(container_restructured.B.C), np.array([[2]])) - assert np.allclose(ivy.to_numpy(container_restructured["B/D"]), np.array([[3]])) - assert np.allclose(ivy.to_numpy(container_restructured.B.D), np.array([[3]])) - - -def test_container_restructure(on_device): - container = Container( + container1 = Container( { - "a": ivy.array([[1, 2], [3, 4]], device=on_device), + "a": ivy.array([3], device=on_device), "b": { - "c": ivy.array([[2, 4], [6, 8]], device=on_device), - "d": ivy.array([3, 6, 9, 12], device=on_device), + "c": ivy.array([4], device=on_device), + "d": ivy.array([5], device=on_device), }, } ) - container_restructured = container.cont_restructure( - { - "a": {"key_chain": "A", "pattern": "a b -> b a"}, - "b/c": {"key_chain": "B/C", "pattern": "a b -> (a b)"}, - "b/d": { - "key_chain": "B/D", - "pattern": "(a b) -> a b", - "axes_lengths": {"a": 2, "b": 2}, - }, - }, - keep_orig=False, + + # with key_chains to apply + container_mapped = ivy.Container.cont_multi_map( + lambda x, _: x[0] + x[1], [container0, container1], assert_identical=True ) - assert np.allclose( - ivy.to_numpy(container_restructured["A"]), np.array([[1, 3], [2, 4]]) + assert np.allclose(ivy.to_numpy(container_mapped["a"]), np.array([[4]])) + assert np.allclose(ivy.to_numpy(container_mapped.a), np.array([[4]])) + assert np.allclose(ivy.to_numpy(container_mapped["b"]["c"]), np.array([[6]])) + assert np.allclose(ivy.to_numpy(container_mapped.b.c), np.array([[6]])) + assert np.allclose(ivy.to_numpy(container_mapped["b"]["d"]), np.array([[8]])) + assert np.allclose(ivy.to_numpy(container_mapped.b.d), np.array([[8]])) + + # with sequences + container0 = Container( + { + "a": ivy.array([1], device=on_device), + "b": [ + ivy.array([2], device=on_device), + ivy.array([3], device=on_device), + ], + } ) - assert np.allclose( - ivy.to_numpy(container_restructured.A), np.array([[1, 3], [2, 4]]) + container1 = Container( + { + "a": ivy.array([3], device=on_device), + "b": [ + ivy.array([4], device=on_device), + ivy.array([5], device=on_device), + ], + } ) - assert np.allclose( - ivy.to_numpy(container_restructured["B/C"]), np.array([2, 4, 6, 8]) + + container_mapped = ivy.Container.cont_multi_map( + lambda x, _: x[0] + x[1], + [container0, container1], + map_nests=True, + assert_identical=True, ) - assert np.allclose(ivy.to_numpy(container_restructured.B.C), np.array([2, 4, 6, 8])) - assert np.allclose( - ivy.to_numpy(container_restructured["B/D"]), np.array([[3, 6], [9, 12]]) + + assert np.allclose(ivy.to_numpy(container_mapped["a"]), np.array([4])) + assert np.allclose(ivy.to_numpy(container_mapped["b"][0]), np.array([6])) + assert np.allclose(ivy.to_numpy(container_mapped["b"][1]), np.array([8])) + + # Non identical containers + a = ivy.Container(a={"b": 2, "c": 4}, d={"e": 6, "f": 9}) + b = ivy.Container(a=2, d=3) + container_mapped = ivy.Container.cont_multi_map(lambda xs, _: xs[0] / xs[1], [a, b]) + + assert np.allclose(ivy.to_numpy(container_mapped["a"].b), 1) + assert np.allclose(ivy.to_numpy(container_mapped["a"]["c"]), 2) + assert np.allclose(ivy.to_numpy(container_mapped.d.e), 2) + assert np.allclose(ivy.to_numpy(container_mapped["d"].f), 3) + + +def test_container_num_arrays(on_device): + dict_in = { + "a": ivy.array([[0.0, 1.0, 2.0, 3.0]], device=on_device), + "b": { + "c": ivy.array([[5.0, 10.0, 15.0, 20.0]], device=on_device), + "d": ivy.array([[10.0, 9.0, 8.0, 7.0]], device=on_device), + }, + } + container = Container(dict_in) + assert container.cont_num_arrays() == 3 + dict_in = { + "a": ivy.array([[0.0, 1.0, 2.0, 3.0]], device=on_device), + "b": { + "c": _variable(ivy.array([[5.0, 10.0, 15.0, 20.0]], device=on_device)), + "d": ivy.array([[10.0, 9.0, 8.0, 7.0]], device=on_device), + }, + } + container = Container(dict_in) + assert ( + container.cont_num_arrays() == 3 + if ivy.current_backend_str() in ("numpy", "jax") + else 2 ) - assert np.allclose( - ivy.to_numpy(container_restructured.B.D), np.array([[3, 6], [9, 12]]) + + +# noinspection PyUnresolvedReferences +def test_container_overwrite_at_key_chain(on_device): + dict_in = { + "a": ivy.array([1], device=on_device), + "b": { + "c": ivy.array([2], device=on_device), + "d": ivy.array([3], device=on_device), + }, + } + container_orig = Container(dict_in) + + # explicit function call + container = container_orig.cont_copy() + # noinspection PyBroadException + try: + container.cont_overwrite_at_key_chain("b/e", ivy.array([4], device=on_device)) + exception_raised = False + except Exception: + exception_raised = True + assert exception_raised + container = container.cont_overwrite_at_key_chain( + "b/d", ivy.array([4], device=on_device) ) + assert np.allclose(ivy.to_numpy(container["a"]), np.array([1])) + assert np.allclose(ivy.to_numpy(container["b"]["c"]), np.array([2])) + assert np.allclose(ivy.to_numpy(container["b"]["d"]), np.array([4])) -def test_container_flatten_key_chains(on_device): +def test_container_overwrite_at_key_chains(on_device): container = Container( { "a": ivy.array([1], device=on_device), "b": { - "c": {"d": ivy.array([2], device=on_device)}, - "e": {"f": {"g": ivy.array([3], device=on_device)}}, + "c": ivy.array([2], device=on_device), + "d": ivy.array([3], device=on_device), }, } ) - - # full - container_flat = container.cont_flatten_key_chains() - assert np.allclose(ivy.to_numpy(container_flat["a"]), np.array([[1]])) - assert np.allclose(ivy.to_numpy(container_flat.a), np.array([[1]])) - assert np.allclose(ivy.to_numpy(container_flat["b__c__d"]), np.array([[2]])) - assert np.allclose(ivy.to_numpy(container_flat.b__c__d), np.array([[2]])) - assert np.allclose(ivy.to_numpy(container_flat["b__e__f__g"]), np.array([[3]])) - assert np.allclose(ivy.to_numpy(container_flat.b__e__f__g), np.array([[3]])) - - # above height 1 - container_flat = container.cont_flatten_key_chains(above_height=1) - assert np.allclose(ivy.to_numpy(container_flat["a"]), np.array([[1]])) - assert np.allclose(ivy.to_numpy(container_flat.a), np.array([[1]])) - assert np.allclose(ivy.to_numpy(container_flat["b__c"]["d"]), np.array([[2]])) - assert np.allclose(ivy.to_numpy(container_flat.b__c.d), np.array([[2]])) - assert np.allclose(ivy.to_numpy(container_flat["b__e__f"]["g"]), np.array([[3]])) - assert np.allclose(ivy.to_numpy(container_flat.b__e__f.g), np.array([[3]])) - - # below depth 1 - container_flat = container.cont_flatten_key_chains(below_depth=1) - assert np.allclose(ivy.to_numpy(container_flat["a"]), np.array([[1]])) - assert np.allclose(ivy.to_numpy(container_flat.a), np.array([[1]])) - assert np.allclose(ivy.to_numpy(container_flat["b"]["c__d"]), np.array([[2]])) - assert np.allclose(ivy.to_numpy(container_flat.b.c__d), np.array([[2]])) - assert np.allclose(ivy.to_numpy(container_flat["b"]["e__f__g"]), np.array([[3]])) - assert np.allclose(ivy.to_numpy(container_flat.b.e__f__g), np.array([[3]])) - - # above height 1, below depth 1 - container_flat = container.cont_flatten_key_chains(above_height=1, below_depth=1) - assert np.allclose(ivy.to_numpy(container_flat["a"]), np.array([[1]])) - assert np.allclose(ivy.to_numpy(container_flat.a), np.array([[1]])) - assert np.allclose(ivy.to_numpy(container_flat["b"]["c"]["d"]), np.array([[2]])) - assert np.allclose(ivy.to_numpy(container_flat.b.c.d), np.array([[2]])) - assert np.allclose(ivy.to_numpy(container_flat["b"]["e__f"]["g"]), np.array([[3]])) - assert np.allclose(ivy.to_numpy(container_flat.b.e__f.g), np.array([[3]])) + target_container = Container( + { + "a": ivy.array([4], device=on_device), + "b": {"d": ivy.array([5], device=on_device)}, + } + ) + new_container = container.cont_overwrite_at_key_chains( + target_container, inplace=False + ) + assert np.allclose(ivy.to_numpy(new_container["a"]), np.array([4])) + assert np.allclose(ivy.to_numpy(new_container["b"]["c"]), np.array([2])) + assert np.allclose(ivy.to_numpy(new_container["b"]["d"]), np.array([5])) + target_container = Container({"b": {"c": ivy.array([7], device=on_device)}}) + new_container = container.cont_overwrite_at_key_chains( + target_container, inplace=False + ) + assert np.allclose(ivy.to_numpy(new_container["a"]), np.array([1])) + assert np.allclose(ivy.to_numpy(new_container["b"]["c"]), np.array([7])) + assert np.allclose(ivy.to_numpy(new_container["b"]["d"]), np.array([3])) + # noinspection PyBroadException + try: + container.cont_overwrite_at_key_chains( + Container({"b": {"e": ivy.array([5], device=on_device)}}) + ) + exception_raised = False + except Exception: + exception_raised = True + assert exception_raised -def test_container_deep_copy(on_device): +def test_container_pickle(on_device): dict_in = { - "a": ivy.array([0.0], device=on_device), + "a": ivy.array([np.float32(1.0)], device=on_device), "b": { - "c": ivy.array([1.0], device=on_device), - "d": ivy.array([2.0], device=on_device), + "c": ivy.array([np.float32(2.0)], device=on_device), + "d": ivy.array([np.float32(3.0)], device=on_device), }, } - cont = Container(dict_in) - cont_deepcopy = cont.cont_deep_copy() - assert np.allclose(ivy.to_numpy(cont.a), ivy.to_numpy(cont_deepcopy.a)) - assert np.allclose(ivy.to_numpy(cont.b.c), ivy.to_numpy(cont_deepcopy.b.c)) - assert np.allclose(ivy.to_numpy(cont.b.d), ivy.to_numpy(cont_deepcopy.b.d)) - assert id(cont.a) != id(cont_deepcopy.a) - assert id(cont.b.c) != id(cont_deepcopy.b.c) - assert id(cont.b.d) != id(cont_deepcopy.b.d) - -def test_container_contains(on_device): - arr0 = ivy.array([0.0], device=on_device) - arr1 = ivy.array([1.0], device=on_device) - arr2 = ivy.array([2.0], device=on_device) - sub_cont = Container({"c": arr1, "d": arr2}) - container = Container({"a": arr0, "b": sub_cont}) + # without module attribute + cont = Container(dict_in) - # keys - assert "a" in container - assert "b" in container - assert "c" not in container - assert "b/c" in container - assert "d" not in container - assert "b/d" in container + # paddle tansor can't be pickled directly as mentioned + # in the issue https://github.com/PaddlePaddle/Paddle/issues/41107 + if ivy.backend == "paddle": + cont = cont.to_numpy() - # sub-container - assert container.cont_contains_sub_container(container) - assert container.cont_contains_sub_container(sub_cont) - assert sub_cont in container + assert cont._local_ivy is None + pickled = pickle.dumps(cont) + cont_again = pickle.loads(pickled) + assert cont_again._local_ivy is None + ivy.Container.cont_identical_structure([cont, cont_again]) + ivy.Container.cont_identical_configs([cont, cont_again]) - # partial sub-container - partial_sub_cont = Container({"b": {"d": arr2}}) - assert container.cont_contains_sub_container(container, partial=True) - assert container.cont_contains_sub_container(partial_sub_cont, partial=True) - assert not partial_sub_cont.cont_contains_sub_container(container, partial=True) + # with module attribute + cont = Container(dict_in, ivyh=ivy) - # sub-structure - sub_struc = Container( - { - "c": ivy.array([3.0], device=on_device), - "d": ivy.array([4.0], device=on_device), - } - ) - assert not container.cont_contains_sub_container(sub_struc) - assert sub_struc not in container - assert container.cont_contains_sub_structure(sub_struc) - assert container.cont_contains_sub_structure(container) + # paddle tansor can't be pickled directly as mentioned + # in the issue https://github.com/PaddlePaddle/Paddle/issues/41107 + if ivy.backend == "paddle": + cont = cont.to_numpy() - # partial sub-structure - partial_sub_struc = Container({"b": {"d": ivy.array([4.0], device=on_device)}}) - assert container.cont_contains_sub_structure(container, partial=True) - assert container.cont_contains_sub_structure(partial_sub_struc, partial=True) - assert not partial_sub_struc.cont_contains_sub_structure(container, partial=True) + assert cont._local_ivy is ivy + pickled = pickle.dumps(cont) + cont_again = pickle.loads(pickled) + # noinspection PyUnresolvedReferences + assert cont_again._local_ivy.current_backend_str() is ivy.current_backend_str() + ivy.Container.cont_identical_structure([cont, cont_again]) + ivy.Container.cont_identical_configs([cont, cont_again]) -@pytest.mark.parametrize("include_empty", [True, False]) -def test_container_to_iterator(include_empty, on_device): - a_val = Container() if include_empty else ivy.array([1], device=on_device) - bc_val = Container() if include_empty else ivy.array([2], device=on_device) - bd_val = Container() if include_empty else ivy.array([3], device=on_device) - dict_in = {"a": a_val, "b": {"c": bc_val, "d": bd_val}} +def test_container_prune_empty(on_device): + dict_in = { + "a": ivy.array([1], device=on_device), + "b": {"c": {}, "d": ivy.array([3], device=on_device)}, + } container = Container(dict_in) + container_pruned = container.cont_prune_empty() + assert np.allclose(ivy.to_numpy(container_pruned["a"]), np.array([[1]])) + assert np.allclose(ivy.to_numpy(container_pruned.a), np.array([[1]])) + assert np.allclose(ivy.to_numpy(container_pruned["b"]["d"]), np.array([[3]])) + assert np.allclose(ivy.to_numpy(container_pruned.b.d), np.array([[3]])) + assert "c" not in container_pruned["b"] - # with key chains - container_iterator = container.cont_to_iterator(include_empty=include_empty) - for (key_chain, value), expected in zip( - container_iterator, [("a", a_val), ("b/c", bc_val), ("b/d", bd_val)] - ): - expected_key_chain = expected[0] - expected_value = expected[1] - assert key_chain == expected_key_chain - assert value is expected_value + def _test_exception(container_in): + try: + _ = container_in.b.c + return False + except AttributeError: + return True - # with leaf keys - container_iterator = container.cont_to_iterator( - leaf_keys_only=True, include_empty=include_empty - ) - for (key_chain, value), expected in zip( - container_iterator, [("a", a_val), ("c", bc_val), ("d", bd_val)] - ): - expected_key_chain = expected[0] - expected_value = expected[1] - assert key_chain == expected_key_chain - assert value is expected_value + assert _test_exception(container_pruned) -@pytest.mark.parametrize("include_empty", [True, False]) -def test_container_to_iterator_values(include_empty, on_device): - a_val = Container() if include_empty else ivy.array([1], device=on_device) - bc_val = Container() if include_empty else ivy.array([2], device=on_device) - bd_val = Container() if include_empty else ivy.array([3], device=on_device) - dict_in = {"a": a_val, "b": {"c": bc_val, "d": bd_val}} +def test_container_prune_key_chain(on_device): + dict_in = { + "a": ivy.array([1], device=on_device), + "b": {"c": ivy.array([2], device=on_device), "d": None}, + } container = Container(dict_in) + container_pruned = container.cont_prune_key_chain("b/c") + assert np.allclose(ivy.to_numpy(container_pruned["a"]), np.array([[1]])) + assert np.allclose(ivy.to_numpy(container_pruned.a), np.array([[1]])) + assert container_pruned["b"]["d"] is None + assert container_pruned.b.d is None + assert "c" not in container_pruned["b"].keys() - # with key chains - container_iterator = container.cont_to_iterator_values(include_empty=include_empty) - for value, expected_value in zip(container_iterator, [a_val, bc_val, bd_val]): - assert value is expected_value - + def _test_exception(container_in): + try: + _ = container_in.b.c + return False + except AttributeError: + return True -@pytest.mark.parametrize("include_empty", [True, False]) -def test_container_to_iterator_keys(include_empty, on_device): - a_val = Container() if include_empty else ivy.array([1], device=on_device) - bc_val = Container() if include_empty else ivy.array([2], device=on_device) - bd_val = Container() if include_empty else ivy.array([3], device=on_device) - dict_in = {"a": a_val, "b": {"c": bc_val, "d": bd_val}} - container = Container(dict_in) + assert _test_exception(container_pruned) - # with key chains - container_iterator = container.cont_to_iterator_keys(include_empty=include_empty) - for key_chain, expected_key_chain in zip(container_iterator, ["a", "b/c", "b/d"]): - assert key_chain == expected_key_chain + container_pruned = container.cont_prune_key_chain("b") + assert np.allclose(ivy.to_numpy(container_pruned["a"]), np.array([[1]])) + assert np.allclose(ivy.to_numpy(container_pruned.a), np.array([[1]])) + assert "b" not in container_pruned.keys() - # with leaf keys - container_iterator = container.cont_to_iterator_keys( - leaf_keys_only=True, include_empty=include_empty - ) - for key, expected_key in zip(container_iterator, ["a", "c", "d"]): - assert key == expected_key + def _test_exception(container_in): + try: + _ = container_in.b + return False + except AttributeError: + return True + + assert _test_exception(container_pruned) -def test_container_to_flat_list(on_device): +def test_container_prune_key_chains(on_device): dict_in = { "a": ivy.array([1], device=on_device), "b": { @@ -1869,40 +2014,82 @@ def test_container_to_flat_list(on_device): }, } container = Container(dict_in) - container_flat_list = container.cont_to_flat_list() - for value, expected_value in zip( - container_flat_list, - [ - ivy.array([1], device=on_device), - ivy.array([2], device=on_device), - ivy.array([3], device=on_device), - ], - ): - assert value == expected_value + container_pruned = container.cont_prune_key_chains(["a", "b/c"]) + assert "a" not in container_pruned + assert np.allclose(ivy.to_numpy(container_pruned["b"]["d"]), np.array([[3]])) + assert np.allclose(ivy.to_numpy(container_pruned.b.d), np.array([[3]])) + assert "c" not in container_pruned["b"] + def _test_a_exception(container_in): + try: + _ = container_in.a + return False + except AttributeError: + return True -def test_container_from_flat_list(on_device): - dict_in = { - "a": ivy.array([1], device=on_device), - "b": { - "c": ivy.array([2], device=on_device), - "d": ivy.array([3], device=on_device), - }, - } - container = Container(dict_in) - flat_list = [4, 5, 6] - container = container.cont_from_flat_list(flat_list) - assert np.allclose(ivy.to_numpy(container["a"]), np.array([4])) - assert np.allclose(ivy.to_numpy(container.a), np.array([4])) - assert np.allclose(ivy.to_numpy(container["b"]["c"]), np.array([5])) - assert np.allclose(ivy.to_numpy(container.b.c), np.array([5])) - assert np.allclose(ivy.to_numpy(container["b"]["d"]), np.array([6])) - assert np.allclose(ivy.to_numpy(container.b.d), np.array([6])) + def _test_bc_exception(container_in): + try: + _ = container_in.b.c + return False + except AttributeError: + return True + assert _test_a_exception(container_pruned) + assert _test_bc_exception(container_pruned) -@pytest.mark.parametrize("inplace", [True, False]) -def test_container_map(inplace, on_device): - # without key_chains specification + container_pruned = container.cont_prune_key_chains( + Container({"a": True, "b": {"c": True}}) + ) + assert "a" not in container_pruned + assert np.allclose(ivy.to_numpy(container_pruned["b"]["d"]), np.array([[3]])) + assert np.allclose(ivy.to_numpy(container_pruned.b.d), np.array([[3]])) + assert "c" not in container_pruned["b"] + assert _test_a_exception(container_pruned) + assert _test_bc_exception(container_pruned) + + +def test_container_prune_key_from_key_chains(on_device): + container = Container( + { + "Ayy": ivy.array([1], device=on_device), + "Bee": { + "Cee": ivy.array([2], device=on_device), + "Dee": ivy.array([3], device=on_device), + }, + "Beh": { + "Ceh": ivy.array([4], device=on_device), + "Deh": ivy.array([5], device=on_device), + }, + } + ) + + # absolute + container_pruned = container.cont_prune_key_from_key_chains("Bee") + assert np.allclose(ivy.to_numpy(container_pruned["Ayy"]), np.array([[1]])) + assert np.allclose(ivy.to_numpy(container_pruned.Ayy), np.array([[1]])) + assert np.allclose(ivy.to_numpy(container_pruned["Cee"]), np.array([[2]])) + assert np.allclose(ivy.to_numpy(container_pruned.Cee), np.array([[2]])) + assert np.allclose(ivy.to_numpy(container_pruned["Dee"]), np.array([[3]])) + assert np.allclose(ivy.to_numpy(container_pruned.Dee), np.array([[3]])) + assert "Bee" not in container_pruned + + # containing + container_pruned = container.cont_prune_key_from_key_chains(containing="B") + assert np.allclose(ivy.to_numpy(container_pruned["Ayy"]), np.array([[1]])) + assert np.allclose(ivy.to_numpy(container_pruned.Ayy), np.array([[1]])) + assert np.allclose(ivy.to_numpy(container_pruned["Cee"]), np.array([[2]])) + assert np.allclose(ivy.to_numpy(container_pruned.Cee), np.array([[2]])) + assert np.allclose(ivy.to_numpy(container_pruned["Dee"]), np.array([[3]])) + assert np.allclose(ivy.to_numpy(container_pruned.Dee), np.array([[3]])) + assert np.allclose(ivy.to_numpy(container_pruned["Ceh"]), np.array([[4]])) + assert np.allclose(ivy.to_numpy(container_pruned.Ceh), np.array([[4]])) + assert np.allclose(ivy.to_numpy(container_pruned["Deh"]), np.array([[5]])) + assert np.allclose(ivy.to_numpy(container_pruned.Deh), np.array([[5]])) + assert "Bee" not in container_pruned + assert "Beh" not in container_pruned + + +def test_container_prune_keys(on_device): dict_in = { "a": ivy.array([1], device=on_device), "b": { @@ -1910,148 +2097,87 @@ def test_container_map(inplace, on_device): "d": ivy.array([3], device=on_device), }, } - container_orig = Container(dict_in) - container = container_orig.cont_deep_copy() - container_mapped = container.cont_map(lambda x, _: x + 1, inplace=inplace) - if inplace: - container_iterator = container.cont_to_iterator() - else: - container_iterator = container_mapped.cont_to_iterator() - for (key, value), expected_value in zip( - container_iterator, - [ - ivy.array([2], device=on_device), - ivy.array([3], device=on_device), - ivy.array([4], device=on_device), - ], - ): - assert ivy.to_numpy(value) == ivy.to_numpy(expected_value) + container = Container(dict_in) + container_pruned = container.cont_prune_keys(["a", "c"]) + assert "a" not in container_pruned + assert np.allclose(ivy.to_numpy(container_pruned["b"]["d"]), np.array([[3]])) + assert np.allclose(ivy.to_numpy(container_pruned.b.d), np.array([[3]])) + assert "c" not in container_pruned["b"] - # with key_chains to apply - container = container_orig.cont_deep_copy() - container_mapped = container.cont_map( - lambda x, _: x + 1, ["a", "b/c"], inplace=inplace - ) - if inplace: - container_mapped = container - assert np.allclose(ivy.to_numpy(container_mapped["a"]), np.array([[2]])) - assert np.allclose(ivy.to_numpy(container_mapped.a), np.array([[2]])) - assert np.allclose(ivy.to_numpy(container_mapped["b"]["c"]), np.array([[3]])) - assert np.allclose(ivy.to_numpy(container_mapped.b.c), np.array([[3]])) - assert np.allclose(ivy.to_numpy(container_mapped["b"]["d"]), np.array([[3]])) - assert np.allclose(ivy.to_numpy(container_mapped.b.d), np.array([[3]])) + def _test_a_exception(container_in): + try: + _ = container_in.a + return False + except AttributeError: + return True - # with key_chains to apply pruned - container = container_orig.cont_deep_copy() - container_mapped = container.cont_map( - lambda x, _: x + 1, ["a", "b/c"], prune_unapplied=True, inplace=inplace - ) - if inplace: - container_mapped = container - assert np.allclose(ivy.to_numpy(container_mapped["a"]), np.array([[2]])) - assert np.allclose(ivy.to_numpy(container_mapped.a), np.array([[2]])) - assert np.allclose(ivy.to_numpy(container_mapped["b"]["c"]), np.array([[3]])) - assert np.allclose(ivy.to_numpy(container_mapped.b.c), np.array([[3]])) - if not inplace: - assert "b/d" not in container_mapped + def _test_bc_exception(container_in): + try: + _ = container_in.b.c + return False + except AttributeError: + return True - # with key_chains to not apply - container = container_orig.cont_deep_copy() - container_mapped = container.cont_map( - lambda x, _: x + 1, - Container({"a": None, "b": {"d": None}}), - to_apply=False, - inplace=inplace, - ) - if inplace: - container_mapped = container - assert np.allclose(ivy.to_numpy(container_mapped["a"]), np.array([[1]])) - assert np.allclose(ivy.to_numpy(container_mapped.a), np.array([[1]])) - assert np.allclose(ivy.to_numpy(container_mapped["b"]["c"]), np.array([[3]])) - assert np.allclose(ivy.to_numpy(container_mapped.b.c), np.array([[3]])) - assert np.allclose(ivy.to_numpy(container_mapped["b"]["d"]), np.array([[3]])) - assert np.allclose(ivy.to_numpy(container_mapped.b.d), np.array([[3]])) + def _test_bd_exception(container_in): + try: + _ = container_in.b.d + return False + except AttributeError: + return True - # with key_chains to not apply pruned - container = container_orig.cont_deep_copy() - container_mapped = container.cont_map( - lambda x, _: x + 1, - Container({"a": None, "b": {"d": None}}), - to_apply=False, - prune_unapplied=True, - inplace=inplace, - ) - if inplace: - container_mapped = container - if not inplace: - assert "a" not in container_mapped - assert np.allclose(ivy.to_numpy(container_mapped["b"]["c"]), np.array([[3]])) - assert np.allclose(ivy.to_numpy(container_mapped.b.c), np.array([[3]])) - if not inplace: - assert "b/d" not in container_mapped + assert _test_a_exception(container_pruned) + assert _test_bc_exception(container_pruned) - # with sequences - container_orig = Container( - { - "a": ivy.array([1], device=on_device), - "b": [ivy.array([2], device=on_device), ivy.array([3], device=on_device)], - } - ) - container = container_orig.cont_deep_copy() - container_mapped = container.cont_map( - lambda x, _: x + 1, inplace=inplace, map_sequences=True - ) - if inplace: - container_mapped = container - assert np.allclose(ivy.to_numpy(container_mapped["a"]), np.array([2])) - assert np.allclose(ivy.to_numpy(container_mapped["b"][0]), np.array([3])) - assert np.allclose(ivy.to_numpy(container_mapped["b"][1]), np.array([4])) + container_pruned = container.cont_prune_keys(["a", "d"]) + assert "a" not in container_pruned + assert np.allclose(ivy.to_numpy(container_pruned["b"]["c"]), np.array([[2]])) + assert np.allclose(ivy.to_numpy(container_pruned.b.c), np.array([[2]])) + assert "d" not in container_pruned["b"] + assert _test_a_exception(container_pruned) + assert _test_bd_exception(container_pruned) -@pytest.mark.parametrize("inplace", [True, False]) -def test_container_map_sub_conts(inplace, on_device): - # without key_chains specification - container_orig = Container( +def test_container_prune_keys_from_key_chains(on_device): + container = Container( { - "a": ivy.array([1], device=on_device), - "b": { - "c": ivy.array([2], device=on_device), - "d": ivy.array([3], device=on_device), + "Ayy": ivy.array([1], device=on_device), + "Bee": { + "Cee": ivy.array([2], device=on_device), + "Dee": ivy.array([3], device=on_device), }, - } - ) - - def _add_e_attr(cont_in): - cont_in.e = ivy.array([4], device=on_device) - return cont_in - - # with self - container = container_orig.cont_deep_copy() - container_mapped = container.cont_map_sub_conts( - lambda c, _: _add_e_attr(c), inplace=inplace + "Eee": {"Fff": ivy.array([4], device=on_device)}, + } ) - if inplace: - container_mapped = container - assert "e" in container_mapped - assert np.array_equal(ivy.to_numpy(container_mapped.e), np.array([4])) - assert "e" in container_mapped.b - assert np.array_equal(ivy.to_numpy(container_mapped.b.e), np.array([4])) - # without self - container = container_orig.cont_deep_copy() - container_mapped = container.cont_map_sub_conts( - lambda c, _: _add_e_attr(c), include_self=False, inplace=inplace - ) - if inplace: - container_mapped = container - assert "e" not in container_mapped - assert "e" in container_mapped.b - assert np.array_equal(ivy.to_numpy(container_mapped.b.e), np.array([4])) + # absolute + container_pruned = container.cont_prune_keys_from_key_chains(["Bee", "Eee"]) + assert np.allclose(ivy.to_numpy(container_pruned["Ayy"]), np.array([[1]])) + assert np.allclose(ivy.to_numpy(container_pruned.Ayy), np.array([[1]])) + assert np.allclose(ivy.to_numpy(container_pruned["Cee"]), np.array([[2]])) + assert np.allclose(ivy.to_numpy(container_pruned.Cee), np.array([[2]])) + assert np.allclose(ivy.to_numpy(container_pruned["Dee"]), np.array([[3]])) + assert np.allclose(ivy.to_numpy(container_pruned.Dee), np.array([[3]])) + assert np.allclose(ivy.to_numpy(container_pruned["Fff"]), np.array([[4]])) + assert np.allclose(ivy.to_numpy(container_pruned.Fff), np.array([[4]])) + assert "Bee" not in container_pruned + assert "Eee" not in container_pruned + + # containing + container_pruned = container.cont_prune_keys_from_key_chains(containing=["B", "E"]) + assert np.allclose(ivy.to_numpy(container_pruned["Ayy"]), np.array([[1]])) + assert np.allclose(ivy.to_numpy(container_pruned.Ayy), np.array([[1]])) + assert np.allclose(ivy.to_numpy(container_pruned["Cee"]), np.array([[2]])) + assert np.allclose(ivy.to_numpy(container_pruned.Cee), np.array([[2]])) + assert np.allclose(ivy.to_numpy(container_pruned["Dee"]), np.array([[3]])) + assert np.allclose(ivy.to_numpy(container_pruned.Dee), np.array([[3]])) + assert np.allclose(ivy.to_numpy(container_pruned["Fff"]), np.array([[4]])) + assert np.allclose(ivy.to_numpy(container_pruned.Fff), np.array([[4]])) + assert "Bee" not in container_pruned + assert "Eee" not in container_pruned -def test_container_multi_map(on_device): - # without key_chains specification - container0 = Container( +def test_container_reduce(on_device): + container_a = ivy.Container( { "a": ivy.array([1], device=on_device), "b": { @@ -2060,354 +2186,360 @@ def test_container_multi_map(on_device): }, } ) - container1 = Container( + container_b = ivy.Container( { - "a": ivy.array([3], device=on_device), + "a": ivy.array([2], device=on_device), "b": { "c": ivy.array([4], device=on_device), - "d": ivy.array([5], device=on_device), + "d": ivy.array([6], device=on_device), }, } ) + res = ivy.Container.cont_reduce([container_a, container_b], lambda x: x[0] + x[1]) + assert np.allclose(ivy.to_numpy(res.a), np.array([3.0])) + assert np.allclose(ivy.to_numpy(res.b.c), np.array([6])) + assert np.allclose(ivy.to_numpy(res.b.d), np.array([9])) - # with key_chains to apply - container_mapped = ivy.Container.cont_multi_map( - lambda x, _: x[0] + x[1], [container0, container1], assert_identical=True - ) - assert np.allclose(ivy.to_numpy(container_mapped["a"]), np.array([[4]])) - assert np.allclose(ivy.to_numpy(container_mapped.a), np.array([[4]])) - assert np.allclose(ivy.to_numpy(container_mapped["b"]["c"]), np.array([[6]])) - assert np.allclose(ivy.to_numpy(container_mapped.b.c), np.array([[6]])) - assert np.allclose(ivy.to_numpy(container_mapped["b"]["d"]), np.array([[8]])) - assert np.allclose(ivy.to_numpy(container_mapped.b.d), np.array([[8]])) - # with sequences - container0 = Container( +def test_container_remove_key_length_limit(on_device): + cont = Container( { - "a": ivy.array([1], device=on_device), - "b": [ - ivy.array([2], device=on_device), - ivy.array([3], device=on_device), - ], + "a": ivy.array([0.0], device=on_device), + "b": { + "c": ivy.array([1.0], device=on_device), + "d": ivy.array([2.0], device=on_device), + }, } ) - container1 = Container( + cont.cont_with_key_length_limit(5, inplace=True) + default_key_length_limit = cont._key_length_limit + id_cont = id(cont) + cont1 = cont.cont_remove_key_length_limit() + assert cont1._key_length_limit is None + assert id(cont1) != id(cont) + assert cont._key_length_limit == default_key_length_limit + assert cont.b._key_length_limit == default_key_length_limit + assert cont._key_length_limit != cont1._key_length_limit + cont.cont_remove_key_length_limit(inplace=True) + assert cont._key_length_limit is None + assert cont.b._key_length_limit is None + assert id(cont) == id_cont + + +def test_container_remove_print_limit(on_device): + cont = Container( { - "a": ivy.array([3], device=on_device), - "b": [ - ivy.array([4], device=on_device), - ivy.array([5], device=on_device), - ], + "a": ivy.array([0.0], device=on_device), + "b": { + "c": ivy.array([1.0], device=on_device), + "d": ivy.array([2.0], device=on_device), + }, } ) - - container_mapped = ivy.Container.cont_multi_map( - lambda x, _: x[0] + x[1], - [container0, container1], - map_nests=True, - assert_identical=True, - ) - - assert np.allclose(ivy.to_numpy(container_mapped["a"]), np.array([4])) - assert np.allclose(ivy.to_numpy(container_mapped["b"][0]), np.array([6])) - assert np.allclose(ivy.to_numpy(container_mapped["b"][1]), np.array([8])) - - # Non identical containers - a = ivy.Container(a={"b": 2, "c": 4}, d={"e": 6, "f": 9}) - b = ivy.Container(a=2, d=3) - container_mapped = ivy.Container.cont_multi_map(lambda xs, _: xs[0] / xs[1], [a, b]) - - assert np.allclose(ivy.to_numpy(container_mapped["a"].b), 1) - assert np.allclose(ivy.to_numpy(container_mapped["a"]["c"]), 2) - assert np.allclose(ivy.to_numpy(container_mapped.d.e), 2) - assert np.allclose(ivy.to_numpy(container_mapped["d"].f), 3) - - -def test_container_common_key_chains(on_device): - arr1 = ivy.array([1], device=on_device) - arr2 = ivy.array([2], device=on_device) - arr3 = ivy.array([3], device=on_device) - cont0 = Container({"a": arr1, "b": {"c": arr2, "d": arr3}}) - cont1 = Container({"b": {"c": arr2, "d": arr3, "e": arr1}}) - cont2 = Container({"a": arr1, "b": {"d": arr3, "e": arr1}}) - - # 0 - common_kcs = Container.cont_common_key_chains([cont0]) - assert len(common_kcs) == 3 - assert "a" in common_kcs - assert "b/c" in common_kcs - assert "b/d" in common_kcs - - # 0-1 - common_kcs = Container.cont_common_key_chains([cont0, cont1]) - assert len(common_kcs) == 2 - assert "b/c" in common_kcs - assert "b/d" in common_kcs - - # 0-2 - common_kcs = Container.cont_common_key_chains([cont0, cont2]) - assert len(common_kcs) == 2 - assert "a" in common_kcs - assert "b/d" in common_kcs - - # 1-2 - common_kcs = Container.cont_common_key_chains([cont1, cont2]) - assert len(common_kcs) == 2 - assert "b/d" in common_kcs - assert "b/e" in common_kcs - - # all - common_kcs = Container.cont_common_key_chains([cont0, cont1, cont2]) - assert len(common_kcs) == 1 - assert "b/d" in common_kcs + default_print_limit = cont._print_limit + id_cont = id(cont) + cont1 = cont.cont_remove_print_limit() + assert cont1._print_limit is None + assert id(cont1) != id(cont) + assert cont._print_limit == default_print_limit + assert cont._print_limit != cont1._print_limit + assert cont.b._print_limit == default_print_limit + cont.cont_remove_print_limit(inplace=True) + assert cont._print_limit is None + assert cont.b._print_limit is None + assert id(cont) == id_cont -def test_container_identical(on_device): - # without key_chains specification - arr1 = ivy.array([1], device=on_device) - arr2 = ivy.array([2], device=on_device) - arr3 = ivy.array([3], device=on_device) - container0 = Container({"a": arr1, "b": {"c": arr2, "d": arr3}}) - container1 = Container({"a": arr1, "b": {"c": arr2, "d": arr3}}) - container2 = Container( +def test_container_reshape_like(on_device): + container = Container( { - "a": ivy.array([1], device=on_device), + "a": ivy.array([[1.0]], device=on_device), "b": { - "c": ivy.array([2], device=on_device), - "d": ivy.array([3], device=on_device), + "c": ivy.array([[3.0], [4.0]], device=on_device), + "d": ivy.array([[5.0], [6.0], [7.0]], device=on_device), }, } ) - container3 = Container({"b": {"d": arr3}}) - container4 = Container({"d": arr3}) - - # the same - assert ivy.Container.cont_identical([container0, container1]) - assert ivy.Container.cont_identical([container1, container0]) + new_shapes = Container({"a": (1,), "b": {"c": (1, 2, 1), "d": (3, 1, 1)}}) - # not the same - assert not ivy.Container.cont_identical([container0, container2]) - assert not ivy.Container.cont_identical([container2, container0]) - assert not ivy.Container.cont_identical([container1, container2]) - assert not ivy.Container.cont_identical([container2, container1]) + # without leading shape + container_reshaped = container.cont_reshape_like(new_shapes) + assert list(container_reshaped["a"].shape) == [1] + assert list(container_reshaped.a.shape) == [1] + assert list(container_reshaped["b"]["c"].shape) == [1, 2, 1] + assert list(container_reshaped.b.c.shape) == [1, 2, 1] + assert list(container_reshaped["b"]["d"].shape) == [3, 1, 1] + assert list(container_reshaped.b.d.shape) == [3, 1, 1] - # partial - assert ivy.Container.cont_identical([container0, container3], partial=True) - assert ivy.Container.cont_identical([container3, container0], partial=True) - assert not ivy.Container.cont_identical([container0, container4], partial=True) - assert not ivy.Container.cont_identical([container4, container0], partial=True) + # with leading shape + container = Container( + { + "a": ivy.array([[[1.0]], [[1.0]], [[1.0]]], device=on_device), + "b": { + "c": ivy.array( + [[[3.0], [4.0]], [[3.0], [4.0]], [[3.0], [4.0]]], device=on_device + ), + "d": ivy.array( + [ + [[5.0], [6.0], [7.0]], + [[5.0], [6.0], [7.0]], + [[5.0], [6.0], [7.0]], + ], + device=on_device, + ), + }, + } + ) + container_reshaped = container.cont_reshape_like(new_shapes, leading_shape=[3]) + assert list(container_reshaped["a"].shape) == [3, 1] + assert list(container_reshaped.a.shape) == [3, 1] + assert list(container_reshaped["b"]["c"].shape) == [3, 1, 2, 1] + assert list(container_reshaped.b.c.shape) == [3, 1, 2, 1] + assert list(container_reshaped["b"]["d"].shape) == [3, 3, 1, 1] + assert list(container_reshaped.b.d.shape) == [3, 3, 1, 1] -def test_container_identical_structure(on_device): - # without key_chains specification - container0 = Container( +def test_container_restructure(on_device): + container = Container( { - "a": ivy.array([1], device=on_device), + "a": ivy.array([[1, 2], [3, 4]], device=on_device), "b": { - "c": ivy.array([2], device=on_device), - "d": ivy.array([3], device=on_device), + "c": ivy.array([[2, 4], [6, 8]], device=on_device), + "d": ivy.array([3, 6, 9, 12], device=on_device), }, } ) - container1 = Container( + container_restructured = container.cont_restructure( { - "a": ivy.array([3], device=on_device), - "b": { - "c": ivy.array([4], device=on_device), - "d": ivy.array([5], device=on_device), + "a": {"key_chain": "A", "pattern": "a b -> b a"}, + "b/c": {"key_chain": "B/C", "pattern": "a b -> (a b)"}, + "b/d": { + "key_chain": "B/D", + "pattern": "(a b) -> a b", + "axes_lengths": {"a": 2, "b": 2}, }, - } + }, + keep_orig=False, ) - container2 = Container( + assert np.allclose( + ivy.to_numpy(container_restructured["A"]), np.array([[1, 3], [2, 4]]) + ) + assert np.allclose( + ivy.to_numpy(container_restructured.A), np.array([[1, 3], [2, 4]]) + ) + assert np.allclose( + ivy.to_numpy(container_restructured["B/C"]), np.array([2, 4, 6, 8]) + ) + assert np.allclose(ivy.to_numpy(container_restructured.B.C), np.array([2, 4, 6, 8])) + assert np.allclose( + ivy.to_numpy(container_restructured["B/D"]), np.array([[3, 6], [9, 12]]) + ) + assert np.allclose( + ivy.to_numpy(container_restructured.B.D), np.array([[3, 6], [9, 12]]) + ) + + +def test_container_restructure_key_chains(on_device): + # single + container = Container( { - "a": ivy.array([3], device=on_device), + "a": ivy.array([1], device=on_device), "b": { - "c": ivy.array([4], device=on_device), - "d": ivy.array([5], device=on_device), - "e": ivy.array([6], device=on_device), + "c": ivy.array([2], device=on_device), + "d": ivy.array([3], device=on_device), }, } ) - container3 = Container( + container_restructured = container.cont_restructure_key_chains({"a": "A"}) + assert np.allclose(ivy.to_numpy(container_restructured["A"]), np.array([[1]])) + assert np.allclose(ivy.to_numpy(container_restructured.A), np.array([[1]])) + assert np.allclose(ivy.to_numpy(container_restructured["b/c"]), np.array([[2]])) + assert np.allclose(ivy.to_numpy(container_restructured.b.c), np.array([[2]])) + assert np.allclose(ivy.to_numpy(container_restructured["b/d"]), np.array([[3]])) + assert np.allclose(ivy.to_numpy(container_restructured.b.d), np.array([[3]])) + + # full + container = Container( { - "a": ivy.array([3], device=on_device), + "a": ivy.array([1], device=on_device), "b": { - "c": ivy.array([4], device=on_device), - "d": ivy.array([5], device=on_device), + "c": ivy.array([2], device=on_device), + "d": ivy.array([3], device=on_device), }, - "e": ivy.array([6], device=on_device), } ) - container4 = Container({"b": {"d": ivy.array([4], device=on_device)}}) - container5 = Container({"d": ivy.array([4], device=on_device)}) - - # with identical - assert ivy.Container.cont_identical_structure([container0, container1]) - assert ivy.Container.cont_identical_structure([container1, container0]) - assert ivy.Container.cont_identical_structure([container1, container0, container1]) - - # without identical - assert not ivy.Container.cont_identical_structure([container2, container3]) - assert not ivy.Container.cont_identical_structure([container0, container3]) - assert not ivy.Container.cont_identical_structure([container1, container2]) - assert not ivy.Container.cont_identical_structure( - [container1, container0, container2] - ) - - # partial - assert ivy.Container.cont_identical_structure( - [container0, container4], partial=True - ) - assert ivy.Container.cont_identical_structure( - [container1, container4], partial=True - ) - assert ivy.Container.cont_identical_structure( - [container2, container4], partial=True - ) - assert ivy.Container.cont_identical_structure( - [container3, container4], partial=True - ) - assert ivy.Container.cont_identical_structure( - [container4, container4], partial=True - ) - assert not ivy.Container.cont_identical_structure( - [container0, container5], partial=True - ) - assert not ivy.Container.cont_identical_structure( - [container1, container5], partial=True - ) - assert not ivy.Container.cont_identical_structure( - [container2, container5], partial=True - ) - assert not ivy.Container.cont_identical_structure( - [container3, container5], partial=True - ) - assert not ivy.Container.cont_identical_structure( - [container4, container5], partial=True + container_restructured = container.cont_restructure_key_chains( + {"a": "A", "b/c": "B/C", "b/d": "B/D"} ) + assert np.allclose(ivy.to_numpy(container_restructured["A"]), np.array([[1]])) + assert np.allclose(ivy.to_numpy(container_restructured.A), np.array([[1]])) + assert np.allclose(ivy.to_numpy(container_restructured["B/C"]), np.array([[2]])) + assert np.allclose(ivy.to_numpy(container_restructured.B.C), np.array([[2]])) + assert np.allclose(ivy.to_numpy(container_restructured["B/D"]), np.array([[3]])) + assert np.allclose(ivy.to_numpy(container_restructured.B.D), np.array([[3]])) -def test_container_identical_configs(on_device): - container0 = Container({"a": ivy.array([1], device=on_device)}, print_limit=5) - container1 = Container({"a": ivy.array([1], device=on_device)}, print_limit=5) - container2 = Container({"a": ivy.array([1], device=on_device)}, print_limit=10) +# noinspection PyUnresolvedReferences +def test_container_set_at_key_chain(on_device): + dict_in = { + "a": ivy.array([1], device=on_device), + "b": { + "c": ivy.array([2], device=on_device), + "d": ivy.array([3], device=on_device), + }, + } + container_orig = Container(dict_in) - # with identical - assert ivy.Container.cont_identical_configs([container0, container1]) - assert ivy.Container.cont_identical_configs([container1, container0]) - assert ivy.Container.cont_identical_configs([container1, container0, container1]) + # explicit function call + container = container_orig.cont_copy() + container = container.cont_set_at_key_chain("b/e", ivy.array([4], device=on_device)) + assert np.allclose(ivy.to_numpy(container["a"]), np.array([1])) + assert np.allclose(ivy.to_numpy(container["b"]["c"]), np.array([2])) + assert np.allclose(ivy.to_numpy(container["b"]["d"]), np.array([3])) + assert np.allclose(ivy.to_numpy(container["b"]["e"]), np.array([4])) + container = container.cont_set_at_key_chain("f", ivy.array([5], device=on_device)) + assert np.allclose(ivy.to_numpy(container["a"]), np.array([1])) + assert np.allclose(ivy.to_numpy(container["b"]["c"]), np.array([2])) + assert np.allclose(ivy.to_numpy(container["b"]["d"]), np.array([3])) + assert np.allclose(ivy.to_numpy(container["b"]["e"]), np.array([4])) + assert np.allclose(ivy.to_numpy(container["f"]), np.array([5])) - # without identical - assert not ivy.Container.cont_identical_configs([container1, container2]) - assert not ivy.Container.cont_identical_configs( - [container1, container0, container2] - ) + # overridden built-in function call + container = container_orig.cont_copy() + assert "b/e" not in container + container["b/e"] = ivy.array([4], device=on_device) + assert np.allclose(ivy.to_numpy(container["a"]), np.array([1])) + assert np.allclose(ivy.to_numpy(container["b"]["c"]), np.array([2])) + assert np.allclose(ivy.to_numpy(container["b"]["d"]), np.array([3])) + assert np.allclose(ivy.to_numpy(container["b"]["e"]), np.array([4])) + assert "f" not in container + container["f"] = ivy.array([5], device=on_device) + assert np.allclose(ivy.to_numpy(container["a"]), np.array([1])) + assert np.allclose(ivy.to_numpy(container["b"]["c"]), np.array([2])) + assert np.allclose(ivy.to_numpy(container["b"]["d"]), np.array([3])) + assert np.allclose(ivy.to_numpy(container["b"]["e"]), np.array([4])) + assert np.allclose(ivy.to_numpy(container["f"]), np.array([5])) -def test_container_identical_array_shapes(on_device): - # without key_chains specification - container0 = Container( - { - "a": ivy.array([1, 2], device=on_device), - "b": { - "c": ivy.array([2, 3, 4], device=on_device), - "d": ivy.array([3, 4, 5, 6], device=on_device), - }, - } - ) - container1 = Container( +def test_container_set_at_key_chains(on_device): + container = Container( { - "a": ivy.array([1, 2, 3, 4], device=on_device), + "a": ivy.array([1], device=on_device), "b": { - "c": ivy.array([3, 4], device=on_device), - "d": ivy.array([3, 4, 5], device=on_device), + "c": ivy.array([2], device=on_device), + "d": ivy.array([3], device=on_device), }, } ) - container2 = Container( + target_container = Container( { - "a": ivy.array([1, 2, 3, 4], device=on_device), - "b": { - "c": ivy.array([3, 4], device=on_device), - "d": ivy.array([3, 4, 5, 6], device=on_device), - }, + "a": ivy.array([4], device=on_device), + "b": {"d": ivy.array([5], device=on_device)}, } ) + new_container = container.cont_set_at_key_chains(target_container, inplace=False) + assert np.allclose(ivy.to_numpy(new_container["a"]), np.array([4])) + assert np.allclose(ivy.to_numpy(new_container["b"]["c"]), np.array([2])) + assert np.allclose(ivy.to_numpy(new_container["b"]["d"]), np.array([5])) + target_container = Container({"b": {"c": ivy.array([7], device=on_device)}}) + new_container = container.cont_set_at_key_chains(target_container, inplace=False) + assert np.allclose(ivy.to_numpy(new_container["a"]), np.array([1])) + assert np.allclose(ivy.to_numpy(new_container["b"]["c"]), np.array([7])) + assert np.allclose(ivy.to_numpy(new_container["b"]["d"]), np.array([3])) - # with identical - assert ivy.Container.cont_identical_array_shapes([container0, container1]) - assert ivy.Container.cont_identical_array_shapes([container1, container0]) - assert ivy.Container.cont_identical_array_shapes( - [container1, container0, container1] + +# noinspection PyUnresolvedReferences +def test_container_set_at_keys(on_device): + dict_in = { + "a": ivy.array([1], device=on_device), + "b": { + "c": ivy.array([2], device=on_device), + "d": ivy.array([3], device=on_device), + }, + } + container_orig = Container(dict_in) + + # explicit function call + orig_container = container_orig.cont_copy() + container = orig_container.cont_set_at_keys({"b": ivy.array([4], device=on_device)}) + assert np.allclose(ivy.to_numpy(container["a"]), np.array([1])) + assert np.allclose(ivy.to_numpy(container["b"]), np.array([4])) + assert not container.cont_has_key("c") # noqa + assert not container.cont_has_key("d") # noqa + container = orig_container.cont_set_at_keys( + {"a": ivy.array([5], device=on_device), "c": ivy.array([6], device=on_device)} ) - assert not ivy.Container.cont_identical([container0, container2]) - assert not ivy.Container.cont_identical([container1, container2]) - assert not ivy.Container.cont_identical([container0, container1, container2]) + assert np.allclose(ivy.to_numpy(container["a"]), np.array([5])) + assert np.allclose(ivy.to_numpy(container["b"]["c"]), np.array([6])) + assert np.allclose(ivy.to_numpy(container["b"]["d"]), np.array([3])) + + +def test_container_shapes(on_device): + dict_in = { + "a": ivy.array([[[1.0], [2.0], [3.0]]], device=on_device), + "b": { + "c": ivy.array([[[2.0], [4.0]]], device=on_device), + "d": ivy.array([[9.0]], device=on_device), + }, + } + container_shapes = Container(dict_in).cont_shapes + assert list(container_shapes["a"]) == [1, 3, 1] + assert list(container_shapes.a) == [1, 3, 1] + assert list(container_shapes["b"]["c"]) == [1, 2, 1] + assert list(container_shapes.b.c) == [1, 2, 1] + assert list(container_shapes["b"]["d"]) == [1, 1] + assert list(container_shapes.b.d) == [1, 1] -def test_container_with_entries_as_lists(on_device): - if ivy.current_backend_str() == "tensorflow": - # to_list() requires eager execution - pytest.skip() +def test_container_show(on_device): dict_in = { "a": ivy.array([1], device=on_device), - "b": {"c": ivy.array([2.0], device=on_device), "d": "some string"}, + "b": { + "c": ivy.array([2], device=on_device), + "d": ivy.array([3], device=on_device), + }, } - container = Container(dict_in) - container_w_list_entries = container.cont_with_entries_as_lists() - for (key, value), expected_value in zip( - container_w_list_entries.cont_to_iterator(), [[1], [2.0], "some string"] - ): - assert value == expected_value + cont = Container(dict_in) + print(cont) + cont.cont_show() -def test_container_reshape_like(on_device): - container = Container( - { - "a": ivy.array([[1.0]], device=on_device), - "b": { - "c": ivy.array([[3.0], [4.0]], device=on_device), - "d": ivy.array([[5.0], [6.0], [7.0]], device=on_device), - }, - } - ) - new_shapes = Container({"a": (1,), "b": {"c": (1, 2, 1), "d": (3, 1, 1)}}) +def test_container_show_sub_container(on_device): + dict_in = { + "a": ivy.array([1], device=on_device), + "b": { + "c": ivy.array([2], device=on_device), + "d": ivy.array([3], device=on_device), + }, + } + top_cont = Container(dict_in) + sub_cont = Container(dict_in["b"]) + top_cont.cont_show_sub_container("b") + top_cont.cont_show_sub_container(sub_cont) - # without leading shape - container_reshaped = container.cont_reshape_like(new_shapes) - assert list(container_reshaped["a"].shape) == [1] - assert list(container_reshaped.a.shape) == [1] - assert list(container_reshaped["b"]["c"].shape) == [1, 2, 1] - assert list(container_reshaped.b.c.shape) == [1, 2, 1] - assert list(container_reshaped["b"]["d"].shape) == [3, 1, 1] - assert list(container_reshaped.b.d.shape) == [3, 1, 1] - # with leading shape - container = Container( - { - "a": ivy.array([[[1.0]], [[1.0]], [[1.0]]], device=on_device), - "b": { - "c": ivy.array( - [[[3.0], [4.0]], [[3.0], [4.0]], [[3.0], [4.0]]], device=on_device - ), - "d": ivy.array( - [ - [[5.0], [6.0], [7.0]], - [[5.0], [6.0], [7.0]], - [[5.0], [6.0], [7.0]], - ], - device=on_device, - ), - }, - } - ) - container_reshaped = container.cont_reshape_like(new_shapes, leading_shape=[3]) - assert list(container_reshaped["a"].shape) == [3, 1] - assert list(container_reshaped.a.shape) == [3, 1] - assert list(container_reshaped["b"]["c"].shape) == [3, 1, 2, 1] - assert list(container_reshaped.b.c.shape) == [3, 1, 2, 1] - assert list(container_reshaped["b"]["d"].shape) == [3, 3, 1, 1] - assert list(container_reshaped.b.d.shape) == [3, 3, 1, 1] +def test_container_size_ordered_arrays(on_device): + dict_in = { + "a": ivy.array([[0.0, 1.0, 2.0, 3.0]], device=on_device), + "b": { + "c": ivy.array([[5.0, 10.0]], device=on_device), + "d": ivy.array([[10.0, 9.0, 8.0]], device=on_device), + }, + } + container = Container(dict_in) + size_ordered = container.cont_size_ordered_arrays() + assert np.allclose(ivy.to_numpy(size_ordered.a), np.array([[0.0, 1.0, 2.0, 3.0]])) + assert np.allclose(ivy.to_numpy(size_ordered.b__c), np.array([[5.0, 10.0]])) + assert np.allclose(ivy.to_numpy(size_ordered.b__d), np.array([[10.0, 9.0, 8.0]])) + for v, arr in zip( + size_ordered.values(), + [ + np.array([[5.0, 10.0]]), + np.array([[10.0, 9.0, 8.0]]), + np.array([[0.0, 1.0, 2.0, 3.0]]), + ], + ): + assert np.allclose(ivy.to_numpy(v), arr) def test_container_slice(on_device): @@ -2435,6 +2567,82 @@ def test_container_slice(on_device): assert np.array_equal(ivy.to_numpy(container1.b.d), np.array([3.0])) +@pytest.mark.parametrize("str_slice", [True, False]) +def test_container_slice_keys(str_slice, on_device): + # values + a_val = ivy.array([1], device=on_device) + b_val = ivy.array([2], device=on_device) + c_val = ivy.array([3], device=on_device) + d_val = ivy.array([4], device=on_device) + e_val = ivy.array([5], device=on_device) + + # slice + if str_slice: + slc = "b:d" + else: + slc = slice(1, 4, 1) + + # without dict + cont = Container({"a": a_val, "b": b_val, "c": c_val, "d": d_val, "e": e_val}) + cont_sliced = cont.cont_slice_keys(slc) + assert "a" not in cont_sliced + assert np.allclose(ivy.to_numpy(cont_sliced.b), ivy.to_numpy(b_val)) + assert np.allclose(ivy.to_numpy(cont_sliced.c), ivy.to_numpy(c_val)) + assert np.allclose(ivy.to_numpy(cont_sliced.d), ivy.to_numpy(d_val)) + assert "e" not in cont_sliced + + # with dict, depth 0 + sub_cont = Container({"a": a_val, "b": b_val, "c": c_val, "d": d_val, "e": e_val}) + cont = Container( + {"a": sub_cont, "b": sub_cont, "c": sub_cont, "d": sub_cont, "e": sub_cont} + ) + cont_sliced = cont.cont_slice_keys({0: slc}) + assert "a" not in cont_sliced + assert Container.cont_identical([cont_sliced.b, sub_cont]) + assert Container.cont_identical([cont_sliced.c, sub_cont]) + assert Container.cont_identical([cont_sliced.d, sub_cont]) + assert "e" not in cont_sliced + + # with dict, depth 1 + sub_cont = Container({"a": a_val, "b": b_val, "c": c_val, "d": d_val, "e": e_val}) + sub_sub_cont = Container({"b": b_val, "c": c_val, "d": d_val}) + cont = Container( + {"a": sub_cont, "b": sub_cont, "c": sub_cont, "d": sub_cont, "e": sub_cont} + ) + cont_sliced = cont.cont_slice_keys({1: slc}) + assert Container.cont_identical([cont_sliced.a, sub_sub_cont]) + assert Container.cont_identical([cont_sliced.b, sub_sub_cont]) + assert Container.cont_identical([cont_sliced.c, sub_sub_cont]) + assert Container.cont_identical([cont_sliced.d, sub_sub_cont]) + assert Container.cont_identical([cont_sliced.e, sub_sub_cont]) + + # with dict, depth 0, 1 + sub_cont = Container({"a": a_val, "b": b_val, "c": c_val, "d": d_val, "e": e_val}) + sub_sub_cont = Container({"b": b_val, "c": c_val, "d": d_val}) + cont = Container( + {"a": sub_cont, "b": sub_cont, "c": sub_cont, "d": sub_cont, "e": sub_cont} + ) + cont_sliced = cont.cont_slice_keys({0: slc, 1: slc}) + assert "a" not in cont_sliced + assert Container.cont_identical([cont_sliced.b, sub_sub_cont]) + assert Container.cont_identical([cont_sliced.c, sub_sub_cont]) + assert Container.cont_identical([cont_sliced.d, sub_sub_cont]) + assert "e" not in cont_sliced + + # all depths + sub_cont = Container({"a": a_val, "b": b_val, "c": c_val, "d": d_val, "e": e_val}) + sub_sub_cont = Container({"b": b_val, "c": c_val, "d": d_val}) + cont = Container( + {"a": sub_cont, "b": sub_cont, "c": sub_cont, "d": sub_cont, "e": sub_cont} + ) + cont_sliced = cont.cont_slice_keys(slc, all_depths=True) + assert "a" not in cont_sliced + assert Container.cont_identical([cont_sliced.b, sub_sub_cont]) + assert Container.cont_identical([cont_sliced.c, sub_sub_cont]) + assert Container.cont_identical([cont_sliced.d, sub_sub_cont]) + assert "e" not in cont_sliced + + def test_container_slice_via_key(on_device): dict_in = { "a": { @@ -2459,625 +2667,504 @@ def test_container_slice_via_key(on_device): assert np.array_equal(ivy.to_numpy(containerx.a), np.array([0.0])) assert np.array_equal(ivy.to_numpy(containerx["b"]["c"]), np.array([1.0])) assert np.array_equal(ivy.to_numpy(containerx.b.c), np.array([1.0])) - assert np.array_equal(ivy.to_numpy(containerx["b"]["d"]), np.array([2.0])) - assert np.array_equal(ivy.to_numpy(containerx.b.d), np.array([2.0])) - assert np.array_equal(ivy.to_numpy(containery["a"]), np.array([1.0])) - assert np.array_equal(ivy.to_numpy(containery.a), np.array([1.0])) - assert np.array_equal(ivy.to_numpy(containery["b"]["c"]), np.array([2.0])) - assert np.array_equal(ivy.to_numpy(containery.b.c), np.array([2.0])) - assert np.array_equal(ivy.to_numpy(containery["b"]["d"]), np.array([3.0])) - assert np.array_equal(ivy.to_numpy(containery.b.d), np.array([3.0])) - - -def test_container_to_and_from_disk_as_hdf5(on_device): - if ivy.current_backend_str() == "tensorflow": - # container disk saving requires eager execution - pytest.skip() - save_filepath = "container_on_disk.hdf5" - dict_in_1 = { - "a": ivy.array([np.float32(1.0)], device=on_device), - "b": { - "c": ivy.array([np.float32(2.0)], device=on_device), - "d": ivy.array([np.float32(3.0)], device=on_device), - }, - } - container1 = Container(dict_in_1) - dict_in_2 = { - "a": ivy.array([np.float32(1.0), np.float32(1.0)], device=on_device), - "b": { - "c": ivy.array([np.float32(2.0), np.float32(2.0)], device=on_device), - "d": ivy.array([np.float32(3.0), np.float32(3.0)], device=on_device), - }, - } - container2 = Container(dict_in_2) - - # saving - container1.cont_to_disk_as_hdf5(save_filepath, max_batch_size=2) - assert os.path.exists(save_filepath) - - # loading - loaded_container = Container.cont_from_disk_as_hdf5(save_filepath, slice(1)) - assert np.array_equal(ivy.to_numpy(loaded_container.a), ivy.to_numpy(container1.a)) - assert np.array_equal( - ivy.to_numpy(loaded_container.b.c), ivy.to_numpy(container1.b.c) - ) - assert np.array_equal( - ivy.to_numpy(loaded_container.b.d), ivy.to_numpy(container1.b.d) - ) - - # appending - container1.cont_to_disk_as_hdf5(save_filepath, max_batch_size=2, starting_index=1) - assert os.path.exists(save_filepath) - - # loading after append - loaded_container = Container.cont_from_disk_as_hdf5(save_filepath) - assert np.array_equal(ivy.to_numpy(loaded_container.a), ivy.to_numpy(container2.a)) - assert np.array_equal( - ivy.to_numpy(loaded_container.b.c), ivy.to_numpy(container2.b.c) - ) - assert np.array_equal( - ivy.to_numpy(loaded_container.b.d), ivy.to_numpy(container2.b.d) - ) - - # load slice - loaded_sliced_container = Container.cont_from_disk_as_hdf5( - save_filepath, slice(1, 2) - ) - assert np.array_equal( - ivy.to_numpy(loaded_sliced_container.a), ivy.to_numpy(container1.a) - ) - assert np.array_equal( - ivy.to_numpy(loaded_sliced_container.b.c), ivy.to_numpy(container1.b.c) - ) - assert np.array_equal( - ivy.to_numpy(loaded_sliced_container.b.d), ivy.to_numpy(container1.b.d) - ) - - # file size - file_size, batch_size = Container.h5_file_size(save_filepath) - assert file_size == 6 * np.dtype(np.float32).itemsize - assert batch_size == 2 - - os.remove(save_filepath) + assert np.array_equal(ivy.to_numpy(containerx["b"]["d"]), np.array([2.0])) + assert np.array_equal(ivy.to_numpy(containerx.b.d), np.array([2.0])) + assert np.array_equal(ivy.to_numpy(containery["a"]), np.array([1.0])) + assert np.array_equal(ivy.to_numpy(containery.a), np.array([1.0])) + assert np.array_equal(ivy.to_numpy(containery["b"]["c"]), np.array([2.0])) + assert np.array_equal(ivy.to_numpy(containery.b.c), np.array([2.0])) + assert np.array_equal(ivy.to_numpy(containery["b"]["d"]), np.array([3.0])) + assert np.array_equal(ivy.to_numpy(containery.b.d), np.array([3.0])) -def test_container_to_disk_shuffle_and_from_disk_as_hdf5(on_device): - if ivy.current_backend_str() == "tensorflow": - # container disk saving requires eager execution - pytest.skip() - save_filepath = "container_on_disk.hdf5" +def test_container_sort_by_key(on_device): dict_in = { - "a": ivy.array([1, 2, 3], device=on_device), - "b": { - "c": ivy.array([1, 2, 3], device=on_device), - "d": ivy.array([1, 2, 3], device=on_device), + "b": ivy.array([1], device=on_device), + "a": { + "d": ivy.array([2], device=on_device), + "c": ivy.array([3], device=on_device), }, } container = Container(dict_in) - - # saving - container.cont_to_disk_as_hdf5(save_filepath, max_batch_size=3) - assert os.path.exists(save_filepath) - - # shuffling - Container.shuffle_h5_file(save_filepath) - - # loading - container_shuffled = Container.cont_from_disk_as_hdf5(save_filepath, slice(3)) - - # testing - data = np.array([1, 2, 3]) - random.seed(0) - random.shuffle(data) - - assert (ivy.to_numpy(container_shuffled["a"]) == data).all() - assert (ivy.to_numpy(container_shuffled.a) == data).all() - assert (ivy.to_numpy(container_shuffled["b"]["c"]) == data).all() - assert (ivy.to_numpy(container_shuffled.b.c) == data).all() - assert (ivy.to_numpy(container_shuffled["b"]["d"]) == data).all() - assert (ivy.to_numpy(container_shuffled.b.d) == data).all() - - os.remove(save_filepath) + container_sorted = container.cont_sort_by_key() + for k, k_true in zip(container_sorted.keys(), ["a", "b"]): + assert k == k_true + for k, k_true in zip(container_sorted.a.keys(), ["c", "d"]): + assert k == k_true -def test_container_pickle(on_device): +def test_container_split_conts(on_device): dict_in = { - "a": ivy.array([np.float32(1.0)], device=on_device), + "a": ivy.array([[1], [2], [3]], device=on_device), "b": { - "c": ivy.array([np.float32(2.0)], device=on_device), - "d": ivy.array([np.float32(3.0)], device=on_device), + "c": ivy.array([[2], [3], [4]], device=on_device), + "d": ivy.array([[3], [4], [5]], device=on_device), }, } + container = Container(dict_in) - # without module attribute - cont = Container(dict_in) + # without key_chains specification + container_split = container.split_conts(1, -1) + for cont, a, bc, bd in zip(container_split, [1, 2, 3], [2, 3, 4], [3, 4, 5]): + assert np.array_equal(ivy.to_numpy(cont["a"])[0], np.array([a])) + assert np.array_equal(ivy.to_numpy(cont.a)[0], np.array([a])) + assert np.array_equal(ivy.to_numpy(cont["b"]["c"])[0], np.array([bc])) + assert np.array_equal(ivy.to_numpy(cont.b.c)[0], np.array([bc])) + assert np.array_equal(ivy.to_numpy(cont["b"]["d"])[0], np.array([bd])) + assert np.array_equal(ivy.to_numpy(cont.b.d)[0], np.array([bd])) - # paddle tansor can't be pickled directly as mentioned - # in the issue https://github.com/PaddlePaddle/Paddle/issues/41107 - if ivy.backend == "paddle": - cont = cont.to_numpy() - assert cont._local_ivy is None - pickled = pickle.dumps(cont) - cont_again = pickle.loads(pickled) - assert cont_again._local_ivy is None - ivy.Container.cont_identical_structure([cont, cont_again]) - ivy.Container.cont_identical_configs([cont, cont_again]) +def test_container_structural_diff(on_device): + # all different keys or shapes + container_0 = Container( + { + "a": ivy.array([1], device=on_device), + "b": { + "c": ivy.array([2], device=on_device), + "d": ivy.array([3], device=on_device), + }, + } + ) + container_1 = Container( + { + "a": ivy.array([[4]], device=on_device), + "b": { + "c": ivy.array([[[5]]], device=on_device), + "e": ivy.array([3], device=on_device), + }, + } + ) + container_diff = ivy.Container.cont_structural_diff(container_0, container_1) + assert np.equal(ivy.to_numpy(container_diff.a.diff_0), np.array([1])) + assert np.equal(ivy.to_numpy(container_diff.a.diff_1), np.array([[4]])) + assert np.equal(ivy.to_numpy(container_diff.b.c.diff_0), np.array([2])) + assert np.equal(ivy.to_numpy(container_diff.b.c.diff_1), np.array([[[5]]])) + assert np.equal(ivy.to_numpy(container_diff.b.d.diff_0), np.array([3])) + assert np.equal(ivy.to_numpy(container_diff.b.e.diff_1), np.array([3])) + container_diff_diff_only = ivy.Container.cont_structural_diff( + container_0, container_1, mode="diff_only" + ) + assert container_diff_diff_only.cont_to_dict() == container_diff.cont_to_dict() + container_diff_same_only = ivy.Container.cont_structural_diff( + container_0, container_1, mode="same_only" + ) + assert container_diff_same_only.cont_to_dict() == {} - # with module attribute - cont = Container(dict_in, ivyh=ivy) + # some different shapes + container_0 = Container( + { + "a": ivy.array([1], device=on_device), + "b": { + "c": ivy.array([2], device=on_device), + "d": ivy.array([3], device=on_device), + }, + } + ) + container_1 = Container( + { + "a": ivy.array([4], device=on_device), + "b": { + "c": ivy.array([[5]], device=on_device), + "d": ivy.array([6], device=on_device), + }, + } + ) + container_diff = ivy.Container.cont_structural_diff(container_0, container_1) + assert np.equal(ivy.to_numpy(container_diff.a), np.array([1])) + assert np.equal(ivy.to_numpy(container_diff.b.c.diff_0), np.array([2])) + assert np.equal(ivy.to_numpy(container_diff.b.c.diff_1), np.array([5])) + assert np.equal(ivy.to_numpy(container_diff.b.d), np.array([3])) + container_diff_diff_only = ivy.Container.cont_structural_diff( + container_0, container_1, mode="diff_only" + ) + assert "a" not in container_diff_diff_only + assert "b" in container_diff_diff_only + assert "c" in container_diff_diff_only["b"] + assert "d" not in container_diff_diff_only["b"] + container_diff_same_only = ivy.Container.cont_structural_diff( + container_0, container_1, mode="same_only" + ) + assert "a" in container_diff_same_only + assert "b" in container_diff_same_only + assert "c" not in container_diff_same_only["b"] + assert "d" in container_diff_same_only["b"] - # paddle tansor can't be pickled directly as mentioned - # in the issue https://github.com/PaddlePaddle/Paddle/issues/41107 - if ivy.backend == "paddle": - cont = cont.to_numpy() + # all different keys + container_0 = Container( + { + "a": ivy.array([1], device=on_device), + "b": { + "c": ivy.array([2], device=on_device), + "d": ivy.array([3], device=on_device), + }, + } + ) + container_1 = Container( + { + "e": ivy.array([4], device=on_device), + "f": { + "g": ivy.array([5], device=on_device), + "h": ivy.array([6], device=on_device), + }, + } + ) + container_diff = ivy.Container.cont_structural_diff(container_0, container_1) + assert np.equal(ivy.to_numpy(container_diff.a.diff_0), np.array([1])) + assert np.equal(ivy.to_numpy(container_diff.b.diff_0.c), np.array([2])) + assert np.equal(ivy.to_numpy(container_diff.b.diff_0.d), np.array([3])) + assert np.equal(ivy.to_numpy(container_diff.e.diff_1), np.array([4])) + assert np.equal(ivy.to_numpy(container_diff.f.diff_1.g), np.array([5])) + assert np.equal(ivy.to_numpy(container_diff.f.diff_1.h), np.array([6])) + container_diff_diff_only = ivy.Container.cont_structural_diff( + container_0, container_1, mode="diff_only" + ) + assert container_diff_diff_only.cont_to_dict() == container_diff.cont_to_dict() + container_diff_same_only = ivy.Container.cont_structural_diff( + container_0, container_1, mode="same_only" + ) + assert container_diff_same_only.cont_to_dict() == {} - assert cont._local_ivy is ivy - pickled = pickle.dumps(cont) - cont_again = pickle.loads(pickled) - # noinspection PyUnresolvedReferences - assert cont_again._local_ivy.current_backend_str() is ivy.current_backend_str() - ivy.Container.cont_identical_structure([cont, cont_again]) - ivy.Container.cont_identical_configs([cont, cont_again]) + # some different keys + container_0 = Container( + { + "a": ivy.array([1], device=on_device), + "b": { + "c": ivy.array([2], device=on_device), + "d": ivy.array([3], device=on_device), + }, + } + ) + container_1 = Container( + { + "a": ivy.array([4], device=on_device), + "b": { + "c": ivy.array([5], device=on_device), + "e": ivy.array([6], device=on_device), + }, + } + ) + container_diff = ivy.Container.cont_structural_diff(container_0, container_1) + assert np.equal(ivy.to_numpy(container_diff.a), np.array([1])) + assert np.equal(ivy.to_numpy(container_diff.b.c), np.array([2])) + assert np.equal(ivy.to_numpy(container_diff.b.d.diff_0), np.array([3])) + assert np.equal(ivy.to_numpy(container_diff.b.e.diff_1), np.array([6])) + container_diff_diff_only = ivy.Container.cont_structural_diff( + container_0, container_1, mode="diff_only" + ) + assert "a" not in container_diff_diff_only + assert "b" in container_diff_diff_only + assert "c" not in container_diff_diff_only["b"] + assert "d" in container_diff_diff_only["b"] + assert "e" in container_diff_diff_only["b"] + container_diff_same_only = ivy.Container.cont_structural_diff( + container_0, container_1, mode="same_only" + ) + assert "a" in container_diff_same_only + assert "b" in container_diff_same_only + assert "c" in container_diff_same_only["b"] + assert "d" not in container_diff_same_only["b"] + assert "e" not in container_diff_same_only["b"] + + # all same + container_0 = Container( + { + "a": ivy.array([1], device=on_device), + "b": { + "c": ivy.array([2], device=on_device), + "d": ivy.array([3], device=on_device), + }, + } + ) + container_1 = Container( + { + "a": ivy.array([4], device=on_device), + "b": { + "c": ivy.array([5], device=on_device), + "d": ivy.array([6], device=on_device), + }, + } + ) + container_diff = ivy.Container.cont_structural_diff(container_0, container_1) + assert np.equal(ivy.to_numpy(container_diff.a), np.array([1])) + assert np.equal(ivy.to_numpy(container_diff.b.c), np.array([2])) + assert np.equal(ivy.to_numpy(container_diff.b.d), np.array([3])) + container_diff_diff_only = ivy.Container.cont_structural_diff( + container_0, container_1, mode="diff_only" + ) + assert container_diff_diff_only.cont_to_dict() == {} + container_diff_same_only = ivy.Container.cont_structural_diff( + container_0, container_1, mode="same_only" + ) + assert container_diff_same_only.cont_to_dict() == container_diff.cont_to_dict() -def test_container_to_and_from_disk_as_pickled(on_device): - save_filepath = "container_on_disk.pickled" - dict_in = { +def test_container_to_and_from_disk_as_hdf5(on_device): + if ivy.current_backend_str() == "tensorflow": + # container disk saving requires eager execution + pytest.skip() + save_filepath = "container_on_disk.hdf5" + dict_in_1 = { "a": ivy.array([np.float32(1.0)], device=on_device), "b": { "c": ivy.array([np.float32(2.0)], device=on_device), "d": ivy.array([np.float32(3.0)], device=on_device), }, } - container = Container(dict_in) - - # paddle tansor can't be pickled directly as mentioned - # in the issue https://github.com/PaddlePaddle/Paddle/issues/41107 - if ivy.backend == "paddle": - container = container.to_numpy() + container1 = Container(dict_in_1) + dict_in_2 = { + "a": ivy.array([np.float32(1.0), np.float32(1.0)], device=on_device), + "b": { + "c": ivy.array([np.float32(2.0), np.float32(2.0)], device=on_device), + "d": ivy.array([np.float32(3.0), np.float32(3.0)], device=on_device), + }, + } + container2 = Container(dict_in_2) # saving - container.cont_to_disk_as_pickled(save_filepath) + container1.cont_to_disk_as_hdf5(save_filepath, max_batch_size=2) assert os.path.exists(save_filepath) # loading - loaded_container = Container.cont_from_disk_as_pickled(save_filepath) - assert np.array_equal(ivy.to_numpy(loaded_container.a), ivy.to_numpy(container.a)) + loaded_container = Container.cont_from_disk_as_hdf5(save_filepath, slice(1)) + assert np.array_equal(ivy.to_numpy(loaded_container.a), ivy.to_numpy(container1.a)) assert np.array_equal( - ivy.to_numpy(loaded_container.b.c), ivy.to_numpy(container.b.c) + ivy.to_numpy(loaded_container.b.c), ivy.to_numpy(container1.b.c) ) assert np.array_equal( - ivy.to_numpy(loaded_container.b.d), ivy.to_numpy(container.b.d) + ivy.to_numpy(loaded_container.b.d), ivy.to_numpy(container1.b.d) ) - os.remove(save_filepath) - - -def test_container_to_and_from_disk_as_json(on_device): - save_filepath = "container_on_disk.json" - dict_in = { - "a": 1.274e-7, - "b": {"c": True, "d": ivy.array([np.float32(3.0)], device=on_device)}, - } - container = Container(dict_in) - - # saving - container.cont_to_disk_as_json(save_filepath) + # appending + container1.cont_to_disk_as_hdf5(save_filepath, max_batch_size=2, starting_index=1) assert os.path.exists(save_filepath) - # loading - loaded_container = Container.cont_from_disk_as_json(save_filepath) - assert np.array_equal(loaded_container.a, container.a) - assert np.array_equal(loaded_container.b.c, container.b.c) - assert isinstance(loaded_container.b.d, str) - - os.remove(save_filepath) - - -def test_container_shapes(on_device): - dict_in = { - "a": ivy.array([[[1.0], [2.0], [3.0]]], device=on_device), - "b": { - "c": ivy.array([[[2.0], [4.0]]], device=on_device), - "d": ivy.array([[9.0]], device=on_device), - }, - } - container_shapes = Container(dict_in).cont_shapes - assert list(container_shapes["a"]) == [1, 3, 1] - assert list(container_shapes.a) == [1, 3, 1] - assert list(container_shapes["b"]["c"]) == [1, 2, 1] - assert list(container_shapes.b.c) == [1, 2, 1] - assert list(container_shapes["b"]["d"]) == [1, 1] - assert list(container_shapes.b.d) == [1, 1] - - -def test_container_dev_str(on_device): - dict_in = { - "a": ivy.array([[[1.0], [2.0], [3.0]]], device=on_device), - "b": { - "c": ivy.array([[[2.0], [4.0], [6.0]]], device=on_device), - "d": ivy.array([[[3.0], [6.0], [9.0]]], device=on_device), - }, - } - container = Container(dict_in) - assert container.cont_dev_str == on_device - - -def test_container_create_if_absent(on_device): - dict_in = { - "a": ivy.array([[[1.0], [2.0], [3.0]]], device=on_device), - "b": { - "c": ivy.array([[[2.0], [4.0], [6.0]]], device=on_device), - "d": ivy.array([[[3.0], [6.0], [9.0]]], device=on_device), - }, - } - - # depth 1 - container = Container(dict_in) - container.cont_create_if_absent("a", None, True) - assert np.allclose(ivy.to_numpy(container.a), np.array([[[1.0], [2.0], [3.0]]])) - container.cont_create_if_absent("e", ivy.array([[[4.0], [8.0], [12.0]]]), True) - assert np.allclose(ivy.to_numpy(container.e), np.array([[[4.0], [8.0], [12.0]]])) - - # depth 2 - container.cont_create_if_absent("f/g", np.array([[[5.0], [10.0], [15.0]]]), True) - assert np.allclose(ivy.to_numpy(container.f.g), np.array([[[5.0], [10.0], [15.0]]])) - + # loading after append + loaded_container = Container.cont_from_disk_as_hdf5(save_filepath) + assert np.array_equal(ivy.to_numpy(loaded_container.a), ivy.to_numpy(container2.a)) + assert np.array_equal( + ivy.to_numpy(loaded_container.b.c), ivy.to_numpy(container2.b.c) + ) + assert np.array_equal( + ivy.to_numpy(loaded_container.b.d), ivy.to_numpy(container2.b.d) + ) -def test_container_if_exists(on_device): - dict_in = { - "a": ivy.array([[[1.0], [2.0], [3.0]]], device=on_device), - "b": { - "c": ivy.array([[[2.0], [4.0], [6.0]]], device=on_device), - "d": ivy.array([[[3.0], [6.0], [9.0]]], device=on_device), - }, - } - container = Container(dict_in) - assert np.allclose( - ivy.to_numpy(container.cont_if_exists("a")), np.array([[[1.0], [2.0], [3.0]]]) + # load slice + loaded_sliced_container = Container.cont_from_disk_as_hdf5( + save_filepath, slice(1, 2) ) - assert "c" not in container - assert container.cont_if_exists("c") is None - container["c"] = ivy.array([[[1.0], [2.0], [3.0]]], device=on_device) - assert np.allclose( - ivy.to_numpy(container.cont_if_exists("c")), np.array([[[1.0], [2.0], [3.0]]]) + assert np.array_equal( + ivy.to_numpy(loaded_sliced_container.a), ivy.to_numpy(container1.a) ) - assert container.cont_if_exists("d") is None - container.d = ivy.array([[[1.0], [2.0], [3.0]]], device=on_device) - assert np.allclose( - ivy.to_numpy(container.cont_if_exists("d")), np.array([[[1.0], [2.0], [3.0]]]) + assert np.array_equal( + ivy.to_numpy(loaded_sliced_container.b.c), ivy.to_numpy(container1.b.c) + ) + assert np.array_equal( + ivy.to_numpy(loaded_sliced_container.b.d), ivy.to_numpy(container1.b.d) ) + # file size + file_size, batch_size = Container.h5_file_size(save_filepath) + assert file_size == 6 * np.dtype(np.float32).itemsize + assert batch_size == 2 -def test_jax_pytree_compatibility(on_device): - if ivy.current_backend_str() != "jax": - pytest.skip() + os.remove(save_filepath) - # import - from jax.tree_util import tree_flatten - # dict in +def test_container_to_and_from_disk_as_json(on_device): + save_filepath = "container_on_disk.json" dict_in = { - "a": ivy.array([1], device=on_device), - "b": { - "c": ivy.array([2], device=on_device), - "d": ivy.array([3], device=on_device), - }, - } - - # container - container = Container(dict_in) - - # container flattened - cont_values = tree_flatten(container)[0] - - # dict flattened - true_values = tree_flatten(dict_in)[0] - - # assertion - for i, true_val in enumerate(true_values): - assert np.array_equal(ivy.to_numpy(cont_values[i]), ivy.to_numpy(true_val)) - - -@pytest.mark.skip("Prevents PyTest from Terminating.") -def test_container_from_queues(on_device): - if "gpu" in on_device: - # Cannot re-initialize CUDA in forked subprocess. 'spawn' - # start method must be used. - pytest.skip() - - if ivy.gpu_is_available() and ivy.current_backend_str() == "jax": - # Not found a way to set default on_device for JAX, and this causes - # issues with multiprocessing and CUDA, even when device=cpu - # ToDo: find a fix for this problem ^^ - pytest.skip() - - def worker_fn(in_queue, out_queue, load_size, worker_id): - keep_going = True - while keep_going: - try: - keep_going = in_queue.get(timeout=0.1) - except queue.Empty: - continue - out_queue.put( - { - "a": [ - ivy.to_native(ivy.array([1.0, 2.0, 3.0], device=on_device)) - * worker_id - ] * load_size - } - ) - - workers = list() - in_queues = list() - out_queues = list() - queue_load_sizes = [1, 2, 1] - for i, queue_load_size in enumerate(queue_load_sizes): - input_queue = multiprocessing.Queue() - output_queue = multiprocessing.Queue() - worker = multiprocessing.Process( - target=worker_fn, args=(input_queue, output_queue, queue_load_size, i + 1) - ) - worker.start() - in_queues.append(input_queue) - out_queues.append(output_queue) - workers.append(worker) - - container = Container( - queues=out_queues, queue_load_sizes=queue_load_sizes, queue_timeout=0.25 - ) - - # queue 0 - queue_was_empty = False - try: - container[0] - except queue.Empty: - queue_was_empty = True - assert queue_was_empty - in_queues[0].put(True) - assert np.allclose(ivy.to_numpy(container[0].a), np.array([1.0, 2.0, 3.0])) - assert np.allclose(ivy.to_numpy(container[0].a), np.array([1.0, 2.0, 3.0])) + "a": 1.274e-7, + "b": {"c": True, "d": ivy.array([np.float32(3.0)], device=on_device)}, + } + container = Container(dict_in) - # queue 1 - queue_was_empty = False - try: - container[1] - except queue.Empty: - queue_was_empty = True - assert queue_was_empty - queue_was_empty = False - try: - container[2] - except queue.Empty: - queue_was_empty = True - assert queue_was_empty - in_queues[1].put(True) - assert np.allclose(ivy.to_numpy(container[1].a), np.array([2.0, 4.0, 6.0])) - assert np.allclose(ivy.to_numpy(container[1].a), np.array([2.0, 4.0, 6.0])) - assert np.allclose(ivy.to_numpy(container[2].a), np.array([2.0, 4.0, 6.0])) - assert np.allclose(ivy.to_numpy(container[2].a), np.array([2.0, 4.0, 6.0])) + # saving + container.cont_to_disk_as_json(save_filepath) + assert os.path.exists(save_filepath) - # queue 2 - queue_was_empty = False - try: - container[3] - except queue.Empty: - queue_was_empty = True - assert queue_was_empty - in_queues[2].put(True) - assert np.allclose(ivy.to_numpy(container[3].a), np.array([3.0, 6.0, 9.0])) - assert np.allclose(ivy.to_numpy(container[3].a), np.array([3.0, 6.0, 9.0])) + # loading + loaded_container = Container.cont_from_disk_as_json(save_filepath) + assert np.array_equal(loaded_container.a, container.a) + assert np.array_equal(loaded_container.b.c, container.b.c) + assert isinstance(loaded_container.b.d, str) - # stop workers - in_queues[0].put(False) - in_queues[1].put(False) - in_queues[2].put(False) - in_queues[0].close() - in_queues[1].close() - in_queues[2].close() + os.remove(save_filepath) - # join workers - for worker in workers: - worker.join() - del container +def test_container_to_and_from_disk_as_pickled(on_device): + save_filepath = "container_on_disk.pickled" + dict_in = { + "a": ivy.array([np.float32(1.0)], device=on_device), + "b": { + "c": ivy.array([np.float32(2.0)], device=on_device), + "d": ivy.array([np.float32(3.0)], device=on_device), + }, + } + container = Container(dict_in) + + # paddle tansor can't be pickled directly as mentioned + # in the issue https://github.com/PaddlePaddle/Paddle/issues/41107 + if ivy.backend == "paddle": + container = container.to_numpy() + # saving + container.cont_to_disk_as_pickled(save_filepath) + assert os.path.exists(save_filepath) -def test_container_reduce(on_device): - container_a = ivy.Container( - { - "a": ivy.array([1], device=on_device), - "b": { - "c": ivy.array([2], device=on_device), - "d": ivy.array([3], device=on_device), - }, - } + # loading + loaded_container = Container.cont_from_disk_as_pickled(save_filepath) + assert np.array_equal(ivy.to_numpy(loaded_container.a), ivy.to_numpy(container.a)) + assert np.array_equal( + ivy.to_numpy(loaded_container.b.c), ivy.to_numpy(container.b.c) ) - container_b = ivy.Container( - { - "a": ivy.array([2], device=on_device), - "b": { - "c": ivy.array([4], device=on_device), - "d": ivy.array([6], device=on_device), - }, - } + assert np.array_equal( + ivy.to_numpy(loaded_container.b.d), ivy.to_numpy(container.b.d) ) - res = ivy.Container.cont_reduce([container_a, container_b], lambda x: x[0] + x[1]) - assert np.allclose(ivy.to_numpy(res.a), np.array([3.0])) - assert np.allclose(ivy.to_numpy(res.b.c), np.array([6])) - assert np.allclose(ivy.to_numpy(res.b.d), np.array([9])) + os.remove(save_filepath) -def test_container_assert_identical(on_device): - # without key_chains specification - arr1 = ivy.array([1], device=on_device) - arr2 = ivy.array([2], device=on_device) - arr3 = ivy.array([3], device=on_device) - container0 = Container({"a": arr1, "b": {"c": arr2, "d": arr3}}) - container1 = Container({"a": arr1, "b": {"c": arr2, "d": arr3}}) - container2 = Container( + +def test_container_to_dict(on_device): + container0 = Container( { "a": ivy.array([1], device=on_device), "b": { - "c": ivy.array([2], device=on_device), - "d": ivy.array([3], device=on_device), + "c": ivy.array([True], device=on_device), + "d": { + "g": ivy.array([2.0], device=on_device), + "h": ivy.array([3], device=on_device), + }, }, } ) - container3 = Container({"b": {"d": arr3}}) - container4 = Container({"d": arr3}) + res = ivy.Container.cont_to_dict(container0) + assert res == {"a": 1, "b": {"c": True, "d": {"g": 2.0, "h": 3}}} - # the same - ivy.Container.cont_assert_identical([container0, container1]) - ivy.Container.cont_assert_identical([container1, container0]) - # not the same - try: - ivy.Container.cont_assert_identical([container0, container2]) - error_caught = False - except IvyException: - error_caught = True - assert error_caught - try: - ivy.Container.cont_assert_identical([container1, container2]) - error_caught = False - except IvyException: - error_caught = True - assert error_caught +def test_container_to_disk_shuffle_and_from_disk_as_hdf5(on_device): + if ivy.current_backend_str() == "tensorflow": + # container disk saving requires eager execution + pytest.skip() + save_filepath = "container_on_disk.hdf5" + dict_in = { + "a": ivy.array([1, 2, 3], device=on_device), + "b": { + "c": ivy.array([1, 2, 3], device=on_device), + "d": ivy.array([1, 2, 3], device=on_device), + }, + } + container = Container(dict_in) - # partial - ivy.Container.cont_assert_identical([container0, container3], partial=True) - ivy.Container.cont_assert_identical([container3, container0], partial=True) - try: - ivy.Container.cont_assert_identical([container4, container0], partial=True) - error_caught = False - except IvyException: - error_caught = True - assert error_caught + # saving + container.cont_to_disk_as_hdf5(save_filepath, max_batch_size=3) + assert os.path.exists(save_filepath) + # shuffling + Container.shuffle_h5_file(save_filepath) -def test_container_assert_identical_structure(on_device): - # without key_chains specification - container0 = Container( - { - "a": ivy.array([1], device=on_device), - "b": { - "c": ivy.array([2], device=on_device), - "d": ivy.array([3], device=on_device), - }, - } - ) - container1 = Container( - { - "a": ivy.array([3], device=on_device), - "b": { - "c": ivy.array([4], device=on_device), - "d": ivy.array([5], device=on_device), - }, - } - ) - container2 = Container( - { - "a": ivy.array([3], device=on_device), - "b": { - "c": ivy.array([4], device=on_device), - "d": ivy.array([5], device=on_device), - "e": ivy.array([6], device=on_device), - }, - } - ) - container3 = Container( - { - "a": ivy.array([3], device=on_device), - "b": { - "c": ivy.array([4], device=on_device), - "d": ivy.array([5], device=on_device), - }, - "e": ivy.array([6], device=on_device), - } + # loading + container_shuffled = Container.cont_from_disk_as_hdf5(save_filepath, slice(3)) + + # testing + data = np.array([1, 2, 3]) + random.seed(0) + random.shuffle(data) + + assert (ivy.to_numpy(container_shuffled["a"]) == data).all() + assert (ivy.to_numpy(container_shuffled.a) == data).all() + assert (ivy.to_numpy(container_shuffled["b"]["c"]) == data).all() + assert (ivy.to_numpy(container_shuffled.b.c) == data).all() + assert (ivy.to_numpy(container_shuffled["b"]["d"]) == data).all() + assert (ivy.to_numpy(container_shuffled.b.d) == data).all() + + os.remove(save_filepath) + + +def test_container_to_flat_list(on_device): + dict_in = { + "a": ivy.array([1], device=on_device), + "b": { + "c": ivy.array([2], device=on_device), + "d": ivy.array([3], device=on_device), + }, + } + container = Container(dict_in) + container_flat_list = container.cont_to_flat_list() + for value, expected_value in zip( + container_flat_list, + [ + ivy.array([1], device=on_device), + ivy.array([2], device=on_device), + ivy.array([3], device=on_device), + ], + ): + assert value == expected_value + + +@pytest.mark.parametrize("include_empty", [True, False]) +def test_container_to_iterator(include_empty, on_device): + a_val = Container() if include_empty else ivy.array([1], device=on_device) + bc_val = Container() if include_empty else ivy.array([2], device=on_device) + bd_val = Container() if include_empty else ivy.array([3], device=on_device) + dict_in = {"a": a_val, "b": {"c": bc_val, "d": bd_val}} + container = Container(dict_in) + + # with key chains + container_iterator = container.cont_to_iterator(include_empty=include_empty) + for (key_chain, value), expected in zip( + container_iterator, [("a", a_val), ("b/c", bc_val), ("b/d", bd_val)] + ): + expected_key_chain = expected[0] + expected_value = expected[1] + assert key_chain == expected_key_chain + assert value is expected_value + + # with leaf keys + container_iterator = container.cont_to_iterator( + leaf_keys_only=True, include_empty=include_empty ) - container4 = Container({"b": {"d": ivy.array([4], device=on_device)}}) - container5 = Container({"d": ivy.array([4], device=on_device)}) + for (key_chain, value), expected in zip( + container_iterator, [("a", a_val), ("c", bc_val), ("d", bd_val)] + ): + expected_key_chain = expected[0] + expected_value = expected[1] + assert key_chain == expected_key_chain + assert value is expected_value - # with identical - ivy.Container.cont_assert_identical_structure([container0, container1]) - ivy.Container.cont_assert_identical_structure([container1, container0]) - ivy.Container.cont_assert_identical_structure([container1, container0, container1]) - # without identical - try: - ivy.Container.cont_assert_identical_structure( - [container0, container1, container2, container3] - ) - error_caught = False - except IvyException: - error_caught = True - # partial - try: - ivy.Container.cont_assert_identical_structure( - [container0, container1, container2, container3, container4, container5], - partial=True, - ) - error_caught = False - except IvyException: - error_caught = True - assert error_caught - try: - ivy.Container.cont_assert_identical_structure( - [container0, container5], partial=True - ) - error_caught = False - except IvyException: - error_caught = True - assert error_caught +@pytest.mark.parametrize("include_empty", [True, False]) +def test_container_to_iterator_keys(include_empty, on_device): + a_val = Container() if include_empty else ivy.array([1], device=on_device) + bc_val = Container() if include_empty else ivy.array([2], device=on_device) + bd_val = Container() if include_empty else ivy.array([3], device=on_device) + dict_in = {"a": a_val, "b": {"c": bc_val, "d": bd_val}} + container = Container(dict_in) + # with key chains + container_iterator = container.cont_to_iterator_keys(include_empty=include_empty) + for key_chain, expected_key_chain in zip(container_iterator, ["a", "b/c", "b/d"]): + assert key_chain == expected_key_chain -def test_container_duplicate_array_keychains(on_device): - arr1 = ivy.array([1], device=on_device) - arr2 = ivy.array([2], device=on_device) - container0 = Container({"a": arr1, "b": {"c": arr1, "d": arr2}}) - container1 = Container( - { - "a": ivy.array([1], device=on_device), - "b": { - "c": ivy.array([1], device=on_device), - "d": ivy.array([2], device=on_device), - }, - } + # with leaf keys + container_iterator = container.cont_to_iterator_keys( + leaf_keys_only=True, include_empty=include_empty ) - res = ivy.Container.cont_duplicate_array_keychains(container0) - assert res == (("a", "b/c"),) - res = ivy.Container.cont_duplicate_array_keychains(container1) - assert res == () + for key, expected_key in zip(container_iterator, ["a", "c", "d"]): + assert key == expected_key -def test_container_cont_inplace_update(on_device): - container0 = Container( - { - "a": ivy.array([1], device=on_device), - "b": { - "c": ivy.array([1], device=on_device), - "d": ivy.array([2], device=on_device), - }, - } - ) - id0 = id(container0) - container1 = Container( - { - "a": ivy.array([0], device=on_device), - "b": { - "c": ivy.array([0], device=on_device), - "d": ivy.array([0], device=on_device), - }, - } - ) - id1 = id(container1) - assert ivy.Container.cont_all_false(container0.all_equal(container1)) - container0.inplace_update(container1) - assert id0 == id(container0) - assert id1 == id(container1) - assert ivy.Container.cont_all_true(container0.all_equal(container1)) +@pytest.mark.parametrize("include_empty", [True, False]) +def test_container_to_iterator_values(include_empty, on_device): + a_val = Container() if include_empty else ivy.array([1], device=on_device) + bc_val = Container() if include_empty else ivy.array([2], device=on_device) + bd_val = Container() if include_empty else ivy.array([3], device=on_device) + dict_in = {"a": a_val, "b": {"c": bc_val, "d": bd_val}} + container = Container(dict_in) + + # with key chains + container_iterator = container.cont_to_iterator_values(include_empty=include_empty) + for value, expected_value in zip(container_iterator, [a_val, bc_val, bd_val]): + assert value is expected_value def test_container_to_nested_list(on_device): @@ -3097,118 +3184,100 @@ def test_container_to_nested_list(on_device): assert res == [1, [True, [2.0, 3]]] -def test_container_to_dict(on_device): - container0 = Container( +def test_container_to_raw(on_device): + tuple_in = ( + ivy.array([1], device=on_device), + (ivy.array([2], device=on_device), ivy.array([3], device=on_device)), + ) + container = Container(tuple_in, types_to_iteratively_nest=[tuple]) + raw = container.cont_to_raw() + assert np.allclose(ivy.to_numpy(raw[0]), np.array([1])) + assert np.allclose(ivy.to_numpy(raw[1][0]), np.array([2])) + assert np.allclose(ivy.to_numpy(raw[1][1]), np.array([3])) + + +def test_container_trim_key(on_device): + key = "abcdefg" + max_length = 3 + trimmed_key = ivy.Container.cont_trim_key(key, max_length) + assert trimmed_key == "adg" + + +def test_container_try_kc(on_device): + cont = Container( { - "a": ivy.array([1], device=on_device), + "a": ivy.array([0.0], device=on_device), "b": { - "c": ivy.array([True], device=on_device), - "d": { - "g": ivy.array([2.0], device=on_device), - "h": ivy.array([3], device=on_device), - }, + "c": ivy.array([1.0], device=on_device), + "d": ivy.array([2.0], device=on_device), }, } ) - res = ivy.Container.cont_to_dict(container0) - assert res == {"a": 1, "b": {"c": True, "d": {"g": 2.0, "h": 3}}} - - -def test_container_assert_contains(on_device): - arr0 = ivy.array([0.0], device=on_device) - arr1 = ivy.array([1.0], device=on_device) - arr2 = ivy.array([2.0], device=on_device) - sub_cont = Container({"c": arr1, "d": arr2}) - container = Container({"a": arr0, "b": sub_cont}) - - # keys - assert "a" in container - assert "b" in container - assert "c" not in container - assert "b/c" in container - assert "d" not in container - assert "b/d" in container + assert cont.cont_try_kc("a") == cont.a + assert cont.cont_try_kc("b/c") == cont.b.c + assert cont.cont_try_kc("b/d") == cont.b.d + assert cont.cont_try_kc("b/e") is cont - # sub-container - container.cont_assert_contains_sub_container(container) - container.cont_assert_contains_sub_container(sub_cont) - assert sub_cont in container - # partial sub-container - partial_sub_cont = Container({"b": {"d": arr2}}) - container.cont_assert_contains_sub_container(container, partial=True) - container.cont_assert_contains_sub_container(partial_sub_cont, partial=True) - try: - partial_sub_cont.cont_assert_contains_sub_container(container, partial=True) - error_caught = False - except IvyException: - error_caught = True - assert error_caught - # sub-structure - sub_struc = Container( +def test_container_unify(on_device): + # on_devices and containers + on_devices = list() + dev0 = on_device + on_devices.append(dev0) + conts = dict() + conts[dev0] = Container( { - "c": ivy.array([3.0], device=on_device), - "d": ivy.array([4.0], device=on_device), + "a": ivy.array([1], device=dev0), + "b": {"c": ivy.array([2], device=dev0), "d": ivy.array([3], device=dev0)}, } ) - try: - not container.cont_assert_contains_sub_container(sub_struc) - error_caught = False - except IvyException: - error_caught = True - assert error_caught - assert sub_struc not in container - container.cont_assert_contains_sub_structure(sub_struc) - container.cont_assert_contains_sub_structure(container) + if "gpu" in on_device and ivy.num_gpus() > 1: + idx = ivy.num_gpus() - 1 + dev1 = on_device[:-1] + str(idx) + on_devices.append(dev1) + conts[dev1] = Container( + { + "a": ivy.array([4], device=dev1), + "b": { + "c": ivy.array([5], device=dev1), + "d": ivy.array([6], device=dev1), + }, + } + ) - # partial sub-structure - partial_sub_struc = Container({"b": {"d": ivy.array([4.0], device=on_device)}}) - container.cont_assert_contains_sub_structure(container, partial=True) - container.cont_assert_contains_sub_structure(partial_sub_struc, partial=True) - try: - partial_sub_struc.cont_assert_contains_sub_structure(container, partial=True) - error_caught = False - except IvyException: - error_caught = True - assert error_caught + # test + container_unified = ivy.Container.cont_unify(conts, dev0, "concat", 0) + assert np.allclose(ivy.to_numpy(container_unified.a[0]), np.array([1])) + assert np.allclose(ivy.to_numpy(container_unified.b.c[0]), np.array([2])) + assert np.allclose(ivy.to_numpy(container_unified.b.d[0]), np.array([3])) + if len(on_devices) > 1: + assert np.allclose(ivy.to_numpy(container_unified.a[1]), np.array([4])) + assert np.allclose(ivy.to_numpy(container_unified.b.c[1]), np.array([5])) + assert np.allclose(ivy.to_numpy(container_unified.b.d[1]), np.array([6])) -def test_container_copy(on_device): +def test_container_unstack_conts(on_device): dict_in = { - "a": ivy.array([0.0], device=on_device), + "a": ivy.array([[1], [2], [3]], device=on_device), "b": { - "c": ivy.array([1.0], device=on_device), - "d": ivy.array([2.0], device=on_device), + "c": ivy.array([[2], [3], [4]], device=on_device), + "d": ivy.array([[3], [4], [5]], device=on_device), }, - } - cont = Container(dict_in) - cont_deepcopy = cont.cont_copy() - assert np.allclose(ivy.to_numpy(cont.a), ivy.to_numpy(cont_deepcopy.a)) - assert np.allclose(ivy.to_numpy(cont.b.c), ivy.to_numpy(cont_deepcopy.b.c)) - assert np.allclose(ivy.to_numpy(cont.b.d), ivy.to_numpy(cont_deepcopy.b.d)) - assert id(cont) != id(cont_deepcopy) - assert id(cont.a) == id(cont_deepcopy.a) - assert id(cont.b.c) == id(cont_deepcopy.b.c) - assert id(cont.b.d) == id(cont_deepcopy.b.d) - + } + container = Container(dict_in) -def test_container_try_kc(on_device): - cont = Container( - { - "a": ivy.array([0.0], device=on_device), - "b": { - "c": ivy.array([1.0], device=on_device), - "d": ivy.array([2.0], device=on_device), - }, - } - ) - assert cont.cont_try_kc("a") == cont.a - assert cont.cont_try_kc("b/c") == cont.b.c - assert cont.cont_try_kc("b/d") == cont.b.d - assert cont.cont_try_kc("b/e") is cont + # without key_chains specification + container_unstacked = container.cont_unstack_conts(0) + for cont, a, bc, bd in zip(container_unstacked, [1, 2, 3], [2, 3, 4], [3, 4, 5]): + assert np.array_equal(ivy.to_numpy(cont["a"]), np.array([a])) + assert np.array_equal(ivy.to_numpy(cont.a), np.array([a])) + assert np.array_equal(ivy.to_numpy(cont["b"]["c"]), np.array([bc])) + assert np.array_equal(ivy.to_numpy(cont.b.c), np.array([bc])) + assert np.array_equal(ivy.to_numpy(cont["b"]["d"]), np.array([bd])) + assert np.array_equal(ivy.to_numpy(cont.b.d), np.array([bd])) -def test_container_with_print_limit(on_device): +def test_container_with_default_key_color(on_device): cont = Container( { "a": ivy.array([0.0], device=on_device), @@ -3218,41 +3287,54 @@ def test_container_with_print_limit(on_device): }, } ) - default_print_limit = cont._print_limit + default_default_key_color = cont._default_key_color id_cont = id(cont) - cont1 = cont.cont_with_print_limit(default_print_limit + 5) - assert cont1._print_limit == default_print_limit + 5 + cont1 = cont.cont_with_default_key_color("red") + assert cont1._default_key_color == "red" assert id(cont1) != id(cont) - assert cont._print_limit == default_print_limit - assert cont._print_limit != cont1._print_limit - cont.cont_with_print_limit(default_print_limit + 5, inplace=True) - assert cont._print_limit == default_print_limit + 5 - assert cont.b._print_limit == default_print_limit + 5 + assert cont._default_key_color == default_default_key_color + assert cont.b._default_key_color == default_default_key_color + assert cont._default_key_color != cont1._default_key_color + cont.cont_with_default_key_color("red", inplace=True) + assert cont._default_key_color == "red" + assert cont.b._default_key_color == "red" assert id(cont) == id_cont -def test_container_remove_print_limit(on_device): - cont = Container( +def test_container_with_entries_as_lists(on_device): + if ivy.current_backend_str() == "tensorflow": + # to_list() requires eager execution + pytest.skip() + dict_in = { + "a": ivy.array([1], device=on_device), + "b": {"c": ivy.array([2.0], device=on_device), "d": "some string"}, + } + container = Container(dict_in) + container_w_list_entries = container.cont_with_entries_as_lists() + for (key, value), expected_value in zip( + container_w_list_entries.cont_to_iterator(), [[1], [2.0], "some string"] + ): + assert value == expected_value + + +def test_container_with_ivy_backend(on_device): + container0 = Container( { - "a": ivy.array([0.0], device=on_device), + "a": ivy.array([1], device=on_device), "b": { - "c": ivy.array([1.0], device=on_device), - "d": ivy.array([2.0], device=on_device), + "c": ivy.array([1], device=on_device), + "d": ivy.array([2], device=on_device), }, } ) - default_print_limit = cont._print_limit - id_cont = id(cont) - cont1 = cont.cont_remove_print_limit() - assert cont1._print_limit is None - assert id(cont1) != id(cont) - assert cont._print_limit == default_print_limit - assert cont._print_limit != cont1._print_limit - assert cont.b._print_limit == default_print_limit - cont.cont_remove_print_limit(inplace=True) - assert cont._print_limit is None - assert cont.b._print_limit is None - assert id(cont) == id_cont + id_container0 = id(container0) + container0 = ivy.Container.cont_with_ivy_backend(container0, "numpy") + assert container0.cont_config["ivyh"] == "numpy" + assert id_container0 != id(container0) + id_container0 = id(container0) + ivy.Container.cont_with_ivy_backend(container0, "torch", inplace=True) + assert container0.cont_config["ivyh"] == "torch" + assert id(container0) == id_container0 def test_container_with_key_length_limit(on_device): @@ -3279,31 +3361,6 @@ def test_container_with_key_length_limit(on_device): assert id(cont) == id_cont -def test_container_remove_key_length_limit(on_device): - cont = Container( - { - "a": ivy.array([0.0], device=on_device), - "b": { - "c": ivy.array([1.0], device=on_device), - "d": ivy.array([2.0], device=on_device), - }, - } - ) - cont.cont_with_key_length_limit(5, inplace=True) - default_key_length_limit = cont._key_length_limit - id_cont = id(cont) - cont1 = cont.cont_remove_key_length_limit() - assert cont1._key_length_limit is None - assert id(cont1) != id(cont) - assert cont._key_length_limit == default_key_length_limit - assert cont.b._key_length_limit == default_key_length_limit - assert cont._key_length_limit != cont1._key_length_limit - cont.cont_remove_key_length_limit(inplace=True) - assert cont._key_length_limit is None - assert cont.b._key_length_limit is None - assert id(cont) == id_cont - - def test_container_with_print_indent(on_device): cont = Container( { @@ -3328,7 +3385,7 @@ def test_container_with_print_indent(on_device): assert id(cont) == id_cont -def test_container_with_print_line_spacing(on_device): +def test_container_with_print_limit(on_device): cont = Container( { "a": ivy.array([0.0], device=on_device), @@ -3338,21 +3395,20 @@ def test_container_with_print_line_spacing(on_device): }, } ) - default_print_line_spacing = cont._print_line_spacing + default_print_limit = cont._print_limit id_cont = id(cont) - cont1 = cont.cont_with_print_line_spacing(default_print_line_spacing + 5) - assert cont1._print_line_spacing == default_print_line_spacing + 5 + cont1 = cont.cont_with_print_limit(default_print_limit + 5) + assert cont1._print_limit == default_print_limit + 5 assert id(cont1) != id(cont) - assert cont._print_line_spacing == default_print_line_spacing - assert cont.b._print_line_spacing == default_print_line_spacing - assert cont._print_line_spacing != cont1._print_line_spacing - cont.cont_with_print_line_spacing(default_print_line_spacing + 5, inplace=True) - assert cont._print_line_spacing == default_print_line_spacing + 5 - assert cont.b._print_line_spacing == default_print_line_spacing + 5 + assert cont._print_limit == default_print_limit + assert cont._print_limit != cont1._print_limit + cont.cont_with_print_limit(default_print_limit + 5, inplace=True) + assert cont._print_limit == default_print_limit + 5 + assert cont.b._print_limit == default_print_limit + 5 assert id(cont) == id_cont -def test_container_with_default_key_color(on_device): +def test_container_with_print_line_spacing(on_device): cont = Container( { "a": ivy.array([0.0], device=on_device), @@ -3362,104 +3418,45 @@ def test_container_with_default_key_color(on_device): }, } ) - default_default_key_color = cont._default_key_color + default_print_line_spacing = cont._print_line_spacing id_cont = id(cont) - cont1 = cont.cont_with_default_key_color("red") - assert cont1._default_key_color == "red" + cont1 = cont.cont_with_print_line_spacing(default_print_line_spacing + 5) + assert cont1._print_line_spacing == default_print_line_spacing + 5 assert id(cont1) != id(cont) - assert cont._default_key_color == default_default_key_color - assert cont.b._default_key_color == default_default_key_color - assert cont._default_key_color != cont1._default_key_color - cont.cont_with_default_key_color("red", inplace=True) - assert cont._default_key_color == "red" - assert cont.b._default_key_color == "red" + assert cont._print_line_spacing == default_print_line_spacing + assert cont.b._print_line_spacing == default_print_line_spacing + assert cont._print_line_spacing != cont1._print_line_spacing + cont.cont_with_print_line_spacing(default_print_line_spacing + 5, inplace=True) + assert cont._print_line_spacing == default_print_line_spacing + 5 + assert cont.b._print_line_spacing == default_print_line_spacing + 5 assert id(cont) == id_cont -def test_container_with_ivy_backend(on_device): - container0 = Container( - { - "a": ivy.array([1], device=on_device), - "b": { - "c": ivy.array([1], device=on_device), - "d": ivy.array([2], device=on_device), - }, - } - ) - id_container0 = id(container0) - container0 = ivy.Container.cont_with_ivy_backend(container0, "numpy") - assert container0.cont_config["ivyh"] == "numpy" - assert id_container0 != id(container0) - id_container0 = id(container0) - ivy.Container.cont_with_ivy_backend(container0, "torch", inplace=True) - assert container0.cont_config["ivyh"] == "torch" - assert id(container0) == id_container0 - - -def test_container_trim_key(on_device): - key = "abcdefg" - max_length = 3 - trimmed_key = ivy.Container.cont_trim_key(key, max_length) - assert trimmed_key == "adg" - - -def test_container_inplace(on_device): - container0 = Container( - { - "a": ivy.array([1], device=on_device), - "b": { - "c": ivy.array([1], device=on_device), - "d": ivy.array([2], device=on_device), - }, - } - ) - const = 3 - arr = ivy.array([1], device=on_device) - container1 = Container( - { - "a": ivy.array([3], device=on_device), - "b": { - "c": ivy.array([4], device=on_device), - "d": ivy.array([5], device=on_device), - }, - } - ) +def test_jax_pytree_compatibility(on_device): + if ivy.current_backend_str() != "jax": + pytest.skip() - special_funcs = [ - "__add__", - "__and__", - "__floordiv__", - "__lshift__", - "__matmul__", - "__mod__", - "__mul__", - "__pow__", - "__rshift__", - "__sub__", - "__truediv__", - "__xor__", - ] + # import + from jax.tree_util import tree_flatten - for func_str in special_funcs: - func = getattr(Container, func_str) - ifunc = getattr(Container, func_str[:2] + "i" + func_str[2:]) + # dict in + dict_in = { + "a": ivy.array([1], device=on_device), + "b": { + "c": ivy.array([2], device=on_device), + "d": ivy.array([3], device=on_device), + }, + } - for value in [ - const, - arr, - container1, - ]: - if value == const and func_str == "__matmul__": - continue - container0_copy = container0.cont_deep_copy() - id_before_op = id(container0_copy) - og_ids = container0_copy.cont_map(lambda x, _: id(x)) - ifunc(container0_copy, value) - op_ids = container0_copy.cont_map(lambda x, _: id(x)) + # container + container = Container(dict_in) - assert func(container0, value) == container0_copy # values - assert id(container0_copy) == id_before_op # container ids - assert og_ids == op_ids # value ids + # container flattened + cont_values = tree_flatten(container)[0] + # dict flattened + true_values = tree_flatten(dict_in)[0] -# TODO: Test non-inplace operator functions like __add__ and __matmul__ + # assertion + for i, true_val in enumerate(true_values): + assert np.array_equal(ivy.to_numpy(cont_values[i]), ivy.to_numpy(true_val)) diff --git a/ivy_tests/test_ivy/test_misc/test_cp_tensor.py b/ivy_tests/test_ivy/test_misc/test_cp_tensor.py index de6422e0bbe39..97b70949ae5a0 100644 --- a/ivy_tests/test_ivy/test_misc/test_cp_tensor.py +++ b/ivy_tests/test_ivy/test_misc/test_cp_tensor.py @@ -3,9 +3,6 @@ import numpy as np import pytest -# These tests have been adapetd from Tensorly -# https://github.com/tensorly/tensorly/blob/main/tensorly/tests/test_cp_tensor.py - @pytest.mark.parametrize( "shape, rank", @@ -16,13 +13,14 @@ ) ], ) -def test_cp_normalize(shape, rank): +def test_cp_flip_sign(shape, rank): cp_tensor = ivy.random_cp(shape, rank) - weights, factors = ivy.CPTensor.cp_normalize(cp_tensor) - expected_norm = ivy.ones((rank,)) - for f in factors: - norm = ivy.sqrt(ivy.sum(ivy.square(f), axis=0)) - assert np.allclose(norm, expected_norm) + weights, factors = ivy.CPTensor.cp_flip_sign(cp_tensor) + + assert ivy.all(ivy.mean(factors[1], axis=0) > 0) + assert ivy.all(ivy.mean(factors[2], axis=0) > 0) + assert cp_tensor.rank == cp_tensor.rank + assert np.allclose(cp_tensor.weights, weights) assert np.allclose( ivy.CPTensor.cp_to_tensor((weights, factors)), ivy.CPTensor.cp_to_tensor(cp_tensor), @@ -33,65 +31,117 @@ def test_cp_normalize(shape, rank): "shape, rank", [ ( - (3, 4, 5), - 4, + (8, 5, 6, 4), + 25, ) ], ) -def test_cp_flip_sign(shape, rank): - cp_tensor = ivy.random_cp(shape, rank) - weights, factors = ivy.CPTensor.cp_flip_sign(cp_tensor) +def test_cp_lstsq_grad(shape, rank): + """Validate the gradient calculation between a CP and dense tensor.""" + cp_tensor = ivy.random_cp(shape, rank, normalise_factors=False) - assert ivy.all(ivy.mean(factors[1], axis=0) > 0) - assert ivy.all(ivy.mean(factors[2], axis=0) > 0) - assert cp_tensor.rank == cp_tensor.rank - assert np.allclose(cp_tensor.weights, weights) - assert np.allclose( - ivy.CPTensor.cp_to_tensor((weights, factors)), - ivy.CPTensor.cp_to_tensor(cp_tensor), + # If we're taking the gradient of comparison with self it should be 0 + cp_grad = ivy.CPTensor.cp_lstsq_grad( + cp_tensor, ivy.CPTensor.cp_to_tensor(cp_tensor) ) + assert ivy.CPTensor.cp_norm(cp_grad) <= 10e-5 + + # Check that we can solve for a direction of descent + dense = ivy.random_cp(shape, rank, full=True, normalise_factors=False) + cost_before = ivy.sqrt( + ivy.sum(ivy.square(ivy.CPTensor.cp_to_tensor(cp_tensor) - dense)) + ) + + cp_grad = ivy.CPTensor.cp_lstsq_grad(cp_tensor, dense) + cp_new = ivy.CPTensor(cp_tensor) + for ii in range(len(shape)): + cp_new.factors[ii] = cp_tensor.factors[ii] - 1e-3 * cp_grad.factors[ii] + + cost_after = ivy.sqrt( + ivy.sum(ivy.square(ivy.CPTensor.cp_to_tensor(cp_new) - dense)) + ) + assert cost_before > cost_after @pytest.mark.parametrize( - "true_shape, true_rank", + "shape, rank", [ ( - (3, 4, 5), + (5, 4, 6), 3, ) ], ) -def test_validate_cp_tensor(true_shape, true_rank): - cp_tensor = ivy.random_cp(true_shape, true_rank) - (weights, factors) = ivy.CPTensor.cp_normalize(cp_tensor) +def test_cp_mode_dot(shape, rank): + cp_ten = ivy.random_cp(shape, rank, orthogonal=True, full=False) + full_tensor = ivy.CPTensor.cp_to_tensor(cp_ten) + # matrix for mode 1 + matrix = ivy.random_uniform(shape=(7, shape[1])) + # vec for mode 2 + vec = ivy.random_uniform(shape=(shape[2])) - # Check correct rank and shapes are returned - shape, rank = ivy.CPTensor.validate_cp_tensor((weights, factors)) - np.testing.assert_equal( - true_shape, - shape, - err_msg=f"Returned incorrect shape (got {shape}, expected {true_shape})", - ) - np.testing.assert_equal( - rank, - true_rank, - err_msg=f"Returned incorrect rank (got {rank}, expected {true_rank})", - ) + # Test cp_mode_dot with matrix + res = ivy.CPTensor.cp_mode_dot(cp_ten, matrix, mode=1, copy=True) + # Note that if copy=True is not respected, factors will be changes + # And the next test will fail + res = ivy.CPTensor.cp_to_tensor(res) + true_res = ivy.mode_dot(full_tensor, matrix, mode=1) + assert np.allclose(true_res, res, atol=1e-3, rtol=1e-3) - # One of the factors has the wrong rank - factors[0], copy = ivy.random_uniform(shape=(4, 4)), factors[0] - with np.testing.assert_raises(ValueError): - ivy.CPTensor.validate_cp_tensor((weights, factors)) + # Check that the data was indeed copied + rec = ivy.CPTensor.cp_to_tensor(cp_ten) + assert np.allclose(full_tensor, rec) - # Not the correct amount of weights - factors[0] = copy - wrong_weights = weights[1:] - with np.testing.assert_raises(ValueError): - ivy.CPTensor.validate_cp_tensor((wrong_weights, factors)) + # Test cp_mode_dot with vec + res = ivy.CPTensor.cp_mode_dot(cp_ten, vec, mode=2, copy=True) + res = ivy.CPTensor.cp_to_tensor(res) + true_res = ivy.mode_dot(full_tensor, vec, mode=2) + assert res.shape == true_res.shape + assert np.allclose(true_res, res) - # Not enough factors - with np.testing.assert_raises(ValueError): - ivy.CPTensor.validate_cp_tensor((weights[:1], factors[:1])) + +@pytest.mark.parametrize( + "shape, rank, tol", + [ + ( + (8, 5, 6, 4), + 25, + 10e-5, + ) + ], +) +def test_cp_norm(shape, rank, tol): + cp_tensor = ivy.random_cp(shape, rank, full=False, normalise_factors=True) + rec = ivy.CPTensor.cp_to_tensor(cp_tensor) + true_res = ivy.sqrt(ivy.sum(ivy.square(rec))) + res = ivy.CPTensor.cp_norm(cp_tensor) + assert ivy.abs(true_res - res) <= tol + + +# These tests have been adapetd from Tensorly +# https://github.com/tensorly/tensorly/blob/main/tensorly/tests/test_cp_tensor.py + + +@pytest.mark.parametrize( + "shape, rank", + [ + ( + (3, 4, 5), + 4, + ) + ], +) +def test_cp_normalize(shape, rank): + cp_tensor = ivy.random_cp(shape, rank) + weights, factors = ivy.CPTensor.cp_normalize(cp_tensor) + expected_norm = ivy.ones((rank,)) + for f in factors: + norm = ivy.sqrt(ivy.sum(ivy.square(f), axis=0)) + assert np.allclose(norm, expected_norm) + assert np.allclose( + ivy.CPTensor.cp_to_tensor((weights, factors)), + ivy.CPTensor.cp_to_tensor(cp_tensor), + ) @pytest.mark.parametrize( @@ -220,55 +270,24 @@ def test_cp_to_vec(shapeU1, shapeU2, shapeU3, shapeU4): "shape, rank", [ ( - (5, 4, 6), - 3, + (10, 10, 10, 4), + 5, ) ], ) -def test_cp_mode_dot(shape, rank): - cp_ten = ivy.random_cp(shape, rank, orthogonal=True, full=False) - full_tensor = ivy.CPTensor.cp_to_tensor(cp_ten) - # matrix for mode 1 - matrix = ivy.random_uniform(shape=(7, shape[1])) - # vec for mode 2 - vec = ivy.random_uniform(shape=(shape[2])) - - # Test cp_mode_dot with matrix - res = ivy.CPTensor.cp_mode_dot(cp_ten, matrix, mode=1, copy=True) - # Note that if copy=True is not respected, factors will be changes - # And the next test will fail - res = ivy.CPTensor.cp_to_tensor(res) - true_res = ivy.mode_dot(full_tensor, matrix, mode=1) - assert np.allclose(true_res, res, atol=1e-3, rtol=1e-3) - - # Check that the data was indeed copied - rec = ivy.CPTensor.cp_to_tensor(cp_ten) - assert np.allclose(full_tensor, rec) - - # Test cp_mode_dot with vec - res = ivy.CPTensor.cp_mode_dot(cp_ten, vec, mode=2, copy=True) - res = ivy.CPTensor.cp_to_tensor(res) - true_res = ivy.mode_dot(full_tensor, vec, mode=2) - assert res.shape == true_res.shape - assert np.allclose(true_res, res) +def test_unfolding_dot_khatri_rao(shape, rank): + tensor = ivy.random_uniform(shape=shape) + weights, factors = ivy.random_cp(shape, rank, full=False, normalise_factors=True) + for mode in range(4): + # Version forming explicitely the khatri-rao product + unfolded = ivy.unfold(tensor, mode) + kr_factors = ivy.khatri_rao(factors, weights=weights, skip_matrix=mode) + true_res = ivy.matmul(unfolded, kr_factors) -@pytest.mark.parametrize( - "shape, rank, tol", - [ - ( - (8, 5, 6, 4), - 25, - 10e-5, - ) - ], -) -def test_cp_norm(shape, rank, tol): - cp_tensor = ivy.random_cp(shape, rank, full=False, normalise_factors=True) - rec = ivy.CPTensor.cp_to_tensor(cp_tensor) - true_res = ivy.sqrt(ivy.sum(ivy.square(rec))) - res = ivy.CPTensor.cp_norm(cp_tensor) - assert ivy.abs(true_res - res) <= tol + # Efficient sparse-safe version + res = ivy.CPTensor.unfolding_dot_khatri_rao(tensor, (weights, factors), mode) + assert np.allclose(true_res, res) @pytest.mark.parametrize("size", [4]) @@ -288,60 +307,42 @@ def test_validate_cp_rank(size): @pytest.mark.parametrize( - "shape, rank", + "true_shape, true_rank", [ ( - (8, 5, 6, 4), - 25, + (3, 4, 5), + 3, ) ], ) -def test_cp_lstsq_grad(shape, rank): - """Validate the gradient calculation between a CP and dense tensor.""" - cp_tensor = ivy.random_cp(shape, rank, normalise_factors=False) - - # If we're taking the gradient of comparison with self it should be 0 - cp_grad = ivy.CPTensor.cp_lstsq_grad( - cp_tensor, ivy.CPTensor.cp_to_tensor(cp_tensor) - ) - assert ivy.CPTensor.cp_norm(cp_grad) <= 10e-5 +def test_validate_cp_tensor(true_shape, true_rank): + cp_tensor = ivy.random_cp(true_shape, true_rank) + (weights, factors) = ivy.CPTensor.cp_normalize(cp_tensor) - # Check that we can solve for a direction of descent - dense = ivy.random_cp(shape, rank, full=True, normalise_factors=False) - cost_before = ivy.sqrt( - ivy.sum(ivy.square(ivy.CPTensor.cp_to_tensor(cp_tensor) - dense)) + # Check correct rank and shapes are returned + shape, rank = ivy.CPTensor.validate_cp_tensor((weights, factors)) + np.testing.assert_equal( + true_shape, + shape, + err_msg=f"Returned incorrect shape (got {shape}, expected {true_shape})", ) - - cp_grad = ivy.CPTensor.cp_lstsq_grad(cp_tensor, dense) - cp_new = ivy.CPTensor(cp_tensor) - for ii in range(len(shape)): - cp_new.factors[ii] = cp_tensor.factors[ii] - 1e-3 * cp_grad.factors[ii] - - cost_after = ivy.sqrt( - ivy.sum(ivy.square(ivy.CPTensor.cp_to_tensor(cp_new) - dense)) + np.testing.assert_equal( + rank, + true_rank, + err_msg=f"Returned incorrect rank (got {rank}, expected {true_rank})", ) - assert cost_before > cost_after - -@pytest.mark.parametrize( - "shape, rank", - [ - ( - (10, 10, 10, 4), - 5, - ) - ], -) -def test_unfolding_dot_khatri_rao(shape, rank): - tensor = ivy.random_uniform(shape=shape) - weights, factors = ivy.random_cp(shape, rank, full=False, normalise_factors=True) + # One of the factors has the wrong rank + factors[0], copy = ivy.random_uniform(shape=(4, 4)), factors[0] + with np.testing.assert_raises(ValueError): + ivy.CPTensor.validate_cp_tensor((weights, factors)) - for mode in range(4): - # Version forming explicitely the khatri-rao product - unfolded = ivy.unfold(tensor, mode) - kr_factors = ivy.khatri_rao(factors, weights=weights, skip_matrix=mode) - true_res = ivy.matmul(unfolded, kr_factors) + # Not the correct amount of weights + factors[0] = copy + wrong_weights = weights[1:] + with np.testing.assert_raises(ValueError): + ivy.CPTensor.validate_cp_tensor((wrong_weights, factors)) - # Efficient sparse-safe version - res = ivy.CPTensor.unfolding_dot_khatri_rao(tensor, (weights, factors), mode) - assert np.allclose(true_res, res) + # Not enough factors + with np.testing.assert_raises(ValueError): + ivy.CPTensor.validate_cp_tensor((weights[:1], factors[:1])) diff --git a/ivy_tests/test_ivy/test_misc/test_exceptions.py b/ivy_tests/test_ivy/test_misc/test_exceptions.py index dc69d53a841fd..d08a3c3fffa81 100644 --- a/ivy_tests/test_ivy/test_misc/test_exceptions.py +++ b/ivy_tests/test_ivy/test_misc/test_exceptions.py @@ -5,6 +5,25 @@ import ivy +@pytest.mark.parametrize("trace_mode", ["full", "ivy", "frontend"]) +def test_get_trace_mode(trace_mode, backend_fw): + ivy.set_backend(backend_fw) + ivy.set_exception_trace_mode(trace_mode) + ivy.set_exception_trace_mode("ivy") + ivy.utils.assertions.check_equal(ivy.exception_trace_mode, "ivy", as_array=False) + ivy.previous_backend() + + +@pytest.mark.parametrize("trace_mode", ["full", "ivy", "frontend"]) +def test_set_trace_mode(trace_mode, backend_fw): + ivy.set_backend(backend_fw) + ivy.set_exception_trace_mode(trace_mode) + ivy.utils.assertions.check_equal( + ivy.exception_trace_mode, trace_mode, as_array=False + ) + ivy.previous_backend() + + @pytest.mark.parametrize("trace_mode", ["full", "ivy", "frontend"]) @pytest.mark.parametrize("show_func_wrapper", [True, False]) def test_trace_modes(backend_fw, trace_mode, show_func_wrapper): @@ -57,16 +76,6 @@ def test_trace_modes(backend_fw, trace_mode, show_func_wrapper): ivy.previous_backend() -@pytest.mark.parametrize("trace_mode", ["full", "ivy", "frontend"]) -def test_set_trace_mode(trace_mode, backend_fw): - ivy.set_backend(backend_fw) - ivy.set_exception_trace_mode(trace_mode) - ivy.utils.assertions.check_equal( - ivy.exception_trace_mode, trace_mode, as_array=False - ) - ivy.previous_backend() - - @pytest.mark.parametrize("trace_mode", ["full", "ivy", "frontend"]) def test_unset_trace_mode(trace_mode, backend_fw): ivy.set_backend(backend_fw) @@ -78,12 +87,3 @@ def test_unset_trace_mode(trace_mode, backend_fw): ivy.exception_trace_mode, trace_mode, as_array=False ) ivy.previous_backend() - - -@pytest.mark.parametrize("trace_mode", ["full", "ivy", "frontend"]) -def test_get_trace_mode(trace_mode, backend_fw): - ivy.set_backend(backend_fw) - ivy.set_exception_trace_mode(trace_mode) - ivy.set_exception_trace_mode("ivy") - ivy.utils.assertions.check_equal(ivy.exception_trace_mode, "ivy", as_array=False) - ivy.previous_backend() diff --git a/ivy_tests/test_ivy/test_misc/test_func_wrapper.py b/ivy_tests/test_ivy/test_misc/test_func_wrapper.py index 3e5befbccdaef..de3c5a73c9604 100644 --- a/ivy_tests/test_ivy/test_misc/test_func_wrapper.py +++ b/ivy_tests/test_ivy/test_misc/test_func_wrapper.py @@ -7,6 +7,10 @@ from typing import Union, Tuple, List, Sequence +# --- Helpers --- # +# --------------- # + + def _fn1(x: Union[ivy.Array, Tuple[int, int]]): return x @@ -23,133 +27,50 @@ def _fn4(x: Union[Sequence[ivy.Array], ivy.Array]): return x -@pytest.mark.parametrize( - ("fn", "x", "expected_type"), - [ - (_fn1, (1, 2), tuple), - (_fn2, (1, 2), ivy.Array), - (_fn2, [1, 2], ivy.Array), - (_fn3, [1, 2], list), - (_fn4, [1, 2], list), - ], -) -def test_handle_array_like_without_promotion(fn, x, expected_type, backend_fw): - ivy.set_backend(backend_fw) - assert isinstance(handle_array_like_without_promotion(fn)(x), expected_type) - ivy.previous_backend() - - -def test_outputs_to_ivy_arrays(backend_fw): - ivy.set_backend(backend_fw) - assert isinstance( - ivy.outputs_to_ivy_arrays(_fn1)(ivy.to_native(ivy.array([2.0]))), ivy.Array - ) - assert ivy.outputs_to_ivy_arrays(_fn1)(ivy.array(1)) == ivy.array(1) - ivy.previous_backend() - - def _fn5(x): # Test input was converted to native array assert isinstance(x, ivy.NativeArray) -def test_inputs_to_native_arrays(backend_fw): - ivy.set_backend(backend_fw) - ivy.inputs_to_native_arrays(_fn5)(ivy.array(1)) - ivy.previous_backend() - - def _fn6(x): # Assert input was converted to Ivy Array assert isinstance(x, ivy.Array) -def test_inputs_to_ivy_arrays(backend_fw): - ivy.set_backend(backend_fw) - ivy.inputs_to_ivy_arrays(_fn6)(ivy.native_array(1)) - ivy.previous_backend() - - def _fn7(x): # Assert input was converted to native array assert isinstance(x, ivy.NativeArray) return x -def test_to_native_arrays_and_back(backend_fw): - ivy.set_backend(backend_fw) - x = ivy.array(1.0) - res = ivy.func_wrapper.to_native_arrays_and_back(_fn7)(x) - assert isinstance(res, ivy.Array) - ivy.previous_backend() +def _fn8(x): + return ivy.ones_like(x) -@pytest.mark.parametrize( - ("x", "weight", "expected"), - [ - ([[1, 1], [1, 1]], [[1, 1], [1, 1], [1, 1]], True), - ( - [[1, 1], [1, 1]], - [ - [[1, 1], [1, 1], [1, 1]], - [[1, 1], [1, 1], [1, 1]], - [[1, 1], [1, 1], [1, 1]], - ], - False, - ), - ], -) -def test_handle_partial_mixed_function(x, weight, expected, backend_fw): - ivy.set_backend(backend_fw) - test_fn = "torch.nn.functional.linear" - if ivy.current_backend_str() != "torch": - # ivy.matmul is used inside the compositional implementation - test_fn = "ivy.matmul" - expected = True - with patch(test_fn) as test_mock_function: - ivy.linear(ivy.array(x), ivy.array(weight)) - assert test_mock_function.called == expected - ivy.previous_backend() +def _jl(x, *args, fn_original, **kwargs): + return fn_original(x) * 3j + + +# --- Main --- # +# ------------ # @pytest.mark.parametrize( - "array_to_update", - [0, 1, 2, 3, 4], + ("fn", "x", "expected_type"), + [ + (_fn1, (1, 2), tuple), + (_fn2, (1, 2), ivy.Array), + (_fn2, [1, 2], ivy.Array), + (_fn3, [1, 2], list), + (_fn4, [1, 2], list), + ], ) -def test_views(array_to_update, backend_fw): +def test_handle_array_like_without_promotion(fn, x, expected_type, backend_fw): ivy.set_backend(backend_fw) - a = ivy.random.random_normal(shape=(6,)) - a_copy = ivy.copy_array(a) - b = a.reshape((2, 3)) - b_copy = ivy.copy_array(b) - c = ivy.flip(b) - c_copy = ivy.copy_array(c) - d = ivy.rot90(c, k=3) - d_copy = ivy.copy_array(d) - e = ivy.split(d) - e_copy = ivy.copy_array(e[0]) - array = (a, b, c, d, e)[array_to_update] - if array_to_update == 4: - for arr in array: - arr += 1 - else: - array += 1 - assert np.allclose(a, a_copy + 1) - assert np.allclose(b, b_copy + 1) - assert np.allclose(c, c_copy + 1) - assert np.allclose(d, d_copy + 1) - assert np.allclose(e[0], e_copy + 1) + assert isinstance(handle_array_like_without_promotion(fn)(x), expected_type) ivy.previous_backend() -def _fn8(x): - return ivy.ones_like(x) - - -def _jl(x, *args, fn_original, **kwargs): - return fn_original(x) * 3j - - @pytest.mark.parametrize( ("x", "mode", "jax_like", "expected"), [ @@ -201,3 +122,90 @@ def test_handle_complex_input(x, mode, jax_like, expected, backend_fw): ) ) ivy.previous_backend() + + +@pytest.mark.parametrize( + ("x", "weight", "expected"), + [ + ([[1, 1], [1, 1]], [[1, 1], [1, 1], [1, 1]], True), + ( + [[1, 1], [1, 1]], + [ + [[1, 1], [1, 1], [1, 1]], + [[1, 1], [1, 1], [1, 1]], + [[1, 1], [1, 1], [1, 1]], + ], + False, + ), + ], +) +def test_handle_partial_mixed_function(x, weight, expected, backend_fw): + ivy.set_backend(backend_fw) + test_fn = "torch.nn.functional.linear" + if ivy.current_backend_str() != "torch": + # ivy.matmul is used inside the compositional implementation + test_fn = "ivy.matmul" + expected = True + with patch(test_fn) as test_mock_function: + ivy.linear(ivy.array(x), ivy.array(weight)) + assert test_mock_function.called == expected + ivy.previous_backend() + + +def test_inputs_to_ivy_arrays(backend_fw): + ivy.set_backend(backend_fw) + ivy.inputs_to_ivy_arrays(_fn6)(ivy.native_array(1)) + ivy.previous_backend() + + +def test_inputs_to_native_arrays(backend_fw): + ivy.set_backend(backend_fw) + ivy.inputs_to_native_arrays(_fn5)(ivy.array(1)) + ivy.previous_backend() + + +def test_outputs_to_ivy_arrays(backend_fw): + ivy.set_backend(backend_fw) + assert isinstance( + ivy.outputs_to_ivy_arrays(_fn1)(ivy.to_native(ivy.array([2.0]))), ivy.Array + ) + assert ivy.outputs_to_ivy_arrays(_fn1)(ivy.array(1)) == ivy.array(1) + ivy.previous_backend() + + +def test_to_native_arrays_and_back(backend_fw): + ivy.set_backend(backend_fw) + x = ivy.array(1.0) + res = ivy.func_wrapper.to_native_arrays_and_back(_fn7)(x) + assert isinstance(res, ivy.Array) + ivy.previous_backend() + + +@pytest.mark.parametrize( + "array_to_update", + [0, 1, 2, 3, 4], +) +def test_views(array_to_update, backend_fw): + ivy.set_backend(backend_fw) + a = ivy.random.random_normal(shape=(6,)) + a_copy = ivy.copy_array(a) + b = a.reshape((2, 3)) + b_copy = ivy.copy_array(b) + c = ivy.flip(b) + c_copy = ivy.copy_array(c) + d = ivy.rot90(c, k=3) + d_copy = ivy.copy_array(d) + e = ivy.split(d) + e_copy = ivy.copy_array(e[0]) + array = (a, b, c, d, e)[array_to_update] + if array_to_update == 4: + for arr in array: + arr += 1 + else: + array += 1 + assert np.allclose(a, a_copy + 1) + assert np.allclose(b, b_copy + 1) + assert np.allclose(c, c_copy + 1) + assert np.allclose(d, d_copy + 1) + assert np.allclose(e[0], e_copy + 1) + ivy.previous_backend() diff --git a/ivy_tests/test_ivy/test_misc/test_inspection.py b/ivy_tests/test_ivy/test_misc/test_inspection.py index 16506b3c472b6..a218e77e09928 100644 --- a/ivy_tests/test_ivy/test_misc/test_inspection.py +++ b/ivy_tests/test_ivy/test_misc/test_inspection.py @@ -6,6 +6,10 @@ import ivy +# --- Helpers --- # +# --------------- # + + def _fn0(xs: Optional[List[ivy.Array]] = None): return xs @@ -27,6 +31,10 @@ def _fn2( return a, bs, cs +# --- Main --- # +# ------------ # + + @pytest.mark.parametrize( "fn_n_spec", [ diff --git a/ivy_tests/test_ivy/test_misc/test_ivy_demos.py b/ivy_tests/test_ivy/test_misc/test_ivy_demos.py index febb201bbbed6..2fffb96b487e5 100644 --- a/ivy_tests/test_ivy/test_misc/test_ivy_demos.py +++ b/ivy_tests/test_ivy/test_misc/test_ivy_demos.py @@ -8,6 +8,25 @@ import ivy.functional.backends.numpy +# functional api +def test_array(on_device): + import jax.numpy as jnp + + assert ivy.concat((jnp.ones((1,)), jnp.ones((1,))), axis=-1).shape == (2,) + import tensorflow as tf + + assert ivy.concat((tf.ones((1,)), tf.ones((1,))), axis=-1).shape == (2,) + import numpy as np + + assert ivy.concat((np.ones((1,)), np.ones((1,))), axis=-1).shape == (2,) + import torch + + assert ivy.concat((torch.ones((1,)), torch.ones((1,))), axis=-1).shape == (2,) + import paddle + + assert ivy.concat((paddle.ones((1,)), paddle.ones((1,))), axis=-1).shape == (2,) + + # Tests # # ------# @@ -44,22 +63,3 @@ def loss_fn(v): model.v = optimizer.step(model.v, grads) ivy.previous_backend() - - -# functional api -def test_array(on_device): - import jax.numpy as jnp - - assert ivy.concat((jnp.ones((1,)), jnp.ones((1,))), axis=-1).shape == (2,) - import tensorflow as tf - - assert ivy.concat((tf.ones((1,)), tf.ones((1,))), axis=-1).shape == (2,) - import numpy as np - - assert ivy.concat((np.ones((1,)), np.ones((1,))), axis=-1).shape == (2,) - import torch - - assert ivy.concat((torch.ones((1,)), torch.ones((1,))), axis=-1).shape == (2,) - import paddle - - assert ivy.concat((paddle.ones((1,)), paddle.ones((1,))), axis=-1).shape == (2,) diff --git a/ivy_tests/test_ivy/test_misc/test_logging.py b/ivy_tests/test_ivy/test_misc/test_logging.py index e5901eb7008e4..8f2053c193a31 100644 --- a/ivy_tests/test_ivy/test_misc/test_logging.py +++ b/ivy_tests/test_ivy/test_misc/test_logging.py @@ -3,6 +3,11 @@ import ivy +def test_invalid_logging_mode(): + with pytest.raises(AssertionError): + ivy.set_logging_mode("INVALID") + + def test_set_logging_mode(): ivy.set_logging_mode("DEBUG") assert logging.getLogger().level == logging.DEBUG @@ -22,8 +27,3 @@ def test_unset_logging_mode(): ivy.set_logging_mode("INFO") ivy.unset_logging_mode() assert logging.getLogger().level == logging.DEBUG - - -def test_invalid_logging_mode(): - with pytest.raises(AssertionError): - ivy.set_logging_mode("INVALID") diff --git a/ivy_tests/test_ivy/test_misc/test_pickling.py b/ivy_tests/test_ivy/test_misc/test_pickling.py index 199c28a3a971a..3ff464b53b2fd 100644 --- a/ivy_tests/test_ivy/test_misc/test_pickling.py +++ b/ivy_tests/test_ivy/test_misc/test_pickling.py @@ -9,11 +9,7 @@ import ivy_tests.test_ivy.helpers as helpers -# Tests # -# ------# - - -# pickling array test to str +# pickling array test to disk @given( dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid", prune_function=False), @@ -23,7 +19,7 @@ max_dim_size=5, ), ) -def test_pickle_to_string(dtype_and_x, on_device, backend_fw): +def test_pickle_to_and_from_disk(dtype_and_x, on_device, backend_fw): ivy.set_backend(backend_fw) input_dtype, x = dtype_and_x assume("bfloat16" not in input_dtype) @@ -33,15 +29,26 @@ def test_pickle_to_string(dtype_and_x, on_device, backend_fw): # https://github.com/PaddlePaddle/Paddle/issues/41107 if ivy.backend == "paddle": x = x.to_numpy() - pickled_arr = pickle.dumps(x) - unpickled_arr = pickle.loads(pickled_arr) + + save_filepath = "ivy_array.pickle" + pickle.dump(x, open(save_filepath, "wb")) + + assert os.path.exists(save_filepath) + + unpickled_arr = pickle.load(open(save_filepath, "rb")) + + os.remove(save_filepath) # check for equality assert np.allclose(ivy.to_numpy(x), ivy.to_numpy(unpickled_arr)) ivy.previous_backend() -# pickling array test to disk +# Tests # +# ------# + + +# pickling array test to str @given( dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid", prune_function=False), @@ -51,7 +58,7 @@ def test_pickle_to_string(dtype_and_x, on_device, backend_fw): max_dim_size=5, ), ) -def test_pickle_to_and_from_disk(dtype_and_x, on_device, backend_fw): +def test_pickle_to_string(dtype_and_x, on_device, backend_fw): ivy.set_backend(backend_fw) input_dtype, x = dtype_and_x assume("bfloat16" not in input_dtype) @@ -61,15 +68,8 @@ def test_pickle_to_and_from_disk(dtype_and_x, on_device, backend_fw): # https://github.com/PaddlePaddle/Paddle/issues/41107 if ivy.backend == "paddle": x = x.to_numpy() - - save_filepath = "ivy_array.pickle" - pickle.dump(x, open(save_filepath, "wb")) - - assert os.path.exists(save_filepath) - - unpickled_arr = pickle.load(open(save_filepath, "rb")) - - os.remove(save_filepath) + pickled_arr = pickle.dumps(x) + unpickled_arr = pickle.loads(pickled_arr) # check for equality assert np.allclose(ivy.to_numpy(x), ivy.to_numpy(unpickled_arr)) diff --git a/ivy_tests/test_ivy/test_misc/test_shape.py b/ivy_tests/test_ivy/test_misc/test_shape.py index ee095cc70ec9f..266292f7e2500 100644 --- a/ivy_tests/test_ivy/test_misc/test_shape.py +++ b/ivy_tests/test_ivy/test_misc/test_shape.py @@ -10,45 +10,52 @@ @handle_method( - method_tree="Shape.__getitem__", - dtypes_x_query=helpers.dtype_array_query( - available_dtypes=helpers.get_dtypes("valid"), - allow_neg_step=False, + method_tree="Shape.__add__", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, + large_abs_safety_factor=2.5, + small_abs_safety_factor=2.5, + safety_factor_scale="log", + shared_dtype=True, ), ) -def test_shape__getitem__( - dtypes_x_query, - init_flags, - method_flags, +def test_shape__add__( + dtype_and_x, method_name, class_name, - backend_fw, ground_truth_backend, + backend_fw, + init_flags, + method_flags, on_device, ): - dtypes, x, query = dtypes_x_query + dtype, x = dtype_and_x helpers.test_method( on_device=on_device, - backend_to_test=backend_fw, ground_truth_backend=ground_truth_backend, + backend_to_test=backend_fw, init_flags=init_flags, method_flags=method_flags, - init_all_as_kwargs_np={"shape_tup": x}, - init_input_dtypes=[dtypes[0]], - method_input_dtypes=[dtypes[1]], - method_all_as_kwargs_np={"key": query}, + init_all_as_kwargs_np={"data": x[0]}, + init_input_dtypes=dtype, + method_input_dtypes=dtype, + method_all_as_kwargs_np={"other": x[1]}, class_name=class_name, method_name=method_name, ) @handle_method( - method_tree="Shape.__index__", + method_tree="Shape.__bool__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), + available_dtypes=st.one_of(st.just(("bool",)), helpers.get_dtypes("integer")), + max_num_dims=0, + min_value=0, + max_value=1, ), ) -def test_shape__index__( +def test_shape__bool__( dtype_and_x, method_name, class_name, @@ -67,7 +74,7 @@ def test_shape__index__( method_flags=method_flags, init_all_as_kwargs_np={"data": x[0]}, init_input_dtypes=dtype, - method_input_dtypes=dtype, + method_input_dtypes=[], method_all_as_kwargs_np={}, class_name=class_name, method_name=method_name, @@ -75,70 +82,48 @@ def test_shape__index__( @handle_method( - method_tree="Shape.__pow__", - dtype_and_x=pow_helper(), + method_tree="Shape.__eq__", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, + shared_dtype=True, + ), ) -def test_shape__pow__( +def test_shape__eq__( dtype_and_x, method_name, class_name, - backend_fw, ground_truth_backend, + backend_fw, init_flags, method_flags, on_device, ): - input_dtype, x = dtype_and_x - - # bfloat16 is not supported by numpy - assume(not ("bfloat16" in input_dtype)) - - # Make sure x2 isn't a float when x1 is integer - with BackendHandler.update_backend(backend_fw) as ivy_backend: - assume( - not ( - ivy_backend.is_int_dtype( - input_dtype[0] and ivy_backend.is_float_dtype(input_dtype[1]) - ) - ) - ) - - # Make sure x2 is non-negative when both is integer - if ivy_backend.is_int_dtype(input_dtype[1]) and ivy_backend.is_int_dtype( - input_dtype[0] - ): - x[1] = np.abs(x[1]) - - x[0] = not_too_close_to_zero(x[0]) - x[1] = not_too_close_to_zero(x[1]) - + dtype, x = dtype_and_x helpers.test_method( on_device=on_device, ground_truth_backend=ground_truth_backend, - init_flags=init_flags, backend_to_test=backend_fw, + init_flags=init_flags, method_flags=method_flags, init_all_as_kwargs_np={"data": x[0]}, - init_input_dtypes=[input_dtype[0]], - method_input_dtypes=[input_dtype[1]], - method_all_as_kwargs_np={"power": x[1]}, + init_input_dtypes=dtype, + method_input_dtypes=dtype, + method_all_as_kwargs_np={"other": x[1]}, class_name=class_name, method_name=method_name, ) @handle_method( - method_tree="Shape.__add__", + method_tree="Shape.__ge__", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2, - large_abs_safety_factor=2.5, - small_abs_safety_factor=2.5, - safety_factor_scale="log", shared_dtype=True, ), ) -def test_shape__add__( +def test_shape__ge__( dtype_and_x, method_name, class_name, @@ -165,54 +150,47 @@ def test_shape__add__( @handle_method( - method_tree="Shape.__radd__", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, - large_abs_safety_factor=2.5, - small_abs_safety_factor=2.5, - safety_factor_scale="log", - shared_dtype=True, + method_tree="Shape.__getitem__", + dtypes_x_query=helpers.dtype_array_query( + available_dtypes=helpers.get_dtypes("valid"), + allow_neg_step=False, ), ) -def test_shape__radd__( - dtype_and_x, +def test_shape__getitem__( + dtypes_x_query, + init_flags, + method_flags, method_name, class_name, backend_fw, ground_truth_backend, - init_flags, - method_flags, on_device, ): - dtype, x = dtype_and_x + dtypes, x, query = dtypes_x_query helpers.test_method( on_device=on_device, - ground_truth_backend=ground_truth_backend, backend_to_test=backend_fw, + ground_truth_backend=ground_truth_backend, init_flags=init_flags, method_flags=method_flags, - init_all_as_kwargs_np={"data": x[0]}, - init_input_dtypes=dtype, - method_input_dtypes=dtype, - method_all_as_kwargs_np={"other": x[1]}, + init_all_as_kwargs_np={"shape_tup": x}, + init_input_dtypes=[dtypes[0]], + method_input_dtypes=[dtypes[1]], + method_all_as_kwargs_np={"key": query}, class_name=class_name, method_name=method_name, ) @handle_method( - method_tree="Shape.__sub__", + method_tree="Shape.__gt__", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2, - large_abs_safety_factor=2.5, - small_abs_safety_factor=2.5, - safety_factor_scale="log", shared_dtype=True, ), ) -def test_shape__sub__( +def test_shape__gt__( dtype_and_x, method_name, class_name, @@ -239,22 +217,17 @@ def test_shape__sub__( @handle_method( - method_tree="Shape.__rsub__", + method_tree="Shape.__index__", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, - large_abs_safety_factor=2.5, - small_abs_safety_factor=2.5, - safety_factor_scale="log", - shared_dtype=True, ), ) -def test_shape__rsub__( +def test_shape__index__( dtype_and_x, method_name, class_name, - backend_fw, ground_truth_backend, + backend_fw, init_flags, method_flags, on_device, @@ -269,29 +242,28 @@ def test_shape__rsub__( init_all_as_kwargs_np={"data": x[0]}, init_input_dtypes=dtype, method_input_dtypes=dtype, - method_all_as_kwargs_np={"other": x[1]}, + method_all_as_kwargs_np={}, class_name=class_name, method_name=method_name, ) @handle_method( - method_tree="Shape.__mul__", + method_tree="Shape.__int__", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, - large_abs_safety_factor=2.5, - small_abs_safety_factor=2.5, - safety_factor_scale="log", - shared_dtype=True, + max_num_dims=0, + min_value=-1e15, + max_value=1e15, ), + method_container_flags=st.just([False]), ) -def test_shape__mul__( +def test_shape__int__( dtype_and_x, method_name, class_name, - backend_fw, ground_truth_backend, + backend_fw, init_flags, method_flags, on_device, @@ -300,33 +272,32 @@ def test_shape__mul__( helpers.test_method( on_device=on_device, ground_truth_backend=ground_truth_backend, + backend_to_test=backend_fw, init_flags=init_flags, method_flags=method_flags, init_all_as_kwargs_np={"data": x[0]}, init_input_dtypes=dtype, - method_input_dtypes=dtype, - method_all_as_kwargs_np={"other": x[1]}, + method_input_dtypes=[], + method_all_as_kwargs_np={}, class_name=class_name, method_name=method_name, ) @handle_method( - method_tree="Shape.__rmul__", + method_tree="Shape.__iter__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, - large_abs_safety_factor=2.5, - small_abs_safety_factor=2.5, - safety_factor_scale="log", - shared_dtype=True, + available_dtypes=helpers.get_dtypes("integer"), + min_dim_size=2, + min_num_dims=1, ), ) -def test_shape__rmul__( +def test_shape__iter__( dtype_and_x, method_name, class_name, ground_truth_backend, + backend_fw, init_flags, method_flags, on_device, @@ -335,40 +306,37 @@ def test_shape__rmul__( helpers.test_method( on_device=on_device, ground_truth_backend=ground_truth_backend, + backend_to_test=backend_fw, init_flags=init_flags, method_flags=method_flags, init_all_as_kwargs_np={"data": x[0]}, init_input_dtypes=dtype, method_input_dtypes=dtype, - method_all_as_kwargs_np={"other": x[1]}, + method_all_as_kwargs_np={}, class_name=class_name, method_name=method_name, ) @handle_method( - method_tree="Shape.__mod__", + method_tree="Shape.__le__", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2, - large_abs_safety_factor=2.5, - small_abs_safety_factor=2.5, - safety_factor_scale="log", shared_dtype=True, ), ) -def test_shape__mod__( +def test_shape__le__( dtype_and_x, method_name, class_name, - backend_fw, ground_truth_backend, + backend_fw, init_flags, method_flags, on_device, ): dtype, x = dtype_and_x - assume(not np.any(np.isclose(x[1], 0))) helpers.test_method( on_device=on_device, ground_truth_backend=ground_truth_backend, @@ -385,28 +353,24 @@ def test_shape__mod__( @handle_method( - method_tree="Shape.__rmod__", + method_tree="Shape.__len__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, - large_abs_safety_factor=2.5, - small_abs_safety_factor=2.5, - safety_factor_scale="log", - shared_dtype=True, + available_dtypes=helpers.get_dtypes("integer"), + min_dim_size=2, + min_num_dims=1, ), ) -def test_shape__rmod__( +def test_shape__len__( dtype_and_x, method_name, class_name, - ground_truth_backend, backend_fw, + ground_truth_backend, init_flags, method_flags, on_device, ): dtype, x = dtype_and_x - assume(not np.any(np.isclose(x[0], 0))) helpers.test_method( on_device=on_device, ground_truth_backend=ground_truth_backend, @@ -416,24 +380,21 @@ def test_shape__rmod__( init_all_as_kwargs_np={"data": x[0]}, init_input_dtypes=dtype, method_input_dtypes=dtype, - method_all_as_kwargs_np={"other": x[1]}, + method_all_as_kwargs_np={}, class_name=class_name, method_name=method_name, ) @handle_method( - method_tree="Shape.__rdivmod__", + method_tree="Shape.__lt__", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2, - large_abs_safety_factor=2.5, - small_abs_safety_factor=2.5, - safety_factor_scale="log", shared_dtype=True, ), ) -def test_shape__rdivmod__( +def test_shape__lt__( dtype_and_x, method_name, class_name, @@ -444,7 +405,6 @@ def test_shape__rdivmod__( on_device, ): dtype, x = dtype_and_x - assume(not np.any(np.isclose(x[0], 0))) helpers.test_method( on_device=on_device, ground_truth_backend=ground_truth_backend, @@ -461,7 +421,7 @@ def test_shape__rdivmod__( @handle_method( - method_tree="Shape.__truediv__", + method_tree="Shape.__mod__", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2, @@ -471,17 +431,18 @@ def test_shape__rdivmod__( shared_dtype=True, ), ) -def test_shape__truediv__( +def test_shape__mod__( dtype_and_x, method_name, class_name, - ground_truth_backend, backend_fw, + ground_truth_backend, init_flags, method_flags, on_device, ): dtype, x = dtype_and_x + assume(not np.any(np.isclose(x[1], 0))) helpers.test_method( on_device=on_device, ground_truth_backend=ground_truth_backend, @@ -498,7 +459,7 @@ def test_shape__truediv__( @handle_method( - method_tree="Shape.__rtruediv__", + method_tree="Shape.__mul__", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2, @@ -508,12 +469,12 @@ def test_shape__truediv__( shared_dtype=True, ), ) -def test_shape__rtruediv__( +def test_shape__mul__( dtype_and_x, method_name, class_name, - ground_truth_backend, backend_fw, + ground_truth_backend, init_flags, method_flags, on_device, @@ -522,7 +483,6 @@ def test_shape__rtruediv__( helpers.test_method( on_device=on_device, ground_truth_backend=ground_truth_backend, - backend_to_test=backend_fw, init_flags=init_flags, method_flags=method_flags, init_all_as_kwargs_np={"data": x[0]}, @@ -535,17 +495,14 @@ def test_shape__rtruediv__( @handle_method( - method_tree="Shape.__rfloordiv__", + method_tree="Shape.__ne__", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2, - large_abs_safety_factor=3.0, - small_abs_safety_factor=3.0, - safety_factor_scale="log", shared_dtype=True, ), ) -def test_shape__rfloordiv__( +def test_shape__ne__( dtype_and_x, method_name, class_name, @@ -556,7 +513,6 @@ def test_shape__rfloordiv__( on_device, ): dtype, x = dtype_and_x - assume(not np.any(np.isclose(x[0], 0))) helpers.test_method( on_device=on_device, ground_truth_backend=ground_truth_backend, @@ -573,56 +529,75 @@ def test_shape__rfloordiv__( @handle_method( - method_tree="Shape.__int__", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - max_num_dims=0, - min_value=-1e15, - max_value=1e15, - ), - method_container_flags=st.just([False]), + method_tree="Shape.__pow__", + dtype_and_x=pow_helper(), ) -def test_shape__int__( +def test_shape__pow__( dtype_and_x, method_name, class_name, - ground_truth_backend, backend_fw, + ground_truth_backend, init_flags, method_flags, on_device, ): - dtype, x = dtype_and_x + input_dtype, x = dtype_and_x + + # bfloat16 is not supported by numpy + assume(not ("bfloat16" in input_dtype)) + + # Make sure x2 isn't a float when x1 is integer + with BackendHandler.update_backend(backend_fw) as ivy_backend: + assume( + not ( + ivy_backend.is_int_dtype( + input_dtype[0] and ivy_backend.is_float_dtype(input_dtype[1]) + ) + ) + ) + + # Make sure x2 is non-negative when both is integer + if ivy_backend.is_int_dtype(input_dtype[1]) and ivy_backend.is_int_dtype( + input_dtype[0] + ): + x[1] = np.abs(x[1]) + + x[0] = not_too_close_to_zero(x[0]) + x[1] = not_too_close_to_zero(x[1]) + helpers.test_method( on_device=on_device, ground_truth_backend=ground_truth_backend, - backend_to_test=backend_fw, init_flags=init_flags, + backend_to_test=backend_fw, method_flags=method_flags, init_all_as_kwargs_np={"data": x[0]}, - init_input_dtypes=dtype, - method_input_dtypes=[], - method_all_as_kwargs_np={}, + init_input_dtypes=[input_dtype[0]], + method_input_dtypes=[input_dtype[1]], + method_all_as_kwargs_np={"power": x[1]}, class_name=class_name, method_name=method_name, ) @handle_method( - method_tree="Shape.__bool__", + method_tree="Shape.__radd__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=st.one_of(st.just(("bool",)), helpers.get_dtypes("integer")), - max_num_dims=0, - min_value=0, - max_value=1, + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, + large_abs_safety_factor=2.5, + small_abs_safety_factor=2.5, + safety_factor_scale="log", + shared_dtype=True, ), ) -def test_shape__bool__( +def test_shape__radd__( dtype_and_x, method_name, class_name, - ground_truth_backend, backend_fw, + ground_truth_backend, init_flags, method_flags, on_device, @@ -636,22 +611,25 @@ def test_shape__bool__( method_flags=method_flags, init_all_as_kwargs_np={"data": x[0]}, init_input_dtypes=dtype, - method_input_dtypes=[], - method_all_as_kwargs_np={}, + method_input_dtypes=dtype, + method_all_as_kwargs_np={"other": x[1]}, class_name=class_name, method_name=method_name, ) @handle_method( - method_tree="Shape.__lt__", + method_tree="Shape.__rdivmod__", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2, + large_abs_safety_factor=2.5, + small_abs_safety_factor=2.5, + safety_factor_scale="log", shared_dtype=True, ), ) -def test_shape__lt__( +def test_shape__rdivmod__( dtype_and_x, method_name, class_name, @@ -662,6 +640,7 @@ def test_shape__lt__( on_device, ): dtype, x = dtype_and_x + assume(not np.any(np.isclose(x[0], 0))) helpers.test_method( on_device=on_device, ground_truth_backend=ground_truth_backend, @@ -678,14 +657,17 @@ def test_shape__lt__( @handle_method( - method_tree="Shape.__le__", + method_tree="Shape.__rfloordiv__", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2, + large_abs_safety_factor=3.0, + small_abs_safety_factor=3.0, + safety_factor_scale="log", shared_dtype=True, ), ) -def test_shape__le__( +def test_shape__rfloordiv__( dtype_and_x, method_name, class_name, @@ -696,6 +678,7 @@ def test_shape__le__( on_device, ): dtype, x = dtype_and_x + assume(not np.any(np.isclose(x[0], 0))) helpers.test_method( on_device=on_device, ground_truth_backend=ground_truth_backend, @@ -712,14 +695,17 @@ def test_shape__le__( @handle_method( - method_tree="Shape.__eq__", + method_tree="Shape.__rmod__", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2, + large_abs_safety_factor=2.5, + small_abs_safety_factor=2.5, + safety_factor_scale="log", shared_dtype=True, ), ) -def test_shape__eq__( +def test_shape__rmod__( dtype_and_x, method_name, class_name, @@ -730,6 +716,7 @@ def test_shape__eq__( on_device, ): dtype, x = dtype_and_x + assume(not np.any(np.isclose(x[0], 0))) helpers.test_method( on_device=on_device, ground_truth_backend=ground_truth_backend, @@ -746,19 +733,21 @@ def test_shape__eq__( @handle_method( - method_tree="Shape.__ne__", + method_tree="Shape.__rmul__", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2, + large_abs_safety_factor=2.5, + small_abs_safety_factor=2.5, + safety_factor_scale="log", shared_dtype=True, ), ) -def test_shape__ne__( +def test_shape__rmul__( dtype_and_x, method_name, class_name, ground_truth_backend, - backend_fw, init_flags, method_flags, on_device, @@ -767,7 +756,6 @@ def test_shape__ne__( helpers.test_method( on_device=on_device, ground_truth_backend=ground_truth_backend, - backend_to_test=backend_fw, init_flags=init_flags, method_flags=method_flags, init_all_as_kwargs_np={"data": x[0]}, @@ -780,19 +768,22 @@ def test_shape__ne__( @handle_method( - method_tree="Shape.__gt__", + method_tree="Shape.__rsub__", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2, + large_abs_safety_factor=2.5, + small_abs_safety_factor=2.5, + safety_factor_scale="log", shared_dtype=True, ), ) -def test_shape__gt__( +def test_shape__rsub__( dtype_and_x, method_name, class_name, - ground_truth_backend, backend_fw, + ground_truth_backend, init_flags, method_flags, on_device, @@ -814,14 +805,17 @@ def test_shape__gt__( @handle_method( - method_tree="Shape.__ge__", + method_tree="Shape.__rtruediv__", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), num_arrays=2, + large_abs_safety_factor=2.5, + small_abs_safety_factor=2.5, + safety_factor_scale="log", shared_dtype=True, ), ) -def test_shape__ge__( +def test_shape__rtruediv__( dtype_and_x, method_name, class_name, @@ -848,19 +842,22 @@ def test_shape__ge__( @handle_method( - method_tree="Shape.__len__", + method_tree="Shape.__sub__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("integer"), - min_dim_size=2, - min_num_dims=1, + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, + large_abs_safety_factor=2.5, + small_abs_safety_factor=2.5, + safety_factor_scale="log", + shared_dtype=True, ), ) -def test_shape__len__( +def test_shape__sub__( dtype_and_x, method_name, class_name, - backend_fw, ground_truth_backend, + backend_fw, init_flags, method_flags, on_device, @@ -875,21 +872,24 @@ def test_shape__len__( init_all_as_kwargs_np={"data": x[0]}, init_input_dtypes=dtype, method_input_dtypes=dtype, - method_all_as_kwargs_np={}, + method_all_as_kwargs_np={"other": x[1]}, class_name=class_name, method_name=method_name, ) @handle_method( - method_tree="Shape.__iter__", + method_tree="Shape.__truediv__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("integer"), - min_dim_size=2, - min_num_dims=1, + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=2, + large_abs_safety_factor=2.5, + small_abs_safety_factor=2.5, + safety_factor_scale="log", + shared_dtype=True, ), ) -def test_shape__iter__( +def test_shape__truediv__( dtype_and_x, method_name, class_name, @@ -909,7 +909,7 @@ def test_shape__iter__( init_all_as_kwargs_np={"data": x[0]}, init_input_dtypes=dtype, method_input_dtypes=dtype, - method_all_as_kwargs_np={}, + method_all_as_kwargs_np={"other": x[1]}, class_name=class_name, method_name=method_name, ) diff --git a/ivy_tests/test_ivy/test_misc/test_tucker_tensor.py b/ivy_tests/test_ivy/test_misc/test_tucker_tensor.py index 7702c6ddff33a..3601d740e4b51 100644 --- a/ivy_tests/test_ivy/test_misc/test_tucker_tensor.py +++ b/ivy_tests/test_ivy/test_misc/test_tucker_tensor.py @@ -4,38 +4,70 @@ import pytest -# These tests have been adapted from TensorLy -# https://github.com/tensorly/tensorly/blob/main/tensorly/tests/test_tucker_tensor.py -@pytest.mark.parametrize("true_shape, true_rank", [((3, 4, 5), (3, 2, 4))]) -def test_validate_tucker_tensor(true_shape, true_rank): - core, factors = ivy.random_tucker(true_shape, true_rank) +@pytest.mark.parametrize("shape, rank", [((5, 4, 6), (3, 2, 3))]) +def test_n_param_tucker(shape, rank): + tucker_tensor = ivy.random_tucker(shape, rank) + true_n_param = ivy.prod(ivy.shape(tucker_tensor[0])) + ivy.sum( + [ivy.prod(ivy.shape(f)) for f in tucker_tensor[1]] + ) + n_param = tucker_tensor.n_param + assert np.allclose(n_param, true_n_param) - # Check shape and rank returned - shape, rank = ivy.TuckerTensor.validate_tucker_tensor((core, factors)) - np.testing.assert_equal( - shape, - true_shape, - err_msg=f"Returned incorrect shape (got {shape}, expected {true_shape})", + +@pytest.mark.parametrize("shape, rank", [((3, 4, 5), 4)]) +def test_tucker_copy(shape, rank): + tucker_tensor = ivy.random_tucker(shape, rank) + core, factors = tucker_tensor + core_normalized, factors_normalized = ivy.TuckerTensor.tucker_normalize( + tucker_tensor.tucker_copy() ) - np.testing.assert_equal( - rank, - true_rank, - err_msg=f"Returned incorrect rank (got {rank}, expected {true_rank})", + # Check that modifying copy tensor doesn't change the original tensor + assert np.allclose( + ivy.TuckerTensor.tucker_to_tensor((core, factors)), + ivy.TuckerTensor.tucker_to_tensor(tucker_tensor), ) - # One of the factors has the wrong rank - factors[0], copy = ivy.random_uniform(shape=((4, 4))), factors[0] - with np.testing.assert_raises(ValueError): - ivy.TuckerTensor.validate_tucker_tensor((core, factors)) - # Not enough factors to match core - factors[0] = copy - with np.testing.assert_raises(ValueError): - ivy.TuckerTensor.validate_tucker_tensor((core, factors[1:])) +@pytest.mark.parametrize("shape, ranks", [((5, 4, 6), (3, 2, 3))]) +def test_tucker_mode_dot(shape, ranks): + tucker_ten = ivy.random_tucker(shape, ranks, full=False) + full_tensor = ivy.TuckerTensor.tucker_to_tensor(tucker_ten) + # matrix for mode 1 + matrix = ivy.random_uniform(shape=(7, shape[1])) + # vec for mode 2 + vec = ivy.random_uniform(shape=(shape[2])) - # Not enough factors - with np.testing.assert_raises(ValueError): - ivy.TuckerTensor.validate_tucker_tensor((core, factors[:1])) + # Test tucker_mode_dot with matrix + res = ivy.TuckerTensor.tucker_mode_dot(tucker_ten, matrix, mode=1, copy=True) + # Note that if copy=True is not respected, factors will be changes + # And the next test will fail + res = ivy.TuckerTensor.tucker_to_tensor(res) + true_res = ivy.mode_dot(full_tensor, matrix, mode=1) + assert np.allclose(true_res, res) + + # Check that the data was indeed copied + rec = ivy.TuckerTensor.tucker_to_tensor(tucker_ten) + assert np.allclose(full_tensor, rec) + + # Test tucker_mode_dot with vec + res = ivy.TuckerTensor.tucker_mode_dot(tucker_ten, vec, mode=2, copy=True) + res = ivy.TuckerTensor.tucker_to_tensor(res) + true_res = ivy.mode_dot(full_tensor, vec, mode=2) + assert np.allclose(res.shape, true_res.shape) + assert np.allclose(true_res, res) + + +@pytest.mark.parametrize("shape, rank", [((3, 4, 5), (3, 2, 4))]) +def test_tucker_normalize(shape, rank): + tucker_ten = ivy.random_tucker(shape, rank) + core, factors = ivy.TuckerTensor.tucker_normalize(tucker_ten) + for i in range(len(factors)): + norm = ivy.sqrt(ivy.sum(ivy.abs(factors[i]) ** 2, axis=0)) + assert np.allclose(norm, ivy.ones(rank[i])) + assert np.allclose( + ivy.TuckerTensor.tucker_to_tensor((core, factors)), + ivy.TuckerTensor.tucker_to_tensor(tucker_ten), + ) @pytest.mark.parametrize( @@ -107,45 +139,6 @@ def test_tucker_to_vec(shape, ranks): ) -@pytest.mark.parametrize("shape, ranks", [((5, 4, 6), (3, 2, 3))]) -def test_tucker_mode_dot(shape, ranks): - tucker_ten = ivy.random_tucker(shape, ranks, full=False) - full_tensor = ivy.TuckerTensor.tucker_to_tensor(tucker_ten) - # matrix for mode 1 - matrix = ivy.random_uniform(shape=(7, shape[1])) - # vec for mode 2 - vec = ivy.random_uniform(shape=(shape[2])) - - # Test tucker_mode_dot with matrix - res = ivy.TuckerTensor.tucker_mode_dot(tucker_ten, matrix, mode=1, copy=True) - # Note that if copy=True is not respected, factors will be changes - # And the next test will fail - res = ivy.TuckerTensor.tucker_to_tensor(res) - true_res = ivy.mode_dot(full_tensor, matrix, mode=1) - assert np.allclose(true_res, res) - - # Check that the data was indeed copied - rec = ivy.TuckerTensor.tucker_to_tensor(tucker_ten) - assert np.allclose(full_tensor, rec) - - # Test tucker_mode_dot with vec - res = ivy.TuckerTensor.tucker_mode_dot(tucker_ten, vec, mode=2, copy=True) - res = ivy.TuckerTensor.tucker_to_tensor(res) - true_res = ivy.mode_dot(full_tensor, vec, mode=2) - assert np.allclose(res.shape, true_res.shape) - assert np.allclose(true_res, res) - - -@pytest.mark.parametrize("shape, rank", [((5, 4, 6), (3, 2, 3))]) -def test_n_param_tucker(shape, rank): - tucker_tensor = ivy.random_tucker(shape, rank) - true_n_param = ivy.prod(ivy.shape(tucker_tensor[0])) + ivy.sum( - [ivy.prod(ivy.shape(f)) for f in tucker_tensor[1]] - ) - n_param = tucker_tensor.n_param - assert np.allclose(n_param, true_n_param) - - @pytest.mark.parametrize("tol", [(0.01)]) def test_validate_tucker_rank(tol): tensor_shape = tuple(ivy.randint(1, 100, shape=(5,))) @@ -196,28 +189,35 @@ def test_validate_tucker_rank(tol): assert n_param >= n_param_tensor * 0.5 * (1 - tol) -@pytest.mark.parametrize("shape, rank", [((3, 4, 5), (3, 2, 4))]) -def test_tucker_normalize(shape, rank): - tucker_ten = ivy.random_tucker(shape, rank) - core, factors = ivy.TuckerTensor.tucker_normalize(tucker_ten) - for i in range(len(factors)): - norm = ivy.sqrt(ivy.sum(ivy.abs(factors[i]) ** 2, axis=0)) - assert np.allclose(norm, ivy.ones(rank[i])) - assert np.allclose( - ivy.TuckerTensor.tucker_to_tensor((core, factors)), - ivy.TuckerTensor.tucker_to_tensor(tucker_ten), - ) - +# These tests have been adapted from TensorLy +# https://github.com/tensorly/tensorly/blob/main/tensorly/tests/test_tucker_tensor.py +@pytest.mark.parametrize("true_shape, true_rank", [((3, 4, 5), (3, 2, 4))]) +def test_validate_tucker_tensor(true_shape, true_rank): + core, factors = ivy.random_tucker(true_shape, true_rank) -@pytest.mark.parametrize("shape, rank", [((3, 4, 5), 4)]) -def test_tucker_copy(shape, rank): - tucker_tensor = ivy.random_tucker(shape, rank) - core, factors = tucker_tensor - core_normalized, factors_normalized = ivy.TuckerTensor.tucker_normalize( - tucker_tensor.tucker_copy() + # Check shape and rank returned + shape, rank = ivy.TuckerTensor.validate_tucker_tensor((core, factors)) + np.testing.assert_equal( + shape, + true_shape, + err_msg=f"Returned incorrect shape (got {shape}, expected {true_shape})", ) - # Check that modifying copy tensor doesn't change the original tensor - assert np.allclose( - ivy.TuckerTensor.tucker_to_tensor((core, factors)), - ivy.TuckerTensor.tucker_to_tensor(tucker_tensor), + np.testing.assert_equal( + rank, + true_rank, + err_msg=f"Returned incorrect rank (got {rank}, expected {true_rank})", ) + + # One of the factors has the wrong rank + factors[0], copy = ivy.random_uniform(shape=((4, 4))), factors[0] + with np.testing.assert_raises(ValueError): + ivy.TuckerTensor.validate_tucker_tensor((core, factors)) + + # Not enough factors to match core + factors[0] = copy + with np.testing.assert_raises(ValueError): + ivy.TuckerTensor.validate_tucker_tensor((core, factors[1:])) + + # Not enough factors + with np.testing.assert_raises(ValueError): + ivy.TuckerTensor.validate_tucker_tensor((core, factors[:1])) diff --git a/ivy_tests/test_ivy/test_misc/test_with_backend.py b/ivy_tests/test_ivy/test_misc/test_with_backend.py index 4864839046d67..9525194141cd5 100644 --- a/ivy_tests/test_ivy/test_misc/test_with_backend.py +++ b/ivy_tests/test_ivy/test_misc/test_with_backend.py @@ -18,6 +18,11 @@ def compiled_backends(): return compiled_backends +def test_is_local(backend_fw): + local_ivy = ivy.with_backend(backend_fw) + assert local_ivy.is_local() + + @settings( # To be able to share compiled_backends between examples suppress_health_check=[HealthCheck(9)] @@ -45,20 +50,15 @@ def test_prevent_access(backend_fw): local_ivy.set_backend(backend_fw) -def test_with_backend_cached(backend_fw): - non_cached_local_ivy = ivy.with_backend(backend_fw) - cached_local_ivy = ivy.with_backend(backend_fw) - assert non_cached_local_ivy == cached_local_ivy - - -def test_is_local(backend_fw): - local_ivy = ivy.with_backend(backend_fw) - assert local_ivy.is_local() - - def test_with_backend_array(backend_fw): local_ivy = ivy.with_backend(backend_fw) local_x = local_ivy.array([1, 2, 3, 4]) ivy.set_backend(backend_fw) x = ivy.array([1, 2, 3, 4]) assert np.allclose(x._data, local_x._data) + + +def test_with_backend_cached(backend_fw): + non_cached_local_ivy = ivy.with_backend(backend_fw) + cached_local_ivy = ivy.with_backend(backend_fw) + assert non_cached_local_ivy == cached_local_ivy diff --git a/ivy_tests/test_ivy/test_stateful/test_activations.py b/ivy_tests/test_ivy/test_stateful/test_activations.py index 65a5f2598e972..f9bf437ae759b 100644 --- a/ivy_tests/test_ivy/test_stateful/test_activations.py +++ b/ivy_tests/test_ivy/test_stateful/test_activations.py @@ -8,27 +8,29 @@ from ivy_tests.test_ivy.helpers import handle_method -# GELU +# ELU @handle_method( - method_tree="stateful.activations.GELU.__call__", + method_tree="stateful.activations.ELU.__call__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - large_abs_safety_factor=1, - small_abs_safety_factor=1, - safety_factor_scale="linear", + available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, + shared_dtype=True, + min_num_dims=2, + large_abs_safety_factor=8, + small_abs_safety_factor=8, + safety_factor_scale="log", ), - approximate=st.booleans(), - method_num_positional_args=helpers.num_positional_args(fn_name="GELU._forward"), + method_num_positional_args=helpers.num_positional_args(fn_name="ELU._forward"), test_gradients=st.just(True), + alpha=helpers.floats(min_value=0.1, max_value=1), ) -def test_gelu( +def test_elu( *, dtype_and_x, - approximate, + alpha, test_gradients, - method_name, class_name, - backend_fw, + method_name, ground_truth_backend, init_flags, method_flags, @@ -36,18 +38,17 @@ def test_gelu( ): input_dtype, x = dtype_and_x helpers.test_method( - backend_to_test=backend_fw, ground_truth_backend=ground_truth_backend, init_flags=init_flags, method_flags=method_flags, init_input_dtypes=input_dtype, - init_all_as_kwargs_np={"approximate": approximate}, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={"x": x[0]}, + init_all_as_kwargs_np={}, + method_all_as_kwargs_np={"x": x[0], "alpha": alpha}, class_name=class_name, method_name=method_name, - atol_=1e-2, rtol_=1e-2, + atol_=1e-2, test_gradients=test_gradients, on_device=on_device, ) @@ -99,23 +100,26 @@ def test_geglu( ) +# GELU @handle_method( - method_tree="stateful.activations.ReLU.__call__", + method_tree="stateful.activations.GELU.__call__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - large_abs_safety_factor=8, - small_abs_safety_factor=8, - safety_factor_scale="log", + available_dtypes=helpers.get_dtypes("numeric"), + large_abs_safety_factor=1, + small_abs_safety_factor=1, + safety_factor_scale="linear", ), - method_num_positional_args=helpers.num_positional_args(fn_name="ReLU._forward"), + approximate=st.booleans(), + method_num_positional_args=helpers.num_positional_args(fn_name="GELU._forward"), test_gradients=st.just(True), ) -def test_relu( +def test_gelu( *, dtype_and_x, + approximate, test_gradients, - class_name, method_name, + class_name, backend_fw, ground_truth_backend, init_flags, @@ -129,42 +133,39 @@ def test_relu( init_flags=init_flags, method_flags=method_flags, init_input_dtypes=input_dtype, + init_all_as_kwargs_np={"approximate": approximate}, method_input_dtypes=input_dtype, - init_all_as_kwargs_np={}, method_all_as_kwargs_np={"x": x[0]}, class_name=class_name, method_name=method_name, - rtol_=1e-2, atol_=1e-2, + rtol_=1e-2, test_gradients=test_gradients, on_device=on_device, ) +# Hardswish @handle_method( - method_tree="stateful.activations.LeakyReLU.__call__", + method_tree="stateful.activations.Hardswish.__call__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes( - "float_and_complex", full=False, key="leaky_relu" - ), - large_abs_safety_factor=16, - small_abs_safety_factor=16, + available_dtypes=helpers.get_dtypes("float"), + large_abs_safety_factor=8, + small_abs_safety_factor=8, safety_factor_scale="log", + min_num_dims=2, ), - alpha=st.floats(min_value=-1e-4, max_value=1e-4), method_num_positional_args=helpers.num_positional_args( - fn_name="LeakyReLU._forward" + fn_name="Hardswish._forward" ), test_gradients=st.just(True), ) -def test_leaky_relu( +def test_hardswish( *, dtype_and_x, - alpha, test_gradients, class_name, method_name, - backend_fw, ground_truth_backend, init_flags, method_flags, @@ -172,13 +173,12 @@ def test_leaky_relu( ): input_dtype, x = dtype_and_x helpers.test_method( - backend_to_test=backend_fw, ground_truth_backend=ground_truth_backend, init_flags=init_flags, method_flags=method_flags, init_input_dtypes=input_dtype, method_input_dtypes=input_dtype, - init_all_as_kwargs_np={"alpha": alpha}, + init_all_as_kwargs_np={}, method_all_as_kwargs_np={"x": x[0]}, class_name=class_name, method_name=method_name, @@ -190,22 +190,25 @@ def test_leaky_relu( @handle_method( - method_tree="stateful.activations.Softmax.__call__", + method_tree="stateful.activations.LeakyReLU.__call__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_num_dims=1, - large_abs_safety_factor=8, - small_abs_safety_factor=8, + available_dtypes=helpers.get_dtypes( + "float_and_complex", full=False, key="leaky_relu" + ), + large_abs_safety_factor=16, + small_abs_safety_factor=16, safety_factor_scale="log", ), - axis=helpers.ints(min_value=-1, max_value=0), - method_num_positional_args=helpers.num_positional_args(fn_name="Softmax._forward"), + alpha=st.floats(min_value=-1e-4, max_value=1e-4), + method_num_positional_args=helpers.num_positional_args( + fn_name="LeakyReLU._forward" + ), test_gradients=st.just(True), ) -def test_softmax( +def test_leaky_relu( *, dtype_and_x, - axis, + alpha, test_gradients, class_name, method_name, @@ -223,8 +226,8 @@ def test_softmax( method_flags=method_flags, init_input_dtypes=input_dtype, method_input_dtypes=input_dtype, - init_all_as_kwargs_np={}, - method_all_as_kwargs_np={"x": x[0], "axis": axis}, + init_all_as_kwargs_np={"alpha": alpha}, + method_all_as_kwargs_np={"x": x[0]}, class_name=class_name, method_name=method_name, rtol_=1e-2, @@ -281,29 +284,27 @@ def test_log_softmax( ) +# Logit @handle_method( - method_tree="stateful.activations.Softplus.__call__", + method_tree="stateful.activations.Logit.__call__", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - min_num_dims=1, large_abs_safety_factor=8, small_abs_safety_factor=8, safety_factor_scale="log", + min_num_dims=2, ), - beta=st.one_of(helpers.number(min_value=0.1, max_value=10), st.none()), - threshold=st.one_of(helpers.number(min_value=0.1, max_value=30), st.none()), - method_num_positional_args=helpers.num_positional_args(fn_name="Softplus._forward"), + method_num_positional_args=helpers.num_positional_args(fn_name="Logit._forward"), + eps=helpers.floats(min_value=1e-4, max_value=1e-2), test_gradients=st.just(True), ) -def test_softplus( +def test_logit( *, dtype_and_x, - beta, - threshold, + eps, test_gradients, class_name, method_name, - backend_fw, ground_truth_backend, init_flags, method_flags, @@ -311,14 +312,13 @@ def test_softplus( ): input_dtype, x = dtype_and_x helpers.test_method( - backend_to_test=backend_fw, ground_truth_backend=ground_truth_backend, init_flags=init_flags, method_flags=method_flags, init_input_dtypes=input_dtype, method_input_dtypes=input_dtype, init_all_as_kwargs_np={}, - method_all_as_kwargs_np={"x": x[0], "beta": beta, "threshold": threshold}, + method_all_as_kwargs_np={"x": x[0], "eps": eps}, class_name=class_name, method_name=method_name, rtol_=1e-2, @@ -328,24 +328,27 @@ def test_softplus( ) +# Logsigmoid @handle_method( - method_tree="stateful.activations.Mish.__call__", + method_tree="stateful.activations.LogSigmoid.__call__", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), large_abs_safety_factor=8, small_abs_safety_factor=8, safety_factor_scale="log", + min_num_dims=2, + ), + method_num_positional_args=helpers.num_positional_args( + fn_name="LogSigmoid._forward" ), - method_num_positional_args=helpers.num_positional_args(fn_name="Mish._forward"), test_gradients=st.just(True), ) -def test_mish( +def test_logsigmoid( *, dtype_and_x, test_gradients, class_name, method_name, - backend_fw, ground_truth_backend, init_flags, method_flags, @@ -353,7 +356,6 @@ def test_mish( ): input_dtype, x = dtype_and_x helpers.test_method( - backend_to_test=backend_fw, ground_truth_backend=ground_truth_backend, init_flags=init_flags, method_flags=method_flags, @@ -371,17 +373,17 @@ def test_mish( @handle_method( - method_tree="stateful.activations.SiLU.__call__", + method_tree="stateful.activations.Mish.__call__", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), large_abs_safety_factor=8, small_abs_safety_factor=8, safety_factor_scale="log", ), - method_num_positional_args=helpers.num_positional_args(fn_name="SiLU._forward"), + method_num_positional_args=helpers.num_positional_args(fn_name="Mish._forward"), test_gradients=st.just(True), ) -def test_silu( +def test_mish( *, dtype_and_x, test_gradients, @@ -412,20 +414,22 @@ def test_silu( ) -# Sigmoid +# PReLU @handle_method( - method_tree="stateful.activations.Sigmoid.__call__", + method_tree="stateful.activations.PReLU.__call__", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), + num_arrays=2, + shared_dtype=True, + min_num_dims=2, large_abs_safety_factor=8, small_abs_safety_factor=8, safety_factor_scale="log", - min_num_dims=2, ), - method_num_positional_args=helpers.num_positional_args(fn_name="Sigmoid._forward"), + method_num_positional_args=helpers.num_positional_args(fn_name="PReLU._forward"), test_gradients=st.just(True), ) -def test_sigmoid( +def test_prelu( *, dtype_and_x, test_gradients, @@ -444,7 +448,7 @@ def test_sigmoid( init_input_dtypes=input_dtype, method_input_dtypes=input_dtype, init_all_as_kwargs_np={}, - method_all_as_kwargs_np={"x": x[0]}, + method_all_as_kwargs_np={"x": x[0], "slope": x[1]}, class_name=class_name, method_name=method_name, rtol_=1e-2, @@ -454,25 +458,24 @@ def test_sigmoid( ) -# Tanh @handle_method( - method_tree="stateful.activations.Tanh.__call__", + method_tree="stateful.activations.ReLU.__call__", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), large_abs_safety_factor=8, small_abs_safety_factor=8, safety_factor_scale="log", - min_num_dims=2, ), - method_num_positional_args=helpers.num_positional_args(fn_name="Tanh._forward"), + method_num_positional_args=helpers.num_positional_args(fn_name="ReLU._forward"), test_gradients=st.just(True), ) -def test_tanh( +def test_relu( *, dtype_and_x, test_gradients, class_name, method_name, + backend_fw, ground_truth_backend, init_flags, method_flags, @@ -480,6 +483,7 @@ def test_tanh( ): input_dtype, x = dtype_and_x helpers.test_method( + backend_to_test=backend_fw, ground_truth_backend=ground_truth_backend, init_flags=init_flags, method_flags=method_flags, @@ -538,9 +542,9 @@ def test_relu6( ) -# Hardswish +# SeLU @handle_method( - method_tree="stateful.activations.Hardswish.__call__", + method_tree="stateful.activations.SeLU.__call__", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), large_abs_safety_factor=8, @@ -548,12 +552,10 @@ def test_relu6( safety_factor_scale="log", min_num_dims=2, ), - method_num_positional_args=helpers.num_positional_args( - fn_name="Hardswish._forward" - ), + method_num_positional_args=helpers.num_positional_args(fn_name="SeLU._forward"), test_gradients=st.just(True), ) -def test_hardswish( +def test_selu( *, dtype_and_x, test_gradients, @@ -582,9 +584,9 @@ def test_hardswish( ) -# Logit +# Sigmoid @handle_method( - method_tree="stateful.activations.Logit.__call__", + method_tree="stateful.activations.Sigmoid.__call__", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), large_abs_safety_factor=8, @@ -592,14 +594,12 @@ def test_hardswish( safety_factor_scale="log", min_num_dims=2, ), - method_num_positional_args=helpers.num_positional_args(fn_name="Logit._forward"), - eps=helpers.floats(min_value=1e-4, max_value=1e-2), + method_num_positional_args=helpers.num_positional_args(fn_name="Sigmoid._forward"), test_gradients=st.just(True), ) -def test_logit( +def test_sigmoid( *, dtype_and_x, - eps, test_gradients, class_name, method_name, @@ -616,7 +616,7 @@ def test_logit( init_input_dtypes=input_dtype, method_input_dtypes=input_dtype, init_all_as_kwargs_np={}, - method_all_as_kwargs_np={"x": x[0], "eps": eps}, + method_all_as_kwargs_np={"x": x[0]}, class_name=class_name, method_name=method_name, rtol_=1e-2, @@ -626,27 +626,24 @@ def test_logit( ) -# PReLU @handle_method( - method_tree="stateful.activations.PReLU.__call__", + method_tree="stateful.activations.SiLU.__call__", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - shared_dtype=True, - min_num_dims=2, large_abs_safety_factor=8, small_abs_safety_factor=8, safety_factor_scale="log", ), - method_num_positional_args=helpers.num_positional_args(fn_name="PReLU._forward"), + method_num_positional_args=helpers.num_positional_args(fn_name="SiLU._forward"), test_gradients=st.just(True), ) -def test_prelu( +def test_silu( *, dtype_and_x, test_gradients, class_name, method_name, + backend_fw, ground_truth_backend, init_flags, method_flags, @@ -654,13 +651,14 @@ def test_prelu( ): input_dtype, x = dtype_and_x helpers.test_method( + backend_to_test=backend_fw, ground_truth_backend=ground_truth_backend, init_flags=init_flags, method_flags=method_flags, init_input_dtypes=input_dtype, method_input_dtypes=input_dtype, init_all_as_kwargs_np={}, - method_all_as_kwargs_np={"x": x[0], "slope": x[1]}, + method_all_as_kwargs_np={"x": x[0]}, class_name=class_name, method_name=method_name, rtol_=1e-2, @@ -670,25 +668,27 @@ def test_prelu( ) -# SeLU @handle_method( - method_tree="stateful.activations.SeLU.__call__", + method_tree="stateful.activations.Softmax.__call__", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), + min_num_dims=1, large_abs_safety_factor=8, small_abs_safety_factor=8, safety_factor_scale="log", - min_num_dims=2, ), - method_num_positional_args=helpers.num_positional_args(fn_name="SeLU._forward"), + axis=helpers.ints(min_value=-1, max_value=0), + method_num_positional_args=helpers.num_positional_args(fn_name="Softmax._forward"), test_gradients=st.just(True), ) -def test_selu( +def test_softmax( *, dtype_and_x, + axis, test_gradients, class_name, method_name, + backend_fw, ground_truth_backend, init_flags, method_flags, @@ -696,13 +696,14 @@ def test_selu( ): input_dtype, x = dtype_and_x helpers.test_method( + backend_to_test=backend_fw, ground_truth_backend=ground_truth_backend, init_flags=init_flags, method_flags=method_flags, init_input_dtypes=input_dtype, method_input_dtypes=input_dtype, init_all_as_kwargs_np={}, - method_all_as_kwargs_np={"x": x[0]}, + method_all_as_kwargs_np={"x": x[0], "axis": axis}, class_name=class_name, method_name=method_name, rtol_=1e-2, @@ -712,29 +713,29 @@ def test_selu( ) -# ELU @handle_method( - method_tree="stateful.activations.ELU.__call__", + method_tree="stateful.activations.Softplus.__call__", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), - num_arrays=2, - shared_dtype=True, - min_num_dims=2, + min_num_dims=1, large_abs_safety_factor=8, small_abs_safety_factor=8, safety_factor_scale="log", ), - method_num_positional_args=helpers.num_positional_args(fn_name="ELU._forward"), + beta=st.one_of(helpers.number(min_value=0.1, max_value=10), st.none()), + threshold=st.one_of(helpers.number(min_value=0.1, max_value=30), st.none()), + method_num_positional_args=helpers.num_positional_args(fn_name="Softplus._forward"), test_gradients=st.just(True), - alpha=helpers.floats(min_value=0.1, max_value=1), ) -def test_elu( +def test_softplus( *, dtype_and_x, - alpha, + beta, + threshold, test_gradients, class_name, method_name, + backend_fw, ground_truth_backend, init_flags, method_flags, @@ -742,13 +743,14 @@ def test_elu( ): input_dtype, x = dtype_and_x helpers.test_method( + backend_to_test=backend_fw, ground_truth_backend=ground_truth_backend, init_flags=init_flags, method_flags=method_flags, init_input_dtypes=input_dtype, method_input_dtypes=input_dtype, init_all_as_kwargs_np={}, - method_all_as_kwargs_np={"x": x[0], "alpha": alpha}, + method_all_as_kwargs_np={"x": x[0], "beta": beta, "threshold": threshold}, class_name=class_name, method_name=method_name, rtol_=1e-2, @@ -758,9 +760,9 @@ def test_elu( ) -# Logsigmoid +# Tanh @handle_method( - method_tree="stateful.activations.LogSigmoid.__call__", + method_tree="stateful.activations.Tanh.__call__", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), large_abs_safety_factor=8, @@ -768,12 +770,10 @@ def test_elu( safety_factor_scale="log", min_num_dims=2, ), - method_num_positional_args=helpers.num_positional_args( - fn_name="LogSigmoid._forward" - ), + method_num_positional_args=helpers.num_positional_args(fn_name="Tanh._forward"), test_gradients=st.just(True), ) -def test_logsigmoid( +def test_tanh( *, dtype_and_x, test_gradients, diff --git a/ivy_tests/test_ivy/test_stateful/test_converters.py b/ivy_tests/test_ivy/test_stateful/test_converters.py index 60edee8cdf952..c8c8c741180b0 100644 --- a/ivy_tests/test_ivy/test_stateful/test_converters.py +++ b/ivy_tests/test_ivy/test_stateful/test_converters.py @@ -5,6 +5,10 @@ from types import SimpleNamespace from typing import Sequence + +# local +import ivy + try: import torch import torch.nn as nn @@ -81,9 +85,15 @@ paddle.optimizer.SGD = SimpleNamespace paddle.nn.L1Loss = SimpleNamespace - -# local -import ivy +FROM_CONVERTERS = { + "torch": ivy.Module.from_torch_module, + "jax": { + "haiku": ivy.Module.from_haiku_module, + "flax": ivy.Module.from_flax_module, + }, + "tensorflow": ivy.Module.from_keras_module, + "paddle": ivy.Module.from_paddle_module, +} class TensorflowLinear(tf.keras.Model): @@ -209,27 +219,6 @@ def forward(self, x): return paddle.nn.functional.tanh(self._linear2(x))[0] -NATIVE_MODULES = { - "torch": TorchModule, - "jax": { - "haiku": HaikuModule, - "flax": FlaxModule, - }, - "tensorflow": TensorflowModule, - "paddle": PaddleModule, -} - -FROM_CONVERTERS = { - "torch": ivy.Module.from_torch_module, - "jax": { - "haiku": ivy.Module.from_haiku_module, - "flax": ivy.Module.from_flax_module, - }, - "tensorflow": ivy.Module.from_keras_module, - "paddle": ivy.Module.from_paddle_module, -} - - @pytest.mark.parametrize("bs_ic_oc", [([1, 2], 4, 5)]) @pytest.mark.parametrize("from_class_and_args", [True, False]) def test_from_backend_module(bs_ic_oc, from_class_and_args): @@ -358,3 +347,14 @@ def loss_fn(v_=None): assert loss.shape == () # value test assert (abs(grads).max() > 0).cont_all_true() + + +NATIVE_MODULES = { + "torch": TorchModule, + "jax": { + "haiku": HaikuModule, + "flax": FlaxModule, + }, + "tensorflow": TensorflowModule, + "paddle": PaddleModule, +} diff --git a/ivy_tests/test_ivy/test_stateful/test_initializers.py b/ivy_tests/test_ivy/test_stateful/test_initializers.py index 526063e9e6e1d..f2b0e3afa3f34 100644 --- a/ivy_tests/test_ivy/test_stateful/test_initializers.py +++ b/ivy_tests/test_ivy/test_stateful/test_initializers.py @@ -63,9 +63,14 @@ def test_constant( @handle_method( - method_tree="Zeros.create_variables", - ground_truth_backend="numpy", + method_tree="FirstLayerSiren.create_variables", + ground_truth_backend="jax", var_shape=helpers.get_shape(), + fan_in=helpers.ints( + min_value=1, + safety_factor=4, + safety_factor_scale="log", + ), init_with_v=st.booleans(), method_with_v=st.booleans(), init_as_variable_flags=st.just([False]), @@ -76,8 +81,9 @@ def test_constant( method_native_arrays=st.just([False]), method_container_flags=st.just([False]), ) -def test_zeros( +def test_first_layer_siren( var_shape, + fan_in, init_with_v, method_with_v, class_name, @@ -99,6 +105,7 @@ def test_zeros( method_all_as_kwargs_np={ "var_shape": var_shape, "device": "cpu", + "fan_in": fan_in, }, class_name=class_name, method_name=method_name, @@ -108,15 +115,18 @@ def test_zeros( on_device=on_device, ) + bound = fan_in assert ret_ivy.shape == ret_gt.shape assert ret_ivy.dtype == ret_gt.dtype - assert ivy.all(ivy.equal(ret_ivy, ivy.array(0.0))) + assert ivy.all(ivy.less(ivy.abs(ret_ivy), ivy.array(bound))) @handle_method( - method_tree="Ones.create_variables", + method_tree="GlorotUniform.create_variables", ground_truth_backend="numpy", var_shape=helpers.get_shape(), + fan_in=helpers.ints(min_value=1, max_value=100), + fan_out=helpers.ints(min_value=1, max_value=100), init_with_v=st.booleans(), method_with_v=st.booleans(), init_as_variable_flags=st.just([False]), @@ -127,8 +137,10 @@ def test_zeros( method_native_arrays=st.just([False]), method_container_flags=st.just([False]), ) -def test_ones( +def test_glorot_uniform( var_shape, + fan_in, + fan_out, init_with_v, method_with_v, class_name, @@ -150,6 +162,8 @@ def test_ones( method_all_as_kwargs_np={ "var_shape": var_shape, "device": "cpu", + "fan_in": fan_in, + "fan_out": fan_out, }, class_name=class_name, method_name=method_name, @@ -159,23 +173,33 @@ def test_ones( on_device=on_device, ) + bound = (6 / (fan_in + fan_out)) ** 0.5 assert ret_ivy.shape == ret_gt.shape assert ret_ivy.dtype == ret_gt.dtype - assert ivy.all(ivy.equal(ret_ivy, ivy.array(1.0))) + assert ivy.all(ivy.less(ivy.abs(ret_ivy), ivy.array(bound))) @handle_method( - method_tree="Uniform.create_variables", - ground_truth_backend="numpy", - numerator=helpers.floats(min_value=1.0, max_value=10.0), + method_tree="KaimingNormal.create_variables", + mean=helpers.floats( + min_value=-1e5, + max_value=1e5, + small_abs_safety_factor=8, + safety_factor_scale="log", + ), fan_mode=st.sampled_from(["fan_in", "fan_out", "fan_sum", "fan_avg"]), - power=helpers.floats(min_value=1.0, max_value=3.0), - gain=helpers.floats(min_value=1.0, max_value=10.0), var_shape=helpers.get_shape(), - fan_in=helpers.ints(min_value=1, max_value=100), - fan_out=helpers.ints(min_value=1, max_value=100), + fan_in=helpers.ints(min_value=1, safety_factor=8, safety_factor_scale="log"), + fan_out=helpers.ints(min_value=1, safety_factor=8, safety_factor_scale="log"), + negative_slope=helpers.floats( + min_value=1e-5, + max_value=5.0, + ), + # should be replaced with helpers.get_dtypes() but somehow it causes inconsistent data generation # noqa + dtype=st.sampled_from([None, "float64", "float32", "float16"]), init_with_v=st.booleans(), method_with_v=st.booleans(), + ground_truth_backend="numpy", init_as_variable_flags=st.just([False]), init_num_positional_args=st.just(0), init_native_arrays=st.just([False]), @@ -184,14 +208,14 @@ def test_ones( method_native_arrays=st.just([False]), method_container_flags=st.just([False]), ) -def test_uniform( - numerator, +def test_kaiming_normal( + mean, fan_mode, - power, - gain, var_shape, fan_in, fan_out, + negative_slope, + dtype, init_with_v, method_with_v, class_name, @@ -209,10 +233,8 @@ def test_uniform( method_flags=method_flags, init_input_dtypes=[], init_all_as_kwargs_np={ - "numerator": numerator, + "mean": mean, "fan_mode": fan_mode, - "power": power, - "gain": gain, }, method_input_dtypes=[], method_all_as_kwargs_np={ @@ -220,6 +242,8 @@ def test_uniform( "device": "cpu", "fan_in": fan_in, "fan_out": fan_out, + "negative_slope": negative_slope, + "dtype": dtype, }, class_name=class_name, method_name=method_name, @@ -228,27 +252,14 @@ def test_uniform( test_values=False, on_device=on_device, ) - if fan_mode == "fan_in": - fan = fan_in - elif fan_mode == "fan_out": - fan = fan_out - elif fan_mode == "fan_sum": - fan = fan_in + fan_out - elif fan_mode == "fan_avg": - fan = (fan_in + fan_out) / 2 - - bound = gain * (numerator / fan) ** power assert ret_ivy.shape == ret_gt.shape assert ret_ivy.dtype == ret_gt.dtype - assert ivy.all(ivy.less(ivy.abs(ret_ivy), ivy.array(bound))) @handle_method( - method_tree="GlorotUniform.create_variables", + method_tree="Ones.create_variables", ground_truth_backend="numpy", var_shape=helpers.get_shape(), - fan_in=helpers.ints(min_value=1, max_value=100), - fan_out=helpers.ints(min_value=1, max_value=100), init_with_v=st.booleans(), method_with_v=st.booleans(), init_as_variable_flags=st.just([False]), @@ -259,10 +270,8 @@ def test_uniform( method_native_arrays=st.just([False]), method_container_flags=st.just([False]), ) -def test_glorot_uniform( +def test_ones( var_shape, - fan_in, - fan_out, init_with_v, method_with_v, class_name, @@ -284,8 +293,6 @@ def test_glorot_uniform( method_all_as_kwargs_np={ "var_shape": var_shape, "device": "cpu", - "fan_in": fan_in, - "fan_out": fan_out, }, class_name=class_name, method_name=method_name, @@ -295,23 +302,31 @@ def test_glorot_uniform( on_device=on_device, ) - bound = (6 / (fan_in + fan_out)) ** 0.5 assert ret_ivy.shape == ret_gt.shape assert ret_ivy.dtype == ret_gt.dtype - assert ivy.all(ivy.less(ivy.abs(ret_ivy), ivy.array(bound))) + assert ivy.all(ivy.equal(ret_ivy, ivy.array(1.0))) @handle_method( - method_tree="FirstLayerSiren.create_variables", - ground_truth_backend="jax", - var_shape=helpers.get_shape(), - fan_in=helpers.ints( - min_value=1, - safety_factor=4, + method_tree="RandomNormal.create_variables", + mean=helpers.floats( + min_value=-1e5, + max_value=1e5, + small_abs_safety_factor=8, safety_factor_scale="log", ), + stddev=helpers.floats( + min_value=0, + max_value=1e5, + small_abs_safety_factor=8, + safety_factor_scale="log", + ), + shape=helpers.get_shape(), + # should be replaced with helpers.get_dtypes() but somehow it causes inconsistent data generation # noqa + dtype=st.sampled_from([None, "float64", "float32", "float16"]), init_with_v=st.booleans(), method_with_v=st.booleans(), + ground_truth_backend="numpy", init_as_variable_flags=st.just([False]), init_num_positional_args=st.just(0), init_native_arrays=st.just([False]), @@ -320,18 +335,20 @@ def test_glorot_uniform( method_native_arrays=st.just([False]), method_container_flags=st.just([False]), ) -def test_first_layer_siren( - var_shape, - fan_in, +def test_random_normal( + mean, + stddev, + shape, + dtype, init_with_v, method_with_v, class_name, method_name, - backend_fw, ground_truth_backend, init_flags, method_flags, on_device, + backend_fw, ): ret_ivy, ret_gt = helpers.test_method( backend_to_test=backend_fw, @@ -339,12 +356,15 @@ def test_first_layer_siren( init_flags=init_flags, method_flags=method_flags, init_input_dtypes=[], - init_all_as_kwargs_np={}, + init_all_as_kwargs_np={ + "mean": mean, + "stddev": stddev, + }, method_input_dtypes=[], method_all_as_kwargs_np={ - "var_shape": var_shape, + "var_shape": shape, "device": "cpu", - "fan_in": fan_in, + "dtype": dtype, }, class_name=class_name, method_name=method_name, @@ -353,11 +373,8 @@ def test_first_layer_siren( test_values=False, on_device=on_device, ) - - bound = fan_in assert ret_ivy.shape == ret_gt.shape assert ret_ivy.dtype == ret_gt.dtype - assert ivy.all(ivy.less(ivy.abs(ret_ivy), ivy.array(bound))) @handle_method( @@ -418,26 +435,17 @@ def test_siren( @handle_method( - method_tree="KaimingNormal.create_variables", - mean=helpers.floats( - min_value=-1e5, - max_value=1e5, - small_abs_safety_factor=8, - safety_factor_scale="log", - ), + method_tree="Uniform.create_variables", + ground_truth_backend="numpy", + numerator=helpers.floats(min_value=1.0, max_value=10.0), fan_mode=st.sampled_from(["fan_in", "fan_out", "fan_sum", "fan_avg"]), + power=helpers.floats(min_value=1.0, max_value=3.0), + gain=helpers.floats(min_value=1.0, max_value=10.0), var_shape=helpers.get_shape(), - fan_in=helpers.ints(min_value=1, safety_factor=8, safety_factor_scale="log"), - fan_out=helpers.ints(min_value=1, safety_factor=8, safety_factor_scale="log"), - negative_slope=helpers.floats( - min_value=1e-5, - max_value=5.0, - ), - # should be replaced with helpers.get_dtypes() but somehow it causes inconsistent data generation # noqa - dtype=st.sampled_from([None, "float64", "float32", "float16"]), + fan_in=helpers.ints(min_value=1, max_value=100), + fan_out=helpers.ints(min_value=1, max_value=100), init_with_v=st.booleans(), method_with_v=st.booleans(), - ground_truth_backend="numpy", init_as_variable_flags=st.just([False]), init_num_positional_args=st.just(0), init_native_arrays=st.just([False]), @@ -446,14 +454,14 @@ def test_siren( method_native_arrays=st.just([False]), method_container_flags=st.just([False]), ) -def test_kaiming_normal( - mean, +def test_uniform( + numerator, fan_mode, + power, + gain, var_shape, fan_in, fan_out, - negative_slope, - dtype, init_with_v, method_with_v, class_name, @@ -471,8 +479,10 @@ def test_kaiming_normal( method_flags=method_flags, init_input_dtypes=[], init_all_as_kwargs_np={ - "mean": mean, + "numerator": numerator, "fan_mode": fan_mode, + "power": power, + "gain": gain, }, method_input_dtypes=[], method_all_as_kwargs_np={ @@ -480,8 +490,6 @@ def test_kaiming_normal( "device": "cpu", "fan_in": fan_in, "fan_out": fan_out, - "negative_slope": negative_slope, - "dtype": dtype, }, class_name=class_name, method_name=method_name, @@ -490,30 +498,27 @@ def test_kaiming_normal( test_values=False, on_device=on_device, ) + if fan_mode == "fan_in": + fan = fan_in + elif fan_mode == "fan_out": + fan = fan_out + elif fan_mode == "fan_sum": + fan = fan_in + fan_out + elif fan_mode == "fan_avg": + fan = (fan_in + fan_out) / 2 + + bound = gain * (numerator / fan) ** power assert ret_ivy.shape == ret_gt.shape assert ret_ivy.dtype == ret_gt.dtype + assert ivy.all(ivy.less(ivy.abs(ret_ivy), ivy.array(bound))) @handle_method( - method_tree="RandomNormal.create_variables", - mean=helpers.floats( - min_value=-1e5, - max_value=1e5, - small_abs_safety_factor=8, - safety_factor_scale="log", - ), - stddev=helpers.floats( - min_value=0, - max_value=1e5, - small_abs_safety_factor=8, - safety_factor_scale="log", - ), - shape=helpers.get_shape(), - # should be replaced with helpers.get_dtypes() but somehow it causes inconsistent data generation # noqa - dtype=st.sampled_from([None, "float64", "float32", "float16"]), + method_tree="Zeros.create_variables", + ground_truth_backend="numpy", + var_shape=helpers.get_shape(), init_with_v=st.booleans(), method_with_v=st.booleans(), - ground_truth_backend="numpy", init_as_variable_flags=st.just([False]), init_num_positional_args=st.just(0), init_native_arrays=st.just([False]), @@ -522,20 +527,17 @@ def test_kaiming_normal( method_native_arrays=st.just([False]), method_container_flags=st.just([False]), ) -def test_random_normal( - mean, - stddev, - shape, - dtype, +def test_zeros( + var_shape, init_with_v, method_with_v, class_name, method_name, + backend_fw, ground_truth_backend, init_flags, method_flags, on_device, - backend_fw, ): ret_ivy, ret_gt = helpers.test_method( backend_to_test=backend_fw, @@ -543,15 +545,11 @@ def test_random_normal( init_flags=init_flags, method_flags=method_flags, init_input_dtypes=[], - init_all_as_kwargs_np={ - "mean": mean, - "stddev": stddev, - }, + init_all_as_kwargs_np={}, method_input_dtypes=[], method_all_as_kwargs_np={ - "var_shape": shape, + "var_shape": var_shape, "device": "cpu", - "dtype": dtype, }, class_name=class_name, method_name=method_name, @@ -560,5 +558,7 @@ def test_random_normal( test_values=False, on_device=on_device, ) + assert ret_ivy.shape == ret_gt.shape assert ret_ivy.dtype == ret_gt.dtype + assert ivy.all(ivy.equal(ret_ivy, ivy.array(0.0))) diff --git a/ivy_tests/test_ivy/test_stateful/test_layers.py b/ivy_tests/test_ivy/test_stateful/test_layers.py index 91c6334671b57..6326496eda3dd 100644 --- a/ivy_tests/test_ivy/test_stateful/test_layers.py +++ b/ivy_tests/test_ivy/test_stateful/test_layers.py @@ -20,20 +20,13 @@ valid_dct, ) -# Helpers # -# --------# - all_constant_initializers = (ivy.Zeros, ivy.Ones) -all_uniform_initializers = (ivy.GlorotUniform, ivy.FirstLayerSiren, ivy.Siren) all_gaussian_initializers = (ivy.KaimingNormal, ivy.Siren) -all_initializers = ( - all_constant_initializers + all_uniform_initializers + all_gaussian_initializers -) +all_uniform_initializers = (ivy.GlorotUniform, ivy.FirstLayerSiren, ivy.Siren) -@st.composite -def _sample_initializer(draw): - return draw(st.sampled_from(all_initializers))() +# --- Helpers --- # +# --------------- # # Linear # @@ -48,6 +41,33 @@ def _bias_flag_and_initializer(draw): return with_bias, None +# Embedding +@st.composite +def _get_embedding_args(draw): + num_embeddings = draw(st.integers(min_value=1, max_value=10)) + embedding_dim = draw(st.integers(min_value=1, max_value=10)) + dtype_indices, indices = draw( + helpers.dtype_and_values( + available_dtypes=["int32", "int64"], + min_num_dims=2, + min_dim_size=1, + min_value=0, + max_value=num_embeddings - 1, + ).filter(lambda x: x[1][0].shape[-1] == embedding_dim) + ) + padding_idx = draw(st.integers(min_value=0, max_value=num_embeddings - 1)) + max_norm = draw(st.one_of(st.none(), st.floats(min_value=1, max_value=5))) + + return ( + num_embeddings, + embedding_dim, + dtype_indices, + indices, + padding_idx, + max_norm, + ) + + @st.composite def _input_channels_and_dtype_and_values(draw): input_channels = draw(st.integers(min_value=1, max_value=2)) @@ -65,119 +85,23 @@ def _input_channels_and_dtype_and_values(draw): return input_channels, dtype, vals -# linear -@handle_method( - method_tree="Linear.__call__", - ic_n_dtype_n_vals=_input_channels_and_dtype_and_values(), - output_channels=st.shared( - st.integers(min_value=1, max_value=2), key="output_channels" - ), - weight_initializer=_sample_initializer(), - wb_n_b_init=_bias_flag_and_initializer(), - init_with_v=st.booleans(), - method_with_v=st.booleans(), - seed=helpers.seed(), -) -def test_linear_layer( - *, - ic_n_dtype_n_vals, - output_channels, - weight_initializer, - wb_n_b_init, - init_with_v, - method_with_v, - seed, - on_device, - class_name, - method_name, - backend_fw, - ground_truth_backend, - init_flags, - method_flags, -): - ivy.seed(seed_value=seed) - input_channels, input_dtype, x = ic_n_dtype_n_vals - with_bias, bias_initializer = wb_n_b_init - helpers.test_method( - backend_to_test=backend_fw, - ground_truth_backend=ground_truth_backend, - init_flags=init_flags, - method_flags=method_flags, - init_all_as_kwargs_np={ - "input_channels": input_channels, - "output_channels": output_channels, - "weight_initializer": weight_initializer, - "bias_initializer": bias_initializer, - "with_bias": with_bias, - "device": on_device, - "dtype": input_dtype[0], - }, - method_input_dtypes=input_dtype, - method_all_as_kwargs_np={"x": x[0]}, - class_name=class_name, - method_name=method_name, - init_with_v=init_with_v, - method_with_v=method_with_v, - rtol_=1e-02, - atol_=1e-02, - on_device=on_device, +# LSTM +@st.composite +def _input_channels_and_dtype_and_values_lstm(draw): + input_channels = draw(st.integers(min_value=1, max_value=10)) + t = draw(st.integers(min_value=1, max_value=3)) + x_shape = draw(helpers.get_shape()) + (t, input_channels) + dtype, vals = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float", full=True), shape=x_shape + ) ) + return input_channels, dtype, vals -# Dropout # -# --------# - - -# dropout -@handle_method( - method_tree="Dropout.__call__", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_value=0, - max_value=50, - allow_inf=False, - min_num_dims=1, - max_num_dims=1, - min_dim_size=2, - ), - prob=helpers.floats(min_value=0, max_value=0.9), - scale=st.booleans(), -) -def test_dropout_layer( - *, - dtype_and_x, - prob, - scale, - on_device, - class_name, - method_name, - backend_fw, - ground_truth_backend, - init_flags, - method_flags, -): - input_dtype, x = dtype_and_x - ret = helpers.test_method( - backend_to_test=backend_fw, - ground_truth_backend=ground_truth_backend, - init_flags=init_flags, - method_flags=method_flags, - init_all_as_kwargs_np={ - "prob": prob, - "scale": scale, - "dtype": input_dtype[0], - }, - method_input_dtypes=input_dtype, - method_all_as_kwargs_np={"inputs": x[0]}, - class_name=class_name, - method_name=method_name, - test_values=False, - on_device=on_device, - ) - ret = helpers.flatten_and_to_np(ret=ret) - for u in ret: - # cardinality test - assert u.shape == x[0].shape +@st.composite +def _sample_initializer(draw): + return draw(st.sampled_from(all_initializers))() # Attention # @@ -251,82 +175,6 @@ def _x_and_mha(draw): ) -# multi_head_attention -@handle_method( - method_tree="MultiHeadAttention.__call__", - dtype_mha=_x_and_mha(), - init_with_v=st.booleans(), - method_with_v=st.booleans(), - method_num_positional_args=helpers.num_positional_args( - fn_name="MultiHeadAttention._forward" - ), - build_mode=st.just("on_init"), -) -def test_multi_head_attention_layer( - dtype_mha, - init_with_v, - method_with_v, - build_mode, - on_device, - class_name, - method_name, - backend_fw, - ground_truth_backend, - init_flags, - method_flags, -): - ( - input_dtype, - x_mha, - scale, - num_heads, - context, - mask, - query_dim, - head_dim, - dropout_rate, - context_dim, - with_to_q_fn, - with_to_kv_fn, - with_to_out_fn, - ) = dtype_mha - ret_np_flat, ret_np_from_gt_flat = helpers.test_method( - backend_to_test=backend_fw, - ground_truth_backend=ground_truth_backend, - init_flags=init_flags, - method_flags=method_flags, - init_all_as_kwargs_np={ - "query_dim": query_dim, - "num_heads": num_heads, - "head_dim": head_dim, - "dropout_rate": dropout_rate, - "context_dim": context_dim, - "with_to_q_fn": with_to_q_fn, - "with_to_kv_fn": with_to_kv_fn, - "with_to_out_fn": with_to_out_fn, - "build_mode": build_mode, - "device": on_device, - "dtype": input_dtype[0], - }, - method_input_dtypes=input_dtype, - method_all_as_kwargs_np={ - "inputs": np.asarray(x_mha, dtype=input_dtype[0]), - "context": np.asarray(context, dtype=input_dtype[0]), - "mask": np.asarray(mask, dtype=input_dtype[0]), - }, - class_name=class_name, - method_name=method_name, - init_with_v=init_with_v, - method_with_v=method_with_v, - rtol_=1e-2, - atol_=1e-2, - test_values=False, - return_flat_np_arrays=True, - on_device=on_device, - ) - assert_same_type_and_shape([ret_np_flat, ret_np_from_gt_flat]) - - # Convolutions # # -------------# @@ -402,25 +250,253 @@ def _x_ic_oc_f_d_df(draw, dim: int = 2, transpose: bool = False, depthwise=False ) -# conv1d -@handle_method( - method_tree="Conv1D.__call__", - _x_ic_oc_f_s_d_df_p=_x_ic_oc_f_d_df(dim=1), - weight_initializer=_sample_initializer(), - bias_initializer=_sample_initializer(), - init_with_v=st.booleans(), - method_with_v=st.booleans(), -) -def test_conv1d_layer( - _x_ic_oc_f_s_d_df_p, - weight_initializer, - bias_initializer, - init_with_v, - method_with_v, - on_device, - class_name, - method_name, - backend_fw, +# AdaptiveAveragePool2d +@st.composite +def array_for_adaptive( + draw, + num_dims=3, + max_dim_size=8, + min_dim_size=3, + num_out_size=2, +): + dtypes, arrays = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_num_dims=num_dims, + max_num_dims=num_dims, + min_dim_size=min_dim_size, + max_dim_size=max_dim_size, + ) + ) + size = draw( + helpers.list_of_size( + x=helpers.ints(min_value=3, max_value=5), + size=num_out_size, + ) + ) + output_size = size[0] if num_out_size == 1 else size + return dtypes, arrays, output_size + + +# --- Main --- # +# ------------ # + + +@handle_method( + method_tree="AdaptiveAvgPool1d.__call__", + dt_arr_size=array_for_adaptive(max_dim_size=3, min_dim_size=2, num_out_size=1), +) +def test_adaptive_avg_pool1d_layer( + *, + dt_arr_size, + test_gradients, + on_device, + class_name, + method_name, + ground_truth_backend, + init_flags, + method_flags, + backend_fw, +): + input_dtype, x, out_size = dt_arr_size + helpers.test_method( + ground_truth_backend=ground_truth_backend, + backend_to_test=backend_fw, + init_flags=init_flags, + method_flags=method_flags, + init_all_as_kwargs_np={ + "output_size": out_size, + "device": on_device, + "dtype": input_dtype[0], + }, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={"x": x[0]}, + class_name=class_name, + method_name=method_name, + test_gradients=test_gradients, + on_device=on_device, + ) + + +@handle_method( + method_tree="AdaptiveAvgPool2d.__call__", + dt_arr_size=array_for_adaptive(), +) +def test_adaptive_avg_pool2d_layer( + *, + dt_arr_size, + test_gradients, + on_device, + class_name, + method_name, + ground_truth_backend, + init_flags, + method_flags, + backend_fw, +): + input_dtype, x, out_size = dt_arr_size + helpers.test_method( + ground_truth_backend=ground_truth_backend, + backend_to_test=backend_fw, + init_flags=init_flags, + method_flags=method_flags, + init_all_as_kwargs_np={ + "output_size": out_size, + "device": on_device, + "dtype": input_dtype[0], + }, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={"x": x[0]}, + class_name=class_name, + method_name=method_name, + test_gradients=test_gradients, + on_device=on_device, + ) + + +# AvgPool1D +@handle_method( + method_tree="AvgPool1D.__call__", + x_k_s_p=helpers.arrays_for_pooling(min_dims=3, max_dims=3, min_side=1, max_side=4), +) +def test_avgpool1d_layer( + *, + x_k_s_p, + test_gradients, + on_device, + class_name, + method_name, + ground_truth_backend, + init_flags, + method_flags, + backend_fw, +): + input_dtype, x, kernel_size, stride, padding = x_k_s_p + helpers.test_method( + ground_truth_backend=ground_truth_backend, + backend_to_test=backend_fw, + init_flags=init_flags, + method_flags=method_flags, + init_all_as_kwargs_np={ + "kernel_size": kernel_size, + "stride": stride, + "padding": padding, + }, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={"inputs": x[0]}, + class_name=class_name, + method_name=method_name, + test_gradients=test_gradients, + on_device=on_device, + ) + + +# AvgPool2D +@handle_method( + method_tree="AvgPool2D.__call__", + x_k_s_p=helpers.arrays_for_pooling(min_dims=4, max_dims=4, min_side=1, max_side=4), +) +def test_avgpool2d_layer( + *, + x_k_s_p, + test_gradients, + on_device, + class_name, + method_name, + backend_fw, + ground_truth_backend, + init_flags, + method_flags, +): + input_dtype, x, kernel_size, stride, padding = x_k_s_p + helpers.test_method( + backend_to_test=backend_fw, + ground_truth_backend=ground_truth_backend, + init_flags=init_flags, + method_flags=method_flags, + init_all_as_kwargs_np={ + "kernel_size": kernel_size, + "stride": stride, + "padding": padding, + "device": on_device, + "dtype": input_dtype[0], + }, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={"inputs": x[0]}, + class_name=class_name, + method_name=method_name, + test_gradients=test_gradients, + on_device=on_device, + ) + + +# ToDo : Add gradient testing once random number generation is unified + + +@handle_method( + method_tree="AvgPool3D.__call__", + x_k_s_p=helpers.arrays_for_pooling(min_dims=5, max_dims=5, min_side=1, max_side=4), + count_include_pad=st.booleans(), + ceil_mode=st.booleans(), + divisor_override=st.one_of(st.none(), st.integers(min_value=1, max_value=4)), +) +def test_avgpool3d_layer( + *, + x_k_s_p, + count_include_pad, + ceil_mode, + divisor_override, + test_gradients, + on_device, + class_name, + method_name, + backend_fw, + ground_truth_backend, + init_flags, + method_flags, +): + input_dtype, x, kernel_size, stride, padding = x_k_s_p + helpers.test_method( + backend_to_test=backend_fw, + ground_truth_backend=ground_truth_backend, + init_flags=init_flags, + method_flags=method_flags, + init_all_as_kwargs_np={ + "kernel_size": kernel_size, + "stride": stride, + "padding": padding, + "count_include_pad": count_include_pad, + "ceil_mode": ceil_mode, + "divisor_override": divisor_override, + }, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={"x": x[0]}, + class_name=class_name, + method_name=method_name, + test_gradients=test_gradients, + on_device=on_device, + ) + + +# conv1d +@handle_method( + method_tree="Conv1D.__call__", + _x_ic_oc_f_s_d_df_p=_x_ic_oc_f_d_df(dim=1), + weight_initializer=_sample_initializer(), + bias_initializer=_sample_initializer(), + init_with_v=st.booleans(), + method_with_v=st.booleans(), +) +def test_conv1d_layer( + _x_ic_oc_f_s_d_df_p, + weight_initializer, + bias_initializer, + init_with_v, + method_with_v, + on_device, + class_name, + method_name, + backend_fw, ground_truth_backend, init_flags, method_flags, @@ -675,20 +751,17 @@ def test_conv2d_transpose_layer( ) -# # depthwise conv2d +# conv3d @handle_method( - method_tree="DepthwiseConv2D.__call__", + method_tree="Conv3D.__call__", ground_truth_backend="jax", - _x_ic_oc_f_s_d_df_p=_x_ic_oc_f_d_df(depthwise=True), + _x_ic_oc_f_s_d_df_p=_x_ic_oc_f_d_df(dim=3), weight_initializer=_sample_initializer(), bias_initializer=_sample_initializer(), init_with_v=st.booleans(), method_with_v=st.booleans(), - method_num_positional_args=helpers.num_positional_args( - fn_name="DepthwiseConv2D._forward" - ), ) -def test_depthwise_conv2d_layer( +def test_conv3d_layer( _x_ic_oc_f_s_d_df_p, weight_initializer, bias_initializer, @@ -713,14 +786,15 @@ def test_depthwise_conv2d_layer( data_format, padding, ) = _x_ic_oc_f_s_d_df_p - assume(not (backend_fw == "tensorflow" and dilations > 1 and strides > 1)) + assume(not (backend_fw == "tensorflow" and on_device == "cpu" and dilations > 1)) helpers.test_method( backend_to_test=backend_fw, ground_truth_backend=ground_truth_backend, init_flags=init_flags, method_flags=method_flags, init_all_as_kwargs_np={ - "num_channels": input_channels, + "input_channels": input_channels, + "output_channels": output_channels, "filter_shape": filter_shape, "strides": strides, "padding": padding, @@ -743,17 +817,23 @@ def test_depthwise_conv2d_layer( ) -# conv3d +# conv3d transpose @handle_method( - method_tree="Conv3D.__call__", + method_tree="Conv3DTranspose.__call__", ground_truth_backend="jax", - _x_ic_oc_f_s_d_df_p=_x_ic_oc_f_d_df(dim=3), + _x_ic_oc_f_s_d_df_p=_x_ic_oc_f_d_df(dim=3, transpose=True), weight_initializer=_sample_initializer(), bias_initializer=_sample_initializer(), init_with_v=st.booleans(), method_with_v=st.booleans(), + init_num_positional_args=helpers.num_positional_args( + fn_name="Conv3DTranspose.__init__" + ), + method_num_positional_args=helpers.num_positional_args( + fn_name="Conv3DTranspose._forward" + ), ) -def test_conv3d_layer( +def test_conv3d_transpose_layer( _x_ic_oc_f_s_d_df_p, weight_initializer, bias_initializer, @@ -777,6 +857,7 @@ def test_conv3d_layer( dilations, data_format, padding, + output_shape, ) = _x_ic_oc_f_s_d_df_p assume(not (backend_fw == "tensorflow" and on_device == "cpu" and dilations > 1)) helpers.test_method( @@ -792,6 +873,7 @@ def test_conv3d_layer( "padding": padding, "weight_initializer": weight_initializer, "bias_initializer": bias_initializer, + "output_shape": output_shape, "data_format": data_format, "dilations": dilations, "device": on_device, @@ -809,23 +891,59 @@ def test_conv3d_layer( ) -# conv3d transpose @handle_method( - method_tree="Conv3DTranspose.__call__", + method_tree="Dct.__call__", + dtype_x_and_args=valid_dct(), +) +def test_dct( + *, + dtype_x_and_args, + test_gradients, + on_device, + class_name, + method_name, + ground_truth_backend, + init_flags, + method_flags, + backend_fw, +): + dtype, x, type, n, axis, norm = dtype_x_and_args + helpers.test_method( + ground_truth_backend=ground_truth_backend, + backend_to_test=backend_fw, + init_flags=init_flags, + method_flags=method_flags, + init_all_as_kwargs_np={ + "dtype": dtype[0], + "type": type, + "n": n, + "axis": axis, + "norm": norm, + "device": on_device, + }, + method_input_dtypes=dtype, + method_all_as_kwargs_np={"x": x[0]}, + class_name=class_name, + method_name=method_name, + test_gradients=test_gradients, + on_device=on_device, + ) + + +# # depthwise conv2d +@handle_method( + method_tree="DepthwiseConv2D.__call__", ground_truth_backend="jax", - _x_ic_oc_f_s_d_df_p=_x_ic_oc_f_d_df(dim=3, transpose=True), + _x_ic_oc_f_s_d_df_p=_x_ic_oc_f_d_df(depthwise=True), weight_initializer=_sample_initializer(), bias_initializer=_sample_initializer(), init_with_v=st.booleans(), method_with_v=st.booleans(), - init_num_positional_args=helpers.num_positional_args( - fn_name="Conv3DTranspose.__init__" - ), method_num_positional_args=helpers.num_positional_args( - fn_name="Conv3DTranspose._forward" + fn_name="DepthwiseConv2D._forward" ), ) -def test_conv3d_transpose_layer( +def test_depthwise_conv2d_layer( _x_ic_oc_f_s_d_df_p, weight_initializer, bias_initializer, @@ -849,23 +967,20 @@ def test_conv3d_transpose_layer( dilations, data_format, padding, - output_shape, ) = _x_ic_oc_f_s_d_df_p - assume(not (backend_fw == "tensorflow" and on_device == "cpu" and dilations > 1)) + assume(not (backend_fw == "tensorflow" and dilations > 1 and strides > 1)) helpers.test_method( backend_to_test=backend_fw, ground_truth_backend=ground_truth_backend, init_flags=init_flags, method_flags=method_flags, init_all_as_kwargs_np={ - "input_channels": input_channels, - "output_channels": output_channels, + "num_channels": input_channels, "filter_shape": filter_shape, "strides": strides, "padding": padding, "weight_initializer": weight_initializer, "bias_initializer": bias_initializer, - "output_shape": output_shape, "data_format": data_format, "dilations": dilations, "device": on_device, @@ -883,42 +998,77 @@ def test_conv3d_transpose_layer( ) -# LSTM -@st.composite -def _input_channels_and_dtype_and_values_lstm(draw): - input_channels = draw(st.integers(min_value=1, max_value=10)) - t = draw(st.integers(min_value=1, max_value=3)) - x_shape = draw(helpers.get_shape()) + (t, input_channels) - dtype, vals = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float", full=True), shape=x_shape - ) - ) - return input_channels, dtype, vals +# Dropout # +# --------# +# dropout @handle_method( - method_tree="LSTM.__call__", - input_dtype_val=_input_channels_and_dtype_and_values_lstm(), - output_channels=st.shared( - st.integers(min_value=1, max_value=10), key="output_channels" + method_tree="Dropout.__call__", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_value=0, + max_value=50, + allow_inf=False, + min_num_dims=1, + max_num_dims=1, + min_dim_size=2, ), + prob=helpers.floats(min_value=0, max_value=0.9), + scale=st.booleans(), +) +def test_dropout_layer( + *, + dtype_and_x, + prob, + scale, + on_device, + class_name, + method_name, + backend_fw, + ground_truth_backend, + init_flags, + method_flags, +): + input_dtype, x = dtype_and_x + ret = helpers.test_method( + backend_to_test=backend_fw, + ground_truth_backend=ground_truth_backend, + init_flags=init_flags, + method_flags=method_flags, + init_all_as_kwargs_np={ + "prob": prob, + "scale": scale, + "dtype": input_dtype[0], + }, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={"inputs": x[0]}, + class_name=class_name, + method_name=method_name, + test_values=False, + on_device=on_device, + ) + ret = helpers.flatten_and_to_np(ret=ret) + for u in ret: + # cardinality test + assert u.shape == x[0].shape + + +@handle_method( + method_tree="Embedding.__call__", + embedding_args=_get_embedding_args(), weight_initializer=_sample_initializer(), - num_layers=st.integers(min_value=1, max_value=3), - return_sequence=st.booleans(), - return_state=st.booleans(), init_with_v=st.booleans(), method_with_v=st.booleans(), + seed=helpers.seed(), ) -def test_lstm_layer( - input_dtype_val, - output_channels, +def test_embedding_layer( + *, + embedding_args, weight_initializer, - num_layers, - return_sequence, - return_state, init_with_v, method_with_v, + seed, on_device, class_name, method_name, @@ -927,177 +1077,96 @@ def test_lstm_layer( init_flags, method_flags, ): - input_channels, input_dtype, vals = input_dtype_val - return_sequence = return_sequence - return_state = return_state + ivy.seed(seed_value=seed) + ( + num_embeddings, + embedding_dim, + dtype_indices, + indices, + padding_idx, + max_norm, + ) = embedding_args + dtype = dtype_indices helpers.test_method( backend_to_test=backend_fw, ground_truth_backend=ground_truth_backend, init_flags=init_flags, method_flags=method_flags, init_all_as_kwargs_np={ - "input_channels": input_channels, - "output_channels": output_channels, - "weight_initializer": weight_initializer, - "num_layers": num_layers, - "return_sequence": return_sequence, - "return_state": return_state, + "num_embeddings": num_embeddings, + "embedding_dim": embedding_dim, + "padding_idx": padding_idx, + "max_norm": max_norm, "device": on_device, - "dtype": input_dtype[0], + "dtype": dtype[0], }, - method_input_dtypes=input_dtype, - method_all_as_kwargs_np={"inputs": np.asarray(vals[0], dtype=input_dtype[0])}, + method_all_as_kwargs_np={"indices": indices[0]}, + method_input_dtypes=dtype, class_name=class_name, method_name=method_name, init_with_v=init_with_v, method_with_v=method_with_v, - rtol_=1e-01, - atol_=1e-01, + rtol_=1e-02, + atol_=1e-02, on_device=on_device, ) -# # Sequential # +# FFT @handle_method( - method_tree="Sequential.__call__", - bs_c_target=st.sampled_from( - [ - ( - [1, 2], - 5, - [ - [ - [-0.34784955, 0.47909835, 0.7241975, -0.82175905, -0.43836743], - [-0.34784955, 0.47909835, 0.7241975, -0.82175905, -0.43836743], - ] - ], - ) - ] - ), - with_v=st.booleans(), - seq_v=st.booleans(), - dtype=helpers.get_dtypes("float", full=False, none=True), + method_tree="FFT.__call__", + x_and_fft=exp_layers_tests.x_and_fft(), ) -def test_sequential_layer( - bs_c_target, - with_v, - seq_v, - dtype, - method_flags, +def test_fft_layer( + *, + x_and_fft, + test_gradients, on_device, - compile_graph, - method_name, class_name, + method_name, + ground_truth_backend, + init_flags, + method_flags, + backend_fw, ): - dtype = dtype[0] - # smoke test - batch_shape, channels, target = bs_c_target - tolerance_dict = { - "bfloat16": 1e-2, - "float16": 1e-2, - "float32": 1e-5, - "float64": 1e-5, - None: 1e-5, - } - if method_flags.as_variable[0]: - x = _variable( - ivy.asarray( - ivy.linspace(ivy.zeros(batch_shape), ivy.ones(batch_shape), channels), - dtype=dtype, - ) - ) - else: - x = ivy.asarray( - ivy.linspace(ivy.zeros(batch_shape), ivy.ones(batch_shape), channels), - dtype=dtype, - ) - if with_v: - np.random.seed(0) - wlim = (6 / (channels + channels)) ** 0.5 - v = Container( - { - "submodules": { - "v0": { - "w": _variable( - ivy.array( - np.random.uniform(-wlim, wlim, (channels, channels)), - dtype=dtype, - device=on_device, - ) - ), - "b": _variable( - ivy.zeros([channels], device=on_device, dtype=dtype) - ), - }, - "v2": { - "w": _variable( - ivy.array( - np.random.uniform(-wlim, wlim, (channels, channels)), - dtype=dtype, - device=on_device, - ) - ), - "b": _variable( - ivy.zeros([channels], device=on_device, dtype=dtype) - ), - }, - } - } - ) - else: - v = None - if seq_v: - seq = ivy.Sequential( - ivy.Linear(channels, channels, device=on_device, dtype=dtype), - ivy.Dropout(0.0), - ivy.Linear(channels, channels, device=on_device, dtype=dtype), - device=on_device, - v=v if with_v else None, - dtype=dtype, - ) - else: - seq = ivy.Sequential( - ivy.Linear( - channels, - channels, - device=on_device, - v=v["submodules"]["v0"] if with_v else None, - dtype=dtype, - ), - ivy.Dropout(0.0), - ivy.Linear( - channels, - channels, - device=on_device, - v=v["submodules"]["v2"] if with_v else None, - dtype=dtype, - ), - device=on_device, - ) - ret = seq(x) - # type test - assert ivy.is_ivy_array(ret) - # cardinality test - assert ret.shape == ivy.Shape(batch_shape + [channels]) - # value test - if not with_v: - return - assert np.allclose( - ivy.to_numpy(seq(x)), np.array(target), rtol=tolerance_dict[dtype] + dtype, x, dim, norm, n = x_and_fft + helpers.test_method( + ground_truth_backend=ground_truth_backend, + backend_to_test=backend_fw, + init_flags=init_flags, + method_flags=method_flags, + init_all_as_kwargs_np={ + "dim": dim, + "norm": norm, + "n": n, + "device": on_device, + "dtype": dtype[0], + }, + method_input_dtypes=dtype, + method_all_as_kwargs_np={"inputs": x[0]}, + class_name=class_name, + method_name=method_name, + test_gradients=test_gradients, + on_device=on_device, ) -# # Pooling # - - -# MaxPool2D +# Identity @handle_method( - method_tree="MaxPool2D.__call__", - x_k_s_p=helpers.arrays_for_pooling(min_dims=4, max_dims=4, min_side=1, max_side=4), + method_tree="Identity.__call__", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=1, + max_num_dims=5, + ), + init_with_v=st.booleans(), + method_with_v=st.booleans(), ) -def test_maxpool2d_layer( +def test_identity_layer( *, - x_k_s_p, + dtype_and_x, + init_with_v, + method_with_v, test_gradients, on_device, class_name, @@ -1107,37 +1176,50 @@ def test_maxpool2d_layer( init_flags, method_flags, ): - input_dtype, x, kernel_size, stride, padding = x_k_s_p + input_dtype, x = dtype_and_x helpers.test_method( backend_to_test=backend_fw, ground_truth_backend=ground_truth_backend, init_flags=init_flags, method_flags=method_flags, init_all_as_kwargs_np={ - "kernel_size": kernel_size, - "stride": stride, - "padding": padding, "device": on_device, - "dtype": input_dtype[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={"inputs": x[0]}, + method_all_as_kwargs_np={"x": x[0]}, class_name=class_name, method_name=method_name, + init_with_v=init_with_v, + method_with_v=method_with_v, + rtol_=1e-03, + atol_=1e-03, test_gradients=test_gradients, on_device=on_device, ) -# AvgPool2D +# linear @handle_method( - method_tree="AvgPool2D.__call__", - x_k_s_p=helpers.arrays_for_pooling(min_dims=4, max_dims=4, min_side=1, max_side=4), + method_tree="Linear.__call__", + ic_n_dtype_n_vals=_input_channels_and_dtype_and_values(), + output_channels=st.shared( + st.integers(min_value=1, max_value=2), key="output_channels" + ), + weight_initializer=_sample_initializer(), + wb_n_b_init=_bias_flag_and_initializer(), + init_with_v=st.booleans(), + method_with_v=st.booleans(), + seed=helpers.seed(), ) -def test_avgpool2d_layer( +def test_linear_layer( *, - x_k_s_p, - test_gradients, + ic_n_dtype_n_vals, + output_channels, + weight_initializer, + wb_n_b_init, + init_with_v, + method_with_v, + seed, on_device, class_name, method_name, @@ -1146,45 +1228,57 @@ def test_avgpool2d_layer( init_flags, method_flags, ): - input_dtype, x, kernel_size, stride, padding = x_k_s_p + ivy.seed(seed_value=seed) + input_channels, input_dtype, x = ic_n_dtype_n_vals + with_bias, bias_initializer = wb_n_b_init helpers.test_method( backend_to_test=backend_fw, ground_truth_backend=ground_truth_backend, init_flags=init_flags, method_flags=method_flags, init_all_as_kwargs_np={ - "kernel_size": kernel_size, - "stride": stride, - "padding": padding, + "input_channels": input_channels, + "output_channels": output_channels, + "weight_initializer": weight_initializer, + "bias_initializer": bias_initializer, + "with_bias": with_bias, "device": on_device, "dtype": input_dtype[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={"inputs": x[0]}, + method_all_as_kwargs_np={"x": x[0]}, class_name=class_name, method_name=method_name, - test_gradients=test_gradients, + init_with_v=init_with_v, + method_with_v=method_with_v, + rtol_=1e-02, + atol_=1e-02, on_device=on_device, ) -# ToDo : Add gradient testing once random number generation is unified - - @handle_method( - method_tree="AvgPool3D.__call__", - x_k_s_p=helpers.arrays_for_pooling(min_dims=5, max_dims=5, min_side=1, max_side=4), - count_include_pad=st.booleans(), - ceil_mode=st.booleans(), - divisor_override=st.one_of(st.none(), st.integers(min_value=1, max_value=4)), + method_tree="LSTM.__call__", + input_dtype_val=_input_channels_and_dtype_and_values_lstm(), + output_channels=st.shared( + st.integers(min_value=1, max_value=10), key="output_channels" + ), + weight_initializer=_sample_initializer(), + num_layers=st.integers(min_value=1, max_value=3), + return_sequence=st.booleans(), + return_state=st.booleans(), + init_with_v=st.booleans(), + method_with_v=st.booleans(), ) -def test_avgpool3d_layer( - *, - x_k_s_p, - count_include_pad, - ceil_mode, - divisor_override, - test_gradients, +def test_lstm_layer( + input_dtype_val, + output_channels, + weight_initializer, + num_layers, + return_sequence, + return_state, + init_with_v, + method_with_v, on_device, class_name, method_name, @@ -1193,25 +1287,32 @@ def test_avgpool3d_layer( init_flags, method_flags, ): - input_dtype, x, kernel_size, stride, padding = x_k_s_p + input_channels, input_dtype, vals = input_dtype_val + return_sequence = return_sequence + return_state = return_state helpers.test_method( backend_to_test=backend_fw, ground_truth_backend=ground_truth_backend, init_flags=init_flags, method_flags=method_flags, init_all_as_kwargs_np={ - "kernel_size": kernel_size, - "stride": stride, - "padding": padding, - "count_include_pad": count_include_pad, - "ceil_mode": ceil_mode, - "divisor_override": divisor_override, + "input_channels": input_channels, + "output_channels": output_channels, + "weight_initializer": weight_initializer, + "num_layers": num_layers, + "return_sequence": return_sequence, + "return_state": return_state, + "device": on_device, + "dtype": input_dtype[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={"x": x[0]}, + method_all_as_kwargs_np={"inputs": np.asarray(vals[0], dtype=input_dtype[0])}, class_name=class_name, method_name=method_name, - test_gradients=test_gradients, + init_with_v=init_with_v, + method_with_v=method_with_v, + rtol_=1e-01, + atol_=1e-01, on_device=on_device, ) @@ -1255,12 +1356,15 @@ def test_maxpool1d_layer( ) -# MaxPool3D +# # Pooling # + + +# MaxPool2D @handle_method( - method_tree="MaxPool3D.__call__", - x_k_s_p=helpers.arrays_for_pooling(min_dims=5, max_dims=5, min_side=1, max_side=4), + method_tree="MaxPool2D.__call__", + x_k_s_p=helpers.arrays_for_pooling(min_dims=4, max_dims=4, min_side=1, max_side=4), ) -def test_maxpool3d_layer( +def test_maxpool2d_layer( *, x_k_s_p, test_gradients, @@ -1286,7 +1390,7 @@ def test_maxpool3d_layer( "dtype": input_dtype[0], }, method_input_dtypes=input_dtype, - method_all_as_kwargs_np={"x": x[0]}, + method_all_as_kwargs_np={"inputs": x[0]}, class_name=class_name, method_name=method_name, test_gradients=test_gradients, @@ -1294,213 +1398,37 @@ def test_maxpool3d_layer( ) -# AdaptiveAveragePool2d -@st.composite -def array_for_adaptive( - draw, - num_dims=3, - max_dim_size=8, - min_dim_size=3, - num_out_size=2, -): - dtypes, arrays = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_num_dims=num_dims, - max_num_dims=num_dims, - min_dim_size=min_dim_size, - max_dim_size=max_dim_size, - ) - ) - size = draw( - helpers.list_of_size( - x=helpers.ints(min_value=3, max_value=5), - size=num_out_size, - ) - ) - output_size = size[0] if num_out_size == 1 else size - return dtypes, arrays, output_size - - -@handle_method( - method_tree="AdaptiveAvgPool2d.__call__", - dt_arr_size=array_for_adaptive(), -) -def test_adaptive_avg_pool2d_layer( - *, - dt_arr_size, - test_gradients, - on_device, - class_name, - method_name, - ground_truth_backend, - init_flags, - method_flags, - backend_fw, -): - input_dtype, x, out_size = dt_arr_size - helpers.test_method( - ground_truth_backend=ground_truth_backend, - backend_to_test=backend_fw, - init_flags=init_flags, - method_flags=method_flags, - init_all_as_kwargs_np={ - "output_size": out_size, - "device": on_device, - "dtype": input_dtype[0], - }, - method_input_dtypes=input_dtype, - method_all_as_kwargs_np={"x": x[0]}, - class_name=class_name, - method_name=method_name, - test_gradients=test_gradients, - on_device=on_device, - ) - - -@handle_method( - method_tree="AdaptiveAvgPool1d.__call__", - dt_arr_size=array_for_adaptive(max_dim_size=3, min_dim_size=2, num_out_size=1), -) -def test_adaptive_avg_pool1d_layer( - *, - dt_arr_size, - test_gradients, - on_device, - class_name, - method_name, - ground_truth_backend, - init_flags, - method_flags, - backend_fw, -): - input_dtype, x, out_size = dt_arr_size - helpers.test_method( - ground_truth_backend=ground_truth_backend, - backend_to_test=backend_fw, - init_flags=init_flags, - method_flags=method_flags, - init_all_as_kwargs_np={ - "output_size": out_size, - "device": on_device, - "dtype": input_dtype[0], - }, - method_input_dtypes=input_dtype, - method_all_as_kwargs_np={"x": x[0]}, - class_name=class_name, - method_name=method_name, - test_gradients=test_gradients, - on_device=on_device, - ) - - -# FFT -@handle_method( - method_tree="FFT.__call__", - x_and_fft=exp_layers_tests.x_and_fft(), -) -def test_fft_layer( - *, - x_and_fft, - test_gradients, - on_device, - class_name, - method_name, - ground_truth_backend, - init_flags, - method_flags, - backend_fw, -): - dtype, x, dim, norm, n = x_and_fft - helpers.test_method( - ground_truth_backend=ground_truth_backend, - backend_to_test=backend_fw, - init_flags=init_flags, - method_flags=method_flags, - init_all_as_kwargs_np={ - "dim": dim, - "norm": norm, - "n": n, - "device": on_device, - "dtype": dtype[0], - }, - method_input_dtypes=dtype, - method_all_as_kwargs_np={"inputs": x[0]}, - class_name=class_name, - method_name=method_name, - test_gradients=test_gradients, - on_device=on_device, - ) - - -# AvgPool1D +# MaxPool3D @handle_method( - method_tree="AvgPool1D.__call__", - x_k_s_p=helpers.arrays_for_pooling(min_dims=3, max_dims=3, min_side=1, max_side=4), + method_tree="MaxPool3D.__call__", + x_k_s_p=helpers.arrays_for_pooling(min_dims=5, max_dims=5, min_side=1, max_side=4), ) -def test_avgpool1d_layer( +def test_maxpool3d_layer( *, x_k_s_p, test_gradients, on_device, class_name, method_name, + backend_fw, ground_truth_backend, init_flags, method_flags, - backend_fw, ): input_dtype, x, kernel_size, stride, padding = x_k_s_p helpers.test_method( - ground_truth_backend=ground_truth_backend, backend_to_test=backend_fw, + ground_truth_backend=ground_truth_backend, init_flags=init_flags, method_flags=method_flags, init_all_as_kwargs_np={ "kernel_size": kernel_size, "stride": stride, "padding": padding, - }, - method_input_dtypes=input_dtype, - method_all_as_kwargs_np={"inputs": x[0]}, - class_name=class_name, - method_name=method_name, - test_gradients=test_gradients, - on_device=on_device, - ) - - -@handle_method( - method_tree="Dct.__call__", - dtype_x_and_args=valid_dct(), -) -def test_dct( - *, - dtype_x_and_args, - test_gradients, - on_device, - class_name, - method_name, - ground_truth_backend, - init_flags, - method_flags, - backend_fw, -): - dtype, x, type, n, axis, norm = dtype_x_and_args - helpers.test_method( - ground_truth_backend=ground_truth_backend, - backend_to_test=backend_fw, - init_flags=init_flags, - method_flags=method_flags, - init_all_as_kwargs_np={ - "dtype": dtype[0], - "type": type, - "n": n, - "axis": axis, - "norm": norm, "device": on_device, + "dtype": input_dtype[0], }, - method_input_dtypes=dtype, + method_input_dtypes=input_dtype, method_all_as_kwargs_np={"x": x[0]}, class_name=class_name, method_name=method_name, @@ -1509,48 +1437,22 @@ def test_dct( ) -# Embedding -@st.composite -def _get_embedding_args(draw): - num_embeddings = draw(st.integers(min_value=1, max_value=10)) - embedding_dim = draw(st.integers(min_value=1, max_value=10)) - dtype_indices, indices = draw( - helpers.dtype_and_values( - available_dtypes=["int32", "int64"], - min_num_dims=2, - min_dim_size=1, - min_value=0, - max_value=num_embeddings - 1, - ).filter(lambda x: x[1][0].shape[-1] == embedding_dim) - ) - padding_idx = draw(st.integers(min_value=0, max_value=num_embeddings - 1)) - max_norm = draw(st.one_of(st.none(), st.floats(min_value=1, max_value=5))) - - return ( - num_embeddings, - embedding_dim, - dtype_indices, - indices, - padding_idx, - max_norm, - ) - - +# multi_head_attention @handle_method( - method_tree="Embedding.__call__", - embedding_args=_get_embedding_args(), - weight_initializer=_sample_initializer(), + method_tree="MultiHeadAttention.__call__", + dtype_mha=_x_and_mha(), init_with_v=st.booleans(), method_with_v=st.booleans(), - seed=helpers.seed(), + method_num_positional_args=helpers.num_positional_args( + fn_name="MultiHeadAttention._forward" + ), + build_mode=st.just("on_init"), ) -def test_embedding_layer( - *, - embedding_args, - weight_initializer, +def test_multi_head_attention_layer( + dtype_mha, init_with_v, method_with_v, - seed, + build_mode, on_device, class_name, method_name, @@ -1559,83 +1461,188 @@ def test_embedding_layer( init_flags, method_flags, ): - ivy.seed(seed_value=seed) ( - num_embeddings, - embedding_dim, - dtype_indices, - indices, - padding_idx, - max_norm, - ) = embedding_args - dtype = dtype_indices - helpers.test_method( + input_dtype, + x_mha, + scale, + num_heads, + context, + mask, + query_dim, + head_dim, + dropout_rate, + context_dim, + with_to_q_fn, + with_to_kv_fn, + with_to_out_fn, + ) = dtype_mha + ret_np_flat, ret_np_from_gt_flat = helpers.test_method( backend_to_test=backend_fw, ground_truth_backend=ground_truth_backend, init_flags=init_flags, method_flags=method_flags, init_all_as_kwargs_np={ - "num_embeddings": num_embeddings, - "embedding_dim": embedding_dim, - "padding_idx": padding_idx, - "max_norm": max_norm, + "query_dim": query_dim, + "num_heads": num_heads, + "head_dim": head_dim, + "dropout_rate": dropout_rate, + "context_dim": context_dim, + "with_to_q_fn": with_to_q_fn, + "with_to_kv_fn": with_to_kv_fn, + "with_to_out_fn": with_to_out_fn, + "build_mode": build_mode, "device": on_device, - "dtype": dtype[0], + "dtype": input_dtype[0], + }, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={ + "inputs": np.asarray(x_mha, dtype=input_dtype[0]), + "context": np.asarray(context, dtype=input_dtype[0]), + "mask": np.asarray(mask, dtype=input_dtype[0]), }, - method_all_as_kwargs_np={"indices": indices[0]}, - method_input_dtypes=dtype, class_name=class_name, method_name=method_name, init_with_v=init_with_v, method_with_v=method_with_v, - rtol_=1e-02, - atol_=1e-02, + rtol_=1e-2, + atol_=1e-2, + test_values=False, + return_flat_np_arrays=True, on_device=on_device, ) + assert_same_type_and_shape([ret_np_flat, ret_np_from_gt_flat]) -# Identity +# # Sequential # @handle_method( - method_tree="Identity.__call__", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - min_num_dims=1, - max_num_dims=5, + method_tree="Sequential.__call__", + bs_c_target=st.sampled_from( + [ + ( + [1, 2], + 5, + [ + [ + [-0.34784955, 0.47909835, 0.7241975, -0.82175905, -0.43836743], + [-0.34784955, 0.47909835, 0.7241975, -0.82175905, -0.43836743], + ] + ], + ) + ] ), - init_with_v=st.booleans(), - method_with_v=st.booleans(), + with_v=st.booleans(), + seq_v=st.booleans(), + dtype=helpers.get_dtypes("float", full=False, none=True), ) -def test_identity_layer( - *, - dtype_and_x, - init_with_v, - method_with_v, - test_gradients, +def test_sequential_layer( + bs_c_target, + with_v, + seq_v, + dtype, + method_flags, on_device, - class_name, + compile_graph, method_name, - backend_fw, - ground_truth_backend, - init_flags, - method_flags, + class_name, ): - input_dtype, x = dtype_and_x - helpers.test_method( - backend_to_test=backend_fw, - ground_truth_backend=ground_truth_backend, - init_flags=init_flags, - method_flags=method_flags, - init_all_as_kwargs_np={ - "device": on_device, - }, - method_input_dtypes=input_dtype, - method_all_as_kwargs_np={"x": x[0]}, - class_name=class_name, - method_name=method_name, - init_with_v=init_with_v, - method_with_v=method_with_v, - rtol_=1e-03, - atol_=1e-03, - test_gradients=test_gradients, - on_device=on_device, + dtype = dtype[0] + # smoke test + batch_shape, channels, target = bs_c_target + tolerance_dict = { + "bfloat16": 1e-2, + "float16": 1e-2, + "float32": 1e-5, + "float64": 1e-5, + None: 1e-5, + } + if method_flags.as_variable[0]: + x = _variable( + ivy.asarray( + ivy.linspace(ivy.zeros(batch_shape), ivy.ones(batch_shape), channels), + dtype=dtype, + ) + ) + else: + x = ivy.asarray( + ivy.linspace(ivy.zeros(batch_shape), ivy.ones(batch_shape), channels), + dtype=dtype, + ) + if with_v: + np.random.seed(0) + wlim = (6 / (channels + channels)) ** 0.5 + v = Container( + { + "submodules": { + "v0": { + "w": _variable( + ivy.array( + np.random.uniform(-wlim, wlim, (channels, channels)), + dtype=dtype, + device=on_device, + ) + ), + "b": _variable( + ivy.zeros([channels], device=on_device, dtype=dtype) + ), + }, + "v2": { + "w": _variable( + ivy.array( + np.random.uniform(-wlim, wlim, (channels, channels)), + dtype=dtype, + device=on_device, + ) + ), + "b": _variable( + ivy.zeros([channels], device=on_device, dtype=dtype) + ), + }, + } + } + ) + else: + v = None + if seq_v: + seq = ivy.Sequential( + ivy.Linear(channels, channels, device=on_device, dtype=dtype), + ivy.Dropout(0.0), + ivy.Linear(channels, channels, device=on_device, dtype=dtype), + device=on_device, + v=v if with_v else None, + dtype=dtype, + ) + else: + seq = ivy.Sequential( + ivy.Linear( + channels, + channels, + device=on_device, + v=v["submodules"]["v0"] if with_v else None, + dtype=dtype, + ), + ivy.Dropout(0.0), + ivy.Linear( + channels, + channels, + device=on_device, + v=v["submodules"]["v2"] if with_v else None, + dtype=dtype, + ), + device=on_device, + ) + ret = seq(x) + # type test + assert ivy.is_ivy_array(ret) + # cardinality test + assert ret.shape == ivy.Shape(batch_shape + [channels]) + # value test + if not with_v: + return + assert np.allclose( + ivy.to_numpy(seq(x)), np.array(target), rtol=tolerance_dict[dtype] ) + + +all_initializers = ( + all_constant_initializers + all_uniform_initializers + all_gaussian_initializers +) diff --git a/ivy_tests/test_ivy/test_stateful/test_losses.py b/ivy_tests/test_ivy/test_stateful/test_losses.py index 1ca40738e72ee..c8efbe04c1295 100644 --- a/ivy_tests/test_ivy/test_stateful/test_losses.py +++ b/ivy_tests/test_ivy/test_stateful/test_losses.py @@ -8,9 +8,121 @@ from ivy_tests.test_ivy.helpers import handle_method -# Log Poisson Loss +# Binary Cross Entropy Loss @handle_method( - method_tree="stateful.losses.LogPoissonLoss.__call__", + method_tree="stateful.losses.BinaryCrossEntropyLoss.__call__", + dtype_and_true=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("integer"), + min_value=1e-04, + max_value=1, + allow_inf=False, + min_num_dims=1, + max_num_dims=1, + min_dim_size=2, + shape=(5,), + ), + dtype_and_pred=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_value=1e-04, + max_value=1, + allow_inf=False, + exclude_min=True, + exclude_max=True, + min_num_dims=1, + max_num_dims=1, + min_dim_size=2, + shape=(5,), + ), + dtype_and_pos=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_value=1e-04, + max_value=1, + allow_inf=False, + exclude_min=True, + exclude_max=True, + min_num_dims=1, + max_num_dims=1, + min_dim_size=2, + shape=(5,), + ), + reduction=st.sampled_from(["none", "sum", "mean"]), + axis=helpers.ints(min_value=-1, max_value=0), + epsilon=helpers.floats(min_value=0, max_value=1.0), + from_logits=st.booleans(), + method_num_positional_args=helpers.num_positional_args( + fn_name="BinaryCrossEntropyLoss._forward" + ), +) +def test_binary_cross_entropy_loss( + *, + dtype_and_true, + dtype_and_pred, + dtype_and_pos, + from_logits, + reduction, + axis, + epsilon, + class_name, + method_name, + ground_truth_backend, + init_flags, + method_flags, + on_device, +): + dtype_true, true = dtype_and_true + dtype_pred, pred = dtype_and_pred + dtype_pos_weight, pos_weight = dtype_and_pos + + if from_logits: + helpers.test_method( + ground_truth_backend=ground_truth_backend, + init_flags=init_flags, + method_flags=method_flags, + method_input_dtypes=dtype_true + dtype_pred + dtype_pos_weight, + init_all_as_kwargs_np={ + "from_logits": from_logits, + "epsilon": epsilon, + "reduction": reduction, + }, + method_all_as_kwargs_np={ + "true": true[0], + "pred": pred[0], + "pos_weight": pos_weight[0], + "axis": axis, + }, + class_name=class_name, + method_name=method_name, + rtol_=1e-2, + atol_=1e-2, + on_device=on_device, + ) + else: + helpers.test_method( + ground_truth_backend=ground_truth_backend, + init_flags=init_flags, + method_flags=method_flags, + method_input_dtypes=dtype_true + dtype_pred, + init_all_as_kwargs_np={ + "from_logits": from_logits, + "epsilon": epsilon, + "reduction": reduction, + }, + method_all_as_kwargs_np={ + "true": true[0], + "pred": pred[0], + "axis": axis, + }, + class_name=class_name, + method_name=method_name, + rtol_=1e-2, + atol_=1e-2, + on_device=on_device, + ) + + +# Cross Entropy Loss +@handle_method( + method_tree="stateful.losses.CrossEntropyLoss.__call__", dtype_and_targets=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), min_value=0, @@ -34,17 +146,15 @@ min_dim_size=3, ), axis=st.integers(min_value=-1, max_value=1), - compute_full_loss=st.sampled_from([True, False]), method_num_positional_args=helpers.num_positional_args( - fn_name="LogPoissonLoss._forward" + fn_name="CrossEntropyLoss._forward" ), reduction=st.sampled_from(["none", "mean", "sum"]), ) -def test_log_poisson_loss( +def test_cross_entropy_loss( *, dtype_and_targets, dtype_and_log_input, - compute_full_loss, axis, reduction, class_name, @@ -62,7 +172,6 @@ def test_log_poisson_loss( method_flags=method_flags, method_input_dtypes=targets_dtype + log_input_dtype, init_all_as_kwargs_np={ - "compute_full_loss": compute_full_loss, "axis": axis, "reduction": reduction, }, @@ -75,9 +184,9 @@ def test_log_poisson_loss( ) -# Cross Entropy Loss +# Log Poisson Loss @handle_method( - method_tree="stateful.losses.CrossEntropyLoss.__call__", + method_tree="stateful.losses.LogPoissonLoss.__call__", dtype_and_targets=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), min_value=0, @@ -101,15 +210,17 @@ def test_log_poisson_loss( min_dim_size=3, ), axis=st.integers(min_value=-1, max_value=1), + compute_full_loss=st.sampled_from([True, False]), method_num_positional_args=helpers.num_positional_args( - fn_name="CrossEntropyLoss._forward" + fn_name="LogPoissonLoss._forward" ), reduction=st.sampled_from(["none", "mean", "sum"]), ) -def test_cross_entropy_loss( +def test_log_poisson_loss( *, dtype_and_targets, dtype_and_log_input, + compute_full_loss, axis, reduction, class_name, @@ -127,6 +238,7 @@ def test_cross_entropy_loss( method_flags=method_flags, method_input_dtypes=targets_dtype + log_input_dtype, init_all_as_kwargs_np={ + "compute_full_loss": compute_full_loss, "axis": axis, "reduction": reduction, }, @@ -137,115 +249,3 @@ def test_cross_entropy_loss( atol_=1e-2, on_device=on_device, ) - - -# Binary Cross Entropy Loss -@handle_method( - method_tree="stateful.losses.BinaryCrossEntropyLoss.__call__", - dtype_and_true=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("integer"), - min_value=1e-04, - max_value=1, - allow_inf=False, - min_num_dims=1, - max_num_dims=1, - min_dim_size=2, - shape=(5,), - ), - dtype_and_pred=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_value=1e-04, - max_value=1, - allow_inf=False, - exclude_min=True, - exclude_max=True, - min_num_dims=1, - max_num_dims=1, - min_dim_size=2, - shape=(5,), - ), - dtype_and_pos=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_value=1e-04, - max_value=1, - allow_inf=False, - exclude_min=True, - exclude_max=True, - min_num_dims=1, - max_num_dims=1, - min_dim_size=2, - shape=(5,), - ), - reduction=st.sampled_from(["none", "sum", "mean"]), - axis=helpers.ints(min_value=-1, max_value=0), - epsilon=helpers.floats(min_value=0, max_value=1.0), - from_logits=st.booleans(), - method_num_positional_args=helpers.num_positional_args( - fn_name="BinaryCrossEntropyLoss._forward" - ), -) -def test_binary_cross_entropy_loss( - *, - dtype_and_true, - dtype_and_pred, - dtype_and_pos, - from_logits, - reduction, - axis, - epsilon, - class_name, - method_name, - ground_truth_backend, - init_flags, - method_flags, - on_device, -): - dtype_true, true = dtype_and_true - dtype_pred, pred = dtype_and_pred - dtype_pos_weight, pos_weight = dtype_and_pos - - if from_logits: - helpers.test_method( - ground_truth_backend=ground_truth_backend, - init_flags=init_flags, - method_flags=method_flags, - method_input_dtypes=dtype_true + dtype_pred + dtype_pos_weight, - init_all_as_kwargs_np={ - "from_logits": from_logits, - "epsilon": epsilon, - "reduction": reduction, - }, - method_all_as_kwargs_np={ - "true": true[0], - "pred": pred[0], - "pos_weight": pos_weight[0], - "axis": axis, - }, - class_name=class_name, - method_name=method_name, - rtol_=1e-2, - atol_=1e-2, - on_device=on_device, - ) - else: - helpers.test_method( - ground_truth_backend=ground_truth_backend, - init_flags=init_flags, - method_flags=method_flags, - method_input_dtypes=dtype_true + dtype_pred, - init_all_as_kwargs_np={ - "from_logits": from_logits, - "epsilon": epsilon, - "reduction": reduction, - }, - method_all_as_kwargs_np={ - "true": true[0], - "pred": pred[0], - "axis": axis, - }, - class_name=class_name, - method_name=method_name, - rtol_=1e-2, - atol_=1e-2, - on_device=on_device, - ) diff --git a/ivy_tests/test_ivy/test_stateful/test_modules.py b/ivy_tests/test_ivy/test_stateful/test_modules.py index 7ce71eacd3879..b84aee7c64b6d 100644 --- a/ivy_tests/test_ivy/test_stateful/test_modules.py +++ b/ivy_tests/test_ivy/test_stateful/test_modules.py @@ -33,57 +33,6 @@ def _forward(self, x): return ivy.tanh(self._linear2(x))[0] -# module training -@given( - batch_shape=helpers.get_shape( - min_num_dims=2, max_num_dims=2, min_dim_size=1, max_dim_size=2 - ), - input_channels=st.integers(min_value=2, max_value=5), - output_channels=st.integers(min_value=2, max_value=5), -) -def test_module_training(batch_shape, input_channels, output_channels, on_device): - # smoke test - if ivy.current_backend_str() == "numpy": - # NumPy does not support gradients - return - x = ivy.astype( - ivy.linspace(ivy.zeros(batch_shape), ivy.ones(batch_shape), input_channels), - "float32", - ) - module = TrainableModule(input_channels, output_channels, device=on_device) - - def loss_fn(v_): - out = module(x, v=v_) - return ivy.mean(out) - - # train - loss_tm1 = 1e12 - loss = None - grads = None - for i in range(10): - loss, grads = ivy.execute_with_gradients(loss_fn, module.v) - module.v = ivy.gradient_descent_update(module.v, grads, 1e-3) - assert loss < loss_tm1 - loss_tm1 = loss - - # type test - assert ivy.is_array(loss) - assert isinstance(grads, ivy.Container) - # cardinality test - assert loss.shape == () - # value test - assert ivy.max(ivy.abs(grads.linear0.b)) > 0 - assert ivy.max(ivy.abs(grads.linear0.w)) > 0 - assert ivy.max(ivy.abs(grads.linear1.b)) > 0 - assert ivy.max(ivy.abs(grads.linear1.w)) > 0 - assert ivy.max(ivy.abs(grads.linear2.b)) > 0 - assert ivy.max(ivy.abs(grads.linear2.w)) > 0 - # compilation test - if ivy.current_backend_str() == "torch": - # pytest scripting does not support **kwargs - return - - class TrainableModuleWithList(ivy.Module): def __init__(self, in_size, out_size, device=None, hidden_size=64): linear0 = ivy.Linear(in_size, hidden_size, device=device) @@ -99,128 +48,6 @@ def _forward(self, x): return ivy.tanh(self._layers[2](x))[0] -# module with list training -@given( - batch_shape=helpers.get_shape( - min_num_dims=2, max_num_dims=2, min_dim_size=1, max_dim_size=2 - ), - input_channels=st.integers(min_value=2, max_value=5), - output_channels=st.integers(min_value=2, max_value=5), -) -def test_module_w_list_training( - batch_shape, input_channels, output_channels, on_device -): - # smoke test - if ivy.current_backend_str() == "numpy": - # NumPy does not support gradients - return - x = ivy.astype( - ivy.linspace(ivy.zeros(batch_shape), ivy.ones(batch_shape), input_channels), - "float32", - ) - module = TrainableModuleWithList(input_channels, output_channels, device=on_device) - - def loss_fn(v_): - out = module(x, v=v_) - return ivy.mean(out) - - # train - loss_tm1 = 1e12 - loss = None - grads = None - for i in range(10): - loss, grads = ivy.execute_with_gradients(loss_fn, module.v) - module.v = ivy.gradient_descent_update(module.v, grads, 1e-3) - assert loss < loss_tm1 - loss_tm1 = loss - - # type test - assert ivy.is_array(loss) - assert isinstance(grads, ivy.Container) - # cardinality test - assert loss.shape == () - # value test - assert ivy.max(ivy.abs(grads.layers.v0.b)) > 0 - assert ivy.max(ivy.abs(grads.layers.v0.w)) > 0 - assert ivy.max(ivy.abs(grads.layers.v1.b)) > 0 - assert ivy.max(ivy.abs(grads.layers.v1.w)) > 0 - assert ivy.max(ivy.abs(grads.layers.v2.b)) > 0 - assert ivy.max(ivy.abs(grads.layers.v2.w)) > 0 - # compilation test - if ivy.current_backend_str() == "torch": - # pytest scripting does not support **kwargs - return - - -# module with partial v -@given( - batch_shape=helpers.get_shape( - min_num_dims=2, max_num_dims=2, min_dim_size=1, max_dim_size=2 - ), - input_channels=st.integers(min_value=2, max_value=5), - output_channels=st.integers(min_value=2, max_value=5), -) -def test_module_w_partial_v(batch_shape, input_channels, output_channels, on_device): - # smoke test - if ivy.current_backend_str() == "numpy": - # NumPy does not support gradients - - return - - x = ivy.astype( - ivy.linspace(ivy.zeros(batch_shape), ivy.ones(batch_shape), input_channels), - "float32", - ) - v = ivy.Container( - { - "linear0": { - "b": _variable(ivy.random_uniform(shape=[64])), - "w": _variable(ivy.random_uniform(shape=[64, 4])), - }, - "linear1": { - "b": _variable(ivy.random_uniform(shape=[64])), - "w": _variable(ivy.random_uniform(shape=[64, 64])), - "extra": _variable(ivy.random_uniform(shape=[64, 64])), - }, - "linear2": { - "b": _variable(ivy.random_uniform(shape=[5])), - "w": _variable(ivy.random_uniform(shape=[5, 64])), - }, - } - ) - try: - TrainableModule( - input_channels, output_channels, device=on_device, v=v, with_partial_v=True - ) - raise Exception( - "TrainableModule did not raise exception desipite being passed " - "with wrongly shaped variables." - ) - except ivy.utils.exceptions.IvyException: - pass - v = ivy.Container( - { - "linear0": { - "b": _variable(ivy.random_uniform(shape=[64])), - }, - "linear1": {"w": _variable(ivy.random_uniform(shape=[64, 64]))}, - "linear2": {"b": _variable(ivy.random_uniform(shape=[output_channels]))}, - } - ) - try: - TrainableModule(input_channels, output_channels, device=on_device, v=v) - raise Exception( - "TrainableModule did not raise exception desipite being passed " - "with wrongly shaped variables." - ) - except ivy.utils.exceptions.IvyException: - pass - module = TrainableModule( - input_channels, output_channels, device=on_device, v=v, with_partial_v=True - ) - module(x) - - class ModuleWithNoneAttribute(ivy.Module): def __init__(self, device=None, hidden_size=64): self.some_attribute = None @@ -230,30 +57,6 @@ def _forward(self, x): return x -# module with none attribute -@given( - batch_shape=helpers.get_shape( - min_num_dims=2, max_num_dims=2, min_dim_size=1, max_dim_size=2 - ), - input_channels=st.integers(min_value=2, max_value=5), - output_channels=st.integers(min_value=2, max_value=5), -) -def test_module_w_none_attribute( - batch_shape, input_channels, output_channels, on_device -): - # smoke test - if ivy.current_backend_str() == "numpy": - # NumPy does not support gradients - return - - x = ivy.astype( - ivy.linspace(ivy.zeros(batch_shape), ivy.ones(batch_shape), input_channels), - "float32", - ) - module = ModuleWithNoneAttribute(device=on_device) - module(x) - - class TrainableModuleWithDuplicate(ivy.Module): def __init__(self, channels, same_layer, device=None): if same_layer: @@ -275,54 +78,6 @@ def _forward(self, x): return self._linear1(x) -# module training with duplicate -@given( - batch_shape=helpers.get_shape( - min_num_dims=2, max_num_dims=2, min_dim_size=1, max_dim_size=2 - ), - channels=st.integers(min_value=1, max_value=64), - same_layer=st.booleans(), -) -def test_module_training_with_duplicate(batch_shape, channels, same_layer, on_device): - # smoke test - if ivy.current_backend_str() == "numpy": - # NumPy does not support gradients - return - x = ivy.astype( - ivy.linspace(ivy.zeros(batch_shape), ivy.ones(batch_shape), channels), "float32" - ) - module = TrainableModuleWithDuplicate(channels, same_layer, device=on_device) - - def loss_fn(v_): - out = module(x, v=v_) - return ivy.mean(out) - - # train - loss_tm1 = 1e12 - loss = None - grads = None - for i in range(10): - loss, grads = ivy.execute_with_gradients(loss_fn, module.v) - module.v = ivy.gradient_descent_update(module.v, grads, 1e-3) - assert loss < loss_tm1 - loss_tm1 = loss - - # type test - assert ivy.is_array(loss) - assert isinstance(grads, ivy.Container) - # cardinality test - assert loss.shape == () - # value test - assert ivy.max(ivy.abs(grads.linear0.b)) > 0 - assert ivy.max(ivy.abs(grads.linear0.w)) > 0 - if not same_layer: - assert ivy.max(ivy.abs(grads.linear1.b)) > 0 - # compilation test - if ivy.current_backend_str() == "torch": - # pytest scripting does not support **kwargs - return - - class TrainableModuleWithDict(ivy.Module): def __init__(self, in_size, out_size, device=None, hidden_size=64): linear0 = ivy.Linear(in_size, hidden_size, device=device) @@ -338,59 +93,6 @@ def _forward(self, x): return ivy.tanh(self._layers["linear2"](x))[0] -# module with dict training -@given( - batch_shape=helpers.get_shape( - min_num_dims=2, max_num_dims=2, min_dim_size=1, max_dim_size=2 - ), - input_channels=st.integers(min_value=2, max_value=5), - output_channels=st.integers(min_value=2, max_value=5), -) -def test_module_w_dict_training( - batch_shape, input_channels, output_channels, on_device -): - # smoke test - if ivy.current_backend_str() == "numpy": - # NumPy does not support gradients - return - x = ivy.astype( - ivy.linspace(ivy.zeros(batch_shape), ivy.ones(batch_shape), input_channels), - "float32", - ) - module = TrainableModuleWithDict(input_channels, output_channels, device=on_device) - - def loss_fn(v_): - out = module(x, v=v_) - return ivy.mean(out) - - # train - loss_tm1 = 1e12 - loss = None - grads = None - for i in range(10): - loss, grads = ivy.execute_with_gradients(loss_fn, module.v) - module.v = ivy.gradient_descent_update(module.v, grads, 1e-3) - assert loss < loss_tm1 - loss_tm1 = loss - - # type test - assert ivy.is_array(loss) - assert isinstance(grads, ivy.Container) - # cardinality test - assert loss.shape == () - # value test - assert ivy.max(ivy.abs(grads.layers.linear0.b)) > 0 - assert ivy.max(ivy.abs(grads.layers.linear0.w)) > 0 - assert ivy.max(ivy.abs(grads.layers.linear1.b)) > 0 - assert ivy.max(ivy.abs(grads.layers.linear1.w)) > 0 - assert ivy.max(ivy.abs(grads.layers.linear2.b)) > 0 - assert ivy.max(ivy.abs(grads.layers.linear2.w)) > 0 - # compilation test - if ivy.current_backend_str() == "torch": - # pytest scripting does not support **kwargs - return - - class WithCustomVarStructure(ivy.Module): def __init__(self, in_size, out_size, device=None, hidden_size=64): self._linear0 = ivy.Linear(in_size, hidden_size, device=device) @@ -405,27 +107,6 @@ def _forward(self, x): pass -# with custom var structure -@given( - batch_shape=helpers.get_shape( - min_num_dims=2, max_num_dims=2, min_dim_size=1, max_dim_size=2 - ), - input_channels=st.integers(min_value=2, max_value=5), - output_channels=st.integers(min_value=2, max_value=5), -) -def test_with_custom_var_structure( - batch_shape, input_channels, output_channels, on_device -): - # smoke test - if ivy.current_backend_str() == "numpy": - # NumPy does not support gradients - return - module = WithCustomVarStructure(input_channels, output_channels, device=on_device) - assert "x" in module.v - assert "y" in module.v - assert "z" in module.v - - class DoubleLinear(ivy.Module): def __init__(self, in_size, out_size, device=None, hidden_size=64): self._l0 = ivy.Linear(in_size, hidden_size, device=device) @@ -451,7 +132,45 @@ def _forward(self, x): return x -# top variables +class ModuleWithBuffer(ivy.Module): + def __init__(self, *args, **kwargs): + pass + + def _forward(*args, **kwargs): + pass + + +class ModuleWithTrainEval(ivy.Module): + def __init__(self): + super().__init__() + + def _forward(): + pass + + +@given( + buffer=st.just( + [ + { + "var1": [ + ivy.ones((1, 2)), + ] + } + ] + ) +) +def test_get_buffers(buffer): + module = ModuleWithBuffer() + buffers = {} + for item in buffer: + buffers.update(item) + for key in item: + module.register_buffer(key, item[key]) + + assert module.buffers == buffers + + +# check submod returns @given( batch_shape=helpers.get_shape( min_num_dims=2, max_num_dims=2, min_dim_size=1, max_dim_size=2 @@ -459,40 +178,93 @@ def _forward(self, x): input_channels=st.integers(min_value=2, max_value=5), output_channels=st.integers(min_value=2, max_value=5), ) -def test_top_variables(batch_shape, input_channels, output_channels, on_device): +def test_module_check_submod_rets( + batch_shape, input_channels, output_channels, on_device +): # smoke test if ivy.current_backend_str() == "numpy": # NumPy does not support gradients return + + x = ivy.astype( + ivy.linspace(ivy.zeros(batch_shape), ivy.ones(batch_shape), input_channels), + "float32", + ) module = WithNestedModules(input_channels, output_channels, device=on_device) - for key_chain in [ - "dl0", - "dl0/l0", - "dl0/l1", - "dl0/l0/b", - "dl0/l0/w", - "dl0/l1/b", - "dl0/l1/w", - "dl1", - "dl1/l0", - "dl1/l1", - "dl1/l0/b", - "dl1/l0/w", - "dl1/l1/b", - "dl1/l1/w", - ]: - # depth 1 - assert key_chain in module._dl0.top_v() - assert key_chain in module._dl1.top_v() - # depth 2 - assert key_chain in module._dl0._l0.top_v() - assert key_chain in module._dl0._l1.top_v() - assert key_chain in module._dl1._l0.top_v() - assert key_chain in module._dl1._l1.top_v() + # depth 1 + ret = module(x, track_submod_rets=True, submod_depth=1) + assert ret.shape == tuple(list(batch_shape) + [64]) + sm_rets = module.submod_rets + module(x, expected_submod_rets=sm_rets) + sm_rets.random_uniform(map_sequences=True) + try: + module(x, expected_submod_rets=sm_rets.random_uniform(map_sequences=True)) + raise Exception( + "forward pass succeeded despite passing random expected_submod_rets, " + "assertion error expected." + ) + except ivy.utils.exceptions.IvyException: + pass + + # depth 2 (full) + ret = module(x, track_submod_rets=True) + assert ret.shape == tuple(list(batch_shape) + [64]) + sm_rets = module.submod_rets + module(x, expected_submod_rets=sm_rets) + try: + module(x, expected_submod_rets=sm_rets.random_uniform(map_sequences=True)) + raise Exception( + "forward pass succeeded despite passing random expected_submod_rets, " + "assertion error expected." + ) + except ivy.utils.exceptions.IvyException: + pass + # partial submodules + ret = module( + x, track_submod_rets=True, submods_to_track=[module._dl1, module._dl0._l0] + ) + assert ret.shape == tuple(list(batch_shape) + [64]) + sm_rets = module.submod_rets + module(x, expected_submod_rets=sm_rets) + try: + module(x, expected_submod_rets=sm_rets.random_uniform(map_sequences=True)) + raise Exception( + "forward pass succeeded despite passing random expected_submod_rets, " + "assertion error expected." + ) + except ivy.utils.exceptions.IvyException: + pass -# top module + # with tolerances + ret = module(x, track_submod_rets=True) + assert ret.shape == tuple(list(batch_shape) + [64]) + sm_rets_orig = module.submod_rets + sm_rets = ivy.Container( + { + k: {"val": v, "atol": [1e-8] * len(v), "rtol": [1e-5] * len(v)} + for k, v in sm_rets_orig.items() + }, + **sm_rets_orig._config + ) + module(x, expected_submod_rets=sm_rets) + sm_rets = ivy.Container( + {k: {"val": v, "atol": 1e-8, "rtol": 1e-5} for k, v in sm_rets_orig.items()}, + **sm_rets_orig._config + ) + module(x, expected_submod_rets=sm_rets) + try: + module(x, expected_submod_rets=sm_rets.random_uniform(map_sequences=True)) + raise Exception( + "forward pass succeeded despite passing random expected_submod_rets, " + "assertion error expected." + ) + except ivy.utils.exceptions.IvyException: + pass + + +# module depth @given( batch_shape=helpers.get_shape( min_num_dims=2, max_num_dims=2, min_dim_size=1, max_dim_size=2 @@ -500,7 +272,7 @@ def test_top_variables(batch_shape, input_channels, output_channels, on_device): input_channels=st.integers(min_value=2, max_value=5), output_channels=st.integers(min_value=2, max_value=5), ) -def test_top_module(batch_shape, input_channels, output_channels, on_device): +def test_module_depth(batch_shape, input_channels, output_channels, on_device): # smoke test if ivy.current_backend_str() == "numpy": # NumPy does not support gradients @@ -508,23 +280,21 @@ def test_top_module(batch_shape, input_channels, output_channels, on_device): module = WithNestedModules(input_channels, output_channels, device=on_device) - # full depth - assert module._dl0.top_mod() is module - assert module._dl1.top_mod() is module - - assert module._dl0._l0.top_mod() is module - assert module._dl0._l1.top_mod() is module - assert module._dl1._l0.top_mod() is module - assert module._dl1._l1.top_mod() is module + # depth 0 + assert module.mod_depth() == 0 # depth 1 - assert module._dl0._l0.top_mod(1) is module._dl0 - assert module._dl0._l1.top_mod(1) is module._dl0 - assert module._dl1._l0.top_mod(1) is module._dl1 - assert module._dl1._l1.top_mod(1) is module._dl1 + assert module._dl0.mod_depth() == 1 + assert module._dl1.mod_depth() == 1 + # depth 2 + assert module._dl0._l0.mod_depth() == 2 + assert module._dl0._l1.mod_depth() == 2 + assert module._dl1._l0.mod_depth() == 2 + assert module._dl1._l1.mod_depth() == 2 -# v with top v key chains + +# module height @given( batch_shape=helpers.get_shape( min_num_dims=2, max_num_dims=2, min_dim_size=1, max_dim_size=2 @@ -532,9 +302,7 @@ def test_top_module(batch_shape, input_channels, output_channels, on_device): input_channels=st.integers(min_value=2, max_value=5), output_channels=st.integers(min_value=2, max_value=5), ) -def test_v_with_top_v_key_chains( - batch_shape, input_channels, output_channels, on_device -): +def test_module_height(batch_shape, input_channels, output_channels, on_device): # smoke test if ivy.current_backend_str() == "numpy": # NumPy does not support gradients @@ -542,106 +310,12 @@ def test_v_with_top_v_key_chains( module = WithNestedModules(input_channels, output_channels, device=on_device) - # full depth - v = module._dl0.v_with_top_v_key_chains() - assert "dl0" in v - assert v.dl0 is module._dl0.v - - v = module._dl1.v_with_top_v_key_chains() - assert "dl1" in v - assert v.dl1 is module._dl1.v + # height 2 + assert module.mod_height() == 2 - v = module._dl0._l0.v_with_top_v_key_chains() - assert "dl0" in v - assert "l0" in v.dl0 - assert v.dl0.l0 is module._dl0._l0.v - - v = module._dl0._l1.v_with_top_v_key_chains() - assert "dl0" in v - assert "l1" in v.dl0 - assert v.dl0.l1 is module._dl0._l1.v - - v = module._dl1._l0.v_with_top_v_key_chains() - assert "dl1" in v - assert "l0" in v.dl1 - assert v.dl1.l0 is module._dl1._l0.v - - v = module._dl1._l1.v_with_top_v_key_chains() - assert "dl1" in v - assert "l1" in v.dl1 - assert v.dl1.l1 is module._dl1._l1.v - - # depth 1 - - v = module._dl0._l0.v_with_top_v_key_chains(depth=1) - assert "l0" in v - assert v.l0 is module._dl0._l0.v - - v = module._dl0._l1.v_with_top_v_key_chains(depth=1) - assert "l1" in v - assert v.l1 is module._dl0._l1.v - - v = module._dl1._l0.v_with_top_v_key_chains(depth=1) - assert "l0" in v - assert v.l0 is module._dl1._l0.v - - v = module._dl1._l1.v_with_top_v_key_chains(depth=1) - assert "l1" in v - assert v.l1 is module._dl1._l1.v - - -# module depth -@given( - batch_shape=helpers.get_shape( - min_num_dims=2, max_num_dims=2, min_dim_size=1, max_dim_size=2 - ), - input_channels=st.integers(min_value=2, max_value=5), - output_channels=st.integers(min_value=2, max_value=5), -) -def test_module_depth(batch_shape, input_channels, output_channels, on_device): - # smoke test - if ivy.current_backend_str() == "numpy": - # NumPy does not support gradients - return - - module = WithNestedModules(input_channels, output_channels, device=on_device) - - # depth 0 - assert module.mod_depth() == 0 - - # depth 1 - assert module._dl0.mod_depth() == 1 - assert module._dl1.mod_depth() == 1 - - # depth 2 - assert module._dl0._l0.mod_depth() == 2 - assert module._dl0._l1.mod_depth() == 2 - assert module._dl1._l0.mod_depth() == 2 - assert module._dl1._l1.mod_depth() == 2 - - -# module height -@given( - batch_shape=helpers.get_shape( - min_num_dims=2, max_num_dims=2, min_dim_size=1, max_dim_size=2 - ), - input_channels=st.integers(min_value=2, max_value=5), - output_channels=st.integers(min_value=2, max_value=5), -) -def test_module_height(batch_shape, input_channels, output_channels, on_device): - # smoke test - if ivy.current_backend_str() == "numpy": - # NumPy does not support gradients - return - - module = WithNestedModules(input_channels, output_channels, device=on_device) - - # height 2 - assert module.mod_height() == 2 - - # height 1 - assert module._dl0.mod_height() == 1 - assert module._dl1.mod_height() == 1 + # height 1 + assert module._dl0.mod_height() == 1 + assert module._dl1.mod_height() == 1 # height 0 assert module._dl0._l0.mod_height() == 0 @@ -650,7 +324,6 @@ def test_module_height(batch_shape, input_channels, output_channels, on_device): assert module._dl1._l1.mod_height() == 0 -# sub modules @given( batch_shape=helpers.get_shape( min_num_dims=2, max_num_dims=2, min_dim_size=1, max_dim_size=2 @@ -658,35 +331,49 @@ def test_module_height(batch_shape, input_channels, output_channels, on_device): input_channels=st.integers(min_value=2, max_value=5), output_channels=st.integers(min_value=2, max_value=5), ) -def test_sub_modules(batch_shape, input_channels, output_channels, on_device): +def test_module_save_and_load_as_pickled( + batch_shape, input_channels, output_channels, on_device +): + save_filepath = "module.pickled" + # smoke test if ivy.current_backend_str() == "numpy": # NumPy does not support gradients return + x = ivy.astype( + ivy.linspace(ivy.zeros(batch_shape), ivy.ones(batch_shape), input_channels), + "float32", + ) + module = TrainableModule(input_channels, output_channels, device=on_device) - module = WithNestedModules(input_channels, output_channels, device=on_device) + def loss_fn(v_): + out = module(x, v=v_) + return ivy.mean(out) - # depth 0 - sub_mods = module.sub_mods(depth=0) - assert module.v is sub_mods + module.save(save_filepath) + assert os.path.exists(save_filepath) + loaded_module = ivy.Module.load(save_filepath) - # depth 1 - sub_mods = module.sub_mods(depth=1) - for v in [module._dl0.v, module._dl1.v]: - assert v in sub_mods + # train + loss, grads = ivy.execute_with_gradients(loss_fn, module.v) + module.v = ivy.gradient_descent_update(module.v, grads, 1e-3) - # depth 2 (full) - sub_mods = module.sub_mods() - for v in [ - module._dl0._l0.v, - module._dl0._l1.v, - module._dl1._l0.v, - module._dl1._l1.v, - ]: - assert v in sub_mods + loaded_loss, loaded_grads = ivy.execute_with_gradients(loss_fn, loaded_module.v) + loaded_module.v = ivy.gradient_descent_update(loaded_module.v, loaded_grads, 1e-3) + + # type test + assert ivy.is_array(loaded_loss) + assert isinstance(loaded_grads, ivy.Container) + # cardinality test + assert loaded_loss.shape == () + # value test + assert ivy.all_equal(loaded_loss == loss) + assert ivy.Container.all(loaded_module.v == module.v).cont_all_true() + + os.remove(save_filepath) -# track submod returns +# track submod call order @given( batch_shape=helpers.get_shape( min_num_dims=2, max_num_dims=2, min_dim_size=1, max_dim_size=2 @@ -694,7 +381,7 @@ def test_sub_modules(batch_shape, input_channels, output_channels, on_device): input_channels=st.integers(min_value=2, max_value=5), output_channels=st.integers(min_value=2, max_value=5), ) -def test_module_track_submod_rets( +def test_module_track_submod_call_order( batch_shape, input_channels, output_channels, on_device ): # smoke test @@ -708,219 +395,58 @@ def test_module_track_submod_rets( ) module = WithNestedModules(input_channels, output_channels, device=on_device) - # depth 1 - ret = module(x, track_submod_rets=True, submod_depth=1) - assert ret.shape == tuple(list(batch_shape) + [64]) - sm_rets = module.submod_rets - for submod in [module._dl0, module._dl1]: - for ret in sm_rets[submod.get_mod_key()]: - assert isinstance(ret, np.ndarray) - assert ret.shape == tuple(list(batch_shape) + [64]) - for submod in [module._dl0._l0, module._dl0._l1, module._dl1._l0, module._dl1._l1]: - assert ( - ivy.Container.cont_flatten_key_chain(submod.__repr__(), "_") not in sm_rets - ) + root_key_0 = ivy.Container.cont_flatten_key_chain(module.__repr__(), "_") + "_0" - # depth 2 (full) - ret = module(x, track_submod_rets=True) - assert ret.shape == tuple(list(batch_shape) + [64]) - sm_rets = module.submod_rets - for submod in [ - module._dl0, - module._dl1, - module._dl0._l0, - module._dl0._l1, - module._dl1._l0, - module._dl1._l1, - ]: - for ret in sm_rets[submod.get_mod_key()]: - assert isinstance(ret, np.ndarray) - assert ret.shape == tuple(list(batch_shape) + [64]) + dl0_key_0 = ivy.Container.cont_flatten_key_chain(module._dl0.__repr__(), "_") + "_0" + dl1_key_0 = ivy.Container.cont_flatten_key_chain(module._dl1.__repr__(), "_") + "_0" + dl1_key_1 = ivy.Container.cont_flatten_key_chain(module._dl1.__repr__(), "_") + "_1" - # partial submodules - ret = module( - x, track_submod_rets=True, submods_to_track=[module._dl1, module._dl0._l0] + dl0_l0_key_0 = ( + ivy.Container.cont_flatten_key_chain(module._dl0._l0.__repr__(), "_") + "_0" + ) + dl0_l1_key_0 = ( + ivy.Container.cont_flatten_key_chain(module._dl0._l1.__repr__(), "_") + "_0" + ) + dl1_l0_key_0 = ( + ivy.Container.cont_flatten_key_chain(module._dl1._l0.__repr__(), "_") + "_0" + ) + dl1_l1_key_0 = ( + ivy.Container.cont_flatten_key_chain(module._dl1._l1.__repr__(), "_") + "_0" ) + + # depth 1 + ret = module(x, track_submod_call_order=True, submod_depth=1) assert ret.shape == tuple(list(batch_shape) + [64]) - sm_rets = module.submod_rets - for submod in [module._dl1, module._dl0._l0]: - for ret in sm_rets[submod.get_mod_key()]: - assert isinstance(ret, np.ndarray) - assert ret.shape == tuple(list(batch_shape) + [64]) - for submod in [module._dl0, module._dl0._l1, module._dl1._l0, module._dl1._l1]: - assert ( - ivy.Container.cont_flatten_key_chain(submod.__repr__(), "_") not in sm_rets - ) + sm_co = module.submod_call_order -# check submod returns -@given( - batch_shape=helpers.get_shape( - min_num_dims=2, max_num_dims=2, min_dim_size=1, max_dim_size=2 - ), - input_channels=st.integers(min_value=2, max_value=5), - output_channels=st.integers(min_value=2, max_value=5), -) -def test_module_check_submod_rets( - batch_shape, input_channels, output_channels, on_device -): - # smoke test - if ivy.current_backend_str() == "numpy": - # NumPy does not support gradients - return + assert root_key_0 in sm_co - x = ivy.astype( - ivy.linspace(ivy.zeros(batch_shape), ivy.ones(batch_shape), input_channels), - "float32", + assert dl0_key_0 in sm_co[root_key_0] + assert dl1_key_0 in sm_co[root_key_0] + assert dl1_key_1 in sm_co[root_key_0] + + assert ivy.Container.cont_identical( + [ + sm_co[root_key_0][dl0_key_0], + module._dl0.v.cont_flatten_key_chains().to_numpy(), + ] + ) + assert ivy.Container.cont_identical( + [ + sm_co[root_key_0][dl1_key_0], + module._dl1.v.cont_flatten_key_chains().to_numpy(), + ] + ) + assert ivy.Container.cont_identical( + [ + sm_co[root_key_0][dl1_key_1], + module._dl1.v.cont_flatten_key_chains().to_numpy(), + ] ) - module = WithNestedModules(input_channels, output_channels, device=on_device) - # depth 1 - ret = module(x, track_submod_rets=True, submod_depth=1) - assert ret.shape == tuple(list(batch_shape) + [64]) - sm_rets = module.submod_rets - module(x, expected_submod_rets=sm_rets) - sm_rets.random_uniform(map_sequences=True) - try: - module(x, expected_submod_rets=sm_rets.random_uniform(map_sequences=True)) - raise Exception( - "forward pass succeeded despite passing random expected_submod_rets, " - "assertion error expected." - ) - except ivy.utils.exceptions.IvyException: - pass - - # depth 2 (full) - ret = module(x, track_submod_rets=True) - assert ret.shape == tuple(list(batch_shape) + [64]) - sm_rets = module.submod_rets - module(x, expected_submod_rets=sm_rets) - try: - module(x, expected_submod_rets=sm_rets.random_uniform(map_sequences=True)) - raise Exception( - "forward pass succeeded despite passing random expected_submod_rets, " - "assertion error expected." - ) - except ivy.utils.exceptions.IvyException: - pass - - # partial submodules - ret = module( - x, track_submod_rets=True, submods_to_track=[module._dl1, module._dl0._l0] - ) - assert ret.shape == tuple(list(batch_shape) + [64]) - sm_rets = module.submod_rets - module(x, expected_submod_rets=sm_rets) - try: - module(x, expected_submod_rets=sm_rets.random_uniform(map_sequences=True)) - raise Exception( - "forward pass succeeded despite passing random expected_submod_rets, " - "assertion error expected." - ) - except ivy.utils.exceptions.IvyException: - pass - - # with tolerances - ret = module(x, track_submod_rets=True) - assert ret.shape == tuple(list(batch_shape) + [64]) - sm_rets_orig = module.submod_rets - sm_rets = ivy.Container( - { - k: {"val": v, "atol": [1e-8] * len(v), "rtol": [1e-5] * len(v)} - for k, v in sm_rets_orig.items() - }, - **sm_rets_orig._config - ) - module(x, expected_submod_rets=sm_rets) - sm_rets = ivy.Container( - {k: {"val": v, "atol": 1e-8, "rtol": 1e-5} for k, v in sm_rets_orig.items()}, - **sm_rets_orig._config - ) - module(x, expected_submod_rets=sm_rets) - try: - module(x, expected_submod_rets=sm_rets.random_uniform(map_sequences=True)) - raise Exception( - "forward pass succeeded despite passing random expected_submod_rets, " - "assertion error expected." - ) - except ivy.utils.exceptions.IvyException: - pass - - -# track submod call order -@given( - batch_shape=helpers.get_shape( - min_num_dims=2, max_num_dims=2, min_dim_size=1, max_dim_size=2 - ), - input_channels=st.integers(min_value=2, max_value=5), - output_channels=st.integers(min_value=2, max_value=5), -) -def test_module_track_submod_call_order( - batch_shape, input_channels, output_channels, on_device -): - # smoke test - if ivy.current_backend_str() == "numpy": - # NumPy does not support gradients - return - - x = ivy.astype( - ivy.linspace(ivy.zeros(batch_shape), ivy.ones(batch_shape), input_channels), - "float32", - ) - module = WithNestedModules(input_channels, output_channels, device=on_device) - - root_key_0 = ivy.Container.cont_flatten_key_chain(module.__repr__(), "_") + "_0" - - dl0_key_0 = ivy.Container.cont_flatten_key_chain(module._dl0.__repr__(), "_") + "_0" - dl1_key_0 = ivy.Container.cont_flatten_key_chain(module._dl1.__repr__(), "_") + "_0" - dl1_key_1 = ivy.Container.cont_flatten_key_chain(module._dl1.__repr__(), "_") + "_1" - - dl0_l0_key_0 = ( - ivy.Container.cont_flatten_key_chain(module._dl0._l0.__repr__(), "_") + "_0" - ) - dl0_l1_key_0 = ( - ivy.Container.cont_flatten_key_chain(module._dl0._l1.__repr__(), "_") + "_0" - ) - dl1_l0_key_0 = ( - ivy.Container.cont_flatten_key_chain(module._dl1._l0.__repr__(), "_") + "_0" - ) - dl1_l1_key_0 = ( - ivy.Container.cont_flatten_key_chain(module._dl1._l1.__repr__(), "_") + "_0" - ) - - # depth 1 - ret = module(x, track_submod_call_order=True, submod_depth=1) - assert ret.shape == tuple(list(batch_shape) + [64]) - - sm_co = module.submod_call_order - - assert root_key_0 in sm_co - - assert dl0_key_0 in sm_co[root_key_0] - assert dl1_key_0 in sm_co[root_key_0] - assert dl1_key_1 in sm_co[root_key_0] - - assert ivy.Container.cont_identical( - [ - sm_co[root_key_0][dl0_key_0], - module._dl0.v.cont_flatten_key_chains().to_numpy(), - ] - ) - assert ivy.Container.cont_identical( - [ - sm_co[root_key_0][dl1_key_0], - module._dl1.v.cont_flatten_key_chains().to_numpy(), - ] - ) - assert ivy.Container.cont_identical( - [ - sm_co[root_key_0][dl1_key_1], - module._dl1.v.cont_flatten_key_chains().to_numpy(), - ] - ) - - # depth 2 (full) - ret = module(x, track_submod_call_order=True) + # depth 2 (full) + ret = module(x, track_submod_call_order=True) assert ret.shape == tuple(list(batch_shape) + [64]) sm_co = module.submod_call_order @@ -1012,6 +538,7 @@ def test_module_track_submod_call_order( ) +# track submod returns @given( batch_shape=helpers.get_shape( min_num_dims=2, max_num_dims=2, min_dim_size=1, max_dim_size=2 @@ -1019,11 +546,74 @@ def test_module_track_submod_call_order( input_channels=st.integers(min_value=2, max_value=5), output_channels=st.integers(min_value=2, max_value=5), ) -def test_module_save_and_load_as_pickled( +def test_module_track_submod_rets( batch_shape, input_channels, output_channels, on_device ): - save_filepath = "module.pickled" + # smoke test + if ivy.current_backend_str() == "numpy": + # NumPy does not support gradients + return + + x = ivy.astype( + ivy.linspace(ivy.zeros(batch_shape), ivy.ones(batch_shape), input_channels), + "float32", + ) + module = WithNestedModules(input_channels, output_channels, device=on_device) + + # depth 1 + ret = module(x, track_submod_rets=True, submod_depth=1) + assert ret.shape == tuple(list(batch_shape) + [64]) + sm_rets = module.submod_rets + for submod in [module._dl0, module._dl1]: + for ret in sm_rets[submod.get_mod_key()]: + assert isinstance(ret, np.ndarray) + assert ret.shape == tuple(list(batch_shape) + [64]) + for submod in [module._dl0._l0, module._dl0._l1, module._dl1._l0, module._dl1._l1]: + assert ( + ivy.Container.cont_flatten_key_chain(submod.__repr__(), "_") not in sm_rets + ) + + # depth 2 (full) + ret = module(x, track_submod_rets=True) + assert ret.shape == tuple(list(batch_shape) + [64]) + sm_rets = module.submod_rets + for submod in [ + module._dl0, + module._dl1, + module._dl0._l0, + module._dl0._l1, + module._dl1._l0, + module._dl1._l1, + ]: + for ret in sm_rets[submod.get_mod_key()]: + assert isinstance(ret, np.ndarray) + assert ret.shape == tuple(list(batch_shape) + [64]) + + # partial submodules + ret = module( + x, track_submod_rets=True, submods_to_track=[module._dl1, module._dl0._l0] + ) + assert ret.shape == tuple(list(batch_shape) + [64]) + sm_rets = module.submod_rets + for submod in [module._dl1, module._dl0._l0]: + for ret in sm_rets[submod.get_mod_key()]: + assert isinstance(ret, np.ndarray) + assert ret.shape == tuple(list(batch_shape) + [64]) + for submod in [module._dl0, module._dl0._l1, module._dl1._l0, module._dl1._l1]: + assert ( + ivy.Container.cont_flatten_key_chain(submod.__repr__(), "_") not in sm_rets + ) + +# module training +@given( + batch_shape=helpers.get_shape( + min_num_dims=2, max_num_dims=2, min_dim_size=1, max_dim_size=2 + ), + input_channels=st.integers(min_value=2, max_value=5), + output_channels=st.integers(min_value=2, max_value=5), +) +def test_module_training(batch_shape, input_channels, output_channels, on_device): # smoke test if ivy.current_backend_str() == "numpy": # NumPy does not support gradients @@ -1038,65 +628,388 @@ def loss_fn(v_): out = module(x, v=v_) return ivy.mean(out) - module.save(save_filepath) - assert os.path.exists(save_filepath) - loaded_module = ivy.Module.load(save_filepath) - # train - loss, grads = ivy.execute_with_gradients(loss_fn, module.v) - module.v = ivy.gradient_descent_update(module.v, grads, 1e-3) - - loaded_loss, loaded_grads = ivy.execute_with_gradients(loss_fn, loaded_module.v) - loaded_module.v = ivy.gradient_descent_update(loaded_module.v, loaded_grads, 1e-3) + loss_tm1 = 1e12 + loss = None + grads = None + for i in range(10): + loss, grads = ivy.execute_with_gradients(loss_fn, module.v) + module.v = ivy.gradient_descent_update(module.v, grads, 1e-3) + assert loss < loss_tm1 + loss_tm1 = loss # type test - assert ivy.is_array(loaded_loss) - assert isinstance(loaded_grads, ivy.Container) + assert ivy.is_array(loss) + assert isinstance(grads, ivy.Container) # cardinality test - assert loaded_loss.shape == () + assert loss.shape == () # value test - assert ivy.all_equal(loaded_loss == loss) - assert ivy.Container.all(loaded_module.v == module.v).cont_all_true() - - os.remove(save_filepath) - - -class ModuleWithBuffer(ivy.Module): - def __init__(self, *args, **kwargs): - pass - - def _forward(*args, **kwargs): - pass + assert ivy.max(ivy.abs(grads.linear0.b)) > 0 + assert ivy.max(ivy.abs(grads.linear0.w)) > 0 + assert ivy.max(ivy.abs(grads.linear1.b)) > 0 + assert ivy.max(ivy.abs(grads.linear1.w)) > 0 + assert ivy.max(ivy.abs(grads.linear2.b)) > 0 + assert ivy.max(ivy.abs(grads.linear2.w)) > 0 + # compilation test + if ivy.current_backend_str() == "torch": + # pytest scripting does not support **kwargs + return +# module training with duplicate @given( - buffer=st.just( - [ - { - "var1": [ - ivy.ones((1, 2)), - ] - } - ] - ) -) -def test_get_buffers(buffer): - module = ModuleWithBuffer() - buffers = {} - for item in buffer: - buffers.update(item) - for key in item: - module.register_buffer(key, item[key]) + batch_shape=helpers.get_shape( + min_num_dims=2, max_num_dims=2, min_dim_size=1, max_dim_size=2 + ), + channels=st.integers(min_value=1, max_value=64), + same_layer=st.booleans(), +) +def test_module_training_with_duplicate(batch_shape, channels, same_layer, on_device): + # smoke test + if ivy.current_backend_str() == "numpy": + # NumPy does not support gradients + return + x = ivy.astype( + ivy.linspace(ivy.zeros(batch_shape), ivy.ones(batch_shape), channels), "float32" + ) + module = TrainableModuleWithDuplicate(channels, same_layer, device=on_device) - assert module.buffers == buffers + def loss_fn(v_): + out = module(x, v=v_) + return ivy.mean(out) + + # train + loss_tm1 = 1e12 + loss = None + grads = None + for i in range(10): + loss, grads = ivy.execute_with_gradients(loss_fn, module.v) + module.v = ivy.gradient_descent_update(module.v, grads, 1e-3) + assert loss < loss_tm1 + loss_tm1 = loss + # type test + assert ivy.is_array(loss) + assert isinstance(grads, ivy.Container) + # cardinality test + assert loss.shape == () + # value test + assert ivy.max(ivy.abs(grads.linear0.b)) > 0 + assert ivy.max(ivy.abs(grads.linear0.w)) > 0 + if not same_layer: + assert ivy.max(ivy.abs(grads.linear1.b)) > 0 + # compilation test + if ivy.current_backend_str() == "torch": + # pytest scripting does not support **kwargs + return -class ModuleWithTrainEval(ivy.Module): - def __init__(self): - super().__init__() - def _forward(): +# module with dict training +@given( + batch_shape=helpers.get_shape( + min_num_dims=2, max_num_dims=2, min_dim_size=1, max_dim_size=2 + ), + input_channels=st.integers(min_value=2, max_value=5), + output_channels=st.integers(min_value=2, max_value=5), +) +def test_module_w_dict_training( + batch_shape, input_channels, output_channels, on_device +): + # smoke test + if ivy.current_backend_str() == "numpy": + # NumPy does not support gradients + return + x = ivy.astype( + ivy.linspace(ivy.zeros(batch_shape), ivy.ones(batch_shape), input_channels), + "float32", + ) + module = TrainableModuleWithDict(input_channels, output_channels, device=on_device) + + def loss_fn(v_): + out = module(x, v=v_) + return ivy.mean(out) + + # train + loss_tm1 = 1e12 + loss = None + grads = None + for i in range(10): + loss, grads = ivy.execute_with_gradients(loss_fn, module.v) + module.v = ivy.gradient_descent_update(module.v, grads, 1e-3) + assert loss < loss_tm1 + loss_tm1 = loss + + # type test + assert ivy.is_array(loss) + assert isinstance(grads, ivy.Container) + # cardinality test + assert loss.shape == () + # value test + assert ivy.max(ivy.abs(grads.layers.linear0.b)) > 0 + assert ivy.max(ivy.abs(grads.layers.linear0.w)) > 0 + assert ivy.max(ivy.abs(grads.layers.linear1.b)) > 0 + assert ivy.max(ivy.abs(grads.layers.linear1.w)) > 0 + assert ivy.max(ivy.abs(grads.layers.linear2.b)) > 0 + assert ivy.max(ivy.abs(grads.layers.linear2.w)) > 0 + # compilation test + if ivy.current_backend_str() == "torch": + # pytest scripting does not support **kwargs + return + + +# module with list training +@given( + batch_shape=helpers.get_shape( + min_num_dims=2, max_num_dims=2, min_dim_size=1, max_dim_size=2 + ), + input_channels=st.integers(min_value=2, max_value=5), + output_channels=st.integers(min_value=2, max_value=5), +) +def test_module_w_list_training( + batch_shape, input_channels, output_channels, on_device +): + # smoke test + if ivy.current_backend_str() == "numpy": + # NumPy does not support gradients + return + x = ivy.astype( + ivy.linspace(ivy.zeros(batch_shape), ivy.ones(batch_shape), input_channels), + "float32", + ) + module = TrainableModuleWithList(input_channels, output_channels, device=on_device) + + def loss_fn(v_): + out = module(x, v=v_) + return ivy.mean(out) + + # train + loss_tm1 = 1e12 + loss = None + grads = None + for i in range(10): + loss, grads = ivy.execute_with_gradients(loss_fn, module.v) + module.v = ivy.gradient_descent_update(module.v, grads, 1e-3) + assert loss < loss_tm1 + loss_tm1 = loss + + # type test + assert ivy.is_array(loss) + assert isinstance(grads, ivy.Container) + # cardinality test + assert loss.shape == () + # value test + assert ivy.max(ivy.abs(grads.layers.v0.b)) > 0 + assert ivy.max(ivy.abs(grads.layers.v0.w)) > 0 + assert ivy.max(ivy.abs(grads.layers.v1.b)) > 0 + assert ivy.max(ivy.abs(grads.layers.v1.w)) > 0 + assert ivy.max(ivy.abs(grads.layers.v2.b)) > 0 + assert ivy.max(ivy.abs(grads.layers.v2.w)) > 0 + # compilation test + if ivy.current_backend_str() == "torch": + # pytest scripting does not support **kwargs + return + + +# module with none attribute +@given( + batch_shape=helpers.get_shape( + min_num_dims=2, max_num_dims=2, min_dim_size=1, max_dim_size=2 + ), + input_channels=st.integers(min_value=2, max_value=5), + output_channels=st.integers(min_value=2, max_value=5), +) +def test_module_w_none_attribute( + batch_shape, input_channels, output_channels, on_device +): + # smoke test + if ivy.current_backend_str() == "numpy": + # NumPy does not support gradients + return + + x = ivy.astype( + ivy.linspace(ivy.zeros(batch_shape), ivy.ones(batch_shape), input_channels), + "float32", + ) + module = ModuleWithNoneAttribute(device=on_device) + module(x) + + +# module with partial v +@given( + batch_shape=helpers.get_shape( + min_num_dims=2, max_num_dims=2, min_dim_size=1, max_dim_size=2 + ), + input_channels=st.integers(min_value=2, max_value=5), + output_channels=st.integers(min_value=2, max_value=5), +) +def test_module_w_partial_v(batch_shape, input_channels, output_channels, on_device): + # smoke test + if ivy.current_backend_str() == "numpy": + # NumPy does not support gradients + + return + + x = ivy.astype( + ivy.linspace(ivy.zeros(batch_shape), ivy.ones(batch_shape), input_channels), + "float32", + ) + v = ivy.Container( + { + "linear0": { + "b": _variable(ivy.random_uniform(shape=[64])), + "w": _variable(ivy.random_uniform(shape=[64, 4])), + }, + "linear1": { + "b": _variable(ivy.random_uniform(shape=[64])), + "w": _variable(ivy.random_uniform(shape=[64, 64])), + "extra": _variable(ivy.random_uniform(shape=[64, 64])), + }, + "linear2": { + "b": _variable(ivy.random_uniform(shape=[5])), + "w": _variable(ivy.random_uniform(shape=[5, 64])), + }, + } + ) + try: + TrainableModule( + input_channels, output_channels, device=on_device, v=v, with_partial_v=True + ) + raise Exception( + "TrainableModule did not raise exception desipite being passed " + "with wrongly shaped variables." + ) + except ivy.utils.exceptions.IvyException: + pass + v = ivy.Container( + { + "linear0": { + "b": _variable(ivy.random_uniform(shape=[64])), + }, + "linear1": {"w": _variable(ivy.random_uniform(shape=[64, 64]))}, + "linear2": {"b": _variable(ivy.random_uniform(shape=[output_channels]))}, + } + ) + try: + TrainableModule(input_channels, output_channels, device=on_device, v=v) + raise Exception( + "TrainableModule did not raise exception desipite being passed " + "with wrongly shaped variables." + ) + except ivy.utils.exceptions.IvyException: pass + module = TrainableModule( + input_channels, output_channels, device=on_device, v=v, with_partial_v=True + ) + module(x) + + +# sub modules +@given( + batch_shape=helpers.get_shape( + min_num_dims=2, max_num_dims=2, min_dim_size=1, max_dim_size=2 + ), + input_channels=st.integers(min_value=2, max_value=5), + output_channels=st.integers(min_value=2, max_value=5), +) +def test_sub_modules(batch_shape, input_channels, output_channels, on_device): + # smoke test + if ivy.current_backend_str() == "numpy": + # NumPy does not support gradients + return + + module = WithNestedModules(input_channels, output_channels, device=on_device) + + # depth 0 + sub_mods = module.sub_mods(depth=0) + assert module.v is sub_mods + + # depth 1 + sub_mods = module.sub_mods(depth=1) + for v in [module._dl0.v, module._dl1.v]: + assert v in sub_mods + + # depth 2 (full) + sub_mods = module.sub_mods() + for v in [ + module._dl0._l0.v, + module._dl0._l1.v, + module._dl1._l0.v, + module._dl1._l1.v, + ]: + assert v in sub_mods + + +# top module +@given( + batch_shape=helpers.get_shape( + min_num_dims=2, max_num_dims=2, min_dim_size=1, max_dim_size=2 + ), + input_channels=st.integers(min_value=2, max_value=5), + output_channels=st.integers(min_value=2, max_value=5), +) +def test_top_module(batch_shape, input_channels, output_channels, on_device): + # smoke test + if ivy.current_backend_str() == "numpy": + # NumPy does not support gradients + return + + module = WithNestedModules(input_channels, output_channels, device=on_device) + + # full depth + assert module._dl0.top_mod() is module + assert module._dl1.top_mod() is module + + assert module._dl0._l0.top_mod() is module + assert module._dl0._l1.top_mod() is module + assert module._dl1._l0.top_mod() is module + assert module._dl1._l1.top_mod() is module + + # depth 1 + assert module._dl0._l0.top_mod(1) is module._dl0 + assert module._dl0._l1.top_mod(1) is module._dl0 + assert module._dl1._l0.top_mod(1) is module._dl1 + assert module._dl1._l1.top_mod(1) is module._dl1 + + +# top variables +@given( + batch_shape=helpers.get_shape( + min_num_dims=2, max_num_dims=2, min_dim_size=1, max_dim_size=2 + ), + input_channels=st.integers(min_value=2, max_value=5), + output_channels=st.integers(min_value=2, max_value=5), +) +def test_top_variables(batch_shape, input_channels, output_channels, on_device): + # smoke test + if ivy.current_backend_str() == "numpy": + # NumPy does not support gradients + return + module = WithNestedModules(input_channels, output_channels, device=on_device) + for key_chain in [ + "dl0", + "dl0/l0", + "dl0/l1", + "dl0/l0/b", + "dl0/l0/w", + "dl0/l1/b", + "dl0/l1/w", + "dl1", + "dl1/l0", + "dl1/l1", + "dl1/l0/b", + "dl1/l0/w", + "dl1/l1/b", + "dl1/l1/w", + ]: + # depth 1 + assert key_chain in module._dl0.top_v() + assert key_chain in module._dl1.top_v() + + # depth 2 + assert key_chain in module._dl0._l0.top_v() + assert key_chain in module._dl0._l1.top_v() + assert key_chain in module._dl1._l0.top_v() + assert key_chain in module._dl1._l1.top_v() @given(mode=st.booleans()) @@ -1106,3 +1019,90 @@ def test_train_eval(mode): assert mode == cls.training cls.eval() assert False == cls.training + + +# v with top v key chains +@given( + batch_shape=helpers.get_shape( + min_num_dims=2, max_num_dims=2, min_dim_size=1, max_dim_size=2 + ), + input_channels=st.integers(min_value=2, max_value=5), + output_channels=st.integers(min_value=2, max_value=5), +) +def test_v_with_top_v_key_chains( + batch_shape, input_channels, output_channels, on_device +): + # smoke test + if ivy.current_backend_str() == "numpy": + # NumPy does not support gradients + return + + module = WithNestedModules(input_channels, output_channels, device=on_device) + + # full depth + v = module._dl0.v_with_top_v_key_chains() + assert "dl0" in v + assert v.dl0 is module._dl0.v + + v = module._dl1.v_with_top_v_key_chains() + assert "dl1" in v + assert v.dl1 is module._dl1.v + + v = module._dl0._l0.v_with_top_v_key_chains() + assert "dl0" in v + assert "l0" in v.dl0 + assert v.dl0.l0 is module._dl0._l0.v + + v = module._dl0._l1.v_with_top_v_key_chains() + assert "dl0" in v + assert "l1" in v.dl0 + assert v.dl0.l1 is module._dl0._l1.v + + v = module._dl1._l0.v_with_top_v_key_chains() + assert "dl1" in v + assert "l0" in v.dl1 + assert v.dl1.l0 is module._dl1._l0.v + + v = module._dl1._l1.v_with_top_v_key_chains() + assert "dl1" in v + assert "l1" in v.dl1 + assert v.dl1.l1 is module._dl1._l1.v + + # depth 1 + + v = module._dl0._l0.v_with_top_v_key_chains(depth=1) + assert "l0" in v + assert v.l0 is module._dl0._l0.v + + v = module._dl0._l1.v_with_top_v_key_chains(depth=1) + assert "l1" in v + assert v.l1 is module._dl0._l1.v + + v = module._dl1._l0.v_with_top_v_key_chains(depth=1) + assert "l0" in v + assert v.l0 is module._dl1._l0.v + + v = module._dl1._l1.v_with_top_v_key_chains(depth=1) + assert "l1" in v + assert v.l1 is module._dl1._l1.v + + +# with custom var structure +@given( + batch_shape=helpers.get_shape( + min_num_dims=2, max_num_dims=2, min_dim_size=1, max_dim_size=2 + ), + input_channels=st.integers(min_value=2, max_value=5), + output_channels=st.integers(min_value=2, max_value=5), +) +def test_with_custom_var_structure( + batch_shape, input_channels, output_channels, on_device +): + # smoke test + if ivy.current_backend_str() == "numpy": + # NumPy does not support gradients + return + module = WithCustomVarStructure(input_channels, output_channels, device=on_device) + assert "x" in module.v + assert "y" in module.v + assert "z" in module.v diff --git a/ivy_tests/test_ivy/test_stateful/test_norms.py b/ivy_tests/test_ivy/test_stateful/test_norms.py index 54791b8f92dd9..ad44f5814eb7b 100644 --- a/ivy_tests/test_ivy/test_stateful/test_norms.py +++ b/ivy_tests/test_ivy/test_stateful/test_norms.py @@ -9,19 +9,43 @@ from ivy_tests.test_ivy.helpers import handle_method +# --- Helpers --- # +# --------------- # + + +@st.composite +def _generate_batchnorm_data(draw): + batch_size = draw(st.integers(min_value=2, max_value=5)) + num_features = draw(st.integers(min_value=2, max_value=3)) + num_dims = draw(st.integers(min_value=1, max_value=3)) + dims = [draw(st.integers(1, 5)) for i in range(num_dims)] + x_shape = [batch_size] + [*dims] + [num_features] + dtype, inputs = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float", full=True), + shape=x_shape, + min_value=0, + max_value=1, + ).filter(lambda x: x[0][0] not in ["float64"]) + ) + return dtype, inputs, num_features + + +# --- Main --- # +# ------------ # + + @handle_method( - method_tree="LayerNorm.__call__", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), min_num_dims=2 - ), - new_std=st.floats(min_value=0.0, max_value=1.0), + method_tree="BatchNorm2D.__call__", + dtype_and_x_features=_generate_batchnorm_data(), + momentum=st.floats(min_value=0.0, max_value=1.0, exclude_min=True), init_with_v=st.booleans(), method_with_v=st.booleans(), ) -def test_layer_norm_layer( +def test_batch_norm_2d_layer( *, - dtype_and_x, - new_std, + dtype_and_x_features, + momentum, init_with_v, method_with_v, test_gradients, @@ -33,17 +57,18 @@ def test_layer_norm_layer( init_flags, method_flags, ): - input_dtype, x = dtype_and_x + input_dtype, x, features = dtype_and_x_features helpers.test_method( backend_to_test=backend_fw, ground_truth_backend=ground_truth_backend, init_flags=init_flags, method_flags=method_flags, init_all_as_kwargs_np={ - "normalized_shape": x[0].shape, + "num_features": features, "eps": ivy.min_base, - "elementwise_affine": True, - "new_std": new_std, + "affine": True, + "momentum": momentum, + "track_running_stats": True, "device": on_device, "dtype": input_dtype[0], }, @@ -54,39 +79,25 @@ def test_layer_norm_layer( init_with_v=init_with_v, method_with_v=method_with_v, test_gradients=test_gradients, + rtol_=1e-02, + atol_=1e-02, on_device=on_device, ) -@st.composite -def _generate_batchnorm_data(draw): - batch_size = draw(st.integers(min_value=2, max_value=5)) - num_features = draw(st.integers(min_value=2, max_value=3)) - num_dims = draw(st.integers(min_value=1, max_value=3)) - dims = [draw(st.integers(1, 5)) for i in range(num_dims)] - x_shape = [batch_size] + [*dims] + [num_features] - dtype, inputs = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float", full=True), - shape=x_shape, - min_value=0, - max_value=1, - ).filter(lambda x: x[0][0] not in ["float64"]) - ) - return dtype, inputs, num_features - - @handle_method( - method_tree="BatchNorm2D.__call__", - dtype_and_x_features=_generate_batchnorm_data(), - momentum=st.floats(min_value=0.0, max_value=1.0, exclude_min=True), + method_tree="LayerNorm.__call__", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), min_num_dims=2 + ), + new_std=st.floats(min_value=0.0, max_value=1.0), init_with_v=st.booleans(), method_with_v=st.booleans(), ) -def test_batch_norm_2d_layer( +def test_layer_norm_layer( *, - dtype_and_x_features, - momentum, + dtype_and_x, + new_std, init_with_v, method_with_v, test_gradients, @@ -98,18 +109,17 @@ def test_batch_norm_2d_layer( init_flags, method_flags, ): - input_dtype, x, features = dtype_and_x_features + input_dtype, x = dtype_and_x helpers.test_method( backend_to_test=backend_fw, ground_truth_backend=ground_truth_backend, init_flags=init_flags, method_flags=method_flags, init_all_as_kwargs_np={ - "num_features": features, + "normalized_shape": x[0].shape, "eps": ivy.min_base, - "affine": True, - "momentum": momentum, - "track_running_stats": True, + "elementwise_affine": True, + "new_std": new_std, "device": on_device, "dtype": input_dtype[0], }, @@ -120,7 +130,5 @@ def test_batch_norm_2d_layer( init_with_v=init_with_v, method_with_v=method_with_v, test_gradients=test_gradients, - rtol_=1e-02, - atol_=1e-02, on_device=on_device, ) diff --git a/ivy_tests/test_ivy/test_stateful/test_optimizers.py b/ivy_tests/test_ivy/test_stateful/test_optimizers.py index c10c41d477d81..9b566531a7e73 100644 --- a/ivy_tests/test_ivy/test_stateful/test_optimizers.py +++ b/ivy_tests/test_ivy/test_stateful/test_optimizers.py @@ -1,4 +1,3 @@ -# For Review """Collection of tests for Ivy optimizers.""" # global @@ -12,21 +11,28 @@ ) -# sgd +# adam @handle_method( - method_tree="SGD._step", + method_tree="Adam._step", dtype_x_lr=get_gradient_arguments_with_lr( - min_value=-1e5, - max_value=1e5, + min_value=1e-05, + max_value=1e08, num_arrays=2, float_lr=True, + large_abs_safety_factor=2, + small_abs_safety_factor=2, + ), + beta1_n_beta2_n_epsilon=helpers.list_of_size( + x=helpers.floats(min_value=1e-1, max_value=1), + size=3, ), inplace=st.booleans(), stop_gradients=st.booleans(), test_gradients=st.just(True), ) -def test_sgd_optimizer( +def test_adam_optimizer( dtype_x_lr, + beta1_n_beta2_n_epsilon, inplace, stop_gradients, on_device, @@ -34,11 +40,12 @@ def test_sgd_optimizer( method_name, backend_fw, ground_truth_backend, + test_gradients, init_flags, method_flags, - test_gradients, ): input_dtype, x, lr = dtype_x_lr + beta1, beta2, epsilon = beta1_n_beta2_n_epsilon xs_grad_idxs = [[0, 0]] if method_flags.num_positional_args else [[1, "v"]] helpers.test_method( backend_to_test=backend_fw, @@ -47,6 +54,9 @@ def test_sgd_optimizer( method_flags=method_flags, init_all_as_kwargs_np={ "lr": lr, + "beta1": beta1, + "beta2": beta2, + "epsilon": epsilon, "inplace": inplace, "stop_gradients": stop_gradients, }, @@ -57,26 +67,42 @@ def test_sgd_optimizer( }, class_name=class_name, method_name=method_name, - rtol_=1e-2, - atol_=1e-2, + rtol_=1e-1, + atol_=1e-1, test_gradients=test_gradients, xs_grad_idxs=xs_grad_idxs, on_device=on_device, ) -# lars +# lamb @handle_method( - method_tree="LARS._step", - dtype_x_lr=get_gradient_arguments_with_lr(num_arrays=2, float_lr=True), + method_tree="LAMB._step", + dtype_x_lr=get_gradient_arguments_with_lr( + min_value=-1e5, + max_value=1e5, + num_arrays=2, + float_lr=True, + ), + beta1_n_beta2_n_epsilon_n_lambda=helpers.list_of_size( + x=helpers.floats( + min_value=1e-2, + max_value=1.0, + ), + size=4, + ), + mtr=st.one_of( + helpers.ints(min_value=1, max_value=10), + st.floats(min_value=1e-2, max_value=10, exclude_min=True), + ), inplace=st.booleans(), - decay_lambda=helpers.floats(min_value=1e-2, max_value=1.0), stop_gradients=st.booleans(), test_gradients=st.just(True), ) -def test_lars_optimizer( +def test_lamb_optimizer( dtype_x_lr, - decay_lambda, + beta1_n_beta2_n_epsilon_n_lambda, + mtr, inplace, stop_gradients, on_device, @@ -89,8 +115,7 @@ def test_lars_optimizer( test_gradients, ): input_dtype, x, lr = dtype_x_lr - if "bfloat16" in input_dtype: - test_gradients = False + beta1, beta2, epsilon, decay_lambda = beta1_n_beta2_n_epsilon_n_lambda xs_grad_idxs = [[0, 0]] if method_flags.num_positional_args else [[1, "v"]] helpers.test_method( backend_to_test=backend_fw, @@ -99,6 +124,10 @@ def test_lars_optimizer( method_flags=method_flags, init_all_as_kwargs_np={ "lr": lr, + "beta1": beta1, + "beta2": beta2, + "epsilon": epsilon, + "max_trust_ratio": mtr, "decay_lambda": decay_lambda, "inplace": inplace, "stop_gradients": stop_gradients, @@ -118,28 +147,18 @@ def test_lars_optimizer( ) -# adam +# lars @handle_method( - method_tree="Adam._step", - dtype_x_lr=get_gradient_arguments_with_lr( - min_value=1e-05, - max_value=1e08, - num_arrays=2, - float_lr=True, - large_abs_safety_factor=2, - small_abs_safety_factor=2, - ), - beta1_n_beta2_n_epsilon=helpers.list_of_size( - x=helpers.floats(min_value=1e-1, max_value=1), - size=3, - ), + method_tree="LARS._step", + dtype_x_lr=get_gradient_arguments_with_lr(num_arrays=2, float_lr=True), inplace=st.booleans(), + decay_lambda=helpers.floats(min_value=1e-2, max_value=1.0), stop_gradients=st.booleans(), test_gradients=st.just(True), ) -def test_adam_optimizer( +def test_lars_optimizer( dtype_x_lr, - beta1_n_beta2_n_epsilon, + decay_lambda, inplace, stop_gradients, on_device, @@ -147,12 +166,13 @@ def test_adam_optimizer( method_name, backend_fw, ground_truth_backend, - test_gradients, init_flags, method_flags, + test_gradients, ): input_dtype, x, lr = dtype_x_lr - beta1, beta2, epsilon = beta1_n_beta2_n_epsilon + if "bfloat16" in input_dtype: + test_gradients = False xs_grad_idxs = [[0, 0]] if method_flags.num_positional_args else [[1, "v"]] helpers.test_method( backend_to_test=backend_fw, @@ -161,9 +181,7 @@ def test_adam_optimizer( method_flags=method_flags, init_all_as_kwargs_np={ "lr": lr, - "beta1": beta1, - "beta2": beta2, - "epsilon": epsilon, + "decay_lambda": decay_lambda, "inplace": inplace, "stop_gradients": stop_gradients, }, @@ -182,34 +200,21 @@ def test_adam_optimizer( ) -# lamb +# sgd @handle_method( - method_tree="LAMB._step", + method_tree="SGD._step", dtype_x_lr=get_gradient_arguments_with_lr( min_value=-1e5, max_value=1e5, num_arrays=2, float_lr=True, ), - beta1_n_beta2_n_epsilon_n_lambda=helpers.list_of_size( - x=helpers.floats( - min_value=1e-2, - max_value=1.0, - ), - size=4, - ), - mtr=st.one_of( - helpers.ints(min_value=1, max_value=10), - st.floats(min_value=1e-2, max_value=10, exclude_min=True), - ), inplace=st.booleans(), stop_gradients=st.booleans(), test_gradients=st.just(True), ) -def test_lamb_optimizer( +def test_sgd_optimizer( dtype_x_lr, - beta1_n_beta2_n_epsilon_n_lambda, - mtr, inplace, stop_gradients, on_device, @@ -222,7 +227,6 @@ def test_lamb_optimizer( test_gradients, ): input_dtype, x, lr = dtype_x_lr - beta1, beta2, epsilon, decay_lambda = beta1_n_beta2_n_epsilon_n_lambda xs_grad_idxs = [[0, 0]] if method_flags.num_positional_args else [[1, "v"]] helpers.test_method( backend_to_test=backend_fw, @@ -231,11 +235,6 @@ def test_lamb_optimizer( method_flags=method_flags, init_all_as_kwargs_np={ "lr": lr, - "beta1": beta1, - "beta2": beta2, - "epsilon": epsilon, - "max_trust_ratio": mtr, - "decay_lambda": decay_lambda, "inplace": inplace, "stop_gradients": stop_gradients, }, @@ -246,8 +245,8 @@ def test_lamb_optimizer( }, class_name=class_name, method_name=method_name, - rtol_=1e-1, - atol_=1e-1, + rtol_=1e-2, + atol_=1e-2, test_gradients=test_gradients, xs_grad_idxs=xs_grad_idxs, on_device=on_device, diff --git a/ivy_tests/test_ivy/test_stateful/test_sequential.py b/ivy_tests/test_ivy/test_stateful/test_sequential.py index 957f0c72ce2f9..1a9223c309987 100644 --- a/ivy_tests/test_ivy/test_stateful/test_sequential.py +++ b/ivy_tests/test_ivy/test_stateful/test_sequential.py @@ -10,6 +10,27 @@ from ivy_tests.test_ivy.helpers.testing_helpers import handle_method +class TrainableModule(ivy.Module): + def __init__(self, in_size, hidden_size, out_size): + self._linear0 = ivy.Linear(in_size, hidden_size) + self._linear1 = ivy.Linear(hidden_size, out_size) + ivy.Module.__init__(self) + + def _forward(self, x): + x = self._linear0(x) + return self._linear1(x) + + +# --- Helpers --- # +# --------------- # + + +def _copy_weights(v1, v2): + # copy weights from layer1 to layer2 + v2.w = ivy.copy_array(v1.w) + v2.b = ivy.copy_array(v1.b) + + # Helpers # ########### def _train(module, input_arr): @@ -31,10 +52,8 @@ def loss_fn(_v): return losses -def _copy_weights(v1, v2): - # copy weights from layer1 to layer2 - v2.w = ivy.copy_array(v1.w) - v2.b = ivy.copy_array(v1.b) +# --- Main --- # +# ------------ # @handle_method( @@ -69,17 +88,6 @@ def test_sequential_construction_and_value( _train(module, input_array) -class TrainableModule(ivy.Module): - def __init__(self, in_size, hidden_size, out_size): - self._linear0 = ivy.Linear(in_size, hidden_size) - self._linear1 = ivy.Linear(hidden_size, out_size) - ivy.Module.__init__(self) - - def _forward(self, x): - x = self._linear0(x) - return self._linear1(x) - - @handle_method( method_tree="Sequential.__call__", input_array=st.lists( From e9a6e88eab99aaf2aa19e47d8898d5695defe360 Mon Sep 17 00:00:00 2001 From: AnnaTz <111577222+AnnaTz@users.noreply.github.com> Date: Tue, 29 Aug 2023 13:31:35 +0100 Subject: [PATCH 47/55] add(ivy): Adds paddle backend implementation to ivy.prod. --- ivy/functional/backends/paddle/statistical.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/ivy/functional/backends/paddle/statistical.py b/ivy/functional/backends/paddle/statistical.py index 41080a6b9032c..2d3019e03eae6 100644 --- a/ivy/functional/backends/paddle/statistical.py +++ b/ivy/functional/backends/paddle/statistical.py @@ -137,9 +137,17 @@ def prod( keepdims: bool = False, out: Optional[paddle.Tensor] = None, ) -> paddle.Tensor: - raise IvyNotImplementedException() - # TODO:prod causes segmentation fault - return paddle.prod(x, axis=axis, keepdim=keepdims, dtype=dtype) + x_dtype= x.dtype + supported_dtypes = ['int32', 'int64', 'float32', 'float64'] + if str(x_dtype) not in supported_dtypes: + x = x.cast("float32") + dtype_ = dtype + if str(dtype) not in supported_dtypes: + dtype = None + ret = paddle.prod(x, axis=axis, keepdim=keepdims, dtype=dtype) + if ret.dtype != dtype_: + ret = ret.cast(dtype_) + return ret def _std(x, axis, correction, keepdim): From d90069b9ba2e35ff2a286a471858854318146d91 Mon Sep 17 00:00:00 2001 From: Bhushan Srivastava <59949692+he11owthere@users.noreply.github.com> Date: Tue, 29 Aug 2023 18:44:56 +0530 Subject: [PATCH 48/55] feat: ddded copysign_ method to the PyTorch frontend. (#22713) Co-authored-by: juliagsy <67888047+juliagsy@users.noreply.github.com> --- ivy/functional/frontends/torch/tensor.py | 7 ++++ .../test_frontends/test_torch/test_tensor.py | 39 +++++++++++++++++++ 2 files changed, 46 insertions(+) diff --git a/ivy/functional/frontends/torch/tensor.py b/ivy/functional/frontends/torch/tensor.py index 53b313b347ab8..c5d0310722798 100644 --- a/ivy/functional/frontends/torch/tensor.py +++ b/ivy/functional/frontends/torch/tensor.py @@ -1556,6 +1556,13 @@ def logdet(self): def copysign(self, other, *, out=None): return torch_frontend.copysign(self, other, out=out) + @with_supported_dtypes( + {"2.0.1 and below": ("float16", "float32", "float64")}, "torch" + ) + def copysign_(self, other, *, out=None): + self.ivy_array = self.copysign(other, out=out).ivy_array + return self + @with_unsupported_dtypes( {"2.0.1 and below": ("complex", "bfloat16", "bool")}, "torch" ) diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py index 04e3b112fbcef..3da42e4771560 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py @@ -5135,6 +5135,45 @@ def test_torch_tensor_copysign( ) +# copysign_ +@handle_frontend_method( + class_tree=CLASS_TREE, + init_tree="torch.tensor", + method_name="copysign_", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=1, + num_arrays=2, + ), +) +def test_torch_tensor_copysign_( + dtype_and_x, + frontend_method_data, + init_flags, + method_flags, + frontend, + on_device, + backend_fw, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_method( + init_input_dtypes=input_dtype, + backend_to_test=backend_fw, + init_all_as_kwargs_np={ + "data": x[0], + }, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={ + "other": x[1], + }, + frontend_method_data=frontend_method_data, + init_flags=init_flags, + method_flags=method_flags, + frontend=frontend, + on_device=on_device, + ) + + # cos @handle_frontend_method( class_tree=CLASS_TREE, From 18831b13df2bc3d3ede0d60843bd9f5e7cb78c61 Mon Sep 17 00:00:00 2001 From: AnnaTz <111577222+AnnaTz@users.noreply.github.com> Date: Tue, 29 Aug 2023 14:16:08 +0100 Subject: [PATCH 49/55] fix(ivy): Fixes remaining bugs in the paddle backend of scatter_nd and improves precision (#22738) --- ivy/functional/backends/paddle/general.py | 89 +++++++++++-------- .../hypothesis_helpers/array_helpers.py | 1 + .../test_functional/test_core/test_general.py | 1 - 3 files changed, 51 insertions(+), 40 deletions(-) diff --git a/ivy/functional/backends/paddle/general.py b/ivy/functional/backends/paddle/general.py index 7d0176b332935..2aeb55012cbd2 100644 --- a/ivy/functional/backends/paddle/general.py +++ b/ivy/functional/backends/paddle/general.py @@ -405,17 +405,32 @@ def scatter_nd( if ivy.exists(out) else list(indices.shape[:-1]) + list(shape[indices.shape[-1] :]) ) - updates = _broadcast_to(updates, expected_shape)._data - - if indices.ndim > 1: - indices, unique_idxs = ivy.unique_all(indices, axis=0)[:2] - indices, unique_idxs = indices.data, unique_idxs.data - updates = ivy.gather(updates, unique_idxs, axis=0).data + updates = _broadcast_to(updates, expected_shape).data + + # remove duplicate indices + # necessary because we will be using scatter_nd_add + if indices.ndim > 1 and reduction != "sum": + indices_shape = indices.shape + indices = paddle.reshape(indices, (-1, indices.shape[-1])) + num_indices = indices.shape[0] + # use flip to keep the last occurrence of each value + indices, unique_idxs = ivy.unique_all(ivy.flip(indices, axis=[0]), axis=0, by_value=True)[:2] + indices = indices.data + if len(unique_idxs) < num_indices: + updates = paddle.reshape(updates, (-1, *updates.shape[len(indices_shape)-1:])) + updates = ivy.gather(ivy.flip(updates, axis=[0]), unique_idxs, axis=0).data + expected_shape = ( + list(indices.shape[:-1]) + list(out.shape[indices.shape[-1]:]) + if ivy.exists(out) + else list(indices.shape[:-1]) + list(shape[indices.shape[-1]:]) + ) + else: + indices = paddle.reshape(indices, indices_shape) # implementation target_given = ivy.exists(out) if target_given: - target = out._data + target = out.data else: shape = list(shape) if ivy.exists(shape) else out.shape target = paddle.zeros(shape=shape).astype(updates.dtype) @@ -429,26 +444,30 @@ def scatter_nd( '"sum", "min", "max" or "replace"'.format(reduction) ) if reduction == "min": - updates = ivy.minimum(ivy.gather_nd(target, indices), updates)._data - reduction = "replace" + updates = ivy.minimum(ivy.gather_nd(target, indices), updates).data elif reduction == "max": - updates = ivy.maximum(ivy.gather_nd(target, indices), updates)._data - reduction = "replace" + updates = ivy.maximum(ivy.gather_nd(target, indices), updates).data + elif reduction == "sum": + updates = ivy.add(ivy.gather_nd(target, indices), updates).data if indices.ndim <= 1: - indices = ivy.expand_dims(indices, axis=0)._data - updates = ivy.expand_dims(updates, axis=0)._data + indices = ivy.expand_dims(indices, axis=0).data + updates = ivy.expand_dims(updates, axis=0).data + updates_ = _broadcast_to(ivy.gather_nd(target, indices), expected_shape).data target_dtype = target.dtype if target_dtype in [ paddle.complex64, paddle.complex128, ]: - if reduction == "replace": - updates = paddle_backend.subtract( - updates, - paddle_backend.gather_nd(target, indices), - ) - result_real = paddle.scatter_nd_add(target.real(), indices, updates.real()) - result_imag = paddle.scatter_nd_add(target.imag(), indices, updates.imag()) + result_real = paddle.scatter_nd_add( + paddle.scatter_nd_add(target.real(), indices, -updates_.real()), + indices, + updates.real(), + ) + result_imag = paddle.scatter_nd_add( + paddle.scatter_nd_add(target.imag(), indices, -updates_.imag()), + indices, + updates.imag(), + ) ret = paddle.complex(result_real, result_imag) elif target_dtype in [ paddle.int8, @@ -457,26 +476,18 @@ def scatter_nd( paddle.float16, paddle.bool, ]: - if reduction == "replace": - updates = paddle.subtract( - updates.cast("float32"), - paddle.gather_nd(target.cast("float32"), indices), - ) - ret = paddle.scatter_nd_add(target.cast("float32"), indices, updates).cast( - target_dtype - ) + target, updates, updates_ = target.cast("float32"), updates.cast("float32"), updates_.cast("float32") + ret = paddle.scatter_nd_add( + paddle.scatter_nd_add(target, indices, -updates_), + indices, + updates, + ).cast(target_dtype) else: - if reduction == "replace": - gathered_vals = paddle.gather_nd(target, indices) - # values greater than 2^24 - 1 can only be accurately represented as float64 - if (np.abs(gathered_vals.numpy()).max() >= 2**24) or ( - np.abs(updates.numpy()).max() >= 2**24 - ): - gathered_vals = gathered_vals.cast("float64") - target = target.cast("float64") - updates = updates.cast("float64") - updates = paddle.subtract(updates, gathered_vals) - ret = paddle.scatter_nd_add(target, indices, updates).cast(target_dtype) + ret = paddle.scatter_nd_add( + paddle.scatter_nd_add(target, indices, -updates_), + indices, + updates, + ) if ivy.exists(out): return ivy.inplace_update(out, ret) return ret diff --git a/ivy_tests/test_ivy/helpers/hypothesis_helpers/array_helpers.py b/ivy_tests/test_ivy/helpers/hypothesis_helpers/array_helpers.py index dd63e38973349..ffffb40b4fd06 100644 --- a/ivy_tests/test_ivy/helpers/hypothesis_helpers/array_helpers.py +++ b/ivy_tests/test_ivy/helpers/hypothesis_helpers/array_helpers.py @@ -1881,6 +1881,7 @@ def dtype_array_query( min_value=-s + 1, max_value=s - 1, dtype=["int64"], + max_num_dims=4, ) ) new_index = new_index[0] diff --git a/ivy_tests/test_ivy/test_functional/test_core/test_general.py b/ivy_tests/test_ivy/test_functional/test_core/test_general.py index 1c4aa8fe876a6..e2fd880aff6a5 100644 --- a/ivy_tests/test_ivy/test_functional/test_core/test_general.py +++ b/ivy_tests/test_ivy/test_functional/test_core/test_general.py @@ -1643,7 +1643,6 @@ def test_set_item( on_device=on_device, backend_to_test=backend_fw, fn_name=fn_name, - rtol_=1e-03, # needed only for the paddle backend x=x, query=query, val=val, From 43f429117515d18e4fcb039de8087b6c2ad19428 Mon Sep 17 00:00:00 2001 From: Bhushan Srivastava <59949692+he11owthere@users.noreply.github.com> Date: Tue, 29 Aug 2023 20:09:38 +0530 Subject: [PATCH 50/55] Fixed erf_ method in the PyTorch frontend (#22718) Co-authored-by: Sam Armstrong <88863522+Sam-Armstrong@users.noreply.github.com> --- ivy/functional/frontends/torch/tensor.py | 7 +++++-- .../test_ivy/test_frontends/test_torch/test_tensor.py | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/ivy/functional/frontends/torch/tensor.py b/ivy/functional/frontends/torch/tensor.py index c5d0310722798..ac680497c5b1c 100644 --- a/ivy/functional/frontends/torch/tensor.py +++ b/ivy/functional/frontends/torch/tensor.py @@ -495,9 +495,12 @@ def equal(self, other): def erf(self, *, out=None): return torch_frontend.erf(self, out=out) - @with_unsupported_dtypes({"2.0.1 and below": ("float16", "complex")}, "torch") + @with_supported_dtypes( + {"2.0.1 and below": ("float32", "float64", "bfloat16")}, "torch" + ) def erf_(self, *, out=None): - self.ivy_array = torch_frontend.erf(self, out=out).ivy_array + self.ivy_array = self.erf(out=out).ivy_array + return self def new_zeros( self, diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py index 3da42e4771560..e252e70c6c6ea 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py @@ -6180,7 +6180,7 @@ def test_torch_tensor_erf( init_tree="torch.tensor", method_name="erf_", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("valid"), ), ) def test_torch_tensor_erf_( From 377bd213829d81d201e2f96cdab64074f87fa829 Mon Sep 17 00:00:00 2001 From: MahmoudAshraf97 Date: Tue, 29 Aug 2023 15:25:49 +0000 Subject: [PATCH 51/55] fix(ivy): fix `ivy.quantile` not accepting float `q` in some cases. added float to test generation, the implementation still contains some bugs especially with paddle backend --- .../numpy/experimental/statistical.py | 41 ++++--- .../paddle/experimental/statistical.py | 102 +++++++++--------- .../tensorflow/experimental/statistical.py | 38 +++---- .../torch/experimental/statistical.py | 39 +++---- .../test_core/test_statistical.py | 28 +++-- 5 files changed, 123 insertions(+), 125 deletions(-) diff --git a/ivy/functional/backends/numpy/experimental/statistical.py b/ivy/functional/backends/numpy/experimental/statistical.py index 4c244f3b3aa58..12c3565148ec7 100644 --- a/ivy/functional/backends/numpy/experimental/statistical.py +++ b/ivy/functional/backends/numpy/experimental/statistical.py @@ -168,6 +168,8 @@ def nanmean( def _validate_quantile(q): + if isinstance(q, float): + q = np.asarray(q) if q.ndim == 1 and q.size < 10: for i in range(q.size): if not (0.0 <= q[i] <= 1.0): @@ -226,40 +228,37 @@ def _handle_axis(a, q, fn, keepdims=False, axis=None): def _quantile(a, q, axis=None): + if isinstance(q, float): + q = np.asarray(q) ret_dtype = a.dtype - if q.ndim > 2: + if q.ndim > 1: raise ValueError("q argument must be a scalar or 1-dimensional!") if axis is None: axis = 0 a = a.flatten() + elif axis != 0: + a = np.moveaxis(a, axis, 0) + axis = 0 n = a.shape[axis] - if axis != 0: - a = np.moveaxis(a, axis, 0) + indices = q * (n - 1) - indices = [] - for q_num in q: - index = q_num * (n - 1) - indices.append(index) + a.sort(axis) - a.sort(0) - outputs = [] + indices_below = np.floor(indices).astype(np.int32) + indices_upper = np.ceil(indices).astype(np.int32) - for index in indices: - indices_below = np.floor(index).astype(np.int32) - indices_upper = np.ceil(index).astype(np.int32) + weights = indices - indices_below.astype("float64") - weights = index - indices_below.astype("float64") + indices_below = np.clip(indices_below, 0, n - 1) + indices_upper = np.clip(indices_upper, 0, n - 1) + tensor_upper = np.take(a, indices_upper, axis=axis) # , mode="clip") + tensor_below = np.take(a, indices_below, axis=axis) # , mode="clip") - indices_below = np.clip(indices_below, 0, n - 1) - indices_upper = np.clip(indices_upper, 0, n - 1) - tensor_upper = np.take(a, indices_upper, axis=0) # , mode="clip") - tensor_below = np.take(a, indices_below, axis=0) # , mode="clip") + pred = weights <= 0.5 + out = np.where(pred, tensor_below, tensor_upper) - pred = weights <= 0.5 - out = np.where(pred, tensor_below, tensor_upper) - outputs.append(out) - return np.array(outputs, dtype=ret_dtype) + return out.astype(ret_dtype) def _compute_quantile_wrapper( diff --git a/ivy/functional/backends/paddle/experimental/statistical.py b/ivy/functional/backends/paddle/experimental/statistical.py index fb3dd5f6ed637..0cc4e40739733 100644 --- a/ivy/functional/backends/paddle/experimental/statistical.py +++ b/ivy/functional/backends/paddle/experimental/statistical.py @@ -97,6 +97,8 @@ def nanmean( def _validate_quantile(q): + if isinstance(q, float): + q = paddle.to_tensor(q) if q.ndim == 1 and q.size < 10: for i in range(q.size): if not (0.0 <= q[i] <= 1.0): @@ -155,66 +157,63 @@ def _handle_axis(a, q, fn, keepdims=False, axis=None, interpolation="nearest"): else: index_ret = tuple(None if i in axis else slice(None) for i in range(nd)) ret = ret[(Ellipsis,) + index_ret] - + # if keepdims: + # axis = axis if axis is not None else list(range(a.ndim)) + # ret = ret.unsqueeze(axis) return ret def _quantile(a, q, axis=None, interpolation="nearest"): + if isinstance(q, float): + q = paddle.to_tensor(q) ret_dtype = a.dtype - if q.ndim > 2: + if q.ndim > 1: raise ValueError("q argument must be a scalar or 1-dimensional!") if axis is None: axis = 0 a = paddle.flatten(a) + elif axis != 0: + a = a.moveaxis(axis, 0) + axis = 0 n = a.shape[axis] - if axis != 0: - a = paddle.moveaxis(a, axis, 0) - indices = [] - for q_num in q: - index = q_num * (n - 1) - indices.append(index) - - a = paddle.sort(a, 0) - outputs = [] - - for index in indices: - if interpolation == "lower": - index = paddle.floor(index) - elif interpolation == "higher": - index = paddle.ceil(index) - elif interpolation == "nearest": - index = paddle.round(index) - elif interpolation == "midpoint": - index_floor = paddle.floor(index) - index_ceil = paddle.ceil(index) - index = (index_ceil + index_floor) / 2 - - indices_below = paddle.floor(index).astype(paddle.int32) - indices_upper = paddle.ceil(index).astype(paddle.int32) - - if interpolation == "nearest_jax": - weights = index - indices_below.astype(paddle.float64) - - indices_below = paddle.clip(indices_below, 0, n - 1) - indices_upper = paddle.clip(indices_upper, 0, n - 1) - tensor_upper = paddle.gather(a, indices_upper, axis=0) - tensor_below = paddle.gather(a, indices_below, axis=0) - - pred = weights <= 0.5 - out = paddle.where(pred, tensor_below, tensor_upper) - else: - tensor_upper = paddle.gather(a, indices_upper, axis=0) - tensor_below = paddle.gather(a, indices_below, axis=0) - weights = index - indices_below.astype(paddle.float64) - out = paddle.lerp( - tensor_below.astype(paddle.float64), - tensor_upper.astype(paddle.float64), - weights.astype(paddle.float64), - ) - outputs.append(out) - return paddle.concat(outputs, axis=0).astype(ret_dtype) + indices = q * (n - 1) + + a = paddle.sort(a, axis) + + if interpolation == "lower": + indices = paddle.floor(indices) + elif interpolation == "higher": + indices = paddle.ceil(indices) + elif interpolation == "nearest": + indices = paddle.round(indices) + elif interpolation == "midpoint": + index_floor = paddle.floor(indices) + index_ceil = paddle.ceil(indices) + indices = (index_ceil + index_floor) / 2 + + indices_below = paddle.floor(indices).astype(paddle.int32) + indices_upper = paddle.ceil(indices).astype(paddle.int32) + weights = indices - indices_below.astype(paddle.float64) + if interpolation == "nearest_jax": + indices_below = paddle.clip(indices_below, 0, n - 1) + indices_upper = paddle.clip(indices_upper, 0, n - 1) + tensor_upper = paddle.gather(a, indices_upper, axis=axis) + tensor_below = paddle.gather(a, indices_below, axis=axis) + + pred = weights <= 0.5 + out = paddle.where(pred, tensor_below, tensor_upper) + else: + tensor_upper = paddle.gather(a, indices_upper, axis=axis) + tensor_below = paddle.gather(a, indices_below, axis=axis) + out = paddle.lerp( + tensor_below.astype(paddle.float64), + tensor_upper.astype(paddle.float64), + weights.astype(paddle.float64), + ) + + return out.astype(ret_dtype) def _compute_quantile_wrapper( @@ -546,10 +545,7 @@ def __find_cummax( if ( isinstance(x.tolist()[0], list) and len(x[0].shape) >= 1 - and ( - (type(x[0]) == paddle.Tensor) - or (type(x[0]) == ivy.data_classes.array.array.Array) - ) + and (isinstance(x[0], paddle.Tensor) or isinstance(x[0], ivy.Array)) ): if axis >= 1: if not isinstance(x, list): @@ -599,7 +595,7 @@ def __find_cummax( values.append(y) indices.append(n) - if type(x) == paddle.Tensor: + if isinstance(x, paddle.Tensor): return paddle.to_tensor(values, dtype=x.dtype), paddle.to_tensor( indices, dtype="int64" ) diff --git a/ivy/functional/backends/tensorflow/experimental/statistical.py b/ivy/functional/backends/tensorflow/experimental/statistical.py index bd4c313d107f8..9cd91437a5f4a 100644 --- a/ivy/functional/backends/tensorflow/experimental/statistical.py +++ b/ivy/functional/backends/tensorflow/experimental/statistical.py @@ -181,39 +181,35 @@ def _handle_axis(a, q, fn, keepdims=False, axis=None): def _quantile(a, q, axis=None): ret_dtype = a.dtype - if tf.experimental.numpy.ndim(q) > 2: + if tf.experimental.numpy.ndim(q) > 1: raise ValueError("q argument must be a scalar or 1-dimensional!") if axis is None: axis = 0 a = tf.reshape(a, [-1]) + elif axis != 0: + a = tf.experimental.numpy.moveaxis(a, axis, 0) + axis = 0 n = a.shape[axis] - if axis != 0: - a = tf.experimental.numpy.moveaxis(a, axis, 0) - indices = [] - for q_num in q: - index = q_num * (n - 1) - indices.append(index) + indices = q * (n - 1) + + a = tf.sort(a, axis) - a = tf.sort(a, 0) - outputs = [] + indices_below = tf.cast(tf.math.floor(indices), dtype=tf.int32) + indices_upper = tf.cast(tf.math.ceil(indices), dtype=tf.int32) - for index in indices: - indices_below = tf.cast(tf.math.floor(index), dtype=tf.int32) - indices_upper = tf.cast(tf.math.ceil(index), dtype=tf.int32) + weights = indices - tf.cast(indices_below, dtype=ret_dtype) - weights = index - tf.cast(indices_below, dtype=ret_dtype) + indices_below = tf.clip_by_value(indices_below, 0, n - 1) + indices_upper = tf.clip_by_value(indices_upper, 0, n - 1) + tensor_upper = tf.gather(a, indices_upper, axis=axis) + tensor_below = tf.gather(a, indices_below, axis=axis) - indices_below = tf.clip_by_value(indices_below, 0, n - 1) - indices_upper = tf.clip_by_value(indices_upper, 0, n - 1) - tensor_upper = tf.gather(a, indices_upper, axis=0) - tensor_below = tf.gather(a, indices_below, axis=0) + pred = weights <= 0.5 + out = tf.where(pred, tensor_below, tensor_upper) - pred = weights <= 0.5 - out = tf.where(pred, tensor_below, tensor_upper) - outputs.append(out) - return tf.convert_to_tensor(outputs, dtype=ret_dtype) + return tf.cast(out, ret_dtype) def _compute_quantile_wrapper( diff --git a/ivy/functional/backends/torch/experimental/statistical.py b/ivy/functional/backends/torch/experimental/statistical.py index 6210e1c05710a..cc1d4d5fa05ce 100644 --- a/ivy/functional/backends/torch/experimental/statistical.py +++ b/ivy/functional/backends/torch/experimental/statistical.py @@ -186,6 +186,8 @@ def nanmean( def _validate_quantile(q): + if isinstance(q, float): + q = torch.as_tensor(q) if q.ndim == 1 and torch.numel(q) < 10: for i in range(torch.numel(q)): if not (0.0 <= q[i] <= 1.0): @@ -250,39 +252,32 @@ def _handle_axis(a, q, fn, keepdims=False, axis=None): def _quantile(a, q, axis=None): ret_dtype = a.dtype - if q.ndim > 2: + if isinstance(q, float): + q = torch.as_tensor(q) + if isinstance(q, torch.Tensor) and q.ndim > 1: raise ValueError("q argument must be a scalar or 1-dimensional!") if axis is None: axis = 0 a = a.flatten() n = a.shape[axis] - if axis != 0: - a = torch.moveaxis(a, axis, 0) - indices = [] - for q_num in q: - index = q_num * (n - 1) - indices.append(index) + indices = q * (n - 1) - a = torch.sort(a, 0)[0] - outputs = [] + a = torch.sort(a, axis)[axis] + indices_below = torch.floor(indices).to(torch.int64) + indices_upper = torch.ceil(indices).to(torch.int64) - for index in indices: - indices_below = torch.floor(index).to(torch.int64) - indices_upper = torch.ceil(index).to(torch.int64) + weights = indices - indices_below.to(torch.float64) - weights = index - indices_below.to(torch.float64) + indices_below = torch.clip(indices_below, 0, n - 1) + indices_upper = torch.clip(indices_upper, 0, n - 1) + tensor_upper = torch.index_select(a, 0, indices_upper) + tensor_below = torch.index_select(a, 0, indices_below) - indices_below = torch.clip(indices_below, 0, n - 1) - indices_upper = torch.clip(indices_upper, 0, n - 1) - tensor_upper = torch.index_select(a, 0, indices_upper) - tensor_below = torch.index_select(a, 0, indices_below) - - pred = weights <= 0.5 - out = torch.where(pred, tensor_below, tensor_upper) - outputs.append(out) - return torch.concat(outputs, dim=0).to(ret_dtype) + pred = weights <= 0.5 + out = torch.where(pred, tensor_below, tensor_upper) + return out.to(ret_dtype) def _compute_quantile_wrapper( diff --git a/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_statistical.py b/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_statistical.py index f43f00c3f4526..143c94db102ad 100644 --- a/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_statistical.py +++ b/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_statistical.py @@ -239,17 +239,27 @@ def _quantile_helper(draw): ) ) q = draw( - helpers.array_values( - dtype=helpers.get_dtypes("float"), - shape=helpers.get_shape(min_dim_size=1, max_num_dims=1, min_num_dims=1), - min_value=0.0, - max_value=1.0, - exclude_max=False, - exclude_min=False, + st.one_of( + helpers.array_values( + dtype=helpers.get_dtypes("float"), + shape=helpers.get_shape(min_dim_size=1, max_num_dims=1, min_num_dims=1), + min_value=0.0, + max_value=1.0, + exclude_max=False, + exclude_min=False, + ), + st.floats(min_value=0.0, max_value=1.0), ) ) - interpolation_names = ["linear", "lower", "higher", "midpoint", "nearest"] + interpolation_names = [ + "linear", + "lower", + "higher", + "midpoint", + "nearest", + "nearest_jax", + ] interpolation = draw( helpers.list_of_size( x=st.sampled_from(interpolation_names), @@ -636,4 +646,6 @@ def test_quantile( axis=axis, interpolation=interpolation[0], keepdims=keep_dims, + atol_=1e-3, + rtol_=1e-3, ) From 33d1157b17ebe9612f1b0e70172c6f8e2b78bbc5 Mon Sep 17 00:00:00 2001 From: saeedashrraf <48128381+saeedashrraf@users.noreply.github.com> Date: Tue, 29 Aug 2023 20:23:40 +0300 Subject: [PATCH 52/55] Updating the deep dive section for Explicitly Nestable Compositional Function --- docs/overview/deep_dive/containers.rst | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/docs/overview/deep_dive/containers.rst b/docs/overview/deep_dive/containers.rst index a1d7110fc778b..be36ead9f099c 100644 --- a/docs/overview/deep_dive/containers.rst +++ b/docs/overview/deep_dive/containers.rst @@ -204,7 +204,7 @@ The *nestable* behaviour is added to any function which is decorated with the `h This wrapper causes the function to be applied at each leaf of any containers passed in the input. More information on this can be found in the `Function Wrapping `_ section of the Deep Dive. -Additionally, any nestable function which returns multiple arrays, will return the same number of containers for it's container counterpart. +Additionally, any nestable function which returns multiple arrays, will return the same number of containers for its container counterpart. This property makes the function symmetric with regards to the input-output behavior, irrespective of whether :class:`ivy.Array` or :class:`ivy.Container` instances are based used. Any argument in the input can be replaced with a container without changing the number of inputs, and the presence or absence of ivy.Container instances in the input should not change the number of return values of the function. In other words, if containers are detected in the input, then we should return a separate container for each array that the function would otherwise return. @@ -246,8 +246,10 @@ The functions :func:`ivy.clip`, :func:`ivy.log`, :func:`ivy.sum` and :func:`ivy. Therefore, our approach is to **not** wrap any compositional functions which are already *implicitly nestable* as a result of the *nestable* functions called internally. +**Explicitly Nestable Compositional Functions** + There may be some compositional functions which are not implicitly nestable for some reason, and in such cases adding the explicit `handle_nestable `_ wrapping may be necessary. -One such example is the :func:`ivy.linear` function which is not implicitly nestable despite being compositional. This is because of the use of special functions like :func:`__len__` which is not nestable and can't be made nestable. +One such example is the :func:`ivy.linear` function which is not implicitly nestable despite being compositional. This is because of the use of special functions like :func:`__len__` and :func:`__list__` which, among other functions, are not nestable and can't be made nestable. But we should try to avoid this, in order to make the flow of computation as intuitive to the user as possible. When compiling the code, the computation graph is **identical** in either case, and there will be no implications on performance whatsoever. From fd2724be7606f7d85e550a5d274f8693ef4bd331 Mon Sep 17 00:00:00 2001 From: Nripesh Niketan <86844847+NripeshN@users.noreply.github.com> Date: Tue, 29 Aug 2023 23:22:49 +0530 Subject: [PATCH 53/55] Hotfix gpu_is_available torch --- ivy/functional/backends/torch/device.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ivy/functional/backends/torch/device.py b/ivy/functional/backends/torch/device.py index d115996dff2e9..f0f3607953ac9 100644 --- a/ivy/functional/backends/torch/device.py +++ b/ivy/functional/backends/torch/device.py @@ -96,8 +96,10 @@ def num_gpus() -> int: def gpu_is_available() -> bool: if hasattr(torch.backends, "mps"): - return torch.backends.mps.is_available() - return torch.cuda.is_available() + return torch.backends.mps.is_available() or torch.cuda.is_available() + elif torch.cuda.is_available(): + return True + return False # noinspection PyUnresolvedReferences From 8051fb673e8eb659f8dcab1abb9a64096cd75a62 Mon Sep 17 00:00:00 2001 From: Nripesh Niketan <86844847+NripeshN@users.noreply.github.com> Date: Wed, 30 Aug 2023 00:08:14 +0530 Subject: [PATCH 54/55] Hotfix(pre-commit) - update commit hash key --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a395e41b926f8..f4c4c298876fb 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -26,6 +26,6 @@ repos: # Exclude everything in frontends except __init__.py, and func_wrapper.py exclude: 'ivy/functional/(frontends|backends)/(?!.*/func_wrapper\.py$).*(?!__init__\.py$)' - repo: https://github.com/unifyai/lint-hook - rev: a90659d806c6d65f20ec41095a2da8e8920cc96f + rev: 27646397c5390f644a645f439535b1061b9c0105 hooks: - id: ivy-lint From 7a048c1ad7193bc3033a68c1c80f0dfd5d4e74df Mon Sep 17 00:00:00 2001 From: Bhushan Srivastava <59949692+he11owthere@users.noreply.github.com> Date: Wed, 30 Aug 2023 01:34:27 +0530 Subject: [PATCH 55/55] Added zero_ method to the Paddle frontend (#22476) --- .../frontends/paddle/tensor/tensor.py | 7 ++++ .../test_paddle/test_tensor/test_tensor.py | 36 +++++++++++++++++++ 2 files changed, 43 insertions(+) diff --git a/ivy/functional/frontends/paddle/tensor/tensor.py b/ivy/functional/frontends/paddle/tensor/tensor.py index 3037eacf755ab..b621b83084113 100644 --- a/ivy/functional/frontends/paddle/tensor/tensor.py +++ b/ivy/functional/frontends/paddle/tensor/tensor.py @@ -155,6 +155,13 @@ def sqrt_(self, name=None): self.ivy_array = self.sqrt().ivy_array return self + @with_unsupported_dtypes({"2.5.1 and below": ("bfloat16", "uint16")}, "paddle") + def zero_(self): + self.ivy_array = paddle_frontend.Tensor( + ivy.zeros_like(self._ivy_array) + ).ivy_array + return self + @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") def cos(self, name=None): return paddle_frontend.Tensor(ivy.cos(self._ivy_array)) diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_tensor.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_tensor.py index 2e9a57c783373..5a8ddd8609dfc 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_tensor.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_tensor.py @@ -2953,6 +2953,42 @@ def test_paddle_tensor_unsqueeze( ) +# zero_ +@handle_frontend_method( + class_tree=CLASS_TREE, + init_tree="paddle.to_tensor", + method_name="zero_", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + allow_inf=False, + ), +) +def test_paddle_tensor_zero_( + dtype_and_x, + frontend_method_data, + init_flags, + method_flags, + frontend, + on_device, + backend_fw, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_method( + init_input_dtypes=input_dtype, + backend_to_test=backend_fw, + init_all_as_kwargs_np={ + "data": x[0], + }, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={}, + frontend_method_data=frontend_method_data, + init_flags=init_flags, + method_flags=method_flags, + frontend=frontend, + on_device=on_device, + ) + + # sqrt_ @handle_frontend_method( class_tree=CLASS_TREE,