Skip to content

Commit

Permalink
Merge pull request #86 from DHI/feature/chainage_indexing
Browse files Browse the repository at this point in the history
Feature/chainage indexing
  • Loading branch information
ryan-kipawa authored Sep 19, 2024
2 parents 55501b1 + 8cdafc5 commit 0fa979c
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 7 deletions.
9 changes: 9 additions & 0 deletions mikeio1d/result_network/result_gridpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,15 @@ def __init__(self, reach, gridpoint, data_items, result_reach, res1d):
self.structure_data_items = []
self.element_indices = []

self.set_static_attributes()

def set_static_attributes(self):
"""Set static attributes. These show up in the html repr."""
self.set_static_attribute("reach_name", self.reach.Name)
self.set_static_attribute("chainage", self.gridpoint.Chainage)
self.set_static_attribute("xcoord", self.gridpoint.X)
self.set_static_attribute("ycoord", self.gridpoint.Y)

def get_m1d_dataset(self, m1d_dataitem=None):
"""Get IRes1DDataSet object associated with ResultGridPoint.
Expand Down
21 changes: 18 additions & 3 deletions mikeio1d/result_network/result_reach.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from __future__ import annotations
import warnings
from typing import TYPE_CHECKING
from typing import Dict

if TYPE_CHECKING:
from ..geometry import ReachGeometry
from typing import List

from .result_location import ResultLocation
from .result_gridpoint import ResultGridPoint
Expand All @@ -13,7 +15,7 @@
from DHI.Mike1D.ResultDataAccess import Res1DGridPoint


class ResultReach(ResultLocation):
class ResultReach(ResultLocation, Dict[str, ResultGridPoint]):
"""
Class for wrapping a list of ResultData reaches
having the same reach name.
Expand Down Expand Up @@ -66,8 +68,10 @@ def __getattr__(self, name: str):
else:
object.__getattribute__(self, name)

def __getitem__(self, index):
return self.reaches[index]
def __getitem__(self, key: str | int) -> ResultGridPoint:
if isinstance(key, int):
return self.gridpoints[key]
return super().__getitem__(key)

def _get_total_length(self):
total_length = 0
Expand All @@ -78,6 +82,14 @@ def _get_total_length(self):
def _get_total_gridpoints(self):
return sum([len(gp_list) for gp_list in self.result_gridpoints])

@property
def chainages(self) -> List[str]:
return list(self.keys())

@property
def gridpoints(self) -> List[ResultGridPoint]:
return list(self.values())

def set_static_attributes(self):
"""Set static attributes. These show up in the html repr."""
self.set_static_attribute("name", self.reaches[0].Name)
Expand Down Expand Up @@ -185,6 +197,9 @@ def set_gridpoint(self, reach, gridpoint):
)
setattr(self, result_gridpoint_attribute_string, result_gridpoint)

chainage_str = f"{gridpoint.Chainage:.3f}"
self[chainage_str] = result_gridpoint

def set_gridpoint_data_items(self, reach):
"""
Assign data items to ResultGridPoint object belonging to current ResultReach
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
"sphinx-copybutton",
"myst-parser",
],
"test": ["pytest", "matplotlib", "pyarrow", "nbformat", "nbconvert", "ipython"],
"test": ["pytest", "matplotlib", "pyarrow", "nbformat", "nbconvert", "ipykernel"],
"all": ["matplotlib", "geopandas"],
},
Expand Down
37 changes: 37 additions & 0 deletions tests/test_network_autocompletion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""
The purpose of this test is to verify that the autocompletion of a network is working as expected.
"""

import pytest
from IPython.terminal.interactiveshell import TerminalInteractiveShell

from typing import List


@pytest.fixture(scope="module")
def shell():
shell = TerminalInteractiveShell()
shell.run_cell(
"""
from mikeio1d import Res1D
from tests import testdata
res = Res1D(testdata.network_river_res1d)
"""
)
return shell


def complete(shell, prompt) -> List[str]:
prompt, completions = shell.complete(prompt)
completions = [c[len(prompt) :] for c in completions]
return completions


def test_reach_names(shell):
completions = complete(shell, "res.reaches['")
assert "river" in completions


def test_reach_chainages(shell):
completions = complete(shell, "res.reaches['river']['")
assert "53100.000" in completions
4 changes: 0 additions & 4 deletions tests/test_res1d_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,3 @@ def test_reaches_dict_access_maintains_backwards_compatibility(res1d_network, re
reach = res1d_network.result_network.reaches["100l1"]
assert reach.Name == "100l1"
assert reach.Length == pytest.approx(47.6827148432828)
# Or it could include multiple reaches
reach = res1d_river_network.result_network.reaches["river"]
assert iter(reach), "Should be iterable where there is several subreaches"
assert reach[0].Id == "river-12"

0 comments on commit 0fa979c

Please sign in to comment.