From c694a41dba248df7f0d7ff121ecc6cca94f6bbce Mon Sep 17 00:00:00 2001 From: luzhan Date: Fri, 20 Sep 2024 18:06:27 +0800 Subject: [PATCH] feat: add legend --- tools/plot_mem.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/tools/plot_mem.py b/tools/plot_mem.py index 4a899e8..72dc4ed 100644 --- a/tools/plot_mem.py +++ b/tools/plot_mem.py @@ -229,7 +229,7 @@ def plot_tensor_lifecycle(tensors: dict, allocations: dict): ax.text( 1.02, start + size * count / 2, - f"Count: {count}\nSize: {size} bytes ({size/1024/1024/1024:.2f} GB)\nTotal: {size * count} bytes ({size * count/1024/1024/1024:.2f} GB)", + f"Buffer Count: {count}\nSize: {size} bytes ({size/1024/1024/1024:.2f} GB)\nTotal: {size * count} bytes ({size * count/1024/1024/1024:.2f} GB)", fontsize=6, ha='left', va='center', @@ -241,7 +241,23 @@ def plot_tensor_lifecycle(tensors: dict, allocations: dict): times = sorted(buffer_usage.keys()) usages = [buffer_usage[t] for t in times] - ax.plot(times, usages, color='Slateblue', linewidth=2, alpha=0.6) + ax.plot( + times, + usages, + color='Slateblue', + linewidth=2, + alpha=0.6, + label='Actual Total Tensor Mem Usage') + ax.plot([], [], + color='skyblue', + linewidth=2, + alpha=0.6, + label='Tensor Lifetime and Mem Usage') + ax.plot([], [], + color='red', + linewidth=2, + alpha=0.6, + label='Buffer Allocation') ax.annotate( f"Time: {peak_time}\nActual Usage: {peak_usage} bytes ({peak_usage/1024/1024/1024:.2f} GB)", @@ -256,7 +272,7 @@ def plot_tensor_lifecycle(tensors: dict, allocations: dict): ax.set_ylim(0, max(max_buffer_size, peak_usage * 1.4)) ax.set_xlabel('Time') ax.set_ylabel('Buffer Size (bytes)') - ax.set_title('Tensor Lifecycle and Buffer Usage') + ax.set_title('Tensor Lifecycle and Buffer Usage', pad=10) y_ticks = ax.get_yticks() ax.set_yticks(y_ticks) @@ -276,7 +292,7 @@ def plot_tensor_lifecycle(tensors: dict, allocations: dict): ax.grid(True, linestyle='--', alpha=0.7) plt.tight_layout() - + plt.legend(fontsize=6) plt.savefig('tensor_lifecycle.png', dpi=1200) print("[INFO] tensor_lifecycle.png saved successfully.")