Skip to content

Commit

Permalink
Converting HIP target to support executable-create2.
Browse files Browse the repository at this point in the history
This produces a new flatbuffer that supports multiple hipModule_ts per
HAL executable, reorganizes per-export information to be per-export,
and removes HAL pipeline layouts and the existing stateful command
recording.
  • Loading branch information
benvanik committed Aug 20, 2024
1 parent acdda79 commit af25026
Show file tree
Hide file tree
Showing 29 changed files with 670 additions and 1,534 deletions.
191 changes: 92 additions & 99 deletions compiler/plugins/target/ROCM/ROCMTarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -212,23 +212,30 @@ class ROCMTargetDevice final : public TargetDevice {
getDefaultDeviceTarget(MLIRContext *context,
const TargetRegistry &targetRegistry) const override {
Builder b(context);
SmallVector<NamedAttribute> configAttrItems;

SmallVector<NamedAttribute> deviceConfigAttrs;
if (options.legacySync) {
// Indicates that the runtime HAL driver operates only in the legacy
// synchronous mode.
configAttrItems.emplace_back(b.getStringAttr("legacy_sync"),
b.getUnitAttr());
deviceConfigAttrs.emplace_back(b.getStringAttr("legacy_sync"),
b.getUnitAttr());
}
DictionaryAttr configAttr = b.getDictionaryAttr(configAttrItems);
deviceConfigAttrs.emplace_back(b.getStringAttr("executable_create_2"),
b.getUnitAttr());
auto deviceConfigAttr = b.getDictionaryAttr(deviceConfigAttrs);

SmallVector<NamedAttribute> executableConfigAttrs;
auto executableConfigAttr = b.getDictionaryAttr(executableConfigAttrs);

// If we had multiple target environments we would generate one target attr
// per environment, with each setting its own environment attribute.
SmallVector<IREE::HAL::ExecutableTargetAttr> executableTargetAttrs;
targetRegistry.getTargetBackend("rocm")->getDefaultExecutableTargets(
context, "rocm", configAttr, executableTargetAttrs);
context, "rocm", executableConfigAttr, executableTargetAttrs);

return IREE::HAL::DeviceTargetAttr::get(context, b.getStringAttr("hip"),
configAttr, executableTargetAttrs);
deviceConfigAttr,
executableTargetAttrs);
}

