Skip to content

Commit

Permalink
Fix the type hints in aesara.printing
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Jul 20, 2022
1 parent 9d36038 commit bb40791
Show file tree
Hide file tree
Showing 5 changed files with 269 additions and 262 deletions.
59 changes: 29 additions & 30 deletions aesara/compile/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,20 @@
import time
from collections import defaultdict
from contextlib import contextmanager
from typing import Dict, List
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

import numpy as np

import aesara
from aesara.configdefaults import config
from aesara.graph.basic import Constant, Variable
from aesara.graph.basic import Apply, Constant, Variable
from aesara.link.utils import get_destroy_dependencies


if TYPE_CHECKING:
from aesara.graph.fg import FunctionGraph


@contextmanager
def extended_open(filename, mode="r"):
if filename == "<stdout>":
Expand All @@ -39,13 +43,13 @@ def extended_open(filename, mode="r"):

logger = logging.getLogger("aesara.compile.profiling")

aesara_imported_time = time.time()
total_fct_exec_time = 0.0
total_graph_opt_time = 0.0
total_time_linker = 0.0
aesara_imported_time: float = time.time()
total_fct_exec_time: float = 0.0
total_graph_opt_time: float = 0.0
total_time_linker: float = 0.0

_atexit_print_list: List = []
_atexit_registered = False
_atexit_print_list: List["ProfileStats"] = []
_atexit_registered: bool = False


def _atexit_print_fn():
Expand Down Expand Up @@ -180,7 +184,6 @@ def register_profiler_printer(fct):


class ProfileStats:

"""
Object to store runtime and memory profiling information for all of
Aesara's operations: compilation, optimization, execution.
Expand Down Expand Up @@ -215,72 +218,68 @@ def reset(self):
#
show_sum: bool = True

compile_time = 0.0
compile_time: float = 0.0
# Total time spent in body of orig_function,
# dominated by graph optimization and compilation of C
#

fct_call_time = 0.0
fct_call_time: float = 0.0
# The total time spent in Function.__call__
#

fct_callcount = 0
fct_callcount: int = 0
# Number of calls to Function.__call__
#

vm_call_time = 0.0
vm_call_time: float = 0.0
# Total time spent in Function.vm.__call__
#

apply_time = None
# dict from `(FunctionGraph, Variable)` to float runtime
#
apply_time: Optional[Dict[Union["FunctionGraph", Variable], float]] = None

apply_callcount = None
# dict from `(FunctionGraph, Variable)` to number of executions
#
apply_callcount: Optional[Dict[Union["FunctionGraph", Variable], int]] = None

apply_cimpl = None
apply_cimpl: Optional[Dict[Apply, bool]] = None
# dict from node -> bool (1 if c, 0 if py)
#

message = None
message: Optional[str] = None
# pretty string to print in summary, to identify this output
#

variable_shape: Dict = {}
variable_shape: Dict[Variable, Any] = {}
# Variable -> shapes
#

variable_strides: Dict = {}
variable_strides: Dict[Variable, Any] = {}
# Variable -> strides
#

variable_offset: Dict = {}
variable_offset: Dict[Variable, Any] = {}
# Variable -> offset
#

optimizer_time = 0.0
optimizer_time: float = 0.0
# time spent optimizing graph (FunctionMaker.__init__)

validate_time = 0.0
validate_time: float = 0.0
# time spent in fgraph.validate
# This is a subset of optimizer_time that is dominated by toposort()
# when the destorymap feature is included.

linker_time = 0.0
linker_time: float = 0.0
# time spent linking graph (FunctionMaker.create)

import_time = 0.0
import_time: float = 0.0
# time spent in importing compiled python module.

linker_node_make_thunks = 0.0
linker_node_make_thunks: float = 0.0

linker_make_thunk_time: Dict = {}

line_width = config.profiling__output_line_width

nb_nodes = -1
nb_nodes: int = -1
# The number of nodes in the graph. We need the information separately in
# case we print the profile when the function wasn't executed, or if there
# is a lazy operation in the graph.
Expand Down
1 change: 1 addition & 0 deletions aesara/link/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ def __init__(
self.call_counts = [0] * len(nodes)
self.call_times = [0] * len(nodes)
self.time_thunks = False
self.storage_map: Optional[StorageMapType] = None

@abstractmethod
def __call__(self):
Expand Down
Loading

0 comments on commit bb40791

Please sign in to comment.