diff --git a/src/LuxDeviceUtils.jl b/src/LuxDeviceUtils.jl index 3f5d3ab..12bfc0d 100644 --- a/src/LuxDeviceUtils.jl +++ b/src/LuxDeviceUtils.jl @@ -469,7 +469,8 @@ end Adapt.adapt_storage(::LuxCPUDevice, x::AbstractRange) = x # Prevent Ambiguity -for T in (LuxAMDGPUDevice, LuxCUDADevice, LuxMetalDevice, LuxoneAPIDevice) +for T in (LuxAMDGPUDevice, LuxAMDGPUDevice{Nothing}, LuxCUDADevice, + LuxCUDADevice{Nothing}, LuxMetalDevice, LuxoneAPIDevice) @eval Adapt.adapt_storage(to::$(T), x::AbstractRange) = Adapt.adapt(to, collect(x)) end diff --git a/test/amdgpu.jl b/test/amdgpu.jl index 5adf443..c6350e3 100644 --- a/test/amdgpu.jl +++ b/test/amdgpu.jl @@ -31,6 +31,7 @@ using FillArrays, Zygote # Extensions @testset "Data Transfer" begin ps = (a=(c=zeros(10, 1), d=1), b=ones(10, 1), e=:c, d="string", mixed=[2.0f0, 3.0, ones(2, 3)], # mixed array types + range=1:10, rng_default=Random.default_rng(), rng=MersenneTwister(), one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)), farray=Fill(1.0f0, (2, 3))) @@ -48,6 +49,7 @@ using FillArrays, Zygote # Extensions @test ps_xpu.mixed[1] isa Float32 @test ps_xpu.mixed[2] isa Float64 @test ps_xpu.mixed[3] isa aType + @test ps_xpu.range isa aType @test ps_xpu.e == ps.e @test ps_xpu.d == ps.d @test ps_xpu.rng_default isa rngType @@ -72,6 +74,7 @@ using FillArrays, Zygote # Extensions @test ps_cpu.mixed[1] isa Float32 @test ps_cpu.mixed[2] isa Float64 @test ps_cpu.mixed[3] isa Array + @test ps_cpu.range isa Array @test ps_cpu.e == ps.e @test ps_cpu.d == ps.d @test ps_cpu.rng_default isa Random.TaskLocalRNG diff --git a/test/cuda.jl b/test/cuda.jl index 88c8cb7..ec996a9 100644 --- a/test/cuda.jl +++ b/test/cuda.jl @@ -31,6 +31,7 @@ using FillArrays, Zygote # Extensions @testset "Data Transfer" begin ps = (a=(c=zeros(10, 1), d=1), b=ones(10, 1), e=:c, d="string", mixed=[2.0f0, 3.0, ones(2, 3)], # mixed array types + range=1:10, rng_default=Random.default_rng(), rng=MersenneTwister(), one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)), farray=Fill(1.0f0, (2, 3))) @@ -47,6 +48,7 @@ using FillArrays, Zygote # Extensions @test ps_xpu.mixed[1] isa Float32 @test ps_xpu.mixed[2] isa Float64 @test ps_xpu.mixed[3] isa aType + @test ps_xpu.range isa aType @test ps_xpu.e == ps.e @test ps_xpu.d == ps.d @test ps_xpu.rng_default isa rngType @@ -71,6 +73,7 @@ using FillArrays, Zygote # Extensions @test ps_cpu.mixed[1] isa Float32 @test ps_cpu.mixed[2] isa Float64 @test ps_cpu.mixed[3] isa Array + @test ps_cpu.range isa Array @test ps_cpu.e == ps.e @test ps_cpu.d == ps.d @test ps_cpu.rng_default isa Random.TaskLocalRNG diff --git a/test/metal.jl b/test/metal.jl index 261a6c0..9ac4446 100644 --- a/test/metal.jl +++ b/test/metal.jl @@ -31,6 +31,7 @@ using FillArrays, Zygote # Extensions @testset "Data Transfer" begin ps = (a=(c=zeros(10, 1), d=1), b=ones(10, 1), e=:c, d="string", mixed=[2.0f0, 3.0, ones(2, 3)], # mixed array types + range=1:10, rng_default=Random.default_rng(), rng=MersenneTwister(), one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)), farray=Fill(1.0f0, (2, 3))) @@ -48,6 +49,7 @@ using FillArrays, Zygote # Extensions @test ps_xpu.mixed[1] isa Float32 @test ps_xpu.mixed[2] isa Float64 @test ps_xpu.mixed[3] isa aType + @test ps_xpu.range isa aType @test ps_xpu.e == ps.e @test ps_xpu.d == ps.d @test ps_xpu.rng_default isa rngType @@ -72,6 +74,7 @@ using FillArrays, Zygote # Extensions @test ps_cpu.mixed[1] isa Float32 @test ps_cpu.mixed[2] isa Float64 @test ps_cpu.mixed[3] isa Array + @test ps_cpu.range isa Array @test ps_cpu.e == ps.e @test ps_cpu.d == ps.d @test ps_cpu.rng_default isa Random.TaskLocalRNG diff --git a/test/oneapi.jl b/test/oneapi.jl index 1e04198..8dc079b 100644 --- a/test/oneapi.jl +++ b/test/oneapi.jl @@ -31,6 +31,7 @@ using FillArrays, Zygote # Extensions @testset "Data Transfer" begin ps = (a=(c=zeros(10, 1), d=1), b=ones(10, 1), e=:c, d="string", mixed=[2.0f0, 3.0, ones(2, 3)], # mixed array types + range=1:10, rng_default=Random.default_rng(), rng=MersenneTwister(), one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)), farray=Fill(1.0f0, (2, 3))) @@ -48,6 +49,7 @@ using FillArrays, Zygote # Extensions @test ps_xpu.mixed[1] isa Float32 @test ps_xpu.mixed[2] isa Float64 @test ps_xpu.mixed[3] isa aType + @test ps_xpu.range isa aType @test ps_xpu.e == ps.e @test ps_xpu.d == ps.d @test ps_xpu.rng_default isa rngType @@ -72,6 +74,7 @@ using FillArrays, Zygote # Extensions @test ps_cpu.mixed[1] isa Float32 @test ps_cpu.mixed[2] isa Float64 @test ps_cpu.mixed[3] isa Array + @test ps_cpu.range isa Array @test ps_cpu.e == ps.e @test ps_cpu.d == ps.d @test ps_cpu.rng_default isa Random.TaskLocalRNG