Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[flang][cuda] Translate cuf.register_kernel and cuf.register_module #112972

Open
wants to merge 1 commit into
base: users/clementval/cuf/module
Choose a base branch
from

Conversation

clementval
Copy link
Contributor

Add LLVM IR Translation for cuf.register_module and cuf.register_kernel. These are lowered to function call to the CUF runtime entries.

@llvmbot llvmbot added flang:runtime flang Flang issues not falling into any other category flang:fir-hlfir labels Oct 18, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented Oct 18, 2024

@llvm/pr-subscribers-flang-fir-hlfir

@llvm/pr-subscribers-flang-runtime

Author: Valentin Clement (バレンタイン クレメン) (clementval)

Changes

Add LLVM IR Translation for cuf.register_module and cuf.register_kernel. These are lowered to function call to the CUF runtime entries.


Full diff: https://github.com/llvm/llvm-project/pull/112972.diff

8 Files Affected:

  • (added) flang/include/flang/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.h (+29)
  • (modified) flang/include/flang/Optimizer/Support/InitFIR.h (+2)
  • (added) flang/include/flang/Runtime/CUDA/registration.h (+28)
  • (modified) flang/lib/Optimizer/Dialect/CUF/CMakeLists.txt (+1)
  • (added) flang/lib/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.cpp (+104)
  • (modified) flang/lib/Optimizer/Transforms/CufOpConversion.cpp (+1)
  • (modified) flang/runtime/CUDA/CMakeLists.txt (+1)
  • (added) flang/runtime/CUDA/registration.cpp (+31)
diff --git a/flang/include/flang/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.h b/flang/include/flang/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.h
new file mode 100644
index 00000000000000..f3edb7fca649d0
--- /dev/null
+++ b/flang/include/flang/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.h
@@ -0,0 +1,29 @@
+//===- CUFToLLVMIRTranslation.h - CUF Dialect to LLVM IR --------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This provides registration calls for GPU dialect to LLVM IR translation.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef FLANG_OPTIMIZER_DIALECT_CUF_GPUTOLLVMIRTRANSLATION_H_
+#define FLANG_OPTIMIZER_DIALECT_CUF_GPUTOLLVMIRTRANSLATION_H_
+
+namespace mlir {
+class DialectRegistry;
+class MLIRContext;
+} // namespace mlir
+
+namespace cuf {
+
+/// Register the CUF dialect and the translation from it to the LLVM IR in
+/// the given registry.
+void registerCUFDialectTranslation(mlir::DialectRegistry &registry);
+
+} // namespace cuf
+
+#endif // FLANG_OPTIMIZER_DIALECT_CUF_GPUTOLLVMIRTRANSLATION_H_
diff --git a/flang/include/flang/Optimizer/Support/InitFIR.h b/flang/include/flang/Optimizer/Support/InitFIR.h
index 04a5dd323e5508..1c61c367199923 100644
--- a/flang/include/flang/Optimizer/Support/InitFIR.h
+++ b/flang/include/flang/Optimizer/Support/InitFIR.h
@@ -14,6 +14,7 @@
 #define FORTRAN_OPTIMIZER_SUPPORT_INITFIR_H
 
 #include "flang/Optimizer/Dialect/CUF/CUFDialect.h"
+#include "flang/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.h"
 #include "flang/Optimizer/Dialect/FIRDialect.h"
 #include "flang/Optimizer/HLFIR/HLFIRDialect.h"
 #include "mlir/Conversion/Passes.h"
