Skip to content

Commit

Permalink
[lang] Add is_argpack property to Expressions and Statements
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
listerily committed Jul 5, 2023
1 parent 20b5fea commit 9e3d3c2
Show file tree
Hide file tree
Showing 21 changed files with 139 additions and 115 deletions.
6 changes: 3 additions & 3 deletions cpp_examples/autograd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,19 +159,19 @@ void autograd() {
auto *i = builder.get_loop_index(loop);

auto *ext_a = builder.create_external_ptr(
builder.create_arg_load({0}, PrimitiveType::f32, true), {i});
builder.create_arg_load({0}, PrimitiveType::f32, true, false), {i});
auto *a_grad_i = builder.create_global_load(
builder.create_global_ptr(a->get_adjoint(), {i}));
builder.create_global_store(ext_a, a_grad_i);

auto *ext_b = builder.create_external_ptr(
builder.create_arg_load({1}, PrimitiveType::f32, true), {i});
builder.create_arg_load({1}, PrimitiveType::f32, true, false), {i});
auto *b_grad_i = builder.create_global_load(
builder.create_global_ptr(b->get_adjoint(), {i}));
builder.create_global_store(ext_b, b_grad_i);

auto *ext_c = builder.create_external_ptr(
builder.create_arg_load({2}, PrimitiveType::f32, true), {i});
builder.create_arg_load({2}, PrimitiveType::f32, true, false), {i});
auto *c_i = builder.create_global_load(builder.create_global_ptr(c, {i}));
builder.create_global_store(ext_c, c_i);
}
Expand Down
2 changes: 1 addition & 1 deletion cpp_examples/run_snode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ void run_snode() {
auto _ = builder.get_loop_guard(loop);
auto *index = builder.get_loop_index(loop);
auto *ext = builder.create_external_ptr(
builder.create_arg_load({0}, PrimitiveType::i32, true), {index});
builder.create_arg_load({0}, PrimitiveType::i32, true, false), {index});
auto *place_index =
builder.create_global_load(builder.create_global_ptr(place, {index}));
builder.create_global_store(ext, place_index);
Expand Down
13 changes: 7 additions & 6 deletions python/taichi/lang/ast/ast_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,13 +602,13 @@ def build_FunctionDef(ctx, node):
assert args.kw_defaults == []
assert args.kwarg is None

def decl_and_create_variable(annotation, name, arg_features):
def decl_and_create_variable(annotation, name, arg_features, is_argpack):
if not isinstance(annotation, primitive_types.RefType):
ctx.kernel_args.append(name)
if isinstance(annotation, ArgPackType):
d = {}
for j, (_name, anno) in enumerate(annotation.members.items()):
d[_name] = decl_and_create_variable(anno, _name, arg_features[j])
d[_name] = decl_and_create_variable(anno, _name, arg_features[j], True)
return kernel_arguments.decl_argpack_arg(annotation, d)
if isinstance(annotation, annotations.template):
return ctx.global_vars[name]
Expand All @@ -635,10 +635,10 @@ def decl_and_create_variable(annotation, name, arg_features):
name,
)
if isinstance(annotation, MatrixType):
return kernel_arguments.decl_matrix_arg(annotation, name)
return kernel_arguments.decl_matrix_arg(annotation, name, is_argpack)
if isinstance(annotation, StructType):
return kernel_arguments.decl_struct_arg(annotation, name)
return kernel_arguments.decl_scalar_arg(annotation, name)
return kernel_arguments.decl_struct_arg(annotation, name, is_argpack)
return kernel_arguments.decl_scalar_arg(annotation, name, is_argpack)

def transform_as_kernel():
# Treat return type
Expand All @@ -651,7 +651,7 @@ def transform_as_kernel():
d = {}
kernel_arguments.push_argpack_arg(ctx.func.arguments[i].name)
for j, (name, anno) in enumerate(ctx.func.arguments[i].annotation.members.items()):
d[name] = decl_and_create_variable(anno, name, ctx.arg_features[i][j])
d[name] = decl_and_create_variable(anno, name, ctx.arg_features[i][j], True)
ctx.create_variable(arg.arg, kernel_arguments.decl_argpack_arg(ctx.func.arguments[i].annotation, d))
else:
ctx.create_variable(
Expand All @@ -660,6 +660,7 @@ def transform_as_kernel():
ctx.func.arguments[i].annotation,
ctx.func.arguments[i].name,
ctx.arg_features[i] if ctx.arg_features is not None else None,
False,
),
)

Expand Down
20 changes: 10 additions & 10 deletions python/taichi/lang/kernel_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def subscript(self, i, j):
return SparseMatrixEntry(self.ptr, i, j, self.dtype)


