diff --git a/Project.toml b/Project.toml index 2322d2b..cd57505 100644 --- a/Project.toml +++ b/Project.toml @@ -47,6 +47,7 @@ Aqua = "0.8.4" ArrayInterface = "7.11" CUDA = "5.2" ChainRulesCore = "1.23" +ChainRulesTestUtils = "1.13.0" ComponentArrays = "0.15.8" ExplicitImports = "1.4.1" FillArrays = "1" @@ -74,6 +75,7 @@ oneAPI = "1.5" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" +ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" @@ -91,4 +93,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "ArrayInterface", "ComponentArrays", "ExplicitImports", "FillArrays", "ForwardDiff", "LuxCore", "Pkg", "Random", "RecursiveArrayTools", "ReverseDiff", "SafeTestsets", "SparseArrays", "Test", "TestSetExtensions", "Tracker", "Zygote"] +test = ["Aqua", "ArrayInterface", "ChainRulesTestUtils", "ComponentArrays", "ExplicitImports", "FillArrays", "ForwardDiff", "LuxCore", "Pkg", "Random", "RecursiveArrayTools", "ReverseDiff", "SafeTestsets", "SparseArrays", "Test", "TestSetExtensions", "Tracker", "Zygote"] diff --git a/ext/LuxDeviceUtilsAMDGPUExt.jl b/ext/LuxDeviceUtilsAMDGPUExt.jl index 1f2352a..87043cf 100644 --- a/ext/LuxDeviceUtilsAMDGPUExt.jl +++ b/ext/LuxDeviceUtilsAMDGPUExt.jl @@ -46,20 +46,18 @@ LuxDeviceUtils._get_device_id(dev::LuxAMDGPUDevice) = AMDGPU.device_id(dev.devic LuxDeviceUtils.default_device_rng(::LuxAMDGPUDevice) = AMDGPU.rocrand_rng() # Query Device from Array -LuxDeviceUtils.get_device(x::AMDGPU.AnyROCArray) = LuxAMDGPUDevice(AMDGPU.device(x)) +function LuxDeviceUtils.get_device(x::AMDGPU.AnyROCArray) + parent_x = parent(x) + parent_x === x && return LuxAMDGPUDevice(AMDGPU.device(x)) + return LuxDeviceUtils.get_device(parent_x) +end # Set Device function LuxDeviceUtils.set_device!(::Type{LuxAMDGPUDevice}, dev::AMDGPU.HIPDevice) - if !AMDGPU.functional() - @warn "AMDGPU is not functional." - return - end - AMDGPU.device!(dev) - return + return AMDGPU.device!(dev) end function LuxDeviceUtils.set_device!(::Type{LuxAMDGPUDevice}, id::Int) - LuxDeviceUtils.set_device!(LuxAMDGPUDevice, AMDGPU.devices()[id]) - return + return LuxDeviceUtils.set_device!(LuxAMDGPUDevice, AMDGPU.devices()[id]) end function LuxDeviceUtils.set_device!(::Type{LuxAMDGPUDevice}, ::Nothing, rank::Int) id = mod1(rank + 1, length(AMDGPU.devices())) @@ -71,7 +69,7 @@ end Adapt.adapt_storage(::LuxAMDGPUDevice{Nothing}, x::AbstractArray) = AMDGPU.roc(x) function Adapt.adapt_storage(to::LuxAMDGPUDevice, x::AbstractArray) old_dev = AMDGPU.device() # remember the current device - if !(x isa AMDGPU.AnyROCArray) + if !(LuxDeviceUtils.get_device(x) isa LuxAMDGPUDevice) AMDGPU.device!(to.device) x_new = AMDGPU.roc(x) AMDGPU.device!(old_dev) diff --git a/ext/LuxDeviceUtilsCUDAExt.jl b/ext/LuxDeviceUtilsCUDAExt.jl index c484558..3e7d253 100644 --- a/ext/LuxDeviceUtilsCUDAExt.jl +++ b/ext/LuxDeviceUtilsCUDAExt.jl @@ -26,19 +26,18 @@ LuxDeviceUtils._get_device_id(dev::LuxCUDADevice) = CUDA.deviceid(dev.device) + LuxDeviceUtils.default_device_rng(::LuxCUDADevice) = CUDA.default_rng() # Query Device from Array -LuxDeviceUtils.get_device(x::CUDA.AnyCuArray) = LuxCUDADevice(CUDA.device(x)) +function LuxDeviceUtils.get_device(x::CUDA.AnyCuArray) + parent_x = parent(x) + parent_x === x && return LuxCUDADevice(CUDA.device(x)) + return LuxDeviceUtils.get_device(parent_x) +end function LuxDeviceUtils.get_device(x::CUDA.CUSPARSE.AbstractCuSparseArray) return LuxCUDADevice(CUDA.device(x.nzVal)) end # Set Device function LuxDeviceUtils.set_device!(::Type{LuxCUDADevice}, dev::CUDA.CuDevice) - if !CUDA.functional() - @warn "CUDA is not functional." - return - end - CUDA.device!(dev) - return + return CUDA.device!(dev) end function LuxDeviceUtils.set_device!(::Type{LuxCUDADevice}, id::Int) return LuxDeviceUtils.set_device!(LuxCUDADevice, CUDA.devices()[id]) @@ -52,7 +51,7 @@ end Adapt.adapt_storage(::LuxCUDADevice{Nothing}, x::AbstractArray) = CUDA.cu(x) function Adapt.adapt_storage(to::LuxCUDADevice, x::AbstractArray) old_dev = CUDA.device() # remember the current device - if !(x isa CUDA.AnyCuArray) + if !(LuxDeviceUtils.get_device(x) isa LuxCUDADevice) CUDA.device!(to.device) x_new = CUDA.cu(x) CUDA.device!(old_dev) diff --git a/ext/LuxDeviceUtilsRecursiveArrayToolsExt.jl b/ext/LuxDeviceUtilsRecursiveArrayToolsExt.jl index 0142242..78aec5e 100644 --- a/ext/LuxDeviceUtilsRecursiveArrayToolsExt.jl +++ b/ext/LuxDeviceUtilsRecursiveArrayToolsExt.jl @@ -1,7 +1,7 @@ module LuxDeviceUtilsRecursiveArrayToolsExt using Adapt: Adapt, adapt -using LuxDeviceUtils: AbstractLuxDevice +using LuxDeviceUtils: LuxDeviceUtils, AbstractLuxDevice using RecursiveArrayTools: VectorOfArray, DiffEqArray # We want to preserve the structure @@ -14,4 +14,8 @@ function Adapt.adapt_structure(to::AbstractLuxDevice, x::DiffEqArray) return DiffEqArray(map(Base.Fix1(adapt, to), x.u), x.t) end +function LuxDeviceUtils.get_device(x::Union{VectorOfArray, DiffEqArray}) + return mapreduce(LuxDeviceUtils.get_device, LuxDeviceUtils.__combine_devices, x.u) +end + end diff --git a/ext/LuxDeviceUtilsTrackerExt.jl b/ext/LuxDeviceUtilsTrackerExt.jl index 7ae149e..6746b9b 100644 --- a/ext/LuxDeviceUtilsTrackerExt.jl +++ b/ext/LuxDeviceUtilsTrackerExt.jl @@ -2,7 +2,7 @@ module LuxDeviceUtilsTrackerExt using Adapt: Adapt using LuxDeviceUtils: LuxDeviceUtils, LuxAMDGPUDevice, LuxCUDADevice, LuxMetalDevice, - LuxoneAPIDevice, LuxCPUDevice + LuxoneAPIDevice using Tracker: Tracker @inline function LuxDeviceUtils.get_device(x::Tracker.TrackedArray) @@ -15,7 +15,7 @@ end @inline LuxDeviceUtils.__special_aos(::AbstractArray{<:Tracker.TrackedReal}) = true for T in (LuxAMDGPUDevice, LuxAMDGPUDevice{Nothing}, LuxCUDADevice, - LuxCUDADevice{Nothing}, LuxMetalDevice, LuxoneAPIDevice, LuxCPUDevice) + LuxCUDADevice{Nothing}, LuxMetalDevice, LuxoneAPIDevice) @eval function Adapt.adapt_storage(to::$(T), x::AbstractArray{<:Tracker.TrackedReal}) @warn "AbstractArray{<:Tracker.TrackedReal} is not supported for $(to). Converting \ to Tracker.TrackedArray." maxlog=1 diff --git a/src/LuxDeviceUtils.jl b/src/LuxDeviceUtils.jl index 12bfc0d..4e48e46 100644 --- a/src/LuxDeviceUtils.jl +++ b/src/LuxDeviceUtils.jl @@ -372,7 +372,6 @@ get_device(x::NamedTuple) = mapreduce(get_device, __combine_devices, values(x)) CRC.@non_differentiable get_device(::Any...) -__combine_devices(dev1) = dev1 function __combine_devices(dev1, dev2) dev1 === nothing && return dev2 dev2 === nothing && return dev1 @@ -380,9 +379,6 @@ function __combine_devices(dev1, dev2) throw(ArgumentError("Objects are on different devices: $dev1 and $dev2.")) return dev1 end -function __combine_devices(dev1, dev2, rem_devs...) - return foldl(__combine_devices, (dev1, dev2, rem_devs...)) -end # Set the device const SET_DEVICE_DOCS = """ @@ -390,7 +386,7 @@ Set the device for the given type. This is a no-op for `LuxCPUDevice`. For `LuxC and `LuxAMDGPUDevice`, it prints a warning if the corresponding trigger package is not loaded. -Currently, `LuxMetalDevice` doesn't support setting the device. +Currently, `LuxMetalDevice` and `LuxoneAPIDevice` doesn't support setting the device. """ const SET_DEVICE_DANGER = """ diff --git a/test/amdgpu.jl b/test/amdgpu.jl index c6350e3..7c472fa 100644 --- a/test/amdgpu.jl +++ b/test/amdgpu.jl @@ -6,6 +6,7 @@ using LuxDeviceUtils, Random @test gpu_device() isa LuxCPUDevice @test_throws LuxDeviceUtils.LuxDeviceSelectionException gpu_device(; force_gpu_usage=true) + @test_throws Exception default_device_rng(LuxAMDGPUDevice(nothing)) end using AMDGPU diff --git a/test/cuda.jl b/test/cuda.jl index ec996a9..189503e 100644 --- a/test/cuda.jl +++ b/test/cuda.jl @@ -6,6 +6,7 @@ using LuxDeviceUtils, Random @test gpu_device() isa LuxCPUDevice @test_throws LuxDeviceUtils.LuxDeviceSelectionException gpu_device(; force_gpu_usage=true) + @test_throws Exception default_device_rng(LuxCUDADevice(nothing)) end using LuxCUDA diff --git a/test/metal.jl b/test/metal.jl index 9ac4446..57d1ff6 100644 --- a/test/metal.jl +++ b/test/metal.jl @@ -6,6 +6,7 @@ using LuxDeviceUtils, Random @test gpu_device() isa LuxCPUDevice @test_throws LuxDeviceUtils.LuxDeviceSelectionException gpu_device(; force_gpu_usage=true) + @test_throws Exception default_device_rng(LuxMetalDevice()) end using Metal diff --git a/test/misc.jl b/test/misc.jl index e1eba18..c4194bf 100644 --- a/test/misc.jl +++ b/test/misc.jl @@ -1,5 +1,8 @@ -using LuxDeviceUtils, ComponentArrays, Random +using Adapt, LuxDeviceUtils, ComponentArrays, Random using ArrayInterface: parameterless_type +using ChainRulesTestUtils: test_rrule +using ReverseDiff, Tracker, ForwardDiff +using SparseArrays, FillArrays, Zygote, RecursiveArrayTools @testset "https://github.com/LuxDL/LuxDeviceUtils.jl/issues/10 patch" begin dev = LuxCPUDevice() @@ -17,8 +20,6 @@ using ArrayInterface: parameterless_type @test ps_ca_dev == (ps |> dev |> ComponentArray) end -using ReverseDiff, Tracker, ForwardDiff - @testset "AD Types" begin x = randn(Float32, 10) @@ -43,3 +44,47 @@ using ReverseDiff, Tracker, ForwardDiff x_fdiff_dev = ForwardDiff.Dual.(x) |> gdev @test get_device(x_fdiff_dev) isa parameterless_type(typeof(gdev)) end + +@testset "CRC Tests" begin + dev = cpu_device() # Other devices don't work with FiniteDifferences.jl + test_rrule(Adapt.adapt_storage, dev, randn(Float64, 10); check_inferred=true) + + gdev = gpu_device() + if !(gdev isa LuxMetalDevice) # On intel devices causes problems + x = randn(10) + ∂dev, ∂x = Zygote.gradient(sum ∘ Adapt.adapt_storage, gdev, x) + @test ∂dev === nothing + @test ∂x ≈ ones(10) + + x = randn(10) |> gdev + ∂dev, ∂x = Zygote.gradient(sum ∘ Adapt.adapt_storage, cpu_device(), x) + @test ∂dev === nothing + @test ∂x ≈ gdev(ones(10)) + @test get_device(∂x) isa parameterless_type(typeof(gdev)) + end +end + +# The following just test for noops +@testset "NoOps CPU" begin + cdev = cpu_device() + + @test cdev(sprand(10, 10, 0.9)) isa SparseMatrixCSC + @test cdev(1:10) isa AbstractRange + @test cdev(Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4))) isa Zygote.OneElement +end + +@testset "RecursiveArrayTools" begin + gdev = gpu_device() + + diffeqarray = DiffEqArray([rand(10) for _ in 1:10], rand(10)) + @test get_device(diffeqarray) isa LuxCPUDevice + + diffeqarray_dev = diffeqarray |> gdev + @test get_device(diffeqarray_dev) isa parameterless_type(typeof(gdev)) + + vecarray = VectorOfArray([rand(10) for _ in 1:10]) + @test get_device(vecarray) isa LuxCPUDevice + + vecarray_dev = vecarray |> gdev + @test get_device(vecarray_dev) isa parameterless_type(typeof(gdev)) +end diff --git a/test/oneapi.jl b/test/oneapi.jl index 8dc079b..d3f6806 100644 --- a/test/oneapi.jl +++ b/test/oneapi.jl @@ -6,6 +6,7 @@ using LuxDeviceUtils, Random @test gpu_device() isa LuxCPUDevice @test_throws LuxDeviceUtils.LuxDeviceSelectionException gpu_device(; force_gpu_usage=true) + @test_throws Exception default_device_rng(LuxoneAPIDevice()) end using oneAPI