Skip to content

Commit

Permalink
Add some metal gemm benchark. (huggingface#2471)
Browse files Browse the repository at this point in the history
* Add some metal gemm benchark.

* More benchmarks.
  • Loading branch information
LaurentMazare authored Sep 11, 2024
1 parent afb6575 commit 0cb0bd1
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 0 deletions.
2 changes: 2 additions & 0 deletions candle-metal-kernels/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@ thiserror = "1"
tracing = "0.1.37"

[dev-dependencies]
clap = { version = "4.2.4", features = ["derive"] }
half = { version = "2.3.1", features = [
"num-traits",
"use-intrinsics",
"rand_distr",
] }
anyhow = "1"
rand = "0.8.5"
rand_distr = "0.4.3"
136 changes: 136 additions & 0 deletions candle-metal-kernels/examples/metal_benchmarks.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
use anyhow::Result;
use candle_metal_kernels::GemmDType;
/// This example contains some simple benchmarks so that it's easy to run them in perf etc.
use clap::{Parser, Subcommand};
use half::f16;

fn run_gemm(f32: bool, n: usize) -> Result<()> {
const WARMUP_ITERS: usize = 2;
const MIN_DUR: f64 = 4.;

let device = metal::Device::system_default().unwrap();

let (b, m, n, k) = (1, n, n, n);
let kernels = candle_metal_kernels::Kernels::new();
let command_queue = device.new_command_queue();
let options = metal::MTLResourceOptions::StorageModeManaged;

let (lhs, rhs) = if f32 {
let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
let lhs = device.new_buffer_with_data(
lhs.as_ptr() as *const core::ffi::c_void,
std::mem::size_of_val(&lhs) as u64,
options,
);
let rhs = device.new_buffer_with_data(
rhs.as_ptr() as *const core::ffi::c_void,
std::mem::size_of_val(&rhs) as u64,
options,
);
(lhs, rhs)
} else {
let lhs: Vec<f16> = (0..b * m * k).map(|f| f16::from_f32(f as f32)).collect();
let rhs: Vec<f16> = (0..b * n * k).map(|f| f16::from_f32(f as f32)).collect();
let lhs = device.new_buffer_with_data(
lhs.as_ptr() as *const core::ffi::c_void,
std::mem::size_of_val(&lhs) as u64,
options,
);
let rhs = device.new_buffer_with_data(
rhs.as_ptr() as *const core::ffi::c_void,
std::mem::size_of_val(&rhs) as u64,
options,
);
(lhs, rhs)
};
let (dtype, name, sizeof) = if f32 {
(GemmDType::F32, "sgemm", core::mem::size_of::<f32>())
} else {
(GemmDType::F16, "hgemm", core::mem::size_of::<f16>())
};
let output = device.new_buffer((b * m * n * sizeof) as u64, options);

for mlx in [false, true] {
let mut sum_dt = 0f64;
let mut iters = 0usize;
for idx in 0.. {
let command_buffer = command_queue.new_command_buffer();
let start_time = std::time::Instant::now();
if mlx {
candle_metal_kernels::call_mlx_gemm(
&device,
command_buffer,
&kernels,
dtype,
(b, m, n, k),
&[m * k, k, 1],
0,
&lhs,
&[n * k, n, 1],
0,
&rhs,
&output,
)?;
} else {
candle_metal_kernels::call_gemm(
&device,
command_buffer,
&kernels,
name,
(b, m, n, k),
&[m * k, k, 1],
0,
&lhs,
&[n * k, n, 1],
0,
&rhs,
&output,
)?;
}
command_buffer.commit();
command_buffer.wait_until_completed();
let dt = start_time.elapsed().as_secs_f64();
if idx < WARMUP_ITERS {
continue;
}
sum_dt += dt;
iters += 1;
if sum_dt > MIN_DUR {
break;
}
}
let gflops = (2 * n * n * n * iters) as f64 / (1e9 * sum_dt);
let mlx = if mlx { "MLX" } else { "MFA" };
println!("{mlx} {dtype:?}, {n:6} gflops {gflops:.0}");
}

Ok(())
}

#[derive(Subcommand, Debug, Clone)]
enum Task {
Gemm,
}

#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
pub struct Args {
/// The benchmark to be run.
#[command(subcommand)]
task: Task,
}

fn main() -> Result<()> {
let args = Args::parse();
match args.task {
Task::Gemm => {
for f32 in [false, true] {
for n in [512, 1024, 2048, 4096] {
run_gemm(f32, n)?;
}
}
}
}
Ok(())
}

0 comments on commit 0cb0bd1

Please sign in to comment.