Skip to content

Commit

Permalink
Test setdevice
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jun 7, 2024
1 parent 6ff2497 commit e00224b
Show file tree
Hide file tree
Showing 8 changed files with 104 additions and 12 deletions.
5 changes: 3 additions & 2 deletions ext/LuxDeviceUtilsAMDGPUExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,13 @@ end
Adapt.adapt_storage(::LuxAMDGPUDevice{Nothing}, x::AbstractArray) = AMDGPU.roc(x)
function Adapt.adapt_storage(to::LuxAMDGPUDevice, x::AbstractArray)

Check warning on line 70 in ext/LuxDeviceUtilsAMDGPUExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxDeviceUtilsAMDGPUExt.jl#L69-L70

Added lines #L69 - L70 were not covered by tests
old_dev = AMDGPU.device() # remember the current device
if !(LuxDeviceUtils.get_device(x) isa LuxAMDGPUDevice)
dev = LuxDeviceUtils.get_device(x)
if !(dev isa LuxAMDGPUDevice)

Check warning on line 73 in ext/LuxDeviceUtilsAMDGPUExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxDeviceUtilsAMDGPUExt.jl#L72-L73

Added lines #L72 - L73 were not covered by tests
AMDGPU.device!(to.device)
x_new = AMDGPU.roc(x)
AMDGPU.device!(old_dev)
return x_new
elseif AMDGPU.device_id(AMDGPU.device(x)) == AMDGPU.device_id(to.device)
elseif AMDGPU.device_id(dev.device) == AMDGPU.device_id(to.device)

Check warning on line 78 in ext/LuxDeviceUtilsAMDGPUExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxDeviceUtilsAMDGPUExt.jl#L78

Added line #L78 was not covered by tests
return x
else
AMDGPU.device!(to.device)
Expand Down
7 changes: 4 additions & 3 deletions ext/LuxDeviceUtilsCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ function LuxDeviceUtils.set_device!(::Type{LuxCUDADevice}, dev::CUDA.CuDevice)
return CUDA.device!(dev)
end
function LuxDeviceUtils.set_device!(::Type{LuxCUDADevice}, id::Int)
return LuxDeviceUtils.set_device!(LuxCUDADevice, CUDA.devices()[id])
return LuxDeviceUtils.set_device!(LuxCUDADevice, collect(CUDA.devices())[id])
end
function LuxDeviceUtils.set_device!(::Type{LuxCUDADevice}, ::Nothing, rank::Int)
id = mod1(rank + 1, length(CUDA.devices()))
Expand All @@ -51,12 +51,13 @@ end
Adapt.adapt_storage(::LuxCUDADevice{Nothing}, x::AbstractArray) = CUDA.cu(x)
function Adapt.adapt_storage(to::LuxCUDADevice, x::AbstractArray)
old_dev = CUDA.device() # remember the current device
if !(LuxDeviceUtils.get_device(x) isa LuxCUDADevice)
dev = LuxDeviceUtils.get_device(x)
if !(dev isa LuxCUDADevice)
CUDA.device!(to.device)
x_new = CUDA.cu(x)
CUDA.device!(old_dev)
return x_new
elseif CUDA.device(x) == to.device
elseif dev.device == to.device

Check warning on line 60 in ext/LuxDeviceUtilsCUDAExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxDeviceUtilsCUDAExt.jl#L60

