diff --git a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp index d983ecd217aee..e2d35645cd832 100644 --- a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp +++ b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp @@ -17,6 +17,7 @@ #include "emitc/Dialect/EmitC/EmitCDialect.h" #include "iree/compiler/Dialect/IREE/IR/IREEDialect.h" #include "iree/compiler/Dialect/VM/IR/VMOps.h" +#include "llvm/ADT/TypeSwitch.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Matchers.h" #include "mlir/Transforms/DialectConversion.h" @@ -154,9 +155,9 @@ class GlobalLoadOpConversion : public OpConversionPattern { auto type = loadOp.getOperation()->getResultTypes(); StringAttr callee = rewriter.getStringAttr(funcName); - // TODO(simon-camp): We can't represent structs in emitc (yet maybe), so the - // buffer where globals live after code generation as well as the state - // struct argument name are hardcoded here. + // TODO(simon-camp): We can't represent structs in emitc (yet maybe), so + // the buffer where globals live after code generation as well as the + // state struct argument name are hardcoded here. ArrayAttr args = rewriter.getArrayAttr( {rewriter.getStringAttr("state->rwdata"), rewriter.getUI32IntegerAttr(static_cast( @@ -190,9 +191,9 @@ class GlobalStoreOpConversion : public OpConversionPattern { auto type = storeOp.getOperation()->getResultTypes(); StringAttr callee = rewriter.getStringAttr(funcName); - // TODO(simon-camp): We can't represent structs in emitc (yet maybe), so the - // buffer where globals live after code generation as well as the state - // struct argument name are hardcoded here. + // TODO(simon-camp): We can't represent structs in emitc (yet maybe), so + // the buffer where globals live after code generation as well as the + // state struct argument name are hardcoded here. ArrayAttr args = rewriter.getArrayAttr( {rewriter.getStringAttr("state->rwdata"), rewriter.getUI32IntegerAttr(static_cast( @@ -217,10 +218,189 @@ class ListAllocOpConversion LogicalResult matchAndRewrite( IREE::VM::ListAllocOp allocOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - return failure(); + // clang-format off + // The generated c code looks roughly like this. + // iree_vm_type_def_t element_type = iree_vm_type_def_make_value_type(IREE_VM_VALUE_TYPE_I32); + // iree_vm_type_def_t* element_type_ptr = &element_type; + // iree_vm_list_t* list = nullptr; + // iree_vm_list_t** list_ptr = &list; + // iree_vm_list_create(&element_type, {initial_capacity}, state->allocator, &list); + // clang-format on + + auto ctx = allocOp.getContext(); + auto loc = allocOp.getLoc(); + + auto listType = allocOp.getType() + .cast() + .getObjectType() + .cast(); + auto elementType = listType.getElementType(); + std::string elementTypeStr; + StringRef elementTypeConstructor; + if (elementType.isa()) { + unsigned int bitWidth = elementType.getIntOrFloatBitWidth(); + elementTypeStr = + std::string("IREE_VM_VALUE_TYPE_I") + std::to_string(bitWidth); + elementTypeConstructor = "iree_vm_type_def_make_value_type"; + } else { + return allocOp.emitError() << "Unhandeled element type " << elementType; + } + + auto elementTypeOp = rewriter.create( + /*location=*/loc, + /*type=*/emitc::OpaqueType::get(ctx, "iree_vm_type_def_t"), + /*callee=*/rewriter.getStringAttr(elementTypeConstructor), + /*args=*/ArrayAttr::get(ctx, {StringAttr::get(ctx, elementTypeStr)}), + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ArrayRef{}); + + auto elementTypePtrOp = rewriter.create( + /*location=*/loc, + /*result=*/emitc::OpaqueType::get(ctx, "iree_vm_type_def_t*"), + /*operand=*/elementTypeOp.getResult(0)); + + auto listOp = rewriter.replaceOpWithNewOp( + /*op=*/allocOp, + /*resultType=*/emitc::OpaqueType::get(ctx, "iree_vm_list_t*"), + /*value=*/StringAttr::get(ctx, "nullptr")); + + auto listPtrOp = rewriter.create( + /*location=*/loc, + /*result=*/emitc::OpaqueType::get(ctx, "iree_vm_list_t**"), + /*operand=*/listOp.getResult()); + + rewriter.create( + /*location=*/loc, + /*type=*/TypeRange{}, + /*callee=*/rewriter.getStringAttr("iree_vm_list_create"), + /*args=*/ + ArrayAttr::get(ctx, {rewriter.getIndexAttr(0), rewriter.getIndexAttr(1), + StringAttr::get(ctx, "state->allocator"), + rewriter.getIndexAttr(2)}), + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ + ArrayRef{elementTypePtrOp.getResult(), operands[0], + listPtrOp.getResult()}); + + return success(); + } +}; + +template +class ListGetOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + private: + LogicalResult matchAndRewrite( + GetOpTy getOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto ctx = getOp.getContext(); + auto loc = getOp.getLoc(); + + Optional valueTypeEnum; + Optional valueExtractor; + + std::tie(valueTypeEnum, valueExtractor) = + TypeSwitch, Optional>>( + getOp.getOperation()) + .Case([&](auto op) { + return std::make_pair(StringRef("IREE_VM_VALUE_TYPE_I32"), + StringRef("vm_list_value_extract_i32")); + }) + .template Case([&](auto op) { + return std::make_pair(StringRef("IREE_VM_VALUE_TYPE_I64"), + StringRef("vm_list_value_extract_i64")); + }) + .Default([](Operation *) { return std::make_pair(None, None); }); + + if (!valueTypeEnum.hasValue() || !valueExtractor.hasValue()) { + return getOp.emitOpError() << " not handeled"; + } + + auto valueOp = rewriter.create( + /*location=*/loc, + /*resultType=*/emitc::OpaqueType::get(ctx, "iree_vm_value_t"), + /*value=*/StringAttr::get(ctx, "")); + + auto valuePtrOp = rewriter.create( + /*location=*/loc, + /*result=*/emitc::OpaqueType::get(ctx, "iree_vm_value_t*"), + /*operand=*/valueOp.getResult()); + + auto getValueOp = rewriter.create( + /*location=*/loc, + /*type=*/TypeRange{}, + /*callee=*/rewriter.getStringAttr("iree_vm_list_get_value_as"), + /*args=*/ + ArrayAttr::get(ctx, {rewriter.getIndexAttr(0), rewriter.getIndexAttr(1), + StringAttr::get(ctx, valueTypeEnum.getValue()), + rewriter.getIndexAttr(2)}), + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ + ArrayRef{getOp.list(), getOp.index(), valuePtrOp.getResult()}); + + rewriter.replaceOpWithNewOp( + /*op=*/getOp, + /*type=*/getOp.getType(), + /*callee=*/rewriter.getStringAttr(valueExtractor.getValue()), + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ArrayRef{valuePtrOp.getResult()}); + + return success(); } }; +template +class ListSetOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + private: + LogicalResult matchAndRewrite( + SetOpTy setOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto ctx = setOp.getContext(); + auto loc = setOp.getLoc(); + + Optional valueConstructor = + TypeSwitch>(setOp.getOperation()) + .Case( + [&](auto op) { return StringRef("iree_vm_value_make_i32"); }) + .template Case( + [&](auto op) { return StringRef("iree_vm_value_make_i64"); }) + .Default([](Operation *) { return None; }); + + if (!valueConstructor.hasValue()) { + return setOp.emitOpError() << " not handeled"; + } + + auto valueOp = rewriter.create( + /*location=*/loc, + /*type=*/emitc::OpaqueType::get(ctx, "iree_vm_value_t"), + /*callee=*/rewriter.getStringAttr(valueConstructor.getValue()), + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ArrayRef{setOp.value()}); + + auto valuePtrOp = rewriter.create( + /*location=*/loc, + /*result=*/emitc::OpaqueType::get(ctx, "iree_vm_value_t*"), + /*operand=*/valueOp.getResult(0)); + + rewriter.replaceOpWithNewOp( + /*op=*/setOp, + /*type=*/TypeRange{}, + /*callee=*/rewriter.getStringAttr("iree_vm_list_set_value"), + /*args=*/ + ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ + ArrayRef{setOp.list(), setOp.index(), valuePtrOp.getResult()}); + + return success(); + } +}; } // namespace void populateVMToCPatterns(MLIRContext *context, @@ -238,8 +418,18 @@ void populateVMToCPatterns(MLIRContext *context, patterns.insert>(context); patterns.insert(context); - // Lists + // List ops + // TODO(simon-camp): We leak memory in the generated code, as we never release + // the lists. patterns.insert(context); + patterns.insert>( + context, "iree_vm_list_reserve"); + patterns.insert>( + context, "iree_vm_list_resize"); + patterns.insert>(context, + "iree_vm_list_size"); + patterns.insert>(context); + patterns.insert>(context); // Conditional assignment ops patterns.insert>(context, @@ -299,6 +489,10 @@ void populateVMToCPatterns(MLIRContext *context, patterns.insert>(context); patterns.insert>(context); + // ExtI64: List ops + patterns.insert>(context); + patterns.insert>(context); + // ExtI64: Conditional assignment ops patterns.insert>(context, "vm_select_i64"); diff --git a/iree/vm/ops.h b/iree/vm/ops.h index 7099f3847f4a2..f769b246706c3 100644 --- a/iree/vm/ops.h +++ b/iree/vm/ops.h @@ -18,6 +18,7 @@ #include #include "iree/base/api.h" +#include "iree/vm/value.h" //===------------------------------------------------------------------===// // Globals @@ -34,6 +35,14 @@ static inline void vm_global_store_i32(uint8_t* base, uint32_t byte_offset, *global_ptr = value; } +//===------------------------------------------------------------------===// +// List ops +//===------------------------------------------------------------------===// + +static inline int32_t vm_list_value_extract_i32(iree_vm_value_t* value) { + return value->i32; +} + //===------------------------------------------------------------------===// // Conditional assignment //===------------------------------------------------------------------===// @@ -139,6 +148,14 @@ static inline iree_status_t vm_fail_or_ok(int32_t status_code, return iree_ok_status(); } +//===------------------------------------------------------------------===// +// ExtI64: List ops +//===------------------------------------------------------------------===// + +static inline int64_t vm_list_value_extract_i64(iree_vm_value_t* value) { + return value->i64; +} + //===------------------------------------------------------------------===// // ExtI64: Conditional assignment //===------------------------------------------------------------------===// diff --git a/iree/vm/test/emitc/list_ops_ref.h b/iree/vm/test/emitc/list_ops_ref.h deleted file mode 100644 index 218c946a4ad9b..0000000000000 --- a/iree/vm/test/emitc/list_ops_ref.h +++ /dev/null @@ -1,121 +0,0 @@ -#include "iree/testing/status_matchers.h" -#include "iree/vm/api.h" -#include "iree/vm/ops.h" -#include "iree/vm/shims.h" - -//============================================================================= -// module "list_ops_ref" -//============================================================================= - -struct list_ops_ref_s { - iree_allocator_t allocator; -}; -struct list_ops_ref_state_s { - iree_allocator_t allocator; - uint8_t rwdata[0]; - iree_vm_ref_t refs[0]; -}; -typedef struct list_ops_ref_s list_ops_ref_t; -typedef struct list_ops_ref_state_s list_ops_ref_state_t; - -iree_status_t list_ops_ref_test_i8_impl(list_ops_ref_state_t* state) { - // %c42 = vm.const.i32 42 : i32 - int32_t c42 = 42; - - // %list = vm.list.alloc %c42 : (i32) -> !vm.list - iree_vm_type_def_t element_type = - iree_vm_type_def_make_value_type(IREE_VM_VALUE_TYPE_I8); - iree_vm_type_def_t* element_type_ptr = &element_type; - iree_vm_list_t* list = nullptr; - iree_vm_list_t** list_ptr = &list; - iree_vm_list_create(element_type_ptr, c42, state->allocator, list_ptr); - - // %sz = vm.list.size %list : (!vm.list) -> i32 - int32_t sz = iree_vm_list_size(list); - - // %sz_dno = iree.do_not_optimize(%sz) : i32 - - // vm.return - return iree_ok_status(); -} - -//============================================================================= -// The code below setups functions and lookup tables to implement the vm -// interface -//============================================================================= -//============================================================================= -// module "list_ops_ref" -//============================================================================= - -static iree_status_t list_ops_ref_test_i8(iree_vm_stack_t* stack, - list_ops_ref_t* module, - list_ops_ref_state_t* state) { - return list_ops_ref_test_i8_impl(state); -} -static const iree_vm_native_export_descriptor_t list_ops_ref_exports_[] = { - {iree_make_cstring_view("test_i8"), iree_make_cstring_view("0v_v"), 0, - NULL}, -}; - -static const iree_vm_native_import_descriptor_t list_ops_ref_imports_[] = {}; - -static const iree_vm_native_function_ptr_t list_ops_ref_funcs_[] = { - {(iree_vm_native_function_shim_t)call_0v_v_shim, - (iree_vm_native_function_target_t)list_ops_ref_test_i8}, -}; - -static const iree_vm_native_module_descriptor_t list_ops_ref_descriptor_ = { - iree_make_cstring_view("list_ops_ref"), - IREE_ARRAYSIZE(list_ops_ref_imports_), - list_ops_ref_imports_, - IREE_ARRAYSIZE(list_ops_ref_exports_), - list_ops_ref_exports_, - IREE_ARRAYSIZE(list_ops_ref_funcs_), - list_ops_ref_funcs_, - 0, - NULL, -}; -static iree_status_t list_ops_ref_alloc_state( - void* self, iree_allocator_t allocator, - iree_vm_module_state_t** out_module_state) { - list_ops_ref_state_t* state = NULL; - IREE_RETURN_IF_ERROR( - iree_allocator_malloc(allocator, sizeof(*state), (void**)&state)); - memset(state, 0, sizeof(*state)); - state->allocator = allocator; - state->allocator = iree_allocator_system(); - *out_module_state = (iree_vm_module_state_t*)state; - return iree_ok_status(); -} -static void list_ops_ref_free_state(void* self, - iree_vm_module_state_t* module_state) { - list_ops_ref_state_t* state = (list_ops_ref_state_t*)module_state; - iree_allocator_free(state->allocator, state); -} - -static void list_ops_ref_destroy(void* self) { - list_ops_ref_t* module = (list_ops_ref_t*)self; - iree_allocator_free(module->allocator, module); -} - -static iree_status_t list_ops_ref_create(iree_allocator_t allocator, - iree_vm_module_t** out_module) { - // Allocate shared module state. - list_ops_ref_t* module = NULL; - IREE_RETURN_IF_ERROR( - iree_allocator_malloc(allocator, sizeof(*module), (void**)&module)); - memset(module, 0, sizeof(*module)); - module->allocator = allocator; - - iree_vm_module_t interface; - iree_status_t status = iree_vm_module_initialize(&interface, module); - if (!iree_status_is_ok(status)) { - iree_allocator_free(allocator, module); - return status; - } - interface.destroy = list_ops_ref_destroy; - interface.alloc_state = list_ops_ref_alloc_state; - interface.free_state = list_ops_ref_free_state; - return iree_vm_native_module_create(&interface, &list_ops_ref_descriptor_, - allocator, out_module); -} diff --git a/iree/vm/test/emitc/module_test.cc b/iree/vm/test/emitc/module_test.cc index 2504a68ddcb63..6081ca7477763 100644 --- a/iree/vm/test/emitc/module_test.cc +++ b/iree/vm/test/emitc/module_test.cc @@ -28,7 +28,7 @@ #include "iree/vm/test/emitc/conversion_ops.h" #include "iree/vm/test/emitc/conversion_ops_i64.h" #include "iree/vm/test/emitc/global_ops.h" -#include "iree/vm/test/emitc/list_ops_ref.h" +#include "iree/vm/test/emitc/list_ops.h" #include "iree/vm/test/emitc/shift_ops.h" #include "iree/vm/test/emitc/shift_ops_i64.h" @@ -49,7 +49,8 @@ struct ModuleDescription { }; std::ostream& operator<<(std::ostream& os, const TestParams& params) { - return os << absl::StrReplaceAll(params.local_name, {{":", "_"}, {".", "_"}}); + std::string qualified_name = params.module_name + "." + params.local_name; + return os << absl::StrReplaceAll(qualified_name, {{":", "_"}, {".", "_"}}); } std::vector GetModuleTestParams() { @@ -67,7 +68,7 @@ std::vector GetModuleTestParams() { {conversion_ops_descriptor_, conversion_ops_create}, {conversion_ops_i64_descriptor_, conversion_ops_i64_create}, {global_ops_descriptor_, global_ops_create}, - {list_ops_ref_descriptor_, list_ops_ref_create}, + {list_ops_descriptor_, list_ops_create}, {shift_ops_descriptor_, shift_ops_create}, {shift_ops_i64_descriptor_, shift_ops_i64_create}}; @@ -132,7 +133,6 @@ class VMCModuleTest : public ::testing::Test, iree_vm_instance_t* instance_ = nullptr; iree_vm_context_t* context_ = nullptr; - iree_vm_module_t* bytecode_module_ = nullptr; }; TEST_P(VMCModuleTest, Check) { diff --git a/iree/vm/test/list_ops.mlir b/iree/vm/test/list_ops.mlir index b336b6fe7f4cd..ba743a800ee63 100644 --- a/iree/vm/test/list_ops.mlir +++ b/iree/vm/test/list_ops.mlir @@ -7,9 +7,30 @@ vm.module @list_ops { vm.export @test_i8 vm.func @test_i8() { %c42 = vm.const.i32 42 : i32 + %c100 = vm.const.i32 100 : i32 + %c0 = vm.const.i32 0 : i32 %list = vm.list.alloc %c42 : (i32) -> !vm.list + vm.list.reserve %list, %c100 : (!vm.list, i32) %sz = vm.list.size %list : (!vm.list) -> i32 %sz_dno = iree.do_not_optimize(%sz) : i32 + vm.check.eq %sz_dno, %c0, "list.empty.size()=0" : i32 + vm.return + } + + //===--------------------------------------------------------------------===// + // vm.list.* with I16 types + //===--------------------------------------------------------------------===// + + vm.export @test_i16 + vm.func @test_i16() { + %c0 = vm.const.i32 0 : i32 + %c1 = vm.const.i32 1 : i32 + %c27 = vm.const.i32 27 : i32 + %list = vm.list.alloc %c1 : (i32) -> !vm.list + vm.list.resize %list, %c1 : (!vm.list, i32) + vm.list.set.i32 %list, %c0, %c27 : (!vm.list, i32, i32) + %v = vm.list.get.i32 %list, %c0 : (!vm.list, i32) -> i32 + vm.check.eq %v, %c27, "list.empty.set(0, 27).get(0)=27" : i32 vm.return } @@ -27,7 +48,7 @@ vm.module @list_ops { vm.list.resize %list, %c101 : (!vm.list, i32) vm.list.set.i32 %list, %c100, %c42 : (!vm.list, i32, i32) %v = vm.list.get.i32 %list, %c100 : (!vm.list, i32) -> i32 - vm.check.eq %v, %c42 : i32 + vm.check.eq %v, %c42, "list.empty.set(100, 42).get(100)=42" : i32 vm.return } @@ -45,7 +66,7 @@ vm.module @list_ops { vm.list.resize %list, %capacity : (!vm.list, i32) vm.list.set.i64 %list, %index, %max_int_plus_1 : (!vm.list, i32, i64) %v = vm.list.get.i64 %list, %index : (!vm.list, i32) -> i64 - vm.check.eq %v, %max_int_plus_1 : i64 + vm.check.eq %v, %max_int_plus_1, "list.empty.set(41, MAX_INT_PLUS_1).get(41)=MAX_INT_PLUS_1" : i64 vm.return }