Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added vander to the torch frontend #13000

Merged
merged 3 commits into from
Mar 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,5 @@ with_time_logs/
.array_api_tests_k_flag*
internal_automation_tools/
.vscode/*
.idea/*
.idea/*
.DS_Store*
32 changes: 32 additions & 0 deletions ivy/functional/frontends/torch/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,3 +204,35 @@ def lu_factor(A, *, pivot=True, out=None):
@to_ivy_arrays_and_back
def matmul(input, other, *, out=None):
return ivy.matmul(input, other, out=out)


@to_ivy_arrays_and_back
@with_unsupported_dtypes({"1.11.0 and below": ("bfloat16", "float16")}, "torch")
def vander(x, N=None):
if len(x.shape) < 1:
raise RuntimeError("Input dim must be greater than or equal to 1.")

# pytorch always return int64 for integers
if "int" in x.dtype:
x = ivy.astype(x, ivy.int64)

if len(x.shape) == 1:
# torch always returns the powers in ascending order
return ivy.vander(x, N=N, increasing=True)

# support multi-dimensional array
original_shape = x.shape
if N is None:
N = x.shape[-1]

# store the vander output
x = ivy.reshape(x, (-1, x.shape[-1]))
output = []

for i in range(x.shape[0]):
output.append(ivy.vander(x[i], N=N, increasing=True))

output = ivy.stack(output)
output = ivy.reshape(output, (*original_shape, N))
output = ivy.astype(output, x.dtype)
return output
69 changes: 61 additions & 8 deletions ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,24 +336,24 @@ def test_torch_eigvals(
)

"""
In "ret" we have out eigenvalues calculated with our backend and
In "ret" we have out eigenvalues calculated with our backend and
in "frontend_ret" are our eigenvalues calculated with the specified frontend
"""

"""
Depending on the chosen framework there may be small differences between our
extremely small or big eigenvalues (eg: -3.62831993e-33+0.j(numpy)
vs -1.9478e-32+0.j(PyTorch)).
Important is that both are very very close to zero, indicating a
Depending on the chosen framework there may be small differences between our
extremely small or big eigenvalues (eg: -3.62831993e-33+0.j(numpy)
vs -1.9478e-32+0.j(PyTorch)).
Important is that both are very very close to zero, indicating a
small value(very close to 0) either way.

To asses the correctness of our calculated eigenvalues for our initial matrix
To asses the correctness of our calculated eigenvalues for our initial matrix
we sort both numpy arrays and call assert_all_close on their modulus.
"""

"""
Supports input of float, double, cfloat and cdouble dtypes.
Also supports batches of matrices, and if A is a batch of matrices then the
Supports input of float, double, cfloat and cdouble dtypes.
Also supports batches of matrices, and if A is a batch of matrices then the
output has the same batch dimension
"""

Expand Down Expand Up @@ -1094,3 +1094,56 @@ def test_torch_matmul(
rtol=1e-03,
atol=1e-06,
)


# vander
@st.composite
def _vander_helper(draw):
# generate input matrix of shape (*, n) and where '*' is one or more
# batch dimensions
N = draw(helpers.ints(min_value=2, max_value=5))
if draw(helpers.floats(min_value=0, max_value=1.)) < 0.5:
N = None

shape = draw(helpers.get_shape(min_num_dims=1,
max_num_dims=5,
min_dim_size=2,
max_dim_size=10))
dtype = "float"
if draw(helpers.floats(min_value=0, max_value=1.)) < 0.5:
dtype = "integer"

x = draw(
helpers.dtype_and_values(
available_dtypes=draw(helpers.get_dtypes(dtype)),
shape=shape,
min_value=-10,
max_value=10,
)
)

return *x, N


@handle_frontend_test(
fn_tree="torch.linalg.vander",
dtype_and_input=_vander_helper(),
)
def test_torch_vander(
*,
dtype_and_input,
frontend,
fn_tree,
on_device,
test_flags,
):
input_dtype, x, N = dtype_and_input
test_flags.num_positional_args = 1
helpers.test_frontend_function(
input_dtypes=input_dtype,
frontend=frontend,
fn_tree=fn_tree,
on_device=on_device,
test_flags=test_flags,
x=x[0], N=N
)