Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor jit quantized tensor representation #2604

Merged
merged 17 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,8 @@ ahash = { version = "0.8.11", default-features = false }
portable-atomic-util = { version = "0.2.4", features = ["alloc"] }

### For the main burn branch. ###
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "1c4e0036c043422773fd6824c2a888160fca8e5e" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "1c4e0036c043422773fd6824c2a888160fca8e5e" }
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "6e6fb265346c6378e939573900c5d32b722569fa" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "6e6fb265346c6378e939573900c5d32b722569fa" }
### For local development. ###
# cubecl = { path = "../cubecl/crates/cubecl", default-features = false }
# cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false }
Expand Down
4 changes: 0 additions & 4 deletions crates/burn-autodiff/src/ops/qtensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,6 @@ impl<B: Backend, C: CheckpointStrategy> QTensorOps<Self> for Autodiff<B, C> {
todo!()
}

fn q_shape(tensor: &QuantizedTensor<Self>) -> Shape {
B::q_shape(tensor)
}

fn q_device(tensor: &QuantizedTensor<Self>) -> Device<Self> {
B::q_device(tensor)
}
Expand Down
4 changes: 0 additions & 4 deletions crates/burn-candle/src/ops/qtensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,6 @@ impl<F: FloatCandleElement, I: IntCandleElement> QTensorOps<Self> for Candle<F,
unimplemented!()
}

fn q_shape(tensor: &QuantizedTensor<Self>) -> Shape {
super::base::shape(&tensor.qtensor)
}

fn q_device(tensor: &QuantizedTensor<Self>) -> Device<Self> {
super::base::device(&tensor.qtensor)
}
Expand Down
4 changes: 0 additions & 4 deletions crates/burn-candle/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,6 @@ impl QTensorPrimitive for CandleQTensor {
fn scheme(&self) -> &QuantizationScheme {
&self.scheme
}

fn strategy(&self) -> QuantizationStrategy {
todo!()
}
}

