Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework context handling #2346

Merged
merged 21 commits into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 42 additions & 106 deletions lib/cudadrv/context.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
export
CuPrimaryContext, CuContext, current_context, has_context, activate,
unsafe_reset!, isactive, flags, setflags!,
device, device_synchronize
unique_id, device, device_synchronize


## construction and destruction
Expand All @@ -20,94 +20,58 @@ the system cleans up the resources allocated to it.

When you are done using the context, call [`CUDA.unsafe_destroy!`](@ref) to mark it for
deletion, or use do-block syntax with this constructor.

"""
mutable struct CuContext
struct CuContext
handle::CUcontext
valid::Bool
id::Int
maleadt marked this conversation as resolved.
Show resolved Hide resolved

function CuContext(dev::CuDevice, flags=0)
handle_ref = Ref{CUcontext}()
cuCtxCreate_v2(handle_ref, flags, dev)
UniqueCuContext(handle_ref[])
end
function CuContext(handle::CUcontext)
handle == C_NULL && throw(UndefRefError())

global function current_context()
handle_ref = Ref{CUcontext}()
cuCtxGetCurrent(handle_ref)
handle_ref[] == C_NULL && throw(UndefRefError())
UniqueCuContext(handle_ref[])
end
id_ref = Ref{Culonglong}()
res = unchecked_cuCtxGetId(handle, id_ref)
res == ERROR_CONTEXT_IS_DESTROYED && throw(UndefRefError())
res != SUCCESS && throw_api_error(res)

global UnsafeCuContext(handle::CUcontext) = new(handle, true)
new(handle, id_ref[])
end
end

unsafe
function CuContext(dev::CuDevice, flags=0)
handle_ref = Ref{CUcontext}()
cuCtxCreate_v2(handle_ref, flags, dev)
CuContext(handle_ref[])
end

"""
current_context()

Returns the current context.
Returns the current context. Throws an undefined reference error if the current thread
has no context bound to it, or if the bound context has been destroyed.

!!! warning

This is a low-level API, returning the current context as known to the CUDA driver.
For most users, it is recommended to use the [`context`](@ref) method instead.
"""
current_context()

"""
has_context()

Returns whether there is an active context.
"""
function has_context()
global function current_context()
handle_ref = Ref{CUcontext}()
cuCtxGetCurrent(handle_ref)
handle_ref[] != C_NULL
@inline cuCtxGetCurrent(handle_ref) # JuliaGPU/CUDA.jl#2347
handle_ref[] == C_NULL && throw(UndefRefError())
CuContext(handle_ref[])
end

# we need to know when a context has been destroyed, to make sure we don't destroy resources
# after the owning context has been destroyed already. this is complicated by the fact that
# contexts obtained from a primary context have the same handle before and after primary
# context destruction, so we cannot use a simple mapping from context handle to a validity
# bit. instead, we unique the context objects and put a validity bit in there.
isvalid(ctx::CuContext) = ctx.valid
function invalidate!(ctx::CuContext)
ctx.valid = false
return
end
# to make this work, every function returning a context (e.g. `cuCtxGetCurrent`, attribute
# functions, etc) need to return the same context objects. because looking up a context is a
# very common operation (often executed from finalizers), we need to ensure this look-up is
# fast and does not switch tasks. we do this by scanning a simple linear vector.
const MAX_CONTEXTS = 1024
const context_objects = Vector{CuContext}(undef, MAX_CONTEXTS)
const context_lock = Base.ThreadSynchronizer()
function UniqueCuContext(handle::CUcontext)
@lock context_lock begin
# look if there's an existing object for this handle
i = 1
@inbounds while i <= MAX_CONTEXTS && isassigned(context_objects, i)
if context_objects[i].handle == handle
if isvalid(context_objects[i])
return context_objects[i]
else
# this object was invalidated, so we can reuse its slot
break
end
end
i += 1
end
if i == MAX_CONTEXTS
error("Exceeded maximum amount of CUDA contexts. This is unexpected; please file an issue.")
end

# we've got a slot we can write to
new_object = UnsafeCuContext(handle)
@inbounds context_objects[i] = new_object
return new_object
end
function isvalid(ctx::CuContext)
# we first try an API call to see if the context handle is usable
id_ref = Ref{Culonglong}()
res = unchecked_cuCtxGetId(ctx, id_ref)
res == ERROR_CONTEXT_IS_DESTROYED && return false
res != SUCCESS && throw_api_error(res)

# do detect handle reuse, which happens when destroying and re-creating a context,
# we ensure that the id for the current version of this context matches the id we
# saved during construction of the original object
return ctx.id == id_ref[]
end

"""
Expand All @@ -119,27 +83,18 @@ respect any users of the context, and might make other objects unusable.
function unsafe_destroy!(ctx::CuContext)
if isvalid(ctx)
cuCtxDestroy_v2(ctx)
invalidate!(ctx)
end
end

