Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework thread state management #356

Merged
merged 9 commits into from
Aug 7, 2020
3 changes: 2 additions & 1 deletion docs/src/api/essentials.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,6 @@ react to context or task switches:

```@docs
CUDA.attaskswitch
CUDA.atcontextswitch
CUDA.atdeviceswitch
CUDA.atdevicereset
```
6 changes: 4 additions & 2 deletions lib/cublas/CUBLAS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,14 @@ function __init__()
resize!(thread_xt_handles, Threads.nthreads())
fill!(thread_xt_handles, nothing)

CUDA.atcontextswitch() do tid, ctx
CUDA.atdeviceswitch() do
tid = Threads.threadid()
thread_handles[tid] = nothing
thread_xt_handles[tid] = nothing
end

CUDA.attaskswitch() do tid, task
CUDA.attaskswitch() do
tid = Threads.threadid()
thread_handles[tid] = nothing
thread_xt_handles[tid] = nothing
end
Expand Down
33 changes: 21 additions & 12 deletions lib/cudadrv/context.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

export
CuContext, CuCurrentContext, activate,
synchronize, device
synchronize, CuCurrentDevice


## construction and destruction
Expand Down Expand Up @@ -125,22 +125,31 @@ end
## context properties

"""
device()
device(ctx::CuContext)
CuCurrentDevice()

Returns the device for a context.
Returns the current device, or `nothing` if there is no active device.
"""
function device(ctx::CuContext)
push!(CuContext, ctx)
function CuCurrentDevice()
device_ref = Ref{CUdevice}()
cuCtxGetDevice(device_ref)
pop!(CuContext)
res = unsafe_cuCtxGetDevice(device_ref)
if res == ERROR_INVALID_CONTEXT
return nothing
elseif res != SUCCESS
throw_api_error(res)
end
return CuDevice(Bool, device_ref[])
end
function device()
device_ref = Ref{CUdevice}()
cuCtxGetDevice(device_ref)
return CuDevice(Bool, device_ref[])

"""
CuDevice(::CuContext)

Returns the device for a context.
"""
function CuDevice(ctx::CuContext)
push!(CuContext, ctx)
dev = CuCurrentDevice()
pop!(CuContext)
return dev
end

"""
Expand Down
2 changes: 1 addition & 1 deletion lib/cudadrv/context/primary.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ does not respect any users of the context, and might make other objects unusable
"""
function unsafe_release!(ctx::CuContext)
if isvalid(ctx)
dev = device(ctx)
dev = CuDevice(ctx)
pctx = CuPrimaryContext(dev)
if version() >= v"11"
cuDevicePrimaryCtxRelease_v2(dev)
Expand Down
4 changes: 3 additions & 1 deletion lib/cudadrv/devices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ end

## device iteration

export devices
export devices, ndevices

struct DeviceSet end

Expand All @@ -114,6 +114,8 @@ end

Base.IteratorSize(::DeviceSet) = Base.HasLength()

ndevices() = length(devices())


## convenience attribute getters

Expand Down
1 change: 0 additions & 1 deletion lib/cudadrv/libcuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,6 @@ end
end

@checked function cuCtxGetDevice(device)
initialize_api()
@runtime_ccall((:cuCtxGetDevice, libcuda()), CUresult,
(Ptr{CUdevice},),
device)
Expand Down
6 changes: 2 additions & 4 deletions lib/cudadrv/memory.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@ Base.pointer(buf::Buffer) = buf.ptr

Base.sizeof(buf::Buffer) = buf.bytesize

CUDA.device(buf::Buffer) = device(buf.ctx)

# ccall integration
#
# taking the pointer of a buffer means returning the underlying pointer,
Expand Down Expand Up @@ -254,7 +252,7 @@ end
Prefetches memory to the specified destination device.
"""
function prefetch(buf::UnifiedBuffer, bytes::Integer=sizeof(buf);
device::CuDevice=device(buf), stream::CuStream=CuDefaultStream())
device::CuDevice=CuCurrentDevice(), stream::CuStream=CuDefaultStream())
bytes > sizeof(buf) && throw(BoundsError(buf, bytes))
CUDA.cuMemPrefetchAsync(buf, bytes, device, stream)
end
Expand All @@ -268,7 +266,7 @@ end
Advise about the usage of a given memory range.
"""
function advise(buf::UnifiedBuffer, advice::CUDA.CUmem_advise, bytes::Integer=sizeof(buf);
device::CuDevice=device(buf))
device::CuDevice=CuCurrentDevice())
bytes > sizeof(buf) && throw(BoundsError(buf, bytes))
CUDA.cuMemAdvise(buf, bytes, advice, device)
end
Expand Down
2 changes: 1 addition & 1 deletion lib/cudadrv/occupancy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ function occupancy(fun::CuFunction, threads::Integer; shmem::Integer=0)

mod = fun.mod
ctx = mod.ctx
dev = device(ctx)
dev = CuDevice(ctx)

threads_per_sm = attribute(dev, DEVICE_ATTRIBUTE_MAX_THREADS_PER_MULTIPROCESSOR)
warp_size = attribute(dev, DEVICE_ATTRIBUTE_WARP_SIZE)
Expand Down
6 changes: 4 additions & 2 deletions lib/cudnn/CUDNN.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,13 @@ function __init__()
resize!(thread_handles, Threads.nthreads())
fill!(thread_handles, nothing)

