Skip to content

Commit

Permalink
Removed x_dilations assertion from ivy.conv as it's unused, added mis…
Browse files Browse the repository at this point in the history
…sing tests for ivy.conv, updated the x_and_filters helper to not generate x_dilations for transposed convolutions (#22000)
  • Loading branch information
vedpatwardhan authored Aug 16, 2023
1 parent 9bee42a commit 83c471f
Show file tree
Hide file tree
Showing 2 changed files with 207 additions and 28 deletions.
7 changes: 4 additions & 3 deletions ivy/functional/ivy/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,8 +590,10 @@ def scaled_dot_product_attention(
... b=ivy.array([[[3.2, 1.], [2.2, 3.6], [4.0, 5.6]]]))
>>> v = ivy.Container(a=ivy.array([[[5.2, 1.], [2.1, 3.], [4.4, 5.6]]]),
... b=ivy.array([[[0.2, 1.], [2.2, 3.], [4.4, 5.6]]]))
>>> mask = ivy.Container(a=ivy.array([[[1.0, 1.0, 1.0],[1.0, 1.0, 1.0],[1.0, 1.0, 1.0]]]),
... b=ivy.array([[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0,1.0]]]))
>>> mask = ivy.Container(
... a=ivy.array([[[1.0, 1.0, 1.0],[1.0, 1.0, 1.0],[1.0, 1.0, 1.0]]]),
... b=ivy.array([[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0,1.0]]])
... )
>>> result = ivy.scaled_dot_product_attention(q,k,v,scale=1,mask=mask)
>>> print(result)
{
Expand Down Expand Up @@ -2053,7 +2055,6 @@ def conv(
The result of the transpose or dilated convolution operation.
"""
if transpose:
assert x_dilations == 1, "x_dilations must be 1 for transpose convolutions."
return conv_general_transpose(
x,
filters,
Expand Down
228 changes: 203 additions & 25 deletions ivy_tests/test_ivy/test_functional/test_nn/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,14 +644,14 @@ def x_and_filters(
)
if general:
data_format = "channel_first" if channel_first else "channel_last"

x_dilation = draw(
st.one_of(
st.integers(1, 3),
st.lists(st.integers(1, 3), min_size=dim, max_size=dim),
if not transpose:
x_dilation = draw(
st.one_of(
st.integers(1, 3),
st.lists(st.integers(1, 3), min_size=dim, max_size=dim),
)
)
)
dilations = (dilations, x_dilation)
dilations = (dilations, x_dilation)
if filter_format is not None:
filter_format = draw(filter_format)
if filter_format == "channel_first":
Expand Down Expand Up @@ -694,9 +694,18 @@ def _assume_tf_dilation_gt_1(backend_fw, on_device, dilations):
ground_truth_backend="jax",
)
def test_conv1d(*, x_f_d_df, test_flags, backend_fw, fn_name, on_device):
dtype, x, filters, dilations, data_format, stride, pad, fc, ff_format, bias = (
x_f_d_df
)
(
dtype,
x,
filters,
dilations,
data_format,
stride,
pad,
fc,
ff_format,
bias,
) = x_f_d_df
# ToDo: Enable gradient tests for dilations > 1 when tensorflow supports it.
_assume_tf_dilation_gt_1(backend_fw, on_device, dilations[0])
helpers.test_function(
Expand Down Expand Up @@ -730,9 +739,18 @@ def test_conv1d(*, x_f_d_df, test_flags, backend_fw, fn_name, on_device):
ground_truth_backend="jax",
)
def test_conv1d_transpose(*, x_f_d_df, test_flags, backend_fw, fn_name, on_device):
dtype, x, filters, dilations, data_format, stride, pad, output_shape, fc, bias = (
x_f_d_df
)
(
dtype,
x,
filters,
dilations,
data_format,
stride,
pad,
output_shape,
fc,
bias,
) = x_f_d_df
_assume_tf_dilation_gt_1(backend_fw, on_device, dilations[0])
helpers.test_function(
input_dtypes=dtype,
Expand Down Expand Up @@ -765,9 +783,18 @@ def test_conv1d_transpose(*, x_f_d_df, test_flags, backend_fw, fn_name, on_devic
ground_truth_backend="jax",
)
def test_conv2d(*, x_f_d_df, test_flags, backend_fw, fn_name, on_device):
dtype, x, filters, dilations, data_format, stride, pad, fc, ff_format, bias = (
x_f_d_df
)
(
dtype,
x,
filters,
dilations,
data_format,
stride,
pad,
fc,
ff_format,
bias,
) = x_f_d_df
# ToDo: Enable gradient tests for dilations > 1 when tensorflow supports it.
_assume_tf_dilation_gt_1(backend_fw, on_device, dilations[0])
helpers.test_function(
Expand Down Expand Up @@ -802,9 +829,18 @@ def test_conv2d(*, x_f_d_df, test_flags, backend_fw, fn_name, on_device):
ground_truth_backend="jax",
)
def test_conv2d_transpose(*, x_f_d_df, test_flags, backend_fw, fn_name, on_device):
dtype, x, filters, dilations, data_format, stride, pad, output_shape, fc, bias = (
x_f_d_df
)
(
dtype,
x,
filters,
dilations,
data_format,
stride,
pad,
output_shape,
fc,
bias,
) = x_f_d_df
_assume_tf_dilation_gt_1(backend_fw, on_device, dilations[0])

helpers.test_function(
Expand Down Expand Up @@ -870,9 +906,18 @@ def test_depthwise_conv2d(*, x_f_d_df, test_flags, backend_fw, fn_name, on_devic
ground_truth_backend="jax",
)
def test_conv3d(*, x_f_d_df, test_flags, backend_fw, fn_name, on_device):
dtype, x, filters, dilations, data_format, stride, pad, fc, ff_format, bias = (
x_f_d_df
)
(
dtype,
x,
filters,
dilations,
data_format,
stride,
pad,
fc,
ff_format,
bias,
) = x_f_d_df
_assume_tf_dilation_gt_1(backend_fw, on_device, dilations[0])
helpers.test_function(
input_dtypes=dtype,
Expand Down Expand Up @@ -905,9 +950,18 @@ def test_conv3d(*, x_f_d_df, test_flags, backend_fw, fn_name, on_device):
ground_truth_backend="jax",
)
def test_conv3d_transpose(*, x_f_d_df, test_flags, backend_fw, fn_name, on_device):
dtype, x, filters, dilations, data_format, stride, pad, output_shape, fc, bias = (
x_f_d_df
)
(
dtype,
x,
filters,
dilations,
data_format,
stride,
pad,
output_shape,
fc,
bias,
) = x_f_d_df
_assume_tf_dilation_gt_1(backend_fw, on_device, dilations[0])
helpers.test_function(
input_dtypes=dtype,
Expand Down Expand Up @@ -1026,6 +1080,130 @@ def test_conv_general_transpose(
)


# filter_format not in conv_general_transpose
# output_shape not in conv_general_dilated
@st.composite
def x_and_filters_and_transpose(
draw,
dim: int = 2,
general=False,
bias=False,
filter_format=None,
):
transpose = draw(st.booleans())
if not transpose:
filter_format = st.sampled_from(["channel_last", "channel_first"])
all_args = draw(
x_and_filters(
dim=dim,
general=general,
bias=bias,
filter_format=filter_format,
transpose=transpose,
)
)
output_shape = None
filter_format = "channel_last"
if transpose:
(
dtype,
x,
filters,
dilations,
data_format,
stride,
pad,
output_shape,
fc,
bias,
) = all_args
else:
(
dtype,
x,
filters,
dilations,
data_format,
stride,
pad,
fc,
filter_format,
bias,
) = all_args
return (
dtype,
x,
filters,
stride,
pad,
transpose,
output_shape,
data_format,
filter_format,
fc,
dilations,
bias,
)


# conv
@handle_test(
fn_tree="functional.ivy.conv",
dims=st.shared(st.integers(1, 3), key="dims"),
x_f_d_df_tr=x_and_filters_and_transpose(
dim=st.shared(st.integers(1, 3), key="dims"),
general=True,
bias=True,
),
ground_truth_backend="jax",
)
def test_conv(*, dims, x_f_d_df_tr, test_flags, backend_fw, fn_name, on_device):
# pass
(
dtype,
x,
filters,
stride,
pad,
transpose,
output_shape,
data_format,
filter_format,
fc,
dilations,
bias,
) = x_f_d_df_tr
tf_dilations = dilations
if not transpose:
tf_dilations = tf_dilations[0]
dilations, x_dilations = dilations
else:
x_dilations = None
_assume_tf_dilation_gt_1(backend_fw, on_device, tf_dilations)
helpers.test_function(
input_dtypes=dtype,
test_flags=test_flags,
backend_to_test=backend_fw,
fn_name=fn_name,
on_device=on_device,
rtol_=1e-2,
atol_=1e-2,
x=x,
filters=filters,
strides=stride,
padding=pad,
transpose=transpose,
dims=dims,
output_shape=output_shape,
data_format=data_format,
filter_format=filter_format,
feature_group_count=fc,
x_dilations=x_dilations,
dilations=dilations,
bias=bias,
)


# LSTM #
# -----#

Expand Down

0 comments on commit 83c471f

Please sign in to comment.