From 68a711b5c931d8b96b635c78f8b63f06a0590c14 Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Fri, 29 Nov 2024 12:51:37 -0500 Subject: [PATCH 01/16] Remove q_shape to use TensorMetadata instead --- crates/burn-autodiff/src/ops/qtensor.rs | 4 ---- crates/burn-candle/src/ops/qtensor.rs | 4 ---- crates/burn-fusion/src/ops/qtensor.rs | 8 +------- crates/burn-jit/src/ops/qtensor.rs | 4 ---- crates/burn-ndarray/src/ops/qtensor.rs | 4 ---- crates/burn-router/src/ops/op_qfloat.rs | 4 ---- crates/burn-tch/src/ops/qtensor.rs | 6 +----- crates/burn-tensor/src/tensor/ops/qtensor.rs | 19 ++++--------------- 8 files changed, 6 insertions(+), 47 deletions(-) diff --git a/crates/burn-autodiff/src/ops/qtensor.rs b/crates/burn-autodiff/src/ops/qtensor.rs index ee3c8ab7f7..b55157f328 100644 --- a/crates/burn-autodiff/src/ops/qtensor.rs +++ b/crates/burn-autodiff/src/ops/qtensor.rs @@ -33,10 +33,6 @@ impl QTensorOps for Autodiff { todo!() } - fn q_shape(tensor: &QuantizedTensor) -> Shape { - B::q_shape(tensor) - } - fn q_device(tensor: &QuantizedTensor) -> Device { B::q_device(tensor) } diff --git a/crates/burn-candle/src/ops/qtensor.rs b/crates/burn-candle/src/ops/qtensor.rs index 9d63f02917..f4c2e96f04 100644 --- a/crates/burn-candle/src/ops/qtensor.rs +++ b/crates/burn-candle/src/ops/qtensor.rs @@ -29,10 +29,6 @@ impl QTensorOps for Candle) -> Shape { - super::base::shape(&tensor.qtensor) - } - fn q_device(tensor: &QuantizedTensor) -> Device { super::base::device(&tensor.qtensor) } diff --git a/crates/burn-fusion/src/ops/qtensor.rs b/crates/burn-fusion/src/ops/qtensor.rs index c8b1c09610..7ff0a47d58 100644 --- a/crates/burn-fusion/src/ops/qtensor.rs +++ b/crates/burn-fusion/src/ops/qtensor.rs @@ -24,7 +24,7 @@ impl QTensorOps for Fusion { DType::QFloat(scheme) => { let client = get_client::(device); let tensor = B::q_from_data(data, device); - let shape = B::q_shape(&tensor); + let shape = burn_tensor::TensorMetadata::shape(&tensor); let handles = B::quantized_tensor_handle(tensor); let qparams = match scheme { @@ -211,12 +211,6 @@ impl QTensorOps for Fusion { out } - fn q_shape(tensor: &QuantizedTensor) -> Shape { - // Conflicting `dtype()` when both `Element` and `TensorMetadata` traits are in - // scope so we use the fully qualified syntax - burn_tensor::TensorMetadata::shape(tensor) - } - fn q_device(tensor: &QuantizedTensor) -> Device { tensor.qtensor.client.device().clone() } diff --git a/crates/burn-jit/src/ops/qtensor.rs b/crates/burn-jit/src/ops/qtensor.rs index 94b1a6f2ee..d9ef224ace 100644 --- a/crates/burn-jit/src/ops/qtensor.rs +++ b/crates/burn-jit/src/ops/qtensor.rs @@ -72,10 +72,6 @@ where kernel::quantization::dequantize::(tensor) } - fn q_shape(tensor: &QuantizedTensor) -> Shape { - tensor.qtensor.shape.clone() - } - fn q_device(tensor: &QuantizedTensor) -> Device { tensor.qtensor.device.clone() } diff --git a/crates/burn-ndarray/src/ops/qtensor.rs b/crates/burn-ndarray/src/ops/qtensor.rs index d610d39804..3542c43992 100644 --- a/crates/burn-ndarray/src/ops/qtensor.rs +++ b/crates/burn-ndarray/src/ops/qtensor.rs @@ -108,10 +108,6 @@ impl QTensorOps) -> Shape { - tensor.qtensor.shape() - } - fn q_device(_tensor: &QuantizedTensor) -> NdArrayDevice { NdArrayDevice::Cpu } diff --git a/crates/burn-router/src/ops/op_qfloat.rs b/crates/burn-router/src/ops/op_qfloat.rs index 1f4784aceb..2ac092505c 100644 --- a/crates/burn-router/src/ops/op_qfloat.rs +++ b/crates/burn-router/src/ops/op_qfloat.rs @@ -32,10 +32,6 @@ impl QTensorOps for BackendRouter { unimplemented!() } - fn q_shape(_tensor: &QuantizedTensor) -> Shape { - unimplemented!() - } - fn q_device(_tensor: &QuantizedTensor) -> Device { unimplemented!() } diff --git a/crates/burn-tch/src/ops/qtensor.rs b/crates/burn-tch/src/ops/qtensor.rs index ec644d066c..360ec628fe 100644 --- a/crates/burn-tch/src/ops/qtensor.rs +++ b/crates/burn-tch/src/ops/qtensor.rs @@ -132,10 +132,6 @@ impl QTensorOps for LibTorch { TchTensor::new(tensor.qtensor.tensor.dequantize().to_kind(E::KIND)) } - fn q_shape(tensor: &QuantizedTensor) -> Shape { - tensor.qtensor.shape() - } - fn q_device(tensor: &QuantizedTensor) -> LibTorchDevice { tensor.qtensor.tensor.device().into() } @@ -157,7 +153,7 @@ impl QTensorOps for LibTorch { } async fn q_into_data(tensor: QuantizedTensor) -> TensorData { - let shape = Self::q_shape(&tensor); + let shape = tensor.shape(); let tensor = Self::q_reshape(tensor.clone(), Shape::new([shape.num_elements()])); let strategy = tensor.strategy(); diff --git a/crates/burn-tensor/src/tensor/ops/qtensor.rs b/crates/burn-tensor/src/tensor/ops/qtensor.rs index 4b9df13a49..781ed7c6eb 100644 --- a/crates/burn-tensor/src/tensor/ops/qtensor.rs +++ b/crates/burn-tensor/src/tensor/ops/qtensor.rs @@ -4,7 +4,7 @@ use core::{future::Future, ops::Range}; use crate::{ backend::Backend, quantization::{QTensorPrimitive, QuantizationParametersPrimitive, QuantizationScheme}, - Device, Shape, TensorData, + Device, Shape, TensorData, TensorMetadata, }; use super::{BoolTensor, FloatElem, FloatTensor, IntElem, IntTensor, QuantizedTensor}; @@ -74,17 +74,6 @@ pub trait QTensorOps { /// Convert the tensor back to a higher precision data type. fn dequantize(tensor: QuantizedTensor) -> FloatTensor; - /// Gets the shape of the tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// - /// # Returns - /// - /// The shape of the tensor. - fn q_shape(tensor: &QuantizedTensor) -> Shape; - /// Gets the device of the tensor. /// /// # Arguments @@ -460,7 +449,7 @@ pub trait QTensorOps { /// /// The transposed tensor. fn q_transpose(tensor: QuantizedTensor) -> QuantizedTensor { - let ndims = Self::q_shape(&tensor).num_dims(); + let ndims = tensor.shape().num_dims(); Self::q_swap_dims(tensor, ndims - 2, ndims - 1) } @@ -1071,7 +1060,7 @@ pub trait QTensorOps { /// /// A tensor with the maximum element of `tensor`. fn q_max(tensor: QuantizedTensor) -> QuantizedTensor { - let shape = B::q_shape(&tensor); + let shape = tensor.shape(); let tensor = B::q_reshape(tensor, Shape::new([shape.num_elements()])); B::q_max_dim(tensor, 0) @@ -1123,7 +1112,7 @@ pub trait QTensorOps { /// /// A tensor with the minimum element of `tensor`. fn q_min(tensor: QuantizedTensor) -> QuantizedTensor { - let shape = B::q_shape(&tensor); + let shape = tensor.shape(); let tensor = B::q_reshape(tensor, Shape::new([shape.num_elements()])); B::q_min_dim(tensor, 0) From eb8edd795e926706a4afc1864cbce3f3034d5753 Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Fri, 6 Dec 2024 13:25:46 -0500 Subject: [PATCH 02/16] Fix spirv bool type --- crates/burn-wgpu/src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/burn-wgpu/src/lib.rs b/crates/burn-wgpu/src/lib.rs index 0751ad9f41..deb6a8ebd8 100644 --- a/crates/burn-wgpu/src/lib.rs +++ b/crates/burn-wgpu/src/lib.rs @@ -22,7 +22,7 @@ pub type SpirV = cubecl::wgpu::spirv::VkSpirvCompiler; #[cfg(feature = "spirv")] type Compiler = SpirV; #[cfg(feature = "spirv")] -type Byte = u8; +type Bool = u8; #[cfg(not(feature = "spirv"))] type Compiler = Wgsl; #[cfg(not(feature = "spirv"))] From d28e270dcc991613ee5c313e8b7e56f16d264a86 Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Mon, 9 Dec 2024 09:09:04 -0500 Subject: [PATCH 03/16] Refactor burn-jit quantized tensor representation --- crates/burn-candle/src/tensor.rs | 4 - crates/burn-fusion/src/backend.rs | 15 +- crates/burn-fusion/src/client/base.rs | 14 +- crates/burn-fusion/src/client/mutex.rs | 54 +--- crates/burn-fusion/src/ops/qtensor.rs | 126 ++-------- crates/burn-fusion/src/server.rs | 54 +--- crates/burn-fusion/src/stream/context.rs | 14 +- crates/burn-fusion/src/tensor.rs | 135 ++-------- crates/burn-jit/src/backend.rs | 11 +- crates/burn-jit/src/fusion/base.rs | 35 +-- .../src/kernel/quantization/dequantize.rs | 236 ++++++++---------- .../burn-jit/src/kernel/quantization/mod.rs | 2 + .../src/kernel/quantization/qtensor.rs | 95 +++++++ .../src/kernel/quantization/quantize.rs | 205 ++++++++------- crates/burn-jit/src/ops/qtensor.rs | 83 +++--- crates/burn-jit/src/tensor/base.rs | 36 ++- crates/burn-jit/src/tensor/mod.rs | 2 - crates/burn-jit/src/tensor/qtensor.rs | 125 ---------- crates/burn-ndarray/src/backend.rs | 23 +- crates/burn-ndarray/src/ops/qtensor.rs | 4 +- crates/burn-ndarray/src/tensor.rs | 15 +- crates/burn-router/src/backend.rs | 6 +- crates/burn-tch/src/ops/qtensor.rs | 3 +- crates/burn-tch/src/tensor.rs | 35 +-- crates/burn-tensor/src/repr/backend.rs | 19 +- crates/burn-tensor/src/repr/handle.rs | 29 +-- crates/burn-tensor/src/repr/mod.rs | 2 - crates/burn-tensor/src/repr/operation.rs | 26 +- crates/burn-tensor/src/repr/quantization.rs | 25 -- crates/burn-tensor/src/tensor/data.rs | 19 +- .../src/tensor/quantization/primitive.rs | 8 +- .../src/tensor/quantization/scheme.rs | 18 ++ .../src/tests/quantization/ops/quantize.rs | 22 +- 33 files changed, 579 insertions(+), 921 deletions(-) create mode 100644 crates/burn-jit/src/kernel/quantization/qtensor.rs delete mode 100644 crates/burn-jit/src/tensor/qtensor.rs delete mode 100644 crates/burn-tensor/src/repr/quantization.rs diff --git a/crates/burn-candle/src/tensor.rs b/crates/burn-candle/src/tensor.rs index c4e7d740ca..ec71b3d47e 100644 --- a/crates/burn-candle/src/tensor.rs +++ b/crates/burn-candle/src/tensor.rs @@ -70,10 +70,6 @@ impl QTensorPrimitive for CandleQTensor { fn scheme(&self) -> &QuantizationScheme { &self.scheme } - - fn strategy(&self) -> QuantizationStrategy { - todo!() - } } impl TensorMetadata for CandleQTensor { diff --git a/crates/burn-fusion/src/backend.rs b/crates/burn-fusion/src/backend.rs index aa308ad9a7..d3f80a79dd 100644 --- a/crates/burn-fusion/src/backend.rs +++ b/crates/burn-fusion/src/backend.rs @@ -1,10 +1,8 @@ -use crate::{ - client::FusionClient, stream::Context, FusionClientLocator, FusionTensor, QFusionTensor, -}; +use crate::{client::FusionClient, stream::Context, FusionClientLocator, FusionTensor}; use burn_tensor::{ backend::{Backend, DeviceOps}, ops::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor}, - repr::{OperationDescription, QuantizedKind, ReprBackend, TensorHandle}, + repr::{OperationDescription, ReprBackend, TensorHandle}, Device, Element, }; use serde::{de::DeserializeOwned, Serialize}; @@ -37,7 +35,7 @@ impl Backend for Fusion { type BoolElem = B::BoolElem; - type QuantizedTensorPrimitive = QFusionTensor; + type QuantizedTensorPrimitive = FusionTensor; type QuantizedEncoding = B::QuantizedEncoding; @@ -184,10 +182,7 @@ impl ReprBackend for Fusion { handle.handle } - fn quantized_tensor( - _handles: QuantizedKind>, - _scheme: burn_tensor::quantization::QuantizationScheme, - ) -> QuantizedTensor { + fn quantized_tensor(_handle: TensorHandle) -> QuantizedTensor { todo!() // not as simple } @@ -203,7 +198,7 @@ impl ReprBackend for Fusion { tensor } - fn quantized_tensor_handle(_tensor: QuantizedTensor) -> QuantizedKind { + fn quantized_tensor_handle(_tensor: QuantizedTensor) -> Self::Handle { todo!() // not as simple } } diff --git a/crates/burn-fusion/src/client/base.rs b/crates/burn-fusion/src/client/base.rs index 33508c40b7..7b802c537e 100644 --- a/crates/burn-fusion/src/client/base.rs +++ b/crates/burn-fusion/src/client/base.rs @@ -2,10 +2,10 @@ use std::future::Future; use crate::{ stream::{execution::Operation, StreamId}, - FusionBackend, FusionDevice, FusionHandle, FusionRuntime, FusionTensor, QFusionTensor, + FusionBackend, FusionDevice, FusionHandle, FusionRuntime, FusionTensor, }; use burn_tensor::{ - repr::{OperationDescription, QuantizedTensorDescription, TensorDescription, TensorId}, + repr::{OperationDescription, TensorDescription, TensorId}, DType, TensorData, }; @@ -61,8 +61,8 @@ where /// Read the values contained by a quantized tensor. fn read_tensor_quantized( &self, - tensor: QuantizedTensorDescription, - streams: Vec, + tensor: TensorDescription, + streams: StreamId, ) -> impl Future + Send + 'static where B: FusionBackend; @@ -108,10 +108,10 @@ where /// Change the client of the given quantized tensor. fn change_client_quantized( &self, - tensor: QuantizedTensorDescription, + tensor: TensorDescription, client: Self, - streams: Vec, - ) -> QFusionTensor + stream: StreamId, + ) -> FusionTensor where B: FusionBackend; /// Drop the tensor with the given [tensor id](TensorId). diff --git a/crates/burn-fusion/src/client/mutex.rs b/crates/burn-fusion/src/client/mutex.rs index 503fb2d434..1301564fb8 100644 --- a/crates/burn-fusion/src/client/mutex.rs +++ b/crates/burn-fusion/src/client/mutex.rs @@ -1,11 +1,10 @@ use super::FusionClient; use crate::{ stream::{execution::Operation, StreamId}, - FusionBackend, FusionDevice, FusionHandle, FusionQuantizationParameters, FusionRuntime, - FusionServer, FusionTensor, QFusionTensor, + FusionBackend, FusionDevice, FusionHandle, FusionRuntime, FusionServer, FusionTensor, }; use burn_tensor::{ - repr::{OperationDescription, QuantizedTensorDescription, TensorDescription, TensorId}, + repr::{OperationDescription, TensorDescription, TensorId}, DType, }; use spin::Mutex; @@ -115,13 +114,13 @@ where fn read_tensor_quantized( &self, - tensor: QuantizedTensorDescription, - streams: Vec, + tensor: TensorDescription, + stream: StreamId, ) -> impl Future + 'static where B: FusionBackend, { - self.server.lock().read_quantized::(tensor, streams) + self.server.lock().read_quantized::(tensor, stream) } fn change_client_float( @@ -190,55 +189,24 @@ where fn change_client_quantized( &self, - tensor: QuantizedTensorDescription, + tensor: TensorDescription, client: Self, - streams: Vec, - ) -> QFusionTensor + stream: StreamId, + ) -> FusionTensor where B: FusionBackend, { let mut server_other = client.server.lock(); let mut server_current = self.server.lock(); - for stream in streams { - server_current.drain_stream(stream); - } + server_current.drain_stream(stream); - let mut ids = + let id = server_current.change_server_quantized::(&tensor, &client.device, &mut server_other); core::mem::drop(server_other); core::mem::drop(server_current); - // NOTE: the expected order is known [qtensor, scale, ] - let offset = tensor.qparams.offset.map(|desc| { - FusionTensor::new( - ids.pop().unwrap(), - desc.shape, - desc.dtype, - client.clone(), - StreamId::current(), - ) - }); - let scale = FusionTensor::new( - ids.pop().unwrap(), - tensor.qparams.scale.shape, - tensor.qparams.scale.dtype, - client.clone(), - StreamId::current(), - ); - let qtensor = FusionTensor::new( - ids.pop().unwrap(), - tensor.tensor.shape, - tensor.tensor.dtype, - client, - StreamId::current(), - ); - - QFusionTensor { - qtensor, - scheme: tensor.scheme, - qparams: FusionQuantizationParameters { scale, offset }, - } + FusionTensor::new(id, tensor.shape, tensor.dtype, client, StreamId::current()) } fn register_orphan(&self, id: &TensorId) { diff --git a/crates/burn-fusion/src/ops/qtensor.rs b/crates/burn-fusion/src/ops/qtensor.rs index 7ff0a47d58..41bc7ccde6 100644 --- a/crates/burn-fusion/src/ops/qtensor.rs +++ b/crates/burn-fusion/src/ops/qtensor.rs @@ -2,11 +2,10 @@ use std::{marker::PhantomData, ops::Range}; use burn_tensor::{ ops::{FloatElem, FloatTensor, IntTensor, QTensorOps, QuantizedTensor}, - quantization::{QuantizationParametersPrimitive, QuantizationScheme, QuantizationType}, + quantization::{QuantizationParametersPrimitive, QuantizationScheme}, repr::{ DequantizeOperationDescription, FloatOperationDescription, HandleContainer, OperationDescription, QuantizationParametersDescription, QuantizeOperationDescription, - QuantizedKind, }, DType, Device, Element, Shape, TensorData, }; @@ -15,67 +14,24 @@ use crate::{ client::FusionClient, get_client, stream::{execution::Operation, StreamId}, - Fusion, FusionBackend, FusionQuantizationParameters, QFusionTensor, + Fusion, FusionBackend, }; impl QTensorOps for Fusion { fn q_from_data(data: TensorData, device: &Device) -> QuantizedTensor { match data.dtype { - DType::QFloat(scheme) => { + DType::QFloat(_scheme) => { + let dtype = data.dtype; let client = get_client::(device); let tensor = B::q_from_data(data, device); let shape = burn_tensor::TensorMetadata::shape(&tensor); - let handles = B::quantized_tensor_handle(tensor); - let qparams = match scheme { - QuantizationScheme::PerTensorAffine(QuantizationType::QInt8) => { - let offset = if let Some(offset) = handles.offset { - offset - } else { - panic!("Expected offset for quantized tensor."); - }; - FusionQuantizationParameters { - scale: client.register_tensor( - handles.scale, - vec![1], - StreamId::current(), - B::FloatElem::dtype(), - ), - offset: Some(client.register_tensor( - offset, - vec![1], - StreamId::current(), - B::IntElem::dtype(), - )), - } - } - QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8) => { - assert!( - handles.offset.is_none(), - "Offset should not be provided for symmetric quantization." - ); - FusionQuantizationParameters { - scale: client.register_tensor( - handles.scale, - vec![1], - StreamId::current(), - B::FloatElem::dtype(), - ), - offset: None, - } - } - }; - let qtensor = client.register_tensor( - handles.tensor, + client.register_tensor( + B::quantized_tensor_handle(tensor), shape.dims, StreamId::current(), - B::QuantizedEncoding::dtype(), - ); - QFusionTensor { - qtensor, - qparams, - scheme, - } + dtype, + ) } _ => panic!( "Invalid dtype (expected DType::QFloat, got {:?})", @@ -108,27 +64,14 @@ impl QTensorOps for Fusion { let qparams = QuantizationParametersPrimitive { scale, offset }; let output = B::quantize(tensor, &self.desc.scheme, qparams); - let q_ids = if let Some(offset) = &self.desc.qparams.offset { - QuantizedKind { - tensor: self.desc.out.id, - scale: self.desc.qparams.scale.id, - offset: Some(offset.id), - } - } else { - QuantizedKind { - tensor: self.desc.out.id, - scale: self.desc.qparams.scale.id, - offset: None, - } - }; - handles.register_quantized_tensor::(&q_ids, output); + handles.register_quantized_tensor::(&self.desc.out.id, output); } } let shape: Vec = tensor.shape.clone(); let out = tensor .client - .tensor_uninitialized(shape, B::QuantizedEncoding::dtype()); + .tensor_uninitialized(shape, DType::QFloat(*scheme)); let streams = if let Some(offset) = &qparams.offset { vec![tensor.stream, qparams.scale.stream, offset.stream] @@ -155,11 +98,7 @@ impl QTensorOps for Fusion { QuantizeOp::::new(desc), ); - QFusionTensor { - qtensor: out, - scheme: *scheme, - qparams: qparams.into(), - } + out } fn dequantize(tensor: QuantizedTensor) -> FloatTensor { @@ -171,36 +110,26 @@ impl QTensorOps for Fusion { impl Operation for DequantizeOp { fn execute(self: Box, handles: &mut HandleContainer) { - let tensor = handles.get_quantized_tensor::(&self.desc.qtensor); + let tensor = handles.get_quantized_tensor::(&self.desc.input); let output = B::dequantize(tensor); handles.register_float_tensor::(&self.desc.out.id, output); } } - let shape: Vec = tensor.qtensor.shape.clone(); + let stream = tensor.stream; + let shape: Vec = tensor.shape.clone(); let out = tensor - .qtensor .client .tensor_uninitialized(shape, B::FloatElem::dtype()); - let streams = if let Some(offset) = &tensor.qparams.offset { - vec![ - tensor.qtensor.stream, - tensor.qparams.scale.stream, - offset.stream, - ] - } else { - vec![tensor.qtensor.stream, tensor.qparams.scale.stream] - }; - let desc = DequantizeOperationDescription { - qtensor: tensor.into_description(), + input: tensor.into_description(), out: out.to_description_out(), }; out.client.register( - streams, + vec![stream], OperationDescription::Float( FloatElem::::dtype(), FloatOperationDescription::Dequantize(desc.clone()), @@ -212,33 +141,22 @@ impl QTensorOps for Fusion { } fn q_device(tensor: &QuantizedTensor) -> Device { - tensor.qtensor.client.device().clone() + tensor.client.device().clone() } fn q_to_device(tensor: QuantizedTensor, device: &Device) -> QuantizedTensor { - // Quantization parameters are on the same device as the qtensor - let device_original: &B::Device = tensor.qtensor.client.device(); + let device_original: &B::Device = tensor.client.device(); let device_target: B::Device = device.clone(); if device_original == &device_target { return tensor; } - println!("q_to_device {:?} {:?}", device_original, device_target); + let id = tensor.stream; let client_target = get_client::(&device_target); - let client_original = tensor.qtensor.client.clone(); - - let ids = if let Some(offset) = &tensor.qparams.offset { - vec![ - tensor.qtensor.stream, - tensor.qparams.scale.stream, - offset.stream, - ] - } else { - vec![tensor.qtensor.stream, tensor.qparams.scale.stream] - }; + let client_original = tensor.client.clone(); - client_original.change_client_quantized::(tensor.into_description(), client_target, ids) + client_original.change_client_quantized::(tensor.into_description(), client_target, id) } fn q_reshape(_tensor: QuantizedTensor, _shape: Shape) -> QuantizedTensor { @@ -246,7 +164,7 @@ impl QTensorOps for Fusion { } async fn q_into_data(tensor: QuantizedTensor) -> TensorData { - tensor.into_data::().await + tensor.q_into_data::().await } fn q_swap_dims( diff --git a/crates/burn-fusion/src/server.rs b/crates/burn-fusion/src/server.rs index bff39a1d2a..688645fc34 100644 --- a/crates/burn-fusion/src/server.rs +++ b/crates/burn-fusion/src/server.rs @@ -2,10 +2,7 @@ use crate::{ stream::{execution::Operation, MultiStream, StreamId}, FusionBackend, FusionRuntime, }; -use burn_tensor::repr::{ - HandleContainer, OperationDescription, QuantizedKind, QuantizedTensorDescription, - TensorDescription, TensorId, -}; +use burn_tensor::repr::{HandleContainer, OperationDescription, TensorDescription, TensorId}; use std::{future::Future, sync::Arc}; pub struct FusionServer { @@ -92,17 +89,15 @@ where pub fn read_quantized( &mut self, - tensor: QuantizedTensorDescription, - ids: Vec, + tensor: TensorDescription, + id: StreamId, ) -> impl Future + 'static where B: FusionBackend, { // Make sure all registered operations are executed. // The underlying backend can still be async. - for id in ids { - self.drain_stream(id); - } + self.drain_stream(id); let tensor = self.handles.get_quantized_tensor::(&tensor); B::q_into_data(tensor) @@ -191,45 +186,22 @@ where pub fn change_server_quantized( &mut self, - desc: &QuantizedTensorDescription, + desc: &TensorDescription, device: &R::FusionDevice, server_device: &mut Self, - ) -> Vec> + ) -> Arc where B: FusionBackend, { let tensor = self.handles.get_quantized_tensor::(desc); let tensor = B::q_to_device(tensor, device); - if desc.qparams.offset.is_some() { - let tensor_id = server_device.create_empty_handle(); - let scale_id = server_device.create_empty_handle(); - let offset_id = server_device.create_empty_handle(); - - let q_ids = QuantizedKind { - tensor: *tensor_id, - scale: *scale_id, - offset: Some(*offset_id), - }; - server_device - .handles - .register_quantized_tensor::(&q_ids, tensor); - - vec![tensor_id, scale_id, offset_id] - } else { - let tensor_id = server_device.create_empty_handle(); - let scale_id = server_device.create_empty_handle(); - - let q_ids = QuantizedKind { - tensor: *tensor_id, - scale: *scale_id, - offset: None, - }; - server_device - .handles - .register_quantized_tensor::(&q_ids, tensor); - - vec![tensor_id, scale_id] - } + let id = server_device.create_empty_handle(); + + server_device + .handles + .register_quantized_tensor::(&id, tensor); + + id } pub fn drop_tensor_handle(&mut self, id: TensorId) { diff --git a/crates/burn-fusion/src/stream/context.rs b/crates/burn-fusion/src/stream/context.rs index 565f160362..7dcc81cb77 100644 --- a/crates/burn-fusion/src/stream/context.rs +++ b/crates/burn-fusion/src/stream/context.rs @@ -553,19 +553,7 @@ impl RelativeOpsScalar for FloatOperationDescription { } FloatOperationDescription::Dequantize(desc) => { FloatOperationDescription::Dequantize(DequantizeOperationDescription { - qtensor: QuantizedTensorDescription { - tensor: desc.qtensor.tensor.to_relative(converter), - qparams: QuantizationParametersDescription { - scale: desc.qtensor.qparams.scale.to_relative(converter), - offset: desc - .qtensor - .qparams - .offset - .as_ref() - .map(|x| x.to_relative(converter)), - }, - scheme: desc.qtensor.scheme, - }, + input: desc.input.to_relative(converter), out: desc.out.to_relative(converter), }) } diff --git a/crates/burn-fusion/src/tensor.rs b/crates/burn-fusion/src/tensor.rs index 44d71abd4c..c6ea209e29 100644 --- a/crates/burn-fusion/src/tensor.rs +++ b/crates/burn-fusion/src/tensor.rs @@ -1,12 +1,7 @@ -use crate::{client::FusionClient, stream::StreamId, Client, Fusion, FusionBackend, FusionRuntime}; +use crate::{client::FusionClient, stream::StreamId, Client, FusionBackend, FusionRuntime}; use burn_tensor::{ - quantization::{ - QTensorPrimitive, QuantizationParametersPrimitive, QuantizationScheme, QuantizationStrategy, - }, - repr::{ - QuantizationParametersDescription, QuantizedTensorDescription, TensorDescription, TensorId, - TensorStatus, - }, + quantization::{QTensorPrimitive, QuantizationScheme}, + repr::{TensorDescription, TensorId, TensorStatus}, DType, Shape, TensorData, TensorMetadata, }; use std::sync::Arc; @@ -133,6 +128,22 @@ impl FusionTensor { .await } + pub(crate) async fn q_into_data(self) -> TensorData + where + B: FusionBackend, + { + if let DType::QFloat(_scheme) = self.dtype { + let id = self.stream; + self.client + .clone() + .read_tensor_quantized::(self.into_description(), id) + .await + // todo!() // doesn't work if we only have one tensordescription when we need the handles for the tensor + qparams + } else { + panic!("Expected quantized float dtype, got {:?}", self.dtype) + } + } + pub(crate) async fn int_into_data(self) -> TensorData where B: FusionBackend, @@ -172,109 +183,15 @@ impl Drop for FusionTensor { } } -/// A quantized tensor primitive for fusion backends. -#[derive(Debug)] -pub struct QFusionTensor { - /// The quantized tensor. - pub qtensor: FusionTensor, - /// The quantization scheme. - pub scheme: QuantizationScheme, - /// The quantization parameters. - pub qparams: FusionQuantizationParameters, -} - -impl QTensorPrimitive for QFusionTensor { +impl QTensorPrimitive for FusionTensor { fn scheme(&self) -> &QuantizationScheme { - &self.scheme - } - - fn strategy(&self) -> QuantizationStrategy { - // TODO - todo!() - } -} - -impl Clone for QFusionTensor { - fn clone(&self) -> Self { - Self { - qtensor: self.qtensor.clone(), - scheme: self.scheme, - qparams: self.qparams.clone(), - } - } -} - -impl TensorMetadata for QFusionTensor { - fn dtype(&self) -> DType { - DType::QFloat(self.scheme) - } - - fn shape(&self) -> Shape { - self.qtensor.shape() - } -} - -impl QFusionTensor { - pub(crate) async fn into_data(self) -> TensorData - where - B: FusionBackend, - { - let streams = if let Some(offset) = &self.qparams.offset { - vec![ - self.qtensor.stream, - self.qparams.scale.stream, - offset.stream, - ] + if let DType::QFloat(scheme) = &self.dtype { + scheme } else { - vec![self.qtensor.stream, self.qparams.scale.stream] - }; - - // Quantized tensor and qparams tensors client are the same - self.qtensor - .client - .clone() - .read_tensor_quantized::(self.into_description(), streams) - .await - } - - /// Description to be used when using an initialized tensor used as input. - pub(crate) fn into_description(self) -> QuantizedTensorDescription { - QuantizedTensorDescription { - tensor: self.qtensor.into_description(), - qparams: QuantizationParametersDescription { - scale: self.qparams.scale.into_description(), - offset: self.qparams.offset.map(|x| x.into_description()), - }, - scheme: self.scheme, - } - } -} - -/// The quantization parameters. -#[derive(Debug)] -pub struct FusionQuantizationParameters { - /// The scaling factor. - pub scale: FusionTensor, - /// The zero-point offset. - pub offset: Option>, -} - -impl Clone for FusionQuantizationParameters { - fn clone(&self) -> Self { - Self { - scale: self.scale.clone(), - offset: self.offset.clone(), - } - } -} - -impl From>> - for FusionQuantizationParameters -{ - fn from(value: QuantizationParametersPrimitive>) -> Self { - FusionQuantizationParameters { - scale: value.scale, - offset: value.offset, + panic!( + "Quantization scheme is not valid for dtype {:?}", + self.dtype, + ) } } } diff --git a/crates/burn-jit/src/backend.rs b/crates/burn-jit/src/backend.rs index b455d859a0..34cb8e0c4e 100644 --- a/crates/burn-jit/src/backend.rs +++ b/crates/burn-jit/src/backend.rs @@ -1,8 +1,4 @@ -use crate::{ - element::BoolElement, - tensor::{JitTensor, QJitTensor}, - FloatElement, IntElement, JitRuntime, -}; +use crate::{element::BoolElement, tensor::JitTensor, FloatElement, IntElement, JitRuntime}; use burn_tensor::backend::{Backend, DeviceOps}; use cubecl::server::ComputeServer; use rand::{rngs::StdRng, SeedableRng}; @@ -12,7 +8,7 @@ use std::{marker::PhantomData, sync::Mutex}; use burn_tensor::{ ops::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor}, quantization::QuantizationScheme, - repr::{HandleKind, QuantizedKind, ReprBackend, TensorHandle}, + repr::{HandleKind, ReprBackend, TensorHandle}, }; pub(crate) static SEED: Mutex> = Mutex::new(None); @@ -44,7 +40,7 @@ where type FloatTensorPrimitive = JitTensor; type IntTensorPrimitive = JitTensor; type BoolTensorPrimitive = JitTensor; - type QuantizedTensorPrimitive = QJitTensor; + type QuantizedTensorPrimitive = JitTensor; type QuantizedEncoding = u32; fn name() -> String { @@ -128,7 +124,6 @@ impl ReprBackend fn quantized_tensor( handles: QuantizedKind>, - _scheme: QuantizationScheme, ) -> QuantizedTensor { let handle = handles.tensor.handle; match handle { diff --git a/crates/burn-jit/src/fusion/base.rs b/crates/burn-jit/src/fusion/base.rs index 4572f580b5..96c22d0898 100644 --- a/crates/burn-jit/src/fusion/base.rs +++ b/crates/burn-jit/src/fusion/base.rs @@ -1,10 +1,8 @@ use super::elemwise::optimization::{ElemwiseOptimization, ElemwiseOptimizationState}; -use crate::tensor::{JitQuantizationParameters, QJitTensor}; use crate::{element::BoolElement, fusion::elemwise::builder::ElementWiseBuilder}; use crate::{kernel, tensor::JitTensor, FloatElement, IntElement, JitBackend, JitRuntime}; use burn_fusion::{client::MutexFusionClient, FusionBackend, FusionRuntime}; -use burn_tensor::quantization::QuantizationScheme; -use burn_tensor::repr::{QuantizedKind, TensorHandle}; +use burn_tensor::repr::TensorHandle; use burn_tensor::DType; use burn_tensor::{repr::ReprBackend, Shape}; use core::marker::PhantomData; @@ -80,23 +78,9 @@ impl ReprBackend } fn quantized_tensor( - handles: QuantizedKind>, - scheme: QuantizationScheme, + handle: TensorHandle, ) -> burn_tensor::ops::QuantizedTensor { - let qtensor = handles.tensor.handle.into_tensor(handles.tensor.shape); - let scale = handles.scale.handle.into_tensor(handles.scale.shape); - let offset = handles.offset; - - let qparams = JitQuantizationParameters { - scale, - offset: offset.map(|h| h.handle.into_tensor(h.shape)), - }; - - QJitTensor { - qtensor, - scheme, - qparams, - } + handle.handle.into_tensor(handle.shape) } fn float_tensor_handle(tensor: burn_tensor::ops::FloatTensor) -> Self::Handle { @@ -111,17 +95,8 @@ impl ReprBackend tensor.into() } - fn quantized_tensor_handle( - tensor: burn_tensor::ops::QuantizedTensor, - ) -> QuantizedKind { - let qtensor: JitFusionHandle = tensor.qtensor.into(); - let scale: JitFusionHandle = tensor.qparams.scale.into(); - - QuantizedKind { - tensor: qtensor, - scale, - offset: tensor.qparams.offset.map(|offset| offset.into()), - } + fn quantized_tensor_handle(tensor: burn_tensor::ops::QuantizedTensor) -> Self::Handle { + tensor.into() } } diff --git a/crates/burn-jit/src/kernel/quantization/dequantize.rs b/crates/burn-jit/src/kernel/quantization/dequantize.rs index 4e2aa89cf7..b8b1764220 100644 --- a/crates/burn-jit/src/kernel/quantization/dequantize.rs +++ b/crates/burn-jit/src/kernel/quantization/dequantize.rs @@ -1,14 +1,21 @@ -use crate::tensor::{JitTensor, QJitTensor}; +use crate::tensor::JitTensor; use crate::FloatElement; -use crate::{IntElement, JitElement, JitRuntime}; +use crate::{JitElement, JitRuntime}; use burn_tensor::quantization::{QuantizationScheme, QuantizationType}; +use burn_tensor::DType; use cubecl::calculate_cube_count_elemwise; use cubecl::prelude::*; +use super::{QParams, QTensor}; + #[cube] -pub(crate) fn dequantize_affine_int8(value: i32, scale: F, offset: i32) -> F { +pub(crate) fn dequantize_affine_int8( + value: Line, + scale: f32, + offset: i32, +) -> Line { // x = scale * (x_q - offset) - scale * (F::cast_from(value) - F::cast_from(offset)) + Line::cast_from(scale) * Line::cast_from(value - Line::cast_from(offset)) } #[cube] @@ -21,199 +28,154 @@ pub(crate) fn extract_i8(value: u32, offset: u32) -> i32 { i32::cast_from(value) - sub } +#[cube] +pub(crate) fn extract_i8s(value: u32) -> Line { + let mut line = Line::empty(4); + // Extract each 8-bit segment + line[0] = extract_i8(value, 24); + line[1] = extract_i8(value, 16); + line[2] = extract_i8(value, 8); + line[3] = extract_i8(value, 0); + + line +} + #[cube(launch_unchecked)] pub(crate) fn dequantize_per_tensor_affine_int8_kernel( - input: &Tensor, - scale: &Tensor, - offset: &Tensor, - output: &mut Tensor, - #[comptime] vectorized: bool, + input: &QTensor, + output: &mut Tensor>, + #[comptime] scheme: QuantizationScheme, ) { - if ABSOLUTE_POS >= output.len() { + // Last two positions contain the qparams + if ABSOLUTE_POS >= output.len() - 2 { return; } - let scale = scale[0]; - let offset = offset[0]; + let qparams = QParams::new(scheme); + let scale = qparams.scale(input); + let offset = qparams.offset(input); - let num_packed = 4; let value = input[ABSOLUTE_POS]; - let output_pos = ABSOLUTE_POS * num_packed; - if vectorized { - let vectorization_factor = vectorization_of(input); + if comptime!(output.line_size() == 4) { + // For input line size = 1 + output[ABSOLUTE_POS] = dequantize_affine_int8(extract_i8s(value[0]), scale, offset); + } else { + // For very small inputs where number of elements < 4, the output line size is 1 #[unroll] - for i in 0..vectorization_factor { - // Extract each 8-bit segment - let v1 = extract_i8(value[i], 24); - let v2 = extract_i8(value[i], 16); - let v3 = extract_i8(value[i], 8); - let v4 = extract_i8(value[i], 0); - - output[output_pos * vectorization_factor + i * num_packed] = - dequantize_affine_int8::(v1, scale, offset); - output[output_pos * vectorization_factor + i * num_packed + 1] = - dequantize_affine_int8::(v2, scale, offset); - output[output_pos * vectorization_factor + i * num_packed + 2] = - dequantize_affine_int8::(v3, scale, offset); - output[output_pos * vectorization_factor + i * num_packed + 3] = - dequantize_affine_int8::(v4, scale, offset); + for i in 0..input.line_size() { + let out = dequantize_affine_int8::(extract_i8s(value[i]), scale, offset); + + #[unroll] + for j in 0..out.size() { + output[ABSOLUTE_POS + j] = Line::cast_from(out[j]); + } } - } else { - // Extract each 8-bit segment - let v1 = extract_i8(value, 24); - let v2 = extract_i8(value, 16); - let v3 = extract_i8(value, 8); - let v4 = extract_i8(value, 0); - - output[output_pos] = dequantize_affine_int8::(v1, scale, offset); - output[output_pos + 1] = dequantize_affine_int8::(v2, scale, offset); - output[output_pos + 2] = dequantize_affine_int8::(v3, scale, offset); - output[output_pos + 3] = dequantize_affine_int8::(v4, scale, offset); } } #[cube] -pub(crate) fn dequantize_symmetric_int8(value: i32, scale: F) -> F { +pub(crate) fn dequantize_symmetric_int8(value: Line, scale: f32) -> Line { // x = scale * x_q - scale * F::cast_from(value) + Line::cast_from(scale) * Line::cast_from(value) } // Would have wrapped symmetric with the same affine kernel but cube doesn't support Option for offset. #[cube(launch_unchecked)] pub(crate) fn dequantize_per_tensor_symmetric_int8_kernel( - input: &Tensor, - scale: &Tensor, - output: &mut Tensor, - #[comptime] vectorized: bool, + input: &QTensor, + output: &mut Tensor>, + #[comptime] scheme: QuantizationScheme, ) { - if ABSOLUTE_POS >= output.len() { + // Last position contains the qparam + if ABSOLUTE_POS >= input.len() - 1 { return; } - let scale = scale[0]; + let qparams = QParams::new(scheme); + let scale = qparams.scale(input); - let num_packed = 4; let value = input[ABSOLUTE_POS]; - let output_pos = ABSOLUTE_POS * num_packed; - if vectorized { - let vectorization_factor = vectorization_of(input); - #[unroll] - for i in 0..vectorization_factor { - for j in 0..num_packed { - let output_idx = output_pos * vectorization_factor + i * num_packed + j; - if output_idx >= output.len() { - return; // value not quantized (padding) - } - // Extract each 8-bit segment - let v = extract_i8(value[i], (3 - j) * 8); - output[output_idx] = dequantize_symmetric_int8::(v, scale); - } - } + if comptime!(output.line_size() == 4) { + // For input line size = 1 + output[ABSOLUTE_POS] = dequantize_symmetric_int8(extract_i8s(value[0]), scale); } else { - // Extract each 8-bit segment - for j in 0..num_packed { - let output_idx = output_pos + j; - if output_idx >= output.len() { - return; // value not quantized (padding) + // For very small inputs where number of elements < 4, the output line size is 1 + #[unroll] + for i in 0..input.line_size() { + let out = dequantize_symmetric_int8::(extract_i8s(value[i]), scale); + + #[unroll] + for j in 0..out.size() { + output[ABSOLUTE_POS + j] = Line::cast_from(out[j]); } - // Extract each 8-bit segment - let v = extract_i8(value, (3 - j) * 8); - output[output_pos + j] = dequantize_symmetric_int8::(v, scale); } } } -pub(crate) fn dequantize_per_tensor( - tensor: JitTensor, - scale: JitTensor, - offset: Option>, -) -> JitTensor +pub(crate) fn dequantize_per_tensor(tensor: JitTensor) -> JitTensor where R: JitRuntime, F: JitElement, - I: IntElement, { // The actual number of elements is 1/4 (four int8 values packed in a single u32) - // so we choose a vectorization factor to match a valid input binding size. - let ndims = tensor.shape.num_dims(); + // so we choose a line size to match a valid input binding size. let num_out_elems = tensor.shape.num_elements(); let num_elems = usize::div_ceil(num_out_elems, 4); - let vectorization_factor = [4u8, 2, 1] - .iter() - .filter_map(|&v| { - if num_elems >= v as usize { - Some(v) - } else { - None - } - }) - .next() - .unwrap(); + let line_size_in = 1; + let line_size_out = if num_out_elems < 4 { 1 } else { 4 }; let cube_dim = CubeDim::default(); - let cube_count = - calculate_cube_count_elemwise(num_elems / vectorization_factor as usize, cube_dim); + let cube_count = calculate_cube_count_elemwise(num_elems / line_size_in as usize, cube_dim); - let shape_output = tensor.shape.clone(); let client = tensor.client.clone(); let handle = client.empty(num_out_elems * core::mem::size_of::()); + let output = JitTensor::new_contiguous( client.clone(), tensor.device.clone(), - shape_output, + tensor.shape.clone(), handle, F::dtype(), ); - let dummy_array = vec![1; ndims]; - if let Some(offset) = offset { - unsafe { - dequantize_per_tensor_affine_int8_kernel::launch_unchecked::( - &client, - cube_count, - cube_dim, - tensor.as_tensor_arg::(vectorization_factor), - // Ignore shape and stride - TensorArg::from_raw_parts::(&scale.handle, &dummy_array, &dummy_array, 1), - TensorArg::from_raw_parts::(&offset.handle, &dummy_array, &dummy_array, 1), - output.as_tensor_arg::(1), - vectorization_factor > 1, - ) - }; - } else { - unsafe { - dequantize_per_tensor_symmetric_int8_kernel::launch_unchecked::( - &client, - cube_count, - cube_dim, - tensor.as_tensor_arg::(vectorization_factor), - // Ignore shape and stride - TensorArg::from_raw_parts::(&scale.handle, &dummy_array, &dummy_array, 1), - output.as_tensor_arg::(1), - vectorization_factor > 1, - ) - }; + if let DType::QFloat(scheme) = tensor.dtype { + match scheme { + QuantizationScheme::PerTensorAffine(QuantizationType::QInt8) => { + unsafe { + dequantize_per_tensor_affine_int8_kernel::launch_unchecked::( + &client, + cube_count, + cube_dim, + tensor.as_array_arg::(line_size_in), + output.as_tensor_arg::(line_size_out), + scheme, + ) + }; + } + QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8) => { + unsafe { + dequantize_per_tensor_symmetric_int8_kernel::launch_unchecked::( + &client, + cube_count, + cube_dim, + tensor.as_array_arg::(line_size_in), + output.as_tensor_arg::(line_size_out), + scheme, + ) + }; + } + } } output } /// Convert the tensor back to a higher precision data type. -pub fn dequantize(tensor: QJitTensor) -> JitTensor +pub fn dequantize(tensor: JitTensor) -> JitTensor where R: JitRuntime, F: FloatElement, - I: IntElement, { - match tensor.scheme { - QuantizationScheme::PerTensorAffine(dtype) - | QuantizationScheme::PerTensorSymmetric(dtype) => match dtype { - QuantizationType::QInt8 => dequantize_per_tensor::( - tensor.qtensor, - tensor.qparams.scale, - tensor.qparams.offset, - ), - }, - } + dequantize_per_tensor::(tensor) } diff --git a/crates/burn-jit/src/kernel/quantization/mod.rs b/crates/burn-jit/src/kernel/quantization/mod.rs index a0244df01f..2b47bb61b6 100644 --- a/crates/burn-jit/src/kernel/quantization/mod.rs +++ b/crates/burn-jit/src/kernel/quantization/mod.rs @@ -1,5 +1,7 @@ mod dequantize; +mod qtensor; mod quantize; pub use dequantize::*; +pub use qtensor::*; pub use quantize::*; diff --git a/crates/burn-jit/src/kernel/quantization/qtensor.rs b/crates/burn-jit/src/kernel/quantization/qtensor.rs new file mode 100644 index 0000000000..a07a2eed52 --- /dev/null +++ b/crates/burn-jit/src/kernel/quantization/qtensor.rs @@ -0,0 +1,95 @@ +#![allow(missing_docs)] // cube derive macros + +use burn_tensor::quantization::QuantizationScheme; +use cubecl::prelude::*; + +/// Quantization parameters. +#[derive(CubeLaunch)] +pub struct QParams { + #[cube(comptime)] + scheme: QuantizationScheme, +} + +/// Quantized tensor representation. +pub type QTensor = Array>; + +#[cube] +impl QParams { + /// Create a new quantization parameters instance. + pub fn new(scheme: QuantizationScheme) -> Self { + QParams { scheme } + } + + // NOTE: a couple of incompatible things for this to work.. + // notably `switch_expand_expr` only works for CubePrimitive and it doesn't really make sense to implement that for a tuple + // or QParams type + // + // /// Get the quantization parameters: + // /// - Floating-point scaling factor (encoded as u32) + // /// - Zero-point offset + // pub fn qparams(&self) -> (f32, i32) { + // let len = self.tensor.buffer_len(); + // match comptime![self.scheme] { + // QuantizationScheme::PerTensorAffine(_) => match self.tensor.line_size() { + // // For line size of 1, scale is second to last in the buffer while the zero-point offset is the last element + // 1 => ( + // f32::cast_from(self.tensor[len - 2][0]), + // i32::cast_from(self.tensor[len - 1][0]), + // ), + // // QParams { + // // scale: f32::cast_from(self.tensor[len - 2][0]), + // // offset: i32::cast_from(self.tensor[len - 1][0]), + // // }, + // // For any other line size > 1, scale and zero-point offset are the first two elements of the last line + // _ => { + // let line = self.tensor[len - 1]; + // // QParams { + // // scale: f32::cast_from(line[0]), + // // offset: i32::cast_from(line[1]), + // // } + // (f32::cast_from(line[0]), i32::cast_from(line[1])) + // } + // }, + // // Symmetric quantization only contains the scaling factor, so we return 0 for the zero-point offset + // QuantizationScheme::PerTensorSymmetric(_) => { + // (f32::cast_from(self.tensor[len - 1][0]), 0) + // } // QParams { + // // scale: f32::cast_from(self.tensor[len - 1][0]), + // // offset: 0, + // // }, + // } + // } + + /// Get the floating-point scaling factor. + pub fn scale(&self, tensor: &QTensor) -> f32 { + let len = tensor.len(); + match comptime!(self.scheme) { + QuantizationScheme::PerTensorAffine(_) => match comptime!(tensor.line_size()) { + // For line size of 1, scale is last in the buffer while the zero-point offset is the second-to-last element + 1 => f32::bitcast_from(tensor[len - 1][tensor.line_size() - 1]), + // For any other line size > 1, scale and zero-point offset are the first two elements of the last line + _ => f32::bitcast_from(tensor[len - 1][tensor.line_size() - 2]), + }, + // Symmetric quantization only contains the scaling factor as the last element + QuantizationScheme::PerTensorSymmetric(_) => { + f32::bitcast_from(tensor[len - 1][tensor.line_size() - 1]) + } + } + } + + /// Get the zero-point offset. + pub fn offset(&self, tensor: &QTensor) -> i32 { + let len = tensor.len(); + let line_size = comptime!(tensor.line_size()); + match comptime!(self.scheme) { + QuantizationScheme::PerTensorAffine(_) => match line_size { + // For line size of 1, scale is last in the buffer while the zero-point offset is the second-to-last element + 1 => i32::cast_from(tensor[len - 2][line_size]), + // For any other line size > 1, scale and zero-point offset are the first two elements of the last line + _ => i32::cast_from(tensor[len - 1][line_size]), + }, + // Symmetric quantization only contains the scaling factor, so we return 0 for the zero-point offset + QuantizationScheme::PerTensorSymmetric(_) => 0, + } + } +} diff --git a/crates/burn-jit/src/kernel/quantization/quantize.rs b/crates/burn-jit/src/kernel/quantization/quantize.rs index 256ae0f418..13b9e3284a 100644 --- a/crates/burn-jit/src/kernel/quantization/quantize.rs +++ b/crates/burn-jit/src/kernel/quantization/quantize.rs @@ -1,4 +1,4 @@ -use crate::tensor::{JitQuantizationParameters, JitTensor, QJitTensor}; +use crate::tensor::JitTensor; use crate::FloatElement; use crate::{IntElement, JitElement, JitRuntime}; use burn_tensor::quantization::{QuantizationScheme, QuantizationType}; @@ -7,34 +7,31 @@ use cubecl::prelude::*; #[cube] pub(crate) fn quantize_affine_int8( - value: F, - scale: F, + value: Line, + scale: f32, offset: i32, - range_min: F, - range_max: F, -) -> u32 { - let offset = F::cast_from(offset); - + range_min: f32, + range_max: f32, +) -> Line { // x_q = clamp(round(x / scale + offset), a, b) // NOTE: we add 256 before casting to unsigned to correctly represent negative values - u32::cast_from( - i32::cast_from(F::clamp( - F::round((value / scale) + offset), - range_min, - range_max, - )) + 256, + Line::cast_from( + Line::clamp( + Line::round((value / Line::cast_from(scale)) + Line::cast_from(offset)), + Line::cast_from(range_min), + Line::cast_from(range_max), + ) + Line::cast_from(comptime!(256f32)), ) } #[cube(launch_unchecked)] pub(crate) fn quantize_per_tensor_affine_int8_kernel( - input: &Tensor, + input: &Tensor>, scale: &Tensor, offset: &Tensor, range_min: f32, range_max: f32, - output: &mut Tensor, - #[comptime] vectorized: bool, + output: &mut Array, ) { if ABSOLUTE_POS >= output.len() { return; @@ -43,20 +40,29 @@ pub(crate) fn quantize_per_tensor_affine_int8_kernel( let scale = scale[0]; let offset = offset[0]; - let num_packed = 4; - let mut v_packed = 0; + // Cast the scale to u32 and write the value in the output + if ABSOLUTE_POS == output.len() - 1 { + output[ABSOLUTE_POS] = u32::bitcast_from(scale); + return; + } - if vectorized { - // Assuming a vectorization factor of 4 (equal to the number of values packed) - let value = input[ABSOLUTE_POS]; - let vectorization_factor = vectorization_of(input); - #[unroll] - for i in 0..vectorization_factor { - let v = quantize_affine_int8::(value[i], scale, offset, range_min, range_max); - // Shift and combine into u32 - v_packed |= (v & 0xFF) << (8 * (num_packed - i - 1)); - } + // Cast the offset to u32 and write the value in the output + if ABSOLUTE_POS == output.len() - 2 { + output[ABSOLUTE_POS] = u32::bitcast_from(offset); + return; + } + + let line_size = comptime!(input.line_size()); + if comptime!(line_size == 4) { + // Assuming a line size of 4 (equal to the number of values packed) + let value = + quantize_affine_int8::(input[ABSOLUTE_POS], scale, offset, range_min, range_max); + // Shift and combine into u32 + output[ABSOLUTE_POS] = pack_i8s_to_u32s(value); } else { + let mut v_packed = 0; + let num_packed = comptime!(4); + #[unroll] for i in 0..num_packed { let v = quantize_affine_int8::( input[ABSOLUTE_POS + i], @@ -66,34 +72,52 @@ pub(crate) fn quantize_per_tensor_affine_int8_kernel( range_max, ); // Shift and combine into u32 - v_packed |= (v & 0xFF) << (8 * (num_packed - i - 1)); + v_packed |= (v[0] & 0xFF) << (8 * (num_packed - i - 1)); } + output[ABSOLUTE_POS] = v_packed; } - - output[ABSOLUTE_POS] = v_packed; } #[cube] pub(crate) fn quantize_symmetric_int8( - value: F, - scale: F, + value: Line, + scale: f32, range_min: F, range_max: F, -) -> u32 { +) -> Line { // x_q = clamp(round(x / scale), a, b) // NOTE: we add 256 before casting to unsigned to correctly represent negative values - u32::cast_from(i32::cast_from(F::clamp(F::round(value / scale), range_min, range_max)) + 256) + Line::cast_from( + Line::clamp( + Line::round(value / Line::cast_from(scale)), + Line::new(range_min), + Line::new(range_max), + ) + Line::cast_from(comptime!(256f32)), + ) +} + +#[cube] +pub(crate) fn pack_i8s_to_u32s(value: Line) -> u32 { + // NOTE: assuming line size of 4 + let line_size = value.size(); + let mut v_packed = 0; + + #[unroll] + for i in 0..line_size { + // Shift and combine into u32 + v_packed |= (value[i] & 0xFF) << (8 * (line_size - i - 1)); + } + v_packed } // Would have wrapped symmetric with the same affine kernel but cube doesn't support Option for offset. #[cube(launch_unchecked)] pub(crate) fn quantize_per_tensor_symmetric_int8_kernel( - input: &Tensor, + input: &Tensor>, scale: &Tensor, range_min: f32, range_max: f32, - output: &mut Tensor, - #[comptime] vectorized: bool, + output: &mut Array, ) { if ABSOLUTE_POS >= output.len() { return; @@ -101,20 +125,23 @@ pub(crate) fn quantize_per_tensor_symmetric_int8_kernel( let scale = scale[0]; - let num_packed = 4; - let mut v_packed = 0; + // Cast the scale to u32 and write the value in the output + if ABSOLUTE_POS == output.len() - 1 { + output[ABSOLUTE_POS] = u32::bitcast_from(scale); + return; + } - if vectorized { + let line_size = comptime!(input.line_size()); + if comptime!(line_size == 4) { // Assuming a vectorization factor of 4 (equal to the number of values packed) - let value = input[ABSOLUTE_POS]; - let vectorization_factor = vectorization_of(input); - #[unroll] - for i in 0..vectorization_factor { - let v = quantize_symmetric_int8::(value[i], scale, range_min, range_max); - // Shift and combine into u32 - v_packed |= (v & 0xFF) << (8 * (num_packed - i - 1)); - } + let value = + quantize_symmetric_int8::(input[ABSOLUTE_POS], scale, range_min, range_max); + // Shift and combine into u32 + output[ABSOLUTE_POS] = pack_i8s_to_u32s(value); } else { + let num_packed = comptime!(4); + let mut v_packed = 0; + #[unroll] for i in 0..num_packed { let v = quantize_symmetric_int8::( input[ABSOLUTE_POS + i], @@ -123,17 +150,17 @@ pub(crate) fn quantize_per_tensor_symmetric_int8_kernel( range_max, ); // Shift and combine into u32 - v_packed |= (v & 0xFF) << (8 * (num_packed - i - 1)); + v_packed |= (v[0] & 0xFF) << (8 * (num_packed - i - 1)); } + output[ABSOLUTE_POS] = v_packed; } - - output[ABSOLUTE_POS] = v_packed; } pub(crate) fn quantize_per_tensor( tensor: JitTensor, scale: JitTensor, offset: Option>, + scheme: QuantizationScheme, ) -> JitTensor where R: JitRuntime, @@ -142,86 +169,90 @@ where { let ndims = tensor.shape.num_dims(); let num_elems = tensor.shape.num_elements(); - let shape_output = tensor.shape.clone(); let client = tensor.client.clone(); // Output tensor contains 4x less elements (four int8 values packed in a single u32) - let handle = client.empty(usize::div_ceil(num_elems, 4) * core::mem::size_of::()); - let output = JitTensor::new_contiguous( - client.clone(), - tensor.device.clone(), - shape_output, - handle, - burn_tensor::DType::U32, - ); + let output_num_elems = usize::div_ceil(num_elems, 4) * core::mem::size_of::(); // Force vectorization to process 4 quantized values packed for 1 output value - let vectorization_factor: u8 = if num_elems < 4 { 1 } else { 4 }; + let line_size: u8 = if num_elems < 4 { 1 } else { 4 }; let cube_dim = CubeDim::default(); - let cube_count = - calculate_cube_count_elemwise(num_elems / vectorization_factor as usize, cube_dim); + let cube_count = calculate_cube_count_elemwise(num_elems / line_size as usize, cube_dim); let dummy_array = vec![1; ndims]; if let Some(offset) = offset { + // Scale and offset qparams are also packed in the tensor dat + let handle = client + .empty(output_num_elems + core::mem::size_of::() + core::mem::size_of::()); + let output = JitTensor::new_contiguous( + client.clone(), + tensor.device.clone(), + tensor.shape.clone(), + handle, + burn_tensor::DType::QFloat(scheme), + ); + unsafe { quantize_per_tensor_affine_int8_kernel::launch_unchecked::( &client, cube_count, cube_dim, - tensor.as_tensor_arg::(vectorization_factor), + tensor.as_tensor_arg::(line_size), // Ignore shape and stride TensorArg::from_raw_parts::(&scale.handle, &dummy_array, &dummy_array, 1), TensorArg::from_raw_parts::(&offset.handle, &dummy_array, &dummy_array, 1), ScalarArg::new(i8::MIN as f32), ScalarArg::new(i8::MAX as f32), - output.as_tensor_arg::(1), - vectorization_factor > 1, + output.as_array_arg::(1), ) }; + output } else { + // Scale qparam is also packed in the tensor data + let handle = client.empty(output_num_elems + core::mem::size_of::()); + let output = JitTensor::new_contiguous( + client.clone(), + tensor.device.clone(), + tensor.shape.clone(), + handle, + burn_tensor::DType::QFloat(scheme), + ); + unsafe { quantize_per_tensor_symmetric_int8_kernel::launch_unchecked::( &client, cube_count, cube_dim, - tensor.as_tensor_arg::(vectorization_factor), + tensor.as_tensor_arg::(line_size), // Ignore shape and stride TensorArg::from_raw_parts::(&scale.handle, &dummy_array, &dummy_array, 1), ScalarArg::new(-i8::MAX as f32), ScalarArg::new(i8::MAX as f32), - output.as_tensor_arg::(1), - vectorization_factor > 1, + output.as_array_arg::(1), ) }; - } - output + output + } } /// Convert the tensor to a lower precision data type based on the quantization scheme and parameters. pub fn quantize( tensor: JitTensor, scheme: &QuantizationScheme, - qparams: JitQuantizationParameters, -) -> QJitTensor + scale: JitTensor, + offset: Option>, +) -> JitTensor where R: JitRuntime, F: FloatElement, I: IntElement, { - let qtensor = match scheme { + match scheme { QuantizationScheme::PerTensorAffine(dtype) | QuantizationScheme::PerTensorSymmetric(dtype) => match dtype { - QuantizationType::QInt8 => quantize_per_tensor::( - tensor, - qparams.scale.clone(), - qparams.offset.clone(), - ), + QuantizationType::QInt8 => { + quantize_per_tensor::(tensor, scale, offset, *scheme) + } }, - }; - - QJitTensor { - qtensor, - scheme: *scheme, - qparams, } } diff --git a/crates/burn-jit/src/ops/qtensor.rs b/crates/burn-jit/src/ops/qtensor.rs index d9ef224ace..cde44b0552 100644 --- a/crates/burn-jit/src/ops/qtensor.rs +++ b/crates/burn-jit/src/ops/qtensor.rs @@ -2,30 +2,32 @@ use std::ops::Range; use burn_tensor::{ ops::{FloatTensor, IntTensor, QTensorOps, QuantizedTensor}, - quantization::{ - QTensorPrimitive, QuantizationParametersPrimitive, QuantizationScheme, QuantizationType, - }, + quantization::{QuantizationParametersPrimitive, QuantizationScheme, QuantizationType}, DType, Device, Shape, TensorData, }; use crate::{ - element::BoolElement, - kernel, - tensor::{JitQuantizationParameters, JitTensor, QJitTensor}, - FloatElement, IntElement, JitBackend, JitRuntime, + element::BoolElement, kernel, tensor::JitTensor, FloatElement, IntElement, JitBackend, + JitRuntime, }; -use cubecl::CubeElement; /// Create a quantized tensor with packed values (u32). -fn packed_tensor>( +fn new_qtensor>( data: &[u8], shape: S, + scheme: QuantizationScheme, device: &R::Device, ) -> JitTensor { let client = R::client(device); let buffer = client.create(data); - JitTensor::new_contiguous(client, device.clone(), shape.into(), buffer, DType::U32) + JitTensor::new_contiguous( + client, + device.clone(), + shape.into(), + buffer, + DType::QFloat(scheme), + ) } impl QTensorOps for JitBackend @@ -40,17 +42,9 @@ where DType::QFloat(scheme) => match scheme { QuantizationScheme::PerTensorAffine(QuantizationType::QInt8) | QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8) => { - // Convert quantized values to packed u32s - let qparams = data.get_q_params::().unwrap(); - QJitTensor { - qtensor: packed_tensor(data.values_as_bytes(), data.shape.clone(), device), - scheme, - qparams: JitQuantizationParameters::new( - qparams.scale, - qparams.offset, - device, - ), - } + // TensorData quantized representation is the same, with multiple quantized values + // packed into u32 and quantization parameters appended to the bytes + new_qtensor(data.as_bytes(), data.shape.clone(), scheme, device) } }, _ => panic!( @@ -65,52 +59,37 @@ where scheme: &QuantizationScheme, qparams: QuantizationParametersPrimitive, ) -> QuantizedTensor { - kernel::quantization::quantize::(tensor, scheme, qparams.into()) + kernel::quantization::quantize::(tensor, scheme, qparams.scale, qparams.offset) } fn dequantize(tensor: QuantizedTensor) -> FloatTensor { - kernel::quantization::dequantize::(tensor) + kernel::quantization::dequantize::(tensor) } fn q_device(tensor: &QuantizedTensor) -> Device { - tensor.qtensor.device.clone() + tensor.device.clone() } fn q_to_device(tensor: QuantizedTensor, device: &Device) -> QuantizedTensor { - let mut tensor = tensor; - tensor.qtensor = super::to_device(tensor.qtensor, device); - tensor.qparams.scale = super::to_device(tensor.qparams.scale, device); - tensor.qparams.offset = tensor.qparams.offset.map(|x| super::to_device(x, device)); - - tensor + super::to_device(tensor, device) } fn q_reshape(tensor: QuantizedTensor, shape: Shape) -> QuantizedTensor { - QJitTensor { - qtensor: super::reshape(tensor.qtensor, shape), - scheme: tensor.scheme, - qparams: tensor.qparams, - } + super::reshape(tensor, shape) } async fn q_into_data(tensor: QuantizedTensor) -> TensorData { - let strategy = tensor.strategy(); - let qtensor = kernel::into_contiguous(tensor.qtensor); - - let bytes = qtensor - .client - .read_one_async(qtensor.handle.binding()) - .await; - - // TensorData keeps quantized values packed into 32-bit unsigned integers so we can - // keep the current representation, just cast the bytes as u32. - match &tensor.scheme { - QuantizationScheme::PerTensorAffine(dtype) - | QuantizationScheme::PerTensorSymmetric(dtype) => match dtype { - QuantizationType::QInt8 => { - TensorData::quantized(u32::from_bytes(&bytes).to_vec(), qtensor.shape, strategy) - } - }, + let tensor = kernel::into_contiguous(tensor); + let bytes = tensor.client.read_one_async(tensor.handle.binding()).await; + + // TODO: this should be refactored such that the bytes type is opaque. + // With this, the logic for handling the bytes representation of quantized data + // (as well as all data manipulations) will be encapsulated in the type. + // Creating a TensorData struct directly from some bytes should probably not be possible outside of the crate. + TensorData { + bytes, + shape: tensor.shape.into(), + dtype: tensor.dtype, } } diff --git a/crates/burn-jit/src/tensor/base.rs b/crates/burn-jit/src/tensor/base.rs index 3eb44b3e02..e114b2f8e6 100644 --- a/crates/burn-jit/src/tensor/base.rs +++ b/crates/burn-jit/src/tensor/base.rs @@ -1,6 +1,7 @@ use crate::element::JitElement; use crate::kernel::{launch_unary, unary_op, UnaryOp}; use crate::JitRuntime; +use burn_tensor::quantization::QTensorPrimitive; use burn_tensor::{DType, Shape, TensorMetadata}; use cubecl::client::ComputeClient; use cubecl::frontend::Numeric; @@ -73,6 +74,19 @@ impl TensorMetadata for JitTensor { } } +impl QTensorPrimitive for JitTensor { + fn scheme(&self) -> &burn_tensor::quantization::QuantizationScheme { + if let DType::QFloat(scheme) = &self.dtype { + scheme + } else { + panic!( + "Quantization scheme is not valid for dtype {:?}", + self.dtype, + ) + } + } +} + /// Macro to execute a kernel/operation for a given element type. /// /// # Panics @@ -241,7 +255,16 @@ where strides: &self.strides, shape: &self.shape.dims, runtime: PhantomData, - elem_size: self.dtype.size(), + elem_size: self.elem_size(), + } + } + + fn elem_size(&self) -> usize { + if let DType::QFloat(_) = self.dtype { + // Encoded as u32 + core::mem::size_of::() + } else { + self.dtype.size() } } @@ -259,6 +282,17 @@ where } } + /// Return the reference to an array argument. + pub fn as_array_arg(&self, vectorisation: u8) -> ArrayArg<'_, R> { + unsafe { + ArrayArg::from_raw_parts::( + &self.handle, + self.handle.size() as usize / core::mem::size_of::(), + vectorisation, + ) + } + } + pub(crate) fn can_mut_broadcast(&self, rhs: &Self) -> bool { if !self.handle.can_mut() || !self.is_contiguous_buffer() { return false; diff --git a/crates/burn-jit/src/tensor/mod.rs b/crates/burn-jit/src/tensor/mod.rs index 960a77e445..cbcb6ac7e7 100644 --- a/crates/burn-jit/src/tensor/mod.rs +++ b/crates/burn-jit/src/tensor/mod.rs @@ -1,5 +1,3 @@ mod base; -mod qtensor; pub use base::*; -pub(crate) use qtensor::*; diff --git a/crates/burn-jit/src/tensor/qtensor.rs b/crates/burn-jit/src/tensor/qtensor.rs deleted file mode 100644 index 4ef5f77589..0000000000 --- a/crates/burn-jit/src/tensor/qtensor.rs +++ /dev/null @@ -1,125 +0,0 @@ -use burn_tensor::{ - quantization::{ - AffineQuantization, QTensorPrimitive, QuantizationParametersPrimitive, QuantizationScheme, - QuantizationStrategy, QuantizationType, SymmetricQuantization, - }, - read_sync, DType, TensorData, TensorMetadata, -}; - -use crate::{ - element::BoolElement, ops::into_data, FloatElement, IntElement, JitBackend, JitRuntime, -}; - -use super::JitTensor; - -/// A quantized tensor primitive. -#[derive(Debug)] -pub struct QJitTensor { - /// The quantized tensor. - /// Values are stored as multiple packed quantized values in u32. - pub qtensor: JitTensor, - /// The quantization scheme. - pub scheme: QuantizationScheme, - /// The quantization parameters. - pub qparams: JitQuantizationParameters, -} - -impl QTensorPrimitive for QJitTensor { - fn scheme(&self) -> &QuantizationScheme { - &self.scheme - } - - fn strategy(&self) -> QuantizationStrategy { - match &self.scheme { - QuantizationScheme::PerTensorAffine(dtype) => match dtype { - QuantizationType::QInt8 => { - let scale = read_sync(into_data::(self.qparams.scale.clone())) - .iter() - .next() - .unwrap(); - let offset = - read_sync(into_data::(self.qparams.offset.clone().unwrap())) - .iter() - .next() - .unwrap(); - QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init( - scale, offset, - )) - } - }, - QuantizationScheme::PerTensorSymmetric(dtype) => match dtype { - QuantizationType::QInt8 => { - let scale = read_sync(into_data::(self.qparams.scale.clone())) - .iter() - .next() - .unwrap(); - QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(scale)) - } - }, - } - } -} - -impl Clone for QJitTensor { - fn clone(&self) -> Self { - Self { - qtensor: self.qtensor.clone(), - scheme: self.scheme, - qparams: self.qparams.clone(), - } - } -} - -impl TensorMetadata for QJitTensor { - fn dtype(&self) -> DType { - DType::QFloat(self.scheme) - } - - fn shape(&self) -> burn_tensor::Shape { - self.qtensor.shape() - } -} - -/// The quantization parameters. -#[derive(Debug)] -pub struct JitQuantizationParameters { - /// The scaling factor. - pub scale: JitTensor, - /// The zero-point offset. - pub offset: Option>, -} - -impl Clone for JitQuantizationParameters { - fn clone(&self) -> Self { - Self { - scale: self.scale.clone(), - offset: self.offset.clone(), - } - } -} - -impl - From>> - for JitQuantizationParameters -{ - fn from(value: QuantizationParametersPrimitive>) -> Self { - JitQuantizationParameters { - scale: value.scale, - offset: value.offset, - } - } -} - -impl JitQuantizationParameters { - pub fn new( - scale: F, - offset: Option, - device: &R::Device, - ) -> Self { - Self { - scale: crate::ops::from_data::(TensorData::new(vec![scale], [1]), device), - offset: offset - .map(|o| crate::ops::from_data::(TensorData::new(vec![o], [1]), device)), - } - } -} diff --git a/crates/burn-ndarray/src/backend.rs b/crates/burn-ndarray/src/backend.rs index 060899b979..6ce29e79cf 100644 --- a/crates/burn-ndarray/src/backend.rs +++ b/crates/burn-ndarray/src/backend.rs @@ -4,8 +4,7 @@ use alloc::string::String; use burn_common::stub::Mutex; use burn_tensor::backend::{Backend, DeviceId, DeviceOps}; use burn_tensor::ops::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor}; -use burn_tensor::quantization::QuantizationScheme; -use burn_tensor::repr::{HandleKind, QuantizedKind, ReprBackend, TensorHandle}; +use burn_tensor::repr::{HandleKind, ReprBackend, TensorHandle}; use core::marker::PhantomData; use rand::{rngs::StdRng, SeedableRng}; @@ -99,14 +98,10 @@ impl ReprBackend } } - fn quantized_tensor( - handles: QuantizedKind>, - _scheme: QuantizationScheme, - ) -> QuantizedTensor { - let handle = handles.tensor.handle; - match handle { + fn quantized_tensor(handle: TensorHandle) -> QuantizedTensor { + match handle.handle { HandleKind::Quantized(handle) => handle, - _ => panic!("Expected quantized handle, got {}", handle.name()), + _ => panic!("Expected quantized handle, got {}", handle.handle.name()), } } @@ -122,13 +117,7 @@ impl ReprBackend HandleKind::Bool(tensor) } - fn quantized_tensor_handle(tensor: QuantizedTensor) -> QuantizedKind { - QuantizedKind { - tensor: HandleKind::Quantized(tensor), - // The quantized tensor primitive already encapsulates the required quantization - // parameters so we set the scale as an empty handle (unused). - scale: HandleKind::Empty, - offset: None, - } + fn quantized_tensor_handle(tensor: QuantizedTensor) -> Self::Handle { + HandleKind::Quantized(tensor) } } diff --git a/crates/burn-ndarray/src/ops/qtensor.rs b/crates/burn-ndarray/src/ops/qtensor.rs index 3542c43992..9f522b6155 100644 --- a/crates/burn-ndarray/src/ops/qtensor.rs +++ b/crates/burn-ndarray/src/ops/qtensor.rs @@ -3,8 +3,8 @@ use core::ops::Range; use burn_tensor::{ ops::{FloatTensor, IntTensor, QTensorOps, QuantizedTensor}, quantization::{ - AffineQuantization, QParams, QTensorPrimitive, QuantizationParametersPrimitive, - QuantizationScheme, QuantizationStrategy, QuantizationType, SymmetricQuantization, + AffineQuantization, QParams, QuantizationParametersPrimitive, QuantizationScheme, + QuantizationStrategy, QuantizationType, SymmetricQuantization, }, DType, Shape, TensorData, TensorMetadata, }; diff --git a/crates/burn-ndarray/src/tensor.rs b/crates/burn-ndarray/src/tensor.rs index 64e8037c91..698119ecea 100644 --- a/crates/burn-ndarray/src/tensor.rs +++ b/crates/burn-ndarray/src/tensor.rs @@ -352,12 +352,9 @@ pub struct NdArrayQTensor { pub qparams: QParams, } -impl QTensorPrimitive for NdArrayQTensor { - fn scheme(&self) -> &QuantizationScheme { - &self.scheme - } - - fn strategy(&self) -> QuantizationStrategy { +impl NdArrayQTensor { + /// Returns the quantization strategy, including quantization parameters, for the given tensor. + pub fn strategy(&self) -> QuantizationStrategy { match self.scheme { QuantizationScheme::PerTensorAffine(QuantizationType::QInt8) => { QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init( @@ -374,6 +371,12 @@ impl QTensorPrimitive for NdArrayQTensor { } } +impl QTensorPrimitive for NdArrayQTensor { + fn scheme(&self) -> &QuantizationScheme { + &self.scheme + } +} + impl TensorMetadata for NdArrayQTensor { fn dtype(&self) -> DType { DType::QFloat(self.scheme) diff --git a/crates/burn-router/src/backend.rs b/crates/burn-router/src/backend.rs index a5ada5e5fd..d809bbb694 100644 --- a/crates/burn-router/src/backend.rs +++ b/crates/burn-router/src/backend.rs @@ -3,7 +3,7 @@ use core::marker::PhantomData; use burn_tensor::{ backend::Backend, - quantization::{QTensorPrimitive, QuantizationScheme, QuantizationStrategy}, + quantization::{QTensorPrimitive, QuantizationScheme}, }; use super::{get_client, set_seed, RouterTensor, RunnerChannel, RunnerClient}; @@ -36,10 +36,6 @@ impl QTensorPrimitive for RouterTensor { fn scheme(&self) -> &QuantizationScheme { todo!() } - - fn strategy(&self) -> QuantizationStrategy { - todo!() - } } impl Backend for BackendRouter { diff --git a/crates/burn-tch/src/ops/qtensor.rs b/crates/burn-tch/src/ops/qtensor.rs index 360ec628fe..0a7ef1c79b 100644 --- a/crates/burn-tch/src/ops/qtensor.rs +++ b/crates/burn-tch/src/ops/qtensor.rs @@ -3,8 +3,7 @@ use std::ops::Range; use burn_tensor::{ ops::{FloatTensor, IntTensor, QTensorOps, QuantizedTensor}, quantization::{ - QParams, QTensorPrimitive, QuantizationParametersPrimitive, QuantizationScheme, - QuantizationType, + QParams, QuantizationParametersPrimitive, QuantizationScheme, QuantizationType, }, DType, Shape, TensorData, TensorMetadata, }; diff --git a/crates/burn-tch/src/tensor.rs b/crates/burn-tch/src/tensor.rs index b634954c8e..1c9dbf594e 100644 --- a/crates/burn-tch/src/tensor.rs +++ b/crates/burn-tch/src/tensor.rs @@ -328,22 +328,9 @@ pub struct TchQTensor { pub scheme: QuantizationScheme, } -impl TensorMetadata for TchQTensor { - fn dtype(&self) -> DType { - DType::QFloat(self.scheme) - } - - fn shape(&self) -> Shape { - self.qtensor.shape() - } -} - -impl QTensorPrimitive for TchQTensor { - fn scheme(&self) -> &QuantizationScheme { - &self.scheme - } - - fn strategy(&self) -> QuantizationStrategy { +impl TchQTensor { + /// Returns the quantization strategy, including quantization parameters, for the given tensor. + pub fn strategy(&self) -> QuantizationStrategy { match &self.scheme { QuantizationScheme::PerTensorAffine(dtype) => match dtype { QuantizationType::QInt8 => { @@ -367,6 +354,22 @@ impl QTensorPrimitive for TchQTensor { } } +impl TensorMetadata for TchQTensor { + fn dtype(&self) -> DType { + DType::QFloat(self.scheme) + } + + fn shape(&self) -> Shape { + self.qtensor.shape() + } +} + +impl QTensorPrimitive for TchQTensor { + fn scheme(&self) -> &QuantizationScheme { + &self.scheme + } +} + #[cfg(test)] mod tests { use crate::LibTorch; diff --git a/crates/burn-tensor/src/repr/backend.rs b/crates/burn-tensor/src/repr/backend.rs index dad341697b..9d92b1e9ed 100644 --- a/crates/burn-tensor/src/repr/backend.rs +++ b/crates/burn-tensor/src/repr/backend.rs @@ -1,7 +1,6 @@ use crate::{ backend::Backend, ops::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor}, - quantization::QuantizationScheme, Shape, }; @@ -14,17 +13,6 @@ pub struct TensorHandle { pub shape: Shape, } -/// A simple struct to encapsulate a quantized tensor kind. -#[derive(Clone)] -pub struct QuantizedKind { - /// The quantized tensor. - pub tensor: T, - /// The scaling factor. - pub scale: T, - /// The zero-point offset. - pub offset: Option, -} - /// Backend extension trait that allows an existing [backend](Backend) to use the Burn tensor representation /// for compilation purpose or other... pub trait ReprBackend: Backend { @@ -38,10 +26,7 @@ pub trait ReprBackend: Backend { /// Convert a [handle](ReprBackend::Handle) to a [bool tensor](Backend::BoolTensorPrimitive). fn bool_tensor(handle: TensorHandle) -> BoolTensor; /// Convert a [handle](ReprBackend::Handle) to a [quantized tensor](Backend::QuantizedTensorPrimitive). - fn quantized_tensor( - handle: QuantizedKind>, - scheme: QuantizationScheme, - ) -> QuantizedTensor; + fn quantized_tensor(handle: TensorHandle) -> QuantizedTensor; /// Convert a [float tensor](Backend::FloatTensorPrimitive) to a [handle](ReprBackend::Handle). fn float_tensor_handle(tensor: FloatTensor) -> Self::Handle; @@ -51,7 +36,7 @@ pub trait ReprBackend: Backend { fn bool_tensor_handle(tensor: BoolTensor) -> Self::Handle; /// Convert a [quantized tensor](Backend::QuantizedTensorPrimitive) to a [handle](ReprBackend::Handle). /// A quantized tensor has multiple handles for the tensor itself and the quantization parameters. - fn quantized_tensor_handle(tensor: QuantizedTensor) -> QuantizedKind; + fn quantized_tensor_handle(tensor: QuantizedTensor) -> Self::Handle; } /// Handle which points to a backend tensor primitive kind. diff --git a/crates/burn-tensor/src/repr/handle.rs b/crates/burn-tensor/src/repr/handle.rs index 3da083c0e6..85e18ec444 100644 --- a/crates/burn-tensor/src/repr/handle.rs +++ b/crates/burn-tensor/src/repr/handle.rs @@ -14,7 +14,7 @@ use alloc::sync::Arc; #[cfg(not(target_has_atomic = "ptr"))] use portable_atomic_util::Arc; -use super::{QuantizedKind, QuantizedTensorDescription, TensorHandle}; +use super::TensorHandle; /// Keep all [tensor handles](ReprBackend::Handle) in one place and ensure that all resources /// are used optimally. @@ -124,21 +124,12 @@ impl HandleContainer { /// given [tensor description](TensorDescription). pub fn get_quantized_tensor( &mut self, - tensor: &QuantizedTensorDescription, + tensor: &TensorDescription, ) -> B::QuantizedTensorPrimitive where B: ReprBackend, { - let handles = QuantizedKind { - tensor: self.get_tensor_handle(&tensor.tensor), - scale: self.get_tensor_handle(&tensor.qparams.scale), - offset: tensor - .qparams - .offset - .as_ref() - .map(|offset| self.get_tensor_handle(offset)), - }; - B::quantized_tensor(handles, tensor.scheme) + B::quantized_tensor(self.get_tensor_handle(tensor)) } /// Register a new [float tensor](crate::backend::Backend::FloatTensorPrimitive) with the corresponding [tensor id](TensorId). @@ -153,21 +144,13 @@ impl HandleContainer { /// Register a new [quantized tensor](crate::backend::Backend::QuantizedTensorPrimitive) with the corresponding [tensor ids](TensorId). pub fn register_quantized_tensor( &mut self, - id: &QuantizedKind, + id: &TensorId, tensor: B::QuantizedTensorPrimitive, ) where B: ReprBackend, { - let handles = B::quantized_tensor_handle(tensor); - - self.handles - .insert(id.tensor, Handle::Existing(handles.tensor)); - self.handles - .insert(id.scale, Handle::Existing(handles.scale)); - - if let (Some(id), Some(handle)) = (id.offset, handles.offset) { - self.handles.insert(id, Handle::Existing(handle)); - } + let handle = B::quantized_tensor_handle(tensor); + self.handles.insert(*id, Handle::Existing(handle)); } /// Register a new [int tensor](crate::backend::Backend::IntTensorPrimitive) with the corresponding [tensor id](TensorId). diff --git a/crates/burn-tensor/src/repr/mod.rs b/crates/burn-tensor/src/repr/mod.rs index e26565b353..b98e43ba3d 100644 --- a/crates/burn-tensor/src/repr/mod.rs +++ b/crates/burn-tensor/src/repr/mod.rs @@ -1,11 +1,9 @@ mod backend; mod handle; mod operation; -mod quantization; mod tensor; pub use backend::*; pub use handle::*; pub use operation::*; -pub use quantization::*; pub use tensor::*; diff --git a/crates/burn-tensor/src/repr/operation.rs b/crates/burn-tensor/src/repr/operation.rs index 2528f44305..001b9d6e83 100644 --- a/crates/burn-tensor/src/repr/operation.rs +++ b/crates/burn-tensor/src/repr/operation.rs @@ -15,8 +15,6 @@ use crate::{ DType, Distribution, Element, }; -use super::{QuantizationParametersDescription, QuantizedTensorDescription}; - /// Custom operation in fusion stream, declaring it's inputs and outputs. #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] pub struct CustomOpDescription { @@ -906,6 +904,15 @@ pub struct ConvTranspose3dOptionsDescription { pub groups: usize, } +/// Quantization parameters description. +#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)] +pub struct QuantizationParametersDescription { + /// The scaling factor. + pub scale: TensorDescription, + /// The zero-point offset. + pub offset: Option, +} + #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct QuantizeOperationDescription { @@ -918,7 +925,7 @@ pub struct QuantizeOperationDescription { #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct DequantizeOperationDescription { - pub qtensor: QuantizedTensorDescription, + pub input: TensorDescription, pub out: TensorDescription, } @@ -1528,18 +1535,7 @@ impl FloatOperationDescription { vec![&desc.tensor, &desc.qparams.scale, &desc.out] } } - FloatOperationDescription::Dequantize(desc) => { - if let Some(offset) = &desc.qtensor.qparams.offset { - vec![ - &desc.qtensor.tensor, - &desc.qtensor.qparams.scale, - &offset, - &desc.out, - ] - } else { - vec![&desc.qtensor.tensor, &desc.qtensor.qparams.scale, &desc.out] - } - } + FloatOperationDescription::Dequantize(desc) => vec![&desc.input, &desc.out], } } } diff --git a/crates/burn-tensor/src/repr/quantization.rs b/crates/burn-tensor/src/repr/quantization.rs deleted file mode 100644 index 8a38a35a44..0000000000 --- a/crates/burn-tensor/src/repr/quantization.rs +++ /dev/null @@ -1,25 +0,0 @@ -use serde::{Deserialize, Serialize}; - -use crate::quantization::QuantizationScheme; - -use super::TensorDescription; - -/// A quantized tensor description represents a snapshot of a quantized tensor when it was used. -#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)] -pub struct QuantizedTensorDescription { - /// The quantized tensor. - pub tensor: TensorDescription, - /// The quantization parameters. - pub qparams: QuantizationParametersDescription, - /// The quantization scheme - pub scheme: QuantizationScheme, -} - -/// Quantization parameters description. -#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)] -pub struct QuantizationParametersDescription { - /// The scaling factor. - pub scale: TensorDescription, - /// The zero-point offset. - pub offset: Option, -} diff --git a/crates/burn-tensor/src/tensor/data.rs b/crates/burn-tensor/src/tensor/data.rs index d075d8e409..c80e7c1640 100644 --- a/crates/burn-tensor/src/tensor/data.rs +++ b/crates/burn-tensor/src/tensor/data.rs @@ -93,6 +93,8 @@ impl TensorData { shape: S, strategy: QuantizationStrategy, ) -> Self { + // TODO: this method should go into a dedicated Bytes opaque type with other bytes + // handling logic let mut value = into_bytes(value); // Notes on quantization data representation: @@ -111,8 +113,11 @@ impl TensorData { } else { panic!("Invalid quantized type"); } - let scale_bytes = bytemuck::bytes_of(&q.scale); - let offset_bytes = bytemuck::bytes_of(&q.offset); + // Scale is always stored as f32 and zero-point offset as i32 + let scale = q.scale as f32; + let offset = q.offset as i32; + let scale_bytes = bytemuck::bytes_of(&scale); + let offset_bytes = bytemuck::bytes_of(&offset); value.extend_from_slice(offset_bytes); value.extend_from_slice(scale_bytes); } @@ -446,7 +451,7 @@ impl TensorData { let mut tensor_bytes_end = self.bytes.len() - scale_size; if let QuantizationScheme::PerTensorAffine(QuantizationType::QInt8) = scheme { - tensor_bytes_end -= core::mem::size_of::(); + tensor_bytes_end -= core::mem::size_of::(); // zero-point offset is stored as i32 } &self.bytes[..tensor_bytes_end] @@ -483,17 +488,17 @@ impl TensorData { // Quantization parameters are added at the end of the tensor data. // As such, the last bytes always correspond to the scale parameter. // If the quantization scheme includes an offset (zero-point) parameter, it is next to last. - let scale_size = core::mem::size_of::(); + let scale_size = core::mem::size_of::(); // scale is stored as f32 let scale_bytes = &self.bytes[total_bytes - scale_size..]; - let scale = read_unaligned(scale_bytes); + let scale = read_unaligned::(scale_bytes).elem(); let mut offset = None; if let QuantizationScheme::PerTensorAffine(_) = scheme { - let offset_size = core::mem::size_of::(); + let offset_size = core::mem::size_of::(); // zero-point offset is stored as i32 let offset_bytes = &self.bytes[total_bytes - scale_size - offset_size..total_bytes - scale_size]; - offset = Some(read_unaligned(offset_bytes)) + offset = Some(read_unaligned::(offset_bytes).elem()) } Some(QParams { scale, offset }) diff --git a/crates/burn-tensor/src/tensor/quantization/primitive.rs b/crates/burn-tensor/src/tensor/quantization/primitive.rs index acc5ff692d..a4f1fba1c8 100644 --- a/crates/burn-tensor/src/tensor/quantization/primitive.rs +++ b/crates/burn-tensor/src/tensor/quantization/primitive.rs @@ -1,13 +1,7 @@ -use super::{QuantizationScheme, QuantizationStrategy}; +use super::QuantizationScheme; /// Quantized tensor primitive. pub trait QTensorPrimitive { /// Returns the quantization scheme for the given tensor. fn scheme(&self) -> &QuantizationScheme; - /// Returns the quantization strategy for the given tensor. - /// - /// # Remarks - /// Retrieving the quantization strategy with its corresponding parameters might require - /// synchronization on the backend. - fn strategy(&self) -> QuantizationStrategy; } diff --git a/crates/burn-tensor/src/tensor/quantization/scheme.rs b/crates/burn-tensor/src/tensor/quantization/scheme.rs index 3534a01c32..fb141ee16d 100644 --- a/crates/burn-tensor/src/tensor/quantization/scheme.rs +++ b/crates/burn-tensor/src/tensor/quantization/scheme.rs @@ -1,11 +1,17 @@ +#![allow(missing_docs)] // cube derive macros + use serde::{Deserialize, Serialize}; use crate::{backend::Backend, Tensor, TensorPrimitive}; use super::{CalibrationRange, QuantizationParameters, QuantizationParametersPrimitive}; +#[cfg(feature = "cubecl")] +use cubecl::prelude::*; + /// Quantization data type. #[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)] +#[cfg_attr(feature = "cubecl", derive(CubeType, PartialOrd, Ord))] pub enum QuantizationType { /// 8-bit signed integer. QInt8, @@ -13,6 +19,7 @@ pub enum QuantizationType { /// Quantization scheme. #[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)] +#[cfg_attr(feature = "cubecl", derive(PartialOrd, Ord))] pub enum QuantizationScheme { /// Per-tensor affine/asymmetric quantization. PerTensorAffine(QuantizationType), @@ -24,6 +31,17 @@ pub enum QuantizationScheme { // PerChannelSymmetric, } +#[cfg(feature = "cubecl")] +impl CubeType for QuantizationScheme { + type ExpandType = Self; +} +#[cfg(feature = "cubecl")] +impl cubecl::frontend::Init for QuantizationScheme { + fn init(self, _context: &mut CubeContext) -> Self { + self + } +} + impl QuantizationScheme { /// Compute the quantization parameters. pub fn compute_q_params( diff --git a/crates/burn-tensor/src/tests/quantization/ops/quantize.rs b/crates/burn-tensor/src/tests/quantization/ops/quantize.rs index 72834b8c8a..a120c2802a 100644 --- a/crates/burn-tensor/src/tests/quantization/ops/quantize.rs +++ b/crates/burn-tensor/src/tests/quantization/ops/quantize.rs @@ -18,7 +18,7 @@ mod tests { offset: Some(Tensor::from_ints([72], &device)), }; - let x_q = tensor.quantize(&scheme, qparams); + let x_q = tensor.quantize(&scheme, qparams).into_data(); let expected = TensorData::quantized( vec![-128i8, -39, 72, 127], @@ -26,7 +26,14 @@ mod tests { QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.009_019_608, 72)), ); - x_q.to_data().assert_eq(&expected, true); + // Values equality + x_q.assert_eq(&expected, true); + + // Quantization parameters check + let qparams = x_q.get_q_params::().unwrap(); + let expected = expected.get_q_params::().unwrap(); + assert_eq!(qparams.scale, expected.scale); + assert_eq!(qparams.offset, expected.offset); } #[test] @@ -39,7 +46,7 @@ mod tests { offset: None, }; - let x_q = tensor.quantize(&scheme, qparams); + let x_q = tensor.quantize(&scheme, qparams).into_data(); let expected = TensorData::quantized( vec![-127i8, -71, 0, 35], @@ -49,7 +56,14 @@ mod tests { )), ); - x_q.to_data().assert_eq(&expected, true); + // Values equality + x_q.assert_eq(&expected, true); + + // Quantization parameters check + let qparams = x_q.get_q_params::().unwrap(); + let expected = expected.get_q_params::().unwrap(); + assert_eq!(qparams.scale, expected.scale); + assert_eq!(qparams.offset, expected.offset); } #[test] From c689f6d0a0a547fb3b7df59c369142a6fe3f5a32 Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Mon, 9 Dec 2024 09:33:18 -0500 Subject: [PATCH 04/16] Remove dead comment --- crates/burn-fusion/src/tensor.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/crates/burn-fusion/src/tensor.rs b/crates/burn-fusion/src/tensor.rs index c6ea209e29..89de37b07d 100644 --- a/crates/burn-fusion/src/tensor.rs +++ b/crates/burn-fusion/src/tensor.rs @@ -138,7 +138,6 @@ impl FusionTensor { .clone() .read_tensor_quantized::(self.into_description(), id) .await - // todo!() // doesn't work if we only have one tensordescription when we need the handles for the tensor + qparams } else { panic!("Expected quantized float dtype, got {:?}", self.dtype) } From c1d2ee2a1eda9194624687241fbb1699268a9cc7 Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Mon, 9 Dec 2024 09:52:14 -0500 Subject: [PATCH 05/16] Update cubecl rev --- Cargo.lock | 24 +++++++++---------- Cargo.toml | 4 ++-- crates/burn-jit/src/kernel/matmul/base.rs | 3 ++- .../burn-jit/src/kernel/matmul/tune/base.rs | 9 ++++--- 4 files changed, 22 insertions(+), 18 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f7e3226962..5ac5e3d109 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1668,7 +1668,7 @@ dependencies = [ [[package]] name = "cubecl" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=1c4e0036c043422773fd6824c2a888160fca8e5e#1c4e0036c043422773fd6824c2a888160fca8e5e" +source = "git+https://github.com/tracel-ai/cubecl?rev=6e6fb265346c6378e939573900c5d32b722569fa#6e6fb265346c6378e939573900c5d32b722569fa" dependencies = [ "cubecl-core", "cubecl-cuda", @@ -1700,7 +1700,7 @@ dependencies = [ [[package]] name = "cubecl-common" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=1c4e0036c043422773fd6824c2a888160fca8e5e#1c4e0036c043422773fd6824c2a888160fca8e5e" +source = "git+https://github.com/tracel-ai/cubecl?rev=6e6fb265346c6378e939573900c5d32b722569fa#6e6fb265346c6378e939573900c5d32b722569fa" dependencies = [ "derive-new 0.6.0", "embassy-futures", @@ -1717,7 +1717,7 @@ dependencies = [ [[package]] name = "cubecl-core" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=1c4e0036c043422773fd6824c2a888160fca8e5e#1c4e0036c043422773fd6824c2a888160fca8e5e" +source = "git+https://github.com/tracel-ai/cubecl?rev=6e6fb265346c6378e939573900c5d32b722569fa#6e6fb265346c6378e939573900c5d32b722569fa" dependencies = [ "bytemuck", "cubecl-common 0.4.0", @@ -1735,7 +1735,7 @@ dependencies = [ [[package]] name = "cubecl-cpp" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=1c4e0036c043422773fd6824c2a888160fca8e5e#1c4e0036c043422773fd6824c2a888160fca8e5e" +source = "git+https://github.com/tracel-ai/cubecl?rev=6e6fb265346c6378e939573900c5d32b722569fa#6e6fb265346c6378e939573900c5d32b722569fa" dependencies = [ "bytemuck", "cubecl-common 0.4.0", @@ -1749,7 +1749,7 @@ dependencies = [ [[package]] name = "cubecl-cuda" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=1c4e0036c043422773fd6824c2a888160fca8e5e#1c4e0036c043422773fd6824c2a888160fca8e5e" +source = "git+https://github.com/tracel-ai/cubecl?rev=6e6fb265346c6378e939573900c5d32b722569fa#6e6fb265346c6378e939573900c5d32b722569fa" dependencies = [ "bytemuck", "cubecl-common 0.4.0", @@ -1765,7 +1765,7 @@ dependencies = [ [[package]] name = "cubecl-hip" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=1c4e0036c043422773fd6824c2a888160fca8e5e#1c4e0036c043422773fd6824c2a888160fca8e5e" +source = "git+https://github.com/tracel-ai/cubecl?rev=6e6fb265346c6378e939573900c5d32b722569fa#6e6fb265346c6378e939573900c5d32b722569fa" dependencies = [ "bytemuck", "cubecl-common 0.4.0", @@ -1791,7 +1791,7 @@ dependencies = [ [[package]] name = "cubecl-linalg" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=1c4e0036c043422773fd6824c2a888160fca8e5e#1c4e0036c043422773fd6824c2a888160fca8e5e" +source = "git+https://github.com/tracel-ai/cubecl?rev=6e6fb265346c6378e939573900c5d32b722569fa#6e6fb265346c6378e939573900c5d32b722569fa" dependencies = [ "bytemuck", "cubecl-core", @@ -1802,7 +1802,7 @@ dependencies = [ [[package]] name = "cubecl-macros" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=1c4e0036c043422773fd6824c2a888160fca8e5e#1c4e0036c043422773fd6824c2a888160fca8e5e" +source = "git+https://github.com/tracel-ai/cubecl?rev=6e6fb265346c6378e939573900c5d32b722569fa#6e6fb265346c6378e939573900c5d32b722569fa" dependencies = [ "cubecl-common 0.4.0", "darling", @@ -1817,7 +1817,7 @@ dependencies = [ [[package]] name = "cubecl-opt" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=1c4e0036c043422773fd6824c2a888160fca8e5e#1c4e0036c043422773fd6824c2a888160fca8e5e" +source = "git+https://github.com/tracel-ai/cubecl?rev=6e6fb265346c6378e939573900c5d32b722569fa#6e6fb265346c6378e939573900c5d32b722569fa" dependencies = [ "cubecl-common 0.4.0", "cubecl-core", @@ -1854,7 +1854,7 @@ dependencies = [ [[package]] name = "cubecl-runtime" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=1c4e0036c043422773fd6824c2a888160fca8e5e#1c4e0036c043422773fd6824c2a888160fca8e5e" +source = "git+https://github.com/tracel-ai/cubecl?rev=6e6fb265346c6378e939573900c5d32b722569fa#6e6fb265346c6378e939573900c5d32b722569fa" dependencies = [ "async-channel", "async-lock", @@ -1875,7 +1875,7 @@ dependencies = [ [[package]] name = "cubecl-spirv" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=1c4e0036c043422773fd6824c2a888160fca8e5e#1c4e0036c043422773fd6824c2a888160fca8e5e" +source = "git+https://github.com/tracel-ai/cubecl?rev=6e6fb265346c6378e939573900c5d32b722569fa#6e6fb265346c6378e939573900c5d32b722569fa" dependencies = [ "cubecl-common 0.4.0", "cubecl-core", @@ -1889,7 +1889,7 @@ dependencies = [ [[package]] name = "cubecl-wgpu" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=1c4e0036c043422773fd6824c2a888160fca8e5e#1c4e0036c043422773fd6824c2a888160fca8e5e" +source = "git+https://github.com/tracel-ai/cubecl?rev=6e6fb265346c6378e939573900c5d32b722569fa#6e6fb265346c6378e939573900c5d32b722569fa" dependencies = [ "ash", "async-channel", diff --git a/Cargo.toml b/Cargo.toml index b93e6d809b..a0a3e0d4e7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -153,8 +153,8 @@ ahash = { version = "0.8.11", default-features = false } portable-atomic-util = { version = "0.2.4", features = ["alloc"] } ### For the main burn branch. ### -cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "1c4e0036c043422773fd6824c2a888160fca8e5e" } -cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "1c4e0036c043422773fd6824c2a888160fca8e5e" } +cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "6e6fb265346c6378e939573900c5d32b722569fa" } +cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "6e6fb265346c6378e939573900c5d32b722569fa" } ### For local development. ### # cubecl = { path = "../cubecl/crates/cubecl", default-features = false } # cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false } diff --git a/crates/burn-jit/src/kernel/matmul/base.rs b/crates/burn-jit/src/kernel/matmul/base.rs index e0b87e8931..7fa141cf67 100644 --- a/crates/burn-jit/src/kernel/matmul/base.rs +++ b/crates/burn-jit/src/kernel/matmul/base.rs @@ -43,7 +43,8 @@ pub fn matmul( &lhs.as_handle_ref(), &rhs.as_handle_ref(), &out.as_handle_ref(), - ); + ) + .unwrap(); out } #[cfg(feature = "autotune")] diff --git a/crates/burn-jit/src/kernel/matmul/tune/base.rs b/crates/burn-jit/src/kernel/matmul/tune/base.rs index 2b8050dc07..cea682ae92 100644 --- a/crates/burn-jit/src/kernel/matmul/tune/base.rs +++ b/crates/burn-jit/src/kernel/matmul/tune/base.rs @@ -148,7 +148,8 @@ matmul_tune_ops!( &lhs.as_handle_ref(), &rhs.as_handle_ref(), &out.as_handle_ref(), - ); + ) + .unwrap(); } ); @@ -162,7 +163,8 @@ matmul_tune_ops!( &lhs.as_handle_ref(), &rhs.as_handle_ref(), &out.as_handle_ref(), - ); + ) + .unwrap(); } ); @@ -176,6 +178,7 @@ matmul_tune_ops!( &lhs.as_handle_ref(), &rhs.as_handle_ref(), &out.as_handle_ref(), - ); + ) + .unwrap(); } ); From c60ab2fbb1503cbab3e02e77c856abc0a60cc430 Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Mon, 9 Dec 2024 09:52:30 -0500 Subject: [PATCH 06/16] Remove dead code --- .../src/kernel/quantization/qtensor.rs | 44 +------------------ 1 file changed, 2 insertions(+), 42 deletions(-) diff --git a/crates/burn-jit/src/kernel/quantization/qtensor.rs b/crates/burn-jit/src/kernel/quantization/qtensor.rs index a07a2eed52..a86a0ac4be 100644 --- a/crates/burn-jit/src/kernel/quantization/qtensor.rs +++ b/crates/burn-jit/src/kernel/quantization/qtensor.rs @@ -20,52 +20,12 @@ impl QParams { QParams { scheme } } - // NOTE: a couple of incompatible things for this to work.. - // notably `switch_expand_expr` only works for CubePrimitive and it doesn't really make sense to implement that for a tuple - // or QParams type - // - // /// Get the quantization parameters: - // /// - Floating-point scaling factor (encoded as u32) - // /// - Zero-point offset - // pub fn qparams(&self) -> (f32, i32) { - // let len = self.tensor.buffer_len(); - // match comptime![self.scheme] { - // QuantizationScheme::PerTensorAffine(_) => match self.tensor.line_size() { - // // For line size of 1, scale is second to last in the buffer while the zero-point offset is the last element - // 1 => ( - // f32::cast_from(self.tensor[len - 2][0]), - // i32::cast_from(self.tensor[len - 1][0]), - // ), - // // QParams { - // // scale: f32::cast_from(self.tensor[len - 2][0]), - // // offset: i32::cast_from(self.tensor[len - 1][0]), - // // }, - // // For any other line size > 1, scale and zero-point offset are the first two elements of the last line - // _ => { - // let line = self.tensor[len - 1]; - // // QParams { - // // scale: f32::cast_from(line[0]), - // // offset: i32::cast_from(line[1]), - // // } - // (f32::cast_from(line[0]), i32::cast_from(line[1])) - // } - // }, - // // Symmetric quantization only contains the scaling factor, so we return 0 for the zero-point offset - // QuantizationScheme::PerTensorSymmetric(_) => { - // (f32::cast_from(self.tensor[len - 1][0]), 0) - // } // QParams { - // // scale: f32::cast_from(self.tensor[len - 1][0]), - // // offset: 0, - // // }, - // } - // } - /// Get the floating-point scaling factor. pub fn scale(&self, tensor: &QTensor) -> f32 { let len = tensor.len(); match comptime!(self.scheme) { QuantizationScheme::PerTensorAffine(_) => match comptime!(tensor.line_size()) { - // For line size of 1, scale is last in the buffer while the zero-point offset is the second-to-last element + // For line size of 1, scale is the last value in the buffer 1 => f32::bitcast_from(tensor[len - 1][tensor.line_size() - 1]), // For any other line size > 1, scale and zero-point offset are the first two elements of the last line _ => f32::bitcast_from(tensor[len - 1][tensor.line_size() - 2]), @@ -83,7 +43,7 @@ impl QParams { let line_size = comptime!(tensor.line_size()); match comptime!(self.scheme) { QuantizationScheme::PerTensorAffine(_) => match line_size { - // For line size of 1, scale is last in the buffer while the zero-point offset is the second-to-last element + // For line size of 1, the zero-point offset is the penultimate value in the buffer 1 => i32::cast_from(tensor[len - 2][line_size]), // For any other line size > 1, scale and zero-point offset are the first two elements of the last line _ => i32::cast_from(tensor[len - 1][line_size]), From 8f5f2199de82d245cc61d8c8fde6f60e9a7f14c3 Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Mon, 9 Dec 2024 10:07:12 -0500 Subject: [PATCH 07/16] Fix comments --- .../burn-jit/src/kernel/quantization/dequantize.rs | 2 +- crates/burn-jit/src/kernel/quantization/qtensor.rs | 13 ++++++------- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/crates/burn-jit/src/kernel/quantization/dequantize.rs b/crates/burn-jit/src/kernel/quantization/dequantize.rs index b8b1764220..2d4f1a79e9 100644 --- a/crates/burn-jit/src/kernel/quantization/dequantize.rs +++ b/crates/burn-jit/src/kernel/quantization/dequantize.rs @@ -47,7 +47,7 @@ pub(crate) fn dequantize_per_tensor_affine_int8_kernel( #[comptime] scheme: QuantizationScheme, ) { // Last two positions contain the qparams - if ABSOLUTE_POS >= output.len() - 2 { + if ABSOLUTE_POS >= input.len() - 2 { return; } diff --git a/crates/burn-jit/src/kernel/quantization/qtensor.rs b/crates/burn-jit/src/kernel/quantization/qtensor.rs index a86a0ac4be..e5d4c5ff05 100644 --- a/crates/burn-jit/src/kernel/quantization/qtensor.rs +++ b/crates/burn-jit/src/kernel/quantization/qtensor.rs @@ -27,8 +27,8 @@ impl QParams { QuantizationScheme::PerTensorAffine(_) => match comptime!(tensor.line_size()) { // For line size of 1, scale is the last value in the buffer 1 => f32::bitcast_from(tensor[len - 1][tensor.line_size() - 1]), - // For any other line size > 1, scale and zero-point offset are the first two elements of the last line - _ => f32::bitcast_from(tensor[len - 1][tensor.line_size() - 2]), + // For any other line size > 1, scale and zero-point offset are the last two elements + _ => f32::bitcast_from(tensor[len - 1][tensor.line_size() - 1]), }, // Symmetric quantization only contains the scaling factor as the last element QuantizationScheme::PerTensorSymmetric(_) => { @@ -40,13 +40,12 @@ impl QParams { /// Get the zero-point offset. pub fn offset(&self, tensor: &QTensor) -> i32 { let len = tensor.len(); - let line_size = comptime!(tensor.line_size()); match comptime!(self.scheme) { - QuantizationScheme::PerTensorAffine(_) => match line_size { + QuantizationScheme::PerTensorAffine(_) => match comptime!(tensor.line_size()) { // For line size of 1, the zero-point offset is the penultimate value in the buffer - 1 => i32::cast_from(tensor[len - 2][line_size]), - // For any other line size > 1, scale and zero-point offset are the first two elements of the last line - _ => i32::cast_from(tensor[len - 1][line_size]), + 1 => i32::cast_from(tensor[len - 2][tensor.line_size() - 1]), + // For any other line size > 1, scale and zero-point offset are the last two elements + _ => i32::cast_from(tensor[len - 1][tensor.line_size() - 2]), }, // Symmetric quantization only contains the scaling factor, so we return 0 for the zero-point offset QuantizationScheme::PerTensorSymmetric(_) => 0, From cdef4f7e84db10c521bb32a4936f3615aec5df68 Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Mon, 9 Dec 2024 10:21:51 -0500 Subject: [PATCH 08/16] Fix clippy --- crates/burn-tensor/src/tensor/data.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/crates/burn-tensor/src/tensor/data.rs b/crates/burn-tensor/src/tensor/data.rs index c80e7c1640..d52a181929 100644 --- a/crates/burn-tensor/src/tensor/data.rs +++ b/crates/burn-tensor/src/tensor/data.rs @@ -114,9 +114,8 @@ impl TensorData { panic!("Invalid quantized type"); } // Scale is always stored as f32 and zero-point offset as i32 - let scale = q.scale as f32; let offset = q.offset as i32; - let scale_bytes = bytemuck::bytes_of(&scale); + let scale_bytes = bytemuck::bytes_of(&q.scale); let offset_bytes = bytemuck::bytes_of(&offset); value.extend_from_slice(offset_bytes); value.extend_from_slice(scale_bytes); From ca211ef8ed3dd7c2dc0f12b25dcb1fc98f206429 Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Mon, 9 Dec 2024 13:54:45 -0500 Subject: [PATCH 09/16] Remove unnecessary loop for input line size of 1 --- .../src/kernel/quantization/dequantize.rs | 26 +++++++------------ 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/crates/burn-jit/src/kernel/quantization/dequantize.rs b/crates/burn-jit/src/kernel/quantization/dequantize.rs index 2d4f1a79e9..093019bee3 100644 --- a/crates/burn-jit/src/kernel/quantization/dequantize.rs +++ b/crates/burn-jit/src/kernel/quantization/dequantize.rs @@ -57,19 +57,16 @@ pub(crate) fn dequantize_per_tensor_affine_int8_kernel( let value = input[ABSOLUTE_POS]; + // Input line size is fixed to 1 if comptime!(output.line_size() == 4) { - // For input line size = 1 output[ABSOLUTE_POS] = dequantize_affine_int8(extract_i8s(value[0]), scale, offset); } else { // For very small inputs where number of elements < 4, the output line size is 1 - #[unroll] - for i in 0..input.line_size() { - let out = dequantize_affine_int8::(extract_i8s(value[i]), scale, offset); + let out = dequantize_affine_int8::(extract_i8s(value[0]), scale, offset); - #[unroll] - for j in 0..out.size() { - output[ABSOLUTE_POS + j] = Line::cast_from(out[j]); - } + #[unroll] + for j in 0..out.size() { + output[ABSOLUTE_POS + j] = Line::cast_from(out[j]); } } } @@ -97,19 +94,16 @@ pub(crate) fn dequantize_per_tensor_symmetric_int8_kernel( let value = input[ABSOLUTE_POS]; + // Input line size is fixed to 1 if comptime!(output.line_size() == 4) { - // For input line size = 1 output[ABSOLUTE_POS] = dequantize_symmetric_int8(extract_i8s(value[0]), scale); } else { // For very small inputs where number of elements < 4, the output line size is 1 - #[unroll] - for i in 0..input.line_size() { - let out = dequantize_symmetric_int8::(extract_i8s(value[i]), scale); + let out = dequantize_symmetric_int8::(extract_i8s(value[0]), scale); - #[unroll] - for j in 0..out.size() { - output[ABSOLUTE_POS + j] = Line::cast_from(out[j]); - } + #[unroll] + for j in 0..out.size() { + output[ABSOLUTE_POS + j] = Line::cast_from(out[j]); } } } From edd1c27c32e3ac6c724a668bcca4b7acf9a24409 Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Tue, 10 Dec 2024 15:27:45 -0500 Subject: [PATCH 10/16] Remove quantized kindremnant --- crates/burn-jit/src/backend.rs | 19 +++++-------------- crates/burn-tensor/src/repr/backend.rs | 3 --- 2 files changed, 5 insertions(+), 17 deletions(-) diff --git a/crates/burn-jit/src/backend.rs b/crates/burn-jit/src/backend.rs index 34cb8e0c4e..9b5ea67e24 100644 --- a/crates/burn-jit/src/backend.rs +++ b/crates/burn-jit/src/backend.rs @@ -122,13 +122,10 @@ impl ReprBackend } } - fn quantized_tensor( - handles: QuantizedKind>, - ) -> QuantizedTensor { - let handle = handles.tensor.handle; - match handle { + fn quantized_tensor(handles: TensorHandle) -> QuantizedTensor { + match handle.handle { HandleKind::Quantized(handle) => handle, - _ => panic!("Expected quantized handle, got {}", handle.name()), + _ => panic!("Expected quantized handle, got {}", handle.handle.name()), } } @@ -144,13 +141,7 @@ impl ReprBackend HandleKind::Bool(tensor) } - fn quantized_tensor_handle(tensor: QuantizedTensor) -> QuantizedKind { - QuantizedKind { - tensor: HandleKind::Quantized(tensor), - // The quantized tensor primitive already encapsulates the required quantization - // parameters so we set the scale as an empty handle (unused). - scale: HandleKind::Empty, - offset: None, - } + fn quantized_tensor_handle(tensor: QuantizedTensor) -> Self::Handle { + HandleKind::Quantized(tensor) } } diff --git a/crates/burn-tensor/src/repr/backend.rs b/crates/burn-tensor/src/repr/backend.rs index 9d92b1e9ed..a56c3461cf 100644 --- a/crates/burn-tensor/src/repr/backend.rs +++ b/crates/burn-tensor/src/repr/backend.rs @@ -50,8 +50,6 @@ pub enum HandleKind { Bool(B::BoolTensorPrimitive), /// Quantized tensor handle. Quantized(B::QuantizedTensorPrimitive), - /// Empty handle (used as a dummy representation). - Empty, } impl HandleKind { @@ -62,7 +60,6 @@ impl HandleKind { HandleKind::Int(_) => "int", HandleKind::Bool(_) => "bool", HandleKind::Quantized(_) => "quantized", - HandleKind::Empty => unreachable!(), // should not happen } } } From 52edc88dab64659c07f63c76762fecfa98d002a3 Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Tue, 10 Dec 2024 15:30:59 -0500 Subject: [PATCH 11/16] Remove no longer valid comment --- crates/burn-tensor/src/repr/backend.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/crates/burn-tensor/src/repr/backend.rs b/crates/burn-tensor/src/repr/backend.rs index a56c3461cf..696da20c49 100644 --- a/crates/burn-tensor/src/repr/backend.rs +++ b/crates/burn-tensor/src/repr/backend.rs @@ -35,7 +35,6 @@ pub trait ReprBackend: Backend { /// Convert a [bool tensor](Backend::BoolTensorPrimitive) to a [handle](ReprBackend::Handle). fn bool_tensor_handle(tensor: BoolTensor) -> Self::Handle; /// Convert a [quantized tensor](Backend::QuantizedTensorPrimitive) to a [handle](ReprBackend::Handle). - /// A quantized tensor has multiple handles for the tensor itself and the quantization parameters. fn quantized_tensor_handle(tensor: QuantizedTensor) -> Self::Handle; } From c49533126ac0d5c35e25b5f93a290a1965e916ce Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Wed, 11 Dec 2024 10:44:37 -0500 Subject: [PATCH 12/16] Get qparams values as tuple --- .../src/kernel/quantization/dequantize.rs | 5 +-- .../src/kernel/quantization/qtensor.rs | 39 ++++++++----------- 2 files changed, 19 insertions(+), 25 deletions(-) diff --git a/crates/burn-jit/src/kernel/quantization/dequantize.rs b/crates/burn-jit/src/kernel/quantization/dequantize.rs index 093019bee3..6d53b2effb 100644 --- a/crates/burn-jit/src/kernel/quantization/dequantize.rs +++ b/crates/burn-jit/src/kernel/quantization/dequantize.rs @@ -52,8 +52,7 @@ pub(crate) fn dequantize_per_tensor_affine_int8_kernel( } let qparams = QParams::new(scheme); - let scale = qparams.scale(input); - let offset = qparams.offset(input); + let (scale, offset) = qparams.values(input); let value = input[ABSOLUTE_POS]; @@ -90,7 +89,7 @@ pub(crate) fn dequantize_per_tensor_symmetric_int8_kernel( } let qparams = QParams::new(scheme); - let scale = qparams.scale(input); + let (scale, _) = qparams.values(input); let value = input[ABSOLUTE_POS]; diff --git a/crates/burn-jit/src/kernel/quantization/qtensor.rs b/crates/burn-jit/src/kernel/quantization/qtensor.rs index e5d4c5ff05..26d9f65091 100644 --- a/crates/burn-jit/src/kernel/quantization/qtensor.rs +++ b/crates/burn-jit/src/kernel/quantization/qtensor.rs @@ -20,35 +20,30 @@ impl QParams { QParams { scheme } } - /// Get the floating-point scaling factor. - pub fn scale(&self, tensor: &QTensor) -> f32 { + /// Get the quantization parameters values. + pub fn values(&self, tensor: &QTensor) -> (f32, i32) { let len = tensor.len(); match comptime!(self.scheme) { QuantizationScheme::PerTensorAffine(_) => match comptime!(tensor.line_size()) { // For line size of 1, scale is the last value in the buffer - 1 => f32::bitcast_from(tensor[len - 1][tensor.line_size() - 1]), + 1 => ( + f32::bitcast_from(tensor[len - 1][tensor.line_size() - 1]), + i32::cast_from(tensor[len - 2][tensor.line_size() - 1]), + ), // For any other line size > 1, scale and zero-point offset are the last two elements - _ => f32::bitcast_from(tensor[len - 1][tensor.line_size() - 1]), + _ => { + let values = tensor[len - 1]; + ( + f32::bitcast_from(values[tensor.line_size() - 1]), + i32::cast_from(values[tensor.line_size() - 2]), + ) + } }, // Symmetric quantization only contains the scaling factor as the last element - QuantizationScheme::PerTensorSymmetric(_) => { - f32::bitcast_from(tensor[len - 1][tensor.line_size() - 1]) - } - } - } - - /// Get the zero-point offset. - pub fn offset(&self, tensor: &QTensor) -> i32 { - let len = tensor.len(); - match comptime!(self.scheme) { - QuantizationScheme::PerTensorAffine(_) => match comptime!(tensor.line_size()) { - // For line size of 1, the zero-point offset is the penultimate value in the buffer - 1 => i32::cast_from(tensor[len - 2][tensor.line_size() - 1]), - // For any other line size > 1, scale and zero-point offset are the last two elements - _ => i32::cast_from(tensor[len - 1][tensor.line_size() - 2]), - }, - // Symmetric quantization only contains the scaling factor, so we return 0 for the zero-point offset - QuantizationScheme::PerTensorSymmetric(_) => 0, + QuantizationScheme::PerTensorSymmetric(_) => ( + f32::bitcast_from(tensor[len - 1][tensor.line_size() - 1]), + 0, + ), } } } From 3a4d9a0faa23dc4aca778d0d8fbb2c842747af02 Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Wed, 11 Dec 2024 13:20:21 -0500 Subject: [PATCH 13/16] Move data into async context --- crates/burn-fusion/src/tensor.rs | 38 ++++++++++++++------------------ 1 file changed, 17 insertions(+), 21 deletions(-) diff --git a/crates/burn-fusion/src/tensor.rs b/crates/burn-fusion/src/tensor.rs index 89de37b07d..7cb98c6a1a 100644 --- a/crates/burn-fusion/src/tensor.rs +++ b/crates/burn-fusion/src/tensor.rs @@ -4,7 +4,7 @@ use burn_tensor::{ repr::{TensorDescription, TensorId, TensorStatus}, DType, Shape, TensorData, TensorMetadata, }; -use std::sync::Arc; +use std::{future::Future, sync::Arc}; /// Tensor primitive for the [fusion backend](crate::FusionBackend) for all kind. pub struct FusionTensor { @@ -117,52 +117,48 @@ impl FusionTensor { } } - pub(crate) async fn into_data(self) -> TensorData + pub(crate) fn into_data(self) -> impl Future where B: FusionBackend, { let id = self.stream; - self.client - .clone() - .read_tensor_float::(self.into_description(), id) - .await + let client = self.client.clone(); + let desc = self.into_description(); + async move { client.read_tensor_float::(desc, id).await } } - pub(crate) async fn q_into_data(self) -> TensorData + pub(crate) fn q_into_data(self) -> impl Future where B: FusionBackend, { if let DType::QFloat(_scheme) = self.dtype { let id = self.stream; - self.client - .clone() - .read_tensor_quantized::(self.into_description(), id) - .await + let client = self.client.clone(); + let desc = self.into_description(); + async move { client.read_tensor_quantized::(desc, id).await } } else { panic!("Expected quantized float dtype, got {:?}", self.dtype) } } - pub(crate) async fn int_into_data(self) -> TensorData + pub(crate) fn int_into_data(self) -> impl Future where B: FusionBackend, { let id = self.stream; - self.client - .clone() - .read_tensor_int::(self.into_description(), id) - .await + let client = self.client.clone(); + let desc = self.into_description(); + async move { client.read_tensor_int::(desc, id).await } } - pub(crate) async fn bool_into_data(self) -> TensorData + pub(crate) fn bool_into_data(self) -> impl Future where B: FusionBackend, { let id = self.stream; - self.client - .clone() - .read_tensor_bool::(self.into_description(), id) - .await + let client = self.client.clone(); + let desc = self.into_description(); + async move { client.read_tensor_bool::(desc, id).await } } } From 3d8fc3c738406ea958244e05ade0f1c92c8a89bf Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Thu, 12 Dec 2024 14:13:49 -0500 Subject: [PATCH 14/16] Fix ReprBackend handle type for JitBackend and Fusion --- crates/burn-fusion/src/backend.rs | 8 ++++---- crates/burn-jit/src/backend.rs | 30 +++++++++--------------------- 2 files changed, 13 insertions(+), 25 deletions(-) diff --git a/crates/burn-fusion/src/backend.rs b/crates/burn-fusion/src/backend.rs index d3f80a79dd..fccd56b4b0 100644 --- a/crates/burn-fusion/src/backend.rs +++ b/crates/burn-fusion/src/backend.rs @@ -182,8 +182,8 @@ impl ReprBackend for Fusion { handle.handle } - fn quantized_tensor(_handle: TensorHandle) -> QuantizedTensor { - todo!() // not as simple + fn quantized_tensor(handle: TensorHandle) -> QuantizedTensor { + handle.handle } fn float_tensor_handle(tensor: FloatTensor) -> Self::Handle { @@ -198,7 +198,7 @@ impl ReprBackend for Fusion { tensor } - fn quantized_tensor_handle(_tensor: QuantizedTensor) -> Self::Handle { - todo!() // not as simple + fn quantized_tensor_handle(tensor: QuantizedTensor) -> Self::Handle { + tensor } } diff --git a/crates/burn-jit/src/backend.rs b/crates/burn-jit/src/backend.rs index 9b5ea67e24..2d1b64ebce 100644 --- a/crates/burn-jit/src/backend.rs +++ b/crates/burn-jit/src/backend.rs @@ -99,49 +99,37 @@ where impl ReprBackend for JitBackend { - type Handle = HandleKind; + type Handle = JitTensor; fn float_tensor(handle: TensorHandle) -> FloatTensor { - match handle.handle { - HandleKind::Float(handle) => handle, - _ => panic!("Expected float handle, got {}", handle.handle.name()), - } + handle.handle } fn int_tensor(handle: TensorHandle) -> IntTensor { - match handle.handle { - HandleKind::Int(handle) => handle, - _ => panic!("Expected int handle, got {}", handle.handle.name()), - } + handle.handle } fn bool_tensor(handle: TensorHandle) -> BoolTensor { - match handle.handle { - HandleKind::Bool(handle) => handle, - _ => panic!("Expected bool handle, got {}", handle.handle.name()), - } + handle.handle } fn quantized_tensor(handles: TensorHandle) -> QuantizedTensor { - match handle.handle { - HandleKind::Quantized(handle) => handle, - _ => panic!("Expected quantized handle, got {}", handle.handle.name()), - } + handle.handle } fn float_tensor_handle(tensor: FloatTensor) -> Self::Handle { - HandleKind::Float(tensor) + tensor } fn int_tensor_handle(tensor: IntTensor) -> Self::Handle { - HandleKind::Int(tensor) + tensor } fn bool_tensor_handle(tensor: BoolTensor) -> Self::Handle { - HandleKind::Bool(tensor) + tensor } fn quantized_tensor_handle(tensor: QuantizedTensor) -> Self::Handle { - HandleKind::Quantized(tensor) + tensor } } From ba03fada2bccf2190c590b47a8cf50e70eecbbcd Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Fri, 13 Dec 2024 14:23:36 -0500 Subject: [PATCH 15/16] Fusion client read takes ownership --- crates/burn-fusion/src/client/base.rs | 8 ++++---- crates/burn-fusion/src/client/mutex.rs | 8 ++++---- crates/burn-fusion/src/tensor.rs | 9 ++++++--- 3 files changed, 14 insertions(+), 11 deletions(-) diff --git a/crates/burn-fusion/src/client/base.rs b/crates/burn-fusion/src/client/base.rs index 7b802c537e..48f1108a8f 100644 --- a/crates/burn-fusion/src/client/base.rs +++ b/crates/burn-fusion/src/client/base.rs @@ -36,7 +36,7 @@ where ) -> FusionTensor; /// Read the values contained by a float tensor. fn read_tensor_float( - &self, + self, tensor: TensorDescription, stream: StreamId, ) -> impl Future + Send + 'static @@ -44,7 +44,7 @@ where B: FusionBackend; /// Read the values contained by an int tensor. fn read_tensor_int( - &self, + self, tensor: TensorDescription, stream: StreamId, ) -> impl Future + Send + 'static @@ -52,7 +52,7 @@ where B: FusionBackend; /// Read the values contained by a bool tensor. fn read_tensor_bool( - &self, + self, tensor: TensorDescription, stream: StreamId, ) -> impl Future + Send + 'static @@ -60,7 +60,7 @@ where B: FusionBackend; /// Read the values contained by a quantized tensor. fn read_tensor_quantized( - &self, + self, tensor: TensorDescription, streams: StreamId, ) -> impl Future + Send + 'static diff --git a/crates/burn-fusion/src/client/mutex.rs b/crates/burn-fusion/src/client/mutex.rs index 1301564fb8..5c00ac391e 100644 --- a/crates/burn-fusion/src/client/mutex.rs +++ b/crates/burn-fusion/src/client/mutex.rs @@ -79,7 +79,7 @@ where } fn read_tensor_float( - &self, + self, tensor: TensorDescription, stream: StreamId, ) -> impl Future + 'static @@ -91,7 +91,7 @@ where } fn read_tensor_int( - &self, + self, tensor: TensorDescription, id: StreamId, ) -> impl Future + 'static @@ -102,7 +102,7 @@ where } fn read_tensor_bool( - &self, + self, tensor: TensorDescription, stream: StreamId, ) -> impl Future + 'static @@ -113,7 +113,7 @@ where } fn read_tensor_quantized( - &self, + self, tensor: TensorDescription, stream: StreamId, ) -> impl Future + 'static diff --git a/crates/burn-fusion/src/tensor.rs b/crates/burn-fusion/src/tensor.rs index 7cb98c6a1a..22eb22c490 100644 --- a/crates/burn-fusion/src/tensor.rs +++ b/crates/burn-fusion/src/tensor.rs @@ -124,7 +124,8 @@ impl FusionTensor { let id = self.stream; let client = self.client.clone(); let desc = self.into_description(); - async move { client.read_tensor_float::(desc, id).await } + let fut = client.read_tensor_float::(desc, id); + async move { fut.await } } pub(crate) fn q_into_data(self) -> impl Future @@ -135,7 +136,8 @@ impl FusionTensor { let id = self.stream; let client = self.client.clone(); let desc = self.into_description(); - async move { client.read_tensor_quantized::(desc, id).await } + let fut = client.read_tensor_quantized::(desc, id); + async move { fut.await } } else { panic!("Expected quantized float dtype, got {:?}", self.dtype) } @@ -158,7 +160,8 @@ impl FusionTensor { let id = self.stream; let client = self.client.clone(); let desc = self.into_description(); - async move { client.read_tensor_bool::(desc, id).await } + let fut = client.read_tensor_bool::(desc, id); + async move { fut.await } } } From a93e8da7efaf6575facab1dba0ec8a903f80c5d4 Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Fri, 13 Dec 2024 14:30:35 -0500 Subject: [PATCH 16/16] Fix clippy --- crates/burn-fusion/src/tensor.rs | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/crates/burn-fusion/src/tensor.rs b/crates/burn-fusion/src/tensor.rs index 22eb22c490..f620e2c722 100644 --- a/crates/burn-fusion/src/tensor.rs +++ b/crates/burn-fusion/src/tensor.rs @@ -124,8 +124,7 @@ impl FusionTensor { let id = self.stream; let client = self.client.clone(); let desc = self.into_description(); - let fut = client.read_tensor_float::(desc, id); - async move { fut.await } + client.read_tensor_float::(desc, id) } pub(crate) fn q_into_data(self) -> impl Future @@ -136,8 +135,7 @@ impl FusionTensor { let id = self.stream; let client = self.client.clone(); let desc = self.into_description(); - let fut = client.read_tensor_quantized::(desc, id); - async move { fut.await } + client.read_tensor_quantized::(desc, id) } else { panic!("Expected quantized float dtype, got {:?}", self.dtype) } @@ -150,7 +148,7 @@ impl FusionTensor { let id = self.stream; let client = self.client.clone(); let desc = self.into_description(); - async move { client.read_tensor_int::(desc, id).await } + client.read_tensor_int::(desc, id) } pub(crate) fn bool_into_data(self) -> impl Future @@ -160,8 +158,7 @@ impl FusionTensor { let id = self.stream; let client = self.client.clone(); let desc = self.into_description(); - let fut = client.read_tensor_bool::(desc, id); - async move { fut.await } + client.read_tensor_bool::(desc, id) } }