Skip to content

Commit

Permalink
Remove whitespace in schemas.py
Browse files Browse the repository at this point in the history
  • Loading branch information
visr committed Dec 13, 2024
1 parent 5c08e2f commit dfa1c72
Show file tree
Hide file tree
Showing 7 changed files with 31 additions and 14 deletions.
7 changes: 4 additions & 3 deletions python/ribasim/ribasim/geometry/edge.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,10 @@ class EdgeSchema(_GeoBaseSchema):
edge_type: Series[str] = pa.Field(default="flow")
geometry: GeoSeries[LineString] = pa.Field(default=None, nullable=True)

@classmethod
def _index_name(self) -> str:
return "edge_id"
@pa.dataframe_parser
def _name_index(cls, df):
df.index.name = "edge_id"
return df


class EdgeTable(SpatialTableModel[EdgeSchema]):
Expand Down
7 changes: 4 additions & 3 deletions python/ribasim/ribasim/geometry/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,10 @@ class NodeSchema(_GeoBaseSchema):
)
geometry: GeoSeries[Point] = pa.Field(default=None, nullable=True)

@classmethod
def _index_name(self) -> str:
return "node_id"
@pa.dataframe_parser
def _name_index(cls, df):
df.index.name = "node_id"
return df


class NodeTable(SpatialTableModel[NodeSchema]):
Expand Down
2 changes: 0 additions & 2 deletions python/ribasim/ribasim/input_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,6 @@ def _check_dataframe(cls, value: Any) -> Any:

# Enable initialization with a DataFrame.
if isinstance(value, pd.DataFrame | gpd.GeoDataFrame):
value.index.rename("fid", inplace=True)
value = {"df": value}

return value
Expand Down Expand Up @@ -386,7 +385,6 @@ def _from_db(cls, path: Path, table: str):
# tell pyarrow to map to pd.ArrowDtype rather than NumPy
arrow_to_pandas_kwargs={"types_mapper": pd.ArrowDtype},
)
df.index.rename(cls.tableschema()._index_name(), inplace=True)
else:
df = None

Expand Down
9 changes: 6 additions & 3 deletions python/ribasim/ribasim/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@ class Config:
add_missing_columns = True
coerce = True

@classmethod
def _index_name(self) -> str:
return "fid"
@pa.dataframe_parser
def _name_index(cls, df):
# Node and Edge have different index names, avoid running both parsers
if cls.__name__ not in ("NodeSchema", "EdgeSchema"):
df.index.name = "fid"
return df

@classmethod
def migrate(cls, df: Any, schema_version: int) -> Any:
Expand Down
5 changes: 5 additions & 0 deletions python/ribasim/tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,11 @@ def test_extra_columns():
def test_index_tables():
p = pump.Static(flow_rate=[1.2])
assert p.df.index.name == "fid"
# Index name is applied by _name_index
df = p.df.reset_index(drop=True)
assert df.index.name is None
p.df = df
assert p.df.index.name == "fid"


def test_extra_spatial_columns():
Expand Down
6 changes: 6 additions & 0 deletions python/ribasim/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,12 @@ def test_write_adds_fid_in_tables(basic, tmp_path):
assert model_orig.edge.df.index.name == "edge_id"
assert model_orig.edge.df.index.equals(pd.RangeIndex(1, nrow + 1))

# Index name is applied by _name_index
df = model_orig.edge.df.copy()
df.index.name = "other"
model_orig.edge.df = df
assert model_orig.edge.df.index.name == "edge_id"

model_orig.write(tmp_path / "basic/ribasim.toml")
with connect(tmp_path / "basic/database.gpkg") as connection:
query = f"select * from {esc_id('Basin / profile')}"
Expand Down
9 changes: 6 additions & 3 deletions utils/templates/schemas.py.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@ class _BaseSchema(pa.DataFrameModel):
add_missing_columns = True
coerce = True

@classmethod
def _index_name(self) -> str:
return "fid"
@pa.dataframe_parser
def _name_index(cls, df):
# Node and Edge have different index names, avoid running both parsers
if cls.__name__ not in ("NodeSchema", "EdgeSchema"):
df.index.name = "fid"
return df

@classmethod
def migrate(cls, df: Any, schema_version: int) -> Any:
Expand Down

0 comments on commit dfa1c72

Please sign in to comment.