Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use map_optimize in particle filter #533

Open
lwang19-ai opened this issue Jun 3, 2024 · 2 comments
Open

Use map_optimize in particle filter #533

lwang19-ai opened this issue Jun 3, 2024 · 2 comments

Comments

@lwang19-ai
Copy link

Is it possible to use map_optimize in the particle filter process?
I did something like:

state_atm = Gen.initialize_particle_filter(unfold_atm, (start_ass_time,), init_obs_atm, num_particles)
selection = Gen.select(:chain => 1 => :u => 4 => :u)
Gen.map_optimize(state_atm.traces[1], selection)

Then I get the error:
MethodError: no method matching zero(::Type{Any})
The error is from line choice_gradients in the map_optimize function. I am unsure if it's because of my implementation or if it's not supported here.

@georgematheos
Copy link
Contributor

Hi @lwang19-ai, sorry you're running into this issue! Can you please share the stacktrace?

@lwang19-ai
Copy link
Author

lwang19-ai commented Jun 3, 2024

@georgematheos Sure.
image

# Generate proposal samples by doing pF separately.
# Then regenerate the proposal samples from first time step.

using Gen, CSV, DataFrames
include("../forward_models/gen_ks_atm_unfold.jl")

# Read from CSV files
df_atm_obs_read = CSV.read("../synthetic_truth/atm_obs.csv", DataFrame)
df_atm_ref_read = CSV.read("../synthetic_truth/atm_ref.csv", DataFrame)

# convert these back into arrays:
atm_obs = Matrix(df_atm_obs_read)


Nx = 32
Nt = 20
num_particles = 100
init_obs_atm = Gen.choicemap()
start_ass_time = 4
for i in 4:4:Nx
    init_obs_atm[(:chain=>start_ass_time=>:u=>i=>:atm_obs)] = atm_obs[start_ass_time+1, i] # Initial observation of atmosphere
end

state_atm = Gen.initialize_particle_filter(unfold_atm, (start_ass_time,), init_obs_atm, num_particles)


state_atm = Gen.map_optimize(new_trace_atm, Gen.select(:chain => 1 => :u => 1 => :u))
'''

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants