Skip to content

Commit

Permalink
Merge pull request #8832 from JuliaLang/rf/rand-fillarray
Browse files Browse the repository at this point in the history
dSFMT: use fill_array_* API instead of genrand_* API
  • Loading branch information
ViralBShah committed Oct 30, 2014
2 parents da6087e + 3f6bd03 commit 06fb0a3
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 20 deletions.
19 changes: 19 additions & 0 deletions base/dSFMT.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ export DSFMT_state, dsfmt_get_min_array_size, dsfmt_get_idstring,
dsfmt_genrand_close1_open2, dsfmt_gv_genrand_close1_open2,
dsfmt_genrand_close_open, dsfmt_gv_genrand_close_open,
dsfmt_genrand_uint32, dsfmt_gv_genrand_uint32,
dsfmt_fill_array_close_open!, dsfmt_fill_array_close1_open2!,
win32_SystemFunction036!

type DSFMT_state
Expand Down Expand Up @@ -95,6 +96,24 @@ function dsfmt_gv_genrand_uint32()
())
end

# precondition for dsfmt_fill_array_*:
# the underlying C array must be 16-byte aligned, which is the case for "Array"
function dsfmt_fill_array_close1_open2!(s::DSFMT_state, A::Array{Float64}, n::Int)
@assert dsfmt_min_array_size <= n <= length(A) && iseven(n)
ccall((:dsfmt_fill_array_close1_open2,:libdSFMT),
Void,
(Ptr{Void}, Ptr{Float64}, Int),
s.val, A, n)
end

function dsfmt_fill_array_close_open!(s::DSFMT_state, A::Array{Float64}, n::Int)
@assert dsfmt_min_array_size <= n <= length(A) && iseven(n)
ccall((:dsfmt_fill_array_close_open,:libdSFMT),
Void,
(Ptr{Void}, Ptr{Float64}, Int),
s.val, A, n)
end

## Windows entropy

@windows_only begin
Expand Down
102 changes: 82 additions & 20 deletions base/random.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,43 @@ abstract AbstractRNG
type MersenneTwister <: AbstractRNG
state::DSFMT_state
seed::Union(Uint32,Vector{Uint32})
vals::Vector{Float64}
idx::Int

function MersenneTwister(seed::Vector{Uint32})
state = DSFMT_state()
dsfmt_init_by_array(state, seed)
return new(state, seed)
return new(state, seed, Array(Float64, dsfmt_get_min_array_size()), dsfmt_get_min_array_size())
end

MersenneTwister(seed=0) = MersenneTwister(make_seed(seed))
end

## Low level API for MersenneTwister

function gen_rand(r::MersenneTwister)
dsfmt_fill_array_close1_open2!(r.state, r.vals, length(r.vals))
r.idx = 0
end

@inline gen_rand_maybe(r::MersenneTwister) = r.idx == length(r.vals) && gen_rand(r)

# precondition: r.idx < length(r.vals)
@inline rand_close1_open2_inbounds(r::MersenneTwister) = (r.idx += 1; @inbounds return r.vals[r.idx])
@inline rand_inbounds(r::MersenneTwister) = rand_close1_open2_inbounds(r) - 1.0

# produce Float64 values
@inline rand_close1_open2(r::MersenneTwister) = (gen_rand_maybe(r); rand_close1_open2_inbounds(r))
@inline rand_close_open(r::MersenneTwister) = (gen_rand_maybe(r); rand_inbounds(r))

# this is similar to `dsfmt_genrand_uint32` from dSFMT.h:
@inline rand_ui32(r::MersenneTwister) = reinterpret(Uint64, rand_close1_open2(r)) % Uint32


function srand(r::MersenneTwister, seed)
r.seed = seed
dsfmt_init_gen_rand(r.state, seed)
r.idx = length(r.vals)
return r
end

Expand Down Expand Up @@ -60,8 +84,10 @@ __init__() = srand()
## srand()

function srand(seed::Vector{Uint32})
global RANDOM_SEED = seed
dsfmt_gv_init_by_array(seed)
GLOBAL_RNG.seed = seed
dsfmt_init_by_array(GLOBAL_RNG.state, seed)
GLOBAL_RNG.idx = length(GLOBAL_RNG.vals)
return GLOBAL_RNG
end
srand(n::Integer) = srand(make_seed(n))

Expand All @@ -86,28 +112,28 @@ function srand(filename::String, n::Integer)
end
srand(filename::String) = srand(filename, 4)

## Global RNG

const GLOBAL_RNG = MersenneTwister()
globalRNG() = GLOBAL_RNG

## random floating point values

rand(::Type{Float64}) = dsfmt_gv_genrand_close_open()
rand() = dsfmt_gv_genrand_close_open()
rand(r::MersenneTwister=GLOBAL_RNG) = rand_close_open(r)

rand(::Type{Float64}) = rand()

rand(::Type{Float32}) = float32(rand())
rand(::Type{Float16}) = float16(rand())

rand{T<:Real}(::Type{Complex{T}}) = complex(rand(T),rand(T))


