Skip to content

Commit

Permalink
[QUANTIZE] Clean code.
Browse files Browse the repository at this point in the history
  • Loading branch information
ZihengJiang committed Dec 4, 2018
1 parent 034f1e0 commit 37ee257
Show file tree
Hide file tree
Showing 11 changed files with 51 additions and 279 deletions.
17 changes: 0 additions & 17 deletions include/tvm/relay/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -551,23 +551,6 @@ inline ValueType OpMap<ValueType>::get(const Expr& expr,
return map_.get<ValueType>(expr, def_value);
}

/*!
* \param Get function from op_map.
* \param op_map The OpMap.
* \param op The operator being called.
* \tparam ValueType the content value type.
* \return The result value map.
*/
template<typename ValueType>
ValueType GetFunc(const OpMap<ValueType>& op_map,
const Expr& op) {
if (const OpNode* opnode = op.as<OpNode>()) {
return op_map.get(GetRef<Op>(opnode), ValueType());
} else {
return ValueType();
}
}


/*!
* \brief Check that an expression is a "primtive operator".
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/quantize/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
from __future__ import absolute_import as _abs

from .quantize import *
from . import quantize_ops
from . import annotate_ops
25 changes: 25 additions & 0 deletions python/tvm/relay/quantize/_quantize.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,30 @@
"""FFI exposing the Relay type inference and checking."""
from __future__ import absolute_import
import topi
from tvm._ffi.function import _init_api
from ..op import op as _reg


@_reg.register_compute("simulated_quantize")
def simulated_quantize_compute(attrs, inputs, output_type, target):
"""Compiler for simulated_quantize."""
assert len(inputs) == 5
assert attrs.sign
assert attrs.rounding == "round"

data, scale, bit, clip_min, clip_max = inputs

# simulate rounding error
scaled_data = topi.divide(data, scale)
clipped_data = topi.maximum(topi.minimum(scaled_data, clip_max), clip_min)
round_data = topi.round(clipped_data)

# recover data
rdata = topi.multiply(round_data, scale)
return [rdata]


_reg.register_schedule("simulated_quantize", _reg.schedule_injective)
_reg.register_pattern("simulated_quantize", _reg.OpPattern.OPAQUE)

_init_api("relay._quantize", __name__)
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from __future__ import absolute_import
import topi
from .. import expr as _expr
from ..op import op as _reg
from .quantize import QFieldKind, QFieldExpr, register_qfield_rewrite
from .quantize import attach_simulated_quantize, get_current_qconfig

Expand Down Expand Up @@ -96,29 +94,3 @@ def relu_rewrite(ref_call, new_args, ctx):
return QFieldExpr(expr, x.kind)
else:
return None


@_reg.register_compute("simulated_quantize")
def simulated_quantize_compute(attrs, inputs, output_type, target):
"""Compiler for simulated_quantize."""
assert len(inputs) == 5
assert attrs.sign
assert attrs.rounding == "round"

data, scale, bit, clip_min, clip_max = inputs

if attrs.kind == QFieldKind.REAL:
return [topi.identity(data)]

# simulate rounding error
scaled_data = topi.divide(data, scale)
clipped_data = topi.maximum(topi.minimum(scaled_data, clip_max), clip_min)
round_data = topi.round(clipped_data)

# recover data
rdata = topi.multiply(round_data, scale)
return [rdata]


_reg.register_schedule("simulated_quantize", _reg.schedule_injective)
_reg.register_pattern("simulated_quantize", _reg.OpPattern.OPAQUE)
24 changes: 13 additions & 11 deletions python/tvm/relay/quantize/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,6 @@
from ..base import register_relay_node
from ..._ffi.function import register_func

# `annotate` will construct a simulated quantized graph
# center around some ops like `conv` and broadcast the quantize part
# `calibrate` will find a set of suitable bit and scale,
# also make them as constant in graph
# `realize` will realize the graph based on this constant
# information

# TODO:
# - gpu


class QFieldKind(object):
INPUT = 1
Expand All @@ -30,7 +20,7 @@ class QFieldKind(object):
class QConfig(object):
current = None

def __init__(self, bit_dict=None, global_scale=2.0, skip_k_conv=0):
def __init__(self, bit_dict=None, global_scale=8.0, skip_k_conv=0):
if bit_dict is None:
bit_dict = {
QFieldKind.INPUT: 8,
Expand Down Expand Up @@ -97,10 +87,18 @@ def register_qfield_rewrite(op_name, frewrite=None, level=10):


def annotate(graph):
"""
`annotate` will construct a simulated quantized graph
center around some ops like `conv` and broadcast the quantize part
"""
return _quantize.annotate(graph)


def calibrate(graph, dataset=None):
"""
`calibrate` will find a set of suitable bit and scale,
also make them as constant in graph
"""
def _scalar(x, dtype):
return _expr.const(np.array(x).astype(dtype))

Expand Down Expand Up @@ -145,6 +143,10 @@ def visit_func(e):


def realize(graph):
"""
`realize` will realize the graph based on this constant
information
"""
return _quantize.realize(graph)


Expand Down
3 changes: 0 additions & 3 deletions src/relay/op/nn/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,6 @@ bool Conv2DRel(const Array<Type>& types,
DataType out_dtype = param->out_dtype;
if (out_dtype.bits() == 0) {
out_dtype = data->dtype;
if (data->dtype == Int(8)) {
out_dtype = Int(32);
}
}
oshape = ConvertLayout(oshape, kNCHW, out_layout);
// assign output type
Expand Down
1 change: 0 additions & 1 deletion src/relay/pass/forward_rewrite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,6 @@ class ForwardRewriter : private ExprMutator {
}
// try to rewrite.
if (frewrite != nullptr) {
//LOG(INFO) << "rewrite op: " << ref_call->op;
Expr res = frewrite(
ref_call, call_args,
fcontext_ != nullptr ? fcontext_(ref_call) : NodeRef(nullptr));
Expand Down
17 changes: 9 additions & 8 deletions src/relay/pass/quantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ TVM_REGISTER_API("relay._quantize.simulated_quantize")
auto attrs = make_node<SimulatedQuantizeAttrs>();
attrs->sign = sign;
attrs->rounding = rounding;
attrs->kind = Int2Kind(kind);
attrs->kind = kind;
static const Op& op = Op::Get("simulated_quantize");
return CallNode::make(op, {data, dom_scale, bit, clip_min, clip_max}, Attrs(attrs), {});
});
Expand All @@ -84,7 +84,8 @@ QFieldExpr QFieldExprNode::make(Expr expr, QFieldKind kind) {

TVM_REGISTER_API("relay._quantize.make_qfield_expr")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = QFieldExprNode::make(args[0], Int2Kind(args[1]));
*ret = QFieldExprNode::make(args[0],
static_cast<QFieldKind>(args[1].operator int()));
});


Expand Down Expand Up @@ -119,11 +120,11 @@ Expr QIntStateNode::Realize() const {
return data;
}

QIntState QIntStateNode::make(Expr data, Expr dom_scale, int safe_nbit, DataType dtype) {
QIntState QIntStateNode::make(Expr data, Expr dom_scale, int nbit, DataType dtype) {
NodePtr<QIntStateNode> n = make_node<QIntStateNode>();
n->data = std::move(data);
n->dom_scale = std::move(dom_scale);
n->safe_nbit = std::move(safe_nbit);
n->nbit = std::move(nbit);
n->dtype = std::move(dtype);
return QIntState(n);
}
Expand Down Expand Up @@ -213,9 +214,9 @@ Expr Conv2dQStateRewrite(const Call& ref_call,
const auto* rhs = new_args[1].as<QIntStateNode>();
CHECK(rhs);

CHECK_EQ(lhs->safe_nbit, rhs->safe_nbit);
Expr ldata = Cast(lhs->data, Int(lhs->safe_nbit));
Expr rdata = Cast(rhs->data, Int(rhs->safe_nbit));
CHECK_EQ(lhs->nbit, rhs->nbit);
Expr ldata = Cast(lhs->data, Int(lhs->nbit));
Expr rdata = Cast(rhs->data, Int(rhs->nbit));

const auto ref_attrs = ref_call->attrs.as<Conv2DAttrs>();
auto attrs = make_node<Conv2DAttrs>();
Expand Down Expand Up @@ -318,7 +319,7 @@ Expr ReluQStateRewrite(const Call& ref_call,
CHECK_EQ(new_args.size(), 1);
if (const auto* n = new_args[0].as<QIntStateNode>()) {
Expr ret = ForwardOp(ref_call, {n->data});
return QIntStateNode::make(ret, n->dom_scale, n->safe_nbit, n->dtype);
return QIntStateNode::make(ret, n->dom_scale, n->nbit, n->dtype);
}
CHECK(!new_args[0]->derived_from<TempExprNode>());
return Expr(nullptr);
Expand Down
10 changes: 3 additions & 7 deletions src/relay/pass/quantize.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,6 @@ enum QFieldKind : int {
kQActivation = 3,
};

inline QFieldKind Int2Kind(int x) {
return static_cast<QFieldKind>(x);
}

// SimulatedQuantize
struct SimulatedQuantizeAttrs : public tvm::AttrsNode<SimulatedQuantizeAttrs> {
bool sign;
Expand Down Expand Up @@ -79,19 +75,19 @@ RELAY_DEFINE_NODE_REF(QState, QStateNode, TempExpr);
class QIntStateNode : public QStateNode {
public:
Expr dom_scale;
int safe_nbit; // number of bit which can be cast safely.
int nbit; // number of bit
DataType dtype; // current data type, realize use this information for final data type casting

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("data", &data);
v->Visit("dom_scale", &dom_scale);
v->Visit("safe_nbit", &safe_nbit);
v->Visit("nbit", &nbit);
v->Visit("dtype", &dtype);
}

Expr Realize() const final;

TVM_DLL static QIntState make(Expr data, Expr dom_scale, int safe_nbit, DataType dtype);
TVM_DLL static QIntState make(Expr data, Expr dom_scale, int nbit, DataType dtype);

static constexpr const char * _type_key = "relay.quantize.QIntState";
TVM_DECLARE_NODE_TYPE_INFO(QIntStateNode, QStateNode);
Expand Down
Loading

0 comments on commit 37ee257

Please sign in to comment.