Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[HMC] Update random variables independent of the joint likelihood ? #372

Closed
emilemathieu opened this issue Nov 13, 2017 · 24 comments
Closed

Comments

@emilemathieu
Copy link
Collaborator

emilemathieu commented Nov 13, 2017

In HMC during a leapfrog step, all variables are updated (via p -= ϵ * grad / 2; θ += ϵ * p).

Should't we avoid updating random variables which do not affect the joint probability ? Otherwise, we might be moving these random variables to one of their prior's modes.

We could detect those variables since their partial derivative (i.e. the associated component in grad = gradient(vi, model, spl)) is equal to the derivative of their prior. For instance by doing the following:

find_independent(vi::VarInfo, spl:Sampler, grad::Vector{Real}) = begin
  gidcs = getidcs(vi, spl)
  independent = Vector{Int}(length(gidcs))
  for i in 1:length(gidcs)
    prior_diff = ForwardDiff.derivative(logpdf(vi.dists[i], vi.vals[vi.ranges[i]]))
    independent[i] = (prior_diff == grad[i])
  end
end

# in leapfrog(...) ...
grad = gradient(vi, model, spl)
independent_θ = find_independent(vi, spa, grad)
grad[independent_θ] = 0

This would be useful (but correct ?) for a BNP mixture model to avoid updating the used clusters' weight and location.
We could otherwise resample these variables from their prior

for index in independent_θ
  θ[index] = rand(vi.dists[index])
end
@emilemathieu
Copy link
Collaborator Author

@yebai @xukai92 what do you think about that ?

@yebai
Copy link
Member

yebai commented Nov 29, 2017

It makes a lot of sense to do this - actually, this is something that @xukai92 and I thought about when designing the HMC sampler. Basically, we wanted a way to figure out which random variables in VarInfo have disappeared from the model. Knowing these disappeared variables allow us to handle them differently, either

  • delete them from the VarInfo data structure
  • keep them, but stop updating them in our MCMC schemes (or re-simulate them from the prior)

@xukai92 do you have any ideas here?

@emilemathieu
Copy link
Collaborator Author

emilemathieu commented Nov 29, 2017

A random variable could still "be in the model" in the sense that it is sampled but not used after, meaning that the partial derivative of the joint with respect to that variable will only be its prior derivative. I think that we should keep them but not update them since in the case of HMC within Gibbs, such a variable could be "unused" (likelihood independent of that variable) at one iteration, but "used" in a later iteration.

I don't know how we can detect these random variables in an efficient way.

@xukai92
Copy link
Member

xukai92 commented Nov 29, 2017

I'm not 100% sure I understand the situation when this is useful. Do you mean some of the dependency between likelihood and variables may disappear in some iterations because of stochastic conditions?

@emilemathieu
Copy link
Collaborator Author

Yes, for instance with a Dirichlet mixture process, a cluster can be empty at an iteration (no observation assigned to it), therefore the cluster's weight and location are unused.

@xukai92
Copy link
Member

xukai92 commented Nov 29, 2017

Emm it's interesting. It's detecting stochastic dependency dynamically instead of using the compiler to check the dependency, which was what we discussed before.

Let me write a minimum example for further discussion

@model test() begin
  a ~ Normal(0, 1)
  b ~ Normal(0, 1)
  if a > 0
    1.5 ~ Normal(a + b, 1)
  else
    1.5 ~ Normal(a, 1)
  end
end

sample(test(), Gibbs(..., HMC(..., :b), PG(..., :a)))

In this model, if in the PG step, a gets a <= 0, then in the HMC we shouldn't change b at all.

I think your proposal will do the job. I am also thinking if we can do the same thing but separately tracking log_likelihood and log_prior during the model simulation, and compute log_joint = log_likelihood + log_prior only needed. If we do so, the dual parts of those variables which log_likelihood are not dependent on will simply be 0 (?)

Also, I think we may be able to use Cassette.jl in the future to do this but I'm not sure either.

@emilemathieu
Copy link
Collaborator Author

That's a perfect test ! Good idea, in assume{T<:Hamiltonian}(... we could accumulate logpdf_with_trans(dist, r, istrans(vi, vn)) in some spl.info[:log_likelihood] field and check which components of the gradient are equal to 0.

