Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TS 2: introduce TypedVarInfo and fix spl.info[:cache_updated] #742

Merged
merged 51 commits into from
May 13, 2019

Conversation

mohamed82008
Copy link
Member

@mohamed82008 mohamed82008 commented Mar 31, 2019

This is the second wave of #660. In this PR:

  1. The TypedVarInfo type is introduced with all its utility functions mapping that of UntypedVarInfo.
  2. Unit tests mapping those of UntypedVarInfo are also defined.
  3. One side change that was made in this PR is to make sure to invalidate the sampler's cache every time push! is called on VarInfo. This seemed like a latent bug that was "worked around" by invalidating the cache at the call site after calling push!.
  4. getvals(vi, spl) was removed as it seemed to be doing the same thing as vi[spl] and it was unused in the rest of Turing.

Your feedback is appreciated.

@mohamed82008
Copy link
Member Author

mohamed82008 commented Mar 31, 2019

TODO: replace most of the generated functions with recursive functions. Example:

@generated function _getranges(vi::TypedVarInfo{Tvis}, idcs) where Tvis
    args = []
    for f in fieldnames(Tvis)
        push!(args, :($f = _map(vi, $(QuoteNode(f)), idcs.$f)))
    end
    if length(args) == 0
        nt = :(NamedTuple())
    else
        nt = :(($(args...),))
    end
    return nt
end

can be replaced with:

_getranges(vi::TypedVarInfo, idcs) = _getranges(vi.vis, vi, idcs)
@inline function _getranges(vis::NamedTuple{names}, vi::TypedVarInfo, idcs) where names
	length(names) === 0 && return NamedTuple()
	f = names[1]
	v = _map(vi, f, getfield(idcs, f))
	nt = NamedTuple{(f,), Tuple{typeof(v)}}(v)
	return merge(nt, _getranges(Base.tail(vis), vi, idcs))
end

@yebai
Copy link
Member

yebai commented Apr 3, 2019

@mohamed82008 is this ready for a review?

@mohamed82008
Copy link
Member Author

Yes if you don't mind the generated functions. The semantics will not change by making them into normal functions.

@yebai yebai requested review from KDr2 and xukai92 April 5, 2019 18:32
Copy link
Member

@xukai92 xukai92 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left some comments. Most of them are things I didn't quite understand and ask for explanations. Also, I didn't check the set and get functions carefully but I'm happy with them given the tests are covered and pass.

@@ -58,7 +58,7 @@ end

function assume(spl::Sampler{<:IS}, dist::Distribution, vn::VarName, vi::VarInfo)
r = rand(dist)
push!(vi, vn, r, dist, spl.selector)
push!(vi, vn, r, dist, spl)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that we are storing spl directly instead of selector inside vi now? Why do we do this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No we are still pushing the same way. The reason why I am passing spl here is to reset the :cache_updated field inside the function.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In you opinion, do you think this cache thing is worth doing here at all. I feel it might be over-optimisation here...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is important in cases where the VarInfo is "filled" most of the time so the cache is valid most of the time. This is the case for most samplers except the particle ones I believe.

src/core/RandomVariables.jl Outdated Show resolved Hide resolved
dists = [Normal(0, 1), MvNormal([0; 0], [1.0 0; 0 1.0]), Wishart(7, [1 0.5; 0.5 1])]
function test_varinfo!(vi)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be nice to improve the tests here a bit, e.g. adding more comments on what's testing and expected, and improving the code (variable names, code reuse, etc).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will add more comments to the src too.

@@ -1,9 +1,9 @@
using Turing, Random
using Turing: Selector, reconstruct, invlink, CACHERESET, SampleFromPrior
using Turing.RandomVariables
using Turing.RandomVariables: uid, cuid, getvals, getidcs,
using Turing.RandomVariables: uid, cuid, getidcs,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we aware of the test coverage of the TypedVarInfo?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made sure the test coverage of TypedVarInfo is no less than that of UntypedVarInfo. But I will see if I can improve it.

src/core/RandomVariables.jl Outdated Show resolved Hide resolved
@@ -118,32 +130,224 @@ mutable struct UntypedVarInfo <: AbstractVarInfo
end
VarInfo() = UntypedVarInfo()

###########################
# Single variable VarInfo #
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does SingleVarInfo only contains a single random variable? If so why does ranges required for SingleVarInfo and fields like dists, gids are still vectors?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In case we have a vector variable where each element has a different prior distribution, so I think dists still needs to be a vector. I think gids can be shared however. Not sure about ranges, in case we have a matrix variable whose columns are multivariate variables, we probably still need it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, I see your point, though I guess we should discourage users to do that. And I'm not sure whether it makes or not for a single multivariate random variable whose different dimension follows different distributions. But I'm OK with it for now.

