From 834d001e10912c815fa7af14422f60c28162f8d8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Valentin=20Clement=20=28=E3=83=90=E3=83=AC=E3=83=B3?= =?UTF-8?q?=E3=82=BF=E3=82=A4=E3=83=B3=20=E3=82=AF=E3=83=AC=E3=83=A1?= =?UTF-8?q?=E3=83=B3=29?= Date: Thu, 17 Oct 2024 08:30:13 -0700 Subject: [PATCH] [flang][cuda] Relax the verifier for cuf.register_kernel op (#112585) Relax the verifier since the `gpu.func` might be converted to `llvm.func` before `cuf.register_kernel` is converted. --- flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp | 27 ++++++++++++++-------- flang/test/Fir/cuf-invalid.fir | 15 ++++++++++++ 2 files changed, 33 insertions(+), 9 deletions(-) diff --git a/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp b/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp index 9e3bbd1f9cbee9..0b03e070a0076e 100644 --- a/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp +++ b/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp @@ -16,6 +16,7 @@ #include "flang/Optimizer/Dialect/FIRAttr.h" #include "flang/Optimizer/Dialect/FIRType.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" @@ -276,18 +277,26 @@ mlir::LogicalResult cuf::RegisterKernelOp::verify() { mlir::SymbolTable symTab(mod); auto gpuMod = symTab.lookup(getKernelModuleName()); - if (!gpuMod) + if (!gpuMod) { + // If already a gpu.binary then stop the check here. + if (symTab.lookup(getKernelModuleName())) + return mlir::success(); return emitOpError("gpu module not found"); + } mlir::SymbolTable gpuSymTab(gpuMod); - auto func = gpuSymTab.lookup(getKernelName()); - if (!func) - return emitOpError("device function not found"); - - if (!func.isKernel()) - return emitOpError("only kernel gpu.func can be registered"); - - return mlir::success(); + if (auto func = gpuSymTab.lookup(getKernelName())) { + if (!func.isKernel()) + return emitOpError("only kernel gpu.func can be registered"); + return mlir::success(); + } else if (auto func = + gpuSymTab.lookup(getKernelName())) { + if (!func->getAttrOfType( + mlir::gpu::GPUDialect::getKernelFuncAttrName())) + return emitOpError("only gpu.kernel llvm.func can be registered"); + return mlir::success(); + } + return emitOpError("device function not found"); } // Tablegen operators diff --git a/flang/test/Fir/cuf-invalid.fir b/flang/test/Fir/cuf-invalid.fir index a5747b8ee4a3b3..8a1eb48576832c 100644 --- a/flang/test/Fir/cuf-invalid.fir +++ b/flang/test/Fir/cuf-invalid.fir @@ -175,3 +175,18 @@ module attributes {gpu.container_module} { llvm.return } } + +// ----- + +module attributes {gpu.container_module} { + gpu.module @cuda_device_mod { + llvm.func @_QPsub_device1() { + llvm.return + } + } + llvm.func internal @__cudaFortranConstructor() { + // expected-error@+1{{'cuf.register_kernel' op only gpu.kernel llvm.func can be registered}} + cuf.register_kernel @cuda_device_mod::@_QPsub_device1 + llvm.return + } +}