Skip to content

Commit

Permalink
Use context managers with open
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Jul 12, 2022
1 parent b4912d9 commit d69eaab
Show file tree
Hide file tree
Showing 9 changed files with 146 additions and 129 deletions.
22 changes: 9 additions & 13 deletions aesara/compile/compiledir.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,13 @@ def cleanup():
"""
compiledir = config.compiledir
for directory in os.listdir(compiledir):
file = None
try:
try:
filename = os.path.join(compiledir, directory, "key.pkl")
file = open(filename, "rb")
# print file
filename = os.path.join(compiledir, directory, "key.pkl")
# print file
with open(filename, "rb") as file:
try:
keydata = pickle.load(file)

for key in list(keydata.keys):
have_npy_abi_version = False
have_c_compiler = False
Expand Down Expand Up @@ -86,14 +85,11 @@ def cleanup():
"the clean-up, please remove manually "
"the directory containing it."
)
except OSError:
_logger.error(
f"Could not clean up this directory: '{directory}'. To complete "
"the clean-up, please remove it manually."
)
finally:
if file is not None:
file.close()
except OSError:
_logger.error(
f"Could not clean up this directory: '{directory}'. To complete "
"the clean-up, please remove it manually."
)


def print_title(title, overline="", underline=""):
Expand Down
196 changes: 104 additions & 92 deletions aesara/compile/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import sys
import time
from collections import defaultdict
from contextlib import contextmanager
from typing import Dict, List

import numpy as np
Expand All @@ -25,6 +26,17 @@
from aesara.link.utils import get_destroy_dependencies


@contextmanager
def extended_open(filename, mode="r"):
if filename == "<stdout>":
yield sys.stdout
elif filename == "<stderr>":
yield sys.stderr
else:
with open(filename, mode=mode) as f:
yield f


logger = logging.getLogger("aesara.compile.profiling")

aesara_imported_time = time.time()
Expand All @@ -37,93 +49,92 @@


def _atexit_print_fn():
"""
Print ProfileStat objects in _atexit_print_list to _atexit_print_file.
"""
"""Print `ProfileStat` objects in `_atexit_print_list` to `_atexit_print_file`."""
if config.profile:
to_sum = []

if config.profiling__destination == "stderr":
destination_file = sys.stderr
destination_file = "<stderr>"
elif config.profiling__destination == "stdout":
destination_file = sys.stdout
destination_file = "<stdout>"
else:
destination_file = open(config.profiling__destination, "w")

# Reverse sort in the order of compile+exec time
for ps in sorted(
_atexit_print_list, key=lambda a: a.compile_time + a.fct_call_time
)[::-1]:
if (
ps.fct_callcount >= 1
or ps.compile_time > 1
or getattr(ps, "callcount", 0) > 1
):
ps.summary(
destination_file = config.profiling__destination

with extended_open(destination_file, mode="w"):

# Reverse sort in the order of compile+exec time
for ps in sorted(
_atexit_print_list, key=lambda a: a.compile_time + a.fct_call_time
)[::-1]:
if (
ps.fct_callcount >= 1
or ps.compile_time > 1
or getattr(ps, "callcount", 0) > 1
):
ps.summary(
file=destination_file,
n_ops_to_print=config.profiling__n_ops,
n_apply_to_print=config.profiling__n_apply,
)

if ps.show_sum:
to_sum.append(ps)
else:
# TODO print the name if there is one!
print("Skipping empty Profile")
if len(to_sum) > 1:
# Make a global profile
cum = copy.copy(to_sum[0])
msg = f"Sum of all({len(to_sum)}) printed profiles at exit."
cum.message = msg
for ps in to_sum[1:]:
for attr in [
"compile_time",
"fct_call_time",
"fct_callcount",
"vm_call_time",
"optimizer_time",
"linker_time",
"validate_time",
"import_time",
"linker_node_make_thunks",
]:
setattr(cum, attr, getattr(cum, attr) + getattr(ps, attr))

# merge dictionary
for attr in [
"apply_time",
"apply_callcount",
"apply_cimpl",
"variable_shape",
"variable_strides",
"variable_offset",
"linker_make_thunk_time",
]:
cum_attr = getattr(cum, attr)
for key, val in getattr(ps, attr.items()):
assert key not in cum_attr, (key, cum_attr)
cum_attr[key] = val

if cum.optimizer_profile and ps.optimizer_profile:
try:
merge = cum.optimizer_profile[0].merge_profile(
cum.optimizer_profile[1], ps.optimizer_profile[1]
)
assert len(merge) == len(cum.optimizer_profile[1])
cum.optimizer_profile = (cum.optimizer_profile[0], merge)
except Exception as e:
print(e)
cum.optimizer_profile = None
else:
cum.optimizer_profile = None

cum.summary(
file=destination_file,
n_ops_to_print=config.profiling__n_ops,
n_apply_to_print=config.profiling__n_apply,
)

if ps.show_sum:
to_sum.append(ps)
else:
# TODO print the name if there is one!
print("Skipping empty Profile")
if len(to_sum) > 1:
# Make a global profile
cum = copy.copy(to_sum[0])
msg = f"Sum of all({len(to_sum)}) printed profiles at exit."
cum.message = msg
for ps in to_sum[1:]:
for attr in [
"compile_time",
"fct_call_time",
"fct_callcount",
"vm_call_time",
"optimizer_time",
"linker_time",
"validate_time",
"import_time",
"linker_node_make_thunks",
]:
setattr(cum, attr, getattr(cum, attr) + getattr(ps, attr))

# merge dictionary
for attr in [
"apply_time",
"apply_callcount",
"apply_cimpl",
"variable_shape",
"variable_strides",
"variable_offset",
"linker_make_thunk_time",
]:
cum_attr = getattr(cum, attr)
for key, val in getattr(ps, attr.items()):
assert key not in cum_attr, (key, cum_attr)
cum_attr[key] = val

if cum.optimizer_profile and ps.optimizer_profile:
try:
merge = cum.optimizer_profile[0].merge_profile(
cum.optimizer_profile[1], ps.optimizer_profile[1]
)
assert len(merge) == len(cum.optimizer_profile[1])
cum.optimizer_profile = (cum.optimizer_profile[0], merge)
except Exception as e:
print(e)
cum.optimizer_profile = None
else:
cum.optimizer_profile = None

cum.summary(
file=destination_file,
n_ops_to_print=config.profiling__n_ops,
n_apply_to_print=config.profiling__n_apply,
)

if config.print_global_stats:
print_global_stats()

Expand All @@ -139,24 +150,25 @@ def print_global_stats():
"""

