Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Request] Add affine_transform #460

Closed
james77777778 opened this issue Jul 12, 2023 · 0 comments
Closed

[Feature Request] Add affine_transform #460

james77777778 opened this issue Jul 12, 2023 · 0 comments

Comments

@james77777778
Copy link
Contributor

james77777778 commented Jul 12, 2023

TF:
tf.raw_ops.ImageProjectiveTransformV3
JAX:
jax.scipy.ndimage.map_coordinates
Torch:
torchvision.transforms.functional.affine
torch.nn.functional.affine_grid
torch.nn.functional.grid_sample

Supporting affine_transform might help implement geometry-related preprocessing layers.

Here is a working script:

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import torch
from torch.nn import functional as F
from torchvision.transforms.v2 import functional as TF

"""
Init
"""

height = 112
width = 224
np.random.seed(2023)
angle = np.radians(45)
image = np.random.uniform(size=(height, width, 3))
x_offset = (
    (width - 1) - (np.cos(angle) * (width - 1) - np.sin(angle) * (height - 1))
) / 2.0
y_offset = (
    (height - 1) - (np.sin(angle) * (width - 1) + np.cos(angle) * (height - 1))
) / 2.0
matrix = np.array(
    [
        [np.cos(angle), -np.sin(angle), x_offset],
        [np.sin(angle), np.cos(angle), y_offset],
        [0, 0, 1],
    ]
)


"""
TF
"""
# [a0, a1, a2, b0, b1, b2, c0, c1]
tf_transform = tf.convert_to_tensor(matrix.flatten()[:-1], dtype=tf.float32)
tf_transform = tf.expand_dims(tf_transform, axis=0)
tf_image = tf.expand_dims(image, axis=0)
tf_output_shape = tf.convert_to_tensor(tf.shape(tf_image)[1:-1], tf.int32)
tf_output = tf.raw_ops.ImageProjectiveTransformV3(
    images=tf_image,
    transforms=tf_transform,
    output_shape=tf_output_shape,
    fill_value=0,
    interpolation="bilinear".upper(),
    fill_mode="constant".upper(),
)
tf_output = tf_output[0].numpy()


"""
JAX
"""
meshgrid = jnp.meshgrid(
    *[jnp.arange(size) for size in image.shape], indexing="ij"
)
indices = jnp.concatenate(
    [jnp.expand_dims(x, axis=-1) for x in meshgrid], axis=-1
)
jax_matrix = jnp.array(
    [
        [matrix[1, 1], matrix[0, 1], 0],
        [matrix[1, 0], matrix[0, 0], 0],
        [0, 0, 1],
    ],
    dtype=jnp.float32,
)
jax_offset = jnp.array([matrix[1, 2], matrix[0, 2], 0], dtype=jnp.float32)
coordinates = indices @ jax_matrix
coordinates = jnp.moveaxis(coordinates, source=-1, destination=0)
coordinates += jnp.reshape(a=jax_offset, newshape=(*jax_offset.shape, 1, 1, 1))
jax_output = jax.scipy.ndimage.map_coordinates(
    image, coordinates, order=1, mode="constant", cval=0.0
)

"""
Torch
"""
# matrix to theta
h, w = height, width
torch_theta = torch.zeros((2, 3))
torch_theta[0, 0] = matrix[0, 0]
torch_theta[0, 1] = matrix[0, 1] * h / w
torch_theta[0, 2] = (
    matrix[0, 2] * 2 / w + torch_theta[0, 0] + torch_theta[0, 1] - 1
)
torch_theta[1, 0] = matrix[1, 0] * w / h
torch_theta[1, 1] = matrix[1, 1]
torch_theta[1, 2] = (
    matrix[1, 2] * 2 / h + torch_theta[1, 0] + torch_theta[1, 1] - 1
)
torch_theta = torch_theta.reshape((1, 2, 3))
torch_image = torch.from_numpy(image).to(torch.float32)
torch_image = torch_image.permute((2, 0, 1))
torch_image = torch_image.unsqueeze(dim=0)
grid = F.affine_grid(torch_theta, torch_image.shape)
torch_output = TF._geometry._apply_grid_transform(
    torch_image, grid, "bilinear", fill=0.0
)
torch_output = torch_output[0].permute((1, 2, 0)).numpy()


"""
Plots
"""
fig, ax_dict = plt.subplot_mosaic([["A", "B", "C"]])
ax_dict["A"].set_title("TF")
ax_dict["A"].imshow(tf_output)
ax_dict["B"].set_title("JAX")
ax_dict["B"].imshow(jax_output)
ax_dict["C"].set_title("Torch")
ax_dict["C"].imshow(torch_output)
fig.tight_layout(h_pad=0.1, w_pad=0.1)
plt.savefig("affine.png")

"""
Assertion
"""
np.testing.assert_allclose(tf_output, jax_output, rtol=1e-5, atol=1e-5)
np.testing.assert_allclose(tf_output, torch_output)  # failed

affine

Even though the result from torch differs from tensorflow and jax, it would be nice to have this function in the backend.
Is there any plan for this?

BTW, I'm willing to contribute!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant