Skip to content

Commit

Permalink
Removing legacy pipeline layout and dispatch binding model.
Browse files Browse the repository at this point in the history
This does not yet rename the methods and is just stripping all of the
legacy ops and methods.

Progress on #18154.
  • Loading branch information
benvanik committed Aug 21, 2024
1 parent e0bca76 commit bb40c45
Show file tree
Hide file tree
Showing 59 changed files with 67 additions and 2,709 deletions.
2 changes: 0 additions & 2 deletions compiler/plugins/target/CUDA/CUDATarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -381,8 +381,6 @@ class CUDATargetDevice final : public TargetDevice {
Builder b(context);

SmallVector<NamedAttribute> deviceConfigAttrs;
deviceConfigAttrs.emplace_back(b.getStringAttr("executable_create_2"),
b.getUnitAttr());
auto deviceConfigAttr = b.getDictionaryAttr(deviceConfigAttrs);

SmallVector<NamedAttribute> executableConfigAttrs;
Expand Down
2 changes: 0 additions & 2 deletions compiler/plugins/target/ROCM/ROCMTarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,6 @@ class ROCMTargetDevice final : public TargetDevice {
deviceConfigAttrs.emplace_back(b.getStringAttr("legacy_sync"),
b.getUnitAttr());
}
deviceConfigAttrs.emplace_back(b.getStringAttr("executable_create_2"),
b.getUnitAttr());
auto deviceConfigAttr = b.getDictionaryAttr(deviceConfigAttrs);

SmallVector<NamedAttribute> executableConfigAttrs;
Expand Down
2 changes: 0 additions & 2 deletions compiler/plugins/target/VulkanSPIRV/VulkanSPIRVTarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,6 @@ class VulkanTargetDevice : public TargetDevice {
Builder b(context);

SmallVector<NamedAttribute> deviceConfigAttrs;
deviceConfigAttrs.emplace_back(b.getStringAttr("executable_create_2"),
b.getUnitAttr());
auto deviceConfigAttr = b.getDictionaryAttr(deviceConfigAttrs);

SmallVector<NamedAttribute> executableConfigAttrs;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -320,111 +320,6 @@ class CommandBufferCollectiveOpConversion
mutable IREE::VM::ImportOp importOp;
};

class CommandBufferPushDescriptorSetOpConversion
: public OpConversionPattern<IREE::HAL::CommandBufferPushDescriptorSetOp> {
public:
CommandBufferPushDescriptorSetOpConversion(MLIRContext *context,
SymbolTable &importSymbols,
TypeConverter &typeConverter,
StringRef importName)
: OpConversionPattern(context) {
importOp = importSymbols.lookup<IREE::VM::ImportOp>(importName);
assert(importOp);
}

LogicalResult
matchAndRewrite(IREE::HAL::CommandBufferPushDescriptorSetOp op,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto importType = importOp.getFunctionType();

auto i32Type = rewriter.getI32Type();
auto i64Type = rewriter.getI64Type();

SmallVector<Value, 8> callOperands = {
adaptor.getCommandBuffer(),
adaptor.getPipelineLayout(),
castToImportType(adaptor.getSet(), i32Type, rewriter),
};
SmallVector<int16_t, 5> segmentSizes = {
/*command_buffer=*/-1,
/*pipeline_layout=*/-1,
/*set=*/-1,
/*bindings=*/
static_cast<int16_t>(adaptor.getBindingOrdinals().size()),
};
for (size_t i = 0; i < adaptor.getBindingOrdinals().size(); ++i) {
callOperands.push_back(
castToImportType(adaptor.getBindingOrdinals()[i], i32Type, rewriter));
auto [bindingBufferSlot, bindingBuffer] = splitBufferSlot(
op.getLoc(), adaptor.getBindingBuffers()[i], rewriter);
callOperands.push_back(bindingBufferSlot);
callOperands.push_back(bindingBuffer);
callOperands.push_back(
castToImportType(adaptor.getBindingOffsets()[i], i64Type, rewriter));
callOperands.push_back(
castToImportType(adaptor.getBindingLengths()[i], i64Type, rewriter));
}

auto callOp = rewriter.replaceOpWithNewOp<IREE::VM::CallVariadicOp>(
op, SymbolRefAttr::get(importOp), importType.getResults(), segmentSizes,
importType.getInputs(), callOperands);
copyImportAttrs(importOp, callOp);
return success();
}

private:
mutable IREE::VM::ImportOp importOp;
};

class CommandBufferDispatchIndirectOpConversion
: public OpConversionPattern<IREE::HAL::CommandBufferDispatchIndirectOp> {
public:
CommandBufferDispatchIndirectOpConversion(MLIRContext *context,
SymbolTable &importSymbols,
TypeConverter &typeConverter,
StringRef importName)
: OpConversionPattern(typeConverter, context) {
importOp = importSymbols.lookup<IREE::VM::ImportOp>(importName);
assert(importOp);
}

LogicalResult
matchAndRewrite(IREE::HAL::CommandBufferDispatchIndirectOp op,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto importType = importOp.getFunctionType();
auto [workgroupsBufferSlot, workgroupsBuffer] =
splitBufferSlot(op.getLoc(), adaptor.getWorkgroupsBuffer(), rewriter);
auto flags = adaptor.getFlagsAttr()
? rewriter
.create<IREE::VM::ConstI64Op>(
op.getLoc(), adaptor.getFlagsAttr().getInt())
.getResult()
: rewriter.create<IREE::VM::ConstI64ZeroOp>(op.getLoc())
.getResult();
SmallVector<Value, 8> callOperands = {
adaptor.getCommandBuffer(),
adaptor.getExecutable(),
castToImportType(adaptor.getEntryPoint(), rewriter.getI32Type(),
rewriter),
workgroupsBufferSlot,
workgroupsBuffer,
castToImportType(adaptor.getWorkgroupsOffset(), rewriter.getI64Type(),
rewriter),
flags,
};
auto callOp = rewriter.replaceOpWithNewOp<IREE::VM::CallOp>(
op, SymbolRefAttr::get(importOp), importType.getResults(),
callOperands);
copyImportAttrs(importOp, callOp);
return success();
}

private:
mutable IREE::VM::ImportOp importOp;
};

class CommandBufferDispatch2OpConversion
: public OpConversionPattern<IREE::HAL::CommandBufferDispatch2Op> {
public:
Expand Down Expand Up @@ -612,18 +507,6 @@ void populateHALCommandBufferToVMPatterns(MLIRContext *context,
context, importSymbols, typeConverter, "hal.command_buffer.copy_buffer");
patterns.insert<CommandBufferCollectiveOpConversion>(
context, importSymbols, typeConverter, "hal.command_buffer.collective");
patterns
.insert<VMImportOpConversion<IREE::HAL::CommandBufferPushConstantsOp>>(
context, importSymbols, typeConverter,
"hal.command_buffer.push_constants");
patterns.insert<CommandBufferPushDescriptorSetOpConversion>(
context, importSymbols, typeConverter,
"hal.command_buffer.push_descriptor_set");
patterns.insert<VMImportOpConversion<IREE::HAL::CommandBufferDispatchOp>>(
context, importSymbols, typeConverter, "hal.command_buffer.dispatch");
patterns.insert<CommandBufferDispatchIndirectOpConversion>(
context, importSymbols, typeConverter,
"hal.command_buffer.dispatch.indirect");
patterns.insert<CommandBufferDispatch2OpConversion>(
context, importSymbols, typeConverter, "hal.command_buffer.dispatch2");
patterns.insert<CommandBufferDispatch2IndirectOpConversion>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,72 +84,6 @@ class RemoveExecutableOpConversion
}
};

