Skip to content

Commit

Permalink
PR #18105: [Add Feature] - Warn user if softmax usage is wrong
Browse files Browse the repository at this point in the history
Imported from GitHub PR #18105

This is a utility function to check if the usage of softmax makes sense (new users make this mistake a lot). Applying softmax on a single neuron will make the model output ones everytime, there are too many Stackoverflow posts about this.

In order to see this in action, please check [the gist](https://colab.research.google.com/gist/Frightera/fdcec020fff6ee9521ae2fd3eaed774d/checksoftmaxlastlayer.ipynb).

This applies for any other layers (Conv2D etc.) where the applied axis (axis=-1 default) of softmax has only one unit.
Copybara import of the project:

--
90c95b1 by Kaan Bıçakcı <[email protected]>:

Add last layer activation check for softmax

--
1cedb20 by Kaan Bıçakcı <[email protected]>:

Split logic for sequential and functional models

--
529f968 by Kaan Bıçakcı <[email protected]>:

Add tests for _check_last_layer_activation

--
d1acddb by Kaan Bıçakcı <[email protected]>:

Update sequential check

--
8363016 by Kaan Bıçakcı <[email protected]>:

Update tests, logic and reformatting

--
ebf16c3 by Kaan Bıçakcı <[email protected]>:

Update tests and the logic

--
afc156a by Kaan Bıçakcı <[email protected]>:

Make validate_softmax_activation experimental

--
3a228fb by Kaan Bıçakcı <[email protected]>:

Fix edge case for _validate_softmax_output

--
e9c950e by Kaan Bıçakcı <[email protected]>:

Check the softmax axis and raise an error if relevant

--
6355b23 by Kaan Bıçakcı <[email protected]>:

Update softmax check tests

--
a6745ee by Kaan Bıçakcı <[email protected]>:

Minor typo fix

Merging this change closes #18105

FUTURE_COPYBARA_INTEGRATE_REVIEW=#18105 from Frightera:last_layer_softmax_warn a6745ee
PiperOrigin-RevId: 538223534
  • Loading branch information
Frightera authored and tensorflower-gardener committed Jun 6, 2023
1 parent 09614a7 commit da35e3e
Show file tree
Hide file tree
Showing 2 changed files with 189 additions and 5 deletions.
134 changes: 130 additions & 4 deletions keras/engine/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Training-related part of the Keras engine."""

import copy
import inspect
import itertools
import json
import warnings
Expand All @@ -38,6 +39,7 @@
from keras.engine import base_layer_utils
from keras.engine import compile_utils
from keras.engine import data_adapter
from keras.engine import functional
from keras.engine import input_layer as input_layer_module
from keras.engine import steps_per_execution_tuning
from keras.engine import training_utils
Expand Down Expand Up @@ -191,8 +193,6 @@ def __new__(cls, *args, **kwargs):
# Signature detection
if is_functional_model_init_params(args, kwargs) and cls == Model:
# Functional model
from keras.engine import functional

return functional.Functional(skip_init=True, *args, **kwargs)
else:
return super(Model, cls).__new__(cls, *args, **kwargs)
Expand All @@ -206,8 +206,6 @@ def __init__(self, *args, **kwargs):
# Special case for Subclassed Functional Model, which we couldn't detect
# when __new__ is called. We only realize it is a functional model when
# it calls super.__init__ with input and output tensor.
from keras.engine import functional

if is_functional_model_init_params(args, kwargs) and not isinstance(
self, functional.Functional
):
Expand Down Expand Up @@ -745,6 +743,13 @@ def compile(
Defaults to `0`.
**kwargs: Arguments supported for backwards compatibility only.
"""

validate_softmax_activation = kwargs.pop(
"experimental_validate_softmax_activation", True
)
if validate_softmax_activation:
_validate_softmax_output(self)

if jit_compile and not tf_utils.can_jit_compile(warn=True):
jit_compile = False
base_layer.keras_api_gauge.get_cell("compile").set(True)
Expand Down Expand Up @@ -3855,6 +3860,7 @@ def _validate_compile(self, optimizer, metrics, **kwargs):

kwargs.pop("cloning", None) # Legacy DistStrat argument, never used.
kwargs.pop("experimental_run_tf_function", None) # Always `True`.
kwargs.pop("experimental_validate_softmax_activation", None)
distribute_arg = kwargs.pop("distribute", None)
if distribute_arg is not None:
raise ValueError(
Expand Down Expand Up @@ -4459,3 +4465,123 @@ def is_functional_model_init_params(args, kwargs):
if "inputs" in kwargs and "outputs" in kwargs:
return True
return False


def _validate_softmax_output(model_instance):
"""
Calls the related function for checking the output activations
Args:
model_instance: A `Model` instance, either functional or sequential.
"""
outputs = model_instance.outputs

output_layers = map_output_layers_with_names(model_instance, outputs)
_check_output_activation_softmax(output_layers)


def _check_output_activation_softmax(output_layers):
"""
Checks if the output activation is softmax and the applied axis has only
one unit.
Args:
output_layers: A dictionary of output layers with their names as keys
and the layers as values.
Raises:
ValueError: If the output activation is softmax and the applied axis
will make the model output 1.0 for all inputs.
"""
for layer_name, layer in output_layers.items():

# If the activation is a layer, we can check the axis, but as a
# precaution, we check if the layer has an axis attribute.
if isinstance(layer.activation, base_layer.Layer):
try:
softmax_axis = layer.activation.axis
except AttributeError:
continue

# This is the case for when user uses "softmax" or tf.nn.softmax
elif "axis=-1" in str(
inspect.signature(layer.activation)
) or "axis=None" in str(inspect.signature(layer.activation)):
softmax_axis = -1

# If above conditions are not met, we cannot check the output.
else:
continue

layer_output_shape = layer.output_shape

if layer_output_shape[softmax_axis] == 1:
raise ValueError(
f"Output layer {layer_name} has a single unit output, "
f"but the activation is softmax. This is most likely "
f"an error because softmax outputs sum to 1 therefore single "
f"unit outputs with softmax "
f"will only output 1.0. If you think that the error is raised "
f"due to an incorrect check, please file an issue on "
f"https://github.com/keras-team/keras/issues. You can "
f"disable this check by setting "
f"`experimental_validate_softmax_activation=False` when calling"
f" `compile()` on the model."
)


def map_output_layers_with_names(model_instance, outputs):
"""
Maps the output layers with their names and returns a dictionary.
Args:
model_instance: A `Model` instance, either functional or sequential.
outputs: A list of output tensors of the model, this can be None in the
case of subclassed models.
Returns:
A dictionary of output layers with their names as keys and the layers
as values.
"""
output_layers = {}

# `outputs` can be None in the case of subclassed models.
if outputs is not None:
# Iterate over each output tensor of the model
for output in model_instance.outputs:

# Get the name of the output layer, this is the KerasTensor.
# Something like: "dense_1/Softmax:0"
output_name = output.name

# If the output name is None, skip it. This is the case for
# native tf_ops. i.e if the model has an output layer like
# tf.cast(outputs, tf.float32).
if output_name is None:
continue

output_layer_name = output_name.split("/")[0]

# We use index -1 because the model can end with an input layer.
# In that case name will be something like "input_1", otherwise
# we'll have index out of range error in that kind of cases.
layer_act = output_name.split("/")[-1]

# Find the layer instance corresponding to the output tensor which
# has a softmax activation
layer_instance = None
for layer in model_instance.layers:
if (
layer.name == output_layer_name
and "softmax" in layer_act.lower()
):
layer_instance = layer
break

# Add the output layer name and the layer instance to the dictionary
if layer_instance is not None:
output_layers[output_layer_name] = layer_instance

return output_layers
60 changes: 59 additions & 1 deletion keras/engine/training_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# ==============================================================================
"""Tests for training routines."""


import collections
import io
import sys
Expand Down Expand Up @@ -5086,6 +5085,65 @@ def test_sequential_model_get_weight_paths(self):
)


class TestCheckLastLayerActivation(test_combinations.TestCase):
def test_sequential_model_output(self):

for activation in ["softmax", tf.nn.softmax, layers_module.Softmax()]:
model = sequential.Sequential(
[
layers_module.InputLayer(input_shape=(10,)),
layers_module.Dense(1, activation=activation),
]
)
with self.assertRaisesRegex(
ValueError,
"has a single unit output, but the activation is softmax.*",
):
model.compile()
del model

def test_functional_model_output(self):
inputs = input_layer.Input(shape=(10,))
for activation in ["softmax", tf.nn.softmax, layers_module.Softmax()]:
x = layers_module.Dense(1, activation=activation)(inputs)
model = training_module.Model(inputs, x)
with self.assertRaisesRegex(
ValueError,
"has a single unit output, but the activation is softmax.*",
):
model.compile()
del model

def test_multi_output_model(self):
inputs = input_layer.Input(shape=(10,))
for activation in ["softmax", tf.nn.softmax, layers_module.Softmax()]:
x = layers_module.Dense(1, activation=activation)(inputs)
y = layers_module.Dense(1, activation=activation)(inputs)
model = training_module.Model(inputs, [x, y])
with self.assertRaisesRegex(
ValueError,
"has a single unit output, but the activation is softmax.*",
):
model.compile()
del model

def test_multi_input_output_model(self):
inputs = [
input_layer.Input(shape=(10,)),
input_layer.Input(shape=(10,)),
]
for activation in ["softmax", tf.nn.softmax, layers_module.Softmax()]:
x = layers_module.Dense(1, activation=activation)(inputs[0])
y = layers_module.Dense(1, activation=activation)(inputs[1])
model = training_module.Model(inputs, [x, y])
with self.assertRaisesRegex(
ValueError,
"has a single unit output, but the activation is softmax.*",
):
model.compile()
del model


def _is_oss():
"""Returns whether the test is run under OSS."""
return len(sys.argv) >= 1 and "bazel" in sys.argv[0]
Expand Down

0 comments on commit da35e3e

Please sign in to comment.