diff --git a/openfisca_core/tracers/flat_trace.py b/openfisca_core/tracers/flat_trace.py index aea9288e3..412ac8b02 100644 --- a/openfisca_core/tracers/flat_trace.py +++ b/openfisca_core/tracers/flat_trace.py @@ -37,27 +37,6 @@ def get_serialized_trace(self) -> t.SerializedNodeMap: for key, flat_trace in self.get_trace().items() } - def serialize( - self, - value: None | t.VarArray | t.ArrayLike[object], - ) -> None | t.ArrayLike[object]: - if value is None: - return None - - if isinstance(value, EnumArray): - return value.decode_to_str().tolist() - - if isinstance(value, numpy.ndarray) and numpy.issubdtype( - value.dtype, - numpy.dtype(bytes), - ): - return value.astype(numpy.dtype(str)).tolist() - - if isinstance(value, numpy.ndarray): - return value.tolist() - - return value - def _get_flat_trace( self, node: t.TraceNode, @@ -83,3 +62,27 @@ def key(node: t.TraceNode) -> t.NodeKey: name = node.name period = node.period return t.NodeKey(f"{name}<{period}>") + + @staticmethod + def serialize( + value: None | t.VarArray | t.ArrayLike[object], + ) -> None | t.ArrayLike[object]: + if value is None: + return None + + if isinstance(value, EnumArray): + return value.decode_to_str().tolist() + + if isinstance(value, numpy.ndarray) and numpy.issubdtype( + value.dtype, + numpy.dtype(bytes), + ): + return value.astype(numpy.dtype(str)).tolist() + + if isinstance(value, numpy.ndarray): + return value.tolist() + + return value + + +__all__ = ["FlatTrace"] diff --git a/openfisca_core/tracers/full_tracer.py b/openfisca_core/tracers/full_tracer.py index 56109cc9f..f6f793e19 100644 --- a/openfisca_core/tracers/full_tracer.py +++ b/openfisca_core/tracers/full_tracer.py @@ -91,10 +91,10 @@ def generate_performance_tables(self, dir_path: str) -> None: def get_nb_requests(self, variable: str) -> int: return sum(self._get_nb_requests(tree, variable) for tree in self.trees) - def get_flat_trace(self) -> dict: + def get_flat_trace(self) -> t.FlatNodeMap: return self.flat_trace.get_trace() - def get_serialized_flat_trace(self) -> dict: + def get_serialized_flat_trace(self) -> t.SerializedNodeMap: return self.flat_trace.get_serialized_trace() def browse_trace(self) -> Iterator[t.TraceNode]: @@ -161,3 +161,6 @@ def _get_nb_requests(self, tree: t.TraceNode, variable: str) -> int: @staticmethod def _get_time_in_sec() -> t.Time: return time.time_ns() / (10**9) + + +__all__ = ["FullTracer"] diff --git a/openfisca_core/tracers/types.py b/openfisca_core/tracers/types.py new file mode 100644 index 000000000..f26c85424 --- /dev/null +++ b/openfisca_core/tracers/types.py @@ -0,0 +1,108 @@ +from __future__ import annotations + +from collections.abc import Iterator +from typing import NewType, Protocol +from typing_extensions import TypeAlias, TypedDict + +from openfisca_core.types import ( + Array, + ArrayLike, + ParameterNode, + ParameterNodeChild, + Period, + PeriodInt, + VariableName, +) + +from numpy import generic as VarDType + +#: A type of a generic array. +VarArray: TypeAlias = Array[VarDType] + +#: A type representing a unit time. +Time: TypeAlias = float + +#: A type representing a mapping of flat traces. +FlatNodeMap: TypeAlias = dict["NodeKey", "FlatTraceMap"] + +#: A type representing a mapping of serialized traces. +SerializedNodeMap: TypeAlias = dict["NodeKey", "SerializedTraceMap"] + +#: A stack of simple traces. +SimpleStack: TypeAlias = list["SimpleTraceMap"] + +#: Key of a trace. +NodeKey = NewType("NodeKey", str) + + +class FlatTraceMap(TypedDict, total=True): + dependencies: list[NodeKey] + parameters: dict[NodeKey, None | ArrayLike[object]] + value: None | VarArray + calculation_time: Time + formula_time: Time + + +class SerializedTraceMap(TypedDict, total=True): + dependencies: list[NodeKey] + parameters: dict[NodeKey, None | ArrayLike[object]] + value: None | ArrayLike[object] + calculation_time: Time + formula_time: Time + + +class SimpleTraceMap(TypedDict, total=True): + name: VariableName + period: int | Period + + +class ComputationLog(Protocol): + def print_log(self, aggregate: bool = ..., max_depth: int = ..., /) -> None: ... + + +class FlatTrace(Protocol): + def get_trace(self, /) -> FlatNodeMap: ... + def get_serialized_trace(self, /) -> SerializedNodeMap: ... + + +class FullTracer(Protocol): + @property + def trees(self, /) -> list[TraceNode]: ... + def browse_trace(self, /) -> Iterator[TraceNode]: ... + + +class PerformanceLog(Protocol): + def generate_graph(self, dir_path: str, /) -> None: ... + def generate_performance_tables(self, dir_path: str, /) -> None: ... + + +class SimpleTracer(Protocol): + @property + def stack(self, /) -> SimpleStack: ... + def record_calculation_start( + self, variable: VariableName, period: PeriodInt | Period, / + ) -> None: ... + def record_calculation_end(self, /) -> None: ... + + +class TraceNode(Protocol): + children: list[TraceNode] + end: Time + name: str + parameters: list[TraceNode] + parent: None | TraceNode + period: PeriodInt | Period + start: Time + value: None | VarArray + + def calculation_time(self, *, round_: bool = ...) -> Time: ... + def formula_time(self, /) -> Time: ... + def append_child(self, node: TraceNode, /) -> None: ... + + +__all__ = [ + "ArrayLike", + "ParameterNode", + "ParameterNodeChild", + "PeriodInt", +] diff --git a/openfisca_core/types.py b/openfisca_core/types.py index 9fb94d1f5..9c8105741 100644 --- a/openfisca_core/types.py +++ b/openfisca_core/types.py @@ -148,8 +148,31 @@ class MemoryUsage(TypedDict, total=False): # Parameters +#: A type representing a node of parameters. +ParameterNode: TypeAlias = Union[ + "ParameterNodeAtInstant", "VectorialParameterNodeAtInstant" +] -class ParameterNodeAtInstant(Protocol): ... +#: A type representing a ??? +ParameterNodeChild: TypeAlias = Union[ParameterNode, ArrayLike[object]] + + +class ParameterNodeAtInstant(Protocol): + _instant_str: InstantStr + + def __contains__(self, __item: object, /) -> bool: ... + def __getitem__( + self, __index: str | Array[DTypeGeneric], / + ) -> ParameterNodeChild: ... + + +class VectorialParameterNodeAtInstant(Protocol): + _instant_str: InstantStr + + def __contains__(self, item: object, /) -> bool: ... + def __getitem__( + self, __index: str | Array[DTypeGeneric], / + ) -> ParameterNodeChild: ... # Periods