Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[OPT] Low-bit Quantization #2116

Merged
merged 5 commits into from
Jan 31, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,6 @@ struct StridedSliceAttrs : public tvm::AttrsNode<StridedSliceAttrs> {
}
};


struct SliceLikeAttrs : public tvm::AttrsNode<SliceLikeAttrs> {
Array<Integer> axes;

Expand All @@ -151,16 +150,16 @@ struct SliceLikeAttrs : public tvm::AttrsNode<SliceLikeAttrs> {
}
};

// Clip
/*! \brief Attributes for Clip operator */
struct ClipAttrs : public tvm::AttrsNode<ClipAttrs> {
double a_min;
double a_max;

TVM_DECLARE_ATTRS(ClipAttrs, "relay.attrs.ClipAttrs") {
TVM_ATTR_FIELD(a_min)
.describe("The minimum clip value.");
TVM_ATTR_FIELD(a_max)
.describe("The maximum clip value.");
TVM_ATTR_FIELD(a_min)
.describe("The minimum clip value.");
TVM_ATTR_FIELD(a_max)
.describe("The maximum clip value.");
}
};

Expand Down
1 change: 1 addition & 0 deletions include/tvm/relay/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,7 @@ inline ValueType OpMap<ValueType>::get(const Expr& expr,
return map_.get<ValueType>(expr, def_value);
}


