Skip to content

Commit

Permalink
Test for potential multi-device
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jun 7, 2024
1 parent e00224b commit dba99cf
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/LuxDeviceUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ Selects GPU device based on the following criteria:
- `force_gpu_usage::Bool`: If `true`, then an error is thrown if no functional GPU
device is found.
"""
function gpu_device(device_id::Union{Nothing, Int}=nothing;
function gpu_device(device_id::Union{Nothing, <:Integer}=nothing;
force_gpu_usage::Bool=false)::AbstractLuxDevice
device_id == 0 && throw(ArgumentError("`device_id` is 1-indexed."))

Expand Down
12 changes: 12 additions & 0 deletions test/amdgpu.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using LuxDeviceUtils, Random
using ArrayInterface: parameterless_type

@testset "CPU Fallback" begin
@test !LuxDeviceUtils.functional(LuxAMDGPUDevice)
Expand Down Expand Up @@ -93,6 +94,17 @@ using FillArrays, Zygote # Extensions

ps_mixed = (; a=rand(2), b=device(rand(2)))
@test_throws ArgumentError get_device(ps_mixed)

dev = gpu_device()
x = rand(Float32, 10, 2)
x_dev = x |> dev
@test get_device(x_dev) isa parameterless_type(typeof(dev))

if LuxDeviceUtils.functional(LuxAMDGPUDevice)
dev2 = gpu_device(length(AMDGPU.devices()))
x_dev2 = x_dev |> dev2
@test get_device(x_dev2) isa typeof(dev2)
end
end

@testset "Wrapped Arrays" begin
Expand Down
12 changes: 12 additions & 0 deletions test/cuda.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using LuxDeviceUtils, Random
using ArrayInterface: parameterless_type

@testset "CPU Fallback" begin
@test !LuxDeviceUtils.functional(LuxCUDADevice)
Expand Down Expand Up @@ -92,6 +93,17 @@ using FillArrays, Zygote # Extensions

ps_mixed = (; a=rand(2), b=device(rand(2)))
@test_throws ArgumentError get_device(ps_mixed)

dev = gpu_device()
x = rand(Float32, 10, 2)
x_dev = x |> dev
@test get_device(x_dev) isa parameterless_type(typeof(dev))

if LuxDeviceUtils.functional(LuxCUDADevice)
dev2 = gpu_device(length(CUDA.devices()))
x_dev2 = x_dev |> dev2
@test get_device(x_dev2) isa typeof(dev2)
end
end

@testset "Wrapped Arrays" begin
Expand Down

0 comments on commit dba99cf

Please sign in to comment.