Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
odow committed Jul 25, 2024
1 parent 9772f51 commit b96fae6
Showing 1 changed file with 8 additions and 15 deletions.
23 changes: 8 additions & 15 deletions test/algorithm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -246,12 +246,10 @@ function test_refine_at_similar_nodes()
iteration_limit = 1,
refine_at_similar_nodes = false,
print_level = 0,
parallel_scheme = SDDP.Serial(),
)
@test SDDP.calculate_bound(model) 5.7 || SDDP.calculate_bound(model) 6.3
mi1 = length(model[(1, 1)].bellman_function.global_theta.cuts)
mi2 = length(model[(1, 2)].bellman_function.global_theta.cuts)
@test mi1 + mi2 == 1
@test mi1 + mi2 == length(model.most_recent_training_results.log)

model = SDDP.MarkovianPolicyGraph(
transition_matrices = [[0.5 0.5], [0.2 0.8; 0.8 0.2]],
Expand All @@ -268,11 +266,10 @@ function test_refine_at_similar_nodes()
iteration_limit = 1,
refine_at_similar_nodes = true,
print_level = 0,
parallel_scheme = SDDP.Serial(),
)
@test SDDP.calculate_bound(model) 9.5
@test length(model[(1, 1)].bellman_function.global_theta.cuts) == 1
@test length(model[(1, 2)].bellman_function.global_theta.cuts) == 1
@test length(model[(1, 1)].bellman_function.global_theta.cuts) ==
length(model[(1, 2)].bellman_function.global_theta.cuts) ==
length(model.most_recent_training_results.log)
return
end

Expand Down Expand Up @@ -324,19 +321,15 @@ function test_write_log_to_csv()
end
end
@test_throws ErrorException SDDP.write_log_to_csv(model, "sddp.csv")
SDDP.train(
model;
iteration_limit = 2,
print_level = 0,
parallel_scheme = SDDP.Serial(),
)
SDDP.train(model; iteration_limit = 2, print_level = 0)
SDDP.write_log_to_csv(model, "sddp.csv")
log = read("sddp.csv", String)
saved_log = """
iteration, simulation, bound, time
1, 3.0, 3.0, 2.993860960006714
2, 3.0, 3.0, 2.994189739227295
"""
for i in 1:length(model.most_recent_training_results.log)
saved_log *= "$i, 3.0, 3.0, 3.0\n"
end
@test replace(log, r"[0-9\.]+\n" => "") ==
replace(saved_log, r"[0-9\.]+\n" => "")
rm("sddp.csv")
Expand Down

0 comments on commit b96fae6

Please sign in to comment.