Skip to content

Commit

Permalink
custom dictionary flattener works
Browse files Browse the repository at this point in the history
  • Loading branch information
calbaker committed Feb 10, 2025
1 parent b767032 commit bfe86c3
Showing 1 changed file with 67 additions and 15 deletions.
82 changes: 67 additions & 15 deletions python/fastsim/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from importlib.metadata import version
from pathlib import Path
from typing import Any, List, Union, Dict
from typing import Any, List, Union, Dict, Optional
from typing_extensions import Self
import numpy as np
import inspect
Expand All @@ -10,7 +10,7 @@
import fastsim
from .fastsim import * # noqa: F403
from .fastsim import Cycle # type: ignore[attr-defined]
from . import utils # type: ignore[attr-defined]
from . import utils # type: ignore[attr-defined] # noqa: F401

DEFAULT_LOGGING_CONFIG = dict(
format="%(asctime)s.%(msecs)03d | %(filename)s:%(lineno)s | %(levelname)s: %(message)s",
Expand Down Expand Up @@ -101,8 +101,55 @@ def to_pydict(self, data_fmt: str = "msg_pack", flatten: bool = False) -> Dict:
if not flatten:
return pydict
else:
return next(iter(pd.json_normalize(pydict, sep=".").to_dict(orient='records')))
hist_len = get_hist_len(pydict)
assert hist_len is not None, "Cannot be flattened"
flat_dict = get_flattened(pydict, hist_len)
return flat_dict

def get_hist_len(obj: Dict) -> Optional[int]:
"""
Finds nested `history` and gets lenth of first element
"""
if 'history' in obj.keys():
return len(next(iter(obj['history'].values())))

elif next(iter(k for k in obj.keys() if '.history.' in k), None) is not None:
return len(next((v for (k, v) in obj.items() if 'history' == k.split(".")[-2])))

for (k, v) in obj.items():
if isinstance(v, dict):
hist_len = get_hist_len(v)
if hist_len is not None:
return hist_len
return None

def get_flattened(obj: Dict | List, hist_len: int, prepend_str: str="") -> Dict:
"""
Flattens and returns dictionary, separating keys and indices with a `"."`
# Arguments
# - `obj`: object to flatten
# - hist_len: length of any lists storing history data
# - `prepend_str`: prepend this to all keys in the returned `flat` dict
"""
flat: Dict = {}
if isinstance(obj, dict):
for (k, v) in obj.items():
new_key = k if (len(prepend_str) == 0) else prepend_str + "." + k
if isinstance(v, dict) or (isinstance(v, list) and len(v) != hist_len):
flat.update(get_flattened(v, hist_len, prepend_str=new_key))
else:
flat[new_key] = v
elif isinstance(obj, list):
for (i, v) in enumerate(obj):
new_key = i if (len(prepend_str) == 0) else prepend_str + "." + str(i)
if isinstance(v, dict) or (isinstance(v, list) and len(v) != hist_len):
flat.update(get_flattened(v, hist_len, prepend_str=new_key))
else:
flat[new_key] = v
else:
raise TypeError("`obj` should be `dict` or `list`")

return flat

@classmethod # type: ignore[misc]
def from_pydict(cls, pydict: Dict, data_fmt: str = "msg_pack", skip_init: bool = False) -> Self: # type: ignore[misc]
Expand Down Expand Up @@ -137,31 +184,36 @@ def to_dataframe(self, pandas: bool = False, allow_partial: bool = False) -> Uni
# Arguments
- `pandas`: returns pandas dataframe if True; otherwise, returns polars dataframe by default
- `allow_partial`: returns dataframe of length equal to solved time steps if simulation fails early
- `allow_partial`: tries to return dataframe of length equal to solved time steps if simulation fails early
"""
obj_dict = self.to_pydict(flatten=True)
history_dict: Dict[str, Any] = {}

history_keys = ['.history.', 'cyc.', '.cyc.']
try:
time_ach = next(iter(v for (k, v) in obj_dict.items() if 'veh' in k and 'time_seconds' in 'k'))
except StopIteration:
time_ach = None
hist_len = get_hist_len(obj_dict)
assert hist_len is not None

history_dict: Dict[str, Any] = {}
for k, v in obj_dict.items():
same_len_as_time_ach = (False if (time_ach is None) else (len(v) == len(time_ach)))
hk_in_k = any(hk in k for hk in history_keys)
if (hk_in_k or same_len_as_time_ach) and ("__len__" in dir(v)):
history_dict[k] = v
if hk_in_k and ("__len__" in dir(v)):
if (len(v) == hist_len) or allow_partial:
history_dict[k] = v

if allow_partial:
cutoff = min([len(val) for val in history_dict.values()])

if not pandas:
df = pl.DataFrame({col: val[:cutoff]
try:
df = pl.DataFrame({col: val[:cutoff]
for col, val in history_dict.items()})
except Exception as err:
raise Exception(f"{err}\n`save_interval` may not be uniform")
else:
df = pd.DataFrame({col: val[:cutoff]
try:
df = pd.DataFrame({col: val[:cutoff]
for col, val in history_dict.items()})
except Exception as err:
raise Exception(f"{err}\n`save_interval` may not be uniform")

else:
if not pandas:
try:
Expand Down

0 comments on commit bfe86c3

Please sign in to comment.