diff --git a/src/script/builder/builder.cc b/src/script/builder/builder.cc index c04a1b5f346a..ed64de24396c 100644 --- a/src/script/builder/builder.cc +++ b/src/script/builder/builder.cc @@ -22,30 +22,57 @@ namespace tvm { namespace script { namespace builder { +Builder::Builder() { + ObjectPtr n = make_object(); + n->frames.clear(); + n->result = NullOpt; + data_ = n; +} + std::vector* ThreadLocalBuilderStack() { thread_local std::vector stack; return &stack; } void Builder::EnterWithScope() { + BuilderNode* n = this->get(); + CHECK(n->frames.empty()) << "ValueError: There are frame(s) left in the builder: " + << n->frames.size() + << ". Please use a fresh new builder every time building IRs"; + n->frames.push_back(IRModuleFrame()); std::vector* stack = ThreadLocalBuilderStack(); stack->push_back(*this); } void Builder::ExitWithScope() { + BuilderNode* n = this->get(); + ICHECK_EQ(n->frames.size(), 1); + IRModuleFrame frame = Downcast(n->frames.back()); + n->frames.pop_back(); std::vector* stack = ThreadLocalBuilderStack(); - CHECK(!stack->empty()); + ICHECK(!stack->empty()); stack->pop_back(); + if (!frame->stmts.empty()) { + ICHECK(frame->global_vars.empty()); + ICHECK(frame->functions.empty()); + n->result = frame->stmts; + } else { + Map func_map; + ICHECK_EQ(frame->functions.size(), frame->global_vars.size()); + int m = frame->functions.size(); + for (int i = 0; i < m; ++i) { + func_map.Set(frame->global_vars[i], frame->functions[i]); + } + } } Builder Builder::Current() { std::vector* stack = ThreadLocalBuilderStack(); - CHECK(!stack->empty()); + CHECK(!stack->empty()) << "ValueError: No builder in current scope"; return stack->back(); } TVM_REGISTER_NODE_TYPE(BuilderNode); -TVM_REGISTER_NODE_TYPE(FrameNode); } // namespace builder } // namespace script diff --git a/src/script/builder/builder.h b/src/script/builder/builder.h index 53700ba8c64d..506ba2030d69 100644 --- a/src/script/builder/builder.h +++ b/src/script/builder/builder.h @@ -21,60 +21,20 @@ #include +#include "./frame.h" + namespace tvm { namespace script { namespace builder { -class FrameNode : public runtime::Object { - public: - std::vector> callbacks; - - void VisitAttrs(tvm::AttrVisitor* v) { - // `callbacks` is not visited. - } - - void AddCallback(runtime::TypedPackedFunc callback) { callbacks.push_back(callback); } - - static constexpr const char* _type_key = "script.builder.Frame"; - TVM_DECLARE_BASE_OBJECT_INFO(FrameNode, runtime::Object); - - public: - virtual void EnterWithScope() {} - - virtual void ExitWithScope() {} - - virtual ~FrameNode() { - for (auto it = callbacks.rbegin(); it != callbacks.rend(); ++it) { - (*it)(); - } - } -}; - -class Frame : public runtime::ObjectRef { - public: - void EnterWithScope() { - ICHECK(data_ != nullptr); - static_cast(data_.get())->EnterWithScope(); - } - - void ExitWithScope() { - ICHECK(data_ != nullptr); - static_cast(data_.get())->ExitWithScope(); - data_.reset(); - } - - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Frame, ObjectRef, FrameNode); - - protected: - Frame() = default; -}; - class BuilderNode : public runtime::Object { public: runtime::Array frames; + Optional result; void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("frames", &frames); // + v->Visit("frames", &frames); + v->Visit("result", &result); } static constexpr const char* _type_key = "script.builder.Builder"; @@ -82,29 +42,43 @@ class BuilderNode : public runtime::Object { public: template - Optional FindFrame() const { - using TFrameNode = typename TFrame::ContainerType; - for (auto it = frames.rbegin(); it != frames.rend(); ++it) { - if (const TFrameNode* p = (*it).template as()) { - return GetRef(p); - } - } - return NullOpt; - } + inline Optional FindFrame() const; + + template + inline TObjectRef Get() const; }; class Builder : public runtime::ObjectRef { public: - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Builder, ObjectRef, BuilderNode); + Builder(); + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Builder, ObjectRef, BuilderNode); public: void EnterWithScope(); - void ExitWithScope(); - static Builder Current(); }; +template +inline Optional BuilderNode::FindFrame() const { + using TFrameNode = typename TFrame::ContainerType; + for (auto it = frames.rbegin(); it != frames.rend(); ++it) { + if (const TFrameNode* p = (*it).template as()) { + return GetRef(p); + } + } + return NullOpt; +} + +template +inline TObjectRef BuilderNode::Get() const { + using TObject = typename TObjectRef::ContainerType; + CHECK(result.defined()) << "IndexError: No result exists in IRBuilder yet"; + const auto* n = result.as(); + CHECK(n != nullptr) << "IndexError: IRBuilder result is not of type: " << TObject::_type_key; + return GetRef(n); +} + } // namespace builder } // namespace script } // namespace tvm diff --git a/src/script/builder/frame.cc b/src/script/builder/frame.cc new file mode 100644 index 000000000000..8db03cfbc482 --- /dev/null +++ b/src/script/builder/frame.cc @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "./builder.h" + +namespace tvm { +namespace script { +namespace builder { + +void FrameNode::EnterWithScope() { + // Push to the current builder + Builder::Current()->frames.push_back(GetRef(this)); +} + +void FrameNode::ExitWithScope() { + for (auto it = callbacks.rbegin(); it != callbacks.rend(); ++it) { + (*it)(); + } + this->callbacks.clear(); + Builder::Current()->frames.pop_back(); +} + +IRModuleFrame::IRModuleFrame() { + ObjectPtr n = make_object(); + n->global_vars.clear(); + n->functions.clear(); + n->stmts.clear(); + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(FrameNode); + +} // namespace builder +} // namespace script +} // namespace tvm diff --git a/src/script/builder/frame.h b/src/script/builder/frame.h new file mode 100644 index 000000000000..04916b6b8842 --- /dev/null +++ b/src/script/builder/frame.h @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_BUILDER_FRAME_H_ +#define TVM_SCRIPT_BUILDER_FRAME_H_ + +#include +#include + +namespace tvm { +namespace script { +namespace builder { + +class FrameNode : public runtime::Object { + public: + std::vector> callbacks; + + void VisitAttrs(tvm::AttrVisitor* v) { + // `callbacks` is not visited. + } + + static constexpr const char* _type_key = "script.builder.Frame"; + TVM_DECLARE_BASE_OBJECT_INFO(FrameNode, runtime::Object); + + public: + virtual ~FrameNode() = default; + virtual void EnterWithScope(); + virtual void ExitWithScope(); +}; + +class Frame : public runtime::ObjectRef { + public: + virtual ~Frame() = default; + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Frame, ObjectRef, FrameNode); + + protected: + Frame() = default; + + public: + inline void EnterWithScope(); + inline void ExitWithScope(); +}; + +class IRModuleFrameNode : public FrameNode { + public: + Array global_vars; + Array functions; + Array stmts; + + void VisitAttrs(tvm::AttrVisitor* v) { + FrameNode::VisitAttrs(v); + v->Visit("global_vars", &global_vars); + v->Visit("functions", &functions); + v->Visit("stmts", &stmts); + } + + static constexpr const char* _type_key = "script.builder.IRModuleFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(IRModuleFrameNode, FrameNode); +}; + +class IRModuleFrame : public Frame { + public: + IRModuleFrame(); + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IRModuleFrame, Frame, IRModuleFrameNode); +}; + +inline void Frame::EnterWithScope() { + ICHECK(data_ != nullptr); + static_cast(data_.get())->EnterWithScope(); +} + +inline void Frame::ExitWithScope() { + ICHECK(data_ != nullptr); + static_cast(data_.get())->ExitWithScope(); + data_.reset(); +} + +} // namespace builder +} // namespace script +} // namespace tvm + +#endif // TVM_SCRIPT_BUILDER_FRAME_H_ diff --git a/src/script/builder/tir/base.cc b/src/script/builder/tir/base.cc new file mode 100644 index 000000000000..4f764435136a --- /dev/null +++ b/src/script/builder/tir/base.cc @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "./base.h" + +#include + +#include "./block_frame.h" +#include "./for_frame.h" +#include "./prim_func_frame.h" +#include "./var.h" + +namespace tvm { +namespace script { +namespace builder { +namespace tir { + +TVM_REGISTER_NODE_TYPE(TIRFrameNode); + +void TestPOC() { + namespace T = tvm::script::builder::tir; + using namespace ::tvm::tir; + + With builder; + { + With _{T::PrimFunc_("main")}; + Buffer A = T::Buffer_({128, 128, 128}, DataType::Float(32)); + Buffer B = T::Buffer_({128, 128, 128}, DataType::Float(32)); + { + With _{T::Grid({128, 128, 128})}; + Var i = _()->vars[0]; + Var j = _()->vars[1]; + Var k = _()->vars[2]; + { + With _{T::Block_("block")}; + IterVar vi = T::axis::Spatial(Range(0, 128), i); + IterVar vj = T::axis::Spatial(Range(0, 128), j); + IterVar vk = T::axis::Spatial(Range(0, 128), k); + } + } + } +} + +} // namespace tir +} // namespace builder +} // namespace script +} // namespace tvm diff --git a/src/script/builder/tir/tir.h b/src/script/builder/tir/base.h similarity index 61% rename from src/script/builder/tir/tir.h rename to src/script/builder/tir/base.h index d5638413b871..a56826eb0718 100644 --- a/src/script/builder/tir/tir.h +++ b/src/script/builder/tir/base.h @@ -16,10 +16,13 @@ * specific language governing permissions and limitations * under the License. */ -#ifndef TVM_SCRIPT_BUILDER_TIR_TIR_H_ -#define TVM_SCRIPT_BUILDER_TIR_TIR_H_ +#ifndef TVM_SCRIPT_BUILDER_TIR_BASE_H_ +#define TVM_SCRIPT_BUILDER_TIR_BASE_H_ +#include +#include #include +#include #include "../builder.h" @@ -49,9 +52,33 @@ class TIRFrame : public Frame { TIRFrame() = default; }; +inline void AddToParent(tvm::tir::Stmt stmt) { + Builder builder = Builder::Current(); + ICHECK(!builder->frames.empty()); + Frame frame = builder->frames.back(); + if (const auto* tir_frame = frame.as()) { + GetRef(tir_frame)->stmts.push_back(stmt); + } else if (const auto* mod_frame = frame.as()) { + GetRef(mod_frame)->stmts.push_back(stmt); + } else { + LOG(FATAL) << "TypeError: Unsupported frame type: " << frame; + } +} + +inline tvm::tir::Stmt AsStmt(const Array& stmt) { + using namespace tvm::tir; + if (stmt.empty()) { + return Evaluate(0); + } else if (stmt.size() == 1) { + return stmt[0]; + } else { + return SeqStmt(stmt); + } +} + } // namespace tir } // namespace builder } // namespace script } // namespace tvm -#endif // TVM_SCRIPT_BUILDER_TIR_TIR_H_ +#endif // TVM_SCRIPT_BUILDER_TIR_BASE_H_ diff --git a/src/script/builder/tir/block_frame.cc b/src/script/builder/tir/block_frame.cc index 14ae64fdf533..94adfead92e7 100644 --- a/src/script/builder/tir/block_frame.cc +++ b/src/script/builder/tir/block_frame.cc @@ -25,19 +25,33 @@ namespace script { namespace builder { namespace tir { -BlockFrame::BlockFrame(String name) { +BlockFrame Block_(String name) { ObjectPtr n = make_object(); n->name = name; n->iter_vars.clear(); - n->reads = NullOpt; - n->writes = NullOpt; + n->reads.clear(); + n->writes.clear(); n->init = NullOpt; n->alloc_buffers.clear(); n->match_buffers.clear(); n->annotations.clear(); n->iter_values.clear(); n->predicate = NullOpt; - data_ = n; + return BlockFrame(n); +} + +void BlockFrameNode::ExitWithScope() { + using namespace tvm::tir; + AddToParent(BlockRealize(iter_values, // + predicate.value_or(Bool(true)), + Block(iter_vars, // + reads, writes, // + name, // + AsStmt(stmts), // + init, // + alloc_buffers, // + match_buffers, // + annotations))); } namespace axis { diff --git a/src/script/builder/tir/block_frame.h b/src/script/builder/tir/block_frame.h index bec8db18b7ef..05e1969e5a54 100644 --- a/src/script/builder/tir/block_frame.h +++ b/src/script/builder/tir/block_frame.h @@ -19,7 +19,7 @@ #ifndef TVM_SCRIPT_BUILDER_TIR_BLOCK_FRAME_H_ #define TVM_SCRIPT_BUILDER_TIR_BLOCK_FRAME_H_ -#include "./tir.h" +#include "./base.h" namespace tvm { namespace script { @@ -30,8 +30,8 @@ class BlockFrameNode : public TIRFrameNode { public: String name; Array iter_vars; - Optional> reads; - Optional> writes; + Array reads; + Array writes; Optional init; Array alloc_buffers; Array match_buffers; @@ -41,6 +41,7 @@ class BlockFrameNode : public TIRFrameNode { Optional predicate; void VisitAttrs(tvm::AttrVisitor* v) { + TIRFrameNode::VisitAttrs(v); v->Visit("name", &name); v->Visit("iter_vars", &iter_vars); v->Visit("reads", &reads); @@ -55,18 +56,23 @@ class BlockFrameNode : public TIRFrameNode { static constexpr const char* _type_key = "script.builder.tir.BlockFrame"; TVM_DECLARE_FINAL_OBJECT_INFO(BlockFrameNode, TIRFrameNode); + + public: + void ExitWithScope() final; }; class BlockFrame : public TIRFrame { public: - explicit BlockFrame(String name); TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BlockFrame, TIRFrame, BlockFrameNode); }; +BlockFrame Block_(String name); + namespace axis { -tvm::tir::IterVar Spatial(Range dom, PrimExpr binding, DataType dtype); -tvm::tir::IterVar Reduce(Range dom, PrimExpr binding, DataType dtype); -Array Remap(String kinds, Array bindings, DataType dtype); +tvm::tir::IterVar Spatial(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32)); +tvm::tir::IterVar Reduce(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32)); +Array Remap(String kinds, Array bindings, + DataType dtype = DataType::Int(32)); } // namespace axis } // namespace tir } // namespace builder diff --git a/src/script/builder/tir/for_frame.cc b/src/script/builder/tir/for_frame.cc index 3aa02bf48997..82b879f7d494 100644 --- a/src/script/builder/tir/for_frame.cc +++ b/src/script/builder/tir/for_frame.cc @@ -23,14 +23,7 @@ namespace script { namespace builder { namespace tir { -ForFrame::ForFrame(Array vars, Array doms, - ForFrameNode::FMakeForLoop f_make_for_loop) { - ObjectPtr n = make_object(); - n->vars = std::move(vars); - n->doms = std::move(doms); - n->f_make_for_loop = std::move(f_make_for_loop); - data_ = std::move(n); -} +void ForFrameNode::ExitWithScope() { AddToParent(f_make_for_loop(vars, doms, AsStmt(stmts))); } #define TVM_SCRIPT_BUILDER_TIR_FOR_CREATE(Method, Kind) \ ForFrame Method(PrimExpr min, PrimExpr extent, Map attrs) { \ diff --git a/src/script/builder/tir/for_frame.h b/src/script/builder/tir/for_frame.h index 7b94645b46e8..2bff8dcd5f5e 100644 --- a/src/script/builder/tir/for_frame.h +++ b/src/script/builder/tir/for_frame.h @@ -24,7 +24,7 @@ #include #include -#include "./tir.h" +#include "./base.h" namespace tvm { namespace script { @@ -41,6 +41,7 @@ class ForFrameNode : public TIRFrameNode { FMakeForLoop f_make_for_loop; void VisitAttrs(tvm::AttrVisitor* v) { + TIRFrameNode::VisitAttrs(v); v->Visit("vars", &vars); v->Visit("doms", &doms); // `f_make_for_loop` is not visited. @@ -48,20 +49,13 @@ class ForFrameNode : public TIRFrameNode { static constexpr const char* _type_key = "script.builder.tir.ForFrame"; TVM_DECLARE_FINAL_OBJECT_INFO(ForFrameNode, TIRFrameNode); + + public: + void ExitWithScope() final; }; class ForFrame : public TIRFrame { public: - explicit ForFrame(Array vars, Array doms, - ForFrameNode::FMakeForLoop f_make_for_loop); - - void EnterWithScope() { ICHECK(data_ != nullptr); } - - void ExitWithScope() { - ICHECK(data_ != nullptr); - data_.reset(); - } - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ForFrame, TIRFrame, ForFrameNode); }; diff --git a/src/script/builder/tir/prim_func_frame.cc b/src/script/builder/tir/prim_func_frame.cc index 70fb93e0e53a..74371def630f 100644 --- a/src/script/builder/tir/prim_func_frame.cc +++ b/src/script/builder/tir/prim_func_frame.cc @@ -19,11 +19,32 @@ #include "./prim_func_frame.h" +#include + namespace tvm { namespace script { namespace builder { namespace tir { +void PrimFuncFrameNode::ExitWithScope() { + using namespace tvm::tir; + IRModuleFrame frame = Builder::Current()->FindFrame().value(); + frame->global_vars.push_back(GlobalVar(name)); + frame->functions.push_back(PrimFunc(/*params=*/args, + /*body=*/AsStmt(stmts), + /*ret_type=*/ret_type, + /*buffer_map=*/buffer_map)); +} + +PrimFuncFrame PrimFunc_(String name) { + ObjectPtr n = make_object(); + n->name = name; + n->args.clear(); + n->ret_type = TupleType::Empty(); + n->buffer_map.clear(); + return PrimFuncFrame(n); +} + void Arg(tvm::tir::Var var) { PrimFuncFrame frame = Builder::Current()->FindFrame().value(); frame->args.push_back(var); diff --git a/src/script/builder/tir/prim_func_frame.h b/src/script/builder/tir/prim_func_frame.h index 0be38a587347..721bfb88dc6f 100644 --- a/src/script/builder/tir/prim_func_frame.h +++ b/src/script/builder/tir/prim_func_frame.h @@ -19,7 +19,7 @@ #ifndef TVM_SCRIPT_BUILDER_TIR_PRIM_FUNC_FRAME_H_ #define TVM_SCRIPT_BUILDER_TIR_PRIM_FUNC_FRAME_H_ -#include "./tir.h" +#include "./base.h" namespace tvm { namespace script { @@ -43,6 +43,9 @@ class PrimFuncFrameNode : public TIRFrameNode { static constexpr const char* _type_key = "script.builder.tir.PrimFuncFrame"; TVM_DECLARE_FINAL_OBJECT_INFO(PrimFuncFrameNode, TIRFrameNode); + + public: + void ExitWithScope() final; }; class PrimFuncFrame : public TIRFrame { @@ -50,6 +53,7 @@ class PrimFuncFrame : public TIRFrame { TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(PrimFuncFrame, TIRFrame, PrimFuncFrameNode); }; +PrimFuncFrame PrimFunc_(String name); void Arg(tvm::tir::Var var); void Arg(tvm::tir::Buffer buffer); diff --git a/src/script/builder/tir/tir.cc b/src/script/builder/tir/var.cc similarity index 83% rename from src/script/builder/tir/tir.cc rename to src/script/builder/tir/var.cc index faef2372909e..f2e77d763e8e 100644 --- a/src/script/builder/tir/tir.cc +++ b/src/script/builder/tir/var.cc @@ -16,14 +16,16 @@ * specific language governing permissions and limitations * under the License. */ -#include "./tir.h" +#include "./var.h" namespace tvm { namespace script { namespace builder { namespace tir { -TVM_REGISTER_NODE_TYPE(TIRFrameNode); +tvm::tir::Buffer Buffer_(Array shape, DataType dtype, String name, String storage_scope) { + return tvm::tir::decl_buffer(shape, dtype, name, storage_scope); +} } // namespace tir } // namespace builder diff --git a/src/script/builder/tir/var.h b/src/script/builder/tir/var.h new file mode 100644 index 000000000000..41257f0ca6d3 --- /dev/null +++ b/src/script/builder/tir/var.h @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_BUILDER_TIR_VAR_H_ +#define TVM_SCRIPT_BUILDER_TIR_VAR_H_ + +#include "./base.h" + +namespace tvm { +namespace script { +namespace builder { +namespace tir { + +tvm::tir::Buffer Buffer_(Array shape, // + DataType dtype, // + String name = "buffer", // + String storage_scope = ""); + +} +} // namespace builder +} // namespace script +} // namespace tvm + +#endif // TVM_SCRIPT_BUILDER_TIR_VAR_H_