diff --git a/ivy/functional/backends/jax/experimental/layers.py b/ivy/functional/backends/jax/experimental/layers.py index bd383c0af823f..c20c5c1847df9 100644 --- a/ivy/functional/backends/jax/experimental/layers.py +++ b/ivy/functional/backends/jax/experimental/layers.py @@ -9,6 +9,7 @@ import ivy from ivy.functional.backends.jax import JaxArray from ivy.functional.backends.jax.random import RNG +from ivy.functional.ivy.layers import _handle_padding def general_pool(inputs, init, reduce_fn, window_shape, strides, padding): @@ -34,7 +35,22 @@ def general_pool(inputs, init, reduce_fn, window_shape, strides, padding): is_single_input = True assert inputs.ndim == len(dims), f"len({inputs.shape}) != len({dims})" - y = jlax.reduce_window(inputs, init, reduce_fn, dims, strides, padding) + + # doing manual padding instead of + if isinstance(padding, str): + pad_int = [ + _handle_padding( + inputs.shape[i + 1], strides[i + 1], window_shape[i], padding + ) + for i in range(len(dims) - 2) + ] + pad_list = [ + (pad_int[i] // 2, pad_int[i] - pad_int[i] // 2) for i in range(len(pad_int)) + ] + pad_list = [(0, 0)] + pad_list + [(0, 0)] + else: + pad_list = [(0, 0)] + padding + [(0, 0)] + y = jlax.reduce_window(inputs, init, reduce_fn, dims, strides, pad_list) if is_single_input: y = jnp.squeeze(y, axis=0) return y diff --git a/ivy/functional/backends/numpy/experimental/layers.py b/ivy/functional/backends/numpy/experimental/layers.py index 805bed0a46bea..0860e978038fb 100644 --- a/ivy/functional/backends/numpy/experimental/layers.py +++ b/ivy/functional/backends/numpy/experimental/layers.py @@ -90,14 +90,17 @@ def max_pool2d( x = np.transpose(x, (0, 2, 3, 1)) x_shape = list(x.shape[1:3]) - pad_h = _handle_padding(x_shape[0], strides[0], kernel[0], padding) - pad_w = _handle_padding(x_shape[1], strides[1], kernel[1], padding) + pad_list = padding + if isinstance(padding, str): + pad_h = _handle_padding(x_shape[0], strides[0], kernel[0], padding) + pad_w = _handle_padding(x_shape[1], strides[1], kernel[1], padding) + pad_list = [(pad_h // 2, pad_h - pad_h // 2), (pad_w // 2, pad_w - pad_w // 2)] + x = np.pad( x, [ (0, 0), - (pad_h // 2, pad_h - pad_h // 2), - (pad_w // 2, pad_w - pad_w // 2), + *pad_list, (0, 0), ], "edge", diff --git a/ivy/functional/backends/tensorflow/experimental/layers.py b/ivy/functional/backends/tensorflow/experimental/layers.py index 0c0885d46bd3b..70b6f806b379c 100644 --- a/ivy/functional/backends/tensorflow/experimental/layers.py +++ b/ivy/functional/backends/tensorflow/experimental/layers.py @@ -38,6 +38,8 @@ def max_pool2d( ) -> Union[tf.Tensor, tf.Variable]: if data_format == "NCHW": x = tf.transpose(x, (0, 2, 3, 1)) + if not isinstance(padding, str): + padding = [(0, 0)] + padding + [(0, 0)] res = tf.nn.max_pool2d(x, kernel, strides, padding) if data_format == "NCHW": return tf.transpose(res, (0, 3, 1, 2)) diff --git a/ivy/functional/backends/torch/experimental/layers.py b/ivy/functional/backends/torch/experimental/layers.py index 5a9f81d95f072..1840c4ff12f5f 100644 --- a/ivy/functional/backends/torch/experimental/layers.py +++ b/ivy/functional/backends/torch/experimental/layers.py @@ -78,18 +78,22 @@ def max_pool2d( if data_format == "NHWC": x = x.permute(0, 3, 1, 2) x_shape = list(x.shape[2:]) - pad_h = _handle_padding(x_shape[0], strides[0], kernel[0], padding) - pad_w = _handle_padding(x_shape[1], strides[1], kernel[1], padding) + + if isinstance(padding, str): + pad_h = _handle_padding(x_shape[0], strides[0], kernel[0], padding) + pad_w = _handle_padding(x_shape[1], strides[1], kernel[1], padding) + pad_list = [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2] + else: + # torch pad takes width padding first, then height padding + padding = (padding[1], padding[0]) + pad_list = [item for sublist in padding for item in sublist] + x = torch.nn.functional.pad( x, - [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2], + pad_list, value=float("-inf"), ) - if padding != "VALID" and padding != "SAME": - raise ivy.exceptions.IvyException( - "Invalid padding arg {}\n" - 'Must be one of: "VALID" or "SAME"'.format(padding) - ) + res = torch.nn.functional.max_pool2d(x, kernel, strides, 0) if data_format == "NHWC": return res.permute(0, 2, 3, 1) diff --git a/ivy_tests/test_ivy/helpers/hypothesis_helpers/array_helpers.py b/ivy_tests/test_ivy/helpers/hypothesis_helpers/array_helpers.py index fa267e4218102..1c1588fefd618 100644 --- a/ivy_tests/test_ivy/helpers/hypothesis_helpers/array_helpers.py +++ b/ivy_tests/test_ivy/helpers/hypothesis_helpers/array_helpers.py @@ -899,7 +899,9 @@ def array_and_broadcastable_shape(draw, dtype): @st.composite -def arrays_for_pooling(draw, min_dims, max_dims, min_side, max_side): +def arrays_for_pooling( + draw, min_dims, max_dims, min_side, max_side, allow_explicit_padding=False +): in_shape = draw( nph.array_shapes( min_dims=min_dims, max_dims=max_dims, min_side=min_side, max_side=max_side @@ -929,6 +931,29 @@ def arrays_for_pooling(draw, min_dims, max_dims, min_side, max_side): ) if array_dim == 3: kernel = draw(st.tuples(st.integers(1, in_shape[1]))) - padding = draw(st.sampled_from(["VALID", "SAME"])) + if allow_explicit_padding: + padding = [] + for i in range(array_dim - 2): + max_pad = kernel[i] // 2 + possible_pad_combos = [ + (i, max_pad - i) + for i in range(0, max_pad) + if i + (max_pad - i) == max_pad + ] + if len(possible_pad_combos) == 0: + pad_selected_combo = (0, 0) + else: + pad_selected_combo = draw(st.sampled_from(possible_pad_combos)) + padding.append( + draw( + st.tuples( + st.integers(0, pad_selected_combo[0]), + st.integers(0, pad_selected_combo[1]), + ) + ) + ) + padding = draw(st.one_of(st.just(padding), st.sampled_from(["VALID", "SAME"]))) + else: + padding = draw(st.sampled_from(["VALID", "SAME"])) strides = draw(st.tuples(st.integers(1, in_shape[1]))) return dtype, x, kernel, strides, padding diff --git a/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_layers.py b/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_layers.py index e6aa141adb206..f4fb66a06473a 100644 --- a/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_layers.py +++ b/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_layers.py @@ -9,8 +9,11 @@ @handle_test( fn_tree="functional.ivy.experimental.max_pool2d", - x_k_s_p=helpers.arrays_for_pooling(min_dims=4, max_dims=4, min_side=1, max_side=4), + x_k_s_p=helpers.arrays_for_pooling( + min_dims=4, max_dims=4, min_side=1, max_side=4, allow_explicit_padding=True + ), test_gradients=st.just(False), + container_flags=st.just([False]), ) def test_max_pool2d( *,