Skip to content

Commit

Permalink
Rework context handling (#2346)
Browse files Browse the repository at this point in the history
Change CuContext to a simple immutable struct. This removes the ability
to identify context sessions by simply looking at the context object,
which can be restored on CUDA 12+ using the cuCtxGetId API.

As a consequence, it is no longer safe to reset the primary context
on drivers below CUDA 12.
  • Loading branch information
maleadt committed Apr 26, 2024
1 parent dc8985b commit 752571b
Show file tree
Hide file tree
Showing 24 changed files with 312 additions and 290 deletions.
170 changes: 67 additions & 103 deletions lib/cudadrv/context.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

export
CuPrimaryContext, CuContext, current_context, has_context, activate,
unsafe_reset!, isactive, flags, setflags!,
unsafe_reset!, isactive, flags, setflags!, unique_id, api_version,
device, device_synchronize


Expand All @@ -20,93 +20,70 @@ the system cleans up the resources allocated to it.
When you are done using the context, call [`CUDA.unsafe_destroy!`](@ref) to mark it for
deletion, or use do-block syntax with this constructor.
"""
mutable struct CuContext
struct CuContext
handle::CUcontext
valid::Bool
id::UInt64

function CuContext(dev::CuDevice, flags=0)
handle_ref = Ref{CUcontext}()
cuCtxCreate_v2(handle_ref, flags, dev)
UniqueCuContext(handle_ref[])
end
function CuContext(handle::CUcontext)
handle == C_NULL && throw(UndefRefError())

global function current_context()
handle_ref = Ref{CUcontext}()
cuCtxGetCurrent(handle_ref)
handle_ref[] == C_NULL && throw(UndefRefError())
UniqueCuContext(handle_ref[])
end
id = if driver_version() >= v"12"
id_ref = Ref{Culonglong}()
res = unchecked_cuCtxGetId(handle, id_ref)
res == ERROR_CONTEXT_IS_DESTROYED && throw(UndefRefError())
res != SUCCESS && throw_api_error(res)
id_ref[]
else
typemax(UInt64)
end

global UnsafeCuContext(handle::CUcontext) = new(handle, true)
new(handle, id)
end
end

unsafe
function CuContext(dev::CuDevice, flags=0)
handle_ref = Ref{CUcontext}()
cuCtxCreate_v2(handle_ref, flags, dev)
CuContext(handle_ref[])
end

"""
current_context()
Returns the current context.
Returns the current context. Throws an undefined reference error if the current thread
has no context bound to it, or if the bound context has been destroyed.
!!! warning
This is a low-level API, returning the current context as known to the CUDA driver.
For most users, it is recommended to use the [`context`](@ref) method instead.
"""
current_context()

"""
has_context()
Returns whether there is an active context.
"""
function has_context()
function current_context()
handle_ref = Ref{CUcontext}()
cuCtxGetCurrent(handle_ref)
handle_ref[] != C_NULL
handle_ref[] == C_NULL && throw(UndefRefError())
CuContext(handle_ref[])
end

# we need to know when a context has been destroyed, to make sure we don't destroy resources
# after the owning context has been destroyed already. this is complicated by the fact that
# contexts obtained from a primary context have the same handle before and after primary
# context destruction, so we cannot use a simple mapping from context handle to a validity
# bit. instead, we unique the context objects and put a validity bit in there.
isvalid(ctx::CuContext) = ctx.valid
function invalidate!(ctx::CuContext)
ctx.valid = false
return
end
# to make this work, every function returning a context (e.g. `cuCtxGetCurrent`, attribute
# functions, etc) need to return the same context objects. because looking up a context is a
# very common operation (often executed from finalizers), we need to ensure this look-up is
# fast and does not switch tasks. we do this by scanning a simple linear vector.
const MAX_CONTEXTS = 1024
const context_objects = Vector{CuContext}(undef, MAX_CONTEXTS)
const context_lock = Base.ThreadSynchronizer()
function UniqueCuContext(handle::CUcontext)
@lock context_lock begin
# look if there's an existing object for this handle
i = 1
@inbounds while i <= MAX_CONTEXTS && isassigned(context_objects, i)
if context_objects[i].handle == handle
if isvalid(context_objects[i])
return context_objects[i]
else
# this object was invalidated, so we can reuse its slot
break
end
end
i += 1
end
if i == MAX_CONTEXTS
error("Exceeded maximum amount of CUDA contexts. This is unexpected; please file an issue.")
end
function isvalid(ctx::CuContext)
# we first try an API call to see if the context handle is usable
if driver_version() >= v"12"
id_ref = Ref{Culonglong}()
res = unchecked_cuCtxGetId(ctx, id_ref)
res == ERROR_CONTEXT_IS_DESTROYED && return false
res != SUCCESS && throw_api_error(res)

# detect handle reuse, which happens when destroying and re-creating a context, by
# looking at the context's unique ID (which does change on re-creation)
return ctx.id == id_ref[]
else
version_ref = Ref{Cuint}()
res = unchecked_cuCtxGetApiVersion(ctx, version_ref)
res == ERROR_INVALID_CONTEXT && return false

# we've got a slot we can write to
new_object = UnsafeCuContext(handle)
@inbounds context_objects[i] = new_object
return new_object
# we can't detect handle reuse, so we just assume the context is valid
return true
end
end

Expand All @@ -119,27 +96,21 @@ respect any users of the context, and might make other objects unusable.
function unsafe_destroy!(ctx::CuContext)
if isvalid(ctx)
cuCtxDestroy_v2(ctx)
invalidate!(ctx)
end
end

Base.unsafe_convert(::Type{CUcontext}, ctx::CuContext) = ctx.handle

# NOTE: we don't implement `isequal` or `hash` in order to fall back to `===` and `objectid`
# as contexts are unique, and with primary device contexts identical handles might be
# returned after resetting the context (device) and all associated resources.

function Base.show(io::IO, ctx::CuContext)
if ctx.handle != C_NULL
fields = [@sprintf("%p", ctx.handle), @sprintf("instance %x", objectid(ctx))]
if !isvalid(ctx)
push!(fields, "invalidated")
end

print(io, "CuContext(", join(fields, ", "), ")")
else
print(io, "CuContext(NULL)")
fields = [@sprintf("%p", ctx.handle)]
if driver_version() >= v"12"
push!(fields, "id=$(ctx.id)")
end
if !isvalid(ctx)
push!(fields, "destroyed")
end

print(io, "CuContext(", join(fields, ", "), ")")
end


Expand Down Expand Up @@ -181,6 +152,18 @@ function CuContext(f::Function, dev::CuDevice, args...)
end
end

function unique_id(ctx::CuContext)
id_ref = Ref{Culonglong}()
cuCtxGetId(ctx, id_ref)
return id_ref[]
end

function api_version(ctx::CuContext)
version = Ref{Cuint}()
cuCtxGetApiVersion(ctx, version)
return version[]
end


## primary context management

Expand All @@ -196,11 +179,6 @@ struct CuPrimaryContext
dev::CuDevice
end

# we need to keep track of contexts derived from primary contexts,
# so that we can invalidate them when the primary context is reset.
const derived_contexts = Dict{CuPrimaryContext,CuContext}()
const derived_lock = ReentrantLock()

"""
CuContext(pctx::CuPrimaryContext)
Expand All @@ -215,9 +193,7 @@ by using the `do`-block syntax.
function CuContext(pctx::CuPrimaryContext)
handle_ref = Ref{CUcontext}()
cuDevicePrimaryCtxRetain(handle_ref, pctx.dev)
ctx = UniqueCuContext(handle_ref[])
Base.@lock derived_lock derived_contexts[pctx] = ctx
return ctx
CuContext(handle_ref[])
end

function CuContext(f::Function, pctx::CuPrimaryContext)
Expand All @@ -237,18 +213,12 @@ does not respect any users of the context, and might make other objects unusable
"""
function unsafe_release!(pctx::CuPrimaryContext)
if driver_version() >= v"11"
cuDevicePrimaryCtxRelease_v2(dev)
cuDevicePrimaryCtxRelease_v2(pctx.dev)
else
cuDevicePrimaryCtxRelease(dev)
cuDevicePrimaryCtxRelease(pctx.dev)
end

# if this releases the last reference, invalidate all derived contexts
if !isactive(pctx)
ctx = @lock derived_lock get(derived_contexts, pctx, nothing)
if ctx !== nothing
invalidate!(ctx)
end
end
return
end

"""
Expand All @@ -265,12 +235,6 @@ function unsafe_reset!(pctx::CuPrimaryContext)
cuDevicePrimaryCtxReset(pctx.dev)
end

# invalidate all derived contexts
ctx = @lock derived_lock get(derived_contexts, pctx, nothing)
if ctx !== nothing
invalidate!(ctx)
end

return
end

Expand Down
17 changes: 0 additions & 17 deletions lib/cudadrv/devices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,23 +42,6 @@ Returns the current device.
"""
current_device()

"""
has_device()
Returns whether there is an active device.
"""
function has_device()
device_ref = Ref{CUdevice}()
res = unchecked_cuCtxGetDevice(device_ref)
if res == SUCCESS
return true
elseif res == ERROR_INVALID_CONTEXT
return false
else
throw_api_error(res)
end
end

const DEVICE_CPU = _CuDevice(CUdevice(-1))
const DEVICE_INVALID = _CuDevice(CUdevice(-2))

Expand Down
5 changes: 1 addition & 4 deletions lib/cudadrv/libcuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ end
end
end

function check(f)
@inline function check(f)
res = f()
if res != SUCCESS
throw_api_error(res)
Expand Down Expand Up @@ -3269,7 +3269,6 @@ end
end

@checked function cuCtxCreate_v3(pctx, paramsArray, numParams, flags, dev)
initialize_context()
@gcsafe_ccall libcuda.cuCtxCreate_v3(pctx::Ptr{CUcontext},
paramsArray::Ptr{CUexecAffinityParam},
numParams::Cint, flags::Cuint,
Expand Down Expand Up @@ -3299,7 +3298,6 @@ end
end

@checked function cuCtxGetId(ctx, ctxId)
initialize_context()
@gcsafe_ccall libcuda.cuCtxGetId(ctx::CUcontext, ctxId::Ptr{Culonglong})::CUresult
end

Expand Down Expand Up @@ -3329,7 +3327,6 @@ end
end

@checked function cuCtxGetApiVersion(ctx, version)
initialize_context()
@gcsafe_ccall libcuda.cuCtxGetApiVersion(ctx::CUcontext, version::Ptr{Cuint})::CUresult
end

Expand Down
7 changes: 4 additions & 3 deletions lib/cudadrv/memory.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,15 @@ Base.unsafe_convert(T::Type{<:Union{Ptr,CuPtr,CuArrayPtr}}, mem::AbstractMemory)
Device memory residing on the GPU.
"""
struct DeviceMemory <: AbstractMemory
dev::CuDevice
ctx::CuContext
ptr::CuPtr{Cvoid}
bytesize::Int

async::Bool
end

DeviceMemory() = DeviceMemory(context(), CU_NULL, 0, false)
DeviceMemory() = DeviceMemory(device(), context(), CU_NULL, 0, false)

Base.pointer(mem::DeviceMemory) = mem.ptr
Base.sizeof(mem::DeviceMemory) = mem.bytesize
Expand Down Expand Up @@ -75,7 +76,7 @@ function alloc(::Type{DeviceMemory}, bytesize::Integer;
cuMemAlloc_v2(ptr_ref, bytesize)
end

return DeviceMemory(context(), reinterpret(CuPtr{Cvoid}, ptr_ref[]), bytesize, async)
return DeviceMemory(device(), context(), reinterpret(CuPtr{Cvoid}, ptr_ref[]), bytesize, async)
end

function free(mem::DeviceMemory; stream::Union{Nothing,CuStream}=nothing)
Expand Down Expand Up @@ -796,7 +797,7 @@ end
Identify the context memory was allocated in.
"""
context(ptr::Union{Ptr,CuPtr}) =
UniqueCuContext(attribute(CUcontext, ptr, POINTER_ATTRIBUTE_CONTEXT))
CuContext(attribute(CUcontext, ptr, POINTER_ATTRIBUTE_CONTEXT))

"""
device(ptr)
Expand Down
2 changes: 1 addition & 1 deletion lib/cudadrv/module/global.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ struct CuGlobal{T}
if nbytes_ref[] != sizeof(T)
throw(ArgumentError("size of global '$name' does not match type parameter type $T"))
end
buf = DeviceMemory(context(), ptr_ref[], nbytes_ref[], false)
buf = DeviceMemory(device(), context(), ptr_ref[], nbytes_ref[], false)

return new{T}(buf)
end
Expand Down
Loading

0 comments on commit 752571b

Please sign in to comment.