Skip to content

Commit

Permalink
Revert "Fuse MatMulIntegerToFloat only when scales are scalar (#6008)" (
Browse files Browse the repository at this point in the history
#6169)

This reverts commit f2dcba7.
  • Loading branch information
yufenglee authored Dec 18, 2020
1 parent 34725ae commit 98d8a3e
Show file tree
Hide file tree
Showing 6 changed files with 6 additions and 21 deletions.
9 changes: 1 addition & 8 deletions onnxruntime/core/optimizer/matmul_integer_to_float.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ static bool CheckBiasShape(const TensorShapeProto* bias_shape) {
/**
MatMulIntegerToFloatFusion will fuse subgraph like below into MatMulIntegerToFloat:
A A_Zero B B_Zero A_Scale B_Scale Bias (Const, Optional)
A A_Zero B B_Zero A_Scale) B_Scale Bias (Const, Optional)
\ | | / \ / |
\ | | / \ / |
\ | | / \ / |
Expand Down Expand Up @@ -84,13 +84,6 @@ Status MatMulIntegerToFloatFusion::ApplyImpl(Graph& graph, bool& modified, int g
continue;
}

// A_Scale is scalar and B_Scale is scalar or 1D tensor
auto mul_node_input_defs = p_mul_node_right->InputDefs();
if (!optimizer_utils::IsScalar(*mul_node_input_defs[0]) ||
!optimizer_utils::IsScalar(*mul_node_input_defs[1])) {
continue;
}

Node& cast_node = *graph.GetNode(p_cast_node->Index());
Node& matmulinteger_node = *graph.GetNode(p_matmulinteger_node->Index());
Node& mul_node_right = *graph.GetNode(p_mul_node_right->Index());
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/optimizer/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ bool IsFloatingPointDataType(const ONNX_NAMESPACE::TensorProto& tensor_proto) {
return tensor_proto.data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT || tensor_proto.data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16 || tensor_proto.data_type() == ONNX_NAMESPACE::TensorProto_DataType_DOUBLE;
}

bool IsScalar(const NodeArg& input_arg) {
inline bool IsScalar(const NodeArg& input_arg) {
auto shape = input_arg.Shape();
if (shape == nullptr) {
// shape inferencing wasn't able to populate shape information for this NodeArg
Expand Down
5 changes: 1 addition & 4 deletions onnxruntime/core/optimizer/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@ namespace optimizer_utils {
// Check if TensorProto contains a floating point type.
bool IsFloatingPointDataType(const ONNX_NAMESPACE::TensorProto& tensor_proto);

// Check if input is a scalar
bool IsScalar(const NodeArg& input_arg);

/** Check whether a input is initializer with specified float value.
@param expected_value is the expected value of the initializer.
@param is_constant means whether the initializer is required to be constant.
Expand Down Expand Up @@ -63,7 +60,7 @@ bool ValidateShape(const NodeArg& node_arg, const std::initializer_list<int64_t>
*/
bool CompareShape(const ONNX_NAMESPACE::TensorShapeProto& node_arg_shape, const ONNX_NAMESPACE::TensorShapeProto& node_arg_other_shape);

/** Check whether each dimension is known for shape of node_arg
/** Check check whether each dimension is known for shape of node_arg
@returns false when shape is nullptr, or total dimension is not same as expected_dim_size length,
or any dim is unknown (without dim value).
*/
Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/test/optimizer/graph_transform_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3069,9 +3069,9 @@ TEST_F(GraphTransformationTests, MatMulIntegerToFloatTest) {

std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
EXPECT_EQ(op_to_count["DynamicQuantizeLinear"], 1);
EXPECT_EQ(op_to_count["MatMulInteger"], 1);
EXPECT_EQ(op_to_count["Cast"], 1);
EXPECT_EQ(op_to_count["Mul"], 2);
EXPECT_EQ(op_to_count["MatMulInteger"], 0);
EXPECT_EQ(op_to_count["Cast"], 0);
EXPECT_EQ(op_to_count["Mul"], 0);
EXPECT_EQ(op_to_count["com.microsoft.MatMulIntegerToFloat"], 3);
EXPECT_EQ(op_to_count["Add"], 1);
}
Expand Down
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ def GenerateModel(model_name):
nodes.extend(MakeSubGraph("_1", True))
nodes.extend(MakeSubGraph("_2", True))
nodes.extend(MakeSubGraph("_3", False))
nodes.extend(MakeSubGraph("_4", False))

initializers = []
initializers.extend(MakeInitializer("_1"))
Expand All @@ -49,15 +48,11 @@ def GenerateModel(model_name):
helper.make_tensor_value_info('b_quantized_2', TensorProto.UINT8, [2, 3]),
helper.make_tensor_value_info('b_zp_2', TensorProto.UINT8, [1]),
helper.make_tensor_value_info('b_scale_2', TensorProto.FLOAT, [1]),
helper.make_tensor_value_info('b_quantized_4', TensorProto.UINT8, [2, 3]),
helper.make_tensor_value_info('b_zp_4', TensorProto.UINT8, [3]),
helper.make_tensor_value_info('b_scale_4', TensorProto.FLOAT, [3]),
],
[ # outputs
helper.make_tensor_value_info('output_1', TensorProto.FLOAT, [3, 3]),
helper.make_tensor_value_info('output_2', TensorProto.FLOAT, [3, 3]),
helper.make_tensor_value_info('output_3', TensorProto.FLOAT, [3, 3]),
helper.make_tensor_value_info('output_4', TensorProto.FLOAT, [3, 3]),
],
initializers)

Expand Down

0 comments on commit 98d8a3e

Please sign in to comment.