Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] VarInfo changes #188

Closed
wants to merge 15 commits into from
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[compat]
Expand All @@ -22,6 +23,7 @@ julia = "1"
[extras]
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
AdvancedMH = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
BinaryProvider = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
Expand All @@ -48,4 +50,4 @@ UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["AdvancedHMC", "AdvancedMH", "DistributionsAD", "DocStringExtensions", "EllipticalSliceSampling", "ForwardDiff", "Libtask", "LinearAlgebra", "LogDensityProblems", "Logging", "MCMCChains", "Markdown", "PDMats", "ProgressLogging", "Random", "Reexport", "Requires", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns", "Test", "Tracker", "UUIDs", "Zygote"]
test = ["AdvancedHMC", "AdvancedMH", "BangBang", "DistributionsAD", "DocStringExtensions", "EllipticalSliceSampling", "ForwardDiff", "Libtask", "LinearAlgebra", "LogDensityProblems", "Logging", "MCMCChains", "Markdown", "PDMats", "ProgressLogging", "Random", "Reexport", "Requires", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns", "Test", "Tracker", "UUIDs", "Zygote"]
5 changes: 4 additions & 1 deletion src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using AbstractMCMC: AbstractSampler, AbstractChains, AbstractModel
using Distributions
using Bijectors
using MacroTools
using Requires

import AbstractMCMC
import ZygoteRules
Expand All @@ -28,6 +29,7 @@ import Base: Symbol,
export AbstractVarInfo,
VarInfo,
UntypedVarInfo,
MixedVarInfo,
getlogp,
setlogp!,
acclogp!,
Expand Down Expand Up @@ -116,8 +118,9 @@ include("sampler.jl")
include("varname.jl")
include("distribution_wrappers.jl")
include("contexts.jl")
include("varinfo.jl")
include("varinfo/varinfo.jl")
include("threadsafe.jl")
include("mixedvarinfo.jl")
include("context_implementations.jl")
include("compiler.jl")
include("prob_macro.jl")
Expand Down
2 changes: 1 addition & 1 deletion src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ Convert the `value` to the correct type for the `sampler` and the `vi` object.
function matchingvalue(sampler, vi, value)
T = typeof(value)
if hasmissing(T)
return get_matching_type(sampler, vi, T)(value)
return convert(get_matching_type(sampler, vi, T), value)
else
return value
end
Expand Down
20 changes: 10 additions & 10 deletions src/context_implementations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@ function tilde(ctx::DefaultContext, sampler, right, vn::VarName, _, vi)
end
function tilde(ctx::PriorContext, sampler, right, vn::VarName, inds, vi)
if ctx.vars !== nothing
vi[vn] = vectorize(right, _getindex(getfield(ctx.vars, getsym(vn)), inds))
vi[vn, right] = _getindex(getfield(ctx.vars, getsym(vn)), inds)
settrans!(vi, false, vn)
end
return _tilde(sampler, right, vn, vi)
end
function tilde(ctx::LikelihoodContext, sampler, right, vn::VarName, inds, vi)
if ctx.vars !== nothing
vi[vn] = vectorize(right, _getindex(getfield(ctx.vars, getsym(vn)), inds))
vi[vn, right] = _getindex(getfield(ctx.vars, getsym(vn)), inds)
settrans!(vi, false, vn)
end
return _tilde(sampler, NoDist(right), vn, vi)
Expand Down Expand Up @@ -127,11 +127,11 @@ function assume(
if spl isa SampleFromUniform || is_flagged(vi, vn, "del")
unset_flag!(vi, vn, "del")
r = init(dist, spl)
vi[vn] = vectorize(dist, r)
vi[vn, dist] = r
settrans!(vi, false, vn)
setorder!(vi, vn, get_num_produce(vi))
else
r = vi[vn]
r = vi[vn, dist]
end
else
r = init(dist, spl)
Expand Down Expand Up @@ -297,12 +297,12 @@ function get_and_set_val!(
r = init(dist, spl, n)
for i in 1:n
vn = vns[i]
vi[vn] = vectorize(dist, r[:, i])
vi[vn, dist] = r[:, i]
settrans!(vi, false, vn)
setorder!(vi, vn, get_num_produce(vi))
end
else
r = vi[vns]
r = vi[vns, dist]
end
else
r = init(dist, spl, n)
Expand Down Expand Up @@ -330,12 +330,12 @@ function get_and_set_val!(
for i in eachindex(vns)
vn = vns[i]
dist = dists isa AbstractArray ? dists[i] : dists
vi[vn] = vectorize(dist, r[i])
vi[vn, dist] = r[i]
settrans!(vi, false, vn)
setorder!(vi, vn, get_num_produce(vi))
end
else
r = reshape(vi[vec(vns)], size(vns))
r = vi[vns, dists]
end
else
f = (vn, dist) -> init(dist, spl)
Expand All @@ -354,7 +354,7 @@ function set_val!(
)
@assert size(val, 2) == length(vns)
foreach(enumerate(vns)) do (i, vn)
vi[vn] = val[:,i]
vi[vn, dist] = val[:,i]
end
return val
end
Expand All @@ -367,7 +367,7 @@ function set_val!(
@assert size(val) == size(vns)
foreach(CartesianIndices(val)) do ind
dist = dists isa AbstractArray ? dists[ind] : dists
vi[vns[ind]] = vectorize(dist, val[ind])
vi[vns[ind], dist] = val[ind]
end
return val
end
Expand Down
Loading