class ExecutableCreateOpConversion
: public OpConversionPattern<IREE::HAL::ExecutableCreateOp> {
public:
ExecutableCreateOpConversion(MLIRContext *context, SymbolTable &importSymbols,
TypeConverter &typeConverter,
StringRef importName)
: OpConversionPattern(context) {
importOp = importSymbols.lookup<IREE::VM::ImportOp>(importName);
assert(importOp);
}

LogicalResult
matchAndRewrite(IREE::HAL::ExecutableCreateOp createOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Materialize vm.rodata for the binary.
auto executableBinaryOp =
SymbolTable::lookupNearestSymbolFrom<IREE::HAL::ExecutableBinaryOp>(
createOp, createOp.getExecutableTarget());
auto executableOp = executableBinaryOp.getOperation()
->getParentOfType<IREE::HAL::ExecutableOp>();
std::string rodataName = sanitizeSymbolName(
(executableOp.getName() + "_" + executableBinaryOp.getName()).str());
auto rodataOp = rewriter.create<IREE::VM::RodataInlineOp>(
executableBinaryOp.getLoc(),
IREE::VM::RefType::get(rewriter.getType<IREE::VM::BufferType>()),
rewriter.getStringAttr(rodataName), executableBinaryOp.getData(),
rewriter.getI64IntegerAttr(16), executableBinaryOp.getMimeTypeAttr());

// Get format string as a rodata blob.
auto executableFormatStr = rewriter.create<IREE::VM::RodataInlineOp>(
createOp.getLoc(), executableBinaryOp.getFormatAttr());

// Pack constants, if any.
auto constantBuffer = createPackedConstantBuffer(
createOp.getLoc(), adaptor.getConstants(), rewriter);

SmallVector<int16_t, 5> segmentSizes = {
/*device=*/-1,
/*executable_format=*/-1,
/*executable_data=*/-1,
/*constants=*/-1,
/*pipeline_layouts=*/
static_cast<int16_t>(llvm::size(adaptor.getLayouts())),
};
SmallVector<Value, 8> callOperands = {
adaptor.getDevice(),
executableFormatStr,
rodataOp,
constantBuffer,
};
callOperands.append(adaptor.getLayouts().begin(),
adaptor.getLayouts().end());

auto importType = importOp.getFunctionType();
auto callOp = rewriter.replaceOpWithNewOp<IREE::VM::CallVariadicOp>(
createOp, SymbolRefAttr::get(importOp), importType.getResults(),
segmentSizes, importType.getInputs(), callOperands);
copyImportAttrs(importOp, callOp);

return success();
}

private:
mutable IREE::VM::ImportOp importOp;
};

