From c1f98b18c9969bb378713c59b74f6fa5cafabcca Mon Sep 17 00:00:00 2001 From: Sam Armstrong <88863522+Sam-Armstrong@users.noreply.github.com> Date: Wed, 6 Mar 2024 15:56:51 +0000 Subject: [PATCH] fix: torch frontend max pooling to support optional batch dim (#28490) --- .../torch/nn/functional/pooling_functions.py | 23 ++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/ivy/functional/frontends/torch/nn/functional/pooling_functions.py b/ivy/functional/frontends/torch/nn/functional/pooling_functions.py index 5f7cbff10383e..17556ac01a22f 100644 --- a/ivy/functional/frontends/torch/nn/functional/pooling_functions.py +++ b/ivy/functional/frontends/torch/nn/functional/pooling_functions.py @@ -237,7 +237,13 @@ def max_pool1d( stride = kernel_size if not isinstance(padding, int): padding = [(pad, pad) for pad in padding] - return ivy.max_pool1d( + if input.ndim == 2: + without_batch_dim = True + input = ivy.expand_dims(input, axis=0) + else: + without_batch_dim = False + + ret = ivy.max_pool1d( input, kernel_size, stride, @@ -246,6 +252,9 @@ def max_pool1d( dilation=dilation, ceil_mode=ceil_mode, ) + if without_batch_dim: + ret = ret[0] + return ret @with_unsupported_dtypes({"2.2 and below": ("float16",)}, "torch") @@ -263,7 +272,13 @@ def max_pool2d( stride = kernel_size if not isinstance(padding, int): padding = [(pad, pad) for pad in padding] - return ivy.max_pool2d( + if input.ndim == 3: + without_batch_dim = True + input = ivy.expand_dims(input, axis=0) + else: + without_batch_dim = False + + ret = ivy.max_pool2d( input, kernel_size, stride, @@ -272,6 +287,9 @@ def max_pool2d( dilation=dilation, ceil_mode=ceil_mode, ) + if without_batch_dim: + ret = ret[0] + return ret @with_unsupported_dtypes({"2.2 and below": ("float16",)}, "torch") @@ -306,5 +324,4 @@ def max_pool3d( ) if without_batch_dim: ret = ret[0] - return ret