Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Removed x_dilations assertion from ivy.conv as it's unused, added missing tests for ivy.conv, updated the x_and_filters helper to not generate x_dilations for transposed convolutions #22000

Merged
merged 1 commit into from
Aug 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions ivy/functional/ivy/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def linear(
ret
Result array of the linear transformation.
*[outer_batch_shape,inner_batch_shape,out_features]*

Both the description and the type hints above assumes an array input for simplicity,
but this function is *nestable*, and therefore also accepts :class:`ivy.Container`
instances in place of any of the arguments.
Expand All @@ -123,7 +123,7 @@ def linear(
>>> y = ivy.linear(x, w)
>>> print(y)
ivy.array([1.])

>>> x = ivy.array([[0.666, -0.4269, 1.911]])
>>> w = ivy.array([[1., 0., 0.], [0., 0., 1.]])
>>> y = ivy.zeros((1, 2))
Expand All @@ -143,7 +143,7 @@ def linear(
ivy.array([[ 34.98495483, 101.0293808 ],
[ 28.0159359 , 83.74752808],
[ 37.20942307, 108.3205719 ]])

With :class:`ivy.Container` input:

>>> x = ivy.Container(a=ivy.array([[1., 2., 3.],
Expand Down Expand Up @@ -181,7 +181,7 @@ def linear(
b: ivy.array([[15.1, 32., 47.9],
[85., 196., 306.]])
}

"""
outer_batch_shape = list(weight.shape[:-2])
num_outer_batch_dims = len(outer_batch_shape)
Expand Down Expand Up @@ -590,8 +590,10 @@ def scaled_dot_product_attention(
... b=ivy.array([[[3.2, 1.], [2.2, 3.6], [4.0, 5.6]]]))
>>> v = ivy.Container(a=ivy.array([[[5.2, 1.], [2.1, 3.], [4.4, 5.6]]]),
... b=ivy.array([[[0.2, 1.], [2.2, 3.], [4.4, 5.6]]]))
>>> mask = ivy.Container(a=ivy.array([[[1.0, 1.0, 1.0],[1.0, 1.0, 1.0],[1.0, 1.0, 1.0]]]),
... b=ivy.array([[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0,1.0]]]))
>>> mask = ivy.Container(
... a=ivy.array([[[1.0, 1.0, 1.0],[1.0, 1.0, 1.0],[1.0, 1.0, 1.0]]]),
... b=ivy.array([[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0,1.0]]])
... )
>>> result = ivy.scaled_dot_product_attention(q,k,v,scale=1,mask=mask)
>>> print(result)
{
Expand Down Expand Up @@ -1602,10 +1604,10 @@ def conv3d(
while "NCDHW" corresponds to input with shape (batch_size, channels, depth,
height, width).
filter_format
Either "channel_first" or "channel_last". "channel_first" corresponds
Either "channel_first" or "channel_last". "channel_first" corresponds
to "OIDHW",input data formats, while "channel_last" corresponds to "DHWIO".
x_dilations
The dilation factor for each dimension of input. (Default value = 1)
The dilation factor for each dimension of input. (Default value = 1)
dilations
The dilation factor for each dimension of input. (Default value = 1)
bias
Expand Down Expand Up @@ -1983,8 +1985,8 @@ def conv_general_transpose(
@handle_exceptions
@handle_array_like_without_promotion
@handle_out_argument
@handle_array_function
@inputs_to_native_shapes
@handle_array_function
def conv(
x: Union[ivy.Array, ivy.NativeArray],
filters: Union[ivy.Array, ivy.NativeArray],
Expand Down Expand Up @@ -2053,7 +2055,6 @@ def conv(
The result of the transpose or dilated convolution operation.
"""
if transpose:
assert x_dilations == 1, "x_dilations must be 1 for transpose convolutions."
return conv_general_transpose(
x,
filters,
Expand Down
228 changes: 203 additions & 25 deletions ivy_tests/test_ivy/test_functional/test_nn/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,14 +644,14 @@ def x_and_filters(
)
if general:
data_format = "channel_first" if channel_first else "channel_last"

x_dilation = draw(
st.one_of(
st.integers(1, 3),
st.lists(st.integers(1, 3), min_size=dim, max_size=dim),
if not transpose:
x_dilation = draw(
st.one_of(
st.integers(1, 3),
st.lists(st.integers(1, 3), min_size=dim, max_size=dim),
)
)
)
dilations = (dilations, x_dilation)
dilations = (dilations, x_dilation)
if filter_format is not None:
filter_format = draw(filter_format)
if filter_format == "channel_first":
Expand Down Expand Up @@ -694,9 +694,18 @@ def _assume_tf_dilation_gt_1(backend_fw, on_device, dilations):
ground_truth_backend="jax",
)
def test_conv1d(*, x_f_d_df, test_flags, backend_fw, fn_name, on_device):
dtype, x, filters, dilations, data_format, stride, pad, fc, ff_format, bias = (
x_f_d_df
)
(
dtype,
x,
filters,
dilations,
data_format,
stride,
pad,
fc,
ff_format,
bias,
) = x_f_d_df
# ToDo: Enable gradient tests for dilations > 1 when tensorflow supports it.
_assume_tf_dilation_gt_1(backend_fw, on_device, dilations[0])
helpers.test_function(
Expand Down Expand Up @@ -730,9 +739,18 @@ def test_conv1d(*, x_f_d_df, test_flags, backend_fw, fn_name, on_device):
ground_truth_backend="jax",
)
def test_conv1d_transpose(*, x_f_d_df, test_flags, backend_fw, fn_name, on_device):
dtype, x, filters, dilations, data_format, stride, pad, output_shape, fc, bias = (
x_f_d_df
)
(
dtype,
x,
filters,
dilations,
data_format,
stride,
pad,
output_shape,
fc,
bias,
) = x_f_d_df
_assume_tf_dilation_gt_1(backend_fw, on_device, dilations[0])
helpers.test_function(
input_dtypes=dtype,
Expand Down Expand Up @@ -765,9 +783,18 @@ def test_conv1d_transpose(*, x_f_d_df, test_flags, backend_fw, fn_name, on_devic
ground_truth_backend="jax",
)
def test_conv2d(*, x_f_d_df, test_flags, backend_fw, fn_name, on_device):
dtype, x, filters, dilations, data_format, stride, pad, fc, ff_format, bias = (
x_f_d_df
)
(
dtype,
x,
filters,
dilations,
data_format,
stride,
pad,
fc,
ff_format,
bias,
) = x_f_d_df
# ToDo: Enable gradient tests for dilations > 1 when tensorflow supports it.
_assume_tf_dilation_gt_1(backend_fw, on_device, dilations[0])
helpers.test_function(
Expand Down Expand Up @@ -802,9 +829,18 @@ def test_conv2d(*, x_f_d_df, test_flags, backend_fw, fn_name, on_device):
ground_truth_backend="jax",
)
def test_conv2d_transpose(*, x_f_d_df, test_flags, backend_fw, fn_name, on_device):
dtype, x, filters, dilations, data_format, stride, pad, output_shape, fc, bias = (
x_f_d_df
)
(
dtype,
x,
filters,
dilations,
data_format,
stride,
pad,
output_shape,
fc,
bias,
) = x_f_d_df
_assume_tf_dilation_gt_1(backend_fw, on_device, dilations[0])

helpers.test_function(
Expand Down Expand Up @@ -870,9 +906,18 @@ def test_depthwise_conv2d(*, x_f_d_df, test_flags, backend_fw, fn_name, on_devic
ground_truth_backend="jax",
)
def test_conv3d(*, x_f_d_df, test_flags, backend_fw, fn_name, on_device):
dtype, x, filters, dilations, data_format, stride, pad, fc, ff_format, bias = (
x_f_d_df
)
(
dtype,
x,
filters,
dilations,
data_format,
stride,
pad,
fc,
ff_format,
bias,
) = x_f_d_df
_assume_tf_dilation_gt_1(backend_fw, on_device, dilations[0])
helpers.test_function(
input_dtypes=dtype,
Expand Down Expand Up @@ -905,9 +950,18 @@ def test_conv3d(*, x_f_d_df, test_flags, backend_fw, fn_name, on_device):
ground_truth_backend="jax",
)
def test_conv3d_transpose(*, x_f_d_df, test_flags, backend_fw, fn_name, on_device):
dtype, x, filters, dilations, data_format, stride, pad, output_shape, fc, bias = (
x_f_d_df
)
(
dtype,
x,
filters,
dilations,
data_format,
stride,
pad,
output_shape,
fc,
bias,
) = x_f_d_df
_assume_tf_dilation_gt_1(backend_fw, on_device, dilations[0])
helpers.test_function(
input_dtypes=dtype,
Expand Down Expand Up @@ -1026,6 +1080,130 @@ def test_conv_general_transpose(
)


# filter_format not in conv_general_transpose
# output_shape not in conv_general_dilated
@st.composite
def x_and_filters_and_transpose(
draw,
dim: int = 2,
general=False,
bias=False,
filter_format=None,
):
transpose = draw(st.booleans())
if not transpose:
filter_format = st.sampled_from(["channel_last", "channel_first"])
all_args = draw(
x_and_filters(
dim=dim,
general=general,
bias=bias,
filter_format=filter_format,
transpose=transpose,
)
)
output_shape = None
filter_format = "channel_last"
if transpose:
(
dtype,
x,
filters,
dilations,
data_format,
stride,
pad,
output_shape,
fc,
bias,
) = all_args
else:
(
dtype,
x,
filters,
dilations,
data_format,
stride,
pad,
fc,
filter_format,
bias,
) = all_args
return (
dtype,
x,
filters,
stride,
pad,
transpose,
output_shape,
data_format,
filter_format,
fc,
dilations,
bias,
)


# conv
@handle_test(
fn_tree="functional.ivy.conv",
dims=st.shared(st.integers(1, 3), key="dims"),
x_f_d_df_tr=x_and_filters_and_transpose(
dim=st.shared(st.integers(1, 3), key="dims"),
general=True,
bias=True,
),
ground_truth_backend="jax",
)
def test_conv(*, dims, x_f_d_df_tr, test_flags, backend_fw, fn_name, on_device):
# pass
(
dtype,
x,
filters,
stride,
pad,
transpose,
output_shape,
data_format,
filter_format,
fc,
dilations,
bias,
) = x_f_d_df_tr
tf_dilations = dilations
if not transpose:
tf_dilations = tf_dilations[0]
dilations, x_dilations = dilations
else:
x_dilations = None
_assume_tf_dilation_gt_1(backend_fw, on_device, tf_dilations)
helpers.test_function(
input_dtypes=dtype,
test_flags=test_flags,
backend_to_test=backend_fw,
fn_name=fn_name,
on_device=on_device,
rtol_=1e-2,
atol_=1e-2,
x=x,
filters=filters,
strides=stride,
padding=pad,
transpose=transpose,
dims=dims,
output_shape=output_shape,
data_format=data_format,
filter_format=filter_format,
feature_group_count=fc,
x_dilations=x_dilations,
dilations=dilations,
bias=bias,
)


# LSTM #
# -----#

Expand Down
Loading