Skip to content

Commit

Permalink
new Buffer (apache#57)
Browse files Browse the repository at this point in the history
* new `Buffer`

* fix `BufferProxy` args

* apply code review suggestions
  • Loading branch information
cyx-6 authored Jul 1, 2022
1 parent e9c5662 commit 541ac01
Show file tree
Hide file tree
Showing 11 changed files with 263 additions and 132 deletions.
119 changes: 111 additions & 8 deletions python/tvm/script/builder/tir/var.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,127 @@
# specific language governing permissions and limitations
# under the License.
"""TVM Script TIR Buffer"""
from tvm import tir

from tvm.ir import PrimExpr, Array, Range
from tvm.tir import Var, IntImm, BufferLoad, BufferRegion
from tvm._ffi import register_object as _register_object
from tvm.runtime import Object, DataType
from . import _ffi_api


@_register_object("script.builder.tir.Buffer")
class Buffer_(Object):
def __init__(
self,
shape,
dtype="float32",
name="buffer",
data=None,
strides=None,
elem_offset=None,
scope="",
data_alignment=0,
offset_factor=0,
buffer_type="",
axis_separators=None,
) -> None:
self.__init_handle_by_constructor__(
_ffi_api.Buffer,
shape,
dtype,
name,
data,
strides,
elem_offset,
scope,
data_alignment,
offset_factor,
buffer_type,
axis_separators,
)

@property
def data(self) -> Var:
return self.buffer.data

@property
def dtype(self) -> DataType:
return self.buffer.dtype

@property
def shape(self) -> Array:
return self.buffer.shape

@property
def axis_separators(self) -> Array:
return self.buffer.axis_separators

@property
def strides(self) -> Array:
return self.buffer.strides

@property
def elem_offset(self) -> PrimExpr:
return self.buffer.elem_offset

@property
def name(self) -> str:
return self.buffer.name

@property
def data_alignment(self) -> int:
return self.buffer.data_alignment

@property
def offset_factor(self) -> int:
return self.buffer.offset_factor

@property
def buffer_type(self) -> int:
return self.buffer.buffer_type

def __getitem__(self, indices):
if any(isinstance(index, slice) for index in indices):
region = []
for index in indices:
if isinstance(index, slice):
region.append(Range(index.start, index.stop))
else:
region.append(Range.from_min_extent(index, 1))
return BufferRegion(self.buffer, region)
else:
return BufferLoad(self.buffer, indices)


class BufferProxy:
def __call__(
self,
shape,
dtype="float32",
*,
storage_scope="",
) -> tir.Buffer:
return _ffi_api.Buffer( # pylint: disable=no-member # type: ignore
shape, dtype, "", storage_scope
name="buffer",
data=None,
strides=None,
elem_offset=None,
scope="",
data_alignment=0,
offset_factor=0,
buffer_type="",
axis_separators=None,
) -> Buffer_:
return Buffer_(
shape,
dtype,
name,
data,
strides,
elem_offset,
scope,
data_alignment,
offset_factor,
buffer_type,
axis_separators,
)

def __getitem__(self, keys) -> tir.Buffer:
def __getitem__(self, keys) -> Buffer_:
return self(*keys) # pylint: disable=no-member # type: ignore


Expand Down
22 changes: 11 additions & 11 deletions src/script/builder/tir/block_frame.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

#include "./for_frame.h"
#include "./utils.h"
#include "./var.h"

namespace tvm {
namespace script {
Expand All @@ -47,17 +46,20 @@ BlockFrame Block(String name, bool no_realize) {
}

void BlockFrameNode::ExitWithScope() {
using namespace tvm::tir;
TIRFrameNode::ExitWithScope();
tvm::tir::Block block(iter_vars, reads, writes, name, AsStmt(stmts), init, alloc_buffers,
Array<tvm::tir::Buffer> tir_alloc_buffers;
for (const Buffer& buffer : alloc_buffers) {
tir_alloc_buffers.push_back(buffer->buffer);
}
tvm::tir::Block block(iter_vars, reads, writes, name, AsStmt(stmts), init, tir_alloc_buffers,
match_buffers, annotations);
if (no_realize) {
CHECK(iter_values.empty())
<< "ValueError: Block bindings are not allowed when `no_realize=True`";
CHECK(!predicate.defined()) << "ValueError: `T.where` is not allowed when `no_realize=True`";
AddToParent(block);
} else {
AddToParent(BlockRealize(iter_values, predicate.value_or(Bool(true)), block));
AddToParent(tvm::tir::BlockRealize(iter_values, predicate.value_or(Bool(true)), block));
}
}

Expand Down Expand Up @@ -142,13 +144,11 @@ void BlockAttrs(Map<String, ObjectRef> attrs) {
frame->annotations = attrs;
}

tvm::tir::Buffer AllocBuffer(Array<PrimExpr> shape, DataType dtype, Optional<tvm::tir::Var> data,
Array<PrimExpr> strides, PrimExpr elem_offset, String storage_scope,
int align, int offset_factor, String buffer_type_str,
Array<IntImm> axis_separators) {
using namespace tvm::tir;
tvm::tir::Buffer buffer = DeclBuffer(shape, dtype, "", data, strides, elem_offset, storage_scope,
align, offset_factor, buffer_type_str, axis_separators);
Buffer AllocBuffer(Array<PrimExpr> shape, DataType dtype, Optional<tvm::tir::Var> data,
Array<PrimExpr> strides, PrimExpr elem_offset, String storage_scope, int align,
int offset_factor, String buffer_type_str, Array<IntImm> axis_separators) {
Buffer buffer(shape, dtype, "", data, strides, elem_offset, storage_scope, align, offset_factor,
buffer_type_str, axis_separators);
BlockFrame frame = FindBlockFrame("T.alloc_buffer");
frame->alloc_buffers.push_back(buffer);
return buffer;
Expand Down
14 changes: 7 additions & 7 deletions src/script/builder/tir/block_frame.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#define TVM_SCRIPT_BUILDER_TIR_BLOCK_FRAME_H_

#include "./base.h"
#include "./var.h"

namespace tvm {
namespace script {
Expand All @@ -33,7 +34,7 @@ class BlockFrameNode : public TIRFrameNode {
Array<tvm::tir::BufferRegion> reads;
Array<tvm::tir::BufferRegion> writes;
Optional<tvm::tir::Stmt> init;
Array<tvm::tir::Buffer> alloc_buffers;
Array<Buffer> alloc_buffers;
Array<tvm::tir::MatchBufferRegion> match_buffers;
Map<String, ObjectRef> annotations;

Expand Down Expand Up @@ -93,12 +94,11 @@ void Where(PrimExpr predicate);
void Reads(Array<ObjectRef> buffer_slices);
void Writes(Array<ObjectRef> buffer_slices);
void BlockAttrs(Map<String, ObjectRef> attrs);
tvm::tir::Buffer AllocBuffer(Array<PrimExpr> shape, DataType dtype = DataType::Float(32),
Optional<tvm::tir::Var> data = NullOpt, Array<PrimExpr> strides = {},
PrimExpr elem_offset = PrimExpr(), String storage_scope = "",
int align = -1, int offset_factor = 0,
String buffer_type_str = "default",
Array<IntImm> axis_separators = {});
Buffer AllocBuffer(Array<PrimExpr> shape, DataType dtype = DataType::Float(32),
Optional<tvm::tir::Var> data = NullOpt, Array<PrimExpr> strides = {},
PrimExpr elem_offset = PrimExpr(), String storage_scope = "", int align = -1,
int offset_factor = 0, String buffer_type_str = "default",
Array<IntImm> axis_separators = {});

namespace axis {
tvm::tir::IterVar Spatial(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32));
Expand Down
68 changes: 36 additions & 32 deletions src/script/builder/tir/prim_func_frame.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,26 +53,34 @@ void PrimFuncFrameNode::EnterWithScope() {
}

void PrimFuncFrameNode::ExitWithScope() {
using namespace tvm::tir;
using ir::IRModuleFrame;
root_block_frame->ExitWithScope();
TIRFrameNode::ExitWithScope();
Builder builder = Builder::Current();
if (!(stmts.size() == 1 && stmts[0]->IsInstance<BlockRealizeNode>())) {
if (!(stmts.size() == 1 && stmts[0]->IsInstance<tvm::tir::BlockRealizeNode>())) {
LOG(FATAL) << "ValueError: PrimFuncFrame shoulde have one and only one root block.";
}
BlockRealize root_block_realize = Downcast<BlockRealize>(stmts[0]);
tvm::tir::BlockRealize root_block_realize = Downcast<tvm::tir::BlockRealize>(stmts[0]);
tvm::tir::Block root_block = root_block_realize->block;
// remove redundant implicit root block
if (root_block->alloc_buffers.empty() && root_block->body->IsInstance<BlockRealizeNode>() &&
if (root_block->alloc_buffers.empty() &&
root_block->body->IsInstance<tvm::tir::BlockRealizeNode>() &&
root_block->annotations.empty() && root_block->reads.empty() && root_block->writes.empty()) {
stmts.Set(0, root_block->body);
}
Map<tvm::tir::Var, tvm::tir::Buffer> tir_buffer_map;
Map<tvm::tir::Var, tvm::tir::Buffer> tir_preflattened_buffer_map;
for (auto const& p : buffer_map) {
tir_buffer_map.Set(p.first, p.second->buffer);
}
for (auto const& p : preflattened_buffer_map) {
tir_preflattened_buffer_map.Set(p.first, p.second->buffer);
}
tvm::tir::PrimFunc func(/*params=*/args,
/*body=*/AsStmt(stmts),
/*ret_type=*/ret_type.value_or(TupleType::Empty()),
/*buffer_map=*/buffer_map,
/*preflattened_buffer_map=*/preflattened_buffer_map,
/*buffer_map=*/tir_buffer_map,
/*preflattened_buffer_map=*/tir_preflattened_buffer_map,
/*attrs=*/DictAttrs(attrs));
if (builder->frames.empty()) {
ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set";
Expand Down Expand Up @@ -105,11 +113,10 @@ tvm::tir::Var Arg(String name, tvm::tir::Var var) {
return var;
}

tvm::tir::Buffer Arg(String name, tvm::tir::Buffer buffer) {
using namespace tvm::tir;
Buffer Arg(String name, Buffer buffer) {
PrimFuncFrame frame = FindPrimFuncFrame("T.Arg");
Namer::Name(buffer, name);
Var handle(buffer->name + "_handle", DataType::Handle());
tvm::tir::Var handle(buffer->buffer->name + "_handle", DataType::Handle());
frame->args.push_back(handle);
frame->buffer_map.Set(handle, buffer);
return buffer;
Expand Down Expand Up @@ -142,51 +149,48 @@ tvm::Type FuncRet(tvm::Type ret_type) {
return ret_type;
}

tvm::tir::Buffer MatchBuffer(ObjectRef param, Array<PrimExpr> shape, DataType dtype,
Optional<tvm::tir::Var> data, Array<PrimExpr> strides,
PrimExpr elem_offset, String storage_scope, int align,
int offset_factor, String buffer_type_str,
Array<IntImm> axis_separators) {
using namespace tvm::tir;
tvm::tir::Buffer buffer = DeclBuffer(shape, dtype, "", data, strides, elem_offset, storage_scope,
align, offset_factor, buffer_type_str, axis_separators);
if (const auto* var = param.as<VarNode>()) {
Buffer MatchBuffer(ObjectRef param, Array<PrimExpr> shape, DataType dtype,
Optional<tvm::tir::Var> data, Array<PrimExpr> strides, PrimExpr elem_offset,
String storage_scope, int align, int offset_factor, String buffer_type_str,
Array<IntImm> axis_separators) {
Buffer buffer(shape, dtype, "", data, strides, elem_offset, storage_scope, align, offset_factor,
buffer_type_str, axis_separators);
if (const auto* var = param.as<tvm::tir::VarNode>()) {
PrimFuncFrame frame = FindPrimFuncFrame("T.match_buffer");
Var v = GetRef<Var>(var);
tvm::tir::Var v = GetRef<tvm::tir::Var>(var);
for (auto const& arg : frame->args) {
if (arg.same_as(v)) {
frame->buffer_map.Set(v, buffer);
return buffer;
}
}
LOG(FATAL) << "ValueError: Can not bind non-input param to buffer.";
} else if (const auto* buffer_region = param.as<BufferRegionNode>()) {
} else if (const auto* buffer_region = param.as<tvm::tir::BufferRegionNode>()) {
BlockFrame frame = FindBlockFrame("T.match_buffer");
frame->match_buffers.push_back(MatchBufferRegion(buffer, GetRef<BufferRegion>(buffer_region)));
frame->match_buffers.push_back(
tvm::tir::MatchBufferRegion(buffer->buffer, GetRef<tvm::tir::BufferRegion>(buffer_region)));
} else {
LOG(FATAL) << "ValueError: Unexpected type for TIR MatchBuffer.";
}
return buffer;
};

void PreflattenedBuffer(tvm::tir::Buffer postflattened_buffer, Array<PrimExpr> shape,
DataType dtype, Optional<tvm::tir::Var> data, Array<PrimExpr> strides,
PrimExpr elem_offset, String storage_scope, int align, int offset_factor,
String buffer_type_str, Array<IntImm> axis_separators) {
using namespace tvm::tir;
void PreflattenedBuffer(Buffer postflattened_buffer, Array<PrimExpr> shape, DataType dtype,
Optional<tvm::tir::Var> data, Array<PrimExpr> strides, PrimExpr elem_offset,
String storage_scope, int align, int offset_factor, String buffer_type_str,
Array<IntImm> axis_separators) {
PrimFuncFrame frame = FindPrimFuncFrame("T.preflattened_buffer");
for (auto const& p : frame->buffer_map) {
if (p.second.same_as(postflattened_buffer)) {
String buffer_name(postflattened_buffer->name + "_preflatten");
tvm::tir::Buffer buffer =
DeclBuffer(shape, dtype, buffer_name, data, strides, elem_offset, storage_scope, align,
offset_factor, buffer_type_str, axis_separators);
String buffer_name(postflattened_buffer->buffer->name + "_preflatten");
Buffer buffer(shape, dtype, buffer_name, data, strides, elem_offset, storage_scope, align,
offset_factor, buffer_type_str, axis_separators);
Namer::Name(buffer, buffer_name);
frame->preflattened_buffer_map.Set(p.first, buffer);
return;
}
}
LOG(FATAL) << "ValueError: postflattened buffer " << postflattened_buffer->name
LOG(FATAL) << "ValueError: postflattened buffer " << postflattened_buffer->buffer->name
<< " does not exist.";
};

Expand All @@ -199,7 +203,7 @@ TVM_REGISTER_GLOBAL("script.builder.tir.Arg")
return Arg(name, GetRef<tvm::tir::Var>(var));
}
if (const auto* buffer = obj.as<BufferNode>()) {
return Arg(name, GetRef<tvm::tir::Buffer>(buffer));
return Arg(name, GetRef<Buffer>(buffer));
}
LOG(FATAL) << "ValueError: Unexpected type for TIR Arg: " << obj->GetTypeKey();
throw;
Expand Down
Loading

0 comments on commit 541ac01

Please sign in to comment.