Skip to content

Commit

Permalink
fix: added support for dilation in torch.nn.functional.conv_transpose…
Browse files Browse the repository at this point in the history
…2d (ivy-llc#27411)

Co-authored-by: @AnnaTz
  • Loading branch information
vismaysur authored Nov 30, 2023
1 parent 99f184d commit 423ebed
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@ def _conv(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
return ret


# ToDo: add support for dilation > 1
# ToDo: add support for output_padding > padding
def _conv_transpose(
input,
weight,
Expand All @@ -54,10 +52,20 @@ def _conv_transpose(
weight = ivy.permute_dims(weight, axes=(*range(2, dims + 2), 0, 1))
for i in range(dims):
weight = ivy.flip(weight, axis=i)
padding, output_padding = map(
lambda x: [x] * dims if isinstance(x, int) else x, [padding, output_padding]
padding, output_padding, stride, dilation = map(
lambda x: [x] * dims if isinstance(x, int) else x,
[padding, output_padding, stride, dilation],
)
pad_widths = [(weight.shape[i] - 1,) * 2 for i in range(dims)]

pad_widths = [
(
(weight.shape[i] - 1) * dilation[i]
+ max([output_padding[i] - padding[i], 0]),
)
* 2
for i in range(dims)
]

ret = ivy.conv_general_dilated(
input,
weight,
Expand All @@ -67,12 +75,17 @@ def _conv_transpose(
data_format="channel_first",
feature_group_count=groups,
x_dilations=stride,
dilations=dilation,
bias=bias,
)
unpad_slice = (slice(None),) * 2
for i in range(dims):
unpad_slice += (
slice(padding[i], ret.shape[2 + i] - padding[i] + output_padding[i], 1),
slice(
max([padding[i] - (dilation[i] // 2), padding[i], output_padding[i]]),
ret.shape[2 + i] - padding[i] + output_padding[i] + (dilation[i] // 2),
1,
),
)
ret = ret[unpad_slice]
return ret
Expand Down Expand Up @@ -208,16 +221,31 @@ def conv_transpose1d(
groups=1,
dilation=1,
):
return _conv_transpose(
input,
weight,
bias=bias,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation,
)
if ivy.current_backend_str() in ["torch"]:
# this backend supports explicit padding, no need for conv_general_dilated
return ivy.conv_general_transpose(
input,
weight,
stride,
_get_transpose_pad(padding, output_padding, 1),
dims=1,
filter_format="channel_first",
data_format="channel_first",
dilations=dilation,
feature_group_count=groups,
bias=bias,
)
else:
return _conv_transpose(
input,
weight,
bias=bias,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation,
)


@with_unsupported_dtypes({"2.1.1 and below": ("float16", "bfloat16")}, "torch")
Expand Down Expand Up @@ -271,16 +299,31 @@ def conv_transpose3d(
groups=1,
dilation=1,
):
return _conv_transpose(
input,
weight,
bias=bias,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation,
)
if ivy.current_backend_str() in ["torch"]:
# this backend supports explicit padding, no need for conv_general_dilated
return ivy.conv_general_transpose(
input,
weight,
stride,
_get_transpose_pad(padding, output_padding, 3),
dims=3,
filter_format="channel_first",
data_format="channel_first",
dilations=dilation,
feature_group_count=groups,
bias=bias,
)
else:
return _conv_transpose(
input,
weight,
bias=bias,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation,
)


@to_ivy_arrays_and_back
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,9 +227,6 @@ def _x_and_filters(draw, dim: int = 2, transpose: bool = False, max_dilation=3):
)
padding = [padding] * dim if isinstance(padding, int) else padding
for i in range(len(output_padding)):
# ToDo: remove this when support for output_padding > padding is added
if dim != 2:
output_padding[i] = min(padding[i], output_padding[i])
m = min(fstrides[i], fdilations[i])
output_padding[i] = min(output_padding[i], m - 1)
if draw(st.booleans()):
Expand Down Expand Up @@ -364,7 +361,7 @@ def test_torch_conv3d(

@handle_frontend_test(
fn_tree="torch.nn.functional.conv_transpose1d",
dtype_vals=_x_and_filters(dim=1, transpose=True, max_dilation=1),
dtype_vals=_x_and_filters(dim=1, transpose=True),
)
def test_torch_conv_transpose1d(
*,
Expand Down Expand Up @@ -444,7 +441,7 @@ def test_torch_conv_transpose2d(

@handle_frontend_test(
fn_tree="torch.nn.functional.conv_transpose3d",
dtype_vals=_x_and_filters(dim=3, transpose=True, max_dilation=1),
dtype_vals=_x_and_filters(dim=3, transpose=True),
)
def test_torch_conv_transpose3d(
*,
Expand Down

0 comments on commit 423ebed

Please sign in to comment.