class ExecutableCreate2OpConversion
: public OpConversionPattern<IREE::HAL::ExecutableCreate2Op> {
public:
Expand Down Expand Up @@ -216,16 +150,8 @@ void populateHALExecutableToVMPatterns(MLIRContext *context,
// contents during conversion of the ops that use them.
patterns.insert<RemoveExecutableOpConversion>(context);

patterns.insert<ExecutableCreateOpConversion>(
context, importSymbols, typeConverter, "hal.executable.create");
patterns.insert<ExecutableCreate2OpConversion>(
context, importSymbols, typeConverter, "hal.executable.create2");

patterns.insert<VMImportOpConversion<IREE::HAL::DescriptorSetLayoutCreateOp>>(
context, importSymbols, typeConverter,
"hal.descriptor_set_layout.create");
patterns.insert<VMImportOpConversion<IREE::HAL::PipelineLayoutCreateOp>>(
context, importSymbols, typeConverter, "hal.pipeline_layout.create");
}

} // namespace mlir::iree_compiler
Original file line number Diff line number Diff line change
Expand Up @@ -292,112 +292,6 @@ util.func public @command_buffer_collective_send(

// -----

// CHECK-LABEL: @command_buffer_push_descriptor_set
// CHECK-SAME: (%[[CMD:.+]]: !vm.ref<!hal.command_buffer>,
// CHECK-SAME: %[[LAYOUT:.+]]: !vm.ref<!hal.pipeline_layout>,
// CHECK-SAME: %[[BUFFER:.+]]: !vm.ref<!hal.buffer>,
// CHECK-SAME: %[[SLOT:.+]]: i32)
util.func public @command_buffer_push_descriptor_set(
%cmd: !hal.command_buffer,
%layout: !hal.pipeline_layout,
%buffer: !hal.buffer,
%slot: index
) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
%c4096 = arith.constant 4096 : index
%c8000 = arith.constant 8000 : index
// CHECK: %[[C0:.+]] = vm.const.i32.zero
// CHECK: %[[C1:.+]] = vm.const.i32 1
// CHECK: %[[NULL:.+]] = vm.const.ref.zero : !vm.ref<!hal.buffer>
// CHECK: vm.call.variadic @hal.command_buffer.push_descriptor_set
// CHECK-SAME: (%[[CMD]], %[[LAYOUT]], %c1, [
// CHECK-SAME: (%[[C0]], %[[C0]], %[[BUFFER]], %c4096, %c8000),
// CHECK-SAME: (%[[C1]], %[[SLOT]], %[[NULL]], %c4, %c4096)
// CHECK-SAME: ]) : (!vm.ref<!hal.command_buffer>, !vm.ref<!hal.pipeline_layout>, i32, tuple<i32, i32, !vm.ref<!hal.buffer>, i64, i64> ...)
hal.command_buffer.push_descriptor_set<%cmd : !hal.command_buffer>
layout(%layout : !hal.pipeline_layout)[%c1]
bindings([
%c0 = (%buffer : !hal.buffer)[%c4096, %c8000],
%c1 = (%slot : index)[%c4, %c4096]
])
util.return
}

