Skip to content

Commit

Permalink
Add support for Tensorflow SparseTensors: merging layers.
Browse files Browse the repository at this point in the history
Added `tf.SparseTensor` support for ops:
- add
- concatenate
- maximum
- minimum
- multiply
- subtract

Added `tf.SparseTensor` support for merging layers:
- Add
- Average
- Concatenate
- Maximum
- Minimum
- Multiply
- Subtract

Note that the `Dot` merging layer will be addressed in a separate PR.
  • Loading branch information
hertschuh committed Sep 20, 2023
1 parent cbc2e47 commit 2d246ad
Show file tree
Hide file tree
Showing 9 changed files with 412 additions and 10 deletions.
66 changes: 66 additions & 0 deletions keras_core/backend/tensorflow/numpy.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import builtins
import functools
import warnings

import tensorflow as tf
Expand All @@ -7,6 +9,8 @@


def add(x1, x2):
if isinstance(x1, tf.SparseTensor) or isinstance(x2, tf.SparseTensor):
return tf.sparse.add(x1, x2)
return tfnp.add(x1, x2)


Expand Down Expand Up @@ -38,6 +42,11 @@ def einsum(subscripts, *operands, **kwargs):


def subtract(x1, x2):
if isinstance(x1, tf.SparseTensor) or isinstance(x2, tf.SparseTensor):
if isinstance(x2, tf.SparseTensor):
return tf.sparse.add(x1, tf.sparse.map_values(tf.negative, x2))
else:
return tf.sparse.add(x1, tf.negative(x2))
return tfnp.subtract(x1, x2)


Expand All @@ -62,6 +71,40 @@ def matmul(x1, x2):


def multiply(x1, x2):
if isinstance(x1, tf.SparseTensor):
if isinstance(x2, tf.SparseTensor):
ones_like_int8 = functools.partial(tf.ones_like, dtype=tf.int8)
zeros_like_int8 = functools.partial(tf.zeros_like, dtype=tf.int8)

# compute the intersection of indices in the form of a sparse tensor
# containing ones as values
ones1 = tf.sparse.map_values(ones_like_int8, x1)
ones2 = tf.sparse.map_values(ones_like_int8, x2)
# tf.sets.intersection ignores the last dimension when comparing,
# so we need to add a dummy extra dimension and then remove it
intersection = tf.sparse.reshape(
tf.sets.intersection(
tf.sparse.expand_dims(ones1, axis=-1),
tf.sparse.expand_dims(ones2, axis=-1),
),
x1.dense_shape,
)

# compute the masks to remove indices in x1 and x2 that are not part
# of the intersection, then trim x1 and x2
zeros1 = tf.sparse.map_values(zeros_like_int8, x1)
zeros2 = tf.sparse.map_values(zeros_like_int8, x2)
mask1 = tf.sparse.add(zeros1, intersection)
mask2 = tf.sparse.add(zeros2, intersection)
x1_trimmed = tf.sparse.retain(x1, tf.cast(mask1.values, tf.bool))
x2_trimmed = tf.sparse.retain(x2, tf.cast(mask2.values, tf.bool))

# now it is an element-wise multiplication on the values
return tf.sparse.map_values(tf.multiply, x1_trimmed, x2_trimmed)
else:
return x1 * x2
elif isinstance(x2, tf.SparseTensor):
return x2 * x1
return tfnp.multiply(x1, x2)


Expand Down Expand Up @@ -202,6 +245,15 @@ def clip(x, x_min, x_max):


def concatenate(xs, axis=0):
sparse_count = builtins.sum(isinstance(x, tf.SparseTensor) for x in xs)
if sparse_count:
if sparse_count == len(xs):
return tf.sparse.concat(axis=axis, sp_inputs=xs)
else:
xs = [
tf.sparse.to_dense(x) if isinstance(x, tf.SparseTensor) else x
for x in xs
]
return tfnp.concatenate(xs, axis=axis)


Expand Down Expand Up @@ -420,6 +472,13 @@ def logspace(start, stop, num=50, endpoint=True, base=10, dtype=None, axis=0):


def maximum(x1, x2):
if isinstance(x1, tf.SparseTensor):
if isinstance(x2, tf.SparseTensor):
return tf.sparse.maximum(x1, x2)
else:
x1 = tf.sparse.to_dense(x1)
elif isinstance(x2, tf.SparseTensor):
x2 = tf.sparse.to_dense(x2)
return tfnp.maximum(x1, x2)


Expand Down Expand Up @@ -449,6 +508,13 @@ def min(x, axis=None, keepdims=False, initial=None):


def minimum(x1, x2):
if isinstance(x1, tf.SparseTensor):
if isinstance(x2, tf.SparseTensor):
return tf.sparse.minimum(x1, x2)
else:
x1 = tf.sparse.to_dense(x1)
elif isinstance(x2, tf.SparseTensor):
x2 = tf.sparse.to_dense(x2)
return tfnp.minimum(x1, x2)


