diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index c49f7c675d13..c6eed9c64e6c 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -19,6 +19,7 @@ """ONNX: Open Neural Network Exchange frontend for Relay.""" import copy import warnings +from typing import Optional import numpy as np import tvm @@ -1926,6 +1927,14 @@ def _impl_v13(cls, inputs, attr, params): class LogSoftmax(OnnxOpConverter): """Operator converter for Softmax.""" + @classmethod + def run_calculation(cls, x, axes): + """Run the calculation for Log Softmax calculation.""" + m = _op.max(x, axes, keepdims=True) + e = _op.exp(x - m) + s = _op.sum(e, axes, keepdims=True) + return x - m - _op.log(s) + @classmethod def _impl_v1(cls, inputs, attr, params): axis = attr.get("axis", 1) @@ -1933,11 +1942,7 @@ def _impl_v1(cls, inputs, attr, params): if axis < 0: axis += ndim axes = list(range(axis, ndim)) - x = inputs[0] - m = _op.max(x, axes, keepdims=True) - e = _op.exp(x - m) - s = _op.sum(e, axes, keepdims=True) - return x - m - _op.log(s) + return cls.run_calculation(inputs[0], axes) @classmethod def _impl_v13(cls, inputs, attr, params): @@ -1946,11 +1951,7 @@ def _impl_v13(cls, inputs, attr, params): if axis < 0: axis += ndim axes = [axis] - x = inputs[0] - m = _op.max(x, axes, keepdims=True) - e = _op.exp(x - m) - s = _op.sum(e, axes, keepdims=True) - return x - m - _op.log(s) + return cls.run_calculation(inputs[0], axes) class Hardmax(OnnxOpConverter): @@ -3611,33 +3612,30 @@ def _impl_v1(cls, inputs, attr, params): class NegativeLogLikelihoodLoss(OnnxOpConverter): - """Operator converter for random_uniform""" + """Operator converter for NegativeLogLikehoodLoss""" VALID_REDUCTIONS = {"mean", "sum", "none"} @classmethod - def _impl_v13(cls, inputs, attr, params): - ignore_index = attr.get("ignore_index", None) - reduction = attr.get("reduction", b"mean").decode("utf-8") - - if reduction not in cls.VALID_REDUCTIONS: - raise ValueError( - f"Unknown reduction type {reduction}, choices are {cls.VALID_REDUCTIONS}" - ) - - input_tensor, target_tensor = inputs[0], inputs[1] - + def run_calculation( + cls: "NegativeLogLikelihoodLoss", + input_tensor: relay.Expr, + target_tensor: relay.Expr, + weight_tensor: Optional[relay.Expr], + ignore_index: int, + ): + """Run calculation for NegativeLogLikelihood, returning output tensor and + weight tensor used for mean-style reductions. + """ # Convert negative indices --> positive indices for gather ops, note we have to # use the original target tensor to interact with ignore_index to have proper behavior. normalized_target_tensor = normalize_gather_indices(input_tensor, target_tensor, 1) - if len(inputs) == 3: - weight_tensor = inputs[2] - else: + if weight_tensor is None: channels = infer_shape(input_tensor)[1] weight_tensor = relay.ones( [channels], - dtype=input_tensor.type_annotation.dtype, + dtype=infer_type(input_tensor).checked_type.dtype, ) loss = -relay.gather( @@ -3670,7 +3668,30 @@ def _impl_v13(cls, inputs, attr, params): select_weights *= relay.cast_like(mask_tensor, select_weights) weight_total = relay.sum(select_weights) + return loss, weight_total + + @classmethod + def _impl_v13(cls, inputs, attr, params): + ignore_index = attr.get("ignore_index", None) + reduction = attr.get("reduction", b"mean").decode("utf-8") + + if reduction not in cls.VALID_REDUCTIONS: + raise ValueError( + f"Unknown reduction type {reduction}, choices are {cls.VALID_REDUCTIONS}" + ) + input_tensor, target_tensor = inputs[0], inputs[1] + if len(inputs) == 3: + weight_tensor = inputs[2] + else: + weight_tensor = None + + loss, weight_total = cls.run_calculation( + input_tensor, + target_tensor, + weight_tensor=weight_tensor, + ignore_index=ignore_index, + ) if reduction == "mean": return relay.sum(loss) / weight_total if reduction == "sum": @@ -3679,6 +3700,39 @@ def _impl_v13(cls, inputs, attr, params): return loss +class SoftmaxCrossEntropyLoss(OnnxOpConverter): + """Operator converter for SCE_loss""" + + @classmethod + def _impl_v13(cls, inputs, attr, params): + ignore_index = attr.get("ignore_index", None) + reduction = attr.get("reduction", b"mean").decode("utf-8") + input_tensor, target_tensor = inputs[0], inputs[1] + if len(inputs) == 3: + weight_tensor = inputs[2] + else: + weight_tensor = None + + get_log_prob = attr["tvm_custom"]["num_outputs"] == 2 + log_softmax_tensor = LogSoftmax.run_calculation(input_tensor, axes=[1]) + + loss, weight_total = NegativeLogLikelihoodLoss.run_calculation( + log_softmax_tensor, + target_tensor, + weight_tensor, + ignore_index=ignore_index, + ) + + if reduction == "mean": + loss = relay.sum(loss) / weight_total + elif reduction == "sum": + loss = relay.sum(loss) + + if get_log_prob: + return relay.TupleWrapper(relay.Tuple((loss, log_softmax_tensor)), 2) + return loss + + class Adagrad(OnnxOpConverter): """Operator converter for adagrad op.""" @@ -4037,6 +4091,7 @@ def _get_convert_map(opset): "RandomUniform": RandomUniform.get_converter(opset), # Loss functions / training "NegativeLogLikelihoodLoss": NegativeLogLikelihoodLoss.get_converter(opset), + "SoftmaxCrossEntropyLoss": SoftmaxCrossEntropyLoss.get_converter(opset), "Adagrad": Adagrad.get_converter(opset), "Adam": Adam.get_converter(opset), "Momentum": Momentum.get_converter(opset), diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 35abc6d896b3..3aef9a2a2ceb 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -4944,73 +4944,40 @@ def verify_eyelike(indata): "test_round", "test_scan9_sum", "test_scan_sum", - "test_sce_NCd1_mean_weight_negative_ii", + # With reduce_sum supported fully, these expanded tests should pass "test_sce_NCd1_mean_weight_negative_ii_expanded", - "test_sce_NCd1_mean_weight_negative_ii_log_prob", "test_sce_NCd1_mean_weight_negative_ii_log_prob_expanded", - "test_sce_NCd1d2d3_none_no_weight_negative_ii", "test_sce_NCd1d2d3_none_no_weight_negative_ii_expanded", - "test_sce_NCd1d2d3_none_no_weight_negative_ii_log_prob", "test_sce_NCd1d2d3_none_no_weight_negative_ii_log_prob_expanded", - "test_sce_NCd1d2d3_sum_weight_high_ii", "test_sce_NCd1d2d3_sum_weight_high_ii_expanded", - "test_sce_NCd1d2d3_sum_weight_high_ii_log_prob", "test_sce_NCd1d2d3_sum_weight_high_ii_log_prob_expanded", - "test_sce_NCd1d2d3d4d5_mean_weight", "test_sce_NCd1d2d3d4d5_mean_weight_expanded", - "test_sce_NCd1d2d3d4d5_mean_weight_log_prob", "test_sce_NCd1d2d3d4d5_mean_weight_log_prob_expanded", - "test_sce_NCd1d2d3d4d5_none_no_weight", "test_sce_NCd1d2d3d4d5_none_no_weight_expanded", - "test_sce_NCd1d2d3d4d5_none_no_weight_log_prob", "test_sce_NCd1d2d3d4d5_none_no_weight_log_prob_expanded", - "test_sce_mean", - "test_sce_mean_3d", "test_sce_mean_3d_expanded", - "test_sce_mean_3d_log_prob", "test_sce_mean_3d_log_prob_expanded", "test_sce_mean_expanded", - "test_sce_mean_log_prob", "test_sce_mean_log_prob_expanded", - "test_sce_mean_no_weight_ii", - "test_sce_mean_no_weight_ii_3d", "test_sce_mean_no_weight_ii_3d_expanded", - "test_sce_mean_no_weight_ii_3d_log_prob", "test_sce_mean_no_weight_ii_3d_log_prob_expanded", - "test_sce_mean_no_weight_ii_4d", "test_sce_mean_no_weight_ii_4d_expanded", - "test_sce_mean_no_weight_ii_4d_log_prob", "test_sce_mean_no_weight_ii_4d_log_prob_expanded", "test_sce_mean_no_weight_ii_expanded", - "test_sce_mean_no_weight_ii_log_prob", "test_sce_mean_no_weight_ii_log_prob_expanded", - "test_sce_mean_weight", "test_sce_mean_weight_expanded", - "test_sce_mean_weight_ii", - "test_sce_mean_weight_ii_3d", "test_sce_mean_weight_ii_3d_expanded", - "test_sce_mean_weight_ii_3d_log_prob", "test_sce_mean_weight_ii_3d_log_prob_expanded", - "test_sce_mean_weight_ii_4d", "test_sce_mean_weight_ii_4d_expanded", - "test_sce_mean_weight_ii_4d_log_prob", "test_sce_mean_weight_ii_4d_log_prob_expanded", "test_sce_mean_weight_ii_expanded", - "test_sce_mean_weight_ii_log_prob", "test_sce_mean_weight_ii_log_prob_expanded", - "test_sce_mean_weight_log_prob", "test_sce_mean_weight_log_prob_expanded", - "test_sce_none", "test_sce_none_expanded", - "test_sce_none_log_prob", "test_sce_none_log_prob_expanded", - "test_sce_none_weights", "test_sce_none_weights_expanded", - "test_sce_none_weights_log_prob", "test_sce_none_weights_log_prob_expanded", - "test_sce_sum", "test_sce_sum_expanded", - "test_sce_sum_log_prob", "test_sce_sum_log_prob_expanded", "test_sequence_insert_at_back", "test_sequence_insert_at_front", @@ -5093,6 +5060,12 @@ def test_onnx_nodes(target, dev, onnx_test): # for some reason the ONNX test crops the # roialign results to 4 decimal places atol = 1e-4 + + if "_sce_" in test_dir: + # complicated loss functions like SoftmaxCrossEntropy can have minor variations + # in accuracy depending on implementation + atol = 1e-4 + onnx_model = onnx.load(test_dir + "/model.onnx") inputs = [] outputs = []