Skip to content

Commit

Permalink
[REFACTOR/PASS] Formalize argument bind and match util (#214)
Browse files Browse the repository at this point in the history
* [REFACTOR/PASS] Formalize argument bind and match util

* grammar
  • Loading branch information
tqchen authored Jul 4, 2017
1 parent 3c19159 commit 4bb3c35
Show file tree
Hide file tree
Showing 11 changed files with 549 additions and 300 deletions.
26 changes: 26 additions & 0 deletions include/tvm/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,14 @@ bool HasSideEffect(const Expr& e);
*/
bool ExprUseVar(const Expr& e, const Var& v);

/*!
* \brief Whether e expression used any var in variable set..
* \param e The expression to be checked.
* \param vset The variable set.
* \return Whether e uses vset.
*/
bool ExprUseVar(const Expr& e, const std::unordered_set<const Variable*>& vset);

/*!
* \brief Convert a IR node to be SSA form.
* \param stmt The source statement to be converted.
Expand All @@ -77,6 +85,24 @@ Stmt ConvertSSA(Stmt stmt);
*/
Stmt CanonicalSimplify(Stmt stmt);

/*!
* \brief Substitute the var specified in key->var to be value.
* \param stmt The source statement to be substituted
* \param value_map The map of new values.
* \return The converted form.
*/
Stmt Substitute(Stmt stmt,
const std::unordered_map<const Variable*, Expr>& value_map);

/*!
* \brief Substitute the var specified in key->var to be value.
* \param expr The source expression to be substituted
* \param value_map The map of new values.
* \return The converted expression.
*/
Expr Substitute(Expr expr,
const std::unordered_map<const Variable*, Expr>& value_map);

/*!
* \brief Substitute the var specified in key->var to be value.
* \param stmt The source statement to be substituted
Expand Down
4 changes: 3 additions & 1 deletion python/tvm/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from . import collections as _collections
from ._ffi.function import _init_api


@register_node
class Buffer(NodeBase):
"""Symbolic data buffer in TVM.
Expand All @@ -24,16 +23,19 @@ class Buffer(NodeBase):
"""
pass


@register_node
class Split(NodeBase):
"""Split operation on axis."""
pass


@register_node
class Fuse(NodeBase):
"""Fuse operation on axis."""
pass


@register_node
class IterVar(NodeBase, _expr.ExprOp):
"""Represent iteration variable.
Expand Down
6 changes: 5 additions & 1 deletion src/api/api_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ TVM_REGISTER_API("ir_pass.Equal")
}
});

TVM_REGISTER_API("ir_pass.ExprUseVar")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = ExprUseVar(args[0].operator Expr(), args[1].operator Var());
});

TVM_REGISTER_API("ir_pass.PostOrderVisit")
.set_body([](TVMArgs args, TVMRetValue *ret) {
PackedFunc f = args[1];
Expand Down Expand Up @@ -69,7 +74,6 @@ REGISTER_PASS1(CanonicalSimplify);
REGISTER_PASS4(Inline);
REGISTER_PASS2(StorageFlatten);
REGISTER_PASS1(VectorizeLoop);
REGISTER_PASS2(ExprUseVar);
REGISTER_PASS4(UnrollLoop);
REGISTER_PASS2(StorageSync);
REGISTER_PASS4(MakeAPI);
Expand Down
6 changes: 1 addition & 5 deletions src/codegen/stack_vm/codegen_stack_vm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -215,11 +215,7 @@ void CodeGenStackVM::PushBinary(StackVM::OpCode op_int64,
if (t.is_int()) {
this->PushOp(op_int64);
} else if (t.is_uint()) {
if (t.bits() <= 32) {
this->PushOp(op_int64);
} else {
LOG(FATAL) << "Cannot handle uint64_t in StackVM";
}
this->PushOp(op_int64);
} else {
this->PushOp(StackVM::CodeI64ToF64(op_int64));
}
Expand Down
196 changes: 196 additions & 0 deletions src/pass/arg_binder.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
/*!
* Copyright (c) 2017 by Contributors
* \file arg_binder.cc
* \brief Helper utility to match and bind arguments.
*/
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/runtime/device_api.h>
#include "./ir_util.h"
#include "./arg_binder.h"
#include "../arithmetic/compute_expr.h"

namespace tvm {
namespace ir {

void BinderAddAssert(Expr cond,
const std::string& arg_name,
std::vector<Stmt>* asserts) {
cond = Simplify(cond);
if (is_zero(cond)) {
LOG(FATAL) << "Bind have an unmet assertion: "
<< cond << ", " << " on argument " << arg_name;
}
if (!is_one(cond)) {
std::ostringstream os;
os << "Argument " << arg_name << " has an unsatisfied constraint";
asserts->emplace_back(AssertStmt::make(cond, os.str()));
}
}

bool ArgBinder::Bind_(const Expr& arg,
const Expr& value,
const std::string& arg_name,
bool with_lets) {
CHECK_EQ(arg.type(), value.type());
if (const Variable* v = arg.as<Variable>()) {
auto it = def_map_->find(v);
if (it == def_map_->end()) {
Var v_arg(arg.node_);
defs_.emplace_back(v_arg);
if (with_lets) {
(*def_map_)[v] = arg;
init_nest_.emplace_back(LetStmt::make(v_arg, value, Evaluate::make(0)));
} else {
(*def_map_)[v] = value;
}
return true;
} else {
BinderAddAssert(it->second == value, arg_name, &asserts_);
}
} else {
BinderAddAssert(arg == value, arg_name, &asserts_);
}
return false;
}

void ArgBinder::Bind(const Expr& arg,
const Expr& value,
const std::string& arg_name,
bool with_let) {
Bind_(arg, value, arg_name, with_let);
}

void ArgBinder::BindArray(const Array<Expr>& arg,
const Array<Expr>& value,
const std::string& arg_name) {
CHECK_EQ(arg.size(), value.size())
<< "Argument " << arg_name << " array size mismatch";
for (size_t i = 0; i < arg.size(); ++i) {
std::ostringstream os;
os << arg_name << "[" << i << "]";
this->Bind(arg[i], value[i], os.str());
}
}

void ArgBinder::BindBuffer(const Buffer& arg,
const Buffer& value,
const std::string& arg_name) {
CHECK_EQ(arg->scope, value->scope)
<< "Argument " << arg_name
<< " Buffer bind scope mismatch";
this->Bind(arg->data, value->data, arg_name + ".data");
this->BindArray(arg->shape, value->shape, arg_name + ".shape");
this->BindArray(arg->strides, value->strides, arg_name + ".strides");
this->Bind(arg->elem_offset, value->elem_offset, arg_name + ".elem_offset");
}

inline Expr TVMArrayGet(Type t, Var arr, intrinsic::TVMStructFieldKind kind) {
return TVMStructGet(t, arr, 0, kind);
}

inline Stmt AssertNull(Var handle, std::string msg) {
return AssertStmt::make(Call::make(
Bool(1), intrinsic::tvm_handle_is_null,
{handle}, Call::PureIntrinsic), msg);
}

void ArgBinder::BindDLTensor(const Buffer& buffer,
const Expr& device_type,
const Expr& device_id,
const Var& handle,
const std::string& arg_name) {
const Type tvm_shape_type = TVMShapeIndexType();
const Type tvm_ndim_type = Int(32);
const Stmt nop = Evaluate::make(0);
// dimension checks
Expr v_ndim = TVMArrayGet(tvm_ndim_type, handle, intrinsic::kArrNDim);
Expr a_ndim = make_const(tvm_ndim_type,
static_cast<int64_t>(buffer->shape.size()));
std::ostringstream ndim_err_msg;
ndim_err_msg << arg_name
<< ".ndim is expected to equal "
<< buffer->shape.size();
asserts_.emplace_back(AssertStmt::make(a_ndim == v_ndim, ndim_err_msg.str()));
// type checks
Type dtype = buffer->dtype;
std::ostringstream type_err_msg;
type_err_msg << arg_name << ".dtype is expected to be " << dtype;
Expr cond = (TVMArrayGet(UInt(8), handle, intrinsic::kArrTypeCode) ==
UIntImm::make(UInt(8), dtype.code()) &&
TVMArrayGet(UInt(8), handle, intrinsic::kArrTypeBits) ==
UIntImm::make(UInt(8), dtype.bits()) &&
TVMArrayGet(UInt(16), handle, intrinsic::kArrTypeLanes) ==
UIntImm::make(UInt(16), dtype.lanes()));
asserts_.emplace_back(AssertStmt::make(cond, type_err_msg.str()));
// data field
if (Bind_(buffer->data, TVMArrayGet(Handle(), handle, intrinsic::kArrData),
arg_name + ".data", true)) {
Var vptr(buffer->data);
def_handle_dtype_.Set(vptr, make_const(buffer->dtype, 0));
// mark alignment of external bufs
init_nest_.emplace_back(AttrStmt::make(
vptr, ir::attr::storage_alignment,
IntImm::make(Int(32), runtime::kAllocAlignment), nop));
}

Var v_shape(arg_name + ".shape", Handle());
def_handle_dtype_.Set(v_shape, make_const(tvm_shape_type, 0));
init_nest_.emplace_back(LetStmt::make(
v_shape, TVMArrayGet(Handle(), handle, intrinsic::kArrShape), nop));
for (size_t k = 0; k < buffer->shape.size(); ++k) {
std::ostringstream field_name;
field_name << v_shape->name_hint << '[' << k << ']';
Bind_(buffer->shape[k],
cast(buffer->shape[k].type(),
Load::make(tvm_shape_type, v_shape,
IntImm::make(Int(32), k), const_true(1))),
field_name.str(), true);
}
// strides field
Var v_strides(arg_name + ".strides", Handle());
def_handle_dtype_.Set(v_strides, make_const(tvm_shape_type, 0));
init_nest_.emplace_back(LetStmt::make(
v_strides, TVMArrayGet(Handle(), handle, intrinsic::kArrStrides),
nop));
if (buffer->strides.size() == 0) {
std::ostringstream stride_err_msg;
stride_err_msg << arg_name << ".strides:"
<< " expected to be nullptr for contiguous array";
init_nest_.emplace_back(AssertNull(v_strides, stride_err_msg.str()));
} else {
for (size_t k = 0; k < buffer->strides.size(); ++k) {
std::ostringstream field_name;
field_name << v_strides->name_hint << '[' << k << ']';
Bind_(buffer->strides[k],
cast(buffer->shape[k].type(),
Load::make(tvm_shape_type, v_strides,
IntImm::make(Int(32), k), const_true(1))),
field_name.str(), true);
}
}
// Byte_offset field.
int data_bytes = GetVectorBytes(buffer->dtype);
int64_t const_offset;
if (arith::GetConst(buffer->elem_offset, &const_offset)) {
Bind_(make_const(UInt(64), const_offset * data_bytes),
TVMArrayGet(UInt(64), handle, intrinsic::kArrByteOffset),
arg_name + ".byte_offset", true);
} else {
Bind_(buffer->elem_offset,
cast(buffer->elem_offset.type(),
(TVMArrayGet(UInt(64), handle, intrinsic::kArrByteOffset) /
make_const(UInt(64), data_bytes))),
arg_name + ".elem_offset", true);
}
// device info.
Bind_(device_type,
TVMArrayGet(Int(32), handle, intrinsic::kArrDeviceType),
arg_name + ".device_type", true);
Bind_(device_id,
TVMArrayGet(Int(32), handle, intrinsic::kArrDeviceId),
arg_name + ".device_id", true);
}

} // namespace ir
} // namespace tvm
Loading

0 comments on commit 4bb3c35

Please sign in to comment.