-
Notifications
You must be signed in to change notification settings - Fork 219
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
VarInfo and type stability refactor #660
Conversation
e655b49
to
4220e18
Compare
Benchmarks show a decent speedup. |
I will see if I can make this any faster. |
4220e18
to
4ebd69c
Compare
|
I added a description of the main changes and design decisions made in this PR. Please give me your feedback on the concept. The code still needs a rebase so it is not clean. |
6bdb1b7
to
45c785e
Compare
45c785e
to
ab21264
Compare
Currently, I am trying to "fix" the samplers and make their infos type stable. The following list shows the currently working inference algorithms on this simple program: @model gdemo(x, y) = begin
s ~ InverseGamma(2,3)
m ~ Normal(0,sqrt(s))
x ~ Normal(m, sqrt(s))
y ~ Normal(m, sqrt(s))
end
|
c9704a2
to
b80a361
Compare
b80a361
to
cedccdc
Compare
The following tests are failing for various reasons:
Wish me luck! |
709d8b4
to
23b06b9
Compare
This PR dissects the VarInfo and performance changes from #637.
The main change in this PR is that it defines 2 different subtypes of
AbstractVarInfo
:UntypedVarInfo
TypedVarInfo
The
UntypedVarInfo
is exactly the same as the currentVarInfo
. ThisVarInfo
is important because it lets us figure out the types of the parameters by actually running the model once. When the types have been figured out, we can now reduce theUntypedVarInfo
to a more performantTypedVarInfo
with all theeltype
s as tight as possible. TheTypedVarInfo
is designed around theNamedTuple
construct in Julia, which lets us have a mapping from the parameters'Symbol
s to their differently typed value vectors and other meta-information. So this lets us decouple theeltypes
of different parameters which lowers the need forReal
insideTypedVarInfo
. However, there is one problem remaining. When using a Hamiltonian inference algorithm, theeltype
of a sampler's parameter is not actually known in advance, since it depends on the function being differentiated by the autodiff package, which is part of the first parameter ofForwardDiff.Dual
for example. Additionally, we should still be able to just run the model and have the parameter'seltype
beFloat64
, not any autodiff type. This makes it difficult to just fix the concrete type of the instance ofTypedVarInfo
throughout the sampling process. The workaround used for this here is to create a newVarInfo
instance using https://github.com/TuringLang/Turing.jl/blob/mt/type_stability2/src/core/ad.jl#L111, which replaces theFloat64
vectors of theSampler
's parameters with similar vectors of the autodiff type before passing it tomodel
. This is less than ideal but it lets us have concretely typedTypedVarInfo
instances throughout. Note that to do this in a type stable way, the space of theSampler
was encoded in its type as a parameter.These are the remaining tasks to do here:
sampler.info
into a type stable struct, one for each inference algorithmsample.value
into 1 or more type stable fieldsTypedVarInfo
_
before the names of functions inVarReplay
which are only used internally_sample
functions are fully type stable (this is run in the second iteration onwards) except for theReal
in model definitions discussed in Completely eliminatingReal
from the model and pre-allocating #665