From bab6ea6bcd88a39011155c96a7a049703b29c271 Mon Sep 17 00:00:00 2001 From: listerily Date: Mon, 3 Jul 2023 14:10:17 +0800 Subject: [PATCH] [lang] Instantiate a runtime ArgPack object when a python ArgPack is created ghstack-source-id: 9bb67afc51c41ea3f69d893d7c56240fcf09841b Pull Request resolved: https://github.com/taichi-dev/taichi/pull/8241 --- python/taichi/lang/argpack.py | 25 ++++--- taichi/program/argpack.cpp | 46 ++++++++++++ taichi/program/argpack.h | 34 +++++++++ taichi/program/program.cpp | 25 +++++++ taichi/program/program.h | 14 +++- taichi/program/program_impl.h | 6 ++ taichi/python/export_lang.cpp | 14 ++++ taichi/runtime/gfx/runtime.cpp | 74 +++++++++++++++++++ taichi/runtime/gfx/runtime.h | 8 ++ taichi/runtime/llvm/llvm_context.cpp | 34 +++++++++ taichi/runtime/llvm/llvm_context.h | 4 + .../runtime/program_impls/gfx/gfx_program.h | 6 ++ .../runtime/program_impls/llvm/llvm_program.h | 8 ++ 13 files changed, 287 insertions(+), 11 deletions(-) create mode 100644 taichi/program/argpack.cpp create mode 100644 taichi/program/argpack.h diff --git a/python/taichi/lang/argpack.py b/python/taichi/lang/argpack.py index 6e4526ee2a2925..0f458d08ffd1c8 100644 --- a/python/taichi/lang/argpack.py +++ b/python/taichi/lang/argpack.py @@ -23,6 +23,8 @@ class ArgPack: Args: annotations (Dict[str, Union[Dict, Matrix, Struct]]): \ The keys and types for `ArgPack` members. + dtype (ArgPackType): \ + The ArgPackType class of this ArgPack object. entries (Dict[str, Union[Dict, Matrix, Struct]]): \ The keys and corresponding values for `ArgPack` members. @@ -40,7 +42,7 @@ class ArgPack: _instance_count = 0 - def __init__(self, annotations, *args, **kwargs): + def __init__(self, annotations, dtype, *args, **kwargs): # converts dicts to argument packs if len(args) == 1 and kwargs == {} and isinstance(args[0], dict): self.__entries = args[0] @@ -56,7 +58,12 @@ def __init__(self, annotations, *args, **kwargs): for k, v in self.__entries.items(): self.__entries[k] = v if in_python_scope() else impl.expr_init(v) self._register_members() - self.__dtype = None + self.__dtype = dtype + self.__argpack = impl.get_runtime().prog.create_argpack(self.__dtype) + + def __del__(self): + if impl is not None and impl.get_runtime() is not None and impl.get_runtime().prog is not None: + impl.get_runtime().prog.delete_argpack(self.__argpack) @property def keys(self): @@ -181,7 +188,7 @@ class _IntermediateArgPack(ArgPack): entries (Dict[str, Union[Expr, Matrix, Struct]]): keys and values for struct members. """ - def __init__(self, annotations, *args, **kwargs): + def __init__(self, annotations, dtype, *args, **kwargs): # converts dicts to argument packs if len(args) == 1 and kwargs == {} and isinstance(args[0], dict): self._ArgPack__entries = args[0] @@ -195,7 +202,8 @@ def __init__(self, annotations, *args, **kwargs): raise TaichiSyntaxError("ArgPack annotations keys not equals to entries keys.") self._ArgPack__annotations = annotations self._register_members() - self._ArgPack__dtype = None + self._ArgPack__dtype = dtype + self._ArgPack__argpack = impl.get_runtime().prog.create_argpack(dtype) class ArgPackType(CompoundType): @@ -263,10 +271,8 @@ def __call__(self, *args, **kwargs): d[name] = data - entries = ArgPack(self.members, d) - entries._ArgPack__dtype = self.dtype + entries = ArgPack(self.members, self.dtype, d) pack = self.cast(entries) - pack._ArgPack__dtype = self.dtype return pack def __instancecheck__(self, instance): @@ -308,8 +314,7 @@ def cast(self, pack): entries[k] = int(v) if dtype in primitive_types.integer_types else float(v) else: entries[k] = ops.cast(pack._ArgPack__entries[k], dtype) - pack = ArgPack(self.members, entries) - pack._ArgPack__dtype = self.dtype + pack = ArgPack(self.members, self.dtype, entries) return pack def from_taichi_object(self, arg_load_dict: dict): @@ -318,7 +323,7 @@ def from_taichi_object(self, arg_load_dict: dict): for index, pair in enumerate(items): name, dtype = pair d[name] = arg_load_dict[name] - pack = _IntermediateArgPack(self.members, d) + pack = _IntermediateArgPack(self.members, self.dtype, d) pack._ArgPack__dtype = self.dtype return pack diff --git a/taichi/program/argpack.cpp b/taichi/program/argpack.cpp new file mode 100644 index 00000000000000..8aa7656594b919 --- /dev/null +++ b/taichi/program/argpack.cpp @@ -0,0 +1,46 @@ +#include + +#include "taichi/program/argpack.h" +#include "taichi/program/program.h" + +#ifdef TI_WITH_LLVM +#include "taichi/runtime/llvm/llvm_context.h" +#include "taichi/runtime/program_impls/llvm/llvm_program.h" +#endif + +namespace taichi::lang { + +ArgPack::ArgPack(Program *prog, const DataType type) : prog_(prog) { + auto *old_type = type->get_type()->as(); + auto [argpack_type, alloc_size] = prog->get_argpack_type_with_data_layout( + old_type, prog->get_kernel_argument_data_layout()); + dtype = DataType(argpack_type); + argpack_alloc_ = + prog->allocate_memory_on_device(alloc_size, prog->result_buffer); +} + +ArgPack::~ArgPack() { + if (prog_) { + argpack_alloc_.device->dealloc_memory(argpack_alloc_); + } +} + +intptr_t ArgPack::get_device_allocation_ptr_as_int() const { + // taichi's own argpack's ptr points to its |DeviceAllocation| on the + // specified device. + return reinterpret_cast(&argpack_alloc_); +} + +DeviceAllocation ArgPack::get_device_allocation() const { + return argpack_alloc_; +} + +std::size_t ArgPack::get_nelement() const { + return dtype->as()->elements().size(); +} + +DataType ArgPack::get_data_type() const { + return dtype; +} + +} // namespace taichi::lang diff --git a/taichi/program/argpack.h b/taichi/program/argpack.h new file mode 100644 index 00000000000000..59784fc432106f --- /dev/null +++ b/taichi/program/argpack.h @@ -0,0 +1,34 @@ +#pragma once + +#include +#include + +#include "taichi/inc/constants.h" +#include "taichi/ir/type_utils.h" +#include "taichi/rhi/device.h" + +namespace taichi::lang { + +class Program; + +class TI_DLL_EXPORT ArgPack { + public: + /* Constructs a ArgPack managed by Program. + * Memory allocation and deallocation is handled by Program. + */ + explicit ArgPack(Program *prog, const DataType type); + + DeviceAllocation argpack_alloc_{kDeviceNullAllocation}; + DataType dtype; + + DataType get_data_type() const; + intptr_t get_device_allocation_ptr_as_int() const; + DeviceAllocation get_device_allocation() const; + std::size_t get_nelement() const; + + ~ArgPack(); + + private: + Program *prog_{nullptr}; +}; +} // namespace taichi::lang diff --git a/taichi/program/program.cpp b/taichi/program/program.cpp index 4bef97a8456087..4bfd1aaab976e3 100644 --- a/taichi/program/program.cpp +++ b/taichi/program/program.cpp @@ -398,6 +398,13 @@ Ndarray *Program::create_ndarray(const DataType type, return arr_ptr; } +ArgPack *Program::create_argpack(const DataType dt) { + auto pack = std::make_unique(this, dt); + auto pack_ptr = pack.get(); + argpacks_.insert({pack_ptr, std::move(pack)}); + return pack_ptr; +} + void Program::delete_ndarray(Ndarray *ndarray) { // [Note] Ndarray memory deallocation // Ndarray's memory allocation is managed by Taichi and Python can control @@ -416,6 +423,24 @@ void Program::delete_ndarray(Ndarray *ndarray) { } } +void Program::delete_argpack(ArgPack *argpack) { + // [Note] Argpack memory deallocation + // Argpack's memory allocation is managed by Taichi and Python can control + // this via Taichi indirectly. For example, when an argpack is GC-ed in + // Python, it signals Taichi to free its memory allocation. But Taichi will + // make sure **no pending kernels to be executed needs the argpack** before it + // actually frees the memory. When `ti.reset()` is called, all argpack + // allocated in this program should be gone and no longer valid in Python. + // This isn't the best implementation, argpacks should be managed by taichi + // runtime instead of this giant program and it should be freed when: + // - Python GC signals taichi that it's no longer useful + // - All kernels using it are executed. + if (argpacks_.count(argpack) && + !program_impl_->used_in_kernel(argpack->argpack_alloc_.alloc_id)) { + argpacks_.erase(argpack); + } +} + Texture *Program::create_texture(BufferFormat buffer_format, const std::vector &shape) { if (shape.size() == 1) { diff --git a/taichi/program/program.h b/taichi/program/program.h index 53ed610c046ad7..4f63313532ea72 100644 --- a/taichi/program/program.h +++ b/taichi/program/program.h @@ -15,6 +15,7 @@ #include "taichi/ir/type_factory.h" #include "taichi/ir/snode.h" #include "taichi/util/lang_util.h" +#include "taichi/program/argpack.h" #include "taichi/program/program_impl.h" #include "taichi/program/callable.h" #include "taichi/program/function.h" @@ -255,6 +256,8 @@ class TI_DLL_EXPORT Program { ExternalArrayLayout layout = ExternalArrayLayout::kNull, bool zero_fill = false); + ArgPack *create_argpack(const DataType dt); + std::string get_kernel_return_data_layout() { return program_impl_->get_kernel_return_data_layout(); }; @@ -269,8 +272,16 @@ class TI_DLL_EXPORT Program { return program_impl_->get_struct_type_with_data_layout(old_ty, layout); } + std::pair get_argpack_type_with_data_layout( + const ArgPackType *old_ty, + const std::string &layout) { + return program_impl_->get_argpack_type_with_data_layout(old_ty, layout); + } + void delete_ndarray(Ndarray *ndarray); + void delete_argpack(ArgPack *argpack); + Texture *create_texture(BufferFormat buffer_format, const std::vector &shape); @@ -335,8 +346,9 @@ class TI_DLL_EXPORT Program { static std::atomic num_instances_; bool finalized_{false}; - // TODO: Move ndarrays_ and textures_ to be managed by runtime + // TODO: Move ndarrays_, argpacks_ and textures_ to be managed by runtime std::unordered_map> ndarrays_; + std::unordered_map> argpacks_; std::vector> textures_; }; diff --git a/taichi/program/program_impl.h b/taichi/program/program_impl.h index 258843e92afabf..14df0c7fac3419 100644 --- a/taichi/program/program_impl.h +++ b/taichi/program/program_impl.h @@ -163,6 +163,12 @@ class ProgramImpl { return {old_ty, 0}; } + virtual std::pair + get_argpack_type_with_data_layout(const ArgPackType *old_ty, + const std::string &layout) { + return {old_ty, 0}; + } + KernelCompilationManager &get_kernel_compilation_manager(); KernelLauncher &get_kernel_launcher(); diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index 96111587414eeb..b2ed35b9fef5b4 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -428,6 +428,13 @@ void export_lang(py::module &m) { py::arg("layout") = ExternalArrayLayout::kNull, py::arg("zero_fill") = false, py::return_value_policy::reference) .def("delete_ndarray", &Program::delete_ndarray) + .def( + "create_argpack", + [&](Program *program, const DataType &dt) -> ArgPack * { + return program->create_argpack(dt); + }, + py::arg("dt"), py::return_value_policy::reference) + .def("delete_argpack", &Program::delete_argpack) .def( "create_texture", [&](Program *program, BufferFormat fmt, const std::vector &shape) @@ -568,6 +575,13 @@ void export_lang(py::module &m) { .def_readonly("dtype", &Ndarray::dtype) .def_readonly("shape", &Ndarray::shape); + py::class_(m, "ArgPack") + .def("device_allocation_ptr", &ArgPack::get_device_allocation_ptr_as_int) + .def("device_allocation", &ArgPack::get_device_allocation) + .def("nelement", &ArgPack::get_nelement) + .def("data_type", &ArgPack::get_data_type) + .def_readonly("dtype", &ArgPack::dtype); + py::enum_(m, "Format") #define PER_BUFFER_FORMAT(x) .value(#x, BufferFormat::x) #include "taichi/inc/rhi_constants.inc.h" diff --git a/taichi/runtime/gfx/runtime.cpp b/taichi/runtime/gfx/runtime.cpp index f489053c6237e5..ecc1ecc50fe813 100644 --- a/taichi/runtime/gfx/runtime.cpp +++ b/taichi/runtime/gfx/runtime.cpp @@ -856,5 +856,79 @@ GfxRuntime::get_struct_type_with_data_layout_impl( bytes, align}; } +std::pair +GfxRuntime::get_argpack_type_with_data_layout(const lang::ArgPackType *old_ty, + const std::string &layout) { + auto [new_ty, size, align] = + get_argpack_type_with_data_layout_impl(old_ty, layout); + return {new_ty, size}; +} + +std::tuple +GfxRuntime::get_argpack_type_with_data_layout_impl( + const lang::ArgPackType *old_ty, + const std::string &layout) { + TI_TRACE("get_argpack_type_with_data_layout: {}", layout); + TI_ASSERT(layout.size() == 2); + auto is_430 = layout[0] == '4'; + auto has_buffer_ptr = layout[1] == 'b'; + auto members = old_ty->elements(); + size_t bytes = 0; + size_t align = 0; + for (int i = 0; i < members.size(); i++) { + auto &member = members[i]; + size_t member_align; + size_t member_size; + if (auto struct_type = member.type->cast()) { + auto [new_ty, size, member_align_] = + get_struct_type_with_data_layout_impl(struct_type, layout); + members[i].type = new_ty; + member_align = member_align_; + member_size = size; + } else if (auto tensor_type = member.type->cast()) { + size_t element_size = data_type_size_gfx(tensor_type->get_element_type()); + size_t num_elements = tensor_type->get_num_elements(); + if (!is_430) { + if (num_elements == 2) { + member_align = element_size * 2; + } else { + member_align = element_size * 4; + } + member_size = member_align; + } else { + member_align = element_size; + member_size = tensor_type->get_num_elements() * element_size; + } + } else if (auto pointer_type = member.type->cast()) { + if (has_buffer_ptr) { + member_size = sizeof(uint64_t); + member_align = member_size; + } else { + // Use u32 as placeholder + member_size = sizeof(uint32_t); + member_align = member_size; + } + } else { + TI_ASSERT(member.type->is()); + member_size = data_type_size_gfx(member.type); + member_align = member_size; + } + bytes = align_up(bytes, member_align); + members[i].offset = bytes; + bytes += member_size; + align = std::max(align, member_align); + } + + if (!is_430) { + align = align_up(align, sizeof(float) * 4); + bytes = align_up(bytes, 4 * sizeof(float)); + } + TI_TRACE(" total_bytes={}", bytes); + return {TypeFactory::get_instance() + .get_argpack_type(members, layout) + ->as(), + bytes, align}; +} + } // namespace gfx } // namespace taichi::lang diff --git a/taichi/runtime/gfx/runtime.h b/taichi/runtime/gfx/runtime.h index 7279d66c30e7c2..c55fbb53b107ad 100644 --- a/taichi/runtime/gfx/runtime.h +++ b/taichi/runtime/gfx/runtime.h @@ -136,6 +136,14 @@ class TI_DLL_EXPORT GfxRuntime { get_struct_type_with_data_layout_impl(const lang::StructType *old_ty, const std::string &layout); + static std::pair + get_argpack_type_with_data_layout(const lang::ArgPackType *old_ty, + const std::string &layout); + + static std::tuple + get_argpack_type_with_data_layout_impl(const lang::ArgPackType *old_ty, + const std::string &layout); + private: friend class taichi::lang::gfx::SNodeTreeManager; diff --git a/taichi/runtime/llvm/llvm_context.cpp b/taichi/runtime/llvm/llvm_context.cpp index cd18f250cc3570..72803f92432d8d 100644 --- a/taichi/runtime/llvm/llvm_context.cpp +++ b/taichi/runtime/llvm/llvm_context.cpp @@ -179,6 +179,12 @@ llvm::Type *TaichiLLVMContext::get_data_type(DataType dt) { types.push_back(get_data_type(element.type)); } return llvm::StructType::get(*ctx, types); + } else if (const auto *argpack_type = dt->cast()) { + std::vector types; + for (const auto &element : argpack_type->elements()) { + types.push_back(get_data_type(element.type)); + } + return llvm::StructType::get(*ctx, types); } else if (const auto *pointer_type = dt->cast()) { return llvm::PointerType::get( get_data_type(pointer_type->get_pointee_type()), 0); @@ -1155,6 +1161,34 @@ TaichiLLVMContext::get_struct_type_with_data_layout(const StructType *old_ty, struct_size}; } +std::pair +TaichiLLVMContext::get_argpack_type_with_data_layout( + const ArgPackType *old_ty, + const std::string &layout) { + auto *llvm_struct_type = llvm::cast(get_data_type(old_ty)); + auto data_layout = llvm::DataLayout::parse(layout); + TI_ASSERT(data_layout); + size_t struct_size = data_layout->getTypeAllocSize(llvm_struct_type); + if (old_ty->get_layout() == layout) { + return {old_ty, struct_size}; + } + std::vector elements = old_ty->elements(); + for (auto &element : elements) { + if (auto struct_type = element.type->cast()) { + element.type = + get_struct_type_with_data_layout(struct_type, layout).first; + } + } + auto struct_layout = data_layout->getStructLayout(llvm_struct_type); + for (int i = 0; i < elements.size(); i++) { + elements[i].offset = struct_layout->getElementOffset(i); + } + return {TypeFactory::get_instance() + .get_argpack_type(elements, layout) + ->cast(), + struct_size}; +} + TI_REGISTER_TASK(make_slim_libdevice); } // namespace taichi::lang diff --git a/taichi/runtime/llvm/llvm_context.h b/taichi/runtime/llvm/llvm_context.h index 6a4b8ac69cd7e5..c7e4c79b9bc711 100644 --- a/taichi/runtime/llvm/llvm_context.h +++ b/taichi/runtime/llvm/llvm_context.h @@ -82,6 +82,10 @@ class TaichiLLVMContext { const StructType *old_ty, const std::string &layout); + std::pair get_argpack_type_with_data_layout( + const ArgPackType *old_ty, + const std::string &layout); + template llvm::Value *get_constant(T t); diff --git a/taichi/runtime/program_impls/gfx/gfx_program.h b/taichi/runtime/program_impls/gfx/gfx_program.h index 689beb9d1f2e04..98965ca2fa6e87 100644 --- a/taichi/runtime/program_impls/gfx/gfx_program.h +++ b/taichi/runtime/program_impls/gfx/gfx_program.h @@ -72,6 +72,12 @@ class GfxProgramImpl : public ProgramImpl { return gfx::GfxRuntime::get_struct_type_with_data_layout(old_ty, layout); } + std::pair get_argpack_type_with_data_layout( + const ArgPackType *old_ty, + const std::string &layout) override { + return gfx::GfxRuntime::get_argpack_type_with_data_layout(old_ty, layout); + } + std::string get_kernel_return_data_layout() override { return "4-"; }; diff --git a/taichi/runtime/program_impls/llvm/llvm_program.h b/taichi/runtime/program_impls/llvm/llvm_program.h index 7f1b5502c593b4..98d760abb60954 100644 --- a/taichi/runtime/program_impls/llvm/llvm_program.h +++ b/taichi/runtime/program_impls/llvm/llvm_program.h @@ -221,12 +221,20 @@ class LlvmProgramImpl : public ProgramImpl { std::string get_kernel_argument_data_layout() override { return get_llvm_context()->get_data_layout_string(); }; + std::pair get_struct_type_with_data_layout( const StructType *old_ty, const std::string &layout) override { return get_llvm_context()->get_struct_type_with_data_layout(old_ty, layout); } + std::pair get_argpack_type_with_data_layout( + const ArgPackType *old_ty, + const std::string &layout) override { + return get_llvm_context()->get_argpack_type_with_data_layout(old_ty, + layout); + } + // TODO(zhanlue): Rearrange llvm::Context's ownership // // In LLVM backend, most of the compiled information are stored in