Expand Down
3 changes: 2 additions & 1 deletion keras_core/layers/merging/add.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from keras_core import ops
from keras_core.api_export import keras_core_export
from keras_core.layers.merging.base_merge import Merge

Expand Down Expand Up @@ -32,7 +33,7 @@ class Add(Merge):
def _merge_function(self, inputs):
output = inputs[0]
for i in range(1, len(inputs)):
output = output + inputs[i]
output = ops.add(output, inputs[i])
return output


Expand Down
3 changes: 2 additions & 1 deletion keras_core/layers/merging/average.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from keras_core import ops
from keras_core.api_export import keras_core_export
from keras_core.layers.merging.base_merge import Merge

Expand Down Expand Up @@ -32,7 +33,7 @@ class Average(Merge):
def _merge_function(self, inputs):
output = inputs[0]
for i in range(1, len(inputs)):
output = output + inputs[i]
output = ops.add(output, inputs[i])
return output / len(inputs)


Expand Down
8 changes: 8 additions & 0 deletions keras_core/layers/merging/base_merge.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from keras_core import backend
from keras_core import ops
from keras_core.backend.common.keras_tensor import KerasTensor
from keras_core.layers.layer import Layer


Expand Down Expand Up @@ -208,6 +209,13 @@ def compute_output_shape(self, input_shape):
output_shape = (None,) + output_shape
return output_shape

def compute_output_spec(self, inputs):
output_shape = self.compute_output_shape([x.shape for x in inputs])
output_sparse = all(x.sparse for x in inputs)
return KerasTensor(
output_shape, dtype=self.compute_dtype, sparse=output_sparse
)

def compute_mask(self, inputs, mask=None):
if mask is None:
return None
Expand Down
56 changes: 56 additions & 0 deletions keras_core/layers/merging/merging_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,3 +220,59 @@ def test_subtract_layer_inputs_length_errors(self):
ValueError, "layer should be called on exactly 2 inputs"
):
layers.Subtract()([input_1])

@parameterized.named_parameters(TEST_PARAMETERS)
@pytest.mark.skipif(
not backend.SUPPORTS_SPARSE_TENSORS,
reason="Backend does not support sparse tensors.",
)
def test_sparse(
self,
layer_class,
np_op,
init_kwargs={},
input_shape=(2, 4, 5),
expected_output_shape=(2, 4, 5),
**kwargs
):
import tensorflow as tf

if layer_class == layers.Dot:
pytest.skip("Dot layer does not support sparse tensors.")

self.run_layer_test(
layer_class,
init_kwargs=init_kwargs,
input_shape=[input_shape, input_shape],
input_sparse=True,
expected_output_shape=expected_output_shape,
expected_output_sparse=True,
expected_num_trainable_weights=0,
expected_num_non_trainable_weights=0,
expected_num_seed_generators=0,
expected_num_losses=0,
supports_masking=True,
run_training_check=False,
run_mixed_precision_check=False,
)

layer = layer_class(**init_kwargs)

# Merging a sparse tensor with a dense tensor, or a dense tensor with a
# sparse tensor produces a dense tensor
x1 = tf.SparseTensor(
indices=[[0, 0], [1, 2]], values=[1.0, 2.0], dense_shape=(2, 3)
)
x1_np = tf.sparse.to_dense(x1).numpy()
x2 = np.random.rand(2, 3)
self.assertAllClose(layer([x1, x2]), np_op(x1_np, x2, **init_kwargs))
self.assertAllClose(layer([x2, x1]), np_op(x2, x1_np, **init_kwargs))

# Merging a sparse tensor with a sparse tensor produces a sparse tensor
x3 = tf.SparseTensor(
indices=[[0, 0], [1, 1]], values=[4.0, 5.0], dense_shape=(2, 3)
)
x3_np = tf.sparse.to_dense(x3).numpy()

self.assertIsInstance(layer([x1, x3]), tf.SparseTensor)
self.assertAllClose(layer([x1, x3]), np_op(x1_np, x3_np, **init_kwargs))
3 changes: 2 additions & 1 deletion keras_core/layers/merging/multiply.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from keras_core import ops
from keras_core.api_export import keras_core_export
from keras_core.layers.merging.base_merge import Merge

Expand Down Expand Up @@ -32,7 +33,7 @@ class Multiply(Merge):
def _merge_function(self, inputs):
output = inputs[0]
for i in range(1, len(inputs)):
output = output * inputs[i]
output = ops.multiply(output, inputs[i])
return output


Expand Down
3 changes: 2 additions & 1 deletion keras_core/layers/merging/subtract.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from keras_core import ops
from keras_core.api_export import keras_core_export
from keras_core.layers.merging.base_merge import Merge

