Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid changing global logging state in tests #1522

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/tlo/logging/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
disable,
getLogger,
initialise,
reset,
restore_global_state,
set_logging_levels,
set_output_file,
)
from .helpers import set_logging_levels

__all__ = [
"CRITICAL",
Expand All @@ -21,7 +21,7 @@
"disable",
"getLogger",
"initialise",
"reset",
"restore_global_state",
"set_output_file",
"set_logging_levels",
]
86 changes: 79 additions & 7 deletions src/tlo/logging/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import logging as _logging
import sys
import warnings
from contextlib import contextmanager
from functools import partialmethod
from pathlib import Path
from typing import Any, Callable, List, Optional, TypeAlias, Union
Expand Down Expand Up @@ -75,14 +76,69 @@ def initialise(
root_logger.addHandler(handler)


def reset():
"""Reset global logging state to values at initial import."""
@contextmanager
def restore_global_state():
"""Context manager which records global logging state on entry and restores at exit."""
global _get_simulation_date, _loggers
while len(_loggers) > 0:
name, _ = _loggers.popitem()
_logging.root.manager.loggerDict.pop(name, None) # pylint: disable=E1101
_loggers.clear()
_get_simulation_date = _mock_simulation_date_getter
original_get_simulation_date = _get_simulation_date
original_loggers = _loggers.copy()
original_logger_states = {
name: logger.__getstate__() for name, logger in _loggers.items()
}
yield
# For any new loggers created in context managed code reset their attributes
# We don't remove the associated logger instances in the base logging library as
# any children of these loggers will then not work as expected
for name, logger in _loggers.items():
if name not in original_loggers:
logger.reset_attributes()
_get_simulation_date = original_get_simulation_date
for name, logger_state in original_logger_states.items():
original_loggers[name].__setstate__(logger_state)
_loggers = original_loggers


def set_logging_levels(custom_levels: dict[str, LogLevel]) -> None:
"""Set custom logging levels for disease modules

:param custom_levels: Dictionary of modules and their level, '*' can be used as a key for all modules
"""
# get list of `tlo.` loggers to process (this assumes logger have been setup on module import)
tlo_methods_loggers = [
logger for name, logger in _loggers.items() if name.startswith("tlo.methods")
]

# set the baseline logging level from methods, if it's been set
if "*" in custom_levels:
getLogger("tlo.methods").setLevel(custom_levels["*"])

# loop over each of the tlo loggers
for logger in tlo_methods_loggers:
# get the full name
logger_name = logger.name
matched = False
# look for name, or any parent name, in the custom levels
while len(logger_name):
if logger_name in custom_levels:
getLogger(logger_name).setLevel(custom_levels[logger_name])
matched = True
break
elif logger_name == "tlo.methods":
# we've reached the top-level of the `tlo.methods` logger
break
else:
# get the parent logger name
logger_name = ".".join(logger_name.split(".")[:-1])
# if we exited without finding a matching logger in custom levels
if not matched:
if "*" in custom_levels:
getLogger(logger.name).setLevel(custom_levels["*"])

# loggers named in custom_level but, for some reason, haven't been getLogger-ed yet
tlo_methods_logger_names = {logger.name for logger in tlo_methods_loggers}
for logger_name, logger_level in custom_levels.items():
if logger_name != "*" and logger_name not in tlo_methods_logger_names:
getLogger(logger_name).setLevel(logger_level)


def set_output_file(
Expand Down Expand Up @@ -251,6 +307,22 @@ def reset_attributes(self) -> None:
self._columns.clear()
self.setLevel(_DEFAULT_LEVEL)

def __getstate__(self) -> dict:
return {
"name": self.name,
"level": self.level,
"handlers": self.handlers,
"uuids": self._uuids,
"columns": self._columns,
}

def __setstate__(self, state: dict):
self._std_logger = _logging.getLogger(name=state["name"])
self._std_logger.setLevel(state["level"])
self.handlers = state["handlers"]
self._uuids = state["uuids"]
self._columns = state["columns"]

def setLevel(self, level: LogLevel) -> None:
self._std_logger.setLevel(level)

Expand Down
48 changes: 0 additions & 48 deletions src/tlo/logging/helpers.py
Original file line number Diff line number Diff line change
@@ -1,57 +1,9 @@
import logging as _logging
from collections.abc import Collection, Iterable
from typing import Dict, List, Optional, Union

import pandas as pd
from pandas.api.types import is_extension_array_dtype

from .core import getLogger


def set_logging_levels(custom_levels: Dict[str, int]) -> None:
"""Set custom logging levels for disease modules

:param custom_levels: Dictionary of modules and their level, '*' can be used as a key for all modules
"""
# get list of `tlo.` loggers to process (this assumes logger have been setup on module import)
loggers = {
_logging.getLogger(name)
for name in _logging.root.manager.loggerDict # pylint: disable=E1101
if name.startswith('tlo.methods')
}

# set the baseline logging level from methods, if it's been set
if '*' in custom_levels:
getLogger('tlo.methods').setLevel(custom_levels['*'])

# loop over each of the tlo loggers
for logger in loggers:
# get the full name
logger_name = logger.name
matched = False
# look for name, or any parent name, in the custom levels
while len(logger_name):
if logger_name in custom_levels:
getLogger(logger_name).setLevel(custom_levels[logger_name])
matched = True
break
elif logger_name == 'tlo.methods':
# we've reached the top-level of the `tlo.methods` logger
break
else:
# get the parent logger name
logger_name = '.'.join(logger_name.split(".")[:-1])
# if we exited without finding a matching logger in custom levels
if not matched:
if '*' in custom_levels:
getLogger(logger.name).setLevel(custom_levels['*'])

# loggers named in custom_level but, for some reason, haven't been getLogger-ed yet
loggers = {logger.name for logger in loggers}
for logger_name, logger_level in custom_levels.items():
if logger_name != "*" and logger_name not in loggers:
getLogger(logger_name).setLevel(logger_level)


def get_dataframe_row_as_dict_for_logging(
dataframe: pd.DataFrame,
Expand Down
16 changes: 8 additions & 8 deletions tests/test_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,14 @@ def initialise_logging(
root_level: core.LogLevel,
stdout_handler_level: core.LogLevel,
) -> Generator[None, None, None]:
logging.initialise(
add_stdout_handler=add_stdout_handler,
simulation_date_getter=simulation_date_getter,
root_level=root_level,
stdout_handler_level=stdout_handler_level,
)
yield
logging.reset()
with logging.restore_global_state():
logging.initialise(
add_stdout_handler=add_stdout_handler,
simulation_date_getter=simulation_date_getter,
root_level=root_level,
stdout_handler_level=stdout_handler_level,
)
yield


@pytest.mark.parametrize("add_stdout_handler", [True, False])
Expand Down
Loading