Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

VarInfo and type stability refactor #660

Closed
wants to merge 23 commits into from
Closed

Conversation

mohamed82008
Copy link
Member

@mohamed82008 mohamed82008 commented Jan 26, 2019

This PR dissects the VarInfo and performance changes from #637.

The main change in this PR is that it defines 2 different subtypes of AbstractVarInfo:

  1. UntypedVarInfo
  2. TypedVarInfo

The UntypedVarInfo is exactly the same as the current VarInfo. This VarInfo is important because it lets us figure out the types of the parameters by actually running the model once. When the types have been figured out, we can now reduce the UntypedVarInfo to a more performant TypedVarInfo with all the eltypes as tight as possible. The TypedVarInfo is designed around the NamedTuple construct in Julia, which lets us have a mapping from the parameters' Symbols to their differently typed value vectors and other meta-information. So this lets us decouple the eltypes of different parameters which lowers the need for Real inside TypedVarInfo. However, there is one problem remaining. When using a Hamiltonian inference algorithm, the eltype of a sampler's parameter is not actually known in advance, since it depends on the function being differentiated by the autodiff package, which is part of the first parameter of ForwardDiff.Dual for example. Additionally, we should still be able to just run the model and have the parameter's eltype be Float64, not any autodiff type. This makes it difficult to just fix the concrete type of the instance of TypedVarInfo throughout the sampling process. The workaround used for this here is to create a new VarInfo instance using https://github.com/TuringLang/Turing.jl/blob/mt/type_stability2/src/core/ad.jl#L111, which replaces the Float64 vectors of the Sampler's parameters with similar vectors of the autodiff type before passing it to model. This is less than ideal but it lets us have concretely typed TypedVarInfo instances throughout. Note that to do this in a type stable way, the space of the Sampler was encoded in its type as a parameter.

These are the remaining tasks to do here:

  • Make sampler.info into a type stable struct, one for each inference algorithm
  • Make sample.value into 1 or more type stable fields
  • Make sure tests pass
  • Add tests for TypedVarInfo
  • Add _ before the names of functions in VarReplay which are only used internally
  • Make sure all of the _sample functions are fully type stable (this is run in the second iteration onwards) except for the Real in model definitions discussed in Completely eliminating Real from the model and pre-allocating #665

@mohamed82008
Copy link
Member Author

Benchmarks show a decent speedup.

@mohamed82008
Copy link
Member Author

I will see if I can make this any faster.

@yebai
Copy link
Member

yebai commented Jan 26, 2019

I will see if I can make this any faster.

Can you post some runtime numbers here?

@mohamed82008
Copy link
Member Author

I added a description of the main changes and design decisions made in this PR. Please give me your feedback on the concept. The code still needs a rebase so it is not clean.

@mohamed82008
Copy link
Member Author

mohamed82008 commented Feb 17, 2019

Currently, I am trying to "fix" the samplers and make their infos type stable. The following list shows the currently working inference algorithms on this simple program:

@model gdemo(x, y) = begin
	s ~ InverseGamma(2,3)
	m ~ Normal(0,sqrt(s))
	x ~ Normal(m, sqrt(s))
	y ~ Normal(m, sqrt(s))
end
  • IS sample(gdemo(2.0, 1.5), IS(1000))
  • MH sample(gdemo(2.0, 1.5), MH(1000))
  • HMC sample(gdemo(2.0, 1.5), HMC(1000, 0.1, 5))
  • SGHMC sample(gdemo(2.0, 1.5), SGHMC(1000, 0.001, 0.1))
  • SGLD sample(gdemo(2.0, 1.5), SGLD(100, 0.01))
  • HMCDA sample(gdemo(2.0, 1.5), HMCDA(1000, 0.15, 0.65))
  • NUTS sample(gdemo(2.0, 1.5), NUTS(1000, 0.65))
  • Gibbs sample(gdemo(2.0, 1.5), Gibbs(1000, MH(1, :s), SGLD(100, 0.01, :m)))
  • SMC sample(gdemo(2.0, 1.5), SMC(1000))
  • IPMCMC sample(gdemo(2.0, 1.5), IPMCMC(100, 100, 4, 2))
  • PMMH sample(gdemo(2.0, 1.5), PMMH(1000, SMC(20, :m), MH(10,:s)))
  • PG sample(gdemo(2.0, 1.5), PG(10,1000))

@mohamed82008
Copy link
Member Author

mohamed82008 commented Mar 3, 2019

The following tests are failing for various reasons:

  • compiler.jl / assume.jl
  • compiler.jl / explicit_ret.jl
  • compiler.jl / observe.jl
  • compiler.jl / opt_param_of_dist.jl
  • dynamichmc-support.jl / dynamic_hmc.jl
  • gibbs.jl / gibbs_constructor.jl
  • gibbs.jl / gibbs2.jl
  • hmc.jl / constrained_simplex.jl
  • hmc.jl / matrix_support.jl
  • hmc.jl / multivariate_support.jl
  • hmcda.jl / hmcda_cons.jl
  • io.jl / save_resume_chain.jl
  • ipmcmc.jl / ipmcmc.jl
  • ipmcmc.jl / ipmcmc2.jl
  • is.jl / importance_sampling.jl
  • mh.jl / mh_cons.jl
  • models.jl / single_dist_correctness.jl
  • nuts.jl / nuts_cons.jl
  • pmmh.jl / pmmh_cons.jl
  • pmmh.jl / pmmh.jl
  • resample.jl / particlecontainer.jl
  • sghmc.jl / sghmc_cons.jl
  • sghmc.jl / sghmc.jl
  • sgld.jl / sgld_cons.jl
  • sgld.jl / sgld.jl
  • trace.jl / trace.jl
  • varinfo.jl / orders.jl
  • varinfo.jl / test_varname.jl
  • varinfo.jl / varinfo.jl
  • vectorization.jl / vec_assume.jl
  • vectorization.jl / vectorize_observe.jl

Wish me luck!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants