forked from ivy-llc/ivy
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
frontends.torch.avg_pool2d: initial commit (ivy-llc#8854)
- Loading branch information
1 parent
9118467
commit bbb0390
Showing
2 changed files
with
130 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
import ivy | ||
from ivy.functional.frontends.tensorflow.func_wrapper import ( | ||
to_ivy_arrays_and_back, | ||
) | ||
|
||
|
||
def _broadcast_pooling_helper(x, pool_dims: str = "2d", name: str = "padding"): | ||
dims = {"1d": 1, "2d": 2, "3d": 3} | ||
|
||
if isinstance(x, int): | ||
return tuple([x for _ in range(dims[pool_dims])]) | ||
|
||
if len(x) == 1: | ||
return tuple([x[0] for _ in range(dims[pool_dims])]) | ||
elif len(x) == dims[pool_dims]: | ||
return tuple(x) | ||
elif len(x) != dims[pool_dims]: | ||
raise ValueError( | ||
f"`{name}` must either be a single int, " | ||
f"or a tuple of {dims[pool_dims]} ints. " | ||
) | ||
|
||
|
||
@to_ivy_arrays_and_back | ||
def avg_pool2d( | ||
input, | ||
kernel_size, | ||
stride=None, | ||
padding=0, | ||
ceil_mode=False, | ||
count_include_pad=True, | ||
divisor_override=None, | ||
): | ||
# Figure out input dims N | ||
input_rank = input.ndim | ||
|
||
if input_rank == 3: | ||
# CHW | ||
data_format = "CHW" | ||
elif input_rank == 4: | ||
# NCHW | ||
data_format = "NCHW" | ||
|
||
kernel_size = _broadcast_pooling_helper(kernel_size, "2d", name="kernel_size") | ||
stride = _broadcast_pooling_helper(stride, "2d", name="stride") | ||
padding = _broadcast_pooling_helper(padding, "2d", name="padding") | ||
kernel_pads = list(zip(kernel_size, padding)) | ||
|
||
# Padding should be less than or equal to half of kernel size | ||
if not all([pad <= kernel / 2 for kernel, pad in kernel_pads]): | ||
raise ValueError( | ||
"pad should be smaller than or equal to half of kernel size, " | ||
f"but got padding={padding}, kernel_size={kernel_size}. " | ||
) | ||
|
||
# Figure out padding string | ||
if all([pad == ivy.ceil((kernel - 1) / 2) for kernel, pad in kernel_pads]): | ||
padding_str = "SAME" | ||
else: | ||
padding_str = "VALID" | ||
|
||
return ivy.avg_pool2d( | ||
input, | ||
kernel_size, | ||
stride, | ||
padding_str, | ||
data_format=data_format, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
# global | ||
import ivy | ||
|
||
# local | ||
import ivy_tests.test_ivy.helpers as helpers | ||
|
||
from ivy_tests.test_ivy.helpers import handle_frontend_test | ||
|
||
|
||
# avg_pool2d | ||
@handle_frontend_test( | ||
fn_tree="torch.nn.functional.avg_pool2d", | ||
dtype_x_k_s=helpers.arrays_for_pooling( | ||
min_dims=4, | ||
max_dims=4, | ||
min_side=1, | ||
max_side=4, | ||
), | ||
) | ||
def test_torch_avg_pool2d( | ||
dtype_x_k_s, | ||
*, | ||
as_variable, | ||
num_positional_args, | ||
native_array, | ||
frontend, | ||
fn_tree, | ||
on_device, | ||
): | ||
input_dtype, x, kernel_size, stride, padding = dtype_x_k_s | ||
|
||
# Torch ground truth func expects input to be consistent | ||
# with a channels first format i.e. NCHW | ||
x[0] = x[0].reshape((x[0].shape[0], x[0].shape[-1], *x[0].shape[1:-1])) | ||
x_shape = list(x[0].shape[2:]) | ||
|
||
# Torch ground truth func also takes padding input as an integer | ||
# or a tuple of integers, not a string | ||
padding = tuple( | ||
[ | ||
ivy.handle_padding(x_shape[i], stride[0], kernel_size[i], padding) | ||
for i in range(len(x_shape)) | ||
] | ||
) | ||
|
||
helpers.test_frontend_function( | ||
input_dtypes=input_dtype, | ||
as_variable_flags=as_variable, | ||
with_out=False, | ||
num_positional_args=num_positional_args, | ||
native_array_flags=native_array, | ||
frontend=frontend, | ||
fn_tree=fn_tree, | ||
on_device=on_device, | ||
input=x[0], | ||
kernel_size=kernel_size, | ||
stride=stride, | ||
padding=padding, | ||
ceil_mode=False, | ||
count_include_pad=True, | ||
divisor_override=None, | ||
) |