From feee5e917da3e89e185293e80895915a8adb1147 Mon Sep 17 00:00:00 2001 From: Sam Armstrong <88863522+Sam-Armstrong@users.noreply.github.com> Date: Mon, 11 Mar 2024 04:26:36 +0000 Subject: [PATCH] fix: support return_indices=True in torch max_pool2d frontend (#28537) --- .../torch/nn/functional/pooling_functions.py | 65 +++++++++++++++++-- .../test_functional/test_pooling_functions.py | 3 + 2 files changed, 62 insertions(+), 6 deletions(-) diff --git a/ivy/functional/frontends/torch/nn/functional/pooling_functions.py b/ivy/functional/frontends/torch/nn/functional/pooling_functions.py index 17556ac01a22f..c442c55afc917 100644 --- a/ivy/functional/frontends/torch/nn/functional/pooling_functions.py +++ b/ivy/functional/frontends/torch/nn/functional/pooling_functions.py @@ -3,6 +3,7 @@ # local import ivy +import ivy.functional.frontends.torch as torch_frontend from ivy import with_unsupported_dtypes from ivy.functional.frontends.torch.func_wrapper import ( to_ivy_arrays_and_back, @@ -270,26 +271,78 @@ def max_pool2d( ): if not stride: stride = kernel_size - if not isinstance(padding, int): - padding = [(pad, pad) for pad in padding] if input.ndim == 3: without_batch_dim = True input = ivy.expand_dims(input, axis=0) else: without_batch_dim = False - ret = ivy.max_pool2d( + output = ivy.max_pool2d( input, kernel_size, stride, - padding, + ([(pad, pad) for pad in padding] if not isinstance(padding, int) else padding), data_format="NCHW", dilation=dilation, ceil_mode=ceil_mode, ) + + if return_indices: + if isinstance(stride, (list, tuple)) and len(stride) == 1: + stride = stride[0] + + in_shape = input.shape + H = in_shape[-2] + W = in_shape[-1] + n_indices = H * W + + # calculate the indices within the input tensor + # for each position in the sliding window + input_indices = torch_frontend.arange(0, n_indices, dtype=torch_frontend.int64) + input_indices = input_indices.reshape((1, 1, H, W)) + unfolded_indices = torch_frontend.nn.functional.unfold( + input_indices, + kernel_size=kernel_size, + padding=padding, + dilation=dilation, + stride=stride, + ).permute((0, 2, 1))[0] + + # find the indices of the max value for each position of the sliding window + input = torch_frontend.nn.functional.pad( + input, + [padding] * 4 if isinstance(padding, int) else padding * 2, + value=float("-inf"), + ) + unfolded_values = torch_frontend.nn.functional.unfold( + input, kernel_size=kernel_size, padding=0, dilation=dilation, stride=stride + ) + unfolded_values_shape = unfolded_values.shape + unfolded_indices = unfolded_indices.repeat( + unfolded_values_shape[0], unfolded_values_shape[1], 1, 1 + ) + unfolded_values = unfolded_values.reshape( + input.shape[0], + input.shape[1], + unfolded_values.shape[1] // input.shape[1], + unfolded_values.shape[2], + ) + indices = torch_frontend.argmax(unfolded_values, dim=2) + + # gather the indices within the input tensor of the max values + indices = torch_frontend.gather( + unfolded_indices, -1, torch_frontend.unsqueeze(indices, -1) + ) + indices = indices.reshape(output.shape) + if without_batch_dim: - ret = ret[0] - return ret + output = output[0] + if return_indices: + indices = indices[0] + + if return_indices: + return output, indices + return output @with_unsupported_dtypes({"2.2 and below": ("float16",)}, "torch") diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_pooling_functions.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_pooling_functions.py index a4c3e9bf7de99..665e54d9156c5 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_pooling_functions.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_pooling_functions.py @@ -479,10 +479,12 @@ def test_torch_max_pool1d( ), test_with_out=st.just(False), ceil_mode=st.booleans(), + return_indices=st.booleans(), ) def test_torch_max_pool2d( x_k_s_p, ceil_mode, + return_indices, *, test_flags, frontend, @@ -506,6 +508,7 @@ def test_torch_max_pool2d( padding=padding, dilation=dilation, ceil_mode=ceil_mode, + return_indices=return_indices, )