Base.unsafe_convert(::Type{CUcontext}, ctx::CuContext) = ctx.handle

# NOTE: we don't implement `isequal` or `hash` in order to fall back to `===` and `objectid`
# as contexts are unique, and with primary device contexts identical handles might be
# returned after resetting the context (device) and all associated resources.

function Base.show(io::IO, ctx::CuContext)
if ctx.handle != C_NULL
fields = [@sprintf("%p", ctx.handle), @sprintf("instance %x", objectid(ctx))]
if !isvalid(ctx)
push!(fields, "invalidated")
end

print(io, "CuContext(", join(fields, ", "), ")")
else
print(io, "CuContext(NULL)")
fields = [@sprintf("%p", ctx.handle), "id=$(ctx.id)"]
if !isvalid(ctx)
push!(fields, "invalidated")
end

print(io, "CuContext(", join(fields, ", "), ")")
end


Expand Down Expand Up @@ -196,11 +151,6 @@ struct CuPrimaryContext
dev::CuDevice
end

# we need to keep track of contexts derived from primary contexts,
# so that we can invalidate them when the primary context is reset.
const derived_contexts = Dict{CuPrimaryContext,CuContext}()
const derived_lock = ReentrantLock()

"""
CuContext(pctx::CuPrimaryContext)

Expand All @@ -215,9 +165,7 @@ by using the `do`-block syntax.
function CuContext(pctx::CuPrimaryContext)
handle_ref = Ref{CUcontext}()
cuDevicePrimaryCtxRetain(handle_ref, pctx.dev)
ctx = UniqueCuContext(handle_ref[])
Base.@lock derived_lock derived_contexts[pctx] = ctx
return ctx
CuContext(handle_ref[])
end

function CuContext(f::Function, pctx::CuPrimaryContext)
Expand All @@ -237,18 +185,12 @@ does not respect any users of the context, and might make other objects unusable
"""
function unsafe_release!(pctx::CuPrimaryContext)
if driver_version() >= v"11"
cuDevicePrimaryCtxRelease_v2(dev)
cuDevicePrimaryCtxRelease_v2(pctx.dev)
else
cuDevicePrimaryCtxRelease(dev)
cuDevicePrimaryCtxRelease(pctx.dev)
end

# if this releases the last reference, invalidate all derived contexts
if !isactive(pctx)
ctx = @lock derived_lock get(derived_contexts, pctx, nothing)
if ctx !== nothing
invalidate!(ctx)
end
end
return
end

"""
Expand All @@ -265,12 +207,6 @@ function unsafe_reset!(pctx::CuPrimaryContext)
cuDevicePrimaryCtxReset(pctx.dev)
end

# invalidate all derived contexts
ctx = @lock derived_lock get(derived_contexts, pctx, nothing)
if ctx !== nothing
invalidate!(ctx)
end

return
end

Expand Down
17 changes: 0 additions & 17 deletions lib/cudadrv/devices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,23 +42,6 @@ Returns the current device.
"""
current_device()

"""
has_device()

Returns whether there is an active device.
"""
function has_device()
device_ref = Ref{CUdevice}()
res = unchecked_cuCtxGetDevice(device_ref)
if res == SUCCESS
return true
elseif res == ERROR_INVALID_CONTEXT
return false
else
throw_api_error(res)
end
end

const DEVICE_CPU = _CuDevice(CUdevice(-1))
const DEVICE_INVALID = _CuDevice(CUdevice(-2))

Expand Down
5 changes: 1 addition & 4 deletions lib/cudadrv/libcuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ end
end
end

function check(f)
@inline function check(f)
res = f()
if res != SUCCESS
throw_api_error(res)
Expand Down Expand Up @@ -3269,7 +3269,6 @@ end
end