def decl_scalar_arg(dtype, name):
def decl_scalar_arg(dtype, name, is_argpack):
is_ref = False
if isinstance(dtype, RefType):
is_ref = True
Expand All @@ -56,7 +56,7 @@ def decl_scalar_arg(dtype, name):
arg_id = impl.get_runtime().compiling_callable.insert_pointer_param(dtype, name)
else:
arg_id = impl.get_runtime().compiling_callable.insert_scalar_param(dtype, name)
return Expr(_ti_core.make_arg_load_expr(arg_id, dtype, is_ref))
return Expr(_ti_core.make_arg_load_expr(arg_id, dtype, is_ref, True, is_argpack))


def get_type_for_kernel_args(dtype, name):
Expand All @@ -80,17 +80,17 @@ def get_type_for_kernel_args(dtype, name):
return dtype


def decl_matrix_arg(matrixtype, name):
def decl_matrix_arg(matrixtype, name, is_argpack):
arg_type = get_type_for_kernel_args(matrixtype, name)
arg_id = impl.get_runtime().compiling_callable.insert_scalar_param(arg_type, name)
arg_load = Expr(_ti_core.make_arg_load_expr(arg_id, arg_type, create_load=False))
arg_load = Expr(_ti_core.make_arg_load_expr(arg_id, arg_type, create_load=False, is_argpack=is_argpack))
return matrixtype.from_taichi_object(arg_load)


def decl_struct_arg(structtype, name):
def decl_struct_arg(structtype, name, is_argpack):
arg_type = get_type_for_kernel_args(structtype, name)
arg_id = impl.get_runtime().compiling_callable.insert_scalar_param(arg_type, name)
arg_load = Expr(_ti_core.make_arg_load_expr(arg_id, arg_type, create_load=False))
arg_load = Expr(_ti_core.make_arg_load_expr(arg_id, arg_type, create_load=False, is_argpack=is_argpack))
return structtype.from_taichi_object(arg_load)


Expand All @@ -108,25 +108,25 @@ def decl_sparse_matrix(dtype, name):
ptr_type = cook_dtype(u64)
# Treat the sparse matrix argument as a scalar since we only need to pass in the base pointer
arg_id = impl.get_runtime().compiling_callable.insert_scalar_param(ptr_type, name)
return SparseMatrixProxy(_ti_core.make_arg_load_expr(arg_id, ptr_type, False), value_type)
return SparseMatrixProxy(_ti_core.make_arg_load_expr(arg_id, ptr_type, False, True, False), value_type)


def decl_ndarray_arg(element_type, ndim, name, needs_grad, boundary):
arg_id = impl.get_runtime().compiling_callable.insert_ndarray_param(element_type, ndim, name, needs_grad)
return AnyArray(_ti_core.make_external_tensor_expr(element_type, ndim, arg_id, needs_grad, boundary))
return AnyArray(_ti_core.make_external_tensor_expr(element_type, ndim, arg_id, needs_grad, False, boundary))


def decl_texture_arg(num_dimensions, name):
# FIXME: texture_arg doesn't have element_shape so better separate them
arg_id = impl.get_runtime().compiling_callable.insert_texture_param(num_dimensions, name)
return TextureSampler(_ti_core.make_texture_ptr_expr(arg_id, num_dimensions), num_dimensions)
return TextureSampler(_ti_core.make_texture_ptr_expr(arg_id, num_dimensions, False), num_dimensions)


def decl_rw_texture_arg(num_dimensions, buffer_format, lod, name):
# FIXME: texture_arg doesn't have element_shape so better separate them
arg_id = impl.get_runtime().compiling_callable.insert_rw_texture_param(num_dimensions, buffer_format, name)
return RWTextureAccessor(
_ti_core.make_rw_texture_ptr_expr(arg_id, num_dimensions, buffer_format, lod),
_ti_core.make_rw_texture_ptr_expr(arg_id, num_dimensions, False, buffer_format, lod),
num_dimensions,
)

