diff --git a/src/django_simple_nav/nav.py b/src/django_simple_nav/nav.py index 905efa7..73be0fd 100644 --- a/src/django_simple_nav/nav.py +++ b/src/django_simple_nav/nav.py @@ -18,20 +18,27 @@ class Nav: template_name: str = field(init=False) items: list[NavGroup | NavItem] = field(init=False) - @classmethod - def render_from_request( - cls, request: HttpRequest, template_name: str | None = None - ) -> str: + def get_context_data(self, request: HttpRequest) -> dict[str, Any]: items = [ RenderedNavItem(item, request) - for item in cls.items + for item in self.items if check_item_permissions(item, request.user) # type: ignore[arg-type] ] + return {"items": items} + + def render(self, request: HttpRequest, template_name: str | None = None) -> str: + context = self.get_context_data(request) return render_to_string( - template_name=template_name or cls.template_name, - context={"items": items}, + template_name=template_name or self.template_name, + context=context, ) + @classmethod + def render_from_request( + cls, request: HttpRequest, template_name: str | None = None + ) -> str: + return cls().render(request, template_name) + @dataclass(frozen=True) class NavGroup: diff --git a/tests/test_nav.py b/tests/test_nav.py index ea97d5c..e5e864b 100644 --- a/tests/test_nav.py +++ b/tests/test_nav.py @@ -185,3 +185,43 @@ def test_extra_context_builtins(req): assert rendered_group_item.permissions == ["is_staff"] assert rendered_group_item.extra_context == {"foo": "bar"} assert rendered_group_item.foo == "bar" + + +def test_get_context_data(req): + req.user = baker.make(get_user_model()) + + context = DummyNav().get_context_data(req) + + assert context["items"] + + +def test_get_context_data_override(req): + class OverrideNav(DummyNav): + def get_context_data(self, request): + return {"foo": "bar"} + + req.user = baker.make(get_user_model()) + + context = OverrideNav().get_context_data(req) + + assert context["foo"] == "bar" + + +def test_render(req): + req.user = baker.make(get_user_model()) + + rendered_template = DummyNav().render(req) + + assert count_anchors(rendered_template) == 10 + + +def test_render_override(req): + class OverrideNav(DummyNav): + def get_context_data(self, request): + return {"foo": "bar"} + + req.user = baker.make(get_user_model()) + + rendered_template = OverrideNav().render(req) + + assert count_anchors(rendered_template) == 0