Skip to content

Commit

Permalink
Use LLVM's byval instead of rewriting kernels.
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt committed May 15, 2020
1 parent 7fd5e2e commit 398b8ac
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 131 deletions.
138 changes: 9 additions & 129 deletions src/irgen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -561,138 +561,18 @@ end
## kernel promotion

# promote a function to a kernel
# FIXME: sig vs tt (code_llvm vs cufunction)
function promote_kernel!(job::CompilerJob, mod::LLVM.Module, entry_f::LLVM.Function)
kernel = wrap_entry!(job, mod, entry_f)
function promote_kernel!(job::CompilerJob, mod::LLVM.Module, kernel::LLVM.Function)
# pass non-opaque pointer arguments by value (this improves performance,
# and is mandated by certain back-ends like SPIR-V).
kernel_ft = eltype(llvmtype(kernel)::LLVM.PointerType)::LLVM.FunctionType
for (i, param_ft) in enumerate(parameters(kernel_ft))
if param_ft isa LLVM.PointerType && issized(eltype(param_ft))
push!(parameter_attributes(kernel, i), EnumAttribute("byval"))
end
end

# target-specific processing
process_kernel!(job, mod, kernel)

return kernel
end

function wrapper_type(julia_t::Type, codegen_t::LLVMType)::LLVMType
if !isbitstype(julia_t)
# don't pass jl_value_t by value; it's an opaque structure
return codegen_t
elseif isa(codegen_t, LLVM.PointerType) && !(julia_t <: Ptr)
# we didn't specify a pointer, but codegen passes one anyway.
# make the wrapper accept the underlying value instead.
return eltype(codegen_t)
else
return codegen_t
end
end

# generate a kernel wrapper to fix & improve argument passing
function wrap_entry!(job::CompilerJob, mod::LLVM.Module, entry_f::LLVM.Function)
entry_ft = eltype(llvmtype(entry_f)::LLVM.PointerType)::LLVM.FunctionType
@compiler_assert return_type(entry_ft) == LLVM.VoidType(JuliaContext()) job

# filter out types which don't occur in the LLVM function signatures
sig = Base.signature_type(job.source.f, job.source.tt)::Type
julia_types = Type[]
for dt::Type in sig.parameters
if !isghosttype(dt) && (VERSION < v"1.5.0-DEV.581" || !Core.Compiler.isconstType(dt))
push!(julia_types, dt)
end
end

# generate the wrapper function type & definition
wrapper_types = LLVM.LLVMType[wrapper_type(julia_t, codegen_t)
for (julia_t, codegen_t)
in zip(julia_types, parameters(entry_ft))]
wrapper_fn = LLVM.name(entry_f)
LLVM.name!(entry_f, wrapper_fn * ".inner")
wrapper_ft = LLVM.FunctionType(LLVM.VoidType(JuliaContext()), wrapper_types)
wrapper_f = LLVM.Function(mod, wrapper_fn, wrapper_ft)

# emit IR performing the "conversions"
let builder = Builder(JuliaContext())
entry = BasicBlock(wrapper_f, "entry", JuliaContext())
position!(builder, entry)

wrapper_args = Vector{LLVM.Value}()

# perform argument conversions
codegen_types = parameters(entry_ft)
wrapper_params = parameters(wrapper_f)
param_index = 0
for (julia_t, codegen_t, wrapper_t, wrapper_param) in
zip(julia_types, codegen_types, wrapper_types, wrapper_params)
param_index += 1
if codegen_t != wrapper_t
# the wrapper argument doesn't match the kernel parameter type.
# this only happens when codegen wants to pass a pointer.
@compiler_assert isa(codegen_t, LLVM.PointerType) job
@compiler_assert eltype(codegen_t) == wrapper_t job

# copy the argument value to a stack slot, and reference it.
ptr = alloca!(builder, wrapper_t)
if LLVM.addrspace(codegen_t) != 0
ptr = addrspacecast!(builder, ptr, codegen_t)
end
store!(builder, wrapper_param, ptr)
push!(wrapper_args, ptr)
else
push!(wrapper_args, wrapper_param)
for attr in collect(parameter_attributes(entry_f, param_index))
push!(parameter_attributes(wrapper_f, param_index), attr)
end
end
end

call!(builder, entry_f, wrapper_args)

ret!(builder)

dispose(builder)
end

# early-inline the original entry function into the wrapper
push!(function_attributes(entry_f), EnumAttribute("alwaysinline", 0, JuliaContext()))
linkage!(entry_f, LLVM.API.LLVMInternalLinkage)

fixup_metadata!(entry_f)
ModulePassManager() do pm
always_inliner!(pm)
run!(pm, mod)
end

return wrapper_f
end

# HACK: get rid of invariant.load and const TBAA metadata on loads from pointer args,
# since storing to a stack slot violates the semantics of those attributes.
# TODO: can we emit a wrapper that doesn't violate Julia's metadata?
function fixup_metadata!(f::LLVM.Function)
for param in parameters(f)
if isa(llvmtype(param), LLVM.PointerType)
# collect all uses of the pointer
worklist = Vector{LLVM.Instruction}(user.(collect(uses(param))))
while !isempty(worklist)
value = popfirst!(worklist)

# remove the invariant.load attribute
md = metadata(value)
if haskey(md, LLVM.MD_invariant_load)
delete!(md, LLVM.MD_invariant_load)
end
if haskey(md, LLVM.MD_tbaa)
delete!(md, LLVM.MD_tbaa)
end

# recurse on the output of some instructions
if isa(value, LLVM.BitCastInst) ||
isa(value, LLVM.GetElementPtrInst) ||
isa(value, LLVM.AddrSpaceCastInst)
append!(worklist, user.(collect(uses(value))))
end

# IMPORTANT NOTE: if we ever want to inline functions at the LLVM level,
# we need to recurse into call instructions here, and strip metadata from
# called functions (see CUDAnative.jl#238).
end
end
end
end
9 changes: 7 additions & 2 deletions test/ptx.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ include("definitions/ptx.jl")
# not a jl_throw referencing a jl_value_t representing the exception
@test !occursin("jl_throw", ir)
end

@testset "kernel functions" begin
@testset "wrapper function aggregate rewriting" begin
@testset "kernel argument attributes" begin
kernel(x) = return

@eval struct Aggregate
Expand All @@ -31,7 +32,11 @@ end
end

ir = sprint(io->ptx_code_llvm(io, kernel, Tuple{Aggregate}; kernel=true))
@test occursin(r"@.*julia_kernel.+\(({ i64 }|\[1 x i64\])\)", ir)
if VERSION < v"1.5.0-DEV.802"
@test occursin(r"@.*julia_kernel.+\(({ i64 }|\[1 x i64\]) addrspace\(\d+\)?\*.+byval", ir)
else
@test occursin(r"@.*julia_kernel.+\(({ i64 }|\[1 x i64\])\*.+byval", ir)
end
end

@testset "property_annotations" begin
Expand Down

0 comments on commit 398b8ac

Please sign in to comment.