Skip to content

Commit

Permalink
prim_func methods (apache#42)
Browse files Browse the repository at this point in the history
  • Loading branch information
cyx-6 authored and junrushao committed Jun 25, 2022
1 parent 01f7bbc commit ad73df3
Show file tree
Hide file tree
Showing 11 changed files with 203 additions and 32 deletions.
2 changes: 1 addition & 1 deletion python/tvm/script/builder/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,6 @@
unroll,
vectorized,
)
from .prim_func_frame import arg, prim_func
from .prim_func_frame import arg, func_attr, func_ret, prim_func, match_buffer, preflattened_buffer
from .var import Buffer
from .op import *
4 changes: 4 additions & 0 deletions python/tvm/script/builder/tir/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,10 @@ def float64(expr):
return _ffi_api.PrimType("float64", expr)


def handle():
return _ffi_api.Handle()


def min(a, b, span=None):
"""Compute the minimum value of two expressions.
Expand Down
71 changes: 70 additions & 1 deletion python/tvm/script/builder/tir/prim_func_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
# specific language governing permissions and limitations
# under the License.
"""TVM Script TIR Prim Func Frame"""
from typing import Union
from typing import Union, Dict, Any

from tvm._ffi import register_object as _register_object
from tvm.tir.buffer import Buffer
from tvm.tir.expr import Var
from tvm.ir import Type

from ..builder import Builder
from . import _ffi_api
Expand All @@ -40,3 +41,71 @@ def arg(name, obj) -> Union[Var, Buffer]:


setattr(prim_func, "dispatch_token", "tir")


def func_attr(attrs: Dict[str, Any]) -> None:
return _ffi_api.FuncAttrs(attrs) # pylint: disable=no-member # type: ignore


def func_ret(ret_type) -> Type:
return _ffi_api.FuncRet(ret_type) # pylint: disable=no-member # type: ignore


def match_buffer(
param,
shape,
dtype="float32",
data=None,
strides=[],
elem_offset=None,
storage_scope="",
align=-1,
offset_factor=0,
buffer_type="default",
axis_separators=None,
span=None,
) -> Buffer:
return _ffi_api.MatchBuffer(
param,
shape,
dtype,
data,
strides,
elem_offset,
storage_scope,
align,
offset_factor,
buffer_type,
axis_separators,
span,
)


def preflattened_buffer(
postflattened,
shape,
dtype="float32",
data=None,
strides=[],
elem_offset=None,
storage_scope="",
align=-1,
offset_factor=0,
buffer_type="default",
axis_separators=None,
span=None,
) -> None:
_ffi_api.PreflattenedBuffer(
postflattened,
shape,
dtype,
data,
strides,
elem_offset,
storage_scope,
align,
offset_factor,
buffer_type,
axis_separators,
span,
)
2 changes: 1 addition & 1 deletion python/tvm/script/builder/tir/var.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

def Buffer( # pylint: disable=invalid-name
shape,
dtype,
dtype="float32",
name="buffer",
storage_scope="",
) -> tir.Buffer:
Expand Down
27 changes: 4 additions & 23 deletions python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,36 +44,17 @@

from .function import PrimFunc, TensorIntrin, IndexMap

from .op import call_packed, call_intrin, call_pure_extern, call_extern
from .op import (
call_llvm_intrin,
call_llvm_pure_intrin,
ret,
all,
any,
min_value,
max_value,
trace,
)
from .op import call_packed, call_cpacked, call_intrin, call_pure_extern, call_extern
from .op import call_llvm_intrin, call_llvm_pure_intrin, ret, all, any, min_value, max_value, trace
from .op import exp, exp2, exp10, log, log2, log10, log1p, ldexp, clz
from .op import sin, sinh, asin, asinh
from .op import cos, cosh, acos, acosh
from .op import tan, tanh, atan, atan2, atanh
from .op import erf, sigmoid, sqrt, rsqrt, floor, ceil, hypot
from .op import (
trunc,
abs,
round,
nextafter,
nearbyint,
power,
popcount,
fmod,
if_then_else,
)
from .op import trunc, abs, round, nextafter, nearbyint, power, popcount, fmod, if_then_else
from .op import isnan, isfinite, isinf, copysign
from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod
from .op import comm_reducer, min, max, sum, infinity, reinterpret
from .op import comm_reducer, min, max, sum
from .op import q_multiply_shift

from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule, ScheduleError
Expand Down
1 change: 1 addition & 0 deletions src/script/builder/tir/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ PrimExpr prim_type(String type_name, PrimExpr expr) {
}

TVM_REGISTER_GLOBAL("script.builder.tir.PrimType").set_body_typed(prim_type);
TVM_REGISTER_GLOBAL("script.builder.tir.Handle").set_body_typed(handle);
TVM_REGISTER_GLOBAL("script.builder.tir.min").set_body_typed([](PrimExpr a, PrimExpr b, Span span) {
return tvm::min(a, b, span);
});
Expand Down
2 changes: 2 additions & 0 deletions src/script/builder/tir/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ PrimExpr bool_(PrimExpr expr) { return cast(DataType::Bool(), expr); }

PrimExpr prim_type(String type_name, PrimExpr expr);

tvm::tir::Var handle() { return tvm::tir::Var("", DataType::Handle()); }

using tvm::cast;
using tvm::if_then_else;
using tvm::infinity;
Expand Down
85 changes: 84 additions & 1 deletion src/script/builder/tir/prim_func_frame.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

#include <tvm/tir/function.h>

#include "./block_frame.h"

namespace tvm {
namespace script {
namespace builder {
Expand All @@ -33,7 +35,9 @@ void PrimFuncFrameNode::ExitWithScope() {
PrimFunc func(/*params=*/args,
/*body=*/AsStmt(stmts),
/*ret_type=*/ret_type,
/*buffer_map=*/buffer_map);
/*buffer_map=*/buffer_map,
/*preflattened_buffer_map=*/preflattened_buffer_map,
/*attrs=*/DictAttrs(attrs));
if (builder->frames.empty()) {
ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set";
builder->result = func;
Expand All @@ -52,6 +56,8 @@ PrimFuncFrame PrimFunc_(String name) {
n->args.clear();
n->ret_type = TupleType::Empty();
n->buffer_map.clear();
n->preflattened_buffer_map.clear();
n->attrs.clear();
return PrimFuncFrame(n);
}

Expand All @@ -72,6 +78,79 @@ tvm::tir::Buffer Arg(String name, tvm::tir::Buffer buffer) {
return buffer;
}

void FuncAttrs(Map<String, ObjectRef> attrs) {
using namespace tvm::tir;
PrimFuncFrame frame = Builder::Current()->FindFrame<PrimFuncFrame>().value();
frame->attrs = attrs;
}

tvm::Type FuncRet(tvm::Type ret_type) {
PrimFuncFrame frame = Builder::Current()->FindFrame<PrimFuncFrame>().value();
frame->ret_type = ret_type;
return ret_type;
}

tvm::tir::Buffer MatchBuffer(ObjectRef param, Array<PrimExpr> shape, DataType dtype,
Optional<tvm::tir::Var> data, Array<PrimExpr> strides,
PrimExpr elem_offset, String storage_scope, int align,
int offset_factor, String buffer_type_str,
Array<IntImm> axis_separators, Span span) {
using namespace tvm::tir;
Var buffer_data;
if (!data.defined()) {
DataType storage_dtype = dtype;
if (storage_dtype == DataType::Bool()) {
storage_dtype = DataType::Int(8);
}
buffer_data = Var("", PointerType(PrimType(storage_dtype), storage_scope), span);
} else {
buffer_data = data.value();
}
BufferType buffer_type = (buffer_type_str == "auto_broadcast") ? kAutoBroadcast : kDefault;
Buffer buffer(buffer_data, dtype, shape, strides, elem_offset, "", align, offset_factor,
buffer_type, axis_separators, span);
PrimFuncFrame frame = Builder::Current()->FindFrame<PrimFuncFrame>().value();
if (const auto* var = param.as<VarNode>()) {
Var v = GetRef<Var>(var);
for (auto const& arg : frame->args) {
if (arg.same_as(v)) {
frame->buffer_map.Set(v, buffer);
return buffer;
}
}
LOG(FATAL) << "ValueError: Can not bind non-input param to buffer.";
} else if (const auto* buffer_region = param.as<BufferRegionNode>()) {
BlockFrame block_frame = Builder::Current()->FindFrame<BlockFrame>().value();
block_frame->match_buffers.push_back(
MatchBufferRegion(buffer, GetRef<BufferRegion>(buffer_region)));
} else {
LOG(FATAL) << "ValueError: Unexpected type for TIR MatchBuffer.";
}
return buffer;
};

void PreflattenedBuffer(tvm::tir::Buffer postflattened_buffer, Array<PrimExpr> shape,
DataType dtype, Optional<tvm::tir::Var> data, Array<PrimExpr> strides,
PrimExpr elem_offset, String storage_scope, int align, int offset_factor,
String buffer_type_str, Array<IntImm> axis_separators, Span span) {
using namespace tvm::tir;
PrimFuncFrame frame = Builder::Current()->FindFrame<PrimFuncFrame>().value();
for (auto const& p : frame->buffer_map) {
if (p.second.same_as(postflattened_buffer)) {
Var buffer_data = (data.defined()) ? data.value() : frame->buffer_map.at(p.first)->data;
String buffer_name(postflattened_buffer->name + "_preflatten");
BufferType buffer_type = (buffer_type_str == "auto_broadcast") ? kAutoBroadcast : kDefault;
Buffer buffer(buffer_data, dtype, shape, strides, elem_offset, buffer_name, align,
offset_factor, buffer_type, axis_separators, span);
Namer::Name(buffer, buffer_name);
frame->preflattened_buffer_map.Set(p.first, buffer);
return;
}
}
LOG(FATAL) << "ValueError: postflattened buffer " << postflattened_buffer->name
<< " does not exist.";
};

TVM_REGISTER_NODE_TYPE(PrimFuncFrameNode);
TVM_REGISTER_GLOBAL("script.builder.tir.PrimFuncFrame").set_body_typed(PrimFunc_);
TVM_REGISTER_GLOBAL("script.builder.tir.Arg")
Expand All @@ -86,6 +165,10 @@ TVM_REGISTER_GLOBAL("script.builder.tir.Arg")
LOG(FATAL) << "ValueError: Unexpected type for TIR Arg.";
throw;
});
TVM_REGISTER_GLOBAL("script.builder.tir.FuncAttrs").set_body_typed(FuncAttrs);
TVM_REGISTER_GLOBAL("script.builder.tir.FuncRet").set_body_typed(FuncRet);
TVM_REGISTER_GLOBAL("script.builder.tir.MatchBuffer").set_body_typed(MatchBuffer);
TVM_REGISTER_GLOBAL("script.builder.tir.PreflattenedBuffer").set_body_typed(PreflattenedBuffer);

} // namespace tir
} // namespace builder
Expand Down
21 changes: 21 additions & 0 deletions src/script/builder/tir/prim_func_frame.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,17 @@ class PrimFuncFrameNode : public TIRFrameNode {
Array<tvm::tir::Var> args;
Type ret_type;
Map<tvm::tir::Var, tvm::tir::Buffer> buffer_map;
Map<tvm::tir::Var, tvm::tir::Buffer> preflattened_buffer_map;
Map<String, ObjectRef> attrs;

void VisitAttrs(tvm::AttrVisitor* v) {
TIRFrameNode::VisitAttrs(v);
v->Visit("name", &name);
v->Visit("args", &args);
v->Visit("ret_type", &ret_type);
v->Visit("buffer_map", &buffer_map);
v->Visit("preflattened_buffer_map", &preflattened_buffer_map);
v->Visit("attrs", &attrs);
}

static constexpr const char* _type_key = "script.builder.tir.PrimFuncFrame";
Expand All @@ -56,6 +60,23 @@ class PrimFuncFrame : public TIRFrame {
PrimFuncFrame PrimFunc_(String name);
tvm::tir::Var Arg(String name, tvm::tir::Var var);
tvm::tir::Buffer Arg(String name, tvm::tir::Buffer buffer);
void FuncAttrs(Map<String, ObjectRef> attrs);
tvm::Type FuncRet(tvm::Type ret_type);

tvm::tir::Buffer MatchBuffer(ObjectRef param, Array<PrimExpr> shape,
DataType dtype = DataType::Float(32),
Optional<tvm::tir::Var> data = NullOpt, Array<PrimExpr> strides = {},
PrimExpr elem_offset = PrimExpr(), String storage_scope = "",
int align = -1, int offset_factor = 0,
String buffer_type_str = "default", Array<IntImm> axis_separators = {},
Span span = Span());

void PreflattenedBuffer(tvm::tir::Buffer postflattened_buffer, Array<PrimExpr> shape,
DataType dtype = DataType::Float(32),
Optional<tvm::tir::Var> data = NullOpt, Array<PrimExpr> strides = {},
PrimExpr elem_offset = PrimExpr(), String storage_scope = "",
int align = -1, int offset_factor = 0, String buffer_type_str = "default",
Array<IntImm> axis_separators = {}, Span span = Span());

} // namespace tir
} // namespace builder
Expand Down
6 changes: 3 additions & 3 deletions src/script/builder/tir/var.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ namespace script {
namespace builder {
namespace tir {

tvm::tir::Buffer Buffer_(Array<PrimExpr> shape, //
DataType dtype, //
String name = "buffer", //
tvm::tir::Buffer Buffer_(Array<PrimExpr> shape, //
DataType dtype = DataType::Float(32), //
String name = "buffer", //
String storage_scope = "");

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,18 @@
def test_builder_basic():
with Builder() as b:
with T.prim_func(name="main"):
A = T.arg("A", T.Buffer((128, 128, 128), "float32"))
B = T.arg("B", T.Buffer((128, 128, 128), "float32"))
T.func_attr({"global_symbol": "main"})
arg_a = T.arg("a", T.handle())
arg_b = T.arg("b", T.handle())
buffer_c = T.Buffer((128,), "float32")
buffer_d = T.Buffer((128,), "float32")
arg_c = T.arg("c", buffer_c)
arg_d = T.arg("d", buffer_d)
T.func_ret(tvm.ir.PrimType("int8"))
A = def_("A", T.match_buffer(arg_a, (128, 128, 128)))
B = def_("B", T.match_buffer(arg_b, (128, 128, 128)))
T.preflattened_buffer(buffer_c, (128,), data=buffer_c.data)
T.preflattened_buffer(buffer_d, (128,), data=buffer_d.data)
with T.grid(128, 128, 128) as (i, j, k):
def_many(["i", "j", "k"], [i, j, k])
with T.block(name="block"):
Expand Down

0 comments on commit ad73df3

Please sign in to comment.