Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[lang] Instantiate a runtime ArgPack object when a python ArgPack is created #8241

Merged
merged 9 commits into from
Jul 11, 2023
Merged
25 changes: 15 additions & 10 deletions python/taichi/lang/argpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ class ArgPack:
Args:
annotations (Dict[str, Union[Dict, Matrix, Struct]]): \
The keys and types for `ArgPack` members.
dtype (ArgPackType): \
The ArgPackType class of this ArgPack object.
entries (Dict[str, Union[Dict, Matrix, Struct]]): \
The keys and corresponding values for `ArgPack` members.

Expand All @@ -40,7 +42,7 @@ class ArgPack:

_instance_count = 0

def __init__(self, annotations, *args, **kwargs):
def __init__(self, annotations, dtype, *args, **kwargs):
# converts dicts to argument packs
if len(args) == 1 and kwargs == {} and isinstance(args[0], dict):
self.__entries = args[0]
Expand All @@ -56,7 +58,12 @@ def __init__(self, annotations, *args, **kwargs):
for k, v in self.__entries.items():
self.__entries[k] = v if in_python_scope() else impl.expr_init(v)
self._register_members()
self.__dtype = None
self.__dtype = dtype
self.__argpack = impl.get_runtime().prog.create_argpack(self.__dtype)

def __del__(self):
if impl is not None and impl.get_runtime() is not None and impl.get_runtime().prog is not None:
impl.get_runtime().prog.delete_argpack(self.__argpack)

