Skip to content

Commit

Permalink
Implemented torch.where
Browse files Browse the repository at this point in the history
  • Loading branch information
fspyridakos committed Mar 3, 2023
1 parent 5b9ab2b commit 23b088d
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -275,3 +275,10 @@ def vsplit(input, indices_or_sections=None, /, *, indices=None, sections=None):
@to_ivy_arrays_and_back
def row_stack(tensors, *, out=None):
return ivy.vstack(tensors, out=out)


@to_ivy_arrays_and_back
def where(condition, input=None, other=None):
if not ivy.exists(input) and not ivy.exists(other):
return nonzero(condition, as_tuple=True)
return ivy.where(condition, input, other)
8 changes: 4 additions & 4 deletions ivy/functional/frontends/torch/pointwise_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,19 +465,19 @@ def signbit(input, *, out=None):

@to_ivy_arrays_and_back
def angle(input, *, out=None):
return ivy.angle(input, out=out)
return ivy.angle(input, out=out)


@to_ivy_arrays_and_back
def arctan(input, *, out=None):
return ivy.arctan(input, out=out)
return ivy.arctan(input, out=out)


@to_ivy_arrays_and_back
def conj_physical(input, *, out=None):
return ivy.conj_physical(input, out=out)
return ivy.conj_physical(input, out=out)


@to_ivy_arrays_and_back
def nextafter(input, *, out=None):
return ivy.nextafter(input, out=out)
return ivy.nextafter(input, out=out)
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
import ivy_tests.test_ivy.helpers as helpers
from ivy_tests.test_ivy.helpers import handle_frontend_test
from ivy_tests.test_ivy.test_functional.test_core.test_manipulation import _get_splits
from ivy_tests.test_ivy.test_functional.test_core.test_searching import (
_broadcastable_trio,
)
from ivy_tests.test_ivy.test_functional.test_experimental.test_core.test_manipulation import ( # noqa
_get_split_locations,
)
Expand Down Expand Up @@ -1138,3 +1141,42 @@ def test_torch_row_stack(
on_device=on_device,
tensors=value,
)


@handle_frontend_test(
fn_tree="torch.where",
broadcastables=_broadcastable_trio(),
only_cond=st.booleans(),
)
def test_torch_where(
*,
broadcastables,
only_cond,
frontend,
test_flags,
fn_tree,
on_device,
):
cond, xs, dtypes = broadcastables

if only_cond:
helpers.test_frontend_function(
input_dtypes=[dtypes[0]],
frontend=frontend,
test_flags=test_flags,
fn_tree=fn_tree,
on_device=on_device,
condition=xs[0],
)

else:
helpers.test_frontend_function(
input_dtypes=["bool"] + dtypes,
frontend=frontend,
test_flags=test_flags,
fn_tree=fn_tree,
on_device=on_device,
condition=cond,
input=xs[0],
other=xs[1],
)

0 comments on commit 23b088d

Please sign in to comment.