Skip to content

Commit

Permalink
Add layer conv1d_transpose for paddle backend (ivy-llc#20899)
Browse files Browse the repository at this point in the history
  • Loading branch information
yopknopixx authored and Sarvesh-Kesharwani committed Aug 10, 2023
1 parent 1ebe408 commit 8fe165d
Showing 1 changed file with 26 additions and 2 deletions.
28 changes: 26 additions & 2 deletions ivy/functional/backends/paddle/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,10 @@ def conv1d(
raise IvyNotImplementedException()


# noinspection PyUnresolvedReferences
@with_unsupported_device_and_dtypes(
{"2.5.0 and below": {"cpu": ("float16", "bfloat16")}},
backend_version,
)
def conv1d_transpose(
x: paddle.Tensor,
filters: paddle.Tensor,
Expand All @@ -170,7 +173,28 @@ def conv1d_transpose(
bias: Optional[paddle.Tensor] = None,
out: Optional[paddle.Tensor] = None,
):
raise IvyNotImplementedException()
if data_format == "NWC":
x = x.transpose([0, 2, 1])
strides = [strides] if isinstance(strides, int) else strides
dilations = [dilations] if isinstance(dilations, int) else dilations
filters = filters.transpose([1, 2, 0])
not_valid_pad, padding_list, output_padding = _pad_before_conv_tranpose(
x, filters, strides, padding, 1, dilations, output_shape, filters.shape[2:]
)
res = paddle.nn.functional.conv1d_transpose(
x,
filters,
stride=strides,
padding=padding_list,
output_padding=output_padding,
dilation=dilations,
data_format="NCL",
)
if not_valid_pad[0]:
res = res[:, :, 0:-1]
if data_format == "NWC":
res = res.transpose([0, 2, 1])
return res


# noinspection PyUnresolvedReferences
Expand Down

0 comments on commit 8fe165d

Please sign in to comment.