Skip to content

Commit

Permalink
Implement conversion for list ops
Browse files Browse the repository at this point in the history
  • Loading branch information
simon-camp committed Apr 8, 2021
1 parent 6e94e8e commit ae05781
Show file tree
Hide file tree
Showing 5 changed files with 246 additions and 135 deletions.
210 changes: 202 additions & 8 deletions iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "emitc/Dialect/EmitC/EmitCDialect.h"
#include "iree/compiler/Dialect/IREE/IR/IREEDialect.h"
#include "iree/compiler/Dialect/VM/IR/VMOps.h"
#include "llvm/ADT/TypeSwitch.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Transforms/DialectConversion.h"
Expand Down Expand Up @@ -154,9 +155,9 @@ class GlobalLoadOpConversion : public OpConversionPattern<LoadOpTy> {
auto type = loadOp.getOperation()->getResultTypes();
StringAttr callee = rewriter.getStringAttr(funcName);

// TODO(simon-camp): We can't represent structs in emitc (yet maybe), so the
// buffer where globals live after code generation as well as the state
// struct argument name are hardcoded here.
// TODO(simon-camp): We can't represent structs in emitc (yet maybe), so
// the buffer where globals live after code generation as well as the
// state struct argument name are hardcoded here.
ArrayAttr args = rewriter.getArrayAttr(
{rewriter.getStringAttr("state->rwdata"),
rewriter.getUI32IntegerAttr(static_cast<uint32_t>(
Expand Down Expand Up @@ -190,9 +191,9 @@ class GlobalStoreOpConversion : public OpConversionPattern<StoreOpTy> {
auto type = storeOp.getOperation()->getResultTypes();
StringAttr callee = rewriter.getStringAttr(funcName);

// TODO(simon-camp): We can't represent structs in emitc (yet maybe), so the
// buffer where globals live after code generation as well as the state
// struct argument name are hardcoded here.
// TODO(simon-camp): We can't represent structs in emitc (yet maybe), so
// the buffer where globals live after code generation as well as the
// state struct argument name are hardcoded here.
ArrayAttr args = rewriter.getArrayAttr(
{rewriter.getStringAttr("state->rwdata"),
rewriter.getUI32IntegerAttr(static_cast<uint32_t>(
Expand All @@ -217,10 +218,189 @@ class ListAllocOpConversion
LogicalResult matchAndRewrite(
IREE::VM::ListAllocOp allocOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
return failure();
// clang-format off
// The generated c code looks roughly like this.
// iree_vm_type_def_t element_type = iree_vm_type_def_make_value_type(IREE_VM_VALUE_TYPE_I32);
// iree_vm_type_def_t* element_type_ptr = &element_type;
// iree_vm_list_t* list = nullptr;
// iree_vm_list_t** list_ptr = &list;
// iree_vm_list_create(&element_type, {initial_capacity}, state->allocator, &list);
// clang-format on

auto ctx = allocOp.getContext();
auto loc = allocOp.getLoc();

auto listType = allocOp.getType()
.cast<IREE::VM::RefType>()
.getObjectType()
.cast<IREE::VM::ListType>();
auto elementType = listType.getElementType();
std::string elementTypeStr;
StringRef elementTypeConstructor;
if (elementType.isa<IntegerType>()) {
unsigned int bitWidth = elementType.getIntOrFloatBitWidth();
elementTypeStr =
std::string("IREE_VM_VALUE_TYPE_I") + std::to_string(bitWidth);
elementTypeConstructor = "iree_vm_type_def_make_value_type";
} else {
return allocOp.emitError() << "Unhandeled element type " << elementType;
}

auto elementTypeOp = rewriter.create<emitc::CallOp>(
/*location=*/loc,
/*type=*/emitc::OpaqueType::get(ctx, "iree_vm_type_def_t"),
/*callee=*/rewriter.getStringAttr(elementTypeConstructor),
/*args=*/ArrayAttr::get(ctx, {StringAttr::get(ctx, elementTypeStr)}),
/*templateArgs=*/ArrayAttr{},
/*operands=*/ArrayRef<Value>{});

auto elementTypePtrOp = rewriter.create<emitc::GetAddressOfOp>(
/*location=*/loc,
/*result=*/emitc::OpaqueType::get(ctx, "iree_vm_type_def_t*"),
/*operand=*/elementTypeOp.getResult(0));

auto listOp = rewriter.replaceOpWithNewOp<emitc::ConstOp>(
/*op=*/allocOp,
/*resultType=*/emitc::OpaqueType::get(ctx, "iree_vm_list_t*"),
/*value=*/StringAttr::get(ctx, "nullptr"));

auto listPtrOp = rewriter.create<emitc::GetAddressOfOp>(
/*location=*/loc,
/*result=*/emitc::OpaqueType::get(ctx, "iree_vm_list_t**"),
/*operand=*/listOp.getResult());

rewriter.create<emitc::CallOp>(
/*location=*/loc,
/*type=*/TypeRange{},
/*callee=*/rewriter.getStringAttr("iree_vm_list_create"),
/*args=*/
ArrayAttr::get(ctx, {rewriter.getIndexAttr(0), rewriter.getIndexAttr(1),
StringAttr::get(ctx, "state->allocator"),
rewriter.getIndexAttr(2)}),
/*templateArgs=*/ArrayAttr{},
/*operands=*/
ArrayRef<Value>{elementTypePtrOp.getResult(), operands[0],
listPtrOp.getResult()});

return success();
}
};

template <typename GetOpTy>
class ListGetOpConversion : public OpConversionPattern<GetOpTy> {
using OpConversionPattern<GetOpTy>::OpConversionPattern;

private:
LogicalResult matchAndRewrite(
GetOpTy getOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto ctx = getOp.getContext();
auto loc = getOp.getLoc();

Optional<StringRef> valueTypeEnum;
Optional<StringRef> valueExtractor;

std::tie(valueTypeEnum, valueExtractor) =
TypeSwitch<Operation *,
std::pair<Optional<StringRef>, Optional<StringRef>>>(
getOp.getOperation())
.Case<IREE::VM::ListGetI32Op>([&](auto op) {
return std::make_pair(StringRef("IREE_VM_VALUE_TYPE_I32"),
StringRef("vm_list_value_extract_i32"));
})
.template Case<IREE::VM::ListGetI64Op>([&](auto op) {
return std::make_pair(StringRef("IREE_VM_VALUE_TYPE_I64"),
StringRef("vm_list_value_extract_i64"));
})
.Default([](Operation *) { return std::make_pair(None, None); });

if (!valueTypeEnum.hasValue() || !valueExtractor.hasValue()) {
return getOp.emitOpError() << " not handeled";
}

auto valueOp = rewriter.create<emitc::ConstOp>(
/*location=*/loc,
/*resultType=*/emitc::OpaqueType::get(ctx, "iree_vm_value_t"),
/*value=*/StringAttr::get(ctx, ""));

auto valuePtrOp = rewriter.create<emitc::GetAddressOfOp>(
/*location=*/loc,
/*result=*/emitc::OpaqueType::get(ctx, "iree_vm_value_t*"),
/*operand=*/valueOp.getResult());

auto getValueOp = rewriter.create<emitc::CallOp>(
/*location=*/loc,
/*type=*/TypeRange{},
/*callee=*/rewriter.getStringAttr("iree_vm_list_get_value_as"),
/*args=*/
ArrayAttr::get(ctx, {rewriter.getIndexAttr(0), rewriter.getIndexAttr(1),
StringAttr::get(ctx, valueTypeEnum.getValue()),
rewriter.getIndexAttr(2)}),
/*templateArgs=*/ArrayAttr{},
/*operands=*/
ArrayRef<Value>{getOp.list(), getOp.index(), valuePtrOp.getResult()});

rewriter.replaceOpWithNewOp<emitc::CallOp>(
/*op=*/getOp,
/*type=*/getOp.getType(),
/*callee=*/rewriter.getStringAttr(valueExtractor.getValue()),
/*args=*/ArrayAttr{},
/*templateArgs=*/ArrayAttr{},
/*operands=*/ArrayRef<Value>{valuePtrOp.getResult()});

return success();
}
};

template <typename SetOpTy>
class ListSetOpConversion : public OpConversionPattern<SetOpTy> {
using OpConversionPattern<SetOpTy>::OpConversionPattern;

private:
LogicalResult matchAndRewrite(
SetOpTy setOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto ctx = setOp.getContext();
auto loc = setOp.getLoc();

Optional<StringRef> valueConstructor =
TypeSwitch<Operation *, Optional<StringRef>>(setOp.getOperation())
.Case<IREE::VM::ListSetI32Op>(
[&](auto op) { return StringRef("iree_vm_value_make_i32"); })
.template Case<IREE::VM::ListSetI64Op>(
[&](auto op) { return StringRef("iree_vm_value_make_i64"); })
.Default([](Operation *) { return None; });

if (!valueConstructor.hasValue()) {
return setOp.emitOpError() << " not handeled";
}

auto valueOp = rewriter.create<emitc::CallOp>(
/*location=*/loc,
/*type=*/emitc::OpaqueType::get(ctx, "iree_vm_value_t"),
/*callee=*/rewriter.getStringAttr(valueConstructor.getValue()),
/*args=*/ArrayAttr{},
/*templateArgs=*/ArrayAttr{},
/*operands=*/ArrayRef<Value>{setOp.value()});

auto valuePtrOp = rewriter.create<emitc::GetAddressOfOp>(
/*location=*/loc,
/*result=*/emitc::OpaqueType::get(ctx, "iree_vm_value_t*"),
/*operand=*/valueOp.getResult(0));

rewriter.replaceOpWithNewOp<emitc::CallOp>(
/*op=*/setOp,
/*type=*/TypeRange{},
/*callee=*/rewriter.getStringAttr("iree_vm_list_set_value"),
/*args=*/
ArrayAttr{},
/*templateArgs=*/ArrayAttr{},
/*operands=*/
ArrayRef<Value>{setOp.list(), setOp.index(), valuePtrOp.getResult()});

return success();
}
};
} // namespace

void populateVMToCPatterns(MLIRContext *context,
Expand All @@ -238,8 +418,18 @@ void populateVMToCPatterns(MLIRContext *context,
patterns.insert<ConstZeroOpConversion<IREE::VM::ConstI32ZeroOp>>(context);
patterns.insert<ConstRefZeroOpConversion>(context);

// Lists
// List ops
// TODO(simon-camp): We leak memory in the generated code, as we never release
// the lists.
patterns.insert<ListAllocOpConversion>(context);
patterns.insert<CallOpConversion<IREE::VM::ListReserveOp>>(
context, "iree_vm_list_reserve");
patterns.insert<CallOpConversion<IREE::VM::ListResizeOp>>(
context, "iree_vm_list_resize");
patterns.insert<CallOpConversion<IREE::VM::ListSizeOp>>(context,
"iree_vm_list_size");
patterns.insert<ListGetOpConversion<IREE::VM::ListGetI32Op>>(context);
patterns.insert<ListSetOpConversion<IREE::VM::ListSetI32Op>>(context);

// Conditional assignment ops
patterns.insert<CallOpConversion<IREE::VM::SelectI32Op>>(context,
Expand Down Expand Up @@ -299,6 +489,10 @@ void populateVMToCPatterns(MLIRContext *context,
patterns.insert<ConstOpConversion<IREE::VM::ConstI64Op>>(context);
patterns.insert<ConstZeroOpConversion<IREE::VM::ConstI64ZeroOp>>(context);

// ExtI64: List ops
patterns.insert<ListGetOpConversion<IREE::VM::ListGetI64Op>>(context);
patterns.insert<ListSetOpConversion<IREE::VM::ListSetI64Op>>(context);

// ExtI64: Conditional assignment ops
patterns.insert<CallOpConversion<IREE::VM::SelectI64Op>>(context,
"vm_select_i64");
Expand Down
17 changes: 17 additions & 0 deletions iree/vm/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <stdint.h>

#include "iree/base/api.h"
#include "iree/vm/value.h"

//===------------------------------------------------------------------===//
// Globals
Expand All @@ -34,6 +35,14 @@ static inline void vm_global_store_i32(uint8_t* base, uint32_t byte_offset,
*global_ptr = value;
}

//===------------------------------------------------------------------===//
// List ops
//===------------------------------------------------------------------===//

static inline int32_t vm_list_value_extract_i32(iree_vm_value_t* value) {
return value->i32;
}

//===------------------------------------------------------------------===//
// Conditional assignment
//===------------------------------------------------------------------===//
Expand Down Expand Up @@ -139,6 +148,14 @@ static inline iree_status_t vm_fail_or_ok(int32_t status_code,
return iree_ok_status();
}

//===------------------------------------------------------------------===//
// ExtI64: List ops
//===------------------------------------------------------------------===//

static inline int64_t vm_list_value_extract_i64(iree_vm_value_t* value) {
return value->i64;
}

//===------------------------------------------------------------------===//
// ExtI64: Conditional assignment
//===------------------------------------------------------------------===//
Expand Down
Loading

0 comments on commit ae05781

Please sign in to comment.