Skip to content

Commit

Permalink
[FIRRTL] Add options and instance choices
Browse files Browse the repository at this point in the history
  • Loading branch information
nandor committed Dec 8, 2023
1 parent 37b2533 commit ebef73b
Show file tree
Hide file tree
Showing 5 changed files with 439 additions and 2 deletions.
50 changes: 50 additions & 0 deletions include/circt/Dialect/FIRRTL/FIRRTLDeclarations.td
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,56 @@ def InstanceOp : HardwareDeclOp<"instance", [
}];
}

def InstanceChoiceOp : HardwareDeclOp<"instance_choice", [
DeclareOpInterfaceMethods<SymbolUserOpInterface>
]> {
let summary = "Creates an instance of a module based on a option";

let description = [{
The instance choice operation creates an instance choosing the target based
on the value of an option if one is specified, instantiating a default
target otherwise.

The port lists of all instance targets must match exactly.

Examples:
```mlir
%0 = firrtl.instance_choice foo @Foo alternatives @Opt { @FPGA -> @FPGAFoo }
(in io: !firrtl.uint)
```
}];

let arguments = (ins FlatSymbolRefArrayAttr:$moduleNames,
SymbolRefArrayAttr:$caseNames,
StrAttr:$name, NameKindAttr:$nameKind,
APIntAttr:$portDirections, StrArrayAttr:$portNames,
AnnotationArrayAttr:$annotations,
PortAnnotationsAttr:$portAnnotations,
OptionalAttr<InnerSymAttr>:$inner_sym);

let results = (outs Variadic<AnyType>:$results);

let hasCustomAssemblyFormat = 1;

let extraClassDeclaration = [{
/// Return the port direction for the specified result number.
Direction getPortDirection(size_t resultNo) {
return direction::get(getPortDirections()[resultNo]);
}
}];

let builders = [
OpBuilder<(ins "FModuleLike":$defaultModule,
"ArrayRef<std::pair<OptionCaseOp, FModuleLike>>":$cases,
"mlir::StringRef":$name,
CArg<"NameKindEnum", "NameKindEnum::DroppableName">:$nameKind,
CArg<"ArrayRef<Attribute>", "{}">:$annotations,
CArg<"ArrayRef<Attribute>", "{}">:$portAnnotations,
CArg<"StringAttr", "StringAttr()">:$innerSym)>
];
}


def MemOp : HardwareDeclOp<"mem"> {
let summary = "Define a new mem";
let arguments =
Expand Down
52 changes: 52 additions & 0 deletions include/circt/Dialect/FIRRTL/FIRRTLStructure.td
Original file line number Diff line number Diff line change
Expand Up @@ -411,4 +411,56 @@ def LayerOp : FIRRTLOp<
}];
}

def OptionOp : FIRRTLOp<"option", [
IsolatedFromAbove,
Symbol,
SymbolTable,
NoTerminator,
HasParent<"firrtl::CircuitOp">
]> {
let summary = "An option group definition";
let description = [{
The `firrtl.option` operation defines a specializable parameter with a
known set of values, represented by the `firrtl.option_case` operations
nested underneath.

Operations which support specialization reference the option and its
cases to define the specializations they support.

Example:
```mlir

firrtl.circuit {
firrtl.option @Target {
firrtl.option_case @FPGA
firrtl.option_case @ASIC
}
}
```
}];

let arguments = (ins SymbolNameAttr:$sym_name);
let results = (outs);
let regions = (region SizedRegion<1>:$body);
let assemblyFormat = [{
$sym_name attr-dict-with-keyword $body
}];
}

def OptionCaseOp : FIRRTLOp<
"option_case", [Symbol, HasParent<"firrtl::OptionOp">]
> {
let summary = "A configuration option value definition";
let description = [{
`firrtl.option_case` defines an acceptable value to be provided for an
option. Ops reference it to define their behavior when this case is active.
}];

let arguments = (ins SymbolNameAttr:$sym_name);
let results = (outs);
let assemblyFormat = [{
$sym_name attr-dict
}];
}

