Skip to content

Commit

Permalink
Merge pull request #481 from probcomp/20221031-fsaad-minor-fixes
Browse files Browse the repository at this point in the history
20221031 fsaad minor fixes
  • Loading branch information
Feras Saad authored Oct 31, 2022
2 parents e1fb6eb + ca21e40 commit bcf3504
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 38 deletions.
50 changes: 32 additions & 18 deletions src/inference/importance.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
"""
(traces, log_norm_weights, lml_est) = importance_sampling(model::GenerativeFunction,
model_args::Tuple, observations::ChoiceMap, num_samples::Int, verbose=false)
model_args::Tuple, observations::ChoiceMap, num_samples::Int; verbose=false)
(traces, log_norm_weights, lml_est) = importance_sampling(model::GenerativeFunction,
model_args::Tuple, observations::ChoiceMap,
proposal::GenerativeFunction, proposal_args::Tuple,
num_samples::Int, verbose=false)
num_samples::Int; verbose=false)
Run importance sampling, returning a vector of traces with associated log weights.
Expand All @@ -17,9 +17,12 @@ The second variant uses a custom proposal distribution defined by the given gene
All addresses of random choices sampled by the proposal should also be sampled by the model function.
Setting `verbose=true` prints a progress message every sample.
"""
function importance_sampling(model::GenerativeFunction{T,U}, model_args::Tuple,
observations::ChoiceMap,
num_samples::Int, verbose=false) where {T,U}
function importance_sampling(
model::GenerativeFunction{T,U},
model_args::Tuple,
observations::ChoiceMap,
num_samples::Int;
verbose=false) where {T,U}
traces = Vector{U}(undef, num_samples)
log_weights = Vector{Float64}(undef, num_samples)
for i=1:num_samples
Expand All @@ -32,10 +35,14 @@ function importance_sampling(model::GenerativeFunction{T,U}, model_args::Tuple,
return (traces, log_normalized_weights, log_ml_estimate)
end

function importance_sampling(model::GenerativeFunction{T,U}, model_args::Tuple,
observations::ChoiceMap,
proposal::GenerativeFunction, proposal_args::Tuple,
num_samples::Int, verbose=false) where {T,U}
function importance_sampling(
model::GenerativeFunction{T,U},
model_args::Tuple,
observations::ChoiceMap,
proposal::GenerativeFunction,
proposal_args::Tuple,
num_samples::Int;
verbose=false) where {T,U}
traces = Vector{U}(undef, num_samples)
log_weights = Vector{Float64}(undef, num_samples)
for i=1:num_samples
Expand All @@ -53,23 +60,26 @@ end

"""
(trace, lml_est) = importance_resampling(model::GenerativeFunction,
model_args::Tuple, observations::ChoiceMap, num_samples::Int,
model_args::Tuple, observations::ChoiceMap, num_samples::Int;
verbose=false)
(traces, lml_est) = importance_resampling(model::GenerativeFunction,
model_args::Tuple, observations::ChoiceMap,
proposal::GenerativeFunction, proposal_args::Tuple,
num_samples::Int, verbose=false)
num_samples::Int; verbose=false)
Run sampling importance resampling, returning a single trace.
Unlike `importance_sampling`, the memory used constant in the number of samples.
Setting `verbose=true` prints a progress message every sample.
"""
function importance_resampling(model::GenerativeFunction{T,U}, model_args::Tuple,
observations::ChoiceMap,
num_samples::Int; verbose=false) where {T,U,V,W}
function importance_resampling(
model::GenerativeFunction{T,U},
model_args::Tuple,
observations::ChoiceMap,
num_samples::Int;
verbose=false) where {T,U}
(model_trace::U, log_weight) = generate(model, model_args, observations)
log_total_weight = log_weight
for i=2:num_samples
Expand All @@ -84,10 +94,14 @@ function importance_resampling(model::GenerativeFunction{T,U}, model_args::Tuple
return (model_trace::U, log_ml_estimate::Float64)
end

function importance_resampling(model::GenerativeFunction{T,U}, model_args::Tuple,
observations::ChoiceMap,
proposal::GenerativeFunction{V,W}, proposal_args::Tuple,
num_samples::Int; verbose=false) where {T,U,V,W}
function importance_resampling(
model::GenerativeFunction{T,U},
model_args::Tuple,
observations::ChoiceMap,
proposal::GenerativeFunction{V,W},
proposal_args::Tuple,
num_samples::Int;
verbose=false) where {T,U,V,W}
(proposal_choices, proposal_weight, _) = propose(proposal, proposal_args)
constraints = merge(observations, proposal_choices)
(model_trace::U, model_weight) = generate(model, model_args, constraints)
Expand Down
39 changes: 28 additions & 11 deletions src/inference/particle_filter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,12 @@ end
Initialize the state of a particle filter using a custom proposal for the initial latent state.
"""
function initialize_particle_filter(model::GenerativeFunction{T,U}, model_args::Tuple,
observations::ChoiceMap, proposal::GenerativeFunction, proposal_args::Tuple,
function initialize_particle_filter(
model::GenerativeFunction{T,U},
model_args::Tuple,
observations::ChoiceMap,
proposal::GenerativeFunction,
proposal_args::Tuple,
num_particles::Int) where {T,U}
traces = Vector{Any}(undef, num_particles)
log_weights = Vector{Float64}(undef, num_particles)
Expand All @@ -96,8 +100,11 @@ end
Initialize the state of a particle filter, using the default proposal for the initial latent state.
"""
function initialize_particle_filter(model::GenerativeFunction{T,U}, model_args::Tuple,
observations::ChoiceMap, num_particles::Int) where {T,U}
function initialize_particle_filter(
model::GenerativeFunction{T,U},
model_args::Tuple,
observations::ChoiceMap,
num_particles::Int) where {T,U}
traces = Vector{Any}(undef, num_particles)
log_weights = Vector{Float64}(undef, num_particles)
for i=1:num_particles
Expand Down Expand Up @@ -137,11 +144,16 @@ the support of the model (with the new arguments). If such a trace exists, then
the random choices not determined by the above requirements are sampled using
the internal proposal.
"""
function particle_filter_step!(state::ParticleFilterState{U}, new_args::Tuple, argdiffs::Tuple,
observations::ChoiceMap, proposal::GenerativeFunction, proposal_args::Tuple) where {U}
function particle_filter_step!(
state::ParticleFilterState{U},
new_args::Tuple,
argdiffs::Tuple,
observations::ChoiceMap,
proposal::GenerativeFunction,
proposal_args::Tuple) where {U}
trace_translator = SimpleExtendingTraceTranslator(new_args, argdiffs, observations, proposal, proposal_args)
num_particles = length(state.traces)
log_incremental_weights = Vector{Float64}(undef, num_particles)
log_incremental_weights = Vector{Float64}(undef, num_particles)
for i=1:num_particles
(state.new_traces[i], log_weight) = trace_translator(state.traces[i])
log_incremental_weights[i] = log_weight
Expand All @@ -163,10 +175,13 @@ end
Perform a particle filter update, where the model arguments are adjusted, new observations are added, and the default proposal is used for new latent state.
"""
function particle_filter_step!(state::ParticleFilterState{U}, new_args::Tuple, argdiffs::Tuple,
function particle_filter_step!(
state::ParticleFilterState{U},
new_args::Tuple,
argdiffs::Tuple,
observations::ChoiceMap) where {U}
num_particles = length(state.traces)
log_incremental_weights = Vector{Float64}(undef, num_particles)
log_incremental_weights = Vector{Float64}(undef, num_particles)
for i=1:num_particles
(state.new_traces[i], increment, _, discard) = update(
state.traces[i], new_args, argdiffs, observations)
Expand All @@ -192,8 +207,10 @@ end
Do a resampling step if the effective sample size is below the given threshold.
Return `true` if a resample thus occurred, `false` otherwise.
"""
function maybe_resample!(state::ParticleFilterState{U};
ess_threshold::Real=length(state.traces)/2, verbose=false) where {U}
function maybe_resample!(
state::ParticleFilterState{U};
ess_threshold::Real=length(state.traces)/2,
verbose=false) where {U}
num_particles = length(state.traces)
(log_total_weight, log_normalized_weights) = normalize_weights(state.log_weights)
ess = effective_sample_size(log_normalized_weights)
Expand Down
18 changes: 9 additions & 9 deletions src/modeling_library/switch/update.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ function update_recurse_merge(prev_choices::ChoiceMap, choices::ChoiceMap)
address in keys(choice_value_iterator) && continue
set_value!(new_choices, address, value)
end
# Add (address, submap) to new_choices from prev_choices if address does not occur in choices.

# Add (address, submap) to new_choices from prev_choices if address does not occur in choices.
# If it does, enter a recursive call to update_recurse_merge.
for (address, node1) in prev_choice_submap_iterator
if address in keys(choice_submap_iterator)
Expand Down Expand Up @@ -87,8 +87,8 @@ function process!(gen_fn::Switch{C, N, K, T},
index_argdiff::UnknownChange,
args::Tuple,
kernel_argdiffs::Tuple,
choices::ChoiceMap,
state::SwitchUpdateState{T}) where {C, N, K, T, DV}
choices::ChoiceMap,
state::SwitchUpdateState{T}) where {C, N, K, T}

# Generate new trace.
merged = update_recurse_merge(get_choices(state.prev_trace), choices)
Expand All @@ -111,7 +111,7 @@ function process!(gen_fn::Switch{C, N, K, T},
index_argdiff::NoChange, # TODO: Diffed wrapper?
args::Tuple,
kernel_argdiffs::Tuple,
choices::ChoiceMap,
choices::ChoiceMap,
state::SwitchUpdateState{T}) where {C, N, K, T}

# Update trace.
Expand All @@ -130,15 +130,15 @@ end
@inline process!(gen_fn::Switch{C, N, K, T}, index::C, index_argdiff::Diff, args::Tuple, kernel_argdiffs::Tuple, choices::ChoiceMap, state::SwitchUpdateState{T}) where {C, N, K, T} = process!(gen_fn, getindex(gen_fn.cases, index), index_argdiff, args, kernel_argdiffs, choices, state)

function update(trace::SwitchTrace{A, T, U},
args::Tuple,
args::Tuple,
argdiffs::Tuple,
choices::ChoiceMap) where {A, T, U}
gen_fn = trace.gen_fn
index, index_argdiff = args[1], argdiffs[1]
state = SwitchUpdateState{T}(0.0, 0.0, 0.0, trace)
process!(gen_fn, index, index_argdiff,
process!(gen_fn, index, index_argdiff,
args[2 : end], argdiffs[2 : end], choices, state)
return SwitchTrace(gen_fn, state.trace,
get_retval(state.trace), args,
return SwitchTrace(gen_fn, state.trace,
get_retval(state.trace), args,
state.score, state.noise), state.weight, state.updated_retdiff, state.discard
end

0 comments on commit bcf3504

Please sign in to comment.