Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/chainage indexing #86

Merged
merged 12 commits into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions mikeio1d/result_network/result_gridpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,15 @@ def __init__(self, reach, gridpoint, data_items, result_reach, res1d):
self.structure_data_items = []
self.element_indices = []

self.set_static_attributes()

def set_static_attributes(self):
"""Set static attributes. These show up in the html repr."""
self.set_static_attribute("reach_name", self.reach.Name)
self.set_static_attribute("chainage", self.gridpoint.Chainage)
self.set_static_attribute("xcoord", self.gridpoint.X)
self.set_static_attribute("ycoord", self.gridpoint.Y)

def get_m1d_dataset(self, m1d_dataitem=None):
"""Get IRes1DDataSet object associated with ResultGridPoint.

Expand Down
21 changes: 18 additions & 3 deletions mikeio1d/result_network/result_reach.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from __future__ import annotations
import warnings
from typing import TYPE_CHECKING
from typing import Dict

if TYPE_CHECKING:
from ..geometry import ReachGeometry
from typing import List

from .result_location import ResultLocation
from .result_gridpoint import ResultGridPoint
Expand All @@ -13,7 +15,7 @@
from DHI.Mike1D.ResultDataAccess import Res1DGridPoint


class ResultReach(ResultLocation):
class ResultReach(ResultLocation, Dict[str, ResultGridPoint]):
"""
Class for wrapping a list of ResultData reaches
having the same reach name.
Expand Down Expand Up @@ -66,8 +68,10 @@ def __getattr__(self, name: str):
else:
object.__getattribute__(self, name)

def __getitem__(self, index):
return self.reaches[index]
def __getitem__(self, key: str | int) -> ResultGridPoint:
if isinstance(key, int):
return self.gridpoints[key]
return super().__getitem__(key)

def _get_total_length(self):
total_length = 0
Expand All @@ -78,6 +82,14 @@ def _get_total_length(self):
def _get_total_gridpoints(self):
return sum([len(gp_list) for gp_list in self.result_gridpoints])

@property
def chainages(self) -> List[str]:
return list(self.keys())

@property
def gridpoints(self) -> List[ResultGridPoint]:
return list(self.values())

def set_static_attributes(self):
"""Set static attributes. These show up in the html repr."""
self.set_static_attribute("name", self.reaches[0].Name)
Expand Down Expand Up @@ -185,6 +197,9 @@ def set_gridpoint(self, reach, gridpoint):
)
setattr(self, result_gridpoint_attribute_string, result_gridpoint)

chainage_str = f"{gridpoint.Chainage:.3f}"
self[chainage_str] = result_gridpoint

def set_gridpoint_data_items(self, reach):
"""
Assign data items to ResultGridPoint object belonging to current ResultReach
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
"sphinx-copybutton",
"myst-parser",
],
"test": ["pytest", "matplotlib", "pyarrow", "nbformat", "nbconvert", "ipython"],
"test": ["pytest", "matplotlib", "pyarrow", "nbformat", "nbconvert", "ipykernel"],
"all": ["matplotlib", "geopandas"],
},
Expand Down
37 changes: 37 additions & 0 deletions tests/test_network_autocompletion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""
The purpose of this test is to verify that the autocompletion of a network is working as expected.
"""

import pytest
from IPython.terminal.interactiveshell import TerminalInteractiveShell

from typing import List


@pytest.fixture(scope="module")
def shell():
shell = TerminalInteractiveShell()
shell.run_cell(
"""
from mikeio1d import Res1D
from tests import testdata
res = Res1D(testdata.network_river_res1d)
"""
)
return shell


def complete(shell, prompt) -> List[str]:
prompt, completions = shell.complete(prompt)
completions = [c[len(prompt) :] for c in completions]
return completions


def test_reach_names(shell):
completions = complete(shell, "res.reaches['")
assert "river" in completions


def test_reach_chainages(shell):
completions = complete(shell, "res.reaches['river']['")
assert "53100.000" in completions
4 changes: 0 additions & 4 deletions tests/test_res1d_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,3 @@ def test_reaches_dict_access_maintains_backwards_compatibility(res1d_network, re
reach = res1d_network.result_network.reaches["100l1"]
assert reach.Name == "100l1"
assert reach.Length == pytest.approx(47.6827148432828)
# Or it could include multiple reaches
reach = res1d_river_network.result_network.reaches["river"]
assert iter(reach), "Should be iterable where there is several subreaches"
assert reach[0].Id == "river-12"
Loading