From a5624c15957d94ead7bb6b351047b63c2de3909b Mon Sep 17 00:00:00 2001 From: Genna Wingert <1058083+wingertge@users.noreply.github.com> Date: Fri, 29 Nov 2024 20:20:21 +0100 Subject: [PATCH] [Optimization] Implicit gemm rewrite (#2545) --- Cargo.lock | 26 +- Cargo.toml | 4 +- backend-comparison/Cargo.toml | 2 +- backend-comparison/benches/conv2d.rs | 147 +++++- backend-comparison/benches/matmul.rs | 42 +- crates/burn-core/Cargo.toml | 6 +- .../burn-core/src/data/dataloader/batcher.rs | 1 + crates/burn-core/src/lib.rs | 2 + crates/burn-core/src/nn/rnn/lstm.rs | 1 - .../burn-fusion/src/stream/execution/tests.rs | 2 +- crates/burn-jit/src/kernel/contiguous.rs | 6 +- .../burn-jit/src/kernel/conv/conv2d/base.rs | 21 +- .../src/kernel/conv/conv2d/gemm/algorithm.rs | 124 +++++ .../src/kernel/conv/conv2d/gemm/base.rs | 140 ++++++ .../src/kernel/conv/conv2d/gemm/config.rs | 15 + .../conv/conv2d/gemm/homogeneous/base.rs | 435 ++++++++++++++++++ .../conv/conv2d/gemm/homogeneous/mod.rs | 1 + .../src/kernel/conv/conv2d/gemm/launch.rs | 269 +++++++++++ .../kernel/conv/conv2d/gemm/loader/bias.rs | 116 +++++ .../kernel/conv/conv2d/gemm/loader/im2col.rs | 147 ++++++ .../src/kernel/conv/conv2d/gemm/loader/mod.rs | 2 + .../src/kernel/conv/conv2d/gemm/mod.rs | 10 + .../kernel/conv/conv2d/gemm/reader/bias.rs | 38 ++ .../kernel/conv/conv2d/gemm/reader/im2col.rs | 112 +++++ .../src/kernel/conv/conv2d/gemm/reader/mod.rs | 2 + .../burn-jit/src/kernel/conv/conv2d/im2col.rs | 12 +- .../src/kernel/conv/conv2d/implicit_gemm.rs | 43 +- .../src/kernel/conv/conv2d/layout_swap.rs | 2 +- crates/burn-jit/src/kernel/conv/conv2d/mod.rs | 3 + .../src/kernel/conv/conv2d/tune/conv2d.rs | 41 +- .../kernel/conv/deform_conv_transpose2d.rs | 54 ++- crates/burn-jit/src/kernel/matmul/base.rs | 37 +- .../burn-jit/src/kernel/matmul/tune/base.rs | 17 +- .../src/kernel/reduce/subcube/kernel.rs | 2 +- crates/burn-jit/src/ops/transaction.rs | 2 +- crates/burn-jit/src/template/base.rs | 2 +- crates/burn-jit/src/tests/conv2d.rs | 73 ++- crates/burn-ndarray/src/tensor.rs | 2 +- crates/burn-tensor/src/tests/module/conv3d.rs | 3 +- crates/burn-train/src/metric/base.rs | 1 + crates/burn/Cargo.toml | 6 +- 41 files changed, 1830 insertions(+), 141 deletions(-) create mode 100644 crates/burn-jit/src/kernel/conv/conv2d/gemm/algorithm.rs create mode 100644 crates/burn-jit/src/kernel/conv/conv2d/gemm/base.rs create mode 100644 crates/burn-jit/src/kernel/conv/conv2d/gemm/config.rs create mode 100644 crates/burn-jit/src/kernel/conv/conv2d/gemm/homogeneous/base.rs create mode 100644 crates/burn-jit/src/kernel/conv/conv2d/gemm/homogeneous/mod.rs create mode 100644 crates/burn-jit/src/kernel/conv/conv2d/gemm/launch.rs create mode 100644 crates/burn-jit/src/kernel/conv/conv2d/gemm/loader/bias.rs create mode 100644 crates/burn-jit/src/kernel/conv/conv2d/gemm/loader/im2col.rs create mode 100644 crates/burn-jit/src/kernel/conv/conv2d/gemm/loader/mod.rs create mode 100644 crates/burn-jit/src/kernel/conv/conv2d/gemm/mod.rs create mode 100644 crates/burn-jit/src/kernel/conv/conv2d/gemm/reader/bias.rs create mode 100644 crates/burn-jit/src/kernel/conv/conv2d/gemm/reader/im2col.rs create mode 100644 crates/burn-jit/src/kernel/conv/conv2d/gemm/reader/mod.rs diff --git a/Cargo.lock b/Cargo.lock index ace556a7e8..8897110b75 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1666,7 +1666,7 @@ dependencies = [ [[package]] name = "cubecl" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2c09d4dd1ecb9f474e524dc47b05599edb7049e7#2c09d4dd1ecb9f474e524dc47b05599edb7049e7" +source = "git+https://github.com/tracel-ai/cubecl?rev=0f8312690a4328266ae9267314f50fb4ed0835ad#0f8312690a4328266ae9267314f50fb4ed0835ad" dependencies = [ "cubecl-core", "cubecl-cuda", @@ -1674,6 +1674,7 @@ dependencies = [ "cubecl-linalg", "cubecl-runtime 0.4.0", "cubecl-wgpu", + "half", ] [[package]] @@ -1697,7 +1698,7 @@ dependencies = [ [[package]] name = "cubecl-common" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2c09d4dd1ecb9f474e524dc47b05599edb7049e7#2c09d4dd1ecb9f474e524dc47b05599edb7049e7" +source = "git+https://github.com/tracel-ai/cubecl?rev=0f8312690a4328266ae9267314f50fb4ed0835ad#0f8312690a4328266ae9267314f50fb4ed0835ad" dependencies = [ "derive-new 0.6.0", "embassy-futures", @@ -1714,7 +1715,7 @@ dependencies = [ [[package]] name = "cubecl-core" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2c09d4dd1ecb9f474e524dc47b05599edb7049e7#2c09d4dd1ecb9f474e524dc47b05599edb7049e7" +source = "git+https://github.com/tracel-ai/cubecl?rev=0f8312690a4328266ae9267314f50fb4ed0835ad#0f8312690a4328266ae9267314f50fb4ed0835ad" dependencies = [ "bytemuck", "cubecl-common 0.4.0", @@ -1732,7 +1733,7 @@ dependencies = [ [[package]] name = "cubecl-cpp" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2c09d4dd1ecb9f474e524dc47b05599edb7049e7#2c09d4dd1ecb9f474e524dc47b05599edb7049e7" +source = "git+https://github.com/tracel-ai/cubecl?rev=0f8312690a4328266ae9267314f50fb4ed0835ad#0f8312690a4328266ae9267314f50fb4ed0835ad" dependencies = [ "bytemuck", "cubecl-common 0.4.0", @@ -1746,7 +1747,7 @@ dependencies = [ [[package]] name = "cubecl-cuda" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2c09d4dd1ecb9f474e524dc47b05599edb7049e7#2c09d4dd1ecb9f474e524dc47b05599edb7049e7" +source = "git+https://github.com/tracel-ai/cubecl?rev=0f8312690a4328266ae9267314f50fb4ed0835ad#0f8312690a4328266ae9267314f50fb4ed0835ad" dependencies = [ "bytemuck", "cubecl-common 0.4.0", @@ -1762,7 +1763,7 @@ dependencies = [ [[package]] name = "cubecl-hip" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2c09d4dd1ecb9f474e524dc47b05599edb7049e7#2c09d4dd1ecb9f474e524dc47b05599edb7049e7" +source = "git+https://github.com/tracel-ai/cubecl?rev=0f8312690a4328266ae9267314f50fb4ed0835ad#0f8312690a4328266ae9267314f50fb4ed0835ad" dependencies = [ "bytemuck", "cubecl-common 0.4.0", @@ -1773,6 +1774,7 @@ dependencies = [ "derive-new 0.6.0", "half", "log", + "paste", ] [[package]] @@ -1787,7 +1789,7 @@ dependencies = [ [[package]] name = "cubecl-linalg" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2c09d4dd1ecb9f474e524dc47b05599edb7049e7#2c09d4dd1ecb9f474e524dc47b05599edb7049e7" +source = "git+https://github.com/tracel-ai/cubecl?rev=0f8312690a4328266ae9267314f50fb4ed0835ad#0f8312690a4328266ae9267314f50fb4ed0835ad" dependencies = [ "bytemuck", "cubecl-core", @@ -1798,7 +1800,7 @@ dependencies = [ [[package]] name = "cubecl-macros" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2c09d4dd1ecb9f474e524dc47b05599edb7049e7#2c09d4dd1ecb9f474e524dc47b05599edb7049e7" +source = "git+https://github.com/tracel-ai/cubecl?rev=0f8312690a4328266ae9267314f50fb4ed0835ad#0f8312690a4328266ae9267314f50fb4ed0835ad" dependencies = [ "cubecl-common 0.4.0", "darling", @@ -1813,7 +1815,7 @@ dependencies = [ [[package]] name = "cubecl-opt" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2c09d4dd1ecb9f474e524dc47b05599edb7049e7#2c09d4dd1ecb9f474e524dc47b05599edb7049e7" +source = "git+https://github.com/tracel-ai/cubecl?rev=0f8312690a4328266ae9267314f50fb4ed0835ad#0f8312690a4328266ae9267314f50fb4ed0835ad" dependencies = [ "cubecl-common 0.4.0", "cubecl-core", @@ -1850,7 +1852,7 @@ dependencies = [ [[package]] name = "cubecl-runtime" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2c09d4dd1ecb9f474e524dc47b05599edb7049e7#2c09d4dd1ecb9f474e524dc47b05599edb7049e7" +source = "git+https://github.com/tracel-ai/cubecl?rev=0f8312690a4328266ae9267314f50fb4ed0835ad#0f8312690a4328266ae9267314f50fb4ed0835ad" dependencies = [ "async-channel", "async-lock", @@ -1871,7 +1873,7 @@ dependencies = [ [[package]] name = "cubecl-spirv" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2c09d4dd1ecb9f474e524dc47b05599edb7049e7#2c09d4dd1ecb9f474e524dc47b05599edb7049e7" +source = "git+https://github.com/tracel-ai/cubecl?rev=0f8312690a4328266ae9267314f50fb4ed0835ad#0f8312690a4328266ae9267314f50fb4ed0835ad" dependencies = [ "cubecl-common 0.4.0", "cubecl-core", @@ -1885,7 +1887,7 @@ dependencies = [ [[package]] name = "cubecl-wgpu" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2c09d4dd1ecb9f474e524dc47b05599edb7049e7#2c09d4dd1ecb9f474e524dc47b05599edb7049e7" +source = "git+https://github.com/tracel-ai/cubecl?rev=0f8312690a4328266ae9267314f50fb4ed0835ad#0f8312690a4328266ae9267314f50fb4ed0835ad" dependencies = [ "ash", "async-channel", diff --git a/Cargo.toml b/Cargo.toml index 5c29707122..ca45f967bb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -153,8 +153,8 @@ ahash = { version = "0.8.11", default-features = false } portable-atomic-util = { version = "0.2.4", features = ["alloc"] } ### For the main burn branch. ### -cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "2c09d4dd1ecb9f474e524dc47b05599edb7049e7" } -cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "2c09d4dd1ecb9f474e524dc47b05599edb7049e7" } +cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "0f8312690a4328266ae9267314f50fb4ed0835ad" } +cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "0f8312690a4328266ae9267314f50fb4ed0835ad" } ### For local development. ### # cubecl = { path = "../cubecl/crates/cubecl", default-features = false } # cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false } diff --git a/backend-comparison/Cargo.toml b/backend-comparison/Cargo.toml index 9e199ba8c9..1d1a62cf49 100644 --- a/backend-comparison/Cargo.toml +++ b/backend-comparison/Cargo.toml @@ -17,8 +17,8 @@ candle-cuda = ["burn/candle-cuda"] candle-metal = ["burn/candle", "burn/metal"] cuda-jit = ["burn/cuda-jit"] cuda-jit-fusion = ["cuda-jit", "burn/fusion"] -hip-jit = ["burn/hip-jit"] default = ["burn/std", "burn/autodiff", "burn/wgpu", "burn/autotune"] +hip-jit = ["burn/hip-jit"] ndarray = ["burn/ndarray"] ndarray-blas-accelerate = ["burn/ndarray", "burn/accelerate"] ndarray-blas-netlib = ["burn/ndarray", "burn/blas-netlib"] diff --git a/backend-comparison/benches/conv2d.rs b/backend-comparison/benches/conv2d.rs index d9fdf47a8c..c2a46ad64f 100644 --- a/backend-comparison/benches/conv2d.rs +++ b/backend-comparison/benches/conv2d.rs @@ -1,3 +1,5 @@ +use std::hint::black_box; + use backend_comparison::persistence::save; use burn::tensor::{ backend::Backend, module::conv2d, ops::ConvOptions, Distribution, Shape, Tensor, @@ -5,6 +7,7 @@ use burn::tensor::{ use burn_common::benchmark::{run_benchmark, Benchmark}; pub struct Conv2dBenchmark { + suffix: &'static str, input_shape: Shape, weight_shape: Shape, bias_shape: Shape, @@ -16,7 +19,7 @@ impl Benchmark for Conv2dBenchmark { type Args = (Tensor, Tensor, Tensor); fn name(&self) -> String { - "conv2d".into() + format!("conv2d-{}", self.suffix) } fn shapes(&self) -> Vec> { @@ -50,6 +53,10 @@ impl Benchmark for Conv2dBenchmark { fn sync(&self) { B::sync(&self.device) } + + fn num_samples(&self) -> usize { + 40 + } } #[allow(dead_code)] @@ -75,6 +82,7 @@ fn bench( let groups = 1; let options = ConvOptions::new(strides, padding, dilations, groups); let benchmark = Conv2dBenchmark:: { + suffix: "input_16x512x512_weight_16x3x3_stride_1", input_shape: [batch_size, channels_in, height_in, width_in].into(), weight_shape: [ channels_out, @@ -88,14 +96,135 @@ fn bench( device: device.clone(), }; - save::( - vec![run_benchmark(benchmark)], - device, - feature_name, - url, - token, - ) - .unwrap(); + let conv1 = Conv2dBenchmark:: { + suffix: "input_3x227x227_weight_96x11x11_stride_4", + input_shape: [batch_size, 3, 227, 227].into(), + weight_shape: [96, 3, 11, 11].into(), + bias_shape: [96].into(), + options: ConvOptions::new([4, 4], [0, 0], [1, 1], 1), + device: device.clone(), + }; + + let conv2 = Conv2dBenchmark:: { + suffix: "input_3x231x231_weight_96x11x11_stride_4", + input_shape: [batch_size, 3, 231, 231].into(), + weight_shape: [96, 3, 11, 11].into(), + bias_shape: [96].into(), + options: ConvOptions::new([4, 4], [0, 0], [1, 1], 1), + device: device.clone(), + }; + + let conv3 = Conv2dBenchmark:: { + suffix: "input_3x227x227_weight_64x7x7_stride_2", + input_shape: [batch_size, 3, 227, 227].into(), + weight_shape: [64, 3, 7, 7].into(), + bias_shape: [64].into(), + options: ConvOptions::new([2, 2], [0, 0], [1, 1], 1), + device: device.clone(), + }; + + let conv4 = Conv2dBenchmark:: { + suffix: "input_64x224x224_weight_64x7x7_stride_2", + input_shape: [batch_size, 64, 224, 224].into(), + weight_shape: [64, 64, 7, 7].into(), + bias_shape: [64].into(), + options: ConvOptions::new([2, 2], [0, 0], [1, 1], 1), + device: device.clone(), + }; + + let conv5 = Conv2dBenchmark:: { + suffix: "input_96x24x24_weight_256x5x5_stride_1", + input_shape: [batch_size, 96, 24, 24].into(), + weight_shape: [256, 96, 5, 5].into(), + bias_shape: [256].into(), + options: ConvOptions::new([1, 1], [0, 0], [1, 1], 1), + device: device.clone(), + }; + + let conv6 = Conv2dBenchmark:: { + suffix: "input_256x12x12_weight_512x3x3_stride_1", + input_shape: [batch_size, 256, 12, 12].into(), + weight_shape: [512, 256, 3, 3].into(), + bias_shape: [512].into(), + options: ConvOptions::new([1, 1], [0, 0], [1, 1], 1), + device: device.clone(), + }; + + let conv7 = Conv2dBenchmark:: { + suffix: "input_3x224x224_weight_64x3x3_stride_1", + input_shape: [batch_size, 3, 224, 224].into(), + weight_shape: [64, 3, 3, 3].into(), + bias_shape: [64].into(), + options: ConvOptions::new([1, 1], [0, 0], [1, 1], 1), + device: device.clone(), + }; + + let conv8 = Conv2dBenchmark:: { + suffix: "input_64x112x112_weight_128x3x3_stride_1", + input_shape: [batch_size, 64, 112, 112].into(), + weight_shape: [128, 64, 3, 3].into(), + bias_shape: [128].into(), + options: ConvOptions::new([1, 1], [0, 0], [1, 1], 1), + device: device.clone(), + }; + + let conv9 = Conv2dBenchmark:: { + suffix: "input_64x56x56_weight_64x3x3_stride_1", + input_shape: [batch_size, 64, 56, 56].into(), + weight_shape: [64, 64, 3, 3].into(), + bias_shape: [64].into(), + options: ConvOptions::new([1, 1], [0, 0], [1, 1], 1), + device: device.clone(), + }; + + let conv10 = Conv2dBenchmark:: { + suffix: "input_128x28x28_weight_128x3x3_stride_1", + input_shape: [batch_size, 128, 28, 28].into(), + weight_shape: [128, 128, 3, 3].into(), + bias_shape: [128].into(), + options: ConvOptions::new([1, 1], [0, 0], [1, 1], 1), + device: device.clone(), + }; + + let conv11 = Conv2dBenchmark:: { + suffix: "input_256x14x14_weight_256x3x3_stride_1", + input_shape: [batch_size, 256, 14, 14].into(), + weight_shape: [256, 256, 3, 3].into(), + bias_shape: [256].into(), + options: ConvOptions::new([1, 1], [0, 0], [1, 1], 1), + device: device.clone(), + }; + + let conv12 = Conv2dBenchmark:: { + suffix: "input_512x7x7_weight_512x3x3_stride_1", + input_shape: [batch_size, 512, 7, 7].into(), + weight_shape: [512, 512, 3, 3].into(), + bias_shape: [512].into(), + options: ConvOptions::new([1, 1], [0, 0], [1, 1], 1), + device: device.clone(), + }; + + let conv13 = Conv2dBenchmark:: { + suffix: "input_96x224x224_weight_64x1x1_stride_1", + input_shape: [batch_size, 96, 224, 224].into(), + weight_shape: [64, 96, 1, 1].into(), + bias_shape: [64].into(), + options: ConvOptions::new([1, 1], [0, 0], [1, 1], 1), + device: device.clone(), + }; + + let benches = vec![ + benchmark, conv1, conv2, conv3, conv4, conv5, conv6, conv7, conv8, conv9, conv10, conv11, + conv12, conv13, + ]; + let mut results = Vec::new(); + + for bench in benches { + let result = black_box(run_benchmark(bench)); + results.push(result); + } + + save::(results, device, feature_name, url, token).unwrap(); } fn main() { diff --git a/backend-comparison/benches/matmul.rs b/backend-comparison/benches/matmul.rs index e4766c3df5..d31e7cc954 100644 --- a/backend-comparison/benches/matmul.rs +++ b/backend-comparison/benches/matmul.rs @@ -1,5 +1,5 @@ use backend_comparison::persistence::save; -use burn::tensor::{backend::Backend, Distribution, Shape, Tensor}; +use burn::tensor::{backend::Backend, Shape, Tensor}; use burn_common::benchmark::{run_benchmark, Benchmark}; use derive_new::new; @@ -21,17 +21,13 @@ impl Benchmark for MatmulBenchmark { vec![self.shape_lhs.dims.clone(), self.shape_rhs.dims.clone()] } - fn num_samples(&self) -> usize { - 10 - } - fn execute(&self, (lhs, rhs): Self::Args) { - lhs.clone().matmul(rhs.clone()); + lhs.matmul(rhs); } fn prepare(&self) -> Self::Args { - let lhs = Tensor::random(self.shape_lhs.clone(), Distribution::Default, &self.device); - let rhs = Tensor::random(self.shape_rhs.clone(), Distribution::Default, &self.device); + let lhs = Tensor::zeros(self.shape_lhs.clone(), &self.device); + let rhs = Tensor::zeros(self.shape_rhs.clone(), &self.device); (lhs, rhs) } @@ -48,24 +44,18 @@ fn bench( url: Option<&str>, token: Option<&str>, ) { - const D: usize = 3; - let batch_size = 8; - let m = 2048; - let k = 2048; - let n = 2048; - let shape_lhs = [batch_size, m, k].into(); - let shape_rhs = [batch_size, k, n].into(); - - let benchmark = MatmulBenchmark::::new(shape_lhs, shape_rhs, device.clone()); - - save::( - vec![run_benchmark(benchmark)], - device, - feature_name, - url, - token, - ) - .unwrap(); + let benchmarks = [(2, 4096, 4096, 4096), (8, 2048, 2048, 2048)] + .into_iter() + .map(|(b, m, n, k)| { + let shape_lhs = [b, m, k].into(); + let shape_rhs = [b, k, n].into(); + + MatmulBenchmark::::new(shape_lhs, shape_rhs, device.clone()) + }) + .map(run_benchmark) + .collect(); + + save::(benchmarks, device, feature_name, url, token).unwrap(); } fn main() { diff --git a/crates/burn-core/Cargo.toml b/crates/burn-core/Cargo.toml index 1ef80fd4f1..68b47b3826 100644 --- a/crates/burn-core/Cargo.toml +++ b/crates/burn-core/Cargo.toml @@ -91,10 +91,10 @@ blas-netlib = ["burn-ndarray?/blas-netlib"] metal = ["burn-candle?/metal"] openblas = ["burn-ndarray?/blas-openblas"] openblas-system = ["burn-ndarray?/blas-openblas-system"] -template = ["burn-wgpu?/template"] remote = ["burn-remote/client"] router = ["burn-router"] server = ["burn-remote/server"] +template = ["burn-wgpu?/template"] candle = ["burn-candle"] candle-cuda = ["candle", "burn-candle/cuda"] @@ -138,10 +138,10 @@ burn-candle = { path = "../burn-candle", version = "0.16.0", optional = true } burn-cuda = { path = "../burn-cuda", version = "0.16.0", optional = true, default-features = false } burn-hip = { path = "../burn-hip", version = "0.16.0", optional = true, default-features = false } burn-ndarray = { path = "../burn-ndarray", version = "0.16.0", optional = true, default-features = false } -burn-tch = { path = "../burn-tch", version = "0.16.0", optional = true } -burn-wgpu = { path = "../burn-wgpu", version = "0.16.0", optional = true, default-features = false } burn-remote = { path = "../burn-remote", version = "0.16.0", default-features = false, optional = true } burn-router = { path = "../burn-router", version = "0.16.0", default-features = false, optional = true } +burn-tch = { path = "../burn-tch", version = "0.16.0", optional = true } +burn-wgpu = { path = "../burn-wgpu", version = "0.16.0", optional = true, default-features = false } data-encoding = { workspace = true } uuid = { workspace = true } diff --git a/crates/burn-core/src/data/dataloader/batcher.rs b/crates/burn-core/src/data/dataloader/batcher.rs index 2ab3b87255..b0c242952e 100644 --- a/crates/burn-core/src/data/dataloader/batcher.rs +++ b/crates/burn-core/src/data/dataloader/batcher.rs @@ -29,6 +29,7 @@ where } } +/// Test batcher #[cfg(test)] #[derive(new, Clone)] pub struct TestBatcher; diff --git a/crates/burn-core/src/lib.rs b/crates/burn-core/src/lib.rs index d1788d10cd..f554518430 100644 --- a/crates/burn-core/src/lib.rs +++ b/crates/burn-core/src/lib.rs @@ -48,6 +48,7 @@ pub use burn_remote::server; extern crate alloc; +/// Backend for test cases #[cfg(all( test, not(feature = "test-tch"), @@ -65,6 +66,7 @@ pub type TestBackend = burn_wgpu::Wgpu; #[cfg(all(test, feature = "test-cuda"))] pub type TestBackend = burn_cuda::Cuda; +/// Backend for autodiff test cases #[cfg(feature = "std")] #[cfg(test)] pub type TestAutodiffBackend = burn_autodiff::Autodiff; diff --git a/crates/burn-core/src/nn/rnn/lstm.rs b/crates/burn-core/src/nn/rnn/lstm.rs index 802d7f4720..9a7c23399b 100644 --- a/crates/burn-core/src/nn/rnn/lstm.rs +++ b/crates/burn-core/src/nn/rnn/lstm.rs @@ -384,7 +384,6 @@ mod tests { /// i_t = sigmoid(0.5*0.1 + 0.5*0) = sigmoid(0.05) = 0.5123725 /// o_t = sigmoid(1.1*0.1 + 1.1*0) = sigmoid(0.11) = 0.5274723 /// c_t = tanh(0.9*0.1 + 0.9*0) = tanh(0.09) = 0.0892937 - /// C_t = f_t * 0 + i_t * c_t = 0 + 0.5123725 * 0.0892937 = 0.04575243 /// h_t = o_t * tanh(C_t) = 0.5274723 * tanh(0.04575243) = 0.5274723 * 0.04568173 = 0.024083648 #[test] diff --git a/crates/burn-fusion/src/stream/execution/tests.rs b/crates/burn-fusion/src/stream/execution/tests.rs index 9dff06f5cd..98cea0e935 100644 --- a/crates/burn-fusion/src/stream/execution/tests.rs +++ b/crates/burn-fusion/src/stream/execution/tests.rs @@ -500,7 +500,7 @@ impl OptimizationBuilder for TestOptimizationBuilder { } } -impl<'i> StreamSegment for TestSegment<'i> { +impl StreamSegment for TestSegment<'_> { // The operations in the process. fn operations(&self) -> &[OperationDescription] { self.operations diff --git a/crates/burn-jit/src/kernel/contiguous.rs b/crates/burn-jit/src/kernel/contiguous.rs index 170f202e76..b21d032c78 100644 --- a/crates/burn-jit/src/kernel/contiguous.rs +++ b/crates/burn-jit/src/kernel/contiguous.rs @@ -7,8 +7,10 @@ pub fn into_contiguous(tensor: JitTensor) -> JitTensor { } execute_with_dtype!(tensor.dtype, E, { - let output = - cubecl::linalg::tensor::into_contiguous::(&tensor.client, tensor.as_handle_ref()); + let output = cubecl::linalg::tensor::into_contiguous::( + &tensor.client, + &tensor.as_handle_ref(), + ); JitTensor::new( tensor.client, diff --git a/crates/burn-jit/src/kernel/conv/conv2d/base.rs b/crates/burn-jit/src/kernel/conv/conv2d/base.rs index 9f07d36c55..0b3a35dc45 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/base.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/base.rs @@ -1,15 +1,12 @@ -use burn_tensor::{ - ops::{ConvOptions, ConvTransposeOptions}, - TensorData, -}; +use burn_tensor::ops::{ConvOptions, ConvTransposeOptions}; -use crate::{tensor::JitTensor, FloatElement, IntElement, JitElement, JitRuntime}; +use crate::{tensor::JitTensor, FloatElement, IntElement, JitRuntime}; #[cfg(feature = "autotune")] use super::{conv2d_autotune, conv_transpose2d_autotune}; use super::{ conv2d_direct, conv2d_im2col, conv_transpose2d_col2im, conv_transpose2d_direct, - implicit_gemm::conv2d_implicit_gemm, + gemm::launch::conv2d_gemm_cmma_large_m, implicit_gemm::conv2d_implicit_gemm, }; /// The strategy to be used when launching a convolution kernel. @@ -24,6 +21,9 @@ pub enum Conv2dStrategy { /// Implicit GEMM implementation of convolution. Lower memory usage but requires CMMA and /// has constraints on tensor shape. ImplicitGemm, + /// Implicit GEMM implementation of convolution. Uses `cubecl` matmul components to provide + /// the flexibility needed to work well for varied problem sizes. + ImplicitGemmComplex, } impl Default for Conv2dStrategy { @@ -82,6 +82,9 @@ pub fn conv2d( Conv2dStrategy::Autotune => conv2d_autotune::(input, weight, bias, options), Conv2dStrategy::Gemm => conv2d_im2col::(input, weight, bias, options), Conv2dStrategy::ImplicitGemm => conv2d_implicit_gemm::(input, weight, bias, options), + Conv2dStrategy::ImplicitGemmComplex => { + conv2d_gemm_cmma_large_m::(input, weight, bias, options) + } } } @@ -113,9 +116,3 @@ pub fn conv_transpose2d( } } } - -#[allow(unused)] -pub(crate) fn debug_data(tensor: JitTensor) -> TensorData { - let bytes = tensor.client.read_one(tensor.handle.binding()); - TensorData::new(E::from_bytes(&bytes).to_vec(), tensor.shape) -} diff --git a/crates/burn-jit/src/kernel/conv/conv2d/gemm/algorithm.rs b/crates/burn-jit/src/kernel/conv/conv2d/gemm/algorithm.rs new file mode 100644 index 0000000000..7a210e7c70 --- /dev/null +++ b/crates/burn-jit/src/kernel/conv/conv2d/gemm/algorithm.rs @@ -0,0 +1,124 @@ +use std::marker::PhantomData; + +use cubecl::{ + linalg::matmul::{ + components::{ + stage::{self, StageSize}, + tile::{ + self, + accelerated::{Accelerated16x16x16, CmmaValid}, + Matmul as _, + }, + MatmulKernel, + }, + kernels::{matmul::AdvancedConfig, MatmulAvailabilityError}, + }, + prelude::*, +}; + +use super::{ + base::{Convolution, ConvolutionKernel, ConvolutionLaunch, ConvolutionProblem}, + homogeneous::base::ImplicitGemmConvolution, +}; + +/// Specifications for a convolution algorithm +pub trait Algorithm { + const PLANE_DIM: u32; + + type EG: Numeric; + type ES: Numeric; + type EA: Numeric; + + type TileMatmul: tile::Matmul + MatmulKernel; + + type StageSize: StageSize; + type StageMatmul: stage::Matmul + MatmulKernel; + + type GlobalConvolution: Convolution + + ConvolutionLaunch; + + /// Cube dim for launch + fn cube_dim() -> CubeDim; + /// The cube count for a given convolution problem + fn cube_count(problem: &ConvolutionProblem) -> CubeCount; + + /// Make a convolution config from a convolution problem, and launch options + fn make_config( + problem: &ConvolutionProblem, + cube_dim: &CubeDim, + cube_count: &CubeCount, + advanced_config: &AdvancedConfig, + ) -> >::Config { + Self::GlobalConvolution::make_config(problem, cube_dim, cube_count, advanced_config) + } + + /// Check availability of the matmul algorithm + fn check_availability( + client: &ComputeClient, + ) -> Result<(), MatmulAvailabilityError> { + Self::GlobalConvolution::check_availability::(client) + } + + /// Determine whether the given convolution problem is valid to launch (within hardware limits) + fn can_launch( + client: &ComputeClient, + problem: &ConvolutionProblem, + ) -> bool { + if problem.options.groups > 1 || Self::check_availability::(client).is_err() { + return false; + } + + let cube_count = Self::cube_count(problem); + let (max_x, max_y, max_z) = R::max_cube_count(); + match cube_count { + CubeCount::Static(x, y, z) => x <= max_x && y <= max_y && z <= max_z, + _ => true, + } + } +} + +/// Cmma convolution +pub struct Cmma { + pub _eg: PhantomData, + pub _es: PhantomData, + pub _ea: PhantomData, + pub _stage: PhantomData, +} + +impl Algorithm + for Cmma +where + (ES, EA): CmmaValid, +{ + const PLANE_DIM: u32 = 32; + type EG = EG; + type ES = ES; + type EA = EA; + + type TileMatmul = Accelerated16x16x16; + + type StageSize = Stage; + type StageMatmul = stage::multi_buffer::Matmul< + Self::ES, + Self::EG, + Self::EA, + Self::TileMatmul, + Self::StageSize, + >; + + type GlobalConvolution = + ImplicitGemmConvolution; + + fn cube_dim() -> CubeDim { + CubeDim::new(Self::PLANE_DIM, Self::StageSize::NUM_M, 1) + } + + fn cube_count(problem: &ConvolutionProblem) -> CubeCount { + let m_stage = Self::StageSize::NUM_M * Self::TileMatmul::M; + let n_stage = Self::StageSize::NUM_N * Self::TileMatmul::N; + let cubes_needed_m = (problem.m as u32).div_ceil(m_stage); + let cubes_needed_n = (problem.n as u32).div_ceil(n_stage); + + CubeCount::Static(cubes_needed_m, cubes_needed_n, 1) + } +} diff --git a/crates/burn-jit/src/kernel/conv/conv2d/gemm/base.rs b/crates/burn-jit/src/kernel/conv/conv2d/gemm/base.rs new file mode 100644 index 0000000000..bc242107f9 --- /dev/null +++ b/crates/burn-jit/src/kernel/conv/conv2d/gemm/base.rs @@ -0,0 +1,140 @@ +use burn_tensor::ops::ConvOptions; +use cubecl::linalg::matmul::{ + components::{ + global::{AccumulatorLoader, Unloader}, + stage, MatmulProblem, MatrixLayout, + }, + kernels::{matmul::AdvancedConfig, MatmulAvailabilityError}, +}; +use cubecl::prelude::*; + +use super::Config; + +#[cube] +pub trait Convolution>: + 'static + Send + Sync + ConvolutionKernel +{ + type LhsLoader: CubeType; + type RhsLoader: CubeType; + type AccumulatorLoader: AccumulatorLoader; + + type Out: Unloader; + type Accumulator: CubeType; + + /// Performs the convolution over data loaded by the + /// LHS and RHS loaders, over the range given for K, and stores with + /// using the output unloader. + /// + /// To compute the whole range of k values, use k_range=(0, K) where + /// K is the K dimension of LHS and RHS. + fn execute( + lhs_loader: Self::LhsLoader, + rhs_loader: Self::RhsLoader, + acc_loader: Self::AccumulatorLoader, + unloader: Self::Out, + acc: &mut Self::Accumulator, + k_range: (u32, u32), + #[comptime] config: Self::Config, + ); + + fn init_lhs_loader( + lhs: &Tensor>, + x_offset: u32, + y_offset: u32, + #[comptime] config: Self::Config, + ) -> Self::LhsLoader; + + fn init_rhs_loader( + rhs: &Tensor>, + x_offset: u32, + y_offset: u32, + #[comptime] config: Self::Config, + ) -> Self::RhsLoader; + + fn init_bias_loader( + rhs: &Tensor>, + n_offset: u32, + #[comptime] config: Self::Config, + #[comptime] has_bias: bool, + ) -> Self::AccumulatorLoader; + + fn init_unloader(out: &mut Tensor>, x_offset: u32, y_offset: u32) -> Self::Out; + + fn init_accumulator(#[comptime] config: Self::Config) -> Self::Accumulator; +} + +/// Provides configuration for a matmul kernel at any level +pub trait ConvolutionKernel { + /// Configuration tailored to the matmul implementation + type Config: Config; + + /// Asserts that the configuration for this matmul will lead to a valid computation + fn check_config(config: Self::Config); + + /// Checks if the client can handle the features used in this computation + fn check_availability( + client: &ComputeClient, + ) -> Result<(), MatmulAvailabilityError>; + + fn make_config( + problem: &ConvolutionProblem, + cube_dim: &CubeDim, + cube_count: &CubeCount, + advanced_config: &AdvancedConfig, + ) -> Self::Config; +} + +/// Provides launch entry point to solve a matmul +pub trait ConvolutionLaunch: ConvolutionKernel { + /// Entry point + /// + /// # Safety + /// + /// Out-of-bounds can happen + #[allow(clippy::too_many_arguments)] + unsafe fn launch_unchecked( + client: &ComputeClient<::Server, ::Channel>, + cube_dim: CubeDim, + cube_count: CubeCount, + input: TensorArg<'_, R>, + weight: TensorArg<'_, R>, + bias: TensorArg<'_, R>, + out: TensorArg<'_, R>, + config: >::Config, + ); +} + +#[derive(Clone)] +/// Description of a matmul problem to solve, regardless of actual data +pub struct ConvolutionProblem { + pub m: usize, + pub n: usize, + pub k: usize, + pub lhs_layout: MatrixLayout, + pub rhs_layout: MatrixLayout, + pub lhs_line_size: u8, + pub rhs_line_size: u8, + pub out_line_size: u8, + + pub kernel_size: (u32, u32), + pub options: ConvOptions<2>, + pub out_shape_y: usize, + pub out_shape_x: usize, + pub has_bias: bool, +} + +impl ConvolutionProblem { + pub fn as_matmul_problem(&self) -> MatmulProblem { + MatmulProblem { + m: self.m, + n: self.n, + k: self.k, + batches: (vec![], vec![]), + lhs_layout: self.lhs_layout, + rhs_layout: self.rhs_layout, + lhs_line_size: self.lhs_line_size, + rhs_line_size: self.rhs_line_size, + out_line_size: self.out_line_size, + } + } +} diff --git a/crates/burn-jit/src/kernel/conv/conv2d/gemm/config.rs b/crates/burn-jit/src/kernel/conv/conv2d/gemm/config.rs new file mode 100644 index 0000000000..7895a5cc1a --- /dev/null +++ b/crates/burn-jit/src/kernel/conv/conv2d/gemm/config.rs @@ -0,0 +1,15 @@ +use cubecl::linalg::matmul::components::global; + +/// Convolution specific config, extends regular matmul [`Config`](global::Config) +pub trait Config: global::Config { + /// The shape of the output at `dim` + fn out_shape(&self, dim: u32) -> u32; + /// The size of the convolution kernel at `dim` + fn kernel_size(&self, dim: u32) -> u32; + /// The dilation of the kernel at `dim` + fn dilation(&self, dim: u32) -> u32; + /// The stride of the kernel at `dim` + fn stride(&self, dim: u32) -> u32; + /// The padding of the kernel at `dim` + fn padding(&self, dim: u32) -> i32; +} diff --git a/crates/burn-jit/src/kernel/conv/conv2d/gemm/homogeneous/base.rs b/crates/burn-jit/src/kernel/conv/conv2d/gemm/homogeneous/base.rs new file mode 100644 index 0000000000..ca7399e72d --- /dev/null +++ b/crates/burn-jit/src/kernel/conv/conv2d/gemm/homogeneous/base.rs @@ -0,0 +1,435 @@ +use cubecl::{ + linalg::matmul::{ + components::{ + global::{ + self, + homogeneous::{self, CyclicLoading, RhsLoader}, + unloader::Unloader, + AccumulatorLoader, Config as _, Loader, + }, + stage::{ + self, + multi_buffer::{LhsReader, RhsReader}, + TilingOrderConfig, + }, + Ident, MatrixLayout, StageDim, + }, + kernels::{matmul::AdvancedConfig, MatmulAvailabilityError}, + }, + prelude::*, +}; +use std::marker::PhantomData; + +use crate::kernel::conv::{ + conv2d::gemm::base::{Convolution, ConvolutionKernel, ConvolutionLaunch, ConvolutionProblem}, + loader::im2col::SimpleIm2colLoader, +}; +use crate::kernel::conv::{conv2d::gemm::Config as _, loader::bias::BiasLoader}; + +/// Performs matrix multiplication at the global level, with each plane sharing the same responsibilities +/// - All planes load data to the stage +/// - All planes are used in the stage matmul computation +pub struct ImplicitGemmConvolution< + EG: Numeric, + ES: Numeric, + Acc: Numeric, + SMM: stage::Matmul, +> { + _eg: PhantomData, + _es: PhantomData, + _acc: PhantomData, + _stage_matmul: PhantomData, +} + +#[cube] +impl Convolution + for ImplicitGemmConvolution +where + EG: Numeric, + ES: Numeric, + Acc: Numeric, + SMMConf: stage::Config, + SMM: stage::Matmul< + ES, + EG, + Acc, + LhsReader = LhsReader, + RhsReader = RhsReader, + Config = SMMConf, + >, +{ + type LhsLoader = SimpleIm2colLoader; + type RhsLoader = RhsLoader; + type AccumulatorLoader = BiasLoader; + + type Out = Unloader; + type Accumulator = SMM::Accumulator; + + fn execute( + mut lhs_loader: Self::LhsLoader, + mut rhs_loader: Self::RhsLoader, + mut acc_loader: Self::AccumulatorLoader, + mut out_unloader: Self::Out, + acc: &mut Self::Accumulator, + k_range: (u32, u32), + #[comptime] config: Self::Config, + ) { + let k_step = SMM::K; + let range = k_range.1 - k_range.0; + #[allow(clippy::manual_div_ceil)] + let num_loops = (range + k_step - 1) / k_step; + + Self::AccumulatorLoader::fill_stage(&mut acc_loader, config.to_smm_config()); + let (mut lhs_tile, mut rhs_tile) = SMM::init_tile_inputs(config.to_smm_config()); + + sync_units(); + + SMM::fill_accumulator::( + &mut acc_loader, + acc, + config.to_smm_config(), + ); + + for _ in 0..num_loops { + sync_units(); + + let lhs_stage_reader = &Self::LhsLoader::fill_stage(&mut lhs_loader, config); + let rhs_stage_reader = + &Self::RhsLoader::fill_stage(&mut rhs_loader, config.to_matmul_config()); + + sync_units(); + + SMM::execute( + lhs_stage_reader, + rhs_stage_reader, + &mut lhs_tile, + &mut rhs_tile, + acc, + config.to_smm_config(), + ); + + Self::LhsLoader::advance_view(&mut lhs_loader, k_step); + Self::RhsLoader::advance_view(&mut rhs_loader, k_step); + } + + sync_units(); + + SMM::read_accumulator::( + acc, + &mut out_unloader, + config.to_smm_config(), + config, + ); + } + + fn init_lhs_loader( + lhs: &Tensor>, + x_offset: u32, + y_offset: u32, + #[comptime] config: Self::Config, + ) -> Self::LhsLoader { + Self::LhsLoader::new( + lhs, + config.out_shape(0), + config.out_shape(1), + x_offset, + y_offset, + config, + ) + } + + fn init_rhs_loader( + rhs: &Tensor>, + x_offset: u32, + y_offset: u32, + #[comptime] config: Self::Config, + ) -> Self::RhsLoader { + Self::RhsLoader::new::(rhs, x_offset, y_offset, 0, config) + } + + fn init_bias_loader( + bias: &Tensor>, + n_offset: u32, + #[comptime] config: Self::Config, + #[comptime] has_bias: bool, + ) -> Self::AccumulatorLoader { + Self::AccumulatorLoader::new(bias, n_offset, config.to_smm_config(), has_bias) + } + + fn init_unloader(out: &mut Tensor>, x_offset: u32, y_offset: u32) -> Self::Out { + Self::Out::new(out, x_offset, y_offset, 0) + } + + fn init_accumulator(#[comptime] config: Self::Config) -> Self::Accumulator { + SMM::init_accumulator(config.to_smm_config()) + } +} + +impl ConvolutionKernel for ImplicitGemmConvolution +where + EG: Numeric, + ES: Numeric, + Acc: Numeric, + SMM: stage::Matmul, +{ + type Config = config::Config>; + + fn check_config(config: Self::Config) { + SMM::check_config(config.to_smm_config()); + } + + fn check_availability( + client: &ComputeClient, + ) -> Result<(), MatmulAvailabilityError> { + SMM::check_availability::(client) + } + + fn make_config( + problem: &ConvolutionProblem, + cube_dim: &CubeDim, + cube_count: &CubeCount, + advanced_config: &AdvancedConfig, + ) -> Self::Config { + let smm_config = SMM::make_config( + &problem.as_matmul_problem(), + cube_dim, + cube_count, + advanced_config, + ); + + config::Config::new( + homogeneous::Config::new( + smm_config, + problem.m as u32 % SMM::M != 0, + problem.n as u32 % SMM::N != 0, + problem.k as u32 % SMM::K != 0, + problem.lhs_layout, + problem.rhs_layout, + problem.lhs_line_size as u32, + problem.rhs_line_size as u32, + problem.out_line_size as u32, + ), + (problem.out_shape_y as u32, problem.out_shape_x as u32), + problem.kernel_size, + &problem.options, + problem.has_bias, + ) + } +} + +impl< + EG: Numeric, + ES: Numeric, + Acc: Numeric, + SMM: stage::Matmul, RhsReader = RhsReader>, + > ConvolutionLaunch for ImplicitGemmConvolution +{ + unsafe fn launch_unchecked( + client: &ComputeClient<::Server, ::Channel>, + cube_dim: CubeDim, + cube_count: CubeCount, + input: TensorArg<'_, R>, + weight: TensorArg<'_, R>, + bias: TensorArg<'_, R>, + out: TensorArg<'_, R>, + config: >::Config, + ) { + Self::check_config(config); + + implicit_conv::launch_unchecked::( + client, + cube_count, + cube_dim, + input, + weight, + bias, + out, + config, + config.has_bias, + ); + } +} + +#[cube(launch_unchecked)] +pub(crate) fn implicit_conv< + EG: Numeric, + ES: Numeric, + Acc: Numeric, + GMM: Convolution, + SMM: stage::Matmul, +>( + lhs: &Tensor>, + rhs: &Tensor>, + bias: &Tensor>, + out: &mut Tensor>, + #[comptime] config: GMM::Config, + #[comptime] has_bias: bool, +) { + let x_offset = CUBE_POS_X * config.stage_dim(Ident::Lhs).num_elements_x_dim(); + let y_offset = CUBE_POS_Y * config.stage_dim(Ident::Rhs).num_elements_y_dim(); + let k_range = (0, rhs.shape(0)); + + GMM::execute( + GMM::init_lhs_loader(lhs, x_offset, k_range.0, config), + GMM::init_rhs_loader(rhs, k_range.0, y_offset, config), + GMM::init_bias_loader(bias, y_offset, config, has_bias), + GMM::init_unloader(out, x_offset, y_offset), + &mut GMM::init_accumulator(config), + k_range, + config, + ); +} + +pub mod config { + use std::ops::Deref; + + use burn_tensor::ops::ConvOptions; + use cubecl::linalg::matmul::components::MatmulConfig; + + use crate::kernel::conv::conv2d::gemm::{self}; + + use super::*; + + #[derive(CubeType, Copy, Clone, Debug, Hash, PartialEq, Eq)] + pub struct Config { + matmul: M, + + out_shape: (u32, u32), + + kernel_size: (u32, u32), + stride: (u32, u32), + dilation: (u32, u32), + padding: (i32, i32), + + pub has_bias: bool, + } + + impl Deref for Config { + type Target = M; + + fn deref(&self) -> &Self::Target { + &self.matmul + } + } + + impl global::Config for Config { + type SmmConfig = M::SmmConfig; + + fn to_smm_config(&self) -> Self::SmmConfig { + self.matmul.to_smm_config() + } + + fn global_line_size(&self, ident: Ident) -> u32 { + self.matmul.global_line_size(ident) + } + + fn stage_line_size(&self, ident: Ident) -> u32 { + self.matmul.stage_line_size(ident) + } + + fn stage_dim(&self, ident: Ident) -> Box { + self.matmul.stage_dim(ident) + } + + fn layout(&self, ident: Ident) -> MatrixLayout { + self.matmul.layout(ident) + } + + fn num_planes(&self) -> u32 { + self.matmul.num_planes() + } + + fn plane_dim(&self) -> u32 { + self.matmul.plane_dim() + } + + fn tiling_order(&self, ident: Ident) -> TilingOrderConfig { + self.matmul.tiling_order(ident) + } + + fn check_m_bounds(&self) -> bool { + self.matmul.check_m_bounds() + } + + fn check_n_bounds(&self) -> bool { + self.matmul.check_n_bounds() + } + + fn check_k_bounds(&self) -> bool { + self.matmul.check_k_bounds() + } + + fn transpose_load(&self, ident: Ident) -> bool { + self.matmul.transpose_load(ident) + } + } + + impl gemm::Config for Config { + fn out_shape(&self, dim: u32) -> u32 { + match dim { + 0 => self.out_shape.0, + 1 => self.out_shape.1, + _ => unreachable!(), + } + } + + fn kernel_size(&self, dim: u32) -> u32 { + match dim { + 0 => self.kernel_size.0, + 1 => self.kernel_size.1, + _ => unreachable!(), + } + } + + fn dilation(&self, dim: u32) -> u32 { + match dim { + 0 => self.dilation.0, + 1 => self.dilation.1, + _ => unreachable!(), + } + } + + fn stride(&self, dim: u32) -> u32 { + match dim { + 0 => self.stride.0, + 1 => self.stride.1, + _ => unreachable!(), + } + } + + fn padding(&self, dim: u32) -> i32 { + match dim { + 0 => self.padding.0, + 1 => self.padding.1, + _ => unreachable!(), + } + } + } + + impl MatmulConfig for Config {} + + impl Config { + #[allow(clippy::too_many_arguments)] + pub fn new( + matmul: M, + out_shape: (u32, u32), + kernel_size: (u32, u32), + conv_args: &ConvOptions<2>, + has_bias: bool, + ) -> Self { + Self { + matmul, + out_shape, + kernel_size, + stride: (conv_args.stride[0] as u32, conv_args.stride[1] as u32), + dilation: (conv_args.dilation[0] as u32, conv_args.dilation[1] as u32), + padding: (conv_args.padding[0] as i32, conv_args.padding[1] as i32), + has_bias, + } + } + + pub fn to_matmul_config(self) -> M { + self.matmul + } + } +} diff --git a/crates/burn-jit/src/kernel/conv/conv2d/gemm/homogeneous/mod.rs b/crates/burn-jit/src/kernel/conv/conv2d/gemm/homogeneous/mod.rs new file mode 100644 index 0000000000..6cf245d4db --- /dev/null +++ b/crates/burn-jit/src/kernel/conv/conv2d/gemm/homogeneous/mod.rs @@ -0,0 +1 @@ +pub mod base; diff --git a/crates/burn-jit/src/kernel/conv/conv2d/gemm/launch.rs b/crates/burn-jit/src/kernel/conv/conv2d/gemm/launch.rs new file mode 100644 index 0000000000..0ecf1880a6 --- /dev/null +++ b/crates/burn-jit/src/kernel/conv/conv2d/gemm/launch.rs @@ -0,0 +1,269 @@ +use burn_tensor::{ + ops::{conv::calculate_conv_output_size, ConvOptions}, + Shape, +}; +use cubecl::{ + ir::{Elem, FloatKind}, + linalg::matmul::{ + self, + components::{ + stage::{S4x2x4, S8x4x2}, + MatrixLayout, + }, + }, + tensor_line_size, tf32, Feature, +}; +use half::{bf16, f16}; + +use crate::{ + kernel::{ + conv::{ + conv2d::gemm::{ + algorithm::{Algorithm, Cmma}, + base::{ConvolutionLaunch, ConvolutionProblem}, + }, + nchw_to_nhwc, Conv2dAutotuneKey, + }, + into_contiguous, + }, + ops::{numeric::empty_device, permute, reshape}, + tensor::JitTensor, + FloatElement, JitRuntime, +}; + +/// Large m stage size for the usual case where `batch_size * out_h * out_w` is significantly larger +/// than `out_channels` +pub type CmmaLargeMAlgorithm = Cmma; +/// Balanced stage size for cases where `batch_size * out_h * out_w` is relatively small and `k` or +/// `out_channels` is relatively large +pub type CmmaBalancedAlgorithm = Cmma; + +macro_rules! select_launch_algo { + ($algo:tt, $float:ty, $input:expr) => { + match (<$float>::as_elem(), has_tf32(&$input)) { + (Elem::Float(FloatKind::F32), true) => { + conv2d_gemm_with_algo::> + } + (Elem::Float(FloatKind::F16), _) => { + conv2d_gemm_with_algo::> + } + (Elem::Float(FloatKind::BF16), _) => { + conv2d_gemm_with_algo::> + } + _ => conv2d_gemm_with_algo::>, + } + }; +} + +/// Perform a 2D convolution using the implicit GEMM (im2col) algorithm, using cubecl tiling matmul +/// components. Uses [`CmmaLargeMAlgorithm`] for the stage size +/// +/// * `input` - The input feature map +/// * `weight` - The weights (filter) applied to each kernel +/// * `bias` - The bias added to each channel +/// * `options` - The options to use for the convolution +pub fn conv2d_gemm_cmma_large_m( + input: JitTensor, + weight: JitTensor, + bias: Option>, + options: ConvOptions<2>, +) -> JitTensor { + let launch = select_launch_algo!(CmmaLargeMAlgorithm, F, input); + launch(input, weight, bias, options) +} + +/// Perform a 2D convolution using the implicit GEMM (im2col) algorithm, using cubecl tiling matmul +/// components. Uses [`CmmaBalancedAlgorithm`] for the stage size +/// +/// * `input` - The input feature map +/// * `weight` - The weights (filter) applied to each kernel +/// * `bias` - The bias added to each channel +/// * `options` - The options to use for the convolution +/// +/// +pub fn conv2d_gemm_cmma_balanced( + input: JitTensor, + weight: JitTensor, + bias: Option>, + options: ConvOptions<2>, +) -> JitTensor { + let launch = select_launch_algo!(CmmaBalancedAlgorithm, F, input); + launch(input, weight, bias, options) +} + +/// Perform a 2D convolution using the implicit GEMM (im2col) algorithm, using cubecl tiling matmul +/// components, using the specified algorithm. +/// +/// * `input` - The input feature map +/// * `weight` - The weights (filter) applied to each kernel +/// * `bias` - The bias added to each channel +/// * `options` - The options to use for the convolution +/// +/// +pub fn conv2d_gemm_with_algo>( + input: JitTensor, + weight: JitTensor, + bias: Option>, + options: ConvOptions<2>, +) -> JitTensor { + let [batch_size, in_channels, height, width] = input.shape.dims(); + let [out_channels, _, kernel_h, kernel_w] = weight.shape.dims(); + + let out_h = calculate_conv_output_size( + kernel_h, + options.stride[0], + options.padding[0], + options.dilation[0], + height, + ); + let out_w = calculate_conv_output_size( + kernel_w, + options.stride[1], + options.padding[1], + options.dilation[1], + width, + ); + + let input = match input.is_contiguous() { + true => nchw_to_nhwc::(input), + false => into_contiguous(permute(input, &[0, 2, 3, 1])), + }; + let weight = into_contiguous(permute(weight, &[2, 3, 1, 0])); + + // Implicit GEMM matrix size + let gemm_m = batch_size * out_h * out_w; + let gemm_n = out_channels; + let gemm_k = kernel_h * kernel_w * in_channels; + + let weight = reshape(weight, Shape::new([gemm_k, gemm_n])); + + let out_shape = Shape::new([gemm_m, gemm_n]); + let out = empty_device::(input.client.clone(), input.device.clone(), out_shape); + + // Target 128 bit accesses + let available_vectorizations = R::supported_line_sizes() + .iter() + .copied() + .filter(|it| *it as usize * size_of::() <= 16) + .collect::>(); + let lhs_line_size = tensor_line_size( + &available_vectorizations, + &input.shape.dims, + &input.strides, + 3, + ); + let rhs_line_size = tensor_line_size( + &available_vectorizations, + &weight.shape.dims, + &weight.strides, + 1, + ); + let out_line_size = + tensor_line_size(&available_vectorizations, &out.shape.dims, &out.strides, 1); + + let problem = ConvolutionProblem { + m: gemm_m, + n: gemm_n, + k: gemm_k, + lhs_layout: matmul::components::MatrixLayout::RowMajor, + rhs_layout: matmul::components::MatrixLayout::RowMajor, + lhs_line_size, + rhs_line_size, + out_line_size, + + kernel_size: (kernel_h as u32, kernel_w as u32), + options, + out_shape_y: out_h, + out_shape_x: out_w, + + has_bias: bias.is_some(), + }; + + if !Alg::can_launch::(&input.client, &problem) { + panic!("Can't do implicit GEMM"); + } + + let cube_dim = Alg::cube_dim(); + let cube_count = Alg::cube_count(&problem); + + let advanced_config = Default::default(); + let config = Alg::make_config(&problem, &cube_dim, &cube_count, &advanced_config); + let bias = bias.unwrap_or_else(|| { + empty_device::(input.client.clone(), input.device.clone(), Shape::new([1])) + }); + + unsafe { + Alg::GlobalConvolution::launch_unchecked::( + &input.client, + cube_dim, + cube_count, + input.as_tensor_arg::(lhs_line_size), + weight.as_tensor_arg::(rhs_line_size), + bias.as_tensor_arg::(out_line_size), + out.as_tensor_arg::(out_line_size), + config, + ); + } + + // Reset to NCHW + let out = reshape(out, Shape::new([batch_size, out_h, out_w, out_channels])); + permute(out, &[0, 3, 1, 2]) +} + +pub fn problem_from_key( + key: &Conv2dAutotuneKey, + out_h: usize, + out_w: usize, +) -> ConvolutionProblem { + let in_stride_2 = key.in_channels; + let in_stride_1 = key.width * in_stride_2; + let in_stride_0 = key.height * in_stride_1; + + let m = key.batch_size * out_h * out_w; + let n = key.out_channels; + let k = key.kernel_size[0] * key.kernel_size[1] * key.in_channels; + + let options = ConvOptions { + stride: key.stride, + padding: key.padding, + dilation: key.dilation, + groups: key.groups, + }; + + // Target 128 bit accesses + let available_vectorizations = R::supported_line_sizes() + .iter() + .copied() + .filter(|it| *it as usize * size_of::() <= 16) + .collect::>(); + let lhs_line_size = tensor_line_size( + &available_vectorizations, + &[key.batch_size, key.height, key.width, key.in_channels], + &[in_stride_0, in_stride_1, in_stride_2, 1], + 3, + ); + let rhs_line_size = tensor_line_size(&available_vectorizations, &[k, n], &[n, 1], 1); + let out_line_size = tensor_line_size(&available_vectorizations, &[m, n], &[n, 1], 1); + + ConvolutionProblem { + m, + n, + k, + lhs_layout: MatrixLayout::RowMajor, + rhs_layout: MatrixLayout::RowMajor, + lhs_line_size, + rhs_line_size, + out_line_size, + kernel_size: (key.kernel_size[0] as u32, key.kernel_size[1] as u32), + options, + out_shape_y: out_h, + out_shape_x: out_w, + has_bias: key.has_bias, + } +} + +pub(crate) fn has_tf32(c: &JitTensor) -> bool { + c.client + .properties() + .feature_enabled(Feature::Type(Elem::Float(FloatKind::TF32))) +} diff --git a/crates/burn-jit/src/kernel/conv/conv2d/gemm/loader/bias.rs b/crates/burn-jit/src/kernel/conv/conv2d/gemm/loader/bias.rs new file mode 100644 index 0000000000..bb4c5bb017 --- /dev/null +++ b/crates/burn-jit/src/kernel/conv/conv2d/gemm/loader/bias.rs @@ -0,0 +1,116 @@ +use std::marker::PhantomData; + +use cubecl::{ + linalg::matmul::components::{ + global::AccumulatorLoader, + stage::{self, Stage}, + tile::{self, Config as _}, + Ident, + }, + prelude::*, +}; + +use crate::kernel::conv::reader::bias::BiasReader; + +/// Special loader to broadcast the 1D bias to the 2D accumulator matrix +#[derive(CubeType)] +pub struct BiasLoader { + pub tensor_view: BiasReader, + pub stage: Stage, + pub has_bias: bool, + _config: PhantomData, +} + +#[cube] +impl AccumulatorLoader + for BiasLoader +{ + fn fill_stage(this: &mut Self, #[comptime] config: G) { + if this.has_bias { + let stage_dim = config.stage_dim(Ident::Rhs); + let line_size = config.line_size(Ident::Out); + + let num_stage_elements = stage_dim.num_elements_y_dim(); + + let unit_id = UNIT_POS_Y * config.plane_dim() + UNIT_POS_X; + let unit_position_base = unit_id * line_size; + + let mut slice = this.stage.as_slice_mut(); + + if unit_position_base < num_stage_elements { + let read_line = this + .tensor_view + .load_simple::(unit_position_base, config); + slice[unit_id] = Line::cast_from(read_line); + } + } + } + + /// Load accumulator + fn load>( + this: &mut Self, + acc: &mut Tile::Accumulator, + tile_n: u32, + #[comptime] config: Tile::Config, + ) { + if this.has_bias { + let line_size = config.line_size(Ident::Out); + let tile_elems = Tile::N / line_size; + let start = tile_n * tile_elems; + let slice = this.stage.as_slice_mut().slice(start, start + tile_elems); + Tile::fill_accumulator(&slice, acc, 0, config); + } else { + Tile::zero_accumulator(acc, config); + } + } +} + +#[cube] +impl BiasLoader { + pub fn new( + tensor: &Tensor>, + n_offset: u32, + #[comptime] config: G, + #[comptime] has_bias: bool, + ) -> Self { + if has_bias { + let stage = { + let line_size = config.line_size(Ident::Out); + + let smem = SharedMemory::new_lined( + comptime!(config.stage_dim(Ident::Rhs).num_elements_y_dim() / line_size), + line_size, + ); + + Stage:: { smem } + }; + let tensor_view = BiasReader:: { + tensor, + n_offset, + shape_n: tensor.shape(0), + }; + + BiasLoader:: { + tensor_view, + stage, + has_bias, + _config: PhantomData::.runtime(), + } + } else { + let stage = Stage:: { + smem: SharedMemory::new(1), + }; + let tensor_view = BiasReader:: { + tensor, + n_offset: 0, + shape_n: 0, + }; + BiasLoader:: { + stage, + tensor_view, + has_bias, + _config: PhantomData::.runtime(), + } + } + } +} diff --git a/crates/burn-jit/src/kernel/conv/conv2d/gemm/loader/im2col.rs b/crates/burn-jit/src/kernel/conv/conv2d/gemm/loader/im2col.rs new file mode 100644 index 0000000000..0a1ed8728c --- /dev/null +++ b/crates/burn-jit/src/kernel/conv/conv2d/gemm/loader/im2col.rs @@ -0,0 +1,147 @@ +use cubecl::{ + linalg::matmul::components::{ + global::Loader, + stage::{ + multi_buffer::LhsReader, ColMajorTiling, RowMajorTiling, Stage, TilingOrder as _, + TilingOrderConfig, + }, + Ident, + }, + prelude::*, +}; +use std::marker::PhantomData; + +use crate::kernel::conv::{reader::im2col::Im2colReader, Config}; + +/// Loader that translates matrix coordinates to input coordinates using the `im2col` algorithm +#[derive(CubeType)] +pub struct SimpleIm2colLoader { + pub tensor_view: Im2colReader, + pub stage: Stage, + _config: PhantomData, +} + +#[cube] +impl Loader for SimpleIm2colLoader { + type StageReader = LhsReader; + + fn fill_stage(this: &mut Self, #[comptime] config: G) -> Self::StageReader { + SimpleIm2col::load_to_slice::( + &this.tensor_view, + &mut this.stage.as_slice_mut(), + Ident::Lhs, + config, + ); + LhsReader::new(this.stage) + } + + fn advance_view(this: &mut Self, k_offset: u32) { + this.tensor_view.update_view(k_offset); + } +} + +#[cube] +impl SimpleIm2colLoader { + pub fn new( + tensor: &Tensor>, + shape_out_y: u32, + shape_out_x: u32, + x_offset: u32, + y_offset: u32, + #[comptime] config: G, + ) -> Self { + let stage = Stage::new::(Ident::Lhs, config.to_smm_config()); + let shape_batch = tensor.shape(0); + let shape_channel = tensor.shape(3); + + let shape_m = shape_batch * shape_out_y * shape_out_x; + let shape_k = shape_channel * config.kernel_size(0) * config.kernel_size(1); + + let tensor_view = Im2colReader:: { + tensor, + m_offset: x_offset, + k_offset: y_offset, + stride_batch: tensor.stride(0), + stride_y: tensor.stride(1), + stride_x: tensor.stride(2), + stride_channel: tensor.stride(3), + shape_y: tensor.shape(1), + shape_x: tensor.shape(2), + shape_channel, + shape_out_y, + shape_out_x, + + shape_m, + shape_k, + }; + + SimpleIm2colLoader:: { + tensor_view, + stage, + _config: PhantomData::.runtime(), + } + } +} + +#[derive(CubeType, Clone, Copy)] +/// Loads the content of all tiles in the tensor view using all planes, +/// iterating with steps determined by the plane's dimension. +pub struct SimpleIm2col; + +#[cube] +impl SimpleIm2col { + pub fn load_to_slice( + read_view: &Im2colReader, + slice: &mut SliceMut>, + #[comptime] ident: Ident, + #[comptime] config: G, + ) { + let stage_dim = config.stage_dim(ident); + let line_size = config.global_line_size(ident); + + let num_stage_elements = stage_dim.total_elements(); + let total_units = comptime!(config.num_planes() * config.plane_dim()); + let jump_length = comptime!(total_units * line_size); + let num_loads_per_unit = num_stage_elements / jump_length; + + #[allow(clippy::all)] + let _ = comptime!(check_jump_divides_well(num_stage_elements, jump_length)); + + let unit_id = UNIT_POS_Y * config.plane_dim() + UNIT_POS_X; + let unit_position_base = unit_id * line_size; + + for i in 0..num_loads_per_unit { + let unit_position = unit_position_base + i * jump_length; + + let tile_num_elements = stage_dim.tile_num_elements(); + let nth_tile = unit_position / tile_num_elements; + let pos_within_tile = unit_position % tile_num_elements; + + let (tile_x, tile_y) = match config.tiling_order(ident) { + TilingOrderConfig::RowMajor => RowMajorTiling::to_x_y( + nth_tile, + stage_dim.num_tiles_x_dim(), + stage_dim.num_tiles_y_dim(), + ), + TilingOrderConfig::ColMajor => ColMajorTiling::to_x_y( + nth_tile, + stage_dim.num_tiles_x_dim(), + stage_dim.num_tiles_y_dim(), + ), + }; + + let line_read = + read_view.load_simple::(tile_x, tile_y, pos_within_tile, ident, config); + + slice[unit_position / line_size] = Line::cast_from(line_read); + } + } +} + +pub fn check_jump_divides_well(num_stage_elements: u32, jump_length: u32) { + assert!( + num_stage_elements % jump_length == 0, + "Too many data will be loaded, resulting in out of bounds. + Try setting line size and number of planes so that jump_length divides num_stage_elements." + ); +} diff --git a/crates/burn-jit/src/kernel/conv/conv2d/gemm/loader/mod.rs b/crates/burn-jit/src/kernel/conv/conv2d/gemm/loader/mod.rs new file mode 100644 index 0000000000..13d3809513 --- /dev/null +++ b/crates/burn-jit/src/kernel/conv/conv2d/gemm/loader/mod.rs @@ -0,0 +1,2 @@ +pub mod bias; +pub mod im2col; diff --git a/crates/burn-jit/src/kernel/conv/conv2d/gemm/mod.rs b/crates/burn-jit/src/kernel/conv/conv2d/gemm/mod.rs new file mode 100644 index 0000000000..5fd4a309b9 --- /dev/null +++ b/crates/burn-jit/src/kernel/conv/conv2d/gemm/mod.rs @@ -0,0 +1,10 @@ +pub mod algorithm; +pub mod base; +mod config; +pub mod homogeneous; +pub mod launch; +pub mod loader; +pub mod reader; + +pub use config::*; +pub use launch::*; diff --git a/crates/burn-jit/src/kernel/conv/conv2d/gemm/reader/bias.rs b/crates/burn-jit/src/kernel/conv/conv2d/gemm/reader/bias.rs new file mode 100644 index 0000000000..67162a28a8 --- /dev/null +++ b/crates/burn-jit/src/kernel/conv/conv2d/gemm/reader/bias.rs @@ -0,0 +1,38 @@ +use cubecl::{ + linalg::matmul::components::{stage, Ident}, + prelude::*, +}; + +#[derive(CubeType)] +/// A view of a tensor that starts reading data from a specified offset. +/// Ensures safe access by preventing out-of-bounds errors. +/// Includes pre-fetched shapes and strides for optimized performance. +pub struct BiasReader { + pub tensor: *const Tensor>, + pub n_offset: u32, + pub shape_n: u32, +} + +unsafe impl Sync for BiasReader {} +unsafe impl Send for BiasReader {} + +#[cube] +impl BiasReader { + /// Load the 1D bias into shared memory + pub fn load_simple(&self, unit_id: u32, #[comptime] config: G) -> Line { + let line_size = config.line_size(Ident::Out); + + let view_n = self.n_offset + unit_id; + let read_pos = view_n / line_size; + + select( + view_n < self.shape_n, + self.read(read_pos), + Line::empty(line_size).fill(E::from_int(0)), + ) + } + + fn read(&self, position: u32) -> Line { + unsafe { *(*self.tensor).index_unchecked(position) } + } +} diff --git a/crates/burn-jit/src/kernel/conv/conv2d/gemm/reader/im2col.rs b/crates/burn-jit/src/kernel/conv/conv2d/gemm/reader/im2col.rs new file mode 100644 index 0000000000..b278bb051b --- /dev/null +++ b/crates/burn-jit/src/kernel/conv/conv2d/gemm/reader/im2col.rs @@ -0,0 +1,112 @@ +use cubecl::{linalg::matmul::components::Ident, prelude::*}; + +use crate::kernel::conv::Config; + +#[derive(CubeType)] +/// A view of a feature map tensor that starts reading data from a specified offset. +/// Ensures safe access by preventing out-of-bounds errors. +/// Includes pre-fetched shapes and strides for optimized performance. +pub struct Im2colReader { + pub tensor: *const Tensor>, + pub m_offset: u32, + pub k_offset: u32, + + pub stride_batch: u32, + pub stride_y: u32, + pub stride_x: u32, + pub stride_channel: u32, + + pub shape_y: u32, + pub shape_x: u32, + pub shape_channel: u32, + + pub shape_out_y: u32, + pub shape_out_x: u32, + + pub shape_m: u32, + pub shape_k: u32, +} + +unsafe impl Sync for Im2colReader {} +unsafe impl Send for Im2colReader {} + +#[cube] +impl Im2colReader { + /// Advance the view along the k dimension by a specified offset, `k_offset`. + pub fn update_view(&mut self, k_offset: u32) { + self.k_offset += k_offset; + } + + /// Reads data from the tensor view at the specified tile coordinates (tile_x, tile_y) using + /// the `im2col` algorithm to translate them to input coordinates. + /// + /// Each unit loads one line in a coalesced manner for improved efficiency. + /// For row-major tensors, subsequent units read lines horizontally within the tile, + /// while for column-major tensors, they read lines vertically. + /// + /// # Note + /// + /// Out-of-bounds reads will be translated to zeros. + pub fn load_simple( + &self, + tile_x: u32, + tile_y: u32, + unit_id: u32, + #[comptime] ident: Ident, + #[comptime] config: G, + ) -> Line { + let line_size = config.global_line_size(ident); + let tile_size_x = config.stage_dim(ident).tile_size_x_dim(); + let tile_size_y = config.stage_dim(ident).tile_size_y_dim(); + + let view_tile_m = tile_x * tile_size_x + self.m_offset; + let view_tile_k = tile_y * tile_size_y + self.k_offset; + + let load_m = unit_id / tile_size_y; + let load_k = unit_id % tile_size_y; + + let view_m = view_tile_m + load_m; + let view_k = view_tile_k + load_k; + + let out_x = view_m % self.shape_out_x; + let rem = view_m / self.shape_out_x; + let out_y = rem % self.shape_out_y; + let batch = rem / self.shape_out_y; + + let kernel_w = config.kernel_size(1); + + let channel = view_k % self.shape_channel; + let rem = view_k / self.shape_channel; + let kernel_x = rem % kernel_w; + let kernel_y = rem / kernel_w; + + let y = + (out_y * config.stride(0) + kernel_y * config.dilation(0)) as i32 - config.padding(0); + let x = + (out_x * config.stride(1) + kernel_x * config.dilation(1)) as i32 - config.padding(1); + + let m_in_bounds = comptime!(!config.check_m_bounds()) || view_m < self.shape_m; + let k_in_bounds = comptime!(!config.check_k_bounds()) || view_k < self.shape_k; + let no_padding = comptime!(config.padding(0) == 0 && config.padding(1) == 0); + let hw_in_bounds = no_padding + || (y >= 0 && (y as u32) < self.shape_y && x >= 0 && (x as u32) < self.shape_x); + let in_bounds = m_in_bounds && k_in_bounds && hw_in_bounds; + let read_pos = batch * self.stride_batch + + y as u32 * self.stride_y + + x as u32 * self.stride_x + + channel * self.stride_channel; + + let read_pos = read_pos / line_size; + + let mut res = Line::empty(line_size).fill(F::from_int(0)); + if in_bounds { + res = self.read(read_pos); + } + + res + } + + fn read(&self, position: u32) -> Line { + unsafe { *(*self.tensor).index_unchecked(position) } + } +} diff --git a/crates/burn-jit/src/kernel/conv/conv2d/gemm/reader/mod.rs b/crates/burn-jit/src/kernel/conv/conv2d/gemm/reader/mod.rs new file mode 100644 index 0000000000..13d3809513 --- /dev/null +++ b/crates/burn-jit/src/kernel/conv/conv2d/gemm/reader/mod.rs @@ -0,0 +1,2 @@ +pub mod bias; +pub mod im2col; diff --git a/crates/burn-jit/src/kernel/conv/conv2d/im2col.rs b/crates/burn-jit/src/kernel/conv/conv2d/im2col.rs index abcb8488fb..3a7b23df76 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/im2col.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/im2col.rs @@ -6,7 +6,10 @@ use cubecl::{calculate_cube_count_elemwise, linalg::matmul, prelude::*}; use crate::{ kernel::{ - conv::index, into_contiguous, launch_binop, matmul::matmul, matmul::MatmulStrategy, AddOp, + conv::index, + into_contiguous, launch_binop, + matmul::{cube_strategy, matmul, MatmulStrategy}, + AddOp, }, ops::{numeric::empty_device, reshape, swap_dims}, tensor::JitTensor, @@ -300,9 +303,10 @@ fn execute( let weight = reshape(weight, Shape::new([groups, out_c_per_group, col_shape_0])); matmul::launch_ref::( + &cube_strategy::(&client), &client, - weight.as_handle_ref(), - columns.as_handle_ref(), - out.as_handle_ref(), + &weight.as_handle_ref(), + &columns.as_handle_ref(), + &out.as_handle_ref(), ); } diff --git a/crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs b/crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs index 6771f2c5e2..a021f7c089 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs @@ -318,14 +318,23 @@ fn implicit_gemm_kernel( let pos = calculate_positions(gemm_settings); + let in_vec = input.line_size(); + let weight_vec = weight.line_size(); + // Shared memory tiles, currently only holds enough data for // each warp to have its own tile for a single MMA op (8 * 16 * 16 elements) // conceptually a WARPS_PER_CUBE x (CMMA_M * CMMA_K) matrix - let mut smem_input_tile = SharedMemory::::new(cmma_input_tile_size * warps_per_cube); - let mut smem_weight_tile = SharedMemory::::new(cmma_filter_tile_size * warps_per_cube); + let mut smem_input_tile = SharedMemory::::new_lined( + comptime!(cmma_input_tile_size * warps_per_cube / in_vec), + in_vec, + ); + let mut smem_weight_tile = SharedMemory::::new_lined( + comptime!(cmma_filter_tile_size * warps_per_cube / weight_vec), + weight_vec, + ); - let input_tile_start = pos.cube_linear_warp_idx * cmma_input_tile_size; - let weight_tile_start = pos.cube_linear_warp_idx * cmma_filter_tile_size; + let input_tile_start = pos.cube_linear_warp_idx * (cmma_input_tile_size / in_vec); + let weight_tile_start = pos.cube_linear_warp_idx * (cmma_filter_tile_size / weight_vec); let mut input_tile = smem_input_tile.slice_mut(input_tile_start, input_tile_start + cmma_input_tile_size); let mut weight_tile = @@ -441,8 +450,8 @@ fn execute_gemm( weight: &Tensor>, bias: &Tensor, out: &mut SliceMut, - input_tile: &mut SliceMut, - weight_tile: &mut SliceMut, + input_tile: &mut SliceMut>, + weight_tile: &mut SliceMut>, dims: &Dimensions, pos: &Positions, args: &ConvArgs, @@ -484,7 +493,7 @@ fn execute_gemm( fn load_input_tile( input: &Tensor>, args: &ConvArgs, - tile: &mut SliceMut, + tile: &mut SliceMut>, dims: &Dimensions, pos: &Positions, k: u32, @@ -566,21 +575,18 @@ fn load_input_tile( + channel; let value = select( in_bounds && m_in_bounds && k_in_bounds, - FMat::cast_from(input[idx / vec]), - FMat::vectorized(0.0, vec), + Line::cast_from(input[idx / vec]), + Line::new(FMat::new(0.0)), ); - #[unroll] - for i in 0..vec { - tile[m + i] = value[i]; - } + tile[m / vec] = value; } } #[cube] fn load_weight_tile( weight: &Tensor>, - tile: &mut SliceMut, + tile: &mut SliceMut>, dims: &Dimensions, pos: &Positions, k: u32, @@ -629,13 +635,10 @@ fn load_weight_tile( let idx = k_idx + global_n; - let value = FMat::cast_from(weight[idx / vec]); - let value = select(k_in_bounds && n_in_bounds, value, FMat::new(0.0)); + let value = Line::cast_from(weight[idx / vec]); + let value = select(k_in_bounds && n_in_bounds, value, Line::new(FMat::new(0.0))); - #[unroll] - for i in 0..vec { - tile[n + i] = value[i]; - } + tile[n / vec] = value; } } diff --git a/crates/burn-jit/src/kernel/conv/conv2d/layout_swap.rs b/crates/burn-jit/src/kernel/conv/conv2d/layout_swap.rs index a998bea86d..62f0e56d8f 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/layout_swap.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/layout_swap.rs @@ -113,7 +113,7 @@ fn nchw_to_nhwc_kernel( let batch_offset = batch * input.stride(0); let warp_id = plane_broadcast(unit_pos / 32, 0); - let warp_id_x = warp_id / CUBE_DIM_Y; + let warp_id_x = warp_id % tiles_x; let tile_x = CUBE_POS_X * tiles_x + warp_id_x; let tile_y = ABSOLUTE_POS_Y; diff --git a/crates/burn-jit/src/kernel/conv/conv2d/mod.rs b/crates/burn-jit/src/kernel/conv/conv2d/mod.rs index 13900acdc1..f48a490aa6 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/mod.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/mod.rs @@ -1,15 +1,18 @@ mod base; mod col2im; mod direct; +mod gemm; mod im2col; mod implicit_gemm; mod layout_swap; mod transpose_direct; + mod tune; pub use base::*; pub use col2im::*; pub use direct::*; +pub use gemm::*; pub use im2col::*; pub use implicit_gemm::*; pub use layout_swap::*; diff --git a/crates/burn-jit/src/kernel/conv/conv2d/tune/conv2d.rs b/crates/burn-jit/src/kernel/conv/conv2d/tune/conv2d.rs index 4a8122a478..56fc73965e 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/tune/conv2d.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/tune/conv2d.rs @@ -3,15 +3,19 @@ use burn_tensor::{ ElementConversion, Shape, }; use cubecl::{ - tune, + ir::{Elem, FloatKind}, + tf32, tune, tune::{local_tuner, tune_with, LocalTuner}, }; +use half::{bf16, f16}; use crate::{ kernel::{ conv::{ - batches_per_run, can_do_implicit_gemm, conv2d_direct, conv2d_im2col, - conv2d_implicit_gemm, + algorithm::Algorithm, batches_per_run, can_do_implicit_gemm, conv2d_direct, + conv2d_gemm_cmma_balanced, conv2d_gemm_cmma_large_m, conv2d_im2col, + conv2d_implicit_gemm, has_tf32, problem_from_key, CmmaBalancedAlgorithm, + CmmaLargeMAlgorithm, }, prng::random_uniform, }, @@ -40,7 +44,13 @@ pub fn conv2d_autotune( } #[tune( - operations(conv2d_direct, conv2d_im2col, conv2d_implicit_gemm), + operations( + conv2d_direct, + conv2d_im2col, + conv2d_implicit_gemm, + conv2d_gemm_cmma_large_m, + conv2d_gemm_cmma_balanced + ), create_key = create_key::, should_run = should_run )] @@ -72,6 +82,23 @@ pub fn conv2d_operations( tune_with!(input, weights, bias, options) } +macro_rules! check_algo { + ($algo:tt, $float:ty, $input:expr, $problem:expr) => { + match (<$float>::as_elem(), has_tf32(&$input)) { + (Elem::Float(FloatKind::F32), true) => { + $algo::<$float, tf32, f32>::can_launch::(&$input.client, &$problem) + } + (Elem::Float(FloatKind::F16), _) => { + $algo::<$float, f16, f16>::can_launch::(&$input.client, &$problem) + } + (Elem::Float(FloatKind::BF16), _) => { + $algo::<$float, bf16, f32>::can_launch::(&$input.client, &$problem) + } + _ => $algo::<$float, f16, f32>::can_launch::(&$input.client, &$problem), + } + }; +} + fn should_run( op: &Conv2dOperations, key: &JitAutotuneKey, @@ -97,6 +124,8 @@ fn should_run( key.width, ); + let conv_problem = problem_from_key::(key, out_h, out_w); + match index { // im2col 1 => batches_per_run(key.batch_size, out_h, out_w).is_some(), @@ -111,6 +140,10 @@ fn should_run( out_w, &op.input.client, ), + // GEMM large m + 3 => check_algo!(CmmaLargeMAlgorithm, F, op.input, conv_problem), + // GEMM balanced + 4 => check_algo!(CmmaBalancedAlgorithm, F, op.input, conv_problem), _ => true, } } diff --git a/crates/burn-jit/src/kernel/conv/deform_conv_transpose2d.rs b/crates/burn-jit/src/kernel/conv/deform_conv_transpose2d.rs index 907b5ef344..163e4796e4 100644 --- a/crates/burn-jit/src/kernel/conv/deform_conv_transpose2d.rs +++ b/crates/burn-jit/src/kernel/conv/deform_conv_transpose2d.rs @@ -211,29 +211,31 @@ fn compute_offset_and_mask_gradient( let cube_dim = CubeDim::default(); let cube_count = calculate_cube_count_elemwise(num_elements_offset, cube_dim); - deform_col2img_coord_kernel::launch::( - &image.client, - cube_count, - cube_dim, - image.as_handle_ref().as_tensor_arg(1), - offset.as_handle_ref().as_tensor_arg(1), - mask.as_handle_ref().as_tensor_arg(1), - columns.as_handle_ref().as_tensor_arg(1), - grad_offset.as_handle_ref().as_tensor_arg(1), - grad_mask.as_handle_ref().as_tensor_arg(1), - DeformConv2dCol2ImgCoordArgsLaunch::new( - ScalarArg::new(options.stride[0] as u32), - ScalarArg::new(options.stride[1] as u32), - ScalarArg::new(options.dilation[0] as u32), - ScalarArg::new(options.dilation[1] as u32), - ScalarArg::new(E::from_elem(options.padding[0] as f32)), - ScalarArg::new(E::from_elem(options.padding[1] as f32)), - ScalarArg::new(options.offset_groups as u32), - ScalarArg::new(kernel_height as u32), - ScalarArg::new(kernel_width as u32), - ), - use_mask, - ); + unsafe { + deform_col2img_coord_kernel::launch_unchecked::( + &image.client, + cube_count, + cube_dim, + image.as_handle_ref().as_tensor_arg(1), + offset.as_handle_ref().as_tensor_arg(1), + mask.as_handle_ref().as_tensor_arg(1), + columns.as_handle_ref().as_tensor_arg(1), + grad_offset.as_handle_ref().as_tensor_arg(1), + grad_mask.as_handle_ref().as_tensor_arg(1), + DeformConv2dCol2ImgCoordArgsLaunch::new( + ScalarArg::new(options.stride[0] as u32), + ScalarArg::new(options.stride[1] as u32), + ScalarArg::new(options.dilation[0] as u32), + ScalarArg::new(options.dilation[1] as u32), + ScalarArg::new(E::from_elem(options.padding[0] as f32)), + ScalarArg::new(E::from_elem(options.padding[1] as f32)), + ScalarArg::new(options.offset_groups as u32), + ScalarArg::new(kernel_height as u32), + ScalarArg::new(kernel_width as u32), + ), + use_mask, + ) + }; let mask_gradient = if use_mask { Some(grad_mask) } else { None }; (grad_offset, mask_gradient) @@ -253,7 +255,7 @@ struct DeformConv2dCol2ImgCoordArgs { } #[allow(clippy::collapsible_if)] -#[cube(launch)] +#[cube(launch_unchecked)] fn deform_col2img_coord_kernel( image: &Tensor, offset: &Tensor, @@ -267,6 +269,10 @@ fn deform_col2img_coord_kernel( // Position format: [batch, [offset_group, kernel_h, kernel_w, 2], out_h, out_w] // Alternatively : [batch, offset_channels, out_h, out_w] + if ABSOLUTE_POS >= grad_offset.len() { + return; + } + let offset_channels = offset.shape(1); let out_h = offset.shape(2); let out_w = offset.shape(3); diff --git a/crates/burn-jit/src/kernel/matmul/base.rs b/crates/burn-jit/src/kernel/matmul/base.rs index 562197647f..60c7cbbd1c 100644 --- a/crates/burn-jit/src/kernel/matmul/base.rs +++ b/crates/burn-jit/src/kernel/matmul/base.rs @@ -1,7 +1,12 @@ use super::{init_matmul_output, matmul_simple}; use crate::{tensor::JitTensor, FloatElement, JitRuntime}; use burn_tensor::Shape; -use cubecl::prelude::*; +use cubecl::{ + ir::{Elem, FloatKind}, + linalg::matmul::Strategy, + prelude::*, + Feature, +}; #[cfg(feature = "autotune")] use super::matmul_autotune; @@ -42,16 +47,20 @@ pub fn matmul( match strategy { MatmulStrategy::Simple { grid_x, grid_y } => { let out = init_matmul_output::(&lhs, &rhs); + matmul_simple::(lhs, rhs, out, grid_x, grid_y) } MatmulStrategy::Cube => { let out = init_matmul_output::(&lhs, &rhs); + let client = &lhs.client; + cubecl::linalg::matmul::launch_ref::( + &cube_strategy::(client), client, - lhs.as_handle_ref(), - rhs.as_handle_ref(), - out.as_handle_ref(), + &lhs.as_handle_ref(), + &rhs.as_handle_ref(), + &out.as_handle_ref(), ); out } @@ -60,6 +69,26 @@ pub fn matmul( } } +pub(crate) fn cube_strategy( + client: &ComputeClient, +) -> Strategy { + // TODO: Replace with auto option once cubecl has one + let cmma_available = client.properties().feature_enabled(Feature::Cmma { + a: Elem::Float(FloatKind::F16), + b: Elem::Float(FloatKind::F16), + c: Elem::Float(FloatKind::F32), + m: 16, + k: 16, + n: 16, + }); + let plane_available = client.properties().feature_enabled(Feature::Plane); + match (cmma_available, plane_available) { + (true, _) => Strategy::Accelerated, + (false, true) => Strategy::PlaneMma, + _ => Strategy::Tiling2D(Default::default()), + } +} + pub(crate) fn simple_cube_count( lhs_shape: &Shape, rhs_shape: &Shape, diff --git a/crates/burn-jit/src/kernel/matmul/tune/base.rs b/crates/burn-jit/src/kernel/matmul/tune/base.rs index 38df9f7fe1..e49ea3154c 100644 --- a/crates/burn-jit/src/kernel/matmul/tune/base.rs +++ b/crates/burn-jit/src/kernel/matmul/tune/base.rs @@ -5,7 +5,10 @@ use cubecl::tune::{local_tuner, AutotuneOperation, AutotuneOperationSet, LocalTu use crate::{ element::FloatElement, - kernel::{matmul::utils::init_matmul_output, prng::random_like_uniform}, + kernel::{ + matmul::{cube_strategy, utils::init_matmul_output}, + prng::random_like_uniform, + }, ops::numeric::empty_device, tensor::JitTensor, tune_key::JitAutotuneKey, @@ -87,10 +90,10 @@ pub fn matmul_autotune( lhs: JitTensor, rhs: JitTensor, ) -> JitTensor { - let client = lhs.client.clone(); - let output = init_matmul_output::(&lhs, &rhs); + let client = lhs.client.clone(); + static TUNER: LocalTuner = local_tuner!(); TUNER.execute( @@ -149,11 +152,13 @@ matmul_tune_ops!(SimpleMatmul16x16, |lhs, rhs, out| { matmul_tune_ops!( MatmulCube, |lhs: JitTensor, rhs: JitTensor, out: JitTensor| { + let strategy = cube_strategy::(&lhs.client); cubecl::linalg::matmul::launch_ref::( + &strategy, &lhs.client, - lhs.as_handle_ref(), - rhs.as_handle_ref(), - out.as_handle_ref(), + &lhs.as_handle_ref(), + &rhs.as_handle_ref(), + &out.as_handle_ref(), ); } ); diff --git a/crates/burn-jit/src/kernel/reduce/subcube/kernel.rs b/crates/burn-jit/src/kernel/reduce/subcube/kernel.rs index 4e783e74e9..4a32b5d641 100644 --- a/crates/burn-jit/src/kernel/reduce/subcube/kernel.rs +++ b/crates/burn-jit/src/kernel/reduce/subcube/kernel.rs @@ -28,7 +28,7 @@ pub fn reduce_dim_subcube_kernel< let should_unroll = elems_per_thread <= 8; - let warp_id = UNIT_POS / PLANE_DIM; + let warp_id = plane_broadcast(UNIT_POS / PLANE_DIM, 0); let mut shared_memory = RD::init_shared(subcube_size); diff --git a/crates/burn-jit/src/ops/transaction.rs b/crates/burn-jit/src/ops/transaction.rs index 7320186570..b67740bc97 100644 --- a/crates/burn-jit/src/ops/transaction.rs +++ b/crates/burn-jit/src/ops/transaction.rs @@ -60,7 +60,7 @@ where let client = client.unwrap(); async move { - let mut data = client + let mut data: Vec> = client .read_async(bindings) .await .into_iter() diff --git a/crates/burn-jit/src/template/base.rs b/crates/burn-jit/src/template/base.rs index 9ff5f28247..54e50468fb 100644 --- a/crates/burn-jit/src/template/base.rs +++ b/crates/burn-jit/src/template/base.rs @@ -24,7 +24,7 @@ impl CubeTask for SourceKernel { let source = source_template.complete(); CompiledKernel { - entrypoint_name: "kernel".to_string(), + entrypoint_name: "main".to_string(), debug_name: Some(core::any::type_name::()), source, cube_dim: self.cube_dim, diff --git a/crates/burn-jit/src/tests/conv2d.rs b/crates/burn-jit/src/tests/conv2d.rs index f93adffe8f..061ab54e65 100644 --- a/crates/burn-jit/src/tests/conv2d.rs +++ b/crates/burn-jit/src/tests/conv2d.rs @@ -52,7 +52,78 @@ mod tests { output .into_data() - .assert_approx_eq(&output_ref.into_data(), 1); + .assert_approx_eq(&output_ref.into_data(), 2); + } + + /// Regression test for bias loader in new implicit GEMM + #[test] + fn conv2d_should_match_reference_backend_bias_regression() { + let test_device = Default::default(); + let input = + Tensor::::random([1, 1, 1, 1], Distribution::Default, &test_device); + let weight = + Tensor::::random([32, 1, 3, 3], Distribution::Default, &test_device); + let bias = Tensor::::random([32], Distribution::Default, &test_device); + let ref_device = Default::default(); + + let input_ref = Tensor::::from_data(input.to_data(), &ref_device); + let weight_ref = Tensor::::from_data(weight.to_data(), &ref_device); + let bias_ref = Tensor::::from_data(bias.to_data(), &ref_device); + + let options = burn_tensor::ops::ConvOptions::new([1, 1], [1, 1], [1, 1], 1); + + let output = + module::conv2d(input, weight, Some(bias), options.clone()).permute([0, 2, 3, 1]); + let output_ref = + module::conv2d(input_ref, weight_ref, Some(bias_ref), options).permute([0, 2, 3, 1]); + + output + .into_data() + .assert_approx_eq(&output_ref.into_data(), 2); + } + + #[test] + fn nchw_to_nhwc_should_match_into_contiguous() { + let test_device = Default::default(); + let input = + Tensor::::random([4, 72, 53, 56], Distribution::Default, &test_device); + + type Float = ::FloatElem; + + let output = nchw_to_nhwc::(input.clone().into_primitive().tensor()); + let output_ref = into_contiguous( + input + .clone() + .permute([0, 2, 3, 1]) + .into_primitive() + .tensor(), + ); + + into_data_sync::(output) + .assert_approx_eq(&into_data_sync::(output_ref), 4); + } + + /// Regression test for transpose kernel that was causing corruption with 17-64 in channels and + /// at least 17 hw + #[test] + fn nchw_to_nhwc_should_match_into_contiguous_regression() { + let test_device = Default::default(); + let input = + Tensor::::random([1, 18, 17, 1], Distribution::Default, &test_device); + + type Float = ::FloatElem; + + let output = nchw_to_nhwc::(input.clone().into_primitive().tensor()); + let output_ref = into_contiguous( + input + .clone() + .permute([0, 2, 3, 1]) + .into_primitive() + .tensor(), + ); + + into_data_sync::(output) + .assert_approx_eq(&into_data_sync::(output_ref), 4); } #[test] diff --git a/crates/burn-ndarray/src/tensor.rs b/crates/burn-ndarray/src/tensor.rs index a69faf73de..64e8037c91 100644 --- a/crates/burn-ndarray/src/tensor.rs +++ b/crates/burn-ndarray/src/tensor.rs @@ -463,7 +463,7 @@ mod tests { scale: B::float_from_data(TensorData::from([0.009_019_608]), &device), offset: Some(B::int_from_data(TensorData::from([72]), &device)), }; - let qtensor: NdArrayQTensor = B::quantize(tensor.into(), &scheme, qparams); + let qtensor: NdArrayQTensor = B::quantize(tensor, &scheme, qparams); assert_eq!(qtensor.scheme(), &scheme); assert_eq!( diff --git a/crates/burn-tensor/src/tests/module/conv3d.rs b/crates/burn-tensor/src/tests/module/conv3d.rs index b7a12b374c..77c827d928 100644 --- a/crates/burn-tensor/src/tests/module/conv3d.rs +++ b/crates/burn-tensor/src/tests/module/conv3d.rs @@ -293,7 +293,8 @@ mod tests { ), ); - y.to_data().assert_approx_eq(&output.into_data(), 3); + y.to_data() + .assert_approx_eq_diff(&output.into_data(), 0.002); } } } diff --git a/crates/burn-train/src/metric/base.rs b/crates/burn-train/src/metric/base.rs index e0eafe649d..db58c15886 100644 --- a/crates/burn-train/src/metric/base.rs +++ b/crates/burn-train/src/metric/base.rs @@ -19,6 +19,7 @@ pub struct MetricMetadata { } impl MetricMetadata { + /// Fake metric metadata #[cfg(test)] pub fn fake() -> Self { Self { diff --git a/crates/burn/Cargo.toml b/crates/burn/Cargo.toml index d6287382d9..583bfc1ddf 100644 --- a/crates/burn/Cargo.toml +++ b/crates/burn/Cargo.toml @@ -52,12 +52,12 @@ candle = ["burn-core/candle"] cuda-jit = ["burn-core/cuda-jit"] hip-jit = ["burn-core/hip-jit"] ndarray = ["burn-core/ndarray"] +remote = ["burn-core/remote"] +router = ["burn-core/router"] +server = ["burn-core/server"] tch = ["burn-core/tch"] wgpu = ["burn-core/wgpu"] wgpu-spirv = ["burn-core/wgpu-spirv"] -remote = ["burn-core/remote"] -server = ["burn-core/server"] -router = ["burn-core/router"] # Network utils network = ["burn-core/network"]