Skip to content

Commit

Permalink
[FRONTEND][Keras] Fix softmax axis (apache#503)
Browse files Browse the repository at this point in the history
  • Loading branch information
kazum authored and tqchen committed May 26, 2018
1 parent 4fea98b commit 1dd8caa
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 1 deletion.
2 changes: 1 addition & 1 deletion nnvm/python/nnvm/frontend/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def _convert_activation(insym, keras_layer, _):
return _sym.__add_scalar__(_sym.__mul_scalar__(insym, \
scalar=alpha), scalar=beta)
elif act_type == 'softmax':
return _sym.softmax(insym)
return _sym.softmax(insym, axis=1)
elif act_type == 'sigmoid':
return _sym.sigmoid(insym)
elif act_type == 'tanh':
Expand Down
10 changes: 10 additions & 0 deletions nnvm/tests/python/frontend/keras/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,15 @@ def test_forward_elemwise_add():
verify_keras_frontend(keras_model)


def test_forward_softmax():
data = keras.layers.Input(shape=(32,32,3))
x = keras.layers.Activation('softmax')(data)
x = keras.layers.Concatenate()([x, x])
x = keras.layers.GlobalMaxPooling2D()(x)
keras_model = keras.models.Model(data, x)
verify_keras_frontend(keras_model)


def test_forward_softrelu():
data = keras.layers.Input(shape=(32,32,3))
x = keras.layers.Activation('softplus')(data)
Expand Down Expand Up @@ -145,6 +154,7 @@ def test_forward_resnet50():

if __name__ == '__main__':
test_forward_elemwise_add()
test_forward_softmax()
test_forward_softrelu()
test_forward_leaky_relu()
test_forward_dense()
Expand Down

0 comments on commit 1dd8caa

Please sign in to comment.