Skip to content

Commit

Permalink
MCMCChain → MCMChains Changeover (#695)
Browse files Browse the repository at this point in the history
* Changed MCMCChain to MCMCChains.

* Rewrite of Turing native Chains functions.

* Add MCMCChains to REQUIRE file.

* Import save! and resume functions for the samplers.

* Fix save! and resume functionality.

* Removed names2inds import.

* Cleaning up.

* Changed :logevidence to :lp, as the previous functionality to handle this was removed.

* Modified SMC and IS to store logevidence, previously this was exp.(logevidence).

* The save! function now works with the NamedTuple info field.

* Adapted tests to use chn.logevidence, used to be chn[:logevidence].

* Updated REQUIRE to MCMCChains v0.3.0

* Removed save!, replaced with save as Chains are immutable.

* Made save function immutable.

* Removed unneeded test set, all are tested at MCMCChains.

* Added serialization tests to save/resume suite.

* Reincluded the io.jl test set for resumeable sampling.

* Removed vcat from utilities/io.jl

* Remove chain read/write tests.

* Changed logevidence in PMMH to -Inf

* Removed exp.(...) from PG

* Removed the correct exp.(...)
  • Loading branch information
cpfiffer authored and yebai committed Mar 5, 2019
1 parent bc8eb03 commit 6e6d7a7
Show file tree
Hide file tree
Showing 20 changed files with 78 additions and 320 deletions.
2 changes: 1 addition & 1 deletion REQUIRE
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Requires 0.5.0
Reexport 0.2.0
Distributions 0.16.0
ForwardDiff 0.8.0
MCMCChain 0.1.1
MCMCChains 0.3.0
Libtask 0.2.5
Flux 0.6.7
MacroTools
Expand Down
2 changes: 1 addition & 1 deletion docs/src/_docs/get-started.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ end
# Run sampler, collect results
chn = sample(gdemo(1.5, 2), HMC(1000, 0.1, 5))

# Summarise results (currently requires the master branch from MCMCChain)
# Summarise results (currently requires the master branch from MCMCChains)
describe(chn)

# Plot and save results
Expand Down
10 changes: 5 additions & 5 deletions docs/src/_docs/guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ c5 = sample(gdemo(1.5, 2), HMCDA(1000, 0.15, 0.65))
c6 = sample(gdemo(1.5, 2), NUTS(1000, 0.65))
```

The `MCMCChain` module (which is re-exported by Turing) provides plotting tools for the `Chain` objects returned by a `sample` function. See the [MCMCChain](https://github.com/TuringLang/MCMCChain.jl) repository for more information on the suite of tools available for diagnosing MCMC chains.
The `MCMCChains` module (which is re-exported by Turing) provides plotting tools for the `Chain` objects returned by a `sample` function. See the [MCMCChains](https://github.com/TuringLang/MCMCChains.jl) repository for more information on the suite of tools available for diagnosing MCMC chains.

```julia
# Summarise results
Expand Down Expand Up @@ -315,18 +315,18 @@ The `Gibbs` sampler can be used to specify unique automatic differentation backe

For more details of compositional sampling in Turing.jl, please check the corresponding [paper](http://xuk.ai/assets/aistats2018-turing.pdf).

### Working with MCMCChain.jl
### Working with MCMCChains.jl

Turing.jl wraps its samples using `MCMCChain.Chain` so that all the functions working for `MCMCChain.Chain` can be re-used in Turing.jl. Two typical functions are `MCMCChain.describe` and `MCMCChain.plot`, which can be used as follows for an obtained chain `chn`. For more information on `MCMCChain`, please see the [GitHub repository](https://github.com/TuringLang/MCMCChain.jl).
Turing.jl wraps its samples using `MCMCChains.Chain` so that all the functions working for `MCMCChains.Chain` can be re-used in Turing.jl. Two typical functions are `MCMCChains.describe` and `MCMCChains.plot`, which can be used as follows for an obtained chain `chn`. For more information on `MCMCChains`, please see the [GitHub repository](https://github.com/TuringLang/MCMCChains.jl).

```julia
using MCMCChain: describe, plot
using MCMCChains: describe, plot

describe(chn) # Lists statistics of the samples.
plot(chn) # Plots statistics of the samples.
```

There are numerous functions in addition to `describe` and `plot` in the `MCMCChain` package, such as those used in convergence diagnostics. For more information on the package, please see the [GitHub repository](https://github.com/TuringLang/MCMCChain.jl).
There are numerous functions in addition to `describe` and `plot` in the `MCMCChains` package, such as those used in convergence diagnostics. For more information on the package, please see the [GitHub repository](https://github.com/TuringLang/MCMCChains.jl).

### Working with Libtask.jl

Expand Down
2 changes: 1 addition & 1 deletion docs/src/_docs/quick-start.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ permalink: /docs/quick-start/

If you are already well-versed in probabalistic programming and just want to take a quick look at how Turing's syntax works or otherwise just want a model to start with, we have provided a Bayesian coin-flipping model to play with.

This example can be run on however you have Julia installed (see [Getting Started](get-started.md)), but you will need to install the packages `Turing`, `Distributions`, `MCMCChain`, and `StatsPlots` if you have not done so already.
This example can be run on however you have Julia installed (see [Getting Started](get-started.md)), but you will need to install the packages `Turing`, `Distributions`, `MCMCChains`, and `StatsPlots` if you have not done so already.

This is an excerpt from a more formal example introducing probabalistic programming which can be found in Jupyter notebook form [here](https://nbviewer.jupyter.org/github/TuringLang/TuringTutorials/blob/master/0_Introduction.ipynb) or as part of the documentation website [here](../../tutorials/0-introduction/).

Expand Down
4 changes: 2 additions & 2 deletions src/Turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@ using Requires, Reexport, ForwardDiff
using Bijectors, StatsFuns, SpecialFunctions
using Statistics, LinearAlgebra, ProgressMeter
using Markdown, Libtask, MacroTools
@reexport using Distributions, MCMCChain, Libtask
@reexport using Distributions, MCMCChains, Libtask
using Flux.Tracker: Tracker

import Base: ~, convert, promote_rule, rand, getindex, setindex!
import Distributions: sample
import MCMCChain: AbstractChains, Chains
import MCMCChains: AbstractChains, Chains

const PROGRESS = Ref(true)
function turnprogress(switch::Bool)
Expand Down
24 changes: 12 additions & 12 deletions src/inference/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ using StatsFuns: logsumexp

import Distributions: sample
import ..Core: getchunksize, getADtype
import ..Utilities: Sample
import ..Utilities: Sample, save, resume

export InferenceAlgorithm,
Hamiltonian,
Expand All @@ -21,20 +21,20 @@ export InferenceAlgorithm,
HamiltonianRobustInit,
SampleFromPrior,
AnySampler,
MH,
MH,
Gibbs, # classic sampling
HMC,
SGLD,
SGHMC,
HMCDA,
HMC,
SGLD,
SGHMC,
HMCDA,
NUTS, # Hamiltonian-like sampling
DynamicNUTS,
IS,
SMC,
CSMC,
PG,
PIMH,
PMMH,
IS,
SMC,
CSMC,
PG,
PIMH,
PMMH,
IPMCMC, # particle-based sampling
getspace,
assume,
Expand Down
4 changes: 2 additions & 2 deletions src/inference/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -178,12 +178,12 @@ function sample(
@info(" Running time = $time_total;")

if resume_from != nothing # concat samples
pushfirst!(samples, resume_from.value2...)
pushfirst!(samples, resume_from.info[:samples]...)
end
c = Chain(0.0, samples) # wrap the result by Chain

if save_state # save state
save!(c, spl, model, varInfo)
c = save(c, spl, model, varInfo, samples)
end

return c
Expand Down
6 changes: 3 additions & 3 deletions src/inference/hmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ function sample(model::Model, alg::Hamiltonian;
reuse_spl_n=0, # flag for spl re-using
adapt_conf=STAN_DEFAULT_ADAPT_CONF, # adapt configuration
)

spl = reuse_spl_n > 0 ?
resume_from.info[:spl] :
Sampler(alg, adapt_conf)
Expand Down Expand Up @@ -184,13 +184,13 @@ function sample(model::Model, alg::Hamiltonian;
end

if resume_from != nothing # concat samples
pushfirst!(samples, resume_from.value2...)
pushfirst!(samples, resume_from.info[:samples]...)
end
c = Chain(0.0, samples) # wrap the result by Chain
if save_state # save state
# Convert vi back to X if vi is required to be saved
if spl.alg.gid == 0 invlink!(vi, spl) end
save!(c, spl, model, vi)
c = save(c, spl, model, vi, samples)
end
return c
end
Expand Down
2 changes: 1 addition & 1 deletion src/inference/is.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ function sample(model::Model, alg::IS)

le = logsumexp(map(x->x[:lp], samples)) - log(n)

Chain(exp.(le), samples)
Chain(le, samples)
end

function assume(spl::Sampler{<:IS}, dist::Distribution, vn::VarName, vi::VarInfo)
Expand Down
4 changes: 2 additions & 2 deletions src/inference/mh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -165,11 +165,11 @@ function sample(model::Model, alg::MH;
println(" Accept rate = $accept_rate;")

if resume_from != nothing # concat samples
pushfirst!(samples, resume_from.value2...)
pushfirst!(samples, resume_from.info[:samples]...)
end
c = Chain(0.0, samples) # wrap the result by Chain
if save_state # save state
save!(c, spl, model, vi)
c = save(c, spl, model, vi, samples)
end

c
Expand Down
34 changes: 17 additions & 17 deletions src/inference/pgibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,13 @@ function step(model, spl::Sampler{<:PG}, vi::VarInfo)
return particles[indx].vi, true
end

function sample( model::Model,
function sample( model::Model,
alg::PG;
save_state=false, # flag for state saving
resume_from=nothing, # chain to continue
reuse_spl_n=0 # flag for spl re-using
)

spl = reuse_spl_n > 0 ?
resume_from.info[:spl] :
Sampler(alg)
Expand Down Expand Up @@ -128,24 +128,24 @@ function sample( model::Model,

loge = exp.(mean(spl.info[:logevidence]))
if resume_from != nothing # concat samples
pushfirst!(samples, resume_from.value2...)
pre_loge = resume_from.weight
pushfirst!(samples, resume_from.info[:samples]...)
pre_loge = exp.(resume_from.logevidence)
# Calculate new log-evidence
pre_n = length(resume_from.value2)
loge = exp.((log(pre_loge) * pre_n + log(loge) * n) / (pre_n + n))
pre_n = length(resume_from.info[:samples])
loge = (log(pre_loge) * pre_n + log(loge) * n) / (pre_n + n)
end
c = Chain(loge, samples) # wrap the result by Chain

if save_state # save state
save!(c, spl, model, vi)
c = save(c, spl, model, vi, samples)
end

return c
end

function assume( spl::Sampler{T},
dist::Distribution,
vn::VarName,
function assume( spl::Sampler{T},
dist::Distribution,
vn::VarName,
_::VarInfo
) where T<:Union{PG,SMC}

Expand Down Expand Up @@ -177,10 +177,10 @@ function assume( spl::Sampler{T},
return r, zero(Real)
end

function assume( spl::Sampler{A},
dists::Vector{D},
vn::VarName,
var::Any,
function assume( spl::Sampler{A},
dists::Vector{D},
vn::VarName,
var::Any,
vi::VarInfo
) where {A<:Union{PG,SMC},D<:Distribution}
error("[Turing] PG and SMC doesn't support vectorizing assume statement")
Expand All @@ -191,9 +191,9 @@ function observe(spl::Sampler{T}, dist::Distribution, value, vi) where T<:Union{
return zero(Real)
end

function observe( spl::Sampler{A},
ds::Vector{D},
value::Any,
function observe( spl::Sampler{A},
ds::Vector{D},
value::Any,
vi::VarInfo
) where {A<:Union{PG,SMC},D<:Distribution}
error("[Turing] PG and SMC doesn't support vectorizing observe statement")
Expand Down
8 changes: 4 additions & 4 deletions src/inference/pmmh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ function step(model, spl::Sampler{<:PMMH}, vi::VarInfo, is_first::Bool)
return vi, is_accept
end

function sample( model::Model,
function sample( model::Model,
alg::PMMH;
save_state=false, # flag for state saving
resume_from=nothing, # chain to continue
Expand Down Expand Up @@ -170,12 +170,12 @@ function sample( model::Model,
println(" Accept rate = $accept_rate;")

if resume_from != nothing # concat samples
pushfirst!(samples, resume_from.value2...)
pushfirst!(samples, resume_from.info[:samples]...)
end
c = Chain(0.0, samples) # wrap the result by Chain
c = Chain(-Inf, samples) # wrap the result by Chain

if save_state # save state
save!(c, spl, model, vi)
c = save(c, spl, model, vi, samples)
end

c
Expand Down
2 changes: 1 addition & 1 deletion src/inference/smc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,6 @@ function sample(model::Model, alg::SMC)
end
end
w, samples = getsample(particles)
res = Chain(w, samples)
res = Chain(log(w), samples)
return res
end
2 changes: 1 addition & 1 deletion src/utilities/Utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module Utilities
using ..Turing: Sampler
using Distributions, Bijectors
using StatsFuns, SpecialFunctions
using MCMCChain: AbstractChains, Chains, names2inds
using MCMCChains: AbstractChains, Chains, setinfo
import Distributions: sample

export resample,
Expand Down
Loading

0 comments on commit 6e6d7a7

Please sign in to comment.