From 37ee257b6833aea6477920db930a4e21eeb271bc Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Tue, 4 Dec 2018 13:03:24 -0800 Subject: [PATCH] [QUANTIZE] Clean code. --- include/tvm/relay/op.h | 17 -- python/tvm/relay/quantize/__init__.py | 2 +- python/tvm/relay/quantize/_quantize.py | 25 +++ .../{quantize_ops.py => annotate_ops.py} | 28 --- python/tvm/relay/quantize/quantize.py | 24 +-- src/relay/op/nn/convolution.cc | 3 - src/relay/pass/forward_rewrite.cc | 1 - src/relay/pass/quantize.cc | 17 +- src/relay/pass/quantize.h | 10 +- tests/python/quantize/evaluate_gluon_model.py | 177 ------------------ tests/python/quantize/test.py | 26 --- 11 files changed, 51 insertions(+), 279 deletions(-) rename python/tvm/relay/quantize/{quantize_ops.py => annotate_ops.py} (81%) delete mode 100644 tests/python/quantize/evaluate_gluon_model.py delete mode 100644 tests/python/quantize/test.py diff --git a/include/tvm/relay/op.h b/include/tvm/relay/op.h index bb53c48be5d5b..0fd54ff5b8fa7 100644 --- a/include/tvm/relay/op.h +++ b/include/tvm/relay/op.h @@ -551,23 +551,6 @@ inline ValueType OpMap::get(const Expr& expr, return map_.get(expr, def_value); } -/*! - * \param Get function from op_map. - * \param op_map The OpMap. - * \param op The operator being called. - * \tparam ValueType the content value type. - * \return The result value map. - */ -template -ValueType GetFunc(const OpMap& op_map, - const Expr& op) { - if (const OpNode* opnode = op.as()) { - return op_map.get(GetRef(opnode), ValueType()); - } else { - return ValueType(); - } -} - /*! * \brief Check that an expression is a "primtive operator". diff --git a/python/tvm/relay/quantize/__init__.py b/python/tvm/relay/quantize/__init__.py index 10f3457ff6dcc..186b9e95d9bff 100644 --- a/python/tvm/relay/quantize/__init__.py +++ b/python/tvm/relay/quantize/__init__.py @@ -3,4 +3,4 @@ from __future__ import absolute_import as _abs from .quantize import * -from . import quantize_ops +from . import annotate_ops diff --git a/python/tvm/relay/quantize/_quantize.py b/python/tvm/relay/quantize/_quantize.py index d216c7c817789..b47b2373ee260 100644 --- a/python/tvm/relay/quantize/_quantize.py +++ b/python/tvm/relay/quantize/_quantize.py @@ -1,5 +1,30 @@ """FFI exposing the Relay type inference and checking.""" from __future__ import absolute_import +import topi from tvm._ffi.function import _init_api +from ..op import op as _reg + + +@_reg.register_compute("simulated_quantize") +def simulated_quantize_compute(attrs, inputs, output_type, target): + """Compiler for simulated_quantize.""" + assert len(inputs) == 5 + assert attrs.sign + assert attrs.rounding == "round" + + data, scale, bit, clip_min, clip_max = inputs + + # simulate rounding error + scaled_data = topi.divide(data, scale) + clipped_data = topi.maximum(topi.minimum(scaled_data, clip_max), clip_min) + round_data = topi.round(clipped_data) + + # recover data + rdata = topi.multiply(round_data, scale) + return [rdata] + + +_reg.register_schedule("simulated_quantize", _reg.schedule_injective) +_reg.register_pattern("simulated_quantize", _reg.OpPattern.OPAQUE) _init_api("relay._quantize", __name__) diff --git a/python/tvm/relay/quantize/quantize_ops.py b/python/tvm/relay/quantize/annotate_ops.py similarity index 81% rename from python/tvm/relay/quantize/quantize_ops.py rename to python/tvm/relay/quantize/annotate_ops.py index 0e898bf026b99..0857878a0e2a5 100644 --- a/python/tvm/relay/quantize/quantize_ops.py +++ b/python/tvm/relay/quantize/annotate_ops.py @@ -1,7 +1,5 @@ from __future__ import absolute_import -import topi from .. import expr as _expr -from ..op import op as _reg from .quantize import QFieldKind, QFieldExpr, register_qfield_rewrite from .quantize import attach_simulated_quantize, get_current_qconfig @@ -96,29 +94,3 @@ def relu_rewrite(ref_call, new_args, ctx): return QFieldExpr(expr, x.kind) else: return None - - -@_reg.register_compute("simulated_quantize") -def simulated_quantize_compute(attrs, inputs, output_type, target): - """Compiler for simulated_quantize.""" - assert len(inputs) == 5 - assert attrs.sign - assert attrs.rounding == "round" - - data, scale, bit, clip_min, clip_max = inputs - - if attrs.kind == QFieldKind.REAL: - return [topi.identity(data)] - - # simulate rounding error - scaled_data = topi.divide(data, scale) - clipped_data = topi.maximum(topi.minimum(scaled_data, clip_max), clip_min) - round_data = topi.round(clipped_data) - - # recover data - rdata = topi.multiply(round_data, scale) - return [rdata] - - -_reg.register_schedule("simulated_quantize", _reg.schedule_injective) -_reg.register_pattern("simulated_quantize", _reg.OpPattern.OPAQUE) diff --git a/python/tvm/relay/quantize/quantize.py b/python/tvm/relay/quantize/quantize.py index 550ff63d06781..c1901a0da68ce 100644 --- a/python/tvm/relay/quantize/quantize.py +++ b/python/tvm/relay/quantize/quantize.py @@ -10,16 +10,6 @@ from ..base import register_relay_node from ..._ffi.function import register_func -# `annotate` will construct a simulated quantized graph -# center around some ops like `conv` and broadcast the quantize part -# `calibrate` will find a set of suitable bit and scale, -# also make them as constant in graph -# `realize` will realize the graph based on this constant -# information - -# TODO: -# - gpu - class QFieldKind(object): INPUT = 1 @@ -30,7 +20,7 @@ class QFieldKind(object): class QConfig(object): current = None - def __init__(self, bit_dict=None, global_scale=2.0, skip_k_conv=0): + def __init__(self, bit_dict=None, global_scale=8.0, skip_k_conv=0): if bit_dict is None: bit_dict = { QFieldKind.INPUT: 8, @@ -97,10 +87,18 @@ def register_qfield_rewrite(op_name, frewrite=None, level=10): def annotate(graph): + """ + `annotate` will construct a simulated quantized graph + center around some ops like `conv` and broadcast the quantize part + """ return _quantize.annotate(graph) def calibrate(graph, dataset=None): + """ + `calibrate` will find a set of suitable bit and scale, + also make them as constant in graph + """ def _scalar(x, dtype): return _expr.const(np.array(x).astype(dtype)) @@ -145,6 +143,10 @@ def visit_func(e): def realize(graph): + """ + `realize` will realize the graph based on this constant + information + """ return _quantize.realize(graph) diff --git a/src/relay/op/nn/convolution.cc b/src/relay/op/nn/convolution.cc index bc086c95e12c5..170b6b6d13c5c 100644 --- a/src/relay/op/nn/convolution.cc +++ b/src/relay/op/nn/convolution.cc @@ -96,9 +96,6 @@ bool Conv2DRel(const Array& types, DataType out_dtype = param->out_dtype; if (out_dtype.bits() == 0) { out_dtype = data->dtype; - if (data->dtype == Int(8)) { - out_dtype = Int(32); - } } oshape = ConvertLayout(oshape, kNCHW, out_layout); // assign output type diff --git a/src/relay/pass/forward_rewrite.cc b/src/relay/pass/forward_rewrite.cc index 5499df7e3d1f1..4f33d4a053b75 100644 --- a/src/relay/pass/forward_rewrite.cc +++ b/src/relay/pass/forward_rewrite.cc @@ -152,7 +152,6 @@ class ForwardRewriter : private ExprMutator { } // try to rewrite. if (frewrite != nullptr) { - //LOG(INFO) << "rewrite op: " << ref_call->op; Expr res = frewrite( ref_call, call_args, fcontext_ != nullptr ? fcontext_(ref_call) : NodeRef(nullptr)); diff --git a/src/relay/pass/quantize.cc b/src/relay/pass/quantize.cc index 65a1a107af1ac..f0652845853d8 100644 --- a/src/relay/pass/quantize.cc +++ b/src/relay/pass/quantize.cc @@ -60,7 +60,7 @@ TVM_REGISTER_API("relay._quantize.simulated_quantize") auto attrs = make_node(); attrs->sign = sign; attrs->rounding = rounding; - attrs->kind = Int2Kind(kind); + attrs->kind = kind; static const Op& op = Op::Get("simulated_quantize"); return CallNode::make(op, {data, dom_scale, bit, clip_min, clip_max}, Attrs(attrs), {}); }); @@ -84,7 +84,8 @@ QFieldExpr QFieldExprNode::make(Expr expr, QFieldKind kind) { TVM_REGISTER_API("relay._quantize.make_qfield_expr") .set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = QFieldExprNode::make(args[0], Int2Kind(args[1])); + *ret = QFieldExprNode::make(args[0], + static_cast(args[1].operator int())); }); @@ -119,11 +120,11 @@ Expr QIntStateNode::Realize() const { return data; } -QIntState QIntStateNode::make(Expr data, Expr dom_scale, int safe_nbit, DataType dtype) { +QIntState QIntStateNode::make(Expr data, Expr dom_scale, int nbit, DataType dtype) { NodePtr n = make_node(); n->data = std::move(data); n->dom_scale = std::move(dom_scale); - n->safe_nbit = std::move(safe_nbit); + n->nbit = std::move(nbit); n->dtype = std::move(dtype); return QIntState(n); } @@ -213,9 +214,9 @@ Expr Conv2dQStateRewrite(const Call& ref_call, const auto* rhs = new_args[1].as(); CHECK(rhs); - CHECK_EQ(lhs->safe_nbit, rhs->safe_nbit); - Expr ldata = Cast(lhs->data, Int(lhs->safe_nbit)); - Expr rdata = Cast(rhs->data, Int(rhs->safe_nbit)); + CHECK_EQ(lhs->nbit, rhs->nbit); + Expr ldata = Cast(lhs->data, Int(lhs->nbit)); + Expr rdata = Cast(rhs->data, Int(rhs->nbit)); const auto ref_attrs = ref_call->attrs.as(); auto attrs = make_node(); @@ -318,7 +319,7 @@ Expr ReluQStateRewrite(const Call& ref_call, CHECK_EQ(new_args.size(), 1); if (const auto* n = new_args[0].as()) { Expr ret = ForwardOp(ref_call, {n->data}); - return QIntStateNode::make(ret, n->dom_scale, n->safe_nbit, n->dtype); + return QIntStateNode::make(ret, n->dom_scale, n->nbit, n->dtype); } CHECK(!new_args[0]->derived_from()); return Expr(nullptr); diff --git a/src/relay/pass/quantize.h b/src/relay/pass/quantize.h index 3aaeda316dc27..d087c043df868 100644 --- a/src/relay/pass/quantize.h +++ b/src/relay/pass/quantize.h @@ -21,10 +21,6 @@ enum QFieldKind : int { kQActivation = 3, }; -inline QFieldKind Int2Kind(int x) { - return static_cast(x); -} - // SimulatedQuantize struct SimulatedQuantizeAttrs : public tvm::AttrsNode { bool sign; @@ -79,19 +75,19 @@ RELAY_DEFINE_NODE_REF(QState, QStateNode, TempExpr); class QIntStateNode : public QStateNode { public: Expr dom_scale; - int safe_nbit; // number of bit which can be cast safely. + int nbit; // number of bit DataType dtype; // current data type, realize use this information for final data type casting void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("data", &data); v->Visit("dom_scale", &dom_scale); - v->Visit("safe_nbit", &safe_nbit); + v->Visit("nbit", &nbit); v->Visit("dtype", &dtype); } Expr Realize() const final; - TVM_DLL static QIntState make(Expr data, Expr dom_scale, int safe_nbit, DataType dtype); + TVM_DLL static QIntState make(Expr data, Expr dom_scale, int nbit, DataType dtype); static constexpr const char * _type_key = "relay.quantize.QIntState"; TVM_DECLARE_NODE_TYPE_INFO(QIntStateNode, QStateNode); diff --git a/tests/python/quantize/evaluate_gluon_model.py b/tests/python/quantize/evaluate_gluon_model.py deleted file mode 100644 index 565cb841caef4..0000000000000 --- a/tests/python/quantize/evaluate_gluon_model.py +++ /dev/null @@ -1,177 +0,0 @@ -import logging -import argparse -import os -import mxnet as mx -from mxnet import gluon -from mxnet.gluon.model_zoo import vision -from gluoncv.data import imagenet - -# Two functions for reading data from record file or raw images -def get_val_data(rec_val, - batch_size, - num_workers=4): - rec_val = os.path.expanduser(rec_val) - mean_rgb = [123.68, 116.779, 103.939] - std_rgb = [58.393, 57.12, 57.375] - def batch_fn(batch, ctx): - data = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx, batch_axis=0) - label = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx, batch_axis=0) - return data, label - - val_data = mx.io.ImageRecordIter( - path_imgrec = rec_val, - preprocess_threads = num_workers, - shuffle = True, - batch_size = batch_size, - resize = 256, - data_shape = (3, 224, 224), - mean_r = mean_rgb[0], - mean_g = mean_rgb[1], - mean_b = mean_rgb[2], - std_r = std_rgb[0], - std_g = std_rgb[1], - std_b = std_rgb[2], - ) - return val_data, batch_fn - - -def evaluate(args, graph, lib, params, ctx): - """Evaluate on the validation set.""" - import tvm - from tvm.contrib import graph_runtime - - # tetup dataset. - batch_size = args.batch_size - val_data, batch_fn = get_val_data(args.rec_val, batch_size) - # create runtime module - m = graph_runtime.create(graph, lib, ctx) - m.set_input(**params) - oshape = (batch_size, args.num_classes) - out_arr = tvm.nd.empty(oshape, "float32") - # setup evaluaiton metric - acc_top1 = mx.metric.Accuracy() - acc_top5 = mx.metric.TopKAccuracy(5) - val_data.reset() - acc_top1.reset() - acc_top5.reset() - # Execute - for i, batch in enumerate(val_data): - data, label = batch_fn(batch, [mx.cpu(0)]) - m.run(data=data[0].asnumpy()) - m.get_output(0, out_arr) - acc_top1.update(label, [mx.nd.array(out_arr.asnumpy())]) - acc_top5.update(label, [mx.nd.array(out_arr.asnumpy())]) - - if args.log_interval and not (i + 1) % args.log_interval: - _, top1 = acc_top1.get() - _, top5 = acc_top5.get() - nsamples = (i + 1) * batch_size - logging.info('[%d samples] validation: acc-top1=%f acc-top5=%f', nsamples, top1, top5) - logging.info('[final] validation: acc-top1=%f acc-top5=%f', top1, top5) - - -def build_nnvm(args, gluon_model): - """Build with nnvm path""" - import tvm - import nnvm - import nnvm.compiler - net, params = nnvm.frontend.from_mxnet(gluon_model) - data_shape = (args.batch_size, 3, 224, 224) - shape_dict = {'data': data_shape} - shape_dict.update({k: v.shape for k, v in params.items()}) - dtype_dict = {"data": "float32"} - target = args.target - - with nnvm.compiler.build_config(opt_level=3): - graph, lib, params = nnvm.compiler.build( - net, target, shape_dict, dtype_dict, params=params) - ctx = tvm.nd.context(target, 0) - return graph, lib,params, ctx - - -def build_relay(args, gluon_model): - """Build with relay.""" - import tvm - from tvm import relay - data_shape = (args.batch_size, 3, 224, 224) - net, params = relay.frontend.from_mxnet(gluon_model, {"data": data_shape}) - target = args.target - with relay.build_config(opt_level=3): - graph, lib, params = relay.build( - net, target, params=params) - ctx = tvm.nd.context(target, 0) - return graph, lib, params, ctx - - -def build_quantize(args, gluon_model): - print('build quantize') - """Build with relay.""" - import tvm - from tvm import relay - from tvm.relay import quantize as qtz - from tvm.relay import ir_pass - data_shape = (args.batch_size, 3, 224, 224) - net, params = relay.frontend.from_mxnet(gluon_model, {"data": data_shape}) - target = args.target - - with qtz.qconfig(skip_k_conv=2, global_scale=2.0): - # print(net.astext()) - - graph = net - # graph = ir_pass.infer_type(graph) - # graph = ir_pass.simplify_inference(graph) - # var_map = {arg.name_hint: arg for arg in graph.params} - # const_map = {var_map[key]: tvm.relay.const(params[key]) for key in params} - # graph = tvm.relay.bind(graph, const_map) - # graph = ir_pass.fold_constant(graph) - # print('after const folding\n') - # print(graph.astext()) - qgraph = qtz.annotate(graph, params) - # print('after annotate\n') - # print(qgraph.astext()) - qgraph = qtz.calibrate(qgraph) - # print('after calibrate\n') - # print(qgraph.astext()) - # qgraph = qtz.realize(qgraph) - # print('after realize\n') - # print(qgraph.astext()) - - with relay.build_config(opt_level=3): - graph, lib, params = relay.build( - qgraph, target, params=params) - ctx = tvm.nd.context(target, 0) - return graph, lib, params, ctx - - -def main(args): - gluon_model = vision.get_model(args.model, pretrained=True) - if args.use_nnvm: - graph, lib, params, ctx = build_nnvm(args, gluon_model) - else: - # graph, lib, params, ctx = build_relay(args, gluon_model) - graph, lib, params, ctx = build_quantize(args, gluon_model) - - logging.info("Finish building model %s...", args.model) - evaluate(args, graph, lib, params, ctx) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Evaluate ImageNet validation accuracy") - parser.add_argument("--rec-val", type=str, default="~/.mxnet/datasets/imagenet/rec/val.rec", - help="the validation data") - parser.add_argument("--num-classes", type=int, default=1000, - help="batch size") - parser.add_argument("--model", type=str, default="resnet18_v1", - help="Name of the model") - parser.add_argument("--log-interval", type=int, default=100, - help="log interval") - parser.add_argument("--batch-size", type=int, default=1, - help="batch size") - parser.add_argument("--target", type=str, default="cuda", - help="target option") - parser.add_argument("--use-nnvm", action="store_true", - help="Use legacy nnvm compiler") - args = parser.parse_args() - logging.basicConfig(level=logging.INFO) - logging.info(args) - main(args) diff --git a/tests/python/quantize/test.py b/tests/python/quantize/test.py deleted file mode 100644 index 462e07e54ec27..0000000000000 --- a/tests/python/quantize/test.py +++ /dev/null @@ -1,26 +0,0 @@ -import numpy as np -import tvm -from tvm import relay -from tvm.relay import testing -from tvm.contrib import graph_runtime - - -network = 'resnet-18' -target = 'llvm' -target_host = None -batch_size = 1 -input_shape = (batch_size, 3, 224, 224) -ctx = tvm.cpu() - -net, params = testing.resnet.get_workload(num_layers=18, batch_size=1, dtype='float32') -print("%-20s relay building..." % network) -print('net: {0}'.format(net.astext())) -with relay.build_module.build_config(opt_level=3): - model, lib, params = relay.build(net, target=target, target_host=target_host, params=params) -module = graph_runtime.create(model, lib, ctx) -data_tvm = tvm.nd.array((np.random.uniform(size=input_shape)).astype('float32')) -module.set_input('data', data_tvm) -module.set_input(**params) -ftimer = module.module.time_evaluator("run", ctx, number=1, repeat=3) -prof_res = np.array(ftimer().results) * 1000 # multiply 1000 for converting to millisecond -print("%-20s %-19s (%s)" % (network, "%.2f ms" % np.mean(prof_res), "%.2f ms" % np.std(prof_res)))