Expand Down
6 changes: 3 additions & 3 deletions taichi/ir/frontend_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ void ArgLoadExpression::type_check(const CompileConfig *) {

void ArgLoadExpression::flatten(FlattenContext *ctx) {
auto arg_load =
std::make_unique<ArgLoadStmt>(arg_id, dt, is_ptr, create_load);
std::make_unique<ArgLoadStmt>(arg_id, dt, is_ptr, create_load, is_argpack);
arg_load->ret_type = ret_type;
ctx->push_back(std::move(arg_load));
stmt = ctx->back_stmt();
Expand All @@ -163,7 +163,7 @@ void TexturePtrExpression::type_check(const CompileConfig *config) {

void TexturePtrExpression::flatten(FlattenContext *ctx) {
ctx->push_back<ArgLoadStmt>(arg_id, PrimitiveType::f32, /*is_ptr=*/true,
/*create_load*/ true);
/*create_load=*/true, /*is_argpack=*/is_argpack);
ctx->push_back<TexturePtrStmt>(ctx->back_stmt(), num_dims, is_storage, format,
lod);
stmt = ctx->back_stmt();
Expand Down Expand Up @@ -610,7 +610,7 @@ void ExternalTensorExpression::flatten(FlattenContext *ctx) {
TypeFactory::get_instance().get_ndarray_struct_type(dt, ndim, needs_grad);

auto ptr = Stmt::make<ArgLoadStmt>(arg_id, type, /*is_ptr=*/true,
/*create_load=*/false);
/*create_load=*/false, /*is_argpack=*/is_argpack);

ptr->tb = tb;
ctx->push_back(std::move(ptr));
Expand Down
23 changes: 18 additions & 5 deletions taichi/ir/frontend_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -324,11 +324,15 @@ class ArgLoadExpression : public Expression {
*/
bool create_load;

bool is_argpack;

ArgLoadExpression(const std::vector<int> &arg_id,
DataType dt,
bool is_ptr = false,
bool create_load = true)
: arg_id(arg_id), dt(dt), is_ptr(is_ptr), create_load(create_load) {
bool create_load = true,
bool is_argpack = false)
: arg_id(arg_id), dt(dt), is_ptr(is_ptr), create_load(create_load),
is_argpack(is_argpack) {
}

void type_check(const CompileConfig *config) override;
Expand All @@ -349,26 +353,30 @@ class TexturePtrExpression : public Expression {
const std::vector<int> arg_id;
int num_dims;
bool is_storage{false};
bool is_argpack;

// Optional, for storage textures
BufferFormat format{BufferFormat::unknown};
int lod{0};

explicit TexturePtrExpression(const std::vector<int> &arg_id, int num_dims)
explicit TexturePtrExpression(const std::vector<int> &arg_id, int num_dims, bool is_argpack)
: arg_id(arg_id),
num_dims(num_dims),
is_storage(false),
is_argpack(is_argpack),
format(BufferFormat::rgba8),
lod(0) {
}

TexturePtrExpression(const std::vector<int> &arg_id,
int num_dims,
bool is_argpack,
BufferFormat format,
int lod)
: arg_id(arg_id),
num_dims(num_dims),
is_storage(true),
is_argpack(is_argpack),
format(format),
lod(lod) {
}
Expand Down Expand Up @@ -480,19 +488,22 @@ class ExternalTensorExpression : public Expression {
std::vector<int> arg_id;
bool needs_grad{false};
bool is_grad{false};
bool is_argpack{false};
BoundaryMode boundary{BoundaryMode::kUnsafe};

ExternalTensorExpression(const DataType &dt,
int ndim,
const std::vector<int> &arg_id,
bool needs_grad = false,
bool is_argpack = false,
BoundaryMode boundary = BoundaryMode::kUnsafe) {
init(dt, ndim, arg_id, needs_grad, boundary);
init(dt, ndim, arg_id, needs_grad, is_argpack, boundary);
}

explicit ExternalTensorExpression(Expr *expr) : is_grad(true) {
auto ptr = expr->cast<ExternalTensorExpression>();
init(ptr->dt, ptr->ndim, ptr->arg_id, ptr->needs_grad, ptr->boundary);
init(ptr->dt, ptr->ndim, ptr->arg_id, ptr->needs_grad, ptr->is_argpack,
ptr->boundary);
}

void flatten(FlattenContext *ctx) override;
Expand All @@ -517,11 +528,13 @@ class ExternalTensorExpression : public Expression {
int ndim,
const std::vector<int> &arg_id,
bool needs_grad,
bool is_argpack,
BoundaryMode boundary) {
this->dt = dt;
this->ndim = ndim;
this->arg_id = arg_id;
this->needs_grad = needs_grad;
this->is_argpack = is_argpack;
this->boundary = boundary;
}
};
Expand Down
11 changes: 7 additions & 4 deletions taichi/ir/ir_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,9 +180,11 @@ RandStmt *IRBuilder::create_rand(DataType value_type) {

ArgLoadStmt *IRBuilder::create_arg_load(const std::vector<int> &arg_id,
DataType dt,
bool is_ptr) {
bool is_ptr,
bool is_argpack) {
return insert(
Stmt::make_typed<ArgLoadStmt>(arg_id, dt, is_ptr, /*create_load*/ true));
Stmt::make_typed<ArgLoadStmt>(arg_id, dt, is_ptr, /*create_load*/ true,
is_argpack));
}

ReturnStmt *IRBuilder::create_return(Stmt *value) {
Expand Down Expand Up @@ -501,11 +503,12 @@ MeshPatchIndexStmt *IRBuilder::get_patch_index() {
}
ArgLoadStmt *IRBuilder::create_ndarray_arg_load(const std::vector<int> &arg_id,
DataType dt,
int ndim) {
int ndim,
bool is_argpack) {
auto type = TypeFactory::get_instance().get_ndarray_struct_type(dt, ndim);

return insert(Stmt::make_typed<ArgLoadStmt>(arg_id, type, /*is_ptr=*/true,
/*create_load=*/false));
/*create_load=*/false, /*is_argpack=*/is_argpack));
}

} // namespace taichi::lang
6 changes: 4 additions & 2 deletions taichi/ir/ir_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,11 +147,13 @@ class IRBuilder {
// Load kernel arguments.
ArgLoadStmt *create_arg_load(const std::vector<int> &arg_id,
DataType dt,
bool is_ptr);
bool is_ptr,
bool is_argpack);
// Load kernel arguments.
ArgLoadStmt *create_ndarray_arg_load(const std::vector<int> &arg_id,
DataType dt,
int total_dim);
int total_dim,
bool is_argpack);

// The return value of the kernel.
ReturnStmt *create_return(Stmt *value);
Expand Down
12 changes: 8 additions & 4 deletions taichi/ir/statements.h
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,8 @@ class UnaryOpStmt : public Stmt {
/**
* Load a kernel argument. The data type should be known when constructing this
* statement. |is_ptr| should be true iff the result can be used as a base
* pointer of an ExternalPtrStmt.
* pointer of an ExternalPtrStmt. |is_argpack| should be true iif the value is
* in an argpack.
*/
class ArgLoadStmt : public Stmt {
public:
Expand All @@ -192,11 +193,14 @@ class ArgLoadStmt : public Stmt {

bool create_load;

bool is_argpack;

ArgLoadStmt(const std::vector<int> &arg_id,
const DataType &dt,
bool is_ptr,
bool create_load)
: arg_id(arg_id), is_ptr(is_ptr), create_load(create_load) {
bool create_load,
bool is_argpack)
: arg_id(arg_id), is_ptr(is_ptr), create_load(create_load), is_argpack(is_argpack) {
this->ret_type = dt;
TI_STMT_REG_FIELDS;
}
Expand All @@ -205,7 +209,7 @@ class ArgLoadStmt : public Stmt {
return false;
}

TI_STMT_DEF_FIELDS(ret_type, arg_id, is_ptr);
TI_STMT_DEF_FIELDS(ret_type, arg_id, is_ptr, is_argpack);
TI_DEFINE_ACCEPT_AND_CLONE
};

Expand Down
11 changes: 6 additions & 5 deletions taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -951,14 +951,15 @@ void export_lang(py::module &m) {

m.def("make_arg_load_expr",
Expr::make<ArgLoadExpression, const std::vector<int> &,
const DataType &, bool, bool>,
"arg_id"_a, "dt"_a, "is_ptr"_a = false, "create_load"_a = true);
const DataType &, bool, bool, bool>,
"arg_id"_a, "dt"_a, "is_ptr"_a = false, "create_load"_a = true,
"is_argpack"_a = false);

m.def("make_reference", Expr::make<ReferenceExpression, const Expr &>);

m.def("make_external_tensor_expr",
Expr::make<ExternalTensorExpression, const DataType &, int,
const std::vector<int> &, bool, const BoundaryMode &>);
const std::vector<int> &, bool, bool, const BoundaryMode &>);

m.def("make_external_tensor_grad_expr",
Expr::make<ExternalTensorExpression, Expr *>);
Expand All @@ -975,9 +976,9 @@ void export_lang(py::module &m) {
Expr::make<ConstExpression, const DataType &, float64>);

m.def("make_texture_ptr_expr",
Expr::make<TexturePtrExpression, const std::vector<int> &, int>);
Expr::make<TexturePtrExpression, const std::vector<int> &, int, bool>);
m.def("make_rw_texture_ptr_expr",
Expr::make<TexturePtrExpression, const std::vector<int> &, int,
Expr::make<TexturePtrExpression, const std::vector<int> &, int, bool,
const BufferFormat &, int>);

auto &&texture =
Expand Down
2 changes: 1 addition & 1 deletion taichi/transforms/scalarize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -702,7 +702,7 @@ class Scalarize : public BasicStmtVisitor {
auto ret_type = stmt->ret_type.ptr_removed().get_element_type();
ret_type = TypeFactory::get_instance().get_pointer_type(ret_type);
auto arg_load = std::make_unique<ArgLoadStmt>(
stmt->arg_id, ret_type, stmt->is_ptr, stmt->create_load);
stmt->arg_id, ret_type, stmt->is_ptr, stmt->create_load, stmt->is_argpack);

immediate_modifier_.replace_usages_with(stmt, arg_load.get());

Expand Down
Loading

0 comments on commit 9e3d3c2

Please sign in to comment.