// -----

// CHECK-LABEL: @command_buffer_dispatch
// CHECK-SAME: (%[[CMD:.+]]: !vm.ref<!hal.command_buffer>,
// CHECK-SAME: %[[EXECUTABLE:.+]]: !vm.ref<!hal.executable>)
util.func public @command_buffer_dispatch(
%cmd: !hal.command_buffer,
%executable: !hal.executable
) {
// CHECK-DAG: %[[ORDINAL:.+]] = vm.const.i32 123
%ordinal = arith.constant 123 : index
%c100 = arith.constant 100 : index
%c200 = arith.constant 200 : index
%c300 = arith.constant 300 : index
// CHECK-DAG: %[[FLAGS:.+]] = vm.const.i64.zero
// CHECK: vm.call @hal.command_buffer.dispatch(%[[CMD]], %[[EXECUTABLE]], %[[ORDINAL]], %c100, %c200, %c300, %[[FLAGS]])
hal.command_buffer.dispatch<%cmd : !hal.command_buffer>
target(%executable : !hal.executable)[%ordinal]
workgroups([%c100, %c200, %c300])
flags(None)
util.return
}

// -----

// CHECK-LABEL: @command_buffer_dispatch_indirect
// CHECK-SAME: (%[[CMD:.+]]: !vm.ref<!hal.command_buffer>,
// CHECK-SAME: %[[EXECUTABLE:.+]]: !vm.ref<!hal.executable>,
// CHECK-SAME: %[[BUFFER:.+]]: !vm.ref<!hal.buffer>)
util.func public @command_buffer_dispatch_indirect(
%cmd: !hal.command_buffer,
%executable: !hal.executable,
%buffer: !hal.buffer
) {
// CHECK-DAG: %[[ORDINAL:.+]] = vm.const.i32 123
%ordinal = arith.constant 123 : index
%c100 = arith.constant 100 : index
// CHECK-DAG: %[[UNUSED_SLOT:.+]] = vm.const.i32.zero
// CHECK-DAG: %[[FLAGS:.+]] = vm.const.i64.zero
// CHECK: vm.call @hal.command_buffer.dispatch.indirect(%[[CMD]], %[[EXECUTABLE]], %[[ORDINAL]], %[[UNUSED_SLOT]], %[[BUFFER]], %c100, %[[FLAGS]])
hal.command_buffer.dispatch.indirect<%cmd : !hal.command_buffer>
target(%executable : !hal.executable)[%ordinal]
workgroups(%buffer : !hal.buffer)[%c100]
flags(None)
util.return
}

// -----

// CHECK-LABEL: @command_buffer_dispatch_indirect_indirect
// CHECK-SAME: (%[[CMD:.+]]: !vm.ref<!hal.command_buffer>,
// CHECK-SAME: %[[EXECUTABLE:.+]]: !vm.ref<!hal.executable>,
// CHECK-SAME: %[[BUFFER_SLOT:.+]]: i32)
util.func public @command_buffer_dispatch_indirect_indirect(
%cmd: !hal.command_buffer,
%executable: !hal.executable,
%buffer_slot: index
) {
// CHECK-DAG: %[[ORDINAL:.+]] = vm.const.i32 123
%ordinal = arith.constant 123 : index
%c100 = arith.constant 100 : index
// CHECK-DAG: %[[NULL_BUFFER:.+]] = vm.const.ref.zero : !vm.ref<!hal.buffer>
// CHECK-DAG: %[[FLAGS:.+]] = vm.const.i64.zero
// CHECK: vm.call @hal.command_buffer.dispatch.indirect(%[[CMD]], %[[EXECUTABLE]], %[[ORDINAL]], %[[BUFFER_SLOT]], %[[NULL_BUFFER]], %c100, %[[FLAGS]])
hal.command_buffer.dispatch.indirect<%cmd : !hal.command_buffer>
target(%executable : !hal.executable)[%ordinal]
workgroups(%buffer_slot : index)[%c100]
flags(None)
util.return
}

// -----

// CHECK-LABEL: @command_buffer_dispatch2
// CHECK-SAME: (%[[CMD:.+]]: !vm.ref<!hal.command_buffer>,
// CHECK-SAME: %[[EXECUTABLE:.+]]: !vm.ref<!hal.executable>,
Expand Down
Loading

0 comments on commit bb40c45

Please sign in to comment.