Skip to content

Commit

Permalink
frontends.torch.avg_pool2d: initial commit (ivy-llc#8854)
Browse files Browse the repository at this point in the history
  • Loading branch information
hmahmood24 committed Jan 19, 2023
1 parent 9118467 commit bbb0390
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 0 deletions.
68 changes: 68 additions & 0 deletions ivy/functional/frontends/torch/nn/functional/pooling_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import ivy
from ivy.functional.frontends.tensorflow.func_wrapper import (
to_ivy_arrays_and_back,
)


def _broadcast_pooling_helper(x, pool_dims: str = "2d", name: str = "padding"):
dims = {"1d": 1, "2d": 2, "3d": 3}

if isinstance(x, int):
return tuple([x for _ in range(dims[pool_dims])])

if len(x) == 1:
return tuple([x[0] for _ in range(dims[pool_dims])])
elif len(x) == dims[pool_dims]:
return tuple(x)
elif len(x) != dims[pool_dims]:
raise ValueError(
f"`{name}` must either be a single int, "
f"or a tuple of {dims[pool_dims]} ints. "
)


@to_ivy_arrays_and_back
def avg_pool2d(
input,
kernel_size,
stride=None,
padding=0,
ceil_mode=False,
count_include_pad=True,
divisor_override=None,
):
# Figure out input dims N
input_rank = input.ndim

if input_rank == 3:
# CHW
data_format = "CHW"
elif input_rank == 4:
# NCHW
data_format = "NCHW"

kernel_size = _broadcast_pooling_helper(kernel_size, "2d", name="kernel_size")
stride = _broadcast_pooling_helper(stride, "2d", name="stride")
padding = _broadcast_pooling_helper(padding, "2d", name="padding")
kernel_pads = list(zip(kernel_size, padding))

# Padding should be less than or equal to half of kernel size
if not all([pad <= kernel / 2 for kernel, pad in kernel_pads]):
raise ValueError(
"pad should be smaller than or equal to half of kernel size, "
f"but got padding={padding}, kernel_size={kernel_size}. "
)

# Figure out padding string
if all([pad == ivy.ceil((kernel - 1) / 2) for kernel, pad in kernel_pads]):
padding_str = "SAME"
else:
padding_str = "VALID"

return ivy.avg_pool2d(
input,
kernel_size,
stride,
padding_str,
data_format=data_format,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# global
import ivy

# local
import ivy_tests.test_ivy.helpers as helpers

from ivy_tests.test_ivy.helpers import handle_frontend_test


# avg_pool2d
@handle_frontend_test(
fn_tree="torch.nn.functional.avg_pool2d",
dtype_x_k_s=helpers.arrays_for_pooling(
min_dims=4,
max_dims=4,
min_side=1,
max_side=4,
),
)
def test_torch_avg_pool2d(
dtype_x_k_s,
*,
as_variable,
num_positional_args,
native_array,
frontend,
fn_tree,
on_device,
):
input_dtype, x, kernel_size, stride, padding = dtype_x_k_s

# Torch ground truth func expects input to be consistent
# with a channels first format i.e. NCHW
x[0] = x[0].reshape((x[0].shape[0], x[0].shape[-1], *x[0].shape[1:-1]))
x_shape = list(x[0].shape[2:])

# Torch ground truth func also takes padding input as an integer
# or a tuple of integers, not a string
padding = tuple(
[
ivy.handle_padding(x_shape[i], stride[0], kernel_size[i], padding)
for i in range(len(x_shape))
]
)

helpers.test_frontend_function(
input_dtypes=input_dtype,
as_variable_flags=as_variable,
with_out=False,
num_positional_args=num_positional_args,
native_array_flags=native_array,
frontend=frontend,
fn_tree=fn_tree,
on_device=on_device,
input=x[0],
kernel_size=kernel_size,
stride=stride,
padding=padding,
ceil_mode=False,
count_include_pad=True,
divisor_override=None,
)

0 comments on commit bbb0390

Please sign in to comment.