Skip to content

Commit

Permalink
Cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard committed Dec 17, 2024
1 parent 9e392d6 commit 66aca7d
Show file tree
Hide file tree
Showing 9 changed files with 117 additions and 79 deletions.
5 changes: 3 additions & 2 deletions backend-comparison/benches/matmul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,10 @@ fn bench<B: Backend>(
token: Option<&str>,
) {
let benchmarks = [
// (3, 4096, 4096, 4096),
(2, 4096, 4096, 4096),
(32, 2048, 2048, 2048),
// (2, 4096, 4096, 512),
(256, 1024, 1024, 1024),
(1024, 256, 256, 256),
]
.into_iter()
.map(|(b, m, n, k)| {
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-autodiff/src/tests/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ mod tests {
let tensor_2 = tensor_1.clone().matmul(tensor_1.clone().transpose());
let tensor_3 = tensor_1.clone().select(0, indices);
let tensor_4 = tensor_2.matmul(tensor_3);
// panic!("Tensor 4 {}", tensor_4);

let grads = tensor_4.backward();

let grad_1 = tensor_1.grad(&grads).unwrap();
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-jit/src/fusion/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ pub enum JitOptimization<R: JitRuntime> {
pub enum JitOptimizationState {
/// Element wise state.
ElementWise(ElemwiseOptimizationState),
/// Matrix multiplication optimizatio state.
/// Matrix multiplication optimization state.
Matmul(MatmulOptimizationState),
}

Expand Down
95 changes: 47 additions & 48 deletions crates/burn-jit/src/fusion/matmul/args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,52 +6,8 @@ use crate::fusion::on_write::{
kernel::fuse_on_write,
};

pub struct FusedMatmulState {
inputs: *const GlobalArgs,
outputs: *mut GlobalArgs,
config: ElemwiseConfig,
lhs: Arg,
rhs: Arg,
out: Arg,
}

#[cube]
impl FusedMatmulState {
pub fn new(
inputs: &FusedMatmulInput,
outputs: &mut GlobalArgs,
#[comptime] config: &ElemwiseConfig,
) -> FusedMatmulState {
FusedMatmulState {
inputs: &inputs.global,
outputs,
config: comptime![config.clone()],
lhs: comptime![inputs.lhs],
rhs: comptime![inputs.rhs],
out: comptime![inputs.out],
}
}
}

#[derive(Clone)]
pub struct FusedMatmulStateExpand {
inputs: GlobalArgsExpand,
outputs: GlobalArgsExpand,
config: ElemwiseConfig,
lhs: Arg,
rhs: Arg,
out: Arg,
}

impl CubeType for FusedMatmulState {
type ExpandType = FusedMatmulStateExpand;
}

impl Init for FusedMatmulStateExpand {
fn init(self, _context: &mut CubeContext) -> Self {
self
}
}
pub struct FusedMatmulArgs;

#[derive(CubeLaunch)]
pub struct FusedMatmulInput {
Expand All @@ -66,9 +22,6 @@ pub struct FusedMatmulInput {
out: Arg,
}

#[derive(Clone)]
pub struct FusedMatmulArgs;

#[cube]
impl<EG: Numeric> MatmulArgs<EG> for FusedMatmulArgs {
type Output = GlobalArgs;
Expand Down Expand Up @@ -248,3 +201,49 @@ impl<EG: Numeric> MatmulArgs<EG> for FusedMatmulArgs {
}
}
}

pub struct FusedMatmulState {
inputs: *const GlobalArgs,
outputs: *mut GlobalArgs,
config: ElemwiseConfig,
lhs: Arg,
rhs: Arg,
out: Arg,
}
#[cube]
impl FusedMatmulState {
pub fn new(
inputs: &FusedMatmulInput,
outputs: &mut GlobalArgs,
#[comptime] config: &ElemwiseConfig,
) -> FusedMatmulState {
FusedMatmulState {
inputs: &inputs.global,
outputs,
config: comptime![config.clone()],
lhs: comptime![inputs.lhs],
rhs: comptime![inputs.rhs],
out: comptime![inputs.out],
}
}
}

#[derive(Clone)]
pub struct FusedMatmulStateExpand {
inputs: GlobalArgsExpand,
outputs: GlobalArgsExpand,
config: ElemwiseConfig,
lhs: Arg,
rhs: Arg,
out: Arg,
}

impl CubeType for FusedMatmulState {
type ExpandType = FusedMatmulStateExpand;
}

impl Init for FusedMatmulStateExpand {
fn init(self, _context: &mut CubeContext) -> Self {
self
}
}
32 changes: 21 additions & 11 deletions crates/burn-jit/src/fusion/matmul/optimization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ impl<R: JitRuntime> MatmulOptimization<R> {
fn execute_fused<BT: BoolElement>(
&mut self,
context: &mut Context<'_, JitFusionHandle<R>>,
) -> Result<(), MatmulLaunchError> {
) -> Result<(), FusedMatmulError> {
self.trace
.run::<R, BT, FusedMatmul>(&self.client, &self.device, context, &self.matmul)
}
Expand Down Expand Up @@ -141,16 +141,28 @@ pub struct FusedMatmul {
op: BinaryOperationDescription,
}

#[derive(Debug)]
pub enum FusedMatmulError {
LaunchError(MatmulLaunchError),
InvalidInput,
}

impl From<MatmulLaunchError> for FusedMatmulError {
fn from(value: MatmulLaunchError) -> Self {
Self::LaunchError(value)
}
}

impl<R: JitRuntime> TraceRunner<R> for FusedMatmul {
type Error = MatmulLaunchError;
type Error = FusedMatmulError;

fn run<'a>(
&'a self,
client: &'a ComputeClient<R::Server, R::Channel>,
inputs: GlobalArgsLaunch<'a, R>,
outputs: GlobalArgsLaunch<'a, R>,
config: &'a ElemwiseConfig,
) -> Result<(), MatmulLaunchError> {
) -> Result<(), FusedMatmulError> {
match self.out.precision() {
ElemwisePrecision::F32 => self.matmul_fused::<R, f32>(client, inputs, outputs, config),
ElemwisePrecision::F16 => self.matmul_fused::<R, f16>(client, inputs, outputs, config),
Expand All @@ -169,7 +181,7 @@ impl FusedMatmul {
inputs: GlobalArgsLaunch<'a, R>,
outputs: GlobalArgsLaunch<'a, R>,
config: &'a ElemwiseConfig,
) -> Result<(), MatmulLaunchError> {
) -> Result<(), FusedMatmulError> {
let lhs_shape = inputs.shape(&self.lhs);
let rhs_shape = inputs.shape(&self.rhs);

Expand All @@ -189,9 +201,7 @@ impl FusedMatmul {
let (rhs_make_contiguous, rhs_transposed) = check_layout(rhs_strides);

if lhs_make_contiguous || rhs_make_contiguous {
return Err(MatmulLaunchError::Unavailable(
MatmulAvailabilityError::PlaneDimUnknown,
));
return Err(FusedMatmulError::InvalidInput);
}

let rank = lhs_shape.len();
Expand All @@ -209,9 +219,7 @@ impl FusedMatmul {
};

if out_line_size == 1 && (lhs_line_size > 1 || rhs_line_size > 1) {
return Err(MatmulLaunchError::Unavailable(
MatmulAvailabilityError::PlaneDimUnknown,
));
return Err(FusedMatmulError::InvalidInput);
}

let problem = MatmulProblem {
Expand Down Expand Up @@ -261,7 +269,9 @@ impl FusedMatmul {
None => Err(MatmulLaunchError::Unavailable(
MatmulAvailabilityError::PlaneDimUnknown,
)),
}
}?;

Ok(())
}
}

Expand Down
48 changes: 42 additions & 6 deletions crates/burn-jit/src/fusion/on_write/ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,22 +115,37 @@ pub struct GlobalArgs {
}

impl<'a, R: Runtime> GlobalArgsLaunch<'a, R> {
/// Get the shape of the given [argument](Arg).
///
/// # Panics
///
/// If the argument doesn't have an handle.
pub fn shape(&self, arg: &Arg) -> &[usize] {
match self.handle(arg) {
match self.resolve_arg(arg) {
TensorArg::Handle { handle, .. } => handle.shape,
TensorArg::Alias { .. } => panic!("Unsupported yet"),
}
}

/// Get the strides of the given [argument](Arg).
///
/// # Panics
///
/// If the argument doesn't have an handle.
pub fn strides(&self, arg: &Arg) -> &[usize] {
match self.handle(arg) {
match self.resolve_arg(arg) {
TensorArg::Handle { handle, .. } => handle.strides,
TensorArg::Alias { .. } => panic!("Unsupported yet"),
}
}

/// Get the line size of the given [argument](Arg).
///
/// # Panics
///
/// If the argument doesn't have an handle.
pub fn line_size(&self, arg: &Arg) -> u8 {
match self.handle(arg) {
match self.resolve_arg(arg) {
TensorArg::Handle {
vectorization_factor,
..
Expand All @@ -139,19 +154,40 @@ impl<'a, R: Runtime> GlobalArgsLaunch<'a, R> {
}
}

pub fn handle(&self, arg: &Arg) -> &TensorArg<'_, R> {
/// Resolve the [argument](Arg) to a [tensor arguemnt](TensorArg).
///
/// # Panics
///
/// If the argument isn't a global input or output tensor.
pub fn resolve_arg(&self, arg: &Arg) -> &TensorArg<'_, R> {
match arg {
Arg::Input(pos, precision, _) => match precision {
ElemwisePrecision::F32 => &self.t_f32.values[*pos as usize],
ElemwisePrecision::F16 => &self.t_f16.values[*pos as usize],
ElemwisePrecision::BF16 => &self.t_bf16.values[*pos as usize],
_ => panic!(),
ElemwisePrecision::I64 => &self.t_i64.values[*pos as usize],
ElemwisePrecision::I32 => &self.t_i32.values[*pos as usize],
ElemwisePrecision::I16 => &self.t_i16.values[*pos as usize],
ElemwisePrecision::I8 => &self.t_i8.values[*pos as usize],
ElemwisePrecision::U64 => &self.t_u64.values[*pos as usize],
ElemwisePrecision::U32 => &self.t_u32.values[*pos as usize],
ElemwisePrecision::U16 => &self.t_u16.values[*pos as usize],
ElemwisePrecision::U8 => &self.t_u8.values[*pos as usize],
ElemwisePrecision::Bool => panic!("Unsupported yet"),
},
Arg::Output(pos, precision, _) => match precision {
ElemwisePrecision::F32 => &self.t_f32.values[*pos as usize],
ElemwisePrecision::F16 => &self.t_f16.values[*pos as usize],
ElemwisePrecision::BF16 => &self.t_bf16.values[*pos as usize],
_ => panic!(),
ElemwisePrecision::I64 => &self.t_i64.values[*pos as usize],
ElemwisePrecision::I32 => &self.t_i32.values[*pos as usize],
ElemwisePrecision::I16 => &self.t_i16.values[*pos as usize],
ElemwisePrecision::I8 => &self.t_i8.values[*pos as usize],
ElemwisePrecision::U64 => &self.t_u64.values[*pos as usize],
ElemwisePrecision::U32 => &self.t_u32.values[*pos as usize],
ElemwisePrecision::U16 => &self.t_u16.values[*pos as usize],
ElemwisePrecision::U8 => &self.t_u8.values[*pos as usize],
ElemwisePrecision::Bool => panic!("Unsupported yet"),
},
_ => panic!("Only input & output can have a shape"),
}
Expand Down
5 changes: 2 additions & 3 deletions crates/burn-jit/src/kernel/conv/conv2d/gemm/spec.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
use cubecl::prelude::Numeric;
use std::marker::PhantomData;

/// Matrix multiplication spec definiting each element types used in the computation as well as
/// how the arguments are passed to the kernel.
/// Implicit convolution spec definiting each element types used in the computation.
pub trait ConvSpec: Send + Sync + Clone + 'static {
/// The plane size used by this kernel.
const PLANE_DIM: u32;
Expand All @@ -15,7 +14,7 @@ pub trait ConvSpec: Send + Sync + Clone + 'static {
type EA: Numeric;
}

/// Specification for a simple standard matmul using global tensor as inputs.
/// Specification for a single conv using global tensor as inputs.
#[derive(Clone)]
pub struct SingleConvSpec<const PLANE_DIM: u32, EG: Numeric, ES: Numeric, EA: Numeric> {
_eg: PhantomData<EG>,
Expand Down
4 changes: 0 additions & 4 deletions crates/burn-tensor/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,6 @@ mod cube_wgpu {
WgpuDevice::VirtualGpu(index) => DeviceId::new(2, *index as u32),
WgpuDevice::Cpu => DeviceId::new(3, 0),
WgpuDevice::BestAvailable | WgpuDevice::DefaultDevice => DeviceId::new(4, 0),
// For an existing device, use the 64 bit wgpu device ID as the burn DeviceID.
// We're only storing 32 bits, so wrap the the 64 bit value to 32 bits. This
// might collide - but a 1 in 4 billion chance seems ok given there's only a few
// devices in flight at any time.
WgpuDevice::Existing(id) => DeviceId::new(5, *id),
}
}
Expand Down
3 changes: 0 additions & 3 deletions crates/burn-tensor/src/repr/handle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ impl<H: Clone> HandleContainer<H> {

/// Register a handle for the given [tensor id](TensorId).
pub fn register_handle(&mut self, id: TensorId, handle: H) {
// println!("Register handle {id:?}");
self.handles.insert(id, Handle::Existing(handle));
}

Expand All @@ -68,7 +67,6 @@ impl<H: Clone> HandleContainer<H> {
/// Make sure the status corresponds to the operation you want to execute the handle on,
/// otherwise you might remove a tensor handle that will be required in the future.
pub fn get_handle(&mut self, id: &TensorId, status: &TensorStatus) -> H {
// println!("Get handle {id:?}");
let (id, handle) = self
.handles
.remove_entry(id)
Expand Down Expand Up @@ -148,7 +146,6 @@ impl<H: Clone> HandleContainer<H> {
where
B: ReprBackend<Handle = H>,
{
// println!("Register float tensor {id:?}");
let handle = B::float_tensor_handle(tensor);
self.handles.insert(*id, Handle::Existing(handle));
}
Expand Down

0 comments on commit 66aca7d

Please sign in to comment.