Skip to content

Commit

Permalink
Add tests for CrossSectionCollection
Browse files Browse the repository at this point in the history
  • Loading branch information
ryan-kipawa committed Mar 6, 2024
1 parent 032907f commit 9c993c7
Show file tree
Hide file tree
Showing 2 changed files with 181 additions and 1 deletion.
23 changes: 22 additions & 1 deletion mikeio1d/cross_sections/cross_section_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,30 @@

class CrossSectionCollection(Dict[Tuple[LocationId, Chainage, TopoId], CrossSection]):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if args and isinstance(args[0], list):
self._handle_args(*args)
else:
super().__init__(*args, **kwargs)

self.xns11: Xns11 | None = None

self._validate()

def _handle_args(self, *args):
if not isinstance(args[0][0], CrossSection):
raise ValueError("Input must be a list of CrossSection objects")
for xs in args[0]:
self[xs.location_id, f"{xs.chainage:.3f}", xs.topo_id] = xs

def _validate(self):
for key, cross_section in self.items():
if key[0] != cross_section.location_id:
raise ValueError(f"Location ID mismatch: {key[0]} != {cross_section.location_id}")
if key[1] != f"{cross_section.chainage:.3f}":
raise ValueError(f"Chainage mismatch: {key[1]} != {cross_section.chainage:.3f}")
if key[2] != cross_section.topo_id:
raise ValueError(f"Topo ID mismatch: {key[2]} != {cross_section.topo_id}")

def __repr__(self) -> str:
return f"<CrossSectionCollection {len(self)}>"

Expand Down
159 changes: 159 additions & 0 deletions tests/test_cross_section_collection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
import pytest

from typing import List

from IPython.terminal.interactiveshell import TerminalInteractiveShell

from mikeio1d.cross_sections import CrossSection
from mikeio1d.cross_sections import CrossSectionCollection

from .test_cross_section import create_xz_data


@pytest.fixture()
def shell():
shell = TerminalInteractiveShell()
shell.run_cell(
"""
from mikeio1d.cross_sections import CrossSectionCollection
"""
)
return shell


def complete(shell, prompt) -> List[str]:
prompt, completions = shell.complete(prompt)
completions = [c[len(prompt) :] for c in completions]
return completions


def create_dummy_cross_section(location_id, chainage, topo_id):
x, z = create_xz_data()
return CrossSection.from_xz(x, z, location_id=location_id, chainage=chainage, topo_id=topo_id)


@pytest.fixture
def dummy_cross_section():
return create_dummy_cross_section("loc1", 100, "topo1")


@pytest.fixture
def many_dummy_cross_sections():
xs = []
for i in range(0, 100, 10):
xs.append(create_dummy_cross_section(f"loc{i}", i, "topo"))
for i in range(0, 100, 10):
xs.append(create_dummy_cross_section(f"loc{i}", i, "topo2"))
return xs


class TestCrossSectionCollectionUnits:
"""
Unit tests for the CrossSectionCollection class.
"""

def test_create_empty_collection(self):
c = CrossSectionCollection()
assert len(c) == 0

def test_create_collection_from_list(self, many_dummy_cross_sections):
csc = CrossSectionCollection(many_dummy_cross_sections)
assert len(csc) == 20

def test_create_collection_from_dict(self, dummy_cross_section):
csc = CrossSectionCollection(
{
("loc1", "100.000", "topo1"): dummy_cross_section,
}
)
assert len(csc) == 1

with pytest.raises(ValueError):
csc = CrossSectionCollection(
{
("loc1", "100.000", "topo1"): dummy_cross_section,
("not_matching_xs", "100.000", "topo1"): dummy_cross_section,
}
)

def test_get_item(self, many_dummy_cross_sections):
csc = CrossSectionCollection(many_dummy_cross_sections)
assert csc["loc0", "0.000", "topo"] == many_dummy_cross_sections[0]
assert csc["loc90", "90.000", "topo2"] == many_dummy_cross_sections[-1]

@pytest.mark.parametrize("slice_char", [..., slice(None)])
def test_get_item_slice(self, many_dummy_cross_sections, slice_char):
csc = CrossSectionCollection(many_dummy_cross_sections)

sliced = csc["loc0", slice_char, slice_char]
assert len(csc["loc0", slice_char, slice_char]) == 2
for xs in sliced.values():
assert xs.location_id == "loc0"

sliced = csc[slice_char, slice_char, "topo2"]
assert len(sliced) == 10
for xs in sliced.values():
assert xs.topo_id == "topo2"

sliced = csc[slice_char, "50.000", slice_char]
assert len(sliced) == 2
for xs in sliced.values():
assert xs.chainage == 50

sliced = csc["loc0"]
assert len(sliced) == 2
for xs in sliced.values():
assert xs.location_id == "loc0"

sliced = csc["loc50", "50.000"]
assert len(sliced) == 2
for xs in sliced.values():
assert xs.location_id == "loc50"
assert xs.chainage == 50

@pytest.mark.parametrize(
"prompt,expected_completions",
[
(
"csc['",
[
"loc0",
"loc10",
"loc20",
"loc30",
"loc40",
"loc50",
"loc60",
"loc70",
"loc80",
"loc90",
],
),
(
"csc['loc0', '",
[
"0.000",
"10.000",
"20.000",
"30.000",
"40.000",
"50.000",
"60.000",
"70.000",
"80.000",
"90.000",
],
),
(
"csc['loc0', '0.000', '",
["topo", "topo2"],
),
],
)
def test_index_autocompletion(
self, many_dummy_cross_sections, shell, prompt, expected_completions
):
cross_sections = many_dummy_cross_sections
shell.push({"csc": CrossSectionCollection(cross_sections)})
completions = complete(shell, prompt)
assert completions == expected_completions

0 comments on commit 9c993c7

Please sign in to comment.