Skip to content

Commit

Permalink
Renaming dispatch2 -> dispatch and create2 -> create.
Browse files Browse the repository at this point in the history
Progress on #18154.
  • Loading branch information
benvanik committed Aug 26, 2024
1 parent d400ca8 commit 26c3820
Show file tree
Hide file tree
Showing 36 changed files with 251 additions and 391 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -320,20 +320,20 @@ class CommandBufferCollectiveOpConversion
mutable IREE::VM::ImportOp importOp;
};

class CommandBufferDispatch2OpConversion
: public OpConversionPattern<IREE::HAL::CommandBufferDispatch2Op> {
class CommandBufferDispatchOpConversion
: public OpConversionPattern<IREE::HAL::CommandBufferDispatchOp> {
public:
CommandBufferDispatch2OpConversion(MLIRContext *context,
SymbolTable &importSymbols,
TypeConverter &typeConverter,
StringRef importName)
CommandBufferDispatchOpConversion(MLIRContext *context,
SymbolTable &importSymbols,
TypeConverter &typeConverter,
StringRef importName)
: OpConversionPattern(typeConverter, context) {
importOp = importSymbols.lookup<IREE::VM::ImportOp>(importName);
assert(importOp);
}

LogicalResult
matchAndRewrite(IREE::HAL::CommandBufferDispatch2Op op, OpAdaptor adaptor,
matchAndRewrite(IREE::HAL::CommandBufferDispatchOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto importType = importOp.getFunctionType();

Expand Down Expand Up @@ -396,20 +396,20 @@ class CommandBufferDispatch2OpConversion
mutable IREE::VM::ImportOp importOp;
};

class CommandBufferDispatch2IndirectOpConversion
: public OpConversionPattern<IREE::HAL::CommandBufferDispatch2IndirectOp> {
class CommandBufferDispatchIndirectOpConversion
: public OpConversionPattern<IREE::HAL::CommandBufferDispatchIndirectOp> {
public:
CommandBufferDispatch2IndirectOpConversion(MLIRContext *context,
SymbolTable &importSymbols,
TypeConverter &typeConverter,
StringRef importName)
CommandBufferDispatchIndirectOpConversion(MLIRContext *context,
SymbolTable &importSymbols,
TypeConverter &typeConverter,
StringRef importName)
: OpConversionPattern(typeConverter, context) {
importOp = importSymbols.lookup<IREE::VM::ImportOp>(importName);
assert(importOp);
}

LogicalResult
matchAndRewrite(IREE::HAL::CommandBufferDispatch2IndirectOp op,
matchAndRewrite(IREE::HAL::CommandBufferDispatchIndirectOp op,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

Expand Down Expand Up @@ -507,11 +507,11 @@ void populateHALCommandBufferToVMPatterns(MLIRContext *context,
context, importSymbols, typeConverter, "hal.command_buffer.copy_buffer");
patterns.insert<CommandBufferCollectiveOpConversion>(
context, importSymbols, typeConverter, "hal.command_buffer.collective");
patterns.insert<CommandBufferDispatch2OpConversion>(
context, importSymbols, typeConverter, "hal.command_buffer.dispatch2");
patterns.insert<CommandBufferDispatch2IndirectOpConversion>(
patterns.insert<CommandBufferDispatchOpConversion>(
context, importSymbols, typeConverter, "hal.command_buffer.dispatch");
patterns.insert<CommandBufferDispatchIndirectOpConversion>(
context, importSymbols, typeConverter,
"hal.command_buffer.dispatch2.indirect");
"hal.command_buffer.dispatch.indirect");
}

} // namespace mlir::iree_compiler
Original file line number Diff line number Diff line change
Expand Up @@ -84,20 +84,19 @@ class RemoveExecutableOpConversion
}
};

class ExecutableCreate2OpConversion
: public OpConversionPattern<IREE::HAL::ExecutableCreate2Op> {
class ExecutableCreateOpConversion
: public OpConversionPattern<IREE::HAL::ExecutableCreateOp> {
public:
ExecutableCreate2OpConversion(MLIRContext *context,
SymbolTable &importSymbols,
TypeConverter &typeConverter,
StringRef importName)
ExecutableCreateOpConversion(MLIRContext *context, SymbolTable &importSymbols,
TypeConverter &typeConverter,
StringRef importName)
: OpConversionPattern(context) {
importOp = importSymbols.lookup<IREE::VM::ImportOp>(importName);
assert(importOp);
}

LogicalResult
matchAndRewrite(IREE::HAL::ExecutableCreate2Op createOp, OpAdaptor adaptor,
matchAndRewrite(IREE::HAL::ExecutableCreateOp createOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Materialize vm.rodata for the binary.
auto executableBinaryOp =
Expand Down Expand Up @@ -150,8 +149,8 @@ void populateHALExecutableToVMPatterns(MLIRContext *context,
// contents during conversion of the ops that use them.
patterns.insert<RemoveExecutableOpConversion>(context);

patterns.insert<ExecutableCreate2OpConversion>(
context, importSymbols, typeConverter, "hal.executable.create2");
patterns.insert<ExecutableCreateOpConversion>(
context, importSymbols, typeConverter, "hal.executable.create");
}

} // namespace mlir::iree_compiler
Original file line number Diff line number Diff line change
Expand Up @@ -292,12 +292,12 @@ util.func public @command_buffer_collective_send(

// -----

// CHECK-LABEL: @command_buffer_dispatch2
// CHECK-LABEL: @command_buffer_dispatch
// CHECK-SAME: (%[[CMD:.+]]: !vm.ref<!hal.command_buffer>,
// CHECK-SAME: %[[EXECUTABLE:.+]]: !vm.ref<!hal.executable>,
// CHECK-SAME: %[[BUFFER:.+]]: !vm.ref<!hal.buffer>,
// CHECK-SAME: %[[SLOT:.+]]: i32)
util.func public @command_buffer_dispatch2(
util.func public @command_buffer_dispatch(
%cmd: !hal.command_buffer,
%executable: !hal.executable,
%buffer: !hal.buffer,
Expand All @@ -321,15 +321,15 @@ util.func public @command_buffer_dispatch2(
%c8000 = arith.constant 8000 : index
// CHECK-DAG: %[[NULL_BUFFER:.+]] = vm.const.ref.zero : !vm.ref<!hal.buffer>
// CHECK-DAG: %[[FLAGS:.+]] = vm.const.i64.zero
// CHECK: vm.call.variadic @hal.command_buffer.dispatch2
// CHECK: vm.call.variadic @hal.command_buffer.dispatch
// CHECK-SAME: %[[CMD]],
// CHECK-SAME: %[[EXECUTABLE]], %[[ORDINAL]],
// CHECK-SAME: %[[X]], %[[Y]], %[[Z]],
// CHECK-SAME: %[[FLAGS]],
// CHECK-SAME: [%[[CONSTANT0]], %[[CONSTANT1]]],
// CHECK-SAME: [(%[[C0]], %[[C0]], %[[BUFFER]], %c4096, %c8000),
// CHECK-SAME: (%[[C0]], %[[SLOT]], %[[NULL_BUFFER]], %c4, %c4096)]
hal.command_buffer.dispatch2<%cmd : !hal.command_buffer>
hal.command_buffer.dispatch<%cmd : !hal.command_buffer>
target(%executable : !hal.executable)[%ordinal]
workgroups([%x, %y, %z])
constants([%constant0, %constant1])
Expand All @@ -343,13 +343,13 @@ util.func public @command_buffer_dispatch2(

// -----

// CHECK-LABEL: vm.func private @command_buffer_dispatch2
// CHECK-LABEL: vm.func private @command_buffer_dispatch
// CHECK-SAME: (%[[CMD:[a-z0-9]+]]: !vm.ref<!hal.command_buffer>,
// CHECK-SAME: %[[EXECUTABLE:[a-z0-9]+]]: !vm.ref<!hal.executable>,
// CHECK-SAME: %[[WORKGROUPS_SLOT:[a-z0-9]+]]: i32,
// CHECK-SAME: %[[BUFFER:[a-z0-9]+]]: !vm.ref<!hal.buffer>,
// CHECK-SAME: %[[SLOT:[a-z0-9]+]]: i32)
util.func public @command_buffer_dispatch2(
util.func public @command_buffer_dispatch(
%cmd: !hal.command_buffer,
%executable: !hal.executable,
%workgroups_slot: index,
Expand All @@ -370,15 +370,15 @@ util.func public @command_buffer_dispatch2(
%c8000 = arith.constant 8000 : index
// CHECK-DAG: %[[NULL_BUFFER:.+]] = vm.const.ref.zero : !vm.ref<!hal.buffer>
// CHECK-DAG: %[[FLAGS:.+]] = vm.const.i64.zero
// CHECK: vm.call.variadic @hal.command_buffer.dispatch2.indirect
// CHECK: vm.call.variadic @hal.command_buffer.dispatch.indirect
// CHECK-SAME: %[[CMD]],
// CHECK-SAME: %[[EXECUTABLE]], %[[ORDINAL]],
// CHECK-SAME: %[[WORKGROUPS_SLOT]], %[[NULL_BUFFER]], %[[WORKGROUPS_OFFSET]],
// CHECK-SAME: %[[FLAGS]],
// CHECK-SAME: [%[[CONSTANT0]], %[[CONSTANT1]]],
// CHECK-SAME: [(%[[C0]], %[[C0]], %[[BUFFER]], %c4096, %c8000),
// CHECK-SAME: (%[[C0]], %[[SLOT]], %[[NULL_BUFFER]], %c4, %c4096)]
hal.command_buffer.dispatch2.indirect<%cmd : !hal.command_buffer>
hal.command_buffer.dispatch.indirect<%cmd : !hal.command_buffer>
target(%executable : !hal.executable)[%ordinal]
workgroups(%workgroups_slot : index)[%workgroups_offset]
constants([%constant0, %constant1])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,27 +54,27 @@ hal.executable @exe {
}
}

// CHECK-LABEL: @executableCreate2
util.func public @executableCreate2(
// CHECK-LABEL: @executableCreate
util.func public @executableCreate(
// CHECK-SAME: %[[DEV:.+]]: !vm.ref<!hal.device>
%device: !hal.device
) -> (!hal.executable, !hal.executable) {

// CHECK-DAG: %[[FORMAT1:.+]] = vm.rodata.inline "_utf8_format1_
// CHECK-DAG: %[[BINARY1:.+]] = vm.rodata.inline "exe_binary1" {alignment = 16 : i64} : !vm.buffer = dense<[0, 1, 2, 3]> : vector<4xi8>
// CHECK-DAG: %[[NULL1:.+]] = vm.const.ref.zero : !vm.buffer
// CHECK: %[[EXE1:.+]] = vm.call @hal.executable.create2(
// CHECK: %[[EXE1:.+]] = vm.call @hal.executable.create(
// CHECK-SAME: %[[DEV]], %[[FORMAT1]], %[[BINARY1]], %[[NULL1]]
// CHECK-SAME: ) {nosideeffects} : (!vm.ref<!hal.device>, !vm.buffer, !vm.buffer, !vm.buffer) -> !vm.ref<!hal.executable>
%0 = hal.executable.create2 device(%device : !hal.device) target(@exe::@binary1) : !hal.executable
%0 = hal.executable.create device(%device : !hal.device) target(@exe::@binary1) : !hal.executable

// CHECK-DAG: %[[FORMAT2:.+]] = vm.rodata.inline "_utf8_format2_
// CHECK-DAG: %[[BINARY2:.+]] = vm.rodata.inline "exe_binary2" {alignment = 16 : i64} : !vm.buffer = dense<[4, 5, 6, 7]> : vector<4xi8>
// CHECK-DAG: %[[NULL2:.+]] = vm.const.ref.zero : !vm.buffer
// CHECK: %[[EXE2:.+]] = vm.call @hal.executable.create2(
// CHECK: %[[EXE2:.+]] = vm.call @hal.executable.create(
// CHECK-SAME: %[[DEV]], %[[FORMAT2]], %[[BINARY2]], %[[NULL2]]
// CHECK-SAME: ) {nosideeffects} : (!vm.ref<!hal.device>, !vm.buffer, !vm.buffer, !vm.buffer) -> !vm.ref<!hal.executable>
%1 = hal.executable.create2 device(%device : !hal.device) target(@exe::@binary2) : !hal.executable
%1 = hal.executable.create device(%device : !hal.device) target(@exe::@binary2) : !hal.executable

// CHECK: vm.return %[[EXE1]], %[[EXE2]]
util.return %0, %1 : !hal.executable, !hal.executable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -663,7 +663,7 @@ struct CmdCollectiveOpPattern
}
};

struct CmdDispatch2OpPattern
struct CmdDispatchOpPattern
: public StreamConversionPattern<IREE::Stream::CmdDispatchOp> {
using StreamConversionPattern::StreamConversionPattern;
LogicalResult
Expand Down Expand Up @@ -796,7 +796,7 @@ struct CmdDispatch2OpPattern

auto flags = IREE::HAL::DispatchFlags::None;

return builder.create<IREE::HAL::CommandBufferDispatch2Op>(
return builder.create<IREE::HAL::CommandBufferDispatchOp>(
loc, commandBufferMapping.getHandle(), executable, ordinal,
workgroupCount, adaptor.getUniformOperands(), bindings, flags);
}
Expand Down Expand Up @@ -1365,7 +1365,7 @@ void populateStreamToHALPatterns(MLIRContext *context,
patterns
.insert<CmdFlushOpPattern, CmdInvalidateOpPattern, CmdDiscardOpPattern,
CmdFillOpPattern, CmdCopyOpPattern, CmdCollectiveOpPattern,
CmdDispatch2OpPattern, CmdFuncOpPattern, CmdCallOpPattern,
CmdDispatchOpPattern, CmdFuncOpPattern, CmdCallOpPattern,
CmdExecuteOpPattern, CmdSerialOpPattern, CmdConcurrentOpPattern>(
mapping, typeConverter, context);
patterns.insert<TimepointImmediateOpPattern, TimepointImportOpPattern,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ iree_lit_test_suite(
srcs = enforce_glob(
[
"channel_ops.mlir",
"cmd_dispatch2_ops.mlir",
"cmd_ops.mlir",
"context_ops.mlir",
"debug_ops.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ iree_lit_test_suite(
lit
SRCS
"channel_ops.mlir"
"cmd_dispatch2_ops.mlir"
"cmd_ops.mlir"
"context_ops.mlir"
"debug_ops.mlir"
Expand Down

This file was deleted.

Loading

0 comments on commit 26c3820

Please sign in to comment.