We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
affine_transform
TF: tf.raw_ops.ImageProjectiveTransformV3 JAX: jax.scipy.ndimage.map_coordinates Torch: torchvision.transforms.functional.affine torch.nn.functional.affine_grid torch.nn.functional.grid_sample
Supporting affine_transform might help implement geometry-related preprocessing layers.
Here is a working script:
import jax import jax.numpy as jnp import matplotlib.pyplot as plt import numpy as np import tensorflow as tf import torch from torch.nn import functional as F from torchvision.transforms.v2 import functional as TF """ Init """ height = 112 width = 224 np.random.seed(2023) angle = np.radians(45) image = np.random.uniform(size=(height, width, 3)) x_offset = ( (width - 1) - (np.cos(angle) * (width - 1) - np.sin(angle) * (height - 1)) ) / 2.0 y_offset = ( (height - 1) - (np.sin(angle) * (width - 1) + np.cos(angle) * (height - 1)) ) / 2.0 matrix = np.array( [ [np.cos(angle), -np.sin(angle), x_offset], [np.sin(angle), np.cos(angle), y_offset], [0, 0, 1], ] ) """ TF """ # [a0, a1, a2, b0, b1, b2, c0, c1] tf_transform = tf.convert_to_tensor(matrix.flatten()[:-1], dtype=tf.float32) tf_transform = tf.expand_dims(tf_transform, axis=0) tf_image = tf.expand_dims(image, axis=0) tf_output_shape = tf.convert_to_tensor(tf.shape(tf_image)[1:-1], tf.int32) tf_output = tf.raw_ops.ImageProjectiveTransformV3( images=tf_image, transforms=tf_transform, output_shape=tf_output_shape, fill_value=0, interpolation="bilinear".upper(), fill_mode="constant".upper(), ) tf_output = tf_output[0].numpy() """ JAX """ meshgrid = jnp.meshgrid( *[jnp.arange(size) for size in image.shape], indexing="ij" ) indices = jnp.concatenate( [jnp.expand_dims(x, axis=-1) for x in meshgrid], axis=-1 ) jax_matrix = jnp.array( [ [matrix[1, 1], matrix[0, 1], 0], [matrix[1, 0], matrix[0, 0], 0], [0, 0, 1], ], dtype=jnp.float32, ) jax_offset = jnp.array([matrix[1, 2], matrix[0, 2], 0], dtype=jnp.float32) coordinates = indices @ jax_matrix coordinates = jnp.moveaxis(coordinates, source=-1, destination=0) coordinates += jnp.reshape(a=jax_offset, newshape=(*jax_offset.shape, 1, 1, 1)) jax_output = jax.scipy.ndimage.map_coordinates( image, coordinates, order=1, mode="constant", cval=0.0 ) """ Torch """ # matrix to theta h, w = height, width torch_theta = torch.zeros((2, 3)) torch_theta[0, 0] = matrix[0, 0] torch_theta[0, 1] = matrix[0, 1] * h / w torch_theta[0, 2] = ( matrix[0, 2] * 2 / w + torch_theta[0, 0] + torch_theta[0, 1] - 1 ) torch_theta[1, 0] = matrix[1, 0] * w / h torch_theta[1, 1] = matrix[1, 1] torch_theta[1, 2] = ( matrix[1, 2] * 2 / h + torch_theta[1, 0] + torch_theta[1, 1] - 1 ) torch_theta = torch_theta.reshape((1, 2, 3)) torch_image = torch.from_numpy(image).to(torch.float32) torch_image = torch_image.permute((2, 0, 1)) torch_image = torch_image.unsqueeze(dim=0) grid = F.affine_grid(torch_theta, torch_image.shape) torch_output = TF._geometry._apply_grid_transform( torch_image, grid, "bilinear", fill=0.0 ) torch_output = torch_output[0].permute((1, 2, 0)).numpy() """ Plots """ fig, ax_dict = plt.subplot_mosaic([["A", "B", "C"]]) ax_dict["A"].set_title("TF") ax_dict["A"].imshow(tf_output) ax_dict["B"].set_title("JAX") ax_dict["B"].imshow(jax_output) ax_dict["C"].set_title("Torch") ax_dict["C"].imshow(torch_output) fig.tight_layout(h_pad=0.1, w_pad=0.1) plt.savefig("affine.png") """ Assertion """ np.testing.assert_allclose(tf_output, jax_output, rtol=1e-5, atol=1e-5) np.testing.assert_allclose(tf_output, torch_output) # failed
Even though the result from torch differs from tensorflow and jax, it would be nice to have this function in the backend. Is there any plan for this?
backend
BTW, I'm willing to contribute!
The text was updated successfully, but these errors were encountered:
No branches or pull requests
TF:
tf.raw_ops.ImageProjectiveTransformV3
JAX:
jax.scipy.ndimage.map_coordinates
Torch:
torchvision.transforms.functional.affine
torch.nn.functional.affine_grid
torch.nn.functional.grid_sample
Supporting affine_transform might help implement geometry-related preprocessing layers.
Here is a working script:
Even though the result from torch differs from tensorflow and jax, it would be nice to have this function in the
backend
.Is there any plan for this?
BTW, I'm willing to contribute!
The text was updated successfully, but these errors were encountered: