Skip to content

Commit

Permalink
Implemented torch split and torch.Tensor.split
Browse files Browse the repository at this point in the history
  • Loading branch information
fspyridakos committed Jan 27, 2023
1 parent 786abee commit f83877d
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,20 @@ def take_along_dim(input, indices, dim, *, out=None):
@to_ivy_arrays_and_back
def vstack(tensors, *, out=None):
return ivy.vstack(tensors, out=out)


@to_ivy_arrays_and_back
def split(tensor, split_size_or_sections, dim=0):
if isinstance(split_size_or_sections, int):
split_size = split_size_or_sections
split_size_or_sections = [split_size] * (tensor.shape[dim] // split_size)
if tensor.shape[dim] % split_size:
split_size_or_sections.append(tensor.shape[dim] % split_size)
return tuple(
ivy.split(
tensor,
num_or_size_splits=split_size_or_sections,
axis=dim,
with_remainder=True,
)
)
3 changes: 3 additions & 0 deletions ivy/functional/frontends/torch/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,9 @@ def unsqueeze_(self, dim):
self._ivy_array = self.unsqueeze(dim).ivy_array
return self

def split(self, split_size, dim=0):
return torch_frontend.split(self, split_size, dim)

def dim(self):
return self._ivy_array.ndim

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# local
import ivy_tests.test_ivy.helpers as helpers
from ivy_tests.test_ivy.helpers import handle_frontend_test
from ivy_tests.test_ivy.test_functional.test_core.test_manipulation import _get_splits


# noinspection DuplicatedCode
Expand Down Expand Up @@ -843,3 +844,42 @@ def test_torch_vstack(
on_device=on_device,
tensors=value,
)


# split
@handle_frontend_test(
fn_tree="torch.split",
dtype_value=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("valid"),
shape=st.shared(helpers.get_shape(min_num_dims=1), key="value_shape"),
),
split_size_or_sections=_get_splits().filter(lambda s: s is not None),
dim=st.shared(
helpers.get_axis(
shape=st.shared(helpers.get_shape(min_num_dims=1), key="value_shape"),
force_int=True,
),
key="target_axis",
),
)
def test_torch_split(
*,
dtype_value,
split_size_or_sections,
dim,
on_device,
fn_tree,
frontend,
test_flags,
):
input_dtype, value = dtype_value
helpers.test_frontend_function(
input_dtypes=input_dtype,
frontend=frontend,
test_flags=test_flags,
fn_tree=fn_tree,
on_device=on_device,
tensor=value[0],
split_size_or_sections=split_size_or_sections,
dim=dim,
)
52 changes: 52 additions & 0 deletions ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from ivy.functional.frontends.torch import Tensor
import ivy_tests.test_ivy.helpers.test_parameter_flags as pf
from ivy_tests.test_ivy.helpers import handle_frontend_method
from ivy_tests.test_ivy.test_functional.test_core.test_manipulation import _get_splits
from ivy_tests.test_ivy.test_functional.test_core.test_searching import (
_broadcastable_trio,
)
Expand Down Expand Up @@ -2213,6 +2214,57 @@ def test_torch_instance_unsqueeze_(
)


# split
@handle_frontend_method(
class_tree=CLASS_TREE,
init_tree="torch.tensor",
method_name="split",
dtype_value=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("valid"),
shape=st.shared(helpers.get_shape(min_num_dims=1), key="value_shape"),
),
split_size=_get_splits().filter(lambda s: s is not None),
dim=st.shared(
helpers.get_axis(
shape=st.shared(helpers.get_shape(min_num_dims=1), key="value_shape"),
force_int=True,
),
key="target_axis",
),
)
def test_torch_instance_split(
dtype_value,
split_size,
dim,
init_num_positional_args: pf.NumPositionalArgFn,
method_num_positional_args: pf.NumPositionalArgMethod,
as_variable: pf.AsVariableFlags,
native_array: pf.NativeArrayFlags,
frontend_method_data,
frontend,
):
input_dtype, x = dtype_value
helpers.test_frontend_method(
init_input_dtypes=input_dtype,
init_as_variable_flags=as_variable,
init_num_positional_args=init_num_positional_args,
init_native_array_flags=native_array,
init_all_as_kwargs_np={
"data": x[0],
},
method_input_dtypes=input_dtype,
method_as_variable_flags=as_variable,
method_num_positional_args=method_num_positional_args,
method_native_array_flags=native_array,
method_all_as_kwargs_np={
"split_size": split_size,
"dim": dim,
},
frontend_method_data=frontend_method_data,
frontend=frontend,
)


# detach
@handle_frontend_method(
class_tree=CLASS_TREE,
Expand Down

0 comments on commit f83877d

Please sign in to comment.