#endif // CIRCT_DIALECT_FIRRTL_FIRRTLSTRUCTURE_TD
243 changes: 241 additions & 2 deletions lib/Dialect/FIRRTL/FIRRTLOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ Flow firrtl::foldFlow(Value val, Flow accumulatedFlow) {
// Registers, Wires, and behavioral memory ports are always Duplex.
.Case<RegOp, RegResetOp, WireOp, MemoryPortOp>(
[](auto) { return Flow::Duplex; })
.Case<InstanceOp>([&](auto inst) {
.Case<InstanceOp, InstanceChoiceOp>([&](auto inst) {
auto resultNo = cast<OpResult>(val).getResultNumber();
if (inst.getPortDirection(resultNo) == Direction::Out)
return accumulatedFlow;
Expand Down Expand Up @@ -1889,7 +1889,7 @@ bool ExtClassOp::canDiscardOnUseEmpty() {
}

//===----------------------------------------------------------------------===//
// Declarations
// InstanceOp
//===----------------------------------------------------------------------===//

SmallVector<::circt::hw::PortInfo> InstanceOp::getPortList() {
Expand Down Expand Up @@ -2225,6 +2225,245 @@ std::optional<size_t> InstanceOp::getTargetResultIndex() {
return std::nullopt;
}

// -----------------------------------------------------------------------------
// InstanceChoiceOp
// -----------------------------------------------------------------------------

void InstanceChoiceOp::build(
OpBuilder &builder, OperationState &result, FModuleLike defaultModule,
ArrayRef<std::pair<OptionCaseOp, FModuleLike>> cases, StringRef name,
NameKindEnum nameKind, ArrayRef<Attribute> annotations,
ArrayRef<Attribute> portAnnotations, StringAttr innerSym) {
// Gather the result types.
SmallVector<Type> resultTypes;
resultTypes.reserve(defaultModule.getNumPorts());
llvm::transform(
defaultModule.getPortTypes(), std::back_inserter(resultTypes),
[](Attribute typeAttr) { return cast<TypeAttr>(typeAttr).getValue(); });

// Create the port annotations.
ArrayAttr portAnnotationsAttr;
if (portAnnotations.empty()) {
portAnnotationsAttr = builder.getArrayAttr(SmallVector<Attribute, 16>(
resultTypes.size(), builder.getArrayAttr({})));
} else {
portAnnotationsAttr = builder.getArrayAttr(portAnnotations);
}

// Gather the module & case names.
SmallVector<Attribute> moduleNames;
SmallVector<Attribute> caseNames;
moduleNames.push_back(SymbolRefAttr::get(defaultModule.getModuleNameAttr()));
for (auto [caseOption, caseModule] : cases) {
auto caseGroup = caseOption->getParentOfType<OptionOp>();
caseNames.push_back(SymbolRefAttr::get(caseGroup.getSymNameAttr(),
{SymbolRefAttr::get(caseOption)}));
moduleNames.push_back(SymbolRefAttr::get(caseModule.getModuleNameAttr()));
}

return build(builder, result, resultTypes, builder.getArrayAttr(moduleNames),
builder.getArrayAttr(caseNames), builder.getStringAttr(name),
NameKindEnumAttr::get(builder.getContext(), nameKind),
defaultModule.getPortDirectionsAttr(),
defaultModule.getPortNamesAttr(),
builder.getArrayAttr(annotations), portAnnotationsAttr,
innerSym ? hw::InnerSymAttr::get(innerSym) : hw::InnerSymAttr());
}

std::optional<size_t> InstanceChoiceOp::getTargetResultIndex() {
return std::nullopt;
}

void InstanceChoiceOp::print(OpAsmPrinter &p) {
// Print the instance name.
p << " ";
p.printKeywordOrString(getName());
if (auto attr = getInnerSymAttr()) {
p << " sym ";
p.printSymbolName(attr.getSymName());
}
if (getNameKindAttr().getValue() != NameKindEnum::DroppableName)
p << ' ' << stringifyNameKindEnum(getNameKindAttr().getValue());

// Print the attr-dict.
SmallVector<StringRef, 9> omittedAttrs = {
"moduleNames", "caseNames", "name",
"portDirections", "portNames", "portTypes",
"portAnnotations", "inner_sym", "nameKind"};
if (getAnnotations().empty())
omittedAttrs.push_back("annotations");
p.printOptionalAttrDict((*this)->getAttrs(), omittedAttrs);

// Print the module name.
p << ' ';

auto moduleNames = getModuleNamesAttr();
auto caseNames = getCaseNamesAttr();
assert(moduleNames.size() == caseNames.size() + 1);

p.printSymbolName(moduleNames[0].cast<FlatSymbolRefAttr>().getValue());

p << " alternatives ";
p.printSymbolName(
caseNames[0].cast<SymbolRefAttr>().getRootReference().getValue());
p << " { ";
for (size_t i = 0, n = caseNames.size(); i < n; ++i) {
if (i != 0)
p << ", ";

auto symbol = caseNames[i].cast<SymbolRefAttr>();
p.printSymbolName(symbol.getNestedReferences()[0].getValue());
p << " -> ";
p.printSymbolName(moduleNames[i + 1].cast<FlatSymbolRefAttr>().getValue());
}

p << " } ";

// Collect all the result types as TypeAttrs for printing.
SmallVector<Attribute> portTypes;
portTypes.reserve(getNumResults());
llvm::transform(getResultTypes(), std::back_inserter(portTypes),
&TypeAttr::get);
auto portDirections = direction::unpackAttribute(getPortDirectionsAttr());
printModulePorts(p, /*block=*/nullptr, portDirections,
getPortNames().getValue(), portTypes,
getPortAnnotations().getValue(), {}, {});
}

ParseResult InstanceChoiceOp::parse(OpAsmParser &parser,
OperationState &result) {
auto *context = parser.getContext();
auto &resultAttrs = result.attributes;

std::string name;
hw::InnerSymAttr innerSymAttr;
SmallVector<Attribute> moduleNames;
SmallVector<Attribute> caseNames;
SmallVector<OpAsmParser::Argument> entryArgs;
SmallVector<Direction, 4> portDirections;
SmallVector<Attribute, 4> portNames;
SmallVector<Attribute, 4> portTypes;
SmallVector<Attribute, 4> portAnnotations;
SmallVector<Attribute, 4> portSyms;
SmallVector<Attribute, 4> portLocs;
NameKindEnumAttr nameKind;

if (parser.parseKeywordOrString(&name))
return failure();
if (succeeded(parser.parseOptionalKeyword("sym"))) {
if (parser.parseCustomAttributeWithFallback(
innerSymAttr, Type{},
hw::InnerSymbolTable::getInnerSymbolAttrName(),
result.attributes)) {
return failure();
}
}
if (parseNameKind(parser, nameKind) ||
parser.parseOptionalAttrDict(result.attributes))
return failure();

FlatSymbolRefAttr defaultModuleName;
if (parser.parseAttribute(defaultModuleName))
return failure();
moduleNames.push_back(defaultModuleName);

// alternatives { @opt::@case -> @target, ... }
{
FlatSymbolRefAttr optionName;
if (parser.parseKeyword("alternatives") ||
parser.parseAttribute(optionName) || parser.parseLBrace())
return failure();

FlatSymbolRefAttr moduleName;
StringAttr caseName;
while (succeeded(parser.parseOptionalSymbolName(caseName))) {
if (parser.parseArrow() || parser.parseAttribute(moduleName))
return failure();
moduleNames.push_back(moduleName);
caseNames.push_back(SymbolRefAttr::get(
optionName.getAttr(), {FlatSymbolRefAttr::get(caseName)}));
if (failed(parser.parseOptionalComma()))
break;
}
if (parser.parseRBrace())
return failure();
}

if (parseModulePorts(parser, /*hasSSAIdentifiers=*/false,
/*supportsSymbols=*/false, entryArgs, portDirections,
portNames, portTypes, portAnnotations, portSyms,
portLocs))
return failure();

// Add the attributes. We let attributes defined in the attr-dict override
// attributes parsed out of the module signature.
if (!resultAttrs.get("moduleNames"))
result.addAttribute("moduleNames", ArrayAttr::get(context, moduleNames));
if (!resultAttrs.get("caseNames"))
result.addAttribute("caseNames", ArrayAttr::get(context, caseNames));
if (!resultAttrs.get("name"))
result.addAttribute("name", StringAttr::get(context, name));
result.addAttribute("nameKind", nameKind);
if (!resultAttrs.get("portDirections"))
result.addAttribute("portDirections",
direction::packAttribute(context, portDirections));
if (!resultAttrs.get("portNames"))
result.addAttribute("portNames", ArrayAttr::get(context, portNames));
if (!resultAttrs.get("portAnnotations"))
result.addAttribute("portAnnotations",
ArrayAttr::get(context, portAnnotations));

// Annotations and LowerToBind are omitted in the printed format if they are
// empty and false, respectively.
if (!resultAttrs.get("annotations"))
resultAttrs.append("annotations", parser.getBuilder().getArrayAttr({}));

// Add result types.
result.types.reserve(portTypes.size());
llvm::transform(
portTypes, std::back_inserter(result.types),
[](Attribute typeAttr) { return cast<TypeAttr>(typeAttr).getValue(); });

return success();
}

void InstanceChoiceOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
StringRef base = getName().empty() ? "inst" : getName();
for (auto [result, name] : llvm::zip(getResults(), getPortNames()))
setNameFn(result, (base + "_" + name.cast<StringAttr>().getValue()).str());
}

LogicalResult
InstanceChoiceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
for (auto moduleName : getModuleNamesAttr()) {
if (failed(instance_like_impl::verifyReferencedModule(
*this, symbolTable, moduleName.cast<FlatSymbolRefAttr>())))
return failure();
}

auto caseNames = getCaseNamesAttr();
if (caseNames.empty())
return emitOpError() << "must have at least one case";

auto root = caseNames[0].cast<SymbolRefAttr>().getRootReference();
for (size_t i = 0, n = caseNames.size(); i < n; ++i) {
auto ref = caseNames[i].cast<SymbolRefAttr>();
if (!symbolTable.lookupNearestSymbolFrom<OptionCaseOp>(*this, ref))
return emitOpError() << "case refence " << ref << " is invalid";

if (ref.getRootReference() != root)
return emitOpError() << "case " << ref
<< " is not in the same option group as "
<< caseNames[0];
}

return success();
}

//===----------------------------------------------------------------------===//
// MemOp
//===----------------------------------------------------------------------===//

void MemOp::build(OpBuilder &builder, OperationState &result,
TypeRange resultTypes, uint32_t readLatency,
uint32_t writeLatency, uint64_t depth, RUWAttr ruw,
Expand Down
Loading

0 comments on commit ebef73b

Please sign in to comment.