Skip to content

Commit

Permalink
feat: change figure style
Browse files Browse the repository at this point in the history
  • Loading branch information
lausannel committed Sep 20, 2024
1 parent 4e7b994 commit 9853d86
Showing 1 changed file with 17 additions and 16 deletions.
33 changes: 17 additions & 16 deletions tools/plot_mem.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,8 @@ def parse_args() -> argparse.Namespace:
The script will analyze the buffer assignment file and generate a plot named "tensor_lifecycle.png".
''')
parser.add_argument('--input',
type=str,
help='The file to be parsed',
required=True)
parser.add_argument(
'--input', type=str, help='The file to be parsed', required=True)
args = parser.parse_args()
args.input = args.input.strip()

Expand Down Expand Up @@ -149,8 +147,8 @@ def get_allocation_group_by_size(allocations: dict) -> dict:
'''
allocations_group_by_size = {}
index = 0
sorted_allocations = sorted(allocations.items(),
key=lambda item: int(item[0]))
sorted_allocations = sorted(
allocations.items(), key=lambda item: int(item[0]))

prev_size = None

Expand All @@ -172,6 +170,7 @@ def get_allocation_group_by_size(allocations: dict) -> dict:

return allocations_group_by_size


def plot_tensor_lifecycle(tensors: dict, allocations: dict):
'''
Plots the lifecycle of tensors and buffer usage over time.
Expand Down Expand Up @@ -211,8 +210,8 @@ def plot_tensor_lifecycle(tensors: dict, allocations: dict):
continue

allocations_group_by_size = get_allocation_group_by_size(allocations)
for index, allocation_info in sorted(allocations_group_by_size.items(),
key=lambda x: x[0]):
for index, allocation_info in sorted(
allocations_group_by_size.items(), key=lambda x: x[0]):
size = allocation_info['size']
start = allocation_info['start']
count = allocation_info['count']
Expand Down Expand Up @@ -264,14 +263,16 @@ def plot_tensor_lifecycle(tensors: dict, allocations: dict):
ax.set_yticklabels([f'{y /1024 / 1024 / 1024 :.2f} GB' for y in y_ticks])

ax.axhline(y=max_buffer_size, color='red', linestyle='--', alpha=0.7)
ax.text(1.01,
max_buffer_size,
f'Allocated Size: {max_buffer_size/1024/1024/1024:.2f} GB',
transform=ax.get_yaxis_transform(),
ha='left',
va='center',
color='red',
fontsize=8)

ax.annotate(
f'Allocated Size: {max_buffer_size/1024/1024/1024:.2f} GB',
xy=(max_time, max_buffer_size),
xytext=(max_time, max_buffer_size + max_buffer_size * 0.1),
arrowprops=dict(facecolor='red', shrink=0.05),
fontsize=8,
ha='center',
va='bottom',
)

ax.grid(True, linestyle='--', alpha=0.7)
plt.tight_layout()
Expand Down

0 comments on commit 9853d86

Please sign in to comment.