forked from meta-llama/llama-recipes
-
Notifications
You must be signed in to change notification settings - Fork 1
/
memory_utils.py
62 lines (51 loc) · 2.3 KB
/
memory_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
import gc
import psutil
import threading
import torch
def byte2gb(x):
return int(x / 2**30)
# This context manager is used to track the peak memory usage of the process
class MemoryTrace:
def __enter__(self):
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated() # reset the peak gauge to zero
self.begin = byte2gb(torch.cuda.memory_allocated())
self.process = psutil.Process()
self.cpu_begin = byte2gb(self.cpu_mem_used())
self.peak_monitoring = True
peak_monitor_thread = threading.Thread(target=self.peak_monitor_func)
peak_monitor_thread.daemon = True
peak_monitor_thread.start()
return self
def cpu_mem_used(self):
"""get resident set size memory for the current process"""
return self.process.memory_info().rss
def peak_monitor_func(self):
self.cpu_peak = -1
while True:
self.cpu_peak = max(self.cpu_mem_used(), self.cpu_peak)
# can't sleep or will not catch the peak right (this comment is here on purpose)
# time.sleep(0.001) # 1msec
if not self.peak_monitoring:
break
def __exit__(self, *exc):
self.peak_monitoring = False
gc.collect()
torch.cuda.empty_cache()
self.end = byte2gb(torch.cuda.memory_allocated())
self.peak = byte2gb(torch.cuda.max_memory_allocated())
cuda_info = torch.cuda.memory_stats()
self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"])
self.cuda_malloc_retires = cuda_info.get("num_alloc_retries", 0)
self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"])
self.m_cuda_ooms = cuda_info.get("num_ooms", 0)
self.used = byte2gb(self.end - self.begin)
self.peaked = byte2gb(self.peak - self.begin)
self.max_reserved = byte2gb(torch.cuda.max_memory_reserved())
self.cpu_end = self.cpu_mem_used()
self.cpu_used = byte2gb(self.cpu_end - self.cpu_begin)
self.cpu_peaked = byte2gb(self.cpu_peak - self.cpu_begin)
# print(f"delta used/peak {self.used:4d}/{self.peaked:4d}")