diff --git a/ivy/data_classes/array/experimental/utility.py b/ivy/data_classes/array/experimental/utility.py index 421c007589869..46aaccd3510ad 100644 --- a/ivy/data_classes/array/experimental/utility.py +++ b/ivy/data_classes/array/experimental/utility.py @@ -1,6 +1,34 @@ # global +from typing import Optional import abc +# local +import ivy + class _ArrayWithUtilityExperimental(abc.ABC): - pass + def optional_get_element( + self: Optional[ivy.Array] = None, + /, + *, + out: Optional[ivy.Array] = None, + ) -> ivy.Array: + """ + If the input is a tensor or sequence type, it returns the input. If the input is + an optional type, it outputs the element in the input. It is an error if the + input is an empty optional-type (i.e. does not have an element) and the behavior + is undefined in this case. + + Parameters + ---------- + self + Input array + out + Optional output array, for writing the result to. + + Returns + ------- + ret + Input array if it is not None + """ + return ivy.optional_get_element(self._data, out=out) diff --git a/ivy/data_classes/container/experimental/utility.py b/ivy/data_classes/container/experimental/utility.py index b21273ff50669..876a9fad339e5 100644 --- a/ivy/data_classes/container/experimental/utility.py +++ b/ivy/data_classes/container/experimental/utility.py @@ -1,5 +1,82 @@ +# global +from typing import Optional, Union, Dict, List + +# local +import ivy from ivy.data_classes.container.base import ContainerBase class _ContainerWithUtilityExperimental(ContainerBase): - pass + @staticmethod + def static_optional_get_element( + x: Optional[Union[ivy.Array, ivy.Container]] = None, + /, + *, + key_chains: Optional[Union[List[str], Dict[str, str]]] = None, + to_apply: bool = True, + prune_unapplied: bool = False, + map_sequences: bool = False, + out: Optional[ivy.Container] = None, + ) -> ivy.Container: + """ + ivy.Container static method variant of ivy.optional_get_element. This method + simply wraps the function, and so the docstring for ivy.optional_get_element + also applies to this method with minimal changes. + + Parameters + ---------- + x + container with array inputs. + key_chains + The keychains to apply or not apply the method to. Default is ``None``. + to_apply + If True, the method will be applied to key_chains, otherwise key_chains + will be skipped. Default is ``True``. + prune_unapplied + Whether to prune key_chains for which the function was not applied. + Default is ``False``. + map_sequences + Whether to also map method to sequences (lists, tuples). + Default is ``False``. + out + optional output container, for writing the result to. + + Returns + ------- + ret + Container with arrays flattened at leaves. + """ + return ContainerBase.cont_multi_map_in_function( + "optional_get_element", + x, + key_chains=key_chains, + to_apply=to_apply, + prune_unapplied=prune_unapplied, + map_sequences=map_sequences, + out=out, + ) + + def optional_get_element( + self: ivy.Container, + /, + *, + out: Optional[ivy.Container] = None, + ) -> ivy.Container: + """ + ivy.Container instance method variant of ivy.optional_get_element. This method + simply wraps the function, and so the docstring for ivy.optional_get_element + also applies to this method with minimal changes. + + Parameters + ---------- + self + Input container + out + Optional output container, for writing the result to. + + Returns + ------- + ret + Output container. + """ + return self.static_optional_get_element(self, out=out) diff --git a/ivy/functional/frontends/torch/tensor.py b/ivy/functional/frontends/torch/tensor.py index d3d18a0b4f498..a6915fb7a1ec9 100644 --- a/ivy/functional/frontends/torch/tensor.py +++ b/ivy/functional/frontends/torch/tensor.py @@ -392,6 +392,15 @@ def log2(self): def relu(self): return torch_frontend_nn.relu(self) + @with_unsupported_dtypes({"2.0.1 and below": ("bfloat16", "uint16")}, "torch") + def copy_(self, other, non_blocking=False): + ivy.utils.assertions.check_one_way_broadcastable( + self.ivy_array.shape, + torch_frontend.tensor(other).ivy_array.shape + ) + self._ivy_array = torch_frontend.tensor(other).ivy_array + return self + @numpy_to_torch_style_args @with_unsupported_dtypes({"2.0.1 and below": ("complex",)}, "torch") def amax(self, dim=None, keepdim=False): diff --git a/ivy/functional/ivy/experimental/utility.py b/ivy/functional/ivy/experimental/utility.py index e69de29bb2d1d..e33383151b01c 100644 --- a/ivy/functional/ivy/experimental/utility.py +++ b/ivy/functional/ivy/experimental/utility.py @@ -0,0 +1,41 @@ +# global +from typing import Optional + +# local +import ivy +from ivy import handle_out_argument, handle_nestable +from ivy.utils.exceptions import handle_exceptions + + +@handle_out_argument +@handle_nestable +@handle_exceptions +def optional_get_element( + x: Optional[ivy.Array] = None, + /, + *, + out: Optional[ivy.Array] = None, +) -> ivy.Array: + """ + If the input is a tensor or sequence type, it returns the input. If the input is an + optional type, it outputs the element in the input. It is an error if the input is + an empty optional-type (i.e. does not have an element) and the behavior is undefined + in this case. + + Parameters + ---------- + x + Input array + out + Optional output array, for writing the result to. + + Returns + ------- + ret + Input array if it is not None + """ + if x is None: + raise ivy.utils.exceptions.IvyError( + "The requested optional input has no value." + ) + return x diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py index 07a418658f32e..2f7d9eca67129 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py @@ -6042,6 +6042,44 @@ def test_torch_tensor_cumsum_( ) +# copy_ +@handle_frontend_method( + class_tree=CLASS_TREE, + init_tree="torch.tensor", + method_name="copy_", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + num_arrays=2, + ), +) +def test_torch_tensor_copy_( + dtype_and_x, + frontend_method_data, + init_flags, + method_flags, + frontend, + on_device, + backend_fw, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_method( + init_input_dtypes=input_dtype, + backend_to_test=backend_fw, + init_all_as_kwargs_np={ + "data": x[0], + }, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={ + "other": x[1], + }, + frontend_method_data=frontend_method_data, + init_flags=init_flags, + method_flags=method_flags, + frontend=frontend, + on_device=on_device, + ) + + # sort @handle_frontend_method( class_tree=CLASS_TREE, diff --git a/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_utility.py b/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_utility.py index e69de29bb2d1d..8789c46a529eb 100644 --- a/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_utility.py +++ b/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_utility.py @@ -0,0 +1,33 @@ +# global +from hypothesis import strategies as st + +# local +import ivy_tests.test_ivy.helpers as helpers +from ivy_tests.test_ivy.helpers import handle_test + + +@handle_test( + fn_tree="functional.ivy.experimental.optional_get_element", + dtype_and_x=helpers.dtype_and_values(), + input_tensor=st.booleans(), +) +def test_optional_get_element( + *, + dtype_and_x, + input_tensor, + test_flags, + backend_fw, + fn_name, + on_device, +): + input_dtype, x = dtype_and_x + fn_input = x[0] if input_tensor else x + + helpers.test_function( + input_dtypes=input_dtype, + test_flags=test_flags, + on_device=on_device, + backend_to_test=backend_fw, + fn_name=fn_name, + x=fn_input, + )