rand(r::MersenneTwister) = dsfmt_genrand_close_open(r.state)

## random integers

dsfmt_randui32() = dsfmt_gv_genrand_uint32()
dsfmt_randui64() = uint64(dsfmt_randui32()) | (uint64(dsfmt_randui32())<<32)

rand(::Type{Uint8}) = rand(Uint32) % Uint8
rand(::Type{Uint16}) = rand(Uint32) % Uint16
rand(::Type{Uint32}) = dsfmt_randui32()
rand(::Type{Uint64}) = dsfmt_randui64()
rand(::Type{Uint32}) = rand_ui32(GLOBAL_RNG)
rand(::Type{Uint64}) = uint64(rand(Uint32)) <<32 | rand(Uint32)
rand(::Type{Uint128}) = uint128(rand(Uint64))<<64 | rand(Uint64)

rand(::Type{Int8}) = rand(Uint32) % Int8
Expand Down Expand Up @@ -142,6 +168,44 @@ function rand!{T}(r::AbstractRNG, A::AbstractArray{T})
A
end

function rand_AbstractArray_Float64!(r::MersenneTwister, A::AbstractArray{Float64})
n = length(A)
# what follows is equivalent to this simple loop but more efficient:
# for i=1:n
# @inbounds A[i] = rand(r)
# end
m = 0
while m < n
s = length(r.vals) - r.idx
if s == 0
gen_rand(r)
s = length(r.vals)
end
m2 = min(n, m+s)
for i=m+1:m2
@inbounds A[i] = rand_inbounds(r)
end
m = m2
end
A
end

rand!(r::MersenneTwister, A::AbstractArray{Float64}) = rand_AbstractArray_Float64!(r, A)

function rand!(r::MersenneTwister, A::Array{Float64})
n = length(A)
if n < dsfmt_get_min_array_size()
rand_AbstractArray_Float64!(r, A)
else
dsfmt_fill_array_close_open!(r.state, A, 2*(n ÷ 2))
isodd(n) && (A[n] = rand(r))
end
A
end

rand!(A::AbstractArray{Float64}) = rand!(GLOBAL_RNG, A)
rand!(A::Array{Float64}) = rand!(GLOBAL_RNG, A)

rand(T::Type, dims::Dims) = rand!(Array(T, dims))
rand{T<:Number}(::Type{T}) = error("no random number generator for type $T; try a more specific type")
rand{T<:Number}(::Type{T}, dims::Int...) = rand(T, dims)
Expand Down Expand Up @@ -241,7 +305,7 @@ rand!(B::BitArray) = Base.bitarray_rand_fill!(B)
randbool(dims::Dims) = rand!(BitArray(dims))
randbool(dims::Int...) = rand!(BitArray(dims))

randbool() = ((dsfmt_randui32() & 1) == 1)
randbool() = ((rand(Uint32) & 1) == 1)
rand(::Type{Bool}) = randbool()

## randn() - Normally distributed random numbers using Ziggurat algorithm
Expand Down Expand Up @@ -737,11 +801,9 @@ ziggurat_nor_r = 3.6541528853610087963519472518
ziggurat_nor_inv_r = inv(ziggurat_nor_r)
ziggurat_exp_r = 7.6971174701310497140446280481

rand(state::DSFMT_state) = dsfmt_genrand_close_open(state)
randi() = reinterpret(Uint64,dsfmt_gv_genrand_close1_open2()) & 0x000fffffffffffff
randi(state::DSFMT_state) = reinterpret(Uint64,dsfmt_genrand_close1_open2(state)) & 0x000fffffffffffff
@inline randi(rng::MersenneTwister=GLOBAL_RNG) = reinterpret(Uint64, rand_close1_open2(rng)) & 0x000fffffffffffff
for (lhs, rhs) in (([], []),
([:(state::DSFMT_state)], [:state]))
([:(rng::MersenneTwister)], [:rng]))
@eval begin
function randmtzig_randn($(lhs...))
@inbounds begin
Expand Down Expand Up @@ -787,9 +849,9 @@ for (lhs, rhs) in (([], []),
end

randn() = randmtzig_randn()
randn(rng::MersenneTwister) = randmtzig_randn(rng.state)
randn(rng::MersenneTwister) = randmtzig_randn(rng)
randn!(A::Array{Float64}) = (for i = 1:length(A);A[i] = randmtzig_randn();end;A)
randn!(rng::MersenneTwister, A::Array{Float64}) = (for i = 1:length(A);A[i] = randmtzig_randn(rng.state);end;A)
randn!(rng::MersenneTwister, A::Array{Float64}) = (for i = 1:length(A);A[i] = randmtzig_randn(rng);end;A)
randn(dims::Dims) = randn!(Array(Float64, dims))
randn(dims::Int...) = randn!(Array(Float64, dims...))
randn(rng::MersenneTwister, dims::Dims) = randn!(rng, Array(Float64, dims))
Expand Down

0 comments on commit 06fb0a3

Please sign in to comment.