Skip to content
This repository has been archived by the owner on Dec 10, 2024. It is now read-only.

SEXP: Add stack and debugging information to decoder #59

Merged
merged 1 commit into from
Sep 12, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 68 additions & 33 deletions src/faebryk/libs/sexp/dataclass_sexp.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,43 +61,74 @@ def from_field(cls, f: Field):
class SymEnum(StrEnum): ...


def _convert(val, t):
# Recurse (GenericAlias e.g list[])
if (origin := get_origin(t)) is not None:
args = get_args(t)
if origin is list:
return [_convert(_val, args[0]) for _val in val]
if origin is tuple:
return tuple(_convert(_val, _t) for _val, _t in zip(val, args))
if origin in (Union, UnionType) and len(args) == 2 and args[1] is type(None):
return _convert(val, args[0]) if val is not None else None
class DecodeError(Exception):
"""Error during decoding"""


def _convert(
val,
t,
stack: list[tuple[str, type]] | None = None,
name: str | None = None,
):
if name is None:
name = "<" + t.__name__ + ">"
if stack is None:
stack = []
substack = stack + [(name, t)]

raise NotImplementedError(f"{origin} not supported")
try:
# Recurse (GenericAlias e.g list[])
if (origin := get_origin(t)) is not None:
args = get_args(t)
if origin is list:
return [_convert(_val, args[0], substack) for _val in val]
if origin is tuple:
return tuple(
_convert(_val, _t, substack) for _val, _t in zip(val, args)
)
if (
origin in (Union, UnionType)
and len(args) == 2
and args[1] is type(None)
):
return _convert(val, args[0], substack) if val is not None else None

#
if is_dataclass(t):
return _decode(val, t)
raise NotImplementedError(f"{origin} not supported")

# Primitive
#
if is_dataclass(t):
return _decode(val, t, substack)

# Unpack list if single atom
if isinstance(val, list) and len(val) == 1 and not isinstance(val[0], list):
val = val[0]
# Primitive

if issubclass(t, bool):
assert val in [Symbol("yes"), Symbol("no")]
return val == Symbol("yes")
if isinstance(val, Symbol):
return t(str(val))
# Unpack list if single atom
if isinstance(val, list) and len(val) == 1 and not isinstance(val[0], list):
val = val[0]

return t(val)
if issubclass(t, bool):
assert val in [Symbol("yes"), Symbol("no")]
return val == Symbol("yes")
if isinstance(val, Symbol):
return t(str(val))

return t(val)
except DecodeError:
raise
except Exception as e:
pretty_stack = ".".join(s[0] for s in substack)
raise DecodeError(f"Failed to decode {pretty_stack} ({t}) with {val} ") from e


netlist_obj = str | Symbol | int | float | bool | list
netlist_type = list[netlist_obj]


def _decode[T](sexp: netlist_type, t: type[T], parent: Any | None = None) -> T:
def _decode[T](
sexp: netlist_type,
t: type[T],
stack: list[tuple[str, type]] | None = None,
) -> T:
if logger.isEnabledFor(logging.DEBUG):
logger.debug(f"parse into: {t.__name__} {'-'*40}")
logger.debug(f"sexp: {sexp}")
Expand Down Expand Up @@ -126,9 +157,9 @@ def _decode[T](sexp: netlist_type, t: type[T], parent: Any | None = None) -> T:
and isinstance(key := val[0], Symbol)
and (str(key) + "s" in key_fields or str(key) in key_fields)
),
lambda val: str(val[0]) + "s"
if str(val[0]) + "s" in key_fields
else str(val[0]),
lambda val: (
str(val[0]) + "s" if str(val[0]) + "s" in key_fields else str(val[0])
),
)
pos_values = {
i: val
Expand Down Expand Up @@ -166,13 +197,17 @@ def _decode[T](sexp: netlist_type, t: type[T], parent: Any | None = None) -> T:
args = get_args(f.type)
if origin is list:
val_t = args[0]
value_dict[name] = [_convert(_val[1:], val_t) for _val in values]
value_dict[name] = [
_convert(_val[1:], val_t, stack, name) for _val in values
]
elif origin is dict:
if not sp.key:
raise ValueError(f"Key function required for multidict: {f.name}")
key_t = args[0]
val_t = args[1]
converted_values = [_convert(_val[1:], val_t) for _val in values]
converted_values = [
_convert(_val[1:], val_t, stack, name) for _val in values
]
values_with_key = [(sp.key(_val), _val) for _val in converted_values]

if not all(isinstance(k, key_t) for k, _ in values_with_key):
Expand All @@ -189,21 +224,21 @@ def _decode[T](sexp: netlist_type, t: type[T], parent: Any | None = None) -> T:
)
else:
assert len(values) == 1, f"Duplicate key: {name}"
value_dict[name] = _convert(values[0][1:], f.type)
value_dict[name] = _convert(values[0][1:], f.type, stack, name)

# Positional
for f, v in (it := zip_non_locked(positional_fields.values(), pos_values.values())):
# special case for missing positional empty StrEnum fields
if isinstance(f.type, type) and issubclass(f.type, StrEnum):
if "" in f.type and not isinstance(v, Symbol):
value_dict[f.name] = _convert(Symbol(""), f.type)
value_dict[f.name] = _convert(Symbol(""), f.type, stack, f.name)
# only advance field iterator
# if no more positional fields, there shouldn't be any more values
if it.next(0) is None:
raise ValueError(f"Unexpected symbol {v}")
continue

value_dict[f.name] = _convert(v, f.type)
value_dict[f.name] = _convert(v, f.type, stack, f.name)

# Check assertions ----------------------------------------------------
for f in fs:
Expand Down