From 0dd228cdcdb0d2c2704518ff2ee4ba7fd5f91cd8 Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Fri, 13 Dec 2024 15:24:01 -0500 Subject: [PATCH 1/5] Refactor jit quantized tensor representation (#2604) * Remove q_shape to use TensorMetadata instead * Fix spirv bool type * Refactor burn-jit quantized tensor representation * Remove dead comment * Update cubecl rev * Remove dead code * Fix comments * Fix clippy * Remove unnecessary loop for input line size of 1 * Remove quantized kindremnant * Remove no longer valid comment * Get qparams values as tuple * Move data into async context * Fix ReprBackend handle type for JitBackend and Fusion * Fusion client read takes ownership * Fix clippy --- Cargo.lock | 24 +- Cargo.toml | 4 +- crates/burn-autodiff/src/ops/qtensor.rs | 4 - crates/burn-candle/src/ops/qtensor.rs | 4 - crates/burn-candle/src/tensor.rs | 4 - crates/burn-fusion/src/backend.rs | 19 +- crates/burn-fusion/src/client/base.rs | 22 +- crates/burn-fusion/src/client/mutex.rs | 62 ++--- crates/burn-fusion/src/ops/qtensor.rs | 136 ++--------- crates/burn-fusion/src/server.rs | 54 +--- crates/burn-fusion/src/stream/context.rs | 14 +- crates/burn-fusion/src/tensor.rs | 162 +++--------- crates/burn-jit/src/backend.rs | 54 ++-- crates/burn-jit/src/fusion/base.rs | 35 +-- crates/burn-jit/src/kernel/matmul/base.rs | 3 +- .../burn-jit/src/kernel/matmul/tune/base.rs | 9 +- .../src/kernel/quantization/dequantize.rs | 231 +++++++----------- .../burn-jit/src/kernel/quantization/mod.rs | 2 + .../src/kernel/quantization/qtensor.rs | 49 ++++ .../src/kernel/quantization/quantize.rs | 205 +++++++++------- crates/burn-jit/src/ops/qtensor.rs | 87 +++---- crates/burn-jit/src/tensor/base.rs | 36 ++- crates/burn-jit/src/tensor/mod.rs | 2 - crates/burn-jit/src/tensor/qtensor.rs | 125 ---------- crates/burn-ndarray/src/backend.rs | 23 +- crates/burn-ndarray/src/ops/qtensor.rs | 8 +- crates/burn-ndarray/src/tensor.rs | 15 +- crates/burn-router/src/backend.rs | 6 +- crates/burn-router/src/ops/op_qfloat.rs | 4 - crates/burn-tch/src/ops/qtensor.rs | 9 +- crates/burn-tch/src/tensor.rs | 35 +-- crates/burn-tensor/src/repr/backend.rs | 23 +- crates/burn-tensor/src/repr/handle.rs | 29 +-- crates/burn-tensor/src/repr/mod.rs | 2 - crates/burn-tensor/src/repr/operation.rs | 26 +- crates/burn-tensor/src/repr/quantization.rs | 25 -- crates/burn-tensor/src/tensor/data.rs | 16 +- crates/burn-tensor/src/tensor/ops/qtensor.rs | 19 +- .../src/tensor/quantization/primitive.rs | 8 +- .../src/tensor/quantization/scheme.rs | 18 ++ .../src/tests/quantization/ops/quantize.rs | 22 +- crates/burn-wgpu/src/lib.rs | 2 +- 42 files changed, 587 insertions(+), 1050 deletions(-) create mode 100644 crates/burn-jit/src/kernel/quantization/qtensor.rs delete mode 100644 crates/burn-jit/src/tensor/qtensor.rs delete mode 100644 crates/burn-tensor/src/repr/quantization.rs diff --git a/Cargo.lock b/Cargo.lock index 7241f4dff6..2367618c2f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1668,7 +1668,7 @@ dependencies = [ [[package]] name = "cubecl" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=1c4e0036c043422773fd6824c2a888160fca8e5e#1c4e0036c043422773fd6824c2a888160fca8e5e" +source = "git+https://github.com/tracel-ai/cubecl?rev=6e6fb265346c6378e939573900c5d32b722569fa#6e6fb265346c6378e939573900c5d32b722569fa" dependencies = [ "cubecl-core", "cubecl-cuda", @@ -1700,7 +1700,7 @@ dependencies = [ [[package]] name = "cubecl-common" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=1c4e0036c043422773fd6824c2a888160fca8e5e#1c4e0036c043422773fd6824c2a888160fca8e5e" +source = "git+https://github.com/tracel-ai/cubecl?rev=6e6fb265346c6378e939573900c5d32b722569fa#6e6fb265346c6378e939573900c5d32b722569fa" dependencies = [ "derive-new 0.6.0", "embassy-futures", @@ -1717,7 +1717,7 @@ dependencies = [ [[package]] name = "cubecl-core" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=1c4e0036c043422773fd6824c2a888160fca8e5e#1c4e0036c043422773fd6824c2a888160fca8e5e" +source = "git+https://github.com/tracel-ai/cubecl?rev=6e6fb265346c6378e939573900c5d32b722569fa#6e6fb265346c6378e939573900c5d32b722569fa" dependencies = [ "bytemuck", "cubecl-common 0.4.0", @@ -1735,7 +1735,7 @@ dependencies = [ [[package]] name = "cubecl-cpp" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=1c4e0036c043422773fd6824c2a888160fca8e5e#1c4e0036c043422773fd6824c2a888160fca8e5e" +source = "git+https://github.com/tracel-ai/cubecl?rev=6e6fb265346c6378e939573900c5d32b722569fa#6e6fb265346c6378e939573900c5d32b722569fa" dependencies = [ "bytemuck", "cubecl-common 0.4.0", @@ -1749,7 +1749,7 @@ dependencies = [ [[package]] name = "cubecl-cuda" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=1c4e0036c043422773fd6824c2a888160fca8e5e#1c4e0036c043422773fd6824c2a888160fca8e5e" +source = "git+https://github.com/tracel-ai/cubecl?rev=6e6fb265346c6378e939573900c5d32b722569fa#6e6fb265346c6378e939573900c5d32b722569fa" dependencies = [ "bytemuck", "cubecl-common 0.4.0", @@ -1765,7 +1765,7 @@ dependencies = [ [[package]] name = "cubecl-hip" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=1c4e0036c043422773fd6824c2a888160fca8e5e#1c4e0036c043422773fd6824c2a888160fca8e5e" +source = "git+https://github.com/tracel-ai/cubecl?rev=6e6fb265346c6378e939573900c5d32b722569fa#6e6fb265346c6378e939573900c5d32b722569fa" dependencies = [ "bytemuck", "cubecl-common 0.4.0", @@ -1791,7 +1791,7 @@ dependencies = [ [[package]] name = "cubecl-linalg" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=1c4e0036c043422773fd6824c2a888160fca8e5e#1c4e0036c043422773fd6824c2a888160fca8e5e" +source = "git+https://github.com/tracel-ai/cubecl?rev=6e6fb265346c6378e939573900c5d32b722569fa#6e6fb265346c6378e939573900c5d32b722569fa" dependencies = [ "bytemuck", "cubecl-core", @@ -1802,7 +1802,7 @@ dependencies = [ [[package]] name = "cubecl-macros" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=1c4e0036c043422773fd6824c2a888160fca8e5e#1c4e0036c043422773fd6824c2a888160fca8e5e" +source = "git+https://github.com/tracel-ai/cubecl?rev=6e6fb265346c6378e939573900c5d32b722569fa#6e6fb265346c6378e939573900c5d32b722569fa" dependencies = [ "cubecl-common 0.4.0", "darling", @@ -1817,7 +1817,7 @@ dependencies = [ [[package]] name = "cubecl-opt" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=1c4e0036c043422773fd6824c2a888160fca8e5e#1c4e0036c043422773fd6824c2a888160fca8e5e" +source = "git+https://github.com/tracel-ai/cubecl?rev=6e6fb265346c6378e939573900c5d32b722569fa#6e6fb265346c6378e939573900c5d32b722569fa" dependencies = [ "cubecl-common 0.4.0", "cubecl-core", @@ -1854,7 +1854,7 @@ dependencies = [ [[package]] name = "cubecl-runtime" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=1c4e0036c043422773fd6824c2a888160fca8e5e#1c4e0036c043422773fd6824c2a888160fca8e5e" +source = "git+https://github.com/tracel-ai/cubecl?rev=6e6fb265346c6378e939573900c5d32b722569fa#6e6fb265346c6378e939573900c5d32b722569fa" dependencies = [ "async-channel", "async-lock", @@ -1875,7 +1875,7 @@ dependencies = [ [[package]] name = "cubecl-spirv" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=1c4e0036c043422773fd6824c2a888160fca8e5e#1c4e0036c043422773fd6824c2a888160fca8e5e" +source = "git+https://github.com/tracel-ai/cubecl?rev=6e6fb265346c6378e939573900c5d32b722569fa#6e6fb265346c6378e939573900c5d32b722569fa" dependencies = [ "cubecl-common 0.4.0", "cubecl-core", @@ -1889,7 +1889,7 @@ dependencies = [ [[package]] name = "cubecl-wgpu" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=1c4e0036c043422773fd6824c2a888160fca8e5e#1c4e0036c043422773fd6824c2a888160fca8e5e" +source = "git+https://github.com/tracel-ai/cubecl?rev=6e6fb265346c6378e939573900c5d32b722569fa#6e6fb265346c6378e939573900c5d32b722569fa" dependencies = [ "ash", "async-channel", diff --git a/Cargo.toml b/Cargo.toml index a23077bd0b..07df9cdda4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -153,8 +153,8 @@ ahash = { version = "0.8.11", default-features = false } portable-atomic-util = { version = "0.2.4", features = ["alloc"] } ### For the main burn branch. ### -cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "1c4e0036c043422773fd6824c2a888160fca8e5e" } -cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "1c4e0036c043422773fd6824c2a888160fca8e5e" } +cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "6e6fb265346c6378e939573900c5d32b722569fa" } +cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "6e6fb265346c6378e939573900c5d32b722569fa" } ### For local development. ### # cubecl = { path = "../cubecl/crates/cubecl", default-features = false } # cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false } diff --git a/crates/burn-autodiff/src/ops/qtensor.rs b/crates/burn-autodiff/src/ops/qtensor.rs index ee3c8ab7f7..b55157f328 100644 --- a/crates/burn-autodiff/src/ops/qtensor.rs +++ b/crates/burn-autodiff/src/ops/qtensor.rs @@ -33,10 +33,6 @@ impl QTensorOps for Autodiff { todo!() } - fn q_shape(tensor: &QuantizedTensor) -> Shape { - B::q_shape(tensor) - } - fn q_device(tensor: &QuantizedTensor) -> Device { B::q_device(tensor) } diff --git a/crates/burn-candle/src/ops/qtensor.rs b/crates/burn-candle/src/ops/qtensor.rs index 9d63f02917..f4c2e96f04 100644 --- a/crates/burn-candle/src/ops/qtensor.rs +++ b/crates/burn-candle/src/ops/qtensor.rs @@ -29,10 +29,6 @@ impl QTensorOps for Candle) -> Shape { - super::base::shape(&tensor.qtensor) - } - fn q_device(tensor: &QuantizedTensor) -> Device { super::base::device(&tensor.qtensor) } diff --git a/crates/burn-candle/src/tensor.rs b/crates/burn-candle/src/tensor.rs index c4e7d740ca..ec71b3d47e 100644 --- a/crates/burn-candle/src/tensor.rs +++ b/crates/burn-candle/src/tensor.rs @@ -70,10 +70,6 @@ impl QTensorPrimitive for CandleQTensor { fn scheme(&self) -> &QuantizationScheme { &self.scheme } - - fn strategy(&self) -> QuantizationStrategy { - todo!() - } } impl TensorMetadata for CandleQTensor { diff --git a/crates/burn-fusion/src/backend.rs b/crates/burn-fusion/src/backend.rs index aa308ad9a7..fccd56b4b0 100644 --- a/crates/burn-fusion/src/backend.rs +++ b/crates/burn-fusion/src/backend.rs @@ -1,10 +1,8 @@ -use crate::{ - client::FusionClient, stream::Context, FusionClientLocator, FusionTensor, QFusionTensor, -}; +use crate::{client::FusionClient, stream::Context, FusionClientLocator, FusionTensor}; use burn_tensor::{ backend::{Backend, DeviceOps}, ops::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor}, - repr::{OperationDescription, QuantizedKind, ReprBackend, TensorHandle}, + repr::{OperationDescription, ReprBackend, TensorHandle}, Device, Element, }; use serde::{de::DeserializeOwned, Serialize}; @@ -37,7 +35,7 @@ impl Backend for Fusion { type BoolElem = B::BoolElem; - type QuantizedTensorPrimitive = QFusionTensor; + type QuantizedTensorPrimitive = FusionTensor; type QuantizedEncoding = B::QuantizedEncoding; @@ -184,11 +182,8 @@ impl ReprBackend for Fusion { handle.handle } - fn quantized_tensor( - _handles: QuantizedKind>, - _scheme: burn_tensor::quantization::QuantizationScheme, - ) -> QuantizedTensor { - todo!() // not as simple + fn quantized_tensor(handle: TensorHandle) -> QuantizedTensor { + handle.handle } fn float_tensor_handle(tensor: FloatTensor) -> Self::Handle { @@ -203,7 +198,7 @@ impl ReprBackend for Fusion { tensor } - fn quantized_tensor_handle(_tensor: QuantizedTensor) -> QuantizedKind { - todo!() // not as simple + fn quantized_tensor_handle(tensor: QuantizedTensor) -> Self::Handle { + tensor } } diff --git a/crates/burn-fusion/src/client/base.rs b/crates/burn-fusion/src/client/base.rs index 33508c40b7..48f1108a8f 100644 --- a/crates/burn-fusion/src/client/base.rs +++ b/crates/burn-fusion/src/client/base.rs @@ -2,10 +2,10 @@ use std::future::Future; use crate::{ stream::{execution::Operation, StreamId}, - FusionBackend, FusionDevice, FusionHandle, FusionRuntime, FusionTensor, QFusionTensor, + FusionBackend, FusionDevice, FusionHandle, FusionRuntime, FusionTensor, }; use burn_tensor::{ - repr::{OperationDescription, QuantizedTensorDescription, TensorDescription, TensorId}, + repr::{OperationDescription, TensorDescription, TensorId}, DType, TensorData, }; @@ -36,7 +36,7 @@ where ) -> FusionTensor; /// Read the values contained by a float tensor. fn read_tensor_float( - &self, + self, tensor: TensorDescription, stream: StreamId, ) -> impl Future + Send + 'static @@ -44,7 +44,7 @@ where B: FusionBackend; /// Read the values contained by an int tensor. fn read_tensor_int( - &self, + self, tensor: TensorDescription, stream: StreamId, ) -> impl Future + Send + 'static @@ -52,7 +52,7 @@ where B: FusionBackend; /// Read the values contained by a bool tensor. fn read_tensor_bool( - &self, + self, tensor: TensorDescription, stream: StreamId, ) -> impl Future + Send + 'static @@ -60,9 +60,9 @@ where B: FusionBackend; /// Read the values contained by a quantized tensor. fn read_tensor_quantized( - &self, - tensor: QuantizedTensorDescription, - streams: Vec, + self, + tensor: TensorDescription, + streams: StreamId, ) -> impl Future + Send + 'static where B: FusionBackend; @@ -108,10 +108,10 @@ where /// Change the client of the given quantized tensor. fn change_client_quantized( &self, - tensor: QuantizedTensorDescription, + tensor: TensorDescription, client: Self, - streams: Vec, - ) -> QFusionTensor + stream: StreamId, + ) -> FusionTensor where B: FusionBackend; /// Drop the tensor with the given [tensor id](TensorId). diff --git a/crates/burn-fusion/src/client/mutex.rs b/crates/burn-fusion/src/client/mutex.rs index 503fb2d434..5c00ac391e 100644 --- a/crates/burn-fusion/src/client/mutex.rs +++ b/crates/burn-fusion/src/client/mutex.rs @@ -1,11 +1,10 @@ use super::FusionClient; use crate::{ stream::{execution::Operation, StreamId}, - FusionBackend, FusionDevice, FusionHandle, FusionQuantizationParameters, FusionRuntime, - FusionServer, FusionTensor, QFusionTensor, + FusionBackend, FusionDevice, FusionHandle, FusionRuntime, FusionServer, FusionTensor, }; use burn_tensor::{ - repr::{OperationDescription, QuantizedTensorDescription, TensorDescription, TensorId}, + repr::{OperationDescription, TensorDescription, TensorId}, DType, }; use spin::Mutex; @@ -80,7 +79,7 @@ where } fn read_tensor_float( - &self, + self, tensor: TensorDescription, stream: StreamId, ) -> impl Future + 'static @@ -92,7 +91,7 @@ where } fn read_tensor_int( - &self, + self, tensor: TensorDescription, id: StreamId, ) -> impl Future + 'static @@ -103,7 +102,7 @@ where } fn read_tensor_bool( - &self, + self, tensor: TensorDescription, stream: StreamId, ) -> impl Future + 'static @@ -114,14 +113,14 @@ where } fn read_tensor_quantized( - &self, - tensor: QuantizedTensorDescription, - streams: Vec, + self, + tensor: TensorDescription, + stream: StreamId, ) -> impl Future + 'static where B: FusionBackend, { - self.server.lock().read_quantized::(tensor, streams) + self.server.lock().read_quantized::(tensor, stream) } fn change_client_float( @@ -190,55 +189,24 @@ where fn change_client_quantized( &self, - tensor: QuantizedTensorDescription, + tensor: TensorDescription, client: Self, - streams: Vec, - ) -> QFusionTensor + stream: StreamId, + ) -> FusionTensor where B: FusionBackend, { let mut server_other = client.server.lock(); let mut server_current = self.server.lock(); - for stream in streams { - server_current.drain_stream(stream); - } + server_current.drain_stream(stream); - let mut ids = + let id = server_current.change_server_quantized::(&tensor, &client.device, &mut server_other); core::mem::drop(server_other); core::mem::drop(server_current); - // NOTE: the expected order is known [qtensor, scale, ] - let offset = tensor.qparams.offset.map(|desc| { - FusionTensor::new( - ids.pop().unwrap(), - desc.shape, - desc.dtype, - client.clone(), - StreamId::current(), - ) - }); - let scale = FusionTensor::new( - ids.pop().unwrap(), - tensor.qparams.scale.shape, - tensor.qparams.scale.dtype, - client.clone(), - StreamId::current(), - ); - let qtensor = FusionTensor::new( - ids.pop().unwrap(), - tensor.tensor.shape, - tensor.tensor.dtype, - client, - StreamId::current(), - ); - - QFusionTensor { - qtensor, - scheme: tensor.scheme, - qparams: FusionQuantizationParameters { scale, offset }, - } + FusionTensor::new(id, tensor.shape, tensor.dtype, client, StreamId::current()) } fn register_orphan(&self, id: &TensorId) { diff --git a/crates/burn-fusion/src/ops/qtensor.rs b/crates/burn-fusion/src/ops/qtensor.rs index c8b1c09610..41bc7ccde6 100644 --- a/crates/burn-fusion/src/ops/qtensor.rs +++ b/crates/burn-fusion/src/ops/qtensor.rs @@ -2,11 +2,10 @@ use std::{marker::PhantomData, ops::Range}; use burn_tensor::{ ops::{FloatElem, FloatTensor, IntTensor, QTensorOps, QuantizedTensor}, - quantization::{QuantizationParametersPrimitive, QuantizationScheme, QuantizationType}, + quantization::{QuantizationParametersPrimitive, QuantizationScheme}, repr::{ DequantizeOperationDescription, FloatOperationDescription, HandleContainer, OperationDescription, QuantizationParametersDescription, QuantizeOperationDescription, - QuantizedKind, }, DType, Device, Element, Shape, TensorData, }; @@ -15,67 +14,24 @@ use crate::{ client::FusionClient, get_client, stream::{execution::Operation, StreamId}, - Fusion, FusionBackend, FusionQuantizationParameters, QFusionTensor, + Fusion, FusionBackend, }; impl QTensorOps for Fusion { fn q_from_data(data: TensorData, device: &Device) -> QuantizedTensor { match data.dtype { - DType::QFloat(scheme) => { + DType::QFloat(_scheme) => { + let dtype = data.dtype; let client = get_client::(device); let tensor = B::q_from_data(data, device); - let shape = B::q_shape(&tensor); - - let handles = B::quantized_tensor_handle(tensor); - let qparams = match scheme { - QuantizationScheme::PerTensorAffine(QuantizationType::QInt8) => { - let offset = if let Some(offset) = handles.offset { - offset - } else { - panic!("Expected offset for quantized tensor."); - }; - FusionQuantizationParameters { - scale: client.register_tensor( - handles.scale, - vec![1], - StreamId::current(), - B::FloatElem::dtype(), - ), - offset: Some(client.register_tensor( - offset, - vec![1], - StreamId::current(), - B::IntElem::dtype(), - )), - } - } - QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8) => { - assert!( - handles.offset.is_none(), - "Offset should not be provided for symmetric quantization." - ); - FusionQuantizationParameters { - scale: client.register_tensor( - handles.scale, - vec![1], - StreamId::current(), - B::FloatElem::dtype(), - ), - offset: None, - } - } - }; - let qtensor = client.register_tensor( - handles.tensor, + let shape = burn_tensor::TensorMetadata::shape(&tensor); + + client.register_tensor( + B::quantized_tensor_handle(tensor), shape.dims, StreamId::current(), - B::QuantizedEncoding::dtype(), - ); - QFusionTensor { - qtensor, - qparams, - scheme, - } + dtype, + ) } _ => panic!( "Invalid dtype (expected DType::QFloat, got {:?})", @@ -108,27 +64,14 @@ impl QTensorOps for Fusion { let qparams = QuantizationParametersPrimitive { scale, offset }; let output = B::quantize(tensor, &self.desc.scheme, qparams); - let q_ids = if let Some(offset) = &self.desc.qparams.offset { - QuantizedKind { - tensor: self.desc.out.id, - scale: self.desc.qparams.scale.id, - offset: Some(offset.id), - } - } else { - QuantizedKind { - tensor: self.desc.out.id, - scale: self.desc.qparams.scale.id, - offset: None, - } - }; - handles.register_quantized_tensor::(&q_ids, output); + handles.register_quantized_tensor::(&self.desc.out.id, output); } } let shape: Vec = tensor.shape.clone(); let out = tensor .client - .tensor_uninitialized(shape, B::QuantizedEncoding::dtype()); + .tensor_uninitialized(shape, DType::QFloat(*scheme)); let streams = if let Some(offset) = &qparams.offset { vec![tensor.stream, qparams.scale.stream, offset.stream] @@ -155,11 +98,7 @@ impl QTensorOps for Fusion { QuantizeOp::::new(desc), ); - QFusionTensor { - qtensor: out, - scheme: *scheme, - qparams: qparams.into(), - } + out } fn dequantize(tensor: QuantizedTensor) -> FloatTensor { @@ -171,36 +110,26 @@ impl QTensorOps for Fusion { impl Operation for DequantizeOp { fn execute(self: Box, handles: &mut HandleContainer) { - let tensor = handles.get_quantized_tensor::(&self.desc.qtensor); + let tensor = handles.get_quantized_tensor::(&self.desc.input); let output = B::dequantize(tensor); handles.register_float_tensor::(&self.desc.out.id, output); } } - let shape: Vec = tensor.qtensor.shape.clone(); + let stream = tensor.stream; + let shape: Vec = tensor.shape.clone(); let out = tensor - .qtensor .client .tensor_uninitialized(shape, B::FloatElem::dtype()); - let streams = if let Some(offset) = &tensor.qparams.offset { - vec![ - tensor.qtensor.stream, - tensor.qparams.scale.stream, - offset.stream, - ] - } else { - vec![tensor.qtensor.stream, tensor.qparams.scale.stream] - }; - let desc = DequantizeOperationDescription { - qtensor: tensor.into_description(), + input: tensor.into_description(), out: out.to_description_out(), }; out.client.register( - streams, + vec![stream], OperationDescription::Float( FloatElem::::dtype(), FloatOperationDescription::Dequantize(desc.clone()), @@ -211,40 +140,23 @@ impl QTensorOps for Fusion { out } - fn q_shape(tensor: &QuantizedTensor) -> Shape { - // Conflicting `dtype()` when both `Element` and `TensorMetadata` traits are in - // scope so we use the fully qualified syntax - burn_tensor::TensorMetadata::shape(tensor) - } - fn q_device(tensor: &QuantizedTensor) -> Device { - tensor.qtensor.client.device().clone() + tensor.client.device().clone() } fn q_to_device(tensor: QuantizedTensor, device: &Device) -> QuantizedTensor { - // Quantization parameters are on the same device as the qtensor - let device_original: &B::Device = tensor.qtensor.client.device(); + let device_original: &B::Device = tensor.client.device(); let device_target: B::Device = device.clone(); if device_original == &device_target { return tensor; } - println!("q_to_device {:?} {:?}", device_original, device_target); + let id = tensor.stream; let client_target = get_client::(&device_target); - let client_original = tensor.qtensor.client.clone(); - - let ids = if let Some(offset) = &tensor.qparams.offset { - vec![ - tensor.qtensor.stream, - tensor.qparams.scale.stream, - offset.stream, - ] - } else { - vec![tensor.qtensor.stream, tensor.qparams.scale.stream] - }; + let client_original = tensor.client.clone(); - client_original.change_client_quantized::(tensor.into_description(), client_target, ids) + client_original.change_client_quantized::(tensor.into_description(), client_target, id) } fn q_reshape(_tensor: QuantizedTensor, _shape: Shape) -> QuantizedTensor { @@ -252,7 +164,7 @@ impl QTensorOps for Fusion { } async fn q_into_data(tensor: QuantizedTensor) -> TensorData { - tensor.into_data::().await + tensor.q_into_data::().await } fn q_swap_dims( diff --git a/crates/burn-fusion/src/server.rs b/crates/burn-fusion/src/server.rs index bff39a1d2a..688645fc34 100644 --- a/crates/burn-fusion/src/server.rs +++ b/crates/burn-fusion/src/server.rs @@ -2,10 +2,7 @@ use crate::{ stream::{execution::Operation, MultiStream, StreamId}, FusionBackend, FusionRuntime, }; -use burn_tensor::repr::{ - HandleContainer, OperationDescription, QuantizedKind, QuantizedTensorDescription, - TensorDescription, TensorId, -}; +use burn_tensor::repr::{HandleContainer, OperationDescription, TensorDescription, TensorId}; use std::{future::Future, sync::Arc}; pub struct FusionServer { @@ -92,17 +89,15 @@ where pub fn read_quantized( &mut self, - tensor: QuantizedTensorDescription, - ids: Vec, + tensor: TensorDescription, + id: StreamId, ) -> impl Future + 'static where B: FusionBackend, { // Make sure all registered operations are executed. // The underlying backend can still be async. - for id in ids { - self.drain_stream(id); - } + self.drain_stream(id); let tensor = self.handles.get_quantized_tensor::(&tensor); B::q_into_data(tensor) @@ -191,45 +186,22 @@ where pub fn change_server_quantized( &mut self, - desc: &QuantizedTensorDescription, + desc: &TensorDescription, device: &R::FusionDevice, server_device: &mut Self, - ) -> Vec> + ) -> Arc where B: FusionBackend, { let tensor = self.handles.get_quantized_tensor::(desc); let tensor = B::q_to_device(tensor, device); - if desc.qparams.offset.is_some() { - let tensor_id = server_device.create_empty_handle(); - let scale_id = server_device.create_empty_handle(); - let offset_id = server_device.create_empty_handle(); - - let q_ids = QuantizedKind { - tensor: *tensor_id, - scale: *scale_id, - offset: Some(*offset_id), - }; - server_device - .handles - .register_quantized_tensor::(&q_ids, tensor); - - vec![tensor_id, scale_id, offset_id] - } else { - let tensor_id = server_device.create_empty_handle(); - let scale_id = server_device.create_empty_handle(); - - let q_ids = QuantizedKind { - tensor: *tensor_id, - scale: *scale_id, - offset: None, - }; - server_device - .handles - .register_quantized_tensor::(&q_ids, tensor); - - vec![tensor_id, scale_id] - } + let id = server_device.create_empty_handle(); + + server_device + .handles + .register_quantized_tensor::(&id, tensor); + + id } pub fn drop_tensor_handle(&mut self, id: TensorId) { diff --git a/crates/burn-fusion/src/stream/context.rs b/crates/burn-fusion/src/stream/context.rs index 565f160362..7dcc81cb77 100644 --- a/crates/burn-fusion/src/stream/context.rs +++ b/crates/burn-fusion/src/stream/context.rs @@ -553,19 +553,7 @@ impl RelativeOpsScalar for FloatOperationDescription { } FloatOperationDescription::Dequantize(desc) => { FloatOperationDescription::Dequantize(DequantizeOperationDescription { - qtensor: QuantizedTensorDescription { - tensor: desc.qtensor.tensor.to_relative(converter), - qparams: QuantizationParametersDescription { - scale: desc.qtensor.qparams.scale.to_relative(converter), - offset: desc - .qtensor - .qparams - .offset - .as_ref() - .map(|x| x.to_relative(converter)), - }, - scheme: desc.qtensor.scheme, - }, + input: desc.input.to_relative(converter), out: desc.out.to_relative(converter), }) } diff --git a/crates/burn-fusion/src/tensor.rs b/crates/burn-fusion/src/tensor.rs index 44d71abd4c..f620e2c722 100644 --- a/crates/burn-fusion/src/tensor.rs +++ b/crates/burn-fusion/src/tensor.rs @@ -1,15 +1,10 @@ -use crate::{client::FusionClient, stream::StreamId, Client, Fusion, FusionBackend, FusionRuntime}; +use crate::{client::FusionClient, stream::StreamId, Client, FusionBackend, FusionRuntime}; use burn_tensor::{ - quantization::{ - QTensorPrimitive, QuantizationParametersPrimitive, QuantizationScheme, QuantizationStrategy, - }, - repr::{ - QuantizationParametersDescription, QuantizedTensorDescription, TensorDescription, TensorId, - TensorStatus, - }, + quantization::{QTensorPrimitive, QuantizationScheme}, + repr::{TensorDescription, TensorId, TensorStatus}, DType, Shape, TensorData, TensorMetadata, }; -use std::sync::Arc; +use std::{future::Future, sync::Arc}; /// Tensor primitive for the [fusion backend](crate::FusionBackend) for all kind. pub struct FusionTensor { @@ -122,37 +117,48 @@ impl FusionTensor { } } - pub(crate) async fn into_data(self) -> TensorData + pub(crate) fn into_data(self) -> impl Future where B: FusionBackend, { let id = self.stream; - self.client - .clone() - .read_tensor_float::(self.into_description(), id) - .await + let client = self.client.clone(); + let desc = self.into_description(); + client.read_tensor_float::(desc, id) } - pub(crate) async fn int_into_data(self) -> TensorData + pub(crate) fn q_into_data(self) -> impl Future + where + B: FusionBackend, + { + if let DType::QFloat(_scheme) = self.dtype { + let id = self.stream; + let client = self.client.clone(); + let desc = self.into_description(); + client.read_tensor_quantized::(desc, id) + } else { + panic!("Expected quantized float dtype, got {:?}", self.dtype) + } + } + + pub(crate) fn int_into_data(self) -> impl Future where B: FusionBackend, { let id = self.stream; - self.client - .clone() - .read_tensor_int::(self.into_description(), id) - .await + let client = self.client.clone(); + let desc = self.into_description(); + client.read_tensor_int::(desc, id) } - pub(crate) async fn bool_into_data(self) -> TensorData + pub(crate) fn bool_into_data(self) -> impl Future where B: FusionBackend, { let id = self.stream; - self.client - .clone() - .read_tensor_bool::(self.into_description(), id) - .await + let client = self.client.clone(); + let desc = self.into_description(); + client.read_tensor_bool::(desc, id) } } @@ -172,109 +178,15 @@ impl Drop for FusionTensor { } } -/// A quantized tensor primitive for fusion backends. -#[derive(Debug)] -pub struct QFusionTensor { - /// The quantized tensor. - pub qtensor: FusionTensor, - /// The quantization scheme. - pub scheme: QuantizationScheme, - /// The quantization parameters. - pub qparams: FusionQuantizationParameters, -} - -impl QTensorPrimitive for QFusionTensor { +impl QTensorPrimitive for FusionTensor { fn scheme(&self) -> &QuantizationScheme { - &self.scheme - } - - fn strategy(&self) -> QuantizationStrategy { - // TODO - todo!() - } -} - -impl Clone for QFusionTensor { - fn clone(&self) -> Self { - Self { - qtensor: self.qtensor.clone(), - scheme: self.scheme, - qparams: self.qparams.clone(), - } - } -} - -impl TensorMetadata for QFusionTensor { - fn dtype(&self) -> DType { - DType::QFloat(self.scheme) - } - - fn shape(&self) -> Shape { - self.qtensor.shape() - } -} - -impl QFusionTensor { - pub(crate) async fn into_data(self) -> TensorData - where - B: FusionBackend, - { - let streams = if let Some(offset) = &self.qparams.offset { - vec![ - self.qtensor.stream, - self.qparams.scale.stream, - offset.stream, - ] + if let DType::QFloat(scheme) = &self.dtype { + scheme } else { - vec![self.qtensor.stream, self.qparams.scale.stream] - }; - - // Quantized tensor and qparams tensors client are the same - self.qtensor - .client - .clone() - .read_tensor_quantized::(self.into_description(), streams) - .await - } - - /// Description to be used when using an initialized tensor used as input. - pub(crate) fn into_description(self) -> QuantizedTensorDescription { - QuantizedTensorDescription { - tensor: self.qtensor.into_description(), - qparams: QuantizationParametersDescription { - scale: self.qparams.scale.into_description(), - offset: self.qparams.offset.map(|x| x.into_description()), - }, - scheme: self.scheme, - } - } -} - -/// The quantization parameters. -#[derive(Debug)] -pub struct FusionQuantizationParameters { - /// The scaling factor. - pub scale: FusionTensor, - /// The zero-point offset. - pub offset: Option>, -} - -impl Clone for FusionQuantizationParameters { - fn clone(&self) -> Self { - Self { - scale: self.scale.clone(), - offset: self.offset.clone(), - } - } -} - -impl From>> - for FusionQuantizationParameters -{ - fn from(value: QuantizationParametersPrimitive>) -> Self { - FusionQuantizationParameters { - scale: value.scale, - offset: value.offset, + panic!( + "Quantization scheme is not valid for dtype {:?}", + self.dtype, + ) } } } diff --git a/crates/burn-jit/src/backend.rs b/crates/burn-jit/src/backend.rs index b455d859a0..2d1b64ebce 100644 --- a/crates/burn-jit/src/backend.rs +++ b/crates/burn-jit/src/backend.rs @@ -1,8 +1,4 @@ -use crate::{ - element::BoolElement, - tensor::{JitTensor, QJitTensor}, - FloatElement, IntElement, JitRuntime, -}; +use crate::{element::BoolElement, tensor::JitTensor, FloatElement, IntElement, JitRuntime}; use burn_tensor::backend::{Backend, DeviceOps}; use cubecl::server::ComputeServer; use rand::{rngs::StdRng, SeedableRng}; @@ -12,7 +8,7 @@ use std::{marker::PhantomData, sync::Mutex}; use burn_tensor::{ ops::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor}, quantization::QuantizationScheme, - repr::{HandleKind, QuantizedKind, ReprBackend, TensorHandle}, + repr::{HandleKind, ReprBackend, TensorHandle}, }; pub(crate) static SEED: Mutex> = Mutex::new(None); @@ -44,7 +40,7 @@ where type FloatTensorPrimitive = JitTensor; type IntTensorPrimitive = JitTensor; type BoolTensorPrimitive = JitTensor; - type QuantizedTensorPrimitive = QJitTensor; + type QuantizedTensorPrimitive = JitTensor; type QuantizedEncoding = u32; fn name() -> String { @@ -103,59 +99,37 @@ where impl ReprBackend for JitBackend { - type Handle = HandleKind; + type Handle = JitTensor; fn float_tensor(handle: TensorHandle) -> FloatTensor { - match handle.handle { - HandleKind::Float(handle) => handle, - _ => panic!("Expected float handle, got {}", handle.handle.name()), - } + handle.handle } fn int_tensor(handle: TensorHandle) -> IntTensor { - match handle.handle { - HandleKind::Int(handle) => handle, - _ => panic!("Expected int handle, got {}", handle.handle.name()), - } + handle.handle } fn bool_tensor(handle: TensorHandle) -> BoolTensor { - match handle.handle { - HandleKind::Bool(handle) => handle, - _ => panic!("Expected bool handle, got {}", handle.handle.name()), - } + handle.handle } - fn quantized_tensor( - handles: QuantizedKind>, - _scheme: QuantizationScheme, - ) -> QuantizedTensor { - let handle = handles.tensor.handle; - match handle { - HandleKind::Quantized(handle) => handle, - _ => panic!("Expected quantized handle, got {}", handle.name()), - } + fn quantized_tensor(handles: TensorHandle) -> QuantizedTensor { + handle.handle } fn float_tensor_handle(tensor: FloatTensor) -> Self::Handle { - HandleKind::Float(tensor) + tensor } fn int_tensor_handle(tensor: IntTensor) -> Self::Handle { - HandleKind::Int(tensor) + tensor } fn bool_tensor_handle(tensor: BoolTensor) -> Self::Handle { - HandleKind::Bool(tensor) + tensor } - fn quantized_tensor_handle(tensor: QuantizedTensor) -> QuantizedKind { - QuantizedKind { - tensor: HandleKind::Quantized(tensor), - // The quantized tensor primitive already encapsulates the required quantization - // parameters so we set the scale as an empty handle (unused). - scale: HandleKind::Empty, - offset: None, - } + fn quantized_tensor_handle(tensor: QuantizedTensor) -> Self::Handle { + tensor } } diff --git a/crates/burn-jit/src/fusion/base.rs b/crates/burn-jit/src/fusion/base.rs index 4572f580b5..96c22d0898 100644 --- a/crates/burn-jit/src/fusion/base.rs +++ b/crates/burn-jit/src/fusion/base.rs @@ -1,10 +1,8 @@ use super::elemwise::optimization::{ElemwiseOptimization, ElemwiseOptimizationState}; -use crate::tensor::{JitQuantizationParameters, QJitTensor}; use crate::{element::BoolElement, fusion::elemwise::builder::ElementWiseBuilder}; use crate::{kernel, tensor::JitTensor, FloatElement, IntElement, JitBackend, JitRuntime}; use burn_fusion::{client::MutexFusionClient, FusionBackend, FusionRuntime}; -use burn_tensor::quantization::QuantizationScheme; -use burn_tensor::repr::{QuantizedKind, TensorHandle}; +use burn_tensor::repr::TensorHandle; use burn_tensor::DType; use burn_tensor::{repr::ReprBackend, Shape}; use core::marker::PhantomData; @@ -80,23 +78,9 @@ impl ReprBackend } fn quantized_tensor( - handles: QuantizedKind>, - scheme: QuantizationScheme, + handle: TensorHandle, ) -> burn_tensor::ops::QuantizedTensor { - let qtensor = handles.tensor.handle.into_tensor(handles.tensor.shape); - let scale = handles.scale.handle.into_tensor(handles.scale.shape); - let offset = handles.offset; - - let qparams = JitQuantizationParameters { - scale, - offset: offset.map(|h| h.handle.into_tensor(h.shape)), - }; - - QJitTensor { - qtensor, - scheme, - qparams, - } + handle.handle.into_tensor(handle.shape) } fn float_tensor_handle(tensor: burn_tensor::ops::FloatTensor) -> Self::Handle { @@ -111,17 +95,8 @@ impl ReprBackend tensor.into() } - fn quantized_tensor_handle( - tensor: burn_tensor::ops::QuantizedTensor, - ) -> QuantizedKind { - let qtensor: JitFusionHandle = tensor.qtensor.into(); - let scale: JitFusionHandle = tensor.qparams.scale.into(); - - QuantizedKind { - tensor: qtensor, - scale, - offset: tensor.qparams.offset.map(|offset| offset.into()), - } + fn quantized_tensor_handle(tensor: burn_tensor::ops::QuantizedTensor) -> Self::Handle { + tensor.into() } } diff --git a/crates/burn-jit/src/kernel/matmul/base.rs b/crates/burn-jit/src/kernel/matmul/base.rs index e0b87e8931..7fa141cf67 100644 --- a/crates/burn-jit/src/kernel/matmul/base.rs +++ b/crates/burn-jit/src/kernel/matmul/base.rs @@ -43,7 +43,8 @@ pub fn matmul( &lhs.as_handle_ref(), &rhs.as_handle_ref(), &out.as_handle_ref(), - ); + ) + .unwrap(); out } #[cfg(feature = "autotune")] diff --git a/crates/burn-jit/src/kernel/matmul/tune/base.rs b/crates/burn-jit/src/kernel/matmul/tune/base.rs index 98951b8f89..ee17a61a3e 100644 --- a/crates/burn-jit/src/kernel/matmul/tune/base.rs +++ b/crates/burn-jit/src/kernel/matmul/tune/base.rs @@ -90,7 +90,8 @@ fn matmul_accelerated( &lhs.as_handle_ref(), &rhs.as_handle_ref(), &out.as_handle_ref(), - ); + ) + .unwrap(); } fn matmul_tiling2d( @@ -104,7 +105,8 @@ fn matmul_tiling2d( &lhs.as_handle_ref(), &rhs.as_handle_ref(), &out.as_handle_ref(), - ); + ) + .unwrap(); } fn matmul_simple( @@ -118,5 +120,6 @@ fn matmul_simple( &lhs.as_handle_ref(), &rhs.as_handle_ref(), &out.as_handle_ref(), - ); + ) + .unwrap(); } diff --git a/crates/burn-jit/src/kernel/quantization/dequantize.rs b/crates/burn-jit/src/kernel/quantization/dequantize.rs index 4e2aa89cf7..6d53b2effb 100644 --- a/crates/burn-jit/src/kernel/quantization/dequantize.rs +++ b/crates/burn-jit/src/kernel/quantization/dequantize.rs @@ -1,14 +1,21 @@ -use crate::tensor::{JitTensor, QJitTensor}; +use crate::tensor::JitTensor; use crate::FloatElement; -use crate::{IntElement, JitElement, JitRuntime}; +use crate::{JitElement, JitRuntime}; use burn_tensor::quantization::{QuantizationScheme, QuantizationType}; +use burn_tensor::DType; use cubecl::calculate_cube_count_elemwise; use cubecl::prelude::*; +use super::{QParams, QTensor}; + #[cube] -pub(crate) fn dequantize_affine_int8(value: i32, scale: F, offset: i32) -> F { +pub(crate) fn dequantize_affine_int8( + value: Line, + scale: f32, + offset: i32, +) -> Line { // x = scale * (x_q - offset) - scale * (F::cast_from(value) - F::cast_from(offset)) + Line::cast_from(scale) * Line::cast_from(value - Line::cast_from(offset)) } #[cube] @@ -21,199 +28,147 @@ pub(crate) fn extract_i8(value: u32, offset: u32) -> i32 { i32::cast_from(value) - sub } +#[cube] +pub(crate) fn extract_i8s(value: u32) -> Line { + let mut line = Line::empty(4); + // Extract each 8-bit segment + line[0] = extract_i8(value, 24); + line[1] = extract_i8(value, 16); + line[2] = extract_i8(value, 8); + line[3] = extract_i8(value, 0); + + line +} + #[cube(launch_unchecked)] pub(crate) fn dequantize_per_tensor_affine_int8_kernel( - input: &Tensor, - scale: &Tensor, - offset: &Tensor, - output: &mut Tensor, - #[comptime] vectorized: bool, + input: &QTensor, + output: &mut Tensor>, + #[comptime] scheme: QuantizationScheme, ) { - if ABSOLUTE_POS >= output.len() { + // Last two positions contain the qparams + if ABSOLUTE_POS >= input.len() - 2 { return; } - let scale = scale[0]; - let offset = offset[0]; + let qparams = QParams::new(scheme); + let (scale, offset) = qparams.values(input); - let num_packed = 4; let value = input[ABSOLUTE_POS]; - let output_pos = ABSOLUTE_POS * num_packed; - if vectorized { - let vectorization_factor = vectorization_of(input); + // Input line size is fixed to 1 + if comptime!(output.line_size() == 4) { + output[ABSOLUTE_POS] = dequantize_affine_int8(extract_i8s(value[0]), scale, offset); + } else { + // For very small inputs where number of elements < 4, the output line size is 1 + let out = dequantize_affine_int8::(extract_i8s(value[0]), scale, offset); + #[unroll] - for i in 0..vectorization_factor { - // Extract each 8-bit segment - let v1 = extract_i8(value[i], 24); - let v2 = extract_i8(value[i], 16); - let v3 = extract_i8(value[i], 8); - let v4 = extract_i8(value[i], 0); - - output[output_pos * vectorization_factor + i * num_packed] = - dequantize_affine_int8::(v1, scale, offset); - output[output_pos * vectorization_factor + i * num_packed + 1] = - dequantize_affine_int8::(v2, scale, offset); - output[output_pos * vectorization_factor + i * num_packed + 2] = - dequantize_affine_int8::(v3, scale, offset); - output[output_pos * vectorization_factor + i * num_packed + 3] = - dequantize_affine_int8::(v4, scale, offset); + for j in 0..out.size() { + output[ABSOLUTE_POS + j] = Line::cast_from(out[j]); } - } else { - // Extract each 8-bit segment - let v1 = extract_i8(value, 24); - let v2 = extract_i8(value, 16); - let v3 = extract_i8(value, 8); - let v4 = extract_i8(value, 0); - - output[output_pos] = dequantize_affine_int8::(v1, scale, offset); - output[output_pos + 1] = dequantize_affine_int8::(v2, scale, offset); - output[output_pos + 2] = dequantize_affine_int8::(v3, scale, offset); - output[output_pos + 3] = dequantize_affine_int8::(v4, scale, offset); } } #[cube] -pub(crate) fn dequantize_symmetric_int8(value: i32, scale: F) -> F { +pub(crate) fn dequantize_symmetric_int8(value: Line, scale: f32) -> Line { // x = scale * x_q - scale * F::cast_from(value) + Line::cast_from(scale) * Line::cast_from(value) } // Would have wrapped symmetric with the same affine kernel but cube doesn't support Option for offset. #[cube(launch_unchecked)] pub(crate) fn dequantize_per_tensor_symmetric_int8_kernel( - input: &Tensor, - scale: &Tensor, - output: &mut Tensor, - #[comptime] vectorized: bool, + input: &QTensor, + output: &mut Tensor>, + #[comptime] scheme: QuantizationScheme, ) { - if ABSOLUTE_POS >= output.len() { + // Last position contains the qparam + if ABSOLUTE_POS >= input.len() - 1 { return; } - let scale = scale[0]; + let qparams = QParams::new(scheme); + let (scale, _) = qparams.values(input); - let num_packed = 4; let value = input[ABSOLUTE_POS]; - let output_pos = ABSOLUTE_POS * num_packed; - if vectorized { - let vectorization_factor = vectorization_of(input); - #[unroll] - for i in 0..vectorization_factor { - for j in 0..num_packed { - let output_idx = output_pos * vectorization_factor + i * num_packed + j; - if output_idx >= output.len() { - return; // value not quantized (padding) - } - // Extract each 8-bit segment - let v = extract_i8(value[i], (3 - j) * 8); - output[output_idx] = dequantize_symmetric_int8::(v, scale); - } - } + // Input line size is fixed to 1 + if comptime!(output.line_size() == 4) { + output[ABSOLUTE_POS] = dequantize_symmetric_int8(extract_i8s(value[0]), scale); } else { - // Extract each 8-bit segment - for j in 0..num_packed { - let output_idx = output_pos + j; - if output_idx >= output.len() { - return; // value not quantized (padding) - } - // Extract each 8-bit segment - let v = extract_i8(value, (3 - j) * 8); - output[output_pos + j] = dequantize_symmetric_int8::(v, scale); + // For very small inputs where number of elements < 4, the output line size is 1 + let out = dequantize_symmetric_int8::(extract_i8s(value[0]), scale); + + #[unroll] + for j in 0..out.size() { + output[ABSOLUTE_POS + j] = Line::cast_from(out[j]); } } } -pub(crate) fn dequantize_per_tensor( - tensor: JitTensor, - scale: JitTensor, - offset: Option>, -) -> JitTensor +pub(crate) fn dequantize_per_tensor(tensor: JitTensor) -> JitTensor where R: JitRuntime, F: JitElement, - I: IntElement, { // The actual number of elements is 1/4 (four int8 values packed in a single u32) - // so we choose a vectorization factor to match a valid input binding size. - let ndims = tensor.shape.num_dims(); + // so we choose a line size to match a valid input binding size. let num_out_elems = tensor.shape.num_elements(); let num_elems = usize::div_ceil(num_out_elems, 4); - let vectorization_factor = [4u8, 2, 1] - .iter() - .filter_map(|&v| { - if num_elems >= v as usize { - Some(v) - } else { - None - } - }) - .next() - .unwrap(); + let line_size_in = 1; + let line_size_out = if num_out_elems < 4 { 1 } else { 4 }; let cube_dim = CubeDim::default(); - let cube_count = - calculate_cube_count_elemwise(num_elems / vectorization_factor as usize, cube_dim); + let cube_count = calculate_cube_count_elemwise(num_elems / line_size_in as usize, cube_dim); - let shape_output = tensor.shape.clone(); let client = tensor.client.clone(); let handle = client.empty(num_out_elems * core::mem::size_of::()); + let output = JitTensor::new_contiguous( client.clone(), tensor.device.clone(), - shape_output, + tensor.shape.clone(), handle, F::dtype(), ); - let dummy_array = vec![1; ndims]; - if let Some(offset) = offset { - unsafe { - dequantize_per_tensor_affine_int8_kernel::launch_unchecked::( - &client, - cube_count, - cube_dim, - tensor.as_tensor_arg::(vectorization_factor), - // Ignore shape and stride - TensorArg::from_raw_parts::(&scale.handle, &dummy_array, &dummy_array, 1), - TensorArg::from_raw_parts::(&offset.handle, &dummy_array, &dummy_array, 1), - output.as_tensor_arg::(1), - vectorization_factor > 1, - ) - }; - } else { - unsafe { - dequantize_per_tensor_symmetric_int8_kernel::launch_unchecked::( - &client, - cube_count, - cube_dim, - tensor.as_tensor_arg::(vectorization_factor), - // Ignore shape and stride - TensorArg::from_raw_parts::(&scale.handle, &dummy_array, &dummy_array, 1), - output.as_tensor_arg::(1), - vectorization_factor > 1, - ) - }; + if let DType::QFloat(scheme) = tensor.dtype { + match scheme { + QuantizationScheme::PerTensorAffine(QuantizationType::QInt8) => { + unsafe { + dequantize_per_tensor_affine_int8_kernel::launch_unchecked::( + &client, + cube_count, + cube_dim, + tensor.as_array_arg::(line_size_in), + output.as_tensor_arg::(line_size_out), + scheme, + ) + }; + } + QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8) => { + unsafe { + dequantize_per_tensor_symmetric_int8_kernel::launch_unchecked::( + &client, + cube_count, + cube_dim, + tensor.as_array_arg::(line_size_in), + output.as_tensor_arg::(line_size_out), + scheme, + ) + }; + } + } } output } /// Convert the tensor back to a higher precision data type. -pub fn dequantize(tensor: QJitTensor) -> JitTensor +pub fn dequantize(tensor: JitTensor) -> JitTensor where R: JitRuntime, F: FloatElement, - I: IntElement, { - match tensor.scheme { - QuantizationScheme::PerTensorAffine(dtype) - | QuantizationScheme::PerTensorSymmetric(dtype) => match dtype { - QuantizationType::QInt8 => dequantize_per_tensor::( - tensor.qtensor, - tensor.qparams.scale, - tensor.qparams.offset, - ), - }, - } + dequantize_per_tensor::(tensor) } diff --git a/crates/burn-jit/src/kernel/quantization/mod.rs b/crates/burn-jit/src/kernel/quantization/mod.rs index a0244df01f..2b47bb61b6 100644 --- a/crates/burn-jit/src/kernel/quantization/mod.rs +++ b/crates/burn-jit/src/kernel/quantization/mod.rs @@ -1,5 +1,7 @@ mod dequantize; +mod qtensor; mod quantize; pub use dequantize::*; +pub use qtensor::*; pub use quantize::*; diff --git a/crates/burn-jit/src/kernel/quantization/qtensor.rs b/crates/burn-jit/src/kernel/quantization/qtensor.rs new file mode 100644 index 0000000000..26d9f65091 --- /dev/null +++ b/crates/burn-jit/src/kernel/quantization/qtensor.rs @@ -0,0 +1,49 @@ +#![allow(missing_docs)] // cube derive macros + +use burn_tensor::quantization::QuantizationScheme; +use cubecl::prelude::*; + +/// Quantization parameters. +#[derive(CubeLaunch)] +pub struct QParams { + #[cube(comptime)] + scheme: QuantizationScheme, +} + +/// Quantized tensor representation. +pub type QTensor = Array>; + +#[cube] +impl QParams { + /// Create a new quantization parameters instance. + pub fn new(scheme: QuantizationScheme) -> Self { + QParams { scheme } + } + + /// Get the quantization parameters values. + pub fn values(&self, tensor: &QTensor) -> (f32, i32) { + let len = tensor.len(); + match comptime!(self.scheme) { + QuantizationScheme::PerTensorAffine(_) => match comptime!(tensor.line_size()) { + // For line size of 1, scale is the last value in the buffer + 1 => ( + f32::bitcast_from(tensor[len - 1][tensor.line_size() - 1]), + i32::cast_from(tensor[len - 2][tensor.line_size() - 1]), + ), + // For any other line size > 1, scale and zero-point offset are the last two elements + _ => { + let values = tensor[len - 1]; + ( + f32::bitcast_from(values[tensor.line_size() - 1]), + i32::cast_from(values[tensor.line_size() - 2]), + ) + } + }, + // Symmetric quantization only contains the scaling factor as the last element + QuantizationScheme::PerTensorSymmetric(_) => ( + f32::bitcast_from(tensor[len - 1][tensor.line_size() - 1]), + 0, + ), + } + } +} diff --git a/crates/burn-jit/src/kernel/quantization/quantize.rs b/crates/burn-jit/src/kernel/quantization/quantize.rs index 256ae0f418..13b9e3284a 100644 --- a/crates/burn-jit/src/kernel/quantization/quantize.rs +++ b/crates/burn-jit/src/kernel/quantization/quantize.rs @@ -1,4 +1,4 @@ -use crate::tensor::{JitQuantizationParameters, JitTensor, QJitTensor}; +use crate::tensor::JitTensor; use crate::FloatElement; use crate::{IntElement, JitElement, JitRuntime}; use burn_tensor::quantization::{QuantizationScheme, QuantizationType}; @@ -7,34 +7,31 @@ use cubecl::prelude::*; #[cube] pub(crate) fn quantize_affine_int8( - value: F, - scale: F, + value: Line, + scale: f32, offset: i32, - range_min: F, - range_max: F, -) -> u32 { - let offset = F::cast_from(offset); - + range_min: f32, + range_max: f32, +) -> Line { // x_q = clamp(round(x / scale + offset), a, b) // NOTE: we add 256 before casting to unsigned to correctly represent negative values - u32::cast_from( - i32::cast_from(F::clamp( - F::round((value / scale) + offset), - range_min, - range_max, - )) + 256, + Line::cast_from( + Line::clamp( + Line::round((value / Line::cast_from(scale)) + Line::cast_from(offset)), + Line::cast_from(range_min), + Line::cast_from(range_max), + ) + Line::cast_from(comptime!(256f32)), ) } #[cube(launch_unchecked)] pub(crate) fn quantize_per_tensor_affine_int8_kernel( - input: &Tensor, + input: &Tensor>, scale: &Tensor, offset: &Tensor, range_min: f32, range_max: f32, - output: &mut Tensor, - #[comptime] vectorized: bool, + output: &mut Array, ) { if ABSOLUTE_POS >= output.len() { return; @@ -43,20 +40,29 @@ pub(crate) fn quantize_per_tensor_affine_int8_kernel( let scale = scale[0]; let offset = offset[0]; - let num_packed = 4; - let mut v_packed = 0; + // Cast the scale to u32 and write the value in the output + if ABSOLUTE_POS == output.len() - 1 { + output[ABSOLUTE_POS] = u32::bitcast_from(scale); + return; + } - if vectorized { - // Assuming a vectorization factor of 4 (equal to the number of values packed) - let value = input[ABSOLUTE_POS]; - let vectorization_factor = vectorization_of(input); - #[unroll] - for i in 0..vectorization_factor { - let v = quantize_affine_int8::(value[i], scale, offset, range_min, range_max); - // Shift and combine into u32 - v_packed |= (v & 0xFF) << (8 * (num_packed - i - 1)); - } + // Cast the offset to u32 and write the value in the output + if ABSOLUTE_POS == output.len() - 2 { + output[ABSOLUTE_POS] = u32::bitcast_from(offset); + return; + } + + let line_size = comptime!(input.line_size()); + if comptime!(line_size == 4) { + // Assuming a line size of 4 (equal to the number of values packed) + let value = + quantize_affine_int8::(input[ABSOLUTE_POS], scale, offset, range_min, range_max); + // Shift and combine into u32 + output[ABSOLUTE_POS] = pack_i8s_to_u32s(value); } else { + let mut v_packed = 0; + let num_packed = comptime!(4); + #[unroll] for i in 0..num_packed { let v = quantize_affine_int8::( input[ABSOLUTE_POS + i], @@ -66,34 +72,52 @@ pub(crate) fn quantize_per_tensor_affine_int8_kernel( range_max, ); // Shift and combine into u32 - v_packed |= (v & 0xFF) << (8 * (num_packed - i - 1)); + v_packed |= (v[0] & 0xFF) << (8 * (num_packed - i - 1)); } + output[ABSOLUTE_POS] = v_packed; } - - output[ABSOLUTE_POS] = v_packed; } #[cube] pub(crate) fn quantize_symmetric_int8( - value: F, - scale: F, + value: Line, + scale: f32, range_min: F, range_max: F, -) -> u32 { +) -> Line { // x_q = clamp(round(x / scale), a, b) // NOTE: we add 256 before casting to unsigned to correctly represent negative values - u32::cast_from(i32::cast_from(F::clamp(F::round(value / scale), range_min, range_max)) + 256) + Line::cast_from( + Line::clamp( + Line::round(value / Line::cast_from(scale)), + Line::new(range_min), + Line::new(range_max), + ) + Line::cast_from(comptime!(256f32)), + ) +} + +#[cube] +pub(crate) fn pack_i8s_to_u32s(value: Line) -> u32 { + // NOTE: assuming line size of 4 + let line_size = value.size(); + let mut v_packed = 0; + + #[unroll] + for i in 0..line_size { + // Shift and combine into u32 + v_packed |= (value[i] & 0xFF) << (8 * (line_size - i - 1)); + } + v_packed } // Would have wrapped symmetric with the same affine kernel but cube doesn't support Option for offset. #[cube(launch_unchecked)] pub(crate) fn quantize_per_tensor_symmetric_int8_kernel( - input: &Tensor, + input: &Tensor>, scale: &Tensor, range_min: f32, range_max: f32, - output: &mut Tensor, - #[comptime] vectorized: bool, + output: &mut Array, ) { if ABSOLUTE_POS >= output.len() { return; @@ -101,20 +125,23 @@ pub(crate) fn quantize_per_tensor_symmetric_int8_kernel( let scale = scale[0]; - let num_packed = 4; - let mut v_packed = 0; + // Cast the scale to u32 and write the value in the output + if ABSOLUTE_POS == output.len() - 1 { + output[ABSOLUTE_POS] = u32::bitcast_from(scale); + return; + } - if vectorized { + let line_size = comptime!(input.line_size()); + if comptime!(line_size == 4) { // Assuming a vectorization factor of 4 (equal to the number of values packed) - let value = input[ABSOLUTE_POS]; - let vectorization_factor = vectorization_of(input); - #[unroll] - for i in 0..vectorization_factor { - let v = quantize_symmetric_int8::(value[i], scale, range_min, range_max); - // Shift and combine into u32 - v_packed |= (v & 0xFF) << (8 * (num_packed - i - 1)); - } + let value = + quantize_symmetric_int8::(input[ABSOLUTE_POS], scale, range_min, range_max); + // Shift and combine into u32 + output[ABSOLUTE_POS] = pack_i8s_to_u32s(value); } else { + let num_packed = comptime!(4); + let mut v_packed = 0; + #[unroll] for i in 0..num_packed { let v = quantize_symmetric_int8::( input[ABSOLUTE_POS + i], @@ -123,17 +150,17 @@ pub(crate) fn quantize_per_tensor_symmetric_int8_kernel( range_max, ); // Shift and combine into u32 - v_packed |= (v & 0xFF) << (8 * (num_packed - i - 1)); + v_packed |= (v[0] & 0xFF) << (8 * (num_packed - i - 1)); } + output[ABSOLUTE_POS] = v_packed; } - - output[ABSOLUTE_POS] = v_packed; } pub(crate) fn quantize_per_tensor( tensor: JitTensor, scale: JitTensor, offset: Option>, + scheme: QuantizationScheme, ) -> JitTensor where R: JitRuntime, @@ -142,86 +169,90 @@ where { let ndims = tensor.shape.num_dims(); let num_elems = tensor.shape.num_elements(); - let shape_output = tensor.shape.clone(); let client = tensor.client.clone(); // Output tensor contains 4x less elements (four int8 values packed in a single u32) - let handle = client.empty(usize::div_ceil(num_elems, 4) * core::mem::size_of::()); - let output = JitTensor::new_contiguous( - client.clone(), - tensor.device.clone(), - shape_output, - handle, - burn_tensor::DType::U32, - ); + let output_num_elems = usize::div_ceil(num_elems, 4) * core::mem::size_of::(); // Force vectorization to process 4 quantized values packed for 1 output value - let vectorization_factor: u8 = if num_elems < 4 { 1 } else { 4 }; + let line_size: u8 = if num_elems < 4 { 1 } else { 4 }; let cube_dim = CubeDim::default(); - let cube_count = - calculate_cube_count_elemwise(num_elems / vectorization_factor as usize, cube_dim); + let cube_count = calculate_cube_count_elemwise(num_elems / line_size as usize, cube_dim); let dummy_array = vec![1; ndims]; if let Some(offset) = offset { + // Scale and offset qparams are also packed in the tensor dat + let handle = client + .empty(output_num_elems + core::mem::size_of::() + core::mem::size_of::()); + let output = JitTensor::new_contiguous( + client.clone(), + tensor.device.clone(), + tensor.shape.clone(), + handle, + burn_tensor::DType::QFloat(scheme), + ); + unsafe { quantize_per_tensor_affine_int8_kernel::launch_unchecked::( &client, cube_count, cube_dim, - tensor.as_tensor_arg::(vectorization_factor), + tensor.as_tensor_arg::(line_size), // Ignore shape and stride TensorArg::from_raw_parts::(&scale.handle, &dummy_array, &dummy_array, 1), TensorArg::from_raw_parts::(&offset.handle, &dummy_array, &dummy_array, 1), ScalarArg::new(i8::MIN as f32), ScalarArg::new(i8::MAX as f32), - output.as_tensor_arg::(1), - vectorization_factor > 1, + output.as_array_arg::(1), ) }; + output } else { + // Scale qparam is also packed in the tensor data + let handle = client.empty(output_num_elems + core::mem::size_of::()); + let output = JitTensor::new_contiguous( + client.clone(), + tensor.device.clone(), + tensor.shape.clone(), + handle, + burn_tensor::DType::QFloat(scheme), + ); + unsafe { quantize_per_tensor_symmetric_int8_kernel::launch_unchecked::( &client, cube_count, cube_dim, - tensor.as_tensor_arg::(vectorization_factor), + tensor.as_tensor_arg::(line_size), // Ignore shape and stride TensorArg::from_raw_parts::(&scale.handle, &dummy_array, &dummy_array, 1), ScalarArg::new(-i8::MAX as f32), ScalarArg::new(i8::MAX as f32), - output.as_tensor_arg::(1), - vectorization_factor > 1, + output.as_array_arg::(1), ) }; - } - output + output + } } /// Convert the tensor to a lower precision data type based on the quantization scheme and parameters. pub fn quantize( tensor: JitTensor, scheme: &QuantizationScheme, - qparams: JitQuantizationParameters, -) -> QJitTensor + scale: JitTensor, + offset: Option>, +) -> JitTensor where R: JitRuntime, F: FloatElement, I: IntElement, { - let qtensor = match scheme { + match scheme { QuantizationScheme::PerTensorAffine(dtype) | QuantizationScheme::PerTensorSymmetric(dtype) => match dtype { - QuantizationType::QInt8 => quantize_per_tensor::( - tensor, - qparams.scale.clone(), - qparams.offset.clone(), - ), + QuantizationType::QInt8 => { + quantize_per_tensor::(tensor, scale, offset, *scheme) + } }, - }; - - QJitTensor { - qtensor, - scheme: *scheme, - qparams, } } diff --git a/crates/burn-jit/src/ops/qtensor.rs b/crates/burn-jit/src/ops/qtensor.rs index 94b1a6f2ee..cde44b0552 100644 --- a/crates/burn-jit/src/ops/qtensor.rs +++ b/crates/burn-jit/src/ops/qtensor.rs @@ -2,30 +2,32 @@ use std::ops::Range; use burn_tensor::{ ops::{FloatTensor, IntTensor, QTensorOps, QuantizedTensor}, - quantization::{ - QTensorPrimitive, QuantizationParametersPrimitive, QuantizationScheme, QuantizationType, - }, + quantization::{QuantizationParametersPrimitive, QuantizationScheme, QuantizationType}, DType, Device, Shape, TensorData, }; use crate::{ - element::BoolElement, - kernel, - tensor::{JitQuantizationParameters, JitTensor, QJitTensor}, - FloatElement, IntElement, JitBackend, JitRuntime, + element::BoolElement, kernel, tensor::JitTensor, FloatElement, IntElement, JitBackend, + JitRuntime, }; -use cubecl::CubeElement; /// Create a quantized tensor with packed values (u32). -fn packed_tensor>( +fn new_qtensor>( data: &[u8], shape: S, + scheme: QuantizationScheme, device: &R::Device, ) -> JitTensor { let client = R::client(device); let buffer = client.create(data); - JitTensor::new_contiguous(client, device.clone(), shape.into(), buffer, DType::U32) + JitTensor::new_contiguous( + client, + device.clone(), + shape.into(), + buffer, + DType::QFloat(scheme), + ) } impl QTensorOps for JitBackend @@ -40,17 +42,9 @@ where DType::QFloat(scheme) => match scheme { QuantizationScheme::PerTensorAffine(QuantizationType::QInt8) | QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8) => { - // Convert quantized values to packed u32s - let qparams = data.get_q_params::().unwrap(); - QJitTensor { - qtensor: packed_tensor(data.values_as_bytes(), data.shape.clone(), device), - scheme, - qparams: JitQuantizationParameters::new( - qparams.scale, - qparams.offset, - device, - ), - } + // TensorData quantized representation is the same, with multiple quantized values + // packed into u32 and quantization parameters appended to the bytes + new_qtensor(data.as_bytes(), data.shape.clone(), scheme, device) } }, _ => panic!( @@ -65,56 +59,37 @@ where scheme: &QuantizationScheme, qparams: QuantizationParametersPrimitive, ) -> QuantizedTensor { - kernel::quantization::quantize::(tensor, scheme, qparams.into()) + kernel::quantization::quantize::(tensor, scheme, qparams.scale, qparams.offset) } fn dequantize(tensor: QuantizedTensor) -> FloatTensor { - kernel::quantization::dequantize::(tensor) - } - - fn q_shape(tensor: &QuantizedTensor) -> Shape { - tensor.qtensor.shape.clone() + kernel::quantization::dequantize::(tensor) } fn q_device(tensor: &QuantizedTensor) -> Device { - tensor.qtensor.device.clone() + tensor.device.clone() } fn q_to_device(tensor: QuantizedTensor, device: &Device) -> QuantizedTensor { - let mut tensor = tensor; - tensor.qtensor = super::to_device(tensor.qtensor, device); - tensor.qparams.scale = super::to_device(tensor.qparams.scale, device); - tensor.qparams.offset = tensor.qparams.offset.map(|x| super::to_device(x, device)); - - tensor + super::to_device(tensor, device) } fn q_reshape(tensor: QuantizedTensor, shape: Shape) -> QuantizedTensor { - QJitTensor { - qtensor: super::reshape(tensor.qtensor, shape), - scheme: tensor.scheme, - qparams: tensor.qparams, - } + super::reshape(tensor, shape) } async fn q_into_data(tensor: QuantizedTensor) -> TensorData { - let strategy = tensor.strategy(); - let qtensor = kernel::into_contiguous(tensor.qtensor); - - let bytes = qtensor - .client - .read_one_async(qtensor.handle.binding()) - .await; - - // TensorData keeps quantized values packed into 32-bit unsigned integers so we can - // keep the current representation, just cast the bytes as u32. - match &tensor.scheme { - QuantizationScheme::PerTensorAffine(dtype) - | QuantizationScheme::PerTensorSymmetric(dtype) => match dtype { - QuantizationType::QInt8 => { - TensorData::quantized(u32::from_bytes(&bytes).to_vec(), qtensor.shape, strategy) - } - }, + let tensor = kernel::into_contiguous(tensor); + let bytes = tensor.client.read_one_async(tensor.handle.binding()).await; + + // TODO: this should be refactored such that the bytes type is opaque. + // With this, the logic for handling the bytes representation of quantized data + // (as well as all data manipulations) will be encapsulated in the type. + // Creating a TensorData struct directly from some bytes should probably not be possible outside of the crate. + TensorData { + bytes, + shape: tensor.shape.into(), + dtype: tensor.dtype, } } diff --git a/crates/burn-jit/src/tensor/base.rs b/crates/burn-jit/src/tensor/base.rs index 3eb44b3e02..e114b2f8e6 100644 --- a/crates/burn-jit/src/tensor/base.rs +++ b/crates/burn-jit/src/tensor/base.rs @@ -1,6 +1,7 @@ use crate::element::JitElement; use crate::kernel::{launch_unary, unary_op, UnaryOp}; use crate::JitRuntime; +use burn_tensor::quantization::QTensorPrimitive; use burn_tensor::{DType, Shape, TensorMetadata}; use cubecl::client::ComputeClient; use cubecl::frontend::Numeric; @@ -73,6 +74,19 @@ impl TensorMetadata for JitTensor { } } +impl QTensorPrimitive for JitTensor { + fn scheme(&self) -> &burn_tensor::quantization::QuantizationScheme { + if let DType::QFloat(scheme) = &self.dtype { + scheme + } else { + panic!( + "Quantization scheme is not valid for dtype {:?}", + self.dtype, + ) + } + } +} + /// Macro to execute a kernel/operation for a given element type. /// /// # Panics @@ -241,7 +255,16 @@ where strides: &self.strides, shape: &self.shape.dims, runtime: PhantomData, - elem_size: self.dtype.size(), + elem_size: self.elem_size(), + } + } + + fn elem_size(&self) -> usize { + if let DType::QFloat(_) = self.dtype { + // Encoded as u32 + core::mem::size_of::() + } else { + self.dtype.size() } } @@ -259,6 +282,17 @@ where } } + /// Return the reference to an array argument. + pub fn as_array_arg(&self, vectorisation: u8) -> ArrayArg<'_, R> { + unsafe { + ArrayArg::from_raw_parts::( + &self.handle, + self.handle.size() as usize / core::mem::size_of::(), + vectorisation, + ) + } + } + pub(crate) fn can_mut_broadcast(&self, rhs: &Self) -> bool { if !self.handle.can_mut() || !self.is_contiguous_buffer() { return false; diff --git a/crates/burn-jit/src/tensor/mod.rs b/crates/burn-jit/src/tensor/mod.rs index 960a77e445..cbcb6ac7e7 100644 --- a/crates/burn-jit/src/tensor/mod.rs +++ b/crates/burn-jit/src/tensor/mod.rs @@ -1,5 +1,3 @@ mod base; -mod qtensor; pub use base::*; -pub(crate) use qtensor::*; diff --git a/crates/burn-jit/src/tensor/qtensor.rs b/crates/burn-jit/src/tensor/qtensor.rs deleted file mode 100644 index 4ef5f77589..0000000000 --- a/crates/burn-jit/src/tensor/qtensor.rs +++ /dev/null @@ -1,125 +0,0 @@ -use burn_tensor::{ - quantization::{ - AffineQuantization, QTensorPrimitive, QuantizationParametersPrimitive, QuantizationScheme, - QuantizationStrategy, QuantizationType, SymmetricQuantization, - }, - read_sync, DType, TensorData, TensorMetadata, -}; - -use crate::{ - element::BoolElement, ops::into_data, FloatElement, IntElement, JitBackend, JitRuntime, -}; - -use super::JitTensor; - -/// A quantized tensor primitive. -#[derive(Debug)] -pub struct QJitTensor { - /// The quantized tensor. - /// Values are stored as multiple packed quantized values in u32. - pub qtensor: JitTensor, - /// The quantization scheme. - pub scheme: QuantizationScheme, - /// The quantization parameters. - pub qparams: JitQuantizationParameters, -} - -impl QTensorPrimitive for QJitTensor { - fn scheme(&self) -> &QuantizationScheme { - &self.scheme - } - - fn strategy(&self) -> QuantizationStrategy { - match &self.scheme { - QuantizationScheme::PerTensorAffine(dtype) => match dtype { - QuantizationType::QInt8 => { - let scale = read_sync(into_data::(self.qparams.scale.clone())) - .iter() - .next() - .unwrap(); - let offset = - read_sync(into_data::(self.qparams.offset.clone().unwrap())) - .iter() - .next() - .unwrap(); - QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init( - scale, offset, - )) - } - }, - QuantizationScheme::PerTensorSymmetric(dtype) => match dtype { - QuantizationType::QInt8 => { - let scale = read_sync(into_data::(self.qparams.scale.clone())) - .iter() - .next() - .unwrap(); - QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(scale)) - } - }, - } - } -} - -impl Clone for QJitTensor { - fn clone(&self) -> Self { - Self { - qtensor: self.qtensor.clone(), - scheme: self.scheme, - qparams: self.qparams.clone(), - } - } -} - -impl TensorMetadata for QJitTensor { - fn dtype(&self) -> DType { - DType::QFloat(self.scheme) - } - - fn shape(&self) -> burn_tensor::Shape { - self.qtensor.shape() - } -} - -/// The quantization parameters. -#[derive(Debug)] -pub struct JitQuantizationParameters { - /// The scaling factor. - pub scale: JitTensor, - /// The zero-point offset. - pub offset: Option>, -} - -impl Clone for JitQuantizationParameters { - fn clone(&self) -> Self { - Self { - scale: self.scale.clone(), - offset: self.offset.clone(), - } - } -} - -impl - From>> - for JitQuantizationParameters -{ - fn from(value: QuantizationParametersPrimitive>) -> Self { - JitQuantizationParameters { - scale: value.scale, - offset: value.offset, - } - } -} - -impl JitQuantizationParameters { - pub fn new( - scale: F, - offset: Option, - device: &R::Device, - ) -> Self { - Self { - scale: crate::ops::from_data::(TensorData::new(vec![scale], [1]), device), - offset: offset - .map(|o| crate::ops::from_data::(TensorData::new(vec![o], [1]), device)), - } - } -} diff --git a/crates/burn-ndarray/src/backend.rs b/crates/burn-ndarray/src/backend.rs index 060899b979..6ce29e79cf 100644 --- a/crates/burn-ndarray/src/backend.rs +++ b/crates/burn-ndarray/src/backend.rs @@ -4,8 +4,7 @@ use alloc::string::String; use burn_common::stub::Mutex; use burn_tensor::backend::{Backend, DeviceId, DeviceOps}; use burn_tensor::ops::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor}; -use burn_tensor::quantization::QuantizationScheme; -use burn_tensor::repr::{HandleKind, QuantizedKind, ReprBackend, TensorHandle}; +use burn_tensor::repr::{HandleKind, ReprBackend, TensorHandle}; use core::marker::PhantomData; use rand::{rngs::StdRng, SeedableRng}; @@ -99,14 +98,10 @@ impl ReprBackend } } - fn quantized_tensor( - handles: QuantizedKind>, - _scheme: QuantizationScheme, - ) -> QuantizedTensor { - let handle = handles.tensor.handle; - match handle { + fn quantized_tensor(handle: TensorHandle) -> QuantizedTensor { + match handle.handle { HandleKind::Quantized(handle) => handle, - _ => panic!("Expected quantized handle, got {}", handle.name()), + _ => panic!("Expected quantized handle, got {}", handle.handle.name()), } } @@ -122,13 +117,7 @@ impl ReprBackend HandleKind::Bool(tensor) } - fn quantized_tensor_handle(tensor: QuantizedTensor) -> QuantizedKind { - QuantizedKind { - tensor: HandleKind::Quantized(tensor), - // The quantized tensor primitive already encapsulates the required quantization - // parameters so we set the scale as an empty handle (unused). - scale: HandleKind::Empty, - offset: None, - } + fn quantized_tensor_handle(tensor: QuantizedTensor) -> Self::Handle { + HandleKind::Quantized(tensor) } } diff --git a/crates/burn-ndarray/src/ops/qtensor.rs b/crates/burn-ndarray/src/ops/qtensor.rs index d610d39804..9f522b6155 100644 --- a/crates/burn-ndarray/src/ops/qtensor.rs +++ b/crates/burn-ndarray/src/ops/qtensor.rs @@ -3,8 +3,8 @@ use core::ops::Range; use burn_tensor::{ ops::{FloatTensor, IntTensor, QTensorOps, QuantizedTensor}, quantization::{ - AffineQuantization, QParams, QTensorPrimitive, QuantizationParametersPrimitive, - QuantizationScheme, QuantizationStrategy, QuantizationType, SymmetricQuantization, + AffineQuantization, QParams, QuantizationParametersPrimitive, QuantizationScheme, + QuantizationStrategy, QuantizationType, SymmetricQuantization, }, DType, Shape, TensorData, TensorMetadata, }; @@ -108,10 +108,6 @@ impl QTensorOps) -> Shape { - tensor.qtensor.shape() - } - fn q_device(_tensor: &QuantizedTensor) -> NdArrayDevice { NdArrayDevice::Cpu } diff --git a/crates/burn-ndarray/src/tensor.rs b/crates/burn-ndarray/src/tensor.rs index 64e8037c91..698119ecea 100644 --- a/crates/burn-ndarray/src/tensor.rs +++ b/crates/burn-ndarray/src/tensor.rs @@ -352,12 +352,9 @@ pub struct NdArrayQTensor { pub qparams: QParams, } -impl QTensorPrimitive for NdArrayQTensor { - fn scheme(&self) -> &QuantizationScheme { - &self.scheme - } - - fn strategy(&self) -> QuantizationStrategy { +impl NdArrayQTensor { + /// Returns the quantization strategy, including quantization parameters, for the given tensor. + pub fn strategy(&self) -> QuantizationStrategy { match self.scheme { QuantizationScheme::PerTensorAffine(QuantizationType::QInt8) => { QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init( @@ -374,6 +371,12 @@ impl QTensorPrimitive for NdArrayQTensor { } } +impl QTensorPrimitive for NdArrayQTensor { + fn scheme(&self) -> &QuantizationScheme { + &self.scheme + } +} + impl TensorMetadata for NdArrayQTensor { fn dtype(&self) -> DType { DType::QFloat(self.scheme) diff --git a/crates/burn-router/src/backend.rs b/crates/burn-router/src/backend.rs index a5ada5e5fd..d809bbb694 100644 --- a/crates/burn-router/src/backend.rs +++ b/crates/burn-router/src/backend.rs @@ -3,7 +3,7 @@ use core::marker::PhantomData; use burn_tensor::{ backend::Backend, - quantization::{QTensorPrimitive, QuantizationScheme, QuantizationStrategy}, + quantization::{QTensorPrimitive, QuantizationScheme}, }; use super::{get_client, set_seed, RouterTensor, RunnerChannel, RunnerClient}; @@ -36,10 +36,6 @@ impl QTensorPrimitive for RouterTensor { fn scheme(&self) -> &QuantizationScheme { todo!() } - - fn strategy(&self) -> QuantizationStrategy { - todo!() - } } impl Backend for BackendRouter { diff --git a/crates/burn-router/src/ops/op_qfloat.rs b/crates/burn-router/src/ops/op_qfloat.rs index 1f4784aceb..2ac092505c 100644 --- a/crates/burn-router/src/ops/op_qfloat.rs +++ b/crates/burn-router/src/ops/op_qfloat.rs @@ -32,10 +32,6 @@ impl QTensorOps for BackendRouter { unimplemented!() } - fn q_shape(_tensor: &QuantizedTensor) -> Shape { - unimplemented!() - } - fn q_device(_tensor: &QuantizedTensor) -> Device { unimplemented!() } diff --git a/crates/burn-tch/src/ops/qtensor.rs b/crates/burn-tch/src/ops/qtensor.rs index ec644d066c..0a7ef1c79b 100644 --- a/crates/burn-tch/src/ops/qtensor.rs +++ b/crates/burn-tch/src/ops/qtensor.rs @@ -3,8 +3,7 @@ use std::ops::Range; use burn_tensor::{ ops::{FloatTensor, IntTensor, QTensorOps, QuantizedTensor}, quantization::{ - QParams, QTensorPrimitive, QuantizationParametersPrimitive, QuantizationScheme, - QuantizationType, + QParams, QuantizationParametersPrimitive, QuantizationScheme, QuantizationType, }, DType, Shape, TensorData, TensorMetadata, }; @@ -132,10 +131,6 @@ impl QTensorOps for LibTorch { TchTensor::new(tensor.qtensor.tensor.dequantize().to_kind(E::KIND)) } - fn q_shape(tensor: &QuantizedTensor) -> Shape { - tensor.qtensor.shape() - } - fn q_device(tensor: &QuantizedTensor) -> LibTorchDevice { tensor.qtensor.tensor.device().into() } @@ -157,7 +152,7 @@ impl QTensorOps for LibTorch { } async fn q_into_data(tensor: QuantizedTensor) -> TensorData { - let shape = Self::q_shape(&tensor); + let shape = tensor.shape(); let tensor = Self::q_reshape(tensor.clone(), Shape::new([shape.num_elements()])); let strategy = tensor.strategy(); diff --git a/crates/burn-tch/src/tensor.rs b/crates/burn-tch/src/tensor.rs index b634954c8e..1c9dbf594e 100644 --- a/crates/burn-tch/src/tensor.rs +++ b/crates/burn-tch/src/tensor.rs @@ -328,22 +328,9 @@ pub struct TchQTensor { pub scheme: QuantizationScheme, } -impl TensorMetadata for TchQTensor { - fn dtype(&self) -> DType { - DType::QFloat(self.scheme) - } - - fn shape(&self) -> Shape { - self.qtensor.shape() - } -} - -impl QTensorPrimitive for TchQTensor { - fn scheme(&self) -> &QuantizationScheme { - &self.scheme - } - - fn strategy(&self) -> QuantizationStrategy { +impl TchQTensor { + /// Returns the quantization strategy, including quantization parameters, for the given tensor. + pub fn strategy(&self) -> QuantizationStrategy { match &self.scheme { QuantizationScheme::PerTensorAffine(dtype) => match dtype { QuantizationType::QInt8 => { @@ -367,6 +354,22 @@ impl QTensorPrimitive for TchQTensor { } } +impl TensorMetadata for TchQTensor { + fn dtype(&self) -> DType { + DType::QFloat(self.scheme) + } + + fn shape(&self) -> Shape { + self.qtensor.shape() + } +} + +impl QTensorPrimitive for TchQTensor { + fn scheme(&self) -> &QuantizationScheme { + &self.scheme + } +} + #[cfg(test)] mod tests { use crate::LibTorch; diff --git a/crates/burn-tensor/src/repr/backend.rs b/crates/burn-tensor/src/repr/backend.rs index dad341697b..696da20c49 100644 --- a/crates/burn-tensor/src/repr/backend.rs +++ b/crates/burn-tensor/src/repr/backend.rs @@ -1,7 +1,6 @@ use crate::{ backend::Backend, ops::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor}, - quantization::QuantizationScheme, Shape, }; @@ -14,17 +13,6 @@ pub struct TensorHandle { pub shape: Shape, } -/// A simple struct to encapsulate a quantized tensor kind. -#[derive(Clone)] -pub struct QuantizedKind { - /// The quantized tensor. - pub tensor: T, - /// The scaling factor. - pub scale: T, - /// The zero-point offset. - pub offset: Option, -} - /// Backend extension trait that allows an existing [backend](Backend) to use the Burn tensor representation /// for compilation purpose or other... pub trait ReprBackend: Backend { @@ -38,10 +26,7 @@ pub trait ReprBackend: Backend { /// Convert a [handle](ReprBackend::Handle) to a [bool tensor](Backend::BoolTensorPrimitive). fn bool_tensor(handle: TensorHandle) -> BoolTensor; /// Convert a [handle](ReprBackend::Handle) to a [quantized tensor](Backend::QuantizedTensorPrimitive). - fn quantized_tensor( - handle: QuantizedKind>, - scheme: QuantizationScheme, - ) -> QuantizedTensor; + fn quantized_tensor(handle: TensorHandle) -> QuantizedTensor; /// Convert a [float tensor](Backend::FloatTensorPrimitive) to a [handle](ReprBackend::Handle). fn float_tensor_handle(tensor: FloatTensor) -> Self::Handle; @@ -50,8 +35,7 @@ pub trait ReprBackend: Backend { /// Convert a [bool tensor](Backend::BoolTensorPrimitive) to a [handle](ReprBackend::Handle). fn bool_tensor_handle(tensor: BoolTensor) -> Self::Handle; /// Convert a [quantized tensor](Backend::QuantizedTensorPrimitive) to a [handle](ReprBackend::Handle). - /// A quantized tensor has multiple handles for the tensor itself and the quantization parameters. - fn quantized_tensor_handle(tensor: QuantizedTensor) -> QuantizedKind; + fn quantized_tensor_handle(tensor: QuantizedTensor) -> Self::Handle; } /// Handle which points to a backend tensor primitive kind. @@ -65,8 +49,6 @@ pub enum HandleKind { Bool(B::BoolTensorPrimitive), /// Quantized tensor handle. Quantized(B::QuantizedTensorPrimitive), - /// Empty handle (used as a dummy representation). - Empty, } impl HandleKind { @@ -77,7 +59,6 @@ impl HandleKind { HandleKind::Int(_) => "int", HandleKind::Bool(_) => "bool", HandleKind::Quantized(_) => "quantized", - HandleKind::Empty => unreachable!(), // should not happen } } } diff --git a/crates/burn-tensor/src/repr/handle.rs b/crates/burn-tensor/src/repr/handle.rs index 3da083c0e6..85e18ec444 100644 --- a/crates/burn-tensor/src/repr/handle.rs +++ b/crates/burn-tensor/src/repr/handle.rs @@ -14,7 +14,7 @@ use alloc::sync::Arc; #[cfg(not(target_has_atomic = "ptr"))] use portable_atomic_util::Arc; -use super::{QuantizedKind, QuantizedTensorDescription, TensorHandle}; +use super::TensorHandle; /// Keep all [tensor handles](ReprBackend::Handle) in one place and ensure that all resources /// are used optimally. @@ -124,21 +124,12 @@ impl HandleContainer { /// given [tensor description](TensorDescription). pub fn get_quantized_tensor( &mut self, - tensor: &QuantizedTensorDescription, + tensor: &TensorDescription, ) -> B::QuantizedTensorPrimitive where B: ReprBackend, { - let handles = QuantizedKind { - tensor: self.get_tensor_handle(&tensor.tensor), - scale: self.get_tensor_handle(&tensor.qparams.scale), - offset: tensor - .qparams - .offset - .as_ref() - .map(|offset| self.get_tensor_handle(offset)), - }; - B::quantized_tensor(handles, tensor.scheme) + B::quantized_tensor(self.get_tensor_handle(tensor)) } /// Register a new [float tensor](crate::backend::Backend::FloatTensorPrimitive) with the corresponding [tensor id](TensorId). @@ -153,21 +144,13 @@ impl HandleContainer { /// Register a new [quantized tensor](crate::backend::Backend::QuantizedTensorPrimitive) with the corresponding [tensor ids](TensorId). pub fn register_quantized_tensor( &mut self, - id: &QuantizedKind, + id: &TensorId, tensor: B::QuantizedTensorPrimitive, ) where B: ReprBackend, { - let handles = B::quantized_tensor_handle(tensor); - - self.handles - .insert(id.tensor, Handle::Existing(handles.tensor)); - self.handles - .insert(id.scale, Handle::Existing(handles.scale)); - - if let (Some(id), Some(handle)) = (id.offset, handles.offset) { - self.handles.insert(id, Handle::Existing(handle)); - } + let handle = B::quantized_tensor_handle(tensor); + self.handles.insert(*id, Handle::Existing(handle)); } /// Register a new [int tensor](crate::backend::Backend::IntTensorPrimitive) with the corresponding [tensor id](TensorId). diff --git a/crates/burn-tensor/src/repr/mod.rs b/crates/burn-tensor/src/repr/mod.rs index e26565b353..b98e43ba3d 100644 --- a/crates/burn-tensor/src/repr/mod.rs +++ b/crates/burn-tensor/src/repr/mod.rs @@ -1,11 +1,9 @@ mod backend; mod handle; mod operation; -mod quantization; mod tensor; pub use backend::*; pub use handle::*; pub use operation::*; -pub use quantization::*; pub use tensor::*; diff --git a/crates/burn-tensor/src/repr/operation.rs b/crates/burn-tensor/src/repr/operation.rs index 2528f44305..001b9d6e83 100644 --- a/crates/burn-tensor/src/repr/operation.rs +++ b/crates/burn-tensor/src/repr/operation.rs @@ -15,8 +15,6 @@ use crate::{ DType, Distribution, Element, }; -use super::{QuantizationParametersDescription, QuantizedTensorDescription}; - /// Custom operation in fusion stream, declaring it's inputs and outputs. #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] pub struct CustomOpDescription { @@ -906,6 +904,15 @@ pub struct ConvTranspose3dOptionsDescription { pub groups: usize, } +/// Quantization parameters description. +#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)] +pub struct QuantizationParametersDescription { + /// The scaling factor. + pub scale: TensorDescription, + /// The zero-point offset. + pub offset: Option, +} + #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct QuantizeOperationDescription { @@ -918,7 +925,7 @@ pub struct QuantizeOperationDescription { #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct DequantizeOperationDescription { - pub qtensor: QuantizedTensorDescription, + pub input: TensorDescription, pub out: TensorDescription, } @@ -1528,18 +1535,7 @@ impl FloatOperationDescription { vec![&desc.tensor, &desc.qparams.scale, &desc.out] } } - FloatOperationDescription::Dequantize(desc) => { - if let Some(offset) = &desc.qtensor.qparams.offset { - vec![ - &desc.qtensor.tensor, - &desc.qtensor.qparams.scale, - &offset, - &desc.out, - ] - } else { - vec![&desc.qtensor.tensor, &desc.qtensor.qparams.scale, &desc.out] - } - } + FloatOperationDescription::Dequantize(desc) => vec![&desc.input, &desc.out], } } } diff --git a/crates/burn-tensor/src/repr/quantization.rs b/crates/burn-tensor/src/repr/quantization.rs deleted file mode 100644 index 8a38a35a44..0000000000 --- a/crates/burn-tensor/src/repr/quantization.rs +++ /dev/null @@ -1,25 +0,0 @@ -use serde::{Deserialize, Serialize}; - -use crate::quantization::QuantizationScheme; - -use super::TensorDescription; - -/// A quantized tensor description represents a snapshot of a quantized tensor when it was used. -#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)] -pub struct QuantizedTensorDescription { - /// The quantized tensor. - pub tensor: TensorDescription, - /// The quantization parameters. - pub qparams: QuantizationParametersDescription, - /// The quantization scheme - pub scheme: QuantizationScheme, -} - -/// Quantization parameters description. -#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)] -pub struct QuantizationParametersDescription { - /// The scaling factor. - pub scale: TensorDescription, - /// The zero-point offset. - pub offset: Option, -} diff --git a/crates/burn-tensor/src/tensor/data.rs b/crates/burn-tensor/src/tensor/data.rs index d075d8e409..d52a181929 100644 --- a/crates/burn-tensor/src/tensor/data.rs +++ b/crates/burn-tensor/src/tensor/data.rs @@ -93,6 +93,8 @@ impl TensorData { shape: S, strategy: QuantizationStrategy, ) -> Self { + // TODO: this method should go into a dedicated Bytes opaque type with other bytes + // handling logic let mut value = into_bytes(value); // Notes on quantization data representation: @@ -111,8 +113,10 @@ impl TensorData { } else { panic!("Invalid quantized type"); } + // Scale is always stored as f32 and zero-point offset as i32 + let offset = q.offset as i32; let scale_bytes = bytemuck::bytes_of(&q.scale); - let offset_bytes = bytemuck::bytes_of(&q.offset); + let offset_bytes = bytemuck::bytes_of(&offset); value.extend_from_slice(offset_bytes); value.extend_from_slice(scale_bytes); } @@ -446,7 +450,7 @@ impl TensorData { let mut tensor_bytes_end = self.bytes.len() - scale_size; if let QuantizationScheme::PerTensorAffine(QuantizationType::QInt8) = scheme { - tensor_bytes_end -= core::mem::size_of::(); + tensor_bytes_end -= core::mem::size_of::(); // zero-point offset is stored as i32 } &self.bytes[..tensor_bytes_end] @@ -483,17 +487,17 @@ impl TensorData { // Quantization parameters are added at the end of the tensor data. // As such, the last bytes always correspond to the scale parameter. // If the quantization scheme includes an offset (zero-point) parameter, it is next to last. - let scale_size = core::mem::size_of::(); + let scale_size = core::mem::size_of::(); // scale is stored as f32 let scale_bytes = &self.bytes[total_bytes - scale_size..]; - let scale = read_unaligned(scale_bytes); + let scale = read_unaligned::(scale_bytes).elem(); let mut offset = None; if let QuantizationScheme::PerTensorAffine(_) = scheme { - let offset_size = core::mem::size_of::(); + let offset_size = core::mem::size_of::(); // zero-point offset is stored as i32 let offset_bytes = &self.bytes[total_bytes - scale_size - offset_size..total_bytes - scale_size]; - offset = Some(read_unaligned(offset_bytes)) + offset = Some(read_unaligned::(offset_bytes).elem()) } Some(QParams { scale, offset }) diff --git a/crates/burn-tensor/src/tensor/ops/qtensor.rs b/crates/burn-tensor/src/tensor/ops/qtensor.rs index 4b9df13a49..781ed7c6eb 100644 --- a/crates/burn-tensor/src/tensor/ops/qtensor.rs +++ b/crates/burn-tensor/src/tensor/ops/qtensor.rs @@ -4,7 +4,7 @@ use core::{future::Future, ops::Range}; use crate::{ backend::Backend, quantization::{QTensorPrimitive, QuantizationParametersPrimitive, QuantizationScheme}, - Device, Shape, TensorData, + Device, Shape, TensorData, TensorMetadata, }; use super::{BoolTensor, FloatElem, FloatTensor, IntElem, IntTensor, QuantizedTensor}; @@ -74,17 +74,6 @@ pub trait QTensorOps { /// Convert the tensor back to a higher precision data type. fn dequantize(tensor: QuantizedTensor) -> FloatTensor; - /// Gets the shape of the tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// - /// # Returns - /// - /// The shape of the tensor. - fn q_shape(tensor: &QuantizedTensor) -> Shape; - /// Gets the device of the tensor. /// /// # Arguments @@ -460,7 +449,7 @@ pub trait QTensorOps { /// /// The transposed tensor. fn q_transpose(tensor: QuantizedTensor) -> QuantizedTensor { - let ndims = Self::q_shape(&tensor).num_dims(); + let ndims = tensor.shape().num_dims(); Self::q_swap_dims(tensor, ndims - 2, ndims - 1) } @@ -1071,7 +1060,7 @@ pub trait QTensorOps { /// /// A tensor with the maximum element of `tensor`. fn q_max(tensor: QuantizedTensor) -> QuantizedTensor { - let shape = B::q_shape(&tensor); + let shape = tensor.shape(); let tensor = B::q_reshape(tensor, Shape::new([shape.num_elements()])); B::q_max_dim(tensor, 0) @@ -1123,7 +1112,7 @@ pub trait QTensorOps { /// /// A tensor with the minimum element of `tensor`. fn q_min(tensor: QuantizedTensor) -> QuantizedTensor { - let shape = B::q_shape(&tensor); + let shape = tensor.shape(); let tensor = B::q_reshape(tensor, Shape::new([shape.num_elements()])); B::q_min_dim(tensor, 0) diff --git a/crates/burn-tensor/src/tensor/quantization/primitive.rs b/crates/burn-tensor/src/tensor/quantization/primitive.rs index acc5ff692d..a4f1fba1c8 100644 --- a/crates/burn-tensor/src/tensor/quantization/primitive.rs +++ b/crates/burn-tensor/src/tensor/quantization/primitive.rs @@ -1,13 +1,7 @@ -use super::{QuantizationScheme, QuantizationStrategy}; +use super::QuantizationScheme; /// Quantized tensor primitive. pub trait QTensorPrimitive { /// Returns the quantization scheme for the given tensor. fn scheme(&self) -> &QuantizationScheme; - /// Returns the quantization strategy for the given tensor. - /// - /// # Remarks - /// Retrieving the quantization strategy with its corresponding parameters might require - /// synchronization on the backend. - fn strategy(&self) -> QuantizationStrategy; } diff --git a/crates/burn-tensor/src/tensor/quantization/scheme.rs b/crates/burn-tensor/src/tensor/quantization/scheme.rs index 3534a01c32..fb141ee16d 100644 --- a/crates/burn-tensor/src/tensor/quantization/scheme.rs +++ b/crates/burn-tensor/src/tensor/quantization/scheme.rs @@ -1,11 +1,17 @@ +#![allow(missing_docs)] // cube derive macros + use serde::{Deserialize, Serialize}; use crate::{backend::Backend, Tensor, TensorPrimitive}; use super::{CalibrationRange, QuantizationParameters, QuantizationParametersPrimitive}; +#[cfg(feature = "cubecl")] +use cubecl::prelude::*; + /// Quantization data type. #[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)] +#[cfg_attr(feature = "cubecl", derive(CubeType, PartialOrd, Ord))] pub enum QuantizationType { /// 8-bit signed integer. QInt8, @@ -13,6 +19,7 @@ pub enum QuantizationType { /// Quantization scheme. #[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)] +#[cfg_attr(feature = "cubecl", derive(PartialOrd, Ord))] pub enum QuantizationScheme { /// Per-tensor affine/asymmetric quantization. PerTensorAffine(QuantizationType), @@ -24,6 +31,17 @@ pub enum QuantizationScheme { // PerChannelSymmetric, } +#[cfg(feature = "cubecl")] +impl CubeType for QuantizationScheme { + type ExpandType = Self; +} +#[cfg(feature = "cubecl")] +impl cubecl::frontend::Init for QuantizationScheme { + fn init(self, _context: &mut CubeContext) -> Self { + self + } +} + impl QuantizationScheme { /// Compute the quantization parameters. pub fn compute_q_params( diff --git a/crates/burn-tensor/src/tests/quantization/ops/quantize.rs b/crates/burn-tensor/src/tests/quantization/ops/quantize.rs index 72834b8c8a..a120c2802a 100644 --- a/crates/burn-tensor/src/tests/quantization/ops/quantize.rs +++ b/crates/burn-tensor/src/tests/quantization/ops/quantize.rs @@ -18,7 +18,7 @@ mod tests { offset: Some(Tensor::from_ints([72], &device)), }; - let x_q = tensor.quantize(&scheme, qparams); + let x_q = tensor.quantize(&scheme, qparams).into_data(); let expected = TensorData::quantized( vec![-128i8, -39, 72, 127], @@ -26,7 +26,14 @@ mod tests { QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.009_019_608, 72)), ); - x_q.to_data().assert_eq(&expected, true); + // Values equality + x_q.assert_eq(&expected, true); + + // Quantization parameters check + let qparams = x_q.get_q_params::().unwrap(); + let expected = expected.get_q_params::().unwrap(); + assert_eq!(qparams.scale, expected.scale); + assert_eq!(qparams.offset, expected.offset); } #[test] @@ -39,7 +46,7 @@ mod tests { offset: None, }; - let x_q = tensor.quantize(&scheme, qparams); + let x_q = tensor.quantize(&scheme, qparams).into_data(); let expected = TensorData::quantized( vec![-127i8, -71, 0, 35], @@ -49,7 +56,14 @@ mod tests { )), ); - x_q.to_data().assert_eq(&expected, true); + // Values equality + x_q.assert_eq(&expected, true); + + // Quantization parameters check + let qparams = x_q.get_q_params::().unwrap(); + let expected = expected.get_q_params::().unwrap(); + assert_eq!(qparams.scale, expected.scale); + assert_eq!(qparams.offset, expected.offset); } #[test] diff --git a/crates/burn-wgpu/src/lib.rs b/crates/burn-wgpu/src/lib.rs index 0751ad9f41..deb6a8ebd8 100644 --- a/crates/burn-wgpu/src/lib.rs +++ b/crates/burn-wgpu/src/lib.rs @@ -22,7 +22,7 @@ pub type SpirV = cubecl::wgpu::spirv::VkSpirvCompiler; #[cfg(feature = "spirv")] type Compiler = SpirV; #[cfg(feature = "spirv")] -type Byte = u8; +type Bool = u8; #[cfg(not(feature = "spirv"))] type Compiler = Wgsl; #[cfg(not(feature = "spirv"))] From dda336adf527d6867515332029e87de17389f436 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Mon, 16 Dec 2024 08:05:50 -0500 Subject: [PATCH 2/5] Combined PRs (#2619) * Bump openblas-src from 0.10.9 to 0.10.10 Bumps [openblas-src](https://github.com/blas-lapack-rs/openblas-src) from 0.10.9 to 0.10.10. - [Release notes](https://github.com/blas-lapack-rs/openblas-src/releases) - [Changelog](https://github.com/blas-lapack-rs/openblas-src/blob/master/CHANGELOG.md) - [Commits](https://github.com/blas-lapack-rs/openblas-src/compare/openblas-src-v0.10.9...openblas-src-v0.10.10) --- updated-dependencies: - dependency-name: openblas-src dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] * Bump clap from 4.5.21 to 4.5.23 Bumps [clap](https://github.com/clap-rs/clap) from 4.5.21 to 4.5.23. - [Release notes](https://github.com/clap-rs/clap/releases) - [Changelog](https://github.com/clap-rs/clap/blob/master/CHANGELOG.md) - [Commits](https://github.com/clap-rs/clap/compare/clap_complete-v4.5.21...clap_complete-v4.5.23) --- updated-dependencies: - dependency-name: clap dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] * Bump libc from 0.2.167 to 0.2.168 Bumps [libc](https://github.com/rust-lang/libc) from 0.2.167 to 0.2.168. - [Release notes](https://github.com/rust-lang/libc/releases) - [Changelog](https://github.com/rust-lang/libc/blob/0.2.168/CHANGELOG.md) - [Commits](https://github.com/rust-lang/libc/compare/0.2.167...0.2.168) --- updated-dependencies: - dependency-name: libc dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] * Bump tracing-subscriber from 0.3.18 to 0.3.19 Bumps [tracing-subscriber](https://github.com/tokio-rs/tracing) from 0.3.18 to 0.3.19. - [Release notes](https://github.com/tokio-rs/tracing/releases) - [Commits](https://github.com/tokio-rs/tracing/compare/tracing-subscriber-0.3.18...tracing-subscriber-0.3.19) --- updated-dependencies: - dependency-name: tracing-subscriber dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] * Bump serde from 1.0.215 to 1.0.216 Bumps [serde](https://github.com/serde-rs/serde) from 1.0.215 to 1.0.216. - [Release notes](https://github.com/serde-rs/serde/releases) - [Commits](https://github.com/serde-rs/serde/compare/v1.0.215...v1.0.216) --- updated-dependencies: - dependency-name: serde dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] --------- Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- Cargo.lock | 81 +++++++++++++++++++++--------------------------------- Cargo.toml | 10 +++---- 2 files changed, 36 insertions(+), 55 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 2367618c2f..9eb57e9c70 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -416,11 +416,11 @@ dependencies = [ "burn", "burn-common", "burn-wgpu", - "clap 4.5.21", + "clap 4.5.23", "colored", "cubecl", "derive-new 0.7.0", - "dirs 5.0.1", + "dirs", "github-device-flow", "half", "indicatif", @@ -740,7 +740,7 @@ dependencies = [ "burn-common", "csv", "derive-new 0.7.0", - "dirs 5.0.1", + "dirs", "fake", "flate2", "gix-tempfile", @@ -1265,9 +1265,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.21" +version = "4.5.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fb3b4b9e5a7c7514dfa52869339ee98b3156b0bfb4e8a77c4ff4babb64b1604f" +checksum = "3135e7ec2ef7b10c6ed8950f0f792ed96ee093fa088608f1c76e569722700c84" dependencies = [ "clap_builder", "clap_derive 4.5.18", @@ -1275,13 +1275,13 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.21" +version = "4.5.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b17a95aa67cc7b5ebd32aa5370189aa0d79069ef1c64ce893bd30fb24bff20ec" +checksum = "30582fc632330df2bd26877bde0c1f4470d57c582bbc070376afcd04d8cb4838" dependencies = [ "anstream", "anstyle", - "clap_lex 0.7.3", + "clap_lex 0.7.4", "strsim 0.11.1", ] @@ -1321,9 +1321,9 @@ dependencies = [ [[package]] name = "clap_lex" -version = "0.7.3" +version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "afb84c814227b90d6895e01398aee0d8033c00e7466aca416fb6a8e0eb19d8a7" +checksum = "f46ad14479a25103f283c0f10005961cf086d8dc42205bb44c46ac563475dca6" [[package]] name = "clipboard-win" @@ -1840,7 +1840,7 @@ dependencies = [ "cfg_aliases 0.2.1", "cubecl-common 0.3.0", "derive-new 0.6.0", - "dirs 5.0.1", + "dirs", "hashbrown 0.14.5", "log", "md5", @@ -1861,7 +1861,7 @@ dependencies = [ "cfg_aliases 0.2.1", "cubecl-common 0.4.0", "derive-new 0.6.0", - "dirs 5.0.1", + "dirs", "hashbrown 0.14.5", "log", "md5", @@ -2185,33 +2185,13 @@ dependencies = [ "subtle", ] -[[package]] -name = "dirs" -version = "3.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "30baa043103c9d0c2a57cf537cc2f35623889dc0d405e6c3cccfadbc81c71309" -dependencies = [ - "dirs-sys 0.3.7", -] - [[package]] name = "dirs" version = "5.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "44c45a9d03d6676652bcb5e724c7e988de1acad23a711b5217ab9cbecbec2225" dependencies = [ - "dirs-sys 0.4.1", -] - -[[package]] -name = "dirs-sys" -version = "0.3.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b1d1d91c932ef41c0f2663aa8b0ca0342d444d842c06914aa0a7e352d0bada6" -dependencies = [ - "libc", - "redox_users", - "winapi", + "dirs-sys", ] [[package]] @@ -3242,7 +3222,7 @@ version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2b780635574b3d92f036890d8373433d6f9fc7abb320ee42a5c25897fc8ed732" dependencies = [ - "dirs 5.0.1", + "dirs", "indicatif", "log", "native-tls", @@ -3901,9 +3881,9 @@ checksum = "03087c2bad5e1034e8cace5926dec053fb3790248370865f5117a7d0213354c8" [[package]] name = "libc" -version = "0.2.167" +version = "0.2.168" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09d6582e104315a817dff97f75133544b2e094ee22447d2acf4a74e189ba06fc" +checksum = "5aaeb2981e0606ca11d79718f8bb01164f1d6ed75080182d3abf017e6d244b6d" [[package]] name = "libfuzzer-sys" @@ -4827,27 +4807,28 @@ dependencies = [ [[package]] name = "openblas-build" -version = "0.10.9" +version = "0.10.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d4b6b44095098cafc71915cfac3427135b6dd2ea85820a7d94a5871cb0d1e169" +checksum = "8ca8f8c64eb5b43f5538059ccbc71391420bba14d987d7e8ab99ed62ed33e26b" dependencies = [ "anyhow", + "cc", "flate2", "native-tls", "tar", - "thiserror 1.0.69", + "thiserror 2.0.6", "ureq", - "walkdir", ] [[package]] name = "openblas-src" -version = "0.10.9" +version = "0.10.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aa4958649f766a1013db4254a852cdf2836764869b6654fa117316905f537363" +checksum = "19a6eee3396f46f65497e83c2f74d75315834bd0fd071cda26bda82ac3fda080" dependencies = [ - "dirs 3.0.2", + "dirs", "openblas-build", + "pkg-config", "vcpkg", ] @@ -6811,9 +6792,9 @@ checksum = "a3f0bf26fd526d2a95683cd0f87bf103b8539e2ca1ef48ce002d67aad59aa0b4" [[package]] name = "serde" -version = "1.0.215" +version = "1.0.216" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6513c1ad0b11a9376da888e3e0baa0077f1aed55c17f50e7b2397136129fb88f" +checksum = "0b9781016e935a97e8beecf0c933758c97a5520d32930e460142b4cd80c6338e" dependencies = [ "serde_derive", ] @@ -6840,9 +6821,9 @@ dependencies = [ [[package]] name = "serde_derive" -version = "1.0.215" +version = "1.0.216" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ad1e866f866923f252f05c889987993144fb74e722403468a4ebd70c3cd756c0" +checksum = "46f859dbbf73865c6627ed570e78961cd3ac92407a2d117204c49232485da55e" dependencies = [ "proc-macro2", "quote", @@ -7840,7 +7821,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "58fccce80a2ef6bc32a512514a53cf853d438a44abaea286a4acb0c9f8566860" dependencies = [ "anyhow", - "clap 4.5.21", + "clap 4.5.23", "derive_more 0.99.18", "env_logger", "log", @@ -7921,9 +7902,9 @@ dependencies = [ [[package]] name = "tracing-subscriber" -version = "0.3.18" +version = "0.3.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ad0f048c97dbd9faa9b7df56362b8ebcaa52adb06b498c050d2f4e32f90a7a8b" +checksum = "e8189decb5ac0fa7bc8b96b7cb9b2701d60d48805aca84a238004d665fcc4008" dependencies = [ "matchers", "nu-ansi-term", diff --git a/Cargo.toml b/Cargo.toml index 07df9cdda4..208cfbc17f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,7 +29,7 @@ version = "0.16.0" atomic_float = "1" bytemuck = "1.20.0" candle-core = { version = "0.8" } -clap = { version = "4.5.21", features = ["derive"] } +clap = { version = "4.5.23", features = ["derive"] } colored = "2.1.0" console_error_panic_hook = "0.1.7" csv = "1.3.1" @@ -88,7 +88,7 @@ thiserror = "2.0.6" tokio = { version = "1.42.0", features = ["rt", "macros"] } tracing-appender = "0.2.3" tracing-core = "0.1.33" -tracing-subscriber = "0.3.18" +tracing-subscriber = "0.3.19" web-time = "1.1.0" zip = "2.2.1" @@ -131,19 +131,19 @@ ndarray = { version = "0.16.1", default-features = false } num-traits = { version = "0.2.19", default-features = false, features = [ "libm", ] } # libm is for no_std -openblas-src = "0.10.9" +openblas-src = "0.10.10" rand = { version = "0.8.5", default-features = false, features = [ "std_rng", ] } # std_rng is for no_std rand_distr = { version = "0.4.3", default-features = false } -serde = { version = "1.0.215", default-features = false, features = [ +serde = { version = "1.0.216", default-features = false, features = [ "derive", "alloc", ] } # alloc is for no_std, derive is needed serde_json = { version = "1.0.133", default-features = false } uuid = { version = "1.11.0", default-features = false } -libc = "0.2.167" +libc = "0.2.168" nvml-wrapper = "0.10.0" sysinfo = "0.32.1" systemstat = "0.2.3" From 53bc165204ca4a2ecb63537036a2276cfbb75f33 Mon Sep 17 00:00:00 2001 From: Arthur Brussee Date: Mon, 16 Dec 2024 15:32:22 +0000 Subject: [PATCH 3/5] Relax Fn requirements (#2620) --- crates/burn-core/src/module/param/base.rs | 26 ++++++++--------------- 1 file changed, 9 insertions(+), 17 deletions(-) diff --git a/crates/burn-core/src/module/param/base.rs b/crates/burn-core/src/module/param/base.rs index 5331f54240..b7e3bf8868 100644 --- a/crates/burn-core/src/module/param/base.rs +++ b/crates/burn-core/src/module/param/base.rs @@ -72,14 +72,14 @@ pub trait Parameter: Clone + core::fmt::Debug + Send { #[allow(clippy::type_complexity)] struct Uninitialized { - init: Box P + Send>, + init: Box P + Send>, device: P::Device, is_require_grad: bool, } impl Uninitialized

{ - fn initialize(&self) -> P { - let init = &self.init; + fn initialize(self) -> P { + let init = self.init; init(&self.device, self.is_require_grad) } } @@ -97,7 +97,7 @@ impl Param { /// Create a new parameter that is not already initialized. pub fn uninitialized(id: ParamId, init: F, device: T::Device, is_require_grad: bool) -> Self where - F: Fn(&T::Device, bool) -> T + Send + 'static, + F: FnOnce(&T::Device, bool) -> T + Send + 'static, { Self { id, @@ -120,12 +120,8 @@ impl Param { .expect("Should have an initialization when no state provided.") .write() .unwrap(); - let state = result.as_ref().expect("Should exist when not initialized"); - let tensor = state.initialize(); - - *result = None; - - tensor + let state = result.take().expect("Should exist when not initialized"); + state.initialize() }) .clone() } @@ -145,7 +141,7 @@ impl Param { } /// Execute the given function on the inner value. - pub fn map T>(self, func: F) -> Self { + pub fn map T>(self, func: F) -> Self { let (id, tensor) = self.consume(); let tensor = func(tensor); @@ -251,12 +247,8 @@ impl Deref for Param { .write() .unwrap(); - let state = result.as_ref().expect("Should exist when not initialized"); - let tensor = state.initialize(); - - *result = None; - - tensor + let state = result.take().expect("Should exist when not initialized"); + state.initialize() }) } } From 8a89293bf3ee02fe7216705ed3b7370506489e4a Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Mon, 16 Dec 2024 11:44:14 -0500 Subject: [PATCH 4/5] Add module mapper book example (#2621) * Add module mapper example * Fix handles typo --- burn-book/src/building-blocks/module.md | 31 +++++++++++++++++++++++++ crates/burn-jit/src/backend.rs | 2 +- 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/burn-book/src/building-blocks/module.md b/burn-book/src/building-blocks/module.md index 128a99a7b6..6404552471 100644 --- a/burn-book/src/building-blocks/module.md +++ b/burn-book/src/building-blocks/module.md @@ -112,6 +112,37 @@ Note that the trait doesn't require all methods to be implemented as they are al perform no operation. If you're only interested in float tensors (like the majority of use cases), then you can simply implement `map_float` or `visit_float`. +For example, the `ModuleMapper` trait could be implemented to clamp all parameters into the range +`[min, max]`. + +```rust, ignore +/// Clamp parameters into the range `[min, max]`. +pub struct Clamp { + /// Lower-bound of the range. + pub min: f32, + /// Upper-bound of the range. + pub max: f32, +} + +// Clamp all floating-point parameter tensors between `[min, max]`. +impl ModuleMapper for Clamp { + fn map_float( + &mut self, + _id: burn::module::ParamId, + tensor: burn::prelude::Tensor, + ) -> burn::prelude::Tensor { + tensor.clamp(self.min, self.max) + } +} + +// Clamp module mapper into the range `[-0.5, 0.5]` +let mut clamp = Clamp { + min: -0.5, + max: 0.5, +}; +let model = model.map(&mut clamp); +``` + ## Module Display Burn provides a simple way to display the structure of a module and its configuration at a glance. diff --git a/crates/burn-jit/src/backend.rs b/crates/burn-jit/src/backend.rs index 2d1b64ebce..3f5637e412 100644 --- a/crates/burn-jit/src/backend.rs +++ b/crates/burn-jit/src/backend.rs @@ -113,7 +113,7 @@ impl ReprBackend handle.handle } - fn quantized_tensor(handles: TensorHandle) -> QuantizedTensor { + fn quantized_tensor(handle: TensorHandle) -> QuantizedTensor { handle.handle } From 28f99d14e9a43ffe31c3d9652da4ac4ff00ee30b Mon Sep 17 00:00:00 2001 From: WorldSEnder Date: Tue, 17 Dec 2024 16:40:44 +0100 Subject: [PATCH 5/5] Fix alignment issue of TensorData bytes (#2416) * implement memory-safe bytes that can be serialized and cloned * change serialization to only serialize the bytes introduce max alignment (which depends on platform anyway) and dont serialize that part fixes Clone, Debug, and Eq impls to work on the bytes, not the pointers. * make bytes no-std compatible * enforce Send and Sync for Bytes * avoid a copy during deserialization if data is already aligned this already improves readability a bit by separating out alloc/dealloc logic and adding a bunch of safety comments and better error messages * revert back to using Vec as deserialization intermediate borrowing from the deserializer will not save a copy, and is moreover inefficient when we could take ownership of an existing byte buffer * add serialization and conversion tests * make Bytes tests run under miri both changes only target miri's borrowing semantics, oprationally the pointers are the same, but they obey different borrow-stack rules. * let the Bytes buffer grow * Clean the code by separation of concerns The new Allocation struct keeps the raw allocation and its layout, the Bytes struct wraps an Allocation and asserts that len bytes of it are initialized * nit: change typo and improve internal naming * use Bytes in jit ops --- Cargo.lock | 1 + crates/burn-core/src/record/serde/data.rs | 12 +- crates/burn-core/src/record/serde/de.rs | 6 +- crates/burn-core/src/record/serde/ser.rs | 2 +- crates/burn-import/src/pytorch/reader.rs | 2 +- crates/burn-jit/src/backend.rs | 3 +- crates/burn-jit/src/ops/qtensor.rs | 8 +- crates/burn-jit/src/ops/transaction.rs | 8 +- crates/burn-tensor/Cargo.toml | 1 + crates/burn-tensor/src/tensor/bytes.rs | 547 ++++++++++++++++++ crates/burn-tensor/src/tensor/data.rs | 122 ++-- crates/burn-tensor/src/tensor/mod.rs | 2 + .../src/tensor/quantization/data.rs | 7 +- 13 files changed, 641 insertions(+), 80 deletions(-) create mode 100644 crates/burn-tensor/src/tensor/bytes.rs diff --git a/Cargo.lock b/Cargo.lock index 9eb57e9c70..0fa466bb87 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -937,6 +937,7 @@ dependencies = [ name = "burn-tensor" version = "0.16.0" dependencies = [ + "bincode", "burn-common", "burn-tensor-testgen", "bytemuck", diff --git a/crates/burn-core/src/record/serde/data.rs b/crates/burn-core/src/record/serde/data.rs index 27a55c5b2a..b5e5442b54 100644 --- a/crates/burn-core/src/record/serde/data.rs +++ b/crates/burn-core/src/record/serde/data.rs @@ -8,6 +8,7 @@ use crate::record::{PrecisionSettings, Record}; use crate::tensor::backend::Backend; use alloc::fmt; +use burn_tensor::Bytes; use num_traits::cast::ToPrimitive; use regex::Regex; use serde::Deserialize; @@ -66,7 +67,11 @@ pub enum NestedValue { /// A vector of 32-bit floating point values. F32s(Vec), + + /// An opaque vector of bytes, with alignment. + Bytes(Bytes), } + impl NestedValue { /// Get the nested value as a map. pub fn as_map(self) -> Option> { @@ -184,9 +189,10 @@ impl NestedValue { } /// Get the nested value as a vector of bytes. - pub fn as_bytes(self) -> Option> { + pub fn as_bytes(self) -> Option { match self { - NestedValue::U8s(u) => Some(u), + NestedValue::Bytes(u) => Some(u), + NestedValue::U8s(u) => Some(Bytes::from_elems(u)), _ => None, } } @@ -368,6 +374,7 @@ impl fmt::Debug for NestedValue { NestedValue::U8s(vec) if vec.len() > 3 => write_vec_truncated(vec, f), NestedValue::U16s(vec) if vec.len() > 3 => write_vec_truncated(vec, f), NestedValue::F32s(vec) if vec.len() > 3 => write_vec_truncated(vec, f), + NestedValue::Bytes(bytes) if bytes.len() > 3 => write_vec_truncated(bytes, f), // Handle other variants as usual NestedValue::Default(origin) => f.debug_tuple("Default").field(origin).finish(), NestedValue::Bool(b) => f.debug_tuple("Bool").field(b).finish(), @@ -385,6 +392,7 @@ impl fmt::Debug for NestedValue { NestedValue::U8s(vec) => f.debug_list().entries(vec.iter()).finish(), NestedValue::U16s(vec) => f.debug_list().entries(vec.iter()).finish(), NestedValue::F32s(vec) => f.debug_list().entries(vec.iter()).finish(), + NestedValue::Bytes(bytes) => f.debug_list().entries(bytes.iter()).finish(), } } } diff --git a/crates/burn-core/src/record/serde/de.rs b/crates/burn-core/src/record/serde/de.rs index 3c93afed16..04bed6899f 100644 --- a/crates/burn-core/src/record/serde/de.rs +++ b/crates/burn-core/src/record/serde/de.rs @@ -233,7 +233,11 @@ impl<'de, A: BurnModuleAdapter> serde::Deserializer<'de> for Deserializer { where V: Visitor<'de>, { - visitor.visit_byte_buf(self.value.unwrap().as_bytes().unwrap()) + let bytes = self.value.unwrap().as_bytes().unwrap(); + match bytes.try_into_vec::() { + Ok(bytes) => visitor.visit_byte_buf(bytes), + Err(bytes) => visitor.visit_bytes(&bytes), + } } fn deserialize_option(self, visitor: V) -> Result diff --git a/crates/burn-core/src/record/serde/ser.rs b/crates/burn-core/src/record/serde/ser.rs index f2803d8bdc..b0baaa5cd1 100644 --- a/crates/burn-core/src/record/serde/ser.rs +++ b/crates/burn-core/src/record/serde/ser.rs @@ -383,6 +383,6 @@ mod tests { .clone() .as_bytes() .expect("has bytes vec"); - assert_eq!(bytes, [1.0f32; 4].map(|f| f.to_le_bytes()).as_flattened()); + assert_eq!(&*bytes, [1.0f32; 4].map(|f| f.to_le_bytes()).as_flattened()); } } diff --git a/crates/burn-import/src/pytorch/reader.rs b/crates/burn-import/src/pytorch/reader.rs index 55a8d9838a..87a76278b8 100644 --- a/crates/burn-import/src/pytorch/reader.rs +++ b/crates/burn-import/src/pytorch/reader.rs @@ -152,7 +152,7 @@ where // Because serializer copies individual elements of TensorData `value` into a new Vec, // which is not necessary and inefficient. let mut tensor_data: HashMap = HashMap::new(); - tensor_data.insert("bytes".into(), NestedValue::U8s(bytes)); + tensor_data.insert("bytes".into(), NestedValue::Bytes(bytes)); tensor_data.insert("shape".into(), shape.serialize(serializer.clone())?); tensor_data.insert("dtype".into(), dtype.serialize(serializer)?); diff --git a/crates/burn-jit/src/backend.rs b/crates/burn-jit/src/backend.rs index 3f5637e412..e29a98ca96 100644 --- a/crates/burn-jit/src/backend.rs +++ b/crates/burn-jit/src/backend.rs @@ -7,8 +7,7 @@ use std::{marker::PhantomData, sync::Mutex}; #[cfg(not(feature = "fusion"))] use burn_tensor::{ ops::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor}, - quantization::QuantizationScheme, - repr::{HandleKind, ReprBackend, TensorHandle}, + repr::{ReprBackend, TensorHandle}, }; pub(crate) static SEED: Mutex> = Mutex::new(None); diff --git a/crates/burn-jit/src/ops/qtensor.rs b/crates/burn-jit/src/ops/qtensor.rs index cde44b0552..feca09a41f 100644 --- a/crates/burn-jit/src/ops/qtensor.rs +++ b/crates/burn-jit/src/ops/qtensor.rs @@ -3,7 +3,7 @@ use std::ops::Range; use burn_tensor::{ ops::{FloatTensor, IntTensor, QTensorOps, QuantizedTensor}, quantization::{QuantizationParametersPrimitive, QuantizationScheme, QuantizationType}, - DType, Device, Shape, TensorData, + Bytes, DType, Device, Shape, TensorData, }; use crate::{ @@ -82,12 +82,8 @@ where let tensor = kernel::into_contiguous(tensor); let bytes = tensor.client.read_one_async(tensor.handle.binding()).await; - // TODO: this should be refactored such that the bytes type is opaque. - // With this, the logic for handling the bytes representation of quantized data - // (as well as all data manipulations) will be encapsulated in the type. - // Creating a TensorData struct directly from some bytes should probably not be possible outside of the crate. TensorData { - bytes, + bytes: Bytes::from_bytes_vec(bytes), shape: tensor.shape.into(), dtype: tensor.dtype, } diff --git a/crates/burn-jit/src/ops/transaction.rs b/crates/burn-jit/src/ops/transaction.rs index b67740bc97..49bc1936ee 100644 --- a/crates/burn-jit/src/ops/transaction.rs +++ b/crates/burn-jit/src/ops/transaction.rs @@ -1,6 +1,6 @@ use burn_tensor::{ ops::{TransactionOps, TransactionPrimitiveResult}, - DType, TensorData, + Bytes, DType, TensorData, }; use crate::{element::BoolElement, FloatElement, IntElement, JitBackend, JitRuntime}; @@ -74,7 +74,7 @@ where Kind::Float(index, shape, dtype) => { let bytes = data.get_mut(index).unwrap().take().unwrap(); result.read_floats.push(TensorData { - bytes, + bytes: Bytes::from_bytes_vec(bytes), shape, dtype, }); @@ -82,7 +82,7 @@ where Kind::Int(index, shape, dtype) => { let bytes = data.get_mut(index).unwrap().take().unwrap(); result.read_ints.push(TensorData { - bytes, + bytes: Bytes::from_bytes_vec(bytes), shape, dtype, }); @@ -90,7 +90,7 @@ where Kind::Bool(index, shape, dtype) => { let bytes = data.get_mut(index).unwrap().take().unwrap(); result.read_bools.push(TensorData { - bytes, + bytes: Bytes::from_bytes_vec(bytes), shape, dtype, }); diff --git a/crates/burn-tensor/Cargo.toml b/crates/burn-tensor/Cargo.toml index 9c4dfa3a49..02683531c6 100644 --- a/crates/burn-tensor/Cargo.toml +++ b/crates/burn-tensor/Cargo.toml @@ -56,6 +56,7 @@ portable-atomic-util = { workspace = true } [dev-dependencies] rand = { workspace = true, features = ["std", "std_rng"] } # Default enables std +bincode = { workspace = true } [package.metadata.docs.rs] features = ["doc"] diff --git a/crates/burn-tensor/src/tensor/bytes.rs b/crates/burn-tensor/src/tensor/bytes.rs new file mode 100644 index 0000000000..f9cb238a26 --- /dev/null +++ b/crates/burn-tensor/src/tensor/bytes.rs @@ -0,0 +1,547 @@ +//! A version of [`bytemuck::BoxBytes`] that is cloneable and allows trailing uninitialized elements. + +use alloc::alloc::{Layout, LayoutError}; +use core::mem::MaybeUninit; +use core::ops::{Deref, DerefMut}; +use core::ptr::NonNull; + +use alloc::vec::Vec; + +/// Internally used to avoid accidentally leaking an allocation or using the wrong layout. +struct Allocation { + /// SAFETY: + /// - If `layout.size() > 0`, `ptr` points to a valid allocation from the global allocator + /// of the specified layout. The first `len` bytes are initialized. + /// - If `layout.size() == 0`, `ptr` is aligned to `layout.align()` and `len` is 0. + /// `ptr` is further suitable to be used as the argument for `Vec::from_raw_parts` see [buffer alloc] + /// for more details. + ptr: NonNull, + layout: Layout, +} + +/// A sort of `Box<[u8]>` that remembers the original alignment and can contain trailing uninitialized bytes. +pub struct Bytes { + alloc: Allocation, + // SAFETY: The first `len` bytes of the allocation are initialized + len: usize, +} + +/// The maximum supported alignment. The limit exists to not have to store alignment when serializing. Instead, +/// the bytes are always over-aligned when deserializing to MAX_ALIGN. +const MAX_ALIGN: usize = core::mem::align_of::(); + +fn debug_from_fn) -> core::fmt::Result>( + f: F, +) -> impl core::fmt::Debug { + // See also: std::fmt::from_fn + struct FromFn(F); + impl core::fmt::Debug for FromFn + where + F: Fn(&mut core::fmt::Formatter<'_>) -> core::fmt::Result, + { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + (self.0)(f) + } + } + FromFn(f) +} + +impl core::fmt::Debug for Bytes { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + let data = &**self; + let fmt_data = move |f: &mut core::fmt::Formatter<'_>| { + if data.len() > 3 { + // There is a nightly API `debug_more_non_exhaustive` which has `finish_non_exhaustive` + f.debug_list().entries(&data[0..3]).entry(&"...").finish() + } else { + f.debug_list().entries(data).finish() + } + }; + f.debug_struct("Bytes") + .field("data", &debug_from_fn(fmt_data)) + .field("len", &self.len) + .finish() + } +} + +impl serde::Serialize for Bytes { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + serde_bytes::serialize(self.deref(), serializer) + } +} + +impl<'de> serde::Deserialize<'de> for Bytes { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + #[cold] + fn too_large(len: usize, align: usize) -> E { + // max_length = largest multiple of align that is <= isize::MAX + // align is a power of 2, hence a multiple has the lower bits unset. Mask them off to find the largest multiple + let max_length = (isize::MAX as usize) & !(align - 1); + E::custom(core::format_args!( + "length too large: {len}. Expected at most {max_length} bytes" + )) + } + + // TODO: we can possibly avoid one copy here by deserializing into an existing, correctly aligned, slice of bytes. + // We might not be able to predict the length of the data, hence it's far more convenient to let `Vec` handle the growth and re-allocations. + // Further, on a lot of systems, the allocator naturally aligns data to some reasonably large alignment, where no further copy is then + // necessary. + let data: Vec = serde_bytes::deserialize(deserializer)?; + // When deserializing, we over-align the data. This saves us from having to encode the alignment (which is platform-dependent in any case). + // If we had more context information here, we could enforce some (smaller) alignment per data type. But this information is only available + // in `TensorData`. Moreover it depends on the Deserializer there whether the datatype or data comes first. + let align = MAX_ALIGN; + let mut bytes = Self::from_elems(data); + bytes + .try_enforce_runtime_align(align) + .map_err(|_| too_large(bytes.len(), align))?; + Ok(bytes) + } +} + +impl Clone for Bytes { + fn clone(&self) -> Self { + // unwrap here: the layout is valid as it has the alignment & size of self + Self::try_from_data(MAX_ALIGN, self.deref()).unwrap() + } +} + +impl PartialEq for Bytes { + fn eq(&self, other: &Self) -> bool { + self.deref() == other.deref() + } +} + +impl Eq for Bytes {} + +impl Allocation { + // Wrap the allocation of a vector without copying + fn from_vec(vec: Vec) -> Self { + let mut elems = core::mem::ManuallyDrop::new(vec); + // Set the length to 0, then all data is in the "spare capacity". + // SAFETY: Data is Copy, so in particular does not need to be dropped. In any case, try not to panic until + // we have taken ownership of the data! + unsafe { elems.set_len(0) }; + let data = elems.spare_capacity_mut(); + // We now have one contiguous slice of data to pass to Layout::for_value. + let layout = Layout::for_value(data); + // SAFETY: data is the allocation of a vec, hence can not be null. We use unchecked to avoid a panic-path. + let ptr = unsafe { NonNull::new_unchecked(elems.as_mut_ptr().cast()) }; + Self { ptr, layout } + } + // Create a new allocation with the specified layout + fn new(layout: Layout) -> Self { + let ptr = buffer_alloc(layout); + Self { ptr, layout } + } + // Reallocate to fit at least the size and align of min_layout + fn grow(&mut self, min_layout: Layout) { + (self.layout, self.ptr) = buffer_grow(self.layout, self.ptr, min_layout); + } + // Returns a mutable view of the memory of the whole allocation + fn memory_mut(&mut self) -> &mut [MaybeUninit] { + // SAFETY: See type invariants + unsafe { core::slice::from_raw_parts_mut(self.ptr.as_ptr().cast(), self.layout.size()) } + } + // Return a pointer to the underlying allocation. This pointer is valid for reads and writes until the allocation is dropped or reallocated. + fn as_mut_ptr(&self) -> *mut u8 { + self.ptr.as_ptr() + } + // Try to convert the allocation to a Vec. The Vec has a length of 0 when returned, but correct capacity and pointer! + fn try_into_vec(self) -> Result, Self> { + let byte_capacity = self.layout.size(); + let Some(capacity) = byte_capacity.checked_div(size_of::()) else { + return Err(self); + }; + if capacity * size_of::() != byte_capacity { + return Err(self); + }; + if self.layout.align() != align_of::() { + return Err(self); + } + // Okay, let's commit + let ptr = self.ptr.as_ptr().cast(); + core::mem::forget(self); + // SAFETY: + // - ptr was allocated by the global allocator as per type-invariant + // - `E` has the same alignment as indicated by the stored layout. + // - capacity * size_of:: == layout.size() + // - 0 <= capacity + // - no bytes are claimed to be initialized + // - the layout represents a valid allocation, hence has allocation size less than isize::MAX + Ok(unsafe { Vec::from_raw_parts(ptr, 0, capacity) }) + } +} + +impl Drop for Allocation { + fn drop(&mut self) { + buffer_dealloc(self.layout, self.ptr); + } +} + +// Allocate a pointer that can be passed to Vec::from_raw_parts +fn buffer_alloc(layout: Layout) -> NonNull { + // [buffer alloc]: The current docs of Vec::from_raw_parts(ptr, ...) say: + // > ptr must have been allocated using the global allocator + // Yet, an empty Vec is guaranteed to not allocate (it is even illegal! to allocate with a zero-sized layout) + // Hence, we slightly re-interpret the above to only needing to hold if `capacity > 0`. Still, the pointer + // must be non-zero. So in case we need a pointer for an empty vec, use a correctly aligned, dangling one. + if layout.size() == 0 { + // we would use NonNull:dangling() but we don't have a concrete type for the requested alignment + let ptr = core::ptr::null_mut::().wrapping_add(layout.align()); + // SAFETY: layout.align() is never 0 + unsafe { NonNull::new_unchecked(ptr) } + } else { + // SAFETY: layout has non-zero size. + let ptr = unsafe { alloc::alloc::alloc(layout) }; + NonNull::new(ptr).unwrap_or_else(|| alloc::alloc::handle_alloc_error(layout)) + } +} + +fn expect_dangling(align: usize, buffer: NonNull) { + debug_assert!( + buffer.as_ptr().wrapping_sub(align).is_null(), + "expected a nullptr for size 0" + ); +} + +#[cold] +fn alloc_overflow() -> ! { + panic!("Overflow, too many elements") +} + +// Grow the buffer while keeping alignment +fn buffer_grow( + old_layout: Layout, + buffer: NonNull, + min_layout: Layout, +) -> (Layout, NonNull) { + let new_align = min_layout.align().max(old_layout.align()); // Don't let data become less aligned + let new_size = min_layout.size().next_multiple_of(new_align); + if new_size > isize::MAX as usize { + alloc_overflow(); + } + + assert!(new_size > old_layout.size(), "size must actually grow"); + if old_layout.size() == 0 { + expect_dangling(old_layout.align(), buffer); + let new_layout = Layout::from_size_align(new_size, new_align).unwrap(); + let buffer = buffer_alloc(new_layout); + return (new_layout, buffer); + }; + let realloc = || { + let new_layout = Layout::from_size_align(new_size, old_layout.align()).unwrap(); + // SAFETY: + // - buffer comes from a Vec or from [`buffer_alloc`/`buffer_grow`]. + // - old_layout is the same as with which the pointer was allocated + // - new_size is not 0, since it is larger than old_layout.size() which is non-zero + // - size constitutes a valid layout + let ptr = unsafe { alloc::alloc::realloc(buffer.as_ptr(), old_layout, new_layout.size()) }; + (new_layout, ptr) + }; + if new_align == old_layout.align() { + // happy path. We can just realloc. + let (new_layout, ptr) = realloc(); + let buffer = NonNull::new(ptr); + let buffer = buffer.unwrap_or_else(|| alloc::alloc::handle_alloc_error(new_layout)); + return (new_layout, buffer); + } + // [buffer grow]: alloc::realloc can *not* change the alignment of the allocation's layout. + // The unstable Allocator::{grow,shrink} API changes this, but might take a while to make it + // into alloc::GlobalAlloc. + // + // As such, we can not request a specific alignment. But most allocators will give us the required + // alignment "for free". Hence, we speculatively avoid a mem-copy by using realloc. + // + // If in the future requesting an alignment change for an existing is available, this can be removed. + #[cfg(target_has_atomic = "8")] + mod alignment_assumption { + use core::sync::atomic::{AtomicBool, Ordering}; + static SPECULATE: AtomicBool = AtomicBool::new(true); + pub fn speculate() -> bool { + // We load and store with relaxed order, since worst case this leads to a few more memcopies + SPECULATE.load(Ordering::Relaxed) + } + pub fn report_violation() { + SPECULATE.store(false, Ordering::Relaxed) + } + } + #[cfg(not(target_has_atomic = "8"))] + mod alignment_assumption { + // On these platforms we don't speculate, and take the hit of performance + pub fn speculate() -> bool { + false + } + pub fn report_violation() {} + } + // reminder: old_layout.align() < new_align + let mut old_buffer = buffer; + let mut old_layout = old_layout; + if alignment_assumption::speculate() { + let (realloc_layout, ptr) = realloc(); + if let Some(buffer) = NonNull::new(ptr) { + if buffer.align_offset(new_align) == 0 { + return (realloc_layout, buffer); + } + // Speculating hasn't succeeded, but access now has to go through the reallocated buffer + alignment_assumption::report_violation(); + old_buffer = buffer; + old_layout = realloc_layout; + } else { + // If realloc fails, the later alloc will likely too, but don't report this yet + } + } + // realloc but change alignment. This requires a mem copy as pointed out above + let new_layout = Layout::from_size_align(new_size, new_align).unwrap(); + let new_buffer = buffer_alloc(new_layout); + // SAFETY: two different memory allocations, and old buffer's size is smaller than new_size + unsafe { + core::ptr::copy_nonoverlapping(old_buffer.as_ptr(), new_buffer.as_ptr(), old_layout.size()); + } + buffer_dealloc(old_layout, old_buffer); + (new_layout, new_buffer) +} + +// Deallocate a buffer of a Vec +fn buffer_dealloc(layout: Layout, buffer: NonNull) { + if layout.size() != 0 { + // SAFETY: buffer comes from a Vec or from [`buffer_alloc`/`buffer_grow`]. + // The layout is the same as per type-invariants + unsafe { + alloc::alloc::dealloc(buffer.as_ptr(), layout); + } + } else { + // An empty Vec does not allocate, hence nothing to dealloc + expect_dangling(layout.align(), buffer); + } +} + +impl Bytes { + /// Copy an existing slice of data into Bytes that are aligned to `align` + fn try_from_data(align: usize, data: &[u8]) -> Result { + let len = data.len(); + let layout = Layout::from_size_align(len, align)?; + let alloc = Allocation::new(layout); + unsafe { + // SAFETY: + // - data and alloc are distinct allocations of `len` bytes + core::ptr::copy_nonoverlapping::(data.as_ref().as_ptr(), alloc.as_mut_ptr(), len); + }; + Ok(Self { alloc, len }) + } + + /// Ensure the contained buffer is aligned to `align` by possibly moving it to a new buffer. + fn try_enforce_runtime_align(&mut self, align: usize) -> Result<(), LayoutError> { + if self.as_mut_ptr().align_offset(align) == 0 { + // data is already aligned correctly + return Ok(()); + } + *self = Self::try_from_data(align, self)?; + Ok(()) + } + + /// Create a sequence of [Bytes] from the memory representation of an unknown type of elements. + /// Prefer this over [Self::from_elems] when the datatype is not statically known and erased at runtime. + pub fn from_bytes_vec(bytes: Vec) -> Self { + let mut bytes = Self::from_elems(bytes); + // TODO: this method could be datatype aware and enforce a less strict alignment. + // On most platforms, this alignment check is fulfilled either way though, so + // the benefits of potentially saving a memcopy are negligible. + bytes.try_enforce_runtime_align(MAX_ALIGN).unwrap(); + bytes + } + + /// Erase the element type of a vector by converting into a sequence of [Bytes]. + /// + /// In case the element type is not statically known at runtime, prefer to use [Self::from_bytes_vec]. + pub fn from_elems(elems: Vec) -> Self + where + // NoUninit implies Copy + E: bytemuck::NoUninit + Send + Sync, + { + let _: () = const { + assert!( + core::mem::align_of::() <= MAX_ALIGN, + "element type not supported due to too large alignment" + ); + }; + // Note: going through a Box as in Vec::into_boxed_slice would re-allocate on excess capacity. Avoid that. + let byte_len = elems.len() * core::mem::size_of::(); + let alloc = Allocation::from_vec(elems); + Self { + alloc, + len: byte_len, + } + } + + fn reserve(&mut self, additional: usize) { + let needs_to_grow = additional > self.capacity().wrapping_sub(self.len()); + if !needs_to_grow { + return; + } + let Some(required_cap) = self.len().checked_add(additional) else { + alloc_overflow() + }; + // guarantee exponential growth for amortization + let new_cap = required_cap.max(self.capacity() * 2); + let new_cap = new_cap.max(MAX_ALIGN); // Small allocations would be pointless + let Ok(new_layout) = Layout::from_size_align(new_cap, MAX_ALIGN) else { + alloc_overflow() + }; + self.alloc.grow(new_layout); + } + + /// Extend the byte buffer from a slice of bytes + pub fn extend_from_byte_slice(&mut self, bytes: &[u8]) { + let additional = bytes.len(); + self.reserve(additional); + let len = self.len(); + let new_cap = len.wrapping_add(additional); // Can not overflow, as we've just successfully reserved sufficient space for it + let uninit_spare = &mut self.alloc.memory_mut()[len..new_cap]; + // SAFETY: reinterpreting the slice as a MaybeUninit. + // See also #![feature(maybe_uninit_write_slice)], which would replace this with safe code + uninit_spare.copy_from_slice(unsafe { + core::slice::from_raw_parts(bytes.as_ptr().cast(), additional) + }); + self.len = new_cap; + } + + /// Get the total capacity, in bytes, of the wrapped allocation. + pub fn capacity(&self) -> usize { + self.alloc.layout.size() + } + + /// Convert the bytes back into a vector. This requires that the type has the same alignment as the element + /// type this [Bytes] was initialized with. + /// This only returns with Ok(_) if the conversion can be done without a memcopy + pub fn try_into_vec( + mut self, + ) -> Result, Self> { + // See if the length is compatible + let Ok(data) = bytemuck::checked::try_cast_slice_mut::<_, E>(&mut self) else { + return Err(self); + }; + let length = data.len(); + // If so, try to convert the allocation to a vec + let mut vec = match self.alloc.try_into_vec::() { + Ok(vec) => vec, + Err(alloc) => { + self.alloc = alloc; + return Err(self); + } + }; + // SAFETY: We computed this length from the bytemuck-ed slice into this allocation + unsafe { + vec.set_len(length); + }; + Ok(vec) + } +} + +impl Deref for Bytes { + type Target = [u8]; + + fn deref(&self) -> &Self::Target { + // SAFETY: see type invariants + unsafe { core::slice::from_raw_parts(self.alloc.as_mut_ptr(), self.len) } + } +} + +impl DerefMut for Bytes { + fn deref_mut(&mut self) -> &mut Self::Target { + // SAFETY: see type invariants + unsafe { core::slice::from_raw_parts_mut(self.alloc.as_mut_ptr(), self.len) } + } +} + +// SAFETY: Bytes behaves like a Box<[u8]> and can contain only elements that are themselves Send +unsafe impl Send for Bytes {} +// SAFETY: Bytes behaves like a Box<[u8]> and can contain only elements that are themselves Sync +unsafe impl Sync for Bytes {} + +#[cfg(test)] +mod tests { + use super::Bytes; + use alloc::{vec, vec::Vec}; + + const _CONST_ASSERTS: fn() = || { + fn test_send() {} + fn test_sync() {} + test_send::(); + test_sync::(); + }; + + fn test_serialization_roundtrip(bytes: &Bytes) { + let config = bincode::config::standard(); + let serialized = + bincode::serde::encode_to_vec(bytes, config).expect("serialization to succeed"); + let (roundtripped, _) = bincode::serde::decode_from_slice(&serialized, config) + .expect("deserialization to succeed"); + assert_eq!( + bytes, &roundtripped, + "roundtripping through serialization didn't lead to equal Bytes" + ); + } + + #[test] + fn test_serialization() { + test_serialization_roundtrip(&Bytes::from_elems::(vec![])); + test_serialization_roundtrip(&Bytes::from_elems(vec![0xdead, 0xbeaf])); + } + + #[test] + fn test_into_vec() { + // We test an edge case here, where the capacity (but not actual size) makes it impossible to convert to a vec + let mut bytes = Vec::with_capacity(6); + let actual_cap = bytes.capacity(); + bytes.extend_from_slice(&[0, 1, 2, 3]); + let mut bytes = Bytes::from_elems::(bytes); + + bytes = bytes + .try_into_vec::<[u8; 0]>() + .expect_err("Conversion should not succeed for a zero-sized type"); + if actual_cap % 4 != 0 { + // We most likely get actual_cap == 6, we can't force Vec to actually do that. Code coverage should complain if the actual test misses this + bytes = bytes.try_into_vec::<[u8; 4]>().err().unwrap_or_else(|| { + panic!("Conversion should not succeed due to capacity {actual_cap} not fitting a whole number of elements"); + }); + } + bytes = bytes + .try_into_vec::() + .expect_err("Conversion should not succeed due to mismatched alignment"); + bytes = bytes.try_into_vec::<[u8; 3]>().expect_err( + "Conversion should not succeed due to size not fitting a whole number of elements", + ); + let bytes = bytes.try_into_vec::<[u8; 2]>().expect("Conversion should succeed for bit-convertible types of equal alignment and compatible size"); + assert_eq!(bytes, &[[0, 1], [2, 3]]); + } + + #[test] + fn test_grow() { + let mut bytes = Bytes::from_elems::(vec![]); + bytes.extend_from_byte_slice(&[0, 1, 2, 3]); + assert_eq!(bytes[..], [0, 1, 2, 3][..]); + + let mut bytes = Bytes::from_elems(vec![42u8; 4]); + bytes.extend_from_byte_slice(&[0, 1, 2, 3]); + assert_eq!(bytes[..], [42, 42, 42, 42, 0, 1, 2, 3][..]); + } + + #[test] + fn test_large_elems() { + let mut bytes = Bytes::from_elems(vec![42u128]); + const TEST_BYTES: [u8; 16] = [ + 0x12, 0x90, 0x78, 0x56, 0x34, 0x12, 0x90, 0x78, 0x56, 0x34, 0x12, 0x90, 0x78, 0x56, + 0x34, 0x12, + ]; + bytes.extend_from_byte_slice(&TEST_BYTES); + let vec = bytes.try_into_vec::().unwrap(); + assert_eq!(vec, [42u128, u128::from_ne_bytes(TEST_BYTES)]); + } +} diff --git a/crates/burn-tensor/src/tensor/data.rs b/crates/burn-tensor/src/tensor/data.rs index d52a181929..c65572a068 100644 --- a/crates/burn-tensor/src/tensor/data.rs +++ b/crates/burn-tensor/src/tensor/data.rs @@ -13,7 +13,7 @@ use half::{bf16, f16}; use crate::{ quantization::{AffineQuantization, Quantization, QuantizationStrategy}, - tensor::Shape, + tensor::{bytes::Bytes, Shape}, DType, Distribution, Element, ElementConversion, }; @@ -43,8 +43,7 @@ pub enum DataError { #[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] pub struct TensorData { /// The values of the tensor (as bytes). - #[serde(with = "serde_bytes")] - pub bytes: Vec, + pub bytes: Bytes, /// The shape of the tensor. pub shape: Vec, @@ -53,33 +52,12 @@ pub struct TensorData { pub dtype: DType, } -fn into_bytes(mut value: Vec) -> Vec { - // Ensure `E` satisfies the `Pod` trait requirements - assert_eq!(core::mem::size_of::() % core::mem::size_of::(), 0); - - let factor = core::mem::size_of::() / core::mem::size_of::(); - let len = value.len() * factor; - let capacity = value.capacity() * factor; - let ptr = value.as_mut_ptr(); - - core::mem::forget(value); - - unsafe { Vec::from_raw_parts(ptr as *mut u8, len, capacity) } -} - impl TensorData { /// Creates a new tensor data structure. - pub fn new>>(mut value: Vec, shape: S) -> Self { + pub fn new>>(value: Vec, shape: S) -> Self { // Ensure shape is valid let shape = shape.into(); - let shape_numel = Self::numel(&shape); - value.truncate(shape_numel); - let numel = value.len(); - assert_eq!( - shape_numel, numel, - "Shape {:?} is invalid for input of size {:?}", - shape, numel, - ); + Self::check_data_len(&value, &shape, None); Self::init(value, shape, E::dtype()) } @@ -93,10 +71,10 @@ impl TensorData { shape: S, strategy: QuantizationStrategy, ) -> Self { - // TODO: this method should go into a dedicated Bytes opaque type with other bytes - // handling logic - let mut value = into_bytes(value); + let shape = shape.into(); + Self::check_data_len(&value, &shape, Some(&strategy)); + let mut bytes: Bytes; // Notes on quantization data representation: // 1) The quantized values are packed into 32-bit unsigned integers. For example, int8 // quantized values pack 4 grouped values into a single `u32`. When unpacking these values, @@ -107,9 +85,9 @@ impl TensorData { match strategy { QuantizationStrategy::PerTensorAffineInt8(q) => { if TypeId::of::() == TypeId::of::() { - value = bytemuck::checked::cast_slice(&value).to_vec(); // already packed values - } else if TypeId::of::() == TypeId::of::() { - value = bytemuck::checked::cast_slice(&pack_i8s_to_u32s(&value)).to_vec(); + bytes = Bytes::from_elems(value); // already packed values + } else if let Some(value) = ::downcast_ref::>(&value) { + bytes = Bytes::from_elems(pack_i8s_to_u32s(value)); } else { panic!("Invalid quantized type"); } @@ -117,31 +95,62 @@ impl TensorData { let offset = q.offset as i32; let scale_bytes = bytemuck::bytes_of(&q.scale); let offset_bytes = bytemuck::bytes_of(&offset); - value.extend_from_slice(offset_bytes); - value.extend_from_slice(scale_bytes); + bytes.extend_from_byte_slice(offset_bytes); + bytes.extend_from_byte_slice(scale_bytes); } QuantizationStrategy::PerTensorSymmetricInt8(q) => { if TypeId::of::() == TypeId::of::() { - value = bytemuck::checked::cast_slice(&value).to_vec(); // already packed values - } else if TypeId::of::() == TypeId::of::() { - let packed = pack_i8s_to_u32s(&value); - value = bytemuck::checked::cast_slice(&packed).to_vec(); + bytes = Bytes::from_elems(value); // already packed values + } else if let Some(value) = ::downcast_ref::>(&value) { + bytes = Bytes::from_elems(pack_i8s_to_u32s(value)); } else { panic!("Invalid quantized type"); } let scale_bytes = bytemuck::bytes_of(&q.scale); - value.extend_from_slice(scale_bytes); + bytes.extend_from_byte_slice(scale_bytes); } } - Self::init(value, shape, DType::QFloat(strategy.scheme())) + Self { + bytes, + shape, + dtype: DType::QFloat(strategy.scheme()), + } + } + + // Check that the input vector contains a correct number of elements + fn check_data_len( + data: &[E], + shape: &Vec, + quantization: Option<&QuantizationStrategy>, + ) { + let mut expected_data_len = Self::numel(shape); + if let Some(quantization) = quantization { + let elem_per_data = match quantization { + QuantizationStrategy::PerTensorAffineInt8(_) + | QuantizationStrategy::PerTensorSymmetricInt8(_) => { + if TypeId::of::() == TypeId::of::() { + 4 + } else { + 1 + } + } + }; + expected_data_len = expected_data_len.div_ceil(elem_per_data); + } + let num_data = data.len(); + assert_eq!( + expected_data_len, num_data, + "Shape {:?} is invalid for input of size {:?}", + shape, num_data, + ); } /// Initializes a new tensor data structure from the provided values. - fn init>>(value: Vec, shape: S, dtype: DType) -> Self { + fn init(value: Vec, shape: Vec, dtype: DType) -> Self { Self { - bytes: into_bytes(value), - shape: shape.into(), + bytes: Bytes::from_elems(value), + shape, dtype, } } @@ -185,7 +194,7 @@ impl TensorData { } /// Returns the tensor data as a vector of scalar values. - pub fn into_vec(mut self) -> Result, DataError> { + pub fn into_vec(self) -> Result, DataError> { if E::dtype() != self.dtype { return Err(DataError::TypeMismatch(format!( "Invalid target element type (expected {:?}, got {:?})", @@ -194,19 +203,16 @@ impl TensorData { ))); } - let capacity_bytes = self.bytes.capacity(); - let length_bytes = self.bytes.len(); - let size_elem = core::mem::size_of::(); - - let capacity = capacity_bytes / size_elem; - let length = length_bytes / size_elem; - - unsafe { - let ptr = self.bytes.as_mut_ptr(); - core::mem::forget(self.bytes); - - Ok(Vec::from_raw_parts(ptr.cast::(), length, capacity)) - } + let mut me = self; + me.bytes = match me.bytes.try_into_vec::() { + Ok(elems) => return Ok(elems), + Err(bytes) => bytes, + }; + // The bytes might have been deserialized and allocated with a different align. + // In that case, we have to memcopy the data into a new vector, more suitably allocated + Ok(bytemuck::checked::try_cast_slice(me.values_as_bytes()) + .map_err(DataError::CastError)? + .to_vec()) } /// Returns an iterator over the values of the tensor data. @@ -405,7 +411,7 @@ impl TensorData { /// Returns the data as a slice of bytes. pub fn as_bytes(&self) -> &[u8] { - self.bytes.as_slice() + &self.bytes } /// Applies the data quantization strategy. diff --git a/crates/burn-tensor/src/tensor/mod.rs b/crates/burn-tensor/src/tensor/mod.rs index feed9571c4..40fa4b0f2c 100644 --- a/crates/burn-tensor/src/tensor/mod.rs +++ b/crates/burn-tensor/src/tensor/mod.rs @@ -1,12 +1,14 @@ pub(crate) mod stats; mod api; +mod bytes; mod data; mod distribution; mod element; mod shape; pub use api::*; +pub use bytes::*; pub use data::*; pub use distribution::*; pub use element::*; diff --git a/crates/burn-tensor/src/tensor/quantization/data.rs b/crates/burn-tensor/src/tensor/quantization/data.rs index 96096b784b..7f833edc58 100644 --- a/crates/burn-tensor/src/tensor/quantization/data.rs +++ b/crates/burn-tensor/src/tensor/quantization/data.rs @@ -1,10 +1,7 @@ use alloc::vec::Vec; /// Pack signed 8-bit integer values into a sequence of unsigned 32-bit integers. -/// -/// # Note -/// This assumes that the bytes represent `i8` values. -pub fn pack_i8s_to_u32s(bytes: &[u8]) -> Vec { +pub fn pack_i8s_to_u32s(bytes: &[i8]) -> Vec { // Shift and combine groups of four 8-bit values into a u32. // Same as doing this: // let result = (a_u8 & 0xFF) << 24 | (b_u8 & 0xFF) << 16 | (c_u8 & 0xFF) << 8 | (d_u8 & 0xFF); @@ -12,7 +9,7 @@ pub fn pack_i8s_to_u32s(bytes: &[u8]) -> Vec { .chunks(4) .map(|x| { x.iter().enumerate().fold(0u32, |acc, (i, x)| { - acc | (*x as i8 as u32 & 0xFF) << ((3 - i) * 8) + acc | (*x as u32 & 0xFF) << ((3 - i) * 8) }) }) .collect()