impl TensorMetadata for CandleQTensor {
Expand Down
15 changes: 5 additions & 10 deletions crates/burn-fusion/src/backend.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
use crate::{
client::FusionClient, stream::Context, FusionClientLocator, FusionTensor, QFusionTensor,
};
use crate::{client::FusionClient, stream::Context, FusionClientLocator, FusionTensor};
use burn_tensor::{
backend::{Backend, DeviceOps},
ops::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor},
repr::{OperationDescription, QuantizedKind, ReprBackend, TensorHandle},
repr::{OperationDescription, ReprBackend, TensorHandle},
Device, Element,
};
use serde::{de::DeserializeOwned, Serialize};
Expand Down Expand Up @@ -37,7 +35,7 @@ impl<B: FusionBackend> Backend for Fusion<B> {

type BoolElem = B::BoolElem;

type QuantizedTensorPrimitive = QFusionTensor<B::FusionRuntime>;
type QuantizedTensorPrimitive = FusionTensor<B::FusionRuntime>;

type QuantizedEncoding = B::QuantizedEncoding;

Expand Down Expand Up @@ -184,10 +182,7 @@ impl<B: FusionBackend> ReprBackend for Fusion<B> {
handle.handle
}

fn quantized_tensor(
_handles: QuantizedKind<TensorHandle<Self::Handle>>,
_scheme: burn_tensor::quantization::QuantizationScheme,
) -> QuantizedTensor<Self> {
fn quantized_tensor(_handle: TensorHandle<Self::Handle>) -> QuantizedTensor<Self> {
todo!() // not as simple
}

Expand All @@ -203,7 +198,7 @@ impl<B: FusionBackend> ReprBackend for Fusion<B> {
tensor
}

fn quantized_tensor_handle(_tensor: QuantizedTensor<Self>) -> QuantizedKind<Self::Handle> {
fn quantized_tensor_handle(_tensor: QuantizedTensor<Self>) -> Self::Handle {
todo!() // not as simple
}
}
14 changes: 7 additions & 7 deletions crates/burn-fusion/src/client/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ use std::future::Future;

use crate::{
stream::{execution::Operation, StreamId},
FusionBackend, FusionDevice, FusionHandle, FusionRuntime, FusionTensor, QFusionTensor,
FusionBackend, FusionDevice, FusionHandle, FusionRuntime, FusionTensor,
};
use burn_tensor::{
repr::{OperationDescription, QuantizedTensorDescription, TensorDescription, TensorId},
repr::{OperationDescription, TensorDescription, TensorId},
DType, TensorData,
};

Expand Down Expand Up @@ -61,8 +61,8 @@ where
/// Read the values contained by a quantized tensor.
fn read_tensor_quantized<B>(
&self,
tensor: QuantizedTensorDescription,
streams: Vec<StreamId>,
tensor: TensorDescription,
streams: StreamId,
) -> impl Future<Output = TensorData> + Send + 'static
where
B: FusionBackend<FusionRuntime = R>;
Expand Down Expand Up @@ -108,10 +108,10 @@ where
/// Change the client of the given quantized tensor.
fn change_client_quantized<B>(
&self,
tensor: QuantizedTensorDescription,
tensor: TensorDescription,
client: Self,
streams: Vec<StreamId>,
) -> QFusionTensor<R>
stream: StreamId,
) -> FusionTensor<R>
where
B: FusionBackend<FusionRuntime = R>;
/// Drop the tensor with the given [tensor id](TensorId).
Expand Down
54 changes: 11 additions & 43 deletions crates/burn-fusion/src/client/mutex.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
use super::FusionClient;
use crate::{
stream::{execution::Operation, StreamId},
FusionBackend, FusionDevice, FusionHandle, FusionQuantizationParameters, FusionRuntime,
FusionServer, FusionTensor, QFusionTensor,
FusionBackend, FusionDevice, FusionHandle, FusionRuntime, FusionServer, FusionTensor,
};
use burn_tensor::{
repr::{OperationDescription, QuantizedTensorDescription, TensorDescription, TensorId},
repr::{OperationDescription, TensorDescription, TensorId},
DType,
};
use spin::Mutex;
Expand Down Expand Up @@ -115,13 +114,13 @@ where

fn read_tensor_quantized<B>(
&self,
tensor: QuantizedTensorDescription,
streams: Vec<StreamId>,
tensor: TensorDescription,
stream: StreamId,
) -> impl Future<Output = burn_tensor::TensorData> + 'static
where
B: FusionBackend<FusionRuntime = R>,
{
self.server.lock().read_quantized::<B>(tensor, streams)
self.server.lock().read_quantized::<B>(tensor, stream)
}

fn change_client_float<B>(
Expand Down Expand Up @@ -190,55 +189,24 @@ where

fn change_client_quantized<B>(
&self,
tensor: QuantizedTensorDescription,
tensor: TensorDescription,
client: Self,
streams: Vec<StreamId>,
) -> QFusionTensor<R>
stream: StreamId,
) -> FusionTensor<R>
where
B: FusionBackend<FusionRuntime = R>,
{
let mut server_other = client.server.lock();
let mut server_current = self.server.lock();
for stream in streams {
server_current.drain_stream(stream);
}
server_current.drain_stream(stream);

let mut ids =
let id =
server_current.change_server_quantized::<B>(&tensor, &client.device, &mut server_other);

core::mem::drop(server_other);
core::mem::drop(server_current);

// NOTE: the expected order is known [qtensor, scale, <offset>]
let offset = tensor.qparams.offset.map(|desc| {
FusionTensor::new(
ids.pop().unwrap(),
desc.shape,
desc.dtype,
client.clone(),
StreamId::current(),
)
});
let scale = FusionTensor::new(
ids.pop().unwrap(),
tensor.qparams.scale.shape,
tensor.qparams.scale.dtype,
client.clone(),
StreamId::current(),
);
let qtensor = FusionTensor::new(
ids.pop().unwrap(),
tensor.tensor.shape,
tensor.tensor.dtype,
client,
StreamId::current(),
);

QFusionTensor {
qtensor,
scheme: tensor.scheme,
qparams: FusionQuantizationParameters { scale, offset },
}
FusionTensor::new(id, tensor.shape, tensor.dtype, client, StreamId::current())
}

fn register_orphan(&self, id: &TensorId) {
Expand Down
Loading
Loading