Skip to content

Commit

Permalink
Add tests for AD types
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jun 7, 2024
1 parent 09da47b commit bc811cf
Show file tree
Hide file tree
Showing 8 changed files with 105 additions and 29 deletions.
16 changes: 14 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b"

Expand All @@ -32,19 +34,23 @@ LuxDeviceUtilsGPUArraysExt = "GPUArrays"
LuxDeviceUtilsLuxCUDAExt = "LuxCUDA"
LuxDeviceUtilsMetalExt = ["GPUArrays", "Metal"]
LuxDeviceUtilsRecursiveArrayToolsExt = "RecursiveArrayTools"
LuxDeviceUtilsReverseDiffExt = "ReverseDiff"
LuxDeviceUtilsSparseArraysExt = "SparseArrays"
LuxDeviceUtilsTrackerExt = "Tracker"
LuxDeviceUtilsZygoteExt = "Zygote"
LuxDeviceUtilsoneAPIExt = ["GPUArrays", "oneAPI"]

[compat]
AMDGPU = "0.8.4, 0.9"
Adapt = "4"
Aqua = "0.8.4"
ArrayInterface = "7.11"
CUDA = "5.2"
ChainRulesCore = "1.20"
ChainRulesCore = "1.23"
ComponentArrays = "0.15.8"
ExplicitImports = "1.4.1"
FillArrays = "1"
ForwardDiff = "0.10.36"
Functors = "0.4.4"
GPUArrays = "10"
LuxCUDA = "0.3.2"
Expand All @@ -55,28 +61,34 @@ PrecompileTools = "1.2"
Preferences = "1.4"
Random = "1.10"
RecursiveArrayTools = "3.8"
ReverseDiff = "1.15"
SafeTestsets = "0.1"
SparseArrays = "1.10"
Test = "1.10"
TestSetExtensions = "3"
Tracker = "0.2.34"
Zygote = "0.6.69"
julia = "1.10"
oneAPI = "1.5"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TestSetExtensions = "98d24dd4-01ad-11ea-1b02-c9a08f80db04"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "ComponentArrays", "ExplicitImports", "FillArrays", "LuxCore", "Pkg", "Random", "RecursiveArrayTools", "SafeTestsets", "SparseArrays", "Test", "TestSetExtensions", "Zygote"]
test = ["Aqua", "ArrayInterface", "ComponentArrays", "ExplicitImports", "FillArrays", "ForwardDiff", "LuxCore", "Pkg", "Random", "RecursiveArrayTools", "ReverseDiff", "SafeTestsets", "SparseArrays", "Test", "TestSetExtensions", "Tracker", "Zygote"]
7 changes: 1 addition & 6 deletions ext/LuxDeviceUtilsCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,7 @@ function LuxDeviceUtils.set_device!(::Type{LuxCUDADevice}, dev::CUDA.CuDevice)
return
end
function LuxDeviceUtils.set_device!(::Type{LuxCUDADevice}, id::Int)
if !CUDA.functional()
@warn "CUDA is not functional."
return
end
CUDA.device!(id - 1)
return
return LuxDeviceUtils.set_device!(LuxCUDADevice, CUDA.devices()[id])

Check warning on line 44 in ext/LuxDeviceUtilsCUDAExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxDeviceUtilsCUDAExt.jl#L44

Added line #L44 was not covered by tests
end
function LuxDeviceUtils.set_device!(::Type{LuxCUDADevice}, ::Nothing, rank::Int)
id = mod1(rank + 1, length(CUDA.devices()))
Expand Down
13 changes: 13 additions & 0 deletions ext/LuxDeviceUtilsReverseDiffExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
module LuxDeviceUtilsReverseDiffExt

using LuxDeviceUtils: LuxDeviceUtils
using ReverseDiff: ReverseDiff

@inline function LuxDeviceUtils.get_device(x::ReverseDiff.TrackedArray)
return LuxDeviceUtils.get_device(ReverseDiff.value(x))
end
@inline function LuxDeviceUtils.get_device(x::AbstractArray{<:ReverseDiff.TrackedReal})
return LuxDeviceUtils.get_device(ReverseDiff.value.(x))
end

end
26 changes: 26 additions & 0 deletions ext/LuxDeviceUtilsTrackerExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
module LuxDeviceUtilsTrackerExt

using Adapt: Adapt
using LuxDeviceUtils: LuxDeviceUtils, LuxAMDGPUDevice, LuxCUDADevice, LuxMetalDevice,
LuxoneAPIDevice, LuxCPUDevice
using Tracker: Tracker

