From 9e3d3c24de5aa4a2cd1727b82e1175eb9957982d Mon Sep 17 00:00:00 2001 From: listerily Date: Wed, 5 Jul 2023 12:55:30 +0800 Subject: [PATCH] [lang] Add is_argpack property to Expressions and Statements [ghstack-poisoned] --- cpp_examples/autograd.cpp | 6 +- cpp_examples/run_snode.cpp | 2 +- python/taichi/lang/ast/ast_transformer.py | 13 +-- python/taichi/lang/kernel_arguments.py | 20 ++--- taichi/ir/frontend_ir.cpp | 6 +- taichi/ir/frontend_ir.h | 23 ++++-- taichi/ir/ir_builder.cpp | 11 ++- taichi/ir/ir_builder.h | 6 +- taichi/ir/statements.h | 12 ++- taichi/python/export_lang.cpp | 11 +-- taichi/transforms/scalarize.cpp | 2 +- tests/cpp/analysis/alias_analysis_test.cpp | 32 ++++---- tests/cpp/ir/ir_builder_test.cpp | 4 +- tests/cpp/ir/ir_type_promotion_test.cpp | 2 +- tests/cpp/ir/ndarray_kernel.cpp | 6 +- .../transforms/binary_op_simplify_test.cpp | 4 +- tests/cpp/transforms/constant_fold_test.cpp | 80 +++++++++---------- .../determine_ad_stack_size_test.cpp | 4 +- .../transforms/half2_vectorization_test.cpp | 2 +- tests/cpp/transforms/inlining_test.cpp | 4 +- tests/cpp/transforms/scalarize_test.cpp | 4 +- 21 files changed, 139 insertions(+), 115 deletions(-) diff --git a/cpp_examples/autograd.cpp b/cpp_examples/autograd.cpp index 411ffd9a1420d..27c1be3d054c7 100644 --- a/cpp_examples/autograd.cpp +++ b/cpp_examples/autograd.cpp @@ -159,19 +159,19 @@ void autograd() { auto *i = builder.get_loop_index(loop); auto *ext_a = builder.create_external_ptr( - builder.create_arg_load({0}, PrimitiveType::f32, true), {i}); + builder.create_arg_load({0}, PrimitiveType::f32, true, false), {i}); auto *a_grad_i = builder.create_global_load( builder.create_global_ptr(a->get_adjoint(), {i})); builder.create_global_store(ext_a, a_grad_i); auto *ext_b = builder.create_external_ptr( - builder.create_arg_load({1}, PrimitiveType::f32, true), {i}); + builder.create_arg_load({1}, PrimitiveType::f32, true, false), {i}); auto *b_grad_i = builder.create_global_load( builder.create_global_ptr(b->get_adjoint(), {i})); builder.create_global_store(ext_b, b_grad_i); auto *ext_c = builder.create_external_ptr( - builder.create_arg_load({2}, PrimitiveType::f32, true), {i}); + builder.create_arg_load({2}, PrimitiveType::f32, true, false), {i}); auto *c_i = builder.create_global_load(builder.create_global_ptr(c, {i})); builder.create_global_store(ext_c, c_i); } diff --git a/cpp_examples/run_snode.cpp b/cpp_examples/run_snode.cpp index f05a03f894afb..b8d6e6e39e231 100644 --- a/cpp_examples/run_snode.cpp +++ b/cpp_examples/run_snode.cpp @@ -116,7 +116,7 @@ void run_snode() { auto _ = builder.get_loop_guard(loop); auto *index = builder.get_loop_index(loop); auto *ext = builder.create_external_ptr( - builder.create_arg_load({0}, PrimitiveType::i32, true), {index}); + builder.create_arg_load({0}, PrimitiveType::i32, true, false), {index}); auto *place_index = builder.create_global_load(builder.create_global_ptr(place, {index})); builder.create_global_store(ext, place_index); diff --git a/python/taichi/lang/ast/ast_transformer.py b/python/taichi/lang/ast/ast_transformer.py index 7d3514e57a29e..4aff34d0b0af1 100644 --- a/python/taichi/lang/ast/ast_transformer.py +++ b/python/taichi/lang/ast/ast_transformer.py @@ -602,13 +602,13 @@ def build_FunctionDef(ctx, node): assert args.kw_defaults == [] assert args.kwarg is None - def decl_and_create_variable(annotation, name, arg_features): + def decl_and_create_variable(annotation, name, arg_features, is_argpack): if not isinstance(annotation, primitive_types.RefType): ctx.kernel_args.append(name) if isinstance(annotation, ArgPackType): d = {} for j, (_name, anno) in enumerate(annotation.members.items()): - d[_name] = decl_and_create_variable(anno, _name, arg_features[j]) + d[_name] = decl_and_create_variable(anno, _name, arg_features[j], True) return kernel_arguments.decl_argpack_arg(annotation, d) if isinstance(annotation, annotations.template): return ctx.global_vars[name] @@ -635,10 +635,10 @@ def decl_and_create_variable(annotation, name, arg_features): name, ) if isinstance(annotation, MatrixType): - return kernel_arguments.decl_matrix_arg(annotation, name) + return kernel_arguments.decl_matrix_arg(annotation, name, is_argpack) if isinstance(annotation, StructType): - return kernel_arguments.decl_struct_arg(annotation, name) - return kernel_arguments.decl_scalar_arg(annotation, name) + return kernel_arguments.decl_struct_arg(annotation, name, is_argpack) + return kernel_arguments.decl_scalar_arg(annotation, name, is_argpack) def transform_as_kernel(): # Treat return type @@ -651,7 +651,7 @@ def transform_as_kernel(): d = {} kernel_arguments.push_argpack_arg(ctx.func.arguments[i].name) for j, (name, anno) in enumerate(ctx.func.arguments[i].annotation.members.items()): - d[name] = decl_and_create_variable(anno, name, ctx.arg_features[i][j]) + d[name] = decl_and_create_variable(anno, name, ctx.arg_features[i][j], True) ctx.create_variable(arg.arg, kernel_arguments.decl_argpack_arg(ctx.func.arguments[i].annotation, d)) else: ctx.create_variable( @@ -660,6 +660,7 @@ def transform_as_kernel(): ctx.func.arguments[i].annotation, ctx.func.arguments[i].name, ctx.arg_features[i] if ctx.arg_features is not None else None, + False, ), ) diff --git a/python/taichi/lang/kernel_arguments.py b/python/taichi/lang/kernel_arguments.py index 08cd8d4fb3243..99737364f85d9 100644 --- a/python/taichi/lang/kernel_arguments.py +++ b/python/taichi/lang/kernel_arguments.py @@ -46,7 +46,7 @@ def subscript(self, i, j): return SparseMatrixEntry(self.ptr, i, j, self.dtype) -def decl_scalar_arg(dtype, name): +def decl_scalar_arg(dtype, name, is_argpack): is_ref = False if isinstance(dtype, RefType): is_ref = True @@ -56,7 +56,7 @@ def decl_scalar_arg(dtype, name): arg_id = impl.get_runtime().compiling_callable.insert_pointer_param(dtype, name) else: arg_id = impl.get_runtime().compiling_callable.insert_scalar_param(dtype, name) - return Expr(_ti_core.make_arg_load_expr(arg_id, dtype, is_ref)) + return Expr(_ti_core.make_arg_load_expr(arg_id, dtype, is_ref, True, is_argpack)) def get_type_for_kernel_args(dtype, name): @@ -80,17 +80,17 @@ def get_type_for_kernel_args(dtype, name): return dtype -def decl_matrix_arg(matrixtype, name): +def decl_matrix_arg(matrixtype, name, is_argpack): arg_type = get_type_for_kernel_args(matrixtype, name) arg_id = impl.get_runtime().compiling_callable.insert_scalar_param(arg_type, name) - arg_load = Expr(_ti_core.make_arg_load_expr(arg_id, arg_type, create_load=False)) + arg_load = Expr(_ti_core.make_arg_load_expr(arg_id, arg_type, create_load=False, is_argpack=is_argpack)) return matrixtype.from_taichi_object(arg_load) -def decl_struct_arg(structtype, name): +def decl_struct_arg(structtype, name, is_argpack): arg_type = get_type_for_kernel_args(structtype, name) arg_id = impl.get_runtime().compiling_callable.insert_scalar_param(arg_type, name) - arg_load = Expr(_ti_core.make_arg_load_expr(arg_id, arg_type, create_load=False)) + arg_load = Expr(_ti_core.make_arg_load_expr(arg_id, arg_type, create_load=False, is_argpack=is_argpack)) return structtype.from_taichi_object(arg_load) @@ -108,25 +108,25 @@ def decl_sparse_matrix(dtype, name): ptr_type = cook_dtype(u64) # Treat the sparse matrix argument as a scalar since we only need to pass in the base pointer arg_id = impl.get_runtime().compiling_callable.insert_scalar_param(ptr_type, name) - return SparseMatrixProxy(_ti_core.make_arg_load_expr(arg_id, ptr_type, False), value_type) + return SparseMatrixProxy(_ti_core.make_arg_load_expr(arg_id, ptr_type, False, True, False), value_type) def decl_ndarray_arg(element_type, ndim, name, needs_grad, boundary): arg_id = impl.get_runtime().compiling_callable.insert_ndarray_param(element_type, ndim, name, needs_grad) - return AnyArray(_ti_core.make_external_tensor_expr(element_type, ndim, arg_id, needs_grad, boundary)) + return AnyArray(_ti_core.make_external_tensor_expr(element_type, ndim, arg_id, needs_grad, False, boundary)) def decl_texture_arg(num_dimensions, name): # FIXME: texture_arg doesn't have element_shape so better separate them arg_id = impl.get_runtime().compiling_callable.insert_texture_param(num_dimensions, name) - return TextureSampler(_ti_core.make_texture_ptr_expr(arg_id, num_dimensions), num_dimensions) + return TextureSampler(_ti_core.make_texture_ptr_expr(arg_id, num_dimensions, False), num_dimensions) def decl_rw_texture_arg(num_dimensions, buffer_format, lod, name): # FIXME: texture_arg doesn't have element_shape so better separate them arg_id = impl.get_runtime().compiling_callable.insert_rw_texture_param(num_dimensions, buffer_format, name) return RWTextureAccessor( - _ti_core.make_rw_texture_ptr_expr(arg_id, num_dimensions, buffer_format, lod), + _ti_core.make_rw_texture_ptr_expr(arg_id, num_dimensions, False, buffer_format, lod), num_dimensions, ) diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index c5c741d983202..18906537181e1 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -152,7 +152,7 @@ void ArgLoadExpression::type_check(const CompileConfig *) { void ArgLoadExpression::flatten(FlattenContext *ctx) { auto arg_load = - std::make_unique(arg_id, dt, is_ptr, create_load); + std::make_unique(arg_id, dt, is_ptr, create_load, is_argpack); arg_load->ret_type = ret_type; ctx->push_back(std::move(arg_load)); stmt = ctx->back_stmt(); @@ -163,7 +163,7 @@ void TexturePtrExpression::type_check(const CompileConfig *config) { void TexturePtrExpression::flatten(FlattenContext *ctx) { ctx->push_back(arg_id, PrimitiveType::f32, /*is_ptr=*/true, - /*create_load*/ true); + /*create_load=*/true, /*is_argpack=*/is_argpack); ctx->push_back(ctx->back_stmt(), num_dims, is_storage, format, lod); stmt = ctx->back_stmt(); @@ -610,7 +610,7 @@ void ExternalTensorExpression::flatten(FlattenContext *ctx) { TypeFactory::get_instance().get_ndarray_struct_type(dt, ndim, needs_grad); auto ptr = Stmt::make(arg_id, type, /*is_ptr=*/true, - /*create_load=*/false); + /*create_load=*/false, /*is_argpack=*/is_argpack); ptr->tb = tb; ctx->push_back(std::move(ptr)); diff --git a/taichi/ir/frontend_ir.h b/taichi/ir/frontend_ir.h index 44834e2452150..b506a4b46c159 100644 --- a/taichi/ir/frontend_ir.h +++ b/taichi/ir/frontend_ir.h @@ -324,11 +324,15 @@ class ArgLoadExpression : public Expression { */ bool create_load; + bool is_argpack; + ArgLoadExpression(const std::vector &arg_id, DataType dt, bool is_ptr = false, - bool create_load = true) - : arg_id(arg_id), dt(dt), is_ptr(is_ptr), create_load(create_load) { + bool create_load = true, + bool is_argpack = false) + : arg_id(arg_id), dt(dt), is_ptr(is_ptr), create_load(create_load), + is_argpack(is_argpack) { } void type_check(const CompileConfig *config) override; @@ -349,26 +353,30 @@ class TexturePtrExpression : public Expression { const std::vector arg_id; int num_dims; bool is_storage{false}; + bool is_argpack; // Optional, for storage textures BufferFormat format{BufferFormat::unknown}; int lod{0}; - explicit TexturePtrExpression(const std::vector &arg_id, int num_dims) + explicit TexturePtrExpression(const std::vector &arg_id, int num_dims, bool is_argpack) : arg_id(arg_id), num_dims(num_dims), is_storage(false), + is_argpack(is_argpack), format(BufferFormat::rgba8), lod(0) { } TexturePtrExpression(const std::vector &arg_id, int num_dims, + bool is_argpack, BufferFormat format, int lod) : arg_id(arg_id), num_dims(num_dims), is_storage(true), + is_argpack(is_argpack), format(format), lod(lod) { } @@ -480,19 +488,22 @@ class ExternalTensorExpression : public Expression { std::vector arg_id; bool needs_grad{false}; bool is_grad{false}; + bool is_argpack{false}; BoundaryMode boundary{BoundaryMode::kUnsafe}; ExternalTensorExpression(const DataType &dt, int ndim, const std::vector &arg_id, bool needs_grad = false, + bool is_argpack = false, BoundaryMode boundary = BoundaryMode::kUnsafe) { - init(dt, ndim, arg_id, needs_grad, boundary); + init(dt, ndim, arg_id, needs_grad, is_argpack, boundary); } explicit ExternalTensorExpression(Expr *expr) : is_grad(true) { auto ptr = expr->cast(); - init(ptr->dt, ptr->ndim, ptr->arg_id, ptr->needs_grad, ptr->boundary); + init(ptr->dt, ptr->ndim, ptr->arg_id, ptr->needs_grad, ptr->is_argpack, + ptr->boundary); } void flatten(FlattenContext *ctx) override; @@ -517,11 +528,13 @@ class ExternalTensorExpression : public Expression { int ndim, const std::vector &arg_id, bool needs_grad, + bool is_argpack, BoundaryMode boundary) { this->dt = dt; this->ndim = ndim; this->arg_id = arg_id; this->needs_grad = needs_grad; + this->is_argpack = is_argpack; this->boundary = boundary; } }; diff --git a/taichi/ir/ir_builder.cpp b/taichi/ir/ir_builder.cpp index 9f00a4e6ee8b2..85653799c5bfc 100644 --- a/taichi/ir/ir_builder.cpp +++ b/taichi/ir/ir_builder.cpp @@ -180,9 +180,11 @@ RandStmt *IRBuilder::create_rand(DataType value_type) { ArgLoadStmt *IRBuilder::create_arg_load(const std::vector &arg_id, DataType dt, - bool is_ptr) { + bool is_ptr, + bool is_argpack) { return insert( - Stmt::make_typed(arg_id, dt, is_ptr, /*create_load*/ true)); + Stmt::make_typed(arg_id, dt, is_ptr, /*create_load*/ true, + is_argpack)); } ReturnStmt *IRBuilder::create_return(Stmt *value) { @@ -501,11 +503,12 @@ MeshPatchIndexStmt *IRBuilder::get_patch_index() { } ArgLoadStmt *IRBuilder::create_ndarray_arg_load(const std::vector &arg_id, DataType dt, - int ndim) { + int ndim, + bool is_argpack) { auto type = TypeFactory::get_instance().get_ndarray_struct_type(dt, ndim); return insert(Stmt::make_typed(arg_id, type, /*is_ptr=*/true, - /*create_load=*/false)); + /*create_load=*/false, /*is_argpack=*/is_argpack)); } } // namespace taichi::lang diff --git a/taichi/ir/ir_builder.h b/taichi/ir/ir_builder.h index 41065df6ded47..a234c83b1385e 100644 --- a/taichi/ir/ir_builder.h +++ b/taichi/ir/ir_builder.h @@ -147,11 +147,13 @@ class IRBuilder { // Load kernel arguments. ArgLoadStmt *create_arg_load(const std::vector &arg_id, DataType dt, - bool is_ptr); + bool is_ptr, + bool is_argpack); // Load kernel arguments. ArgLoadStmt *create_ndarray_arg_load(const std::vector &arg_id, DataType dt, - int total_dim); + int total_dim, + bool is_argpack); // The return value of the kernel. ReturnStmt *create_return(Stmt *value); diff --git a/taichi/ir/statements.h b/taichi/ir/statements.h index 5151c9076f473..aa18aeb47f9fc 100644 --- a/taichi/ir/statements.h +++ b/taichi/ir/statements.h @@ -172,7 +172,8 @@ class UnaryOpStmt : public Stmt { /** * Load a kernel argument. The data type should be known when constructing this * statement. |is_ptr| should be true iff the result can be used as a base - * pointer of an ExternalPtrStmt. + * pointer of an ExternalPtrStmt. |is_argpack| should be true iif the value is + * in an argpack. */ class ArgLoadStmt : public Stmt { public: @@ -192,11 +193,14 @@ class ArgLoadStmt : public Stmt { bool create_load; + bool is_argpack; + ArgLoadStmt(const std::vector &arg_id, const DataType &dt, bool is_ptr, - bool create_load) - : arg_id(arg_id), is_ptr(is_ptr), create_load(create_load) { + bool create_load, + bool is_argpack) + : arg_id(arg_id), is_ptr(is_ptr), create_load(create_load), is_argpack(is_argpack) { this->ret_type = dt; TI_STMT_REG_FIELDS; } @@ -205,7 +209,7 @@ class ArgLoadStmt : public Stmt { return false; } - TI_STMT_DEF_FIELDS(ret_type, arg_id, is_ptr); + TI_STMT_DEF_FIELDS(ret_type, arg_id, is_ptr, is_argpack); TI_DEFINE_ACCEPT_AND_CLONE }; diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index b0629c8dfe93e..740c65f9b6571 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -951,14 +951,15 @@ void export_lang(py::module &m) { m.def("make_arg_load_expr", Expr::make &, - const DataType &, bool, bool>, - "arg_id"_a, "dt"_a, "is_ptr"_a = false, "create_load"_a = true); + const DataType &, bool, bool, bool>, + "arg_id"_a, "dt"_a, "is_ptr"_a = false, "create_load"_a = true, + "is_argpack"_a = false); m.def("make_reference", Expr::make); m.def("make_external_tensor_expr", Expr::make &, bool, const BoundaryMode &>); + const std::vector &, bool, bool, const BoundaryMode &>); m.def("make_external_tensor_grad_expr", Expr::make); @@ -975,9 +976,9 @@ void export_lang(py::module &m) { Expr::make); m.def("make_texture_ptr_expr", - Expr::make &, int>); + Expr::make &, int, bool>); m.def("make_rw_texture_ptr_expr", - Expr::make &, int, + Expr::make &, int, bool, const BufferFormat &, int>); auto &&texture = diff --git a/taichi/transforms/scalarize.cpp b/taichi/transforms/scalarize.cpp index 1687b7a48bba3..887222f222421 100644 --- a/taichi/transforms/scalarize.cpp +++ b/taichi/transforms/scalarize.cpp @@ -702,7 +702,7 @@ class Scalarize : public BasicStmtVisitor { auto ret_type = stmt->ret_type.ptr_removed().get_element_type(); ret_type = TypeFactory::get_instance().get_pointer_type(ret_type); auto arg_load = std::make_unique( - stmt->arg_id, ret_type, stmt->is_ptr, stmt->create_load); + stmt->arg_id, ret_type, stmt->is_ptr, stmt->create_load, stmt->is_argpack); immediate_modifier_.replace_usages_with(stmt, arg_load.get()); diff --git a/tests/cpp/analysis/alias_analysis_test.cpp b/tests/cpp/analysis/alias_analysis_test.cpp index 3165b6bde262c..52e6bce9db860 100644 --- a/tests/cpp/analysis/alias_analysis_test.cpp +++ b/tests/cpp/analysis/alias_analysis_test.cpp @@ -95,8 +95,8 @@ TEST(AliasAnalysis, GlobalPtr_DiffSNodes) { TEST(AliasAnalysis, ExternalPtr_Same) { IRBuilder builder; - auto *arg1 = builder.create_arg_load({1}, PrimitiveType::i32, true); - auto *arg2 = builder.create_arg_load({2}, PrimitiveType::i32, false); + auto *arg1 = builder.create_arg_load({1}, PrimitiveType::i32, true, false); + auto *arg2 = builder.create_arg_load({2}, PrimitiveType::i32, false, false); const auto indices = std::vector{arg2, arg2}; auto *eptr1 = builder.create_external_ptr(arg1, indices); auto *eptr2 = builder.create_external_ptr(arg1, indices); @@ -107,8 +107,8 @@ TEST(AliasAnalysis, ExternalPtr_Same) { TEST(AliasAnalysis, ExternalPtr_Different) { IRBuilder builder; - auto *arg1 = builder.create_arg_load({1}, PrimitiveType::i32, true); - auto *arg2 = builder.create_arg_load({2}, PrimitiveType::i32, false); + auto *arg1 = builder.create_arg_load({1}, PrimitiveType::i32, true, false); + auto *arg2 = builder.create_arg_load({2}, PrimitiveType::i32, false, false); const auto indices1 = std::vector{arg2, builder.get_int32(1)}; const auto indices2 = std::vector{arg2, builder.get_int32(2)}; auto *eptr1 = builder.create_external_ptr(arg1, indices1); @@ -120,9 +120,9 @@ TEST(AliasAnalysis, ExternalPtr_Different) { TEST(AliasAnalysis, ExternalPtr_Uncertain) { IRBuilder builder; - auto *arg1 = builder.create_arg_load({1}, PrimitiveType::i32, true); - auto *arg2 = builder.create_arg_load({2}, PrimitiveType::i32, false); - auto *arg3 = builder.create_arg_load({3}, PrimitiveType::i32, false); + auto *arg1 = builder.create_arg_load({1}, PrimitiveType::i32, true, false); + auto *arg2 = builder.create_arg_load({2}, PrimitiveType::i32, false, false); + auto *arg3 = builder.create_arg_load({3}, PrimitiveType::i32, false, false); const auto indices1 = std::vector{arg2, arg2}; const auto indices2 = std::vector{arg2, arg3}; auto *eptr1 = builder.create_external_ptr(arg1, indices1); @@ -134,9 +134,9 @@ TEST(AliasAnalysis, ExternalPtr_Uncertain) { TEST(AliasAnalysis, ExternalPtr_DiffPtr) { IRBuilder builder; - auto *arg1 = builder.create_arg_load({1}, PrimitiveType::i32, true); - auto *arg2 = builder.create_arg_load({2}, PrimitiveType::i32, true); - auto *arg3 = builder.create_arg_load({3}, PrimitiveType::i32, false); + auto *arg1 = builder.create_arg_load({1}, PrimitiveType::i32, true, false); + auto *arg2 = builder.create_arg_load({2}, PrimitiveType::i32, true, false); + auto *arg3 = builder.create_arg_load({3}, PrimitiveType::i32, false, false); const auto indices = std::vector{arg3, arg3}; auto *eptr1 = builder.create_external_ptr(arg1, indices); auto *eptr2 = builder.create_external_ptr(arg2, indices); @@ -147,9 +147,9 @@ TEST(AliasAnalysis, ExternalPtr_DiffPtr) { TEST(AliasAnalysis, ExternalPtr_GradSame) { IRBuilder builder; - auto *arg1 = builder.create_arg_load({1}, PrimitiveType::i32, true); - auto *arg2 = builder.create_arg_load({1}, PrimitiveType::i32, true); - auto *arg3 = builder.create_arg_load({2}, PrimitiveType::i32, false); + auto *arg1 = builder.create_arg_load({1}, PrimitiveType::i32, true, false); + auto *arg2 = builder.create_arg_load({1}, PrimitiveType::i32, true, false); + auto *arg3 = builder.create_arg_load({2}, PrimitiveType::i32, false, false); const auto indices = std::vector{arg3, arg3}; auto *eptr1 = builder.create_external_ptr(arg1, indices); auto *eptr2 = builder.create_external_ptr(arg2, indices); @@ -160,9 +160,9 @@ TEST(AliasAnalysis, ExternalPtr_GradSame) { TEST(AliasAnalysis, ExternalPtr_GradDiff) { IRBuilder builder; - auto *arg1 = builder.create_arg_load({1}, PrimitiveType::i32, true); - auto *arg2 = builder.create_arg_load({1}, PrimitiveType::i32, true); - auto *arg3 = builder.create_arg_load({2}, PrimitiveType::i32, false); + auto *arg1 = builder.create_arg_load({1}, PrimitiveType::i32, true, false); + auto *arg2 = builder.create_arg_load({1}, PrimitiveType::i32, true, false); + auto *arg3 = builder.create_arg_load({2}, PrimitiveType::i32, false, false); const auto indices = std::vector{arg3, arg3}; auto *eptr1 = builder.create_external_ptr(arg1, indices, /*is_grad=*/false); auto *eptr2 = builder.create_external_ptr(arg2, indices, /*is_grad=*/true); diff --git a/tests/cpp/ir/ir_builder_test.cpp b/tests/cpp/ir/ir_builder_test.cpp index 9cbdd5ddb8e14..58052b6301f8c 100644 --- a/tests/cpp/ir/ir_builder_test.cpp +++ b/tests/cpp/ir/ir_builder_test.cpp @@ -97,7 +97,7 @@ TEST(IRBuilder, ExternalPtr) { array[0] = 2; array[2] = 40; auto *arg = - builder.create_ndarray_arg_load(/*arg_id=*/{0}, get_data_type(), 1); + builder.create_ndarray_arg_load(/*arg_id=*/{0}, get_data_type(), 1, false); auto *zero = builder.get_int32(0); auto *one = builder.get_int32(1); auto *two = builder.get_int32(2); @@ -174,7 +174,7 @@ TEST(IRBuilder, AtomicOp) { array[0] = 2; array[2] = 40; auto *arg = - builder.create_ndarray_arg_load(/*arg_id=*/{0}, get_data_type(), 1); + builder.create_ndarray_arg_load(/*arg_id=*/{0}, get_data_type(), 1, false); auto *zero = builder.get_int32(0); auto *one = builder.get_int32(1); auto *a0ptr = builder.create_external_ptr(arg, {zero}); diff --git a/tests/cpp/ir/ir_type_promotion_test.cpp b/tests/cpp/ir/ir_type_promotion_test.cpp index 063a7f4e06d97..1d8e918f0c164 100644 --- a/tests/cpp/ir/ir_type_promotion_test.cpp +++ b/tests/cpp/ir/ir_type_promotion_test.cpp @@ -11,7 +11,7 @@ TEST(IRTypePromotionTest, ShiftOp) { IRBuilder builder; // (u8)x << (i32)1 -> (u8)res - auto *lhs = builder.create_arg_load({0}, get_data_type(), false); + auto *lhs = builder.create_arg_load({0}, get_data_type(), false, false); builder.create_shl(lhs, builder.get_int32(1)); auto ir = builder.extract_ir(); diff --git a/tests/cpp/ir/ndarray_kernel.cpp b/tests/cpp/ir/ndarray_kernel.cpp index 9cc3f6b89b664..c02fc45d63bab 100644 --- a/tests/cpp/ir/ndarray_kernel.cpp +++ b/tests/cpp/ir/ndarray_kernel.cpp @@ -6,7 +6,7 @@ std::unique_ptr setup_kernel1(Program *prog) { IRBuilder builder1; { auto *arg = builder1.create_ndarray_arg_load(/*arg_id=*/{0}, - get_data_type(), 1); + get_data_type(), 1, false); auto *zero = builder1.get_int32(0); auto *one = builder1.get_int32(1); auto *two = builder1.get_int32(2); @@ -32,9 +32,9 @@ std::unique_ptr setup_kernel2(Program *prog) { { auto *arg0 = builder2.create_ndarray_arg_load(/*arg_id=*/{0}, - get_data_type(), 1); + get_data_type(), 1, false); auto *arg1 = builder2.create_arg_load(/*arg_id=*/{1}, get_data_type(), - /*is_ptr=*/false); + /*is_ptr=*/false, /*is_argpack=*/false); auto *one = builder2.get_int32(1); auto *a1ptr = builder2.create_external_ptr(arg0, {one}); builder2.create_global_store(a1ptr, arg1); // a[1] = arg1 diff --git a/tests/cpp/transforms/binary_op_simplify_test.cpp b/tests/cpp/transforms/binary_op_simplify_test.cpp index 183e9da9147e8..75178d2db86b1 100644 --- a/tests/cpp/transforms/binary_op_simplify_test.cpp +++ b/tests/cpp/transforms/binary_op_simplify_test.cpp @@ -19,7 +19,7 @@ class BinaryOpSimplifyTest : public ::testing::Test { TEST_F(BinaryOpSimplifyTest, MultiplyPOT) { IRBuilder builder; // (x * 32) << 3 - auto *x = builder.create_arg_load({0}, get_data_type(), false); + auto *x = builder.create_arg_load({0}, get_data_type(), false, false); auto *product = builder.create_mul(x, builder.get_int32(32)); auto *result = builder.create_shl(product, builder.get_int32(3)); builder.create_return(result); @@ -56,7 +56,7 @@ TEST_F(BinaryOpSimplifyTest, ModPOT) { IRBuilder builder; // x % 8 in the Python frontend is transformed into: // x - x / 8 * 8 - auto *x = builder.create_arg_load({0}, get_data_type(), false); + auto *x = builder.create_arg_load({0}, get_data_type(), false, false); auto *division = builder.create_div(x, builder.get_uint32(8)); auto *product = builder.create_mul(division, builder.get_uint32(8)); auto *result = builder.create_sub(x, product); diff --git a/tests/cpp/transforms/constant_fold_test.cpp b/tests/cpp/transforms/constant_fold_test.cpp index 750913c3abb60..3485f75a495ec 100644 --- a/tests/cpp/transforms/constant_fold_test.cpp +++ b/tests/cpp/transforms/constant_fold_test.cpp @@ -27,7 +27,7 @@ class ConstantFoldTest : public ::testing::Test { }; TEST_F(ConstantFoldTest, UnaryNeg) { - auto *x = builder.create_arg_load({0}, get_data_type(), false); + auto *x = builder.create_arg_load({0}, get_data_type(), false, false); auto *op = builder.get_int32(1); auto *out = builder.create_neg(op); auto *result = builder.create_sub(x, out); @@ -40,7 +40,7 @@ TEST_F(ConstantFoldTest, UnaryNeg) { } TEST_F(ConstantFoldTest, UnarySqrt) { - auto *x = builder.create_arg_load({0}, get_data_type(), false); + auto *x = builder.create_arg_load({0}, get_data_type(), false, false); auto *op = builder.get_int32(4); auto *out = builder.create_sqrt(op); auto *result = builder.create_sub(x, out); @@ -53,7 +53,7 @@ TEST_F(ConstantFoldTest, UnarySqrt) { } TEST_F(ConstantFoldTest, UnaryRound) { - auto *x = builder.create_arg_load({0}, get_data_type(), false); + auto *x = builder.create_arg_load({0}, get_data_type(), false, false); auto *op = builder.get_float32(3.4); auto *out = builder.create_round(op); auto *result = builder.create_sub(x, out); @@ -66,7 +66,7 @@ TEST_F(ConstantFoldTest, UnaryRound) { } TEST_F(ConstantFoldTest, UnaryFloor) { - auto *x = builder.create_arg_load({0}, get_data_type(), false); + auto *x = builder.create_arg_load({0}, get_data_type(), false, false); auto *op = builder.get_float32(3.4); auto *out = builder.create_floor(op); auto *result = builder.create_sub(x, out); @@ -79,7 +79,7 @@ TEST_F(ConstantFoldTest, UnaryFloor) { } TEST_F(ConstantFoldTest, UnaryCeil) { - auto *x = builder.create_arg_load({0}, get_data_type(), false); + auto *x = builder.create_arg_load({0}, get_data_type(), false, false); auto *op = builder.get_float32(3.4); auto *out = builder.create_ceil(op); auto *result = builder.create_sub(x, out); @@ -92,7 +92,7 @@ TEST_F(ConstantFoldTest, UnaryCeil) { } TEST_F(ConstantFoldTest, UnaryBitCast) { - auto *x = builder.create_arg_load({0}, get_data_type(), false); + auto *x = builder.create_arg_load({0}, get_data_type(), false, false); auto *op = builder.get_int32(1); auto *out = builder.create_bit_cast(op, PrimitiveType::f32); auto *result = builder.create_sub(x, out); @@ -105,7 +105,7 @@ TEST_F(ConstantFoldTest, UnaryBitCast) { } TEST_F(ConstantFoldTest, UnaryAbs) { - auto *x = builder.create_arg_load({0}, get_data_type(), false); + auto *x = builder.create_arg_load({0}, get_data_type(), false, false); auto *op = builder.get_float32(-3.4); auto *out = builder.create_abs(op); auto *result = builder.create_sub(x, out); @@ -118,7 +118,7 @@ TEST_F(ConstantFoldTest, UnaryAbs) { } TEST_F(ConstantFoldTest, UnarySin) { - auto *x = builder.create_arg_load({0}, get_data_type(), false); + auto *x = builder.create_arg_load({0}, get_data_type(), false, false); auto *op = builder.get_float32(1); auto *out = builder.create_sin(op); auto *result = builder.create_sub(x, out); @@ -131,7 +131,7 @@ TEST_F(ConstantFoldTest, UnarySin) { } TEST_F(ConstantFoldTest, UnaryAsin) { - auto *x = builder.create_arg_load({0}, get_data_type(), false); + auto *x = builder.create_arg_load({0}, get_data_type(), false, false); auto *op = builder.get_float32(1); auto *out = builder.create_asin(op); auto *result = builder.create_sub(x, out); @@ -144,7 +144,7 @@ TEST_F(ConstantFoldTest, UnaryAsin) { } TEST_F(ConstantFoldTest, UnaryCos) { - auto *x = builder.create_arg_load({0}, get_data_type(), false); + auto *x = builder.create_arg_load({0}, get_data_type(), false, false); auto *op = builder.get_float32(0.5); auto *out = builder.create_cos(op); auto *result = builder.create_sub(x, out); @@ -157,7 +157,7 @@ TEST_F(ConstantFoldTest, UnaryCos) { } TEST_F(ConstantFoldTest, UnaryAcos) { - auto *x = builder.create_arg_load({0}, get_data_type(), false); + auto *x = builder.create_arg_load({0}, get_data_type(), false, false); auto *op = builder.get_float32(0.5); auto *out = builder.create_acos(op); auto *result = builder.create_sub(x, out); @@ -170,7 +170,7 @@ TEST_F(ConstantFoldTest, UnaryAcos) { } TEST_F(ConstantFoldTest, UnaryTan) { - auto *x = builder.create_arg_load({0}, get_data_type(), false); + auto *x = builder.create_arg_load({0}, get_data_type(), false, false); auto *op = builder.get_float32(0.5); auto *out = builder.create_tan(op); auto *result = builder.create_sub(x, out); @@ -183,7 +183,7 @@ TEST_F(ConstantFoldTest, UnaryTan) { } TEST_F(ConstantFoldTest, UnaryTanh) { - auto *x = builder.create_arg_load({0}, get_data_type(), false); + auto *x = builder.create_arg_load({0}, get_data_type(), false, false); auto *op = builder.get_float32(0.5); auto *out = builder.create_tanh(op); auto *result = builder.create_sub(x, out); @@ -196,7 +196,7 @@ TEST_F(ConstantFoldTest, UnaryTanh) { } TEST_F(ConstantFoldTest, UnaryExp) { - auto *x = builder.create_arg_load({0}, get_data_type(), false); + auto *x = builder.create_arg_load({0}, get_data_type(), false, false); auto *op = builder.get_float32(0.5); auto *out = builder.create_exp(op); auto *result = builder.create_sub(x, out); @@ -209,7 +209,7 @@ TEST_F(ConstantFoldTest, UnaryExp) { } TEST_F(ConstantFoldTest, UnaryLog) { - auto *x = builder.create_arg_load({0}, get_data_type(), false); + auto *x = builder.create_arg_load({0}, get_data_type(), false, false); auto *op = builder.get_float32(4); auto *out = builder.create_log(op); auto *result = builder.create_sub(x, out); @@ -222,7 +222,7 @@ TEST_F(ConstantFoldTest, UnaryLog) { } TEST_F(ConstantFoldTest, UnaryBitNot) { - auto *x = builder.create_arg_load({0}, get_data_type(), false); + auto *x = builder.create_arg_load({0}, get_data_type(), false, false); auto *op = builder.get_int32(1); auto *out = builder.create_not(op); auto *result = builder.create_sub(x, out); @@ -235,7 +235,7 @@ TEST_F(ConstantFoldTest, UnaryBitNot) { } TEST_F(ConstantFoldTest, UnaryLogicNot) { - auto *x = builder.create_arg_load({0}, get_data_type(), false); + auto *x = builder.create_arg_load({0}, get_data_type(), false, false); auto *op = builder.get_int32(1); auto *out = builder.create_logical_not(op); auto *result = builder.create_sub(x, out); @@ -248,7 +248,7 @@ TEST_F(ConstantFoldTest, UnaryLogicNot) { } TEST_F(ConstantFoldTest, UnaryCastValue) { - auto *x = builder.create_arg_load({0}, get_data_type(), false); + auto *x = builder.create_arg_load({0}, get_data_type(), false, false); auto *op = builder.get_int32(1); auto *out = builder.create_cast(op, PrimitiveType::f32); auto *result = builder.create_sub(x, out); @@ -261,7 +261,7 @@ TEST_F(ConstantFoldTest, UnaryCastValue) { } TEST_F(ConstantFoldTest, UnaryRsqrt) { - auto *x = builder.create_arg_load({0}, get_data_type(), false); + auto *x = builder.create_arg_load({0}, get_data_type(), false, false); auto *op = builder.get_float32(4); auto *out = builder.create_rsqrt(op); auto *result = builder.create_sub(x, out); @@ -274,7 +274,7 @@ TEST_F(ConstantFoldTest, UnaryRsqrt) { } TEST_F(ConstantFoldTest, BinaryMul) { - auto *x = builder.create_arg_load({0}, get_data_type(), false); + auto *x = builder.create_arg_load({0}, get_data_type(), false, false); auto *lhs = builder.get_int32(1); auto *rhs = builder.get_int32(2); auto *out = builder.create_mul(lhs, rhs); @@ -288,7 +288,7 @@ TEST_F(ConstantFoldTest, BinaryMul) { } TEST_F(ConstantFoldTest, BinaryAdd) { - auto *x = builder.create_arg_load({0}, get_data_type(), false); + auto *x = builder.create_arg_load({0}, get_data_type(), false, false); auto *one = builder.get_int32(1); auto *rhs = builder.get_int32(2); auto *out = builder.create_add(one, rhs); @@ -302,7 +302,7 @@ TEST_F(ConstantFoldTest, BinaryAdd) { } TEST_F(ConstantFoldTest, BinarySub) { - auto *x = builder.create_arg_load({0}, get_data_type(), false); + auto *x = builder.create_arg_load({0}, get_data_type(), false, false); auto *one = builder.get_int32(1); auto *rhs = builder.get_int32(2); auto *out = builder.create_sub(one, rhs); @@ -316,7 +316,7 @@ TEST_F(ConstantFoldTest, BinarySub) { } TEST_F(ConstantFoldTest, BinaryFloorDiv) { - auto *x = builder.create_arg_load({0}, get_data_type(), false); + auto *x = builder.create_arg_load({0}, get_data_type(), false, false); auto *one = builder.get_float32(1); auto *rhs = builder.get_float32(2); auto *out = builder.create_floordiv(one, rhs); @@ -330,7 +330,7 @@ TEST_F(ConstantFoldTest, BinaryFloorDiv) { } TEST_F(ConstantFoldTest, BinaryDiv) { - auto *x = builder.create_arg_load({0}, get_data_type(), false); + auto *x = builder.create_arg_load({0}, get_data_type(), false, false); auto *one = builder.get_float32(1); auto *rhs = builder.get_float32(2); auto *out = builder.create_div(one, rhs); @@ -344,7 +344,7 @@ TEST_F(ConstantFoldTest, BinaryDiv) { } TEST_F(ConstantFoldTest, BinaryMod) { - auto *x = builder.create_arg_load({0}, get_data_type(), false); + auto *x = builder.create_arg_load({0}, get_data_type(), false, false); auto *lhs = builder.get_int32(3); auto *rhs = builder.get_int32(2); auto *out = builder.create_mod(lhs, rhs); @@ -358,7 +358,7 @@ TEST_F(ConstantFoldTest, BinaryMod) { } TEST_F(ConstantFoldTest, BinaryMax) { - auto *x = builder.create_arg_load({0}, get_data_type(), false); + auto *x = builder.create_arg_load({0}, get_data_type(), false, false); auto *lhs = builder.get_int32(3); auto *rhs = builder.get_int32(2); auto *out = builder.create_max(lhs, rhs); @@ -372,7 +372,7 @@ TEST_F(ConstantFoldTest, BinaryMax) { } TEST_F(ConstantFoldTest, BinaryMin) { - auto *x = builder.create_arg_load({0}, get_data_type(), false); + auto *x = builder.create_arg_load({0}, get_data_type(), false, false); auto *lhs = builder.get_int32(3); auto *rhs = builder.get_int32(2); auto *out = builder.create_min(lhs, rhs); @@ -386,7 +386,7 @@ TEST_F(ConstantFoldTest, BinaryMin) { } TEST_F(ConstantFoldTest, BinaryBitAnd) { - auto *x = builder.create_arg_load({0}, get_data_type(), false); + auto *x = builder.create_arg_load({0}, get_data_type(), false, false); auto *lhs = builder.get_int32(3); auto *rhs = builder.get_int32(2); auto *out = builder.create_and(lhs, rhs); @@ -400,7 +400,7 @@ TEST_F(ConstantFoldTest, BinaryBitAnd) { } TEST_F(ConstantFoldTest, BinaryBitOr) { - auto *x = builder.create_arg_load({0}, get_data_type(), false); + auto *x = builder.create_arg_load({0}, get_data_type(), false, false); auto *lhs = builder.get_int32(3); auto *rhs = builder.get_int32(2); auto *out = builder.create_or(lhs, rhs); @@ -414,7 +414,7 @@ TEST_F(ConstantFoldTest, BinaryBitOr) { } TEST_F(ConstantFoldTest, BinaryBitShl) { - auto *x = builder.create_arg_load({0}, get_data_type(), false); + auto *x = builder.create_arg_load({0}, get_data_type(), false, false); auto *lhs = builder.get_int32(3); auto *rhs = builder.get_int32(2); auto *out = builder.create_shl(lhs, rhs); @@ -428,7 +428,7 @@ TEST_F(ConstantFoldTest, BinaryBitShl) { } TEST_F(ConstantFoldTest, BinaryBitShrInt32) { - auto *x = builder.create_arg_load({0}, get_data_type(), false); + auto *x = builder.create_arg_load({0}, get_data_type(), false, false); auto *lhs = builder.get_int32(-1); auto *rhs = builder.get_uint32(2); auto *out = builder.create_shr(lhs, rhs); @@ -442,7 +442,7 @@ TEST_F(ConstantFoldTest, BinaryBitShrInt32) { } TEST_F(ConstantFoldTest, BinaryBitShrInt64) { - auto *x = builder.create_arg_load({0}, get_data_type(), false); + auto *x = builder.create_arg_load({0}, get_data_type(), false, false); auto *lhs = builder.get_int64(-1); auto *rhs = builder.get_uint32(2); auto *out = builder.create_shr(lhs, rhs); @@ -456,7 +456,7 @@ TEST_F(ConstantFoldTest, BinaryBitShrInt64) { } TEST_F(ConstantFoldTest, BinaryBitSar) { - auto *x = builder.create_arg_load({0}, get_data_type(), false); + auto *x = builder.create_arg_load({0}, get_data_type(), false, false); auto *lhs = builder.get_int32(-1); auto *rhs = builder.get_uint32(2); auto *out = builder.create_sar(lhs, rhs); @@ -469,7 +469,7 @@ TEST_F(ConstantFoldTest, BinaryBitSar) { } TEST_F(ConstantFoldTest, BinaryCmpLt) { - auto *x = builder.create_arg_load({0}, get_data_type(), false); + auto *x = builder.create_arg_load({0}, get_data_type(), false, false); auto *lhs = builder.get_int32(1); auto *rhs = builder.get_int32(2); auto *out = builder.create_cmp_lt(lhs, rhs); @@ -483,7 +483,7 @@ TEST_F(ConstantFoldTest, BinaryCmpLt) { } TEST_F(ConstantFoldTest, BinaryCmpGt) { - auto *x = builder.create_arg_load({0}, get_data_type(), false); + auto *x = builder.create_arg_load({0}, get_data_type(), false, false); auto *lhs = builder.get_int32(2); auto *rhs = builder.get_int32(2); auto *out = builder.create_cmp_gt(lhs, rhs); @@ -497,7 +497,7 @@ TEST_F(ConstantFoldTest, BinaryCmpGt) { } TEST_F(ConstantFoldTest, BinaryCmpGe) { - auto *x = builder.create_arg_load({0}, get_data_type(), false); + auto *x = builder.create_arg_load({0}, get_data_type(), false, false); auto *lhs = builder.get_int32(2); auto *rhs = builder.get_int32(2); auto *out = builder.create_cmp_ge(lhs, rhs); @@ -511,7 +511,7 @@ TEST_F(ConstantFoldTest, BinaryCmpGe) { } TEST_F(ConstantFoldTest, BinaryCmpEq) { - auto *x = builder.create_arg_load({0}, get_data_type(), false); + auto *x = builder.create_arg_load({0}, get_data_type(), false, false); auto *lhs = builder.get_int32(2); auto *rhs = builder.get_int32(2); auto *out = builder.create_cmp_eq(lhs, rhs); @@ -525,7 +525,7 @@ TEST_F(ConstantFoldTest, BinaryCmpEq) { } TEST_F(ConstantFoldTest, BinaryCmpNes) { - auto *x = builder.create_arg_load({0}, get_data_type(), false); + auto *x = builder.create_arg_load({0}, get_data_type(), false, false); auto *lhs = builder.get_int32(2); auto *rhs = builder.get_int32(2); auto *out = builder.create_cmp_ne(lhs, rhs); @@ -539,7 +539,7 @@ TEST_F(ConstantFoldTest, BinaryCmpNes) { } TEST_F(ConstantFoldTest, BinaryPow) { - auto *x = builder.create_arg_load({0}, get_data_type(), false); + auto *x = builder.create_arg_load({0}, get_data_type(), false, false); auto *lhs = builder.get_int32(2); auto *rhs = builder.get_int32(2); auto *out = builder.create_pow(lhs, rhs); @@ -553,7 +553,7 @@ TEST_F(ConstantFoldTest, BinaryPow) { } TEST_F(ConstantFoldTest, BinaryAtan2) { - auto *x = builder.create_arg_load({0}, get_data_type(), false); + auto *x = builder.create_arg_load({0}, get_data_type(), false, false); auto *lhs = builder.get_int32(0); auto *rhs = builder.get_int32(2); auto *out = builder.create_atan2(lhs, rhs); diff --git a/tests/cpp/transforms/determine_ad_stack_size_test.cpp b/tests/cpp/transforms/determine_ad_stack_size_test.cpp index ca905091ab453..ea2342da84c97 100644 --- a/tests/cpp/transforms/determine_ad_stack_size_test.cpp +++ b/tests/cpp/transforms/determine_ad_stack_size_test.cpp @@ -108,7 +108,7 @@ TEST_P(DetermineAdStackSizeTest, If) { bool has_false_branch = (kFalseBranchPushes > 0); IRBuilder builder; - auto *arg = builder.create_arg_load({0}, get_data_type(), false); + auto *arg = builder.create_arg_load({0}, get_data_type(), false, false); auto *stack = builder.create_ad_stack(get_data_type(), 0 /*adaptive size*/); auto *if_stmt = builder.create_if(arg); @@ -157,7 +157,7 @@ INSTANTIATE_TEST_SUITE_P( TEST_F(DetermineAdStackSizeTest, EmptyNodes) { IRBuilder builder; - auto *arg = builder.create_arg_load({0}, get_data_type(), false); + auto *arg = builder.create_arg_load({0}, get_data_type(), false, false); auto *stack = builder.create_ad_stack(get_data_type(), 0 /*adaptive size*/); auto *one = builder.get_int32(1); diff --git a/tests/cpp/transforms/half2_vectorization_test.cpp b/tests/cpp/transforms/half2_vectorization_test.cpp index 6d5861df8f305..769d7a2e23dd0 100644 --- a/tests/cpp/transforms/half2_vectorization_test.cpp +++ b/tests/cpp/transforms/half2_vectorization_test.cpp @@ -36,7 +36,7 @@ TEST(Half2Vectorization, Ndarray) { auto argload_stmt = block->push_back( std::vector{0} /*arg_id*/, PrimitiveType::f16, /*is_ptr*/ true, - /*create_load*/ false); + /*create_load*/ false, /*is_argpack*/ false); argload_stmt->ret_type = half2_type; auto const_0_stmt = block->push_back(TypedConstant(0)); diff --git a/tests/cpp/transforms/inlining_test.cpp b/tests/cpp/transforms/inlining_test.cpp index 7f241b0ff8583..a102707e12791 100644 --- a/tests/cpp/transforms/inlining_test.cpp +++ b/tests/cpp/transforms/inlining_test.cpp @@ -23,7 +23,7 @@ TEST_F(InliningTest, ArgLoadOfArgLoad) { // def test_func(x: ti.i32) -> ti.i32: // return x + 1 auto *arg = builder.create_arg_load(/*arg_id=*/{0}, get_data_type(), - /*is_ptr=*/false); + /*is_ptr=*/false, /*is_argpack=*/false); auto *sum = builder.create_add(arg, builder.get_int32(1)); builder.create_return(sum); auto func_body = builder.extract_ir(); @@ -43,7 +43,7 @@ TEST_F(InliningTest, ArgLoadOfArgLoad) { // return test_func(x) auto *kernel_arg = builder.create_arg_load(/*arg_id=*/{0}, get_data_type(), - /*is_ptr=*/false); + /*is_ptr=*/false, /*is_argpack=*/false); auto *func_call = builder.create_func_call(func, {kernel_arg}); builder.create_return(func_call); auto kernel_body = builder.extract_ir(); diff --git a/tests/cpp/transforms/scalarize_test.cpp b/tests/cpp/transforms/scalarize_test.cpp index b65ad3e2afd6c..629c2c163b5ff 100644 --- a/tests/cpp/transforms/scalarize_test.cpp +++ b/tests/cpp/transforms/scalarize_test.cpp @@ -33,7 +33,7 @@ TEST(Scalarize, ScalarizeGlobalStore) { auto argload_stmt = block->push_back( std::vector{0} /*arg_id*/, type, /*is_ptr*/ true, - /*create_load*/ false); + /*create_load*/ false, /*is_argpack*/ false); std::vector indices = {}; Stmt *dest_stmt = block->push_back( @@ -98,7 +98,7 @@ TEST(Scalarize, ScalarizeGlobalLoad) { auto argload_stmt = block->push_back( std::vector{0} /*arg_id*/, type, /*is_ptr*/ true, - /*create_load*/ false); + /*create_load*/ false, /*is_argpack*/ false); std::vector indices = {}; Stmt *src_stmt = block->push_back(