From 471f02ff43be8f2443674fc5c692319489afad61 Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Wed, 21 Jun 2023 13:20:07 -0700 Subject: [PATCH] Fix edge case with compile loss --- keras_core/layers/core/masking.py | 2 +- keras_core/trainers/compile_utils.py | 36 +++++++++++++++++------ keras_core/trainers/compile_utils_test.py | 16 ++++++++++ 3 files changed, 44 insertions(+), 10 deletions(-) diff --git a/keras_core/layers/core/masking.py b/keras_core/layers/core/masking.py index abfa3c757..716d31a59 100644 --- a/keras_core/layers/core/masking.py +++ b/keras_core/layers/core/masking.py @@ -61,7 +61,7 @@ def call(self, inputs): try: outputs._keras_mask = ops.squeeze(boolean_mask, axis=-1) except AttributeError: - # outputs in a C type. + # tensor is a C type. pass return outputs diff --git a/keras_core/trainers/compile_utils.py b/keras_core/trainers/compile_utils.py index 9980833f2..cb60836ea 100644 --- a/keras_core/trainers/compile_utils.py +++ b/keras_core/trainers/compile_utils.py @@ -175,8 +175,8 @@ def build(self, y_true, y_pred): if output_names: num_outputs = len(output_names) - y_pred = nest.flatten(y_pred) - y_true = nest.flatten(y_true) + y_pred = self._flatten_y(y_pred) + y_true = self._flatten_y(y_true) metrics = self._user_metrics weighted_metrics = self._user_weighted_metrics @@ -314,16 +314,25 @@ def _build_metrics_set( flat_metrics.append(None) return flat_metrics + def _flatten_y(self, y): + if isinstance(y, dict) and self.output_names: + result = [] + for name in self.output_names: + if name in y: + result.append(y[name]) + return result + return nest.flatten(y) + def update_state(self, y_true, y_pred, sample_weight=None): if not self.built: self.build(y_true, y_pred) - y_true = nest.flatten(y_true) - y_pred = nest.flatten(y_pred) + y_true = self._flatten_y(y_true) + y_pred = self._flatten_y(y_pred) for m, y_t, y_p in zip(self._flat_metrics, y_true, y_pred): if m: m.update_state(y_t, y_p) if sample_weight is not None: - sample_weight = nest.flatten(sample_weight) + sample_weight = self._flatten_y(sample_weight) # For multi-outputs, repeat sample weights for n outputs. if len(sample_weight) < len(y_true): sample_weight = [sample_weight[0] for _ in range(len(y_true))] @@ -427,7 +436,7 @@ def build(self, y_true, y_pred): if output_names: num_outputs = len(output_names) - y_pred = nest.flatten(y_pred) + y_pred = self._flatten_y(y_pred) loss = self._user_loss loss_weights = self._user_loss_weights flat_losses = [] @@ -577,15 +586,24 @@ def __call__(self, y_true, y_pred, sample_weight=None): with ops.name_scope(self.name): return self.call(y_true, y_pred, sample_weight) + def _flatten_y(self, y): + if isinstance(y, dict) and self.output_names: + result = [] + for name in self.output_names: + if name in y: + result.append(y[name]) + return result + return nest.flatten(y) + def call(self, y_true, y_pred, sample_weight=None): if not self.built: self.build(y_true, y_pred) - y_true = nest.flatten(y_true) - y_pred = nest.flatten(y_pred) + y_true = self._flatten_y(y_true) + y_pred = self._flatten_y(y_pred) if sample_weight is not None: - sample_weight = nest.flatten(sample_weight) + sample_weight = self._flatten_y(sample_weight) # For multi-outputs, repeat sample weights for n outputs. if len(sample_weight) < len(y_true): sample_weight = [sample_weight[0] for _ in range(len(y_true))] diff --git a/keras_core/trainers/compile_utils_test.py b/keras_core/trainers/compile_utils_test.py index 853b95574..d097ef4ec 100644 --- a/keras_core/trainers/compile_utils_test.py +++ b/keras_core/trainers/compile_utils_test.py @@ -294,3 +294,19 @@ def test_dict_output_case(self, broadcast): compile_loss.build(y_true, y_pred) value = compile_loss(y_true, y_pred, sample_weight) self.assertAllClose(value, 1.266666, atol=1e-5) + + def test_list_loss_dict_data(self): + compile_loss = CompileLoss(loss=["mse", "mae"], output_names=["b", "a"]) + y_true = [backend.KerasTensor((3, 4)), backend.KerasTensor((3, 4))] + y_pred = [backend.KerasTensor((3, 4)), backend.KerasTensor((3, 4))] + compile_loss.build(y_true, y_pred) + y_true = { + "a": np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]), + "b": np.array([[0.7, 0.8], [0.9, 1.0], [1.1, 1.2]]), + } + y_pred = { + "a": np.array([[1.2, 1.1], [1.0, 0.9], [0.8, 0.7]]), + "b": np.array([[0.6, 0.5], [0.4, 0.3], [0.2, 0.1]]), + } + value = compile_loss(y_true, y_pred) + self.assertAllClose(value, 1.07666, atol=1e-5)