diff --git a/src/rand/highlevel.jl b/src/rand/highlevel.jl index 3807d22d..995a96bf 100644 --- a/src/rand/highlevel.jl +++ b/src/rand/highlevel.jl @@ -96,6 +96,16 @@ rand_logn(rng::RNG, T::LognormalType, dim1::Integer, dims::Integer...; kwargs... rand_poisson(rng::RNG, T::PoissonType, dim1::Integer, dims::Integer...; kwargs...) = rand_poisson(rng, T, Dims((dim1, dims...)); kwargs...) +# scalar (slow, but provided for consistency) +Random.rand(rng::RNG, T::UniformType=Float32) = + Random.rand!(rng, CuArray{T}(undef, 1))[1] +Random.randn(rng::RNG, T::NormalType=Float32; kwargs...) = + Random.randn!(rng, CuArray{T}(undef, 2); kwargs...)[1] +rand_logn(rng::RNG, T::LognormalType=Float32; kwargs...) = + rand_logn!(rng, CuArray{T}(undef, 2); kwargs...)[1] +rand_poisson(rng::RNG, T::PoissonType=Cuint; kwargs...) = + rand_poisson!(rng, CuArray{T}(undef, 1); kwargs...)[1] + ## functions that dispatch to either CURAND or GPUArrays @@ -121,6 +131,12 @@ rand_logn(T::LognormalType, dim1::Integer, dims::Integer...; kwargs...) = rand_poisson(T::PoissonType, dim1::Integer, dims::Integer...; kwargs...) = rand_poisson(generator(), T, Dims((dim1, dims...)); kwargs...) +# scalar +rand(T::UniformType=Float32) = Random.rand(generator(), T) +randn(T::NormalType=Float32; kwargs...) = Random.randn(generator(), T; kwargs...) +rand_logn(T::LognormalType=Float32; kwargs...) = rand_logn(generator(), T; kwargs...) +rand_poisson(T::PoissonType=Cuint; kwargs...) = rand_poisson(generator(), T; kwargs...) + # GPUArrays in-place Random.rand!(A::CuArray) = Random.rand!(GPUArrays.global_rng(A), A) Random.randn!(A::CuArray; kwargs...) = @@ -146,6 +162,12 @@ rand_logn(T::Type, dim1::Integer, dims::Integer...; kwargs...) = rand_poisson(T::Type, dim1::Integer, dims::Integer...; kwargs...) = rand_poisson!(CuArray{T}(undef, dim1, dims...); kwargs...) +# scalar (slow, but provided for consistency) +rand(T::Type) = Random.rand!(CuArray{T}(undef, 1))[1] +randn(T::Type; kwargs...) = Random.randn!(CuArray{T}(undef, 1); kwargs...)[1] +rand_logn(T::Type; kwargs...) = rand_logn!(CuArray{T}(undef, 1); kwargs...)[1] +rand_poisson(T::Type; kwargs...) = rand_poisson!(CuArray{T}(undef, 1); kwargs...)[1] + # untyped out-of-place rand(dim1::Integer, dims::Integer...) = Random.rand(generator(), Dims((dim1, dims...))) diff --git a/test/rand.jl b/test/rand.jl index 6cb93305..4e453709 100644 --- a/test/rand.jl +++ b/test/rand.jl @@ -52,4 +52,18 @@ end @test_throws ErrorException rand_logn!(CuArray{Cuint}(undef, 10)) @test_throws ErrorException rand_poisson!(CuArray{Float64}(undef, 10)) +# scalars +@allowscalar begin + for f in (CuArrays.rand, CuArrays.randn, CuArrays.rand_logn, CuArrays.rand_poisson) + @test isa(f(), Number) + end + for (f,T) in ((CuArrays.rand,Float32), (CuArrays.randn,Float32), (CuArrays.rand_logn,Float32), + (CuArrays.rand,Float64), (CuArrays.randn,Float64), (CuArrays.rand_logn,Float64), + (CuArrays.rand_poisson,Cuint), + (rand,Float32), (randn,Float32), + (rand,Float64), (randn,Float64)) + @test isa(f(T), T) + end +end + end