Skip to content
This repository has been archived by the owner on Dec 10, 2024. It is now read-only.

Commit

Permalink
SEXP:Add stack and debugging information to decoder
Browse files Browse the repository at this point in the history
  • Loading branch information
mawildoer committed Sep 11, 2024
1 parent 7e89dd8 commit 00971b4
Showing 1 changed file with 68 additions and 33 deletions.
101 changes: 68 additions & 33 deletions src/faebryk/libs/sexp/dataclass_sexp.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,43 +61,74 @@ def from_field(cls, f: Field):
class SymEnum(StrEnum): ...


def _convert(val, t):
# Recurse (GenericAlias e.g list[])
if (origin := get_origin(t)) is not None:
args = get_args(t)
if origin is list:
return [_convert(_val, args[0]) for _val in val]
if origin is tuple:
return tuple(_convert(_val, _t) for _val, _t in zip(val, args))
if origin in (Union, UnionType) and len(args) == 2 and args[1] is type(None):
return _convert(val, args[0]) if val is not None else None
class DecodeError(Exception):
"""Error during decoding"""


def _convert(
val,
t,
stack: list[tuple[str, type]] | None = None,
name: str | None = None,
):
if name is None:
name = "<" + t.__name__ + ">"
if stack is None:
stack = []
substack = stack + [(name, t)]

raise NotImplementedError(f"{origin} not supported")
try:
# Recurse (GenericAlias e.g list[])
if (origin := get_origin(t)) is not None:
args = get_args(t)
if origin is list:
return [_convert(_val, args[0], substack) for _val in val]
if origin is tuple:
return tuple(
_convert(_val, _t, substack) for _val, _t in zip(val, args)
)
if (
origin in (Union, UnionType)
and len(args) == 2
and args[1] is type(None)
):
return _convert(val, args[0], substack) if val is not None else None

#
if is_dataclass(t):
return _decode(val, t)
raise NotImplementedError(f"{origin} not supported")

# Primitive
#
if is_dataclass(t):
return _decode(val, t, substack)

# Unpack list if single atom
if isinstance(val, list) and len(val) == 1 and not isinstance(val[0], list):
val = val[0]
# Primitive

if issubclass(t, bool):
assert val in [Symbol("yes"), Symbol("no")]
return val == Symbol("yes")
if isinstance(val, Symbol):
return t(str(val))
# Unpack list if single atom
if isinstance(val, list) and len(val) == 1 and not isinstance(val[0], list):
val = val[0]

return t(val)
if issubclass(t, bool):
assert val in [Symbol("yes"), Symbol("no")]
return val == Symbol("yes")
if isinstance(val, Symbol):
return t(str(val))

return t(val)
except DecodeError:
raise
except Exception as e:
pretty_stack = ".".join(s[0] for s in substack)
raise DecodeError(f"Failed to decode {pretty_stack} ({t}) with {val} ") from e


netlist_obj = str | Symbol | int | float | bool | list
netlist_type = list[netlist_obj]


def _decode[T](sexp: netlist_type, t: type[T], parent: Any | None = None) -> T:
def _decode[T](
sexp: netlist_type,
t: type[T],
stack: list[tuple[str, type]] | None = None,
) -> T:
if logger.isEnabledFor(logging.DEBUG):
logger.debug(f"parse into: {t.__name__} {'-'*40}")
logger.debug(f"sexp: {sexp}")
Expand Down Expand Up @@ -126,9 +157,9 @@ def _decode[T](sexp: netlist_type, t: type[T], parent: Any | None = None) -> T:
and isinstance(key := val[0], Symbol)
and (str(key) + "s" in key_fields or str(key) in key_fields)
),
lambda val: str(val[0]) + "s"
if str(val[0]) + "s" in key_fields
else str(val[0]),
lambda val: (
str(val[0]) + "s" if str(val[0]) + "s" in key_fields else str(val[0])
),
)
pos_values = {
i: val
Expand Down Expand Up @@ -166,13 +197,17 @@ def _decode[T](sexp: netlist_type, t: type[T], parent: Any | None = None) -> T:
args = get_args(f.type)
if origin is list:
val_t = args[0]
value_dict[name] = [_convert(_val[1:], val_t) for _val in values]
value_dict[name] = [
_convert(_val[1:], val_t, stack, name) for _val in values
]
elif origin is dict:
if not sp.key:
raise ValueError(f"Key function required for multidict: {f.name}")
key_t = args[0]
val_t = args[1]
converted_values = [_convert(_val[1:], val_t) for _val in values]
converted_values = [
_convert(_val[1:], val_t, stack, name) for _val in values
]
values_with_key = [(sp.key(_val), _val) for _val in converted_values]

if not all(isinstance(k, key_t) for k, _ in values_with_key):
Expand All @@ -189,21 +224,21 @@ def _decode[T](sexp: netlist_type, t: type[T], parent: Any | None = None) -> T:
)
else:
assert len(values) == 1, f"Duplicate key: {name}"
value_dict[name] = _convert(values[0][1:], f.type)
value_dict[name] = _convert(values[0][1:], f.type, stack, name)

# Positional
for f, v in (it := zip_non_locked(positional_fields.values(), pos_values.values())):
# special case for missing positional empty StrEnum fields
if isinstance(f.type, type) and issubclass(f.type, StrEnum):
if "" in f.type and not isinstance(v, Symbol):
value_dict[f.name] = _convert(Symbol(""), f.type)
value_dict[f.name] = _convert(Symbol(""), f.type, stack, f.name)
# only advance field iterator
# if no more positional fields, there shouldn't be any more values
if it.next(0) is None:
raise ValueError(f"Unexpected symbol {v}")
continue

value_dict[f.name] = _convert(v, f.type)
value_dict[f.name] = _convert(v, f.type, stack, f.name)

# Check assertions ----------------------------------------------------
for f in fs:
Expand Down

0 comments on commit 00971b4

Please sign in to comment.