diff --git a/python/ribasim/ribasim/geometry/edge.py b/python/ribasim/ribasim/geometry/edge.py index f6285e1c5..4f2d61946 100644 --- a/python/ribasim/ribasim/geometry/edge.py +++ b/python/ribasim/ribasim/geometry/edge.py @@ -49,9 +49,10 @@ class EdgeSchema(_GeoBaseSchema): edge_type: Series[str] = pa.Field(default="flow") geometry: GeoSeries[LineString] = pa.Field(default=None, nullable=True) - @classmethod - def _index_name(self) -> str: - return "edge_id" + @pa.dataframe_parser + def _name_index(cls, df): + df.index.name = "edge_id" + return df class EdgeTable(SpatialTableModel[EdgeSchema]): diff --git a/python/ribasim/ribasim/geometry/node.py b/python/ribasim/ribasim/geometry/node.py index 691dc4d19..b92354b63 100644 --- a/python/ribasim/ribasim/geometry/node.py +++ b/python/ribasim/ribasim/geometry/node.py @@ -27,9 +27,10 @@ class NodeSchema(_GeoBaseSchema): ) geometry: GeoSeries[Point] = pa.Field(default=None, nullable=True) - @classmethod - def _index_name(self) -> str: - return "node_id" + @pa.dataframe_parser + def _name_index(cls, df): + df.index.name = "node_id" + return df class NodeTable(SpatialTableModel[NodeSchema]): diff --git a/python/ribasim/ribasim/input_base.py b/python/ribasim/ribasim/input_base.py index c6613c908..48e0e03fd 100644 --- a/python/ribasim/ribasim/input_base.py +++ b/python/ribasim/ribasim/input_base.py @@ -217,7 +217,6 @@ def _check_dataframe(cls, value: Any) -> Any: # Enable initialization with a DataFrame. if isinstance(value, pd.DataFrame | gpd.GeoDataFrame): - value.index.rename("fid", inplace=True) value = {"df": value} return value @@ -386,7 +385,6 @@ def _from_db(cls, path: Path, table: str): # tell pyarrow to map to pd.ArrowDtype rather than NumPy arrow_to_pandas_kwargs={"types_mapper": pd.ArrowDtype}, ) - df.index.rename(cls.tableschema()._index_name(), inplace=True) else: df = None diff --git a/python/ribasim/ribasim/schemas.py b/python/ribasim/ribasim/schemas.py index 2bcd22222..8de01c774 100644 --- a/python/ribasim/ribasim/schemas.py +++ b/python/ribasim/ribasim/schemas.py @@ -16,9 +16,12 @@ class Config: add_missing_columns = True coerce = True - @classmethod - def _index_name(self) -> str: - return "fid" + @pa.dataframe_parser + def _name_index(cls, df): + # Node and Edge have different index names, avoid running both parsers + if cls.__name__ not in ("NodeSchema", "EdgeSchema"): + df.index.name = "fid" + return df @classmethod def migrate(cls, df: Any, schema_version: int) -> Any: diff --git a/python/ribasim/tests/test_io.py b/python/ribasim/tests/test_io.py index 0320dca87..7c2dcf7b7 100644 --- a/python/ribasim/tests/test_io.py +++ b/python/ribasim/tests/test_io.py @@ -97,6 +97,11 @@ def test_extra_columns(): def test_index_tables(): p = pump.Static(flow_rate=[1.2]) assert p.df.index.name == "fid" + # Index name is applied by _name_index + df = p.df.reset_index(drop=True) + assert df.index.name is None + p.df = df + assert p.df.index.name == "fid" def test_extra_spatial_columns(): diff --git a/python/ribasim/tests/test_model.py b/python/ribasim/tests/test_model.py index f5e341561..9336b5275 100644 --- a/python/ribasim/tests/test_model.py +++ b/python/ribasim/tests/test_model.py @@ -112,6 +112,12 @@ def test_write_adds_fid_in_tables(basic, tmp_path): assert model_orig.edge.df.index.name == "edge_id" assert model_orig.edge.df.index.equals(pd.RangeIndex(1, nrow + 1)) + # Index name is applied by _name_index + df = model_orig.edge.df.copy() + df.index.name = "other" + model_orig.edge.df = df + assert model_orig.edge.df.index.name == "edge_id" + model_orig.write(tmp_path / "basic/ribasim.toml") with connect(tmp_path / "basic/database.gpkg") as connection: query = f"select * from {esc_id('Basin / profile')}" diff --git a/utils/templates/schemas.py.jinja b/utils/templates/schemas.py.jinja index 19b53c7ce..6e79f6953 100644 --- a/utils/templates/schemas.py.jinja +++ b/utils/templates/schemas.py.jinja @@ -15,9 +15,12 @@ class _BaseSchema(pa.DataFrameModel): add_missing_columns = True coerce = True - @classmethod - def _index_name(self) -> str: - return "fid" + @pa.dataframe_parser + def _name_index(cls, df): + # Node and Edge have different index names, avoid running both parsers + if cls.__name__ not in ("NodeSchema", "EdgeSchema"): + df.index.name = "fid" + return df @classmethod def migrate(cls, df: Any, schema_version: int) -> Any: