diff --git a/tests/test_api.py b/tests/test_api.py index 1acc26f..38f0279 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -280,3 +280,15 @@ def test_deepcopy_timezone(): o2 = copy.deepcopy(o) assert o2["dob"] == o["dob"] assert o2["dob"] is not o["dob"] + + +def test_list_of_inline_tables_with_preserve(): + test_str = """[[test]] +x = [ { name = "test1" }, { name = "test2" },] + +""" + s = toml.dumps(toml.loads(test_str, + decoder=toml.TomlDecoder()), + encoder=toml.TomlEncoder(preserve_list=True)) + + assert len(s) == len(test_str) and sorted(test_str) == sorted(s) diff --git a/toml/encoder.py b/toml/encoder.py index bf17a72..44f9a43 100644 --- a/toml/encoder.py +++ b/toml/encoder.py @@ -128,9 +128,10 @@ def _dump_time(v): class TomlEncoder(object): - def __init__(self, _dict=dict, preserve=False): + def __init__(self, _dict=dict, preserve=False, preserve_list=False): self._dict = _dict self.preserve = preserve + self.preserve_list = preserve_list self.dump_funcs = { str: _dump_str, unicode: _dump_str, @@ -154,7 +155,7 @@ def dump_list(self, v): retval += "]" return retval - def dump_inline_table(self, section): + def dump_inline_table(self, section, with_newline=True): """Preserve inline table in its compact syntax instead of expanding into subsection. @@ -166,14 +167,21 @@ def dump_inline_table(self, section): for k, v in section.items(): val = self.dump_inline_table(v) val_list.append(k + " = " + val) - retval += "{ " + ", ".join(val_list) + " }\n" + retval += "{ " + ", ".join(val_list) + " }" + + if with_newline: + retval += "\n" + return retval else: return unicode(self.dump_value(section)) + def dump_inline_table_value(self, value): + return self.dump_inline_table(value, False) + def dump_value(self, v): # Lookup function corresponding to v's type - dump_fn = self.dump_funcs.get(type(v)) + dump_fn = self.dump_inline_table_value if isinstance(v, InlineTableDict) else self.dump_funcs.get(type(v)) if dump_fn is None and hasattr(v, '__iter__'): dump_fn = self.dump_funcs[list] # Evaluate function (if it exists) else return v @@ -192,11 +200,23 @@ def dump_sections(self, o, sup): qsection = _dump_str(section) if not isinstance(o[section], dict): arrayoftables = False + inlines = [] + if isinstance(o[section], list): for a in o[section]: if isinstance(a, dict): arrayoftables = True if arrayoftables: + for a in o[section]: + if isinstance(a, InlineTableDict): + inlines.append(a) + + if self.preserve_list and len(inlines) == len(o[section]): + retstr += (qsection + " = " + + unicode(self.dump_value(o[section])) + "\n") + + continue + for a in o[section]: arraytabstr = "\n" arraystr += "[[" + sup + qsection + "]]\n" @@ -235,14 +255,14 @@ def dump_sections(self, o, sup): class TomlPreserveInlineDictEncoder(TomlEncoder): - def __init__(self, _dict=dict): - super(TomlPreserveInlineDictEncoder, self).__init__(_dict, True) + def __init__(self, _dict=dict, preserve_list=False): + super(TomlPreserveInlineDictEncoder, self).__init__(_dict, True, preserve_list) class TomlArraySeparatorEncoder(TomlEncoder): - def __init__(self, _dict=dict, preserve=False, separator=","): - super(TomlArraySeparatorEncoder, self).__init__(_dict, preserve) + def __init__(self, _dict=dict, preserve=False, separator=",", preserve_list=False): + super(TomlArraySeparatorEncoder, self).__init__(_dict, preserve, preserve_list) if separator.strip() == "": separator = "," + separator elif separator.strip(' \t\n\r,'): @@ -269,9 +289,9 @@ def dump_list(self, v): class TomlNumpyEncoder(TomlEncoder): - def __init__(self, _dict=dict, preserve=False): + def __init__(self, _dict=dict, preserve=False, preserve_list=False): import numpy as np - super(TomlNumpyEncoder, self).__init__(_dict, preserve) + super(TomlNumpyEncoder, self).__init__(_dict, preserve, preserve_list) self.dump_funcs[np.float16] = _dump_float self.dump_funcs[np.float32] = _dump_float self.dump_funcs[np.float64] = _dump_float @@ -285,9 +305,9 @@ def _dump_int(self, v): class TomlPreserveCommentEncoder(TomlEncoder): - def __init__(self, _dict=dict, preserve=False): + def __init__(self, _dict=dict, preserve=False, preserve_list=False): from toml.decoder import CommentValue - super(TomlPreserveCommentEncoder, self).__init__(_dict, preserve) + super(TomlPreserveCommentEncoder, self).__init__(_dict, preserve, preserve_list) self.dump_funcs[CommentValue] = lambda v: v.dump(self.dump_value) diff --git a/toml/encoder.pyi b/toml/encoder.pyi index 194a358..f8b12be 100644 --- a/toml/encoder.pyi +++ b/toml/encoder.pyi @@ -8,27 +8,29 @@ def dumps(o: Mapping[str, Any], encoder: TomlEncoder = ...) -> str: ... class TomlEncoder: preserve: Any = ... + preserve_list: Any = ... dump_funcs: Any = ... - def __init__(self, _dict: Any = ..., preserve: bool = ...): ... + def __init__(self, _dict: Any = ..., preserve: bool = ..., preserve_list: bool = ...): ... def get_empty_table(self): ... def dump_list(self, v: Any): ... - def dump_inline_table(self, section: Any): ... + def dump_inline_table(self, section: Any, with_newline: bool = ...): ... + def dump_inline_table_value(self, section: Any): ... def dump_value(self, v: Any): ... def dump_sections(self, o: Any, sup: Any): ... class TomlPreserveInlineDictEncoder(TomlEncoder): - def __init__(self, _dict: Any = ...) -> None: ... + def __init__(self, _dict: Any = ..., preserve_list: bool = ...) -> None: ... class TomlArraySeparatorEncoder(TomlEncoder): separator: Any = ... - def __init__(self, _dict: Any = ..., preserve: bool = ..., separator: str = ...) -> None: ... + def __init__(self, _dict: Any = ..., preserve: bool = ..., separator: str = ..., preserve_list: bool = ...) -> None: ... def dump_list(self, v: Any): ... class TomlNumpyEncoder(TomlEncoder): - def __init__(self, _dict: Any = ..., preserve: bool = ...) -> None: ... + def __init__(self, _dict: Any = ..., preserve: bool = ..., preserve_list: bool = ...) -> None: ... class TomlPreserveCommentEncoder(TomlEncoder): - def __init__(self, _dict: Any = ..., preserve: bool = ...): ... + def __init__(self, _dict: Any = ..., preserve: bool = ..., preserve_list: bool = ...): ... class TomlPathlibEncoder(TomlEncoder): def dump_value(self, v: Any): ...