@checked function cuCtxCreate_v3(pctx, paramsArray, numParams, flags, dev)
initialize_context()
@gcsafe_ccall libcuda.cuCtxCreate_v3(pctx::Ptr{CUcontext},
paramsArray::Ptr{CUexecAffinityParam},
numParams::Cint, flags::Cuint,
Expand Down Expand Up @@ -3299,7 +3298,6 @@ end
end

@checked function cuCtxGetId(ctx, ctxId)
initialize_context()
@gcsafe_ccall libcuda.cuCtxGetId(ctx::CUcontext, ctxId::Ptr{Culonglong})::CUresult
end

Expand Down Expand Up @@ -3329,7 +3327,6 @@ end
end

@checked function cuCtxGetApiVersion(ctx, version)
initialize_context()
@gcsafe_ccall libcuda.cuCtxGetApiVersion(ctx::CUcontext, version::Ptr{Cuint})::CUresult
end

Expand Down
5 changes: 3 additions & 2 deletions lib/cudadrv/memory.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,15 @@ Base.unsafe_convert(T::Type{<:Union{Ptr,CuPtr,CuArrayPtr}}, mem::AbstractMemory)
Device memory residing on the GPU.
"""
struct DeviceMemory <: AbstractMemory
dev::CuDevice
ctx::CuContext
ptr::CuPtr{Cvoid}
bytesize::Int

async::Bool
end

DeviceMemory() = DeviceMemory(context(), CU_NULL, 0, false)
DeviceMemory() = DeviceMemory(device(), context(), CU_NULL, 0, false)

Base.pointer(mem::DeviceMemory) = mem.ptr
Base.sizeof(mem::DeviceMemory) = mem.bytesize
Expand Down Expand Up @@ -75,7 +76,7 @@ function alloc(::Type{DeviceMemory}, bytesize::Integer;
cuMemAlloc_v2(ptr_ref, bytesize)
end

return DeviceMemory(context(), reinterpret(CuPtr{Cvoid}, ptr_ref[]), bytesize, async)
return DeviceMemory(device(), context(), reinterpret(CuPtr{Cvoid}, ptr_ref[]), bytesize, async)
end

function free(mem::DeviceMemory; stream::Union{Nothing,CuStream}=nothing)
Expand Down
2 changes: 1 addition & 1 deletion lib/cudadrv/module/global.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ struct CuGlobal{T}
if nbytes_ref[] != sizeof(T)
throw(ArgumentError("size of global '$name' does not match type parameter type $T"))
end
buf = DeviceMemory(context(), ptr_ref[], nbytes_ref[], false)
buf = DeviceMemory(device(), context(), ptr_ref[], nbytes_ref[], false)

return new{T}(buf)
end
Expand Down
9 changes: 0 additions & 9 deletions lib/cudadrv/state.jl
Original file line number Diff line number Diff line change
Expand Up @@ -302,22 +302,13 @@ function device!(f::Function, dev::CuDevice)
context!(f, ctx)
end

# NVIDIA bug #3240770
can_reset_device() = !(Base.thisminor(driver_version()) == v"11.2" &&
any(dev->stream_ordered(dev), devices()))

"""
device_reset!(dev::CuDevice=device())

Reset the CUDA state associated with a device. This call with release the underlying
context, at which point any objects allocated in that context will be invalidated.
"""
function device_reset!(dev::CuDevice=device())
if !can_reset_device()
@error "Due to a bug in CUDA, resetting the device is not possible on CUDA 11.2 when using the stream-ordered memory allocator."
return
end

# unconditionally reset the primary context (don't just release it),
# as there might be users outside of CUDA.jl
pctx = CuPrimaryContext(dev)
Expand Down
14 changes: 11 additions & 3 deletions lib/cudadrv/stream.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

export
CuStream, default_stream, legacy_stream, per_thread_stream,
priority, priority_range, synchronize, device_synchronize
unique_id, priority, priority_range, synchronize, device_synchronize

"""
CuStream(; flags=STREAM_DEFAULT, priority=nothing)
Expand All @@ -13,7 +13,7 @@ mutable struct CuStream
const handle::CUstream
Base.@atomic valid::Bool

const ctx::CuContext
const ctx::Union{Nothing,CuContext}

function CuStream(; flags::CUstream_flags=STREAM_DEFAULT,
priority::Union{Nothing,Integer}=nothing)
Expand Down Expand Up @@ -84,6 +84,7 @@ Base.hash(s::CuStream, h::UInt) = hash(s.handle, h)
@enum_without_prefix CUstream_flags_enum CU_

function unsafe_destroy!(s::CuStream)
@assert s.ctx !== nothing "Cannot destroy unassociated stream"
context!(s.ctx; skip_destroyed=true) do
cuStreamDestroy_v2(s)
end
Expand All @@ -93,9 +94,16 @@ end
function Base.show(io::IO, stream::CuStream)
print(io, "CuStream(")
@printf(io, "%p", stream.handle)
if isdefined(stream, :ctx)
if stream.ctx !== nothing
print(io, ", ", stream.ctx)
end
print(io, ")")
end

function unique_id(s::CuStream)
id = Ref{Culonglong}()
cuStreamGetId(s, id)
return id[]
end

"""
Expand Down
2 changes: 1 addition & 1 deletion lib/cupti/wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ function process(f, cfg::ActivityConfig)

# extract typed activity records
for (ctx_handle, stream_id, buf_ptr, sz, valid_sz) in cfg.results
ctx = CUDA.UniqueCuContext(ctx_handle)
ctx = ctx_handle == C_NULL ? nothing : CUDA.UniqueCuContext(ctx_handle)
# XXX: can we reconstruct the stream from the stream ID?

record_ptr = Ref{Ptr{CUpti_Activity}}(C_NULL)
Expand Down
Loading