Skip to content

Commit

Permalink
added affine_grid to the paddle frontend #18942 (#19371)
Browse files Browse the repository at this point in the history
Co-authored-by: hirwa-nshuti <[email protected]>
  • Loading branch information
so-dipe and fnhirwa authored Jul 21, 2023
1 parent 2f90ce7 commit 61f9234
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 0 deletions.
77 changes: 77 additions & 0 deletions ivy/functional/frontends/paddle/nn/functional/vision.py
Original file line number Diff line number Diff line change
@@ -1 +1,78 @@
# local

import ivy
from ivy.func_wrapper import with_unsupported_dtypes
from ivy.functional.frontends.paddle.func_wrapper import (
to_ivy_arrays_and_back,
)


@to_ivy_arrays_and_back
@with_unsupported_dtypes({"2.5.0 and below": ("float16", "bfloat16")}, "paddle")
def affine_grid(theta, out_shape, align_corners=True):
if len(out_shape) == 4:
N, C, H, W = out_shape
base_grid = ivy.empty((N, H, W, 3))
if align_corners:
base_grid[:, :, :, 0] = ivy.linspace(-1, 1, W)
base_grid[:, :, :, 1] = ivy.expand_dims(ivy.linspace(-1, 1, H), axis=-1)
height_values = ivy.expand_dims(ivy.linspace(-1, 1, H), axis=-1)
base_grid[:, :, :, 1] = ivy.array(
[[[height_values[i]] * W for i in range(H)]]
)[:, :, :, 0]
base_grid[:, :, :, 2] = ivy.full((H, W), 1)
grid = ivy.matmul(base_grid.view((N, H * W, 3)), theta.swapaxes(1, 2))
return grid.view((N, H, W, 2))
else:
base_grid[:, :, :, 0] = ivy.linspace(-1, 1, W) * (W - 1) / W
base_grid[:, :, :, 1] = ivy.expand_dims(
ivy.linspace(-1, 1, H) * (H - 1) / H, axis=-1
)
height_values = ivy.expand_dims(
ivy.linspace(-1, 1, H) * (H - 1) / H, axis=-1
)
base_grid[:, :, :, 1] = ivy.array(
[[[height_values[i]] * W for i in range(H)]]
)[:, :, :, 0]
base_grid[:, :, :, 2] = ivy.full((H, W), 1)
grid = ivy.matmul(base_grid.view((N, H * W, 3)), ivy.swapaxes(theta, 1, 2))
return grid.view((N, H, W, 2))
else:
N, C, D, H, W = out_shape
base_grid = ivy.empty((N, D, H, W, 4))
if align_corners:
base_grid[:, :, :, :, 0] = ivy.linspace(-1, 1, W)
base_grid[:, :, :, :, 1] = ivy.expand_dims(ivy.linspace(-1, 1, H), axis=-1)
height_values = ivy.linspace(-1, 1, H)
base_grid[:, :, :, :, 1] = ivy.array(
[[[[height_values[i]] * W for i in range(H)]] * D]
)
base_grid[:, :, :, :, 2] = ivy.expand_dims(
ivy.expand_dims(ivy.linspace(-1, 1, D), axis=-1), axis=-1
)
width_values = ivy.linspace(-1, 1, D)
base_grid[:, :, :, :, 2] = ivy.array(
[[ivy.array([[width_values[i]] * W] * H) for i in range(D)]]
)
base_grid[:, :, :, :, 3] = ivy.full((D, H, W), 1)
grid = ivy.matmul(base_grid.view((N, D * H * W, 4)), theta.swapaxes(1, 2))
return grid.view((N, D, H, W, 3))
else:
base_grid[:, :, :, :, 0] = ivy.linspace(-1, 1, W) * (W - 1) / W
base_grid[:, :, :, :, 1] = ivy.expand_dims(
ivy.linspace(-1, 1, H) * (H - 1) / H, axis=-1
)
height_values = ivy.linspace(-1, 1, H) * (H - 1) / H
base_grid[:, :, :, :, 1] = ivy.array(
[[[[height_values[i]] * W for i in range(H)]] * D]
)
base_grid[:, :, :, :, 2] = ivy.expand_dims(
ivy.expand_dims(ivy.linspace(-1, 1, D) * (D - 1) / D, axis=-1), axis=-1
)
width_values = ivy.linspace(-1, 1, D) * (D - 1) / D
base_grid[:, :, :, :, 2] = ivy.array(
[[ivy.array([[width_values[i]] * W] * H) for i in range(D)]]
)
base_grid[:, :, :, :, 3] = ivy.full((D, H, W), 1)
grid = ivy.matmul(base_grid.view((N, D * H * W, 4)), theta.swapaxes(1, 2))
return grid.view((N, D, H, W, 3))
Original file line number Diff line number Diff line change
@@ -1,3 +1,70 @@
# global
from hypothesis import strategies as st

# local
import ivy_tests.test_ivy.helpers as helpers
from ivy_tests.test_ivy.helpers import handle_frontend_test


@st.composite
def _affine_grid_helper(draw):
align_corners = draw(st.booleans())
dims = draw(st.integers(4, 5))
if dims == 4:
size = draw(
st.tuples(
st.integers(1, 20),
st.integers(1, 20),
st.integers(2, 20),
st.integers(2, 20),
)
)
theta_dtype, theta = draw(
helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("float"),
min_value=0,
max_value=1,
shape=(size[0], 2, 3),
)
)
return theta_dtype, theta[0], size, align_corners
else:
size = draw(
st.tuples(
st.integers(1, 20),
st.integers(1, 20),
st.integers(2, 20),
st.integers(2, 20),
st.integers(2, 20),
)
)
theta_dtype, theta = draw(
helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("float"),
min_value=0,
max_value=1,
shape=(size[0], 3, 4),
)
)
return theta_dtype, theta[0], size, align_corners


@handle_frontend_test(
fn_tree="paddle.nn.functional.affine_grid",
dtype_and_input_and_other=_affine_grid_helper(),
)
def test_paddle_affine_grid(
*, dtype_and_input_and_other, on_device, fn_tree, frontend, test_flags
):
dtype, theta, size, align_corners = dtype_and_input_and_other

helpers.test_frontend_function(
input_dtypes=dtype,
frontend=frontend,
test_flags=test_flags,
fn_tree=fn_tree,
on_device=on_device,
theta=theta,
out_shape=size,
align_corners=align_corners,
)

0 comments on commit 61f9234

Please sign in to comment.