diff --git a/pyhtml/__tag_base.py b/pyhtml/__tag_base.py index 456a2bb..6ba97ab 100644 --- a/pyhtml/__tag_base.py +++ b/pyhtml/__tag_base.py @@ -18,7 +18,7 @@ def __init__(self, *children: Any, **properties: Any) -> None: """ Create a new tag instance """ - self.children = list(children) + self.children = util.flatten_list(list(children)) """Children of this tag""" self.properties = util.filter_properties(properties) @@ -34,7 +34,7 @@ def __call__( properties are based on this original tag, but with additional children appended and additional properties unioned. """ - new_children = self.children + list(children) + new_children = self.children + util.flatten_list(list(children)) new_properties = self.properties | properties return self.__class__(*new_children, **new_properties) diff --git a/pyhtml/__util.py b/pyhtml/__util.py index 73fce0a..0a346f1 100644 --- a/pyhtml/__util.py +++ b/pyhtml/__util.py @@ -3,7 +3,10 @@ Random helpful functions used elsewhere """ -from typing import Any +from typing import Any, TypeVar + + +T = TypeVar('T') def increase_indent(text: list[str], amount: int) -> list[str]: @@ -111,3 +114,18 @@ def render_children(children: list[Any], sep: str = ' ') -> list[str]: for ele in children: rendered.extend(render_inline_element(ele)) return increase_indent(rendered, 2) + + +def flatten_list(the_list: list[T | list[T]]) -> list[T]: + """ + Flatten a list by taking any list elements and inserting their items + individually. Note that other iterables (such as str and tuple) are not + flattened. + """ + result: list[T] = [] + for item in the_list: + if isinstance(item, list): + result.extend(item) + else: + result.append(item) + return result diff --git a/tests/basic_rendering_test.py b/tests/basic_rendering_test.py index 91bfd16..4c90a40 100644 --- a/tests/basic_rendering_test.py +++ b/tests/basic_rendering_test.py @@ -122,3 +122,22 @@ def test_format_through_repr(): doc = html() assert repr(doc) == "" + + +def test_flatten_element_lists(): + """ + If a list of elements is given as a child element, each element should be + considered as a child. + """ + doc = html([p("Hello"), p("world")]) + + assert repr(doc) == "\n".join([ + "", + "

", + " Hello", + "

", + "

", + " world", + "

", + "", + ])