diff --git a/nnvm/src/top/nn/lowbits.cc b/nnvm/src/top/nn/lowbits.cc new file mode 100644 index 0000000000000..10781e7d861e0 --- /dev/null +++ b/nnvm/src/top/nn/lowbits.cc @@ -0,0 +1,68 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file lowbit.cc + * \brief Support operators for lowbit + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "./nn_common.h" +#include "../op_common.h" +#include "../elemwise_op_common.h" + +namespace nnvm { +namespace top { + +struct BitPackParam : public dmlc::Parameter { + int lanes; + + DMLC_DECLARE_PARAMETER(BitPackParam) { + DMLC_DECLARE_FIELD(lanes).set_lower_bound(1) + .describe("Number of lanes packed in one element"); + } +}; + + +// dense +DMLC_REGISTER_PARAMETER(BitPackParam); + +inline bool BitPackInferShape(const nnvm::NodeAttrs& attrs, + std::vector* in_shape, + std::vector* out_shape) { + const BitPackParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(out_shape->size(), 1U); + if ((*in_shape)[DenseParam::kData].ndim() != 0) { + TShape dshape = (*in_shape)[0]; + CHECK_EQ(dshape[dshape.ndim() - 1] % param.lanes, 0); + dshape[dshape.ndim() - 1] /= param.lanes; + NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, dshape); + return false; + } + return true; +} + + +NNVM_REGISTER_OP(bitpack) +.describe(R"code(Applies bit packing to innermost dimension. + +)code" NNVM_ADD_FILELINE) +.add_argument("data", "nD Tensor", "Input data.") +.add_argument("weight", "2D Tensor", "Weight matrix.") +.add_argument("bias", "1D Tensor", "Bias parameter.") +.add_arguments(BitPackParam::__FIELDS__()) +.set_attr_parser(ParamParser) +.set_attr("FGetAttrDict", ParamGetAttrDict) +.set_num_outputs(1) +.set_num_inputs(1) +.set_support_level(5) +.set_attr("FInferShape", BitPackInferShape) +.set_attr("FInferType", ElemwiseType<-1, 1>); + +} // namespace top +} // namespace nnvm diff --git a/vta/python/vta/top/__init__.py b/vta/python/vta/top/__init__.py index 614ed2347181e..46454ebf789f8 100644 --- a/vta/python/vta/top/__init__.py +++ b/vta/python/vta/top/__init__.py @@ -3,3 +3,4 @@ from .vta_conv2d import packed_conv2d, schedule_packed_conv2d from . import vta_conv2d from . import arm_conv2d +from .bitpack import bitpack diff --git a/vta/python/vta/top/bitpack.py b/vta/python/vta/top/bitpack.py new file mode 100644 index 0000000000000..61d3cd3c4f4a0 --- /dev/null +++ b/vta/python/vta/top/bitpack.py @@ -0,0 +1,72 @@ +"""Bit packing operators""" +from __future__ import absolute_import as _abs + +import tvm +from topi import util + +from nnvm.top import registry as reg, OpPattern +from nnvm.top import nn as _nn +from nnvm.top.tensor import _fschedule_broadcast + +def bitpack(data, bits, pack_type="int8", name="bitpack"): + """Packs lowest dimension into format needed by VTA + + Parameters + ---------- + pack_axis : int + index of the axis to pack in data + bit_axis : int + index of axis to place bit axis in resulting packed data + + Returns + ------- + packed : Tensor + The packed tensor. + """ + shape_vec = list(data.shape) + if pack_type == 'int8': + data_width = 8 + elif pack_type == 'int16': + data_width = 16 + elif pack_type == 'int32': + data_width = 32 + else: + raise RuntimeError("Unknown pack type %s" % pack_type) + assert data_width % bits == 0 + lanes = data_width // bits + + # Data must be in multiples of the data_width + assert util.get_const_int(shape_vec[-1]) % lanes == 0, "Not a multiple of word size" + shape_vec[-1] = shape_vec[-1] // lanes + oshape = tuple(shape_vec) + + def _bitpack(*indices): + ret = None + mask = tvm.const((1 << bits) - 1, pack_type) + for k in range(lanes): + idx = list(indices) + idx[-1] = idx[-1] * lanes + k + elem = data(*idx).astype(pack_type) + if k == 0: + ret = elem & mask + else: + val = (elem & mask) << tvm.const(k * bits, pack_type) + ret = ret | val + return ret + + return tvm.compute( + oshape, _bitpack, name=name, tag='bitpack') + + +@reg.register_compute("bitpack", level=15) +def compute_bitpack(attrs, inputs, out): + lanes = attrs.get_int("lanes") + dtype = inputs[0].dtype + assert dtype == "int8" + width = 8 + assert width % lanes == 0 + bits = 8 // lanes + return bitpack(inputs[0], bits, dtype) + +reg.register_schedule("bitpack", _fschedule_broadcast) +reg.register_pattern("bitpack", OpPattern.INJECTIVE) diff --git a/vta/python/vta/top/vta_conv2d.py b/vta/python/vta/top/vta_conv2d.py index 1426e5d0d5ed3..549d7144d3214 100644 --- a/vta/python/vta/top/vta_conv2d.py +++ b/vta/python/vta/top/vta_conv2d.py @@ -10,7 +10,7 @@ from nnvm.top import registry as reg, OpPattern from nnvm.top import nn as _nn from ..environment import get_env - +from ..ptr_alias import reinterpret Workload = namedtuple("Conv2DWorkload", ['batch', 'height', 'width', 'in_filter', 'out_filter', @@ -259,9 +259,23 @@ def compute_conv2d(attrs, inputs, out): groups = attrs.get_int("groups") layout = attrs["layout"] out_dtype = attrs['out_dtype'] + assert dilation == (1, 1), "not support dilate now" if is_packed_layout(layout): assert groups == 1 + env = get_env() + assert env.LOG_INP_WIDTH == 3, "only support 8bit inp for now" + assert env.LOG_OUT_WIDTH == 3, "only support 8bit inp for now" + inputs = list(inputs) + w_pack_factor = 1 << (3 - env.LOG_WGT_WIDTH) + assert inputs[1].dtype == "int8" + + # Apply bit packing if necessary + if w_pack_factor != 1: + kshape = list(topi.util.get_const_tuple(inputs[1].shape)) + kshape[-1] *= w_pack_factor + inputs[1] = reinterpret(inputs[1], kshape, dtype=env.wgt_dtype) + return packed_conv2d(inputs[0], inputs[1], padding, strides, out_dtype=out_dtype) return _nn.compute_conv2d(attrs, inputs, out) diff --git a/vta/tests/python/integration/test_benchmark_topi_conv2d.py b/vta/tests/python/integration/test_benchmark_topi_conv2d.py index 5e5f035372e2f..b95103be182e2 100644 --- a/vta/tests/python/integration/test_benchmark_topi_conv2d.py +++ b/vta/tests/python/integration/test_benchmark_topi_conv2d.py @@ -61,7 +61,6 @@ def test_cpu_conv2d(): def run_cpu_conv2d(env, remote, key, batch_size, wl, profile=True): data_shape = (batch_size, wl.in_filter, wl.height, wl.width) kernel_shape = (wl.out_filter, wl.in_filter, wl.hkernel, wl.wkernel) - fout_height = (wl.height + 2 * wl.hpad - wl.hkernel) // wl.hstride + 1 fout_width = (wl.width + 2 * wl.wpad - wl.wkernel) // wl.wstride + 1 data = tvm.placeholder(data_shape, name="data", dtype=env.inp_dtype)