-
Notifications
You must be signed in to change notification settings - Fork 219
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TS 2: introduce TypedVarInfo and fix spl.info[:cache_updated] #742
Conversation
@generated function _getranges(vi::TypedVarInfo{Tvis}, idcs) where Tvis
args = []
for f in fieldnames(Tvis)
push!(args, :($f = _map(vi, $(QuoteNode(f)), idcs.$f)))
end
if length(args) == 0
nt = :(NamedTuple())
else
nt = :(($(args...),))
end
return nt
end can be replaced with: _getranges(vi::TypedVarInfo, idcs) = _getranges(vi.vis, vi, idcs)
@inline function _getranges(vis::NamedTuple{names}, vi::TypedVarInfo, idcs) where names
length(names) === 0 && return NamedTuple()
f = names[1]
v = _map(vi, f, getfield(idcs, f))
nt = NamedTuple{(f,), Tuple{typeof(v)}}(v)
return merge(nt, _getranges(Base.tail(vis), vi, idcs))
end |
@mohamed82008 is this ready for a review? |
Yes if you don't mind the generated functions. The semantics will not change by making them into normal functions. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I left some comments. Most of them are things I didn't quite understand and ask for explanations. Also, I didn't check the set
and get
functions carefully but I'm happy with them given the tests are covered and pass.
@@ -58,7 +58,7 @@ end | |||
|
|||
function assume(spl::Sampler{<:IS}, dist::Distribution, vn::VarName, vi::VarInfo) | |||
r = rand(dist) | |||
push!(vi, vn, r, dist, spl.selector) | |||
push!(vi, vn, r, dist, spl) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems that we are storing spl
directly instead of selector
inside vi
now? Why do we do this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No we are still pushing the same way. The reason why I am passing spl
here is to reset the :cache_updated
field inside the function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In you opinion, do you think this cache thing is worth doing here at all. I feel it might be over-optimisation here...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it is important in cases where the VarInfo
is "filled" most of the time so the cache is valid most of the time. This is the case for most samplers except the particle ones I believe.
dists = [Normal(0, 1), MvNormal([0; 0], [1.0 0; 0 1.0]), Wishart(7, [1 0.5; 0.5 1])] | ||
function test_varinfo!(vi) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be nice to improve the tests here a bit, e.g. adding more comments on what's testing and expected, and improving the code (variable names, code reuse, etc).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will add more comments to the src too.
test/core/RandomVariables.jl
Outdated
@@ -1,9 +1,9 @@ | |||
using Turing, Random | |||
using Turing: Selector, reconstruct, invlink, CACHERESET, SampleFromPrior | |||
using Turing.RandomVariables | |||
using Turing.RandomVariables: uid, cuid, getvals, getidcs, | |||
using Turing.RandomVariables: uid, cuid, getidcs, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we aware of the test coverage of the TypedVarInfo
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made sure the test coverage of TypedVarInfo
is no less than that of UntypedVarInfo
. But I will see if I can improve it.
src/core/RandomVariables.jl
Outdated
@@ -118,32 +130,224 @@ mutable struct UntypedVarInfo <: AbstractVarInfo | |||
end | |||
VarInfo() = UntypedVarInfo() | |||
|
|||
########################### | |||
# Single variable VarInfo # |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does SingleVarInfo
only contains a single random variable? If so why does ranges
required for SingleVarInfo
and fields like dists
, gids
are still vectors?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In case we have a vector variable where each element has a different prior distribution, so I think dists
still needs to be a vector. I think gids
can be shared however. Not sure about ranges
, in case we have a matrix variable whose columns are multivariate variables, we probably still need it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmmm, I see your point, though I guess we should discourage users to do that. And I'm not sure whether it makes or not for a single multivariate random variable whose different dimension follows different distributions. But I'm OK with it for now.
end | ||
@inline function _link!(vis::NamedTuple{names}, vi, vns, space) where {names} | ||
length(names) === 0 && return nothing | ||
f = names[1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A bit confused by the variables here, what is f
? Also we are calling getfield(vns, f)
multiple times later, can we assign it to some local variable with a name?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
f
is the first variable's symbol.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But using recursion, this "first" variable will be a different variable in each level of the recursion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But for the loop below we are calling getfield(vns, f)
with same vns
and f
all the time do we?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, so we are iterating over all the vn
s of the first symbol, then all the vn
s of the second, and so on. Each symbol has a vns
vector because it can consist of multiple random variables, e.g. multiple univariate variables arranged in vector form or multiple multivariate variables arranged in matrix form, etc. The last line is where recursion happens.
end | ||
@inline function _invlink!(vis::NamedTuple{names}, vi, vns, space) where {names} | ||
length(names) === 0 && return nothing | ||
f = names[1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar here.
I am working on docstrings for all the major functions and types in the module |
To do:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Review in progress - I'll review this PR in several steps in the next 2-3 days.
src/core/RandomVariables.jl
Outdated
|
||
Examples: | ||
|
||
- `x[2] ~ Normal()` will generate a `VarName` with `sym == :x` and `indexing == "[1]"` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
indexing == "[1]"
==> indexing == "[2]"
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks.
src/core/RandomVariables.jl
Outdated
end | ||
``` | ||
|
||
A variable identifier. Every variable has a symbol `sym`, indices `indexing`, and internal fields: `csym` and `counter`. The Julia variable in the model corresponding to `sym` can refer to a single value or to a hierarchical array structure of univariate, multivariate or matrix variables. `indexing` stores the indices that can access the random variable from the Julia variable. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps breaking this line into smaller lines (i.e. <80 chars)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or 92 which is the number from our guide.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will do.
src/core/RandomVariables.jl
Outdated
@@ -60,20 +108,57 @@ isequal(x::VarName, y::VarName) = hash(uid(x)) == hash(uid(y)) | |||
Base.string(vn::VarName) = "{$(vn.csym),$(vn.sym)$(vn.indexing)}:$(vn.counter)" | |||
Base.string(vns::Vector{<:VarName}) = replace(string(map(vn -> string(vn), vns)), "String" => "") | |||
|
|||
""" | |||
`sym_idx(vn::VarName)` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this equivalent to the function mentioned in #721?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's the inverse of it.
src/core/RandomVariables.jl
Outdated
|
||
A light wrapper over one or more instances of `Metadata`. Let `vi` be an instance of `Metadata`. If `vi isa VarInfo{<:Metadata}`, then only `Metadata` instance is used for all the sybmols. `VarInfo{<:Metadata}` is aliased `UntypedVarInfo`. If `vi isa VarInfo{<:NamedTuple}`, then `vi.metadata` is a `NamedTuple` that maps each symbol used on the LHS of `~` in the model to its `Metadata` instance. The latter allows for the type specialization of `vi` after the first sampling iteration when all the symbols have been observed. `VarInfo{<:NamedTuple}` is aliased `TypedVarInfo`. | ||
|
||
Note: It is the user's responsibility to ensure that each symbol is visited at least once whenever the model is called, regardless of any stochastic branching. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
presumably, it's ok if a variable is visited multiple times since they can have different counter
? If so, can we say this explicitly in the comment to reduce potential confusion?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Each random variable can only be visited once in a model call. But the symbol of the Julia variable can be visited more than once, e.g. x[1] ~ ...
and x[2] ~ ...
. This line refers to the symbol x
not the random variables x[1]
and x[2]
. The requirement here is that each symbol, e.g. x
, is visited at least once. I will try to make this clearer.
src/core/RandomVariables.jl
Outdated
end | ||
``` | ||
|
||
A variable identifier. Every variable has a symbol `sym`, indices `indexing`, and internal fields: `csym` and `counter`. The Julia variable in the model corresponding to `sym` can refer to a single value or to a hierarchical array structure of univariate, multivariate or matrix variables. `indexing` stores the indices that can access the random variable from the Julia variable. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or 92 which is the number from our guide.
src/core/RandomVariables.jl
Outdated
- `md.vals[md.ranges[md.idcs[vn]]]` is the vector of values of corresponding to `vn`. | ||
- `md.flags` is a dictionary of true/false flags. `md.flags[flag][md.idcs[vn]]` is the value of `flag` corresponding to `vn`. | ||
|
||
To make `md::Metadata` type stable, all the `md.vns` must have the same symbol and distribution type. However, one can have a Julia variable, say `x`, that is a matrix or a hierarchical array sampled in partitions, e.g. `x[1][:] ~ MvNormal(zeros(2), 1.0); x[2][:] ~ MvNormal(ones(2), 1.0)` and is managed by a single `md::Metadata` so long as all the distributions on the RHS of `~` are of the same type. Type unstable `Metadata` will still work but will have inferior performance. When sampling, the first iteration uses a type unstable `Metadata` for all the variables then a specialized `Metadata` is used for each symbol along with a function barrier to make the rest of the sampling type stable. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
src/core/RandomVariables.jl
Outdated
|
||
A light wrapper over one or more instances of `Metadata`. Let `vi` be an instance of `Metadata`. If `vi isa VarInfo{<:Metadata}`, then only `Metadata` instance is used for all the sybmols. `VarInfo{<:Metadata}` is aliased `UntypedVarInfo`. If `vi isa VarInfo{<:NamedTuple}`, then `vi.metadata` is a `NamedTuple` that maps each symbol used on the LHS of `~` in the model to its `Metadata` instance. The latter allows for the type specialization of `vi` after the first sampling iteration when all the symbols have been observed. `VarInfo{<:NamedTuple}` is aliased `TypedVarInfo`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let vi
be an instance of Metadata
. -> Let vi
be an instance of VarInfo
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks.
|
||
To make `md::Metadata` type stable, all the `md.vns` must have the same symbol and distribution type. However, one can have a Julia variable, say `x`, that is a matrix or a hierarchical array sampled in partitions, e.g. `x[1][:] ~ MvNormal(zeros(2), 1.0); x[2][:] ~ MvNormal(ones(2), 1.0)` and is managed by a single `md::Metadata` so long as all the distributions on the RHS of `~` are of the same type. Type unstable `Metadata` will still work but will have inferior performance. When sampling, the first iteration uses a type unstable `Metadata` for all the variables then a specialized `Metadata` is used for each symbol along with a function barrier to make the rest of the sampling type stable. | ||
""" | ||
struct Metadata{TIdcs <: Dict{<:VarName,Int}, TDists <: AbstractVector{<:Distribution}, TVN <: AbstractVector{<:VarName}, TVal <: AbstractVector{<:Real}, TGIds <: AbstractVector{Set{Selector}}} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In fact, I guess in the future, we could have two types of Metadata
which are NativeMetadata
and FlatMetadata
. The first one being not flatten the native Julia variable at all, as in SMC and PG they are not required to be flatten. I think this would give us some performance back from flattening and reconstruction.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, avoiding the flattening is an interesting idea. I am not sure what the best way to handle this is. Should probably open an issue to discuss it.
src/core/RandomVariables.jl
Outdated
|
||
This function finds all the unique `sym`s from the instances of `VarName{sym}` found in `vi.metadata.vns`. It then extracts the metadata associated with each symbol from the global `vi.metadata` field. Finally, a new `VarInfo` is created with a new `metadata` as a `NamedTuple` mapping from symbols to type-stable `Metadata` instances, one for each symbol. | ||
""" | ||
function TypedVarInfo(vi::UntypedVarInfo) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't check line by line but I'm happy as long as it has some test cases.
getlogp(vi::AbstractVarInfo) = vi.logp | ||
|
||
""" | ||
`setlogp!(vi::VarInfo, logp::Real)` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just for discussion. I guess there is no way in Julia to make a filed "private" and force people using setlogp!
instead of vi.log =
. Do you think it's actually a good idea for us to have setlogp!
at all?
FYI, this set of get
/set
functions were originally introduced by me at a time I was new to Julia and only know the practise of OO programming.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can overload setproperty!
for :log
and default that to an error, but that's probably overkill unless there is a good reason to stop people from doing this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm happy with removing those setter and getter methods if we don't need them. I don't see much benefit.
function link!(vi::UntypedVarInfo, spl::Sampler) | ||
vns = getvns(vi, spl) | ||
# TODO: Change to a lazy iterator over `vns` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does it mean?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As in we don't actually need to materialize vns
, we can replace it with a lazy iterator that loops over the relevant vn
s.
src/core/RandomVariables.jl
Outdated
Base.getindex(vi::UntypedVarInfo, spl::Sampler) = copy(getval(vi, _getranges(vi, spl))) | ||
function Base.getindex(vi::TypedVarInfo, spl::Sampler) | ||
# Gets the ranges as a NamedTuple | ||
# getfield(ranges, f) is all the indices in `vals` of the `vn`s with symbol `f` sampled by `spl` in `vi` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is no getfield(ranges, f)
any more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ranges
refers to the output variable from _getranges
. Calling getfield(ranges, f)
is what this line refers to. I will try to make it clearer.
src/core/RandomVariables.jl
Outdated
return vcat(_get(vi.metadata, ranges)...) | ||
end | ||
# Recursively builds a tuple of the `vals` of all the symbols | ||
@inline function _get(metadata::NamedTuple{names}, ranges) where {names} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe name it as _getindex
as it's only used by Base.getindex
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure
|
||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bookmark for Kai.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Look good to me! Awesome work.
src/core/RandomVariables.jl
Outdated
VarName(csym, sym, indexing, counter) = VarName{sym}(csym, indexing, counter) | ||
function VarName(csym::Symbol, sym::Symbol, indexing::String) | ||
# TODO: update this method when implementing the sanity check | ||
VarName{sym}(csym, indexing, 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it correct that the counter is always 1
in the constructor?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't changed this behavior. Maybe @xukai92 knows better.
getlogp(vi::AbstractVarInfo) = vi.logp | ||
|
||
""" | ||
`setlogp!(vi::VarInfo, logp::Real)` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm happy with removing those setter and getter methods if we don't need them. I don't see much benefit.
Current exported API Turing.jl/src/core/RandomVariables.jl Lines 13 to 36 in 06e0dbb
Key data types
Internal / utility functions
|
I think we need to do the following in a separate PR:
|
Sounds good to do the API redesign in a separate PR. |
06e0dbb
to
7dc8e05
Compare
…gLang/Turing.jl into mt/ts_wave2_untypedvarinfo
Rebased on master |
Looks great to me. The @model gdemo_d() = begin
s ~ InverseGamma(2, 3)
m ~ Normal(0, sqrt(s))
1.5 ~ Normal(m, sqrt(s))
2.0 ~ Normal(m, sqrt(s))
return s, m
end
mf = gdemo_d()
@code_warntype sample(mf, NUTS(2000, 0.8))
Body::Any
1 ─ %1 = invoke Turing.Inference.AHMCAdaptor(_3::NUTS{Turing.Core.ForwardDiffAD{40},Any})::AdvancedHMC.Adaptation.StanNUTSAdaptor
│ %2 = (Turing.Inference.:(#sample#19))(false, Turing.Inference.nothing, 0, %1, Turing.Inference.nothing, Turing.Inference.GLOBAL_RNG, $(QuoteNode(Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}())), #self#, model, alg)::Any
└── return %2 |
OK it seems only the return type is not stable but not during sampling? |
No, this PR doesn't hook into
|
Something weird happened when I merged, I will fix the error. |
This issue should go away after #746 is fixed. |
This is the second wave of #660. In this PR:
TypedVarInfo
type is introduced with all its utility functions mapping that ofUntypedVarInfo
.UntypedVarInfo
are also defined.push!
is called onVarInfo
. This seemed like a latent bug that was "worked around" by invalidating the cache at the call site after callingpush!
.getvals(vi, spl)
was removed as it seemed to be doing the same thing asvi[spl]
and it was unused in the rest of Turing.Your feedback is appreciated.