Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PR #18105: [Add Feature] - Throw an error if softmax is used with 1 neuron #18201

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 132 additions & 4 deletions keras/engine/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Training-related part of the Keras engine."""

import copy
import inspect
import itertools
import json
import warnings
Expand All @@ -38,9 +39,11 @@
from keras.engine import base_layer_utils
from keras.engine import compile_utils
from keras.engine import data_adapter
from keras.engine import functional
from keras.engine import input_layer as input_layer_module
from keras.engine import steps_per_execution_tuning
from keras.engine import training_utils
from keras.layers.activation import Softmax as SoftmaxLayer
from keras.metrics import base_metric
from keras.mixed_precision import loss_scale_optimizer as lso
from keras.optimizers import optimizer
Expand Down Expand Up @@ -191,8 +194,6 @@ def __new__(cls, *args, **kwargs):
# Signature detection
if is_functional_model_init_params(args, kwargs) and cls == Model:
# Functional model
from keras.engine import functional

return functional.Functional(skip_init=True, *args, **kwargs)
else:
return super(Model, cls).__new__(cls, *args, **kwargs)
Expand All @@ -206,8 +207,6 @@ def __init__(self, *args, **kwargs):
# Special case for Subclassed Functional Model, which we couldn't detect
# when __new__ is called. We only realize it is a functional model when
# it calls super.__init__ with input and output tensor.
from keras.engine import functional

if is_functional_model_init_params(args, kwargs) and not isinstance(
self, functional.Functional
):
Expand Down Expand Up @@ -745,6 +744,13 @@ def compile(
Defaults to `0`.
**kwargs: Arguments supported for backwards compatibility only.
"""

validate_softmax_activation = kwargs.pop(
"experimental_validate_softmax_activation", True
)
if validate_softmax_activation:
_validate_softmax_output(self)

if jit_compile and not tf_utils.can_jit_compile(warn=True):
jit_compile = False
base_layer.keras_api_gauge.get_cell("compile").set(True)
Expand Down Expand Up @@ -3855,6 +3861,7 @@ def _validate_compile(self, optimizer, metrics, **kwargs):

