Skip to content

Commit

Permalink
Merge pull request #250 from heitorlessa/fix/#249
Browse files Browse the repository at this point in the history
fix: prevent touching preconfigured loggers #249
  • Loading branch information
heitorlessa authored Dec 21, 2020
2 parents 835789e + f33cffe commit bdb3925
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 95 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]
- **Docs**: Add clarification to Tracer docs for how `capture_method` decorator can cause function responses to be read and serialized.

### Fixed
- **Logger**: Bugfix to prevent parent loggers with the same name being configured more than once

## [1.9.0] - 2020-12-04

### Added
Expand Down
41 changes: 26 additions & 15 deletions aws_lambda_powertools/logging/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,21 +148,32 @@ def _get_logger(self):
def _init_logger(self, **kwargs):
"""Configures new logger"""

# Skip configuration if it's a child logger to prevent
# multiple handlers being attached as well as different sampling mechanisms
# and multiple messages from being logged as handlers can be duplicated
if not self.child:
self._configure_sampling()
self._logger.setLevel(self.log_level)
self._logger.addHandler(self._handler)
self.structure_logs(**kwargs)

logger.debug("Adding filter in root logger to suppress child logger records to bubble up")
for handler in logging.root.handlers:
# It'll add a filter to suppress any child logger from self.service
# Where service is Order, it'll reject parent logger Order,
# and child loggers such as Order.checkout, Order.shared
handler.addFilter(SuppressFilter(self.service))
# Skip configuration if it's a child logger or a pre-configured logger
# to prevent the following:
# a) multiple handlers being attached
# b) different sampling mechanisms
# c) multiple messages from being logged as handlers can be duplicated
is_logger_preconfigured = getattr(self._logger, "init", False)
if self.child or is_logger_preconfigured:
return

self._configure_sampling()
self._logger.setLevel(self.log_level)
self._logger.addHandler(self._handler)
self.structure_logs(**kwargs)

logger.debug("Adding filter in root logger to suppress child logger records to bubble up")
for handler in logging.root.handlers:
# It'll add a filter to suppress any child logger from self.service
# Where service is Order, it'll reject parent logger Order,
# and child loggers such as Order.checkout, Order.shared
handler.addFilter(SuppressFilter(self.service))

# as per bug in #249, we should not be pre-configuring an existing logger
# therefore we set a custom attribute in the Logger that will be returned
# std logging will return the same Logger with our attribute if name is reused
logger.debug(f"Marking logger {self.service} as preconfigured")
self._logger.init = True

def _configure_sampling(self):
"""Dynamically set log level based on sampling rate
Expand Down
68 changes: 41 additions & 27 deletions tests/functional/test_aws_lambda_logging.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""aws_lambda_logging tests."""
import io
import json
import random
import string

import pytest

Expand All @@ -12,9 +14,15 @@ def stdout():
return io.StringIO()


@pytest.fixture
def service_name():
chars = string.ascii_letters + string.digits
return "".join(random.SystemRandom().choice(chars) for _ in range(15))