@@ -61,6 +62,7 @@ inline void addFIRExtensions(mlir::DialectRegistry &registry,
   if (addFIRInlinerInterface)
     addFIRInlinerExtension(registry);
   addFIRToLLVMIRExtension(registry);
+  cuf::registerCUFDialectTranslation(registry);
 }
 
 inline void loadNonCodegenDialects(mlir::MLIRContext &context) {
diff --git a/flang/include/flang/Runtime/CUDA/registration.h b/flang/include/flang/Runtime/CUDA/registration.h
new file mode 100644
index 00000000000000..cbe202c4d23e0d
--- /dev/null
+++ b/flang/include/flang/Runtime/CUDA/registration.h
@@ -0,0 +1,28 @@
+//===-- include/flang/Runtime/CUDA/registration.h ---------------*- C -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef FORTRAN_RUNTIME_CUDA_REGISTRATION_H_
+#define FORTRAN_RUNTIME_CUDA_REGISTRATION_H_
+
+#include "flang/Runtime/entry-names.h"
+#include <cstddef>
+
+namespace Fortran::runtime::cuda {
+
+extern "C" {
+
+/// Register a CUDA module.
+void *RTDECL(CUFRegisterModule)(void *data);
+
+/// Register a device function.
+void RTDECL(CUFRegisterFunction)(void **module, const char *fct);
+
+} // extern "C"
+
+} // namespace Fortran::runtime::cuda
+#endif // FORTRAN_RUNTIME_CUDA_REGISTRATION_H_
diff --git a/flang/lib/Optimizer/Dialect/CUF/CMakeLists.txt b/flang/lib/Optimizer/Dialect/CUF/CMakeLists.txt
index b2221199995d58..5d4bd0785971f7 100644
--- a/flang/lib/Optimizer/Dialect/CUF/CMakeLists.txt
+++ b/flang/lib/Optimizer/Dialect/CUF/CMakeLists.txt
@@ -3,6 +3,7 @@ add_subdirectory(Attributes)
 add_flang_library(CUFDialect
   CUFDialect.cpp
   CUFOps.cpp
+  CUFToLLVMIRTranslation.cpp
 
   DEPENDS
   MLIRIR
diff --git a/flang/lib/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.cpp b/flang/lib/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.cpp
new file mode 100644
index 00000000000000..c6c9f96b811352
--- /dev/null
+++ b/flang/lib/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.cpp
@@ -0,0 +1,104 @@
+//===- CUFToLLVMIRTranslation.cpp - Translate CUF dialect to LLVM IR ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a translation between the MLIR CUF dialect and LLVM IR.
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.h"
+#include "flang/Optimizer/Dialect/CUF/CUFOps.h"
+#include "flang/Runtime/entry-names.h"
+#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
+#include "mlir/Target/LLVMIR/ModuleTranslation.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Module.h"
+#include "llvm/Support/FormatVariadic.h"
+
+using namespace mlir;
+
+namespace {
+
+LogicalResult registerModule(cuf::RegisterModuleOp op,
+                             llvm::IRBuilderBase &builder,
+                             LLVM::ModuleTranslation &moduleTranslation) {
+  std::string binaryIdentifier =
+      op.getName().getLeafReference().str() + "_bin_cst";
+  llvm::Module *module = moduleTranslation.getLLVMModule();
+  llvm::Value *binary = module->getGlobalVariable(binaryIdentifier, true);
+  if (!binary)
+    return op.emitError() << "Couldn't find the binary: " << binaryIdentifier;
+
+  llvm::Type *ptrTy = builder.getPtrTy(0);
+  llvm::FunctionCallee fct = module->getOrInsertFunction(
+      RTNAME_STRING(CUFRegisterModule),
+      llvm::FunctionType::get(ptrTy, ArrayRef<llvm::Type *>({ptrTy}), false));
+  auto *handle = builder.CreateCall(fct, {binary});
+  moduleTranslation.mapValue(op->getResults().front()) = handle;
+  return mlir::success();
+}
+
+llvm::Value *getOrCreateFunctionName(llvm::Module *module,
+                                     llvm::IRBuilderBase &builder,
+                                     llvm::StringRef moduleName,
+                                     llvm::StringRef kernelName) {
+  std::string globalName =
+      std::string(llvm::formatv("{0}_{1}_kernel_name", moduleName, kernelName));
+
+  if (llvm::GlobalVariable *gv = module->getGlobalVariable(globalName))
+    return gv;
+
+  return builder.CreateGlobalString(kernelName, globalName);
+}
+
+LogicalResult registerKernel(cuf::RegisterKernelOp op,
+                             llvm::IRBuilderBase &builder,
+                             LLVM::ModuleTranslation &moduleTranslation) {
+  llvm::Module *module = moduleTranslation.getLLVMModule();
+  llvm::Type *ptrTy = builder.getPtrTy(0);
+  llvm::FunctionCallee fct = module->getOrInsertFunction(
+      RTNAME_STRING(CUFRegisterFunction),
+      llvm::FunctionType::get(ptrTy, ArrayRef<llvm::Type *>({ptrTy, ptrTy}),
+                              false));
+  llvm::Value *modulePtr = moduleTranslation.lookupValue(op.getModulePtr());
+  builder.CreateCall(
+      fct, {modulePtr, getOrCreateFunctionName(module, builder,
+                                               op.getKernelModuleName().str(),
+                                               op.getKernelName().str())});
+  return mlir::success();
+}
+
+class CUFDialectLLVMIRTranslationInterface
+    : public LLVMTranslationDialectInterface {
+public:
+  using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
+
+  LogicalResult
+  convertOperation(Operation *operation, llvm::IRBuilderBase &builder,
+                   LLVM::ModuleTranslation &moduleTranslation) const override {
+    return llvm::TypeSwitch<Operation *, LogicalResult>(operation)
+        .Case([&](cuf::RegisterModuleOp op) {
+          return registerModule(op, builder, moduleTranslation);
+        })
+        .Case([&](cuf::RegisterKernelOp op) {
+          return registerKernel(op, builder, moduleTranslation);
+        })
+        .Default([&](Operation *op) {
+          return op->emitError("unsupported GPU operation: ") << op->getName();
+        });
+  }
+};
+
+} // namespace
+
+void cuf::registerCUFDialectTranslation(DialectRegistry &registry) {
+  registry.insert<cuf::CUFDialect>();
+  registry.addExtension(+[](MLIRContext *ctx, cuf::CUFDialect *dialect) {
+    dialect->addInterfaces<CUFDialectLLVMIRTranslationInterface>();
+  });
+}
diff --git a/flang/lib/Optimizer/Transforms/CufOpConversion.cpp b/flang/lib/Optimizer/Transforms/CufOpConversion.cpp
index 91ef1259332de9..e81fafb529a27d 100644
--- a/flang/lib/Optimizer/Transforms/CufOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/CufOpConversion.cpp
@@ -20,6 +20,7 @@
 #include "flang/Runtime/CUDA/descriptor.h"
 #include "flang/Runtime/CUDA/memory.h"
 #include "flang/Runtime/allocatable.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
diff --git a/flang/runtime/CUDA/CMakeLists.txt b/flang/runtime/CUDA/CMakeLists.txt
index 193dd77e934558..86523b419f8711 100644
--- a/flang/runtime/CUDA/CMakeLists.txt
+++ b/flang/runtime/CUDA/CMakeLists.txt
@@ -18,6 +18,7 @@ add_flang_library(${CUFRT_LIBNAME}
   allocatable.cpp
   descriptor.cpp
   memory.cpp
+  registration.cpp
 )
 
 if (BUILD_SHARED_LIBS)
diff --git a/flang/runtime/CUDA/registration.cpp b/flang/runtime/CUDA/registration.cpp
new file mode 100644
index 00000000000000..971192b16156be
--- /dev/null
+++ b/flang/runtime/CUDA/registration.cpp
@@ -0,0 +1,31 @@
+//===-- runtime/CUDA/registration.cpp -------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Runtime/CUDA/registration.h"
+
+#include "cuda_runtime.h"
+
+namespace Fortran::runtime::cuda {
+
+extern "C" {
+
+extern void **__cudaRegisterFatBinary(void *data);
+extern void __cudaRegisterFunction(void **fatCubinHandle, const char *hostFun,
+    char *deviceFun, const char *deviceName, int thread_limit, uint3 *tid,
+    uint3 *bid, dim3 *bDim, dim3 *gDim, int *wSize);
+
+void *RTDECL(CUFRegisterModule)(void *data) {
+  return __cudaRegisterFatBinary(data);
+}
+
+void RTDEF(CUFRegisterFunction)(void **module, const char *fct) {
+  __cudaRegisterFunction(module, fct, (char *)fct, fct, -1, (uint3 *)0,
+      (uint3 *)0, (dim3 *)0, (dim3 *)0, (int *)0);
+}
+}
+} // namespace Fortran::runtime::cuda

Copy link
Contributor

@Renaud-K Renaud-K left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks really good. Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang:runtime flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants