Skip to content

Commit

Permalink
[QUANTIZE] Reimplement Annotate with ForwardRewrite.
Browse files Browse the repository at this point in the history
  • Loading branch information
ZihengJiang committed Nov 26, 2018
1 parent 78c9449 commit bff439a
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 125 deletions.
213 changes: 88 additions & 125 deletions src/relay/pass/quantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <string>
#include <tvm/relay/pass.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op_attr_types.h>
#include "quantize.h"
#include "pattern_util.h"

Expand All @@ -23,24 +24,6 @@ namespace quantize {

using runtime::TypedPackedFunc;

// SimulatedQuantize
struct SimulatedQuantizeAttrs : public tvm::AttrsNode<SimulatedQuantizeAttrs> {
bool sign;
std::string rounding;
int id;
int field_type;

TVM_DECLARE_ATTRS(SimulatedQuantizeAttrs, "relay.attrs.SimulatedQuantizeAttrs") {
TVM_ATTR_FIELD(sign).set_default(true);
TVM_ATTR_FIELD(rounding).set_default("round")
.describe("rounding mode. Can be 'floor', 'ceil', 'round'");
TVM_ATTR_FIELD(id)
.describe("id");
TVM_ATTR_FIELD(field_type)
.describe("field_type");
}
};

TVM_REGISTER_NODE_TYPE(SimulatedQuantizeAttrs);

bool SimulatedQuantizeRel(const Array<Type>& types,
Expand Down Expand Up @@ -106,22 +89,7 @@ TVM_REGISTER_API("relay.op._make.simulated_quantize")
// |->mul2


// qtz_field
enum QField : int {
kFloat = 0,
kQInput = 1,
kQWeight = 2,
kQActivation = 3,
};

using FQFieldSpec = TypedPackedFunc<Array<Integer>(Attrs, Array<Integer>)>;

inline QField Int2Field(Integer x) {
return static_cast<QField>(x.operator int64_t());
}


Expr MakeSimulatedQuantize(Expr x, QField field) {
Expr MakeSimulatedQuantize(Expr x, std::string field) {
static const Op& op = Op::Get("simulated_quantize");
static int cnt = 0;
std::string name_postfix = std::to_string(cnt++);
Expand All @@ -134,117 +102,112 @@ Expr MakeSimulatedQuantize(Expr x, QField field) {
auto attrs = make_node<SimulatedQuantizeAttrs>();
attrs->sign = true;
attrs->rounding = "round";
attrs->id = cnt;
attrs->field_type = field;
return CallNode::make(op, {x, dom_scale, bit, clip_min, clip_max}, Attrs(attrs), {});
}


class Annotator : public ExprMutator {
public:
Expr Annotate(Expr e) {
this->cnt_map_ = GetExprRefCount(e);
return this->Mutate(e);
}

Expr VisitExpr_(const CallNode* n) final {
static const auto& fqfield_spec =
Op::GetAttr<FQFieldSpec>("FQFieldSpec");
size_t ref_cnt = cnt_map_.at(n);

Expr new_e = ExprMutator::VisitExpr_(n);
const auto* call = new_e.as<CallNode>();
CHECK(call);

size_t num_inputs = call->args.size();
// prepare input fields
Array<Integer> ifields;
for (size_t i = 0; i < num_inputs; ++i) {
ifields.push_back(field_map_.at(call->args[i].get()));
}

auto f = GetFunc(fqfield_spec, call->op);
if (f != nullptr) {
// get fields spec
Array<Integer> fields = f(call->attrs, ifields);
// insert simulated quantize
Array<Expr> new_args;
for (size_t i = 0; i < num_inputs; ++i) {
if (Int2Field(ifields[i]) != Int2Field(fields[i])) {
new_args.push_back(MakeSimulatedQuantize(call->args[i], Int2Field(fields[i])));
} else {
new_args.push_back(call->args[i]);
}
}
// mark output's field
Call ret = CallNode::make(call->op, new_args, call->attrs, call->type_args);
field_map_[ret.get()] = Int2Field(fields[num_inputs]);
return ret;
Array<QFieldExpr> PrepareInputs(const Array<Expr>& args) {
Array<QFieldExpr> inputs;
for (Expr arg : args) {
if (const auto* n = arg.as<QFieldExprNode>()) {
inputs.push_back(QFieldExpr(arg.node_));
} else {
// default behavior for nodes like add, relu
// it will broadcast the previous node's field
QField field = SelectField(ifields);
// change to float field for multiple ref
field_map_[new_e.get()] = ref_cnt > 1 ? kFloat : field;
return new_e;
auto node = make_node<QFieldExprNode>();
node->expr = arg;
node->field = "float";
inputs.push_back(QFieldExpr(node));
}
}
return inputs;
}

Expr VisitExpr_(const VarNode* n) final {
Expr new_e = ExprMutator::VisitExpr_(n);
field_map_[new_e.get()] = kFloat;
return new_e;
}

private:
std::unordered_map<const Node*, size_t> cnt_map_;
std::unordered_map<const Node*, QField> field_map_;

QField SelectField(Array<Integer> ifields) {
for (auto field : ifields) {
// just return the first non-float field
if (Int2Field(field) != kFloat) {
return Int2Field(field);
}
}
return kFloat;
Expr Conv2DQFieldRewrite(const Call& ref_call,
const Array<Expr>& new_args,
const NodeRef& ctx) {
auto rnode = make_node<QFieldExprNode>();
Array<QFieldExpr> inputs = PrepareInputs(new_args);
QFieldExpr lhs = inputs[0];
QFieldExpr rhs = inputs[1];

Expr lhs_expr = MakeSimulatedQuantize(lhs->expr, "input");
Expr rhs_expr = MakeSimulatedQuantize(rhs->expr, "weight");
rnode->expr = CallNode::make(ref_call->op,
{lhs_expr, rhs_expr},
ref_call->attrs, ref_call->type_args);
rnode->field = "activation";
return Expr(rnode);
}

RELAY_REGISTER_OP("nn.conv2d")
.set_attr<FForwardRewrite>("FQFieldRewrite", Conv2DQFieldRewrite);


Expr MulQFieldRewrite(const Call& ref_call,
const Array<Expr>& new_args,
const NodeRef& ctx) {
auto rnode = make_node<QFieldExprNode>();
Array<QFieldExpr> inputs = PrepareInputs(new_args);
QFieldExpr lhs = inputs[0];
QFieldExpr rhs = inputs[1];

if (lhs->field == "float" && rhs->field == "float") {
// execute the op on float domain
rnode->expr = CallNode::make(ref_call->op,
{lhs->expr, rhs->expr},
ref_call->attrs, ref_call->type_args);
rnode->field = "float";
} else if (lhs->field == "activation" && rhs->field == "float"){
// quantize rhs first
Expr rhs_expr = MakeSimulatedQuantize(rhs->expr, "weight");
rnode->expr = CallNode::make(ref_call->op,
{lhs->expr, rhs_expr},
ref_call->attrs, ref_call->type_args);
rnode->field = "activation";
} else {
LOG(FATAL) << "do not handle yet.";
}
};

Expr Annotate(const Expr& e) {
return Annotator().Annotate(e);
return Expr(rnode);
}

TVM_REGISTER_API("relay._quantize.annotate")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = Annotate(args[0]);
});

// register attribute for annotator
Array<Integer> Conv2dQFieldSpec(Attrs attrs, Array<Integer> ifields) {
return {kQInput, kQWeight, kQActivation};
RELAY_REGISTER_OP("multiply")
.set_attr<FForwardRewrite>("FQFieldRewrite", MulQFieldRewrite);


// share rewrite function for now
RELAY_REGISTER_OP("add")
.set_attr<FForwardRewrite>("FQFieldRewrite", MulQFieldRewrite);


Expr ReluQFieldRewrite(const Call& ref_call,
const Array<Expr>& new_args,
const NodeRef& ctx) {
auto rnode = make_node<QFieldExprNode>();
Array<QFieldExpr> inputs = PrepareInputs(new_args);
QFieldExpr input = inputs[0];

rnode->expr = CallNode::make(ref_call->op, {input->expr},
ref_call->attrs, ref_call->type_args);
rnode->field = input->field;
return Expr(rnode);
}

RELAY_REGISTER_OP("nn.conv2d")
.set_attr<FQFieldSpec>("FQFieldSpec", Conv2dQFieldSpec);
RELAY_REGISTER_OP("nn.relu")
.set_attr<FForwardRewrite>("FQFieldRewrite", ReluQFieldRewrite);


Array<Integer> MulQFieldSpec(Attrs attrs, Array<Integer> ifields) {
CHECK(ifields.size() == 2);
QField lhs = Int2Field(ifields[0]);
QField rhs = Int2Field(ifields[1]);
if (lhs == kFloat && rhs == kFloat) {
return {kFloat, kFloat, kFloat};
} else if (lhs == kQActivation || rhs == kQActivation) {
return {kQActivation, kQWeight, kQActivation};
} else {
LOG(FATAL) << "wrong fields for mul";
return {};
}
Expr Annotate(Expr expr) {
return ForwardRewrite(
expr, "FQFieldRewrite");
}

RELAY_REGISTER_OP("multiply")
.set_attr<FQFieldSpec>("FQFieldSpec", MulQFieldSpec);
TVM_REGISTER_API("relay._quantize.annotate")
.set_body_typed<Expr(Expr)>(Annotate);




// =============
Expand Down Expand Up @@ -319,7 +282,7 @@ QIntState RealizeQuantize(const Attrs attrs,
if (static_cast<int>(magnitude) == magnitude) {
// int32->int8, idom_scale < odom_scale, right_shift
data = RightShift(data, MakeConstantScalar(Int(32), static_cast<int>(magnitude)));
// TODO do we need to clip
// TODO do we need to clip?
DataType cast_dtype = Int(bit_imm);
Expr cast_data = Cast(data, cast_dtype);
return QIntStateNode::make(cast_data, odom_scale, bit);
Expand Down
35 changes: 35 additions & 0 deletions src/relay/pass/quantize.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,41 @@ namespace tvm {
namespace relay {
namespace quantize {

// SimulatedQuantize
struct SimulatedQuantizeAttrs : public tvm::AttrsNode<SimulatedQuantizeAttrs> {
bool sign;
std::string rounding;
std::string field_type;

TVM_DECLARE_ATTRS(SimulatedQuantizeAttrs, "relay.attrs.SimulatedQuantizeAttrs") {
TVM_ATTR_FIELD(sign).set_default(true);
TVM_ATTR_FIELD(rounding).set_default("round")
.describe("rounding mode. Can be 'floor', 'ceil', 'round'");
TVM_ATTR_FIELD(field_type)
.describe("field_type");
}
};

Expr MakeSimulatedQuantize(Expr x, std::string field);


class QFieldExprNode : public TempExprNode {
public:
Expr expr;
std::string field;
Expr Realize() const {
// dequantize
Expr ret = MakeSimulatedQuantize(expr, "float");
return ret;
}

static constexpr const char* _type_key = "relay.QFieldExpr";
TVM_DECLARE_NODE_TYPE_INFO(QFieldExprNode, TempExprNode);
};

RELAY_DEFINE_NODE_REF(QFieldExpr, QFieldExprNode, TempExpr);


class QState;
class QIntState;
class QRealState;
Expand Down

0 comments on commit bff439a

Please sign in to comment.