Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add affine_transform op to all backends #477

Merged
merged 18 commits into from
Jul 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 128 additions & 6 deletions keras_core/backend/jax/image.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import functools

import jax
import jax.numpy as jnp

from keras_core.backend.jax.core import convert_to_tensor

RESIZE_METHODS = (
RESIZE_INTERPOLATIONS = (
"bilinear",
"nearest",
"lanczos3",
Expand All @@ -10,12 +15,16 @@


def resize(
image, size, method="bilinear", antialias=False, data_format="channels_last"
image,
size,
interpolation="bilinear",
antialias=False,
data_format="channels_last",
):
if method not in RESIZE_METHODS:
if interpolation not in RESIZE_INTERPOLATIONS:
raise ValueError(
"Invalid value for argument `method`. Expected of one "
f"{RESIZE_METHODS}. Received: method={method}"
"Invalid value for argument `interpolation`. Expected of one "
f"{RESIZE_INTERPOLATIONS}. Received: interpolation={interpolation}"
)
if not len(size) == 2:
raise ValueError(
Expand All @@ -39,4 +48,117 @@ def resize(
"or rank 4 (batch of images). Received input with shape: "
f"image.shape={image.shape}"
)
return jax.image.resize(image, size, method=method, antialias=antialias)
return jax.image.resize(
image, size, method=interpolation, antialias=antialias
)


AFFINE_TRANSFORM_INTERPOLATIONS = { # map to order
"nearest": 0,
"bilinear": 1,
}
AFFINE_TRANSFORM_FILL_MODES = {
"constant": "grid-constant",
"nearest": "nearest",
"wrap": "grid-wrap",
"mirror": "mirror",
"reflect": "reflect",
}


def affine_transform(
image,
transform,
interpolation="bilinear",
fill_mode="constant",
fill_value=0,
data_format="channels_last",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function signature looks good!

):
if interpolation not in AFFINE_TRANSFORM_INTERPOLATIONS.keys():
raise ValueError(
"Invalid value for argument `interpolation`. Expected of one "
f"{set(AFFINE_TRANSFORM_INTERPOLATIONS.keys())}. Received: "
f"interpolation={interpolation}"
)
if fill_mode not in AFFINE_TRANSFORM_FILL_MODES.keys():
raise ValueError(
"Invalid value for argument `fill_mode`. Expected of one "
f"{set(AFFINE_TRANSFORM_FILL_MODES.keys())}. "
f"Received: fill_mode={fill_mode}"
)

transform = convert_to_tensor(transform)

if len(image.shape) not in (3, 4):
raise ValueError(
"Invalid image rank: expected rank 3 (single image) "
"or rank 4 (batch of images). Received input with shape: "
f"image.shape={image.shape}"
)
if len(transform.shape) not in (1, 2):
raise ValueError(
"Invalid transform rank: expected rank 1 (single transform) "
"or rank 2 (batch of transforms). Received input with shape: "
f"transform.shape={transform.shape}"
)

# unbatched case
need_squeeze = False
if len(image.shape) == 3:
image = jnp.expand_dims(image, axis=0)
need_squeeze = True
if len(transform.shape) == 1:
transform = jnp.expand_dims(transform, axis=0)

if data_format == "channels_first":
image = jnp.transpose(image, (0, 2, 3, 1))

batch_size = image.shape[0]

# get indices
meshgrid = jnp.meshgrid(
*[jnp.arange(size) for size in image.shape[1:]], indexing="ij"
)
indices = jnp.concatenate(
[jnp.expand_dims(x, axis=-1) for x in meshgrid], axis=-1
)
indices = jnp.tile(indices, (batch_size, 1, 1, 1, 1))

# swap the values
a0 = transform[:, 0]
a2 = transform[:, 2]
b1 = transform[:, 4]
b2 = transform[:, 5]
transform = transform.at[:, 0].set(b1)
transform = transform.at[:, 2].set(b2)
transform = transform.at[:, 4].set(a0)
transform = transform.at[:, 5].set(a2)

# deal with transform
transform = jnp.pad(
transform, pad_width=[[0, 0], [0, 1]], constant_values=1
)
transform = jnp.reshape(transform, (batch_size, 3, 3))
offset = transform[:, 0:2, 2]
offset = jnp.pad(offset, pad_width=[[0, 0], [0, 1]])
transform = transform.at[:, 0:2, 2].set(0)

# transform the indices
coordinates = jnp.einsum("Bhwij, Bjk -> Bhwik", indices, transform)
coordinates = jnp.moveaxis(coordinates, source=-1, destination=1)
coordinates += jnp.reshape(a=offset, newshape=(*offset.shape, 1, 1, 1))

# apply affine transformation
_map_coordinates = functools.partial(
jax.scipy.ndimage.map_coordinates,
order=AFFINE_TRANSFORM_INTERPOLATIONS[interpolation],
mode=fill_mode,
cval=fill_value,
)
affined = jax.vmap(_map_coordinates)(image, coordinates)

if data_format == "channels_first":
affined = jnp.transpose(affined, (0, 3, 1, 2))
if need_squeeze:
affined = jnp.squeeze(affined, axis=0)
return affined
135 changes: 129 additions & 6 deletions keras_core/backend/numpy/image.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import jax
import numpy as np
import scipy.ndimage

RESIZE_METHODS = (
from keras_core.backend.numpy.core import convert_to_tensor

RESIZE_INTERPOLATIONS = (
"bilinear",
"nearest",
"lanczos3",
Expand All @@ -11,12 +14,16 @@


def resize(
image, size, method="bilinear", antialias=False, data_format="channels_last"
image,
size,
interpolation="bilinear",
antialias=False,
data_format="channels_last",
):
if method not in RESIZE_METHODS:
if interpolation not in RESIZE_INTERPOLATIONS:
raise ValueError(
"Invalid value for argument `method`. Expected of one "
f"{RESIZE_METHODS}. Received: method={method}"
"Invalid value for argument `interpolation`. Expected of one "
f"{RESIZE_INTERPOLATIONS}. Received: interpolation={interpolation}"
)
if not len(size) == 2:
raise ValueError(
Expand All @@ -41,5 +48,121 @@ def resize(
f"image.shape={image.shape}"
)
return np.array(
jax.image.resize(image, size, method=method, antialias=antialias)
jax.image.resize(image, size, method=interpolation, antialias=antialias)
)


AFFINE_TRANSFORM_INTERPOLATIONS = { # map to order
"nearest": 0,
"bilinear": 1,
}
AFFINE_TRANSFORM_FILL_MODES = {
"constant": "grid-constant",
"nearest": "nearest",
"wrap": "grid-wrap",
"mirror": "mirror",
"reflect": "reflect",
}


def affine_transform(
image,
transform,
interpolation="bilinear",
fill_mode="constant",
fill_value=0,
data_format="channels_last",
):
if interpolation not in AFFINE_TRANSFORM_INTERPOLATIONS.keys():
raise ValueError(
"Invalid value for argument `interpolation`. Expected of one "
f"{set(AFFINE_TRANSFORM_INTERPOLATIONS.keys())}. Received: "
f"interpolation={interpolation}"
)
if fill_mode not in AFFINE_TRANSFORM_FILL_MODES.keys():
raise ValueError(
"Invalid value for argument `fill_mode`. Expected of one "
f"{set(AFFINE_TRANSFORM_FILL_MODES.keys())}. "
f"Received: fill_mode={fill_mode}"
)

transform = convert_to_tensor(transform)

if len(image.shape) not in (3, 4):
raise ValueError(
"Invalid image rank: expected rank 3 (single image) "
"or rank 4 (batch of images). Received input with shape: "
f"image.shape={image.shape}"
)
if len(transform.shape) not in (1, 2):
raise ValueError(
"Invalid transform rank: expected rank 1 (single transform) "
"or rank 2 (batch of transforms). Received input with shape: "
f"transform.shape={transform.shape}"
)

# unbatched case
need_squeeze = False
if len(image.shape) == 3:
image = np.expand_dims(image, axis=0)
need_squeeze = True
if len(transform.shape) == 1:
transform = np.expand_dims(transform, axis=0)

if data_format == "channels_first":
image = np.transpose(image, (0, 2, 3, 1))

batch_size = image.shape[0]

# get indices
meshgrid = np.meshgrid(
*[np.arange(size) for size in image.shape[1:]], indexing="ij"
)
indices = np.concatenate(
[np.expand_dims(x, axis=-1) for x in meshgrid], axis=-1
)
indices = np.tile(indices, (batch_size, 1, 1, 1, 1))

# swap the values
a0 = transform[:, 0].copy()
a2 = transform[:, 2].copy()
b1 = transform[:, 4].copy()
b2 = transform[:, 5].copy()
transform[:, 0] = b1
transform[:, 2] = b2
transform[:, 4] = a0
transform[:, 5] = a2

# deal with transform
transform = np.pad(transform, pad_width=[[0, 0], [0, 1]], constant_values=1)
transform = np.reshape(transform, (batch_size, 3, 3))
offset = transform[:, 0:2, 2].copy()
offset = np.pad(offset, pad_width=[[0, 0], [0, 1]])
transform[:, 0:2, 2] = 0

# transform the indices
coordinates = np.einsum("Bhwij, Bjk -> Bhwik", indices, transform)
coordinates = np.moveaxis(coordinates, source=-1, destination=1)
coordinates += np.reshape(a=offset, newshape=(*offset.shape, 1, 1, 1))

# apply affine transformation
affined = np.stack(
[
scipy.ndimage.map_coordinates(
image[i],
coordinates[i],
order=AFFINE_TRANSFORM_INTERPOLATIONS[interpolation],
mode=AFFINE_TRANSFORM_FILL_MODES[fill_mode],
cval=fill_value,
prefilter=False,
)
for i in range(batch_size)
],
axis=0,
)

if data_format == "channels_first":
affined = np.transpose(affined, (0, 3, 1, 2))
if need_squeeze:
affined = np.squeeze(affined, axis=0)
return affined
Loading