diff --git a/keras_core/backend/tensorflow/random.py b/keras_core/backend/tensorflow/random.py index 994814a56..9dc2f22e5 100644 --- a/keras_core/backend/tensorflow/random.py +++ b/keras_core/backend/tensorflow/random.py @@ -34,9 +34,7 @@ def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None): def categorical(logits, num_samples, dtype="int64", seed=None): seed = tf_draw_seed(seed) - output = tf.random.stateless_categorical( - logits, num_samples, seed=seed - ) + output = tf.random.stateless_categorical(logits, num_samples, seed=seed) return tf.cast(output, dtype) diff --git a/keras_core/backend/torch/random.py b/keras_core/backend/torch/random.py index e03d296e8..77e024411 100644 --- a/keras_core/backend/torch/random.py +++ b/keras_core/backend/torch/random.py @@ -31,7 +31,10 @@ def categorical(logits, num_samples, dtype="int32", seed=None): dtype = to_torch_dtype(dtype) generator = torch_seed_generator(seed, device=get_device()) return torch.multinomial( - logits, num_samples, replacement=True, generator=generator, + logits, + num_samples, + replacement=True, + generator=generator, ).type(dtype) diff --git a/keras_core/layers/core/masking.py b/keras_core/layers/core/masking.py index abfa3c757..716d31a59 100644 --- a/keras_core/layers/core/masking.py +++ b/keras_core/layers/core/masking.py @@ -61,7 +61,7 @@ def call(self, inputs): try: outputs._keras_mask = ops.squeeze(boolean_mask, axis=-1) except AttributeError: - # outputs in a C type. + # tensor is a C type. pass return outputs diff --git a/keras_core/layers/preprocessing/tf_data_layer.py b/keras_core/layers/preprocessing/tf_data_layer.py index 91f4227c6..9c64a2f4a 100644 --- a/keras_core/layers/preprocessing/tf_data_layer.py +++ b/keras_core/layers/preprocessing/tf_data_layer.py @@ -1,3 +1,5 @@ +from tensorflow import nest + from keras_core import backend from keras_core.layers.layer import Layer from keras_core.utils import backend_utils @@ -22,8 +24,11 @@ def __call__(self, inputs, **kwargs): ): # We're in a TF graph, e.g. a tf.data pipeline. self.backend.set_backend("tensorflow") - inputs = self.backend.convert_to_tensor( - inputs, dtype=self.compute_dtype + inputs = nest.map_structure( + lambda x: self.backend.convert_to_tensor( + x, dtype=self.compute_dtype + ), + inputs, ) switch_convert_input_args = False if self._convert_input_args: diff --git a/keras_core/models/functional.py b/keras_core/models/functional.py index 8d7895bf1..20ea429a8 100644 --- a/keras_core/models/functional.py +++ b/keras_core/models/functional.py @@ -155,6 +155,8 @@ def __init__(self, inputs, outputs, name=None, **kwargs): # We will convert directly (to the correct dtype per input). self._convert_input_args = False self._allow_non_tensor_positional_args = True + output_layers = [x._keras_history[0] for x in self.outputs] + self.output_names = [x.name for x in output_layers] self._post_build() @property diff --git a/keras_core/models/model_test.py b/keras_core/models/model_test.py index a91db7fb7..abed61432 100644 --- a/keras_core/models/model_test.py +++ b/keras_core/models/model_test.py @@ -8,22 +8,47 @@ from keras_core.models.model import model_from_json -class ModelTest(testing.TestCase): - def _get_model(self): - input_a = Input(shape=(3,), batch_size=2, name="input_a") - input_b = Input(shape=(3,), batch_size=2, name="input_b") - x = input_a + input_b - x = layers.Dense(5)(x) - outputs = layers.Dense(4)(x) - model = Model([input_a, input_b], outputs) - return model +def _get_model(): + input_a = Input(shape=(3,), batch_size=2, name="input_a") + input_b = Input(shape=(3,), batch_size=2, name="input_b") + x = input_a + input_b + x = layers.Dense(5)(x) + outputs = layers.Dense(4)(x) + model = Model([input_a, input_b], outputs) + return model + + +def _get_model_multi_outputs_list(): + x = Input(shape=(3,), name="input_a") + output_a = layers.Dense(1, name="output_a")(x) + output_b = layers.Dense(1, name="output_b", activation="sigmoid")(x) + model = Model(x, [output_a, output_b]) + return model + + +def _get_model_multi_outputs_list_no_output_names(): + x = Input(shape=(3,), name="input_a") + output_a = layers.Dense(1)(x) + output_b = layers.Dense(1, activation="sigmoid")(x) + model = Model(x, [output_a, output_b]) + return model + +def _get_model_multi_outputs_dict(): + x = Input(shape=(3,), name="input_a") + output_a = layers.Dense(1, name="output_a")(x) + output_b = layers.Dense(1, name="output_b", activation="sigmoid")(x) + model = Model(x, {"output_a": output_a, "output_b": output_b}) + return model + + +class ModelTest(testing.TestCase): def test_functional_rerouting(self): - model = self._get_model() + model = _get_model() self.assertTrue(isinstance(model, Functional)) def test_json_serialization(self): - model = self._get_model() + model = _get_model() json_string = model.to_json() new_model = model_from_json(json_string) self.assertEqual(json_string, new_model.to_json()) @@ -65,3 +90,244 @@ def call(self, x): config, custom_objects={"CustomDense": CustomDense} ) self.assertTrue(isinstance(new_model, Functional)) + + def test_functional_list_outputs_list_losses(self): + model = _get_model_multi_outputs_list() + self.assertTrue(isinstance(model, Functional)) + x = np.random.rand(8, 3) + y1 = np.random.rand(8, 1) + y2 = np.random.randint(0, 2, (8, 1)) + model.compile( + optimizer="sgd", + loss=["mean_squared_error", "binary_crossentropy"], + metrics=[ + ["mean_squared_error"], + ["mean_squared_error", "accuracy"], + ], + ) + # Fit the model to make sure compile_metrics are built + hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0) + hist_keys = sorted(hist.history.keys()) + # TODO `tf.keras` also outputs individual losses for outputs + ref_keys = sorted( + [ + "loss", + # "output_a_loss", + "output_a_mean_squared_error", + "output_b_accuracy", + # "output_b_loss", + "output_b_mean_squared_error", + ] + ) + self.assertListEqual(hist_keys, ref_keys) + + def test_functional_dict_outputs_dict_losses(self): + model = _get_model_multi_outputs_dict() + self.assertTrue(isinstance(model, Functional)) + x = np.random.rand(8, 3) + y1 = np.random.rand(8, 1) + y2 = np.random.randint(0, 2, (8, 1)) + model.compile( + optimizer="sgd", + loss={ + "output_a": "mean_squared_error", + "output_b": "binary_crossentropy", + }, + metrics={ + "output_a": ["mean_squared_error"], + "output_b": ["mean_squared_error", "accuracy"], + }, + ) + # Fit the model to make sure compile_metrics are built + hist = model.fit( + x, + {"output_a": y1, "output_b": y2}, + batch_size=2, + epochs=1, + verbose=0, + ) + hist_keys = sorted(hist.history.keys()) + # TODO `tf.keras` also outputs individual losses for outputs + ref_keys = sorted( + [ + "loss", + # "output_a_loss", + "output_a_mean_squared_error", + "output_b_accuracy", + # "output_b_loss", + "output_b_mean_squared_error", + ] + ) + self.assertListEqual(hist_keys, ref_keys) + + def test_functional_list_outputs_dict_losses_metrics(self): + model = _get_model_multi_outputs_list() + self.assertTrue(isinstance(model, Functional)) + x = np.random.rand(8, 3) + y1 = np.random.rand(8, 1) + y2 = np.random.randint(0, 2, (8, 1)) + model.compile( + optimizer="sgd", + loss={ + "output_a": "mean_squared_error", + "output_b": "binary_crossentropy", + }, + metrics={ + "output_a": ["mean_squared_error"], + "output_b": ["mean_squared_error", "accuracy"], + }, + ) + # Fit the model to make sure compile_metrics are built + hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0) + hist_keys = sorted(hist.history.keys()) + # TODO `tf.keras` also outputs individual losses for outputs + ref_keys = sorted( + [ + "loss", + # "output_a_loss", + "output_a_mean_squared_error", + "output_b_accuracy", + # "output_b_loss", + "output_b_mean_squared_error", + ] + ) + self.assertListEqual(hist_keys, ref_keys) + + def test_functional_list_outputs_dict_losses_partial_metrics(self): + model = _get_model_multi_outputs_list() + self.assertTrue(isinstance(model, Functional)) + x = np.random.rand(8, 3) + y1 = np.random.rand(8, 1) + y2 = np.random.randint(0, 2, (8, 1)) + model.compile( + optimizer="sgd", + loss={ + "output_a": "mean_squared_error", + "output_b": "binary_crossentropy", + }, + metrics={ + "output_b": ["mean_squared_error", "accuracy"], + }, + ) + # Fit the model to make sure compile_metrics are built + hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0) + hist_keys = sorted(hist.history.keys()) + # TODO `tf.keras` also outputs individual losses for outputs + ref_keys = sorted( + [ + "loss", + # "output_a_loss", + "output_b_accuracy", + # "output_b_loss", + "output_b_mean_squared_error", + ] + ) + self.assertListEqual(hist_keys, ref_keys) + + def test_functional_list_outputs_dict_losses_invalid_keys(self): + model = _get_model_multi_outputs_list() + self.assertTrue(isinstance(model, Functional)) + x = np.random.rand(8, 3) + y1 = np.random.rand(8, 1) + y2 = np.random.randint(0, 2, (8, 1)) + model.compile( + optimizer="sgd", + loss={ + "output_a": "mean_squared_error", + "output_c": "binary_crossentropy", + }, + ) + # Fit the model to make sure compile_metrics are built + with self.assertRaisesRegex( + ValueError, + "In the dict argument `loss`, " + "key 'output_c' does not correspond to any model output", + ): + model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0) + + def test_functional_list_outputs_dict_losses_no_output_names(self): + model = _get_model_multi_outputs_list_no_output_names() + self.assertTrue(isinstance(model, Functional)) + x = np.random.rand(8, 3) + y1 = np.random.rand(8, 1) + y2 = np.random.randint(0, 2, (8, 1)) + model.compile( + optimizer="sgd", + loss={"output_a": "mean_squared_error"}, + ) + # Fit the model to make sure compile_metrics are built + with self.assertRaisesRegex( + ValueError, + "In the dict argument `loss`, " + "key 'output_a' does not correspond to any model output", + ): + model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0) + + def test_functional_list_outputs_dict_metrics_invalid_keys(self): + model = _get_model_multi_outputs_list() + self.assertTrue(isinstance(model, Functional)) + x = np.random.rand(8, 3) + y1 = np.random.rand(8, 1) + y2 = np.random.randint(0, 2, (8, 1)) + model.compile( + optimizer="sgd", + loss={ + "output_a": "mean_squared_error", + "output_b": "binary_crossentropy", + }, + metrics={ + "output_c": ["mean_squared_error", "accuracy"], + }, + ) + # Fit the model to make sure compile_metrics are built + with self.assertRaisesRegex( + ValueError, + "In the dict argument `metrics`, " + "key 'output_c' does not correspond to any model output", + ): + model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0) + + def test_functional_dict_outputs_dict_losses_invalid_keys(self): + model = _get_model_multi_outputs_dict() + self.assertTrue(isinstance(model, Functional)) + x = np.random.rand(8, 3) + y1 = np.random.rand(8, 1) + y2 = np.random.randint(0, 2, (8, 1)) + model.compile( + optimizer="sgd", + loss={ + "output_a": "mean_squared_error", + "output_c": "binary_crossentropy", + }, + ) + # Fit the model to make sure compile_metrics are built + with self.assertRaisesRegex( + ValueError, + "In the dict argument `loss`, " + "key 'output_c' does not correspond to any model output", + ): + model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0) + + def test_functional_dict_outputs_dict_metrics_invalid_keys(self): + model = _get_model_multi_outputs_dict() + self.assertTrue(isinstance(model, Functional)) + x = np.random.rand(8, 3) + y1 = np.random.rand(8, 1) + y2 = np.random.randint(0, 2, (8, 1)) + model.compile( + optimizer="sgd", + loss={ + "output_a": "mean_squared_error", + "output_b": "binary_crossentropy", + }, + metrics={ + "output_c": ["mean_squared_error", "accuracy"], + }, + ) + # Fit the model to make sure compile_metrics are built + with self.assertRaisesRegex( + ValueError, + "In the dict argument `metrics`, " + "key 'output_c' does not correspond to any model output", + ): + model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0) diff --git a/keras_core/trainers/compile_utils.py b/keras_core/trainers/compile_utils.py index 6a60e0918..cb60836ea 100644 --- a/keras_core/trainers/compile_utils.py +++ b/keras_core/trainers/compile_utils.py @@ -54,7 +54,7 @@ def is_binary_or_sparse_categorical(y_true, y_pred): return is_binary, is_sparse_categorical -def get_metric(identifier, y_true, y_pred): +def get_metric(identifier, y_true, y_pred, name_prefix=None): if identifier is None: return None # Ok to have no metric for an output. @@ -85,6 +85,8 @@ def get_metric(identifier, y_true, y_pred): metric_obj = metrics_module.MeanMetricWrapper( metric_obj, name=metric_name ) + if name_prefix and not metric_obj.name.startswith(name_prefix): + metric_obj.name = "_".join([name_prefix, metric_obj.name]) return metric_obj @@ -117,7 +119,13 @@ def get_loss(identifier, y_true, y_pred): class CompileMetrics(metrics_module.Metric): - def __init__(self, metrics, weighted_metrics, name="compile_metric"): + def __init__( + self, + metrics, + weighted_metrics, + name="compile_metric", + output_names=None, + ): super().__init__(name=name) if metrics and not isinstance(metrics, (list, tuple, dict)): raise ValueError( @@ -136,6 +144,7 @@ def __init__(self, metrics, weighted_metrics, name="compile_metric"): self._user_weighted_metrics = weighted_metrics self.built = False self.name = "compile_metrics" + self.output_names = output_names @property def variables(self): @@ -150,9 +159,10 @@ def variables(self): return vars def build(self, y_true, y_pred): - if isinstance(y_pred, dict): + if self.output_names: + output_names = self.output_names + elif isinstance(y_pred, dict): output_names = sorted(list(y_pred.keys())) - num_outputs = len(output_names) elif isinstance(y_pred, (list, tuple)): num_outputs = len(y_pred) if all(hasattr(x, "_keras_history") for x in y_pred): @@ -162,15 +172,14 @@ def build(self, y_true, y_pred): else: output_names = None num_outputs = 1 + if output_names: + num_outputs = len(output_names) - y_pred = nest.flatten(y_pred) - y_true = nest.flatten(y_true) + y_pred = self._flatten_y(y_pred) + y_true = self._flatten_y(y_true) metrics = self._user_metrics weighted_metrics = self._user_weighted_metrics - if output_names and not num_outputs: - num_outputs = len(output_names) - self._flat_metrics = self._build_metrics_set( metrics, num_outputs, @@ -238,7 +247,12 @@ def _build_metrics_set( "(the list of metrics corresponding to that output). " f"Received:\n{argument_name}={metrics}" ) - for mls, yt, yp in zip(metrics, y_true, y_pred): + name = None + for idx, (mls, yt, yp) in enumerate( + zip(metrics, y_true, y_pred) + ): + if output_names: + name = output_names[idx] if not all(is_function_like(e) for e in mls): raise ValueError( f"All entries in the sublists of the " @@ -249,7 +263,7 @@ def _build_metrics_set( flat_metrics.append( MetricsList( [ - get_metric(m, yt, yp) + get_metric(m, yt, yp, name) for m in mls if m is not None ] @@ -290,7 +304,7 @@ def _build_metrics_set( flat_metrics.append( MetricsList( [ - get_metric(m, yt, yp) + get_metric(m, yt, yp, name) for m in metrics[name] if m is not None ] @@ -300,16 +314,28 @@ def _build_metrics_set( flat_metrics.append(None) return flat_metrics + def _flatten_y(self, y): + if isinstance(y, dict) and self.output_names: + result = [] + for name in self.output_names: + if name in y: + result.append(y[name]) + return result + return nest.flatten(y) + def update_state(self, y_true, y_pred, sample_weight=None): if not self.built: self.build(y_true, y_pred) - y_true = nest.flatten(y_true) - y_pred = nest.flatten(y_pred) + y_true = self._flatten_y(y_true) + y_pred = self._flatten_y(y_pred) for m, y_t, y_p in zip(self._flat_metrics, y_true, y_pred): if m: m.update_state(y_t, y_p) if sample_weight is not None: - sample_weight = nest.flatten(sample_weight) + sample_weight = self._flatten_y(sample_weight) + # For multi-outputs, repeat sample weights for n outputs. + if len(sample_weight) < len(y_true): + sample_weight = [sample_weight[0] for _ in range(len(y_true))] else: sample_weight = [None for _ in range(len(y_true))] for m, y_t, y_p, s_w in zip( @@ -375,7 +401,11 @@ def from_config(cls, config): class CompileLoss(losses_module.Loss): def __init__( - self, loss, loss_weights=None, reduction="sum_over_batch_size" + self, + loss, + loss_weights=None, + reduction="sum_over_batch_size", + output_names=None, ): if loss_weights and not isinstance(loss_weights, (list, tuple, dict)): raise ValueError( @@ -386,12 +416,14 @@ def __init__( self._user_loss = loss self._user_loss_weights = loss_weights self.built = False + self.output_names = output_names super().__init__(name="compile_loss", reduction=reduction) def build(self, y_true, y_pred): - if isinstance(y_pred, dict): + if self.output_names: + output_names = self.output_names + elif isinstance(y_pred, dict): output_names = sorted(list(y_pred.keys())) - num_outputs = len(output_names) elif isinstance(y_pred, (list, tuple)): num_outputs = len(y_pred) if all(hasattr(x, "_keras_history") for x in y_pred): @@ -401,8 +433,10 @@ def build(self, y_true, y_pred): else: output_names = None num_outputs = 1 + if output_names: + num_outputs = len(output_names) - y_pred = nest.flatten(y_pred) + y_pred = self._flatten_y(y_pred) loss = self._user_loss loss_weights = self._user_loss_weights flat_losses = [] @@ -552,20 +586,31 @@ def __call__(self, y_true, y_pred, sample_weight=None): with ops.name_scope(self.name): return self.call(y_true, y_pred, sample_weight) + def _flatten_y(self, y): + if isinstance(y, dict) and self.output_names: + result = [] + for name in self.output_names: + if name in y: + result.append(y[name]) + return result + return nest.flatten(y) + def call(self, y_true, y_pred, sample_weight=None): if not self.built: self.build(y_true, y_pred) - y_true = nest.flatten(y_true) - y_pred = nest.flatten(y_pred) + y_true = self._flatten_y(y_true) + y_pred = self._flatten_y(y_pred) if sample_weight is not None: - sample_weight = nest.flatten(sample_weight) + sample_weight = self._flatten_y(sample_weight) + # For multi-outputs, repeat sample weights for n outputs. + if len(sample_weight) < len(y_true): + sample_weight = [sample_weight[0] for _ in range(len(y_true))] else: sample_weight = [None for _ in y_true] loss_values = [] - for loss, y_t, y_p, loss_weight, sample_weight in zip( self.flat_losses, y_true, diff --git a/keras_core/trainers/compile_utils_test.py b/keras_core/trainers/compile_utils_test.py index 95e482799..d097ef4ec 100644 --- a/keras_core/trainers/compile_utils_test.py +++ b/keras_core/trainers/compile_utils_test.py @@ -115,21 +115,21 @@ def test_dict_output_case(self): metrics={ "output_1": [ metrics_module.MeanSquaredError(), - metrics_module.MeanSquaredError(), + metrics_module.MeanSquaredError(name="mse"), ], "output_2": [ metrics_module.MeanSquaredError(), - metrics_module.MeanSquaredError(), + metrics_module.MeanSquaredError(name="mse"), ], }, weighted_metrics={ "output_1": [ metrics_module.MeanSquaredError(), - metrics_module.MeanSquaredError(), + metrics_module.MeanSquaredError(name="mse"), ], "output_2": [ metrics_module.MeanSquaredError(), - metrics_module.MeanSquaredError(), + metrics_module.MeanSquaredError(name="mse"), ], }, ) @@ -169,15 +169,32 @@ def test_dict_output_case(self): result = compile_metrics.result() self.assertTrue(isinstance(result, dict)) self.assertEqual(len(result), 8) - self.assertAllClose(result["mean_squared_error"], 0.055833336) - self.assertAllClose(result["weighted_mean_squared_error"], 0.0725) + # Result values obtained from `tf.keras` + # m = tf.keras.metrics.MeanSquaredError() + # m.update_state(y_true, y_pred1, sample_weight=weight) + # m.update_state(y_true, y_pred2, sample_weight=weight) + # m.result().numpy() + self.assertAllClose(result["output_1_mean_squared_error"], 0.055833336) + self.assertAllClose(result["output_2_mean_squared_error"], 0.055833336) + self.assertAllClose(result["output_1_mse"], 0.055833336) + self.assertAllClose(result["output_2_mse"], 0.055833336) + self.assertAllClose( + result["weighted_output_1_mean_squared_error"], 0.0725 + ) + self.assertAllClose( + result["weighted_output_2_mean_squared_error"], 0.0725 + ) + self.assertAllClose(result["weighted_output_1_mse"], 0.0725) + self.assertAllClose(result["weighted_output_2_mse"], 0.0725) compile_metrics.reset_state() result = compile_metrics.result() self.assertTrue(isinstance(result, dict)) self.assertEqual(len(result), 8) - self.assertAllClose(result["mean_squared_error"], 0.0) - self.assertAllClose(result["weighted_mean_squared_error"], 0.0) + self.assertAllClose(result["output_1_mean_squared_error"], 0.0) + self.assertAllClose(result["output_2_mean_squared_error"], 0.0) + self.assertAllClose(result["weighted_output_1_mean_squared_error"], 0.0) + self.assertAllClose(result["weighted_output_2_mean_squared_error"], 0.0) def test_name_conversions(self): compile_metrics = CompileMetrics( @@ -277,3 +294,19 @@ def test_dict_output_case(self, broadcast): compile_loss.build(y_true, y_pred) value = compile_loss(y_true, y_pred, sample_weight) self.assertAllClose(value, 1.266666, atol=1e-5) + + def test_list_loss_dict_data(self): + compile_loss = CompileLoss(loss=["mse", "mae"], output_names=["b", "a"]) + y_true = [backend.KerasTensor((3, 4)), backend.KerasTensor((3, 4))] + y_pred = [backend.KerasTensor((3, 4)), backend.KerasTensor((3, 4))] + compile_loss.build(y_true, y_pred) + y_true = { + "a": np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]), + "b": np.array([[0.7, 0.8], [0.9, 1.0], [1.1, 1.2]]), + } + y_pred = { + "a": np.array([[1.2, 1.1], [1.0, 0.9], [0.8, 0.7]]), + "b": np.array([[0.6, 0.5], [0.4, 0.3], [0.2, 0.1]]), + } + value = compile_loss(y_true, y_pred) + self.assertAllClose(value, 1.07666, atol=1e-5) diff --git a/keras_core/trainers/trainer.py b/keras_core/trainers/trainer.py index 5f8c29a0e..b15ad1611 100644 --- a/keras_core/trainers/trainer.py +++ b/keras_core/trainers/trainer.py @@ -34,12 +34,20 @@ def compile( jit_compile="auto", ): self.optimizer = optimizers.get(optimizer) + if hasattr(self, "output_names"): + output_names = self.output_names + else: + output_names = None if loss is not None: - self._compile_loss = CompileLoss(loss, loss_weights) + self._compile_loss = CompileLoss( + loss, loss_weights, output_names=output_names + ) else: self._compile_loss = None if metrics is not None: - self._compile_metrics = CompileMetrics(metrics, weighted_metrics) + self._compile_metrics = CompileMetrics( + metrics, weighted_metrics, output_names=output_names + ) else: self._compile_metrics = None if jit_compile == "auto":