From 5270224f407502b82fe90bc2622894ce3871b002 Mon Sep 17 00:00:00 2001 From: Jani Monoses Date: Tue, 16 Jan 2024 07:34:16 +0200 Subject: [PATCH] Add MobileOne model. (#1595) * Add MobileOne model. * Clippy fixes * Remove a comment. --------- Co-authored-by: laurent --- candle-examples/examples/mobileone/README.md | 22 ++ candle-examples/examples/mobileone/main.rs | 96 ++++++ candle-transformers/src/models/mobileone.rs | 333 +++++++++++++++++++ candle-transformers/src/models/mod.rs | 1 + 4 files changed, 452 insertions(+) create mode 100644 candle-examples/examples/mobileone/README.md create mode 100644 candle-examples/examples/mobileone/main.rs create mode 100644 candle-transformers/src/models/mobileone.rs diff --git a/candle-examples/examples/mobileone/README.md b/candle-examples/examples/mobileone/README.md new file mode 100644 index 0000000000..b5e88b6f28 --- /dev/null +++ b/candle-examples/examples/mobileone/README.md @@ -0,0 +1,22 @@ +# candle-mobileone + +[MobileOne: An Improved One millisecond Mobile Backbone](https://arxiv.org/abs/2206.04040). + +This candle implementation uses a pre-trained MobileOne network for inference. The +classification head has been trained on the ImageNet dataset and returns the +probabilities for the top-5 classes. + +## Running an example + +``` +$ cargo run --example mobileone --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg --which s2 + +loaded image Tensor[dims 3, 224, 224; f32] +model built +mountain bike, all-terrain bike, off-roader: 79.33% +bicycle-built-for-two, tandem bicycle, tandem: 15.32% +crash helmet : 2.58% +unicycle, monocycle : 1.70% +alp : 0.21% + +``` diff --git a/candle-examples/examples/mobileone/main.rs b/candle-examples/examples/mobileone/main.rs new file mode 100644 index 0000000000..4cd55001d3 --- /dev/null +++ b/candle-examples/examples/mobileone/main.rs @@ -0,0 +1,96 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use clap::{Parser, ValueEnum}; + +use candle::{DType, IndexOp, D}; +use candle_nn::{Module, VarBuilder}; +use candle_transformers::models::mobileone; + +#[derive(Clone, Copy, Debug, ValueEnum)] +enum Which { + S0, + S1, + S2, + S3, + S4, +} + +impl Which { + fn model_filename(&self) -> String { + let name = match self { + Self::S0 => "s0", + Self::S1 => "s1", + Self::S2 => "s2", + Self::S3 => "s3", + Self::S4 => "s4", + }; + format!("timm/mobileone_{}.apple_in1k", name) + } + + fn config(&self) -> mobileone::Config { + match self { + Self::S0 => mobileone::Config::s0(), + Self::S1 => mobileone::Config::s1(), + Self::S2 => mobileone::Config::s2(), + Self::S3 => mobileone::Config::s3(), + Self::S4 => mobileone::Config::s4(), + } + } +} + +#[derive(Parser)] +struct Args { + #[arg(long)] + model: Option, + + #[arg(long)] + image: String, + + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + #[arg(value_enum, long, default_value_t=Which::S0)] + which: Which, +} + +pub fn main() -> anyhow::Result<()> { + let args = Args::parse(); + + let device = candle_examples::device(args.cpu)?; + + let image = candle_examples::imagenet::load_image224(args.image)?; + println!("loaded image {image:?}"); + + let model_file = match args.model { + None => { + let model_name = args.which.model_filename(); + let api = hf_hub::api::sync::Api::new()?; + let api = api.model(model_name); + api.get("model.safetensors")? + } + Some(model) => model.into(), + }; + + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? }; + let model = mobileone::mobileone(&args.which.config(), 1000, vb)?; + println!("model built"); + let logits = model.forward(&image.unsqueeze(0)?)?; + let prs = candle_nn::ops::softmax(&logits, D::Minus1)? + .i(0)? + .to_vec1::()?; + let mut prs = prs.iter().enumerate().collect::>(); + prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1)); + for &(category_idx, pr) in prs.iter().take(5) { + println!( + "{:24}: {:.2}%", + candle_examples::imagenet::CLASSES[category_idx], + 100. * pr + ); + } + Ok(()) +} diff --git a/candle-transformers/src/models/mobileone.rs b/candle-transformers/src/models/mobileone.rs new file mode 100644 index 0000000000..674da40b97 --- /dev/null +++ b/candle-transformers/src/models/mobileone.rs @@ -0,0 +1,333 @@ +//! MobileOne inference implementation based on timm and candle-repvgg +//! +//! See "MobileOne: An Improved One millisecond Mobile Backbone" +//! https://arxiv.org/abs/2206.04040 + +use candle::{DType, Result, Tensor, D}; +use candle_nn::{ + batch_norm, conv2d, conv2d_no_bias, linear, ops::sigmoid, BatchNorm, Conv2d, Conv2dConfig, + Func, VarBuilder, +}; + +struct StageConfig { + blocks: usize, + channels: usize, +} + +// The architecture in the paper has 6 stages. The timm implementation uses an equivalent form +// by concatenating the 5th stage (starts with stride 1) to the previous one. +const STAGES: [StageConfig; 5] = [ + StageConfig { + blocks: 1, + channels: 64, + }, + StageConfig { + blocks: 2, + channels: 64, + }, + StageConfig { + blocks: 8, + channels: 128, + }, + StageConfig { + blocks: 10, + channels: 256, + }, + StageConfig { + blocks: 1, + channels: 512, + }, +]; + +#[derive(Clone)] +pub struct Config { + /// overparameterization factor + k: usize, + /// per-stage channel number multipliers + alphas: [f32; 5], +} + +impl Config { + pub fn s0() -> Self { + Self { + k: 4, + alphas: [0.75, 0.75, 1.0, 1.0, 2.0], + } + } + pub fn s1() -> Self { + Self { + k: 1, + alphas: [1.5, 1.5, 1.5, 2.0, 2.5], + } + } + pub fn s2() -> Self { + Self { + k: 1, + alphas: [1.5, 1.5, 2.0, 2.5, 4.0], + } + } + pub fn s3() -> Self { + Self { + k: 1, + alphas: [2.0, 2.0, 2.5, 3.0, 4.0], + } + } + pub fn s4() -> Self { + Self { + k: 1, + alphas: [3.0, 3.0, 3.5, 3.5, 4.0], + } + } +} + +// SE blocks are used in the last stages of the s4 variant. +fn squeeze_and_excitation( + in_channels: usize, + squeeze_channels: usize, + vb: VarBuilder, +) -> Result> { + let conv2d_cfg = Conv2dConfig { + ..Default::default() + }; + let fc1 = conv2d(in_channels, squeeze_channels, 1, conv2d_cfg, vb.pp("fc1"))?; + let fc2 = conv2d(squeeze_channels, in_channels, 1, conv2d_cfg, vb.pp("fc2"))?; + + Ok(Func::new(move |xs| { + let residual = xs; + let xs = xs.mean_keepdim(D::Minus2)?.mean_keepdim(D::Minus1)?; + let xs = sigmoid(&xs.apply(&fc1)?.relu()?.apply(&fc2)?)?; + + residual.broadcast_mul(&xs) + })) +} + +// fuses a convolutional kernel and a batchnorm layer into a convolutional layer +// based on the _fuse_bn_tensor method in timm +// see https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/byobnet.py#L602 +fn fuse_conv_bn(weights: &Tensor, bn: BatchNorm) -> Result<(Tensor, Tensor)> { + let (gamma, beta) = bn.weight_and_bias().unwrap(); + let mu = bn.running_mean(); + let sigma = (bn.running_var() + bn.eps())?.sqrt(); + let gps = (gamma / sigma)?; + let bias = (beta - mu * &gps)?; + let weights = weights.broadcast_mul(&gps.reshape(((), 1, 1, 1))?)?; + + Ok((weights, bias)) +} + +// A mobileone block has a different training time and inference time architecture. +// The latter is a simple and efficient equivalent transformation of the former +// realized by a structural reparameterization technique, where convolutions +// along with identity branches and batchnorm layers are fused into a single convolution. +#[allow(clippy::too_many_arguments)] +fn mobileone_block( + has_identity: bool, + k: usize, + dim: usize, + stride: usize, + padding: usize, + groups: usize, + kernel: usize, + in_channels: usize, + out_channels: usize, + vb: VarBuilder, +) -> Result> { + let conv2d_cfg = Conv2dConfig { + stride, + padding, + groups, + ..Default::default() + }; + + let mut w = Tensor::zeros( + (out_channels, in_channels / groups, kernel, kernel), + DType::F32, + vb.device(), + )?; + let mut b = Tensor::zeros(dim, DType::F32, vb.device())?; + + // k is the training-time overparameterization factor, larger than 1 only in the s0 variant + for i in 0..k { + let conv_kxk_bn = batch_norm(dim, 1e-5, vb.pp(format!("conv_kxk.{i}.bn")))?; + let conv_kxk = conv2d_no_bias( + in_channels, + out_channels, + kernel, + conv2d_cfg, + vb.pp(format!("conv_kxk.{i}.conv")), + )?; + let (wk, bk) = fuse_conv_bn(conv_kxk.weight(), conv_kxk_bn)?; + w = (w + wk)?; + b = (b + bk)?; + } + + if kernel > 1 { + let conv_scale_bn = batch_norm(dim, 1e-5, vb.pp("conv_scale.bn"))?; + let conv_scale = conv2d_no_bias( + in_channels, + out_channels, + 1, + conv2d_cfg, + vb.pp("conv_scale.conv"), + )?; + + let (mut ws, bs) = fuse_conv_bn(conv_scale.weight(), conv_scale_bn)?; + // resize to 3x3 + ws = ws.pad_with_zeros(D::Minus1, 1, 1)?; + ws = ws.pad_with_zeros(D::Minus2, 1, 1)?; + + w = (w + ws)?; + b = (b + bs)?; + } + + // Use SE blocks if present (last layers of the s4 variant) + let se = squeeze_and_excitation(out_channels, out_channels / 16, vb.pp("attn")); + + // read and reparameterize the identity bn into wi and bi + if has_identity { + let identity_bn = batch_norm(dim, 1e-5, vb.pp("identity"))?; + + let mut weights: Vec = vec![0.0; w.elem_count()]; + + let id = in_channels / groups; + // See https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/byobnet.py#L809 + for i in 0..in_channels { + if kernel > 1 { + weights[i * kernel * kernel + 4] = 1.0; + } else { + weights[i * (id + 1)] = 1.0; + } + } + + let weights = &Tensor::from_vec(weights, w.shape(), w.device())?; + let (wi, bi) = fuse_conv_bn(weights, identity_bn)?; + + w = (w + wi)?; + b = (b + bi)?; + } + + let reparam_conv = Conv2d::new(w, Some(b), conv2d_cfg); + + Ok(Func::new(move |xs| { + let mut xs = xs.apply(&reparam_conv)?; + if let Ok(f) = &se { + xs = xs.apply(f)?; + } + xs = xs.relu()?; + Ok(xs) + })) +} + +// Get the number of output channels per stage taking into account the multipliers +fn output_channels_per_stage(cfg: &Config, stage: usize) -> usize { + let channels = STAGES[stage].channels as f32; + let alpha = cfg.alphas[stage]; + + match stage { + 0 => std::cmp::min(64, (channels * alpha) as usize), + _ => (channels * alpha) as usize, + } +} + +// Each stage is made of blocks. The first layer always downsamples with stride 2. +// All but the first block have a residual connection. +fn mobileone_stage(cfg: &Config, idx: usize, vb: VarBuilder) -> Result> { + let nblocks = STAGES[idx].blocks; + let mut blocks = Vec::with_capacity(nblocks); + + let mut in_channels = output_channels_per_stage(cfg, idx - 1); + + for block_idx in 0..nblocks { + let out_channels = output_channels_per_stage(cfg, idx); + let (has_identity, stride) = if block_idx == 0 { + (false, 2) + } else { + (true, 1) + }; + + // depthwise convolution layer + blocks.push(mobileone_block( + has_identity, + cfg.k, + in_channels, + stride, + 1, + in_channels, + 3, + in_channels, + in_channels, + vb.pp(block_idx * 2), + )?); + + // pointwise convolution layer + blocks.push(mobileone_block( + has_identity, + cfg.k, + out_channels, + 1, // stride + 0, // padding + 1, // groups + 1, // kernel + in_channels, + out_channels, + vb.pp(block_idx * 2 + 1), + )?); + + in_channels = out_channels; + } + + Ok(Func::new(move |xs| { + let mut xs = xs.clone(); + for block in blocks.iter() { + xs = xs.apply(block)? + } + Ok(xs) + })) +} + +// Build a mobileone model for a given configuration. +fn mobileone_model( + config: &Config, + nclasses: Option, + vb: VarBuilder, +) -> Result> { + let cls = match nclasses { + None => None, + Some(nclasses) => { + let outputs = output_channels_per_stage(config, 4); + let linear = linear(outputs, nclasses, vb.pp("head.fc"))?; + Some(linear) + } + }; + + let stem_dim = output_channels_per_stage(config, 0); + let stem = mobileone_block(false, 1, stem_dim, 2, 1, 1, 3, 3, stem_dim, vb.pp("stem"))?; + let vb = vb.pp("stages"); + let stage1 = mobileone_stage(config, 1, vb.pp(0))?; + let stage2 = mobileone_stage(config, 2, vb.pp(1))?; + let stage3 = mobileone_stage(config, 3, vb.pp(2))?; + let stage4 = mobileone_stage(config, 4, vb.pp(3))?; + + Ok(Func::new(move |xs| { + let xs = xs + .apply(&stem)? + .apply(&stage1)? + .apply(&stage2)? + .apply(&stage3)? + .apply(&stage4)? + .mean(D::Minus2)? + .mean(D::Minus1)?; + match &cls { + None => Ok(xs), + Some(cls) => xs.apply(cls), + } + })) +} + +pub fn mobileone(cfg: &Config, nclasses: usize, vb: VarBuilder) -> Result> { + mobileone_model(cfg, Some(nclasses), vb) +} + +pub fn mobileone_no_final_layer(cfg: &Config, vb: VarBuilder) -> Result> { + mobileone_model(cfg, None, vb) +} diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index 9af6df69af..a94fd07a06 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -15,6 +15,7 @@ pub mod marian; pub mod mistral; pub mod mixformer; pub mod mixtral; +pub mod mobileone; pub mod mpt; pub mod persimmon; pub mod phi;