Skip to content

Commit

Permalink
Support low bit weights in nnvm (#2)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen authored and tmoreau89 committed Nov 25, 2018
1 parent 93b71e2 commit adc588f
Show file tree
Hide file tree
Showing 5 changed files with 156 additions and 2 deletions.
68 changes: 68 additions & 0 deletions nnvm/src/top/nn/lowbits.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*!
* Copyright (c) 2017 by Contributors
* \file lowbit.cc
* \brief Support operators for lowbit
*/
#include <tvm/tvm.h>
#include <tvm/expr.h>
#include <tvm/packed_func_ext.h>
#include <nnvm/op.h>
#include <nnvm/node.h>
#include <nnvm/layout.h>
#include <nnvm/op_attr_types.h>
#include <nnvm/compiler/op_attr_types.h>
#include <nnvm/top/nn.h>
#include "./nn_common.h"
#include "../op_common.h"
#include "../elemwise_op_common.h"

namespace nnvm {
namespace top {

struct BitPackParam : public dmlc::Parameter<BitPackParam> {
int lanes;

DMLC_DECLARE_PARAMETER(BitPackParam) {
DMLC_DECLARE_FIELD(lanes).set_lower_bound(1)
.describe("Number of lanes packed in one element");
}
};


// dense
DMLC_REGISTER_PARAMETER(BitPackParam);

inline bool BitPackInferShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape>* in_shape,
std::vector<TShape>* out_shape) {
const BitPackParam& param = nnvm::get<BitPackParam>(attrs.parsed);
CHECK_EQ(out_shape->size(), 1U);
if ((*in_shape)[DenseParam::kData].ndim() != 0) {
TShape dshape = (*in_shape)[0];
CHECK_EQ(dshape[dshape.ndim() - 1] % param.lanes, 0);
dshape[dshape.ndim() - 1] /= param.lanes;
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, dshape);
return false;
}
return true;
}


NNVM_REGISTER_OP(bitpack)
.describe(R"code(Applies bit packing to innermost dimension.
)code" NNVM_ADD_FILELINE)
.add_argument("data", "nD Tensor", "Input data.")
.add_argument("weight", "2D Tensor", "Weight matrix.")
.add_argument("bias", "1D Tensor", "Bias parameter.")
.add_arguments(BitPackParam::__FIELDS__())
.set_attr_parser(ParamParser<BitPackParam>)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<BitPackParam>)
.set_num_outputs(1)
.set_num_inputs(1)
.set_support_level(5)
.set_attr<FInferShape>("FInferShape", BitPackInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<-1, 1>);

} // namespace top
} // namespace nnvm
1 change: 1 addition & 0 deletions vta/python/vta/top/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
from .vta_conv2d import packed_conv2d, schedule_packed_conv2d
from . import vta_conv2d
from . import arm_conv2d
from .bitpack import bitpack
72 changes: 72 additions & 0 deletions vta/python/vta/top/bitpack.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""Bit packing operators"""
from __future__ import absolute_import as _abs

import tvm
from topi import util

from nnvm.top import registry as reg, OpPattern
from nnvm.top import nn as _nn
from nnvm.top.tensor import _fschedule_broadcast

def bitpack(data, bits, pack_type="int8", name="bitpack"):
"""Packs lowest dimension into format needed by VTA
Parameters
----------
pack_axis : int
index of the axis to pack in data
bit_axis : int
index of axis to place bit axis in resulting packed data
Returns
-------
packed : Tensor
The packed tensor.
"""
shape_vec = list(data.shape)
if pack_type == 'int8':
data_width = 8
elif pack_type == 'int16':
data_width = 16
elif pack_type == 'int32':
data_width = 32
else:
raise RuntimeError("Unknown pack type %s" % pack_type)
assert data_width % bits == 0
lanes = data_width // bits

# Data must be in multiples of the data_width
assert util.get_const_int(shape_vec[-1]) % lanes == 0, "Not a multiple of word size"
shape_vec[-1] = shape_vec[-1] // lanes
oshape = tuple(shape_vec)

def _bitpack(*indices):
ret = None
mask = tvm.const((1 << bits) - 1, pack_type)
for k in range(lanes):
idx = list(indices)
idx[-1] = idx[-1] * lanes + k
elem = data(*idx).astype(pack_type)
if k == 0:
ret = elem & mask
else:
val = (elem & mask) << tvm.const(k * bits, pack_type)
ret = ret | val
return ret

return tvm.compute(
oshape, _bitpack, name=name, tag='bitpack')


@reg.register_compute("bitpack", level=15)
def compute_bitpack(attrs, inputs, out):
lanes = attrs.get_int("lanes")
dtype = inputs[0].dtype
assert dtype == "int8"
width = 8
assert width % lanes == 0
bits = 8 // lanes
return bitpack(inputs[0], bits, dtype)

reg.register_schedule("bitpack", _fschedule_broadcast)
reg.register_pattern("bitpack", OpPattern.INJECTIVE)
16 changes: 15 additions & 1 deletion vta/python/vta/top/vta_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from nnvm.top import registry as reg, OpPattern
from nnvm.top import nn as _nn
from ..environment import get_env

from ..ptr_alias import reinterpret

Workload = namedtuple("Conv2DWorkload",
['batch', 'height', 'width', 'in_filter', 'out_filter',
Expand Down Expand Up @@ -259,9 +259,23 @@ def compute_conv2d(attrs, inputs, out):
groups = attrs.get_int("groups")
layout = attrs["layout"]
out_dtype = attrs['out_dtype']

assert dilation == (1, 1), "not support dilate now"
if is_packed_layout(layout):
assert groups == 1
env = get_env()
assert env.LOG_INP_WIDTH == 3, "only support 8bit inp for now"
assert env.LOG_OUT_WIDTH == 3, "only support 8bit inp for now"
inputs = list(inputs)
w_pack_factor = 1 << (3 - env.LOG_WGT_WIDTH)
assert inputs[1].dtype == "int8"

# Apply bit packing if necessary
if w_pack_factor != 1:
kshape = list(topi.util.get_const_tuple(inputs[1].shape))
kshape[-1] *= w_pack_factor
inputs[1] = reinterpret(inputs[1], kshape, dtype=env.wgt_dtype)

return packed_conv2d(inputs[0], inputs[1],
padding, strides, out_dtype=out_dtype)
return _nn.compute_conv2d(attrs, inputs, out)
Expand Down
1 change: 0 additions & 1 deletion vta/tests/python/integration/test_benchmark_topi_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ def test_cpu_conv2d():
def run_cpu_conv2d(env, remote, key, batch_size, wl, profile=True):
data_shape = (batch_size, wl.in_filter, wl.height, wl.width)
kernel_shape = (wl.out_filter, wl.in_filter, wl.hkernel, wl.wkernel)

fout_height = (wl.height + 2 * wl.hpad - wl.hkernel) // wl.hstride + 1
fout_width = (wl.width + 2 * wl.wpad - wl.wkernel) // wl.wstride + 1
data = tvm.placeholder(data_shape, name="data", dtype=env.inp_dtype)
Expand Down

0 comments on commit adc588f

Please sign in to comment.