/*!
* \brief Check that an expression is a "primtive operator".
*
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/relay/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from . import expr_functor
from . import module
from . import ir_pass
from .build_module import build, build_config, create_executor
from .build_module import build, build_config, create_executor, optimize
from . import parser
from . import debug

Expand All @@ -23,6 +23,7 @@
from . import image
from . import frontend
from . import backend
from . import quantize

from .scope_builder import ScopeBuilder

Expand Down
4 changes: 2 additions & 2 deletions python/tvm/relay/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def _bind_params_by_name(func, params):
return expr.bind(func, bind_dict)


def optimize(func, target, params=None):
def optimize(func, target=None, params=None):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems this API changes recently? It breaks some codes @tqchen

"""Perform target invariant optimizations.

Parameters
Expand Down Expand Up @@ -400,7 +400,7 @@ def _make_executor(self, func):
graph_json, mod, params = build(func, target=self.target)
gmodule = _graph_rt.create(graph_json, mod, self.ctx)
if params:
gmodule.set_input(*params)
gmodule.set_input(**params)

def _graph_wrapper(*args, **kwargs):
args = self._convert_args(func, args, kwargs)
Expand Down
6 changes: 6 additions & 0 deletions python/tvm/relay/quantize/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#pylint: disable=wildcard-import, redefined-builtin
"""Automatic quantization utilities."""
from __future__ import absolute_import as _abs

from .quantize import *
from ._annotate import register_annotate_function
246 changes: 246 additions & 0 deletions python/tvm/relay/quantize/_annotate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
#pylint: disable=unused-argument
"""Internal module for registering attribute for annotation."""
from __future__ import absolute_import

import topi
from . import _quantize
from .quantize import QAnnotateKind, current_qconfig
from .quantize import _conv_counter, _set_conv_counter
from .. import expr as _expr
from .. import op as _op
from ..op import op as _reg
from ..base import register_relay_node
from ..._ffi.function import register_func


@_reg.register_compute("relay.op.annotation.simulated_quantize")
def simulated_quantize_compute(attrs, inputs, out_type, target):
"""Compiler for simulated_quantize."""
assert len(inputs) == 4
assert attrs.sign
assert attrs.rounding == "round"

data, scale, clip_min, clip_max = inputs

# simulate rounding error
scaled_data = topi.divide(data, scale)
clipped_data = topi.maximum(topi.minimum(scaled_data, clip_max), clip_min)
round_data = topi.round(clipped_data)

# recover data
rdata = topi.multiply(round_data, scale)
return [rdata]


_reg.register_schedule("relay.op.annotation.simulated_quantize",
_reg.schedule_injective)
_reg.register_pattern("relay.op.annotation.simulated_quantize",
_reg.OpPattern.OPAQUE)


@register_relay_node
class QAnnotateExpr(_expr.TempExpr):
"""A special kind of Expr for Annotating.

Parameters
---------
expr: Expr
the original relay ir expr.

kind: QAnnotateKind
the kind of annotation field.
"""
def __init__(self, expr, kind):
self.__init_handle_by_constructor__(
_quantize.make_annotate_expr, expr, kind)


def _forward_op(ref_call, args):
"""forward the operator of ref_call with provided arguments"""
return _expr.Call(
ref_call.op, args, ref_call.attrs, ref_call.type_args)


def _get_expr_kind(anno):
"""Get the expression and QAnnotateKind from QAnnotateExpr or Expr"""
if isinstance(anno, QAnnotateExpr):
return anno.expr, anno.kind
return anno, None


def register_annotate_function(op_name, frewrite=None, level=10):
"""register a rewrite function for operator, used by annotation.

Parameters
---------
op_name: str
The name of operation

frewrite : function, optional
The function to be registered.

level : int, optional
The priority level
"""
def default_rewrite(ref_call, new_args, ctx):
# recover from QAnnotateExpr
args = [_get_expr_kind(x)[0] for x in new_args]
return _forward_op(ref_call, args)

def _register(func):
"""internal register function"""
def frewrite_with_guard(ref_call, new_args, ctx):
if not current_qconfig().guard(ref_call):
return default_rewrite(ref_call, new_args, ctx)
return func(ref_call, new_args, ctx)
_op.op._Register(op_name, "FQAnnotateRewrite", frewrite_with_guard, level)
return frewrite_with_guard

return _register(frewrite) if frewrite is not None else _register


@register_func("relay.quantize.attach_simulated_quantize")
def attach_simulated_quantize(data, kind, sign=True, rounding="round"):
"""Attach a simulated quantize operation after input data expr.

Parameters
---------
data: Expr
the original data expr.

kind: QAnnotateKind
the kind of annotation field.
"""
dom_scale = _expr.var("dom_scale")
clip_min = _expr.var("clip_min")
clip_max = _expr.var("clip_max")
return _quantize.simulated_quantize(
data, dom_scale, clip_min, clip_max, kind, sign, rounding)


@register_annotate_function("nn.conv2d")
def conv2d_rewrite(ref_call, new_args, ctx):
"""Rewrite function for conv2d. Lhs of conv will be quantized to
input field, and rhs of conv will be quantized to weight field.
Output would be in activation field"""
cnt = _conv_counter()
if cnt < current_qconfig().skip_k_conv:
_set_conv_counter(cnt + 1)
return None
_set_conv_counter(cnt + 1)
ZihengJiang marked this conversation as resolved.
Show resolved Hide resolved

lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
rhs_expr, rhs_kind = _get_expr_kind(new_args[1])

if lhs_kind is None or lhs_kind != QAnnotateKind.INPUT:
lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can / should we avoid duplicated quantization in parallel branches? e.g

   data (fp32) 
    /       \
conv1    conv2

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should and we can, possibly via memoization. In theory forward rewrite already memoize, if there is any problem, please provide a minimum test case and let us double check

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a test case, to reproduce, you need to set opt level "CombineParallelConv2D": 4 to disable this pass.

import tvm
import tvm.relay as relay
import tvm.relay.testing

def get_workload():
    data = relay.var("data", shape=(1, 3, 224, 224), dtype='float32')
    conv1 = relay.testing.layers.conv2d(data=data, channels=16, kernel_size=(1, 1), name='conv1')
    shortcut = relay.testing.layers.conv2d(data=data, channels=16, kernel_size=(1, 1), name='sc')
    net = relay.add(conv1, shortcut)
    f = relay.Function(relay.ir_pass.free_vars(net), net)
    return relay.testing.init.create_workload(f)
    
sym, params = get_workload()
with tvm.relay.quantize.qconfig(skip_k_conv=0):
    sym = relay.quantize.quantize(sym, params)
print(sym.astext(show_meta_data=False))
tvm.relay.build(sym, 'llvm', params=params)

Result:

fn (%data: Tensor[(1, 3, 224, 224), float32])
    -> Tensor[(1, 16, 224, 224), float32] {
  %0 = multiply(%data, 16f) # ty=Tensor[(1, 3, 224, 224), float32]
  %1 = round(%0) # ty=Tensor[(1, 3, 224, 224), float32]
  %2 = clip(%1, a_min=-127, a_max=127) # ty=Tensor[(1, 3, 224, 224), float32]
  %3 = cast(%2, dtype="int8") # ty=Tensor[(1, 3, 224, 224), int8]
  %4 = meta.relay.Constant(id=0) # ty=Tensor[(16, 3, 1, 1), int8]
  %5 = nn.conv2d(%3, %4, channels=16, kernel_size=[1, 1], out_dtype="int32")
  %6 = multiply(%data, 16f) # ty=Tensor[(1, 3, 224, 224), float32]
  %7 = round(%6) # ty=Tensor[(1, 3, 224, 224), float32]
  %8 = clip(%7, a_min=-127, a_max=127) # ty=Tensor[(1, 3, 224, 224), float32]
  %9 = cast(%8, dtype="int8") # ty=Tensor[(1, 3, 224, 224), int8]
  %10 = meta.relay.Constant(id=1) # ty=Tensor[(16, 3, 1, 1), int8]
  %11 = nn.conv2d(%9, %10, channels=16, kernel_size=[1, 1], out_dtype="int32")
  %12 = add(%5, %11)
  %13 = add(%12, 64)
  %14 = right_shift(%13, 7)
  %15 = clip(%14, a_min=-127, a_max=127)
  %16 = cast(%15, dtype="int8")
  %17 = cast(%16, dtype="float32")
  %18 = multiply(%17, 0.0625f)
  %18
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ZihengJiang @merrymercy @vinx13 can you look into this? let us open this testcase as an issue to be fixed

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is because quantization of data happens during rewrite of conv2d, so this won't be memorized. We need some message passing to quantize data during forward rewrite of data.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does not have things to do with ANF. The problem is that if two conv refers to the same input and they want to run the same transformation f on that input, there will be two such f.

One solution is to build a generic common subexpression combination(elimination) path to create a concise dag

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As we have seen in @vinx13 's test case, there're three multiply operations. The multipliers are typically 16 or 1/16 (which is 0.0625). In order to eliminate floating-point multiplication, can we convert them into shift operation upon integers?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as i understand, this PR already do that

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought so, and i think it should be configured by disabling round_for_shift in qconfig. However, the argument doesn't actually work, or at least not for replacing the combination of multiply and round with shift operation.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ZihengJiang What's the status of this issue of parallel branches? Will it be future work?


assert rhs_kind is None
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT)

expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
ZihengJiang marked this conversation as resolved.
Show resolved Hide resolved


@register_annotate_function("multiply")
def multiply_rewrite(ref_call, new_args, ctx):
"""Rewrite function for multiply."""
if _conv_counter() <= current_qconfig().skip_k_conv:
return None

lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
rhs_expr, rhs_kind = _get_expr_kind(new_args[1])

if lhs_kind is None and rhs_kind is None:
return None
if lhs_kind == QAnnotateKind.ACTIVATION and rhs_kind is None:
# quantize lhs to INPUT field
lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT)
# quantize rhs to WEIGHT field
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT)
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
raise ValueError


@register_annotate_function("add")
def add_rewrite(ref_call, new_args, ctx):
"""Rewrite function for add."""
if _conv_counter() <= current_qconfig().skip_k_conv:
ZihengJiang marked this conversation as resolved.
Show resolved Hide resolved
return None

lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
rhs_expr, rhs_kind = _get_expr_kind(new_args[1])

if lhs_kind is None and rhs_kind is None:
return None
if lhs_kind is None and rhs_kind is not None:
# quantize lhs to INPUT field if it is normal expression
lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT)
if lhs_kind is not None and rhs_kind is None:
if isinstance(rhs_expr, _expr.Constant):
# quantize rhs to WEIGHT field if it is Constant
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT)
else:
# quantize rhs to INPUT field if it is not Constant
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT)

expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)


def identity_rewrite(ref_call, new_args, ctx):
"""Simply forward the original operation"""
if _conv_counter() <= current_qconfig().skip_k_conv:
return None

x_expr, x_kind = _get_expr_kind(new_args[0])
if x_kind is None:
return None

ret_expr = _forward_op(ref_call, [x_expr])
return QAnnotateExpr(ret_expr, x_kind)


register_annotate_function("nn.relu", identity_rewrite)
register_annotate_function("strided_slice", identity_rewrite)
register_annotate_function("nn.avg_pool2d", identity_rewrite)


def pool2d_rewrite(ref_call, new_args, ctx):
"""Rewrite function for max pool2d"""
if _conv_counter() <= current_qconfig().skip_k_conv:
return None
expr, x_kind = _get_expr_kind(new_args[0])

if x_kind is None:
return None
if x_kind == QAnnotateKind.ACTIVATION:
expr = attach_simulated_quantize(expr, QAnnotateKind.INPUT)
expr = _forward_op(ref_call, [expr])
return QAnnotateExpr(expr, QAnnotateKind.INPUT)


register_annotate_function("nn.max_pool2d", pool2d_rewrite)


@register_annotate_function("concatenate")
def concatenate_rewrite(ref_call, new_args, ctx):
"""Rewrite function for concatenate"""
if _conv_counter() <= current_qconfig().skip_k_conv:
return None
ZihengJiang marked this conversation as resolved.
Show resolved Hide resolved

input_tuple = new_args[0]
expr_list = [_get_expr_kind(x)[0] for x in input_tuple]
kind_list = [_get_expr_kind(x)[1] for x in input_tuple]

# make sure the inputs of concatenate are all normal
# expression or annotate expression
if kind_list[0] is None:
for k in kind_list:
assert k is None
return None
for k in kind_list:
assert k is not None
expr = _forward_op(ref_call, [_expr.Tuple(expr_list)])
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
6 changes: 6 additions & 0 deletions python/tvm/relay/quantize/_quantize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#pylint: disable=unused-argument
"""Internal module for quantization."""
from __future__ import absolute_import
from tvm._ffi.function import _init_api

_init_api("relay._quantize", __name__)
Loading