Skip to content

Commit

Permalink
[Feat] 8-bit bool for JitBackend (tracel-ai#2526)
Browse files Browse the repository at this point in the history
  • Loading branch information
wingertge authored Nov 29, 2024
1 parent 7cdd6b0 commit 42e7c1f
Show file tree
Hide file tree
Showing 60 changed files with 597 additions and 384 deletions.
1 change: 1 addition & 0 deletions crates/burn-autodiff/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ impl<B: Backend, C: CheckpointStrategy> Backend for Autodiff<B, C> {
type IntElem = B::IntElem;

type BoolTensorPrimitive = B::BoolTensorPrimitive;
type BoolElem = B::BoolElem;

type QuantizedTensorPrimitive = B::QuantizedTensorPrimitive;
type QuantizedEncoding = B::QuantizedEncoding;
Expand Down
10 changes: 5 additions & 5 deletions crates/burn-autodiff/src/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,18 +90,18 @@ macro_rules! testgen_all {

pub type FloatType = <TestBackend as burn_tensor::backend::Backend>::FloatElem;
pub type IntType = <TestBackend as burn_tensor::backend::Backend>::IntElem;
pub type BoolType = <TestBackend as burn_tensor::backend::Backend>::BoolTensorPrimitive;
pub type BoolType = <TestBackend as burn_tensor::backend::Backend>::BoolElem;

::paste::paste! {
$(mod [<$float _ty>] {
pub use super::*;

pub type TestBackend = TestBackend2<$float, IntType>;
pub type TestBackend = TestBackend2<$float, IntType, BoolType>;
pub type TestAutodiffBackend = burn_autodiff::Autodiff<TestBackend>;
pub type TestAutodiffTensor<const D: usize> = burn_tensor::Tensor<TestAutodiffBackend, D>;
pub type TestTensor<const D: usize> = TestTensor2<$float, IntType, D>;
pub type TestTensorInt<const D: usize> = TestTensorInt2<$float, IntType, D>;
pub type TestTensorBool<const D: usize> = TestTensorBool2<$float, IntType, D>;
pub type TestTensor<const D: usize> = TestTensor2<$float, IntType, BoolType, D>;
pub type TestTensorInt<const D: usize> = TestTensorInt2<$float, IntType, BoolType, D>;
pub type TestTensorBool<const D: usize> = TestTensorBool2<$float, IntType, BoolType, D>;

type FloatType = $float;

Expand Down
1 change: 1 addition & 0 deletions crates/burn-candle/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ impl<F: FloatCandleElement, I: IntCandleElement> Backend for Candle<F, I> {
type IntElem = I;

type BoolTensorPrimitive = CandleTensor;
type BoolElem = u32;

type QuantizedTensorPrimitive = CandleQTensor;
type QuantizedEncoding = u8;
Expand Down
6 changes: 3 additions & 3 deletions crates/burn-cuda/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ pub use cubecl::cuda::CudaDevice;
use cubecl::cuda::CudaRuntime;

#[cfg(not(feature = "fusion"))]
pub type Cuda<F = f32, I = i32> = JitBackend<CudaRuntime, F, I>;
pub type Cuda<F = f32, I = i32> = JitBackend<CudaRuntime, F, I, u8>;

#[cfg(feature = "fusion")]
pub type Cuda<F = f32, I = i32> = burn_fusion::Fusion<JitBackend<CudaRuntime, F, I>>;
pub type Cuda<F = f32, I = i32> = burn_fusion::Fusion<JitBackend<CudaRuntime, F, I, u8>>;

#[cfg(test)]
mod tests {
Expand All @@ -19,5 +19,5 @@ mod tests {
pub type TestRuntime = cubecl::cuda::CudaRuntime;
pub use half::{bf16, f16};

burn_jit::testgen_all!([f16, bf16, f32], [i8, i16, i32, i64]);
burn_jit::testgen_all!([f16, bf16, f32], [i8, i16, i32, i64], [u8, u32]);
}
6 changes: 5 additions & 1 deletion crates/burn-fusion/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use burn_tensor::{
backend::{Backend, DeviceOps},
ops::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor},
repr::{OperationDescription, QuantizedKind, ReprBackend, TensorHandle},
Device,
Device, Element,
};
use serde::{de::DeserializeOwned, Serialize};
use std::marker::PhantomData;
Expand Down Expand Up @@ -35,6 +35,8 @@ impl<B: FusionBackend> Backend for Fusion<B> {

type BoolTensorPrimitive = FusionTensor<B::FusionRuntime>;

type BoolElem = B::BoolElem;

type QuantizedTensorPrimitive = QFusionTensor<B::FusionRuntime>;

type QuantizedEncoding = B::QuantizedEncoding;
Expand Down Expand Up @@ -142,6 +144,8 @@ pub trait FusionRuntime: Send + Sync + Sized + core::fmt::Debug {
type FusionDevice: DeviceOps;
/// The client to interact with the runtime.
type FusionClient: FusionClient<Self>;
/// The type that represents booleans on the backend.
type BoolRepr: Element;

/// The list of optimizations that will be used to optimize the computational graph.
fn optimizations(
Expand Down
4 changes: 2 additions & 2 deletions crates/burn-hip/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ use cubecl::hip::HipRuntime;

#[cfg(target_os = "linux")]
#[cfg(not(feature = "fusion"))]
pub type Hip<F = f32, I = i32> = JitBackend<HipRuntime, F, I>;
pub type Hip<F = f32, I = i32, B = u8> = JitBackend<HipRuntime, F, I, B>;

#[cfg(target_os = "linux")]
#[cfg(feature = "fusion")]
pub type Hip<F = f32, I = i32> = burn_fusion::Fusion<JitBackend<HipRuntime, F, I>>;
pub type Hip<F = f32, I = i32, B = u8> = burn_fusion::Fusion<JitBackend<HipRuntime, F, I, B>>;

// TODO: Hang the computer when AMD isn't available.
//
Expand Down
24 changes: 18 additions & 6 deletions crates/burn-jit/src/backend.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::{
element::BoolElement,
tensor::{JitTensor, QJitTensor},
FloatElement, IntElement, JitRuntime,
};
Expand All @@ -18,24 +19,27 @@ pub(crate) static SEED: Mutex<Option<StdRng>> = Mutex::new(None);

/// Generic tensor backend that can be compiled just-in-time to any shader runtime
#[derive(new)]
pub struct JitBackend<R: JitRuntime, F: FloatElement, I: IntElement> {
pub struct JitBackend<R: JitRuntime, F: FloatElement, I: IntElement, BT: BoolElement> {
_runtime: PhantomData<R>,
_float_elem: PhantomData<F>,
_int_elem: PhantomData<I>,
_bool_elem: PhantomData<BT>,
}

impl<R, F, I> Backend for JitBackend<R, F, I>
impl<R, F, I, BT> Backend for JitBackend<R, F, I, BT>
where
R: JitRuntime,
R::Server: ComputeServer,
R::Device: burn_tensor::backend::DeviceOps,
F: FloatElement,
I: IntElement,
BT: BoolElement,
{
type Device = R::Device;

type FloatElem = F;
type IntElem = I;
type BoolElem = BT;

type FloatTensorPrimitive = JitTensor<R>;
type IntTensorPrimitive = JitTensor<R>;
Expand Down Expand Up @@ -63,19 +67,25 @@ where
}
}

impl<R: JitRuntime, F: FloatElement, I: IntElement> core::fmt::Debug for JitBackend<R, F, I> {
impl<R: JitRuntime, F: FloatElement, I: IntElement, BT: BoolElement> core::fmt::Debug
for JitBackend<R, F, I, BT>
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_fmt(format_args!("JitBackend {{ runtime: {}}}", R::name()))
}
}

impl<R: JitRuntime, F: FloatElement, I: IntElement> Clone for JitBackend<R, F, I> {
impl<R: JitRuntime, F: FloatElement, I: IntElement, BT: BoolElement> Clone
for JitBackend<R, F, I, BT>
{
fn clone(&self) -> Self {
Self::new()
}
}

impl<R: JitRuntime, F: FloatElement, I: IntElement> Default for JitBackend<R, F, I> {
impl<R: JitRuntime, F: FloatElement, I: IntElement, BT: BoolElement> Default
for JitBackend<R, F, I, BT>
{
fn default() -> Self {
Self::new()
}
Expand All @@ -90,7 +100,9 @@ where
}

#[cfg(not(feature = "fusion"))]
impl<R: JitRuntime, F: FloatElement, I: IntElement> ReprBackend for JitBackend<R, F, I> {
impl<R: JitRuntime, F: FloatElement, I: IntElement, BT: BoolElement> ReprBackend
for JitBackend<R, F, I, BT>
{
type Handle = HandleKind<Self>;

fn float_tensor(handle: TensorHandle<Self::Handle>) -> FloatTensor<Self> {
Expand Down
24 changes: 24 additions & 0 deletions crates/burn-jit/src/element.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,27 @@ pub trait FloatElement: JitElement + Float {}
/// The int element type for the jit backend.
pub trait IntElement: JitElement + Int {}

/// The element type for booleans for the jit backend.
pub trait BoolElement: JitElement + Int {
/// The true value for the boolean element.
fn true_val() -> Self {
Self::from_int(1)
}

/// The false value for the boolean element.
fn false_val() -> Self {
Self::from_int(0)
}

/// New bool element from Rust bool.
fn new_bool(val: bool) -> Self {
match val {
true => Self::true_val(),
false => Self::false_val(),
}
}
}

impl JitElement for u64 {}
impl JitElement for u32 {}
impl JitElement for u16 {}
Expand All @@ -36,3 +57,6 @@ impl IntElement for i64 {}
impl IntElement for i32 {}
impl IntElement for i16 {}
impl IntElement for i8 {}

impl BoolElement for u8 {}
impl BoolElement for u32 {}
30 changes: 20 additions & 10 deletions crates/burn-jit/src/fusion/base.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use super::elemwise::optimization::{ElemwiseOptimization, ElemwiseOptimizationState};
use crate::fusion::elemwise::builder::ElementWiseBuilder;
use crate::tensor::{JitQuantizationParameters, QJitTensor};
use crate::{element::BoolElement, fusion::elemwise::builder::ElementWiseBuilder};
use crate::{kernel, tensor::JitTensor, FloatElement, IntElement, JitBackend, JitRuntime};
use burn_fusion::{client::MutexFusionClient, FusionBackend, FusionRuntime};
use burn_tensor::quantization::QuantizationScheme;
Expand Down Expand Up @@ -30,13 +30,14 @@ pub enum JitOptimizationState {
ElementWise(ElemwiseOptimizationState),
}

impl<R> burn_fusion::Optimization<FusionJitRuntime<R>> for JitOptimization<R>
impl<R, BT> burn_fusion::Optimization<FusionJitRuntime<R, BT>> for JitOptimization<R>
where
R: JitRuntime,
BT: BoolElement,
{
fn execute(&mut self, context: &mut burn_fusion::stream::Context<'_, JitFusionHandle<R>>) {
match self {
Self::ElementWise2(op) => op.execute(context),
Self::ElementWise2(op) => op.execute::<BT>(context),
}
}

Expand All @@ -61,7 +62,9 @@ where
}
}

impl<R: JitRuntime, F: FloatElement, I: IntElement> ReprBackend for JitBackend<R, F, I> {
impl<R: JitRuntime, F: FloatElement, I: IntElement, BT: BoolElement> ReprBackend
for JitBackend<R, F, I, BT>
{
type Handle = JitFusionHandle<R>;

fn float_tensor(handle: TensorHandle<Self::Handle>) -> burn_tensor::ops::FloatTensor<Self> {
Expand Down Expand Up @@ -122,30 +125,37 @@ impl<R: JitRuntime, F: FloatElement, I: IntElement> ReprBackend for JitBackend<R
}
}

impl<R: JitRuntime> FusionRuntime for FusionJitRuntime<R> {
impl<R: JitRuntime, BT: BoolElement> FusionRuntime for FusionJitRuntime<R, BT> {
type OptimizationState = JitOptimizationState;
type Optimization = JitOptimization<R>;
type FusionHandle = JitFusionHandle<R>;
type FusionDevice = R::JitDevice;
type FusionClient = MutexFusionClient<Self>;
type BoolRepr = BT;

fn optimizations(
device: R::Device,
) -> Vec<Box<dyn burn_fusion::OptimizationBuilder<Self::Optimization>>> {
vec![Box::new(ElementWiseBuilder::<R>::new(device.clone()))]
vec![Box::new(ElementWiseBuilder::<R>::new(
device.clone(),
BT::as_elem().into(),
))]
}
}

/// Fusion runtime for JIT runtimes.
#[derive(Debug)]
pub struct FusionJitRuntime<R: JitRuntime> {
pub struct FusionJitRuntime<R: JitRuntime, BT: BoolElement> {
_b: PhantomData<R>,
_bool: PhantomData<BT>,
}

impl<R: JitRuntime, F: FloatElement, I: IntElement> FusionBackend for JitBackend<R, F, I> {
type FusionRuntime = FusionJitRuntime<R>;
impl<R: JitRuntime, F: FloatElement, I: IntElement, BT: BoolElement> FusionBackend
for JitBackend<R, F, I, BT>
{
type FusionRuntime = FusionJitRuntime<R, BT>;

type FullPrecisionBackend = JitBackend<R, f32, i32>;
type FullPrecisionBackend = JitBackend<R, f32, i32, BT>;

fn cast_float(
tensor: burn_tensor::ops::FloatTensor<Self>,
Expand Down
9 changes: 6 additions & 3 deletions crates/burn-jit/src/fusion/elemwise/builder.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use burn_fusion::OptimizationBuilder;

use crate::{
fusion::{on_write::builder::FuseOnWriteBuilder, JitOptimization},
fusion::{
on_write::{builder::FuseOnWriteBuilder, ir::ElemwisePrecision},
JitOptimization,
},
JitRuntime,
};

Expand All @@ -14,13 +17,13 @@ pub(crate) struct ElementWiseBuilder<R: JitRuntime> {
}

impl<R: JitRuntime> ElementWiseBuilder<R> {
pub fn new(device: R::Device) -> Self {
pub fn new(device: R::Device, bool_precision: ElemwisePrecision) -> Self {
let client = R::client(&device);
let props = client.properties();
let max_bindings = props.hardware_properties().max_bindings;

Self {
builder: FuseOnWriteBuilder::new(max_bindings),
builder: FuseOnWriteBuilder::new(max_bindings, bool_precision),
device,
}
}
Expand Down
6 changes: 3 additions & 3 deletions crates/burn-jit/src/fusion/elemwise/optimization.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::fusion::on_write::kernel::fuse_on_write;
use crate::{fusion::on_write::kernel::fuse_on_write, BoolElement};
use crate::{fusion::JitFusionHandle, JitRuntime};
use burn_fusion::stream::Context;
use burn_tensor::repr::TensorDescription;
Expand Down Expand Up @@ -28,9 +28,9 @@ pub struct ElemwiseOptimizationState {

impl<R: JitRuntime> ElemwiseOptimization<R> {
/// Execute the optimization.
pub fn execute(&mut self, context: &mut Context<'_, JitFusionHandle<R>>) {
pub fn execute<BT: BoolElement>(&mut self, context: &mut Context<'_, JitFusionHandle<R>>) {
self.trace
.run::<R, Self>(&self.client, &self.device, context)
.run::<R, BT, Self>(&self.client, &self.device, context)
}

/// Number of element wise operations fused.
Expand Down
12 changes: 6 additions & 6 deletions crates/burn-jit/src/fusion/on_write/builder.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::{
ir::{Arg, BinaryElemwiseArgs, ElemwiseOp, UnaryElemwiseArgs},
ir::{Arg, BinaryElemwiseArgs, ElemwiseOp, ElemwisePrecision, UnaryElemwiseArgs},
trace::FuseOnWriteTrace,
trace_builder::FuseOnWriteTraceBuilder,
};
Expand Down Expand Up @@ -30,9 +30,9 @@ struct TryFuseBuilder {
}

impl TryFuseBuilder {
fn new(max_bindings: u32) -> Self {
fn new(max_bindings: u32, bool_precision: ElemwisePrecision) -> Self {
Self {
builder: FuseOnWriteTraceBuilder::new(),
builder: FuseOnWriteTraceBuilder::new(bool_precision),
max_bindings,
added_ops: false,
}
Expand Down Expand Up @@ -118,7 +118,7 @@ impl OptimizationBuilder<FuseOnWriteTrace> for FuseOnWriteBuilder {
fn reset(&mut self) {
self.num_ops = 0;
self.status = OptimizationStatus::Open;
self.builder = TryFuseBuilder::new(self.max_bindings);
self.builder = TryFuseBuilder::new(self.max_bindings, self.builder.builder.bool_precision);
self.current_output_shape.clear();
}

Expand All @@ -137,9 +137,9 @@ impl OptimizationBuilder<FuseOnWriteTrace> for FuseOnWriteBuilder {
}

impl FuseOnWriteBuilder {
pub fn new(max_bindings: u32) -> Self {
pub fn new(max_bindings: u32, bool_precision: ElemwisePrecision) -> Self {
Self {
builder: TryFuseBuilder::new(max_bindings),
builder: TryFuseBuilder::new(max_bindings, bool_precision),
num_ops: 0,
max_bindings,
current_output_shape: Vec::new(),
Expand Down
Loading

0 comments on commit 42e7c1f

Please sign in to comment.