diff --git a/.gitignore b/.gitignore index 35a762120aef7..8f4cff4ddfa92 100644 --- a/.gitignore +++ b/.gitignore @@ -27,4 +27,5 @@ with_time_logs/ .array_api_tests_k_flag* internal_automation_tools/ .vscode/* -.idea/* \ No newline at end of file +.idea/* +.DS_Store* diff --git a/ivy/functional/frontends/torch/linalg.py b/ivy/functional/frontends/torch/linalg.py index 900ce08e02f02..79f4180f5246a 100644 --- a/ivy/functional/frontends/torch/linalg.py +++ b/ivy/functional/frontends/torch/linalg.py @@ -204,3 +204,35 @@ def lu_factor(A, *, pivot=True, out=None): @to_ivy_arrays_and_back def matmul(input, other, *, out=None): return ivy.matmul(input, other, out=out) + + +@to_ivy_arrays_and_back +@with_unsupported_dtypes({"1.11.0 and below": ("bfloat16", "float16")}, "torch") +def vander(x, N=None): + if len(x.shape) < 1: + raise RuntimeError("Input dim must be greater than or equal to 1.") + + # pytorch always return int64 for integers + if "int" in x.dtype: + x = ivy.astype(x, ivy.int64) + + if len(x.shape) == 1: + # torch always returns the powers in ascending order + return ivy.vander(x, N=N, increasing=True) + + # support multi-dimensional array + original_shape = x.shape + if N is None: + N = x.shape[-1] + + # store the vander output + x = ivy.reshape(x, (-1, x.shape[-1])) + output = [] + + for i in range(x.shape[0]): + output.append(ivy.vander(x[i], N=N, increasing=True)) + + output = ivy.stack(output) + output = ivy.reshape(output, (*original_shape, N)) + output = ivy.astype(output, x.dtype) + return output diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py index a57280b1cd7fe..e31b264b07305 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py @@ -336,24 +336,24 @@ def test_torch_eigvals( ) """ - In "ret" we have out eigenvalues calculated with our backend and + In "ret" we have out eigenvalues calculated with our backend and in "frontend_ret" are our eigenvalues calculated with the specified frontend """ """ - Depending on the chosen framework there may be small differences between our - extremely small or big eigenvalues (eg: -3.62831993e-33+0.j(numpy) - vs -1.9478e-32+0.j(PyTorch)). - Important is that both are very very close to zero, indicating a + Depending on the chosen framework there may be small differences between our + extremely small or big eigenvalues (eg: -3.62831993e-33+0.j(numpy) + vs -1.9478e-32+0.j(PyTorch)). + Important is that both are very very close to zero, indicating a small value(very close to 0) either way. - To asses the correctness of our calculated eigenvalues for our initial matrix + To asses the correctness of our calculated eigenvalues for our initial matrix we sort both numpy arrays and call assert_all_close on their modulus. """ """ - Supports input of float, double, cfloat and cdouble dtypes. - Also supports batches of matrices, and if A is a batch of matrices then the + Supports input of float, double, cfloat and cdouble dtypes. + Also supports batches of matrices, and if A is a batch of matrices then the output has the same batch dimension """ @@ -1094,3 +1094,56 @@ def test_torch_matmul( rtol=1e-03, atol=1e-06, ) + + +# vander +@st.composite +def _vander_helper(draw): + # generate input matrix of shape (*, n) and where '*' is one or more + # batch dimensions + N = draw(helpers.ints(min_value=2, max_value=5)) + if draw(helpers.floats(min_value=0, max_value=1.)) < 0.5: + N = None + + shape = draw(helpers.get_shape(min_num_dims=1, + max_num_dims=5, + min_dim_size=2, + max_dim_size=10)) + dtype = "float" + if draw(helpers.floats(min_value=0, max_value=1.)) < 0.5: + dtype = "integer" + + x = draw( + helpers.dtype_and_values( + available_dtypes=draw(helpers.get_dtypes(dtype)), + shape=shape, + min_value=-10, + max_value=10, + ) + ) + + return *x, N + + +@handle_frontend_test( + fn_tree="torch.linalg.vander", + dtype_and_input=_vander_helper(), +) +def test_torch_vander( + *, + dtype_and_input, + frontend, + fn_tree, + on_device, + test_flags, +): + input_dtype, x, N = dtype_and_input + test_flags.num_positional_args = 1 + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + fn_tree=fn_tree, + on_device=on_device, + test_flags=test_flags, + x=x[0], N=N + )