Skip to content

Commit

Permalink
Fix edge case with compile loss
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Jun 21, 2023
1 parent 3682bd5 commit 471f02f
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 10 deletions.
2 changes: 1 addition & 1 deletion keras_core/layers/core/masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def call(self, inputs):
try:
outputs._keras_mask = ops.squeeze(boolean_mask, axis=-1)
except AttributeError:
# outputs in a C type.
# tensor is a C type.
pass
return outputs

Expand Down
36 changes: 27 additions & 9 deletions keras_core/trainers/compile_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,8 @@ def build(self, y_true, y_pred):
if output_names:
num_outputs = len(output_names)

y_pred = nest.flatten(y_pred)
y_true = nest.flatten(y_true)
y_pred = self._flatten_y(y_pred)
y_true = self._flatten_y(y_true)

metrics = self._user_metrics
weighted_metrics = self._user_weighted_metrics
Expand Down Expand Up @@ -314,16 +314,25 @@ def _build_metrics_set(
flat_metrics.append(None)
return flat_metrics

def _flatten_y(self, y):
if isinstance(y, dict) and self.output_names:
result = []
for name in self.output_names:
if name in y:
result.append(y[name])
return result
return nest.flatten(y)

def update_state(self, y_true, y_pred, sample_weight=None):
if not self.built:
self.build(y_true, y_pred)
y_true = nest.flatten(y_true)
y_pred = nest.flatten(y_pred)
y_true = self._flatten_y(y_true)
y_pred = self._flatten_y(y_pred)
for m, y_t, y_p in zip(self._flat_metrics, y_true, y_pred):
if m:
m.update_state(y_t, y_p)
if sample_weight is not None:
sample_weight = nest.flatten(sample_weight)
sample_weight = self._flatten_y(sample_weight)
# For multi-outputs, repeat sample weights for n outputs.
if len(sample_weight) < len(y_true):
sample_weight = [sample_weight[0] for _ in range(len(y_true))]
Expand Down Expand Up @@ -427,7 +436,7 @@ def build(self, y_true, y_pred):
if output_names:
num_outputs = len(output_names)

y_pred = nest.flatten(y_pred)
y_pred = self._flatten_y(y_pred)
loss = self._user_loss
loss_weights = self._user_loss_weights
flat_losses = []
Expand Down Expand Up @@ -577,15 +586,24 @@ def __call__(self, y_true, y_pred, sample_weight=None):
with ops.name_scope(self.name):
return self.call(y_true, y_pred, sample_weight)

def _flatten_y(self, y):
if isinstance(y, dict) and self.output_names:
result = []
for name in self.output_names:
if name in y:
result.append(y[name])
return result
return nest.flatten(y)

def call(self, y_true, y_pred, sample_weight=None):
if not self.built:
self.build(y_true, y_pred)

y_true = nest.flatten(y_true)
y_pred = nest.flatten(y_pred)
y_true = self._flatten_y(y_true)
y_pred = self._flatten_y(y_pred)

if sample_weight is not None:
sample_weight = nest.flatten(sample_weight)
sample_weight = self._flatten_y(sample_weight)
# For multi-outputs, repeat sample weights for n outputs.
if len(sample_weight) < len(y_true):
sample_weight = [sample_weight[0] for _ in range(len(y_true))]
Expand Down
16 changes: 16 additions & 0 deletions keras_core/trainers/compile_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,3 +294,19 @@ def test_dict_output_case(self, broadcast):
compile_loss.build(y_true, y_pred)
value = compile_loss(y_true, y_pred, sample_weight)
self.assertAllClose(value, 1.266666, atol=1e-5)

def test_list_loss_dict_data(self):
compile_loss = CompileLoss(loss=["mse", "mae"], output_names=["b", "a"])
y_true = [backend.KerasTensor((3, 4)), backend.KerasTensor((3, 4))]
y_pred = [backend.KerasTensor((3, 4)), backend.KerasTensor((3, 4))]
compile_loss.build(y_true, y_pred)
y_true = {
"a": np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]),
"b": np.array([[0.7, 0.8], [0.9, 1.0], [1.1, 1.2]]),
}
y_pred = {
"a": np.array([[1.2, 1.1], [1.0, 0.9], [0.8, 0.7]]),
"b": np.array([[0.6, 0.5], [0.4, 0.3], [0.2, 0.1]]),
}
value = compile_loss(y_true, y_pred)
self.assertAllClose(value, 1.07666, atol=1e-5)

0 comments on commit 471f02f

Please sign in to comment.