CUDA.atcontextswitch() do tid, ctx
CUDA.atdeviceswitch() do
tid = Threads.threadid()
thread_handles[tid] = nothing
end

CUDA.attaskswitch() do tid, task
CUDA.attaskswitch() do
tid = Threads.threadid()
thread_handles[tid] = nothing
end
end
Expand Down
6 changes: 4 additions & 2 deletions lib/curand/CURAND.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,14 @@ function __init__()
resize!(GPUARRAY_THREAD_RNGs, Threads.nthreads())
fill!(GPUARRAY_THREAD_RNGs, nothing)

CUDA.atcontextswitch() do tid, ctx
CUDA.atdeviceswitch() do
tid = Threads.threadid()
CURAND_THREAD_RNGs[tid] = nothing
GPUARRAY_THREAD_RNGs[tid] = nothing
end

CUDA.attaskswitch() do tid, task
CUDA.attaskswitch() do
tid = Threads.threadid()
CURAND_THREAD_RNGs[tid] = nothing
GPUARRAY_THREAD_RNGs[tid] = nothing
end
Expand Down
6 changes: 4 additions & 2 deletions lib/cusolver/CUSOLVER.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,14 @@ function __init__()
resize!(thread_sparse_handles, Threads.nthreads())
fill!(thread_sparse_handles, nothing)

CUDA.atcontextswitch() do tid, ctx
CUDA.atdeviceswitch() do
tid = Threads.threadid()
thread_dense_handles[tid] = nothing
thread_sparse_handles[tid] = nothing
end

CUDA.attaskswitch() do tid, task
CUDA.attaskswitch() do
tid = Threads.threadid()
thread_dense_handles[tid] = nothing
thread_sparse_handles[tid] = nothing
end
Expand Down
6 changes: 4 additions & 2 deletions lib/cusparse/CUSPARSE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,13 @@ function __init__()
resize!(thread_handles, Threads.nthreads())
fill!(thread_handles, nothing)

CUDA.atcontextswitch() do tid, ctx
CUDA.atdeviceswitch() do
tid = Threads.threadid()
thread_handles[tid] = nothing
end

CUDA.attaskswitch() do tid, task
CUDA.attaskswitch() do
tid = Threads.threadid()
thread_handles[tid] = nothing
end
end
Expand Down
6 changes: 4 additions & 2 deletions lib/cutensor/CUTENSOR.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,13 @@ function __init__()
resize!(thread_handles, Threads.nthreads())
fill!(thread_handles, nothing)

CUDA.atcontextswitch() do tid, ctx
CUDA.atdeviceswitch() do
tid = Threads.threadid()
thread_handles[tid] = nothing
end

CUDA.attaskswitch() do tid, task
CUDA.attaskswitch() do
tid = Threads.threadid()
thread_handles[tid] = nothing
end
end
Expand Down
1 change: 1 addition & 0 deletions res/wrap/wrap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ preinit_apicalls = Set{String}([
"cuCtxGetCurrent",
"cuCtxPushCurrent",
"cuCtxPopCurrent",
"cuCtxGetDevice", # this actually does require a context, but we use it unsafely
## primary context management
"cuDevicePrimaryCtxGetState",
"cuDevicePrimaryCtxRelease",
Expand Down
2 changes: 1 addition & 1 deletion src/accumulate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ function scan!(f::Function, output::CuArray{T}, input::CuArray;

# determine the grid layout to cover the other dimensions
if length(Rother) > 1
dev = device(kernel.fun.mod.ctx)
dev = device()
max_other_blocks = attribute(dev, DEVICE_ATTRIBUTE_MAX_GRID_DIM_Y)
blocks_other = (Base.min(length(Rother), max_other_blocks),
cld(length(Rother), max_other_blocks))
Expand Down
4 changes: 1 addition & 3 deletions src/compiler/exceptions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,16 +82,14 @@ function create_exceptions!(mod::CuModule)
end

# check the exception flags on every API call, similarly to how CUDA handles errors
# FIXME: this is expensive. Maybe kernels should return a `wait`able object, a la KA.jl,
# which then performs the necessary checks.
function check_exceptions()
for (ctx,buf) in exception_flags
if isvalid(ctx)
ptr = convert(Ptr{Int}, buf)
flag = unsafe_load(ptr)
if flag != 0
unsafe_store!(ptr, 0)
dev = device(ctx)
dev = CuDevice(ctx)
throw(KernelException(dev))
end
end
Expand Down
2 changes: 1 addition & 1 deletion src/compiler/execution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ end
function _cufunction(source::FunctionSpec; kwargs...)
# compile to PTX
ctx = context()
dev = device(ctx)
dev = device()
cap = supported_capability(dev)
target = PTXCompilerTarget(; cap=supported_capability(dev), kwargs...)
params = CUDACompilerParams()
Expand Down
4 changes: 2 additions & 2 deletions src/initialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ function __init__()
# enable generation of FMA instructions to mimic behavior of nvcc
LLVM.clopts("-nvptx-fma-level=1")

resize!(thread_contexts, Threads.nthreads())
fill!(thread_contexts, nothing)
resize!(thread_state, Threads.nthreads())
fill!(thread_state, nothing)

resize!(thread_tasks, Threads.nthreads())
fill!(thread_tasks, nothing)
Expand Down
Loading