end
@inline function _link!(vis::NamedTuple{names}, vi, vns, space) where {names}
length(names) === 0 && return nothing
f = names[1]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bit confused by the variables here, what is f? Also we are calling getfield(vns, f) multiple times later, can we assign it to some local variable with a name?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

f is the first variable's symbol.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But using recursion, this "first" variable will be a different variable in each level of the recursion.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But for the loop below we are calling getfield(vns, f) with same vns and f all the time do we?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, so we are iterating over all the vns of the first symbol, then all the vns of the second, and so on. Each symbol has a vns vector because it can consist of multiple random variables, e.g. multiple univariate variables arranged in vector form or multiple multivariate variables arranged in matrix form, etc. The last line is where recursion happens.

end
@inline function _invlink!(vis::NamedTuple{names}, vi, vns, space) where {names}
length(names) === 0 && return nothing
f = names[1]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar here.

@mohamed82008
Copy link
Member Author

I am working on docstrings for all the major functions and types in the module RandomVariables.

@mohamed82008
Copy link
Member Author

mohamed82008 commented Apr 7, 2019

To do:

  • More unit tests for TypedVarInfo and UntypedVarInfo
  • More comments and docstrings
  • Make gids for TypedVarInfo a resizeable FillArray type

@yebai yebai mentioned this pull request Apr 19, 2019
56 tasks
Copy link
Member

@yebai yebai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review in progress - I'll review this PR in several steps in the next 2-3 days.


Examples:

- `x[2] ~ Normal()` will generate a `VarName` with `sym == :x` and `indexing == "[1]"`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indexing == "[1]" ==> indexing == "[2]"?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks.

end
```

A variable identifier. Every variable has a symbol `sym`, indices `indexing`, and internal fields: `csym` and `counter`. The Julia variable in the model corresponding to `sym` can refer to a single value or to a hierarchical array structure of univariate, multivariate or matrix variables. `indexing` stores the indices that can access the random variable from the Julia variable.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps breaking this line into smaller lines (i.e. <80 chars)?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or 92 which is the number from our guide.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do.

@@ -60,20 +108,57 @@ isequal(x::VarName, y::VarName) = hash(uid(x)) == hash(uid(y))
Base.string(vn::VarName) = "{$(vn.csym),$(vn.sym)$(vn.indexing)}:$(vn.counter)"
Base.string(vns::Vector{<:VarName}) = replace(string(map(vn -> string(vn), vns)), "String" => "")

"""
`sym_idx(vn::VarName)`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this equivalent to the function mentioned in #721?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's the inverse of it.


A light wrapper over one or more instances of `Metadata`. Let `vi` be an instance of `Metadata`. If `vi isa VarInfo{<:Metadata}`, then only `Metadata` instance is used for all the sybmols. `VarInfo{<:Metadata}` is aliased `UntypedVarInfo`. If `vi isa VarInfo{<:NamedTuple}`, then `vi.metadata` is a `NamedTuple` that maps each symbol used on the LHS of `~` in the model to its `Metadata` instance. The latter allows for the type specialization of `vi` after the first sampling iteration when all the symbols have been observed. `VarInfo{<:NamedTuple}` is aliased `TypedVarInfo`.

Note: It is the user's responsibility to ensure that each symbol is visited at least once whenever the model is called, regardless of any stochastic branching.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

presumably, it's ok if a variable is visited multiple times since they can have different counter? If so, can we say this explicitly in the comment to reduce potential confusion?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Each random variable can only be visited once in a model call. But the symbol of the Julia variable can be visited more than once, e.g. x[1] ~ ... and x[2] ~ .... This line refers to the symbol x not the random variables x[1] and x[2]. The requirement here is that each symbol, e.g. x, is visited at least once. I will try to make this clearer.

end
```

A variable identifier. Every variable has a symbol `sym`, indices `indexing`, and internal fields: `csym` and `counter`. The Julia variable in the model corresponding to `sym` can refer to a single value or to a hierarchical array structure of univariate, multivariate or matrix variables. `indexing` stores the indices that can access the random variable from the Julia variable.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or 92 which is the number from our guide.

- `md.vals[md.ranges[md.idcs[vn]]]` is the vector of values of corresponding to `vn`.
- `md.flags` is a dictionary of true/false flags. `md.flags[flag][md.idcs[vn]]` is the value of `flag` corresponding to `vn`.