Expand Down Expand Up @@ -44,7 +45,7 @@ def _merge_function(self, inputs):
"A `Subtract` layer should be called on exactly 2 inputs. "
f"Received: inputs={inputs}"
)
return inputs[0] - inputs[1]
return ops.subtract(inputs[0], inputs[1])


@keras_core_export("keras_core.layers.subtract")
Expand Down
30 changes: 24 additions & 6 deletions keras_core/ops/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,10 @@ def compute_output_spec(self, x1, x2):
x1_shape = getattr(x1, "shape", [])
x2_shape = getattr(x2, "shape", [])
output_shape = broadcast_shapes(x1_shape, x2_shape)
return KerasTensor(output_shape, dtype=x1.dtype)
x1_sparse = getattr(x1, "sparse", True)
x2_sparse = getattr(x2, "sparse", True)
output_sparse = x1_sparse and x2_sparse
return KerasTensor(output_shape, dtype=x1.dtype, sparse=output_sparse)


@keras_core_export(["keras_core.ops.add", "keras_core.ops.numpy.add"])
Expand Down Expand Up @@ -1386,6 +1389,7 @@ def call(self, xs):
def compute_output_spec(self, xs):
first_shape = xs[0].shape
total_size_on_axis = 0
all_sparse = True
for x in xs:
if not shape_equal(
x.shape, first_shape, axis=[self.axis], allow_none=True
Expand All @@ -1400,9 +1404,11 @@ def compute_output_spec(self, xs):
total_size_on_axis = None
else:
total_size_on_axis += x.shape[self.axis]
if not x.sparse:
all_sparse = False
output_shape = list(first_shape)
output_shape[self.axis] = total_size_on_axis
return KerasTensor(output_shape, dtype=x.dtype)
return KerasTensor(output_shape, dtype=x.dtype, sparse=all_sparse)


@keras_core_export(
Expand Down Expand Up @@ -3443,7 +3449,10 @@ def compute_output_spec(self, x1, x2):
x1_shape = getattr(x1, "shape", [])
x2_shape = getattr(x2, "shape", [])
output_shape = broadcast_shapes(x1_shape, x2_shape)
return KerasTensor(output_shape, dtype=x1.dtype)
x1_sparse = getattr(x1, "sparse", True)
x2_sparse = getattr(x2, "sparse", True)
output_sparse = x1_sparse and x2_sparse
return KerasTensor(output_shape, dtype=x1.dtype, sparse=output_sparse)


@keras_core_export(["keras_core.ops.maximum", "keras_core.ops.numpy.maximum"])
Expand Down Expand Up @@ -3582,7 +3591,10 @@ def compute_output_spec(self, x1, x2):
x1_shape = getattr(x1, "shape", [])
x2_shape = getattr(x2, "shape", [])
output_shape = broadcast_shapes(x1_shape, x2_shape)
return KerasTensor(output_shape, dtype=x1.dtype)
x1_sparse = getattr(x1, "sparse", True)
x2_sparse = getattr(x2, "sparse", True)
output_sparse = x1_sparse and x2_sparse
return KerasTensor(output_shape, dtype=x1.dtype, sparse=output_sparse)


@keras_core_export(["keras_core.ops.minimum", "keras_core.ops.numpy.minimum"])
Expand Down Expand Up @@ -5080,7 +5092,10 @@ def compute_output_spec(self, x1, x2):
x1_shape = getattr(x1, "shape", [])
x2_shape = getattr(x2, "shape", [])
output_shape = broadcast_shapes(x1_shape, x2_shape)
return KerasTensor(output_shape, dtype=x1.dtype)
x1_sparse = getattr(x1, "sparse", True)
x2_sparse = getattr(x2, "sparse", True)
output_sparse = x1_sparse and x2_sparse
return KerasTensor(output_shape, dtype=x1.dtype, sparse=output_sparse)


@keras_core_export(["keras_core.ops.subtract", "keras_core.ops.numpy.subtract"])
Expand All @@ -5107,7 +5122,10 @@ def compute_output_spec(self, x1, x2):
x1_shape = getattr(x1, "shape", [])
x2_shape = getattr(x2, "shape", [])
output_shape = broadcast_shapes(x1_shape, x2_shape)
return KerasTensor(output_shape, dtype=x1.dtype)
x1_sparse = getattr(x1, "sparse", True)
x2_sparse = getattr(x2, "sparse", True)
output_sparse = x1_sparse or x2_sparse
return KerasTensor(output_shape, dtype=x1.dtype, sparse=output_sparse)


@keras_core_export(["keras_core.ops.multiply", "keras_core.ops.numpy.multiply"])
Expand Down
Loading

0 comments on commit 2d246ad

Please sign in to comment.