Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use LLVM's byval instead of rewriting kernels. #16

Merged
merged 1 commit into from
May 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 9 additions & 129 deletions src/irgen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -561,138 +561,18 @@ end
## kernel promotion

# promote a function to a kernel
# FIXME: sig vs tt (code_llvm vs cufunction)
function promote_kernel!(job::CompilerJob, mod::LLVM.Module, entry_f::LLVM.Function)
kernel = wrap_entry!(job, mod, entry_f)
function promote_kernel!(job::CompilerJob, mod::LLVM.Module, kernel::LLVM.Function)
# pass non-opaque pointer arguments by value (this improves performance,
# and is mandated by certain back-ends like SPIR-V).
kernel_ft = eltype(llvmtype(kernel)::LLVM.PointerType)::LLVM.FunctionType
for (i, param_ft) in enumerate(parameters(kernel_ft))
if param_ft isa LLVM.PointerType && issized(eltype(param_ft))
push!(parameter_attributes(kernel, i), EnumAttribute("byval"))
end
end

# target-specific processing
process_kernel!(job, mod, kernel)

return kernel
end

function wrapper_type(julia_t::Type, codegen_t::LLVMType)::LLVMType
if !isbitstype(julia_t)
# don't pass jl_value_t by value; it's an opaque structure
return codegen_t
elseif isa(codegen_t, LLVM.PointerType) && !(julia_t <: Ptr)
# we didn't specify a pointer, but codegen passes one anyway.
# make the wrapper accept the underlying value instead.
return eltype(codegen_t)
else
return codegen_t
end
end

# generate a kernel wrapper to fix & improve argument passing
function wrap_entry!(job::CompilerJob, mod::LLVM.Module, entry_f::LLVM.Function)
entry_ft = eltype(llvmtype(entry_f)::LLVM.PointerType)::LLVM.FunctionType
@compiler_assert return_type(entry_ft) == LLVM.VoidType(JuliaContext()) job

# filter out types which don't occur in the LLVM function signatures
sig = Base.signature_type(job.source.f, job.source.tt)::Type
julia_types = Type[]
for dt::Type in sig.parameters
if !isghosttype(dt) && (VERSION < v"1.5.0-DEV.581" || !Core.Compiler.isconstType(dt))
push!(julia_types, dt)
end
end

# generate the wrapper function type & definition
wrapper_types = LLVM.LLVMType[wrapper_type(julia_t, codegen_t)
for (julia_t, codegen_t)
in zip(julia_types, parameters(entry_ft))]
wrapper_fn = LLVM.name(entry_f)
LLVM.name!(entry_f, wrapper_fn * ".inner")
wrapper_ft = LLVM.FunctionType(LLVM.VoidType(JuliaContext()), wrapper_types)
wrapper_f = LLVM.Function(mod, wrapper_fn, wrapper_ft)

# emit IR performing the "conversions"
let builder = Builder(JuliaContext())
entry = BasicBlock(wrapper_f, "entry", JuliaContext())
position!(builder, entry)

wrapper_args = Vector{LLVM.Value}()

# perform argument conversions
codegen_types = parameters(entry_ft)
wrapper_params = parameters(wrapper_f)
param_index = 0
for (julia_t, codegen_t, wrapper_t, wrapper_param) in
zip(julia_types, codegen_types, wrapper_types, wrapper_params)
param_index += 1
if codegen_t != wrapper_t
# the wrapper argument doesn't match the kernel parameter type.
# this only happens when codegen wants to pass a pointer.
@compiler_assert isa(codegen_t, LLVM.PointerType) job
@compiler_assert eltype(codegen_t) == wrapper_t job

# copy the argument value to a stack slot, and reference it.
ptr = alloca!(builder, wrapper_t)
if LLVM.addrspace(codegen_t) != 0
ptr = addrspacecast!(builder, ptr, codegen_t)
end
store!(builder, wrapper_param, ptr)
push!(wrapper_args, ptr)
else
push!(wrapper_args, wrapper_param)
for attr in collect(parameter_attributes(entry_f, param_index))
push!(parameter_attributes(wrapper_f, param_index), attr)
end
end
end

call!(builder, entry_f, wrapper_args)

ret!(builder)

dispose(builder)
end

# early-inline the original entry function into the wrapper
push!(function_attributes(entry_f), EnumAttribute("alwaysinline", 0, JuliaContext()))
linkage!(entry_f, LLVM.API.LLVMInternalLinkage)

fixup_metadata!(entry_f)
ModulePassManager() do pm
always_inliner!(pm)
run!(pm, mod)
end

return wrapper_f
end

# HACK: get rid of invariant.load and const TBAA metadata on loads from pointer args,
# since storing to a stack slot violates the semantics of those attributes.
# TODO: can we emit a wrapper that doesn't violate Julia's metadata?
function fixup_metadata!(f::LLVM.Function)
for param in parameters(f)
if isa(llvmtype(param), LLVM.PointerType)
# collect all uses of the pointer
worklist = Vector{LLVM.Instruction}(user.(collect(uses(param))))
while !isempty(worklist)
value = popfirst!(worklist)

# remove the invariant.load attribute
md = metadata(value)
if haskey(md, LLVM.MD_invariant_load)
delete!(md, LLVM.MD_invariant_load)
end
if haskey(md, LLVM.MD_tbaa)
delete!(md, LLVM.MD_tbaa)
end

# recurse on the output of some instructions
if isa(value, LLVM.BitCastInst) ||
isa(value, LLVM.GetElementPtrInst) ||
isa(value, LLVM.AddrSpaceCastInst)
append!(worklist, user.(collect(uses(value))))
end

# IMPORTANT NOTE: if we ever want to inline functions at the LLVM level,
# we need to recurse into call instructions here, and strip metadata from
# called functions (see CUDAnative.jl#238).
end
end
end
end
9 changes: 7 additions & 2 deletions test/ptx.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ include("definitions/ptx.jl")
# not a jl_throw referencing a jl_value_t representing the exception
@test !occursin("jl_throw", ir)
end

@testset "kernel functions" begin
@testset "wrapper function aggregate rewriting" begin
@testset "kernel argument attributes" begin
kernel(x) = return

@eval struct Aggregate
Expand All @@ -31,7 +32,11 @@ end
end

ir = sprint(io->ptx_code_llvm(io, kernel, Tuple{Aggregate}; kernel=true))
@test occursin(r"@.*julia_kernel.+\(({ i64 }|\[1 x i64\])\)", ir)
if VERSION < v"1.5.0-DEV.802"
@test occursin(r"@.*julia_kernel.+\(({ i64 }|\[1 x i64\]) addrspace\(\d+\)?\*.+byval", ir)
else
@test occursin(r"@.*julia_kernel.+\(({ i64 }|\[1 x i64\])\*.+byval", ir)
end
end

@testset "property_annotations" begin
Expand Down