Skip to content

Commit

Permalink
make Sampler{E} encode the type E of elements which are generated
Browse files Browse the repository at this point in the history
Before, a call like `rand(mm, Sampler(mm, 1:10), 3)`
generated an `Array{Any,1}`, so a way to get the `eltype`
of a Sampler is necessary. Instead of changing Sampler -> Sampler{E},
implementing appropriate eltype methods would have been possible,
to keep the helper Sampler subtypes more flexible, but it seemed
to be simpler this way.
  • Loading branch information
rfourquet committed Dec 22, 2017
1 parent 12d756b commit 7f2f88a
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 17 deletions.
13 changes: 7 additions & 6 deletions base/random/generation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ rand(r::AbstractRNG, ::SamplerTrivial{CloseOpen01_64}) = rand(r, CloseOpen12())
const bits_in_Limb = sizeof(Limb) << 3
const Limb_high_bit = one(Limb) << (bits_in_Limb-1)

struct SamplerBigFloat{I<:FloatInterval{BigFloat}} <: Sampler
struct SamplerBigFloat{I<:FloatInterval{BigFloat}} <: Sampler{BigFloat}
prec::Int
nlimbs::Int
limbs::Vector{Limb}
Expand Down Expand Up @@ -155,7 +155,7 @@ uint_sup(::Type{<:Union{Int128,UInt128}}) = UInt128

#### Fast

struct SamplerRangeFast{U<:BitUnsigned,T<:Union{BitInteger,Bool}} <: Sampler
struct SamplerRangeFast{U<:BitUnsigned,T<:Union{BitInteger,Bool}} <: Sampler{T}
a::T # first element of the range
bw::UInt # bit width
m::U # range length - 1
Expand Down Expand Up @@ -215,7 +215,7 @@ maxmultiple(k::T, sup::T=zero(T)) where {T<:Unsigned} =
unsafe_maxmultiple(k::T, sup::T) where {T<:Unsigned} =
div(sup, k + (k == 0))*k - one(k)

struct SamplerRangeInt{T<:Union{Bool,Integer},U<:Unsigned} <: Sampler
struct SamplerRangeInt{T<:Union{Bool,Integer},U<:Unsigned} <: Sampler{T}
a::T # first element of the range
bw::Int # bit width
k::U # range length or zero for full range
Expand Down Expand Up @@ -270,7 +270,7 @@ end

### BigInt

struct SamplerBigInt <: Sampler
struct SamplerBigInt <: Sampler{BigInt}
a::BigInt # first
m::BigInt # range length - 1
nlimbs::Int # number of limbs in generated BigInt's (z ∈ [0, m])
Expand Down Expand Up @@ -336,9 +336,10 @@ end

## random values from Set

Sampler(rng::AbstractRNG, t::Set, n::Repetition) = SamplerTag{Set}(Sampler(rng, t.dict, n))
Sampler(rng::AbstractRNG, t::Set{T}, n::Repetition) where {T} =
SamplerTag{Set{T}}(Sampler(rng, t.dict, n))

rand(rng::AbstractRNG, sp::SamplerTag{Set,<:Sampler}) = rand(rng, sp.data).first
rand(rng::AbstractRNG, sp::SamplerTag{<:Set,<:Sampler}) = rand(rng, sp.data).first

## random values from BitSet

Expand Down
32 changes: 21 additions & 11 deletions base/random/random.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ export srand,

abstract type AbstractRNG end


### integers

# we define types which encode the generation of a specific number of bits
Expand Down Expand Up @@ -83,7 +84,9 @@ const BitFloatType = Union{Type{Float16},Type{Float32},Type{Float64}}

### Sampler

abstract type Sampler end
abstract type Sampler{E} end

Base.eltype(::Sampler{E}) where {E} = E

# temporarily for BaseBenchmarks
RangeGenerator(x) = Sampler(GLOBAL_RNG, x)
Expand All @@ -109,41 +112,48 @@ Sampler(rng::AbstractRNG, ::Type{X}) where {X} = Sampler(rng, X, Val(Inf))
#### pre-defined useful Sampler types

# default fall-back for types
struct SamplerType{T} <: Sampler end
struct SamplerType{T} <: Sampler{T} end

Sampler(::AbstractRNG, ::Type{T}, ::Repetition) where {T} = SamplerType{T}()

Base.getindex(sp::SamplerType{T}) where {T} = T
Base.getindex(::SamplerType{T}) where {T} = T

# default fall-back for values
struct SamplerTrivial{T} <: Sampler
struct SamplerTrivial{T,E} <: Sampler{E}
self::T
end

Sampler(::AbstractRNG, X, ::Repetition) = SamplerTrivial(X)
SamplerTrivial(x::T) where {T} = SamplerTrivial{T,eltype(T)}(x)

Sampler(::AbstractRNG, x, ::Repetition) = SamplerTrivial(x)

Base.getindex(sp::SamplerTrivial) = sp.self

# simple sampler carrying data (which can be anything)
struct SamplerSimple{T,S} <: Sampler
struct SamplerSimple{T,S,E} <: Sampler{E}
self::T
data::S
end

SamplerSimple(x::T, data::S) where {T,S} = SamplerSimple{T,S,eltype(T)}(x, data)

Base.getindex(sp::SamplerSimple) = sp.self

# simple sampler carrying a (type) tag T and data
struct SamplerTag{T,S} <: Sampler
struct SamplerTag{T,S,E} <: Sampler{E}
data::S
SamplerTag{T}(s::S) where {T,S} = new{T,S}(s)
SamplerTag{T}(s::S) where {T,S} = new{T,S,eltype(T)}(s)
end


#### helper samplers

# TODO: make constraining constructors to enforce that those
# types are <: Sampler{T}

##### Adapter to generate a randome value in [0, n]

struct LessThan{T<:Integer,S} <: Sampler
struct LessThan{T<:Integer,S} <: Sampler{T}
sup::T
s::S # the scalar specification/sampler to feed to rand
end
Expand All @@ -155,7 +165,7 @@ function rand(rng::AbstractRNG, sp::LessThan)
end
end

struct Masked{T<:Integer,S} <: Sampler
struct Masked{T<:Integer,S} <: Sampler{T}
mask::T
s::S
end
Expand All @@ -164,7 +174,7 @@ rand(rng::AbstractRNG, sp::Masked) = rand(rng, sp.s) & sp.mask

##### Uniform

struct UniformT{T} <: Sampler end
struct UniformT{T} <: Sampler{T} end

uniform(::Type{T}) where {T} = UniformT{T}()

Expand Down

0 comments on commit 7f2f88a

Please sign in to comment.