Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LANG/SCHEDULE] Reduction factor, predicate in reduction. #77

Merged
merged 1 commit into from
Mar 29, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/tvm/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ using Halide::Internal::as_const_uint;
using Halide::Internal::const_true;
using Halide::Internal::const_false;
using Halide::Internal::is_no_op;
using Halide::likely;
using Halide::likely_if_innermost;

inline Type TVMShapeIndexType() {
if (std::is_signed<tvm_index_t>::value) {
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ struct Reduce : public ExprNode<Reduce> {
/*! \brief construct expr from op and rdom */
static Expr make(std::string op, Expr src,
Array<IterVar> rdom,
Expr condition = make_const(Bool(1), true));
Expr condition = const_true());

void VisitAttrs(AttrVisitor* v) final {
v->Visit("dtype", &type);
Expand Down
12 changes: 12 additions & 0 deletions include/tvm/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,18 @@ class Schedule : public NodeRef {
* \return The created tensor.
*/
Tensor cache_write(const Tensor& tensor, const std::string& scope);
/*!
* \brief Factor a reduction axis in tensor's schedule to be an explicit axis.
* This will create a new stage that generated the new tensor with axis
* as the first dimension. The tensor's body wil be rewriten as a reduction
* over the factored tensor.
*
* \param tensor The tensor to be factored.
* \param axis The reduction axis in tensor's schedule to be factored.
* \return The created factored tensor.
*/
Tensor rfactor(const Tensor& tensor,
const IterVar& axis);
/*!
* \brief Normalize the schedule.
* This is needed before bound inference.
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@

from ._base import TVMError
from .api import *
from .build import build
from .build import build, lower
21 changes: 15 additions & 6 deletions python/tvm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ def reduce_axis(dom, name="rv"):
return _IterVar(dom, name, 2)


def sum(expr, axis):
def sum(expr, axis, where=None):
"""Create a sum expression over axis

Parameters
Expand All @@ -382,13 +382,16 @@ def sum(expr, axis):

axis : IterVar
The reduction IterVar axis

where : optional, Expr
Filtering predicate of the reduction.
"""
axis = axis if isinstance(axis, list) else [axis]
x = _make.Reduce("Add", expr, axis)
x = _make.Reduce("Add", expr, axis, where)
return x


def min(lhs, rhs=None, axis=None):
def min(lhs, rhs=None, axis=None, where=None):
"""Create a min expression.

Parameters
Expand All @@ -401,6 +404,9 @@ def min(lhs, rhs=None, axis=None):

axis : IterVar, optional
The reduction IterVar axis

where : optional, Expr
Filtering predicate of the reduction.
"""
if rhs and axis:
raise ValueError("Can only take one argument, rhs or axis")
Expand All @@ -409,11 +415,11 @@ def min(lhs, rhs=None, axis=None):
if rhs:
return _make.Min(lhs, rhs)
axis = axis if isinstance(axis, list) else [axis]
x = _make.Reduce("Min", expr, axis)
x = _make.Reduce("Min", expr, axis, where)
return x


def max(lhs, rhs=None, axis=None):
def max(lhs, rhs=None, axis=None, where=None):
"""Create a max expression.

Parameters
Expand All @@ -426,6 +432,9 @@ def max(lhs, rhs=None, axis=None):

axis : IterVar, optional
The reduction IterVar axis

where : optional, Expr
Filtering predicate of the reduction.
"""
if rhs and axis:
raise ValueError("Can only take one argument, rhs or axis")
Expand All @@ -434,7 +443,7 @@ def max(lhs, rhs=None, axis=None):
if rhs:
return _make.Max(lhs, rhs)
axis = axis if isinstance(axis, list) else [axis]
x = _make.Reduce("Max", expr, axis)
x = _make.Reduce("Max", expr, axis, where)
return x


Expand Down
73 changes: 60 additions & 13 deletions python/tvm/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,15 @@
from . import schedule
from . import expr
from . import ir_pass
from . import collections
from . import codegen

def build(sch,
def lower(sch,
args,
target,
target_host="stackvm",
name="default_function",
binds=None,
max_auto_unroll_step=8):
"""Build a function with arguments as signiture.
"""Lowering step before build into target.

Parameters
----------
Expand All @@ -28,12 +27,6 @@ def build(sch,
args : list of Buffer or Tensor or Var
The argument lists to the function.

target : str
The target of the compilation.

target_host :
Host compilation target, if target is device.

name : str
The name of result function.

Expand All @@ -46,10 +39,8 @@ def build(sch,

Returns
-------
f : Function, or pair of functions
f : LoweredFunc
The result function.
If the function requires host space allocation,
a pair of functions will be returned.
"""
binds = {} if binds is None else binds.copy()
arg_list = []
Expand Down Expand Up @@ -77,6 +68,62 @@ def build(sch,
stmt = ir_pass.UnrollLoop(stmt, max_auto_unroll_step)
stmt = ir_pass.Simplify(stmt)
fapi = ir_pass.MakeAPI(stmt, name, arg_list, 0)
return fapi



def build(sch,
args=None,
target="llvm",
target_host="stackvm",
name="default_function",
binds=None,
max_auto_unroll_step=8):
"""Build a function with arguments as signiture.

Parameters
----------
sch : tvm.Schedule, or LoweredFunc
The schedule to be builded

args : list of Buffer or Tensor or Var
The argument lists to the function.

target : str
The target of the compilation.

target_host :
Host compilation target, if target is device.

name : str
The name of result function.

binds : dict, optional
Dictionary that maps the binding of symbolic buffer to Tensor.
By default, a new buffer is created for each tensor in the argument.

max_auto_unroll_step: int
Maximum step to perform automatic unrolling

Returns
-------
f : Function, or pair of functions
The result function.
"""
if isinstance(sch, schedule.Schedule):
if args is None:
raise ValueError("args must be given for build from schedule")
fapi = lower(sch, args,
name=name,
binds=binds,
max_auto_unroll_step=max_auto_unroll_step)
elif isinstance(sch, collections.LoweredFunc):
if args:
raise ValueError("args must be done when build from LoweredFunc")
fapi = sch
else:
raise ValueError("sch have to be Schedule or LoweredFunc")

fsplits = ir_pass.SplitHostDevice(fapi)
fsplits = [x for x in fsplits]
for i in range(1, len(fsplits)):
Expand Down
23 changes: 21 additions & 2 deletions python/tvm/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,27 @@ def cache_write(self, tensor, scope):
"""
return _api_internal._ScheduleCacheWrite(self, tensor, scope)

def rfactor(self, tensor, axis):
""" Factor a reduction axis in tensor's schedule to be an explicit axis.

This will create a new stage that generated the new tensor with axis
as the first dimension. The tensor's body wil be rewriten as a reduction
over the factored tensor.

Parameters
----------
tensor : Tensor
The tensor to be factored.
axis : IterVar
The reduction axis in the schedule to be factored.

Returns
-------
tfactor : Tensor
The created factored tensor.
"""
return _api_internal._ScheduleRFactor(self, tensor, axis)


@register_node
class Stage(NodeBase):
Expand Down Expand Up @@ -114,8 +135,6 @@ def split(self, parent, factor=None, outer=None):
The inner variable of iteration.
"""
if outer is not None:
if outer.thread_tag == '':
raise ValueError("split by outer must have special thread_tag")
inner = _api_internal._StageSplitByOuter(self, parent, outer, factor)
else:
if factor is None:
Expand Down
2 changes: 1 addition & 1 deletion src/api/api_ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ TVM_REGISTER_API(_make_Allocate)
*ret = Node::make(a, b); \
})

REGISTER_MAKE3(Reduce);
REGISTER_MAKE4(Reduce);
REGISTER_MAKE4(AttrStmt);

REGISTER_MAKE2(IntImm);
Expand Down
6 changes: 6 additions & 0 deletions src/api/api_lang.cc
Original file line number Diff line number Diff line change
Expand Up @@ -318,4 +318,10 @@ TVM_REGISTER_API(_ScheduleCacheWrite)
.cache_write(args[1], args[2]);
});

TVM_REGISTER_API(_ScheduleRFactor)
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Schedule()
.rfactor(args[1], args[2]);
});

} // namespace tvm
2 changes: 1 addition & 1 deletion src/codegen/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -526,8 +526,8 @@ void CodeGenC::VisitExpr_(const Load* op, std::ostream& os) { // NOLINT(*)
void CodeGenC::VisitStmt_(const Store* op) {
Type t = op->value.type();
if (t.lanes() == 1) {
this->PrintIndent();
std::string value = this->PrintExpr(op->value);
this->PrintIndent();
this->PrintBufferRef(op->buffer_var.get(), t, op->index, stream);
stream << " = " << value << ";\n";
} else {
Expand Down
5 changes: 4 additions & 1 deletion src/lang/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
p->print(op->source);
p->stream << ", axis=" << op->axis;
if (!is_const(op->condition, 1)) {
p->stream << ", condition=" << op->condition;
p->stream << ", where=" << op->condition;
}
p->stream << ")";
});
Expand All @@ -45,6 +45,9 @@ Expr Reduce::make(std::string op, Expr source,
CHECK_EQ(axis[i]->iter_type, kCommReduce)
<< "Can only take axis created by reduce_axis";
}
if (!condition.defined()) {
condition = const_true();
}
auto n = std::make_shared<Reduce>();
CHECK(source.defined());
for (size_t i = 0; i < axis.size(); ++i) {
Expand Down
6 changes: 0 additions & 6 deletions src/lang/operation.cc

This file was deleted.

36 changes: 6 additions & 30 deletions src/op/compute_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <tvm/ir_pass.h>
#include <unordered_set>
#include "./op_util.h"
#include "../schedule/message_passing.h"

namespace tvm {

Expand Down Expand Up @@ -64,10 +65,7 @@ Tensor compute(Array<Expr> shape, FCompute fcompute, std::string name) {
args.push_back(axis.back()->var);
}

op_node->axis = Array<IterVar>(axis);
op_node->body = fcompute(args);
op_node->name = name;
return Operation(op_node).output(0);
return ComputeOpNode::make(name, axis, fcompute(args)).output(0);
}

Operation ComputeOpNode::make(std::string name,
Expand Down Expand Up @@ -191,6 +189,9 @@ void MakeReduction(const ComputeOpNode* op,
}
*init = Provide::make(t->op, t->value_index, init_value, args);
*provide = Provide::make(t->op, t->value_index, update_value, args);
if (!is_one(reduce->condition)) {
*provide = IfThenElse::make(reduce->condition, *provide);
}
}

Stmt MakeProvide(const ComputeOpNode* op,
Expand All @@ -202,31 +203,6 @@ Stmt MakeProvide(const ComputeOpNode* op,
return Provide::make(t->op, t->value_index, op->body, args);
}

// message passing to find if IterVar is related to reduction.
void PassDownReduceFlag(const Stage& s,
std::unordered_map<IterVar, int>* p_state) {
auto& state = *p_state;
for (IterVarRelation rel : s->relations) {
if (rel.as<SplitNode>()) {
const SplitNode* s = rel.as<SplitNode>();
int flag = state.at(s->parent);
state[s->outer] = flag;
state[s->inner] = flag;
} else if (rel.as<FuseNode>()) {
const FuseNode* s = rel.as<FuseNode>();
int flag_outer = state.at(s->outer);
int flag_inner = state.at(s->inner);
state[s->fused] = flag_outer | flag_inner;
} else if (rel.as<RebaseNode>()) {
const RebaseNode* s = rel.as<RebaseNode>();
int flag = state.at(s->parent);
state[s->rebased] = flag;
} else {
LOG(FATAL) << "unknown relation type";
}
}
}

Stmt Substitute(Stmt s,
const std::unordered_map<IterVar, Expr>& value_map) {
Map<Var, Expr> temp;
Expand Down Expand Up @@ -267,7 +243,7 @@ Stmt ComputeOpNode::BuildProvide(
update_state[iv] = 1;
}
// find which iter var is related to reduction and which is related to axis.
PassDownReduceFlag(stage, &update_state);
schedule::PassDownBitMaskOr(stage, &update_state);
auto leaf_iter_vars = stage->leaf_iter_vars;
std::unordered_map<IterVar, Expr> init_value_map;
// first first loop that is related to reduction.
Expand Down
Loading