diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index e4a6885efeb7..f3282f03c813 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -618,6 +618,7 @@ def _impl_v1(cls, inputs, attr, params): assert len(inputs) == 3 or len(inputs) == 2, "Gemm op take 2 or 3 inputs, {} given".format( len(inputs) ) + dtype = infer_type(inputs[0]).checked_type.dtype # Y = alpha * A * B + beta * C alpha = float(attr.get("alpha", 1.0)) beta = float(attr.get("beta", 1.0)) @@ -631,10 +632,10 @@ def _impl_v1(cls, inputs, attr, params): inputs[1] = _op.transpose(inputs[1], axes=(1, 0)) inputs[0] = _op.nn.batch_flatten(inputs[0]) if alpha != 1.0: - inputs[0] *= _expr.const(alpha) + inputs[0] *= _expr.const(alpha, dtype=dtype) out = _op.nn.dense(inputs[0], inputs[1], units=channels) if len(inputs) == 3: - out = out + _expr.const(beta) * inputs[2] + out = out + _expr.const(beta, dtype=dtype) * inputs[2] return out diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 89655840da2a..3ffeb3e4f788 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -1055,20 +1055,21 @@ def test_onehot(): tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5) -def verify_gemm(a_shape, b_shape, c_shape=None, freeze_params=False): +def verify_gemm(a_shape, b_shape, c_shape=None, freeze_params=False, dtype="float32"): out_shape = [a_shape[0], b_shape[1]] - a_array = np.random.uniform(size=a_shape).astype("float32") - b_array = np.random.uniform(size=b_shape).astype("float32") + a_array = np.random.uniform(size=a_shape).astype(dtype) + b_array = np.random.uniform(size=b_shape).astype(dtype) input_names = ["a", "b"] + ONNX_DTYPE = mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)] input_nodes = [ - helper.make_tensor_value_info("a", TensorProto.FLOAT, list(a_shape)), - helper.make_tensor_value_info("b", TensorProto.FLOAT, list(b_shape)), + helper.make_tensor_value_info("a", ONNX_DTYPE, list(a_shape)), + helper.make_tensor_value_info("b", ONNX_DTYPE, list(b_shape)), ] input_values = [a_array, b_array] if c_shape is not None: - c_array = np.random.uniform(size=c_shape).astype("float32") + c_array = np.random.uniform(size=c_shape).astype(dtype) input_names.append("c") - input_nodes.append(helper.make_tensor_value_info("c", TensorProto.FLOAT, list(c_shape))) + input_nodes.append(helper.make_tensor_value_info("c", ONNX_DTYPE, list(c_shape))) input_values.append(c_array) gemm_node = helper.make_node("Gemm", input_names, ["out"]) @@ -1077,11 +1078,11 @@ def verify_gemm(a_shape, b_shape, c_shape=None, freeze_params=False): [gemm_node], "gemm_test", inputs=input_nodes, - outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))], + outputs=[helper.make_tensor_value_info("out", ONNX_DTYPE, list(out_shape))], ) model = helper.make_model(graph, producer_name="gemm_test") - verify_with_ort_with_inputs(model, input_values, freeze_params=freeze_params) + verify_with_ort_with_inputs(model, input_values, freeze_params=freeze_params, dtype=dtype) @tvm.testing.uses_gpu @@ -1089,6 +1090,7 @@ def test_gemm(): verify_gemm(a_shape=(4, 3), b_shape=(3, 4)) verify_gemm(a_shape=(4, 3), b_shape=(3, 4), c_shape=(4,)) verify_gemm(a_shape=(4, 3), b_shape=(3, 4), c_shape=(4,), freeze_params=True) + verify_gemm(a_shape=(4, 3), b_shape=(3, 4), c_shape=(4,), freeze_params=True, dtype="float16") @tvm.testing.uses_gpu