if config.profiling__destination == "stderr":
destination_file = sys.stderr
destination_file = "<stderr>"
elif config.profiling__destination == "stdout":
destination_file = sys.stdout
destination_file = "<stdout>"
else:
destination_file = open(config.profiling__destination, "w")

print("=" * 50, file=destination_file)
print(
(
"Global stats: ",
f"Time elasped since Aesara import = {time.time() - aesara_imported_time:6.3f}s, "
f"Time spent in Aesara functions = {total_fct_exec_time:6.3f}s, "
"Time spent compiling Aesara functions: "
f" optimization = {total_graph_opt_time:6.3f}s, linker = {total_time_linker:6.3f}s ",
),
file=destination_file,
)
print("=" * 50, file=destination_file)
destination_file = config.profiling__destination

with extended_open(destination_file, mode="w"):
print("=" * 50, file=destination_file)
print(
(
"Global stats: ",
f"Time elasped since Aesara import = {time.time() - aesara_imported_time:6.3f}s, "
f"Time spent in Aesara functions = {total_fct_exec_time:6.3f}s, "
"Time spent compiling Aesara functions: "
f" optimization = {total_graph_opt_time:6.3f}s, linker = {total_time_linker:6.3f}s ",
),
file=destination_file,
)
print("=" * 50, file=destination_file)