@property
def keys(self):
Expand Down Expand Up @@ -181,7 +188,7 @@ class _IntermediateArgPack(ArgPack):
entries (Dict[str, Union[Expr, Matrix, Struct]]): keys and values for struct members.
"""

def __init__(self, annotations, *args, **kwargs):
def __init__(self, annotations, dtype, *args, **kwargs):
# converts dicts to argument packs
if len(args) == 1 and kwargs == {} and isinstance(args[0], dict):
self._ArgPack__entries = args[0]
Expand All @@ -195,7 +202,8 @@ def __init__(self, annotations, *args, **kwargs):
raise TaichiSyntaxError("ArgPack annotations keys not equals to entries keys.")
self._ArgPack__annotations = annotations
self._register_members()
self._ArgPack__dtype = None
self._ArgPack__dtype = dtype
self._ArgPack__argpack = impl.get_runtime().prog.create_argpack(dtype)


class ArgPackType(CompoundType):
Expand Down Expand Up @@ -263,10 +271,8 @@ def __call__(self, *args, **kwargs):

d[name] = data

entries = ArgPack(self.members, d)
entries._ArgPack__dtype = self.dtype
entries = ArgPack(self.members, self.dtype, d)
pack = self.cast(entries)
pack._ArgPack__dtype = self.dtype
return pack

def __instancecheck__(self, instance):
Expand Down Expand Up @@ -308,8 +314,7 @@ def cast(self, pack):
entries[k] = int(v) if dtype in primitive_types.integer_types else float(v)
else:
entries[k] = ops.cast(pack._ArgPack__entries[k], dtype)
pack = ArgPack(self.members, entries)
pack._ArgPack__dtype = self.dtype
pack = ArgPack(self.members, self.dtype, entries)
return pack

def from_taichi_object(self, arg_load_dict: dict):
Expand All @@ -318,7 +323,7 @@ def from_taichi_object(self, arg_load_dict: dict):
for index, pair in enumerate(items):
name, dtype = pair
d[name] = arg_load_dict[name]
pack = _IntermediateArgPack(self.members, d)
pack = _IntermediateArgPack(self.members, self.dtype, d)
pack._ArgPack__dtype = self.dtype
return pack

Expand Down
46 changes: 46 additions & 0 deletions taichi/program/argpack.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#include <numeric>

#include "taichi/program/argpack.h"
#include "taichi/program/program.h"

#ifdef TI_WITH_LLVM
#include "taichi/runtime/llvm/llvm_context.h"
#include "taichi/runtime/program_impls/llvm/llvm_program.h"
#endif

namespace taichi::lang {

ArgPack::ArgPack(Program *prog, const DataType type) : prog_(prog) {
auto *old_type = type->get_type()->as<ArgPackType>();
auto [argpack_type, alloc_size] = prog->get_argpack_type_with_data_layout(
old_type, prog->get_kernel_argument_data_layout());
dtype = DataType(argpack_type);
argpack_alloc_ =
prog->allocate_memory_on_device(alloc_size, prog->result_buffer);
}

ArgPack::~ArgPack() {
if (prog_) {
argpack_alloc_.device->dealloc_memory(argpack_alloc_);
}
}

intptr_t ArgPack::get_device_allocation_ptr_as_int() const {
// taichi's own argpack's ptr points to its |DeviceAllocation| on the
// specified device.
return reinterpret_cast<intptr_t>(&argpack_alloc_);
}

DeviceAllocation ArgPack::get_device_allocation() const {
return argpack_alloc_;
}

std::size_t ArgPack::get_nelement() const {
return dtype->as<ArgPackType>()->elements().size();
}

DataType ArgPack::get_data_type() const {
return dtype;
}

} // namespace taichi::lang
34 changes: 34 additions & 0 deletions taichi/program/argpack.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#pragma once

#include <cstdint>
#include <vector>

#include "taichi/inc/constants.h"
#include "taichi/ir/type_utils.h"
#include "taichi/rhi/device.h"

namespace taichi::lang {

class Program;

class TI_DLL_EXPORT ArgPack {
public:
/* Constructs a ArgPack managed by Program.
* Memory allocation and deallocation is handled by Program.
*/
explicit ArgPack(Program *prog, const DataType type);

DeviceAllocation argpack_alloc_{kDeviceNullAllocation};
DataType dtype;

DataType get_data_type() const;
intptr_t get_device_allocation_ptr_as_int() const;
DeviceAllocation get_device_allocation() const;
std::size_t get_nelement() const;

~ArgPack();

private:
Program *prog_{nullptr};
};
} // namespace taichi::lang
25 changes: 25 additions & 0 deletions taichi/program/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,13 @@ Ndarray *Program::create_ndarray(const DataType type,
return arr_ptr;
}

ArgPack *Program::create_argpack(const DataType dt) {
auto pack = std::make_unique<ArgPack>(this, dt);
auto pack_ptr = pack.get();
argpacks_.insert({pack_ptr, std::move(pack)});
return pack_ptr;
}

void Program::delete_ndarray(Ndarray *ndarray) {
// [Note] Ndarray memory deallocation
// Ndarray's memory allocation is managed by Taichi and Python can control
Expand All @@ -416,6 +423,24 @@ void Program::delete_ndarray(Ndarray *ndarray) {
}
}

void Program::delete_argpack(ArgPack *argpack) {
// [Note] Argpack memory deallocation
// Argpack's memory allocation is managed by Taichi and Python can control
// this via Taichi indirectly. For example, when an argpack is GC-ed in
// Python, it signals Taichi to free its memory allocation. But Taichi will
// make sure **no pending kernels to be executed needs the argpack** before it
// actually frees the memory. When `ti.reset()` is called, all argpack
// allocated in this program should be gone and no longer valid in Python.
// This isn't the best implementation, argpacks should be managed by taichi
// runtime instead of this giant program and it should be freed when:
// - Python GC signals taichi that it's no longer useful
// - All kernels using it are executed.
if (argpacks_.count(argpack) &&
!program_impl_->used_in_kernel(argpack->argpack_alloc_.alloc_id)) {
argpacks_.erase(argpack);
}
}

Texture *Program::create_texture(BufferFormat buffer_format,
const std::vector<int> &shape) {
if (shape.size() == 1) {
Expand Down
14 changes: 13 additions & 1 deletion taichi/program/program.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "taichi/ir/type_factory.h"
#include "taichi/ir/snode.h"
#include "taichi/util/lang_util.h"
#include "taichi/program/argpack.h"
#include "taichi/program/program_impl.h"
#include "taichi/program/callable.h"
#include "taichi/program/function.h"
Expand Down Expand Up @@ -255,6 +256,8 @@ class TI_DLL_EXPORT Program {
ExternalArrayLayout layout = ExternalArrayLayout::kNull,
bool zero_fill = false);

ArgPack *create_argpack(const DataType dt);

std::string get_kernel_return_data_layout() {
return program_impl_->get_kernel_return_data_layout();
};
Expand All @@ -269,8 +272,16 @@ class TI_DLL_EXPORT Program {
return program_impl_->get_struct_type_with_data_layout(old_ty, layout);
}

std::pair<const ArgPackType *, size_t> get_argpack_type_with_data_layout(
const ArgPackType *old_ty,
const std::string &layout) {
return program_impl_->get_argpack_type_with_data_layout(old_ty, layout);
}

void delete_ndarray(Ndarray *ndarray);

void delete_argpack(ArgPack *argpack);

Texture *create_texture(BufferFormat buffer_format,
const std::vector<int> &shape);

Expand Down Expand Up @@ -335,8 +346,9 @@ class TI_DLL_EXPORT Program {
static std::atomic<int> num_instances_;
bool finalized_{false};

// TODO: Move ndarrays_ and textures_ to be managed by runtime
// TODO: Move ndarrays_, argpacks_ and textures_ to be managed by runtime
std::unordered_map<void *, std::unique_ptr<Ndarray>> ndarrays_;
std::unordered_map<void *, std::unique_ptr<ArgPack>> argpacks_;
std::vector<std::unique_ptr<Texture>> textures_;
};

Expand Down
6 changes: 6 additions & 0 deletions taichi/program/program_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,12 @@ class ProgramImpl {
return {old_ty, 0};
}

virtual std::pair<const ArgPackType *, size_t>
get_argpack_type_with_data_layout(const ArgPackType *old_ty,
const std::string &layout) {
return {old_ty, 0};
}

KernelCompilationManager &get_kernel_compilation_manager();

KernelLauncher &get_kernel_launcher();
Expand Down
14 changes: 14 additions & 0 deletions taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,13 @@ void export_lang(py::module &m) {
py::arg("layout") = ExternalArrayLayout::kNull,
py::arg("zero_fill") = false, py::return_value_policy::reference)
.def("delete_ndarray", &Program::delete_ndarray)
.def(
"create_argpack",
[&](Program *program, const DataType &dt) -> ArgPack * {
return program->create_argpack(dt);
},
py::arg("dt"), py::return_value_policy::reference)
.def("delete_argpack", &Program::delete_argpack)
.def(
"create_texture",
[&](Program *program, BufferFormat fmt, const std::vector<int> &shape)
Expand Down Expand Up @@ -577,6 +584,13 @@ void export_lang(py::module &m) {
.def_readonly("dtype", &Ndarray::dtype)
.def_readonly("shape", &Ndarray::shape);

py::class_<ArgPack>(m, "ArgPack")
.def("device_allocation_ptr", &ArgPack::get_device_allocation_ptr_as_int)
.def("device_allocation", &ArgPack::get_device_allocation)
.def("nelement", &ArgPack::get_nelement)
.def("data_type", &ArgPack::get_data_type)
.def_readonly("dtype", &ArgPack::dtype);

py::enum_<BufferFormat>(m, "Format")
#define PER_BUFFER_FORMAT(x) .value(#x, BufferFormat::x)
#include "taichi/inc/rhi_constants.inc.h"
Expand Down
74 changes: 74 additions & 0 deletions taichi/runtime/gfx/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -856,5 +856,79 @@ GfxRuntime::get_struct_type_with_data_layout_impl(
bytes, align};
}

std::pair<const lang::ArgPackType *, size_t>
GfxRuntime::get_argpack_type_with_data_layout(const lang::ArgPackType *old_ty,
const std::string &layout) {
auto [new_ty, size, align] =
get_argpack_type_with_data_layout_impl(old_ty, layout);
return {new_ty, size};
}

std::tuple<const lang::ArgPackType *, size_t, size_t>
GfxRuntime::get_argpack_type_with_data_layout_impl(
const lang::ArgPackType *old_ty,
const std::string &layout) {
TI_TRACE("get_argpack_type_with_data_layout: {}", layout);
TI_ASSERT(layout.size() == 2);
auto is_430 = layout[0] == '4';
auto has_buffer_ptr = layout[1] == 'b';
auto members = old_ty->elements();
size_t bytes = 0;
size_t align = 0;
for (int i = 0; i < members.size(); i++) {
auto &member = members[i];
size_t member_align;
size_t member_size;
if (auto struct_type = member.type->cast<lang::StructType>()) {
auto [new_ty, size, member_align_] =
get_struct_type_with_data_layout_impl(struct_type, layout);
members[i].type = new_ty;
member_align = member_align_;
member_size = size;
} else if (auto tensor_type = member.type->cast<lang::TensorType>()) {
size_t element_size = data_type_size_gfx(tensor_type->get_element_type());
size_t num_elements = tensor_type->get_num_elements();
if (!is_430) {
if (num_elements == 2) {
member_align = element_size * 2;
} else {
member_align = element_size * 4;
}
member_size = member_align;
} else {
member_align = element_size;
member_size = tensor_type->get_num_elements() * element_size;
}
} else if (auto pointer_type = member.type->cast<PointerType>()) {
if (has_buffer_ptr) {
member_size = sizeof(uint64_t);
member_align = member_size;
} else {
// Use u32 as placeholder
member_size = sizeof(uint32_t);
member_align = member_size;
}
} else {
TI_ASSERT(member.type->is<PrimitiveType>());
member_size = data_type_size_gfx(member.type);
member_align = member_size;
}
bytes = align_up(bytes, member_align);
members[i].offset = bytes;
bytes += member_size;
align = std::max(align, member_align);
}

if (!is_430) {
align = align_up(align, sizeof(float) * 4);
bytes = align_up(bytes, 4 * sizeof(float));
}
TI_TRACE(" total_bytes={}", bytes);
return {TypeFactory::get_instance()
.get_argpack_type(members, layout)
->as<lang::ArgPackType>(),
bytes, align};
}

} // namespace gfx
} // namespace taichi::lang
8 changes: 8 additions & 0 deletions taichi/runtime/gfx/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,14 @@ class TI_DLL_EXPORT GfxRuntime {
get_struct_type_with_data_layout_impl(const lang::StructType *old_ty,
const std::string &layout);

static std::pair<const lang::ArgPackType *, size_t>
get_argpack_type_with_data_layout(const lang::ArgPackType *old_ty,
const std::string &layout);

static std::tuple<const lang::ArgPackType *, size_t, size_t>
get_argpack_type_with_data_layout_impl(const lang::ArgPackType *old_ty,
const std::string &layout);

private:
friend class taichi::lang::gfx::SNodeTreeManager;

Expand Down
Loading
Loading