Skip to content

Commit

Permalink
Add tests for _check_last_layer_activation
Browse files Browse the repository at this point in the history
  • Loading branch information
Frightera committed May 7, 2023
1 parent 1cedb20 commit 529f968
Showing 1 changed file with 52 additions and 0 deletions.
52 changes: 52 additions & 0 deletions keras/engine/training_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5000,6 +5000,58 @@ def test_sequential_model_get_weight_paths(self):
)


class TestCheckLastLayerActivation(test_combinations.TestCase):
def test_sequential_model_output(self):

for activation in ["softmax", tf.nn.softmax, layers_module.Softmax()]:
model = sequential.Sequential(
[
layers_module.InputLayer(input_shape=(10,)),
layers_module.Dense(1, activation=activation),
]
)
self.assertRaisesWarning(
training_module._check_last_layer_activation(model)
)
self.assertRaisesWarning(model.compile())
del model

def test_functional_model_output(self):
inputs = input_layer.Input(shape=(10,))
for activation in ["softmax", tf.nn.softmax, layers_module.Softmax()]:
x = layers_module.Dense(1, activation=activation)(inputs)
model = training_module.Model(inputs, x)
self.assertRaisesWarning(
training_module._check_last_layer_activation(model)
)
self.assertRaisesWarning(model.compile())
del model

def test_multi_output_model(self):
inputs = input_layer.Input(shape=(10,))
for activation in ["softmax", tf.nn.softmax, layers_module.Softmax()]:
x = layers_module.Dense(1, activation=activation)(inputs)
y = layers_module.Dense(1, activation=activation)(inputs)
model = training_module.Model(inputs, [x, y])
self.assertRaisesWarning(
training_module._check_last_layer_activation(model)
)
self.assertRaisesWarning(model.compile())
del model

def test_multi_input_output_model(self):
inputs = [input_layer.Input(shape=(10,)),
input_layer.Input(shape=(10,))]
for activation in ["softmax", tf.nn.softmax, layers_module.Softmax()]:
x = layers_module.Dense(1, activation=activation)(inputs[0])
y = layers_module.Dense(1, activation=activation)(inputs[1])
model = training_module.Model(inputs, [x, y])
self.assertRaisesWarning(
training_module._check_last_layer_activation(model)
)
self.assertRaisesWarning(model.compile())
del model

def _is_oss():
"""Returns whether the test is run under OSS."""
return len(sys.argv) >= 1 and "bazel" in sys.argv[0]
Expand Down

0 comments on commit 529f968

Please sign in to comment.