@pytest.mark.parametrize("level", ["DEBUG", "WARNING", "ERROR", "INFO", "CRITICAL"])
def test_setup_with_valid_log_levels(stdout, level):
logger = Logger(level=level, stream=stdout, request_id="request id!", another="value")
def test_setup_with_valid_log_levels(stdout, level, service_name):
logger = Logger(service=service_name, level=level, stream=stdout, request_id="request id!", another="value")
msg = "This is a test"
log_command = {
"INFO": logger.info,
Expand All @@ -37,8 +45,8 @@ def test_setup_with_valid_log_levels(stdout, level):
assert "exception" not in log_dict


def test_logging_exception_traceback(stdout):
logger = Logger(level="DEBUG", stream=stdout)
def test_logging_exception_traceback(stdout, service_name):
logger = Logger(service=service_name, level="DEBUG", stream=stdout)

try:
raise ValueError("Boom")
Expand All @@ -52,9 +60,9 @@ def test_logging_exception_traceback(stdout):
assert "exception" in log_dict


def test_setup_with_invalid_log_level(stdout):
def test_setup_with_invalid_log_level(stdout, service_name):
with pytest.raises(ValueError) as e:
Logger(level="not a valid log level")
Logger(service=service_name, level="not a valid log level")
assert "Unknown level" in e.value.args[0]


Expand All @@ -65,8 +73,8 @@ def check_log_dict(log_dict):
assert "message" in log_dict


def test_with_dict_message(stdout):
logger = Logger(level="DEBUG", stream=stdout)
def test_with_dict_message(stdout, service_name):
logger = Logger(service=service_name, level="DEBUG", stream=stdout)

msg = {"x": "isx"}
logger.critical(msg)
Expand All @@ -76,8 +84,8 @@ def test_with_dict_message(stdout):
assert msg == log_dict["message"]


def test_with_json_message(stdout):
logger = Logger(stream=stdout)
def test_with_json_message(stdout, service_name):
logger = Logger(service=service_name, stream=stdout)

msg = {"x": "isx"}
logger.info(json.dumps(msg))
Expand All @@ -87,8 +95,8 @@ def test_with_json_message(stdout):
assert msg == log_dict["message"]


def test_with_unserializable_value_in_message(stdout):
logger = Logger(level="DEBUG", stream=stdout)
def test_with_unserializable_value_in_message(stdout, service_name):
logger = Logger(service=service_name, level="DEBUG", stream=stdout)

class Unserializable:
pass
Expand All @@ -101,12 +109,17 @@ class Unserializable:
assert log_dict["message"]["x"].startswith("<")


def test_with_unserializable_value_in_message_custom(stdout):
def test_with_unserializable_value_in_message_custom(stdout, service_name):
class Unserializable:
pass

# GIVEN a custom json_default
logger = Logger(level="DEBUG", stream=stdout, json_default=lambda o: f"<non-serializable: {type(o).__name__}>")
logger = Logger(
service=service_name,
level="DEBUG",
stream=stdout,
json_default=lambda o: f"<non-serializable: {type(o).__name__}>",
)

# WHEN we log a message
logger.debug({"x": Unserializable()})
Expand All @@ -118,9 +131,9 @@ class Unserializable:
assert "json_default" not in log_dict


def test_log_dict_key_seq(stdout):
def test_log_dict_key_seq(stdout, service_name):
# GIVEN the default logger configuration
logger = Logger(stream=stdout)
logger = Logger(service=service_name, stream=stdout)

# WHEN logging a message
logger.info("Message")
Expand All @@ -131,9 +144,9 @@ def test_log_dict_key_seq(stdout):
assert ",".join(list(log_dict.keys())[:4]) == "level,location,message,timestamp"


def test_log_dict_key_custom_seq(stdout):
def test_log_dict_key_custom_seq(stdout, service_name):
# GIVEN a logger configuration with log_record_order set to ["message"]
logger = Logger(stream=stdout, log_record_order=["message"])
logger = Logger(service=service_name, stream=stdout, log_record_order=["message"])

# WHEN logging a message
logger.info("Message")
Expand All @@ -144,9 +157,9 @@ def test_log_dict_key_custom_seq(stdout):
assert list(log_dict.keys())[0] == "message"


def test_log_custom_formatting(stdout):
def test_log_custom_formatting(stdout, service_name):
# GIVEN a logger where we have a custom `location`, 'datefmt' format
logger = Logger(stream=stdout, location="[%(funcName)s] %(module)s", datefmt="fake-datefmt")
logger = Logger(service=service_name, stream=stdout, location="[%(funcName)s] %(module)s", datefmt="fake-datefmt")

# WHEN logging a message
logger.info("foo")
Expand All @@ -158,7 +171,7 @@ def test_log_custom_formatting(stdout):
assert log_dict["timestamp"] == "fake-datefmt"


def test_log_dict_key_strip_nones(stdout):
def test_log_dict_key_strip_nones(stdout, service_name):
# GIVEN a logger confirmation where we set `location` and `timestamp` to None
# Note: level, sampling_rate and service can not be suppressed
logger = Logger(stream=stdout, level=None, location=None, timestamp=None, sampling_rate=None, service=None)
Expand All @@ -170,14 +183,15 @@ def test_log_dict_key_strip_nones(stdout):

# THEN the keys should only include `level`, `message`, `service`, `sampling_rate`
assert sorted(log_dict.keys()) == ["level", "message", "sampling_rate", "service"]
assert log_dict["service"] == "service_undefined"


def test_log_dict_xray_is_present_when_tracing_is_enabled(stdout, monkeypatch):
def test_log_dict_xray_is_present_when_tracing_is_enabled(stdout, monkeypatch, service_name):
# GIVEN a logger is initialized within a Lambda function with X-Ray enabled
trace_id = "1-5759e988-bd862e3fe1be46a994272793"
trace_header = f"Root={trace_id};Parent=53995c3f42cd8ad8;Sampled=1"
monkeypatch.setenv(name="_X_AMZN_TRACE_ID", value=trace_header)
logger = Logger(stream=stdout)
logger = Logger(service=service_name, stream=stdout)

# WHEN logging a message
logger.info("foo")
Expand All @@ -190,9 +204,9 @@ def test_log_dict_xray_is_present_when_tracing_is_enabled(stdout, monkeypatch):
monkeypatch.delenv(name="_X_AMZN_TRACE_ID")


def test_log_dict_xray_is_not_present_when_tracing_is_disabled(stdout, monkeypatch):
def test_log_dict_xray_is_not_present_when_tracing_is_disabled(stdout, monkeypatch, service_name):
# GIVEN a logger is initialized within a Lambda function with X-Ray disabled (default)
logger = Logger(stream=stdout)
logger = Logger(service=service_name, stream=stdout)

# WHEN logging a message
logger.info("foo")
Expand All @@ -203,12 +217,12 @@ def test_log_dict_xray_is_not_present_when_tracing_is_disabled(stdout, monkeypat
assert "xray_trace_id" not in log_dict


def test_log_dict_xray_is_updated_when_tracing_id_changes(stdout, monkeypatch):
def test_log_dict_xray_is_updated_when_tracing_id_changes(stdout, monkeypatch, service_name):
# GIVEN a logger is initialized within a Lambda function with X-Ray enabled
trace_id = "1-5759e988-bd862e3fe1be46a994272793"
trace_header = f"Root={trace_id};Parent=53995c3f42cd8ad8;Sampled=1"
monkeypatch.setenv(name="_X_AMZN_TRACE_ID", value=trace_header)
logger = Logger(stream=stdout)
logger = Logger(service=service_name, stream=stdout)

# WHEN logging a message
logger.info("foo")
Expand Down
Loading

0 comments on commit bdb3925

Please sign in to comment.