Skip to content

Commit

Permalink
Add affine_transform op to all backends (keras-team#477)
Browse files Browse the repository at this point in the history
* Add affine op

* Sync import convention

* Use `np.random.random`

* Refactor jax implementation

* Fix

* Address fchollet's comments

* Update docstring

* Fix test

* Replace method with interpolation

* Replace method with interpolation

* Replace method with interpolation

* Update test
  • Loading branch information
james77777778 authored and adi-kmt committed Jul 21, 2023
1 parent 0ba17cf commit 9c75bd0
Show file tree
Hide file tree
Showing 9 changed files with 776 additions and 48 deletions.
134 changes: 128 additions & 6 deletions keras_core/backend/jax/image.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import functools

import jax
import jax.numpy as jnp

from keras_core.backend.jax.core import convert_to_tensor

RESIZE_METHODS = (
RESIZE_INTERPOLATIONS = (
"bilinear",
"nearest",
"lanczos3",
Expand All @@ -10,12 +15,16 @@


def resize(
image, size, method="bilinear", antialias=False, data_format="channels_last"
image,
size,
interpolation="bilinear",
antialias=False,
data_format="channels_last",
):
if method not in RESIZE_METHODS:
if interpolation not in RESIZE_INTERPOLATIONS:
raise ValueError(
"Invalid value for argument `method`. Expected of one "
f"{RESIZE_METHODS}. Received: method={method}"
"Invalid value for argument `interpolation`. Expected of one "
f"{RESIZE_INTERPOLATIONS}. Received: interpolation={interpolation}"
)
if not len(size) == 2:
raise ValueError(
Expand All @@ -39,4 +48,117 @@ def resize(
"or rank 4 (batch of images). Received input with shape: "
f"image.shape={image.shape}"
)
return jax.image.resize(image, size, method=method, antialias=antialias)
return jax.image.resize(
image, size, method=interpolation, antialias=antialias
)


AFFINE_TRANSFORM_INTERPOLATIONS = { # map to order
"nearest": 0,
"bilinear": 1,
}
AFFINE_TRANSFORM_FILL_MODES = {
"constant": "grid-constant",
"nearest": "nearest",
"wrap": "grid-wrap",
"mirror": "mirror",
"reflect": "reflect",
}


def affine_transform(
image,
transform,
interpolation="bilinear",
fill_mode="constant",
fill_value=0,
data_format="channels_last",
):
if interpolation not in AFFINE_TRANSFORM_INTERPOLATIONS.keys():
raise ValueError(
"Invalid value for argument `interpolation`. Expected of one "
f"{set(AFFINE_TRANSFORM_INTERPOLATIONS.keys())}. Received: "
f"interpolation={interpolation}"
)
if fill_mode not in AFFINE_TRANSFORM_FILL_MODES.keys():
raise ValueError(
"Invalid value for argument `fill_mode`. Expected of one "
f"{set(AFFINE_TRANSFORM_FILL_MODES.keys())}. "
f"Received: fill_mode={fill_mode}"
)

transform = convert_to_tensor(transform)

if len(image.shape) not in (3, 4):
raise ValueError(
"Invalid image rank: expected rank 3 (single image) "
"or rank 4 (batch of images). Received input with shape: "
f"image.shape={image.shape}"
)
if len(transform.shape) not in (1, 2):
raise ValueError(
"Invalid transform rank: expected rank 1 (single transform) "
"or rank 2 (batch of transforms). Received input with shape: "
f"transform.shape={transform.shape}"
)

# unbatched case
need_squeeze = False
if len(image.shape) == 3:
image = jnp.expand_dims(image, axis=0)
need_squeeze = True
if len(transform.shape) == 1:
transform = jnp.expand_dims(transform, axis=0)

if data_format == "channels_first":
image = jnp.transpose(image, (0, 2, 3, 1))

batch_size = image.shape[0]

# get indices
meshgrid = jnp.meshgrid(
*[jnp.arange(size) for size in image.shape[1:]], indexing="ij"
)
indices = jnp.concatenate(
[jnp.expand_dims(x, axis=-1) for x in meshgrid], axis=-1
)
indices = jnp.tile(indices, (batch_size, 1, 1, 1, 1))

# swap the values
a0 = transform[:, 0]
a2 = transform[:, 2]
b1 = transform[:, 4]
b2 = transform[:, 5]
transform = transform.at[:, 0].set(b1)
transform = transform.at[:, 2].set(b2)
transform = transform.at[:, 4].set(a0)
transform = transform.at[:, 5].set(a2)

# deal with transform
transform = jnp.pad(
transform, pad_width=[[0, 0], [0, 1]], constant_values=1
)
transform = jnp.reshape(transform, (batch_size, 3, 3))
offset = transform[:, 0:2, 2]
offset = jnp.pad(offset, pad_width=[[0, 0], [0, 1]])
transform = transform.at[:, 0:2, 2].set(0)

# transform the indices
coordinates = jnp.einsum("Bhwij, Bjk -> Bhwik", indices, transform)
coordinates = jnp.moveaxis(coordinates, source=-1, destination=1)
coordinates += jnp.reshape(a=offset, newshape=(*offset.shape, 1, 1, 1))

# apply affine transformation
_map_coordinates = functools.partial(
jax.scipy.ndimage.map_coordinates,
order=AFFINE_TRANSFORM_INTERPOLATIONS[interpolation],
mode=fill_mode,
cval=fill_value,
)
affined = jax.vmap(_map_coordinates)(image, coordinates)

if data_format == "channels_first":
affined = jnp.transpose(affined, (0, 3, 1, 2))
if need_squeeze:
affined = jnp.squeeze(affined, axis=0)
return affined
135 changes: 129 additions & 6 deletions keras_core/backend/numpy/image.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import jax
import numpy as np
import scipy.ndimage

RESIZE_METHODS = (
from keras_core.backend.numpy.core import convert_to_tensor

RESIZE_INTERPOLATIONS = (
"bilinear",
"nearest",
"lanczos3",
Expand All @@ -11,12 +14,16 @@


def resize(
image, size, method="bilinear", antialias=False, data_format="channels_last"
image,
size,
interpolation="bilinear",
antialias=False,
data_format="channels_last",
):
if method not in RESIZE_METHODS:
if interpolation not in RESIZE_INTERPOLATIONS:
raise ValueError(
"Invalid value for argument `method`. Expected of one "
f"{RESIZE_METHODS}. Received: method={method}"
"Invalid value for argument `interpolation`. Expected of one "
f"{RESIZE_INTERPOLATIONS}. Received: interpolation={interpolation}"
)
if not len(size) == 2:
raise ValueError(
Expand All @@ -41,5 +48,121 @@ def resize(
f"image.shape={image.shape}"
)
return np.array(
jax.image.resize(image, size, method=method, antialias=antialias)
jax.image.resize(image, size, method=interpolation, antialias=antialias)
)


AFFINE_TRANSFORM_INTERPOLATIONS = { # map to order
"nearest": 0,
"bilinear": 1,
}
AFFINE_TRANSFORM_FILL_MODES = {
"constant": "grid-constant",
"nearest": "nearest",
"wrap": "grid-wrap",
"mirror": "mirror",
"reflect": "reflect",
}


def affine_transform(
image,
transform,
interpolation="bilinear",
fill_mode="constant",
fill_value=0,
data_format="channels_last",
):
if interpolation not in AFFINE_TRANSFORM_INTERPOLATIONS.keys():
raise ValueError(
"Invalid value for argument `interpolation`. Expected of one "
f"{set(AFFINE_TRANSFORM_INTERPOLATIONS.keys())}. Received: "
f"interpolation={interpolation}"
)
if fill_mode not in AFFINE_TRANSFORM_FILL_MODES.keys():
raise ValueError(
"Invalid value for argument `fill_mode`. Expected of one "
f"{set(AFFINE_TRANSFORM_FILL_MODES.keys())}. "
f"Received: fill_mode={fill_mode}"
)

transform = convert_to_tensor(transform)

if len(image.shape) not in (3, 4):
raise ValueError(
"Invalid image rank: expected rank 3 (single image) "
"or rank 4 (batch of images). Received input with shape: "
f"image.shape={image.shape}"
)
if len(transform.shape) not in (1, 2):
raise ValueError(
"Invalid transform rank: expected rank 1 (single transform) "
"or rank 2 (batch of transforms). Received input with shape: "
f"transform.shape={transform.shape}"
)

# unbatched case
need_squeeze = False
if len(image.shape) == 3:
image = np.expand_dims(image, axis=0)
need_squeeze = True
if len(transform.shape) == 1:
transform = np.expand_dims(transform, axis=0)

if data_format == "channels_first":
image = np.transpose(image, (0, 2, 3, 1))

batch_size = image.shape[0]

# get indices
meshgrid = np.meshgrid(
*[np.arange(size) for size in image.shape[1:]], indexing="ij"
)
indices = np.concatenate(
[np.expand_dims(x, axis=-1) for x in meshgrid], axis=-1
)
indices = np.tile(indices, (batch_size, 1, 1, 1, 1))

# swap the values
a0 = transform[:, 0].copy()
a2 = transform[:, 2].copy()
b1 = transform[:, 4].copy()
b2 = transform[:, 5].copy()
transform[:, 0] = b1
transform[:, 2] = b2
transform[:, 4] = a0
transform[:, 5] = a2

# deal with transform
transform = np.pad(transform, pad_width=[[0, 0], [0, 1]], constant_values=1)
transform = np.reshape(transform, (batch_size, 3, 3))
offset = transform[:, 0:2, 2].copy()
offset = np.pad(offset, pad_width=[[0, 0], [0, 1]])
transform[:, 0:2, 2] = 0

# transform the indices
coordinates = np.einsum("Bhwij, Bjk -> Bhwik", indices, transform)
coordinates = np.moveaxis(coordinates, source=-1, destination=1)
coordinates += np.reshape(a=offset, newshape=(*offset.shape, 1, 1, 1))

# apply affine transformation
affined = np.stack(
[
scipy.ndimage.map_coordinates(
image[i],
coordinates[i],
order=AFFINE_TRANSFORM_INTERPOLATIONS[interpolation],
mode=AFFINE_TRANSFORM_FILL_MODES[fill_mode],
cval=fill_value,
prefilter=False,
)
for i in range(batch_size)
],
axis=0,
)

if data_format == "channels_first":
affined = np.transpose(affined, (0, 3, 1, 2))
if need_squeeze:
affined = np.squeeze(affined, axis=0)
return affined
Loading

0 comments on commit 9c75bd0

Please sign in to comment.