private:
Expand Down Expand Up @@ -355,35 +362,27 @@ class ROCMTargetBackend final : public TargetBackend {
auto exportOps = llvm::to_vector_of<IREE::HAL::ExecutableExportOp>(
variantOp.getExportOps());
llvm::StringMap<IREE::HAL::ExecutableExportOp> exportOpMap;
std::vector<std::array<int32_t, 3>> workgroupSizes;
SmallVector<uint32_t> workgroupLocalMemories;
uint32_t subgroupSize = 64;
std::optional<uint32_t> subgroupSize;
for (IREE::HAL::ExecutableExportOp exportOp : exportOps) {
exportOpMap[exportOp.getSymName()] = exportOp;

std::array<int32_t, 3> workgroupSize = {1, 1, 1};
if (std::optional<ArrayAttr> workgroupSizeAttr =
exportOp.getWorkgroupSize()) {
for (auto [value, sizeAttr] :
llvm::zip_equal(workgroupSize, *workgroupSizeAttr))
value = cast<IntegerAttr>(sizeAttr).getInt();
}
workgroupSizes.push_back(workgroupSize);

// TODO: put this either on the variant or propagate as a function
// attribute instead - today this *must* be consistent across all exports
// and it shouldn't need to be.
if (auto setSubgroupSize = exportOp.getSubgroupSizeAsUInt()) {
if (setSubgroupSize.value() != 32 && setSubgroupSize.value() != 64) {
return variantOp.emitError()
<< "invalid subgroup size " << setSubgroupSize.value();
}
if (subgroupSize.has_value() &&
setSubgroupSize.value() != subgroupSize.value()) {
return variantOp.emitError()
<< "multiple exports with different subgroup sizes; this is a "
"limitation of the IREE compilation process and should be "
"fixed";
}
subgroupSize = setSubgroupSize.value();
}

uint32_t workgroupLocalMemory = 0;
if (std::optional<APInt> workgroupLocalMemoryAttr =
exportOp.getWorkgroupLocalMemory()) {
workgroupLocalMemory = workgroupLocalMemoryAttr->getSExtValue();
}
workgroupLocalMemories.push_back(workgroupLocalMemory);
}

std::string targetHSACO;
Expand Down Expand Up @@ -468,10 +467,15 @@ class ROCMTargetBackend final : public TargetBackend {
std::string features;
if (targetArch.starts_with("gfx10") ||
targetArch.starts_with("gfx11")) {
if (subgroupSize == 32)
switch (subgroupSize.value_or(64)) {
case 32:
features = "+wavefrontsize32";
if (subgroupSize == 64)
break;
default:
case 64:
features = "+wavefrontsize64";
break;
}
}
if (!targetFeatures.empty()) {
features += (features.empty() ? "" : ",") + targetFeatures.str();
Expand Down Expand Up @@ -579,92 +583,81 @@ class ROCMTargetBackend final : public TargetBackend {
auto sourceFilesRef = createSourceFilesVec(
serOptions.debugLevel, variantOp.getSourcesAttr(), builder);

SmallVector<StringRef> entryPointNames;
SmallVector<iree_hal_debug_FileLineLocDef_ref_t> sourceLocationRefs;
entryPointNames.resize(exportOps.size());
// Only a single module today.
SmallVector<iree_hal_hip_ModuleDef_ref_t> moduleRefs;
{
auto hsacoImageRef = flatbuffers_string_create(
builder, targetHSACO.c_str(), targetHSACO.size());
moduleRefs.push_back(
iree_hal_hip_ModuleDef_create(builder, hsacoImageRef));
}
auto modulesRef = builder.createOffsetVecDestructive(moduleRefs);

// Generate optional per-export debug information.
// May be empty if no debug information was requested.
auto exportDebugInfos =
createExportDefs(serOptions.debugLevel, exportOps, builder);

SmallVector<iree_hal_hip_ExportDef_ref_t> exportRefs;
exportRefs.resize(exportOps.size(), 0);
for (auto exportOp : exportOps) {
auto ordinalAttr = exportOp.getOrdinalAttr();
if (!ordinalAttr) {
return mlir::emitError(exportOp.getLoc())
<< "could not compile rocm binary: export op is missing ordinal";
}
int64_t ordinal = ordinalAttr.getInt();
entryPointNames[ordinal] = exportOp.getName();

// Optional source location information for debugging/profiling.
if (serOptions.debugLevel >= 1) {
if (auto loc = findFirstFileLoc(exportOp.getLoc())) {
// We only ever resize to the maximum -- so all previous data will
// be kept as-is.
sourceLocationRefs.resize(exportOps.size());
auto filenameRef = builder.createString(loc->getFilename());
sourceLocationRefs[ordinal] = iree_hal_debug_FileLineLocDef_create(
builder, filenameRef, loc->getLine());
}

auto kernelNameRef = builder.createString(exportOp.getName());

iree_hal_hip_BlockDims_t blockDims = {0};
if (auto workgroupSizeAttr = exportOp.getWorkgroupSize()) {
auto workgroupSize = workgroupSizeAttr->getValue();
blockDims.x = cast<IntegerAttr>(workgroupSize[0]).getInt();
blockDims.y = cast<IntegerAttr>(workgroupSize[1]).getInt();
blockDims.z = cast<IntegerAttr>(workgroupSize[2]).getInt();
}
}

// Optional compilation stage source files.
SmallVector<iree_hal_debug_StageLocationsDef_ref_t> stageLocationsRefs;
if (serOptions.debugLevel >= 3) {
for (auto exportOp : exportOps) {
SmallVector<iree_hal_debug_StageLocationDef_ref_t> stageLocationRefs;
if (auto locsAttr = exportOp.getSourceLocsAttr()) {
for (auto locAttr : locsAttr.getValue()) {
if (auto loc =
findFirstFileLoc(cast<LocationAttr>(locAttr.getValue()))) {
auto stageNameRef = builder.createString(locAttr.getName());
auto filenameRef = builder.createString(loc->getFilename());
stageLocationRefs.push_back(
iree_hal_debug_StageLocationDef_create(
builder, stageNameRef,
iree_hal_debug_FileLineLocDef_create(builder, filenameRef,
loc->getLine())));
}
}
uint32_t blockSharedMemorySize = 0;
if (std::optional<APInt> workgroupLocalMemoryAttr =
exportOp.getWorkgroupLocalMemory()) {
blockSharedMemorySize = workgroupLocalMemoryAttr->getSExtValue();
}

auto layoutAttr = exportOp.getLayoutAttr();
uint32_t constantCount =
static_cast<uint32_t>(layoutAttr.getPushConstants());
SmallVector<iree_hal_hip_BindingBits_enum_t> bindingFlags;
for (auto bindingAttr : layoutAttr.getSetLayout(0).getBindings()) {
iree_hal_hip_BindingBits_enum_t flags = 0;
if (allEnumBitsSet(bindingAttr.getFlags(),
IREE::HAL::DescriptorFlags::ReadOnly)) {
flags |= iree_hal_hip_BindingBits_READ_ONLY;
}
if (!stageLocationRefs.empty()) {
// We only ever resize to the maximum -- so all previous data will
// be kept as-is.
stageLocationsRefs.resize(exportOps.size());
int64_t ordinal = exportOp.getOrdinalAttr().getInt();
stageLocationsRefs[ordinal] = iree_hal_debug_StageLocationsDef_create(
builder, builder.createOffsetVecDestructive(stageLocationRefs));
if (allEnumBitsSet(bindingAttr.getFlags(),
IREE::HAL::DescriptorFlags::Indirect)) {
flags |= iree_hal_hip_BindingBits_INDIRECT;
}
bindingFlags.push_back(flags);
}
auto bindingFlagsRef = iree_hal_hip_BindingBits_vec_create(
builder, bindingFlags.data(), bindingFlags.size());

iree_hal_hip_ExportDef_start(builder);
iree_hal_hip_ExportDef_module_ordinal_add(builder, 0); // always 0 today
iree_hal_hip_ExportDef_kernel_name_add(builder, kernelNameRef);
iree_hal_hip_ExportDef_block_dims_add(builder, &blockDims);
iree_hal_hip_ExportDef_block_shared_memory_size_add(
builder, blockSharedMemorySize);
iree_hal_hip_ExportDef_constant_count_add(builder, constantCount);
iree_hal_hip_ExportDef_binding_flags_add(builder, bindingFlagsRef);
iree_hal_hip_ExportDef_debug_info_add(builder, exportDebugInfos[ordinal]);
exportRefs[ordinal] = iree_hal_hip_ExportDef_end(builder);
}
auto exportsRef = builder.createOffsetVecDestructive(exportRefs);

auto hsacoRef = flatbuffers_string_create(builder, targetHSACO.c_str(),
targetHSACO.size());

auto entryPointsRef = builder.createStringVec(entryPointNames);
iree_hal_hip_BlockSize_vec_start(builder);
auto blockSizes = workgroupSizes.begin();
for (int i = 0, e = entryPointNames.size(); i < e; ++i) {
iree_hal_hip_BlockSize_vec_push_create(
builder, (*blockSizes)[0], (*blockSizes)[1], (*blockSizes)[2]);
++blockSizes;
}
auto workgroupLocalMemoriesRef =
builder.createInt32Vec(workgroupLocalMemories);
auto blockSizesRef = iree_hal_hip_BlockSize_vec_end(builder);
iree_hal_hip_ExecutableDef_entry_points_add(builder, entryPointsRef);
iree_hal_hip_ExecutableDef_block_sizes_add(builder, blockSizesRef);
iree_hal_hip_ExecutableDef_shared_memory_sizes_add(
builder, workgroupLocalMemoriesRef);
iree_hal_hip_ExecutableDef_hsaco_image_add(builder, hsacoRef);
if (!sourceLocationRefs.empty()) {
auto sourceLocationsRef =
builder.createOffsetVecDestructive(sourceLocationRefs);
iree_hal_hip_ExecutableDef_source_locations_add(builder,
sourceLocationsRef);
}
if (!stageLocationsRefs.empty()) {
auto stageLocationsRef =
builder.createOffsetVecDestructive(stageLocationsRefs);
iree_hal_hip_ExecutableDef_stage_locations_add(builder,
stageLocationsRef);
}
iree_hal_hip_ExecutableDef_exports_add(builder, exportsRef);
iree_hal_hip_ExecutableDef_modules_add(builder, modulesRef);
iree_hal_hip_ExecutableDef_source_files_add(builder, sourceFilesRef);
iree_hal_hip_ExecutableDef_end_as_root(builder);

Expand Down
12 changes: 12 additions & 0 deletions compiler/src/iree/compiler/Dialect/HAL/Analysis/DeviceAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,18 @@ DeviceAnalysis::lookupDeviceGlobals(Value deviceValue) {
return globalOps;
}

std::optional<DeviceSet> DeviceAnalysis::lookupDeviceTargets(
IREE::Util::GlobalOpInterface deviceGlobalOp) {
return lookupDeviceTargets(FlatSymbolRefAttr::get(deviceGlobalOp));
}

std::optional<DeviceSet>
DeviceAnalysis::lookupDeviceTargets(SymbolRefAttr deviceGlobalAttr) {
SetVector<IREE::HAL::DeviceTargetAttr> resultSet;
gatherDeviceTargets(deviceGlobalAttr, explorer.getRootOp(), resultSet);
return DeviceSet(resultSet.getArrayRef());
}

std::optional<DeviceSet>
DeviceAnalysis::lookupDeviceTargets(Value deviceValue) {
auto valuePVS = solver.lookupElementFor<DeviceTargetValuePVS>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,15 @@ class DeviceAnalysis {
std::optional<SmallVector<IREE::Util::GlobalOpInterface>>
lookupDeviceGlobals(Value deviceValue);

// Returns a set of possible targets of the given `!hal.device` global, if
// analyzed.
std::optional<DeviceSet>
lookupDeviceTargets(IREE::Util::GlobalOpInterface deviceGlobalOp);

// Returns a set of possible targets of the given `!hal.device` global, if
// analyzed.
std::optional<DeviceSet> lookupDeviceTargets(SymbolRefAttr deviceGlobalAttr);

// Returns a set of possible targets of the given `!hal.device` value, if
// analyzed.
std::optional<DeviceSet> lookupDeviceTargets(Value deviceValue);
Expand Down
6 changes: 6 additions & 0 deletions compiler/src/iree/compiler/Dialect/HAL/Analysis/DeviceSet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@ DeviceSet::DeviceSet(ArrayAttr targetsAttr) {
}
}

DeviceSet::DeviceSet(ArrayRef<IREE::HAL::DeviceTargetAttr> targetAttrs) {
for (auto targetAttr : targetAttrs) {
this->targetAttrs.insert(targetAttr);
}
}

DeviceSet::DeviceSet(const DenseSet<IREE::HAL::DeviceTargetAttr> &targetAttrs)
: targetAttrs(targetAttrs) {}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class DeviceSet {
public:
DeviceSet() = default;
explicit DeviceSet(ArrayAttr targetsAttr);
explicit DeviceSet(ArrayRef<IREE::HAL::DeviceTargetAttr> targetAttrs);
explicit DeviceSet(const DenseSet<IREE::HAL::DeviceTargetAttr> &targetAttrs);
~DeviceSet();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ static llvm::cl::opt<bool> clIndirectCommandBuffers{
static llvm::cl::opt<bool> clExperimentalDispatch2{
"iree-hal-experimental-dispatch2",
llvm::cl::desc("Whether to emit iree_hal_command_buffer_dispatch2 ops."),
llvm::cl::init(false),
llvm::cl::init(true),
};

struct ContextResolveOpPattern
Expand Down
Loading

0 comments on commit af25026

Please sign in to comment.