Skip to content
This repository has been archived by the owner on Sep 10, 2023. It is now read-only.

Add quap for Turing #24

Merged
merged 2 commits into from
May 20, 2020
Merged

Conversation

karajan9
Copy link
Contributor

This takes a Turing model and performs quadratic approximation, which means finding the maximum of the loglikelihood (or the minimum of the NLL) and taking the hessian to find the variance and covariance of the parameters.

The goal is to make it just work™ which is why I put some tricks in there like sampling from the prior to find a starting point and chaining optimizers to make sure it finds a good optimum. I tried to put some checks in there, still, there are probably many ways to make this break.

On the simple model in the file it takes ~5ms on my computer, so I didn't take much time to make it any faster.

Since I consider this very experimental, there is still a list of things that need to be done:

  • Importing the file
  • Add dependencies
  • Tests and testing this with more than one or two models
  • Docstrings
  • Formatting? For example, in the existing code I have found some 2 spaces, 4 spaces and tabs. I think it would be good to unify this.
  • Add more features like calling precis

This takes a Turing model and performs quadratic approximation, which means finding the maximum of the loglikelihood (or the minimum of the NLL) and taking the hessian to find the variance and covariance of the parameters.

The goal is to make it just work™ which is why I put some tricks in there like sampling from the prior to find a starting point and chaining optimizers to make sure it finds a good optimum. I tried to put some checks in there, still, there are probably many ways to make this break.

On the simple model in the file it takes ~5ms on my computer, so I didn't take much time to make it any faster.

Since I consider this very experimental, there is still a list of things that need to be done:
- Importing the file
- Add dependencies
- Tests and testing this with more than one or two models
- Docstrings
- Formatting? For example, in the existing code I have found some 2 spaces, 4 spaces and tabs. I think it would be good to unify this.
- Add more features like calling `precis`
@goedman goedman merged commit 6285781 into StatisticalRethinkingJulia:master May 20, 2020
@goedman
Copy link
Member

goedman commented May 20, 2020

@karajan9 I definitely like this approach.

Do you think (know if) it is possible to get hold of the mu & sigma in the prior and simply use those as start values for quap? E.g.:

quap(m, [178.0,20.0])
(coef = [154.60702994440686, 7.731331556591009], vcov = [0.16973954614799505 0.00021817160619773412; 0.00021817160619773412 0.08490575274440608], converged = true)

It would be nice if we could have just a single Turing model definition instead of having the Turing model version augmented with the return statement just for quap.

I think Turing.VarInfo(m) holds the mu & sigma.

I remember Richard somewhere in the book mentions that quap might sometimes need some coaching.

Rob

@karajan9
Copy link
Contributor Author

karajan9 commented May 20, 2020

Yes, this would be nicer, right now with the model in the file (I guess that's m4.1 or something?) return μ, σ is okay, while all of

return σ, μ
return σ
return μ
nothing

won't work correctly (I think, haven't tested).
(Although according to the Turing docs this is the way to go.)

It's pretty easy to get the distributions needed:

vi = Turing.VarInfo(m)
d = vi.metadata.μ.dists  ----->
1-element Array{Normal{Float64},1}:
 Normal{Float64}(μ=178.0, σ=20.0)
median(d) == 178.0

which means that

function quap(model::DynamicPPL.Model; method = SimulatedAnnealing())
    vi = Turing.VarInfo(m)
    meta = vi.metadata

    start = [median(var.dists[1]) for var in meta]
    @show start

    quap(model, start; method = method)
end

should work pretty well. I'm a little weary because of the hardcoded [1]. I don't even know what else is supposed to go into that array, so as soon as anything does, it means trouble. Maybe I can ask the Turing devs.

If I change the model to

@model height(heights) = begin
    μ ~ Normal(178, 20)
    σ ~ Uniform(0, 50)
    tmp ~ Normal(μ, 1)
    heights .~ Normal(tmp, σ)
end

then tmps distribution is hardcoded to something like Normal(165..., 1), presumably some random number drawn from μ. I'm not sure how to circumvent that except maybe doing a few samples; it seems to get redrawn by running m().

In spite of those few open questions this seems to me like a good idea and a way better solution than what I cooked up.

@goedman
Copy link
Member

goedman commented May 20, 2020

Hmm, not sure if I didn't like your initial approach better then adding the tmp ~ Normal(μ, 1) line.

So in m4.1t, your solution would be:

@model line(x, y) = begin
    #priors
    alpha ~ Normal(178.0, 100.0)
    beta ~ Normal(0.0, 10.0)
    s ~ Uniform(0, 50)

    #model
    mu = alpha .+ beta*x
    y .~ Normal.(mu, s)

    return alpha, beta, s
end

which can then be used as in your example:

