From 7121ece901a1b65b51254c8dc10d932b5fcccaeb Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Thu, 15 Aug 2024 09:30:10 -0700 Subject: [PATCH] Converting CUDA target to support executable-create2. This produces a new flatbuffer that supports multiple CUmodules per HAL executable, reorganizes per-export information to be per-export, and removes HAL pipeline layouts and the existing stateful command recording. --- compiler/plugins/target/CUDA/CUDATarget.cpp | 304 +++++++------ .../plugins/target/CUDA/test/smoketest.mlir | 2 +- .../Dialect/HAL/Target/TargetBackend.cpp | 2 +- runtime/src/iree/hal/drivers/cuda/BUILD.bazel | 2 - .../src/iree/hal/drivers/cuda/CMakeLists.txt | 2 - .../src/iree/hal/drivers/cuda/cuda_device.c | 27 -- .../hal/drivers/cuda/graph_command_buffer.c | 231 +--------- .../iree/hal/drivers/cuda/native_executable.c | 421 +++++++++++------- .../iree/hal/drivers/cuda/native_executable.h | 35 +- .../iree/hal/drivers/cuda/pipeline_layout.c | 260 ----------- .../iree/hal/drivers/cuda/pipeline_layout.h | 110 ----- .../hal/drivers/cuda/stream_command_buffer.c | 217 +-------- .../src/iree/schemas/cuda_executable_def.fbs | 69 ++- 13 files changed, 514 insertions(+), 1168 deletions(-) delete mode 100644 runtime/src/iree/hal/drivers/cuda/pipeline_layout.c delete mode 100644 runtime/src/iree/hal/drivers/cuda/pipeline_layout.h diff --git a/compiler/plugins/target/CUDA/CUDATarget.cpp b/compiler/plugins/target/CUDA/CUDATarget.cpp index 5a41ffe67642e..b3a7a99746c51 100644 --- a/compiler/plugins/target/CUDA/CUDATarget.cpp +++ b/compiler/plugins/target/CUDA/CUDATarget.cpp @@ -55,7 +55,6 @@ namespace mlir::iree_compiler::IREE::HAL { namespace { struct CUDAOptions { - bool dumpPtx = false; std::string clTargetChip = "sm_60"; std::string clTargetFeature = "+ptx76"; bool clUsePtxas = false; @@ -64,8 +63,6 @@ struct CUDAOptions { void bindOptions(OptionsBinder &binder) { static llvm::cl::OptionCategory category("CUDA HAL Target"); - binder.opt("iree-hal-cuda-dump-ptx", dumpPtx, llvm::cl::cat(category), - llvm::cl::desc("Dump ptx to the debug stream.")); binder.opt("iree-hal-cuda-llvm-target-arch", clTargetChip, llvm::cl::cat(category), @@ -254,26 +251,14 @@ static std::string produceGpuImage(const CUDAOptions &options, return ptxImage; } -static void dumpLLVMModuleToPath(StringRef path, StringRef baseName, - StringRef suffix, StringRef extPrefix, - llvm::Module &module) { - // Dump disassembly to path. - llvm::SmallVector textData; - llvm::raw_svector_ostream textOstream(textData); - - module.print(textOstream, nullptr); - std::string textExtension = extPrefix.str() + ".ll"; - dumpDataToPath(path, baseName, suffix, textExtension, - StringRef(textData.data(), textData.size())); - - // Dump bitcode to path. - llvm::SmallVector binaryData; - llvm::raw_svector_ostream binaryOstream(binaryData); - // Write the specified module to the specified output stream. - llvm::WriteBitcodeToFile(module, binaryOstream); - std::string binaryExtension = extPrefix.str() + ".bc"; - dumpDataToPath(path, baseName, suffix, binaryExtension, - StringRef(binaryData.data(), binaryData.size())); +static void dumpModuleToPath(StringRef path, StringRef baseName, + StringRef suffix, StringRef extension, + llvm::Module &module) { + llvm::SmallVector data; + llvm::raw_svector_ostream ostream(data); + module.print(ostream, nullptr); + dumpDataToPath(path, baseName, suffix, extension, + StringRef(data.data(), data.size())); } static std::string translateModuleToISA(llvm::Module &module, @@ -394,19 +379,24 @@ class CUDATargetDevice final : public TargetDevice { getDefaultDeviceTarget(MLIRContext *context, const TargetRegistry &targetRegistry) const override { Builder b(context); - SmallVector configItems; - // TODO: device configuration attrs. - auto configAttr = b.getDictionaryAttr(configItems); + SmallVector deviceConfigAttrs; + deviceConfigAttrs.emplace_back(b.getStringAttr("executable_create_2"), + b.getUnitAttr()); + auto deviceConfigAttr = b.getDictionaryAttr(deviceConfigAttrs); + + SmallVector executableConfigAttrs; + auto executableConfigAttr = b.getDictionaryAttr(executableConfigAttrs); // If we had multiple target environments we would generate one target attr // per environment, with each setting its own environment attribute. SmallVector executableTargetAttrs; targetRegistry.getTargetBackend("cuda")->getDefaultExecutableTargets( - context, "cuda", configAttr, executableTargetAttrs); + context, "cuda", executableConfigAttr, executableTargetAttrs); return IREE::HAL::DeviceTargetAttr::get(context, b.getStringAttr("cuda"), - configAttr, executableTargetAttrs); + deviceConfigAttr, + executableTargetAttrs); } private: @@ -471,6 +461,7 @@ class CUDATargetBackend final : public TargetBackend { LogicalResult serializeExecutable(const SerializationOptions &serOptions, IREE::HAL::ExecutableVariantOp variantOp, OpBuilder &executableBuilder) override { + ModuleOp innerModuleOp = variantOp.getInnerModule(); auto targetAttr = variantOp.getTargetAttr(); StringRef targetArch = options.clTargetChip; StringRef targetFeatures = options.clTargetFeature; @@ -479,10 +470,6 @@ class CUDATargetBackend final : public TargetBackend { targetFeatures = attr.getFeatures(); } - // Perform the translation in a separate context to avoid any - // multi-threading issues. - llvm::LLVMContext context; - // We name our files after the executable name so that they are easy to // track both during compilation (logs/artifacts/etc), as outputs (final // intermediate code/binary files), and at runtime (loaded @@ -490,41 +477,16 @@ class CUDATargetBackend final : public TargetBackend { auto libraryName = variantOp->getParentOfType().getName().str(); - // TODO(thomasraoux): property handle export ordinals; this code is assuming - // that ordinals are dense starting at 0 but that is not required. - - // Collect all the entry point parameters. - SmallVector> workgroupSizes; - SmallVector workgroupLocalMemories; - for (auto exportOp : variantOp.getExportOps()) { - std::array workgroupSize; - if (std::optional workgroupSizeAttr = - exportOp.getWorkgroupSize()) { - for (auto it : llvm::enumerate(workgroupSizeAttr.value())) { - workgroupSize[it.index()] = - llvm::cast(it.value()).getInt(); - } - } else { - workgroupSize = {1, 1, 1}; - } - workgroupSizes.push_back(workgroupSize); - uint32_t workgroupLocalMemory = 0; - if (auto workgroupLocalMemoryAttr = exportOp.getWorkgroupLocalMemory()) { - workgroupLocalMemory = workgroupLocalMemoryAttr->getSExtValue(); - } - workgroupLocalMemories.push_back(workgroupLocalMemory); + // Collect all the entry point names. + auto exportOps = llvm::to_vector_of( + variantOp.getExportOps()); + llvm::StringMap exportOpMap; + for (IREE::HAL::ExecutableExportOp exportOp : exportOps) { + exportOpMap[exportOp.getSymName()] = exportOp; } - FlatbufferBuilder builder; - iree_hal_cuda_ExecutableDef_start_as_root(builder); - - // Attach embedded source file contents. - auto sourceFilesRef = createSourceFilesVec( - serOptions.debugLevel, variantOp.getSourcesAttr(), builder); - - SmallVector entryPointNames; - std::string ptxImage; - SmallVector sourceLocationRefs; + std::array maxWorkgroupSize = {1, 1, 1}; + std::string targetPTX; if (variantOp.isExternal()) { if (!variantOp.getObjects().has_value()) { return variantOp.emitOpError() @@ -537,42 +499,32 @@ class CUDATargetBackend final : public TargetBackend { "supported for external variants"; } - // Take exported names verbatim. The user must have already sanitized - // these to match the names in their kernels. We don't support any kind of - // mangling and if the user was silly enough to rely on nvcc C++ mangling - // they'll have to figure that out. - for (auto exportOp : variantOp.getExportOps()) { - entryPointNames.emplace_back(exportOp.getSymName()); - } - + // Read the PTX from the object file. auto objectAttr = llvm::cast( variantOp.getObjects()->getValue().front()); if (auto data = objectAttr.loadData()) { - ptxImage = data.value(); + targetPTX = data.value(); } else { return variantOp.emitOpError() << "object file could not be loaded: " << objectAttr; } } else { - ModuleOp innerModuleOp = variantOp.getInnerModule(); + // Perform the translation in a separate context to avoid any + // multi-threading issues. + llvm::LLVMContext context; - auto llvmModule = + std::unique_ptr llvmModule = mlir::translateModuleToLLVMIR(innerModuleOp, context, libraryName); if (!llvmModule) { return variantOp.emitError() << "failed to translate the MLIR LLVM " "dialect to the native llvm::Module"; } - for (auto [exportOp, workgroupSize] : - llvm::zip_equal(variantOp.getExportOps(), workgroupSizes)) { - auto *llvmFunc = llvmModule->getFunction(exportOp.getName()); + for (auto func : innerModuleOp.getOps()) { + llvm::Function *llvmFunc = llvmModule->getFunction(func.getName()); if (llvmFunc->isDeclaration()) continue; - // setName will make sure the function name is unique. - llvmFunc->setName(sanitizeSymbolName(exportOp.getName())); - entryPointNames.emplace_back(llvmFunc->getName()); - auto *annotations = llvmModule->getOrInsertNamedMetadata("nvvm.annotations"); auto setMetadataValueI32 = [&](StringRef name, int value) { @@ -586,18 +538,25 @@ class CUDATargetBackend final : public TargetBackend { }; // Mark the entry point as a kernel. setMetadataValueI32("kernel", 1); + // Set the maximum number of threads in the thread block (CTA). - setMetadataValueI32("maxntidx", workgroupSize[0]); - setMetadataValueI32("maxntidy", workgroupSize[1]); - setMetadataValueI32("maxntidz", workgroupSize[2]); - - // Optional source location information for debugging/profiling. - if (serOptions.debugLevel >= 1) { - if (auto loc = findFirstFileLoc(exportOp.getLoc())) { - auto filenameRef = builder.createString(loc->getFilename()); - sourceLocationRefs.push_back(iree_hal_debug_FileLineLocDef_create( - builder, filenameRef, loc->getLine())); - } + auto exportOp = exportOpMap[func.getName()]; + if (auto workgroupSizeAttr = exportOp.getWorkgroupSize()) { + auto workgroupSizeValues = workgroupSizeAttr->getValue(); + std::array workgroupSize = { + static_cast( + cast(workgroupSizeValues[0]).getInt()), + static_cast( + cast(workgroupSizeValues[1]).getInt()), + static_cast( + cast(workgroupSizeValues[2]).getInt()), + }; + maxWorkgroupSize[0] = std::max(maxWorkgroupSize[0], workgroupSize[0]); + maxWorkgroupSize[1] = std::max(maxWorkgroupSize[1], workgroupSize[1]); + maxWorkgroupSize[2] = std::max(maxWorkgroupSize[2], workgroupSize[2]); + setMetadataValueI32("maxntidx", workgroupSize[0]); + setMetadataValueI32("maxntidy", workgroupSize[1]); + setMetadataValueI32("maxntidz", workgroupSize[2]); } } @@ -617,11 +576,17 @@ class CUDATargetBackend final : public TargetBackend { } } - // Dump just the codegen bitcode before linking and optimization. - if (!serOptions.dumpIntermediatesPath.empty()) { - dumpLLVMModuleToPath(serOptions.dumpIntermediatesPath, - serOptions.dumpBaseName, variantOp.getName(), - ".codegen", *llvmModule); + llvmModule->setDataLayout(targetMachine->createDataLayout()); + + for (llvm::Function &f : llvmModule->functions()) + f.addFnAttr(llvm::Attribute::AlwaysInline); + + // Link user-provided modules. + llvm::Linker linker(*llvmModule); + if (failed(linkCmdlineBitcodeFiles( + variantOp.getLoc(), linker, llvm::Linker::OverrideFromSrc, + *targetMachine, llvmModule->getContext()))) { + return failure(); } // Link user and device bitcode alongside the generated module. @@ -630,77 +595,122 @@ class CUDATargetBackend final : public TargetBackend { return failure(); } - // Dump all linked bitcode prior to optimization. if (!serOptions.dumpIntermediatesPath.empty()) { - dumpLLVMModuleToPath(serOptions.dumpIntermediatesPath, - serOptions.dumpBaseName, variantOp.getName(), - ".linked", *llvmModule); + dumpModuleToPath(serOptions.dumpIntermediatesPath, + serOptions.dumpBaseName, variantOp.getName(), + ".linked.ll", *llvmModule); } - std::array maxWorkgroupSize = {1, 1, 1}; - for (int64_t i = 0, e = workgroupSizes.size(); i < e; i++) { - for (int64_t j = 0; j < maxWorkgroupSize.size(); j++) { - maxWorkgroupSize[j] = - std::max(maxWorkgroupSize[j], workgroupSizes[i][j]); - } - } - // Run LTO-style full optimization on the linked modules. + // Run LLVM optimization passes. optimizeModule(*llvmModule, *targetMachine, maxWorkgroupSize); - - // Dump bitcode post-linking and optimization. if (!serOptions.dumpIntermediatesPath.empty()) { - dumpLLVMModuleToPath(serOptions.dumpIntermediatesPath, - serOptions.dumpBaseName, variantOp.getName(), - ".optimized", *llvmModule); + dumpModuleToPath(serOptions.dumpIntermediatesPath, + serOptions.dumpBaseName, variantOp.getName(), + ".optimized.ll", *llvmModule); } - // Serialize CUDA kernel into the binary that we will embed in the + // Serialize ptx kernel into the binary that we will embed in the // final FlatBuffer. - ptxImage = translateModuleToISA(*llvmModule, *targetMachine); + targetPTX = translateModuleToISA(*llvmModule, *targetMachine); + if (targetPTX.empty()) + return failure(); } - if (options.dumpPtx) { - llvm::dbgs() << ptxImage; - } if (!serOptions.dumpBinariesPath.empty()) { dumpDataToPath(serOptions.dumpBinariesPath, serOptions.dumpBaseName, - variantOp.getName(), ".ptx", ptxImage); + variantOp.getName(), ".ptx", targetPTX); } - std::string gpuImage = produceGpuImage(options, targetArch, ptxImage); - auto gpuImageRef = - flatbuffers_string_create(builder, gpuImage.c_str(), gpuImage.size()); - iree_hal_cuda_BlockSize_vec_start(builder); - for (const auto &workgroupSize : workgroupSizes) { - iree_hal_cuda_BlockSize_vec_push_create( - builder, workgroupSize[0], workgroupSize[1], workgroupSize[2]); + FlatbufferBuilder builder; + iree_hal_cuda_ExecutableDef_start_as_root(builder); + + auto sourceFilesRef = createSourceFilesVec( + serOptions.debugLevel, variantOp.getSourcesAttr(), builder); + + // Only a single module today. + SmallVector moduleRefs; + { + auto ptxImageRef = flatbuffers_string_create(builder, targetPTX.c_str(), + targetPTX.size()); + moduleRefs.push_back( + iree_hal_cuda_ModuleDef_create(builder, ptxImageRef)); } - auto blockSizesRef = iree_hal_cuda_BlockSize_vec_end(builder); - auto workgroupLocalMemoriesRef = - builder.createInt32Vec(workgroupLocalMemories); - auto entryPointsRef = builder.createStringVec(entryPointNames); - - iree_hal_cuda_ExecutableDef_entry_points_add(builder, entryPointsRef); - iree_hal_cuda_ExecutableDef_block_sizes_add(builder, blockSizesRef); - iree_hal_cuda_ExecutableDef_shared_memory_size_add( - builder, workgroupLocalMemoriesRef); - iree_hal_cuda_ExecutableDef_ptx_image_add(builder, gpuImageRef); - if (!sourceLocationRefs.empty()) { - auto sourceLocationsRef = - builder.createOffsetVecDestructive(sourceLocationRefs); - iree_hal_cuda_ExecutableDef_source_locations_add(builder, - sourceLocationsRef); + auto modulesRef = builder.createOffsetVecDestructive(moduleRefs); + + // Generate optional per-export debug information. + // May be empty if no debug information was requested. + auto exportDebugInfos = + createExportDefs(serOptions.debugLevel, exportOps, builder); + + SmallVector exportRefs; + exportRefs.resize(exportOps.size(), 0); + for (auto exportOp : exportOps) { + auto ordinalAttr = exportOp.getOrdinalAttr(); + if (!ordinalAttr) { + return mlir::emitError(exportOp.getLoc()) + << "could not compile rocm binary: export op is missing ordinal"; + } + int64_t ordinal = ordinalAttr.getInt(); + + auto kernelNameRef = builder.createString(exportOp.getName()); + + iree_hal_cuda_BlockDims_t blockDims = {0}; + if (auto workgroupSizeAttr = exportOp.getWorkgroupSize()) { + auto workgroupSize = workgroupSizeAttr->getValue(); + blockDims.x = cast(workgroupSize[0]).getInt(); + blockDims.y = cast(workgroupSize[1]).getInt(); + blockDims.z = cast(workgroupSize[2]).getInt(); + } + + uint32_t blockSharedMemorySize = 0; + if (std::optional workgroupLocalMemoryAttr = + exportOp.getWorkgroupLocalMemory()) { + blockSharedMemorySize = workgroupLocalMemoryAttr->getSExtValue(); + } + + auto layoutAttr = exportOp.getLayoutAttr(); + uint32_t constantCount = + static_cast(layoutAttr.getPushConstants()); + SmallVector bindingFlags; + for (auto bindingAttr : layoutAttr.getSetLayout(0).getBindings()) { + iree_hal_cuda_BindingBits_enum_t flags = 0; + if (allEnumBitsSet(bindingAttr.getFlags(), + IREE::HAL::DescriptorFlags::ReadOnly)) { + flags |= iree_hal_cuda_BindingBits_READ_ONLY; + } + if (allEnumBitsSet(bindingAttr.getFlags(), + IREE::HAL::DescriptorFlags::Indirect)) { + flags |= iree_hal_cuda_BindingBits_INDIRECT; + } + bindingFlags.push_back(flags); + } + auto bindingFlagsRef = iree_hal_cuda_BindingBits_vec_create( + builder, bindingFlags.data(), bindingFlags.size()); + + iree_hal_cuda_ExportDef_start(builder); + iree_hal_cuda_ExportDef_module_ordinal_add(builder, 0); // always 0 today + iree_hal_cuda_ExportDef_kernel_name_add(builder, kernelNameRef); + iree_hal_cuda_ExportDef_block_dims_add(builder, &blockDims); + iree_hal_cuda_ExportDef_block_shared_memory_size_add( + builder, blockSharedMemorySize); + iree_hal_cuda_ExportDef_constant_count_add(builder, constantCount); + iree_hal_cuda_ExportDef_binding_flags_add(builder, bindingFlagsRef); + iree_hal_cuda_ExportDef_debug_info_add(builder, + exportDebugInfos[ordinal]); + exportRefs[ordinal] = iree_hal_cuda_ExportDef_end(builder); } + auto exportsRef = builder.createOffsetVecDestructive(exportRefs); + + iree_hal_cuda_ExecutableDef_exports_add(builder, exportsRef); + iree_hal_cuda_ExecutableDef_modules_add(builder, modulesRef); iree_hal_cuda_ExecutableDef_source_files_add(builder, sourceFilesRef); iree_hal_cuda_ExecutableDef_end_as_root(builder); // Add the binary data to the target executable. - auto binaryOp = executableBuilder.create( + executableBuilder.create( variantOp.getLoc(), variantOp.getSymName(), variantOp.getTarget().getFormat(), builder.getBufferAttr(executableBuilder.getContext())); - binaryOp.setMimeTypeAttr( - executableBuilder.getStringAttr("application/x-flatbuffers")); return success(); } diff --git a/compiler/plugins/target/CUDA/test/smoketest.mlir b/compiler/plugins/target/CUDA/test/smoketest.mlir index fc7d8fcd1b55a..6e6fa946fcd92 100644 --- a/compiler/plugins/target/CUDA/test/smoketest.mlir +++ b/compiler/plugins/target/CUDA/test/smoketest.mlir @@ -1,5 +1,5 @@ // RUN: iree-opt --split-input-file --iree-hal-transformation-pipeline --iree-gpu-test-target=sm_60 %s | FileCheck %s -// RUN: iree-opt --split-input-file --iree-hal-transformation-pipeline --iree-gpu-test-target=sm_60 --iree-hal-cuda-dump-ptx %s 2>&1 | FileCheck %s --check-prefix=PTX +// RUN: iree-opt --split-input-file --iree-hal-transformation-pipeline --iree-gpu-test-target=sm_60 --iree-hal-dump-executable-binaries-to=- %s 2>&1 | FileCheck %s --check-prefix=PTX #map = affine_map<(d0) -> (d0)> diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/TargetBackend.cpp b/compiler/src/iree/compiler/Dialect/HAL/Target/TargetBackend.cpp index a668bd5ce4e81..c576369a1c7b6 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Target/TargetBackend.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Target/TargetBackend.cpp @@ -52,7 +52,7 @@ void dumpDataToPath(StringRef path, StringRef baseName, StringRef suffix, llvm::join_items(llvm::sys::path::get_separator(), path, fileName); auto filePath = llvm::sys::path::convert_to_slash(fileParts); std::string error; - auto file = mlir::openOutputFile(filePath, &error); + auto file = mlir::openOutputFile(path == "-" ? path : filePath, &error); if (!file) { llvm::errs() << "Unable to dump debug output to " << filePath << "\n"; return; diff --git a/runtime/src/iree/hal/drivers/cuda/BUILD.bazel b/runtime/src/iree/hal/drivers/cuda/BUILD.bazel index be16732a3dbf6..0f9c2a4e4ffdb 100644 --- a/runtime/src/iree/hal/drivers/cuda/BUILD.bazel +++ b/runtime/src/iree/hal/drivers/cuda/BUILD.bazel @@ -37,8 +37,6 @@ iree_runtime_cc_library( "nccl_channel.h", "nop_executable_cache.c", "nop_executable_cache.h", - "pipeline_layout.c", - "pipeline_layout.h", "stream_command_buffer.c", "stream_command_buffer.h", "timepoint_pool.c", diff --git a/runtime/src/iree/hal/drivers/cuda/CMakeLists.txt b/runtime/src/iree/hal/drivers/cuda/CMakeLists.txt index 20c4715ac1dda..3e4f6a700f065 100644 --- a/runtime/src/iree/hal/drivers/cuda/CMakeLists.txt +++ b/runtime/src/iree/hal/drivers/cuda/CMakeLists.txt @@ -38,8 +38,6 @@ iree_cc_library( "nccl_channel.h" "nop_executable_cache.c" "nop_executable_cache.h" - "pipeline_layout.c" - "pipeline_layout.h" "stream_command_buffer.c" "stream_command_buffer.h" "timepoint_pool.c" diff --git a/runtime/src/iree/hal/drivers/cuda/cuda_device.c b/runtime/src/iree/hal/drivers/cuda/cuda_device.c index 37d4212094398..4038a602a63b3 100644 --- a/runtime/src/iree/hal/drivers/cuda/cuda_device.c +++ b/runtime/src/iree/hal/drivers/cuda/cuda_device.c @@ -23,7 +23,6 @@ #include "iree/hal/drivers/cuda/nccl_channel.h" #include "iree/hal/drivers/cuda/nccl_dynamic_symbols.h" #include "iree/hal/drivers/cuda/nop_executable_cache.h" -#include "iree/hal/drivers/cuda/pipeline_layout.h" #include "iree/hal/drivers/cuda/stream_command_buffer.h" #include "iree/hal/drivers/cuda/timepoint_pool.h" #include "iree/hal/drivers/cuda/tracing.h" @@ -756,18 +755,6 @@ static iree_status_t iree_hal_cuda_device_create_command_buffer( } } -static iree_status_t iree_hal_cuda_device_create_descriptor_set_layout( - iree_hal_device_t* base_device, - iree_hal_descriptor_set_layout_flags_t flags, - iree_host_size_t binding_count, - const iree_hal_descriptor_set_layout_binding_t* bindings, - iree_hal_descriptor_set_layout_t** out_descriptor_set_layout) { - iree_hal_cuda_device_t* device = iree_hal_cuda_device_cast(base_device); - return iree_hal_cuda_descriptor_set_layout_create( - flags, binding_count, bindings, device->host_allocator, - out_descriptor_set_layout); -} - static iree_status_t iree_hal_cuda_device_create_event( iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity, iree_hal_event_flags_t flags, iree_hal_event_t** out_event) { @@ -799,17 +786,6 @@ static iree_status_t iree_hal_cuda_device_import_file( iree_hal_device_host_allocator(base_device), out_file); } -static iree_status_t iree_hal_cuda_device_create_pipeline_layout( - iree_hal_device_t* base_device, iree_host_size_t push_constants, - iree_host_size_t set_layout_count, - iree_hal_descriptor_set_layout_t* const* set_layouts, - iree_hal_pipeline_layout_t** out_pipeline_layout) { - iree_hal_cuda_device_t* device = iree_hal_cuda_device_cast(base_device); - return iree_hal_cuda_pipeline_layout_create( - set_layout_count, set_layouts, push_constants, device->host_allocator, - out_pipeline_layout); -} - static iree_status_t iree_hal_cuda_device_create_semaphore( iree_hal_device_t* base_device, uint64_t initial_value, iree_hal_semaphore_flags_t flags, iree_hal_semaphore_t** out_semaphore) { @@ -1023,12 +999,9 @@ static const iree_hal_device_vtable_t iree_hal_cuda_device_vtable = { .query_i64 = iree_hal_cuda_device_query_i64, .create_channel = iree_hal_cuda_device_create_channel, .create_command_buffer = iree_hal_cuda_device_create_command_buffer, - .create_descriptor_set_layout = - iree_hal_cuda_device_create_descriptor_set_layout, .create_event = iree_hal_cuda_device_create_event, .create_executable_cache = iree_hal_cuda_device_create_executable_cache, .import_file = iree_hal_cuda_device_import_file, - .create_pipeline_layout = iree_hal_cuda_device_create_pipeline_layout, .create_semaphore = iree_hal_cuda_device_create_semaphore, .query_semaphore_compatibility = iree_hal_cuda_device_query_semaphore_compatibility, diff --git a/runtime/src/iree/hal/drivers/cuda/graph_command_buffer.c b/runtime/src/iree/hal/drivers/cuda/graph_command_buffer.c index 94dc554a491d1..a20d2e848f8f0 100644 --- a/runtime/src/iree/hal/drivers/cuda/graph_command_buffer.c +++ b/runtime/src/iree/hal/drivers/cuda/graph_command_buffer.c @@ -14,7 +14,6 @@ #include "iree/hal/drivers/cuda/cuda_dynamic_symbols.h" #include "iree/hal/drivers/cuda/cuda_status_util.h" #include "iree/hal/drivers/cuda/native_executable.h" -#include "iree/hal/drivers/cuda/pipeline_layout.h" #include "iree/hal/drivers/cuda/tracing.h" #include "iree/hal/utils/collective_batch.h" #include "iree/hal/utils/resource_set.h" @@ -58,12 +57,6 @@ typedef struct iree_hal_cuda_graph_command_buffer_t { // Iteratively constructed batch of collective operations. iree_hal_collective_batch_t collective_batch; - - // TODO(#18154): drop state used by legacy bindings mechanism. - int32_t push_constants[IREE_HAL_CUDA_MAX_PUSH_CONSTANT_COUNT]; - struct { - CUdeviceptr bindings[IREE_HAL_CUDA_MAX_DESCRIPTOR_SET_BINDING_COUNT]; - } descriptor_sets[IREE_HAL_CUDA_MAX_DESCRIPTOR_SET_COUNT]; } iree_hal_cuda_graph_command_buffer_t; static const iree_hal_command_buffer_vtable_t @@ -705,194 +698,6 @@ static iree_status_t iree_hal_cuda_graph_command_buffer_collective( recv_binding, element_count); } -static iree_status_t iree_hal_cuda_graph_command_buffer_push_constants( - iree_hal_command_buffer_t* base_command_buffer, - iree_hal_pipeline_layout_t* pipeline_layout, iree_host_size_t offset, - const void* values, iree_host_size_t values_length) { - iree_hal_cuda_graph_command_buffer_t* command_buffer = - iree_hal_cuda_graph_command_buffer_cast(base_command_buffer); - iree_host_size_t constant_base_index = offset / sizeof(int32_t); - for (iree_host_size_t i = 0; i < values_length / sizeof(int32_t); i++) { - command_buffer->push_constants[i + constant_base_index] = - ((uint32_t*)values)[i]; - } - return iree_ok_status(); -} - -static iree_status_t iree_hal_cuda_graph_command_buffer_push_descriptor_set( - iree_hal_command_buffer_t* base_command_buffer, - iree_hal_pipeline_layout_t* pipeline_layout, uint32_t set, - iree_host_size_t binding_count, const iree_hal_buffer_ref_t* bindings) { - if (binding_count > IREE_HAL_CUDA_MAX_DESCRIPTOR_SET_BINDING_COUNT) { - return iree_make_status( - IREE_STATUS_RESOURCE_EXHAUSTED, - "exceeded available binding slots for push " - "descriptor set #%" PRIu32 "; requested %" PRIhsz " vs. maximal %d", - set, binding_count, IREE_HAL_CUDA_MAX_DESCRIPTOR_SET_BINDING_COUNT); - } - - iree_hal_cuda_graph_command_buffer_t* command_buffer = - iree_hal_cuda_graph_command_buffer_cast(base_command_buffer); - IREE_TRACE_ZONE_BEGIN(z0); - - CUdeviceptr* current_bindings = command_buffer->descriptor_sets[set].bindings; - for (iree_host_size_t i = 0; i < binding_count; i++) { - const iree_hal_buffer_ref_t* binding = &bindings[i]; - CUdeviceptr device_ptr = 0; - if (binding->buffer) { - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_hal_resource_set_insert(command_buffer->resource_set, 1, - &binding->buffer)); - - CUdeviceptr device_buffer = iree_hal_cuda_buffer_device_pointer( - iree_hal_buffer_allocated_buffer(binding->buffer)); - iree_device_size_t offset = iree_hal_buffer_byte_offset(binding->buffer); - device_ptr = device_buffer + offset + binding->offset; - } - current_bindings[binding->ordinal] = device_ptr; - } - - IREE_TRACE_ZONE_END(z0); - return iree_ok_status(); -} - -static iree_status_t iree_hal_cuda_graph_command_buffer_dispatch( - iree_hal_command_buffer_t* base_command_buffer, - iree_hal_executable_t* executable, int32_t entry_point, - uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z, - iree_hal_dispatch_flags_t flags) { - iree_hal_cuda_graph_command_buffer_t* command_buffer = - iree_hal_cuda_graph_command_buffer_cast(base_command_buffer); - IREE_TRACE_ZONE_BEGIN(z0); - - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_hal_cuda_graph_command_buffer_flush_collectives(command_buffer)); - - // Lookup kernel parameters used for side-channeling additional launch - // information from the compiler. - iree_hal_cuda_kernel_info_t kernel_info; - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_hal_cuda_native_executable_entry_point_kernel_info( - executable, entry_point, &kernel_info)); - - IREE_CUDA_GRAPH_COMMAND_BUFFER_TRACE_ZONE_BEGIN_EXTERNAL( - command_buffer, IREE_HAL_CUDA_TRACING_VERBOSITY_FINE, - kernel_info.source_filename.data, kernel_info.source_filename.size, - kernel_info.source_line, kernel_info.function_name.data, - kernel_info.function_name.size, - /*name=*/NULL, 0); - - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_hal_resource_set_insert(command_buffer->resource_set, 1, - &executable)); - - // The total number of descriptors across all descriptor sets. - iree_host_size_t descriptor_count = - iree_hal_cuda_pipeline_layout_total_binding_count(kernel_info.layout); - // The total number of push constants. - iree_host_size_t push_constant_count = - iree_hal_cuda_pipeline_layout_push_constant_count(kernel_info.layout); - // We append push constants to the end of descriptors to form a linear chain - // of kernel arguments. - iree_host_size_t kernel_params_count = descriptor_count + push_constant_count; - iree_host_size_t kernel_params_length = kernel_params_count * sizeof(void*); - - // Per CUDA API requirements, we need two levels of indirection for passing - // kernel arguments in. - // "If the kernel has N parameters, then kernelParams needs to be an array - // of N pointers. Each pointer, from kernelParams[0] to kernelParams[N-1], - // points to the region of memory from which the actual parameter will be - // copied." - // - // (From the cuGraphAddKernelNode API doc in - // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GRAPH.html#group__CUDA__GRAPH_1g50d871e3bd06c1b835e52f2966ef366b) - // - // It means each kernel_params[i] is itself a pointer to the corresponding - // element at the *second* inline allocation at the end of the current - // segment. - iree_host_size_t total_size = kernel_params_length * 2; - uint8_t* storage_base = NULL; - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_arena_allocate(&command_buffer->arena, total_size, - (void**)&storage_base)); - void** params_ptr = (void**)storage_base; - - // Set up kernel arguments to point to the payload slots. - CUdeviceptr* payload_ptr = - (CUdeviceptr*)((uint8_t*)params_ptr + kernel_params_length); - for (size_t i = 0; i < kernel_params_count; i++) { - params_ptr[i] = &payload_ptr[i]; - } - - // Copy descriptors from all sets to the end of the current segment for later - // access. - iree_host_size_t set_count = - iree_hal_cuda_pipeline_layout_descriptor_set_count(kernel_info.layout); - for (iree_host_size_t i = 0; i < set_count; ++i) { - // TODO: cache this information in the kernel info to avoid recomputation. - iree_host_size_t binding_count = - iree_hal_cuda_descriptor_set_layout_binding_count( - iree_hal_cuda_pipeline_layout_descriptor_set_layout( - kernel_info.layout, i)); - iree_host_size_t index = - iree_hal_cuda_pipeline_layout_base_binding_index(kernel_info.layout, i); - memcpy(payload_ptr + index, command_buffer->descriptor_sets[i].bindings, - binding_count * sizeof(CUdeviceptr)); - } - - // Append the push constants to the kernel arguments. - iree_host_size_t base_index = - iree_hal_cuda_pipeline_layout_push_constant_index(kernel_info.layout); - // As commented in the above, what each kernel parameter points to is a - // CUdeviceptr, which as the size of a pointer on the target machine. we are - // just storing a 32-bit value for the push constant here instead. So we must - // process one element each type, for 64-bit machines. - for (iree_host_size_t i = 0; i < push_constant_count; i++) { - *((uint32_t*)params_ptr[base_index + i]) = - command_buffer->push_constants[i]; - } - - CUDA_KERNEL_NODE_PARAMS params = { - .func = kernel_info.function, - .blockDimX = kernel_info.block_size[0], - .blockDimY = kernel_info.block_size[1], - .blockDimZ = kernel_info.block_size[2], - .gridDimX = workgroup_x, - .gridDimY = workgroup_y, - .gridDimZ = workgroup_z, - .kernelParams = params_ptr, - .sharedMemBytes = kernel_info.shared_memory_size, - }; - - if (command_buffer->graph_node_count >= - IREE_HAL_CUDA_MAX_CONCURRENT_GRAPH_NODE_COUNT) { - return iree_make_status(IREE_STATUS_OUT_OF_RANGE, - "exceeded max concurrent node limit"); - } - - size_t dependency_count = command_buffer->cu_barrier_node ? 1 : 0; - IREE_CUDA_RETURN_AND_END_ZONE_IF_ERROR( - z0, command_buffer->symbols, - cuGraphAddKernelNode( - &command_buffer->cu_graph_nodes[command_buffer->graph_node_count++], - command_buffer->cu_graph, &command_buffer->cu_barrier_node, - dependency_count, ¶ms), - "cuGraphAddKernelNode"); - - IREE_CUDA_GRAPH_COMMAND_BUFFER_TRACE_ZONE_END( - command_buffer, IREE_HAL_CUDA_TRACING_VERBOSITY_FINE); - IREE_TRACE_ZONE_END(z0); - return iree_ok_status(); -} - -static iree_status_t iree_hal_cuda_graph_command_buffer_dispatch_indirect( - iree_hal_command_buffer_t* base_command_buffer, - iree_hal_executable_t* executable, int32_t entry_point, - iree_hal_buffer_ref_t workgroups_ref, iree_hal_dispatch_flags_t flags) { - return iree_make_status(IREE_STATUS_UNIMPLEMENTED, - "indirect dispatch not yet implemented"); -} - static iree_status_t iree_hal_cuda_graph_command_buffer_dispatch2( iree_hal_command_buffer_t* base_command_buffer, iree_hal_executable_t* executable, int32_t entry_point, @@ -907,16 +712,18 @@ static iree_status_t iree_hal_cuda_graph_command_buffer_dispatch2( // Lookup kernel parameters used for side-channeling additional launch // information from the compiler. - iree_hal_cuda_kernel_info_t kernel_info; + const iree_hal_cuda_kernel_params_t* kernel_params = NULL; IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_hal_cuda_native_executable_entry_point_kernel_info( - executable, entry_point, &kernel_info)); + z0, iree_hal_cuda_native_executable_lookup_kernel_params( + executable, entry_point, &kernel_params)); IREE_CUDA_GRAPH_COMMAND_BUFFER_TRACE_ZONE_BEGIN_EXTERNAL( command_buffer, IREE_HAL_CUDA_TRACING_VERBOSITY_FINE, - kernel_info.source_filename.data, kernel_info.source_filename.size, - kernel_info.source_line, kernel_info.function_name.data, - kernel_info.function_name.size, /*name=*/NULL, 0); + kernel_params->debug_info.source_filename.data, + kernel_params->debug_info.source_filename.size, + kernel_params->debug_info.source_line, + kernel_params->debug_info.name.data, kernel_params->debug_info.name.size, + /*name=*/NULL, 0); IREE_RETURN_AND_END_ZONE_IF_ERROR( z0, iree_hal_resource_set_insert(command_buffer->resource_set, 1, @@ -924,7 +731,7 @@ static iree_status_t iree_hal_cuda_graph_command_buffer_dispatch2( // We append push constants to the end of descriptors to form a linear chain // of kernel arguments. iree_host_size_t kernel_params_count = - kernel_info.binding_count + kernel_info.constant_count; + kernel_params->binding_count + kernel_params->constant_count; iree_host_size_t kernel_params_length = kernel_params_count * sizeof(void*); // TODO: use packed parameters instead of the indirection mechanism - this @@ -973,21 +780,21 @@ static iree_status_t iree_hal_cuda_graph_command_buffer_dispatch2( // CUdeviceptr, which as the size of a pointer on the target machine. we are // just storing a 32-bit value for the push constant here instead. So we must // process one element each type, for 64-bit machines. - for (iree_host_size_t i = 0; i < kernel_info.constant_count; i++) { - *((uint32_t*)params_ptr[kernel_info.binding_count + i]) = + for (iree_host_size_t i = 0; i < kernel_params->constant_count; i++) { + *((uint32_t*)params_ptr[kernel_params->binding_count + i]) = ((const uint32_t*)constants.data)[i]; } CUDA_KERNEL_NODE_PARAMS params = { - .func = kernel_info.function, - .blockDimX = kernel_info.block_size[0], - .blockDimY = kernel_info.block_size[1], - .blockDimZ = kernel_info.block_size[2], + .func = kernel_params->function, + .blockDimX = kernel_params->block_dims[0], + .blockDimY = kernel_params->block_dims[1], + .blockDimZ = kernel_params->block_dims[2], .gridDimX = workgroup_count[0], .gridDimY = workgroup_count[1], .gridDimZ = workgroup_count[2], .kernelParams = params_ptr, - .sharedMemBytes = kernel_info.shared_memory_size, + .sharedMemBytes = kernel_params->block_shared_memory_size, }; if (command_buffer->graph_node_count >= @@ -1038,12 +845,6 @@ static const iree_hal_command_buffer_vtable_t .update_buffer = iree_hal_cuda_graph_command_buffer_update_buffer, .copy_buffer = iree_hal_cuda_graph_command_buffer_copy_buffer, .collective = iree_hal_cuda_graph_command_buffer_collective, - .push_constants = iree_hal_cuda_graph_command_buffer_push_constants, - .push_descriptor_set = - iree_hal_cuda_graph_command_buffer_push_descriptor_set, - .dispatch = iree_hal_cuda_graph_command_buffer_dispatch, - .dispatch_indirect = - iree_hal_cuda_graph_command_buffer_dispatch_indirect, .dispatch2 = iree_hal_cuda_graph_command_buffer_dispatch2, .dispatch2_indirect = iree_hal_cuda_graph_command_buffer_dispatch2_indirect, diff --git a/runtime/src/iree/hal/drivers/cuda/native_executable.c b/runtime/src/iree/hal/drivers/cuda/native_executable.c index c3b32e01ecd23..2c3c703374201 100644 --- a/runtime/src/iree/hal/drivers/cuda/native_executable.c +++ b/runtime/src/iree/hal/drivers/cuda/native_executable.c @@ -11,7 +11,6 @@ #include "iree/base/api.h" #include "iree/hal/drivers/cuda/cuda_dynamic_symbols.h" #include "iree/hal/drivers/cuda/cuda_status_util.h" -#include "iree/hal/drivers/cuda/pipeline_layout.h" #include "iree/hal/utils/executable_debug_info.h" // flatcc schemas: @@ -25,20 +24,18 @@ typedef struct iree_hal_cuda_native_executable_t { // Abstract resource used for injecting reference counting and vtable; // must be at offset 0. iree_hal_resource_t resource; - iree_allocator_t host_allocator; const iree_hal_cuda_dynamic_symbols_t* symbols; - // The loaded CUDA module. - CUmodule cu_module; + // Loaded CUDA modules. + iree_host_size_t module_count; + CUmodule* modules; - iree_host_size_t entry_point_count; - // The list of entry point data pointers, pointing to trailing inline - // allocation after the end of this struct. - iree_hal_cuda_kernel_info_t entry_points[]; + // Exported kernels referencing the loaded modules. + iree_host_size_t export_count; + iree_hal_cuda_kernel_params_t exports[]; } iree_hal_cuda_native_executable_t; -// + Additional inline allocation for holding entry point information. static const iree_hal_executable_vtable_t iree_hal_cuda_native_executable_vtable; @@ -49,6 +46,41 @@ static iree_hal_cuda_native_executable_t* iree_hal_cuda_native_executable_cast( return (iree_hal_cuda_native_executable_t*)base_value; } +typedef struct iree_hal_cuda_limits_t { + uint32_t max_block_dims[3]; + uint32_t max_block_shared_memory_size; +} iree_hal_cuda_limits_t; +static iree_status_t iree_hal_cuda_query_limits( + const iree_hal_cuda_dynamic_symbols_t* symbols, CUdevice device, + iree_hal_cuda_limits_t* out_limits) { + memset(out_limits, 0, sizeof(*out_limits)); + + IREE_CUDA_RETURN_IF_ERROR( + symbols, + cuDeviceGetAttribute(&out_limits->max_block_dims[0], + CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X, device), + "cuDeviceGetAttribute"); + IREE_CUDA_RETURN_IF_ERROR( + symbols, + cuDeviceGetAttribute(&out_limits->max_block_dims[1], + CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Y, device), + "cuDeviceGetAttribute"); + IREE_CUDA_RETURN_IF_ERROR( + symbols, + cuDeviceGetAttribute(&out_limits->max_block_dims[2], + CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Z, device), + "cuDeviceGetAttribute"); + + IREE_CUDA_RETURN_IF_ERROR( + symbols, + cuDeviceGetAttribute( + &out_limits->max_block_shared_memory_size, + CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, device), + "cuDeviceGetAttribute"); + + return iree_ok_status(); +} + // Verifies the structure of the flatbuffer so that we can avoid doing so during // runtime. // @@ -56,7 +88,8 @@ static iree_hal_cuda_native_executable_t* iree_hal_cuda_native_executable_cast( // functions with internal linkage), however we shouldn't need to bounds check // anything within the flatbuffer after this succeeds. static iree_status_t iree_hal_cuda_native_executable_flatbuffer_verify( - iree_const_byte_span_t flatbuffer_data) { + iree_const_byte_span_t flatbuffer_data, + const iree_hal_cuda_limits_t* limits) { if (!flatbuffer_data.data) { return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, "flatbuffer data is not present"); @@ -76,37 +109,99 @@ static iree_status_t iree_hal_cuda_native_executable_flatbuffer_verify( iree_hal_cuda_ExecutableDef_table_t executable_def = iree_hal_cuda_ExecutableDef_as_root(flatbuffer_data.data); - flatbuffers_string_vec_t entry_points_vec = - iree_hal_cuda_ExecutableDef_entry_points_get(executable_def); - size_t entry_point_count = flatbuffers_string_vec_len(entry_points_vec); - for (size_t i = 0; i < entry_point_count; ++i) { + iree_hal_cuda_ModuleDef_vec_t modules_vec = + iree_hal_cuda_ExecutableDef_modules_get(executable_def); + iree_host_size_t module_count = iree_hal_cuda_ModuleDef_vec_len(modules_vec); + for (iree_host_size_t i = 0; i < module_count; ++i) { + iree_hal_cuda_ModuleDef_table_t module_def = + iree_hal_cuda_ModuleDef_vec_at(modules_vec, i); + if (!module_def) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "modules[%" PRIhsz "] is NULL", i); + } if (flatbuffers_string_len( - flatbuffers_string_vec_at(entry_points_vec, i)) == 0) { + iree_hal_cuda_ModuleDef_ptx_image_get(module_def)) == 0) { return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, - "executable entry point %zu has no name", i); + "modules[%" PRIhsz "] contents are empty", i); } } - iree_hal_cuda_BlockSize_vec_t block_sizes_vec = - iree_hal_cuda_ExecutableDef_block_sizes_get(executable_def); - size_t block_size_count = iree_hal_cuda_BlockSize_vec_len(block_sizes_vec); - if (block_size_count == 0) { - return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, - "no block sizes present"); - } + iree_hal_cuda_ExportDef_vec_t exports_vec = + iree_hal_cuda_ExecutableDef_exports_get(executable_def); + for (iree_host_size_t i = 0; i < iree_hal_cuda_ExportDef_vec_len(exports_vec); + ++i) { + iree_hal_cuda_ExportDef_table_t export_def = + iree_hal_cuda_ExportDef_vec_at(exports_vec, i); + if (!export_def) continue; + + uint32_t module_ordinal = + iree_hal_cuda_ExportDef_module_ordinal_get(export_def); + if (module_ordinal >= module_count) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "exports[%" PRIhsz + "] module_ordinal %u is out of bounds %" PRIhsz, + i, module_ordinal, module_count); + } - if (entry_point_count != block_size_count) { - return iree_make_status( - IREE_STATUS_INVALID_ARGUMENT, - "entry points (%zu) and block sizes (%zu) count mismatch", - entry_point_count, block_size_count); - } + if (flatbuffers_string_len( + iree_hal_cuda_ExportDef_kernel_name_get(export_def)) == 0) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "exports[%" PRIhsz "] name is empty", i); + } - flatbuffers_string_t ptx_image = - iree_hal_cuda_ExecutableDef_ptx_image_get(executable_def); - if (flatbuffers_string_len(ptx_image) == 0) { - return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, - "no PTX image present"); + if (iree_hal_cuda_ExportDef_block_dims_is_present(export_def)) { + const iree_hal_cuda_BlockDims_t* block_dims = + iree_hal_cuda_ExportDef_block_dims_get(export_def); + if (block_dims->x > limits->max_block_dims[0] || + block_dims->y > limits->max_block_dims[1] || + block_dims->z > limits->max_block_dims[2]) { + return iree_make_status( + IREE_STATUS_INVALID_ARGUMENT, + "exports[%" PRIhsz + "] block dims %ux%ux%u exceeds device maximum %ux%ux%u", + i, block_dims->x, block_dims->y, block_dims->z, + limits->max_block_dims[0], limits->max_block_dims[1], + limits->max_block_dims[2]); + } + } else { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "exports[%" PRIhsz "] blocks dims are missing", + i); + } + + uint32_t block_shared_memory_size = + iree_hal_cuda_ExportDef_block_shared_memory_size_get(export_def); + if (block_shared_memory_size > limits->max_block_shared_memory_size) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "exports[%" PRIhsz + "] requires %uB of shared memory and " + "exceeds the device maximum of %uB per block", + i, block_shared_memory_size, + limits->max_block_shared_memory_size); + } + + uint32_t constant_count = + iree_hal_cuda_ExportDef_constant_count_get(export_def); + if (constant_count > IREE_HAL_CUDA_MAX_DISPATCH_CONSTANT_COUNT) { + return iree_make_status( + IREE_STATUS_INVALID_ARGUMENT, + "exports[%" PRIhsz "] constant_count %u exceeds maximum of %u", i, + constant_count, IREE_HAL_CUDA_MAX_DISPATCH_CONSTANT_COUNT); + } + + iree_hal_cuda_BindingBits_vec_t binding_flags_vec = + iree_hal_cuda_ExportDef_binding_flags_get(export_def); + if (iree_hal_cuda_BindingBits_vec_len(binding_flags_vec) > + IREE_HAL_CUDA_MAX_DISPATCH_BINDING_COUNT) { + return iree_make_status( + IREE_STATUS_INVALID_ARGUMENT, + "exports[%" PRIhsz "] binding_flags count %zu exceeds maximum of %u", + i, iree_hal_cuda_BindingBits_vec_len(binding_flags_vec), + IREE_HAL_CUDA_MAX_DISPATCH_BINDING_COUNT); + } + + IREE_RETURN_IF_ERROR(iree_hal_debug_verify_export_def( + iree_hal_cuda_ExportDef_debug_info_get(export_def))); } return iree_ok_status(); @@ -123,167 +218,154 @@ iree_status_t iree_hal_cuda_native_executable_create( *out_executable = NULL; iree_hal_cuda_native_executable_t* executable = NULL; + // TODO: move to the executable cache to avoid repeated queries. + iree_hal_cuda_limits_t limits = {0}; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_cuda_query_limits(symbols, device, &limits)); + IREE_RETURN_AND_END_ZONE_IF_ERROR( z0, iree_hal_cuda_native_executable_flatbuffer_verify( - executable_params->executable_data)); + executable_params->executable_data, &limits)); iree_hal_cuda_ExecutableDef_table_t executable_def = iree_hal_cuda_ExecutableDef_as_root( executable_params->executable_data.data); - flatbuffers_string_t ptx_image = - iree_hal_cuda_ExecutableDef_ptx_image_get(executable_def); - flatbuffers_uint32_vec_t shared_memory_sizes = - iree_hal_cuda_ExecutableDef_shared_memory_size_get(executable_def); - flatbuffers_string_vec_t entry_points_vec = - iree_hal_cuda_ExecutableDef_entry_points_get(executable_def); - iree_hal_cuda_BlockSize_vec_t block_sizes_vec = - iree_hal_cuda_ExecutableDef_block_sizes_get(executable_def); - iree_host_size_t entry_point_count = - flatbuffers_string_vec_len(entry_points_vec); + iree_hal_cuda_ModuleDef_vec_t modules_vec = + iree_hal_cuda_ExecutableDef_modules_get(executable_def); + iree_host_size_t module_count = iree_hal_cuda_ModuleDef_vec_len(modules_vec); + iree_hal_cuda_ExportDef_vec_t exports_vec = + iree_hal_cuda_ExecutableDef_exports_get(executable_def); + iree_host_size_t export_count = iree_hal_cuda_ExportDef_vec_len(exports_vec); // Calculate the total number of characters across all entry point names. This // is only required when tracing so that we can store copies of the names as // the flatbuffer storing the strings may be released while the executable is // still live. - iree_host_size_t total_entry_point_name_chars = 0; + iree_host_size_t total_export_info_length = 0; IREE_TRACE({ - for (iree_host_size_t i = 0; i < entry_point_count; i++) { - const char* entry_name = flatbuffers_string_vec_at(entry_points_vec, i); - total_entry_point_name_chars += flatbuffers_string_len(entry_name); + for (iree_host_size_t i = 0; i < export_count; ++i) { + iree_hal_cuda_ExportDef_table_t export_def = + iree_hal_cuda_ExportDef_vec_at(exports_vec, i); + total_export_info_length += iree_hal_debug_calculate_export_info_size( + iree_hal_cuda_ExportDef_debug_info_get(export_def)); } }); // Allocate storage for the kernel module. - iree_host_size_t total_size = - sizeof(*executable) + - entry_point_count * sizeof(executable->entry_points[0]) + - total_entry_point_name_chars; + const iree_host_size_t total_size = + sizeof(*executable) + module_count * sizeof(executable->modules[0]) + + export_count * sizeof(executable->exports[0]) + total_export_info_length; IREE_RETURN_AND_END_ZONE_IF_ERROR( z0, iree_allocator_malloc(host_allocator, total_size, (void**)&executable)); - IREE_TRACE( - char* string_table_buffer = - (char*)((char*)executable + sizeof(*executable) + - entry_point_count * sizeof(executable->entry_points[0]))); iree_hal_resource_initialize(&iree_hal_cuda_native_executable_vtable, &executable->resource); - - // Load the PTX image - this will fail if the device cannot handle the - // contents. We could check this prior to creating - CUmodule module = NULL; - - iree_status_t status = IREE_CURESULT_TO_STATUS( - symbols, cuModuleLoadDataEx(&module, ptx_image, 0, NULL, NULL), - "cuModuleLoadDataEx"); - - // Query max optin shared memory per block - we'll use it to compare with - // kernel usages. - int32_t max_shared_memory = 0; - if (iree_status_is_ok(status)) { + executable->host_allocator = host_allocator; + executable->symbols = symbols; + executable->module_count = module_count; + executable->modules = + (CUmodule*)((uint8_t*)executable + sizeof(*executable) + + export_count * sizeof(executable->exports[0])); + executable->export_count = export_count; + IREE_TRACE( + iree_hal_debug_export_info_t* export_infos = + (iree_hal_debug_export_info_t*)((uint8_t*)executable->modules + + module_count * + sizeof(executable->modules[0]))); + + // Publish any embedded source files to the tracing infrastructure. + iree_hal_debug_publish_source_files( + iree_hal_cuda_ExecutableDef_source_files_get(executable_def)); + + // Load each module first so that exports can reference them. + iree_status_t status = iree_ok_status(); + for (iree_host_size_t i = 0; i < module_count; ++i) { + iree_hal_cuda_ModuleDef_table_t module_def = + iree_hal_cuda_ModuleDef_vec_at(modules_vec, i); + + // WARNING: CUDA doesn't take an expected length here so we can't bound it. + // It's likely that users could craft inputs that read beyond the extents of + // the embedded binary. + flatbuffers_string_t ptx_image = + iree_hal_cuda_ModuleDef_ptx_image_get(module_def); + + // TODO: pass cuJitOption values to get log info and other info back. + CUmodule module = NULL; status = IREE_CURESULT_TO_STATUS( - symbols, - cuDeviceGetAttribute( - &max_shared_memory, - CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, device), - "cuDeviceGetAttribute"); + symbols, cuModuleLoadDataEx(&module, ptx_image, 0, NULL, NULL), + "cuModuleLoadDataEx"); + if (!iree_status_is_ok(status)) { + status = iree_status_annotate( + status, + IREE_SV("mismatched target chip? missing/wrong bitcode directory?")); + break; + } + + executable->modules[i] = module; } if (iree_status_is_ok(status)) { - executable->host_allocator = host_allocator; - executable->symbols = symbols; - executable->cu_module = module; - executable->entry_point_count = entry_point_count; - - // Publish any embedded source files to the tracing infrastructure. - if (iree_status_is_ok(status)) { - iree_hal_debug_publish_source_files( - iree_hal_cuda_ExecutableDef_source_files_get(executable_def)); - } - - for (iree_host_size_t i = 0; i < entry_point_count; i++) { - // Lookup the function in the module; this should always succeed but we - // cannot trust that the input was generated by our compiler. + for (iree_host_size_t i = 0; i < export_count; ++i) { + iree_hal_cuda_ExportDef_table_t export_def = + iree_hal_cuda_ExportDef_vec_at(exports_vec, i); + + // Lookup the function in the module; this should always succeed but + // we cannot trust that the input was generated by our compiler. + uint32_t module_ordinal = + iree_hal_cuda_ExportDef_module_ordinal_get(export_def); + CUmodule module = executable->modules[module_ordinal]; + flatbuffers_string_t kernel_name = + iree_hal_cuda_ExportDef_kernel_name_get(export_def); CUfunction function = NULL; - const char* entry_name = flatbuffers_string_vec_at(entry_points_vec, i); status = IREE_CURESULT_TO_STATUS( - symbols, - cuModuleGetFunction(&function, executable->cu_module, entry_name), + symbols, cuModuleGetFunction(&function, module, kernel_name), "cuModuleGetFunction"); if (!iree_status_is_ok(status)) break; if (!function) { status = iree_make_status(IREE_STATUS_NOT_FOUND, - "exported module function '%s' not found", - entry_name); + "exports[%" PRIhsz + "] kernel `%s` not found in modules[%u]", + i, kernel_name, module_ordinal); break; } - if (shared_memory_sizes[i] > max_shared_memory) { - status = iree_make_status(IREE_STATUS_OUT_OF_RANGE, - "requested shared memory size of %d bytes " - "larger than allowed size of %d bytes", - shared_memory_sizes[i], max_shared_memory); - } else { - status = IREE_CURESULT_TO_STATUS( - symbols, - cuFuncSetAttribute(function, - CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, - shared_memory_sizes[i]), - "cuFuncSetAttribute"); - } + uint32_t block_shared_memory_size = + iree_hal_cuda_ExportDef_block_shared_memory_size_get(export_def); + status = IREE_CURESULT_TO_STATUS( + symbols, + cuFuncSetAttribute(function, + CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, + block_shared_memory_size), + "cuFuncSetAttribute"); if (!iree_status_is_ok(status)) break; - // TODO(#18154): embed all of this on a single flatbuffer table - // per-export. - // // Package required parameters for kernel launches for each entry point. - iree_hal_cuda_kernel_info_t* info = &executable->entry_points[i]; - info->layout = executable_params->pipeline_layouts[i]; - iree_hal_pipeline_layout_retain(info->layout); - info->function = function; - info->constant_count = - iree_hal_cuda_pipeline_layout_push_constant_count(info->layout); - info->binding_count = - iree_hal_cuda_pipeline_layout_total_binding_count(info->layout); - info->block_size[0] = block_sizes_vec[i].x; - info->block_size[1] = block_sizes_vec[i].y; - info->block_size[2] = block_sizes_vec[i].z; - info->shared_memory_size = shared_memory_sizes[i]; - - if (info->binding_count > - IREE_HAL_CUDA_MAX_DESCRIPTOR_SET_BINDING_COUNT) { - status = iree_make_status( - IREE_STATUS_RESOURCE_EXHAUSTED, - "exceeded available binding slots; requested %u of maximum %d", - info->binding_count, - IREE_HAL_CUDA_MAX_DESCRIPTOR_SET_BINDING_COUNT); - } - if (!iree_status_is_ok(status)) break; + iree_hal_cuda_kernel_params_t* kernel_info = &executable->exports[i]; + kernel_info->function = function; + const iree_hal_cuda_BlockDims_t* block_dims = + iree_hal_cuda_ExportDef_block_dims_get(export_def); + kernel_info->block_dims[0] = block_dims->x; + kernel_info->block_dims[1] = block_dims->y; + kernel_info->block_dims[2] = block_dims->z; + kernel_info->block_shared_memory_size = + iree_hal_cuda_ExportDef_block_shared_memory_size_get(export_def); + kernel_info->constant_count = + iree_hal_cuda_ExportDef_constant_count_get(export_def); + iree_hal_cuda_BindingBits_vec_t binding_flags_vec = + iree_hal_cuda_ExportDef_binding_flags_get(export_def); + kernel_info->binding_count = + iree_hal_cuda_BindingBits_vec_len(binding_flags_vec); - // Stash the entry point name in the string table for use when tracing. IREE_TRACE({ - iree_host_size_t entry_name_length = flatbuffers_string_len(entry_name); - memcpy(string_table_buffer, entry_name, entry_name_length); - info->function_name = - iree_make_string_view(string_table_buffer, entry_name_length); - string_table_buffer += entry_name_length; - }); - - IREE_TRACE({ - if (iree_hal_cuda_ExecutableDef_source_locations_is_present( - executable_def)) { - iree_hal_debug_FileLineLocDef_vec_t source_locs_vec = - iree_hal_cuda_ExecutableDef_source_locations_get(executable_def); - iree_hal_debug_FileLineLocDef_table_t source_loc = - iree_hal_debug_FileLineLocDef_vec_at(source_locs_vec, i); - flatbuffers_string_t filename = - iree_hal_debug_FileLineLocDef_filename_get(source_loc); - uint32_t line = iree_hal_debug_FileLineLocDef_line_get(source_loc); - info->source_filename = - iree_make_string_view(filename, flatbuffers_string_len(filename)); - info->source_line = line; - } + iree_hal_debug_copy_export_info( + iree_hal_cuda_ExportDef_debug_info_get(export_def), + &export_infos[i]); + kernel_info->debug_info.name = export_infos[i].name; + kernel_info->debug_info.source_filename = + export_infos[i].source_filename; + kernel_info->debug_info.source_line = export_infos[i].source_line; }); } } @@ -305,30 +387,31 @@ static void iree_hal_cuda_native_executable_destroy( iree_allocator_t host_allocator = executable->host_allocator; IREE_TRACE_ZONE_BEGIN(z0); - for (iree_host_size_t i = 0; i < executable->entry_point_count; ++i) { - iree_hal_pipeline_layout_release(executable->entry_points[i].layout); - } - if (executable->cu_module) { - IREE_CUDA_IGNORE_ERROR(executable->symbols, - cuModuleUnload(executable->cu_module)); + for (iree_host_size_t i = 0; i < executable->module_count; ++i) { + if (executable->modules[i]) { + IREE_CUDA_IGNORE_ERROR(executable->symbols, + cuModuleUnload(executable->modules[i])); + } } + iree_allocator_free(host_allocator, executable); IREE_TRACE_ZONE_END(z0); } -iree_status_t iree_hal_cuda_native_executable_entry_point_kernel_info( - iree_hal_executable_t* base_executable, int32_t entry_point, - iree_hal_cuda_kernel_info_t* out_info) { +iree_status_t iree_hal_cuda_native_executable_lookup_kernel_params( + iree_hal_executable_t* base_executable, int32_t ordinal, + const iree_hal_cuda_kernel_params_t** out_params) { iree_hal_cuda_native_executable_t* executable = iree_hal_cuda_native_executable_cast(base_executable); - if (entry_point >= executable->entry_point_count) { - return iree_make_status(IREE_STATUS_OUT_OF_RANGE, - "entry point ordinal %d out of range; executable " - "only contains %" PRIhsz " entry points", - entry_point, executable->entry_point_count); + if (ordinal >= executable->export_count) { + return iree_make_status( + IREE_STATUS_OUT_OF_RANGE, + "export ordinal %d out of range; executable contains %" PRIhsz + " exports", + ordinal, executable->export_count); } - memcpy(out_info, &executable->entry_points[entry_point], sizeof(*out_info)); + *out_params = &executable->exports[ordinal]; return iree_ok_status(); } diff --git a/runtime/src/iree/hal/drivers/cuda/native_executable.h b/runtime/src/iree/hal/drivers/cuda/native_executable.h index 3f6faf5ec8ef8..d0dad3da525ca 100644 --- a/runtime/src/iree/hal/drivers/cuda/native_executable.h +++ b/runtime/src/iree/hal/drivers/cuda/native_executable.h @@ -19,20 +19,31 @@ extern "C" { #endif // __cplusplus -typedef struct iree_hal_cuda_kernel_info_t { - // TODO(#18154): remove when using simplified bindings. - iree_hal_pipeline_layout_t* layout; +// The max number of per-dispatch bindings allowed in the CUDA HAL +// implementation. +#define IREE_HAL_CUDA_MAX_DISPATCH_BINDING_COUNT 16 + +// The max number of per-dispatch constants supported by the CUDA HAL +// implementation. +#define IREE_HAL_CUDA_MAX_DISPATCH_CONSTANT_COUNT 64 + +typedef struct iree_hal_cuda_kernel_debug_info_t { + iree_string_view_t name; + iree_string_view_t source_filename; + uint32_t source_line; +} iree_hal_cuda_kernel_debug_info_t; + +typedef struct iree_hal_cuda_kernel_params_t { CUfunction function; + uint32_t constant_count; uint32_t binding_count; - // TODO(#18154): add bitfield indicating indirect bindings. - uint32_t block_size[3]; - uint32_t shared_memory_size; - IREE_TRACE(iree_string_view_t function_name;) - IREE_TRACE(iree_string_view_t source_filename;) - IREE_TRACE(uint32_t source_line;) -} iree_hal_cuda_kernel_info_t; + uint32_t block_dims[3]; + uint32_t block_shared_memory_size; + + IREE_TRACE(iree_hal_cuda_kernel_debug_info_t debug_info;) +} iree_hal_cuda_kernel_params_t; // Creates an IREE executable from a CUDA PTX module. The module may contain // several kernels that can be extracted along with the associated block size. @@ -43,9 +54,9 @@ iree_status_t iree_hal_cuda_native_executable_create( // Returns the kernel launch information for the given |entry_point| in the // |executable|. -iree_status_t iree_hal_cuda_native_executable_entry_point_kernel_info( +iree_status_t iree_hal_cuda_native_executable_lookup_kernel_params( iree_hal_executable_t* executable, int32_t entry_point, - iree_hal_cuda_kernel_info_t* out_info); + const iree_hal_cuda_kernel_params_t** out_params); #ifdef __cplusplus } // extern "C" diff --git a/runtime/src/iree/hal/drivers/cuda/pipeline_layout.c b/runtime/src/iree/hal/drivers/cuda/pipeline_layout.c deleted file mode 100644 index a14d312d1ce8a..0000000000000 --- a/runtime/src/iree/hal/drivers/cuda/pipeline_layout.c +++ /dev/null @@ -1,260 +0,0 @@ -// Copyright 2023 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "iree/hal/drivers/cuda/pipeline_layout.h" - -#include - -#include "iree/base/api.h" -#include "iree/base/tracing.h" - -//===----------------------------------------------------------------------===// -// iree_hal_cuda_descriptor_set_layout_t -//===----------------------------------------------------------------------===// - -typedef struct iree_hal_cuda_descriptor_set_layout_t { - // Abstract resource used for injecting reference counting and vtable; - // must be at offset 0. - iree_hal_resource_t resource; - - // The host allocator used for creating this descriptor set layout struct. - iree_allocator_t host_allocator; - - // The total number of bindings in this descriptor set. - iree_host_size_t binding_count; -} iree_hal_cuda_descriptor_set_layout_t; - -static const iree_hal_descriptor_set_layout_vtable_t - iree_hal_cuda_descriptor_set_layout_vtable; - -static iree_hal_cuda_descriptor_set_layout_t* -iree_hal_cuda_descriptor_set_layout_cast( - iree_hal_descriptor_set_layout_t* base_value) { - IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_cuda_descriptor_set_layout_vtable); - return (iree_hal_cuda_descriptor_set_layout_t*)base_value; -} - -static const iree_hal_cuda_descriptor_set_layout_t* -iree_hal_cuda_descriptor_set_layout_const_cast( - const iree_hal_descriptor_set_layout_t* base_value) { - IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_cuda_descriptor_set_layout_vtable); - return (const iree_hal_cuda_descriptor_set_layout_t*)base_value; -} - -iree_status_t iree_hal_cuda_descriptor_set_layout_create( - iree_hal_descriptor_set_layout_flags_t flags, - iree_host_size_t binding_count, - const iree_hal_descriptor_set_layout_binding_t* bindings, - iree_allocator_t host_allocator, - iree_hal_descriptor_set_layout_t** out_descriptor_set_layout) { - IREE_ASSERT_ARGUMENT(!binding_count || bindings); - IREE_ASSERT_ARGUMENT(out_descriptor_set_layout); - IREE_TRACE_ZONE_BEGIN(z0); - - *out_descriptor_set_layout = NULL; - - iree_hal_cuda_descriptor_set_layout_t* descriptor_set_layout = NULL; - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_allocator_malloc(host_allocator, sizeof(*descriptor_set_layout), - (void**)&descriptor_set_layout)); - - iree_hal_resource_initialize(&iree_hal_cuda_descriptor_set_layout_vtable, - &descriptor_set_layout->resource); - descriptor_set_layout->host_allocator = host_allocator; - descriptor_set_layout->binding_count = binding_count; - *out_descriptor_set_layout = - (iree_hal_descriptor_set_layout_t*)descriptor_set_layout; - - IREE_TRACE_ZONE_END(z0); - return iree_ok_status(); -} - -iree_host_size_t iree_hal_cuda_descriptor_set_layout_binding_count( - const iree_hal_descriptor_set_layout_t* base_descriptor_set_layout) { - const iree_hal_cuda_descriptor_set_layout_t* descriptor_set_layout = - iree_hal_cuda_descriptor_set_layout_const_cast( - base_descriptor_set_layout); - return descriptor_set_layout->binding_count; -} - -static void iree_hal_cuda_descriptor_set_layout_destroy( - iree_hal_descriptor_set_layout_t* base_descriptor_set_layout) { - iree_hal_cuda_descriptor_set_layout_t* descriptor_set_layout = - iree_hal_cuda_descriptor_set_layout_cast(base_descriptor_set_layout); - IREE_TRACE_ZONE_BEGIN(z0); - - iree_allocator_t host_allocator = descriptor_set_layout->host_allocator; - - iree_allocator_free(host_allocator, descriptor_set_layout); - - IREE_TRACE_ZONE_END(z0); -} - -static const iree_hal_descriptor_set_layout_vtable_t - iree_hal_cuda_descriptor_set_layout_vtable = { - .destroy = iree_hal_cuda_descriptor_set_layout_destroy, -}; - -//===----------------------------------------------------------------------===// -// iree_hal_cuda_pipeline_layout_t -//===----------------------------------------------------------------------===// - -typedef struct iree_hal_cuda_pipeline_layout_t { - // Abstract resource used for injecting reference counting and vtable; - // must be at offset 0. - iree_hal_resource_t resource; - - // The host allocator used for creating this pipeline layout struct. - iree_allocator_t host_allocator; - - // The kernel argument index for push constants. - // Note that push constants are placed after all normal descriptors. - iree_host_size_t push_constant_base_index; - iree_host_size_t push_constant_count; - - iree_host_size_t set_layout_count; - // The list of descriptor set layout pointers, pointing to trailing inline - // allocation after the end of this struct. - struct { - iree_hal_descriptor_set_layout_t* set_layout; - // Base kernel argument index for this descriptor set. - iree_host_size_t base_index; - } set_layouts[]; -} iree_hal_cuda_pipeline_layout_t; -// + Additional inline allocation for holding all descriptor sets. - -static const iree_hal_pipeline_layout_vtable_t - iree_hal_cuda_pipeline_layout_vtable; - -static iree_hal_cuda_pipeline_layout_t* iree_hal_cuda_pipeline_layout_cast( - iree_hal_pipeline_layout_t* base_value) { - IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_cuda_pipeline_layout_vtable); - return (iree_hal_cuda_pipeline_layout_t*)base_value; -} - -static const iree_hal_cuda_pipeline_layout_t* -iree_hal_cuda_pipeline_layout_const_cast( - const iree_hal_pipeline_layout_t* base_value) { - IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_cuda_pipeline_layout_vtable); - return (const iree_hal_cuda_pipeline_layout_t*)base_value; -} - -iree_status_t iree_hal_cuda_pipeline_layout_create( - iree_host_size_t set_layout_count, - iree_hal_descriptor_set_layout_t* const* set_layouts, - iree_host_size_t push_constant_count, iree_allocator_t host_allocator, - iree_hal_pipeline_layout_t** out_pipeline_layout) { - IREE_ASSERT_ARGUMENT(!set_layout_count || set_layouts); - IREE_ASSERT_ARGUMENT(out_pipeline_layout); - IREE_TRACE_ZONE_BEGIN(z0); - - *out_pipeline_layout = NULL; - if (push_constant_count > IREE_HAL_CUDA_MAX_PUSH_CONSTANT_COUNT) { - IREE_TRACE_ZONE_END(z0); - return iree_make_status( - IREE_STATUS_INVALID_ARGUMENT, - "push constant count %" PRIhsz " over the limit of %d", - push_constant_count, IREE_HAL_CUDA_MAX_PUSH_CONSTANT_COUNT); - } - - // Currently the pipeline layout doesn't do anything. - // TODO: Handle creating the argument layout at that time hadling both push - // constant and buffers. - iree_hal_cuda_pipeline_layout_t* pipeline_layout = NULL; - iree_host_size_t total_size = - sizeof(*pipeline_layout) + - set_layout_count * sizeof(*pipeline_layout->set_layouts); - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_allocator_malloc(host_allocator, total_size, - (void**)&pipeline_layout)); - - iree_hal_resource_initialize(&iree_hal_cuda_pipeline_layout_vtable, - &pipeline_layout->resource); - pipeline_layout->host_allocator = host_allocator; - pipeline_layout->set_layout_count = set_layout_count; - iree_host_size_t base_index = 0; - for (iree_host_size_t i = 0; i < set_layout_count; ++i) { - pipeline_layout->set_layouts[i].set_layout = set_layouts[i]; - // Copy and retain all descriptor sets so we don't lose them. - iree_hal_descriptor_set_layout_retain(set_layouts[i]); - pipeline_layout->set_layouts[i].base_index = base_index; - base_index += - iree_hal_cuda_descriptor_set_layout_binding_count(set_layouts[i]); - } - pipeline_layout->push_constant_base_index = base_index; - pipeline_layout->push_constant_count = push_constant_count; - *out_pipeline_layout = (iree_hal_pipeline_layout_t*)pipeline_layout; - - IREE_TRACE_ZONE_END(z0); - return iree_ok_status(); -} - -static void iree_hal_cuda_pipeline_layout_destroy( - iree_hal_pipeline_layout_t* base_pipeline_layout) { - iree_hal_cuda_pipeline_layout_t* pipeline_layout = - iree_hal_cuda_pipeline_layout_cast(base_pipeline_layout); - IREE_TRACE_ZONE_BEGIN(z0); - - iree_allocator_t host_allocator = pipeline_layout->host_allocator; - - for (iree_host_size_t i = 0; i < pipeline_layout->set_layout_count; ++i) { - iree_hal_descriptor_set_layout_release( - pipeline_layout->set_layouts[i].set_layout); - } - iree_allocator_free(host_allocator, pipeline_layout); - - IREE_TRACE_ZONE_END(z0); -} - -iree_host_size_t iree_hal_cuda_pipeline_layout_descriptor_set_count( - const iree_hal_pipeline_layout_t* base_pipeline_layout) { - const iree_hal_cuda_pipeline_layout_t* pipeline_layout = - iree_hal_cuda_pipeline_layout_const_cast(base_pipeline_layout); - return pipeline_layout->set_layout_count; -} - -const iree_hal_descriptor_set_layout_t* -iree_hal_cuda_pipeline_layout_descriptor_set_layout( - const iree_hal_pipeline_layout_t* base_pipeline_layout, uint32_t set) { - const iree_hal_cuda_pipeline_layout_t* pipeline_layout = - iree_hal_cuda_pipeline_layout_const_cast(base_pipeline_layout); - if (set < pipeline_layout->set_layout_count) { - return pipeline_layout->set_layouts[set].set_layout; - } - return NULL; -} - -iree_host_size_t iree_hal_cuda_pipeline_layout_base_binding_index( - const iree_hal_pipeline_layout_t* base_pipeline_layout, uint32_t set) { - const iree_hal_cuda_pipeline_layout_t* pipeline_layout = - iree_hal_cuda_pipeline_layout_const_cast(base_pipeline_layout); - return pipeline_layout->set_layouts[set].base_index; -} - -iree_host_size_t iree_hal_cuda_pipeline_layout_total_binding_count( - const iree_hal_pipeline_layout_t* base_pipeline_layout) { - return iree_hal_cuda_pipeline_layout_push_constant_index( - base_pipeline_layout); -} - -iree_host_size_t iree_hal_cuda_pipeline_layout_push_constant_index( - const iree_hal_pipeline_layout_t* base_pipeline_layout) { - const iree_hal_cuda_pipeline_layout_t* pipeline_layout = - iree_hal_cuda_pipeline_layout_const_cast(base_pipeline_layout); - return pipeline_layout->push_constant_base_index; -} - -iree_host_size_t iree_hal_cuda_pipeline_layout_push_constant_count( - const iree_hal_pipeline_layout_t* base_pipeline_layout) { - const iree_hal_cuda_pipeline_layout_t* pipeline_layout = - iree_hal_cuda_pipeline_layout_const_cast(base_pipeline_layout); - return pipeline_layout->push_constant_count; -} - -static const iree_hal_pipeline_layout_vtable_t - iree_hal_cuda_pipeline_layout_vtable = { - .destroy = iree_hal_cuda_pipeline_layout_destroy, -}; diff --git a/runtime/src/iree/hal/drivers/cuda/pipeline_layout.h b/runtime/src/iree/hal/drivers/cuda/pipeline_layout.h deleted file mode 100644 index 49eaf54821208..0000000000000 --- a/runtime/src/iree/hal/drivers/cuda/pipeline_layout.h +++ /dev/null @@ -1,110 +0,0 @@ -// Copyright 2023 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#ifndef IREE_HAL_DRIVERS_CUDA_PIPELINE_LAYOUT_H_ -#define IREE_HAL_DRIVERS_CUDA_PIPELINE_LAYOUT_H_ - -#include "iree/base/api.h" -#include "iree/hal/api.h" - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - -// The max number of bindings per descriptor set allowed in the CUDA HAL -// implementation. -#define IREE_HAL_CUDA_MAX_DESCRIPTOR_SET_BINDING_COUNT 16 - -// The max number of descriptor sets allowed in the CUDA HAL implementation. -// -// This depends on the general descriptor set planning in IREE and should adjust -// with it. -#define IREE_HAL_CUDA_MAX_DESCRIPTOR_SET_COUNT 4 - -// The max number of push constants supported by the CUDA HAL implementation. -#define IREE_HAL_CUDA_MAX_PUSH_CONSTANT_COUNT 64 - -// Note that IREE HAL uses a descriptor binding model for expressing resources -// to the kernels--each descriptor specifies the resource information, together -// with a (set, binding) number indicating which "slots" it's bound to. -// -// In CUDA, however, we don't have a direct correspondance of such mechanism. -// Resources are expressed as kernel arguments. Therefore to implement IREE -// HAL descriptor set and pipepline layout in CUDA, we order and flatten all -// sets and bindings and map to them to a linear array of kernel arguments. -// -// For example, given a pipeline layout with two sets and two bindings each: -// (set #, binding #) | kernel argument # -// :----------------: | :---------------: -// (0, 0) | 0 -// (0, 4) | 1 -// (2, 1) | 2 -// (2, 3) | 3 - -//===----------------------------------------------------------------------===// -// iree_hal_cuda_descriptor_set_layout_t -//===----------------------------------------------------------------------===// - -// Creates a descriptor set layout with the given |bindings|. -// -// Bindings in a descriptor set map to a list of consecutive kernel arguments in -// CUDA kernels. -iree_status_t iree_hal_cuda_descriptor_set_layout_create( - iree_hal_descriptor_set_layout_flags_t flags, - iree_host_size_t binding_count, - const iree_hal_descriptor_set_layout_binding_t* bindings, - iree_allocator_t host_allocator, - iree_hal_descriptor_set_layout_t** out_descriptor_set_layout); - -// Returns the binding count for the given descriptor set layout. -iree_host_size_t iree_hal_cuda_descriptor_set_layout_binding_count( - const iree_hal_descriptor_set_layout_t* descriptor_set_layout); - -//===----------------------------------------------------------------------===// -// iree_hal_cuda_pipeline_layout_t -//===----------------------------------------------------------------------===// - -// Creates the pipeline layout with the given |set_layouts| and -// |push_constant_count|. -// -// Bindings in the pipeline map to kernel arguments in CUDA kernels, followed by -// the kernel argument for the push constant data. -iree_status_t iree_hal_cuda_pipeline_layout_create( - iree_host_size_t set_layout_count, - iree_hal_descriptor_set_layout_t* const* set_layouts, - iree_host_size_t push_constant_count, iree_allocator_t host_allocator, - iree_hal_pipeline_layout_t** out_pipeline_layout); - -// Returns the total number of sets in the given |pipeline_layout|. -iree_host_size_t iree_hal_cuda_pipeline_layout_descriptor_set_count( - const iree_hal_pipeline_layout_t* pipeline_layout); - -// Returns the descriptor set layout of the given |set| in |pipeline_layout|. -const iree_hal_descriptor_set_layout_t* -iree_hal_cuda_pipeline_layout_descriptor_set_layout( - const iree_hal_pipeline_layout_t* pipeline_layout, uint32_t set); - -// Returns the base kernel argument index for the given set. -iree_host_size_t iree_hal_cuda_pipeline_layout_base_binding_index( - const iree_hal_pipeline_layout_t* pipeline_layout, uint32_t set); - -// Returns the total number of descriptor bindings across all sets. -iree_host_size_t iree_hal_cuda_pipeline_layout_total_binding_count( - const iree_hal_pipeline_layout_t* pipeline_layout); - -// Returns the kernel argument index for push constant data. -iree_host_size_t iree_hal_cuda_pipeline_layout_push_constant_index( - const iree_hal_pipeline_layout_t* pipeline_layout); - -// Returns the number of push constants in the pipeline layout. -iree_host_size_t iree_hal_cuda_pipeline_layout_push_constant_count( - const iree_hal_pipeline_layout_t* pipeline_layout); - -#ifdef __cplusplus -} // extern "C" -#endif // __cplusplus - -#endif // IREE_HAL_DRIVERS_CUDA_PIPELINE_LAYOUT_H_ diff --git a/runtime/src/iree/hal/drivers/cuda/stream_command_buffer.c b/runtime/src/iree/hal/drivers/cuda/stream_command_buffer.c index 75a415c6238eb..a85ee727b70ba 100644 --- a/runtime/src/iree/hal/drivers/cuda/stream_command_buffer.c +++ b/runtime/src/iree/hal/drivers/cuda/stream_command_buffer.c @@ -10,7 +10,6 @@ #include "iree/hal/drivers/cuda/cuda_status_util.h" #include "iree/hal/drivers/cuda/native_executable.h" #include "iree/hal/drivers/cuda/nccl_channel.h" -#include "iree/hal/drivers/cuda/pipeline_layout.h" #include "iree/hal/utils/collective_batch.h" #include "iree/hal/utils/resource_set.h" @@ -38,12 +37,6 @@ typedef struct iree_hal_cuda_stream_command_buffer_t { // Iteratively constructed batch of collective operations. iree_hal_collective_batch_t collective_batch; - - // TODO(#18154): drop state used by legacy bindings mechanism. - int32_t push_constants[IREE_HAL_CUDA_MAX_PUSH_CONSTANT_COUNT]; - struct { - CUdeviceptr bindings[IREE_HAL_CUDA_MAX_DESCRIPTOR_SET_BINDING_COUNT]; - } descriptor_sets[IREE_HAL_CUDA_MAX_DESCRIPTOR_SET_COUNT]; } iree_hal_cuda_stream_command_buffer_t; static const iree_hal_command_buffer_vtable_t @@ -475,183 +468,6 @@ static iree_status_t iree_hal_cuda_stream_command_buffer_collective( return status; } -static iree_status_t iree_hal_cuda_stream_command_buffer_push_constants( - iree_hal_command_buffer_t* base_command_buffer, - iree_hal_pipeline_layout_t* pipeline_layout, iree_host_size_t offset, - const void* values, iree_host_size_t values_length) { - iree_hal_cuda_stream_command_buffer_t* command_buffer = - iree_hal_cuda_stream_command_buffer_cast(base_command_buffer); - IREE_TRACE_ZONE_BEGIN(z0); - - iree_host_size_t constant_base_index = offset / sizeof(int32_t); - for (iree_host_size_t i = 0; i < values_length / sizeof(int32_t); i++) { - command_buffer->push_constants[i + constant_base_index] = - ((uint32_t*)values)[i]; - } - - IREE_TRACE_ZONE_END(z0); - return iree_ok_status(); -} - -static iree_status_t iree_hal_cuda_stream_command_buffer_push_descriptor_set( - iree_hal_command_buffer_t* base_command_buffer, - iree_hal_pipeline_layout_t* pipeline_layout, uint32_t set, - iree_host_size_t binding_count, const iree_hal_buffer_ref_t* bindings) { - if (binding_count > IREE_HAL_CUDA_MAX_DESCRIPTOR_SET_BINDING_COUNT) { - return iree_make_status( - IREE_STATUS_RESOURCE_EXHAUSTED, - "exceeded available binding slots for push " - "descriptor set #%" PRIu32 "; requested %" PRIhsz " vs. maximal %d", - set, binding_count, IREE_HAL_CUDA_MAX_DESCRIPTOR_SET_BINDING_COUNT); - } - - iree_hal_cuda_stream_command_buffer_t* command_buffer = - iree_hal_cuda_stream_command_buffer_cast(base_command_buffer); - IREE_TRACE_ZONE_BEGIN(z0); - - CUdeviceptr* current_bindings = command_buffer->descriptor_sets[set].bindings; - for (iree_host_size_t i = 0; i < binding_count; i++) { - const iree_hal_buffer_ref_t* binding = &bindings[i]; - CUdeviceptr device_ptr = 0; - if (binding->buffer) { - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_hal_resource_set_insert(command_buffer->resource_set, 1, - &binding->buffer)); - CUdeviceptr device_buffer = iree_hal_cuda_buffer_device_pointer( - iree_hal_buffer_allocated_buffer(binding->buffer)); - iree_device_size_t offset = iree_hal_buffer_byte_offset(binding->buffer); - device_ptr = device_buffer + offset + binding->offset; - } - current_bindings[binding->ordinal] = device_ptr; - } - - IREE_TRACE_ZONE_END(z0); - return iree_ok_status(); -} - -static iree_status_t iree_hal_cuda_stream_command_buffer_dispatch( - iree_hal_command_buffer_t* base_command_buffer, - iree_hal_executable_t* executable, int32_t entry_point, - uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z, - iree_hal_dispatch_flags_t flags) { - iree_hal_cuda_stream_command_buffer_t* command_buffer = - iree_hal_cuda_stream_command_buffer_cast(base_command_buffer); - IREE_TRACE_ZONE_BEGIN(z0); - - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, - iree_hal_cuda_stream_command_buffer_flush_collectives(command_buffer)); - - // Lookup kernel parameters used for side-channeling additional launch - // information from the compiler. - iree_hal_cuda_kernel_info_t kernel_info; - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_hal_cuda_native_executable_entry_point_kernel_info( - executable, entry_point, &kernel_info)); - - IREE_CUDA_STREAM_TRACE_ZONE_BEGIN_EXTERNAL( - command_buffer->tracing_context, &command_buffer->tracing_event_list, - command_buffer->cu_stream, IREE_HAL_CUDA_TRACING_VERBOSITY_FINE, - kernel_info.source_filename.data, kernel_info.source_filename.size, - kernel_info.source_line, kernel_info.function_name.data, - kernel_info.function_name.size, - /*name=*/NULL, 0); - - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_hal_resource_set_insert(command_buffer->resource_set, 1, - &executable)); - - // The total number of descriptors across all descriptor sets. - iree_host_size_t descriptor_count = - iree_hal_cuda_pipeline_layout_total_binding_count(kernel_info.layout); - // The total number of push constants. - iree_host_size_t push_constant_count = - iree_hal_cuda_pipeline_layout_push_constant_count(kernel_info.layout); - // We append push constants to the end of descriptors to form a linear chain - // of kernel arguments. - iree_host_size_t kernel_params_count = descriptor_count + push_constant_count; - iree_host_size_t kernel_params_length = kernel_params_count * sizeof(void*); - - // Per CUDA API requirements, we need two levels of indirection for passing - // kernel arguments in. - // "If the kernel has N parameters, then kernelParams needs to be an array - // of N pointers. Each pointer, from kernelParams[0] to kernelParams[N-1], - // points to the region of memory from which the actual parameter will be - // copied." - // - // (From the cuGraphAddKernelNode API doc in - // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GRAPH.html#group__CUDA__GRAPH_1g50d871e3bd06c1b835e52f2966ef366b) - // - // It means each kernel_params[i] is itself a pointer to the corresponding - // element at the *second* inline allocation at the end of the current - // segment. - iree_host_size_t total_size = kernel_params_length * 2; - uint8_t* storage_base = NULL; - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_arena_allocate(&command_buffer->arena, total_size, - (void**)&storage_base)); - void** params_ptr = (void**)storage_base; - - // Set up kernel arguments to point to the payload slots. - CUdeviceptr* payload_ptr = - (CUdeviceptr*)((uint8_t*)params_ptr + kernel_params_length); - for (size_t i = 0; i < kernel_params_count; i++) { - params_ptr[i] = &payload_ptr[i]; - } - - // Copy descriptors from all sets to the end of the current segment for later - // access. - iree_host_size_t set_count = - iree_hal_cuda_pipeline_layout_descriptor_set_count(kernel_info.layout); - for (iree_host_size_t i = 0; i < set_count; ++i) { - // TODO: cache this information in the kernel info to avoid recomputation. - iree_host_size_t binding_count = - iree_hal_cuda_descriptor_set_layout_binding_count( - iree_hal_cuda_pipeline_layout_descriptor_set_layout( - kernel_info.layout, i)); - iree_host_size_t index = - iree_hal_cuda_pipeline_layout_base_binding_index(kernel_info.layout, i); - memcpy(payload_ptr + index, command_buffer->descriptor_sets[i].bindings, - binding_count * sizeof(CUdeviceptr)); - } - - // Append the push constants to the kernel arguments. - iree_host_size_t base_index = - iree_hal_cuda_pipeline_layout_push_constant_index(kernel_info.layout); - // As commented in the above, what each kernel parameter points to is a - // CUdeviceptr, which as the size of a pointer on the target machine. we are - // just storing a 32-bit value for the push constant here instead. So we must - // process one element each type, for 64-bit machines. - for (iree_host_size_t i = 0; i < push_constant_count; i++) { - *((uint32_t*)params_ptr[base_index + i]) = - command_buffer->push_constants[i]; - } - - IREE_CUDA_RETURN_AND_END_ZONE_IF_ERROR( - z0, command_buffer->cuda_symbols, - cuLaunchKernel(kernel_info.function, workgroup_x, workgroup_y, - workgroup_z, kernel_info.block_size[0], - kernel_info.block_size[1], kernel_info.block_size[2], - kernel_info.shared_memory_size, command_buffer->cu_stream, - params_ptr, NULL), - "cuLaunchKernel"); - - IREE_CUDA_STREAM_TRACE_ZONE_END( - command_buffer->tracing_context, &command_buffer->tracing_event_list, - command_buffer->cu_stream, IREE_HAL_CUDA_TRACING_VERBOSITY_FINE); - - IREE_TRACE_ZONE_END(z0); - return iree_ok_status(); -} - -static iree_status_t iree_hal_cuda_stream_command_buffer_dispatch_indirect( - iree_hal_command_buffer_t* base_command_buffer, - iree_hal_executable_t* executable, int32_t entry_point, - iree_hal_buffer_ref_t workgroups_ref, iree_hal_dispatch_flags_t flags) { - return iree_make_status(IREE_STATUS_UNIMPLEMENTED, - "need cuda implementation of dispatch indirect"); -} - static iree_status_t iree_hal_cuda_stream_command_buffer_dispatch2( iree_hal_command_buffer_t* base_command_buffer, iree_hal_executable_t* executable, int32_t entry_point, @@ -667,17 +483,19 @@ static iree_status_t iree_hal_cuda_stream_command_buffer_dispatch2( // Lookup kernel parameters used for side-channeling additional launch // information from the compiler. - iree_hal_cuda_kernel_info_t kernel_info; + const iree_hal_cuda_kernel_params_t* kernel_params = NULL; IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_hal_cuda_native_executable_entry_point_kernel_info( - executable, entry_point, &kernel_info)); + z0, iree_hal_cuda_native_executable_lookup_kernel_params( + executable, entry_point, &kernel_params)); IREE_CUDA_STREAM_TRACE_ZONE_BEGIN_EXTERNAL( command_buffer->tracing_context, &command_buffer->tracing_event_list, command_buffer->cu_stream, IREE_HAL_CUDA_TRACING_VERBOSITY_FINE, - kernel_info.source_filename.data, kernel_info.source_filename.size, - kernel_info.source_line, kernel_info.function_name.data, - kernel_info.function_name.size, /*name=*/NULL, 0); + kernel_params->debug_info.source_filename.data, + kernel_params->debug_info.source_filename.size, + kernel_params->debug_info.source_line, + kernel_params->debug_info.name.data, kernel_params->debug_info.name.size, + /*name=*/NULL, 0); IREE_RETURN_AND_END_ZONE_IF_ERROR( z0, iree_hal_resource_set_insert(command_buffer->resource_set, 1, @@ -686,7 +504,7 @@ static iree_status_t iree_hal_cuda_stream_command_buffer_dispatch2( // We append push constants to the end of descriptors to form a linear chain // of kernel arguments. iree_host_size_t kernel_params_count = - kernel_info.binding_count + kernel_info.constant_count; + kernel_params->binding_count + kernel_params->constant_count; iree_host_size_t kernel_params_length = kernel_params_count * sizeof(void*); // TODO: use packed parameters instead of the indirection mechanism - this @@ -735,17 +553,18 @@ static iree_status_t iree_hal_cuda_stream_command_buffer_dispatch2( // CUdeviceptr, which as the size of a pointer on the target machine. we are // just storing a 32-bit value for the push constant here instead. So we must // process one element each type, for 64-bit machines. - for (iree_host_size_t i = 0; i < kernel_info.constant_count; i++) { - *((uint32_t*)params_ptr[kernel_info.binding_count + i]) = + for (iree_host_size_t i = 0; i < kernel_params->constant_count; i++) { + *((uint32_t*)params_ptr[kernel_params->binding_count + i]) = ((const uint32_t*)constants.data)[i]; } IREE_CUDA_RETURN_AND_END_ZONE_IF_ERROR( z0, command_buffer->cuda_symbols, - cuLaunchKernel(kernel_info.function, workgroup_count[0], + cuLaunchKernel(kernel_params->function, workgroup_count[0], workgroup_count[1], workgroup_count[2], - kernel_info.block_size[0], kernel_info.block_size[1], - kernel_info.block_size[2], kernel_info.shared_memory_size, + kernel_params->block_dims[0], kernel_params->block_dims[1], + kernel_params->block_dims[2], + kernel_params->block_shared_memory_size, command_buffer->cu_stream, params_ptr, NULL), "cuLaunchKernel"); @@ -784,12 +603,6 @@ static const iree_hal_command_buffer_vtable_t .update_buffer = iree_hal_cuda_stream_command_buffer_update_buffer, .copy_buffer = iree_hal_cuda_stream_command_buffer_copy_buffer, .collective = iree_hal_cuda_stream_command_buffer_collective, - .push_constants = iree_hal_cuda_stream_command_buffer_push_constants, - .push_descriptor_set = - iree_hal_cuda_stream_command_buffer_push_descriptor_set, - .dispatch = iree_hal_cuda_stream_command_buffer_dispatch, - .dispatch_indirect = - iree_hal_cuda_stream_command_buffer_dispatch_indirect, .dispatch2 = iree_hal_cuda_stream_command_buffer_dispatch2, .dispatch2_indirect = iree_hal_cuda_stream_command_buffer_dispatch2_indirect, diff --git a/runtime/src/iree/schemas/cuda_executable_def.fbs b/runtime/src/iree/schemas/cuda_executable_def.fbs index 0ba9c2552e388..550c3d3d58214 100644 --- a/runtime/src/iree/schemas/cuda_executable_def.fbs +++ b/runtime/src/iree/schemas/cuda_executable_def.fbs @@ -12,33 +12,62 @@ namespace iree.hal.cuda; file_identifier "CDA1"; file_extension "cda1"; -// A struct for the kernel block size along each dimensions. -struct BlockSize { +// A struct for the kernel block size along each dimension. +struct BlockDims { x:uint32; y:uint32; z:uint32; } -table ExecutableDef { - // A map of entry point ordinals to string names as used in the shader - // library. - entry_points:[string]; - - // Block sizes for each entry point. - // - // Currently the thread group size/block size is decided during code gen but - // in CUDA it is set by the runtime. - block_sizes:[BlockSize]; - // Size of dynamic shared memory. - shared_memory_size:[uint32]; - - // PTX string of the module. +// Describes the behavior of each binding. +// Roughly maps to iree_hal_descriptor_flags_t but is not required to match +// exactly; if there's additional binding information we want to pass through +// to HIP we can encode that here. +enum BindingBits:uint64 (bit_flags) { + // IREE_HAL_DESCRIPTOR_FLAG_READ_ONLY + READ_ONLY = 0, // 1u << 0 + // IREE_HAL_DESCRIPTOR_FLAG_INDIRECT + INDIRECT = 1, // 1u << 1 +} + +// Information about an exported function on the executable. +table ExportDef { + // Ordinal of the shader library containing the entry point in the executable + // libraries list. + module_ordinal:uint32; + + // String name of the exported kernel function in the module. + kernel_name:string; + + // Grid block dimensions for the export. + block_dims:BlockDims; + + // Size of dynamic shared memory per block. + block_shared_memory_size:uint32; + + // Total number of 32-bit push constants used by the export. + constant_count:uint32; + + // Binding count and flags for each binding. + binding_flags:[BindingBits]; + + // Optional debug information related to the export. + debug_info:iree.hal.debug.ExportDef; +} + +// A library containing one or more exported functions. +table ModuleDef { + // PTX image. ptx_image:string; +} + +table ExecutableDef { + // Exported functions in canonical executable entry point order. + exports:[ExportDef]; - // A map of entry point ordinals to source locations. - // This information is optional and may be used by debuggers and profilers to - // associate executable entry points with the source that generated them. - source_locations:[iree.hal.debug.FileLineLocDef]; + // A list of all kernel modules used by the executable. + // Exports index into this list and multiple exports may use the same library. + modules:[ModuleDef]; // Embedded source files sorted ascending by path. source_files:[iree.hal.debug.SourceFileDef];