Skip to content

Commit

Permalink
Align compile metrics with
Browse files Browse the repository at this point in the history
  • Loading branch information
sampathweb committed Jun 22, 2023
1 parent ebe69e4 commit 2ebd3ea
Show file tree
Hide file tree
Showing 2 changed files with 205 additions and 65 deletions.
137 changes: 135 additions & 2 deletions keras_core/models/model_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
from absl.testing import parameterized

from keras_core import layers
from keras_core import testing
Expand Down Expand Up @@ -34,6 +35,27 @@ def _get_model_multi_outputs_list_no_output_names():
return model


def _get_model_single_output():
x = Input(shape=(3,), name="input_a")
output_a = layers.Dense(1, name="output_a")(x)
model = Model(x, output_a)
return model


def _get_model_single_output_list():
x = Input(shape=(3,), name="input_a")
output_a = layers.Dense(1, name="output_a")(x)
model = Model(x, [output_a])
return model


def _get_model_single_output_dict():
x = Input(shape=(3,), name="input_a")
output_a = layers.Dense(1, name="output_a")(x)
model = Model(x, {"output_a": output_a})
return model


def _get_model_multi_outputs_dict():
x = Input(shape=(3,), name="input_a")
output_a = layers.Dense(1, name="output_a")(x)
Expand All @@ -42,7 +64,7 @@ def _get_model_multi_outputs_dict():
return model


class ModelTest(testing.TestCase):
class ModelTest(testing.TestCase, parameterized.TestCase):
def test_functional_rerouting(self):
model = _get_model()
self.assertTrue(isinstance(model, Functional))
Expand Down Expand Up @@ -91,6 +113,61 @@ def call(self, x):
)
self.assertTrue(isinstance(new_model, Functional))

@parameterized.named_parameters(
("single_output_1", _get_model_single_output, None),
("single_output_2", _get_model_single_output, "list"),
("single_output_3", _get_model_single_output, "dict"),
("single_output_4", _get_model_single_output, "dict_list"),
("single_list_output_1", _get_model_single_output_list, None),
("single_list_output_2", _get_model_single_output_list, "list"),
("single_list_output_3", _get_model_single_output_list, "dict"),
("single_list_output_4", _get_model_single_output_list, "dict_list"),
("single_dict_output_1", _get_model_single_output_dict, None),
("single_dict_output_2", _get_model_single_output_dict, "list"),
("single_dict_output_3", _get_model_single_output_dict, "dict"),
("single_dict_output_4", _get_model_single_output_dict, "dict_list"),
)
def test_functional_single_output(self, model_fn, loss_type):
model = model_fn()
self.assertTrue(isinstance(model, Functional))
loss = "mean_squared_error"
if loss_type == "list":
loss = [loss]
elif loss_type == "dict":
loss = {"output_a": loss}
elif loss_type == "dict_lsit":
loss = {"output_a": [loss]}
model.compile(
optimizer="sgd",
loss=loss,
metrics={
"output_a": ["mean_squared_error", "mean_absolute_error"],
},
weighted_metrics={
"output_a": "mean_squared_error",
},
)
# Fit the model to make sure compile_metrics are built
x = np.random.rand(8, 3)
y = np.random.rand(8, 1)
hist = model.fit(
x,
y,
batch_size=2,
epochs=1,
verbose=0,
)
hist_keys = sorted(hist.history.keys())
ref_keys = sorted(
[
"loss",
"mean_absolute_error",
"mean_squared_error",
"weighted_mean_squared_error",
]
)
self.assertListEqual(hist_keys, ref_keys)

def test_functional_list_outputs_list_losses(self):
model = _get_model_multi_outputs_list()
self.assertTrue(isinstance(model, Functional))
Expand All @@ -101,9 +178,10 @@ def test_functional_list_outputs_list_losses(self):
optimizer="sgd",
loss=["mean_squared_error", "binary_crossentropy"],
metrics=[
["mean_squared_error"],
"mean_squared_error",
["mean_squared_error", "accuracy"],
],
loss_weights=[0.1, 2],
)
# Fit the model to make sure compile_metrics are built
hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)
Expand Down Expand Up @@ -137,6 +215,10 @@ def test_functional_dict_outputs_dict_losses(self):
"output_a": ["mean_squared_error"],
"output_b": ["mean_squared_error", "accuracy"],
},
weighted_metrics={
"output_a": ["mean_squared_error"],
"output_b": ["mean_squared_error", "accuracy"],
},
)
# Fit the model to make sure compile_metrics are built
hist = model.fit(
Expand All @@ -153,9 +235,12 @@ def test_functional_dict_outputs_dict_losses(self):
"loss",
# "output_a_loss",
"output_a_mean_squared_error",
"output_a_weighted_mean_squared_error",
"output_b_accuracy",
# "output_b_loss",
"output_b_mean_squared_error",
"output_b_weighted_accuracy",
"output_b_weighted_mean_squared_error",
]
)
self.assertListEqual(hist_keys, ref_keys)
Expand All @@ -176,16 +261,64 @@ def test_functional_list_outputs_dict_losses_metrics(self):
"output_a": ["mean_squared_error"],
"output_b": ["mean_squared_error", "accuracy"],
},
weighted_metrics={
"output_a": ["mean_squared_error"],
"output_b": ["mean_squared_error", "accuracy"],
},
)
# Fit the model to make sure compile_metrics are built
hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)
hist_keys = sorted(hist.history.keys())
# TODO `tf.keras` also outputs individual losses for outputs
ref_keys = sorted(
[
"loss",
# "output_a_loss",
"output_a_mean_squared_error",
"output_a_weighted_mean_squared_error",
"output_b_accuracy",
# "output_b_loss",
"output_b_mean_squared_error",
"output_b_weighted_accuracy",
"output_b_weighted_mean_squared_error",
]
)
self.assertListEqual(hist_keys, ref_keys)

def test_functional_list_outputs_dict_losses_metrics_uniq_weighted(self):
model = _get_model_multi_outputs_list()
self.assertTrue(isinstance(model, Functional))
x = np.random.rand(8, 3)
y1 = np.random.rand(8, 1)
y2 = np.random.randint(0, 2, (8, 1))
model.compile(
optimizer="sgd",
loss={
"output_a": "mean_squared_error",
"output_b": "binary_crossentropy",
},
metrics={
"output_a": ["mean_squared_error"],
"output_b": ["mean_squared_error"],
},
weighted_metrics={
"output_a": ["mean_squared_error"],
"output_b": ["accuracy"],
},
)
# Fit the model to make sure compile_metrics are built
hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)
hist_keys = sorted(hist.history.keys())
# TODO `tf.keras` also outputs individual losses for outputs
# `output_b_accuracy` doesn't have `weighted_` in metric name.
# When a metric is only in weighted metrics, it skips `weighted_`
# prefix. This behavior matches`tf.keras`.
ref_keys = sorted(
[
"loss",
# "output_a_loss",
"output_a_mean_squared_error",
"output_a_weighted_mean_squared_error",
"output_b_accuracy",
# "output_b_loss",
"output_b_mean_squared_error",
Expand Down
Loading

0 comments on commit 2ebd3ea

Please sign in to comment.