diff --git a/astar_utils/nested_mapping.py b/astar_utils/nested_mapping.py index e4e8ce7..5e348bf 100644 --- a/astar_utils/nested_mapping.py +++ b/astar_utils/nested_mapping.py @@ -227,24 +227,30 @@ def _repr_pretty_(self, printer, cycle): printer.text(str(self)) -class RecursiveNestedMapping(NestedMapping): - """Like NestedMapping but internally resolves any bang-string values. - - In the event of an infinite loop of recursive bang-string keys pointing - back to each other, this should savely and quickly throw a - ``RecursionError``. - """ +class RecursiveMapping: + """Mixin class just to factor out resolving string key functionality.""" def __getitem__(self, key: str): """x.__getitem__(y) <==> x[y].""" - value = super().__getitem__(key) - while is_bangkey(value): + value = super().__getitem__(key.removesuffix("!")) + + if is_bangkey(value) and is_resolving_key(key): try: - value = self[value] + value = self[f"{value}!"] except KeyError: - return value + pass # return value unresolved + return value + +class RecursiveNestedMapping(RecursiveMapping, NestedMapping): + """Like NestedMapping but internally resolves any bang-string values. + + In the event of an infinite loop of recursive bang-string keys pointing + back to each other, this should savely and quickly throw a + ``RecursionError``. + """ + @classmethod def from_maps(cls, maps, key): """Yield instances from maps if key is found.""" @@ -255,7 +261,7 @@ def from_maps(cls, maps, key): mapping[key], title=f"[{i}] mapping") -class NestedChainMap(ChainMap): +class NestedChainMap(RecursiveMapping, ChainMap): """Subclass of ``collections.ChainMap`` using ``RecursiveNestedMapping``. Only overrides ``__getitem__`` to allow for both recursive bang-string keys @@ -278,16 +284,14 @@ def __getitem__(self, key): """x.__getitem__(y) <==> x[y].""" value = super().__getitem__(key) - if isinstance(value, abc.Mapping): - submaps = tuple(RecursiveNestedMapping.from_maps(self.maps, key)) - if len(submaps) == 1: - # Don't need the chain if it's just one... - return submaps[0] - return NestedChainMap(*submaps) + if not isinstance(value, abc.Mapping): + return value - if is_bangkey(value): - value = self[value] - return value + submaps = tuple(RecursiveNestedMapping.from_maps(self.maps, key)) + if len(submaps) == 1: + # Don't need the chain if it's just one... + return submaps[0] + return NestedChainMap(*submaps) def __str__(self): """Return str(self).""" @@ -306,6 +310,11 @@ def is_bangkey(key) -> bool: return isinstance(key, str) and key.startswith("!") +def is_resolving_key(key) -> bool: + """Return ``True`` if the key is a ``str`` and ends with a "!".""" + return isinstance(key, str) and key.endswith("!") + + def is_nested_mapping(mapping) -> bool: """Return ``True`` if `mapping` contains any further map as a value.""" if not isinstance(mapping, abc.Mapping): diff --git a/tests/test_nested_mapping.py b/tests/test_nested_mapping.py index fbd10b8..711c2d2 100644 --- a/tests/test_nested_mapping.py +++ b/tests/test_nested_mapping.py @@ -230,7 +230,15 @@ def test_repr_pretty(self, nested_nestmap): class TestRecursiveNestedMapping: - def test_resolves_bangs(self): + @pytest.mark.parametrize(("key", "result"), (("bar", "!foo"), + ("bar!", "a"))) + def test_resolves_bangs(self, key, result): + rnm = RecursiveNestedMapping({"foo": "a", "bar": "!foo"}) + assert rnm[key] == result + + @pytest.mark.parametrize(("key", "result"), (("!foo.b", "!bar.y"), + ("!foo.b!", 42))) + def test_resolves_bangs_multistage(self, key, result): rnm = RecursiveNestedMapping( {"foo": { "a": "!bar.x", @@ -241,7 +249,7 @@ def test_resolves_bangs(self): "y": "!foo.a", }, }) - assert rnm["!foo.b"] == 42 + assert rnm[key] == result def test_infinite_loop(self): rnm = RecursiveNestedMapping( @@ -255,7 +263,7 @@ def test_infinite_loop(self): }, }) with pytest.raises(RecursionError): - rnm["!foo.b"] + rnm["!foo.b!"] def test_returns_unresolved_as_is(self): rnm = RecursiveNestedMapping( @@ -264,12 +272,21 @@ def test_returns_unresolved_as_is(self): "b": "!bar.y", }, }) - assert rnm["!foo.b"] == "!bar.y" + assert rnm["!foo.b!"] == "!bar.y" class TestNestedChainMap: - def test_resolves_bangs(self, simple_nestchainmap): - assert simple_nestchainmap["!foo.a"] == "bogus" + @pytest.mark.parametrize(("key", "result"), (("!foo.a", "!foo.b"), + ("!foo.a!", "bogus"))) + def test_resolves_bangs(self, simple_nestchainmap, key, result): + assert simple_nestchainmap[key] == result + + def test_returns_unresolved_as_is(self): + ncm = NestedChainMap( + RecursiveNestedMapping({"foo": {"a": "!foo.b"}}), + RecursiveNestedMapping({"foo": {"b": "!foo.c"}}) + ) + assert ncm["!foo.a!"] == "!foo.c" def test_infinite_loop(self): ncm = NestedChainMap( @@ -277,7 +294,7 @@ def test_infinite_loop(self): RecursiveNestedMapping({"foo": {"b": "!foo.a"}}) ) with pytest.raises(RecursionError): - ncm["!foo.a"] + ncm["!foo.a!"] def test_repr_pretty(self, simple_nestchainmap): printer = Mock()