Skip to content

Commit

Permalink
[IRBuilder] Attempt to Cover the POC (apache#33)
Browse files Browse the repository at this point in the history
* Add ForFrame

* move static methods to global scope

* Builder is a final object

* Add PrimFuncFrame
  • Loading branch information
junrushao committed Jul 4, 2022
1 parent 472130c commit f224597
Show file tree
Hide file tree
Showing 8 changed files with 347 additions and 26 deletions.
1 change: 1 addition & 0 deletions src/script/builder/builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ Builder Builder::Current() {
}

TVM_REGISTER_NODE_TYPE(BuilderNode);
TVM_REGISTER_NODE_TYPE(FrameNode);

} // namespace builder
} // namespace script
Expand Down
48 changes: 44 additions & 4 deletions src/script/builder/builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,39 @@

#include <tvm/node/node.h>

#include "./frame.h"

namespace tvm {
namespace script {
namespace builder {

class FrameNode : public runtime::Object {
public:
std::vector<runtime::TypedPackedFunc<void()>> callbacks;

void VisitAttrs(tvm::AttrVisitor* v) {
// `callbacks` is not visited.
}

void AddCallback(runtime::TypedPackedFunc<void()> callback) { callbacks.push_back(callback); }

static constexpr const char* _type_key = "script.builder.Frame";
TVM_DECLARE_BASE_OBJECT_INFO(FrameNode, runtime::Object);

public:
virtual ~FrameNode() {
for (auto it = callbacks.rbegin(); it != callbacks.rend(); ++it) {
(*it)();
}
}
};

class Frame : public runtime::ObjectRef {
public:
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Frame, ObjectRef, FrameNode);

protected:
Frame() = default;
};

class BuilderNode : public runtime::Object {
public:
runtime::Array<Frame> frames;
Expand All @@ -35,8 +62,21 @@ class BuilderNode : public runtime::Object {
v->Visit("frames", &frames); //
}

static constexpr const char* _type_key = "script.Builder";
TVM_DECLARE_BASE_OBJECT_INFO(BuilderNode, runtime::Object);
static constexpr const char* _type_key = "script.builder.Builder";
TVM_DECLARE_FINAL_OBJECT_INFO(BuilderNode, runtime::Object);

public:
template <typename TFrame>
TFrame FindFrame() const {
using TFrameNode = typename TFrame::ContainerType;
for (auto it = frames.rbegin(); it != frames.rend(); ++it) {
if (const TFrameNode* p = (*it).template as<TFrameNode>()) {
return GetRef<TFrame>(p);
}
}
LOG(FATAL) << "IndexError: Cannot find frame: " << TFrameNode::_type_key;
throw;
}
};

class Builder : public runtime::ObjectRef {
Expand Down
93 changes: 93 additions & 0 deletions src/script/builder/tir/for_frame.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include "./for_frame.h"

namespace tvm {
namespace script {
namespace builder {
namespace tir {

ForFrame::ForFrame(Array<tvm::tir::Var> loop_vars, ForFrame::FMakeForLoop f_make_for_loop) {
ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>();
n->loop_vars = std::move(loop_vars);
n->f_make_for_loop = std::move(f_make_for_loop);
data_ = std::move(n);
}

#define TVM_SCRIPT_BUILDER_TIR_FOR_CREATE(Method, Kind) \
With<ForFrame> Method(PrimExpr min, PrimExpr extent, Map<String, ObjectRef> attrs) { \
ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>(); \
int bits = std::max(min.dtype().bits(), extent.dtype().bits()); \
n->loop_vars = {tvm::tir::Var("v", DataType::Int(bits))}; \
n->f_make_for_loop = [=](Array<tvm::tir::Var> vars, tvm::tir::Stmt body) -> tvm::tir::For { \
ICHECK_EQ(vars.size(), 1); \
return tvm::tir::For(/*loop_var=*/vars[0], min, extent, Kind, body, \
/*thread_binding=*/NullOpt, attrs); \
}; \
return With<ForFrame>(n); \
}

TVM_SCRIPT_BUILDER_TIR_FOR_CREATE(Serial, tvm::tir::ForKind::kSerial);
TVM_SCRIPT_BUILDER_TIR_FOR_CREATE(Parallel, tvm::tir::ForKind::kParallel);
TVM_SCRIPT_BUILDER_TIR_FOR_CREATE(Vectorized, tvm::tir::ForKind::kVectorized);
TVM_SCRIPT_BUILDER_TIR_FOR_CREATE(Unroll, tvm::tir::ForKind::kUnrolled);

#undef TVM_SCRIPT_BUILDER_TIR_FOR_CREATE

With<ForFrame> ThreadBinding(PrimExpr min, PrimExpr extent, String thread,
Map<String, ObjectRef> attrs) {
using namespace tvm::tir;
ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>();
int bits = std::max(min.dtype().bits(), extent.dtype().bits());
n->loop_vars = {Var("v", DataType::Int(bits))};
n->f_make_for_loop = [=](Array<Var> vars, Stmt body) -> For {
ICHECK_EQ(vars.size(), 1);
IterVar iter_var(Range(nullptr), Var(ObjectPtr<Object>(nullptr)), IterVarType::kThreadIndex,
thread);
return For(vars[0], min, extent, tvm::tir::ForKind::kThreadBinding, body, iter_var, attrs);
};
return With<ForFrame>(n);
}

With<ForFrame> Grid(Array<PrimExpr> extents) {
using namespace tvm::tir;
ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>();
n->loop_vars.reserve(extents.size());
for (const auto& extent : extents) {
n->loop_vars.push_back(Var("v", extent.dtype()));
}
n->f_make_for_loop = [=](Array<Var> vars, Stmt body) -> Stmt {
ICHECK_EQ(extents.size(), vars.size());
int n = extents.size();
for (int i = n - 1; i >= 0; --i) {
Var var = vars[i];
PrimExpr extent = extents[i];
body = For(var, Integer(0), extent, ForKind::kSerial, body, /*thread_binding=*/NullOpt, {});
}
return body;
};
return With<ForFrame>(n);
}

TVM_REGISTER_NODE_TYPE(ForFrameNode);

} // namespace tir
} // namespace builder
} // namespace script
} // namespace tvm
80 changes: 80 additions & 0 deletions src/script/builder/tir/for_frame.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef TVM_SCRIPT_BUILDER_TIR_FOR_FRAME_H_
#define TVM_SCRIPT_BUILDER_TIR_FOR_FRAME_H_

#include <tvm/support/with.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt.h>

#include "./tir.h"

namespace tvm {
namespace script {
namespace builder {
namespace tir {

class ForFrameNode : public TIRFrameNode {
public:
using FMakeForLoop =
runtime::TypedPackedFunc<tvm::tir::Stmt(Array<tvm::tir::Var>, tvm::tir::Stmt)>;

Array<tvm::tir::Var> loop_vars;
FMakeForLoop f_make_for_loop;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("loop_vars", &loop_vars);
// `f_make_for_loop` is not visited.
}

static constexpr const char* _type_key = "script.builder.tir.ForFrame";
TVM_DECLARE_FINAL_OBJECT_INFO(ForFrameNode, TIRFrameNode);
};

class ForFrame : public TIRFrame {
public:
using FMakeForLoop = ForFrameNode::FMakeForLoop;

explicit ForFrame(Array<tvm::tir::Var> loop_vars, FMakeForLoop f_make_for_loop);

void EnterWithScope() { ICHECK(data_ != nullptr); }

void ExitWithScope() {
ICHECK(data_ != nullptr);
data_.reset();
}

TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ForFrame, TIRFrame, ForFrameNode);
};

With<ForFrame> Serial(PrimExpr min, PrimExpr extent, Map<String, ObjectRef> annotations);
With<ForFrame> Parallel(PrimExpr min, PrimExpr extent, Map<String, ObjectRef> annotations);
With<ForFrame> Vectorized(PrimExpr min, PrimExpr extent, Map<String, ObjectRef> annotations);
With<ForFrame> Unroll(PrimExpr min, PrimExpr extent, Map<String, ObjectRef> annotations);
With<ForFrame> ThreadBinding(PrimExpr min, PrimExpr extent, String thread,
Map<String, ObjectRef> annotations);
With<ForFrame> Grid(Array<PrimExpr> extents);

} // namespace tir
} // namespace builder
} // namespace script
} // namespace tvm

#endif // TVM_SCRIPT_BUILDER_TIR_FOR_FRAME_H_
45 changes: 45 additions & 0 deletions src/script/builder/tir/prim_func_frame.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#include "./prim_func_frame.h"

namespace tvm {
namespace script {
namespace builder {
namespace tir {

void Arg(tvm::tir::Var var) {
PrimFuncFrame frame = Builder::Current()->FindFrame<PrimFuncFrame>();
frame->args.push_back(var);
}

void Arg(tvm::tir::Buffer buffer) {
using namespace tvm::tir;
PrimFuncFrame frame = Builder::Current()->FindFrame<PrimFuncFrame>();
Var handle(buffer->name + "_handle", DataType::Handle());
frame->args.push_back(handle);
frame->buffer_map.Set(handle, buffer);
}

TVM_REGISTER_NODE_TYPE(PrimFuncFrameNode);

} // namespace tir
} // namespace builder
} // namespace script
} // namespace tvm
61 changes: 61 additions & 0 deletions src/script/builder/tir/prim_func_frame.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef TVM_SCRIPT_BUILDER_TIR_PRIM_FUNC_FRAME_H_
#define TVM_SCRIPT_BUILDER_TIR_PRIM_FUNC_FRAME_H_

#include "./tir.h"

namespace tvm {
namespace script {
namespace builder {
namespace tir {

class PrimFuncFrameNode : public TIRFrameNode {
public:
String name;
Array<tvm::tir::Var> args;
Type ret_type;
Map<tvm::tir::Var, tvm::tir::Buffer> buffer_map;

void VisitAttrs(tvm::AttrVisitor* v) {
TIRFrameNode::VisitAttrs(v);
v->Visit("name", &name);
v->Visit("args", &args);
v->Visit("ret_type", &ret_type);
v->Visit("buffer_map", &buffer_map);
}

static constexpr const char* _type_key = "script.builder.tir.PrimFuncFrame";
TVM_DECLARE_FINAL_OBJECT_INFO(PrimFuncFrameNode, TIRFrameNode);
};

class PrimFuncFrame : public TIRFrame {
public:
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(PrimFuncFrame, TIRFrame, PrimFuncFrameNode);
};

void Arg(tvm::tir::Var var);
void Arg(tvm::tir::Buffer buffer);

} // namespace tir
} // namespace builder
} // namespace script
} // namespace tvm

#endif // TVM_SCRIPT_BUILDER_TIR_PRIM_FUNC_FRAME_H_
6 changes: 4 additions & 2 deletions src/script/builder/frame.cc → src/script/builder/tir/tir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,16 @@
* specific language governing permissions and limitations
* under the License.
*/
#include "./frame.h"
#include "./tir.h"

namespace tvm {
namespace script {
namespace builder {
namespace tir {

TVM_REGISTER_NODE_TYPE(FrameNode);
TVM_REGISTER_NODE_TYPE(TIRFrameNode);

} // namespace tir
} // namespace builder
} // namespace script
} // namespace tvm
Loading

0 comments on commit f224597

Please sign in to comment.