Skip to content
This repository has been archived by the owner on Mar 12, 2021. It is now read-only.

Commit

Permalink
Support for scalars.
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt committed Aug 23, 2019
1 parent db0715a commit 0b49741
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 0 deletions.
22 changes: 22 additions & 0 deletions src/rand/highlevel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,16 @@ rand_logn(rng::RNG, T::LognormalType, dim1::Integer, dims::Integer...; kwargs...
rand_poisson(rng::RNG, T::PoissonType, dim1::Integer, dims::Integer...; kwargs...) =
rand_poisson(rng, T, Dims((dim1, dims...)); kwargs...)

# scalar (slow, but provided for consistency)
Random.rand(rng::RNG, T::UniformType=Float32) =
Random.rand!(rng, CuArray{T}(undef, 1))[1]
Random.randn(rng::RNG, T::NormalType=Float32; kwargs...) =
Random.randn!(rng, CuArray{T}(undef, 2); kwargs...)[1]
rand_logn(rng::RNG, T::LognormalType=Float32; kwargs...) =
rand_logn!(rng, CuArray{T}(undef, 2); kwargs...)[1]
rand_poisson(rng::RNG, T::PoissonType=Cuint; kwargs...) =
rand_poisson!(rng, CuArray{T}(undef, 1); kwargs...)[1]


## functions that dispatch to either CURAND or GPUArrays

Expand All @@ -121,6 +131,12 @@ rand_logn(T::LognormalType, dim1::Integer, dims::Integer...; kwargs...) =
rand_poisson(T::PoissonType, dim1::Integer, dims::Integer...; kwargs...) =
rand_poisson(generator(), T, Dims((dim1, dims...)); kwargs...)

# scalar
rand(T::UniformType=Float32) = Random.rand(generator(), T)
randn(T::NormalType=Float32; kwargs...) = Random.randn(generator(), T; kwargs...)
rand_logn(T::LognormalType=Float32; kwargs...) = rand_logn(generator(), T; kwargs...)
rand_poisson(T::PoissonType=Cuint; kwargs...) = rand_poisson(generator(), T; kwargs...)

# GPUArrays in-place
Random.rand!(A::CuArray) = Random.rand!(GPUArrays.global_rng(A), A)
Random.randn!(A::CuArray; kwargs...) =
Expand All @@ -146,6 +162,12 @@ rand_logn(T::Type, dim1::Integer, dims::Integer...; kwargs...) =
rand_poisson(T::Type, dim1::Integer, dims::Integer...; kwargs...) =
rand_poisson!(CuArray{T}(undef, dim1, dims...); kwargs...)

# scalar (slow, but provided for consistency)
rand(T::Type) = Random.rand!(CuArray{T}(undef, 1))[1]
randn(T::Type; kwargs...) = Random.randn!(CuArray{T}(undef, 1); kwargs...)[1]
rand_logn(T::Type; kwargs...) = rand_logn!(CuArray{T}(undef, 1); kwargs...)[1]
rand_poisson(T::Type; kwargs...) = rand_poisson!(CuArray{T}(undef, 1); kwargs...)[1]

# untyped out-of-place
rand(dim1::Integer, dims::Integer...) =
Random.rand(generator(), Dims((dim1, dims...)))
Expand Down
14 changes: 14 additions & 0 deletions test/rand.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,18 @@ end
@test_throws ErrorException rand_logn!(CuArray{Cuint}(undef, 10))
@test_throws ErrorException rand_poisson!(CuArray{Float64}(undef, 10))

# scalars
@allowscalar begin
for f in (CuArrays.rand, CuArrays.randn, CuArrays.rand_logn, CuArrays.rand_poisson)
@test isa(f(), Number)
end
for (f,T) in ((CuArrays.rand,Float32), (CuArrays.randn,Float32), (CuArrays.rand_logn,Float32),
(CuArrays.rand,Float64), (CuArrays.randn,Float64), (CuArrays.rand_logn,Float64),
(CuArrays.rand_poisson,Cuint),
(rand,Float32), (randn,Float32),
(rand,Float64), (randn,Float64))
@test isa(f(T), T)
end
end

end

0 comments on commit 0b49741

Please sign in to comment.