Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Threaded parallel scheme #758

Merged
merged 13 commits into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ jobs:
- uses: julia-actions/cache@v1
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-runtest@v1
env:
JULIA_NUM_THREADS: 4
- uses: julia-actions/julia-processcoverage@v1
- uses: codecov/codecov-action@v4
with:
Expand Down
1 change: 1 addition & 0 deletions docs/src/apireference.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ SDDP.SimulatorSamplingScheme
```@docs
SDDP.AbstractParallelScheme
SDDP.Serial
SDDP.Threaded
SDDP.Asynchronous
```

Expand Down
89 changes: 63 additions & 26 deletions src/algorithm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,16 @@
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.

macro _timeit_threadsafe(timer, label, block)
return esc(quote
if Threads.threadid() == 1
TimerOutputs.@timeit $timer $label $block
else
$block
end
end)
end

# to_nodal_form is an internal helper function so users can pass arguments like:
# risk_measure = SDDP.Expectation(),
# risk_measure = Dict(1=>Expectation(), 2=>WorstCase())
Expand Down Expand Up @@ -101,6 +111,8 @@
forward_pass_callback::Any
post_iteration_callback::Any
last_log_iteration::Ref{Int}
# For threading
lock::ReentrantLock
# Internal function: users should never construct this themselves.
function Options(
model::PolicyGraph{T},
Expand Down Expand Up @@ -144,6 +156,7 @@
forward_pass_callback,
post_iteration_callback,
Ref{Int}(0), # last_log_iteration
ReentrantLock(),
)
end
end
Expand Down Expand Up @@ -387,6 +400,7 @@
scenario_path::Vector{Tuple{T,S}};
duality_handler::Union{Nothing,AbstractDualityHandler},
) where {T,S}
lock(node.lock) # LOCK-ID-005
_initialize_solver(node; throw_error = false)
# Parameterize the model. First, fix the value of the incoming state
# variables. Then parameterize the model depending on `noise`. Finally,
Expand Down Expand Up @@ -423,12 +437,13 @@
end
state = get_outgoing_state(node)
stage_objective = stage_objective_value(node.stage_objective)
TimerOutputs.@timeit model.timer_output "get_dual_solution" begin
@_timeit_threadsafe model.timer_output "get_dual_solution" begin

Check warning on line 440 in src/algorithm.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithm.jl#L440

Added line #L440 was not covered by tests
objective, dual_values = get_dual_solution(node, duality_handler)
end
if node.post_optimize_hook !== nothing
node.post_optimize_hook(pre_optimize_ret)
end
unlock(node.lock) # LOCK-ID-005
return (
state = state,
duals = dual_values,
Expand Down Expand Up @@ -505,10 +520,6 @@
objective_states::Vector{NTuple{N,Float64}},
belief_states::Vector{Tuple{Int,Dict{T,Float64}}},
) where {T,NoiseType,N}
TimerOutputs.@timeit model.timer_output "prepare_backward_pass" begin
restore_duality =
prepare_backward_pass(model, options.duality_handler, options)
end
# TODO(odow): improve storage type.
cuts = Dict{T,Vector{Any}}(index => Any[] for index in keys(model.nodes))
for index in length(scenario_path):-1:1
Expand All @@ -533,6 +544,7 @@
options.backward_sampling_scheme,
scenario_path[1:index],
options.duality_handler,
options,
)
end
# We need to refine our estimate at all nodes in the partition.
Expand Down Expand Up @@ -573,6 +585,7 @@
options.backward_sampling_scheme,
scenario_path[1:index],
options.duality_handler,
options,
)
new_cuts = refine_bellman_function(
model,
Expand Down Expand Up @@ -613,9 +626,6 @@
end
end
end
TimerOutputs.@timeit model.timer_output "prepare_backward_pass" begin
restore_duality()
end
return cuts
end

Expand Down Expand Up @@ -652,13 +662,22 @@
backward_sampling_scheme::AbstractBackwardSamplingScheme,
scenario_path,
duality_handler::Union{Nothing,AbstractDualityHandler},
options,
) where {T}
length_scenario_path = length(scenario_path)
for child in node.children
if isapprox(child.probability, 0.0, atol = 1e-6)
continue
end
child_node = model[child.term]
lock(child_node.lock) # LOCK-ID-004
@_timeit_threadsafe model.timer_output "prepare_backward_pass" begin

Check warning on line 674 in src/algorithm.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithm.jl#L674

Added line #L674 was not covered by tests
restore_duality = prepare_backward_pass(
child_node,
options.duality_handler,
options,
)
end
for noise in sample_backward_noise_terms_with_state(
backward_sampling_scheme,
child_node,
Expand Down Expand Up @@ -695,7 +714,7 @@
noise.term,
)
end
TimerOutputs.@timeit model.timer_output "solve_subproblem" begin
@_timeit_threadsafe model.timer_output "solve_subproblem" begin

Check warning on line 717 in src/algorithm.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithm.jl#L717

Added line #L717 was not covered by tests
subproblem_results = solve_subproblem(
model,
child_node,
Expand All @@ -715,6 +734,10 @@
length(items.duals)
end
end
@_timeit_threadsafe model.timer_output "prepare_backward_pass" begin

Check warning on line 737 in src/algorithm.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithm.jl#L737

Added line #L737 was not covered by tests
restore_duality()
end
unlock(child_node.lock) # LOCK-ID-004
end
if length(scenario_path) == length_scenario_path
# No-op. There weren't any children to solve.
Expand Down Expand Up @@ -752,6 +775,7 @@
continue
end
node = model[child.term]
lock(node.lock) # LOCK-ID-006
for noise in node.noise_terms
if node.objective_state !== nothing
update_objective_state(
Expand Down Expand Up @@ -783,6 +807,7 @@
push!(probabilities, child.probability * noise.probability)
push!(noise_supports, noise.term)
end
unlock(node.lock) # LOCK-ID-006
end
# Now compute the risk-adjusted probability measure:
risk_adjusted_probability = similar(probabilities)
Expand Down Expand Up @@ -812,11 +837,11 @@

function iteration(model::PolicyGraph{T}, options::Options) where {T}
model.ext[:numerical_issue] = false
TimerOutputs.@timeit model.timer_output "forward_pass" begin
@_timeit_threadsafe model.timer_output "forward_pass" begin

Check warning on line 840 in src/algorithm.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithm.jl#L840

Added line #L840 was not covered by tests
forward_trajectory = forward_pass(model, options, options.forward_pass)
options.forward_pass_callback(forward_trajectory)
end
TimerOutputs.@timeit model.timer_output "backward_pass" begin
@_timeit_threadsafe model.timer_output "backward_pass" begin

Check warning on line 844 in src/algorithm.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithm.jl#L844

Added line #L844 was not covered by tests
cuts = backward_pass(
model,
options,
Expand All @@ -826,26 +851,31 @@
forward_trajectory.belief_states,
)
end
TimerOutputs.@timeit model.timer_output "calculate_bound" begin
@_timeit_threadsafe model.timer_output "calculate_bound" begin

Check warning on line 854 in src/algorithm.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithm.jl#L854

Added line #L854 was not covered by tests
bound = calculate_bound(model)
end
push!(
options.log,
Log(
length(options.log) + 1,
bound,
forward_trajectory.cumulative_value,
time() - options.start_time,
Distributed.myid(),
model.ext[:total_solves],
duality_log_key(options.duality_handler),
model.ext[:numerical_issue],
),
)
lock(options.lock)
try
push!(
options.log,
Log(
length(options.log) + 1,
bound,
forward_trajectory.cumulative_value,
time() - options.start_time,
max(Threads.threadid(), Distributed.myid()),
model.ext[:total_solves],
duality_log_key(options.duality_handler),
model.ext[:numerical_issue],
),
)
finally
unlock(options.lock)
end
has_converged, status =
convergence_test(model, options.log, options.stopping_rules)
return IterationResult(
Distributed.myid(),
max(Threads.threadid(), Distributed.myid()),
bound,
forward_trajectory.cumulative_value,
has_converged,
Expand Down Expand Up @@ -1130,6 +1160,11 @@
finally
# And close the dashboard callback if necessary.
dashboard_callback(nothing, true)
for node in values(model.nodes)
if islocked(node.lock)
unlock(node.lock)
end
end
end
training_results = TrainingResults(status, log)
model.most_recent_training_results = training_results
Expand Down Expand Up @@ -1177,6 +1212,7 @@
objective_states = NTuple{N,Float64}[]
for (depth, (node_index, noise)) in enumerate(scenario_path)
node = model[node_index]
lock(node.lock) # LOCK-ID-002
# Objective state interpolation.
objective_state_vector = update_objective_state(
node.objective_state,
Expand Down Expand Up @@ -1253,6 +1289,7 @@
push!(simulation, store)
# Set outgoing state as the incoming state for the next node.
incoming_state = copy(subproblem_results.state)
unlock(node.lock) # LOCK-ID-002
end
return simulation
end
Expand Down
9 changes: 6 additions & 3 deletions src/plugins/bellman_functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,7 @@ function refine_bellman_function(
nominal_probability::Vector{Float64},
objective_realizations::Vector{Float64},
) where {T}
lock(node.lock) # LOCK-ID-003
# Sanity checks.
@assert length(dual_variables) ==
length(noise_supports) ==
Expand All @@ -426,8 +427,8 @@ function refine_bellman_function(
model.objective_sense == MOI.MIN_SENSE,
)
# The meat of the function.
if bellman_function.cut_type == SINGLE_CUT
return _add_average_cut(
ret = if bellman_function.cut_type == SINGLE_CUT
_add_average_cut(
node,
outgoing_state,
risk_adjusted_probability,
Expand All @@ -438,7 +439,7 @@ function refine_bellman_function(
else # Add a multi-cut
@assert bellman_function.cut_type == MULTI_CUT
_add_locals_if_necessary(node, bellman_function, length(dual_variables))
return _add_multi_cut(
_add_multi_cut(
node,
outgoing_state,
risk_adjusted_probability,
Expand All @@ -447,6 +448,8 @@ function refine_bellman_function(
offset,
)
end
unlock(node.lock) # LOCK-ID-003
return ret
end

function _add_average_cut(
Expand Down
29 changes: 7 additions & 22 deletions src/plugins/duality_handlers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,24 +61,6 @@ SDDiP(args...; kwargs...) = _deprecate_integrality_handler()

ContinuousRelaxation(args...; kwargs...) = _deprecate_integrality_handler()

function prepare_backward_pass(
model::PolicyGraph,
duality_handler::AbstractDualityHandler,
options::Options,
)
undo = Function[]
for (_, node) in model.nodes
push!(undo, prepare_backward_pass(node, duality_handler, options))
end
function undo_relax()
for f in undo
f()
end
return
end
return undo_relax
end

function get_dual_solution(node::Node, ::Nothing)
return JuMP.objective_value(node.subproblem), Dict{Symbol,Float64}()
end
Expand Down Expand Up @@ -351,8 +333,10 @@ focus more on the more-recent rewards.
mutable struct BanditDuality <: AbstractDualityHandler
arms::Vector{_BanditArm}
last_arm_index::Int
logs_seen::Int

function BanditDuality(args::AbstractDualityHandler...)
return new(_BanditArm[_BanditArm(arg, Float64[]) for arg in args], 1)
return new(_BanditArm[_BanditArm(arg, Float64[]) for arg in args], 1, 1)
end
end

Expand Down Expand Up @@ -404,15 +388,16 @@ function _update_rewards(handler::BanditDuality, log::Vector{Log})
end

function prepare_backward_pass(
model::PolicyGraph,
node::Node,
handler::BanditDuality,
options::Options,
)
if length(options.log) > 1
if length(options.log) > handler.logs_seen
_update_rewards(handler, options.log)
handler.logs_seen = length(options.log)
end
arm = _choose_best_arm(handler)
return prepare_backward_pass(model, arm.handler, options)
return prepare_backward_pass(node, arm.handler, options)
end

function get_dual_solution(node::Node, handler::BanditDuality)
Expand Down
6 changes: 4 additions & 2 deletions src/plugins/forward_passes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ function forward_pass(
) where {T}
# First up, sample a scenario. Note that if a cycle is detected, this will
# return the cycle node as well.
TimerOutputs.@timeit model.timer_output "sample_scenario" begin
@_timeit_threadsafe model.timer_output "sample_scenario" begin
scenario_path, terminated_due_to_cycle =
sample_scenario(model, options.sampling_scheme)
end
Expand All @@ -51,6 +51,7 @@ function forward_pass(
# Iterate down the scenario.
for (depth, (node_index, noise)) in enumerate(scenario_path)
node = model[node_index]
lock(node.lock) # LOCK-ID-001
# Objective state interpolation.
objective_state_vector = update_objective_state(
node.objective_state,
Expand Down Expand Up @@ -94,7 +95,7 @@ function forward_pass(
end
# ===== End: starting state for infinite horizon =====
# Solve the subproblem, note that `duality_handler = nothing`.
TimerOutputs.@timeit model.timer_output "solve_subproblem" begin
@_timeit_threadsafe model.timer_output "solve_subproblem" begin
subproblem_results = solve_subproblem(
model,
node,
Expand All @@ -112,6 +113,7 @@ function forward_pass(
# Add the outgoing state variable to the list of states we have sampled
# on this forward pass.
push!(sampled_states, incoming_state_value)
unlock(node.lock) # LOCK-ID-001
end
if terminated_due_to_cycle
# We terminated due to a cycle. Here is the list of possible starting
Expand Down
Loading
Loading