@inline function LuxDeviceUtils.get_device(x::Tracker.TrackedArray)
return LuxDeviceUtils.get_device(Tracker.data(x))
end
@inline function LuxDeviceUtils.get_device(x::AbstractArray{<:Tracker.TrackedReal})
return LuxDeviceUtils.get_device(Tracker.data.(x))
end

@inline LuxDeviceUtils.__special_aos(::AbstractArray{<:Tracker.TrackedReal}) = true

for T in (LuxAMDGPUDevice, LuxAMDGPUDevice{Nothing}, LuxCUDADevice,
LuxCUDADevice{Nothing}, LuxMetalDevice, LuxoneAPIDevice, LuxCPUDevice)
@eval function Adapt.adapt_storage(to::$(T), x::AbstractArray{<:Tracker.TrackedReal})
@warn "AbstractArray{<:Tracker.TrackedReal} is not supported for $(to). Converting \
to Tracker.TrackedArray." maxlog=1
return to(Tracker.collect(x))
end
end

end
8 changes: 5 additions & 3 deletions src/LuxDeviceUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,8 @@ function _get_gpu_device(; force_gpu_usage::Bool)
2. If GPU is available, load the corresponding trigger package.
a. `LuxCUDA.jl` for NVIDIA CUDA Support.
b. `AMDGPU.jl` for AMD GPU ROCM Support.
c. `Metal.jl` for Apple Metal GPU Support.
d. `oneAPI.jl` for Intel oneAPI GPU Support.""" maxlog=1
c. `Metal.jl` for Apple Metal GPU Support. (Experimental)
d. `oneAPI.jl` for Intel oneAPI GPU Support. (Experimental)""" maxlog=1
return LuxCPUDevice
end
end
Expand Down Expand Up @@ -319,7 +319,7 @@ for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI)
@eval begin
function (D::$(ldev))(x::AbstractArray{T}) where {T}
fn = Base.Fix1(Adapt.adapt, D)
return isbitstype(T) ? fn(x) : map(D, x)
return isbitstype(T) || __special_aos(x) ? fn(x) : map(D, x)
end
(D::$(ldev))(x::Tuple) = map(D, x)
(D::$(ldev))(x::NamedTuple{F}) where {F} = NamedTuple{F}(D(values(x)))
Expand All @@ -336,6 +336,8 @@ for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI)
end
end

@inline __special_aos(x::AbstractArray) = false

# Query Device from Array
"""
get_device(x) -> AbstractLuxDevice | Exception | Nothing
Expand Down
17 changes: 0 additions & 17 deletions test/component_arrays.jl

This file was deleted.

45 changes: 45 additions & 0 deletions test/misc.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
using LuxDeviceUtils, ComponentArrays, Random
using ArrayInterface: parameterless_type

@testset "https://github.com/LuxDL/LuxDeviceUtils.jl/issues/10 patch" begin
dev = LuxCPUDevice()
ps = (; weight=randn(10, 1), bias=randn(1))

ps_ca = ps |> ComponentArray

ps_ca_dev = ps_ca |> dev

@test ps_ca_dev isa ComponentArray

@test ps_ca_dev.weight == ps.weight
@test ps_ca_dev.bias == ps.bias

@test ps_ca_dev == (ps |> dev |> ComponentArray)
end

using ReverseDiff, Tracker, ForwardDiff

@testset "AD Types" begin
x = randn(Float32, 10)

x_rdiff = ReverseDiff.track(x)
@test get_device(x_rdiff) isa LuxCPUDevice
x_rdiff = ReverseDiff.track.(x)
@test get_device(x_rdiff) isa LuxCPUDevice

gdev = gpu_device()

x_tracker = Tracker.param(x)
@test get_device(x_tracker) isa LuxCPUDevice
x_tracker = Tracker.param.(x)
@test get_device(x_tracker) isa LuxCPUDevice
x_tracker_dev = Tracker.param(x) |> gdev
@test get_device(x_tracker_dev) isa parameterless_type(typeof(gdev))
x_tracker_dev = Tracker.param.(x) |> gdev
@test get_device(x_tracker_dev) isa parameterless_type(typeof(gdev))

x_fdiff = ForwardDiff.Dual.(x)
@test get_device(x_fdiff) isa LuxCPUDevice
x_fdiff_dev = ForwardDiff.Dual.(x) |> gdev
@test get_device(x_fdiff_dev) isa parameterless_type(typeof(gdev))
end
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ const GROUP = get(ENV, "GROUP", "NONE")
@testset "Others" begin
@testset "Aqua Tests" Aqua.test_all(LuxDeviceUtils)

@safetestset "Component Arrays" include("component_arrays.jl")
@safetestset "Misc Tests" include("misc.jl")

@safetestset "Explicit Imports" include("explicit_imports.jl")
end
Expand Down

0 comments on commit bc811cf

Please sign in to comment.