Skip to content

Commit

Permalink
Enable read method again with the new add API (#1243)
Browse files Browse the repository at this point in the history
Re-introduce database context (originally in removed Network node) and
split Node tables per Node. Fixes #1232
  • Loading branch information
evetion authored Mar 13, 2024
1 parent bc3e4fc commit e1f8b81
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 37 deletions.
7 changes: 6 additions & 1 deletion python/ribasim/ribasim/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pandas as pd
import pydantic
from geopandas import GeoDataFrame
from pydantic import ConfigDict, Field
from pydantic import ConfigDict, Field, model_validator
from shapely.geometry import Point

from ribasim.geometry import BasinAreaSchema, NodeTable
Expand Down Expand Up @@ -108,6 +108,11 @@ class MultiNodeModel(NodeModel):
node: NodeTable = Field(default_factory=NodeTable)
_node_type: str

@model_validator(mode="after")
def filter(self) -> "MultiNodeModel":
self.node.filter(self.__class__.__name__)
return self

def add(self, node: Node, tables: Sequence[TableModel[Any]] | None = None) -> None:
if tables is None:
tables = []
Expand Down
19 changes: 11 additions & 8 deletions python/ribasim/ribasim/geometry/edge.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
import pandas as pd
import pandera as pa
import shapely
from geopandas import GeoDataFrame
from matplotlib.axes import Axes
from numpy.typing import NDArray
from pandera.typing import DataFrame, Series
from pandera.typing.geopandas import GeoSeries
from pandera.typing import Series
from pandera.typing.geopandas import GeoDataFrame, GeoSeries
from pydantic import model_validator
from shapely.geometry import LineString, MultiLineString, Point

from ribasim.input_base import SpatialTableModel
Expand Down Expand Up @@ -42,9 +42,12 @@ class Config:
class EdgeTable(SpatialTableModel[EdgeSchema]):
"""Defines the connections between nodes."""

def __init__(self, **kwargs):
kwargs.setdefault("df", DataFrame[EdgeSchema]())
super().__init__(**kwargs)
@model_validator(mode="after")
def empty_table(self) -> "EdgeTable":
if self.df is None:
self.df = GeoDataFrame[EdgeSchema]()
self.df.set_geometry("geometry", inplace=True)
return self

def add(
self,
Expand All @@ -60,7 +63,7 @@ def add(
if geometry is None
else [geometry]
)
table_to_append = GeoDataFrame(
table_to_append = GeoDataFrame[EdgeSchema](
data={
"from_node_type": pd.Series([from_node.node_type], dtype=str),
"from_node_id": pd.Series([from_node.node_id], dtype=int),
Expand All @@ -76,7 +79,7 @@ def add(
if self.df is None:
self.df = table_to_append
else:
self.df = pd.concat([self.df, table_to_append])
self.df = GeoDataFrame[EdgeSchema](pd.concat([self.df, table_to_append]))

def get_where_edge_type(self, edge_type: str) -> NDArray[np.bool_]:
assert self.df is not None
Expand Down
7 changes: 7 additions & 0 deletions python/ribasim/ribasim/geometry/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@ class Config:
class NodeTable(SpatialTableModel[NodeSchema]):
"""The Ribasim nodes as Point geometries."""

def filter(self, nodetype: str):
"""Filter the node table based on the node type."""
if self.df is not None:
mask = self.df[self.df["node_type"] != nodetype].index
self.df.drop(mask, inplace=True)
self.df.reset_index(inplace=True, drop=True)

def plot_allocation_networks(self, ax=None, zorder=None) -> Any:
if ax is None:
_, ax = plt.subplots()
Expand Down
7 changes: 4 additions & 3 deletions python/ribasim/ribasim/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,10 @@ def _load(cls, filepath: Path | None) -> dict[str, Any]:
with open(filepath, "rb") as f:
config = tomli.load(f)

context_file_loading.get()["directory"] = filepath.parent / config.get(
"input_dir", "."
)
directory = filepath.parent / config.get("input_dir", ".")
context_file_loading.get()["directory"] = directory
context_file_loading.get()["database"] = directory / "database.gpkg"

return config
else:
return {}
Expand Down
28 changes: 5 additions & 23 deletions python/ribasim/tests/test_io.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import pytest
import ribasim
import tomli
from numpy.testing import assert_array_equal
from pandas import DataFrame
from pandas.testing import assert_frame_equal
from pydantic import ValidationError
Expand All @@ -17,9 +16,9 @@ def __assert_equal(a: DataFrame, b: DataFrame, is_network=False) -> None:
# We set this on write, needed for GeoPackage.
a.index.name = "fid"
a.index.name = "fid"
else:
a = a.reset_index(drop=True)
b = b.reset_index(drop=True)

a = a.reset_index(drop=True)
b = b.reset_index(drop=True)

# avoid comparing datetime64[ns] with datetime64[ms]
if "time" in a:
Expand All @@ -34,7 +33,6 @@ def __assert_equal(a: DataFrame, b: DataFrame, is_network=False) -> None:
return assert_frame_equal(a, b)


@pytest.mark.xfail(reason="Needs Model read implementation")
def test_basic(basic, tmp_path):
model_orig = basic
toml_path = tmp_path / "basic/ribasim.toml"
Expand All @@ -46,19 +44,10 @@ def test_basic(basic, tmp_path):

assert toml_dict["ribasim_version"] == ribasim.__version__

index_a = model_orig.network.node.df.index.to_numpy(int)
index_b = model_loaded.network.node.df.index.to_numpy(int)
assert_array_equal(index_a, index_b)
__assert_equal(
model_orig.network.node.df, model_loaded.network.node.df, is_network=True
)
__assert_equal(
model_orig.network.edge.df, model_loaded.network.edge.df, is_network=True
)
__assert_equal(model_orig.edge.df, model_loaded.edge.df, is_network=True)
assert model_loaded.basin.time.df is None


@pytest.mark.xfail(reason="Needs Model read implementation")
def test_basic_arrow(basic_arrow, tmp_path):
model_orig = basic_arrow
model_orig.write(tmp_path / "basic_arrow/ribasim.toml")
Expand All @@ -67,18 +56,12 @@ def test_basic_arrow(basic_arrow, tmp_path):
__assert_equal(model_orig.basin.profile.df, model_loaded.basin.profile.df)


@pytest.mark.xfail(reason="Needs Model read implementation")
def test_basic_transient(basic_transient, tmp_path):
model_orig = basic_transient
model_orig.write(tmp_path / "basic_transient/ribasim.toml")
model_loaded = ribasim.Model(filepath=tmp_path / "basic_transient/ribasim.toml")

__assert_equal(
model_orig.network.node.df, model_loaded.network.node.df, is_network=True
)
__assert_equal(
model_orig.network.edge.df, model_loaded.network.edge.df, is_network=True
)
__assert_equal(model_orig.edge.df, model_loaded.edge.df, is_network=True)

time = model_loaded.basin.time
assert model_orig.basin.time.df.time[0] == time.df.time[0]
Expand Down Expand Up @@ -111,7 +94,6 @@ def test_extra_columns(basic_transient):
terminal.Static(meta_id=[-1, -2, -3], extra=[-1, -2, -3])


@pytest.mark.xfail(reason="Needs Model read implementation")
def test_sort(level_setpoint_with_minmax, tmp_path):
model = level_setpoint_with_minmax
table = model.discrete_control.condition
Expand Down
3 changes: 1 addition & 2 deletions python/ribasim/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,6 @@ def test_node_ids_unsequential(basic):
model.validate_model_node_field_ids()


@pytest.mark.xfail(reason="Needs Model read implementation")
def test_tabulated_rating_curve_model(tabulated_rating_curve, tmp_path):
model_orig = tabulated_rating_curve
basin_area = tabulated_rating_curve.basin.area.df
Expand All @@ -155,7 +154,7 @@ def test_write_adds_fid_in_tables(basic, tmp_path):
# for node an explicit index was provided
nrow = len(model_orig.basin.node.df)
assert model_orig.basin.node.df.index.name is None
assert model_orig.basin.node.df.index.equals(pd.Index(np.full(nrow, 0)))

# for edge no index was provided, but it still needs to write it to file
nrow = len(model_orig.edge.df)
assert model_orig.edge.df.index.name is None
Expand Down

0 comments on commit e1f8b81

Please sign in to comment.