Skip to content

Commit

Permalink
adding explicit padding in max_pool2d. also adding `allow_explicit_…
Browse files Browse the repository at this point in the history
…padding` in `arrays_for_pooling` helper (#10143)
  • Loading branch information
sherry30 authored Jan 27, 2023
1 parent ba81802 commit 786abee
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 16 deletions.
18 changes: 17 additions & 1 deletion ivy/functional/backends/jax/experimental/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import ivy
from ivy.functional.backends.jax import JaxArray
from ivy.functional.backends.jax.random import RNG
from ivy.functional.ivy.layers import _handle_padding


def general_pool(inputs, init, reduce_fn, window_shape, strides, padding):
Expand All @@ -34,7 +35,22 @@ def general_pool(inputs, init, reduce_fn, window_shape, strides, padding):
is_single_input = True

assert inputs.ndim == len(dims), f"len({inputs.shape}) != len({dims})"
y = jlax.reduce_window(inputs, init, reduce_fn, dims, strides, padding)

# doing manual padding instead of
if isinstance(padding, str):
pad_int = [
_handle_padding(
inputs.shape[i + 1], strides[i + 1], window_shape[i], padding
)
for i in range(len(dims) - 2)
]
pad_list = [
(pad_int[i] // 2, pad_int[i] - pad_int[i] // 2) for i in range(len(pad_int))
]
pad_list = [(0, 0)] + pad_list + [(0, 0)]
else:
pad_list = [(0, 0)] + padding + [(0, 0)]
y = jlax.reduce_window(inputs, init, reduce_fn, dims, strides, pad_list)
if is_single_input:
y = jnp.squeeze(y, axis=0)
return y
Expand Down
11 changes: 7 additions & 4 deletions ivy/functional/backends/numpy/experimental/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,17 @@ def max_pool2d(
x = np.transpose(x, (0, 2, 3, 1))

x_shape = list(x.shape[1:3])
pad_h = _handle_padding(x_shape[0], strides[0], kernel[0], padding)
pad_w = _handle_padding(x_shape[1], strides[1], kernel[1], padding)
pad_list = padding
if isinstance(padding, str):
pad_h = _handle_padding(x_shape[0], strides[0], kernel[0], padding)
pad_w = _handle_padding(x_shape[1], strides[1], kernel[1], padding)
pad_list = [(pad_h // 2, pad_h - pad_h // 2), (pad_w // 2, pad_w - pad_w // 2)]

x = np.pad(
x,
[
(0, 0),
(pad_h // 2, pad_h - pad_h // 2),
(pad_w // 2, pad_w - pad_w // 2),
*pad_list,
(0, 0),
],
"edge",
Expand Down
2 changes: 2 additions & 0 deletions ivy/functional/backends/tensorflow/experimental/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ def max_pool2d(
) -> Union[tf.Tensor, tf.Variable]:
if data_format == "NCHW":
x = tf.transpose(x, (0, 2, 3, 1))
if not isinstance(padding, str):
padding = [(0, 0)] + padding + [(0, 0)]
res = tf.nn.max_pool2d(x, kernel, strides, padding)
if data_format == "NCHW":
return tf.transpose(res, (0, 3, 1, 2))
Expand Down
20 changes: 12 additions & 8 deletions ivy/functional/backends/torch/experimental/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,18 +78,22 @@ def max_pool2d(
if data_format == "NHWC":
x = x.permute(0, 3, 1, 2)
x_shape = list(x.shape[2:])
pad_h = _handle_padding(x_shape[0], strides[0], kernel[0], padding)
pad_w = _handle_padding(x_shape[1], strides[1], kernel[1], padding)

if isinstance(padding, str):
pad_h = _handle_padding(x_shape[0], strides[0], kernel[0], padding)
pad_w = _handle_padding(x_shape[1], strides[1], kernel[1], padding)
pad_list = [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]
else:
# torch pad takes width padding first, then height padding
padding = (padding[1], padding[0])
pad_list = [item for sublist in padding for item in sublist]

x = torch.nn.functional.pad(
x,
[pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2],
pad_list,
value=float("-inf"),
)
if padding != "VALID" and padding != "SAME":
raise ivy.exceptions.IvyException(
"Invalid padding arg {}\n"
'Must be one of: "VALID" or "SAME"'.format(padding)
)

res = torch.nn.functional.max_pool2d(x, kernel, strides, 0)
if data_format == "NHWC":
return res.permute(0, 2, 3, 1)
Expand Down
29 changes: 27 additions & 2 deletions ivy_tests/test_ivy/helpers/hypothesis_helpers/array_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,7 +899,9 @@ def array_and_broadcastable_shape(draw, dtype):


@st.composite
def arrays_for_pooling(draw, min_dims, max_dims, min_side, max_side):
def arrays_for_pooling(
draw, min_dims, max_dims, min_side, max_side, allow_explicit_padding=False
):
in_shape = draw(
nph.array_shapes(
min_dims=min_dims, max_dims=max_dims, min_side=min_side, max_side=max_side
Expand Down Expand Up @@ -929,6 +931,29 @@ def arrays_for_pooling(draw, min_dims, max_dims, min_side, max_side):
)
if array_dim == 3:
kernel = draw(st.tuples(st.integers(1, in_shape[1])))
padding = draw(st.sampled_from(["VALID", "SAME"]))
if allow_explicit_padding:
padding = []
for i in range(array_dim - 2):
max_pad = kernel[i] // 2
possible_pad_combos = [
(i, max_pad - i)
for i in range(0, max_pad)
if i + (max_pad - i) == max_pad
]
if len(possible_pad_combos) == 0:
pad_selected_combo = (0, 0)
else:
pad_selected_combo = draw(st.sampled_from(possible_pad_combos))
padding.append(
draw(
st.tuples(
st.integers(0, pad_selected_combo[0]),
st.integers(0, pad_selected_combo[1]),
)
)
)
padding = draw(st.one_of(st.just(padding), st.sampled_from(["VALID", "SAME"])))
else:
padding = draw(st.sampled_from(["VALID", "SAME"]))
strides = draw(st.tuples(st.integers(1, in_shape[1])))
return dtype, x, kernel, strides, padding
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@

@handle_test(
fn_tree="functional.ivy.experimental.max_pool2d",
x_k_s_p=helpers.arrays_for_pooling(min_dims=4, max_dims=4, min_side=1, max_side=4),
x_k_s_p=helpers.arrays_for_pooling(
min_dims=4, max_dims=4, min_side=1, max_side=4, allow_explicit_padding=True
),
test_gradients=st.just(False),
container_flags=st.just([False]),
)
def test_max_pool2d(
*,
Expand Down

0 comments on commit 786abee

Please sign in to comment.