diff --git a/src/lightning/pytorch/profilers/advanced.py b/src/lightning/pytorch/profilers/advanced.py index 41681fbd239f3..38a67b4dca25d 100644 --- a/src/lightning/pytorch/profilers/advanced.py +++ b/src/lightning/pytorch/profilers/advanced.py @@ -19,6 +19,7 @@ import os import pstats import tempfile +from collections import defaultdict from pathlib import Path from typing import Optional, Union @@ -66,14 +67,15 @@ def __init__( If you attempt to stop recording an action which was never started. """ super().__init__(dirpath=dirpath, filename=filename) - self.profiled_actions: dict[str, cProfile.Profile] = {} + self.profiled_actions: dict[str, cProfile.Profile] = defaultdict(cProfile.Profile) self.line_count_restriction = line_count_restriction self.dump_stats = dump_stats @override def start(self, action_name: str) -> None: - if action_name not in self.profiled_actions: - self.profiled_actions[action_name] = cProfile.Profile() + # Disable all profilers before starting a new one + for pr in self.profiled_actions.values(): + pr.disable() self.profiled_actions[action_name].enable() @override @@ -114,7 +116,7 @@ def summary(self) -> str: @override def teardown(self, stage: Optional[str]) -> None: super().teardown(stage=stage) - self.profiled_actions = {} + self.profiled_actions = defaultdict(cProfile.Profile) def __reduce__(self) -> tuple: # avoids `TypeError: cannot pickle 'cProfile.Profile' object`