Skip to content

Commit

Permalink
Merge branch 'master' into mmd
Browse files Browse the repository at this point in the history
  • Loading branch information
kzkadc authored May 8, 2024
2 parents 26b131b + 0c680df commit 597e8c6
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 1 deletion.
7 changes: 6 additions & 1 deletion ignite/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def setup_logger(
filepath: Optional[str] = None,
distributed_rank: Optional[int] = None,
reset: bool = False,
encoding: Optional[str] = "utf-8",
) -> logging.Logger:
"""Setups logger: name, level, format etc.
Expand All @@ -175,6 +176,7 @@ def setup_logger(
distributed_rank: Optional, rank in distributed configuration to avoid logger setup for workers.
If None, distributed_rank is initialized to the rank of process.
reset: if True, reset an existing logger rather than keep format, handlers, and level.
encoding: open the file with the encoding. By default, 'utf-8'.
Returns:
logging.Logger
Expand Down Expand Up @@ -228,6 +230,9 @@ def setup_logger(
.. versionchanged:: 0.4.5
Added ``reset`` parameter.
.. versionchanged:: 0.5.1
Argument ``encoding`` added to correctly handle special characters in the file, default "utf-8".
"""
# check if the logger already exists
existing = name is None or name in logging.root.manager.loggerDict
Expand Down Expand Up @@ -265,7 +270,7 @@ def setup_logger(
logger.addHandler(ch)

if filepath is not None:
fh = logging.FileHandler(filepath)
fh = logging.FileHandler(filepath, encoding=encoding)
fh.setLevel(level)
fh.setFormatter(formatter)
logger.addHandler(fh)
Expand Down
24 changes: 24 additions & 0 deletions tests/ignite/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import platform
import sys
from collections import namedtuple

Expand Down Expand Up @@ -174,6 +175,29 @@ def test_override_setup_logger(capsys):
logging.shutdown()


@pytest.mark.parametrize("encoding", [None, "utf-8"])
def test_setup_logger_encoding(encoding, dirname):
fp = dirname / "log.txt"
logger = setup_logger(name="logger", filepath=fp, encoding=encoding, reset=True)
test_words = ["say hello", "say 你好", "say こんにちわ", "say 안녕하세요", "say привет"]
for w in test_words:
logger.info(w)
logging.shutdown()

with open(fp, "r", encoding=encoding) as h:
data = h.readlines()

if platform.system() == "Windows" and encoding is None:
flatten_data = "\n".join(data)
assert test_words[0] in flatten_data
for word in test_words[1:]:
assert word not in flatten_data
else:
assert len(data) == len(test_words)
for expected, output in zip(test_words, data):
assert expected in output


def test_deprecated():
# Test on function without docs, @deprecated without reasons
@deprecated("0.4.2", "0.6.0")
Expand Down

0 comments on commit 597e8c6

Please sign in to comment.