To make `md::Metadata` type stable, all the `md.vns` must have the same symbol and distribution type. However, one can have a Julia variable, say `x`, that is a matrix or a hierarchical array sampled in partitions, e.g. `x[1][:] ~ MvNormal(zeros(2), 1.0); x[2][:] ~ MvNormal(ones(2), 1.0)` and is managed by a single `md::Metadata` so long as all the distributions on the RHS of `~` are of the same type. Type unstable `Metadata` will still work but will have inferior performance. When sampling, the first iteration uses a type unstable `Metadata` for all the variables then a specialized `Metadata` is used for each symbol along with a function barrier to make the rest of the sampling type stable.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!


A light wrapper over one or more instances of `Metadata`. Let `vi` be an instance of `Metadata`. If `vi isa VarInfo{<:Metadata}`, then only `Metadata` instance is used for all the sybmols. `VarInfo{<:Metadata}` is aliased `UntypedVarInfo`. If `vi isa VarInfo{<:NamedTuple}`, then `vi.metadata` is a `NamedTuple` that maps each symbol used on the LHS of `~` in the model to its `Metadata` instance. The latter allows for the type specialization of `vi` after the first sampling iteration when all the symbols have been observed. `VarInfo{<:NamedTuple}` is aliased `TypedVarInfo`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let vi be an instance of Metadata. -> Let vi be an instance of VarInfo ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks.


To make `md::Metadata` type stable, all the `md.vns` must have the same symbol and distribution type. However, one can have a Julia variable, say `x`, that is a matrix or a hierarchical array sampled in partitions, e.g. `x[1][:] ~ MvNormal(zeros(2), 1.0); x[2][:] ~ MvNormal(ones(2), 1.0)` and is managed by a single `md::Metadata` so long as all the distributions on the RHS of `~` are of the same type. Type unstable `Metadata` will still work but will have inferior performance. When sampling, the first iteration uses a type unstable `Metadata` for all the variables then a specialized `Metadata` is used for each symbol along with a function barrier to make the rest of the sampling type stable.
"""
struct Metadata{TIdcs <: Dict{<:VarName,Int}, TDists <: AbstractVector{<:Distribution}, TVN <: AbstractVector{<:VarName}, TVal <: AbstractVector{<:Real}, TGIds <: AbstractVector{Set{Selector}}}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact, I guess in the future, we could have two types of Metadata which are NativeMetadata and FlatMetadata. The first one being not flatten the native Julia variable at all, as in SMC and PG they are not required to be flatten. I think this would give us some performance back from flattening and reconstruction.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, avoiding the flattening is an interesting idea. I am not sure what the best way to handle this is. Should probably open an issue to discuss it.


This function finds all the unique `sym`s from the instances of `VarName{sym}` found in `vi.metadata.vns`. It then extracts the metadata associated with each symbol from the global `vi.metadata` field. Finally, a new `VarInfo` is created with a new `metadata` as a `NamedTuple` mapping from symbols to type-stable `Metadata` instances, one for each symbol.
"""
function TypedVarInfo(vi::UntypedVarInfo)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't check line by line but I'm happy as long as it has some test cases.

getlogp(vi::AbstractVarInfo) = vi.logp

"""
`setlogp!(vi::VarInfo, logp::Real)`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for discussion. I guess there is no way in Julia to make a filed "private" and force people using setlogp! instead of vi.log = . Do you think it's actually a good idea for us to have setlogp! at all?

FYI, this set of get/set functions were originally introduced by me at a time I was new to Julia and only know the practise of OO programming.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can overload setproperty! for :log and default that to an error, but that's probably overkill unless there is a good reason to stop people from doing this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm happy with removing those setter and getter methods if we don't need them. I don't see much benefit.

function link!(vi::UntypedVarInfo, spl::Sampler)
vns = getvns(vi, spl)
# TODO: Change to a lazy iterator over `vns`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does it mean?

Copy link
Member Author

@mohamed82008 mohamed82008 Apr 27, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As in we don't actually need to materialize vns, we can replace it with a lazy iterator that loops over the relevant vns.

Base.getindex(vi::UntypedVarInfo, spl::Sampler) = copy(getval(vi, _getranges(vi, spl)))
function Base.getindex(vi::TypedVarInfo, spl::Sampler)
# Gets the ranges as a NamedTuple
# getfield(ranges, f) is all the indices in `vals` of the `vn`s with symbol `f` sampled by `spl` in `vi`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no getfield(ranges, f) any more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ranges refers to the output variable from _getranges. Calling getfield(ranges, f) is what this line refers to. I will try to make it clearer.

return vcat(_get(vi.metadata, ranges)...)
end
# Recursively builds a tuple of the `vals` of all the symbols
@inline function _get(metadata::NamedTuple{names}, ranges) where {names}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe name it as _getindex as it's only used by Base.getindex.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure


"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bookmark for Kai.

@yebai yebai mentioned this pull request Apr 29, 2019
Copy link
Member

@trappmartin trappmartin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Look good to me! Awesome work.

VarName(csym, sym, indexing, counter) = VarName{sym}(csym, indexing, counter)
function VarName(csym::Symbol, sym::Symbol, indexing::String)
# TODO: update this method when implementing the sanity check
VarName{sym}(csym, indexing, 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it correct that the counter is always 1 in the constructor?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't changed this behavior. Maybe @xukai92 knows better.

getlogp(vi::AbstractVarInfo) = vi.logp

"""
`setlogp!(vi::VarInfo, logp::Real)`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm happy with removing those setter and getter methods if we don't need them. I don't see much benefit.

