diff --git a/ext/LuxDeviceUtilsAMDGPUExt.jl b/ext/LuxDeviceUtilsAMDGPUExt.jl index 87043cf..d39c8f9 100644 --- a/ext/LuxDeviceUtilsAMDGPUExt.jl +++ b/ext/LuxDeviceUtilsAMDGPUExt.jl @@ -69,12 +69,13 @@ end Adapt.adapt_storage(::LuxAMDGPUDevice{Nothing}, x::AbstractArray) = AMDGPU.roc(x) function Adapt.adapt_storage(to::LuxAMDGPUDevice, x::AbstractArray) old_dev = AMDGPU.device() # remember the current device - if !(LuxDeviceUtils.get_device(x) isa LuxAMDGPUDevice) + dev = LuxDeviceUtils.get_device(x) + if !(dev isa LuxAMDGPUDevice) AMDGPU.device!(to.device) x_new = AMDGPU.roc(x) AMDGPU.device!(old_dev) return x_new - elseif AMDGPU.device_id(AMDGPU.device(x)) == AMDGPU.device_id(to.device) + elseif AMDGPU.device_id(dev.device) == AMDGPU.device_id(to.device) return x else AMDGPU.device!(to.device) diff --git a/ext/LuxDeviceUtilsCUDAExt.jl b/ext/LuxDeviceUtilsCUDAExt.jl index 3e7d253..19cc144 100644 --- a/ext/LuxDeviceUtilsCUDAExt.jl +++ b/ext/LuxDeviceUtilsCUDAExt.jl @@ -40,7 +40,7 @@ function LuxDeviceUtils.set_device!(::Type{LuxCUDADevice}, dev::CUDA.CuDevice) return CUDA.device!(dev) end function LuxDeviceUtils.set_device!(::Type{LuxCUDADevice}, id::Int) - return LuxDeviceUtils.set_device!(LuxCUDADevice, CUDA.devices()[id]) + return LuxDeviceUtils.set_device!(LuxCUDADevice, collect(CUDA.devices())[id]) end function LuxDeviceUtils.set_device!(::Type{LuxCUDADevice}, ::Nothing, rank::Int) id = mod1(rank + 1, length(CUDA.devices())) @@ -51,12 +51,13 @@ end Adapt.adapt_storage(::LuxCUDADevice{Nothing}, x::AbstractArray) = CUDA.cu(x) function Adapt.adapt_storage(to::LuxCUDADevice, x::AbstractArray) old_dev = CUDA.device() # remember the current device - if !(LuxDeviceUtils.get_device(x) isa LuxCUDADevice) + dev = LuxDeviceUtils.get_device(x) + if !(dev isa LuxCUDADevice) CUDA.device!(to.device) x_new = CUDA.cu(x) CUDA.device!(old_dev) return x_new - elseif CUDA.device(x) == to.device + elseif dev.device == to.device return x else CUDA.device!(to.device) diff --git a/src/LuxDeviceUtils.jl b/src/LuxDeviceUtils.jl index 4e48e46..bd43c51 100644 --- a/src/LuxDeviceUtils.jl +++ b/src/LuxDeviceUtils.jl @@ -150,8 +150,8 @@ Selects GPU device based on the following criteria: !!! warning - `device_id` is only applicable for `CUDA` and `AMDGPU` backends. For `Metal` and `CPU` - backends, `device_id` is ignored and a warning is printed. + `device_id` is only applicable for `CUDA` and `AMDGPU` backends. For `Metal`, `oneAPI` + and `CPU` backends, `device_id` is ignored and a warning is printed. ## Keyword Arguments @@ -413,15 +413,15 @@ $SET_DEVICE_DANGER """ function set_device!(::Type{T}, dev_or_id) where {T <: AbstractLuxDevice} T === LuxCUDADevice && - @warn "`CUDA.jl` hasn't been loaded. Ignoring the device setting." maxlog=1 + @warn "`CUDA.jl` hasn't been loaded. Ignoring the device setting." T === LuxAMDGPUDevice && - @warn "`AMDGPU.jl` hasn't been loaded. Ignoring the device setting." maxlog=1 + @warn "`AMDGPU.jl` hasn't been loaded. Ignoring the device setting." T === LuxMetalDevice && - @warn "Support for Multi Device Metal hasn't been implemented yet. Ignoring the device setting." maxlog=1 + @warn "Support for Multi Device Metal hasn't been implemented yet. Ignoring the device setting." T === LuxoneAPIDevice && - @warn "Support for Multi Device oneAPI hasn't been implemented yet. Ignoring the device setting." maxlog=1 + @warn "Support for Multi Device oneAPI hasn't been implemented yet. Ignoring the device setting." T === LuxCPUDevice && - @warn "Setting device for `LuxCPUDevice` doesn't make sense. Ignoring the device setting." maxlog=1 + @warn "Setting device for `LuxCPUDevice` doesn't make sense. Ignoring the device setting." return end diff --git a/test/amdgpu.jl b/test/amdgpu.jl index 7c472fa..4840b98 100644 --- a/test/amdgpu.jl +++ b/test/amdgpu.jl @@ -7,6 +7,8 @@ using LuxDeviceUtils, Random @test_throws LuxDeviceUtils.LuxDeviceSelectionException gpu_device(; force_gpu_usage=true) @test_throws Exception default_device_rng(LuxAMDGPUDevice(nothing)) + @test_logs (:warn, "`AMDGPU.jl` hasn't been loaded. Ignoring the device setting.") LuxDeviceUtils.set_device!( + LuxAMDGPUDevice, nothing, 1) end using AMDGPU @@ -93,6 +95,15 @@ using FillArrays, Zygote # Extensions @test_throws ArgumentError get_device(ps_mixed) end +@testset "Wrapped Arrays" begin + if LuxDeviceUtils.functional(LuxAMDGPUDevice) + x = rand(10, 10) |> LuxAMDGPUDevice() + @test get_device(x) isa LuxAMDGPUDevice + x_view = view(x, 1:5, 1:5) + @test get_device(x_view) isa LuxAMDGPUDevice + end +end + @testset "Multiple Devices AMDGPU" begin if LuxDeviceUtils.functional(LuxAMDGPUDevice) ps = (; weight=rand(Float32, 10), bias=rand(Float32, 10)) @@ -117,3 +128,11 @@ end @test ps.bias isa Array end end + +@testset "setdevice!" begin + if LuxDeviceUtils.functional(LuxAMDGPUDevice) + for i in 1:10 + @test_nowarn LuxDeviceUtils.set_device!(LuxAMDGPUDevice, nothing, i) + end + end +end diff --git a/test/cuda.jl b/test/cuda.jl index 189503e..3b1983b 100644 --- a/test/cuda.jl +++ b/test/cuda.jl @@ -7,6 +7,8 @@ using LuxDeviceUtils, Random @test_throws LuxDeviceUtils.LuxDeviceSelectionException gpu_device(; force_gpu_usage=true) @test_throws Exception default_device_rng(LuxCUDADevice(nothing)) + @test_logs (:warn, "`CUDA.jl` hasn't been loaded. Ignoring the device setting.") LuxDeviceUtils.set_device!( + LuxCUDADevice, nothing, 1) end using LuxCUDA @@ -92,6 +94,15 @@ using FillArrays, Zygote # Extensions @test_throws ArgumentError get_device(ps_mixed) end +@testset "Wrapped Arrays" begin + if LuxDeviceUtils.functional(LuxCUDADevice) + x = rand(10, 10) |> LuxCUDADevice() + @test get_device(x) isa LuxCUDADevice + x_view = view(x, 1:5, 1:5) + @test get_device(x_view) isa LuxCUDADevice + end +end + @testset "Multiple Devices CUDA" begin if LuxDeviceUtils.functional(LuxCUDADevice) ps = (; weight=rand(Float32, 10), bias=rand(Float32, 10)) @@ -143,3 +154,11 @@ using SparseArrays @test ps.bias isa SparseVector end end + +@testset "setdevice!" begin + if LuxDeviceUtils.functional(LuxCUDADevice) + for i in 1:10 + @test_nowarn LuxDeviceUtils.set_device!(LuxCUDADevice, nothing, i) + end + end +end diff --git a/test/metal.jl b/test/metal.jl index 57d1ff6..5c500bf 100644 --- a/test/metal.jl +++ b/test/metal.jl @@ -92,3 +92,20 @@ using FillArrays, Zygote # Extensions ps_mixed = (; a=rand(2), b=device(rand(2))) @test_throws ArgumentError get_device(ps_mixed) end + +@testset "Wrapper Arrays" begin + if LuxDeviceUtils.functional(LuxMetalDevice) + x = rand(Float32, 10, 10) |> LuxMetalDevice() + @test get_device(x) isa LuxMetalDevice + x_view = view(x, 1:5, 1:5) + @test get_device(x_view) isa LuxMetalDevice + end +end + +@testset "setdevice!" begin + if LuxDeviceUtils.functional(LuxMetalDevice) + @test_logs (:warn, + "Support for Multi Device Metal hasn't been implemented yet. Ignoring the device setting.") LuxDeviceUtils.set_device!( + LuxMetalDevice, nothing, 1) + end +end diff --git a/test/misc.jl b/test/misc.jl index c4194bf..6d59372 100644 --- a/test/misc.jl +++ b/test/misc.jl @@ -88,3 +88,21 @@ end vecarray_dev = vecarray |> gdev @test get_device(vecarray_dev) isa parameterless_type(typeof(gdev)) end + +@testset "CPU default rng" begin + @test default_device_rng(LuxCPUDevice()) isa Random.TaskLocalRNG +end + +@testset "CPU setdevice!" begin + @test_logs (:warn, + "Setting device for `LuxCPUDevice` doesn't make sense. Ignoring the device setting.") LuxDeviceUtils.set_device!( + LuxCPUDevice, nothing, 1) +end + +@testset "get_device on Arrays" begin + x = rand(10, 10) + x_view = view(x, 1:5, 1:5) + + @test get_device(x) isa LuxCPUDevice + @test get_device(x_view) isa LuxCPUDevice +end diff --git a/test/oneapi.jl b/test/oneapi.jl index d3f6806..619ef8d 100644 --- a/test/oneapi.jl +++ b/test/oneapi.jl @@ -92,3 +92,20 @@ using FillArrays, Zygote # Extensions ps_mixed = (; a=rand(2), b=device(rand(2))) @test_throws ArgumentError get_device(ps_mixed) end + +@testset "Wrapper Arrays" begin + if LuxDeviceUtils.functional(LuxoneAPIDevice) + x = rand(10, 10) |> LuxoneAPIDevice() + @test get_device(x) isa LuxoneAPIDevice + x_view = view(x, 1:5, 1:5) + @test get_device(x_view) isa LuxoneAPIDevice + end +end + +@testset "setdevice!" begin + if LuxDeviceUtils.functional(LuxoneAPIDevice) + @test_logs (:warn, + "Support for Multi Device oneAPI hasn't been implemented yet. Ignoring the device setting.") LuxDeviceUtils.set_device!( + LuxoneAPIDevice, nothing, 1) + end +end