Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add copy_ for torch frontend #21946

Closed
wants to merge 14 commits into from
30 changes: 29 additions & 1 deletion ivy/data_classes/array/experimental/utility.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,34 @@
# global
from typing import Optional
import abc

# local
import ivy


class _ArrayWithUtilityExperimental(abc.ABC):
pass
def optional_get_element(
self: Optional[ivy.Array] = None,
/,
*,
out: Optional[ivy.Array] = None,
) -> ivy.Array:
"""
If the input is a tensor or sequence type, it returns the input. If the input is
an optional type, it outputs the element in the input. It is an error if the
input is an empty optional-type (i.e. does not have an element) and the behavior
is undefined in this case.

Parameters
----------
self
Input array
out
Optional output array, for writing the result to.

Returns
-------
ret
Input array if it is not None
"""
return ivy.optional_get_element(self._data, out=out)
79 changes: 78 additions & 1 deletion ivy/data_classes/container/experimental/utility.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,82 @@
# global
from typing import Optional, Union, Dict, List

# local
import ivy
from ivy.data_classes.container.base import ContainerBase


class _ContainerWithUtilityExperimental(ContainerBase):
pass
@staticmethod
def static_optional_get_element(
x: Optional[Union[ivy.Array, ivy.Container]] = None,
/,
*,
key_chains: Optional[Union[List[str], Dict[str, str]]] = None,
to_apply: bool = True,
prune_unapplied: bool = False,
map_sequences: bool = False,
out: Optional[ivy.Container] = None,
) -> ivy.Container:
"""
ivy.Container static method variant of ivy.optional_get_element. This method
simply wraps the function, and so the docstring for ivy.optional_get_element
also applies to this method with minimal changes.

Parameters
----------
x
container with array inputs.
key_chains
The keychains to apply or not apply the method to. Default is ``None``.
to_apply
If True, the method will be applied to key_chains, otherwise key_chains
will be skipped. Default is ``True``.
prune_unapplied
Whether to prune key_chains for which the function was not applied.
Default is ``False``.
map_sequences
Whether to also map method to sequences (lists, tuples).
Default is ``False``.
out
optional output container, for writing the result to.

Returns
-------
ret
Container with arrays flattened at leaves.
"""
return ContainerBase.cont_multi_map_in_function(
"optional_get_element",
x,
key_chains=key_chains,
to_apply=to_apply,
prune_unapplied=prune_unapplied,
map_sequences=map_sequences,
out=out,
)

def optional_get_element(
self: ivy.Container,
/,
*,
out: Optional[ivy.Container] = None,
) -> ivy.Container:
"""
ivy.Container instance method variant of ivy.optional_get_element. This method
simply wraps the function, and so the docstring for ivy.optional_get_element
also applies to this method with minimal changes.

Parameters
----------
self
Input container
out
Optional output container, for writing the result to.

Returns
-------
ret
Output container.
"""
return self.static_optional_get_element(self, out=out)
9 changes: 9 additions & 0 deletions ivy/functional/frontends/torch/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,15 @@ def log2(self):
def relu(self):
return torch_frontend_nn.relu(self)

@with_unsupported_dtypes({"2.0.1 and below": ("bfloat16", "uint16")}, "torch")
def copy_(self, other, non_blocking=False):
ivy.utils.assertions.check_one_way_broadcastable(
self.ivy_array.shape,
torch_frontend.tensor(other).ivy_array.shape
)
self._ivy_array = torch_frontend.tensor(other).ivy_array
return self

@numpy_to_torch_style_args
@with_unsupported_dtypes({"2.0.1 and below": ("complex",)}, "torch")
def amax(self, dim=None, keepdim=False):
Expand Down
41 changes: 41 additions & 0 deletions ivy/functional/ivy/experimental/utility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# global
from typing import Optional

# local
import ivy
from ivy import handle_out_argument, handle_nestable
from ivy.utils.exceptions import handle_exceptions


@handle_out_argument
@handle_nestable
@handle_exceptions
def optional_get_element(
x: Optional[ivy.Array] = None,
/,
*,
out: Optional[ivy.Array] = None,
) -> ivy.Array:
"""
If the input is a tensor or sequence type, it returns the input. If the input is an
optional type, it outputs the element in the input. It is an error if the input is
an empty optional-type (i.e. does not have an element) and the behavior is undefined
in this case.

Parameters
----------
x
Input array
out
Optional output array, for writing the result to.

Returns
-------
ret
Input array if it is not None
"""
if x is None:
raise ivy.utils.exceptions.IvyError(
"The requested optional input has no value."
)
return x
38 changes: 38 additions & 0 deletions ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6042,6 +6042,44 @@ def test_torch_tensor_cumsum_(
)


# copy_
@handle_frontend_method(
class_tree=CLASS_TREE,
init_tree="torch.tensor",
method_name="copy_",
dtype_and_x=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("valid"),
num_arrays=2,
),
)
def test_torch_tensor_copy_(
dtype_and_x,
frontend_method_data,
init_flags,
method_flags,
frontend,
on_device,
backend_fw,
):
input_dtype, x = dtype_and_x
helpers.test_frontend_method(
init_input_dtypes=input_dtype,
backend_to_test=backend_fw,
init_all_as_kwargs_np={
"data": x[0],
},
method_input_dtypes=input_dtype,
method_all_as_kwargs_np={
"other": x[1],
},
frontend_method_data=frontend_method_data,
init_flags=init_flags,
method_flags=method_flags,
frontend=frontend,
on_device=on_device,
)


# sort
@handle_frontend_method(
class_tree=CLASS_TREE,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# global
from hypothesis import strategies as st

# local
import ivy_tests.test_ivy.helpers as helpers
from ivy_tests.test_ivy.helpers import handle_test


@handle_test(
fn_tree="functional.ivy.experimental.optional_get_element",
dtype_and_x=helpers.dtype_and_values(),
input_tensor=st.booleans(),
)
def test_optional_get_element(
*,
dtype_and_x,
input_tensor,
test_flags,
backend_fw,
fn_name,
on_device,
):
input_dtype, x = dtype_and_x
fn_input = x[0] if input_tensor else x

helpers.test_function(
input_dtypes=input_dtype,
test_flags=test_flags,
on_device=on_device,
backend_to_test=backend_fw,
fn_name=fn_name,
x=fn_input,
)
Loading