Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add terminate_on_cycle option to Historical sampling scheme #549

Merged
merged 1 commit into from
Dec 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 25 additions & 10 deletions src/plugins/sampling_schemes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,7 @@ mutable struct Historical{T,S} <: AbstractSamplingScheme
scenarios::Vector{Noise{Vector{Tuple{T,S}}}}
sequential::Bool
counter::Int
terminate_on_cycle::Bool
end

function Base.show(io::IO, h::Historical)
Expand All @@ -331,7 +332,8 @@ end
"""
Historical(
scenarios::Vector{Vector{Tuple{T,S}}},
probability::Vector{Float64},
probability::Vector{Float64};
terminate_on_cycle::Bool = false,
) where {T,S}

A sampling scheme that samples a scenario from the vector of scenarios
Expand All @@ -352,7 +354,8 @@ Historical(
"""
function Historical(
scenarios::Vector{Vector{Tuple{T,S}}},
probability::Vector{Float64},
probability::Vector{Float64};
terminate_on_cycle::Bool = false,
) where {T,S}
if !(sum(probability) ≈ 1.0)
error(
Expand All @@ -361,11 +364,14 @@ function Historical(
)
end
output = [Noise(s, p) for (s, p) in zip(scenarios, probability)]
return Historical(output, false, 0)
return Historical(output, false, 0, terminate_on_cycle)
end

"""
Historical(scenarios::Vector{Vector{Tuple{T,S}}}) where {T,S}
Historical(
scenarios::Vector{Vector{Tuple{T,S}}};
terminate_on_cycle::Bool = false,
) where {T,S}

A deterministic sampling scheme that iterates through the vector of provided
`scenarios`.
Expand All @@ -380,12 +386,18 @@ Historical([
])
```
"""
function Historical(scenarios::Vector{Vector{Tuple{T,S}}}) where {T,S}
return Historical(Noise.(scenarios, NaN), true, 0)
function Historical(
scenarios::Vector{Vector{Tuple{T,S}}};
terminate_on_cycle::Bool = false,
) where {T,S}
return Historical(Noise.(scenarios, NaN), true, 0, terminate_on_cycle)
end

"""
Historical(scenario::Vector{Tuple{T,S}}) where {T,S}
Historical(
scenario::Vector{Tuple{T,S}};
terminate_on_cycle::Bool = false,
) where {T,S}

A deterministic sampling scheme that always samples `scenario`.

Expand All @@ -395,7 +407,9 @@ A deterministic sampling scheme that always samples `scenario`.
Historical([(1, 0.5), (2, 1.5), (3, 0.75)])
```
"""
Historical(scenario::Vector{Tuple{T,S}}) where {T,S} = Historical([scenario])
function Historical(scenario::Vector{Tuple{T,S}}; kwargs...) where {T,S}
return Historical([scenario]; kwargs...)
end

function sample_scenario(
graph::PolicyGraph{T},
Expand All @@ -404,14 +418,15 @@ function sample_scenario(
# us the full scenario.
kwargs...,
) where {T,NoiseTerm}
ret = sampling_scheme.terminate_on_cycle
if sampling_scheme.sequential
sampling_scheme.counter += 1
if sampling_scheme.counter > length(sampling_scheme.scenarios)
sampling_scheme.counter = 1
end
return sampling_scheme.scenarios[sampling_scheme.counter].term, false
return sampling_scheme.scenarios[sampling_scheme.counter].term, ret
end
return sample_noise(sampling_scheme.scenarios), false
return sample_noise(sampling_scheme.scenarios), ret
end

"""
Expand Down
24 changes: 24 additions & 0 deletions test/plugins/sampling_schemes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,30 @@ function test_Historical_SingleTrajectory()
return
end

function test_Historical_SingleTrajectory_terminate_on_cycle()
model = SDDP.LinearPolicyGraph(
stages = 2,
lower_bound = 0.0,
direct_mode = false,
) do node, stage
@variable(node, 0 <= x <= 1)
SDDP.parameterize(node, stage * [1, 3], [0.5, 0.5]) do ω
return JuMP.set_upper_bound(x, ω)
end
end
scenario, terminated_due_to_cycle = SDDP.sample_scenario(
model,
SDDP.Historical(
[(1, 0.1), (2, 0.2), (1, 0.3)];
terminate_on_cycle = true,
),
)
@test length(scenario) == 3
@test terminated_due_to_cycle
@test scenario == [(1, 0.1), (2, 0.2), (1, 0.3)]
return
end

function test_Historical_multiple()
model = SDDP.LinearPolicyGraph(
stages = 2,
Expand Down