diff --git a/keras_core/backend/common/backend_utils.py b/keras_core/backend/common/backend_utils.py index ace6d9ccf..64d008a5c 100644 --- a/keras_core/backend/common/backend_utils.py +++ b/keras_core/backend/common/backend_utils.py @@ -255,3 +255,48 @@ def compute_conv_transpose_output_shape( else: output_shape = [input_shape[0], filters] + output_shape return output_shape + + +def encode_categorical_inputs( + inputs, + output_mode, + depth, + dtype="float32", + sparse=False, + count_weights=None, +): + from keras_core import ops + + """Encodes categoical inputs according to output_mode.""" + if output_mode == "int": + return ops.cast(inputs, dtype=dtype) + + original_shape = inputs.shape + # In all cases, we should uprank scalar input to a single sample. + if len(ops.shape(inputs)) == 0: + inputs = ops.expand_dims(inputs, -1) + # One hot will unprank only if the final output dimension is not already 1. + if output_mode == "one_hot": + if ops.shape(inputs)[-1] != 1: + inputs = ops.expand_dims(inputs, -1) + + if len(ops.shape(inputs)) > 2: + raise ValueError( + "When output_mode is not `'int'`, maximum supported output rank " + f"is 2. Received output_mode {output_mode} and input shape " + f"{original_shape}, " + f"which would result in output rank {inputs.shape.rank}." + ) + + binary_output = output_mode in ("multi_hot", "one_hot") + bincounts = ops.bincount( + inputs, + weights=count_weights, + minlength=depth, + ) + if binary_output: + one_hot_input = ops.one_hot(inputs, depth) + bincounts = ops.where(ops.any(one_hot_input, axis=-2), 1, 0) + bincounts = ops.cast(bincounts, dtype) + + return bincounts diff --git a/keras_core/layers/preprocessing/discretization.py b/keras_core/layers/preprocessing/discretization.py index 963d7c513..9af723c11 100644 --- a/keras_core/layers/preprocessing/discretization.py +++ b/keras_core/layers/preprocessing/discretization.py @@ -1,30 +1,23 @@ import numpy as np from keras_core import backend +from keras_core import ops from keras_core.api_export import keras_core_export -from keras_core.layers.layer import Layer +from keras_core.backend.common.backend_utils import encode_categorical_inputs +from keras_core.layers.preprocessing.tf_data_layer import TFDataLayer from keras_core.utils import argument_validation from keras_core.utils import backend_utils -from keras_core.utils import tf_utils from keras_core.utils.module_utils import tensorflow as tf @keras_core_export("keras_core.layers.Discretization") -class Discretization(Layer): +class Discretization(TFDataLayer): """A preprocessing layer which buckets continuous features by ranges. This layer will place each element of its input data into one of several contiguous ranges and output an integer index indicating which range each element was placed in. - **Note:** This layer uses TensorFlow internally. It cannot - be used as part of the compiled computation graph of a model with - any backend other than TensorFlow. - It can however be used with any backend when running eagerly. - It can also always be used as part of an input preprocessing pipeline - with any backend (outside the model itself), which is how we recommend - to use this layer. - **Note:** This layer is safe to use inside a `tf.data` pipeline (independently of which backend you're using). @@ -78,14 +71,14 @@ class Discretization(Layer): Examples: - Bucketize float values based on provided buckets. + Discretize float values based on provided buckets. >>> input = np.array([[-1.5, 1.0, 3.4, .5], [0.0, 3.0, 1.3, 0.0]]) >>> layer = Discretization(bin_boundaries=[0., 1., 2.]) >>> layer(input) array([[0, 2, 3, 1], [1, 3, 2, 1]]) - Bucketize float values based on a number of buckets to compute. + Discretize float values based on a number of buckets to compute. >>> input = np.array([[-1.5, 1.0, 3.4, .5], [0.0, 3.0, 1.3, 0.0]]) >>> layer = Discretization(num_bins=4, epsilon=0.01) >>> layer.adapt(input) @@ -161,8 +154,8 @@ def __init__( self.summary = None else: self.summary = np.array([[], []], dtype="float32") - self._convert_input_args = False self._allow_non_tensor_positional_args = True + self._convert_input_args = True def build(self, input_shape=None): self.built = True @@ -238,29 +231,11 @@ def load_own_variables(self, store): return def call(self, inputs): - if not isinstance( - inputs, - ( - tf.Tensor, - tf.SparseTensor, - tf.RaggedTensor, - np.ndarray, - backend.KerasTensor, - ), - ): - inputs = tf.convert_to_tensor( - backend.convert_to_numpy(inputs), dtype=self.input_dtype - ) - - from keras_core.backend.tensorflow.numpy import digitize - - indices = digitize(inputs, self.bin_boundaries) - - outputs = tf_utils.encode_categorical_inputs( + indices = ops.digitize(inputs, self.bin_boundaries) + outputs = encode_categorical_inputs( indices, output_mode=self.output_mode, depth=len(self.bin_boundaries) + 1, - sparse=self.sparse, dtype=self.compute_dtype, ) if ( @@ -370,7 +345,3 @@ def compress_summary(summary, epsilon): ) summary = np.stack((new_bins, new_weights)) return summary.astype("float32") - - -def bucketize(inputs, boundaries): - return tf.raw_ops.Bucketize(input=inputs, boundaries=boundaries) diff --git a/keras_core/layers/preprocessing/discretization_test.py b/keras_core/layers/preprocessing/discretization_test.py index b7f99a096..3af625a1b 100644 --- a/keras_core/layers/preprocessing/discretization_test.py +++ b/keras_core/layers/preprocessing/discretization_test.py @@ -70,7 +70,9 @@ def test_correctness(self): def test_tf_data_compatibility(self): # With fixed bins - layer = layers.Discretization(bin_boundaries=[0.0, 0.35, 0.5, 1.0]) + layer = layers.Discretization( + bin_boundaries=[0.0, 0.35, 0.5, 1.0], dtype="float32" + ) x = np.array([[-1.0, 0.0, 0.1, 0.2, 0.4, 0.5, 1.0, 1.2, 0.98]]) self.assertAllClose(layer(x), np.array([[0, 1, 1, 1, 2, 3, 4, 4, 3]])) ds = tf_data.Dataset.from_tensor_slices(x).batch(1).map(layer)