Skip to content

Commit

Permalink
[Arc] Add support for struct and array states
Browse files Browse the repository at this point in the history
Allow `!arc.state` to carry HW structs and arrays. The state only has to
be able to compute the bit width of the inner type, but it does not care
what exactly this type is.

Rework the Arc-to-LLVM lowering to do the entire lowering in one full
conversion, instead of two separate ones. There is no real need for the
split, and combining all patterns into one large conversion allows all
Arc types to be directly converted to LLVM types. Previously, after the
first partial conversion the IR would be in a strange in-between state
of mixing Arc types into LLVM operations (for example, loads and stores
of HW struct/array types).
  • Loading branch information
fabianschuiki committed Dec 11, 2023
1 parent 102ed65 commit 6560d98
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 149 deletions.
4 changes: 2 additions & 2 deletions include/circt/Dialect/Arc/ArcOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,7 @@ def StateReadOp : ArcOp<"state_read", [
]> {
let summary = "Get a state's current value";
let arguments = (ins StateType:$state);
let results = (outs AnyInteger:$value);
let results = (outs AnyType:$value);
let assemblyFormat = [{
$state attr-dict `:` type($state)
}];
Expand All @@ -551,7 +551,7 @@ def StateWriteOp : ArcOp<"state_write", [
immediately without affecting correctness. This allows later lowering passes
to treat `arc.state_write` as an immediate assignment (without defering).
}];
let arguments = (ins StateType:$state, AnyInteger:$value,
let arguments = (ins StateType:$state, AnyType:$value,
Optional<I1>:$condition);
let assemblyFormat = [{
$state `=` $value (`if` $condition^)? attr-dict `:` type($state)
Expand Down
7 changes: 4 additions & 3 deletions include/circt/Dialect/Arc/ArcTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,17 @@ class ArcTypeDef<string name> : TypeDef<ArcDialect, name> { }

def StateType : ArcTypeDef<"State"> {
let mnemonic = "state";
let parameters = (ins "::mlir::IntegerType":$type);
let parameters = (ins "::mlir::Type":$type);
let assemblyFormat = "`<` $type `>`";
let genVerifyDecl = 1;
let builders = [
AttrBuilderWithInferredContext<(ins "::mlir::IntegerType":$type), [{
AttrBuilderWithInferredContext<(ins "::mlir::Type":$type), [{
return $_get(type.getContext(), type);
}]>
];

let extraClassDeclaration = [{
unsigned getBitWidth() { return getType().getWidth(); }
unsigned getBitWidth();
unsigned getByteWidth() { return (getBitWidth() + 7) / 8; }
}];
}
Expand Down
195 changes: 51 additions & 144 deletions lib/Conversion/ArcToLLVM/LowerArcToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/Support/Debug.h"

#define DEBUG_TYPE "arc-lower-to-llvm"
#define DEBUG_TYPE "lower-arc-to-llvm"

using namespace mlir;
using namespace circt;
Expand Down Expand Up @@ -251,31 +251,6 @@ struct ClockGateOpLowering : public OpConversionPattern<seq::ClockGateOp> {
}
};

struct ReturnOpLowering : public OpConversionPattern<func::ReturnOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<func::ReturnOp>(op, adaptor.getOperands());
return success();
}
};

struct FuncCallOpLowering : public OpConversionPattern<func::CallOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(func::CallOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SmallVector<Type> newResultTypes;
if (failed(
typeConverter->convertTypes(op->getResultTypes(), newResultTypes)))
return failure();
rewriter.replaceOpWithNewOp<func::CallOp>(
op, op.getCalleeAttr(), newResultTypes, adaptor.getOperands());
return success();
}
};

struct ZeroCountOpLowering : public OpConversionPattern<arc::ZeroCountOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
Expand Down Expand Up @@ -310,72 +285,65 @@ struct ReplaceOpWithInputPattern : public OpConversionPattern<OpTy> {

} // namespace

static bool isArcType(Type type) {
return type.isa<StorageType>() || type.isa<MemoryType>() ||
type.isa<StateType>() || type.isa<seq::ClockType>();
}

static bool hasArcType(TypeRange types) {
return llvm::any_of(types, isArcType);
}
//===----------------------------------------------------------------------===//
// Pass Implementation
//===----------------------------------------------------------------------===//

static bool hasArcType(ValueRange values) {
return hasArcType(values.getTypes());
}
namespace {
struct LowerArcToLLVMPass : public LowerArcToLLVMBase<LowerArcToLLVMPass> {
void runOnOperation() override;
};
} // namespace

template <typename Op>
static void addGenericLegality(ConversionTarget &target) {
target.addDynamicallyLegalOp<Op>([](Op op) {
return !hasArcType(op->getOperands()) && !hasArcType(op->getResults());
});
}
void LowerArcToLLVMPass::runOnOperation() {
// Collect the symbols in the root op such that the HW-to-LLVM lowering can
// create LLVM globals with non-colliding names.
Namespace globals;
SymbolCache cache;
cache.addDefinitions(getOperation());
globals.add(cache);

static void populateLegality(ConversionTarget &target) {
target.addLegalDialect<mlir::BuiltinDialect>();
target.addLegalDialect<hw::HWDialect>();
target.addLegalDialect<comb::CombDialect>();
target.addLegalDialect<func::FuncDialect>();
target.addLegalDialect<scf::SCFDialect>();
target.addLegalDialect<LLVM::LLVMDialect>();
target.addIllegalDialect<seq::SeqDialect>();

target.addIllegalOp<arc::DefineOp>();
target.addIllegalOp<arc::OutputOp>();
target.addIllegalOp<arc::StateOp>();
target.addIllegalOp<arc::ClockTreeOp>();
target.addIllegalOp<arc::PassThroughOp>();

target.addDynamicallyLegalOp<func::FuncOp>([](func::FuncOp op) {
auto argsConverted = llvm::none_of(op.getBlocks(), [](auto &block) {
return hasArcType(block.getArguments());
});
auto resultsConverted = !hasArcType(op.getResultTypes());
return argsConverted && resultsConverted;
});
addGenericLegality<func::ReturnOp>(target);
addGenericLegality<func::CallOp>(target);
}
// Setup the conversion target. Explicitly mark `scf.yield` legal since it
// does not have a conversion itself, which would cause it to fail
// legalization and for the conversion to abort. (It relies on its parent op's
// conversion to remove it.)
LLVMConversionTarget target(getContext());
target.addLegalOp<mlir::ModuleOp>();
target.addLegalOp<scf::YieldOp>(); // quirk of SCF dialect conversion

static void populateTypeConversion(TypeConverter &typeConverter) {
typeConverter.addConversion([&](seq::ClockType type) {
// Setup the arc dialect type conversion.
LLVMTypeConverter converter(&getContext());
converter.addConversion([&](seq::ClockType type) {
return IntegerType::get(type.getContext(), 1);
});
typeConverter.addConversion([&](StorageType type) {
converter.addConversion([&](StorageType type) {
return LLVM::LLVMPointerType::get(type.getContext());
});
typeConverter.addConversion([&](MemoryType type) {
converter.addConversion([&](MemoryType type) {
return LLVM::LLVMPointerType::get(type.getContext());
});
typeConverter.addConversion([&](StateType type) {
converter.addConversion([&](StateType type) {
return LLVM::LLVMPointerType::get(type.getContext());
});
typeConverter.addConversion([](hw::ArrayType type) { return type; });
typeConverter.addConversion([](mlir::IntegerType type) { return type; });
}

static void populateOpConversion(RewritePatternSet &patterns,
TypeConverter &typeConverter) {
auto *context = patterns.getContext();
// Setup the conversion patterns.
RewritePatternSet patterns(&getContext());

// MLIR patterns.
populateSCFToControlFlowConversionPatterns(patterns);
populateFuncToLLVMConversionPatterns(converter, patterns);
cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
arith::populateArithToLLVMConversionPatterns(converter, patterns);
populateAnyFunctionOpInterfaceTypeConversionPattern(patterns, converter);

// CIRCT patterns.
DenseMap<std::pair<Type, ArrayAttr>, LLVM::GlobalOp> constAggregateGlobalsMap;
populateHWToLLVMConversionPatterns(converter, patterns, globals,
constAggregateGlobalsMap);
populateHWToLLVMTypeConversions(converter);
populateCombToLLVMConversionPatterns(converter, patterns);

// Arc patterns.
// clang-format off
patterns.add<
AllocMemoryOpLowering,
Expand All @@ -384,82 +352,21 @@ static void populateOpConversion(RewritePatternSet &patterns,
AllocStateLikeOpLowering<arc::RootOutputOp>,
AllocStorageOpLowering,
ClockGateOpLowering,
FuncCallOpLowering,
MemoryReadOpLowering,
MemoryWriteOpLowering,
ModelOpLowering,
ReplaceOpWithInputPattern<seq::ToClockOp>,
ReplaceOpWithInputPattern<seq::FromClockOp>,
ReturnOpLowering,
StateReadOpLowering,
StateWriteOpLowering,
StorageGetOpLowering,
ZeroCountOpLowering
>(typeConverter, context);
>(converter, &getContext());
// clang-format on

mlir::populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
patterns, typeConverter);
}

//===----------------------------------------------------------------------===//
// Pass Implementation
//===----------------------------------------------------------------------===//

namespace {
struct LowerArcToLLVMPass : public LowerArcToLLVMBase<LowerArcToLLVMPass> {
void runOnOperation() override;
LogicalResult lowerToMLIR();
LogicalResult lowerArcToLLVM();
};
} // namespace

void LowerArcToLLVMPass::runOnOperation() {
if (failed(lowerToMLIR()))
return signalPassFailure();

if (failed(lowerArcToLLVM()))
return signalPassFailure();
}

/// Perform the lowering to Func and SCF.
LogicalResult LowerArcToLLVMPass::lowerToMLIR() {
LLVM_DEBUG(llvm::dbgs() << "Lowering arcs to Func/SCF dialects\n");
ConversionTarget target(getContext());
TypeConverter converter;
RewritePatternSet patterns(&getContext());
populateLegality(target);
populateTypeConversion(converter);
populateOpConversion(patterns, converter);
return applyPartialConversion(getOperation(), target, std::move(patterns));
}

/// Perform lowering to LLVM.
LogicalResult LowerArcToLLVMPass::lowerArcToLLVM() {
LLVM_DEBUG(llvm::dbgs() << "Lowering to LLVM dialect\n");

Namespace globals;
SymbolCache cache;
cache.addDefinitions(getOperation());
globals.add(cache);

LLVMConversionTarget target(getContext());
LLVMTypeConverter converter(&getContext());
RewritePatternSet patterns(&getContext());
target.addLegalOp<mlir::ModuleOp>();
target.addIllegalOp<arc::ModelOp>();
populateSCFToControlFlowConversionPatterns(patterns);
populateFuncToLLVMConversionPatterns(converter, patterns);
cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);

DenseMap<std::pair<Type, ArrayAttr>, LLVM::GlobalOp> constAggregateGlobalsMap;
populateHWToLLVMConversionPatterns(converter, patterns, globals,
constAggregateGlobalsMap);
populateHWToLLVMTypeConversions(converter);
populateCombToLLVMConversionPatterns(converter, patterns);
arith::populateArithToLLVMConversionPatterns(converter, patterns);

return applyFullConversion(getOperation(), target, std::move(patterns));
// Apply the conversion.
if (failed(applyFullConversion(getOperation(), target, std::move(patterns))))
signalPassFailure();
}

std::unique_ptr<OperationPass<ModuleOp>> circt::createLowerArcToLLVMPass() {
Expand Down
12 changes: 12 additions & 0 deletions lib/Dialect/Arc/ArcTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "circt/Dialect/Arc/ArcTypes.h"
#include "circt/Dialect/Arc/ArcDialect.h"
#include "circt/Dialect/HW/HWTypes.h"
#include "circt/Dialect/Seq/SeqTypes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
Expand All @@ -20,6 +21,17 @@ using namespace mlir;
#define GET_TYPEDEF_CLASSES
#include "circt/Dialect/Arc/ArcTypes.cpp.inc"

unsigned StateType::getBitWidth() { return hw::getBitWidth(getType()); }

LogicalResult
StateType::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
Type innerType) {
if (hw::getBitWidth(innerType) < 0)
return emitError() << "state type must have a known bit width; got "
<< innerType;
return success();
}

unsigned MemoryType::getStride() {
unsigned stride = (getWordType().getWidth() + 7) / 8;
return llvm::alignToPowerOf2(stride, llvm::bit_ceil(std::min(stride, 16U)));
Expand Down
37 changes: 37 additions & 0 deletions test/Conversion/ArcToLLVM/lower-arc-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -185,3 +185,40 @@ func.func @seqClocks(%clk1: !seq.clock, %clk2: !seq.clock) -> !seq.clock {
// CHECK-SAME: ([[CLK1:%.+]]: i1, [[CLK2:%.+]]: i1)
// CHECK: [[RES:%.+]] = llvm.xor [[CLK1]], [[CLK2]]
// CHECK: llvm.return [[RES]] : i1

// CHECK-LABEL: llvm.func @ReadAggregates(
// CHECK-SAME: %arg0: !llvm.ptr
// CHECK-SAME: %arg1: !llvm.ptr
func.func @ReadAggregates(%arg0: !arc.state<!hw.struct<a: i1, b: i1>>, %arg1: !arc.state<!hw.array<4xi1>>) {
// CHECK: llvm.load %arg0 : !llvm.ptr -> !llvm.struct<(i1, i1)>
// CHECK: llvm.load %arg1 : !llvm.ptr -> !llvm.array<4 x i1>
arc.state_read %arg0 : <!hw.struct<a: i1, b: i1>>
arc.state_read %arg1 : <!hw.array<4xi1>>
return
}

// CHECK-LABEL: llvm.func @WriteStruct(
// CHECK-SAME: %arg0: !llvm.ptr
// CHECK-SAME: %arg1: !llvm.struct<(i1, i1)>
func.func @WriteStruct(%arg0: !arc.state<!hw.struct<a: i1, b: i1>>, %arg1: !hw.struct<a: i1, b: i1>) {
// CHECK: [[CONST:%.+]] = llvm.load {{%.+}} : !llvm.ptr -> !llvm.struct<(i1, i1)>
%0 = hw.aggregate_constant [false, false] : !hw.struct<a: i1, b: i1>
// CHECK: llvm.store [[CONST]], %arg0 : !llvm.struct<(i1, i1)>, !llvm.ptr
// CHECK: llvm.store %arg1, %arg0 : !llvm.struct<(i1, i1)>, !llvm.ptr
arc.state_write %arg0 = %0 : <!hw.struct<a: i1, b: i1>>
arc.state_write %arg0 = %arg1 : <!hw.struct<a: i1, b: i1>>
return
}

// CHECK-LABEL: llvm.func @WriteArray(
// CHECK-SAME: %arg0: !llvm.ptr
// CHECK-SAME: %arg1: !llvm.array<4 x i1>
func.func @WriteArray(%arg0: !arc.state<!hw.array<4xi1>>, %arg1: !hw.array<4xi1>) {
// CHECK: [[CONST:%.+]] = llvm.load {{%.+}} : !llvm.ptr -> !llvm.array<4 x i1>
%0 = hw.aggregate_constant [false, false, false, false] : !hw.array<4xi1>
// CHECK: llvm.store [[CONST]], %arg0 : !llvm.array<4 x i1>, !llvm.ptr
// CHECK: llvm.store %arg1, %arg0 : !llvm.array<4 x i1>, !llvm.ptr
arc.state_write %arg0 = %0 : <!hw.array<4xi1>>
arc.state_write %arg0 = %arg1 : <!hw.array<4xi1>>
return
}
5 changes: 5 additions & 0 deletions test/Dialect/Arc/basic-errors.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -563,3 +563,8 @@ hw.module @vectorize(in %in0: i4, in %in1: i4, out out0: i4) {
}
hw.output %0 : i4
}

// -----

// expected-error @below {{state type must have a known bit width}}
func.func @InvalidStateType(%arg0: !arc.state<index>)

0 comments on commit 6560d98

Please sign in to comment.