diff --git a/src/LuxDeviceUtils.jl b/src/LuxDeviceUtils.jl index bd43c51..8ad0c2a 100644 --- a/src/LuxDeviceUtils.jl +++ b/src/LuxDeviceUtils.jl @@ -158,7 +158,7 @@ Selects GPU device based on the following criteria: - `force_gpu_usage::Bool`: If `true`, then an error is thrown if no functional GPU device is found. """ -function gpu_device(device_id::Union{Nothing, Int}=nothing; +function gpu_device(device_id::Union{Nothing, <:Integer}=nothing; force_gpu_usage::Bool=false)::AbstractLuxDevice device_id == 0 && throw(ArgumentError("`device_id` is 1-indexed.")) diff --git a/test/amdgpu.jl b/test/amdgpu.jl index 4840b98..159b241 100644 --- a/test/amdgpu.jl +++ b/test/amdgpu.jl @@ -1,4 +1,5 @@ using LuxDeviceUtils, Random +using ArrayInterface: parameterless_type @testset "CPU Fallback" begin @test !LuxDeviceUtils.functional(LuxAMDGPUDevice) @@ -93,6 +94,17 @@ using FillArrays, Zygote # Extensions ps_mixed = (; a=rand(2), b=device(rand(2))) @test_throws ArgumentError get_device(ps_mixed) + + dev = gpu_device() + x = rand(Float32, 10, 2) + x_dev = x |> dev + @test get_device(x_dev) isa parameterless_type(typeof(dev)) + + if LuxDeviceUtils.functional(LuxAMDGPUDevice) + dev2 = gpu_device(length(AMDGPU.devices())) + x_dev2 = x_dev |> dev2 + @test get_device(x_dev2) isa typeof(dev2) + end end @testset "Wrapped Arrays" begin diff --git a/test/cuda.jl b/test/cuda.jl index 3b1983b..5c4a7ee 100644 --- a/test/cuda.jl +++ b/test/cuda.jl @@ -1,4 +1,5 @@ using LuxDeviceUtils, Random +using ArrayInterface: parameterless_type @testset "CPU Fallback" begin @test !LuxDeviceUtils.functional(LuxCUDADevice) @@ -92,6 +93,17 @@ using FillArrays, Zygote # Extensions ps_mixed = (; a=rand(2), b=device(rand(2))) @test_throws ArgumentError get_device(ps_mixed) + + dev = gpu_device() + x = rand(Float32, 10, 2) + x_dev = x |> dev + @test get_device(x_dev) isa parameterless_type(typeof(dev)) + + if LuxDeviceUtils.functional(LuxCUDADevice) + dev2 = gpu_device(length(CUDA.devices())) + x_dev2 = x_dev |> dev2 + @test get_device(x_dev2) isa typeof(dev2) + end end @testset "Wrapped Arrays" begin