Skip to content

Commit

Permalink
Deprecate mulithreading in bootstrap (#674)
Browse files Browse the repository at this point in the history
* make use_threads a noop in bootstrap

* test deprecation warning instead of results

* NEWS

* version bump

* drop references in docs

* remove threading from replicate as well
  • Loading branch information
palday authored Apr 11, 2023
1 parent c68bb4b commit e5d8a6a
Show file tree
Hide file tree
Showing 7 changed files with 32 additions and 82 deletions.
5 changes: 5 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
MixedModels v4.10.0 Release Notes
==============================
* Multithreading in `parametricbootstrap` with `use_threads` is now deprecated and a noop. With improvements in BLAS threading, multithreading at the Julia level did not help performance and sometimes hurt it. [#674]

MixedModels v4.9.0 Release Notes
==============================
* Support `StatsModels` 0.7, drop support for `StatsModels` 0.6. [#664]
Expand Down Expand Up @@ -400,3 +404,4 @@ Package dependencies
[#664]: https://github.com/JuliaStats/MixedModels.jl/issues/664
[#665]: https://github.com/JuliaStats/MixedModels.jl/issues/665
[#667]: https://github.com/JuliaStats/MixedModels.jl/issues/667
[#674]: https://github.com/JuliaStats/MixedModels.jl/issues/674
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MixedModels"
uuid = "ff71e718-51f3-5ec2-a782-8ffcbfa3c316"
author = ["Phillip Alday <[email protected]>", "Douglas Bates <[email protected]>", "Jose Bayoan Santiago Calderon <[email protected]>"]
version = "4.9.0"
version = "4.10.0"

[deps]
Arrow = "69666777-d1a9-59fb-9406-91d4454c9d45"
Expand Down
2 changes: 1 addition & 1 deletion docs/src/bootstrap.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ m2 = fit(
DisplayAs.Text(ans) # hide
```
```@example Main
samp2 = parametricbootstrap(rng, 10_000, m2, use_threads=true);
samp2 = parametricbootstrap(rng, 10_000, m2);
df2 = DataFrame(samp2.allpars);
first(df2, 10)
```
Expand Down
54 changes: 13 additions & 41 deletions src/bootstrap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ end

"""
parametricbootstrap([rng::AbstractRNG], nsamp::Integer, m::MixedModel{T}, ftype=T;
β = coef(m), σ = m.σ, θ = m.θ, use_threads=false, hide_progress=false)
β = coef(m), σ = m.σ, θ = m.θ, hide_progress=false)
Perform `nsamp` parametric bootstrap replication fits of `m`, returning a `MixedModelBootstrap`.
Expand All @@ -52,20 +52,8 @@ performance benefits.
- `β`, `σ`, and `θ` are the values of `m`'s parameters for simulating the responses.
- `σ` is only valid for `LinearMixedModel` and `GeneralizedLinearMixedModel` for
families with a dispersion parameter.
- `use_threads` determines whether or not to use thread-based parallelism.
- `hide_progress` can be used to disable the progress bar. Note that the progress
bar is automatically disabled for non-interactive (i.e. logging) contexts.
!!! note
Note that `use_threads=true` may not offer a performance boost and may even
decrease performance if multithreaded linear algebra (BLAS) routines are available.
In this case, threads at the level of the linear algebra may already occupy all
processors/processor cores. There are plans to provide better support in coordinating
Julia- and BLAS-level threads in the future.
!!! warning
The PRNG shared between threads is locked using `Threads.SpinLock`, which
should not be used recursively. Do not wrap `parametricbootstrap` in an outer `SpinLock`.
"""
function parametricbootstrap(
rng::AbstractRNG,
Expand All @@ -88,35 +76,19 @@ function parametricbootstrap(

β_names = (Symbol.(fixefnames(morig))...,)

# we need arrays of these for in-place operations to work across threads
m_threads = [m]
βsc_threads = [βsc]
θsc_threads = [θsc]

if use_threads
Threads.resize_nthreads!(m_threads)
Threads.resize_nthreads!(βsc_threads)
Threads.resize_nthreads!(θsc_threads)
end
# we use locks to guarantee thread-safety, but there might be better ways to do this for some RNGs
# see https://docs.julialang.org/en/v1.3/manual/parallel-computing/#Side-effects-and-mutable-function-arguments-1
# see https://docs.julialang.org/en/v1/stdlib/Future/index.html
rnglock = Threads.SpinLock()
samp = replicate(n; use_threads=use_threads, hide_progress=hide_progress) do
tidx = use_threads ? Threads.threadid() : 1
mod = m_threads[tidx]
local βsc = βsc_threads[tidx]
local θsc = θsc_threads[tidx]
lock(rnglock)
mod = simulate!(rng, mod; β=β, σ=σ, θ=θ)
unlock(rnglock)
refit!(mod; progress=false)
use_threads && Base.depwarn(
"use_threads is deprecated and will be removed in a future release",
:parametricbootstrap,
)
samp = replicate(n; hide_progress=hide_progress) do
simulate!(rng, m; β, σ, θ)
refit!(m; progress=false)
(
objective=ftype.(mod.objective),
σ=ismissing(mod.σ) ? missing : ftype(mod.σ),
β=NamedTuple{β_names}(fixef!(βsc, mod)),
se=SVector{p,ftype}(stderror!(βsc, mod)),
θ=SVector{k,ftype}(getθ!(θsc, mod)),
objective=ftype.(m.objective),
σ=ismissing(m.σ) ? missing : ftype(m.σ),
β=NamedTuple{β_names}(fixef!(βsc, m)),
se=SVector{p,ftype}(stderror!(βsc, m)),
θ=SVector{k,ftype}(getθ!(θsc, m)),
)
end
return MixedModelBootstrap{ftype}(
Expand Down
27 changes: 9 additions & 18 deletions src/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,39 +124,30 @@ end
_is_logging(io) = isa(io, Base.TTY) == false || (get(ENV, "CI", nothing) == "true")

"""
replicate(f::Function, n::Integer; use_threads=false)
replicate(f::Function, n::Integer; hide_progress=false)
Return a vector of the values of `n` calls to `f()` - used in simulations where the value of `f` is stochastic.
`hide_progress` can be used to disable the progress bar. Note that the progress
bar is automatically disabled for non-interactive (i.e. logging) contexts.
!!! warning
If `f()` is not thread-safe or depends on a non thread-safe RNG,
then you must set `use_threads=false`. Also note that ordering of replications
is not guaranteed when `use_threads=true`, although the replications are not
otherwise affected for thread-safe `f()`.
"""
function replicate(f::Function, n::Integer; use_threads=false, hide_progress=false)
# no macro version yet: https://github.com/timholy/ProgressMeter.jl/issues/143
use_threads && Base.depwarn(
"use_threads is deprecated and will be removed in a future release",
:replicate,
)
# and we want some advanced options
p = Progress(n; output=Base.stderr, enabled=!hide_progress && !_is_logging(stderr))
# get the type
rr = f()
next!(p)
# pre-allocate
results = [rr for _ in Base.OneTo(n)]
if use_threads
Threads.@threads for idx in 2:n
results[idx] = f()
next!(p)
end
else
for idx in 2:n
results[idx] = f()
next!(p)
end
for idx in 2:n
results[idx] = f()
next!(p)
end
finish!(p)
return results
end

Expand Down
12 changes: 2 additions & 10 deletions test/bootstrap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,16 +94,8 @@ end
@test propertynames(coefp) == [:iter, :coefname, , :se, :z, :p]

@testset "threaded bootstrap" begin
bsamp_threaded = parametricbootstrap(MersenneTwister(1234321), 100, fm;
use_threads=true, hide_progress=true)
# even though it's bad practice with floating point, exact equality should
# be a valid test here -- if everything is working right, then it's the exact
# same operations occurring within each bootstrap sample, which IEEE 754
# guarantees will yield the same result
@test sort(bsamp_threaded.σ) == sort(bsamp.σ)
@test sort(bsamp_threaded.θ) == sort(bsamp.θ)
@test sort(columntable(bsamp_threaded.β).β) == sort(columntable(bsamp.β).β)
@test sum(issingular(bsamp)) == sum(issingular(bsamp_threaded))
@test_logs (:warn, r"use_threads is deprecated") parametricbootstrap(MersenneTwister(1234321), 1, fm;
use_threads=true, hide_progress=true)
end

@testset "zerocorr + Base.length + ftype" begin
Expand Down
12 changes: 1 addition & 11 deletions test/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,7 @@ end
end

@testset "threaded_replicate" begin
rng = StableRNG(42);
single_thread = replicate(10;use_threads=false) do; only(randn(rng, 1)) ; end
rng = StableRNG(42);
multi_thread = replicate(10;use_threads=true) do
if Threads.threadid() % 2 == 0
sleep(0.001)
end
r = only(randn(rng, 1));
end

@test all(sort!(single_thread) .≈ sort!(multi_thread))
@test_logs (:warn, r"use_threads is deprecated") replicate(string, 1; use_threads=true)
end

@testset "datasets" begin
Expand Down

0 comments on commit e5d8a6a

Please sign in to comment.