Skip to content

Commit

Permalink
Allow direct devices as well
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Mar 27, 2024
1 parent f0bc7eb commit 0faa572
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 3 deletions.
8 changes: 8 additions & 0 deletions ext/LuxDeviceUtilsAMDGPUExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@ LuxDeviceUtils.default_device_rng(::LuxAMDGPUDevice) = AMDGPU.rocrand_rng()
LuxDeviceUtils.get_device(x::AMDGPU.AnyROCArray) = LuxAMDGPUDevice(AMDGPU.device(x))

Check warning on line 27 in ext/LuxDeviceUtilsAMDGPUExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxDeviceUtilsAMDGPUExt.jl#L27

Added line #L27 was not covered by tests

# Set Device
function LuxDeviceUtils.set_device!(::Type{LuxAMDGPUDevice}, dev::AMDGPU.HIPDevice)
if !AMDGPU.functional()
@warn "AMDGPU is not functional."
return

Check warning on line 33 in ext/LuxDeviceUtilsAMDGPUExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxDeviceUtilsAMDGPUExt.jl#L30-L33

Added lines #L30 - L33 were not covered by tests
end
AMDGPU.device!(dev)
return

Check warning on line 36 in ext/LuxDeviceUtilsAMDGPUExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxDeviceUtilsAMDGPUExt.jl#L35-L36

Added lines #L35 - L36 were not covered by tests
end
function LuxDeviceUtils.set_device!(::Type{LuxAMDGPUDevice}, id::Int)
if !AMDGPU.functional()
@warn "AMDGPU is not functional."
Expand Down
8 changes: 8 additions & 0 deletions ext/LuxDeviceUtilsCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,14 @@ LuxDeviceUtils.default_device_rng(::LuxCUDADevice) = CUDA.default_rng()
LuxDeviceUtils.get_device(x::CUDA.AnyCuArray) = LuxCUDADevice(CUDA.device(x))

Check warning on line 28 in ext/LuxDeviceUtilsCUDAExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxDeviceUtilsCUDAExt.jl#L28

Added line #L28 was not covered by tests

# Set Device
function LuxDeviceUtils.set_device!(::Type{LuxCUDADevice}, dev::CUDA.CuDevice)
if !CUDA.functional()
@warn "CUDA is not functional."
return

Check warning on line 34 in ext/LuxDeviceUtilsCUDAExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxDeviceUtilsCUDAExt.jl#L31-L34

Added lines #L31 - L34 were not covered by tests
end
CUDA.device!(dev)
return

Check warning on line 37 in ext/LuxDeviceUtilsCUDAExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxDeviceUtilsCUDAExt.jl#L36-L37

Added lines #L36 - L37 were not covered by tests
end
function LuxDeviceUtils.set_device!(::Type{LuxCUDADevice}, id::Int)
if !CUDA.functional()
@warn "CUDA is not functional."
Expand Down
8 changes: 5 additions & 3 deletions src/LuxDeviceUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -355,18 +355,20 @@ const SET_DEVICE_DANGER = """
"""

"""
set_device!(T::Type{<:AbstractLuxDevice}, id::Int)
set_device!(T::Type{<:AbstractLuxDevice}, dev_or_id)
$SET_DEVICE_DOCS
## Arguments
- `T::Type{<:AbstractLuxDevice}`: The device type to set.
- `id::Int`: The device id to set. This is `1`-indexed.
- `dev_or_id`: Can be the device from the corresponding package. For example for CUDA it
can be a `CuDevice`. If it is an integer, it is the device id to set. This is
`1`-indexed.
$SET_DEVICE_DANGER
"""
function set_device!(::Type{T}, id::Int) where {T <: AbstractLuxDevice}
function set_device!(::Type{T}, dev_or_id) where {T <: AbstractLuxDevice}
T === LuxCUDADevice &&
@warn "`CUDA.jl` hasn't been loaded. Ignoring the device setting." maxlog=1
T === LuxAMDGPUDevice &&
Expand Down

0 comments on commit 0faa572

Please sign in to comment.