Skip to content

Commit

Permalink
Add a settable optimize_hook with access to the current scenario.
Browse files Browse the repository at this point in the history
  • Loading branch information
odow committed May 21, 2019
1 parent 33e7024 commit efff934
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 17 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ docs/spaghetti_plot.html
docs/src/examples/*.md
test/*.js
*.DS_Store
.vscode/*
52 changes: 35 additions & 17 deletions src/algorithm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -248,14 +248,20 @@ end
function solve_subproblem(model::PolicyGraph{T},
node::Node{T},
state::Dict{Symbol, Float64},
noise;
noise,
scenario_path::Vector{Tuple{T, <:Any}};
require_duals::Bool) where {T}
# Parameterize the model. First, fix the value of the incoming state
# variables. Then parameterize the model depending on `noise`. Finally,
# set the objective.
set_incoming_state(node, state)
parameterize(node, noise)
JuMP.optimize!(node.subproblem)
if node.optimize_hook !== nothing
node.optimize_hook(
model, node, state, noise, scenario_path, require_duals)
else
JuMP.optimize!(node.subproblem)
end
# Test for primal feasibility.
if JuMP.primal_status(node.subproblem) != JuMP.MOI.FEASIBLE_POINT
write_subproblem_to_file(node, "subproblem", throw_error = true)
Expand Down Expand Up @@ -333,7 +339,7 @@ function forward_pass(model::PolicyGraph{T}, options::Options) where {T}
model[scenario_path[1][1]])
objective_states = NTuple{N, Float64}[]
# Iterate down the scenario.
for (node_index, noise) in scenario_path
for (depth, (node_index, noise)) in enumerate(scenario_path)
node = model[node_index]
# Objective state interpolation.
objective_state_vector = update_objective_state(
Expand Down Expand Up @@ -374,7 +380,8 @@ function forward_pass(model::PolicyGraph{T}, options::Options) where {T}
# Solve the subproblem, note that `require_duals = false`.
TimerOutputs.@timeit SDDP_TIMER "solve_subproblem" begin
subproblem_results = solve_subproblem(
model, node, incoming_state_value, noise, require_duals = false)
model, node, incoming_state_value, noise,
scenario_path[1:depth], require_duals = false)
end
# Cumulate the stage_objective.
cumulative_value += subproblem_results.stage_objective
Expand Down Expand Up @@ -456,7 +463,7 @@ function backward_pass(
belief == 0.0 && continue
solve_all_children(
model, model[node_index], items, belief, belief_state,
objective_state, outgoing_state)
objective_state, outgoing_state, scenario_path[1:index])
end
# We need to refine our estimate at all nodes in the partition.
for node_index in model.belief_partition[partition_index]
Expand All @@ -480,7 +487,7 @@ function backward_pass(
end
solve_all_children(
model, node, items, 1.0, belief_state, objective_state,
outgoing_state)
outgoing_state, scenario_path[1:index])
refine_bellman_function(
model, node, node.bellman_function,
options.risk_measures[node_index], outgoing_state,
Expand Down Expand Up @@ -526,10 +533,16 @@ end
function solve_all_children(
model::PolicyGraph{T}, node::Node{T}, items::BackwardPassItems,
belief::Float64, belief_state, objective_state,
outgoing_state::Dict{Symbol, Float64}) where {T}
outgoing_state::Dict{Symbol, Float64}, scenario_path) where {T}
length_scenario_path = length(scenario_path)
for child in node.children
child_node = model[child.term]
for noise in child_node.noise_terms
if length(scenario_path) == length_scenario_path
push!(scenario_path, (child.term, noise.term))
else
scenario_path[end] = (child.term, noise.term)
end
if haskey(items.cached_solutions, (child.term, noise.term))
sol_index = items.cached_solutions[(child.term, noise.term)]
push!(items.duals, items.duals[sol_index])
Expand All @@ -553,7 +566,7 @@ function solve_all_children(
TimerOutputs.@timeit SDDP_TIMER "solve_subproblem" begin
subproblem_results = solve_subproblem(
model, child_node, outgoing_state, noise.term,
require_duals = true)
scenario_path, require_duals = true)
end
push!(items.duals, subproblem_results.duals)
push!(items.supports, noise)
Expand All @@ -565,6 +578,12 @@ function solve_all_children(
end
end
end
if length(scenario_path) == length_scenario_path
# No-op. There weren't any children to solve.
else
# Drop the last element (i.e., the one we added).
pop!(scenario_path)
end
end

"""
Expand Down Expand Up @@ -601,20 +620,18 @@ function calculate_bound(model::PolicyGraph{T},
belief.belief, current_belief, partition_index, noise.term)
end
subproblem_results = solve_subproblem(
model, node, root_state, noise.term, require_duals = false)
model, node, root_state, noise.term,
Tuple{T, Any}[(child.term, noise.term)], require_duals = false)
push!(objectives, subproblem_results.objective)
push!(probabilities, child.probability * noise.probability)
push!(noise_supports, noise.term)
end
end
# Now compute the risk-adjusted probability measure:
risk_adjusted_probability = similar(probabilities)
adjust_probability(risk_measure,
risk_adjusted_probability,
probabilities,
noise_supports,
objectives,
model.objective_sense == MOI.MIN_SENSE)
adjust_probability(
risk_measure, risk_adjusted_probability, probabilities, noise_supports,
objectives, model.objective_sense == MOI.MIN_SENSE)
# Finally, calculate the risk-adjusted value.
return sum(obj * prob for (obj, prob) in
zip(objectives, risk_adjusted_probability))
Expand Down Expand Up @@ -884,7 +901,7 @@ function _simulate(model::PolicyGraph{T},
objective_state_vector, N = initialize_objective_state(
model[scenario_path[1][1]])
objective_states = NTuple{N, Float64}[]
for (node_index, noise) in scenario_path
for (depth, (node_index, noise)) in enumerate(scenario_path)
node = model[node_index]
# Objective state interpolation.
objective_state_vector = update_objective_state(node.objective_state,
Expand All @@ -902,7 +919,8 @@ function _simulate(model::PolicyGraph{T},
end
# Solve the subproblem.
subproblem_results = solve_subproblem(
model, node, incoming_state, noise, require_duals = false)
model, node, incoming_state, noise,
scenario_path[1:depth], require_duals = false)
# Add the stage-objective
cumulative_value += subproblem_results.stage_objective
# Record useful variables from the solve.
Expand Down
5 changes: 5 additions & 0 deletions src/user_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,8 @@ mutable struct Node{T}
objective_state::Union{Nothing, ObjectiveState}
# For dynamic interpolation of belief states.
belief_state::Union{Nothing, BeliefState{T}}
# An over-loadable hook for the JuMP.optimize! function.
optimize_hook::Union{Nothing, Function}
# An extension dictionary. This is a useful place for packages that extend
# SDDP.jl to stash things.
ext::Dict{Symbol, Any}
Expand Down Expand Up @@ -498,6 +500,9 @@ function PolicyGraph(builder::Function, graph::Graph{T};
nothing,
# And for belief states.
nothing,
# The optimize hook defaults to nothing.
nothing,
# The extension dictionary.
Dict{Symbol, Any}()
)
subproblem.ext[:sddp_policy_graph] = policy_graph
Expand Down

0 comments on commit efff934

Please sign in to comment.