Skip to content

Commit

Permalink
Merge branch 'main' into fused/matmul
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard committed Dec 17, 2024
2 parents 97f5638 + 28f99d1 commit 9f1d47e
Show file tree
Hide file tree
Showing 51 changed files with 1,276 additions and 1,177 deletions.
82 changes: 32 additions & 50 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 5 additions & 5 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ version = "0.16.0"
atomic_float = "1"
bytemuck = "1.20.0"
candle-core = { version = "0.8" }
clap = { version = "4.5.21", features = ["derive"] }
clap = { version = "4.5.23", features = ["derive"] }
colored = "2.1.0"
console_error_panic_hook = "0.1.7"
csv = "1.3.1"
Expand Down Expand Up @@ -88,7 +88,7 @@ thiserror = "2.0.6"
tokio = { version = "1.42.0", features = ["rt", "macros"] }
tracing-appender = "0.2.3"
tracing-core = "0.1.33"
tracing-subscriber = "0.3.18"
tracing-subscriber = "0.3.19"
web-time = "1.1.0"
zip = "2.2.1"

Expand Down Expand Up @@ -131,19 +131,19 @@ ndarray = { version = "0.16.1", default-features = false }
num-traits = { version = "0.2.19", default-features = false, features = [
"libm",
] } # libm is for no_std
openblas-src = "0.10.9"
openblas-src = "0.10.10"
rand = { version = "0.8.5", default-features = false, features = [
"std_rng",
] } # std_rng is for no_std
rand_distr = { version = "0.4.3", default-features = false }
serde = { version = "1.0.215", default-features = false, features = [
serde = { version = "1.0.216", default-features = false, features = [
"derive",
"alloc",
] } # alloc is for no_std, derive is needed
serde_json = { version = "1.0.133", default-features = false }
uuid = { version = "1.11.0", default-features = false }

libc = "0.2.167"
libc = "0.2.168"
nvml-wrapper = "0.10.0"
sysinfo = "0.32.1"
systemstat = "0.2.3"
Expand Down
31 changes: 31 additions & 0 deletions burn-book/src/building-blocks/module.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,37 @@ Note that the trait doesn't require all methods to be implemented as they are al
perform no operation. If you're only interested in float tensors (like the majority of use cases),
then you can simply implement `map_float` or `visit_float`.

For example, the `ModuleMapper` trait could be implemented to clamp all parameters into the range
`[min, max]`.

```rust, ignore
/// Clamp parameters into the range `[min, max]`.
pub struct Clamp {
/// Lower-bound of the range.
pub min: f32,
/// Upper-bound of the range.
pub max: f32,
}
// Clamp all floating-point parameter tensors between `[min, max]`.
impl<B: Backend> ModuleMapper<B> for Clamp {
fn map_float<const D: usize>(
&mut self,
_id: burn::module::ParamId,
tensor: burn::prelude::Tensor<B, D>,
) -> burn::prelude::Tensor<B, D> {
tensor.clamp(self.min, self.max)
}
}
// Clamp module mapper into the range `[-0.5, 0.5]`
let mut clamp = Clamp {
min: -0.5,
max: 0.5,
};
let model = model.map(&mut clamp);
```

## Module Display

Burn provides a simple way to display the structure of a module and its configuration at a glance.
Expand Down
4 changes: 0 additions & 4 deletions crates/burn-autodiff/src/ops/qtensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,6 @@ impl<B: Backend, C: CheckpointStrategy> QTensorOps<Self> for Autodiff<B, C> {
todo!()
}

fn q_shape(tensor: &QuantizedTensor<Self>) -> Shape {
B::q_shape(tensor)
}

fn q_device(tensor: &QuantizedTensor<Self>) -> Device<Self> {
B::q_device(tensor)
}
Expand Down
4 changes: 0 additions & 4 deletions crates/burn-candle/src/ops/qtensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,6 @@ impl<F: FloatCandleElement, I: IntCandleElement> QTensorOps<Self> for Candle<F,
unimplemented!()
}

fn q_shape(tensor: &QuantizedTensor<Self>) -> Shape {
super::base::shape(&tensor.qtensor)
}

fn q_device(tensor: &QuantizedTensor<Self>) -> Device<Self> {
super::base::device(&tensor.qtensor)
}
Expand Down
4 changes: 0 additions & 4 deletions crates/burn-candle/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,6 @@ impl QTensorPrimitive for CandleQTensor {
fn scheme(&self) -> &QuantizationScheme {
&self.scheme
}

fn strategy(&self) -> QuantizationStrategy {
todo!()
}
}

impl TensorMetadata for CandleQTensor {
Expand Down
Loading

0 comments on commit 9f1d47e

Please sign in to comment.