Skip to content

Commit

Permalink
[TIR][BUILD] Remove buffer params from pass config.
Browse files Browse the repository at this point in the history
Buffer configurations can be passed during construction
and does not need to be part of the build config.

This is a refactor step to simplify the BuildConfig for the PassContext migration.
  • Loading branch information
tqchen committed May 22, 2020
1 parent e55f9ff commit 19d6680
Show file tree
Hide file tree
Showing 12 changed files with 62 additions and 102 deletions.
17 changes: 0 additions & 17 deletions include/tvm/target/target.h
Original file line number Diff line number Diff line change
Expand Up @@ -177,17 +177,6 @@ TVM_DLL Target hexagon(const std::vector<std::string>& options = std::vector<std
*/
class BuildConfigNode : public Object {
public:
/*!
* \brief The data alignment to use when constructing buffers. If this is set to
* -1, then TVM's internal default will be used
*/
int data_alignment = -1;
/*!
* \brief The offset factor to use when constructing buffers. If this is set to
* 0, then the offset field is not used.
*/
int offset_factor = 0;

/*!
* \brief Splitting factor for loop splitting. If this is set to zero, no splitting will be
* done. Otherwise, a split will be done with this factor and the inner loop will be unrolled.
Expand Down Expand Up @@ -217,9 +206,6 @@ class BuildConfigNode : public Object {
/*! \brief List of passes to be injected into the low-level pipeline. */
std::vector<std::pair<int, transform::Pass>> add_lower_pass;

/*! \brief Whether to dump the IR of each pass (only when building from python) */
bool dump_pass_ir = false;

/*! \brief Whether to instrument loads and stores with check for out of the bounds. */
bool instrument_bound_checkers = false;

Expand All @@ -233,8 +219,6 @@ class BuildConfigNode : public Object {
bool disable_assert = false;

void VisitAttrs(AttrVisitor* v) {
v->Visit("data_alignment", &data_alignment);
v->Visit("offset_factor", &offset_factor);
v->Visit("double_buffer_split_loop", &double_buffer_split_loop);
v->Visit("auto_unroll_max_step", &auto_unroll_max_step);
v->Visit("auto_unroll_max_depth", &auto_unroll_max_depth);
Expand All @@ -243,7 +227,6 @@ class BuildConfigNode : public Object {
v->Visit("restricted_func", &restricted_func);
v->Visit("detect_global_barrier", &detect_global_barrier);
v->Visit("partition_const_loop", &partition_const_loop);
v->Visit("dump_pass_ir", &dump_pass_ir);
v->Visit("instrument_bound_checkers", &instrument_bound_checkers);
v->Visit("disable_select_rewriting", &disable_select_rewriting);
v->Visit("disable_vectorize", &disable_vectorize);
Expand Down
6 changes: 0 additions & 6 deletions python/tvm/driver/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ def get_binds(args, compact=False, binds=None):
The list of symbolic buffers of arguments.
"""
binds = {} if binds is None else binds.copy()
cfg = BuildConfig.current()
arg_list = []
for x in args:
if isinstance(x, tensor.Tensor):
Expand All @@ -66,9 +65,6 @@ def get_binds(args, compact=False, binds=None):
buf = tvm.tir.decl_buffer(
x.shape,
dtype=x.dtype,
name=x.name,
data_alignment=cfg.data_alignment,
offset_factor=cfg.offset_factor,
buffer_type=buffer_type)
binds[x] = buf
arg_list.append(buf)
Expand Down Expand Up @@ -157,8 +153,6 @@ def lower(sch,
"""
cfg = BuildConfig.current()
add_lower_pass = cfg.add_lower_pass if cfg.add_lower_pass else []
if cfg.dump_pass_ir:
add_lower_pass = BuildConfig._dump_ir.decorate_custompass(add_lower_pass)
lower_phase0 = [x[1] for x in add_lower_pass if x[0] == 0]
lower_phase1 = [x[1] for x in add_lower_pass if x[0] == 1]
lower_phase2 = [x[1] for x in add_lower_pass if x[0] == 2]
Expand Down
13 changes: 0 additions & 13 deletions python/tvm/target/build_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,8 @@ class BuildConfig(Object):
"unroll_explicit": True,
"detect_global_barrier": False,
"partition_const_loop": False,
"offset_factor": 0,
"data_alignment": -1,
"restricted_func": True,
"double_buffer_split_loop": 1,
"dump_pass_ir": False,
"instrument_bound_checkers": False,
"disable_select_rewriting": False,
"disable_vectorize": False,
Expand Down Expand Up @@ -129,14 +126,6 @@ def build_config(**kwargs):
partition_const_loop: bool, default=False
Whether partition const loop
data_alignment: int, optional
The alignment of data pointer in bytes.
If -1 is passed, the alignment will be set to TVM's internal default.
offset_factor: int, default=0
The factor used in default buffer declaration.
If specified as 0, offset field is not used.
restricted_func: bool, default=True
Whether build restricted function.
That is each buffer argument to the function are guaranteed
Expand All @@ -152,8 +141,6 @@ def build_config(**kwargs):
phase contains an integer on which optimization pass we apply the pass.
Additional lowering passes to be applied before make_api.
dump_pass_ir: dump ir of each pass into file idx_passname_ir.cc, default=False
Returns
-------
config: BuildConfig
Expand Down
13 changes: 8 additions & 5 deletions python/tvm/te/tensor_intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

from tvm.runtime import Object, convert
from tvm.ir import Range
from tvm.target import BuildConfig
from .tensor import PlaceholderOp

from . import tensor as _tensor
Expand Down Expand Up @@ -68,7 +67,9 @@ def __call__(self, *args, **kwargs):
def decl_tensor_intrin(op,
fcompute,
name="tensor_intrin",
binds=None, scalar_params=None):
binds=None,
scalar_params=None,
default_buffer_params=None):
"""Declare a tensor intrinsic function.
Parameters
Expand Down Expand Up @@ -104,6 +105,9 @@ def decl_tensor_intrin(op,
scalar_params: a list of variables used by op, whose values will be passed
as scalar_inputs when the tensor intrinsic is called.
default_buffer_params: Optional[dict]
Dictionary of buffer arguments to be passed when constructing a buffer.
Returns
-------
intrin: TensorIntrin
Expand All @@ -122,12 +126,11 @@ def decl_tensor_intrin(op,
if not isinstance(t.op, PlaceholderOp):
raise ValueError("Do not yet support composition op")

cfg = BuildConfig.current()
default_buffer_params = {} if default_buffer_params is None else default_buffer_params
for t in tensors:
buf = (binds[t] if t in binds else
tvm.tir.decl_buffer(t.shape, t.dtype, t.op.name,
data_alignment=cfg.data_alignment,
offset_factor=cfg.offset_factor))
**default_buffer_params))
binds_list.append(buf)

if scalar_params:
Expand Down
3 changes: 1 addition & 2 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,7 @@ void GetBinds(const Array<te::Tensor>& args, bool compact,

for (const auto& x : args) {
if (out_binds->find(x) == out_binds->end()) {
auto buf = BufferWithOffsetAlignment(x->shape, x->dtype, x->op->name, config->data_alignment,
config->offset_factor, compact);
auto buf = BufferWithOffsetAlignment(x->shape, x->dtype, x->op->name, -1, 0, compact);
out_binds->Set(x, buf);
out_arg_list->push_back(buf);
} else {
Expand Down
3 changes: 0 additions & 3 deletions src/target/target.cc
Original file line number Diff line number Diff line change
Expand Up @@ -357,8 +357,6 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<BuildConfigNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const BuildConfigNode*>(node.get());
p->stream << "build_config(";
p->stream << "data_alignment=" << op->data_alignment << ", ";
p->stream << "offset_factor=" << op->offset_factor << ", ";
p->stream << "double_buffer_split_loop=" << op->double_buffer_split_loop << ", ";
p->stream << "auto_unroll_max_step=" << op->auto_unroll_max_step << ", ";
p->stream << "auto_unroll_max_depth=" << op->auto_unroll_max_depth << ", ";
Expand All @@ -367,7 +365,6 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << "restricted_func=" << op->restricted_func << ", ";
p->stream << "detect_global_barrier=" << op->detect_global_barrier << ", ";
p->stream << "partition_const_loop=" << op->partition_const_loop << ", ";
p->stream << "dump_pass_ir=" << op->dump_pass_ir << ", ";
p->stream << "instrument_bound_checkers=" << op->instrument_bound_checkers << ", ";
p->stream << "disable_select_rewriting=" << op->disable_select_rewriting;
p->stream << "disable_vectorize=" << op->disable_vectorize;
Expand Down
2 changes: 1 addition & 1 deletion src/tir/ir/buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

namespace tvm {
namespace tir {
// TODO(tqchen): change to floormod/div

using IndexMod = tir::FloorModNode;
using IndexDiv = tir::FloorDivNode;

Expand Down
26 changes: 13 additions & 13 deletions tests/python/unittest/test_te_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@ def test_fuse_with_split():
assert any(isinstance(x, tvm.te.schedule.Fuse) for x in s[T].relations)
assert tuple(s[T].leaf_iter_vars) == (xo, fused)

@pytest.mark.xfail
def test_fuse_with_out_of_order_axis():
m = te.size_var('m')
n = te.size_var('n')
Expand All @@ -125,9 +124,10 @@ def test_fuse_with_out_of_order_axis():
s = te.create_schedule(T.op)
y = T.op.axis[1]
xo, xi = s[T].split(T.op.axis[0], factor=10)
fused = s[T].fuse(xo, y) # should throw here

@pytest.mark.xfail
with pytest.raises(RuntimeError):
fused = s[T].fuse(xo, y) # should throw here

def test_fuse_with_out_of_order_axis_with_reorder():
m = te.size_var('m')
n = te.size_var('n')
Expand All @@ -144,23 +144,21 @@ def test_fuse_with_out_of_order_axis_with_reorder():
y = T.op.axis[1]
xo, xi = s[T].split(T.op.axis[0], factor=10)
s[T].reorder(y, xo, xi)
fused = s[T].fuse(y, xi) # should throw here

with pytest.raises(RuntimeError):
fused = s[T].fuse(y, xi) # should throw here

def test_singleton():
print("test singleton")
A = te.placeholder((), name='A')
T = te.compute((), lambda : A() + 1)
s = te.create_schedule(T.op)
print("test singleton fin1")
fused = s[T].fuse()
assert any(isinstance(x, tvm.te.schedule.Singleton) for x in s[T].relations)
assert tuple(s[T].leaf_iter_vars) == (fused,)
dump = pkl.dumps(s)
print("test singleton fin3")
s_loaded = pkl.loads(dump)
print("test singleton fin2")
assert isinstance(s_loaded, tvm.te.schedule.Schedule)
print("test singleton fin")


def test_vectorize():
m = te.size_var('m')
Expand All @@ -177,13 +175,14 @@ def test_vectorize():
assert s[T].iter_var_attrs[xi].iter_type == UNROLL
assert s[T].iter_var_attrs[yi].iter_type == VECTORIZE

@pytest.mark.xfail

def test_vectorize_commreduce():
V = te.placeholder((128,), name='V')
ax = te.reduce_axis((0, 128), name='ax')
O = te.compute((1,), lambda _: te.sum(V[ax], axis=[ax]))
s = te.create_schedule(O.op)
s[O].vectorize(ax) # should throw here
with pytest.raises(RuntimeError):
s[O].vectorize(ax) # should throw here

def test_pragma():
m = 100
Expand Down Expand Up @@ -271,8 +270,9 @@ def intrin_func(ins, outs, sp):
assert(sp[1] == w)
return tvm.tir.call_packed("hw_func", ins[0].data, outs[0].data, sp[0], sp[1])

with tvm.target.build_config(offset_factor=1):
intrin = te.decl_tensor_intrin(z.op, intrin_func, scalar_params=[v, w])
intrin = te.decl_tensor_intrin(z.op, intrin_func, scalar_params=[v, w], default_buffer_params={
"offset_factor": 1
})
assert intrin.op == z.op
assert intrin.reduce_init is None
assert tuple(intrin.inputs) == tuple(z.op.input_tensors)
Expand Down
12 changes: 6 additions & 6 deletions tests/python/unittest/test_te_schedule_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,10 +321,9 @@ def intrin_func(ins, outs):
"gemv_add", ww_ptr, xx_ptr, zz_ptr, n, ww.strides[0])
return body, reset, update

with tvm.target.build_config(data_alignment=16,
offset_factor=16):
return te.decl_tensor_intrin(z.op, intrin_func,
binds={w: Wb})
buffer_params = {"data_alignment": 16, "offset_factor": 16}
return te.decl_tensor_intrin(
z.op, intrin_func, binds={w: Wb}, default_buffer_params=buffer_params)


def test_schedule_tensor_compute1():
Expand Down Expand Up @@ -377,8 +376,9 @@ def intrin_func(ins, outs):
ib.emit(tvm.tir.call_extern(outs[0].dtype, 'vadd', ins[0].access_ptr("r"), ins[1].access_ptr('r'), outs[0].access_ptr('wr')))
return ib.get()

with tvm.target.build_config(offset_factor=16):
return te.decl_tensor_intrin(z.op, intrin_func, binds=binds)
return te.decl_tensor_intrin(z.op, intrin_func, binds=binds, default_buffer_params={
"offset_factor": 16
})


def test_schedule_tensor_compute2():
Expand Down
39 changes: 17 additions & 22 deletions tests/python/unittest/test_te_schedule_tensorize.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ def intrin_func(ins, outs):
xx, yy = ins
zz = outs[0]
return tvm.tir.call_packed("vadd", xx, yy, zz)
with tvm.target.build_config(offset_factor=16):
return te.decl_tensor_intrin(z.op, intrin_func)
buffer_params = {"offset_factor": 16}
return te.decl_tensor_intrin(z.op, intrin_func, default_buffer_params=buffer_params)

def intrin_gemv(m, n):
w = te.placeholder((m, n), name='w')
Expand All @@ -52,10 +52,9 @@ def intrin_func(ins, outs):
"gemv_add", ww_ptr, xx_ptr, zz_ptr, n, ww.strides[0])
return body, reset, update

with tvm.target.build_config(data_alignment=16,
offset_factor=16):
return te.decl_tensor_intrin(z.op, intrin_func,
binds={w: Wb})
buffer_params = {"offset_factor": 16, "data_alignment": 16}
return te.decl_tensor_intrin(
z.op, intrin_func, binds={w: Wb}, default_buffer_params=buffer_params)

def intrin_gemv_no_reset(m, n):
w = te.placeholder((m, n), name='w')
Expand All @@ -79,10 +78,10 @@ def intrin_func(ins, outs):
"gemv_add", ww_ptr, xx_ptr, zz_ptr, n, ww.strides[0])
return body, None, update

with tvm.target.build_config(data_alignment=16,
offset_factor=16):
return te.decl_tensor_intrin(z.op, intrin_func,
binds={w: Wb})

buffer_params = {"offset_factor": 16, "data_alignment": 16}
return te.decl_tensor_intrin(
z.op, intrin_func, binds={w: Wb}, default_buffer_params=buffer_params)


def test_tensorize_vadd():
Expand Down Expand Up @@ -248,8 +247,9 @@ def intrin_func(ins, outs):
zz = outs[0]
return tvm.tir.call_packed("op", xx, zz)

with tvm.target.build_config(offset_factor=2):
return te.decl_tensor_intrin(y.op, intrin_func)
return te.decl_tensor_intrin(y.op, intrin_func, default_buffer_params={
"offset_factor": 2
})

A = te.placeholder((5, 5), name='A')
B = te.compute((9,9), lambda i, j: A[idxd(j,3) + idxm(i,3), idxm(j,3) + idxd(i,3)])
Expand Down Expand Up @@ -286,8 +286,7 @@ def intrin_multivadd(n):
def intrin_func(ins, outs):
return tvm.tir.call_packed("multivadd")

with tvm.target.build_config():
return te.decl_tensor_intrin(z.op, intrin_func, name="multivadd")
return te.decl_tensor_intrin(z.op, intrin_func, name="multivadd")

def intrin_vadd(n):
dtype = 'float32'
Expand All @@ -297,21 +296,17 @@ def intrin_vadd(n):
s = te.create_schedule(z.op)

def create_buffer(t):
return tvm.tir.decl_buffer(t.shape, t.dtype,
name='W'+t.name,
offset_factor=16)
return tvm.tir.decl_buffer(t.shape, t.dtype, name='W'+t.name, offset_factor=16)

def intrin_func(ins, outs):
ib = tvm.tir.ir_builder.create()
ib.emit(tvm.tir.call_extern("float32", 'vadd',
ins[0].access_ptr("r"), ins[1].access_ptr('r'),
outs[0].access_ptr('wr')))
return ib.get()

with tvm.target.build_config(offset_factor=16):
return te.decl_tensor_intrin(z.op, intrin_func, binds={x: create_buffer(x),
y: create_buffer(y),
z: create_buffer(z)})
return te.decl_tensor_intrin(z.op, intrin_func, binds={x: create_buffer(x),
y: create_buffer(y),
z: create_buffer(z)})

# cache_read, cache_write
M = 1024
Expand Down
Loading

0 comments on commit 19d6680

Please sign in to comment.