Added line #L60 was not covered by tests
return x
else
CUDA.device!(to.device)
Expand Down
14 changes: 7 additions & 7 deletions src/LuxDeviceUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,8 @@ Selects GPU device based on the following criteria:
!!! warning
`device_id` is only applicable for `CUDA` and `AMDGPU` backends. For `Metal` and `CPU`
backends, `device_id` is ignored and a warning is printed.
`device_id` is only applicable for `CUDA` and `AMDGPU` backends. For `Metal`, `oneAPI`
and `CPU` backends, `device_id` is ignored and a warning is printed.
## Keyword Arguments
Expand Down Expand Up @@ -413,15 +413,15 @@ $SET_DEVICE_DANGER
"""
function set_device!(::Type{T}, dev_or_id) where {T <: AbstractLuxDevice}
T === LuxCUDADevice &&
@warn "`CUDA.jl` hasn't been loaded. Ignoring the device setting." maxlog=1
@warn "`CUDA.jl` hasn't been loaded. Ignoring the device setting."

Check warning on line 416 in src/LuxDeviceUtils.jl

View check run for this annotation

Codecov / codecov/patch

src/LuxDeviceUtils.jl#L416

Added line #L416 was not covered by tests
T === LuxAMDGPUDevice &&
@warn "`AMDGPU.jl` hasn't been loaded. Ignoring the device setting." maxlog=1
@warn "`AMDGPU.jl` hasn't been loaded. Ignoring the device setting."

Check warning on line 418 in src/LuxDeviceUtils.jl

View check run for this annotation

Codecov / codecov/patch

src/LuxDeviceUtils.jl#L418

Added line #L418 was not covered by tests
T === LuxMetalDevice &&
@warn "Support for Multi Device Metal hasn't been implemented yet. Ignoring the device setting." maxlog=1
@warn "Support for Multi Device Metal hasn't been implemented yet. Ignoring the device setting."

Check warning on line 420 in src/LuxDeviceUtils.jl

View check run for this annotation

Codecov / codecov/patch

src/LuxDeviceUtils.jl#L420

Added line #L420 was not covered by tests
T === LuxoneAPIDevice &&
@warn "Support for Multi Device oneAPI hasn't been implemented yet. Ignoring the device setting." maxlog=1
@warn "Support for Multi Device oneAPI hasn't been implemented yet. Ignoring the device setting."

Check warning on line 422 in src/LuxDeviceUtils.jl

View check run for this annotation

Codecov / codecov/patch

src/LuxDeviceUtils.jl#L422

Added line #L422 was not covered by tests
T === LuxCPUDevice &&
@warn "Setting device for `LuxCPUDevice` doesn't make sense. Ignoring the device setting." maxlog=1
@warn "Setting device for `LuxCPUDevice` doesn't make sense. Ignoring the device setting."

Check warning on line 424 in src/LuxDeviceUtils.jl

View check run for this annotation

Codecov / codecov/patch

src/LuxDeviceUtils.jl#L424

Added line #L424 was not covered by tests
return
end

Expand Down
19 changes: 19 additions & 0 deletions test/amdgpu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ using LuxDeviceUtils, Random
@test_throws LuxDeviceUtils.LuxDeviceSelectionException gpu_device(;
force_gpu_usage=true)
@test_throws Exception default_device_rng(LuxAMDGPUDevice(nothing))
@test_logs (:warn, "`AMDGPU.jl` hasn't been loaded. Ignoring the device setting.") LuxDeviceUtils.set_device!(
LuxAMDGPUDevice, nothing, 1)
end

using AMDGPU
Expand Down Expand Up @@ -93,6 +95,15 @@ using FillArrays, Zygote # Extensions
@test_throws ArgumentError get_device(ps_mixed)
end

@testset "Wrapped Arrays" begin
if LuxDeviceUtils.functional(LuxAMDGPUDevice)
x = rand(10, 10) |> LuxAMDGPUDevice()
@test get_device(x) isa LuxAMDGPUDevice
x_view = view(x, 1:5, 1:5)
@test get_device(x_view) isa LuxAMDGPUDevice
end
end

@testset "Multiple Devices AMDGPU" begin
if LuxDeviceUtils.functional(LuxAMDGPUDevice)
ps = (; weight=rand(Float32, 10), bias=rand(Float32, 10))
Expand All @@ -117,3 +128,11 @@ end
@test ps.bias isa Array
end
end

@testset "setdevice!" begin
if LuxDeviceUtils.functional(LuxAMDGPUDevice)
for i in 1:10
@test_nowarn LuxDeviceUtils.set_device!(LuxAMDGPUDevice, nothing, i)
end
end
end
19 changes: 19 additions & 0 deletions test/cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ using LuxDeviceUtils, Random
@test_throws LuxDeviceUtils.LuxDeviceSelectionException gpu_device(;
force_gpu_usage=true)
@test_throws Exception default_device_rng(LuxCUDADevice(nothing))
@test_logs (:warn, "`CUDA.jl` hasn't been loaded. Ignoring the device setting.") LuxDeviceUtils.set_device!(
LuxCUDADevice, nothing, 1)
end

using LuxCUDA
Expand Down Expand Up @@ -92,6 +94,15 @@ using FillArrays, Zygote # Extensions
@test_throws ArgumentError get_device(ps_mixed)
end

@testset "Wrapped Arrays" begin
if LuxDeviceUtils.functional(LuxCUDADevice)
x = rand(10, 10) |> LuxCUDADevice()
@test get_device(x) isa LuxCUDADevice
x_view = view(x, 1:5, 1:5)
@test get_device(x_view) isa LuxCUDADevice
end
end

@testset "Multiple Devices CUDA" begin
if LuxDeviceUtils.functional(LuxCUDADevice)
ps = (; weight=rand(Float32, 10), bias=rand(Float32, 10))
Expand Down Expand Up @@ -143,3 +154,11 @@ using SparseArrays
@test ps.bias isa SparseVector
end
end

@testset "setdevice!" begin
if LuxDeviceUtils.functional(LuxCUDADevice)
for i in 1:10
@test_nowarn LuxDeviceUtils.set_device!(LuxCUDADevice, nothing, i)
end
end
end
17 changes: 17 additions & 0 deletions test/metal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,20 @@ using FillArrays, Zygote # Extensions
ps_mixed = (; a=rand(2), b=device(rand(2)))
@test_throws ArgumentError get_device(ps_mixed)
end

@testset "Wrapper Arrays" begin
if LuxDeviceUtils.functional(LuxMetalDevice)
x = rand(Float32, 10, 10) |> LuxMetalDevice()
@test get_device(x) isa LuxMetalDevice
x_view = view(x, 1:5, 1:5)
@test get_device(x_view) isa LuxMetalDevice
end
end

@testset "setdevice!" begin
if LuxDeviceUtils.functional(LuxMetalDevice)
@test_logs (:warn,
"Support for Multi Device Metal hasn't been implemented yet. Ignoring the device setting.") LuxDeviceUtils.set_device!(
LuxMetalDevice, nothing, 1)
end
end
18 changes: 18 additions & 0 deletions test/misc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,21 @@ end
vecarray_dev = vecarray |> gdev
@test get_device(vecarray_dev) isa parameterless_type(typeof(gdev))
end

@testset "CPU default rng" begin
@test default_device_rng(LuxCPUDevice()) isa Random.TaskLocalRNG
end

@testset "CPU setdevice!" begin
@test_logs (:warn,
"Setting device for `LuxCPUDevice` doesn't make sense. Ignoring the device setting.") LuxDeviceUtils.set_device!(
LuxCPUDevice, nothing, 1)
end

@testset "get_device on Arrays" begin
x = rand(10, 10)
x_view = view(x, 1:5, 1:5)

@test get_device(x) isa LuxCPUDevice
@test get_device(x_view) isa LuxCPUDevice
end
17 changes: 17 additions & 0 deletions test/oneapi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,20 @@ using FillArrays, Zygote # Extensions
ps_mixed = (; a=rand(2), b=device(rand(2)))
@test_throws ArgumentError get_device(ps_mixed)
end

@testset "Wrapper Arrays" begin
if LuxDeviceUtils.functional(LuxoneAPIDevice)
x = rand(10, 10) |> LuxoneAPIDevice()
@test get_device(x) isa LuxoneAPIDevice
x_view = view(x, 1:5, 1:5)
@test get_device(x_view) isa LuxoneAPIDevice
end
end

@testset "setdevice!" begin
if LuxDeviceUtils.functional(LuxoneAPIDevice)
@test_logs (:warn,
"Support for Multi Device oneAPI hasn't been implemented yet. Ignoring the device setting.") LuxDeviceUtils.set_device!(
LuxoneAPIDevice, nothing, 1)
end
end

0 comments on commit e00224b

Please sign in to comment.