Skip to content

Commit

Permalink
Make edge_id user configurable (#1737)
Browse files Browse the repository at this point in the history
Fixes #1718

Sets node_id and edge_id as primary key fields in the geopackage. Also
sets fid indexes for all non-node/edge tables. This now sorts the Node
table by node_id only (used to be node_type first).

I triggered some bugs by the amount of code I touched, so I changed some
more:
- Split contextvars into reading and writing (arrow file was trying to
be read on writing a model).
- Simplified sorting
- Moved all the DataFrame(dict(**kwargs)) boilerplate into a single
validator on TableModel.

---------

Co-authored-by: Martijn Visser <[email protected]>
  • Loading branch information
evetion and visr authored Aug 20, 2024
1 parent 9bdb262 commit 1f2e303
Show file tree
Hide file tree
Showing 33 changed files with 280 additions and 252 deletions.
6 changes: 3 additions & 3 deletions core/src/graph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ function create_graph(db::DB, config::Config)::MetaGraph
db,
"""
SELECT
Edge.fid,
Edge.edge_id,
FromNode.node_id AS from_node_id,
FromNode.node_type AS from_node_type,
ToNode.node_id AS to_node_id,
Expand Down Expand Up @@ -59,7 +59,7 @@ function create_graph(db::DB, config::Config)::MetaGraph

errors = false
for (;
fid,
edge_id,
from_node_type,
from_node_id,
to_node_type,
Expand All @@ -79,7 +79,7 @@ function create_graph(db::DB, config::Config)::MetaGraph
subnetwork_id = 0
end
edge_metadata = EdgeMetadata(;
id = fid,
id = edge_id,
flow_idx = edge_type == EdgeType.flow ? flow_counter + 1 : 0,
type = edge_type,
subnetwork_id_source = subnetwork_id,
Expand Down
6 changes: 3 additions & 3 deletions core/src/validation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -510,14 +510,14 @@ end
function valid_edge_types(db::DB)::Bool
edge_rows = execute(
db,
"SELECT fid, from_node_id, to_node_id, edge_type FROM Edge ORDER BY fid",
"SELECT edge_id, from_node_id, to_node_id, edge_type FROM Edge ORDER BY edge_id",
)
errors = false

for (; fid, from_node_id, to_node_id, edge_type) in edge_rows
for (; edge_id, from_node_id, to_node_id, edge_type) in edge_rows
if edge_type ["flow", "control"]
errors = true
@error "Invalid edge type '$edge_type' for edge #$fid from node #$from_node_id to node #$to_node_id."
@error "Invalid edge type '$edge_type' for edge #$edge_id from node #$from_node_id to node #$to_node_id."
end
end
return !errors
Expand Down
2 changes: 1 addition & 1 deletion core/test/run_models_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@

@testset "Results values" begin
@test flow.time[1] == DateTime(2020)
@test coalesce.(flow.edge_id[1:2], -1) == [0, 1]
@test coalesce.(flow.edge_id[1:2], -1) == [100, 101]
@test flow.from_node_id[1:2] == [6, 0]
@test flow.to_node_id[1:2] == [0, 2147483647]

Expand Down
4 changes: 2 additions & 2 deletions core/test/validation_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -285,10 +285,10 @@ end
@test length(logger.logs) == 2
@test logger.logs[1].level == Error
@test logger.logs[1].message ==
"Invalid edge type 'foo' for edge #0 from node #1 to node #2."
"Invalid edge type 'foo' for edge #1 from node #1 to node #2."
@test logger.logs[2].level == Error
@test logger.logs[2].message ==
"Invalid edge type 'bar' for edge #1 from node #2 to node #3."
"Invalid edge type 'bar' for edge #2 from node #2 to node #3."
end

@testitem "Subgrid validation" begin
Expand Down
29 changes: 18 additions & 11 deletions python/ribasim/ribasim/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def __init__(

def into_geodataframe(self, node_type: str, node_id: int) -> GeoDataFrame:
extra = self.model_extra if self.model_extra is not None else {}
return GeoDataFrame(
gdf = GeoDataFrame(
data={
"node_id": pd.Series([node_id], dtype=np.int32),
"node_type": pd.Series([node_type], dtype=str),
Expand All @@ -192,6 +192,8 @@ def into_geodataframe(self, node_type: str, node_id: int) -> GeoDataFrame:
},
geometry=[self.geometry],
)
gdf.set_index("node_id", inplace=True)
return gdf


class MultiNodeModel(NodeModel):
Expand Down Expand Up @@ -229,8 +231,8 @@ def add(
)

if node_id is None:
node_id = self._parent.used_node_ids.new_id()
elif node_id in self._parent.used_node_ids:
node_id = self._parent._used_node_ids.new_id()
elif node_id in self._parent._used_node_ids:
raise ValueError(
f"Node IDs have to be unique, but {node_id} already exists."
)
Expand All @@ -243,17 +245,22 @@ def add(
)
assert table.df is not None
table_to_append = table.df.assign(node_id=node_id)
setattr(self, member_name, pd.concat([existing_table, table_to_append]))
setattr(
self,
member_name,
pd.concat([existing_table, table_to_append], ignore_index=True),
)

node_table = node.into_geodataframe(
node_type=self.__class__.__name__, node_id=node_id
)
self.node.df = (
node_table
if self.node.df is None
else pd.concat([self.node.df, node_table])
)
self._parent.used_node_ids.add(node_id)
if self.node.df is None:
self.node.df = node_table
else:
df = pd.concat([self.node.df, node_table])
self.node.df = df

self._parent._used_node_ids.add(node_id)
return self[node_id]

def __getitem__(self, index: int) -> NodeData:
Expand All @@ -265,7 +272,7 @@ def __getitem__(self, index: int) -> NodeData:
f"{node_model_name} index must be an integer, not {indextype}"
)

row = self.node[index].iloc[0]
row = self.node.df.loc[index]
return NodeData(
node_id=int(index), node_type=row["node_type"], geometry=row["geometry"]
)
Expand Down
4 changes: 2 additions & 2 deletions python/ribasim/ribasim/delwaq/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,9 @@ def _setup_graph(nodes, edge, use_evaporation=True):
for row in nodes.df.itertuples():
if row.node_type not in ribasim.geometry.edge.SPATIALCONTROLNODETYPES:
G.add_node(
row.node_id,
row.Index,
type=row.node_type,
id=row.node_id,
id=row.Index,
x=row.geometry.x,
y=row.geometry.y,
pos=(row.geometry.x, row.geometry.y),
Expand Down
3 changes: 2 additions & 1 deletion python/ribasim/ribasim/geometry/area.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@

import pandera as pa
from pandera.dtypes import Int32
from pandera.typing import Series
from pandera.typing import Index, Series
from pandera.typing.geopandas import GeoSeries

from ribasim.schemas import _BaseSchema


class BasinAreaSchema(_BaseSchema):
fid: Index[Int32] = pa.Field(default=0, check_name=True)
node_id: Series[Int32] = pa.Field(nullable=False, default=0)
geometry: GeoSeries[Any] = pa.Field(default=None, nullable=True)
59 changes: 32 additions & 27 deletions python/ribasim/ribasim/geometry/edge.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, NamedTuple
from typing import Any, NamedTuple, Optional

import matplotlib.pyplot as plt
import numpy as np
Expand All @@ -8,11 +8,14 @@
from matplotlib.axes import Axes
from numpy.typing import NDArray
from pandera.dtypes import Int32
from pandera.typing import Series
from pandera.typing import Index, Series
from pandera.typing.geopandas import GeoDataFrame, GeoSeries
from pydantic import NonNegativeInt, PrivateAttr
from shapely.geometry import LineString, MultiLineString, Point

from ribasim.input_base import SpatialTableModel
from ribasim.schemas import _BaseSchema
from ribasim.utils import UsedIDs

__all__ = ("EdgeTable",)

Expand All @@ -31,30 +34,33 @@ class NodeData(NamedTuple):
geometry: Point


class EdgeSchema(pa.DataFrameModel):
class EdgeSchema(_BaseSchema):
edge_id: Index[Int32] = pa.Field(default=0, ge=0, check_name=True)
name: Series[str] = pa.Field(default="")
from_node_id: Series[Int32] = pa.Field(default=0, coerce=True)
to_node_id: Series[Int32] = pa.Field(default=0, coerce=True)
edge_type: Series[str] = pa.Field(default="flow", coerce=True)
subnetwork_id: Series[pd.Int32Dtype] = pa.Field(
default=pd.NA, nullable=True, coerce=True
)
from_node_id: Series[Int32] = pa.Field(default=0)
to_node_id: Series[Int32] = pa.Field(default=0)
edge_type: Series[str] = pa.Field(default="flow")
subnetwork_id: Series[pd.Int32Dtype] = pa.Field(default=pd.NA, nullable=True)
geometry: GeoSeries[Any] = pa.Field(default=None, nullable=True)

class Config:
add_missing_columns = True
@classmethod
def _index_name(self) -> str:
return "edge_id"


class EdgeTable(SpatialTableModel[EdgeSchema]):
"""Defines the connections between nodes."""

_used_edge_ids: UsedIDs = PrivateAttr(default_factory=UsedIDs)

def add(
self,
from_node: NodeData,
to_node: NodeData,
geometry: LineString | MultiLineString | None = None,
name: str = "",
subnetwork_id: int | None = None,
edge_id: Optional[NonNegativeInt] = None,
**kwargs,
):
"""Add an edge between nodes. The type of the edge (flow or control)
Expand Down Expand Up @@ -84,39 +90,38 @@ def add(
"control" if from_node.node_type in SPATIALCONTROLNODETYPES else "flow"
)
assert self.df is not None
if edge_id is None:
edge_id = self._used_edge_ids.new_id()
elif edge_id in self._used_edge_ids:
raise ValueError(
f"Edge IDs have to be unique, but {edge_id} already exists."
)

table_to_append = GeoDataFrame[EdgeSchema](
data={
"from_node_id": pd.Series([from_node.node_id], dtype=np.int32),
"to_node_id": pd.Series([to_node.node_id], dtype=np.int32),
"edge_type": pd.Series([edge_type], dtype=str),
"name": pd.Series([name], dtype=str),
"subnetwork_id": pd.Series([subnetwork_id], dtype=pd.Int32Dtype()),
"from_node_id": [from_node.node_id],
"to_node_id": [to_node.node_id],
"edge_type": [edge_type],
"name": [name],
"subnetwork_id": [subnetwork_id],
**kwargs,
},
geometry=geometry_to_append,
crs=self.df.crs,
index=pd.Index([edge_id], name="edge_id"),
)

self.df = GeoDataFrame[EdgeSchema](
pd.concat([self.df, table_to_append], ignore_index=True)
)
self.df = GeoDataFrame[EdgeSchema](pd.concat([self.df, table_to_append]))
if self.df.duplicated(subset=["from_node_id", "to_node_id"]).any():
raise ValueError(
f"Edges have to be unique, but edge ({from_node.node_id}, {to_node.node_id}) already exists."
f"Edges have to be unique, but edge with from_node_id {from_node.node_id} to_node_id {to_node.node_id} already exists."
)
self.df.index.name = "fid"
self._used_edge_ids.add(edge_id)

def _get_where_edge_type(self, edge_type: str) -> NDArray[np.bool_]:
assert self.df is not None
return (self.df.edge_type == edge_type).to_numpy()

def sort(self):
# Only sort the index (fid / edge_id) since this needs to be sorted in a GeoPackage.
# Under most circumstances, this retains the input order,
# making the edge_id as stable as possible; useful for post-processing.
self.df.sort_index(inplace=True)

def plot(self, **kwargs) -> Axes:
"""Plot the edges of the model.
Expand Down
23 changes: 8 additions & 15 deletions python/ribasim/ribasim/geometry/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,27 @@
import pandera as pa
from matplotlib.patches import Patch
from pandera.dtypes import Int32
from pandera.typing import Series
from pandera.typing import Index, Series
from pandera.typing.geopandas import GeoSeries

from ribasim.input_base import SpatialTableModel
from ribasim.schemas import _BaseSchema

__all__ = ("NodeTable",)


class NodeSchema(pa.DataFrameModel):
node_id: Series[Int32] = pa.Field(ge=0)
class NodeSchema(_BaseSchema):
node_id: Index[Int32] = pa.Field(default=0, check_name=True)
name: Series[str] = pa.Field(default="")
node_type: Series[str] = pa.Field(default="")
subnetwork_id: Series[pd.Int32Dtype] = pa.Field(
default=pd.NA, nullable=True, coerce=True
)
geometry: GeoSeries[Any] = pa.Field(default=None, nullable=True)

class Config:
add_missing_columns = True
coerce = True
@classmethod
def _index_name(self) -> str:
return "node_id"


class NodeTable(SpatialTableModel[NodeSchema]):
Expand All @@ -37,12 +38,6 @@ def filter(self, nodetype: str):
if self.df is not None:
mask = self.df[self.df["node_type"] != nodetype].index
self.df.drop(mask, inplace=True)
self.df.reset_index(inplace=True, drop=True)

def sort(self):
assert self.df is not None
sort_keys = ["node_type", "node_id"]
self.df.sort_values(sort_keys, ignore_index=True, inplace=True)

def plot_allocation_networks(self, ax=None, zorder=None) -> Any:
if ax is None:
Expand Down Expand Up @@ -156,9 +151,7 @@ def plot(self, ax=None, zorder=None) -> Any:

assert self.df is not None
geometry = self.df["geometry"]
for text, xy in zip(
self.df["node_id"], np.column_stack((geometry.x, geometry.y))
):
for text, xy in zip(self.df.index, np.column_stack((geometry.x, geometry.y))):
ax.annotate(text=text, xy=xy, xytext=(2.0, 2.0), textcoords="offset points")

return ax
Loading

0 comments on commit 1f2e303

Please sign in to comment.