diff --git a/benchmark/python/benchmark_e2e.py b/benchmark/python/benchmark_e2e.py index f5c3ace8c..8cd99f20d 100644 --- a/benchmark/python/benchmark_e2e.py +++ b/benchmark/python/benchmark_e2e.py @@ -44,7 +44,7 @@ def monitor_gpu_memory(): memory_usage = result.stdout.splitlines() - if len(memory_usage) > 1: + if len(memory_usage) >= 1: gpu_memory = [float(line) for line in memory_usage] current_peak = round(max(gpu_memory) / 1024, 2) with peak_memory_lock: @@ -137,7 +137,7 @@ def save_results(args, results, filename, print_memory_usage=False): if IS_NVIDIA_SYSTEM: columns.append("peak_gpu_memory (GiB)") else: - columns.append("peak_cpu_memory(GiB)") + columns.append("peak_cpu_memory (GiB)") df = pd.DataFrame( results, @@ -165,6 +165,12 @@ def save_results(args, results, filename, print_memory_usage=False): record.metrics.customized["sampling_latency_ms"] = row["Sampling Latency (ms)"] record.metrics.customized["wall_clock_throughput_tps"] = row["Wall Clock Throughput (tps)"] record.metrics.customized["wall_clock_time_s"] = row["Wall Clock Time (s)"] + + if print_memory_usage: + if IS_NVIDIA_SYSTEM: + record.metrics.customized["peak_gpu_memory_gb"] = row["peak_gpu_memory (GiB)"] + else: + record.metrics.customized["peak_cpu_memory_gb"] = row["peak_cpu_memory (GiB)"] records.append(record) @@ -178,6 +184,13 @@ def run_benchmark_memory(args, batch_size, prompt_length, generation_length, max This function is to run benchmark and print the momory usage """ global stop_monitoring + global peak_gpu_memory + global peak_cpu_memory + + # Reset the peak memory variables and the monitoring flag + stop_monitoring = False + peak_gpu_memory = 0.0 + peak_cpu_memory = 0.0 if IS_NVIDIA_SYSTEM: monitor_thread = threading.Thread(target=monitor_gpu_memory) @@ -226,8 +239,9 @@ def run_benchmark(args, batch_size, prompt_length, generation_length, max_length if args.verbose: print("Running warmup runs...") for _ in tqdm(range(args.warmup)): generator = og.Generator(model, params) - generator.compute_logits() - generator.generate_next_token() + while not generator.is_done(): + generator.compute_logits() + generator.generate_next_token() if args.print_model_output: print(tokenizer.decode(generator.get_sequence(0))) # Delete the generator to free the captured graph for the next generator, if graph capture is enabled del generator @@ -241,9 +255,6 @@ def run_benchmark(args, batch_size, prompt_length, generation_length, max_length for _ in tqdm(range(num_repetitions)): wall_clock_start_time = time.time() - # Prepare run - generator = og.Generator(model, params) - # Measure tokenization tokenize_start_time = time.perf_counter() tokens = tokenizer.encode_batch(prompt) @@ -329,6 +340,12 @@ def run_benchmark(args, batch_size, prompt_length, generation_length, max_length print(f"Average Wall Clock Time: {avg_wall_clock_time} s") print(f"Average Wall Clock Throughput: {avg_wall_clock_thrpt} tps") + if args.print_memory_usage: + if IS_NVIDIA_SYSTEM: + print(f"Peak GPU Memory Usage: {peak_gpu_memory} GiB ") + else: + print(f"Peak CPU Memory Usage: {peak_cpu_memory} GiB ") + metrics = [ batch_size, prompt_length, @@ -359,7 +376,7 @@ def main(args): max_length = args.max_lengths[0] if len(args.max_lengths) == 1 else args.max_lengths[m] else: max_length = prompt_length + gen_length - print(f"Args: batch_size = {batch_size}, prompt_length = {prompt_length}, tokens = {gen_length}, max_length = {max_length}") + print(f"\nArgs: batch_size = {batch_size}, prompt_length = {prompt_length}, tokens = {gen_length}, max_length = {max_length}") if args.print_memory_usage: metrics = run_benchmark_memory(args, batch_size, prompt_length, gen_length, max_length) else: @@ -370,10 +387,6 @@ def main(args): filename = args.output if args.print_memory_usage: - if IS_NVIDIA_SYSTEM: - print(f"-------------------* Peak GPU Memory Usage: {peak_gpu_memory} GiB *-------------------") - else: - print(f"-------------------* Peak CPU Memory Usage: {peak_cpu_memory} GiB *-------------------") save_results(args, all_csv_metrics, filename, print_memory_usage=True) else: save_results(args, all_csv_metrics, filename)