From 398b8ace1e5ddc7f29d3654faa3a052a48249126 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Fri, 15 May 2020 09:11:50 +0200 Subject: [PATCH] Use LLVM's byval instead of rewriting kernels. --- src/irgen.jl | 138 ++++----------------------------------------------- test/ptx.jl | 9 +++- 2 files changed, 16 insertions(+), 131 deletions(-) diff --git a/src/irgen.jl b/src/irgen.jl index fa37df5d..0fdc5284 100644 --- a/src/irgen.jl +++ b/src/irgen.jl @@ -561,138 +561,18 @@ end ## kernel promotion # promote a function to a kernel -# FIXME: sig vs tt (code_llvm vs cufunction) -function promote_kernel!(job::CompilerJob, mod::LLVM.Module, entry_f::LLVM.Function) - kernel = wrap_entry!(job, mod, entry_f) +function promote_kernel!(job::CompilerJob, mod::LLVM.Module, kernel::LLVM.Function) + # pass non-opaque pointer arguments by value (this improves performance, + # and is mandated by certain back-ends like SPIR-V). + kernel_ft = eltype(llvmtype(kernel)::LLVM.PointerType)::LLVM.FunctionType + for (i, param_ft) in enumerate(parameters(kernel_ft)) + if param_ft isa LLVM.PointerType && issized(eltype(param_ft)) + push!(parameter_attributes(kernel, i), EnumAttribute("byval")) + end + end # target-specific processing process_kernel!(job, mod, kernel) return kernel end - -function wrapper_type(julia_t::Type, codegen_t::LLVMType)::LLVMType - if !isbitstype(julia_t) - # don't pass jl_value_t by value; it's an opaque structure - return codegen_t - elseif isa(codegen_t, LLVM.PointerType) && !(julia_t <: Ptr) - # we didn't specify a pointer, but codegen passes one anyway. - # make the wrapper accept the underlying value instead. - return eltype(codegen_t) - else - return codegen_t - end -end - -# generate a kernel wrapper to fix & improve argument passing -function wrap_entry!(job::CompilerJob, mod::LLVM.Module, entry_f::LLVM.Function) - entry_ft = eltype(llvmtype(entry_f)::LLVM.PointerType)::LLVM.FunctionType - @compiler_assert return_type(entry_ft) == LLVM.VoidType(JuliaContext()) job - - # filter out types which don't occur in the LLVM function signatures - sig = Base.signature_type(job.source.f, job.source.tt)::Type - julia_types = Type[] - for dt::Type in sig.parameters - if !isghosttype(dt) && (VERSION < v"1.5.0-DEV.581" || !Core.Compiler.isconstType(dt)) - push!(julia_types, dt) - end - end - - # generate the wrapper function type & definition - wrapper_types = LLVM.LLVMType[wrapper_type(julia_t, codegen_t) - for (julia_t, codegen_t) - in zip(julia_types, parameters(entry_ft))] - wrapper_fn = LLVM.name(entry_f) - LLVM.name!(entry_f, wrapper_fn * ".inner") - wrapper_ft = LLVM.FunctionType(LLVM.VoidType(JuliaContext()), wrapper_types) - wrapper_f = LLVM.Function(mod, wrapper_fn, wrapper_ft) - - # emit IR performing the "conversions" - let builder = Builder(JuliaContext()) - entry = BasicBlock(wrapper_f, "entry", JuliaContext()) - position!(builder, entry) - - wrapper_args = Vector{LLVM.Value}() - - # perform argument conversions - codegen_types = parameters(entry_ft) - wrapper_params = parameters(wrapper_f) - param_index = 0 - for (julia_t, codegen_t, wrapper_t, wrapper_param) in - zip(julia_types, codegen_types, wrapper_types, wrapper_params) - param_index += 1 - if codegen_t != wrapper_t - # the wrapper argument doesn't match the kernel parameter type. - # this only happens when codegen wants to pass a pointer. - @compiler_assert isa(codegen_t, LLVM.PointerType) job - @compiler_assert eltype(codegen_t) == wrapper_t job - - # copy the argument value to a stack slot, and reference it. - ptr = alloca!(builder, wrapper_t) - if LLVM.addrspace(codegen_t) != 0 - ptr = addrspacecast!(builder, ptr, codegen_t) - end - store!(builder, wrapper_param, ptr) - push!(wrapper_args, ptr) - else - push!(wrapper_args, wrapper_param) - for attr in collect(parameter_attributes(entry_f, param_index)) - push!(parameter_attributes(wrapper_f, param_index), attr) - end - end - end - - call!(builder, entry_f, wrapper_args) - - ret!(builder) - - dispose(builder) - end - - # early-inline the original entry function into the wrapper - push!(function_attributes(entry_f), EnumAttribute("alwaysinline", 0, JuliaContext())) - linkage!(entry_f, LLVM.API.LLVMInternalLinkage) - - fixup_metadata!(entry_f) - ModulePassManager() do pm - always_inliner!(pm) - run!(pm, mod) - end - - return wrapper_f -end - -# HACK: get rid of invariant.load and const TBAA metadata on loads from pointer args, -# since storing to a stack slot violates the semantics of those attributes. -# TODO: can we emit a wrapper that doesn't violate Julia's metadata? -function fixup_metadata!(f::LLVM.Function) - for param in parameters(f) - if isa(llvmtype(param), LLVM.PointerType) - # collect all uses of the pointer - worklist = Vector{LLVM.Instruction}(user.(collect(uses(param)))) - while !isempty(worklist) - value = popfirst!(worklist) - - # remove the invariant.load attribute - md = metadata(value) - if haskey(md, LLVM.MD_invariant_load) - delete!(md, LLVM.MD_invariant_load) - end - if haskey(md, LLVM.MD_tbaa) - delete!(md, LLVM.MD_tbaa) - end - - # recurse on the output of some instructions - if isa(value, LLVM.BitCastInst) || - isa(value, LLVM.GetElementPtrInst) || - isa(value, LLVM.AddrSpaceCastInst) - append!(worklist, user.(collect(uses(value)))) - end - - # IMPORTANT NOTE: if we ever want to inline functions at the LLVM level, - # we need to recurse into call instructions here, and strip metadata from - # called functions (see CUDAnative.jl#238). - end - end - end -end diff --git a/test/ptx.jl b/test/ptx.jl index 8797b182..eea517f7 100644 --- a/test/ptx.jl +++ b/test/ptx.jl @@ -15,8 +15,9 @@ include("definitions/ptx.jl") # not a jl_throw referencing a jl_value_t representing the exception @test !occursin("jl_throw", ir) end + @testset "kernel functions" begin -@testset "wrapper function aggregate rewriting" begin +@testset "kernel argument attributes" begin kernel(x) = return @eval struct Aggregate @@ -31,7 +32,11 @@ end end ir = sprint(io->ptx_code_llvm(io, kernel, Tuple{Aggregate}; kernel=true)) - @test occursin(r"@.*julia_kernel.+\(({ i64 }|\[1 x i64\])\)", ir) + if VERSION < v"1.5.0-DEV.802" + @test occursin(r"@.*julia_kernel.+\(({ i64 }|\[1 x i64\]) addrspace\(\d+\)?\*.+byval", ir) + else + @test occursin(r"@.*julia_kernel.+\(({ i64 }|\[1 x i64\])\*.+byval", ir) + end end @testset "property_annotations" begin