diff --git a/ivy/functional/frontends/torch/nn/functional/pooling_functions.py b/ivy/functional/frontends/torch/nn/functional/pooling_functions.py index 5f7cbff10383e..17556ac01a22f 100644 --- a/ivy/functional/frontends/torch/nn/functional/pooling_functions.py +++ b/ivy/functional/frontends/torch/nn/functional/pooling_functions.py @@ -237,7 +237,13 @@ def max_pool1d( stride = kernel_size if not isinstance(padding, int): padding = [(pad, pad) for pad in padding] - return ivy.max_pool1d( + if input.ndim == 2: + without_batch_dim = True + input = ivy.expand_dims(input, axis=0) + else: + without_batch_dim = False + + ret = ivy.max_pool1d( input, kernel_size, stride, @@ -246,6 +252,9 @@ def max_pool1d( dilation=dilation, ceil_mode=ceil_mode, ) + if without_batch_dim: + ret = ret[0] + return ret @with_unsupported_dtypes({"2.2 and below": ("float16",)}, "torch") @@ -263,7 +272,13 @@ def max_pool2d( stride = kernel_size if not isinstance(padding, int): padding = [(pad, pad) for pad in padding] - return ivy.max_pool2d( + if input.ndim == 3: + without_batch_dim = True + input = ivy.expand_dims(input, axis=0) + else: + without_batch_dim = False + + ret = ivy.max_pool2d( input, kernel_size, stride, @@ -272,6 +287,9 @@ def max_pool2d( dilation=dilation, ceil_mode=ceil_mode, ) + if without_batch_dim: + ret = ret[0] + return ret @with_unsupported_dtypes({"2.2 and below": ("float16",)}, "torch") @@ -306,5 +324,4 @@ def max_pool3d( ) if without_batch_dim: ret = ret[0] - return ret