ml=line(x, y);
quap(ml, [160.0, 0.0, 0.0]) ==>
(coef = [154.5972710428122, 0.9050152762513541, 5.071883205900706], vcov = [0.07307900945890079 -1.4023086590478741e-11 2.570950382816481e-6; -1.4023086590478741e-11 0.0017579294049063208 -1.9930588768800978e-7; 2.570950382816481e-6 -1.9930588768800978e-7 0.036540143453361824], converged = true)

quap(ml, [160.0, 0.0, 25.0]) ==>
(coef = [154.59725198909285, 0.90501297192992, 5.071871422255509], vcov = [0.07307866985227743 -1.4612147142026734e-11 2.296387802369754e-6; -1.4612147142026734e-11 0.001757921236960678 -2.3250589352437402e-7; 2.296387802369754e-6 -2.3250589352437402e-7 0.03653971897865233], converged = true)

MvNormal(res.coef, res.vcov)
FullNormal(
dim: 3
μ: [154.59725198909285, 0.90501297192992, 5.071871422255509]
Σ: [0.07307866985227743 -1.4612147142026734e-11 2.296387802369754e-6; -1.4612147142026734e-11 0.001757921236960678 -2.3250589352437402e-7; 2.296387802369754e-6 -2.3250589352437402e-7 0.03653971897865233]
)

while script m4.1t.jl works as before.

I actually don't mind at all if we would ask our customers to always provide start values. They do that anyway in the priors.

@karajan9
Copy link
Contributor Author

karajan9 commented May 20, 2020

I'm not sure if we aren't talking past each other. The tmp ~ Normal(μ, 1) was just to show that there might be problems if the models get bigger/hierarchical. But this is easily resolved by taking multiple samples, just like I did in the first version by sampling from the prior. The reference to m4.1t was probably misguiding, sorry for that. But I'll take it as an example going forward.

So basically, there are three possible versions given this model:

@model line(x, y) = begin
    #priors
    alpha ~ Normal(178.0, 100.0)
    beta ~ Normal(0.0, 10.0)
    s ~ Uniform(0, 50)

    #model
    mu = alpha .+ beta*x
    y .~ Normal.(mu, s)

    return alpha, beta, s
end
  1. Specifying a start vector. In this case it could be [178.0, 0.0, 25.0] (the mean values).
  2. The version in the PR that samples from the prior. Requires the return statement. Since it takes the median of a bunch of samples the vector will be approximately [178.0, 0.0, 25.0] as well.
  3. The quap version in this comment. It looks at the distributions and from there calculates (not averages) the identical start vector [178.0, 0.0, 25.0] without the need for a return statement. In fact, it will work when you remove the return statement from the model above, while version 2) will fail.
  4. (Using NUTS to compare results.)

With

d = CSV.read(datadir("exp_raw/Howell_1.csv"), copycols = true)
d2 = filter(row -> row.age >= 18, d)
d2 = d[d.age .>= 18, :]
m = line(d2.weight, d2.height)

all 3 optimizations result in (approximately)

(coef = [113.90338516901606, 0.9045062614138886, 5.071867147465665], vcov = [3.6300435001413702 -0.07906095794838724 0.00034555641375722055; -0.07906095794838724 0.0017572945349224637 -7.530697042737666e-6; 0.00034555641375722055 -7.530697042737666e-6 0.036539573181395675], converged = true)

and NUTS

  parameters      mean     std  naive_se    mcse       ess   r_hat
  ──────────  ────────  ──────  ────────  ──────  ────────  ──────
       alpha  113.7695  1.7799    0.0796  0.1339  184.8117  0.9997
        beta    0.9075  0.0393    0.0018  0.0029  183.2906  0.9990
           s    5.1037  0.2186    0.0098  0.0161  179.0444  0.9985

@karajan9
Copy link
Contributor Author

It seems like this PR might need a rework soon anyways 😄
TuringLang/Turing.jl#1230

@cpfiffer
Copy link

Yeah, after TuringLang/Turing.jl#1230 goes through you should be able to just do

using Turing
using Optim

@model function something()
   some kinda model
end

model = something()

# To account for the prior probability
optimize(model, MAP())

# If you just want maximum likelihood 
optimize(model, MLE())

@goedman
Copy link
Member

goedman commented May 20, 2020

Cool!

@goedman
Copy link
Member

goedman commented May 21, 2020

@karajan9 Is your decision to hold off further work on quap-turing for now until #1230 is released? If so, I hope it was still worthwhile the effort you put in! I certainly learned again some stuff and it was good to refresh my Turing memory a bit!

@karajan9
Copy link
Contributor Author

karajan9 commented May 21, 2020

I think so, yes. It sounds like it's going to be merged relatively soon, and in the meantime I hope the current version will do.
When it's here my plan is to use it to provide an interface similar to the book (just a thin layer, should be pretty easy then with all the hard work already done by the Turing team) and integrate it a little more (tests, docstrings, making it work with precis or sampling from distributions, etc.).
Since most of the current code won't be used any more then I don't think working on it is a good investment on time, but I certainly enjoyed making it work (although I'm equally glad we'll get a well thought out version soon!).

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants