Skip to content

Commit

Permalink
[TOP] Add dense, batchnorm (#22)
Browse files Browse the repository at this point in the history
* [TOP] Add dense, batchnorm

* update tvm
  • Loading branch information
tqchen committed May 29, 2018
1 parent b37e5c2 commit 02a60d0
Show file tree
Hide file tree
Showing 14 changed files with 401 additions and 213 deletions.
7 changes: 5 additions & 2 deletions nnvm/include/nnvm/compiler/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,14 @@ using TOpPattern = int;
* \brief Computation description interface
* \param attrs The attribute of the node.
* \param inputs The input tensors(placeholders)
* \param out_info Tensors holding shape/type information about output,
& these are always placeholders.
* \return The output description of the tensor.
*/
using FTVMCompute = std::function<
Array<Tensor>
(const NodeAttrs& attrs, const Array<Tensor>& inputs)>;
Array<Tensor>(const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info)>;

/*!
* \brief Build the computation schedule for
Expand Down
13 changes: 11 additions & 2 deletions nnvm/python/nnvm/compiler/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,12 @@ def optimize(graph, shape, dtype="float32"):
"""
# pylint: disable=unused-argument
cfg = BuildConfig.current
graph = graph_attr.set_shape_inputs(graph, shape)
graph = graph.apply("InferShape")
if graph.json_attr("shape_num_unknown_nodes"):
raise ValueError("InferShape fails..")
if cfg.opt_level >= OPT_PASS_LEVEL["SimplifyBatchNormInference"]:
graph = graph_attr.set_shape_inputs(graph, shape)
graph = graph.apply(["InferShape", "SimplifyBatchNormInference"])
graph = graph.apply("SimplifyBatchNormInference")
return graph


Expand Down Expand Up @@ -164,6 +167,12 @@ def build(graph, target, shape, dtype="float32", params=None):
cfg = BuildConfig.current
graph = graph if isinstance(graph, _graph.Graph) else _graph.create(graph)
shape, dtype = _update_shape_dtype(shape, dtype, params)
# Initial pass do shape type inference
ishape, _ = graph_util.infer_shape(graph, **shape)
shape.update(zip(graph.index.input_names, ishape))
if not isinstance(dtype, str):
idtype, _ = graph_util.infer_dtype(graph, **dtype)
dtype.update(zip(graph.index.input_names, idtype))
# Apply optimization
graph = optimize(graph, shape, dtype)
# Precompute prune
Expand Down
4 changes: 3 additions & 1 deletion nnvm/python/nnvm/compiler/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
class OpPattern(object):
ELEM_WISE = 0
BROADCAST = 1
# Complex means we can fuse elemwise to it
COMPLEX = 2
EXTERN = 2
# Extern means the op is not fusable
EXTERN = 3

_register_compute = tvm.get_global_func("nnvm._register_compute")
_register_schedule = tvm.get_global_func("nnvm._register_schedule")
Expand Down
1 change: 1 addition & 0 deletions nnvm/python/nnvm/top/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from .attr_dict import AttrDict
from . import tensor
from . import nn
from . import transform
49 changes: 39 additions & 10 deletions nnvm/python/nnvm/top/nn.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,37 @@
# pylint: disable=invalid-name, unused-argument
"""Definition of nn ops"""
from __future__ import absolute_import

import tvm
import topi
from topi.util import get_const_int
from .tensor import schedule_elemwise
from .tensor import _fschedule_broadcast
from ..compiler import registry as reg
from ..compiler import OpPattern

# relu
@reg.register_compute("relu")
def compute_relu(_, inputs):
def compute_relu(attrs, inputs, _):
"""Compute definition of relu"""
return topi.nn.relu(inputs[0])

@reg.register_schedule("relu")
def schedule_relu(_, outs, target):
"""Schedule definition of relu"""
return schedule_elemwise(_, outs, target)

reg.register_schedule("relu", _fschedule_broadcast)
reg.register_pattern("relu", OpPattern.ELEM_WISE)


# flatten
@reg.register_compute("flatten")
def compute_flatten(attrs, inputs, _):
"""Compute definition of flatten"""
return topi.nn.flatten(inputs[0])

reg.register_schedule("flatten", _fschedule_broadcast)
reg.register_pattern("flatten", OpPattern.COMPLEX)


# softmax
@reg.register_compute("softmax")
def compute_softmax(attrs, inputs):
def compute_softmax(attrs, inputs, _):
"""Compute definition of softmax"""
axis = attrs.get_int("axis")
assert axis == -1, "only support axis == -1 for now"
Expand All @@ -38,12 +45,34 @@ def schedule_softmax(_, outs, target):
# naive schedule
return tvm.create_schedule([x.op for x in outs])

reg.register_pattern("softmax", OpPattern.COMPLEX)
# Mark softmax as extern as we do not fuse it in call cases
reg.register_pattern("softmax", OpPattern.EXTERN)


# dense
@reg.register_compute("dense")
def compute_dense(attrs, inputs, _):
"""Compute definition of dense"""
if attrs.get_bool("use_bias"):
return topi.nn.fully_connected_with_bias(
inputs[0], inputs[1], inputs[2])
return topi.nn.fully_connected(inputs[0], inputs[1])

@reg.register_schedule("dense")
def schedule_dense(_, outs, target):
"""Schedule definition of dense"""
if target == "cuda":
raise ValueError("fully_connected not yet implemented")
# naive schedule
return tvm.create_schedule([x.op for x in outs])

# register extern for now, change me when fusion is enabled.
reg.register_pattern("dense", OpPattern.EXTERN)


# conv
@reg.register_compute("conv2d")
def compute_conv2d(attrs, inputs):
def compute_conv2d(attrs, inputs, _):
"""Compute definition of conv2d"""
padding = attrs.get_int_tuple("padding")
strides = attrs.get_int_tuple("strides")
Expand Down
121 changes: 93 additions & 28 deletions nnvm/python/nnvm/top/tensor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# pylint: disable=invalid-name
# pylint: disable=invalid-name, unused-argument
"""Tensor ops"""
from __future__ import absolute_import

Expand All @@ -8,15 +8,6 @@
from ..compiler import registry as reg
from ..compiler import OpPattern

def schedule_elemwise(_, outs, target):
"""Generic schedule for elemwise operation"""
if target == "cuda":
return topi.cuda.schedule_elemwise(outs)
assert target.startswith("llvm")
s = tvm.create_schedule([x.op for x in outs])
tvm.schedule.AutoInlineInjective(s)
return s

def _schedule_broadcast(_, outs, target):
"""Generic schedule for binary bcast"""
if target == "cuda":
Expand All @@ -29,66 +20,140 @@ def _schedule_broadcast(_, outs, target):
def _compute_binary_scalar(f):
"""auxiliary function"""
@tvm.tag_scope("ewise")
def _compute(attrs, x):
def _compute(attrs, x, _):
x = x[0]
scalar = attrs.get_float("scalar")
scalar = tvm.const(scalar, x.dtype)
return tvm.compute(x.shape, lambda *i: f(x(*i), scalar))
return _compute


def _compute_unary(f):
"""auxiliary function"""
def _compute(attrs, x, _):
return f(x[0])
return _compute


def _compute_binary(f):
"""auxiliary function"""
def _compute(attrs, x, _):
return f(x[0], x[1])
return _compute


_fschedule_broadcast = tvm.convert(_schedule_broadcast)

# exp
reg.register_compute("exp",
lambda _, x: topi.exp(x[0]))
reg.register_compute("exp", _compute_unary(topi.exp))
reg.register_pattern("exp", OpPattern.ELEM_WISE)
reg.register_schedule("exp", _fschedule_broadcast)

# sqrt
reg.register_compute("sqrt", _compute_unary(topi.sqrt))
reg.register_pattern("sqrt", OpPattern.ELEM_WISE)
reg.register_schedule("sqrt", _fschedule_broadcast)

# log
reg.register_compute("log",
lambda _, x: topi.log(x[0]))
reg.register_compute("log", _compute_unary(topi.log))
reg.register_pattern("log", OpPattern.ELEM_WISE)
reg.register_schedule("log", _fschedule_broadcast)

# tanh
reg.register_compute("tanh",
lambda _, x: topi.tanh(x[0]))
reg.register_compute("tanh", _compute_unary(topi.tanh))
reg.register_pattern("tanh", OpPattern.ELEM_WISE)
reg.register_schedule("tanh", _fschedule_broadcast)

# negative
reg.register_compute("negative", _compute_unary(topi.negative))
reg.register_pattern("negative", OpPattern.ELEM_WISE)
reg.register_schedule("negative", _fschedule_broadcast)

# sigmoid
reg.register_compute("sigmoid",
lambda _, x: topi.sigmoid(x[0]))
reg.register_compute("sigmoid", _compute_unary(topi.sigmoid))
reg.register_pattern("sigmoid", OpPattern.ELEM_WISE)
reg.register_schedule("sigmoid", _fschedule_broadcast)

# add scalar
# add_scalar
reg.register_compute("__add_scalar__",
_compute_binary_scalar(lambda x, y: x + y))
reg.register_pattern("__add_scalar__", OpPattern.ELEM_WISE)
reg.register_schedule("__add_scalar__", _fschedule_broadcast)

# sub_calar
reg.register_compute("__sub_scalar__",
_compute_binary_scalar(lambda x, y: x - y))
reg.register_pattern("__sub_scalar__", OpPattern.ELEM_WISE)
reg.register_schedule("__sub_scalar__", _fschedule_broadcast)

# rsub_scalar
reg.register_compute("__rsub_scalar__",
_compute_binary_scalar(lambda x, y: y - x))
reg.register_pattern("__rsub_scalar__", OpPattern.ELEM_WISE)
reg.register_schedule("__rsub_scalar__", _fschedule_broadcast)

# mul_scalar
reg.register_compute("__mul_scalar__",
_compute_binary_scalar(lambda x, y: x * y))
reg.register_pattern("__mul_scalar__", OpPattern.ELEM_WISE)
reg.register_schedule("__mul_scalar__", _fschedule_broadcast)

# div_scalar
reg.register_compute("__div_scalar__",
_compute_binary_scalar(lambda x, y: x / y))
reg.register_pattern("__div_scalar__", OpPattern.ELEM_WISE)
reg.register_schedule("__div_scalar__", _fschedule_broadcast)

# rdiv_scalar
reg.register_compute("__rdiv_scalar__",
_compute_binary_scalar(lambda x, y: y / x))
reg.register_pattern("__rdiv_scalar__", OpPattern.ELEM_WISE)
reg.register_schedule("__rdiv_scalar__", _fschedule_broadcast)

# elemwise_add
reg.register_compute("elemwise_add", _compute_binary(topi.broadcast_add))
reg.register_pattern("elemwise_add", OpPattern.BROADCAST)
reg.register_schedule("elemwise_add", _fschedule_broadcast)

# elemwise_sub
reg.register_compute("elemwise_sub", _compute_binary(topi.broadcast_sub))
reg.register_pattern("elemwise_sub", OpPattern.BROADCAST)
reg.register_schedule("elemwise_sub", _fschedule_broadcast)

# elemwise_mul
reg.register_compute("elemwise_mul", _compute_binary(topi.broadcast_mul))
reg.register_pattern("elemwise_mul", OpPattern.BROADCAST)
reg.register_schedule("elemwise_mul", _fschedule_broadcast)

# elemwise_div
reg.register_compute("elemwise_div", _compute_binary(topi.broadcast_div))
reg.register_pattern("elemwise_div", OpPattern.BROADCAST)
reg.register_schedule("elemwise_div", _fschedule_broadcast)

# broadcast_add
reg.register_compute("broadcast_add",
lambda _, x: topi.broadcast_add(x[0], x[1]))
reg.register_compute("broadcast_add", _compute_binary(topi.broadcast_add))
reg.register_pattern("broadcast_add", OpPattern.BROADCAST)
reg.register_schedule("broadcast_add", _fschedule_broadcast)

# broadcast_sub
reg.register_compute("broadcast_sub",
lambda _, x: topi.broadcast_sub(x[0], x[1]))
reg.register_compute("broadcast_sub", _compute_binary(topi.broadcast_sub))
reg.register_pattern("broadcast_sub", OpPattern.BROADCAST)
reg.register_schedule("broadcast_sub", _fschedule_broadcast)

# broadcast_mul
reg.register_compute("broadcast_mul",
lambda _, x: topi.broadcast_mul(x[0], x[1]))
reg.register_compute("broadcast_mul", _compute_binary(topi.broadcast_mul))
reg.register_pattern("broadcast_mul", OpPattern.BROADCAST)
reg.register_schedule("broadcast_mul", _fschedule_broadcast)

# broadcast_div
reg.register_compute("broadcast_div",
lambda _, x: topi.broadcast_div(x[0], x[1]))
reg.register_compute("broadcast_div", _compute_binary(topi.broadcast_div))
reg.register_pattern("broadcast_div", OpPattern.BROADCAST)
reg.register_schedule("broadcast_div", _fschedule_broadcast)

# broadcast_to
@reg.register_compute("broadcast_to")
def compute_softmax(attrs, inputs, out_info):
"""Compute definition of softmax"""
return topi.broadcast_to(inputs[0], shape=out_info[0].shape)
reg.register_pattern("broadcast_to", OpPattern.BROADCAST)
reg.register_schedule("broadcast_to", _fschedule_broadcast)
31 changes: 31 additions & 0 deletions nnvm/python/nnvm/top/transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# pylint: disable=invalid-name, unused-argument
"""Tensor transformation ops"""
from __future__ import absolute_import

import tvm
from .tensor import _fschedule_broadcast
from ..compiler import registry as reg
from ..compiler import OpPattern

# Need add reshape, transpose

def _flatten_index(indices, shape):
"""flatten the index to 1D"""
idx = 0
for i, value in enumerate(shape):
if i != 0:
idx *= value
idx = idx + indices[i]
return idx

# reshape
@reg.register_compute("reshape")
def compute_reshape(attrs, inputs, out_info):
"""Compute definition of softmax"""
# TODO(sxj) add support for general reshape
assert len(inputs[0].shape) == 1, "Only support 1d input for now"
oshape = out_info[0].shape
x = inputs[0]
return tvm.compute(oshape, lambda *i: x(_flatten_index(i, oshape)))
reg.register_pattern("reshape", OpPattern.COMPLEX)
reg.register_schedule("reshape", _fschedule_broadcast)
17 changes: 14 additions & 3 deletions nnvm/src/compiler/graph_fuse.cc
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ nnvm::Graph GraphFuse(nnvm::Graph g) {
if (inode.source->is_variable()) continue;
int root_id = group_vec[nid];
FuseEntry& fe = fuse_vec[root_id];
Array<Tensor> inputs;
Array<Tensor> inputs, out_info;
// input loading
for (const auto& e : inode.inputs) {
if (group_vec[e.node_id] != root_id) {
Expand All @@ -274,11 +274,21 @@ nnvm::Graph GraphFuse(nnvm::Graph g) {
inputs.push_back(t);
}
}
// output hint
for (uint32_t i = 0; i < inode.source->num_outputs(); ++i) {
Array<Expr> shape;
for (int64_t x : shape_vec[idx.entry_id(nid, i)]) {
CHECK_LE(x, static_cast<int64_t>(std::numeric_limits<int>::max()));
shape.push_back(make_const(Int(32), x));
}
out_info.push_back(
placeholder(shape,
TVMType2Type(dltype_vec[idx.entry_id(nid, i)])));
}
// get default
Array<Tensor> out = fcompute[inode.source->op()](
inode.source->attrs, inputs);
inode.source->attrs, inputs, out_info);
CHECK_EQ(out.size(), inode.source->num_outputs());

// schedule on root node, and use master's schedule
if (nid != root_id) {
for (uint32_t index = 0; index < inode.source->num_outputs(); ++index) {
Expand Down Expand Up @@ -312,6 +322,7 @@ nnvm::Graph GraphFuse(nnvm::Graph g) {
}
}
}

tvm::runtime::Module module = fbuild(funcs, target);
// Final step: Remap the node, with given attribute
const nnvm::Op* tvm_op = nnvm::Op::Get("tvm_op");
Expand Down
Loading

0 comments on commit 02a60d0

Please sign in to comment.