Skip to content

Commit

Permalink
Converting Metal target to support executable-create2.
Browse files Browse the repository at this point in the history
This produces a new flatbuffer that supports multiple shared libraries
per HAL executable, reorganizes per-export information to be per-export,
and swaps HAL pipeline layouts with the metadata required to setup
Metal compute pipelines with proper binding attributes. Extra debug
information is also plumbed through now - though we don't have tracy
wired up yet it could be plumbed through to Metal tooling.
  • Loading branch information
benvanik committed Aug 27, 2024
1 parent 9bbc926 commit e044271
Show file tree
Hide file tree
Showing 23 changed files with 831 additions and 1,107 deletions.
141 changes: 101 additions & 40 deletions compiler/plugins/target/MetalSPIRV/MetalSPIRVTarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,18 @@ class MetalSPIRVTargetBackend : public TargetBackend {
IREE::HAL::ExecutableVariantOp variantOp,
OpBuilder &executableBuilder) override {
ModuleOp innerModuleOp = variantOp.getInnerModule();

// TODO: rework this to compile all modules into the same metallib and
// source the entry points from them. Or use a linking tool (metal-ar) to
// link the compiled metallibs together. If we were not using spirv-cross
// we'd never do it like this with one module per function.
//
// Currently this is _really_ bad because it doesn't support linking like
// the Vulkan SPIR-V target: that allows multiple spirv::ModuleOps so we
// at least only have a single HAL executable; this should all be reworked
// to have multiple SPIR-V modules in a single executable and then even if
// passing through spirv-cross independently should link the resulting
// metallibs together.
auto spvModuleOp = *innerModuleOp.getOps<spirv::ModuleOp>().begin();
if (!serOptions.dumpIntermediatesPath.empty()) {
std::string assembly;
Expand All @@ -141,14 +153,6 @@ class MetalSPIRVTargetBackend : public TargetBackend {
variantOp.getName(), ".mlir", assembly);
}

// The runtime use ordinals instead of names but Metal requires function
// names for constructing pipeline states. Get an ordered list of the entry
// point names.
SmallVector<StringRef, 8> spirvEntryPointNames;
spvModuleOp.walk([&](spirv::EntryPointOp exportOp) {
spirvEntryPointNames.push_back(exportOp.getFn());
});

// 1. Serialize the spirv::ModuleOp into binary format.
SmallVector<uint32_t, 0> spvBinary;
if (failed(spirv::serialize(spvModuleOp, spvBinary))) {
Expand All @@ -160,6 +164,14 @@ class MetalSPIRVTargetBackend : public TargetBackend {
".spv", spvBinary);
}

// The runtime use ordinals instead of names but Metal requires function
// names for constructing pipeline states. Get an ordered list of the entry
// point names.
SmallVector<StringRef, 8> spirvEntryPointNames;
spvModuleOp.walk([&](spirv::EntryPointOp exportOp) {
spirvEntryPointNames.push_back(exportOp.getFn());
});

// 2. Cross compile SPIR-V to MSL source code.
SmallVector<MetalShader, 2> mslShaders;
SmallVector<std::string, 2> mslEntryPointNames;
Expand Down Expand Up @@ -188,23 +200,25 @@ class MetalSPIRVTargetBackend : public TargetBackend {
}

// 3. Compile MSL to MTLLibrary.
SmallVector<std::unique_ptr<llvm::MemoryBuffer>> metalLibs;
SmallVector<std::unique_ptr<llvm::MemoryBuffer>> metallibs;
metallibs.resize(mslShaders.size());
if (options.compileToMetalLib) {
// We need to use offline Metal shader compilers.
// TODO(#14048): The toolchain can also exist on other platforms. Probe
// the PATH instead.
auto hostTriple = llvm::Triple(llvm::sys::getProcessTriple());
if (hostTriple.isMacOSX()) {
for (auto [shader, entryPoint] :
llvm::zip(mslShaders, mslEntryPointNames)) {
for (auto [i, shader, entryPoint] :
llvm::zip_equal(llvm::seq(mslShaders.size()), mslShaders,
mslEntryPointNames)) {
std::unique_ptr<llvm::MemoryBuffer> lib = compileMSLToMetalLib(
options.targetPlatform, shader.source, entryPoint);
if (!lib) {
return variantOp.emitError()
<< "failed to compile to MTLLibrary from MSL:\n\n"
<< shader.source << "\n\n";
}
metalLibs.push_back(std::move(lib));
metallibs[i] = std::move(lib);
}
}
}
Expand All @@ -217,37 +231,84 @@ class MetalSPIRVTargetBackend : public TargetBackend {
auto sourceFilesRef = createSourceFilesVec(
serOptions.debugLevel, variantOp.getSourcesAttr(), builder);

auto entryPointNamesRef = builder.createStringVec(mslEntryPointNames);
iree_hal_metal_ExecutableDef_entry_points_add(builder, entryPointNamesRef);

iree_hal_metal_ThreadgroupSize_vec_start(builder);
for (auto &shader : mslShaders) {
iree_hal_metal_ThreadgroupSize_vec_push_create(
builder, shader.threadgroupSize.x, shader.threadgroupSize.y,
shader.threadgroupSize.z);
// Each library may provide multiple functions so we encode them
// independently.
SmallVector<iree_hal_metal_LibraryDef_ref_t> libraryRefs;
for (auto [shader, metallib] : llvm::zip_equal(mslShaders, metallibs)) {
const bool embedSource = !metallib || serOptions.debugLevel > 1;
iree_hal_metal_MSLSourceDef_ref_t sourceRef = 0;
if (embedSource) {
// TODO: pull this from an attribute?
// https://developer.apple.com/documentation/metal/mtllanguageversion
unsigned version = 196608; // MTLLanguageVersion3_0
auto sourceStrRef = builder.createString(shader.source);
sourceRef =
iree_hal_metal_MSLSourceDef_create(builder, version, sourceStrRef);
}
flatbuffers_string_ref_t metallibRef = 0;
if (metallib) {
metallibRef = flatbuffers_string_create(
builder, metallib->getBufferStart(), metallib->getBufferSize());
}
iree_hal_metal_LibraryDef_start(builder);
iree_hal_metal_LibraryDef_source_add(builder, sourceRef);
iree_hal_metal_LibraryDef_metallib_add(builder, metallibRef);
libraryRefs.push_back(iree_hal_metal_LibraryDef_end(builder));
}
auto threadgroupSizesRef = iree_hal_metal_ThreadgroupSize_vec_end(builder);
iree_hal_metal_ExecutableDef_threadgroup_sizes_add(builder,
threadgroupSizesRef);

if (metalLibs.empty()) {
auto shaderSourcesRef = builder.createStringVec(
llvm::map_range(mslShaders, [&](const MetalShader &shader) {
return shader.source;
}));
iree_hal_metal_ExecutableDef_shader_sources_add(builder,
shaderSourcesRef);
} else {
auto refs = llvm::to_vector<8>(llvm::map_range(
metalLibs, [&](const std::unique_ptr<llvm::MemoryBuffer> &buffer) {
return flatbuffers_string_create(builder, buffer->getBufferStart(),
buffer->getBufferSize());
}));
auto libsRef =
flatbuffers_string_vec_create(builder, refs.data(), refs.size());
iree_hal_metal_ExecutableDef_shader_libraries_add(builder, libsRef);
auto librariesRef = builder.createOffsetVecDestructive(libraryRefs);

// Generate optional per-export debug information.
// May be empty if no debug information was requested.
auto exportOps = llvm::to_vector_of<IREE::HAL::ExecutableExportOp>(
variantOp.getExportOps());
auto exportDebugInfos =
createExportDefs(serOptions.debugLevel, exportOps, builder);

SmallVector<iree_hal_metal_PipelineDef_ref_t> pipelineRefs;
for (auto [i, shader, entryPoint, exportOp] :
llvm::zip_equal(llvm::seq(mslShaders.size()), mslShaders,
mslEntryPointNames, exportOps)) {
auto entryPointRef = builder.createString(entryPoint);

iree_hal_metal_ThreadgroupSize_t threadgroupSize = {
shader.threadgroupSize.x,
shader.threadgroupSize.y,
shader.threadgroupSize.z,
};

auto layoutAttr = exportOp.getLayoutAttr();
uint32_t constantCount =
static_cast<uint32_t>(layoutAttr.getPushConstants());
SmallVector<iree_hal_metal_BindingBits_enum_t> bindingFlags;
for (auto bindingAttr : layoutAttr.getSetLayout(0).getBindings()) {
iree_hal_metal_BindingBits_enum_t flags = 0;
if (allEnumBitsSet(bindingAttr.getFlags(),
IREE::HAL::DescriptorFlags::ReadOnly)) {
flags |= iree_hal_metal_BindingBits_IMMUTABLE;
}
bindingFlags.push_back(flags);
}
auto bindingFlagsRef = iree_hal_metal_BindingBits_vec_create(
builder, bindingFlags.data(), bindingFlags.size());

iree_hal_metal_PipelineDef_start(builder);
iree_hal_metal_PipelineDef_library_ordinal_add(builder, i);
iree_hal_metal_PipelineDef_entry_point_add(builder, entryPointRef);
iree_hal_metal_PipelineDef_threadgroup_size_add(builder,
&threadgroupSize);
// TODO: embed additional metadata on threadgroup info if available.
// iree_hal_metal_PipelineDef_max_threads_per_threadgroup_add(builder, 0);
// iree_hal_metal_PipelineDef_threadgroup_size_aligned_add(builder,
// false);
iree_hal_metal_PipelineDef_constant_count_add(builder, constantCount);
iree_hal_metal_PipelineDef_binding_flags_add(builder, bindingFlagsRef);
iree_hal_metal_PipelineDef_debug_info_add(builder, exportDebugInfos[i]);
pipelineRefs.push_back(iree_hal_metal_PipelineDef_end(builder));
}
auto pipelinesRef = builder.createOffsetVecDestructive(pipelineRefs);

iree_hal_metal_ExecutableDef_pipelines_add(builder, pipelinesRef);
iree_hal_metal_ExecutableDef_libraries_add(builder, librariesRef);
iree_hal_metal_ExecutableDef_source_files_add(builder, sourceFilesRef);

iree_hal_metal_ExecutableDef_end_as_root(builder);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,14 @@ SmallVector<flatbuffers_ref_t>
createExportDefs(int debugLevel,
ArrayRef<IREE::HAL::ExecutableExportOp> exportOps,
FlatbufferBuilder &fbb) {
SmallVector<flatbuffers_ref_t> exportDefs;
exportDefs.resize(exportOps.size(), 0);

if (debugLevel < 1) {
// No debug information.
return {};
return exportDefs;
}

SmallVector<flatbuffers_ref_t> exportDefs;
exportDefs.resize(exportOps.size(), 0);

for (auto exportOp : exportOps) {
auto ordinalAttr = exportOp.getOrdinalAttr();
assert(ordinalAttr && "ordinals must be assigned");
Expand Down
3 changes: 1 addition & 2 deletions runtime/src/iree/base/internal/threading_win32.c
Original file line number Diff line number Diff line change
Expand Up @@ -278,8 +278,7 @@ void iree_thread_request_affinity(iree_thread_t* thread,
int affinity_desc_length = snprintf(
affinity_desc, IREE_ARRAYSIZE(affinity_desc), "group=%d, id=%d, smt=%d",
affinity.group, affinity.id, affinity.smt);
IREE_TRACE_ZONE_APPEND_TEXT_STRING_VIEW(z0, affinity_desc,
affinity_desc_length);
IREE_TRACE_ZONE_APPEND_TEXT(z0, affinity_desc, affinity_desc_length);
#endif // IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_INSTRUMENTATION

GROUP_AFFINITY group_affinity;
Expand Down
2 changes: 1 addition & 1 deletion runtime/src/iree/hal/command_buffer.c
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,7 @@ IREE_API_EXPORT iree_status_t iree_hal_command_buffer_dispatch2(
int xyz_string_length =
snprintf(xyz_string, IREE_ARRAYSIZE(xyz_string), "%ux%ux%u",
workgroup_count[0], workgroup_count[1], workgroup_count[2]);
IREE_TRACE_ZONE_APPEND_TEXT_STRING_VIEW(z0, xyz_string, xyz_string_length);
IREE_TRACE_ZONE_APPEND_TEXT(z0, xyz_string, xyz_string_length);
});
#endif // IREE_HAL_VERBOSE_TRACING_ENABLE

Expand Down
4 changes: 2 additions & 2 deletions runtime/src/iree/hal/drivers/cuda/native_executable.c
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,6 @@ iree_status_t iree_hal_cuda_native_executable_create(
IREE_TRACE_ZONE_BEGIN(z0);

*out_executable = NULL;
iree_hal_cuda_native_executable_t* executable = NULL;

// TODO: move to the executable cache to avoid repeated queries.
iree_hal_cuda_limits_t limits = {0};
Expand Down Expand Up @@ -252,7 +251,8 @@ iree_status_t iree_hal_cuda_native_executable_create(
}
});

// Allocate storage for the kernel module.
// Allocate storage for the executable and its associated data structures.
iree_hal_cuda_native_executable_t* executable = NULL;
const iree_host_size_t total_size =
sizeof(*executable) + module_count * sizeof(executable->modules[0]) +
export_count * sizeof(executable->exports[0]) + total_export_info_length;
Expand Down
4 changes: 2 additions & 2 deletions runtime/src/iree/hal/drivers/hip/native_executable.c
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,6 @@ iree_status_t iree_hal_hip_native_executable_create(
IREE_TRACE_ZONE_BEGIN(z0);

*out_executable = NULL;
iree_hal_hip_native_executable_t* executable = NULL;

// TODO: move to the executable cache to avoid repeated queries.
iree_hal_hip_limits_t limits = {0};
Expand Down Expand Up @@ -251,7 +250,8 @@ iree_status_t iree_hal_hip_native_executable_create(
}
});

// Allocate storage for the kernel module.
// Allocate storage for the executable and its associated data structures.
iree_hal_hip_native_executable_t* executable = NULL;
const iree_host_size_t total_size =
sizeof(*executable) + module_count * sizeof(executable->modules[0]) +
export_count * sizeof(executable->exports[0]) + total_export_info_length;
Expand Down
2 changes: 0 additions & 2 deletions runtime/src/iree/hal/drivers/metal/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ iree_cc_library(
"metal_driver.m"
"nop_executable_cache.h"
"nop_executable_cache.m"
"pipeline_layout.h"
"pipeline_layout.m"
"shared_event.h"
"shared_event.m"
"staging_buffer.h"
Expand Down
30 changes: 18 additions & 12 deletions runtime/src/iree/hal/drivers/metal/builtin_executables.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,22 @@
extern "C" {
#endif // __cplusplus

// Object and launch parameters for a compute function.
typedef struct iree_hal_metal_builtin_pipeline_t {
id<MTLComputePipelineState> pipeline_state;
IREE_TRACE(iree_hal_metal_source_location_t source_location;)
} iree_hal_metal_builtin_pipeline_t;

typedef struct iree_hal_metal_builtin_executable_t {
iree_allocator_t host_allocator;

// The number of entry points in this builtin executable.
iree_host_size_t entry_point_count;
// THe list of entry points, pointing to the end of the struct allocation.
iree_hal_metal_kernel_params_t entry_points[];
// Compiled MTLLibrary instances containing the builtin kernels.
NSArray<id<MTLLibrary>>* libraries;

// The number of pipelines in this builtin executable.
iree_host_size_t pipeline_count;
// The list of pipelines, pointing to the end of the struct allocation.
iree_hal_metal_builtin_pipeline_t pipelines[];
} iree_hal_metal_builtin_executable_t;
// + Additional inline allocation for holding all entry point kernel parameters.

Expand All @@ -33,18 +42,16 @@ iree_status_t iree_hal_metal_builtin_executable_create(
id<MTLDevice> device, iree_allocator_t host_allocator,
iree_hal_metal_builtin_executable_t** out_executable);

void iree_hal_metal_builtin_executable_destroy(
iree_hal_metal_builtin_executable_t* executable);
void iree_hal_metal_builtin_executable_destroy(iree_hal_metal_builtin_executable_t* executable);

// Fills the |target_buffer| at the given |target_offset| of |length| with
// |pattern| using builtin executables dispatched via |encoder|.
//
// Under the hood, this will record all necessary commands to bind kernel
// objects and buffer resources, and the perform dispatch.
iree_status_t iree_hal_metal_builtin_executable_fill_buffer(
const iree_hal_metal_builtin_executable_t* executable,
id<MTLComputeCommandEncoder> encoder, id<MTLBuffer> target_buffer,
iree_device_size_t target_offset, iree_device_size_t length,
const iree_hal_metal_builtin_executable_t* executable, id<MTLComputeCommandEncoder> encoder,
id<MTLBuffer> target_buffer, iree_device_size_t target_offset, iree_device_size_t length,
uint32_t pattern);

// Copies the |source_buffer| at |source_offset| to the |target_buffer| at
Expand All @@ -54,9 +61,8 @@ iree_status_t iree_hal_metal_builtin_executable_fill_buffer(
// Under the hood, this will record all necessary commands to bind kernel
// objects and buffer resources, and the perform dispatch.
iree_status_t iree_hal_metal_builtin_executable_copy_buffer(
const iree_hal_metal_builtin_executable_t* executable,
id<MTLComputeCommandEncoder> encoder, id<MTLBuffer> source_buffer,
iree_device_size_t source_offset, id<MTLBuffer> target_buffer,
const iree_hal_metal_builtin_executable_t* executable, id<MTLComputeCommandEncoder> encoder,
id<MTLBuffer> source_buffer, iree_device_size_t source_offset, id<MTLBuffer> target_buffer,
iree_device_size_t target_offset, iree_device_size_t length);

#ifdef __cplusplus
Expand Down
Loading

0 comments on commit e044271

Please sign in to comment.