Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Implements map_coordinates op in keras-core #784

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 30 additions & 3 deletions keras_core/backend/jax/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def resize(
):
if interpolation not in RESIZE_INTERPOLATIONS:
raise ValueError(
"Invalid value for argument `interpolation`. Expected of one "
"Invalid value for argument `interpolation`. Expected one of "
f"{RESIZE_INTERPOLATIONS}. Received: interpolation={interpolation}"
)
if not len(size) == 2:
Expand Down Expand Up @@ -76,13 +76,13 @@ def affine_transform(
):
if interpolation not in AFFINE_TRANSFORM_INTERPOLATIONS.keys():
raise ValueError(
"Invalid value for argument `interpolation`. Expected of one "
"Invalid value for argument `interpolation`. Expected one of "
f"{set(AFFINE_TRANSFORM_INTERPOLATIONS.keys())}. Received: "
f"interpolation={interpolation}"
)
if fill_mode not in AFFINE_TRANSFORM_FILL_MODES.keys():
raise ValueError(
"Invalid value for argument `fill_mode`. Expected of one "
"Invalid value for argument `fill_mode`. Expected one of "
f"{set(AFFINE_TRANSFORM_FILL_MODES.keys())}. "
f"Received: fill_mode={fill_mode}"
)
Expand Down Expand Up @@ -162,3 +162,30 @@ def affine_transform(
if need_squeeze:
affined = jnp.squeeze(affined, axis=0)
return affined


MAP_COORDINATES_MODES = {
"constant",
"nearest",
"wrap",
"mirror",
"reflect",
}


def map_coordinates(input, coordinates, order, mode="constant", cval=0.0):
if mode not in MAP_COORDINATES_MODES:
raise ValueError(
"Invalid value for argument `mode`. Expected one of "
f"{set(MAP_COORDINATES_MODES.keys())}. Received: "
f"mode={mode}"
)
if order not in range(2):
raise ValueError(
"Invalid value for argument `order`. Expected one of "
f"{[0, 1]}. Received: "
f"mode={mode}"
)
return jax.scipy.ndimage.map_coordinates(
input, coordinates, order, mode, cval
)
27 changes: 27 additions & 0 deletions keras_core/backend/numpy/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,30 @@ def affine_transform(
if need_squeeze:
affined = np.squeeze(affined, axis=0)
return affined


MAP_COORDINATES_MODES = {
"constant",
"nearest",
"wrap",
"mirror",
"reflect",
}


def map_coordinates(input, coordinates, order, mode="constant", cval=0.0):
if mode not in MAP_COORDINATES_MODES:
raise ValueError(
"Invalid value for argument `mode`. Expected one of "
f"{set(MAP_COORDINATES_MODES.keys())}. Received: "
f"mode={mode}"
)
if order not in range(2):
raise ValueError(
"Invalid value for argument `order`. Expected one of "
f"{[0, 1]}. Received: "
f"mode={mode}"
)
return scipy.ndimage.map_coordinates(
input, coordinates, order, mode, cval
)
127 changes: 127 additions & 0 deletions keras_core/backend/tensorflow/image.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
import functools
import itertools
import operator

import tensorflow as tf

from keras_core.backend.tensorflow.core import convert_to_tensor

RESIZE_INTERPOLATIONS = (
"bilinear",
"nearest",
Expand Down Expand Up @@ -119,3 +125,124 @@ def affine_transform(
if need_squeeze:
affined = tf.squeeze(affined, axis=0)
return affined

def _unzip3(xyzs):
"""Unzip sequence of length-3 tuples into three tuples."""
# Note: we deliberately don't use zip(*xyzs) because it is lazily evaluated,
# is too permissive about inputs, and does not guarantee a length-3 output.
xs = []
ys = []
zs = []
for x, y, z in xyzs:
xs.append(x)
ys.append(y)
zs.append(z)
return tuple(xs), tuple(ys), tuple(zs)


def _nonempty_prod(arrs):
return functools.reduce(operator.mul, arrs)


def _nonempty_sum(arrs):
return functools.reduce(operator.add, arrs)


def _mirror_index_fixer(index, size):
s = size - 1 # Half-wavelength of triangular wave
# Scaled, integer-valued version of the triangular wave |x - round(x)|
return tf.abs((index + s) % (2 * s) - s)


def _reflect_index_fixer(index, size):
return tf.math.floordiv(
_mirror_index_fixer(2 * index + 1, 2 * size + 1) - 1, 2
)


_INDEX_FIXERS = {
"constant": lambda index, size: index,
"nearest": lambda index, size: tf.clip_by_value(index, 0, size - 1),
"wrap": lambda index, size: index % size,
"mirror": _mirror_index_fixer,
"reflect": _reflect_index_fixer,
}


def _round_half_away_from_zero(a):
return (
a
if a.dtype.is_integer
else tf.round(a)
)


def _nearest_indices_and_weights(coordinate):
index = tf.cast(_round_half_away_from_zero(coordinate), tf.int32)
weight = tf.constant(1, coordinate.dtype)
return [(index, weight)]


def _linear_indices_and_weights(coordinate):
lower = tf.floor(coordinate)
upper_weight = coordinate - lower
lower_weight = 1 - upper_weight
index = tf.cast(lower, tf.int32)
return [(index, lower_weight), (index + 1, upper_weight)]


def map_coordinates(input, coordinates, order, mode, cval=0.0):
input_arr = convert_to_tensor(input)
coordinate_arrs = [convert_to_tensor(c) for c in coordinates]
cval = convert_to_tensor(tf.cast(cval, input_arr.dtype))

if len(coordinates) != input_arr.ndim:
raise ValueError(
"coordinates must be a sequence of length input.ndim, but "
"{} != {}".format(len(coordinates), input_arr.ndim)
)

index_fixer = _INDEX_FIXERS.get(mode)
if index_fixer is None:
raise NotImplementedError(
"map_coordinates does not yet support mode {}. "
"Currently supported modes are {}.".format(mode, set(_INDEX_FIXERS))
)

def is_valid(index, size):
if mode == "constant":
return (0 <= index) & (index < size)
else:
return True

if order == 0:
interp_fun = _nearest_indices_and_weights
elif order == 1:
interp_fun = _linear_indices_and_weights
else:
raise NotImplementedError("map_coordinates currently requires order<=1")

valid_1d_interpolations = []
for coordinate, size in zip(coordinate_arrs, input_arr.shape):
interp_nodes = interp_fun(coordinate)
valid_interp = []
for index, weight in interp_nodes:
fixed_index = index_fixer(index, size)
valid = is_valid(index, size)
valid_interp.append((fixed_index, valid, weight))
valid_1d_interpolations.append(valid_interp)

outputs = []
for items in itertools.product(*valid_1d_interpolations):
indices, validities, weights = _unzip3(items)
if all(valid is True for valid in validities):
# fast path
contribution = input_arr[indices]
else:
all_valid = functools.reduce(operator.and_, validities)
contribution = tf.where(all_valid, input_arr[indices], cval)
outputs.append(_nonempty_prod(weights) * contribution)
result = _nonempty_sum(outputs)
if input_arr.dtype.is_integer:
result = _round_half_away_from_zero(result)
return tf.cast(result, input_arr.dtype)
128 changes: 128 additions & 0 deletions keras_core/backend/torch/image.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import functools
import itertools
import operator

import torch
import torch.nn.functional as tnn

Expand Down Expand Up @@ -263,3 +267,127 @@ def affine_transform(
if need_squeeze:
affined = affined.squeeze(dim=0)
return affined


def _unzip3(xyzs):
"""Unzip sequence of length-3 tuples into three tuples."""
# Note: we deliberately don't use zip(*xyzs) because it is lazily evaluated,
# is too permissive about inputs, and does not guarantee a length-3 output.
xs = []
ys = []
zs = []
for x, y, z in xyzs:
xs.append(x)
ys.append(y)
zs.append(z)
return tuple(xs), tuple(ys), tuple(zs)


def _nonempty_prod(arrs):
return functools.reduce(operator.mul, arrs)


def _nonempty_sum(arrs):
return functools.reduce(operator.add, arrs)


def _mirror_index_fixer(index, size):
s = size - 1 # Half-wavelength of triangular wave
# Scaled, integer-valued version of the triangular wave |x - round(x)|
return torch.abs((index + s) % (2 * s) - s)


def _reflect_index_fixer(index, size):
return torch.floor_divide(
_mirror_index_fixer(2 * index + 1, 2 * size + 1) - 1, 2
)


_INDEX_FIXERS = {
"constant": lambda index, size: index,
"nearest": lambda index, size: torch.clip(index, 0, size - 1),
"wrap": lambda index, size: index % size,
"mirror": _mirror_index_fixer,
"reflect": _reflect_index_fixer,
}


def _round_half_away_from_zero(a):
return (
a
if (not torch.is_floating_point(a) and not torch.is_complex(a))
else torch.round(a)
)


def _nearest_indices_and_weights(coordinate):
index = _round_half_away_from_zero(coordinate).to(torch.int32)
weight = torch.tensor(1).to(torch.int32)
return [(index, weight)]


def _linear_indices_and_weights(coordinate):
lower = torch.floor(coordinate)
upper_weight = coordinate - lower
lower_weight = 1 - upper_weight
index = lower.to(torch.int32)
return [(index, lower_weight), (index + 1, upper_weight)]


def map_coordinates(input, coordinates, order, mode, cval=0.0):
input_arr = convert_to_tensor(input)
coordinate_arrs = [convert_to_tensor(c) for c in coordinates]
cval = convert_to_tensor(cval, input_arr.dtype)

if len(coordinates) != input_arr.ndim:
raise ValueError(
"coordinates must be a sequence of length input.ndim, but "
"{} != {}".format(len(coordinates), input_arr.ndim)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Always use f-strings rather than .format()

)

index_fixer = _INDEX_FIXERS.get(mode)
if index_fixer is None:
raise NotImplementedError(
"map_coordinates does not yet support mode {}. "
"Currently supported modes are {}.".format(mode, set(_INDEX_FIXERS))
)

def is_valid(index, size):
if mode == "constant":
return (0 <= index) & (index < size)
else:
return True

if order == 0:
interp_fun = _nearest_indices_and_weights
elif order == 1:
interp_fun = _linear_indices_and_weights
else:
raise NotImplementedError("map_coordinates currently requires order<=1")

valid_1d_interpolations = []
for coordinate, size in zip(coordinate_arrs, input_arr.shape):
interp_nodes = interp_fun(coordinate)
valid_interp = []
for index, weight in interp_nodes:
fixed_index = index_fixer(index, size)
valid = is_valid(index, size)
valid_interp.append((fixed_index, valid, weight))
valid_1d_interpolations.append(valid_interp)

outputs = []
for items in itertools.product(*valid_1d_interpolations):
indices, validities, weights = _unzip3(items)
if all(valid is True for valid in validities):
# fast path
contribution = input_arr[indices]
else:
all_valid = functools.reduce(operator.and_, validities)
contribution = torch.where(all_valid, input_arr[indices], cval)
outputs.append(_nonempty_prod(weights) * contribution)
result = _nonempty_sum(outputs)
if not torch.is_floating_point(input_arr) and not torch.is_complex(
input_arr
):
result = _round_half_away_from_zero(result)
return result.to(input_arr.dtype)
Loading
Loading