@xukai92
Copy link
Member

xukai92 commented Nov 29, 2017

Oh yes, we can simply use spl.info. And I think you mean in observe() we could accumulate ... not assume(), right?

One minor concern is that in this approach, we do more additions than before. I don't know to what extent it will slow us down but we can separate them into two fields in VarInfo anyway if this happens.

@emilemathieu
Copy link
Collaborator Author

Oh yes I mean observe. Since logp is a Vector{Real} could we accumulate the prior in the first component, and the likelihood in the second for instance ? And at the end addition this two (modulo components that we don't want to update). I'm not sure what's the cleanest and most efficient way to do this.

@xukai92
Copy link
Member

xukai92 commented Nov 30, 2017

The reason logp is a Vector is that we designed VarInfo to be able to store multiple copies of sample state in the same time, however, it is not really used.

I would prefer split logp (i.e. logposterior) into loglike and logprior, and each time we originally use logp can be simply replaced by loglike+logprior or by an interface logp(vi).

Let's see how @yebai thinks about this design!

@emilemathieu
Copy link
Collaborator Author

@xukai92 I'm working on a PR, I like your solution !
Should we keep loglike and logprior being Vector{Real} or maybe simply Real ?

@xukai92
Copy link
Member

xukai92 commented Nov 30, 2017

Maybe keep it as Vector at this moment, until we decide to abandon the ability of store multiple copies of sample state in the same time

@emilemathieu
Copy link
Collaborator Author

Is there no computational overhead by using Vector ?

@xukai92
Copy link
Member

xukai92 commented Nov 30, 2017

There is but I think that's not significant (depending on our previous profiling experience)

@emilemathieu
Copy link
Collaborator Author

emilemathieu commented Nov 30, 2017

I have just realised that what we are currently doing is sampling these "unused" random variables from their priors, since HMC is sampling from the joint (which is proportional to the posterior) and the joint reduces to the prior for these random variables. Therefore I'm not sure whether that's an issue or not.

@emilemathieu
Copy link
Collaborator Author

PR #401

@emilemathieu
Copy link
Collaborator Author

emilemathieu commented Dec 1, 2017

@xukai92 Actually this solution cannot work since a random variable x_i can be in the joint without being in the likelihood but via other priors. For instance with

@model gdemo(x) = begin
  s ~ InverseGamma(2,3)
  m ~ Normal(0,sqrt.(s))
  x[1] ~ Normal(m, 1)
  x[2] ~ Normal(m, 1)
  return s, m
end

s is independent of the likelihood but appears twice in the joint: its prior p(s) but also in p(m|s).

@xukai92
Copy link
Member

xukai92 commented Dec 7, 2017

I'm a bit lost here. Do you mean we shouldn't update s here or something?

@emilemathieu
Copy link
Collaborator Author

In that example s SHOULD be updated, but with our proposed fix it won't appear in loglike's gradient, but only in logprior's gradient because of its own prior and m's prior. Therefore I think that this solution does not work.

@xukai92
Copy link
Member

xukai92 commented Dec 11, 2017

I see. You are right.

We probably still need to do that in model compilation time..., in which we can extract the variable dependency somehow. But again this way we cannot deal with the stochastic control flow easily.

I don't know if we can use Cassette.jl for run-time stochastic dependency check when it's mature.

@emilemathieu
Copy link
Collaborator Author

Indeed accessing the underlying (dependencies) graph could be a way. I'm trying to think about an other way.

@xukai92
Copy link
Member

xukai92 commented Dec 17, 2017

What's the alternative in your mind?

@emilemathieu
Copy link
Collaborator Author

Maybe to check this independence during each assume and observe statement, by updating some array of boolean.

@yebai
Copy link
Member

yebai commented Jan 8, 2018

Indeed accessing the underlying (dependencies) graph could be a way. I'm trying to think about an other way.

I don't know if we can use Cassette.jl for run-time stochastic dependency check when it's mature.

@emilemathieu @xukai92 I'm hoping that we will eventually be able to use Cassette.jl or something similar to extract computational graph since it will help speed up both computing the joint target distribution and local conditional distributions. Before we can do that, let's stick with the slow but correct solution - rerun the program fully.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants