From e33f8dfae3f9aa7ce8048e74d14ad0718f6f8ef1 Mon Sep 17 00:00:00 2001 From: odow Date: Mon, 19 Dec 2022 09:25:46 +1300 Subject: [PATCH] Add terminate_on_cycle option to Historical sampling scheme --- src/plugins/sampling_schemes.jl | 35 +++++++++++++++++++++++--------- test/plugins/sampling_schemes.jl | 24 ++++++++++++++++++++++ 2 files changed, 49 insertions(+), 10 deletions(-) diff --git a/src/plugins/sampling_schemes.jl b/src/plugins/sampling_schemes.jl index 50948b576..3dcd3b0ea 100644 --- a/src/plugins/sampling_schemes.jl +++ b/src/plugins/sampling_schemes.jl @@ -317,6 +317,7 @@ mutable struct Historical{T,S} <: AbstractSamplingScheme scenarios::Vector{Noise{Vector{Tuple{T,S}}}} sequential::Bool counter::Int + terminate_on_cycle::Bool end function Base.show(io::IO, h::Historical) @@ -331,7 +332,8 @@ end """ Historical( scenarios::Vector{Vector{Tuple{T,S}}}, - probability::Vector{Float64}, + probability::Vector{Float64}; + terminate_on_cycle::Bool = false, ) where {T,S} A sampling scheme that samples a scenario from the vector of scenarios @@ -352,7 +354,8 @@ Historical( """ function Historical( scenarios::Vector{Vector{Tuple{T,S}}}, - probability::Vector{Float64}, + probability::Vector{Float64}; + terminate_on_cycle::Bool = false, ) where {T,S} if !(sum(probability) ≈ 1.0) error( @@ -361,11 +364,14 @@ function Historical( ) end output = [Noise(s, p) for (s, p) in zip(scenarios, probability)] - return Historical(output, false, 0) + return Historical(output, false, 0, terminate_on_cycle) end """ - Historical(scenarios::Vector{Vector{Tuple{T,S}}}) where {T,S} + Historical( + scenarios::Vector{Vector{Tuple{T,S}}}; + terminate_on_cycle::Bool = false, + ) where {T,S} A deterministic sampling scheme that iterates through the vector of provided `scenarios`. @@ -380,12 +386,18 @@ Historical([ ]) ``` """ -function Historical(scenarios::Vector{Vector{Tuple{T,S}}}) where {T,S} - return Historical(Noise.(scenarios, NaN), true, 0) +function Historical( + scenarios::Vector{Vector{Tuple{T,S}}}; + terminate_on_cycle::Bool = false, +) where {T,S} + return Historical(Noise.(scenarios, NaN), true, 0, terminate_on_cycle) end """ - Historical(scenario::Vector{Tuple{T,S}}) where {T,S} + Historical( + scenario::Vector{Tuple{T,S}}; + terminate_on_cycle::Bool = false, + ) where {T,S} A deterministic sampling scheme that always samples `scenario`. @@ -395,7 +407,9 @@ A deterministic sampling scheme that always samples `scenario`. Historical([(1, 0.5), (2, 1.5), (3, 0.75)]) ``` """ -Historical(scenario::Vector{Tuple{T,S}}) where {T,S} = Historical([scenario]) +function Historical(scenario::Vector{Tuple{T,S}}; kwargs...) where {T,S} + return Historical([scenario]; kwargs...) +end function sample_scenario( graph::PolicyGraph{T}, @@ -404,14 +418,15 @@ function sample_scenario( # us the full scenario. kwargs..., ) where {T,NoiseTerm} + ret = sampling_scheme.terminate_on_cycle if sampling_scheme.sequential sampling_scheme.counter += 1 if sampling_scheme.counter > length(sampling_scheme.scenarios) sampling_scheme.counter = 1 end - return sampling_scheme.scenarios[sampling_scheme.counter].term, false + return sampling_scheme.scenarios[sampling_scheme.counter].term, ret end - return sample_noise(sampling_scheme.scenarios), false + return sample_noise(sampling_scheme.scenarios), ret end """ diff --git a/test/plugins/sampling_schemes.jl b/test/plugins/sampling_schemes.jl index 6512ea899..79593e9c4 100644 --- a/test/plugins/sampling_schemes.jl +++ b/test/plugins/sampling_schemes.jl @@ -181,6 +181,30 @@ function test_Historical_SingleTrajectory() return end +function test_Historical_SingleTrajectory_terminate_on_cycle() + model = SDDP.LinearPolicyGraph( + stages = 2, + lower_bound = 0.0, + direct_mode = false, + ) do node, stage + @variable(node, 0 <= x <= 1) + SDDP.parameterize(node, stage * [1, 3], [0.5, 0.5]) do ω + return JuMP.set_upper_bound(x, ω) + end + end + scenario, terminated_due_to_cycle = SDDP.sample_scenario( + model, + SDDP.Historical( + [(1, 0.1), (2, 0.2), (1, 0.3)]; + terminate_on_cycle = true, + ), + ) + @test length(scenario) == 3 + @test terminated_due_to_cycle + @test scenario == [(1, 0.1), (2, 0.2), (1, 0.3)] + return +end + function test_Historical_multiple() model = SDDP.LinearPolicyGraph( stages = 2,