Skip to content
This repository has been archived by the owner on Sep 10, 2023. It is now read-only.

Update to quap to Turings MAP #26

Merged
merged 3 commits into from
May 27, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 17 additions & 74 deletions src/quap_turing.jl
Original file line number Diff line number Diff line change
@@ -1,96 +1,39 @@
# Everything is very experimental! If you want to use it, you need to import it yourself.
# Everything is experimental! If you want to use it, you need to import it yourself.
# If you find problems, please open an issue.

using Optim, NLSolversBase
using DynamicPPL
using Optim
using Turing
using StatsBase
using LinearAlgebra

# Get the NLL function from a Turing model. Taken from:
# https://turing.ml/dev/docs/using-turing/advanced#maximum-a-posteriori-estimation
function get_nlogp(model)
# Construct a trace struct
vi = Turing.VarInfo(model)

# Define a function to optimize.
function nlogp(sm)
spl = Turing.SampleFromPrior()
new_vi = Turing.VarInfo(vi, spl, sm)
model(new_vi, spl)
-Turing.getlogp(new_vi)
end

return nlogp
end


# Run like
# using Turing
# @model height(heights) = begin
# μ ~ Normal(178, 20)
# σ ~ Uniform(0, 50)
# heights .~ Normal(μ, σ)
# return μ, σ <-- you need this line or you need to provide a
# start point
# end
# m = height(d2.height)
# res = quap(m)
# MvNormal(res.coef, res.vcov)

# To find a good starting point, try to sample from the prior
function quap(model::DynamicPPL.Model; method = SimulatedAnnealing())
# Find out if sampling from the prior is possible.
# Something like `return μ` or `return μ, σ` will give you a Number or a Tuple,
# while leaving it out will return the data, usually an Array. This isn't
# bullet-proof (e.g. you can pass a single number as data) but I don't have
# anything more sophisticated right now.
testprior = typeof(m())
if !(testprior <: Number || testprior <: Tuple)
error("Your model must either include a return statement to sample from the prior or you must provide a start point: `quap(model, start; method)`")
end

priors = [[m()...] for _ in 1:100] # Tuples -> Arrays
start = median(hcat(priors...), dims = 2)
start = [start...] # 2x1 Matrix -> Vector

quap(model, start; method = method)
end

# Find the MAP via optimization and take the hessian at that point.
# During a bit of simple testing SimulatedAnnealing did the best job getting close
# to the minimum while NelderMead while good at finishing the job. So if
# SimulatedAnnealing didn't converge or your supplied method errors, try again
# with NelderMead (or BFGS in the 1D case).
# Find the MAP via optimization and get the information matrix at that point.
# Look if your solution converged. Sometimes even solutions that didn't converge
# might be pretty good, on the other hand, just because the solver converged that
# doesn't mean you got the point you are looking for; this isn't even global
# optimization. In any case, trying other methods or starting points might help.
# Adapted from:
# https://julianlsolvers.github.io/Optim.jl/stable/#examples/generated/maxlikenlm/
function quap(model::DynamicPPL.Model, start; method = SimulatedAnnealing())
nlogp = get_nlogp(model)
func = TwiceDifferentiable(vars -> nlogp(vars), start; autodiff = :forward)
# optimization. In any case, trying other optimization methods might help.
function quap(model::Turing.Model, args...; kwargs...)
opt = optimize(model, MAP(), args...; kwargs...)

converged = true
MAP = start
coef = opt.values.array
var_cov_matrix = informationmatrix(opt)
sym_var_cov_matrix = Symmetric(var_cov_matrix) # lest MvNormal complains, loudly
converged = Optim.converged(opt.optim_result)

methods = [
method,
length(start) == 1 ? BFGS() : NelderMead(), # NelderMead doesn't work in 1D
]
for method in methods
try
opt = optimize(func, start, method)
MAP = Optim.minimizer(opt)
converged = Optim.converged(opt)
catch
converged = false
end
converged && break
distr = if length(coef) == 1
Normal(coef[1], √sym_var_cov_matrix[1]) # Normal expects stddev
else
MvNormal(coef, sym_var_cov_matrix) # MvNormal expects variance matrix
end

numerical_hessian = hessian!(func, MAP)
var_cov_matrix = inv(numerical_hessian)
sym_var_cov_matrix = Symmetric(var_cov_matrix) # lest MvNormal complains, loudly

(coef = MAP, vcov = sym_var_cov_matrix, converged = converged)
(coef = coef, vcov = sym_var_cov_matrix, converged = converged, distr = distr)
end