Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REFACTOR/PASS] Formalize argument bind and match util #214

Merged
merged 2 commits into from
Jul 4, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions include/tvm/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,14 @@ bool HasSideEffect(const Expr& e);
*/
bool ExprUseVar(const Expr& e, const Var& v);

/*!
* \brief Whether e expression used any var in variable set..
* \param e The expression to be checked.
* \param vset The variable set.
* \return Whether e uses vset.
*/
bool ExprUseVar(const Expr& e, const std::unordered_set<const Variable*>& vset);

/*!
* \brief Convert a IR node to be SSA form.
* \param stmt The source statement to be converted.
Expand All @@ -77,6 +85,24 @@ Stmt ConvertSSA(Stmt stmt);
*/
Stmt CanonicalSimplify(Stmt stmt);

/*!
* \brief Substitute the var specified in key->var to be value.
* \param stmt The source statement to be substituted
* \param value_map The map of new values.
* \return The converted form.
*/
Stmt Substitute(Stmt stmt,
const std::unordered_map<const Variable*, Expr>& value_map);

/*!
* \brief Substitute the var specified in key->var to be value.
* \param expr The source expression to be substituted
* \param value_map The map of new values.
* \return The converted expression.
*/
Expr Substitute(Expr expr,
const std::unordered_map<const Variable*, Expr>& value_map);

/*!
* \brief Substitute the var specified in key->var to be value.
* \param stmt The source statement to be substituted
Expand Down
4 changes: 3 additions & 1 deletion python/tvm/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from . import collections as _collections
from ._ffi.function import _init_api


@register_node
class Buffer(NodeBase):
"""Symbolic data buffer in TVM.
Expand All @@ -24,16 +23,19 @@ class Buffer(NodeBase):
"""
pass


@register_node
class Split(NodeBase):
"""Split operation on axis."""
pass


@register_node
class Fuse(NodeBase):
"""Fuse operation on axis."""
pass


@register_node
class IterVar(NodeBase, _expr.ExprOp):
"""Represent iteration variable.
Expand Down
6 changes: 5 additions & 1 deletion src/api/api_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ TVM_REGISTER_API("ir_pass.Equal")
}
});

TVM_REGISTER_API("ir_pass.ExprUseVar")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = ExprUseVar(args[0].operator Expr(), args[1].operator Var());
});

TVM_REGISTER_API("ir_pass.PostOrderVisit")
.set_body([](TVMArgs args, TVMRetValue *ret) {
PackedFunc f = args[1];
Expand Down Expand Up @@ -69,7 +74,6 @@ REGISTER_PASS1(CanonicalSimplify);
REGISTER_PASS4(Inline);
REGISTER_PASS2(StorageFlatten);
REGISTER_PASS1(VectorizeLoop);
REGISTER_PASS2(ExprUseVar);
REGISTER_PASS4(UnrollLoop);
REGISTER_PASS2(StorageSync);
REGISTER_PASS4(MakeAPI);
Expand Down
6 changes: 1 addition & 5 deletions src/codegen/stack_vm/codegen_stack_vm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -215,11 +215,7 @@ void CodeGenStackVM::PushBinary(StackVM::OpCode op_int64,
if (t.is_int()) {
this->PushOp(op_int64);
} else if (t.is_uint()) {
if (t.bits() <= 32) {
this->PushOp(op_int64);
} else {
LOG(FATAL) << "Cannot handle uint64_t in StackVM";
}
this->PushOp(op_int64);
} else {
this->PushOp(StackVM::CodeI64ToF64(op_int64));
}
Expand Down
196 changes: 196 additions & 0 deletions src/pass/arg_binder.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
/*!
* Copyright (c) 2017 by Contributors
* \file arg_binder.cc
* \brief Helper utility to match and bind arguments.
*/
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/runtime/device_api.h>
#include "./ir_util.h"
#include "./arg_binder.h"
#include "../arithmetic/compute_expr.h"

