Skip to content

Commit

Permalink
fix: support return_indices=True in torch max_pool2d frontend (#28537)
Browse files Browse the repository at this point in the history
  • Loading branch information
Sam-Armstrong committed Mar 11, 2024
1 parent 2301f2a commit feee5e9
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 6 deletions.
65 changes: 59 additions & 6 deletions ivy/functional/frontends/torch/nn/functional/pooling_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

# local
import ivy
import ivy.functional.frontends.torch as torch_frontend
from ivy import with_unsupported_dtypes
from ivy.functional.frontends.torch.func_wrapper import (
to_ivy_arrays_and_back,
Expand Down Expand Up @@ -270,26 +271,78 @@ def max_pool2d(
):
if not stride:
stride = kernel_size
if not isinstance(padding, int):
padding = [(pad, pad) for pad in padding]
if input.ndim == 3:
without_batch_dim = True
input = ivy.expand_dims(input, axis=0)
else:
without_batch_dim = False

ret = ivy.max_pool2d(
output = ivy.max_pool2d(
input,
kernel_size,
stride,
padding,
([(pad, pad) for pad in padding] if not isinstance(padding, int) else padding),
data_format="NCHW",
dilation=dilation,
ceil_mode=ceil_mode,
)

if return_indices:
if isinstance(stride, (list, tuple)) and len(stride) == 1:
stride = stride[0]

in_shape = input.shape
H = in_shape[-2]
W = in_shape[-1]
n_indices = H * W

# calculate the indices within the input tensor
# for each position in the sliding window
input_indices = torch_frontend.arange(0, n_indices, dtype=torch_frontend.int64)
input_indices = input_indices.reshape((1, 1, H, W))
unfolded_indices = torch_frontend.nn.functional.unfold(
input_indices,
kernel_size=kernel_size,
padding=padding,
dilation=dilation,
stride=stride,
).permute((0, 2, 1))[0]

# find the indices of the max value for each position of the sliding window
input = torch_frontend.nn.functional.pad(
input,
[padding] * 4 if isinstance(padding, int) else padding * 2,
value=float("-inf"),
)
unfolded_values = torch_frontend.nn.functional.unfold(
input, kernel_size=kernel_size, padding=0, dilation=dilation, stride=stride
)
unfolded_values_shape = unfolded_values.shape
unfolded_indices = unfolded_indices.repeat(
unfolded_values_shape[0], unfolded_values_shape[1], 1, 1
)
unfolded_values = unfolded_values.reshape(
input.shape[0],
input.shape[1],
unfolded_values.shape[1] // input.shape[1],
unfolded_values.shape[2],
)
indices = torch_frontend.argmax(unfolded_values, dim=2)

# gather the indices within the input tensor of the max values
indices = torch_frontend.gather(
unfolded_indices, -1, torch_frontend.unsqueeze(indices, -1)
)
indices = indices.reshape(output.shape)

if without_batch_dim:
ret = ret[0]
return ret
output = output[0]
if return_indices:
indices = indices[0]

if return_indices:
return output, indices
return output


@with_unsupported_dtypes({"2.2 and below": ("float16",)}, "torch")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -479,10 +479,12 @@ def test_torch_max_pool1d(
),
test_with_out=st.just(False),
ceil_mode=st.booleans(),
return_indices=st.booleans(),
)
def test_torch_max_pool2d(
x_k_s_p,
ceil_mode,
return_indices,
*,
test_flags,
frontend,
Expand All @@ -506,6 +508,7 @@ def test_torch_max_pool2d(
padding=padding,
dilation=dilation,
ceil_mode=ceil_mode,
return_indices=return_indices,
)


Expand Down

0 comments on commit feee5e9

Please sign in to comment.