Skip to content

Commit

Permalink
Migrate matmul autotune to macro and fix accelerated (#2584)
Browse files Browse the repository at this point in the history
* Migrate matmul autotune to macro and fix accelerated by checking for CMMA availability first

* Set max anchor on batch
  • Loading branch information
wingertge authored Dec 11, 2024
1 parent c5c6022 commit ebd7649
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 199 deletions.
207 changes: 74 additions & 133 deletions crates/burn-jit/src/kernel/matmul/tune/base.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use core::marker::PhantomData;

use burn_tensor::{Element, ElementConversion};
use cubecl::{
ir::{Elem, FloatKind},
linalg::matmul::{kernels::tiling2d::Tiling2dConfig, Strategy},
tune::{local_tuner, AutotuneOperation, AutotuneOperationSet, LocalTuner},
tune,
tune::{local_tuner, tune_with, LocalTuner},
Feature,
};

use crate::{
Expand All @@ -15,73 +16,45 @@ use crate::{
JitRuntime, JitTuneId,
};

use super::key::MatmulAutotuneKey;
use super::key::create_key;

/// Set of matmul implementations available for autotune
/// Autotune key is given by concatenating the closest upper power of 2 of m, k and n
pub struct MatmulAutotuneOperationSet<R: JitRuntime, E: FloatElement> {
#[tune(
operations(matmul_tiling2d, matmul_accelerated, matmul_simple),
create_key = create_key::<R, E>,
should_run = should_run
)]
fn matmul_ops<R: JitRuntime, E: FloatElement>(
key: JitAutotuneKey,
lhs: JitTensor<R>,
rhs: JitTensor<R>,
out: JitTensor<R>,
_e: PhantomData<E>,
}
impl<R: JitRuntime, E: FloatElement> MatmulAutotuneOperationSet<R, E> {
fn new(lhs: JitTensor<R>, rhs: JitTensor<R>, out: JitTensor<R>) -> Self {
Self {
key: JitAutotuneKey::Matmul(MatmulAutotuneKey::new(&lhs.shape, &rhs.shape, E::dtype())),
lhs,
rhs,
out,
_e: PhantomData,
}
}
}
) {
let random_bounds: (E, E) = ((-10.0).elem::<E>(), (10.0).elem::<E>());
let lhs = random_like_uniform(lhs, random_bounds.0, random_bounds.1);
let rhs = random_like_uniform(rhs, random_bounds.0, random_bounds.1);

impl<R: JitRuntime, E: FloatElement> AutotuneOperationSet<JitAutotuneKey>
for MatmulAutotuneOperationSet<R, E>
{
fn key(&self) -> JitAutotuneKey {
self.key.clone()
}

fn autotunables(&self) -> Vec<Box<dyn AutotuneOperation>> {
let random_bounds: (E, E) = ((-10.0).elem::<E>(), (10.0).elem::<E>());
let lhs = random_like_uniform(&self.lhs, random_bounds.0, random_bounds.1);
let rhs = random_like_uniform(&self.rhs, random_bounds.0, random_bounds.1);
let out = empty_device::<R, E>(out.client.clone(), out.device.clone(), out.shape.clone());

let out = empty_device::<R, E>(
self.out.client.clone(),
self.out.device.clone(),
self.out.shape.clone(),
);

vec![
Box::new(MatmulTiling2d::<R, E>::new(
lhs.clone(),
rhs.clone(),
out.clone(),
)),
Box::new(MatmulAccelerated::<R, E>::new(
lhs.clone(),
rhs.clone(),
out.clone(),
)),
Box::new(MatmulSimple::<R, E>::new(
lhs.clone(),
rhs.clone(),
out.clone(),
)),
]
}
tune_with!(lhs, rhs, out)
}

fn fastest(self: Box<Self>, fastest_index: usize) -> Box<dyn AutotuneOperation> {
match fastest_index {
0 => Box::new(MatmulTiling2d::<R, E>::new(self.lhs, self.rhs, self.out)),
1 => Box::new(MatmulAccelerated::<R, E>::new(self.lhs, self.rhs, self.out)),
2 => Box::new(MatmulSimple::<R, E>::new(self.lhs, self.rhs, self.out)),
_ => panic!("Fastest index is out of bound"),
}
fn should_run<R: JitRuntime, E: FloatElement>(
op: &MatmulOps<R, E>,
_key: &JitAutotuneKey,
index: usize,
) -> bool {
match index {
// Accelerated
// TODO: Add way to query actual requirements from cubecl
1 => op.lhs.client.properties().feature_enabled(Feature::Cmma {
a: Elem::Float(FloatKind::F16),
b: Elem::Float(FloatKind::F16),
c: Elem::Float(FloatKind::F32),
m: 16,
k: 16,
n: 16,
}),
_ => true,
}
}

Expand All @@ -100,82 +73,50 @@ pub fn matmul_autotune<R: JitRuntime, E: FloatElement + Element>(
TUNER.execute(
&JitTuneId::new::<R>(&lhs.device),
&client,
Box::new(MatmulAutotuneOperationSet::<R, E>::new(
lhs,
rhs,
output.clone(),
)),
Box::new(MatmulOps::<R, E>::new(lhs, rhs, output.clone())),
);

output
}

macro_rules! matmul_tune_ops {
($name:ident, $func:expr) => {
#[derive(new, Debug)]
pub(crate) struct $name<R: JitRuntime, E: FloatElement> {
lhs: JitTensor<R>,
rhs: JitTensor<R>,
out: JitTensor<R>,
_e: PhantomData<E>,
}

impl<R: JitRuntime, E: FloatElement> AutotuneOperation for $name<R, E> {
fn execute(self: Box<Self>) {
#[allow(clippy::redundant_closure_call)]
$func(self.lhs, self.rhs, self.out);
}

fn clone(&self) -> Box<dyn AutotuneOperation> {
Box::new(Self {
lhs: self.lhs.clone(),
rhs: self.rhs.clone(),
out: self.out.clone(),
_e: self._e,
})
}
}
};
fn matmul_accelerated<R: JitRuntime, E: FloatElement>(
lhs: JitTensor<R>,
rhs: JitTensor<R>,
out: JitTensor<R>,
) {
cubecl::linalg::matmul::launch_ref::<R, E>(
&Strategy::Accelerated,
&lhs.client,
&lhs.as_handle_ref(),
&rhs.as_handle_ref(),
&out.as_handle_ref(),
);
}

// Probably the fastest in the general case.
matmul_tune_ops!(
MatmulAccelerated,
|lhs: JitTensor<R>, rhs: JitTensor<R>, out: JitTensor<R>| {
cubecl::linalg::matmul::launch_ref::<R, E>(
&Strategy::Accelerated,
&lhs.client,
&lhs.as_handle_ref(),
&rhs.as_handle_ref(),
&out.as_handle_ref(),
);
}
);

// Probably the fastest when tensor cores are not available.
matmul_tune_ops!(
MatmulTiling2d,
|lhs: JitTensor<R>, rhs: JitTensor<R>, out: JitTensor<R>| {
cubecl::linalg::matmul::launch_ref::<R, E>(
&Strategy::Tiling2D(Tiling2dConfig::default()),
&lhs.client,
&lhs.as_handle_ref(),
&rhs.as_handle_ref(),
&out.as_handle_ref(),
);
}
);
fn matmul_tiling2d<R: JitRuntime, E: FloatElement>(
lhs: JitTensor<R>,
rhs: JitTensor<R>,
out: JitTensor<R>,
) {
cubecl::linalg::matmul::launch_ref::<R, E>(
&Strategy::Tiling2D(Tiling2dConfig::default()),
&lhs.client,
&lhs.as_handle_ref(),
&rhs.as_handle_ref(),
&out.as_handle_ref(),
);
}

// Probably the fastest for small matrices.
matmul_tune_ops!(
MatmulSimple,
|lhs: JitTensor<R>, rhs: JitTensor<R>, out: JitTensor<R>| {
cubecl::linalg::matmul::launch_ref::<R, E>(
&Strategy::Simple,
&lhs.client,
&lhs.as_handle_ref(),
&rhs.as_handle_ref(),
&out.as_handle_ref(),
);
}
);
fn matmul_simple<R: JitRuntime, E: FloatElement>(
lhs: JitTensor<R>,
rhs: JitTensor<R>,
out: JitTensor<R>,
) {
cubecl::linalg::matmul::launch_ref::<R, E>(
&Strategy::Simple,
&lhs.client,
&lhs.as_handle_ref(),
&rhs.as_handle_ref(),
&out.as_handle_ref(),
);
}
84 changes: 37 additions & 47 deletions crates/burn-jit/src/kernel/matmul/tune/key.rs
Original file line number Diff line number Diff line change
@@ -1,42 +1,28 @@
use crate::tune::anchor;
use crate::{tensor::JitTensor, FloatElement, JitAutotuneKey, JitRuntime};
use burn_tensor::{DType, Shape};
use core::fmt::Debug;
use cubecl::AutotuneKey;
use serde::{Deserialize, Serialize};
use std::{cmp::max, fmt::Display, hash::Hash};
use std::{cmp::max, hash::Hash};

#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize)]
#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)]
/// Autotune key representative of matmul versions
pub struct MatmulAutotuneKey {
round: bool, // True when all matmul dims are multiples of 64
broadcast: bool, // True when there are differences in batch size
anchored_m: usize,
anchored_k: usize,
anchored_n: usize,
anchored_batch: usize,
#[autotune(anchor)]
m: usize,
#[autotune(anchor)]
k: usize,
#[autotune(anchor)]
n: usize,
#[autotune(anchor(max = 256))]
batch: usize,
dtype: DType,
}

impl Display for MatmulAutotuneKey {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_str(
format!(
"Matmul - Round:{:?} Broadcast:{:?} m:{:?} k:{:?} n:{:?} batch:{:?} dtype:{:?}",
self.round,
self.broadcast,
self.anchored_m,
self.anchored_k,
self.anchored_n,
self.anchored_batch,
self.dtype
)
.as_str(),
)
}
}

impl MatmulAutotuneKey {
/// Create a matmul autotune key from the input shapes
pub fn new(lhs_shape: &Shape, rhs_shape: &Shape, dtype: DType) -> Self {
fn from_shape(lhs_shape: &Shape, rhs_shape: &Shape, dtype: DType) -> Self {
let ndims = lhs_shape.num_dims();
let m = lhs_shape.dims[ndims - 2];
let k = lhs_shape.dims[ndims - 1];
Expand All @@ -57,18 +43,22 @@ impl MatmulAutotuneKey {

let round = m % 64 == 0 && k % 64 == 0 && n % 64 == 0;

Self {
round,
broadcast,
anchored_m: anchor(m, None),
anchored_k: anchor(k, None),
anchored_n: anchor(n, None),
anchored_batch: anchor(batch_product, Some(256)),
dtype,
}
Self::new(round, broadcast, m, k, n, batch_product, dtype)
}
}

pub(crate) fn create_key<R: JitRuntime, E: FloatElement>(
lhs: &JitTensor<R>,
rhs: &JitTensor<R>,
_out: &JitTensor<R>,
) -> JitAutotuneKey {
JitAutotuneKey::Matmul(MatmulAutotuneKey::from_shape(
&lhs.shape,
&rhs.shape,
E::dtype(),
))
}

#[cfg(test)]
mod tests {
use super::*;
Expand All @@ -77,35 +67,35 @@ mod tests {
fn matmul_autotune_key_all_same_and_round() {
let lhs_shape: Shape = [4, 512, 512].into();
let rhs_shape: Shape = [4, 512, 512].into();
let key = MatmulAutotuneKey::new(&lhs_shape, &rhs_shape, DType::F32);
let key = MatmulAutotuneKey::from_shape(&lhs_shape, &rhs_shape, DType::F32);

assert!(key.round);
assert!(!key.broadcast);
assert!(key.anchored_m == 512);
assert!(key.anchored_k == 512);
assert!(key.anchored_n == 512);
assert_eq!(key.m, 512);
assert_eq!(key.k, 512);
assert_eq!(key.n, 512);
}

#[test]
fn matmul_autotune_key_all_different() {
let lhs_shape: Shape = [2, 3, 511, 512].into();
let rhs_shape: Shape = [3, 2, 512, 513].into();
let key = MatmulAutotuneKey::new(&lhs_shape, &rhs_shape, DType::F32);
let key = MatmulAutotuneKey::from_shape(&lhs_shape, &rhs_shape, DType::F32);

assert!(!key.round);
assert!(key.broadcast);
assert!(key.anchored_m == 512);
assert!(key.anchored_k == 512);
assert!(key.anchored_n == 1024);
assert!(key.anchored_batch == 8);
assert_eq!(key.m, 512);
assert_eq!(key.k, 512);
assert_eq!(key.n, 1024);
assert_eq!(key.batch, 8);
}

#[test]
fn matmul_autotune_key_large_batch() {
let lhs_shape: Shape = [128, 512, 511, 512].into();
let rhs_shape: Shape = [200, 400, 512, 513].into();
let key = MatmulAutotuneKey::new(&lhs_shape, &rhs_shape, DType::F32);
let key = MatmulAutotuneKey::from_shape(&lhs_shape, &rhs_shape, DType::F32);

assert!(key.anchored_batch == 256);
assert_eq!(key.batch, 256);
}
}
2 changes: 1 addition & 1 deletion crates/burn-jit/src/kernel/matmul/tune/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@ mod base;
mod key;

#[cfg(feature = "autotune")]
pub use base::*;
pub use base::matmul_autotune;
pub use key::*;
2 changes: 0 additions & 2 deletions crates/burn-jit/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@ pub mod kernel;
/// Tensor module.
pub mod tensor;

pub(crate) mod tune;

/// Elements for JIT backend
pub mod element;

Expand Down
Loading

0 comments on commit ebd7649

Please sign in to comment.