namespace tvm {
namespace ir {

void BinderAddAssert(Expr cond,
const std::string& arg_name,
std::vector<Stmt>* asserts) {
cond = Simplify(cond);
if (is_zero(cond)) {
LOG(FATAL) << "Bind have an unmet assertion: "
<< cond << ", " << " on argument " << arg_name;
}
if (!is_one(cond)) {
std::ostringstream os;
os << "Argument " << arg_name << " has an unsatisfied constraint";
asserts->emplace_back(AssertStmt::make(cond, os.str()));
}
}

bool ArgBinder::Bind_(const Expr& arg,
const Expr& value,
const std::string& arg_name,
bool with_lets) {
CHECK_EQ(arg.type(), value.type());
if (const Variable* v = arg.as<Variable>()) {
auto it = def_map_->find(v);
if (it == def_map_->end()) {
Var v_arg(arg.node_);
defs_.emplace_back(v_arg);
if (with_lets) {
(*def_map_)[v] = arg;
init_nest_.emplace_back(LetStmt::make(v_arg, value, Evaluate::make(0)));
} else {
(*def_map_)[v] = value;
}
return true;
} else {
BinderAddAssert(it->second == value, arg_name, &asserts_);
}
} else {
BinderAddAssert(arg == value, arg_name, &asserts_);
}
return false;
}

void ArgBinder::Bind(const Expr& arg,
const Expr& value,
const std::string& arg_name,
bool with_let) {
Bind_(arg, value, arg_name, with_let);
}

void ArgBinder::BindArray(const Array<Expr>& arg,
const Array<Expr>& value,
const std::string& arg_name) {
CHECK_EQ(arg.size(), value.size())
<< "Argument " << arg_name << " array size mismatch";
for (size_t i = 0; i < arg.size(); ++i) {
std::ostringstream os;
os << arg_name << "[" << i << "]";
this->Bind(arg[i], value[i], os.str());
}
}

void ArgBinder::BindBuffer(const Buffer& arg,
const Buffer& value,
const std::string& arg_name) {
CHECK_EQ(arg->scope, value->scope)
<< "Argument " << arg_name
<< " Buffer bind scope mismatch";
this->Bind(arg->data, value->data, arg_name + ".data");
this->BindArray(arg->shape, value->shape, arg_name + ".shape");
this->BindArray(arg->strides, value->strides, arg_name + ".strides");
this->Bind(arg->elem_offset, value->elem_offset, arg_name + ".elem_offset");
}

inline Expr TVMArrayGet(Type t, Var arr, intrinsic::TVMStructFieldKind kind) {
return TVMStructGet(t, arr, 0, kind);
}

inline Stmt AssertNull(Var handle, std::string msg) {
return AssertStmt::make(Call::make(
Bool(1), intrinsic::tvm_handle_is_null,
{handle}, Call::PureIntrinsic), msg);
}

void ArgBinder::BindDLTensor(const Buffer& buffer,
const Expr& device_type,
const Expr& device_id,
const Var& handle,
const std::string& arg_name) {
const Type tvm_shape_type = TVMShapeIndexType();
const Type tvm_ndim_type = Int(32);
const Stmt nop = Evaluate::make(0);
// dimension checks
Expr v_ndim = TVMArrayGet(tvm_ndim_type, handle, intrinsic::kArrNDim);
Expr a_ndim = make_const(tvm_ndim_type,
static_cast<int64_t>(buffer->shape.size()));
std::ostringstream ndim_err_msg;
ndim_err_msg << arg_name
<< ".ndim is expected to equal "
<< buffer->shape.size();
asserts_.emplace_back(AssertStmt::make(a_ndim == v_ndim, ndim_err_msg.str()));
// type checks
Type dtype = buffer->dtype;
std::ostringstream type_err_msg;
type_err_msg << arg_name << ".dtype is expected to be " << dtype;
Expr cond = (TVMArrayGet(UInt(8), handle, intrinsic::kArrTypeCode) ==
UIntImm::make(UInt(8), dtype.code()) &&
TVMArrayGet(UInt(8), handle, intrinsic::kArrTypeBits) ==
UIntImm::make(UInt(8), dtype.bits()) &&
TVMArrayGet(UInt(16), handle, intrinsic::kArrTypeLanes) ==
UIntImm::make(UInt(16), dtype.lanes()));
asserts_.emplace_back(AssertStmt::make(cond, type_err_msg.str()));
// data field
if (Bind_(buffer->data, TVMArrayGet(Handle(), handle, intrinsic::kArrData),
arg_name + ".data", true)) {
Var vptr(buffer->data);
def_handle_dtype_.Set(vptr, make_const(buffer->dtype, 0));
// mark alignment of external bufs
init_nest_.emplace_back(AttrStmt::make(
vptr, ir::attr::storage_alignment,
IntImm::make(Int(32), runtime::kAllocAlignment), nop));
}

Var v_shape(arg_name + ".shape", Handle());
def_handle_dtype_.Set(v_shape, make_const(tvm_shape_type, 0));
init_nest_.emplace_back(LetStmt::make(
v_shape, TVMArrayGet(Handle(), handle, intrinsic::kArrShape), nop));
for (size_t k = 0; k < buffer->shape.size(); ++k) {
std::ostringstream field_name;
field_name << v_shape->name_hint << '[' << k << ']';
Bind_(buffer->shape[k],
cast(buffer->shape[k].type(),
Load::make(tvm_shape_type, v_shape,
IntImm::make(Int(32), k), const_true(1))),
field_name.str(), true);
}
// strides field
Var v_strides(arg_name + ".strides", Handle());
def_handle_dtype_.Set(v_strides, make_const(tvm_shape_type, 0));
init_nest_.emplace_back(LetStmt::make(
v_strides, TVMArrayGet(Handle(), handle, intrinsic::kArrStrides),
nop));
if (buffer->strides.size() == 0) {
std::ostringstream stride_err_msg;
stride_err_msg << arg_name << ".strides:"
<< " expected to be nullptr for contiguous array";
init_nest_.emplace_back(AssertNull(v_strides, stride_err_msg.str()));
} else {
for (size_t k = 0; k < buffer->strides.size(); ++k) {
std::ostringstream field_name;
field_name << v_strides->name_hint << '[' << k << ']';
Bind_(buffer->strides[k],
cast(buffer->shape[k].type(),
Load::make(tvm_shape_type, v_strides,
IntImm::make(Int(32), k), const_true(1))),
field_name.str(), true);
}
}
// Byte_offset field.
int data_bytes = GetVectorBytes(buffer->dtype);
int64_t const_offset;
if (arith::GetConst(buffer->elem_offset, &const_offset)) {
Bind_(make_const(UInt(64), const_offset * data_bytes),
TVMArrayGet(UInt(64), handle, intrinsic::kArrByteOffset),
arg_name + ".byte_offset", true);
} else {
Bind_(buffer->elem_offset,
cast(buffer->elem_offset.type(),
(TVMArrayGet(UInt(64), handle, intrinsic::kArrByteOffset) /
make_const(UInt(64), data_bytes))),
arg_name + ".elem_offset", true);
}
// device info.
Bind_(device_type,
TVMArrayGet(Int(32), handle, intrinsic::kArrDeviceType),
arg_name + ".device_type", true);
Bind_(device_id,
TVMArrayGet(Int(32), handle, intrinsic::kArrDeviceId),
arg_name + ".device_id", true);
}

} // namespace ir
} // namespace tvm
Loading