kwargs.pop("cloning", None) # Legacy DistStrat argument, never used.
kwargs.pop("experimental_run_tf_function", None) # Always `True`.
kwargs.pop("experimental_validate_softmax_activation", None)
distribute_arg = kwargs.pop("distribute", None)
if distribute_arg is not None:
raise ValueError(
Expand Down Expand Up @@ -4459,3 +4466,124 @@ def is_functional_model_init_params(args, kwargs):
if "inputs" in kwargs and "outputs" in kwargs:
return True
return False


def _validate_softmax_output(model_instance):
"""
Calls the related function for checking the output activations

Args:
model_instance: A `Model` instance, either functional or sequential.

"""
outputs = model_instance.outputs

output_layers = map_output_layers_with_names(model_instance, outputs)
_check_output_activation_softmax(output_layers)


def _check_output_activation_softmax(output_layers):
"""
Checks if the output activation is softmax and the applied axis has only
one unit.

Args:
output_layers: A dictionary of output layers with their names as keys
and the layers as values.

Raises:
ValueError: If the output activation is softmax and the applied axis
will make the model output 1.0 for all inputs.

"""
for layer_name, layer in output_layers.items():

# If the activation is a layer, we can check the axis, but as a
# precaution, we check if the layer has an axis attribute.
if hasattr(layer, "activation"):
if isinstance(layer.activation, SoftmaxLayer):
try:
softmax_axis = layer.activation.axis
except AttributeError:
continue

# This is the case for when user uses "softmax" or tf.nn.softmax
elif "axis=-1" in str(
inspect.signature(layer.activation)
) or "axis=None" in str(inspect.signature(layer.activation)):
softmax_axis = -1

# If above conditions are not met, we cannot check the output.
else:
continue

layer_output_shape = layer.output_shape

if layer_output_shape[softmax_axis] == 1:
raise ValueError(
f"Output layer {layer_name} has a single unit output, "
"but the activation is softmax. This is most likely an "
"error because softmax outputs sum to 1 therefore single "
"unit outputs with softmax will only output 1.0. If you "
"think that the error is raised due to an incorrect check, "
"please file an issue on "
"https://github.com/keras-team/keras/issues. You can "
"disable this check by setting "
"`experimental_validate_softmax_activation=False` when "
"calling `compile()` on the model."
)


def map_output_layers_with_names(model_instance, outputs):
"""
Maps the output layers with their names and returns a dictionary.

Args:
model_instance: A `Model` instance, either functional or sequential.
outputs: A list of output tensors of the model, this can be None in the
case of subclassed models.

Returns:
A dictionary of output layers with their names as keys and the layers
as values.
"""
output_layers = {}

# `outputs` can be None in the case of subclassed models.
if outputs is not None:
# Iterate over each output tensor of the model
for output in model_instance.outputs:

# Get the name of the output layer, this is the KerasTensor.
# Something like: "dense_1/Softmax:0"
output_name = output.name

# If the output name is None, skip it. This is the case for
# native tf_ops. i.e if the model has an output layer like
# tf.cast(outputs, tf.float32).
if output_name is None:
continue

output_layer_name = output_name.split("/")[0]

# We use index -1 because the model can end with an input layer.
# In that case name will be something like "input_1", otherwise
# we'll have index out of range error in that kind of cases.
layer_act = output_name.split("/")[-1]

# Find the layer instance corresponding to the output tensor which
# has a softmax activation
layer_instance = None
for layer in model_instance.layers:
if (
layer.name == output_layer_name
and "softmax" in layer_act.lower()
):
layer_instance = layer
break

# Add the output layer name and the layer instance to the dictionary
if layer_instance is not None:
output_layers[output_layer_name] = layer_instance

return output_layers
60 changes: 59 additions & 1 deletion keras/engine/training_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# ==============================================================================
"""Tests for training routines."""


import collections
import io
import sys
Expand Down Expand Up @@ -5086,6 +5085,65 @@ def test_sequential_model_get_weight_paths(self):
)


class TestCheckLastLayerActivation(test_combinations.TestCase):
def test_sequential_model_output(self):

for activation in ["softmax", tf.nn.softmax, layers_module.Softmax()]:
model = sequential.Sequential(
[
layers_module.InputLayer(input_shape=(10,)),
layers_module.Dense(1, activation=activation),
]
)
with self.assertRaisesRegex(
ValueError,
"has a single unit output, but the activation is softmax.*",
):
model.compile()
del model

def test_functional_model_output(self):
inputs = input_layer.Input(shape=(10,))
for activation in ["softmax", tf.nn.softmax, layers_module.Softmax()]:
x = layers_module.Dense(1, activation=activation)(inputs)
model = training_module.Model(inputs, x)
with self.assertRaisesRegex(
ValueError,
"has a single unit output, but the activation is softmax.*",
):
model.compile()
del model

def test_multi_output_model(self):
inputs = input_layer.Input(shape=(10,))
for activation in ["softmax", tf.nn.softmax, layers_module.Softmax()]:
x = layers_module.Dense(1, activation=activation)(inputs)
y = layers_module.Dense(1, activation=activation)(inputs)
model = training_module.Model(inputs, [x, y])
with self.assertRaisesRegex(
ValueError,
"has a single unit output, but the activation is softmax.*",
):
model.compile()
del model

def test_multi_input_output_model(self):
inputs = [
input_layer.Input(shape=(10,)),
input_layer.Input(shape=(10,)),
]
for activation in ["softmax", tf.nn.softmax, layers_module.Softmax()]:
x = layers_module.Dense(1, activation=activation)(inputs[0])
y = layers_module.Dense(1, activation=activation)(inputs[1])
model = training_module.Model(inputs, [x, y])
with self.assertRaisesRegex(
ValueError,
"has a single unit output, but the activation is softmax.*",
):
model.compile()
del model


def _is_oss():
"""Returns whether the test is run under OSS."""
return len(sys.argv) >= 1 and "bazel" in sys.argv[0]
Expand Down