Skip to content

Commit

Permalink
Add range tests
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jun 7, 2024
1 parent bc811cf commit 09d910f
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/LuxDeviceUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,8 @@ end

Adapt.adapt_storage(::LuxCPUDevice, x::AbstractRange) = x

Check warning on line 470 in src/LuxDeviceUtils.jl

View check run for this annotation

Codecov / codecov/patch

src/LuxDeviceUtils.jl#L470

Added line #L470 was not covered by tests
# Prevent Ambiguity
for T in (LuxAMDGPUDevice, LuxCUDADevice, LuxMetalDevice, LuxoneAPIDevice)
for T in (LuxAMDGPUDevice, LuxAMDGPUDevice{Nothing}, LuxCUDADevice,
LuxCUDADevice{Nothing}, LuxMetalDevice, LuxoneAPIDevice)
@eval Adapt.adapt_storage(to::$(T), x::AbstractRange) = Adapt.adapt(to, collect(x))
end

Expand Down
3 changes: 3 additions & 0 deletions test/amdgpu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ using FillArrays, Zygote # Extensions
@testset "Data Transfer" begin
ps = (a=(c=zeros(10, 1), d=1), b=ones(10, 1), e=:c,
d="string", mixed=[2.0f0, 3.0, ones(2, 3)], # mixed array types
range=1:10,
rng_default=Random.default_rng(), rng=MersenneTwister(),
one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)), farray=Fill(1.0f0, (2, 3)))

Expand All @@ -48,6 +49,7 @@ using FillArrays, Zygote # Extensions
@test ps_xpu.mixed[1] isa Float32
@test ps_xpu.mixed[2] isa Float64
@test ps_xpu.mixed[3] isa aType
@test ps_xpu.range isa aType
@test ps_xpu.e == ps.e
@test ps_xpu.d == ps.d
@test ps_xpu.rng_default isa rngType
Expand All @@ -72,6 +74,7 @@ using FillArrays, Zygote # Extensions
@test ps_cpu.mixed[1] isa Float32
@test ps_cpu.mixed[2] isa Float64
@test ps_cpu.mixed[3] isa Array
@test ps_cpu.range isa Array
@test ps_cpu.e == ps.e
@test ps_cpu.d == ps.d
@test ps_cpu.rng_default isa Random.TaskLocalRNG
Expand Down
3 changes: 3 additions & 0 deletions test/cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ using FillArrays, Zygote # Extensions
@testset "Data Transfer" begin
ps = (a=(c=zeros(10, 1), d=1), b=ones(10, 1), e=:c,
d="string", mixed=[2.0f0, 3.0, ones(2, 3)], # mixed array types
range=1:10,
rng_default=Random.default_rng(), rng=MersenneTwister(),
one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)), farray=Fill(1.0f0, (2, 3)))

Expand All @@ -47,6 +48,7 @@ using FillArrays, Zygote # Extensions
@test ps_xpu.mixed[1] isa Float32
@test ps_xpu.mixed[2] isa Float64
@test ps_xpu.mixed[3] isa aType
@test ps_xpu.range isa aType
@test ps_xpu.e == ps.e
@test ps_xpu.d == ps.d
@test ps_xpu.rng_default isa rngType
Expand All @@ -71,6 +73,7 @@ using FillArrays, Zygote # Extensions
@test ps_cpu.mixed[1] isa Float32
@test ps_cpu.mixed[2] isa Float64
@test ps_cpu.mixed[3] isa Array
@test ps_cpu.range isa Array
@test ps_cpu.e == ps.e
@test ps_cpu.d == ps.d
@test ps_cpu.rng_default isa Random.TaskLocalRNG
Expand Down
3 changes: 3 additions & 0 deletions test/metal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ using FillArrays, Zygote # Extensions
@testset "Data Transfer" begin
ps = (a=(c=zeros(10, 1), d=1), b=ones(10, 1), e=:c,
d="string", mixed=[2.0f0, 3.0, ones(2, 3)], # mixed array types
range=1:10,
rng_default=Random.default_rng(), rng=MersenneTwister(),
one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)), farray=Fill(1.0f0, (2, 3)))

Expand All @@ -48,6 +49,7 @@ using FillArrays, Zygote # Extensions
@test ps_xpu.mixed[1] isa Float32
@test ps_xpu.mixed[2] isa Float64
@test ps_xpu.mixed[3] isa aType
@test ps_xpu.range isa aType
@test ps_xpu.e == ps.e
@test ps_xpu.d == ps.d
@test ps_xpu.rng_default isa rngType
Expand All @@ -72,6 +74,7 @@ using FillArrays, Zygote # Extensions
@test ps_cpu.mixed[1] isa Float32
@test ps_cpu.mixed[2] isa Float64
@test ps_cpu.mixed[3] isa Array
@test ps_cpu.range isa Array
@test ps_cpu.e == ps.e
@test ps_cpu.d == ps.d
@test ps_cpu.rng_default isa Random.TaskLocalRNG
Expand Down
3 changes: 3 additions & 0 deletions test/oneapi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ using FillArrays, Zygote # Extensions
@testset "Data Transfer" begin
ps = (a=(c=zeros(10, 1), d=1), b=ones(10, 1), e=:c,
d="string", mixed=[2.0f0, 3.0, ones(2, 3)], # mixed array types
range=1:10,
rng_default=Random.default_rng(), rng=MersenneTwister(),
one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)), farray=Fill(1.0f0, (2, 3)))

Expand All @@ -48,6 +49,7 @@ using FillArrays, Zygote # Extensions
@test ps_xpu.mixed[1] isa Float32
@test ps_xpu.mixed[2] isa Float64
@test ps_xpu.mixed[3] isa aType
@test ps_xpu.range isa aType
@test ps_xpu.e == ps.e
@test ps_xpu.d == ps.d
@test ps_xpu.rng_default isa rngType
Expand All @@ -72,6 +74,7 @@ using FillArrays, Zygote # Extensions
@test ps_cpu.mixed[1] isa Float32
@test ps_cpu.mixed[2] isa Float64
@test ps_cpu.mixed[3] isa Array
@test ps_cpu.range isa Array
@test ps_cpu.e == ps.e
@test ps_cpu.d == ps.d
@test ps_cpu.rng_default isa Random.TaskLocalRNG
Expand Down

0 comments on commit 09d910f

Please sign in to comment.