Skip to content

Commit

Permalink
Autogenerate the validation tables for Python (#1797)
Browse files Browse the repository at this point in the history
Fixes #1779
  • Loading branch information
Jingru923 authored Sep 6, 2024
1 parent d319af4 commit 4e48967
Show file tree
Hide file tree
Showing 8 changed files with 171 additions and 53 deletions.
11 changes: 10 additions & 1 deletion core/src/config.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ using OrdinaryDiffEqRosenbrock: Rodas5, Rosenbrock23

export Config, Solver, Results, Logging, Toml
export algorithm,
snake_case, input_path, results_path, convert_saveat, convert_dt, nodetypes
camel_case, snake_case, input_path, results_path, convert_saveat, convert_dt, nodetypes

const schemas =
getfield.(
Expand All @@ -49,6 +49,15 @@ end

snake_case(sym::Symbol)::Symbol = Symbol(snake_case(String(sym)))

"Convert a string from snake_case to CamelCase."
function camel_case(snake_case::AbstractString)::String
camel_case = replace(snake_case, r"_([a-z])" => s -> uppercase(s[2]))
camel_case = uppercase(first(camel_case)) * camel_case[2:end]
return camel_case
end

camel_case(sym::Symbol)::Symbol = Symbol(camel_case(String(sym)))

"""
Add fieldnames with Union{String, Nothing} type to struct expression. Requires @option use before it.
"""
Expand Down
2 changes: 1 addition & 1 deletion core/src/validation.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Allowed types for downstream (to_node_id) nodes given the type of the upstream (from_node_id) node
neighbortypes(nodetype::Symbol) = neighbortypes(Val(nodetype))
neighbortypes(nodetype::Symbol) = neighbortypes(Val(config.snake_case(nodetype)))
neighbortypes(::Val{:pump}) = Set((:basin, :terminal, :level_boundary))
neighbortypes(::Val{:outlet}) = Set((:basin, :terminal, :level_boundary))
neighbortypes(::Val{:user_demand}) = Set((:basin, :terminal, :level_boundary))
Expand Down
12 changes: 12 additions & 0 deletions core/test/config_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,4 +81,16 @@ end
@test Ribasim.snake_case("CamelCase") == "camel_case"
@test Ribasim.snake_case("ABCdef") == "a_b_cdef"
@test Ribasim.snake_case("snake_case") == "snake_case"
@test Ribasim.snake_case(:CamelCase) == :camel_case
@test Ribasim.snake_case(:ABCdef) == :a_b_cdef
@test Ribasim.snake_case(:snake_case) == :snake_case
end

@testitem "camel_case" begin
@test Ribasim.camel_case("camel_case") == "CamelCase"
@test Ribasim.camel_case("a_b_cdef") == "ABCdef"
@test Ribasim.camel_case("CamelCase") == "CamelCase"
@test Ribasim.camel_case(:camel_case) == :CamelCase
@test Ribasim.camel_case(:a_b_cdef) == :ABCdef
@test Ribasim.camel_case(:CamelCase) == :CamelCase
end
3 changes: 2 additions & 1 deletion pixi.toml
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,13 @@ tests = { depends_on = ["lint", "test-ribasim-python", "test-ribasim-core"] }
delwaq = { cmd = "pytest python/ribasim/tests/test_delwaq.py" }
model-integration-test = { cmd = "julia --project=core --eval 'using Pkg; Pkg.test(test_args=[\"integration\"])'" }
# Codegen
codegen = { cmd = "julia --project utils/gen_python.jl && ruff format python/ribasim/ribasim/schemas.py", depends_on = [
codegen = { cmd = "julia --project utils/gen_python.jl && ruff format python/ribasim/ribasim/schemas.py python/ribasim/ribasim/validation.py", depends_on = [
"initialize-julia",
], inputs = [
"core",
], outputs = [
"python/ribasim/ribasim/schemas.py",
"python/ribasim/ribasim/validation.py",
] }
# Publish
add-ribasim-icon = { cmd = "rcedit build/ribasim/ribasim.exe --set-icon docs/assets/ribasim.ico" }
Expand Down
136 changes: 87 additions & 49 deletions python/ribasim/ribasim/validation.py
Original file line number Diff line number Diff line change
@@ -1,85 +1,123 @@
# Table for connectivity between different node types
# Automatically generated file. Do not modify.

# Table for connectivity
# "Basin": ["LinearResistance"] means that the downstream of basin can be LinearResistance only
node_type_connectivity: dict[str, list[str]] = {
"Basin": [
"LinearResistance",
"ManningResistance",
"TabulatedRatingCurve",
"Pump",
"PidControl": [
"Outlet",
"UserDemand",
],
"LinearResistance": ["Basin", "LevelBoundary"],
"ManningResistance": ["Basin"],
"TabulatedRatingCurve": ["Basin", "Terminal", "LevelBoundary"],
"LevelBoundary": ["LinearResistance", "Pump", "Outlet", "TabulatedRatingCurve"],
"FlowBoundary": ["Basin", "Terminal", "LevelBoundary"],
"Pump": ["Basin", "Terminal", "LevelBoundary"],
"Outlet": ["Basin", "Terminal", "LevelBoundary"],
"Terminal": [],
"DiscreteControl": [
"Pump",
],
"LevelBoundary": [
"Outlet",
"TabulatedRatingCurve",
"LinearResistance",
"ManningResistance",
"PidControl",
"Pump",
],
"Pump": [
"LevelBoundary",
"Basin",
"Terminal",
],
"UserDemand": [
"LevelBoundary",
"Basin",
"Terminal",
],
"TabulatedRatingCurve": [
"LevelBoundary",
"Basin",
"Terminal",
],
"ContinuousControl": ["Pump", "Outlet"],
"PidControl": ["Pump", "Outlet"],
"UserDemand": ["Basin", "Terminal", "LevelBoundary"],
"LevelDemand": ["Basin"],
"FlowDemand": [
"LinearResistance",
"Outlet",
"TabulatedRatingCurve",
"ManningResistance",
"Pump",
],
"FlowBoundary": [
"LevelBoundary",
"Basin",
"Terminal",
],
"Basin": [
"LinearResistance",
"UserDemand",
"Outlet",
"TabulatedRatingCurve",
"ManningResistance",
"Pump",
],
"ManningResistance": [
"Basin",
],
"LevelDemand": [
"Basin",
],
"DiscreteControl": [
"LinearResistance",
"PidControl",
"Outlet",
"TabulatedRatingCurve",
"ManningResistance",
"Pump",
],
"Outlet": [
"LevelBoundary",
"Basin",
"Terminal",
],
"ContinuousControl": [
"Outlet",
"Pump",
],
"LinearResistance": [
"LevelBoundary",
"Basin",
],
"Terminal": [],
}


# Function to validate connectivity between two node types
# Function to validate connection
def can_connect(node_type_up: str, node_type_down: str) -> bool:
if node_type_up in node_type_connectivity:
return node_type_down in node_type_connectivity[node_type_up]
return False


flow_edge_neighbor_amount: dict[str, list[int]] = {
# list[int] = [in_min, in_max, out_min, out_max]
"Basin": [0, int(1e9), 0, int(1e9)],
"LinearResistance": [1, 1, 1, 1],
"ManningResistance": [1, 1, 1, 1],
"TabulatedRatingCurve": [1, 1, 1, int(1e9)],
"LevelBoundary": [0, int(1e9), 0, int(1e9)],
"FlowBoundary": [0, 0, 1, int(1e9)],
"Pump": [1, 1, 1, int(1e9)],
"Outlet": [1, 1, 1, 1],
"Terminal": [1, int(1e9), 0, 0],
"DiscreteControl": [0, 0, 0, 0],
"ContinuousControl": [0, 0, 0, 0],
"PidControl": [0, 0, 0, 0],
"LevelBoundary": [0, 9223372036854775807, 0, 9223372036854775807],
"Pump": [1, 1, 1, 1],
"UserDemand": [1, 1, 1, 1],
"LevelDemand": [0, 0, 0, 0],
"TabulatedRatingCurve": [1, 1, 1, 1],
"FlowDemand": [0, 0, 0, 0],
"FlowBoundary": [0, 0, 1, 9223372036854775807],
"Basin": [0, 9223372036854775807, 0, 9223372036854775807],
"ManningResistance": [1, 1, 1, 1],
"LevelDemand": [0, 0, 0, 0],
"DiscreteControl": [0, 0, 0, 0],
"Outlet": [1, 1, 1, 1],
"ContinuousControl": [0, 0, 0, 0],
"LinearResistance": [1, 1, 1, 1],
"Terminal": [1, 9223372036854775807, 0, 0],
}

control_edge_neighbor_amount: dict[str, list[int]] = {
# list[int] = [in_min, in_max, out_min, out_max]
"Basin": [0, 1, 0, 0],
"LinearResistance": [0, 1, 0, 0],
"ManningResistance": [0, 1, 0, 0],
"TabulatedRatingCurve": [0, 1, 0, 0],
"PidControl": [0, 1, 1, 1],
"LevelBoundary": [0, 0, 0, 0],
"FlowBoundary": [0, 0, 0, 0],
"Pump": [0, 1, 0, 0],
"Outlet": [0, 1, 0, 0],
"Terminal": [0, 0, 0, 0],
"DiscreteControl": [0, 0, 1, int(1e9)],
"ContinuousControl": [0, 0, 1, int(1e9)],
"PidControl": [0, 1, 1, 1],
"UserDemand": [0, 0, 0, 0],
"LevelDemand": [0, 0, 1, int(1e9)],
"TabulatedRatingCurve": [0, 1, 0, 0],
"FlowDemand": [0, 0, 1, 1],
"FlowBoundary": [0, 0, 0, 0],
"Basin": [0, 1, 0, 0],
"ManningResistance": [0, 1, 0, 0],
"LevelDemand": [0, 0, 1, 9223372036854775807],
"DiscreteControl": [0, 0, 1, 9223372036854775807],
"Outlet": [0, 1, 0, 0],
"ContinuousControl": [0, 0, 1, 9223372036854775807],
"LinearResistance": [0, 1, 0, 0],
"Terminal": [0, 0, 0, 0],
}
2 changes: 1 addition & 1 deletion python/ribasim/tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def test_connectivity(trivial):
with pytest.raises(
ValueError,
match=re.escape(
"Node of type Terminal cannot be downstream of node of type Basin. Possible downstream node: ['LinearResistance', 'ManningResistance', 'TabulatedRatingCurve', 'Pump', 'Outlet', 'UserDemand']"
"Node of type Terminal cannot be downstream of node of type Basin. Possible downstream node: ['LinearResistance', 'UserDemand', 'Outlet', 'TabulatedRatingCurve', 'ManningResistance', 'Pump']"
),
):
model.edge.add(model.basin[6], model.terminal[2147483647])
Expand Down
27 changes: 27 additions & 0 deletions utils/gen_python.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,30 @@ open(normpath(@__DIR__, "..", "python", "ribasim", "ribasim", "schemas.py"), "w"
init = Dict("models" => get_models())
println(io, model_template(; init = init))
end

function get_connectivity()
"""
Set up a vector contains all possible connecting node for all node types.
"""
[
(
name = T,
connectivity = Set(
Ribasim.config.camel_case(x) for x in Ribasim.neighbortypes(T)
),
flow_neighbor_bound = Ribasim.n_neighbor_bounds_flow(T),
control_neighbor_bound = Ribasim.n_neighbor_bounds_control(T),
) for T in keys(Ribasim.config.nodekinds)
]
end

connection_template = Template(
normpath(@__DIR__, "templates", "validation.py.jinja");
config = Dict("trim_blocks" => true, "lstrip_blocks" => true, "autoescape" => false),
)

# Write validation.py
open(normpath(@__DIR__, "..", "python", "ribasim", "ribasim", "validation.py"), "w") do io
init = Dict("nodes" => get_connectivity())
println(io, connection_template(; init = init))
end
31 changes: 31 additions & 0 deletions utils/templates/validation.py.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Automatically generated file. Do not modify.

# Table for connectivity
# "Basin": ["LinearResistance"] means that the downstream of basin can be LinearResistance only
node_type_connectivity: dict[str, list[str]] = {
{% for n in nodes %}
'{{n[:name]}}': [{% for value in n[:connectivity] %}
'{{ value }}',
{% end %}],
{% end %}
}

# Function to validate connection
def can_connect(node_type_up: str, node_type_down: str) -> bool:
if node_type_up in node_type_connectivity:
return node_type_down in node_type_connectivity[node_type_up]
return False

flow_edge_neighbor_amount: dict[str, list[int]] = {
{% for n in nodes %}
'{{n[:name]}}':
[{{ n[:flow_neighbor_bound].in_min }}, {{ n[:flow_neighbor_bound].in_max }}, {{ n[:flow_neighbor_bound].out_min }}, {{ n[:flow_neighbor_bound].out_max }}],
{% end %}
}

control_edge_neighbor_amount: dict[str, list[int]] = {
{% for n in nodes %}
'{{n[:name]}}':
[{{ n[:control_neighbor_bound].in_min }}, {{ n[:control_neighbor_bound].in_max }}, {{ n[:control_neighbor_bound].out_min }}, {{ n[:control_neighbor_bound].out_max }}],
{% end %}
}

0 comments on commit 4e48967

Please sign in to comment.