From 959132822c6109c56514e44ebedd4cea77d027c7 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Fri, 7 Aug 2020 08:28:42 +0200 Subject: [PATCH 1/9] Optimize API call prologue. --- src/compiler/exceptions.jl | 2 -- src/state.jl | 5 +++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/compiler/exceptions.jl b/src/compiler/exceptions.jl index 3a9a9fef98..4290da5ddd 100644 --- a/src/compiler/exceptions.jl +++ b/src/compiler/exceptions.jl @@ -82,8 +82,6 @@ function create_exceptions!(mod::CuModule) end # check the exception flags on every API call, similarly to how CUDA handles errors -# FIXME: this is expensive. Maybe kernels should return a `wait`able object, a la KA.jl, -# which then performs the necessary checks. function check_exceptions() for (ctx,buf) in exception_flags if isvalid(ctx) diff --git a/src/state.jl b/src/state.jl index 37840088ed..3ed0a0def6 100644 --- a/src/state.jl +++ b/src/state.jl @@ -25,7 +25,7 @@ proper invalidation. task = current_task() # detect when a different task is now executing on a thread - if @inbounds thread_tasks[tid] != task + if @inbounds thread_tasks[tid].value::Task !== task switched_tasks(tid, task) end @@ -34,6 +34,8 @@ proper invalidation. initialize_thread(tid) end + # FIXME: this is expensive. Maybe kernels should return a `wait`able object, a la KA.jl, + # which then performs the necessary checks. Or only check when launching kernels. check_exceptions() return @@ -56,7 +58,6 @@ const thread_contexts = Union{Nothing,CuContext}[] # compatibility with externally-initialized contexts thread_contexts[tid] = ctx end - end # Julia executes with tasks, so we need to keep track of the active task for each thread From b0b771088f198bfcd0e9a2f81908b3c4c557b56a Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Fri, 7 Aug 2020 09:58:10 +0200 Subject: [PATCH 2/9] Replace device getter with CuCurrentDevice/CuDevice like CuContext. --- lib/cudadrv/context.jl | 33 +++++++++++++++++++++------------ lib/cudadrv/context/primary.jl | 2 +- lib/cudadrv/libcuda.jl | 1 - lib/cudadrv/occupancy.jl | 2 +- res/wrap/wrap.jl | 1 + src/accumulate.jl | 2 +- src/compiler/exceptions.jl | 2 +- src/compiler/execution.jl | 2 +- test/cudadrv/context.jl | 8 ++++---- test/initialization.jl | 3 ++- 10 files changed, 33 insertions(+), 23 deletions(-) diff --git a/lib/cudadrv/context.jl b/lib/cudadrv/context.jl index 34fd267d49..841981f0a5 100644 --- a/lib/cudadrv/context.jl +++ b/lib/cudadrv/context.jl @@ -2,7 +2,7 @@ export CuContext, CuCurrentContext, activate, - synchronize, device + synchronize, CuCurrentDevice ## construction and destruction @@ -125,22 +125,31 @@ end ## context properties """ - device() - device(ctx::CuContext) + CuCurrentDevice() -Returns the device for a context. +Returns the current device, or `nothing` if there is no active device. """ -function device(ctx::CuContext) - push!(CuContext, ctx) +function CuCurrentDevice() device_ref = Ref{CUdevice}() - cuCtxGetDevice(device_ref) - pop!(CuContext) + res = unsafe_cuCtxGetDevice(device_ref) + if res == ERROR_INVALID_CONTEXT + return nothing + elseif res != SUCCESS + throw_api_error(res) + end return CuDevice(Bool, device_ref[]) end -function device() - device_ref = Ref{CUdevice}() - cuCtxGetDevice(device_ref) - return CuDevice(Bool, device_ref[]) + +""" + CuDevice(::CuContext) + +Returns the device for a context. +""" +function CuDevice(ctx::CuContext) + push!(CuContext, ctx) + dev = CuCurrentDevice() + pop!(CuContext) + return dev end """ diff --git a/lib/cudadrv/context/primary.jl b/lib/cudadrv/context/primary.jl index 2b865f1cca..e86ba8920c 100644 --- a/lib/cudadrv/context/primary.jl +++ b/lib/cudadrv/context/primary.jl @@ -40,7 +40,7 @@ does not respect any users of the context, and might make other objects unusable """ function unsafe_release!(ctx::CuContext) if isvalid(ctx) - dev = device(ctx) + dev = CuDevice(ctx) pctx = CuPrimaryContext(dev) if version() >= v"11" cuDevicePrimaryCtxRelease_v2(dev) diff --git a/lib/cudadrv/libcuda.jl b/lib/cudadrv/libcuda.jl index 7b2a4f8b95..c16f627354 100644 --- a/lib/cudadrv/libcuda.jl +++ b/lib/cudadrv/libcuda.jl @@ -152,7 +152,6 @@ end end @checked function cuCtxGetDevice(device) - initialize_api() @runtime_ccall((:cuCtxGetDevice, libcuda()), CUresult, (Ptr{CUdevice},), device) diff --git a/lib/cudadrv/occupancy.jl b/lib/cudadrv/occupancy.jl index cdc094ade9..e9d4ee73a4 100644 --- a/lib/cudadrv/occupancy.jl +++ b/lib/cudadrv/occupancy.jl @@ -25,7 +25,7 @@ function occupancy(fun::CuFunction, threads::Integer; shmem::Integer=0) mod = fun.mod ctx = mod.ctx - dev = device(ctx) + dev = CuDevice(ctx) threads_per_sm = attribute(dev, DEVICE_ATTRIBUTE_MAX_THREADS_PER_MULTIPROCESSOR) warp_size = attribute(dev, DEVICE_ATTRIBUTE_WARP_SIZE) diff --git a/res/wrap/wrap.jl b/res/wrap/wrap.jl index fc6b51a48e..f1d4ceacf0 100644 --- a/res/wrap/wrap.jl +++ b/res/wrap/wrap.jl @@ -192,6 +192,7 @@ preinit_apicalls = Set{String}([ "cuCtxGetCurrent", "cuCtxPushCurrent", "cuCtxPopCurrent", + "cuCtxGetDevice", # this actually does require a context, but we use it unsafely ## primary context management "cuDevicePrimaryCtxGetState", "cuDevicePrimaryCtxRelease", diff --git a/src/accumulate.jl b/src/accumulate.jl index a69a8d1cdf..cde18ae07d 100644 --- a/src/accumulate.jl +++ b/src/accumulate.jl @@ -153,7 +153,7 @@ function scan!(f::Function, output::CuArray{T}, input::CuArray; # determine the grid layout to cover the other dimensions if length(Rother) > 1 - dev = device(kernel.fun.mod.ctx) + dev = CuDevice(kernel.fun.mod.ctx) max_other_blocks = attribute(dev, DEVICE_ATTRIBUTE_MAX_GRID_DIM_Y) blocks_other = (Base.min(length(Rother), max_other_blocks), cld(length(Rother), max_other_blocks)) diff --git a/src/compiler/exceptions.jl b/src/compiler/exceptions.jl index 4290da5ddd..33833d6a3d 100644 --- a/src/compiler/exceptions.jl +++ b/src/compiler/exceptions.jl @@ -89,7 +89,7 @@ function check_exceptions() flag = unsafe_load(ptr) if flag != 0 unsafe_store!(ptr, 0) - dev = device(ctx) + dev = CuDevice(ctx) throw(KernelException(dev)) end end diff --git a/src/compiler/execution.jl b/src/compiler/execution.jl index 70df68f3f1..51cee63c92 100644 --- a/src/compiler/execution.jl +++ b/src/compiler/execution.jl @@ -302,7 +302,7 @@ end function _cufunction(source::FunctionSpec; kwargs...) # compile to PTX ctx = context() - dev = device(ctx) + dev = CuDevice(ctx) cap = supported_capability(dev) target = PTXCompilerTarget(; cap=supported_capability(dev), kwargs...) params = CUDACompilerParams() diff --git a/test/cudadrv/context.jl b/test/cudadrv/context.jl index 629385ec6b..b38439d1fd 100644 --- a/test/cudadrv/context.jl +++ b/test/cudadrv/context.jl @@ -1,14 +1,14 @@ @testset "context" begin ctx = CuCurrentContext() -dev = device() +dev = CuCurrentDevice() let ctx2 = CuContext(dev) @test ctx2 == CuCurrentContext() # ctor implicitly pushes activate(ctx) @test ctx == CuCurrentContext() - @test device(ctx2) == dev + @test CuDevice(ctx2) == dev CUDA.unsafe_destroy!(ctx2) end @@ -22,8 +22,8 @@ let global_ctx2 = nothing @test !CUDA.isvalid(global_ctx2) @test ctx == CuCurrentContext() - @test device(ctx) == dev - @test device() == dev + @test CuDevice(ctx) == dev + @test CuCurrentDevice() == dev synchronize() end diff --git a/test/initialization.jl b/test/initialization.jl index 22b4225ca1..8e837fa445 100644 --- a/test/initialization.jl +++ b/test/initialization.jl @@ -3,6 +3,7 @@ # the API shouldn't have been initialized @test CuCurrentContext() == nothing +@test CuCurrentDevice() == nothing context_cb = Union{Nothing, CuContext}[nothing for tid in 1:Threads.nthreads()] CUDA.atcontextswitch() do tid, ctx @@ -17,7 +18,7 @@ end # now cause initialization ctx = context() @test CuCurrentContext() == ctx -@test device() == CuDevice(0) +@test CuCurrentDevice() == CuDevice(0) @test context_cb[1] == ctx @test task_cb[1] == current_task() From 4ca7e97b435dcd11c2a5a849a0a1012c5d3a77f7 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Fri, 7 Aug 2020 09:58:47 +0200 Subject: [PATCH 3/9] Make the device an explicit argument to unified memory APIs. --- lib/cudadrv/memory.jl | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/lib/cudadrv/memory.jl b/lib/cudadrv/memory.jl index 4b9cd99da3..4ad192cd60 100644 --- a/lib/cudadrv/memory.jl +++ b/lib/cudadrv/memory.jl @@ -36,8 +36,6 @@ Base.pointer(buf::Buffer) = buf.ptr Base.sizeof(buf::Buffer) = buf.bytesize -CUDA.device(buf::Buffer) = device(buf.ctx) - # ccall integration # # taking the pointer of a buffer means returning the underlying pointer, @@ -254,7 +252,7 @@ end Prefetches memory to the specified destination device. """ function prefetch(buf::UnifiedBuffer, bytes::Integer=sizeof(buf); - device::CuDevice=device(buf), stream::CuStream=CuDefaultStream()) + device::CuDevice=CuCurrentDevice(), stream::CuStream=CuDefaultStream()) bytes > sizeof(buf) && throw(BoundsError(buf, bytes)) CUDA.cuMemPrefetchAsync(buf, bytes, device, stream) end @@ -268,7 +266,7 @@ end Advise about the usage of a given memory range. """ function advise(buf::UnifiedBuffer, advice::CUDA.CUmem_advise, bytes::Integer=sizeof(buf); - device::CuDevice=device(buf)) + device::CuDevice=CuCurrentDevice()) bytes > sizeof(buf) && throw(BoundsError(buf, bytes)) CUDA.cuMemAdvise(buf, bytes, advice, device) end From bce44f997b65eccfc8a1810ae68f2ed0baf00e52 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Fri, 7 Aug 2020 10:13:57 +0200 Subject: [PATCH 4/9] Generalize thread-local state, and cache the device too. --- src/accumulate.jl | 2 +- src/compiler/execution.jl | 2 +- src/initialization.jl | 4 +-- src/state.jl | 66 ++++++++++++++++++++++++++------------- test/initialization.jl | 12 ++++++- 5 files changed, 60 insertions(+), 26 deletions(-) diff --git a/src/accumulate.jl b/src/accumulate.jl index cde18ae07d..39a37663c9 100644 --- a/src/accumulate.jl +++ b/src/accumulate.jl @@ -153,7 +153,7 @@ function scan!(f::Function, output::CuArray{T}, input::CuArray; # determine the grid layout to cover the other dimensions if length(Rother) > 1 - dev = CuDevice(kernel.fun.mod.ctx) + dev = device() max_other_blocks = attribute(dev, DEVICE_ATTRIBUTE_MAX_GRID_DIM_Y) blocks_other = (Base.min(length(Rother), max_other_blocks), cld(length(Rother), max_other_blocks)) diff --git a/src/compiler/execution.jl b/src/compiler/execution.jl index 51cee63c92..6611b141c7 100644 --- a/src/compiler/execution.jl +++ b/src/compiler/execution.jl @@ -302,7 +302,7 @@ end function _cufunction(source::FunctionSpec; kwargs...) # compile to PTX ctx = context() - dev = CuDevice(ctx) + dev = device() cap = supported_capability(dev) target = PTXCompilerTarget(; cap=supported_capability(dev), kwargs...) params = CUDACompilerParams() diff --git a/src/initialization.jl b/src/initialization.jl index 2d707ad692..6c5fa7b9a3 100644 --- a/src/initialization.jl +++ b/src/initialization.jl @@ -67,8 +67,8 @@ function __init__() # enable generation of FMA instructions to mimic behavior of nvcc LLVM.clopts("-nvptx-fma-level=1") - resize!(thread_contexts, Threads.nthreads()) - fill!(thread_contexts, nothing) + resize!(thread_state, Threads.nthreads()) + fill!(thread_state, nothing) resize!(thread_tasks, Threads.nthreads()) fill!(thread_tasks, nothing) diff --git a/src/state.jl b/src/state.jl index 3ed0a0def6..d385c756da 100644 --- a/src/state.jl +++ b/src/state.jl @@ -1,6 +1,6 @@ # global state management -export context, context!, device!, device_reset! +export context, context!, device, device!, device_reset! ## initialization @@ -22,15 +22,16 @@ proper invalidation. """ @inline function prepare_cuda_call() tid = Threads.threadid() - task = current_task() # detect when a different task is now executing on a thread - if @inbounds thread_tasks[tid].value::Task !== task - switched_tasks(tid, task) + task = @inbounds thread_tasks[tid] + if task === nothing || task.value === nothing || task.value::Task !== current_task() + switched_tasks(tid, current_task()) end - # initialize a CUDA context when first executing on a thread - if @inbounds thread_contexts[tid] === nothing + # initialize CUDA state when first executing on a thread + state = @inbounds thread_state[tid] + if state === nothing initialize_thread(tid) end @@ -46,9 +47,10 @@ end # this setting won't be used when switching tasks on a pre-initialized thread. const default_device = Ref{Union{Nothing,CuDevice}}(nothing) -# CUDA uses thread-bound contexts, but calling CuCurrentContext all the time is expensive, -# so we maintain our own thread-local state keeping track of the current context. -const thread_contexts = Union{Nothing,CuContext}[] +# CUDA uses thread-bound state, but calling CuCurrent* all the time is expensive, +# so we maintain our own thread-local copy keeping track of the current CUDA state. +CuCurrentState = NamedTuple{(:ctx, :dev), Tuple{CuContext,CuDevice}} +const thread_state = Union{Nothing,CuCurrentState}[] @noinline function initialize_thread(tid::Int) ctx = CuCurrentContext() if ctx === nothing @@ -56,7 +58,8 @@ const thread_contexts = Union{Nothing,CuContext}[] device!(dev) else # compatibility with externally-initialized contexts - thread_contexts[tid] = ctx + dev = CuCurrentDevice() + thread_state[tid] = (;ctx,dev) end end @@ -103,13 +106,12 @@ current thread). tid = Threads.threadid() prepare_cuda_call() - ctx = @inbounds thread_contexts[tid]::CuContext + state = @inbounds thread_state[tid]::CuCurrentState if Base.JLOptions().debug_level >= 2 - @assert ctx == CuCurrentContext() + @assert state.ctx == CuCurrentContext() end - - ctx + state.ctx end """ @@ -123,10 +125,11 @@ Note that the contexts used with this call should be previously acquired by call function context!(ctx::CuContext) # update the thread-local state tid = Threads.threadid() - thread_ctx = @inbounds thread_contexts[tid] - if thread_ctx != ctx - thread_contexts[tid] = ctx + state = @inbounds thread_state[tid] + if state === nothing || state.ctx != ctx activate(ctx) + dev = CuCurrentDevice() + thread_state[tid] = (;ctx, dev) _atcontextswitch(tid, ctx) end @@ -168,9 +171,29 @@ atcontextswitch(f::Function) = (pushfirst!(context_hooks, f); nothing) const context_hooks = [] _atcontextswitch(tid, ctx) = foreach(f->Base.invokelatest(f, tid, ctx), context_hooks) +# TODO: atdeviceswitch && atdevicereset make more sense + ## device-based API +""" + device()::CuDevice + +Get the CUDA device for the current thread, similar to how [`context()`](@ref) works +compared to [`CuCurrentContext()`](@ref). +""" +@inline function device() + tid = Threads.threadid() + + prepare_cuda_call() + state = @inbounds thread_state[tid]::CuCurrentState + + if Base.JLOptions().debug_level >= 2 + @assert state.dev == CuCurrentDevice() + end + state.dev +end + """ device!(dev::Integer) device!(dev::CuDevice) @@ -196,7 +219,8 @@ function device!(dev::CuDevice, flags=nothing) end # bail out if switching to the current device - if @inbounds thread_contexts[tid] !== nothing && dev == device() + state = @inbounds thread_state[tid] + if state !== nothing && state.dev == dev return end @@ -242,9 +266,9 @@ function device_reset!(dev::CuDevice=device()) unsafe_reset!(pctx) # wipe the context handles for all threads using this device - for (tid, thread_ctx) in enumerate(thread_contexts) - if thread_ctx == ctx - thread_contexts[tid] = nothing + for (tid, state) in enumerate(thread_state) + if state !== nothing && state.ctx == ctx + thread_state[tid] = nothing _atcontextswitch(tid, nothing) end end diff --git a/test/initialization.jl b/test/initialization.jl index 8e837fa445..99533ae41d 100644 --- a/test/initialization.jl +++ b/test/initialization.jl @@ -17,8 +17,9 @@ end # now cause initialization ctx = context() +dev = device() @test CuCurrentContext() == ctx -@test CuCurrentDevice() == CuDevice(0) +@test CuCurrentDevice() == dev @test context_cb[1] == ctx @test task_cb[1] == current_task() @@ -33,6 +34,15 @@ end @test context_cb[1] == nothing @test task_cb[1] == task +fill!(context_cb, nothing) +fill!(task_cb, nothing) + +# ... back to the main task +ctx = context() +dev = device() +@test context_cb[1] == nothing +@test task_cb[1] == current_task() + device!(CuDevice(0)) device!(CuDevice(0)) do nothing From 4d7084d52f7c6f51ebc26fa0d8156b15591d3f0b Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Fri, 7 Aug 2020 10:20:00 +0200 Subject: [PATCH 5/9] Extend the integer-based initialization API. --- lib/cudadrv/devices.jl | 4 +++- src/state.jl | 10 ++++++++-- test/cudadrv/devices.jl | 2 ++ test/initialization.jl | 2 ++ 4 files changed, 15 insertions(+), 3 deletions(-) diff --git a/lib/cudadrv/devices.jl b/lib/cudadrv/devices.jl index a3c2b37d43..94eca9e886 100644 --- a/lib/cudadrv/devices.jl +++ b/lib/cudadrv/devices.jl @@ -89,7 +89,7 @@ end ## device iteration -export devices +export devices, ndevices struct DeviceSet end @@ -114,6 +114,8 @@ end Base.IteratorSize(::DeviceSet) = Base.HasLength() +ndevices() = length(devices()) + ## convenience attribute getters diff --git a/src/state.jl b/src/state.jl index d385c756da..cf276e1d80 100644 --- a/src/state.jl +++ b/src/state.jl @@ -1,6 +1,6 @@ # global state management -export context, context!, device, device!, device_reset! +export context, context!, device, device!, device_reset!, deviceid ## initialization @@ -249,7 +249,6 @@ function device!(f::Function, dev::CuDevice) end end end -device!(f::Function, dev::Integer) = device!(f, CuDevice(dev)) """ device_reset!(dev::CuDevice=device()) @@ -275,3 +274,10 @@ function device_reset!(dev::CuDevice=device()) return end + + +## integer device-based API + +deviceid() = Int(convert(CUdevice, device())) + +device!(f::Function, dev::Integer) = device!(f, CuDevice(dev)) diff --git a/test/cudadrv/devices.jl b/test/cudadrv/devices.jl index 1908413398..6a24292838 100644 --- a/test/cudadrv/devices.jl +++ b/test/cudadrv/devices.jl @@ -13,3 +13,5 @@ capability(dev) @test eltype(devices()) == CuDevice @grab_output show(stdout, "text/plain", CUDA.DEVICE_CPU) @grab_output show(stdout, "text/plain", CUDA.DEVICE_INVALID) + +@test length(devices()) == ndevices() diff --git a/test/initialization.jl b/test/initialization.jl index 99533ae41d..e832c8427d 100644 --- a/test/initialization.jl +++ b/test/initialization.jl @@ -110,3 +110,5 @@ if length(devices()) > 1 end @test device() == CuDevice(0) end + +@test deviceid() >= 0 From 46ffc30b625287681408c4973600200282374bc7 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Fri, 7 Aug 2020 11:06:28 +0200 Subject: [PATCH 6/9] Split atcontextswitch in atdeviceswitch and atdevicereset. --- docs/src/api/essentials.md | 3 +- lib/cublas/CUBLAS.jl | 6 ++- lib/cudnn/CUDNN.jl | 6 ++- lib/curand/CURAND.jl | 6 ++- lib/cusolver/CUSOLVER.jl | 6 ++- lib/cusparse/CUSPARSE.jl | 6 ++- lib/cutensor/CUTENSOR.jl | 6 ++- src/state.jl | 92 +++++++++++++++++++++----------------- test/initialization.jl | 44 ++++++++++++------ 9 files changed, 109 insertions(+), 66 deletions(-) diff --git a/docs/src/api/essentials.md b/docs/src/api/essentials.md index 90c32db144..e751c1cae2 100644 --- a/docs/src/api/essentials.md +++ b/docs/src/api/essentials.md @@ -26,5 +26,6 @@ react to context or task switches: ```@docs CUDA.attaskswitch -CUDA.atcontextswitch +CUDA.atdeviceswitch +CUDA.atdevicereset ``` diff --git a/lib/cublas/CUBLAS.jl b/lib/cublas/CUBLAS.jl index be0c9404a7..c83cc76d74 100644 --- a/lib/cublas/CUBLAS.jl +++ b/lib/cublas/CUBLAS.jl @@ -82,12 +82,14 @@ function __init__() resize!(thread_xt_handles, Threads.nthreads()) fill!(thread_xt_handles, nothing) - CUDA.atcontextswitch() do tid, ctx + CUDA.atdeviceswitch() do + tid = Threads.threadid() thread_handles[tid] = nothing thread_xt_handles[tid] = nothing end - CUDA.attaskswitch() do tid, task + CUDA.attaskswitch() do + tid = Threads.threadid() thread_handles[tid] = nothing thread_xt_handles[tid] = nothing end diff --git a/lib/cudnn/CUDNN.jl b/lib/cudnn/CUDNN.jl index efb7ffe437..e1bf4e4b2a 100644 --- a/lib/cudnn/CUDNN.jl +++ b/lib/cudnn/CUDNN.jl @@ -59,11 +59,13 @@ function __init__() resize!(thread_handles, Threads.nthreads()) fill!(thread_handles, nothing) - CUDA.atcontextswitch() do tid, ctx + CUDA.atdeviceswitch() do + tid = Threads.threadid() thread_handles[tid] = nothing end - CUDA.attaskswitch() do tid, task + CUDA.attaskswitch() do + tid = Threads.threadid() thread_handles[tid] = nothing end end diff --git a/lib/curand/CURAND.jl b/lib/curand/CURAND.jl index 6c458844bb..4ca06b3b1b 100644 --- a/lib/curand/CURAND.jl +++ b/lib/curand/CURAND.jl @@ -59,12 +59,14 @@ function __init__() resize!(GPUARRAY_THREAD_RNGs, Threads.nthreads()) fill!(GPUARRAY_THREAD_RNGs, nothing) - CUDA.atcontextswitch() do tid, ctx + CUDA.atdeviceswitch() do + tid = Threads.threadid() CURAND_THREAD_RNGs[tid] = nothing GPUARRAY_THREAD_RNGs[tid] = nothing end - CUDA.attaskswitch() do tid, task + CUDA.attaskswitch() do + tid = Threads.threadid() CURAND_THREAD_RNGs[tid] = nothing GPUARRAY_THREAD_RNGs[tid] = nothing end diff --git a/lib/cusolver/CUSOLVER.jl b/lib/cusolver/CUSOLVER.jl index 0e526c8aaf..3f3e688887 100644 --- a/lib/cusolver/CUSOLVER.jl +++ b/lib/cusolver/CUSOLVER.jl @@ -72,12 +72,14 @@ function __init__() resize!(thread_sparse_handles, Threads.nthreads()) fill!(thread_sparse_handles, nothing) - CUDA.atcontextswitch() do tid, ctx + CUDA.atdeviceswitch() do + tid = Threads.threadid() thread_dense_handles[tid] = nothing thread_sparse_handles[tid] = nothing end - CUDA.attaskswitch() do tid, task + CUDA.attaskswitch() do + tid = Threads.threadid() thread_dense_handles[tid] = nothing thread_sparse_handles[tid] = nothing end diff --git a/lib/cusparse/CUSPARSE.jl b/lib/cusparse/CUSPARSE.jl index 7e12d07ffe..dd4e06728d 100644 --- a/lib/cusparse/CUSPARSE.jl +++ b/lib/cusparse/CUSPARSE.jl @@ -52,11 +52,13 @@ function __init__() resize!(thread_handles, Threads.nthreads()) fill!(thread_handles, nothing) - CUDA.atcontextswitch() do tid, ctx + CUDA.atdeviceswitch() do + tid = Threads.threadid() thread_handles[tid] = nothing end - CUDA.attaskswitch() do tid, task + CUDA.attaskswitch() do + tid = Threads.threadid() thread_handles[tid] = nothing end end diff --git a/lib/cutensor/CUTENSOR.jl b/lib/cutensor/CUTENSOR.jl index 1c3fdc4012..88ac7d8cf9 100644 --- a/lib/cutensor/CUTENSOR.jl +++ b/lib/cutensor/CUTENSOR.jl @@ -42,11 +42,13 @@ function __init__() resize!(thread_handles, Threads.nthreads()) fill!(thread_handles, nothing) - CUDA.atcontextswitch() do tid, ctx + CUDA.atdeviceswitch() do + tid = Threads.threadid() thread_handles[tid] = nothing end - CUDA.attaskswitch() do tid, task + CUDA.attaskswitch() do + tid = Threads.threadid() thread_handles[tid] = nothing end end diff --git a/src/state.jl b/src/state.jl index cf276e1d80..6f284ef5cf 100644 --- a/src/state.jl +++ b/src/state.jl @@ -3,6 +3,44 @@ export context, context!, device, device!, device_reset!, deviceid +## hooks + +""" + CUDA.attaskswitch(f::Function) + +Register a function to be called after switching to or initializing a task on a thread. + +Use this hook to invalidate thread-local state that depends on the current task. +""" +attaskswitch(f::Function) = (pushfirst!(task_hooks, f); nothing) +const task_hooks = [] +_attaskswitch() = foreach(f->Base.invokelatest(f), task_hooks) + +""" + CUDA.atdeviceswitch(f::Function) + +Register a function to be called after switching to or initializing a device on a thread. + +Use this hook to invalidate thread-local state that depends on the current device. If that +state is also context dependent, be sure to query the context in your callback. +""" +atdeviceswitch(f::Function) = (pushfirst!(device_switch_hooks, f); nothing) +const device_switch_hooks = [] +_atdeviceswitch() = foreach(f->Base.invokelatest(f), device_switch_hooks) + +""" + CUDA.atdevicereset(f::Function) + +Register a function to be called after resetting devices. The function is passed one +argument: the device which has been reset. + +Use this hook to invalidate global state that depends on the current device. +""" +atdevicereset(f::Function) = (pushfirst!(device_reset_hooks, f); nothing) +const device_reset_hooks = [] +_atdevicereset(dev) = foreach(f->Base.invokelatest(f, dev), device_reset_hooks) + + ## initialization """ @@ -17,7 +55,7 @@ between) different threads. To synchronize these two worlds, call this function CUDA API call to update thread-local state based on the current task and its context. If you need to maintain your own thread-local state, subscribe to context and task switch -events using [`CUDA.atcontextswitch`](@ref) and [`CUDA.attaskswitch`](@ref) for +events using [`CUDA.atdeviceswitch`](@ref) and [`CUDA.attaskswitch`](@ref) for proper invalidation. """ @inline function prepare_cuda_call() @@ -69,7 +107,7 @@ end const thread_tasks = Union{Nothing,WeakRef}[] @noinline function switched_tasks(tid::Int, task::Task) thread_tasks[tid] = WeakRef(task) - _attaskswitch(tid, task) + _attaskswitch() # switch contexts if task switched to was already bound to one ctx = get(task_local_storage(), :CuContext, nothing) @@ -80,18 +118,6 @@ const thread_tasks = Union{Nothing,WeakRef}[] # but that confuses CUDA and leads to invalid contexts later on. end -""" - CUDA.attaskswitch(f::Function) - -Register a function to be called after switching tasks on a thread. The function is passed -two arguments: the thread ID, and the task switched to. - -Use this hook to invalidate thread-local state that depends on the current task. -""" -attaskswitch(f::Function) = (pushfirst!(task_hooks, f); nothing) -const task_hooks = [] -_attaskswitch(tid, task) = foreach(f->Base.invokelatest(f, tid, task), task_hooks) - ## context-based API @@ -120,7 +146,7 @@ end Bind the current host thread to the context `ctx`. Note that the contexts used with this call should be previously acquired by calling -[`context`](@ref), and not arbirary contexts created by calling the `CuContext` constructor. +[`context`](@ref), and not arbitrary contexts created by calling the `CuContext` constructor. """ function context!(ctx::CuContext) # update the thread-local state @@ -130,7 +156,7 @@ function context!(ctx::CuContext) activate(ctx) dev = CuCurrentDevice() thread_state[tid] = (;ctx, dev) - _atcontextswitch(tid, ctx) + _atdeviceswitch() end # update the task-local state @@ -156,23 +182,6 @@ function context!(f::Function, ctx::CuContext) end end -""" - CUDA.atcontextswitch(f::Function) - -Register a function to be called after switching contexts on a thread. The function is -passed two arguments: the thread ID, and the context switched to. - -If the new context is `nothing`, this indicates that the context is being unbound from this -thread (typically during device reset). - -Use this hook to invalidate thread-local state that depends on the current device or context. -""" -atcontextswitch(f::Function) = (pushfirst!(context_hooks, f); nothing) -const context_hooks = [] -_atcontextswitch(tid, ctx) = foreach(f->Base.invokelatest(f, tid, ctx), context_hooks) - -# TODO: atdeviceswitch && atdevicereset make more sense - ## device-based API @@ -206,7 +215,7 @@ for initial set-up of the environment. If you need to switch devices on a regula work with contexts instead and call [`context!`](@ref) directly (5-10ns). If your library or code needs to perform an action when the active context changes, -add a hook using [`CUDA.atcontextswitch`](@ref). +add a hook using [`CUDA.atdeviceswitch`](@ref). """ function device!(dev::CuDevice, flags=nothing) tid = Threads.threadid() @@ -255,23 +264,26 @@ end Reset the CUDA state associated with a device. This call with release the underlying context, at which point any objects allocated in that context will be invalidated. + +If your library or code needs to perform an action when the active context changes, +add a hook using [`CUDA.atdevicereset`](@ref). Resetting the device will also cause +subsequent API calls to fire the [`CUDA.atdeviceswitch`](@ref) hook. """ function device_reset!(dev::CuDevice=device()) - pctx = CuPrimaryContext(dev) - ctx = CuContext(pctx) - # unconditionally reset the primary context (don't just release it), # as there might be users outside of CUDA.jl + pctx = CuPrimaryContext(dev) unsafe_reset!(pctx) - # wipe the context handles for all threads using this device + # wipe the thread-local state for all threads using this device for (tid, state) in enumerate(thread_state) - if state !== nothing && state.ctx == ctx + if state !== nothing && state.dev == dev thread_state[tid] = nothing - _atcontextswitch(tid, nothing) end end + _atdevicereset(dev) + return end diff --git a/test/initialization.jl b/test/initialization.jl index e832c8427d..01158b8278 100644 --- a/test/initialization.jl +++ b/test/initialization.jl @@ -5,14 +5,25 @@ @test CuCurrentContext() == nothing @test CuCurrentDevice() == nothing -context_cb = Union{Nothing, CuContext}[nothing for tid in 1:Threads.nthreads()] -CUDA.atcontextswitch() do tid, ctx - context_cb[tid] = ctx +task_cb = Any[nothing for tid in 1:Threads.nthreads()] +CUDA.attaskswitch() do + task_cb[Threads.threadid()] = current_task() end -task_cb = Union{Nothing, Task}[nothing for tid in 1:Threads.nthreads()] -CUDA.attaskswitch() do tid, task - task_cb[tid] = task +device_switch_cb = Any[nothing for tid in 1:Threads.nthreads()] +CUDA.atdeviceswitch() do + device_switch_cb[Threads.threadid()] = (dev=device(), ctx=context()) +end + +device_reset_cb = Any[nothing for tid in 1:Threads.nthreads()] +CUDA.atdevicereset() do dev + device_reset_cb[Threads.threadid()] = dev +end + +function reset_cb() + fill!(task_cb, nothing) + fill!(device_switch_cb, nothing) + fill!(device_reset_cb, nothing) end # now cause initialization @@ -20,28 +31,27 @@ ctx = context() dev = device() @test CuCurrentContext() == ctx @test CuCurrentDevice() == dev -@test context_cb[1] == ctx @test task_cb[1] == current_task() +@test device_switch_cb[1].ctx == ctx +@test device_switch_cb[1].dev == dev -fill!(context_cb, nothing) -fill!(task_cb, nothing) +reset_cb() # ... on a different task task = @async begin context() end @test ctx == fetch(task) -@test context_cb[1] == nothing @test task_cb[1] == task +@test device_switch_cb[1] == nothing -fill!(context_cb, nothing) -fill!(task_cb, nothing) +reset_cb() # ... back to the main task ctx = context() dev = device() -@test context_cb[1] == nothing @test task_cb[1] == current_task() +@test device_switch_cb[1] == nothing device!(CuDevice(0)) device!(CuDevice(0)) do @@ -55,9 +65,17 @@ end @test_throws AssertionError device!(0, CUDA.CU_CTX_SCHED_YIELD) +reset_cb() + device_reset!() +@test device_reset_cb[1] == CuDevice(0) + +reset_cb() + device!(0, CUDA.CU_CTX_SCHED_YIELD) +@test task_cb[1] == nothing +@test device_switch_cb[1].dev == CuDevice(0) # test the device selection functionality if length(devices()) > 1 From 8f04df3fc2278a8e9e6380eb42df683eb4755711 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Fri, 7 Aug 2020 11:56:25 +0200 Subject: [PATCH 7/9] Compatibility with older Julia. --- src/state.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/state.jl b/src/state.jl index 6f284ef5cf..391cded075 100644 --- a/src/state.jl +++ b/src/state.jl @@ -97,7 +97,7 @@ const thread_state = Union{Nothing,CuCurrentState}[] else # compatibility with externally-initialized contexts dev = CuCurrentDevice() - thread_state[tid] = (;ctx,dev) + thread_state[tid] = (;ctx=ctx, dev=dev) end end @@ -155,7 +155,7 @@ function context!(ctx::CuContext) if state === nothing || state.ctx != ctx activate(ctx) dev = CuCurrentDevice() - thread_state[tid] = (;ctx, dev) + thread_state[tid] = (;ctx=ctx, dev=dev) _atdeviceswitch() end From dc72fd55b8bbbcb37a3db15d356f6952634bf18b Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Fri, 7 Aug 2020 12:54:57 +0200 Subject: [PATCH 8/9] NFC clean-ups. [ci skip] --- src/state.jl | 48 +++++++++++++++++++++++++----------------------- 1 file changed, 25 insertions(+), 23 deletions(-) diff --git a/src/state.jl b/src/state.jl index 391cded075..b923693e9d 100644 --- a/src/state.jl +++ b/src/state.jl @@ -54,9 +54,11 @@ Execution can switch between them, and tasks can be executing on (and in the fut between) different threads. To synchronize these two worlds, call this function before any CUDA API call to update thread-local state based on the current task and its context. -If you need to maintain your own thread-local state, subscribe to context and task switch +If you need to maintain your own task-local state, subscribe to device and task switch events using [`CUDA.atdeviceswitch`](@ref) and [`CUDA.attaskswitch`](@ref) for -proper invalidation. +proper invalidation. If your state is device-specific, but global (i.e. not task-bound), it +suffices to index your state with the current [`deviceid()`](@ref) and invalidate that state +when the device is reset by subscribing to [`CUDA.atdevicereset()`](@ref). """ @inline function prepare_cuda_call() tid = Threads.threadid() @@ -80,6 +82,23 @@ proper invalidation. return end +# Julia executes with tasks, so we need to keep track of the active task for each thread +# in order to detect task switches and update the thread-local state accordingly. +# doing so using task_local_storage is too expensive. +const thread_tasks = Union{Nothing,WeakRef}[] +@noinline function switched_tasks(tid::Int, task::Task) + thread_tasks[tid] = WeakRef(task) + _attaskswitch() + + # switch contexts if task switched to was already bound to one + ctx = get(task_local_storage(), :CuContext, nothing) + if ctx !== nothing + context!(ctx) + end + # NOTE: deactivating the context in the case ctx===nothing would be more correct, + # but that confuses CUDA and leads to invalid contexts later on. +end + # the default device unitialized tasks will use, set when switching devices. # this behavior differs from the CUDA Runtime, where device 0 is always used. # this setting won't be used when switching tasks on a pre-initialized thread. @@ -87,7 +106,7 @@ const default_device = Ref{Union{Nothing,CuDevice}}(nothing) # CUDA uses thread-bound state, but calling CuCurrent* all the time is expensive, # so we maintain our own thread-local copy keeping track of the current CUDA state. -CuCurrentState = NamedTuple{(:ctx, :dev), Tuple{CuContext,CuDevice}} +const CuCurrentState = NamedTuple{(:ctx, :dev), Tuple{CuContext,CuDevice}} const thread_state = Union{Nothing,CuCurrentState}[] @noinline function initialize_thread(tid::Int) ctx = CuCurrentContext() @@ -101,23 +120,6 @@ const thread_state = Union{Nothing,CuCurrentState}[] end end -# Julia executes with tasks, so we need to keep track of the active task for each thread -# in order to detect task switches and update the thread-local state accordingly. -# doing so using task_local_storage is too expensive. -const thread_tasks = Union{Nothing,WeakRef}[] -@noinline function switched_tasks(tid::Int, task::Task) - thread_tasks[tid] = WeakRef(task) - _attaskswitch() - - # switch contexts if task switched to was already bound to one - ctx = get(task_local_storage(), :CuContext, nothing) - if ctx !== nothing - context!(ctx) - end - # NOTE: deactivating the context in the case ctx===nothing would be more correct, - # but that confuses CUDA and leads to invalid contexts later on. -end - ## context-based API @@ -248,13 +250,13 @@ device!(dev::Integer, flags=nothing) = device!(CuDevice(dev), flags) Sets the active device for the duration of `f`. """ function device!(f::Function, dev::CuDevice) - old_ctx = CuCurrentContext() + ctx = CuCurrentContext() try device!(dev) f() finally - if old_ctx != nothing - context!(old_ctx) + if ctx != nothing + context!(ctx) end end end From a96a97874a8bcf81eb5fa3ada7a4d7b84a179cc1 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Fri, 7 Aug 2020 14:39:07 +0200 Subject: [PATCH 9/9] Make deviceid 1-indexed. --- src/state.jl | 10 +++++++++- test/initialization.jl | 2 +- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/state.jl b/src/state.jl index b923693e9d..d5f4d5ffd3 100644 --- a/src/state.jl +++ b/src/state.jl @@ -292,6 +292,14 @@ end ## integer device-based API -deviceid() = Int(convert(CUdevice, device())) +""" + deviceid(dev::CuDevice=device())::Int + deviceid()::Int + +Get the ID number of the current device of execution. This is a 1-indexed number, and can +be used to index, e.g., thread-local state. It should not be used to acquire a `CuDevice`; +use [`device()`](@ref) for that. +""" +deviceid(dev::CuDevice=device()) = Int(convert(CUdevice, dev)) + 1 device!(f::Function, dev::Integer) = device!(f, CuDevice(dev)) diff --git a/test/initialization.jl b/test/initialization.jl index 01158b8278..a174f1aad0 100644 --- a/test/initialization.jl +++ b/test/initialization.jl @@ -129,4 +129,4 @@ if length(devices()) > 1 @test device() == CuDevice(0) end -@test deviceid() >= 0 +@test deviceid() >= 1