_profiler_printers = []
Expand Down
3 changes: 2 additions & 1 deletion aesara/configdefaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -1300,7 +1300,8 @@ def _filter_compiledir(path):
init_file = os.path.join(path, "__init__.py")
if not os.path.exists(init_file):
try:
open(init_file, "w").close()
with open(init_file, "w"):
pass
except OSError as e:
if os.path.exists(init_file):
pass # has already been created
Expand Down
14 changes: 8 additions & 6 deletions aesara/link/c/cmodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -1008,8 +1008,8 @@ def unpickle_failure():
entry = key_data.get_entry()
try:
# Test to see that the file is [present and] readable.
open(entry).close()
gone = False
with open(entry):
gone = False
except OSError:
gone = True

Expand Down Expand Up @@ -1505,8 +1505,8 @@ def clear_unversioned(self, min_age=None):
if filename.startswith("tmp"):
try:
fname = os.path.join(self.dirname, filename, "key.pkl")
open(fname).close()
has_key = True
with open(fname):
has_key = True
except OSError:
has_key = False
if not has_key:
Expand Down Expand Up @@ -1599,7 +1599,8 @@ def _rmtree(
if os.path.exists(parent):
try:
_logger.info(f'placing "delete.me" in {parent}')
open(os.path.join(parent, "delete.me"), "w").close()
with open(os.path.join(parent, "delete.me"), "w"):
pass
except Exception as ee:
_logger.warning(
f"Failed to remove or mark cache directory {parent} for removal {ee}"
Expand Down Expand Up @@ -2641,7 +2642,8 @@ def print_command_line_error():

if py_module:
# touch the __init__ file
open(os.path.join(location, "__init__.py"), "w").close()
with open(os.path.join(location, "__init__.py"), "w"):
pass
assert os.path.isfile(lib_filename)
return dlimport(lib_filename)

Expand Down
3 changes: 2 additions & 1 deletion aesara/link/c/cutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ def compile_cutils():
assert e.errno == errno.EEXIST
assert os.path.exists(location), location
if not os.path.exists(os.path.join(location, "__init__.py")):
open(os.path.join(location, "__init__.py"), "w").close()
with open(os.path.join(location, "__init__.py"), "w"):
pass

try:
from cutils_ext.cutils_ext import * # noqa
Expand Down
12 changes: 9 additions & 3 deletions aesara/link/c/lazylinker_c.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ def try_reload():
init_file = os.path.join(location, "__init__.py")
if not os.path.exists(init_file):
try:
open(init_file, "w").close()
with open(init_file, "w"):
pass
except OSError as e:
if os.path.exists(init_file):
pass # has already been created
Expand Down Expand Up @@ -126,10 +127,12 @@ def try_reload():
"code generation."
)
raise ImportError("The file lazylinker_c.c is not available.")
code = open(cfile).read()

with open(cfile) as f:
code = f.read()

loc = os.path.join(config.compiledir, dirname)
if not os.path.exists(loc):

try:
os.mkdir(loc)
except OSError as e:
Expand All @@ -140,14 +143,17 @@ def try_reload():
GCC_compiler.compile_str(dirname, code, location=loc, preargs=args)
# Save version into the __init__.py file.
init_py = os.path.join(loc, "__init__.py")

with open(init_py, "w") as f:
f.write(f"_version = {version}\n")

# If we just compiled the module for the first time, then it was
# imported at the same time: we need to make sure we do not
# reload the now outdated __init__.pyc below.
init_pyc = os.path.join(loc, "__init__.pyc")
if os.path.isfile(init_pyc):
os.remove(init_pyc)

try_import()
try_reload()
from lazylinker_ext import lazylinker_ext as lazy_c
Expand Down
Loading

0 comments on commit d69eaab

Please sign in to comment.