diff --git a/Project.toml b/Project.toml index 347f686..2322d2b 100644 --- a/Project.toml +++ b/Project.toml @@ -20,7 +20,9 @@ GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" Metal = "dde4c033-4e86-420c-a63e-0dd931031962" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b" @@ -32,7 +34,9 @@ LuxDeviceUtilsGPUArraysExt = "GPUArrays" LuxDeviceUtilsLuxCUDAExt = "LuxCUDA" LuxDeviceUtilsMetalExt = ["GPUArrays", "Metal"] LuxDeviceUtilsRecursiveArrayToolsExt = "RecursiveArrayTools" +LuxDeviceUtilsReverseDiffExt = "ReverseDiff" LuxDeviceUtilsSparseArraysExt = "SparseArrays" +LuxDeviceUtilsTrackerExt = "Tracker" LuxDeviceUtilsZygoteExt = "Zygote" LuxDeviceUtilsoneAPIExt = ["GPUArrays", "oneAPI"] @@ -40,11 +44,13 @@ LuxDeviceUtilsoneAPIExt = ["GPUArrays", "oneAPI"] AMDGPU = "0.8.4, 0.9" Adapt = "4" Aqua = "0.8.4" +ArrayInterface = "7.11" CUDA = "5.2" -ChainRulesCore = "1.20" +ChainRulesCore = "1.23" ComponentArrays = "0.15.8" ExplicitImports = "1.4.1" FillArrays = "1" +ForwardDiff = "0.10.36" Functors = "0.4.4" GPUArrays = "10" LuxCUDA = "0.3.2" @@ -55,28 +61,34 @@ PrecompileTools = "1.2" Preferences = "1.4" Random = "1.10" RecursiveArrayTools = "3.8" +ReverseDiff = "1.15" SafeTestsets = "0.1" SparseArrays = "1.10" Test = "1.10" TestSetExtensions = "3" +Tracker = "0.2.34" Zygote = "0.6.69" julia = "1.10" oneAPI = "1.5" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" TestSetExtensions = "98d24dd4-01ad-11ea-1b02-c9a08f80db04" +Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "ComponentArrays", "ExplicitImports", "FillArrays", "LuxCore", "Pkg", "Random", "RecursiveArrayTools", "SafeTestsets", "SparseArrays", "Test", "TestSetExtensions", "Zygote"] +test = ["Aqua", "ArrayInterface", "ComponentArrays", "ExplicitImports", "FillArrays", "ForwardDiff", "LuxCore", "Pkg", "Random", "RecursiveArrayTools", "ReverseDiff", "SafeTestsets", "SparseArrays", "Test", "TestSetExtensions", "Tracker", "Zygote"] diff --git a/ext/LuxDeviceUtilsCUDAExt.jl b/ext/LuxDeviceUtilsCUDAExt.jl index 88acd11..c484558 100644 --- a/ext/LuxDeviceUtilsCUDAExt.jl +++ b/ext/LuxDeviceUtilsCUDAExt.jl @@ -41,12 +41,7 @@ function LuxDeviceUtils.set_device!(::Type{LuxCUDADevice}, dev::CUDA.CuDevice) return end function LuxDeviceUtils.set_device!(::Type{LuxCUDADevice}, id::Int) - if !CUDA.functional() - @warn "CUDA is not functional." - return - end - CUDA.device!(id - 1) - return + return LuxDeviceUtils.set_device!(LuxCUDADevice, CUDA.devices()[id]) end function LuxDeviceUtils.set_device!(::Type{LuxCUDADevice}, ::Nothing, rank::Int) id = mod1(rank + 1, length(CUDA.devices())) diff --git a/ext/LuxDeviceUtilsReverseDiffExt.jl b/ext/LuxDeviceUtilsReverseDiffExt.jl new file mode 100644 index 0000000..a683b3e --- /dev/null +++ b/ext/LuxDeviceUtilsReverseDiffExt.jl @@ -0,0 +1,13 @@ +module LuxDeviceUtilsReverseDiffExt + +using LuxDeviceUtils: LuxDeviceUtils +using ReverseDiff: ReverseDiff + +@inline function LuxDeviceUtils.get_device(x::ReverseDiff.TrackedArray) + return LuxDeviceUtils.get_device(ReverseDiff.value(x)) +end +@inline function LuxDeviceUtils.get_device(x::AbstractArray{<:ReverseDiff.TrackedReal}) + return LuxDeviceUtils.get_device(ReverseDiff.value.(x)) +end + +end diff --git a/ext/LuxDeviceUtilsTrackerExt.jl b/ext/LuxDeviceUtilsTrackerExt.jl new file mode 100644 index 0000000..7ae149e --- /dev/null +++ b/ext/LuxDeviceUtilsTrackerExt.jl @@ -0,0 +1,26 @@ +module LuxDeviceUtilsTrackerExt + +using Adapt: Adapt +using LuxDeviceUtils: LuxDeviceUtils, LuxAMDGPUDevice, LuxCUDADevice, LuxMetalDevice, + LuxoneAPIDevice, LuxCPUDevice +using Tracker: Tracker + +@inline function LuxDeviceUtils.get_device(x::Tracker.TrackedArray) + return LuxDeviceUtils.get_device(Tracker.data(x)) +end +@inline function LuxDeviceUtils.get_device(x::AbstractArray{<:Tracker.TrackedReal}) + return LuxDeviceUtils.get_device(Tracker.data.(x)) +end + +@inline LuxDeviceUtils.__special_aos(::AbstractArray{<:Tracker.TrackedReal}) = true + +for T in (LuxAMDGPUDevice, LuxAMDGPUDevice{Nothing}, LuxCUDADevice, + LuxCUDADevice{Nothing}, LuxMetalDevice, LuxoneAPIDevice, LuxCPUDevice) + @eval function Adapt.adapt_storage(to::$(T), x::AbstractArray{<:Tracker.TrackedReal}) + @warn "AbstractArray{<:Tracker.TrackedReal} is not supported for $(to). Converting \ + to Tracker.TrackedArray." maxlog=1 + return to(Tracker.collect(x)) + end +end + +end diff --git a/src/LuxDeviceUtils.jl b/src/LuxDeviceUtils.jl index a14bb24..3f5d3ab 100644 --- a/src/LuxDeviceUtils.jl +++ b/src/LuxDeviceUtils.jl @@ -238,8 +238,8 @@ function _get_gpu_device(; force_gpu_usage::Bool) 2. If GPU is available, load the corresponding trigger package. a. `LuxCUDA.jl` for NVIDIA CUDA Support. b. `AMDGPU.jl` for AMD GPU ROCM Support. - c. `Metal.jl` for Apple Metal GPU Support. - d. `oneAPI.jl` for Intel oneAPI GPU Support.""" maxlog=1 + c. `Metal.jl` for Apple Metal GPU Support. (Experimental) + d. `oneAPI.jl` for Intel oneAPI GPU Support. (Experimental)""" maxlog=1 return LuxCPUDevice end end @@ -319,7 +319,7 @@ for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI) @eval begin function (D::$(ldev))(x::AbstractArray{T}) where {T} fn = Base.Fix1(Adapt.adapt, D) - return isbitstype(T) ? fn(x) : map(D, x) + return isbitstype(T) || __special_aos(x) ? fn(x) : map(D, x) end (D::$(ldev))(x::Tuple) = map(D, x) (D::$(ldev))(x::NamedTuple{F}) where {F} = NamedTuple{F}(D(values(x))) @@ -336,6 +336,8 @@ for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI) end end +@inline __special_aos(x::AbstractArray) = false + # Query Device from Array """ get_device(x) -> AbstractLuxDevice | Exception | Nothing diff --git a/test/component_arrays.jl b/test/component_arrays.jl deleted file mode 100644 index 3825a22..0000000 --- a/test/component_arrays.jl +++ /dev/null @@ -1,17 +0,0 @@ -using LuxDeviceUtils, ComponentArrays, Random - -@testset "https://github.com/LuxDL/LuxDeviceUtils.jl/issues/10 patch" begin - dev = LuxCPUDevice() - ps = (; weight=randn(10, 1), bias=randn(1)) - - ps_ca = ps |> ComponentArray - - ps_ca_dev = ps_ca |> dev - - @test ps_ca_dev isa ComponentArray - - @test ps_ca_dev.weight == ps.weight - @test ps_ca_dev.bias == ps.bias - - @test ps_ca_dev == (ps |> dev |> ComponentArray) -end diff --git a/test/misc.jl b/test/misc.jl new file mode 100644 index 0000000..e1eba18 --- /dev/null +++ b/test/misc.jl @@ -0,0 +1,45 @@ +using LuxDeviceUtils, ComponentArrays, Random +using ArrayInterface: parameterless_type + +@testset "https://github.com/LuxDL/LuxDeviceUtils.jl/issues/10 patch" begin + dev = LuxCPUDevice() + ps = (; weight=randn(10, 1), bias=randn(1)) + + ps_ca = ps |> ComponentArray + + ps_ca_dev = ps_ca |> dev + + @test ps_ca_dev isa ComponentArray + + @test ps_ca_dev.weight == ps.weight + @test ps_ca_dev.bias == ps.bias + + @test ps_ca_dev == (ps |> dev |> ComponentArray) +end + +using ReverseDiff, Tracker, ForwardDiff + +@testset "AD Types" begin + x = randn(Float32, 10) + + x_rdiff = ReverseDiff.track(x) + @test get_device(x_rdiff) isa LuxCPUDevice + x_rdiff = ReverseDiff.track.(x) + @test get_device(x_rdiff) isa LuxCPUDevice + + gdev = gpu_device() + + x_tracker = Tracker.param(x) + @test get_device(x_tracker) isa LuxCPUDevice + x_tracker = Tracker.param.(x) + @test get_device(x_tracker) isa LuxCPUDevice + x_tracker_dev = Tracker.param(x) |> gdev + @test get_device(x_tracker_dev) isa parameterless_type(typeof(gdev)) + x_tracker_dev = Tracker.param.(x) |> gdev + @test get_device(x_tracker_dev) isa parameterless_type(typeof(gdev)) + + x_fdiff = ForwardDiff.Dual.(x) + @test get_device(x_fdiff) isa LuxCPUDevice + x_fdiff_dev = ForwardDiff.Dual.(x) |> gdev + @test get_device(x_fdiff_dev) isa parameterless_type(typeof(gdev)) +end diff --git a/test/runtests.jl b/test/runtests.jl index 1a38d67..d63a17c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -27,7 +27,7 @@ const GROUP = get(ENV, "GROUP", "NONE") @testset "Others" begin @testset "Aqua Tests" Aqua.test_all(LuxDeviceUtils) - @safetestset "Component Arrays" include("component_arrays.jl") + @safetestset "Misc Tests" include("misc.jl") @safetestset "Explicit Imports" include("explicit_imports.jl") end