Skip to content

Commit

Permalink
Add support for control flow lowering in the VM to emitc target (#5208)
Browse files Browse the repository at this point in the history
  • Loading branch information
simon-camp authored Mar 24, 2021
1 parent c8a7b2f commit c9f3742
Show file tree
Hide file tree
Showing 14 changed files with 443 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -280,12 +280,6 @@ void populateVMToCPatterns(MLIRContext *context,
patterns.insert<CallOpConversion<IREE::VM::CmpNZI32Op>>(context,
"vm_cmp_nz_i32");

// Check
// TODO(simon-camp): These conversions to macro calls should be deleted once
// support for control flow ops has landed in the c module target
patterns.insert<CallOpConversion<IREE::VM::CheckEQOp>>(context,
"VM_CHECK_EQ");

// ExtI64: Constants
patterns.insert<ConstOpConversion<IREE::VM::ConstI64Op>>(context);
patterns.insert<ConstZeroOpConversion<IREE::VM::ConstI64ZeroOp>>(context);
Expand Down Expand Up @@ -369,7 +363,12 @@ class ConvertVMToEmitCPass
target.addLegalOp<IREE::VM::ExportOp>();

// Control flow ops
target.addLegalOp<IREE::VM::BranchOp>();
target.addLegalOp<IREE::VM::CallOp>();
target.addLegalOp<IREE::VM::CondBranchOp>();
// Note: We translate the fail op to two function calls in the end, but we
// can't simply convert it here because it is a terminator.
target.addLegalOp<IREE::VM::FailOp>();
target.addLegalOp<IREE::VM::ReturnOp>();

if (failed(
Expand Down
139 changes: 127 additions & 12 deletions iree/compiler/Dialect/VM/Target/C/CModuleTarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,11 +147,88 @@ static LogicalResult initializeGlobals(IREE::VM::ModuleOp moduleOp,
return success();
}

static LogicalResult translateBranchOp(IREE::VM::BranchOp branchOp,
mlir::emitc::CppEmitter &emitter) {
auto &output = emitter.ostream();
Block &successor = *branchOp.getSuccessor();

for (auto pair :
llvm::zip(branchOp.getOperands(), successor.getArguments())) {
auto &operand = std::get<0>(pair);
auto &argument = std::get<1>(pair);
output << emitter.getOrCreateName(argument) << " = "
<< emitter.getOrCreateName(operand) << ";\n";
}

output << "goto ";
if (!(emitter.hasBlockLabel(successor))) {
return branchOp.emitOpError() << "Unable to find label for successor block";
}
output << emitter.getOrCreateName(successor) << ";\n";
return success();
}

static LogicalResult translateCallOpToC(IREE::VM::CallOp callOp,
mlir::emitc::CppEmitter &emitter) {
return success();
}

static LogicalResult translateCondBranchOp(IREE::VM::CondBranchOp condBranchOp,
mlir::emitc::CppEmitter &emitter) {
llvm::raw_ostream &output = emitter.ostream();

Block &trueSuccessor = *condBranchOp.getTrueDest();
Block &falseSuccessor = *condBranchOp.getFalseDest();

output << "if (" << emitter.getOrCreateName(condBranchOp.getCondition())
<< ") {\n";

// If condition is true.
for (auto pair : llvm::zip(condBranchOp.getTrueOperands(),
trueSuccessor.getArguments())) {
auto &operand = std::get<0>(pair);
auto &argument = std::get<1>(pair);
output << emitter.getOrCreateName(argument) << " = "
<< emitter.getOrCreateName(operand) << ";\n";
}

output << "goto ";
if (!(emitter.hasBlockLabel(trueSuccessor))) {
return condBranchOp.emitOpError()
<< "Unable to find label for successor block";
}
output << emitter.getOrCreateName(trueSuccessor) << ";\n";
output << "} else {\n";
// If condition is false.
for (auto pair : llvm::zip(condBranchOp.getFalseOperands(),
falseSuccessor.getArguments())) {
auto &operand = std::get<0>(pair);
auto &argument = std::get<1>(pair);
output << emitter.getOrCreateName(argument) << " = "
<< emitter.getOrCreateName(operand) << ";\n";
}

output << "goto ";
if (!(emitter.hasBlockLabel(falseSuccessor))) {
return condBranchOp.emitOpError()
<< "Unable to find label for successor block";
}
output << emitter.getOrCreateName(falseSuccessor) << ";\n";
output << "}\n";
return success();
}

static LogicalResult translateFailOp(IREE::VM::FailOp failOp,
mlir::emitc::CppEmitter &emitter) {
llvm::raw_ostream &output = emitter.ostream();

auto status = failOp.status();

output << "return vm_fail_or_ok(" << emitter.getOrCreateName(status)
<< ", iree_make_cstring_view(\"" << failOp.message() << "\"));\n";
return success();
}

static LogicalResult translateReturnOpToC(
IREE::VM::ReturnOp returnOp, mlir::emitc::CppEmitter &emitter,
SmallVector<std::string, 4> resultNames) {
Expand All @@ -171,8 +248,14 @@ static LogicalResult translateReturnOpToC(
static LogicalResult translateOpToC(Operation &op,
mlir::emitc::CppEmitter &emitter,
SmallVector<std::string, 4> resultNames) {
if (auto branchOp = dyn_cast<IREE::VM::BranchOp>(op))
return translateBranchOp(branchOp, emitter);
if (auto callOp = dyn_cast<IREE::VM::CallOp>(op))
return translateCallOpToC(callOp, emitter);
if (auto condBranchOp = dyn_cast<IREE::VM::CondBranchOp>(op))
return translateCondBranchOp(condBranchOp, emitter);
if (auto failOp = dyn_cast<IREE::VM::FailOp>(op))
return translateFailOp(failOp, emitter);
if (auto returnOp = dyn_cast<IREE::VM::ReturnOp>(op))
return translateReturnOpToC(returnOp, emitter, resultNames);
// Fall back to generic emitc printer
Expand Down Expand Up @@ -222,9 +305,45 @@ static LogicalResult translateFunctionToC(IREE::VM::ModuleOp &moduleOp,
// struct argument name here must not be changed.
output << moduleName << "_state_t* state) {\n";

// We forward declare all result variables.
for (auto &op : funcOp.getOps()) {
if (failed(translateOpToC(op, emitter, resultNames))) {
return failure();
for (auto result : op.getResults()) {
if (failed(emitter.emitVariableDeclaration(result,
/*trailingSemicolon=*/true))) {
return op.emitError() << "Unable to declare result variable for op";
}
}
}

auto &blocks = funcOp.getBlocks();
// Create label names for basic blocks.
for (auto &block : blocks) {
emitter.getOrCreateName(block);
}

// Emit variables for basic block arguments.
for (auto it = std::next(blocks.begin()); it != blocks.end(); ++it) {
Block &block = *it;
for (auto &arg : block.getArguments()) {
if (emitter.hasValueInScope(arg)) return failure();
if (failed(emitter.emitType(arg.getType()))) {
return failure();
}
output << " " << emitter.getOrCreateName(arg) << ";\n";
}
}

for (auto &block : blocks) {
// Only print a label if there is more than one block.
if (blocks.size() > 1) {
if (failed(emitter.emitLabel(block))) {
return funcOp.emitOpError() << "Unable to print label for basic block";
}
}
for (Operation &op : block.getOperations()) {
if (failed(translateOpToC(op, emitter, resultNames))) {
return failure();
}
}
}

Expand Down Expand Up @@ -459,11 +578,9 @@ static LogicalResult canonicalizeModule(
for (auto *op : context->getRegisteredOperations()) {
// Non-serializable ops must be removed prior to serialization.
if (op->hasTrait<OpTrait::IREE::VM::PseudoOp>()) {
// TODO(simon-camp): reenable pass once support for control flow ops has
// landed
// op->getCanonicalizationPatterns(patterns, context);
// target.setOpAction(OperationName(op->name, context),
// ConversionTarget::LegalizationAction::Illegal);
op->getCanonicalizationPatterns(patterns, context);
target.setOpAction(OperationName(op->name, context),
ConversionTarget::LegalizationAction::Illegal);
}

// Debug ops must not be present when stripping.
Expand All @@ -485,11 +602,9 @@ static LogicalResult canonicalizeModule(

if (targetOptions.optimize) {
// TODO(benvanik): does this run until it quiesces?
// TODO(simon-camp): reenable pass once support for control flow ops has
// landed
// modulePasses.addPass(mlir::createInlinerPass());
modulePasses.addPass(mlir::createInlinerPass());
modulePasses.addPass(mlir::createCSEPass());
// modulePasses.addPass(mlir::createCanonicalizerPass());
modulePasses.addPass(mlir::createCanonicalizerPass());
}

// In the the Bytecode module the order is:
Expand Down Expand Up @@ -546,7 +661,7 @@ LogicalResult translateModuleToC(IREE::VM::ModuleOp moduleOp,
output << "\n";

mlir::emitc::CppEmitter emitter(output, /*restrictToC=*/true,
/*forwardDeclareVariables=*/false);
/*forwardDeclareVariables=*/true);
mlir::emitc::CppEmitter::Scope scope(emitter);

// build struct definitions
Expand Down
6 changes: 4 additions & 2 deletions iree/compiler/Dialect/VM/Target/C/test/add.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
vm.module @add_module {
// CHECK: iree_status_t add_module_add_1_impl(int32_t v1, int32_t v2, int32_t *out0, int32_t *out1, add_module_state_t* state) {
vm.func @add_1(%arg0 : i32, %arg1 : i32) -> (i32, i32) {
// CHECK-NEXT: int32_t v3 = vm_add_i32(v1, v2);
// CHECK-NEXT: int32_t v3;
// CHECK-NEXT: int32_t v4;
// CHECK-NEXT: v3 = vm_add_i32(v1, v2);
%0 = vm.add.i32 %arg0, %arg1 : i32
// CHECK-NEXT: int32_t v4 = vm_add_i32(v3, v3);
// CHECK-NEXT: v4 = vm_add_i32(v3, v3);
%1 = vm.add.i32 %0, %0 : i32
// CHECK-NEXT: *out0 = v3;
// CHECK-NEXT: *out1 = v4;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ vm.module @calling_convention_test {

// CHECK: iree_status_t calling_convention_test_no_in_i32_return_impl(int32_t *out0, calling_convention_test_state_t* state) {
vm.func @no_in_i32_return() -> (i32) {
// CHECK-NEXT: int32_t v1 = vm_const_i32(32);
// CHECK-NEXT: int32_t v1;
// CHECK-NEXT: v1 = 32;
%0 = vm.const.i32 32 : i32
// CHECK-NEXT: *out0 = v1;
// CHECK-NEXT: return iree_ok_status();
Expand All @@ -25,7 +26,8 @@ vm.module @calling_convention_test {

// CHECK: iree_status_t calling_convention_test_i32_in_i32_return_impl(int32_t v1, int32_t *out0, calling_convention_test_state_t* state) {
vm.func @i32_in_i32_return(%arg0 : i32) -> (i32) {
// CHECK-NEXT: int32_t v2 = vm_const_i32(32);
// CHECK-NEXT: int32_t v2;
// CHECK-NEXT: v2 = 32;
%0 = vm.const.i32 32 : i32
// CHECK-NEXT: *out0 = v2;
// CHECK-NEXT: return iree_ok_status();
Expand Down
44 changes: 44 additions & 0 deletions iree/compiler/Dialect/VM/Target/C/test/control_flow.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// RUN: iree-translate -iree-vm-ir-to-c-module -iree-vm-c-module-optimize=false %s | IreeFileCheck %s

vm.module @control_flow_module {
vm.func @control_flow_test(%a: i32, %cond: i32) -> i32 {
vm.cond_br %cond, ^bb1, ^bb2
^bb1:
vm.br ^bb3(%a: i32)
^bb2:
%b = vm.add.i32 %a, %a : i32
vm.br ^bb3(%b: i32)
^bb3(%c: i32):
vm.br ^bb4(%c, %a : i32, i32)
^bb4(%d : i32, %e : i32):
%0 = vm.add.i32 %d, %e : i32
vm.return %0 : i32
}
}
// CHECK: iree_status_t control_flow_module_control_flow_test_impl(int32_t [[A:[^ ]*]], int32_t [[COND:[^ ]*]], int32_t *[[RESULT:[^ ]*]], control_flow_module_state_t* [[STATE:[^ ]*]]) {
// CHECK-NEXT: int32_t [[B:[^ ]*]];
// CHECK-NEXT: int32_t [[V0:[^ ]*]];
// CHECK-NEXT: int32_t [[C:[^ ]*]];
// CHECK-NEXT: int32_t [[D:[^ ]*]];
// CHECK-NEXT: int32_t [[E:[^ ]*]];
// CHECK-NEXT: [[BB0:[^ ]*]]:
// CHECK-NEXT: if ([[COND]]) {
// CHECK-NEXT: goto [[BB1:[^ ]*]];
// CHECK-NEXT: } else {
// CHECK-NEXT: goto [[BB2:[^ ]*]];
// CHECK-NEXT: }
// CHECK-NEXT: [[BB1]]:
// CHECK-NEXT: [[C]] = [[A]];
// CHECK-NEXT: goto [[BB3:[^ ]*]];
// CHECK-NEXT: [[BB2]]:
// CHECK-NEXT: [[B]] = vm_add_i32([[A]], [[A]]);
// CHECK-NEXT: [[C]] = [[B]];
// CHECK-NEXT: goto [[BB3]];
// CHECK-NEXT: [[BB3]]:
// CHECK-NEXT: [[D]] = [[C]];
// CHECK-NEXT: [[E]] = [[A]];
// CHECK-NEXT: goto [[BB4:[^ ]*]];
// CHECK-NEXT: [[BB4]]:
// CHECK-NEXT: [[V0]] = vm_add_i32([[D]], [[E]]);
// CHECK-NEXT: *[[RESULT]] = [[V0]];
// CHECK-NEXT: return iree_ok_status();
12 changes: 8 additions & 4 deletions iree/compiler/Dialect/VM/Target/C/test/global_ops.mlir
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
// RUN: iree-translate -iree-vm-ir-to-c-module %s | IreeFileCheck %s
// RUN: iree-translate -iree-vm-ir-to-c-module -iree-vm-c-module-optimize=false %s | IreeFileCheck %s

vm.module @global_ops {
// check the generated state struct
// CHECK-LABEL: struct global_ops_state_s {
// CHECK-NEXT: iree_allocator_t allocator;
// CHECK-NEXT: uint8_t rwdata[8];
// CHECK-NEXT: iree_vm_ref_t refs[0];
// CHECK-NEXT: };

vm.global.i32 @c42 42 : i32
Expand All @@ -13,19 +14,22 @@ vm.module @global_ops {
vm.export @test_global_load_i32
// CHECK-LABEL: iree_status_t global_ops_test_global_load_i32_impl(
vm.func @test_global_load_i32() -> i32 {
// CHECK-NEXT: int32_t v1 = vm_global_load_i32(state->rwdata, 0);
// CHECK-NEXT: int32_t v1;
// CHECK-NEXT: v1 = vm_global_load_i32(state->rwdata, 0);
%value = vm.global.load.i32 @c42 : i32
vm.return %value : i32
}

vm.export @test_global_store_i32
// CHECK-LABEL: iree_status_t global_ops_test_global_store_i32_impl(
vm.func @test_global_store_i32() -> i32 {
// CHECK-NEXT: int32_t v1 = vm_const_i32(17);
// CHECK-NEXT: int32_t v1;
// CHECK-NEXT: int32_t v2;
// CHECK-NEXT: v1 = 17;
%c17 = vm.const.i32 17 : i32
// CHECK-NEXT: vm_global_store_i32(state->rwdata, 4, v1);
vm.global.store.i32 %c17, @c107_mut : i32
// CHECK-NEXT: int32_t v2 = vm_global_load_i32(state->rwdata, 4);
// CHECK-NEXT: v2 = vm_global_load_i32(state->rwdata, 4);
%value = vm.global.load.i32 @c107_mut : i32
vm.return %value : i32
}
Expand Down
3 changes: 3 additions & 0 deletions iree/vm/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,9 @@ cc_library(
hdrs = [
"ops.h",
],
deps = [
"//iree/base:api",
],
)

cc_library(
Expand Down
2 changes: 2 additions & 0 deletions iree/vm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,8 @@ iree_cc_library(
ops
HDRS
"ops.h"
DEPS
iree::base::api
PUBLIC
)

Expand Down
20 changes: 11 additions & 9 deletions iree/vm/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

#include <stdint.h>

#include "iree/base/api.h"

//===------------------------------------------------------------------===//
// Globals
//===------------------------------------------------------------------===//
Expand Down Expand Up @@ -125,17 +127,17 @@ static inline int32_t vm_cmp_nz_i32(int32_t operand) {
}

//===------------------------------------------------------------------===//
// Additional ops
// Control flow ops
//===------------------------------------------------------------------===//
// Check ops
// TODO(simon-camp): These macros should be removed once control flow ops are
// supported in the c module target.
// Note that this will fail if message contains a comma
#define VM_CHECK_EQ(a, b, message) \
if (a != b) { \
return iree_status_allocate(IREE_STATUS_FAILED_PRECONDITION, "<vm>", 0, \
iree_make_cstring_view(#message)); \

static inline iree_status_t vm_fail_or_ok(int32_t status_code,
iree_string_view_t message) {
if (status_code != 0) {
return iree_status_allocate(IREE_STATUS_FAILED_PRECONDITION, "<vm>", 0,
message);
}
return iree_ok_status();
}

//===------------------------------------------------------------------===//
// ExtI64: Conditional assignment
Expand Down
Loading

0 comments on commit c9f3742

Please sign in to comment.