Skip to content

Commit

Permalink
Adding the fixed point compute handling for requantiazation.
Browse files Browse the repository at this point in the history
  • Loading branch information
anijain2305 committed Jul 6, 2019
1 parent f365ea7 commit 155ccc1
Show file tree
Hide file tree
Showing 15 changed files with 944 additions and 421 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,8 @@
namespace tvm {
namespace relay {

// TODO(anijain2305) - Copy of QuantizedConv2DAttrs. Should we inherit?
/*! \brief Attribute for quantized conv2d operator */
struct QuantizedConv2DAttrs : public tvm::AttrsNode<QuantizedConv2DAttrs> {
struct QConv2DAttrs : public tvm::AttrsNode<QConv2DAttrs> {
// Traditional conv2d attributes.
Array<IndexExpr> strides;
Array<IndexExpr> padding;
Expand All @@ -48,13 +47,8 @@ struct QuantizedConv2DAttrs : public tvm::AttrsNode<QuantizedConv2DAttrs> {
// Quantization related attributes.
int32_t input_zero_point;
int32_t kernel_zero_point;
int32_t output_zero_point;
double input_scale;
double kernel_scale;
double output_scale;
bool use_integer_computation_for_scale_handling;

TVM_DECLARE_ATTRS(QuantizedConv2DAttrs, "relay.attrs.QuantizedConv2DAttrs") {
TVM_DECLARE_ATTRS(QConv2DAttrs, "relay.attrs.QConv2DAttrs") {
TVM_ATTR_FIELD(strides).set_default(Array<IndexExpr>({1, 1}))
.describe("Specifies the strides of the convolution.");
TVM_ATTR_FIELD(padding).set_default(Array<IndexExpr>({0, 0}))
Expand Down Expand Up @@ -88,32 +82,44 @@ struct QuantizedConv2DAttrs : public tvm::AttrsNode<QuantizedConv2DAttrs> {
.describe("Dimension ordering of output. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Default to be same as input layout.");

// use 0 bits to indicate none.
TVM_ATTR_FIELD(out_dtype)
.set_default(NullValue<DataType>())
.describe("Output data type, set to explicit type under mixed precision setting");


TVM_ATTR_FIELD(input_zero_point)
.describe("The zero point of the input tensor.");
TVM_ATTR_FIELD(kernel_zero_point)
.describe("The zero point of the kernel tensor.");
}
};


/*! \brief Attribute for requantize operator */
struct RequantizeAttrs : public tvm::AttrsNode<RequantizeAttrs> {
double input_scale;
int32_t input_zero_point;
double output_scale;
int32_t output_zero_point;
bool use_int_compute;
DataType out_dtype;

TVM_DECLARE_ATTRS(RequantizeAttrs, "relay.attrs.RequantizeAttrs") {
TVM_ATTR_FIELD(input_zero_point)
.describe("The zero point of the input tensor.");
TVM_ATTR_FIELD(output_zero_point)
.describe("The zero point of the output tensor.");
TVM_ATTR_FIELD(input_scale)
.describe("The scale of the input tensor.");
TVM_ATTR_FIELD(kernel_scale)
.describe("The scale of the kernel tensor.");
TVM_ATTR_FIELD(output_scale)
.describe("The scale of the output tensor.");
TVM_ATTR_FIELD(use_integer_computation_for_scale_handling).set_default(false)
TVM_ATTR_FIELD(use_int_compute).set_default(false)
.describe("When true, the integer computation is used to handle output scale");


TVM_ATTR_FIELD(out_dtype)
.set_default(NullValue<DataType>())
.describe("Output data type, set to explicit type under mixed precision setting");
}
};


} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_NN_QUANTIZE_H_
139 changes: 139 additions & 0 deletions include/tvm/relay/quantize_util.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file nnvm/compiler/quantize_util.h
* \brief Utility methods needs for quantized ops that can be shared
*/

#ifndef TVM_QUANTIZE_UTIL_H
#define TVM_QUANTIZE_UTIL_H

#include <tvm/expr.h>
#include "./base.h"

namespace tvm {
namespace relay {

inline bool is_Int8(const DataType& dtype) {
return dtype == Int(8);
}

inline bool is_UInt8(const DataType& dtype) {
return dtype == UInt(8);
}


inline bool is_Int16(const DataType& dtype) {
return dtype == Int(16);
}

inline bool is_UInt16(const DataType& dtype) {
return dtype == UInt(16);
}

inline bool is_Int32(const DataType& dtype) {
return dtype == Int(32);
}

inline bool is_UInt32(const DataType& dtype) {
return dtype == UInt(32);
}



inline bool is_Float32(const DataType& dtype) {
return dtype == Float(32);
}

inline bool is_quantized_type(const DataType& dtype) {
return is_Int8(dtype) || is_UInt8(dtype)
|| is_Int16(dtype) || is_UInt16(dtype);
}

enum class QuantizeOpType : uint8_t {
Quantize_Requantize,
Dequantize,
Requantize
};

inline bool is_valid_quantized_op_input_type(const QuantizeOpType &op_type, const DataType &in_dtype) {
switch(op_type) {
case QuantizeOpType::Quantize_Requantize:
return is_Float32(in_dtype) || is_quantized_type(in_dtype);
case QuantizeOpType ::Dequantize:
return is_quantized_type(in_dtype);
case QuantizeOpType ::Requantize:
return is_Int16(in_dtype) || is_Int32(in_dtype);
default:
return false;
}
}

inline bool is_valid_quantized_op_output_type(const QuantizeOpType &op_type, const DataType &in_dtype) {
switch(op_type) {
case QuantizeOpType::Quantize_Requantize:
return is_quantized_type(in_dtype);
case QuantizeOpType::Dequantize:
return is_Float32(in_dtype);
default:
return false;
}
}

inline const int32_t get_qmin(const DataType& dtype) {
if (is_Int8(dtype)) {
return std::numeric_limits<int8_t>::min();
} else if (is_UInt8(dtype)) {
return std::numeric_limits<uint8_t>::min();
} else if (is_Int16(dtype)) {
return std::numeric_limits<int16_t>::min();
} else if (is_UInt16(dtype)) {
return std::numeric_limits<uint16_t>::min();
} else if (is_Int32(dtype)) {
return std::numeric_limits<int32_t>::min();
} else if (is_UInt32(dtype)) {
return std::numeric_limits<uint32_t>::min();
}
LOG(FATAL) << "Type not supported\n";
return -1;
}


inline const int32_t get_qmax(const DataType& dtype) {
if (is_Int8(dtype)) {
return std::numeric_limits<int8_t>::max();
} else if (is_UInt8(dtype)) {
return std::numeric_limits<uint8_t>::max();
} else if (is_Int16(dtype)) {
return std::numeric_limits<int16_t>::max();
} else if (is_UInt16(dtype)) {
return std::numeric_limits<uint16_t>::max();
} else if (is_Int32(dtype)) {
return std::numeric_limits<int32_t>::max();
} else if (is_UInt32(dtype)) {
return std::numeric_limits<uint32_t>::max();
}
LOG(FATAL) << "Type not supported\n";
return -1;
}

} // namespace relay
} // namespace tvm
#endif //TVM_QUANTIZE_UTIL_H
1 change: 1 addition & 0 deletions python/tvm/relay/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from .transform import *
from .algorithm import *
from . import nn
from . import qnn
from . import annotation
from . import image
from . import vision
Expand Down
1 change: 0 additions & 1 deletion python/tvm/relay/op/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,4 @@
"""Neural network related operators."""
from __future__ import absolute_import as _abs
from .nn import *
from . import _quantize
from . import _nn
21 changes: 21 additions & 0 deletions python/tvm/relay/op/qnn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=wildcard-import
"""Neural network related operators."""
from __future__ import absolute_import as _abs
from .qnn import *
# from . import _nn
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@
"""Constructor APIs"""
from ...._ffi.function import _init_api

_init_api("relay.op.nn._quantize._make", __name__)
_init_api("relay.op.qnn._make", __name__)
Loading

0 comments on commit 155ccc1

Please sign in to comment.