diff --git a/keras_core/backend/tensorflow/random.py b/keras_core/backend/tensorflow/random.py index 994814a56..9dc2f22e5 100644 --- a/keras_core/backend/tensorflow/random.py +++ b/keras_core/backend/tensorflow/random.py @@ -34,9 +34,7 @@ def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None): def categorical(logits, num_samples, dtype="int64", seed=None): seed = tf_draw_seed(seed) - output = tf.random.stateless_categorical( - logits, num_samples, seed=seed - ) + output = tf.random.stateless_categorical(logits, num_samples, seed=seed) return tf.cast(output, dtype) diff --git a/keras_core/backend/torch/random.py b/keras_core/backend/torch/random.py index e03d296e8..77e024411 100644 --- a/keras_core/backend/torch/random.py +++ b/keras_core/backend/torch/random.py @@ -31,7 +31,10 @@ def categorical(logits, num_samples, dtype="int32", seed=None): dtype = to_torch_dtype(dtype) generator = torch_seed_generator(seed, device=get_device()) return torch.multinomial( - logits, num_samples, replacement=True, generator=generator, + logits, + num_samples, + replacement=True, + generator=generator, ).type(dtype) diff --git a/keras_core/layers/preprocessing/tf_data_layer.py b/keras_core/layers/preprocessing/tf_data_layer.py index 91f4227c6..9c64a2f4a 100644 --- a/keras_core/layers/preprocessing/tf_data_layer.py +++ b/keras_core/layers/preprocessing/tf_data_layer.py @@ -1,3 +1,5 @@ +from tensorflow import nest + from keras_core import backend from keras_core.layers.layer import Layer from keras_core.utils import backend_utils @@ -22,8 +24,11 @@ def __call__(self, inputs, **kwargs): ): # We're in a TF graph, e.g. a tf.data pipeline. self.backend.set_backend("tensorflow") - inputs = self.backend.convert_to_tensor( - inputs, dtype=self.compute_dtype + inputs = nest.map_structure( + lambda x: self.backend.convert_to_tensor( + x, dtype=self.compute_dtype + ), + inputs, ) switch_convert_input_args = False if self._convert_input_args: