Skip to content

Commit

Permalink
Add tests for rrule
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jun 7, 2024
1 parent 09d910f commit 6ff2497
Show file tree
Hide file tree
Showing 11 changed files with 78 additions and 30 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ Aqua = "0.8.4"
ArrayInterface = "7.11"
CUDA = "5.2"
ChainRulesCore = "1.23"
ChainRulesTestUtils = "1.13.0"
ComponentArrays = "0.15.8"
ExplicitImports = "1.4.1"
FillArrays = "1"
Expand Down Expand Up @@ -74,6 +75,7 @@ oneAPI = "1.5"
[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
Expand All @@ -91,4 +93,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "ArrayInterface", "ComponentArrays", "ExplicitImports", "FillArrays", "ForwardDiff", "LuxCore", "Pkg", "Random", "RecursiveArrayTools", "ReverseDiff", "SafeTestsets", "SparseArrays", "Test", "TestSetExtensions", "Tracker", "Zygote"]
test = ["Aqua", "ArrayInterface", "ChainRulesTestUtils", "ComponentArrays", "ExplicitImports", "FillArrays", "ForwardDiff", "LuxCore", "Pkg", "Random", "RecursiveArrayTools", "ReverseDiff", "SafeTestsets", "SparseArrays", "Test", "TestSetExtensions", "Tracker", "Zygote"]
18 changes: 8 additions & 10 deletions ext/LuxDeviceUtilsAMDGPUExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,20 +46,18 @@ LuxDeviceUtils._get_device_id(dev::LuxAMDGPUDevice) = AMDGPU.device_id(dev.devic
LuxDeviceUtils.default_device_rng(::LuxAMDGPUDevice) = AMDGPU.rocrand_rng()

# Query Device from Array
LuxDeviceUtils.get_device(x::AMDGPU.AnyROCArray) = LuxAMDGPUDevice(AMDGPU.device(x))
function LuxDeviceUtils.get_device(x::AMDGPU.AnyROCArray)
parent_x = parent(x)
parent_x === x && return LuxAMDGPUDevice(AMDGPU.device(x))
return LuxDeviceUtils.get_device(parent_x)

Check warning on line 52 in ext/LuxDeviceUtilsAMDGPUExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxDeviceUtilsAMDGPUExt.jl#L49-L52

Added lines #L49 - L52 were not covered by tests
end

# Set Device
function LuxDeviceUtils.set_device!(::Type{LuxAMDGPUDevice}, dev::AMDGPU.HIPDevice)
if !AMDGPU.functional()
@warn "AMDGPU is not functional."
return
end
AMDGPU.device!(dev)
return
return AMDGPU.device!(dev)

Check warning on line 57 in ext/LuxDeviceUtilsAMDGPUExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxDeviceUtilsAMDGPUExt.jl#L57

Added line #L57 was not covered by tests
end
function LuxDeviceUtils.set_device!(::Type{LuxAMDGPUDevice}, id::Int)
LuxDeviceUtils.set_device!(LuxAMDGPUDevice, AMDGPU.devices()[id])
return
return LuxDeviceUtils.set_device!(LuxAMDGPUDevice, AMDGPU.devices()[id])

Check warning on line 60 in ext/LuxDeviceUtilsAMDGPUExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxDeviceUtilsAMDGPUExt.jl#L60

Added line #L60 was not covered by tests
end
function LuxDeviceUtils.set_device!(::Type{LuxAMDGPUDevice}, ::Nothing, rank::Int)
id = mod1(rank + 1, length(AMDGPU.devices()))
Expand All @@ -71,7 +69,7 @@ end
Adapt.adapt_storage(::LuxAMDGPUDevice{Nothing}, x::AbstractArray) = AMDGPU.roc(x)
function Adapt.adapt_storage(to::LuxAMDGPUDevice, x::AbstractArray)

Check warning on line 70 in ext/LuxDeviceUtilsAMDGPUExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxDeviceUtilsAMDGPUExt.jl#L69-L70

Added lines #L69 - L70 were not covered by tests
old_dev = AMDGPU.device() # remember the current device
if !(x isa AMDGPU.AnyROCArray)
if !(LuxDeviceUtils.get_device(x) isa LuxAMDGPUDevice)

Check warning on line 72 in ext/LuxDeviceUtilsAMDGPUExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxDeviceUtilsAMDGPUExt.jl#L72

Added line #L72 was not covered by tests
AMDGPU.device!(to.device)
x_new = AMDGPU.roc(x)
AMDGPU.device!(old_dev)
Expand Down
15 changes: 7 additions & 8 deletions ext/LuxDeviceUtilsCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,18 @@ LuxDeviceUtils._get_device_id(dev::LuxCUDADevice) = CUDA.deviceid(dev.device) +
LuxDeviceUtils.default_device_rng(::LuxCUDADevice) = CUDA.default_rng()

# Query Device from Array
LuxDeviceUtils.get_device(x::CUDA.AnyCuArray) = LuxCUDADevice(CUDA.device(x))
function LuxDeviceUtils.get_device(x::CUDA.AnyCuArray)
parent_x = parent(x)
parent_x === x && return LuxCUDADevice(CUDA.device(x))
return LuxDeviceUtils.get_device(parent_x)

Check warning on line 32 in ext/LuxDeviceUtilsCUDAExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxDeviceUtilsCUDAExt.jl#L32

Added line #L32 was not covered by tests
end
function LuxDeviceUtils.get_device(x::CUDA.CUSPARSE.AbstractCuSparseArray)
return LuxCUDADevice(CUDA.device(x.nzVal))
end

# Set Device
function LuxDeviceUtils.set_device!(::Type{LuxCUDADevice}, dev::CUDA.CuDevice)
if !CUDA.functional()
@warn "CUDA is not functional."
return
end
CUDA.device!(dev)
return
return CUDA.device!(dev)

Check warning on line 40 in ext/LuxDeviceUtilsCUDAExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxDeviceUtilsCUDAExt.jl#L40

Added line #L40 was not covered by tests
end
function LuxDeviceUtils.set_device!(::Type{LuxCUDADevice}, id::Int)
return LuxDeviceUtils.set_device!(LuxCUDADevice, CUDA.devices()[id])

Check warning on line 43 in ext/LuxDeviceUtilsCUDAExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxDeviceUtilsCUDAExt.jl#L43

Added line #L43 was not covered by tests
Expand All @@ -52,7 +51,7 @@ end
Adapt.adapt_storage(::LuxCUDADevice{Nothing}, x::AbstractArray) = CUDA.cu(x)
function Adapt.adapt_storage(to::LuxCUDADevice, x::AbstractArray)
old_dev = CUDA.device() # remember the current device
if !(x isa CUDA.AnyCuArray)
if !(LuxDeviceUtils.get_device(x) isa LuxCUDADevice)
CUDA.device!(to.device)
x_new = CUDA.cu(x)
CUDA.device!(old_dev)
Expand Down
6 changes: 5 additions & 1 deletion ext/LuxDeviceUtilsRecursiveArrayToolsExt.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module LuxDeviceUtilsRecursiveArrayToolsExt

using Adapt: Adapt, adapt
using LuxDeviceUtils: AbstractLuxDevice
using LuxDeviceUtils: LuxDeviceUtils, AbstractLuxDevice
using RecursiveArrayTools: VectorOfArray, DiffEqArray

# We want to preserve the structure
Expand All @@ -14,4 +14,8 @@ function Adapt.adapt_structure(to::AbstractLuxDevice, x::DiffEqArray)
return DiffEqArray(map(Base.Fix1(adapt, to), x.u), x.t)
end

function LuxDeviceUtils.get_device(x::Union{VectorOfArray, DiffEqArray})
return mapreduce(LuxDeviceUtils.get_device, LuxDeviceUtils.__combine_devices, x.u)
end

end
4 changes: 2 additions & 2 deletions ext/LuxDeviceUtilsTrackerExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module LuxDeviceUtilsTrackerExt

using Adapt: Adapt
using LuxDeviceUtils: LuxDeviceUtils, LuxAMDGPUDevice, LuxCUDADevice, LuxMetalDevice,
LuxoneAPIDevice, LuxCPUDevice
LuxoneAPIDevice
using Tracker: Tracker

@inline function LuxDeviceUtils.get_device(x::Tracker.TrackedArray)
Expand All @@ -15,7 +15,7 @@ end
@inline LuxDeviceUtils.__special_aos(::AbstractArray{<:Tracker.TrackedReal}) = true

for T in (LuxAMDGPUDevice, LuxAMDGPUDevice{Nothing}, LuxCUDADevice,
LuxCUDADevice{Nothing}, LuxMetalDevice, LuxoneAPIDevice, LuxCPUDevice)
LuxCUDADevice{Nothing}, LuxMetalDevice, LuxoneAPIDevice)
@eval function Adapt.adapt_storage(to::$(T), x::AbstractArray{<:Tracker.TrackedReal})
@warn "AbstractArray{<:Tracker.TrackedReal} is not supported for $(to). Converting \
to Tracker.TrackedArray." maxlog=1
Expand Down
6 changes: 1 addition & 5 deletions src/LuxDeviceUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -372,25 +372,21 @@ get_device(x::NamedTuple) = mapreduce(get_device, __combine_devices, values(x))

CRC.@non_differentiable get_device(::Any...)

__combine_devices(dev1) = dev1
function __combine_devices(dev1, dev2)
dev1 === nothing && return dev2
dev2 === nothing && return dev1
dev1 != dev2 &&
throw(ArgumentError("Objects are on different devices: $dev1 and $dev2."))
return dev1
end
function __combine_devices(dev1, dev2, rem_devs...)
return foldl(__combine_devices, (dev1, dev2, rem_devs...))
end

# Set the device
const SET_DEVICE_DOCS = """
Set the device for the given type. This is a no-op for `LuxCPUDevice`. For `LuxCUDADevice`
and `LuxAMDGPUDevice`, it prints a warning if the corresponding trigger package is not
loaded.
Currently, `LuxMetalDevice` doesn't support setting the device.
Currently, `LuxMetalDevice` and `LuxoneAPIDevice` doesn't support setting the device.
"""

const SET_DEVICE_DANGER = """
Expand Down
1 change: 1 addition & 0 deletions test/amdgpu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using LuxDeviceUtils, Random
@test gpu_device() isa LuxCPUDevice
@test_throws LuxDeviceUtils.LuxDeviceSelectionException gpu_device(;
force_gpu_usage=true)
@test_throws Exception default_device_rng(LuxAMDGPUDevice(nothing))
end

using AMDGPU
Expand Down
1 change: 1 addition & 0 deletions test/cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using LuxDeviceUtils, Random
@test gpu_device() isa LuxCPUDevice
@test_throws LuxDeviceUtils.LuxDeviceSelectionException gpu_device(;
force_gpu_usage=true)
@test_throws Exception default_device_rng(LuxCUDADevice(nothing))
end

using LuxCUDA
Expand Down
1 change: 1 addition & 0 deletions test/metal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using LuxDeviceUtils, Random
@test gpu_device() isa LuxCPUDevice
@test_throws LuxDeviceUtils.LuxDeviceSelectionException gpu_device(;
force_gpu_usage=true)
@test_throws Exception default_device_rng(LuxMetalDevice())
end

using Metal
Expand Down
51 changes: 48 additions & 3 deletions test/misc.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
using LuxDeviceUtils, ComponentArrays, Random
using Adapt, LuxDeviceUtils, ComponentArrays, Random
using ArrayInterface: parameterless_type
using ChainRulesTestUtils: test_rrule
using ReverseDiff, Tracker, ForwardDiff
using SparseArrays, FillArrays, Zygote, RecursiveArrayTools

@testset "https://github.com/LuxDL/LuxDeviceUtils.jl/issues/10 patch" begin
dev = LuxCPUDevice()
Expand All @@ -17,8 +20,6 @@ using ArrayInterface: parameterless_type
@test ps_ca_dev == (ps |> dev |> ComponentArray)
end

using ReverseDiff, Tracker, ForwardDiff

@testset "AD Types" begin
x = randn(Float32, 10)

Expand All @@ -43,3 +44,47 @@ using ReverseDiff, Tracker, ForwardDiff
x_fdiff_dev = ForwardDiff.Dual.(x) |> gdev
@test get_device(x_fdiff_dev) isa parameterless_type(typeof(gdev))
end

@testset "CRC Tests" begin
dev = cpu_device() # Other devices don't work with FiniteDifferences.jl
test_rrule(Adapt.adapt_storage, dev, randn(Float64, 10); check_inferred=true)

gdev = gpu_device()
if !(gdev isa LuxMetalDevice) # On intel devices causes problems
x = randn(10)
∂dev, ∂x = Zygote.gradient(sum Adapt.adapt_storage, gdev, x)
@test ∂dev === nothing
@test ∂x ones(10)

x = randn(10) |> gdev
∂dev, ∂x = Zygote.gradient(sum Adapt.adapt_storage, cpu_device(), x)
@test ∂dev === nothing
@test ∂x gdev(ones(10))
@test get_device(∂x) isa parameterless_type(typeof(gdev))
end
end

# The following just test for noops
@testset "NoOps CPU" begin
cdev = cpu_device()

@test cdev(sprand(10, 10, 0.9)) isa SparseMatrixCSC
@test cdev(1:10) isa AbstractRange
@test cdev(Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4))) isa Zygote.OneElement
end

@testset "RecursiveArrayTools" begin
gdev = gpu_device()

diffeqarray = DiffEqArray([rand(10) for _ in 1:10], rand(10))
@test get_device(diffeqarray) isa LuxCPUDevice

diffeqarray_dev = diffeqarray |> gdev
@test get_device(diffeqarray_dev) isa parameterless_type(typeof(gdev))

vecarray = VectorOfArray([rand(10) for _ in 1:10])
@test get_device(vecarray) isa LuxCPUDevice

vecarray_dev = vecarray |> gdev
@test get_device(vecarray_dev) isa parameterless_type(typeof(gdev))
end
1 change: 1 addition & 0 deletions test/oneapi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using LuxDeviceUtils, Random
@test gpu_device() isa LuxCPUDevice
@test_throws LuxDeviceUtils.LuxDeviceSelectionException gpu_device(;
force_gpu_usage=true)
@test_throws Exception default_device_rng(LuxoneAPIDevice())
end

using oneAPI
Expand Down

0 comments on commit 6ff2497

Please sign in to comment.