@yebai
Copy link
Member

yebai commented May 3, 2019

Current exported API

export VarName,
AbstractVarInfo,
VarInfo,
UntypedVarInfo,
uid,
sym,
getlogp,
set_retained_vns_del_by_spl!,
resetlogp!,
is_flagged,
unset_flag!,
setgid!,
copybyindex,
setorder!,
updategid!,
acclogp!,
istrans,
link!,
invlink!,
setlogp!,
getranges,
getrange,
getvns,
getval

  • logp
    function Turing.runmodel!(model::Model, vi::AbstractVarInfo, spl::AbstractSampler = SampleFromPrior())
    setlogp!(vi, zero(Float64))
    if spl isa Sampler && isdefined(spl.info, :eval_num)
    spl.info.eval_num += 1
    end
    model(vi, spl)
    return vi
    end

Key data types

Internal / utility functions

getval and setval!:

getidx:

getrange

setorder!

  • `setorder!(vi::VarInfo, vn::VarName, index::Int)`
    Sets the `order` of `vn` in `vi` to `index`, where `order` is the number of `observe
    statements run before sampling `vn`.
    """
    function setorder!(vi::UntypedVarInfo, vn::VarName, index::Int)

  • function setorder!(mvi::TypedVarInfo, vn::VarName{sym}, index::Int) where {sym}

flagging and gid functions:

@mohamed82008
Copy link
Member Author

I think we need to do the following in a separate PR:

  1. Find which functions we need outside of RandomVariables to define the "necessary" API.
  2. For internal functions, we can use any names we like.
  3. For functions needed outside, we can try to find Base functions with similar enough semantics and overload those instead to minimize the exported names and make it easier to remember.

@yebai
Copy link
Member

yebai commented May 3, 2019

I think we need to do the following in a separate PR:

Sounds good to do the API redesign in a separate PR.

@mohamed82008
Copy link
Member Author

Rebased on master

@xukai92
Copy link
Member

xukai92 commented May 13, 2019

Looks great to me.

The @code_warntype check doesn't pass on my local, althought it's not the issue from this PR.

 @model gdemo_d() = begin
    s ~ InverseGamma(2, 3)
    m ~ Normal(0, sqrt(s))
    1.5 ~ Normal(m, sqrt(s))
    2.0 ~ Normal(m, sqrt(s))
    return s, m
end

mf = gdemo_d()

@code_warntype sample(mf, NUTS(2000, 0.8))
Body::Any
1%1 = invoke Turing.Inference.AHMCAdaptor(_3::NUTS{Turing.Core.ForwardDiffAD{40},Any})::AdvancedHMC.Adaptation.StanNUTSAdaptor%2 = (Turing.Inference.:(#sample#19))(false, Turing.Inference.nothing, 0, %1, Turing.Inference.nothing, Turing.Inference.GLOBAL_RNG, $(QuoteNode(Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}())), #self#, model, alg)::Any
└──      return %2

@xukai92
Copy link
Member

xukai92 commented May 13, 2019

OK it seems only the return type is not stable but not during sampling?

@mohamed82008
Copy link
Member Author

No, this PR doesn't hook into sample yet, this will be the next PR. And even if I hook them, sample won't be type stable; we need to address the info field of Sampler and the value field of Sample to get complete type stability from the second iteration onwards. Currently, spl.info is Dict{Symbol, Any} so when we retrieve values from it during sampling, the compiler can't infer its type and this muddies the rest of the sampling function. So 2 more PRs are needed to get complete type stability from the second iteration onwards for all algorithms:

  1. Hooking TypedVarInfo to sample, and
  2. Making spl.info and sample.value type stable.

@mohamed82008
Copy link
Member Author

Something weird happened when I merged, I will fix the error.

@yebai
Copy link
Member

yebai commented May 13, 2019

Making spl.info and sample.value type stable.

This issue should go away after #746 is fixed.

@yebai yebai merged commit 51b7880 into master May 13, 2019
@yebai yebai deleted the mt/ts_wave2_untypedvarinfo branch June 8, 2019 22:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants