diff --git a/candle-examples/examples/moondream/main.rs b/candle-examples/examples/moondream/main.rs index 3e0f6d574..008346f04 100644 --- a/candle-examples/examples/moondream/main.rs +++ b/candle-examples/examples/moondream/main.rs @@ -9,11 +9,19 @@ use clap::Parser; use candle::{DType, Device, Tensor}; use candle_nn::VarBuilder; -use candle_transformers::{generation::LogitsProcessor, models::moondream}; +use candle_transformers::{ + generation::LogitsProcessor, + models::{moondream, quantized_moondream}, +}; use tokenizers::Tokenizer; +enum Model { + Moondream(moondream::Model), + Quantized(quantized_moondream::Model), +} + struct TextGeneration { - model: moondream::Model, + model: Model, device: Device, tokenizer: Tokenizer, logits_processor: LogitsProcessor, @@ -25,7 +33,7 @@ struct TextGeneration { impl TextGeneration { #[allow(clippy::too_many_arguments)] fn new( - model: moondream::Model, + model: Model, tokenizer: Tokenizer, seed: u64, temp: Option, @@ -64,6 +72,14 @@ impl TextGeneration { let mut tokens = tokens.get_ids().to_vec(); let mut generated_tokens = 0usize; + // Moondream tokenizer bos_token is "<|endoftext|>" + // https://huggingface.co/vikhyatk/moondream2/blob/main/special_tokens_map.json + let bos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") { + Some(token) => *token, + None => anyhow::bail!("cannot find the BOS token"), + }; + // eos_token is "END" + // https://github.com/vikhyat/moondream/blob/a9d788a20d1543fb1479edc54106e88cff7759d3/moondream/moondream.py#L100 let eos_token = match self.tokenizer.get_vocab(true).get("END") { Some(token) => *token, None => anyhow::bail!("cannot find the EOS token"), @@ -75,11 +91,24 @@ impl TextGeneration { let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?; let logits = if index > 0 { - self.model.text_model.forward(&input)? + match self.model { + Model::Moondream(ref mut model) => model.text_model.forward(&input)?, + Model::Quantized(ref mut model) => model.text_model.forward(&input)?, + } } else { - self.model - .text_model - .forward_with_img(&input, image_embeds)? + let bos_token = Tensor::new(&[bos_token], &self.device)?.unsqueeze(0)?; + match self.model { + Model::Moondream(ref mut model) => { + model + .text_model + .forward_with_img(&bos_token, &input, image_embeds)? + } + Model::Quantized(ref mut model) => { + model + .text_model + .forward_with_img(&bos_token, &input, image_embeds)? + } + } }; let logits = logits.squeeze(0)?.to_dtype(DType::F32)?; let logits = if self.repeat_penalty == 1. { @@ -142,7 +171,7 @@ struct Args { top_p: Option, /// The seed to use when generating random samples. - #[arg(long, default_value_t = 299792458)] + #[arg(long, default_value_t = 0)] seed: u64, #[arg(long, default_value_t = 5000)] @@ -156,12 +185,15 @@ struct Args { #[arg(long, default_value_t = 64)] repeat_last_n: usize, - #[arg(long, default_value = "vikhyatk/moondream2")] - model_id: String, + #[arg(long)] + model_id: Option, #[arg(long, default_value = "main")] revision: String, + #[arg(long)] + quantized: bool, + #[arg(long)] model_file: Option, @@ -216,14 +248,30 @@ async fn main() -> anyhow::Result<()> { let start = std::time::Instant::now(); let api = hf_hub::api::tokio::Api::new()?; + let model_id = match args.model_id { + Some(model_id) => model_id.to_string(), + None => { + if args.quantized { + "santiagomed/candle-moondream".to_string() + } else { + "vikhyatk/moondream2".to_string() + } + } + }; let repo = api.repo(hf_hub::Repo::with_revision( - args.model_id, + model_id, hf_hub::RepoType::Model, args.revision, )); let model_file = match args.model_file { Some(m) => m.into(), - None => repo.get("model.safetensors").await?, + None => { + if args.quantized { + repo.get("model-q4_0.gguf").await? + } else { + repo.get("model.safetensors").await? + } + } }; let tokenizer = match args.tokenizer_file { Some(m) => m.into(), @@ -234,22 +282,35 @@ async fn main() -> anyhow::Result<()> { let start = std::time::Instant::now(); let device = candle_examples::device(args.cpu)?; - let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? }; let config = moondream::Config::v2(); - let model = moondream::Model::new(&config, vb)?; + let model = if args.quantized { + let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf( + &model_file, + &device, + )?; + let model = quantized_moondream::Model::new(&config, vb)?; + Model::Quantized(model) + } else { + let vb = + unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? }; + let model = moondream::Model::new(&config, vb)?; + Model::Moondream(model) + }; println!("loaded the model in {:?}", start.elapsed()); let start = std::time::Instant::now(); let image = load_image(args.image)?.to_device(&device)?; let image_embeds = image.unsqueeze(0)?; - let image_embeds = image_embeds.apply(model.vision_encoder())?; + let image_embeds = match model { + Model::Moondream(ref m) => image_embeds.apply(m.vision_encoder())?, + Model::Quantized(ref m) => image_embeds.apply(m.vision_encoder())?, + }; println!( "loaded and encoded the image {image:?} in {:?}", start.elapsed() ); let prompt = format!("\n\nQuestion: {0}\n\nAnswer:", args.prompt); - let mut pipeline = TextGeneration::new( model, tokenizer, diff --git a/candle-transformers/src/models/mixformer.rs b/candle-transformers/src/models/mixformer.rs index edca8b9d7..65a1665a6 100644 --- a/candle-transformers/src/models/mixformer.rs +++ b/candle-transformers/src/models/mixformer.rs @@ -438,16 +438,20 @@ impl MixFormerSequentialForCausalLM { xs.narrow(1, seq_len - 1, 1)?.apply(&self.head)?.squeeze(1) } - pub fn forward_with_img(&mut self, xs: &Tensor, img_embeds: &Tensor) -> Result { + pub fn forward_with_img( + &mut self, + bos_token: &Tensor, + xs: &Tensor, + img_embeds: &Tensor, + ) -> Result { let _enter = self.span.enter(); let xs = xs.apply(&self.embedding)?; - let mut xs = Tensor::cat(&[img_embeds.clone(), xs], 1)?; + let bos_token = bos_token.apply(&self.embedding)?; + // Python implementation sequence order is + // https://github.com/vikhyat/moondream/blob/a9d788a20d1543fb1479edc54106e88cff7759d3/moondream/moondream.py#L43-L56 + let mut xs = Tensor::cat(&[bos_token, img_embeds.clone(), xs], 1)?; let (_b_size, seq_len, _embds) = xs.dims3()?; - let mask = if seq_len <= 1 { - None - } else { - Some(get_mask(seq_len, xs.device())?) - }; + let mask = Some(get_mask(seq_len, xs.device())?); for block in self.blocks.iter_mut() { xs = block.forward(&xs, mask.as_ref())? } diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index ed0e0de71..3514e648c 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -35,6 +35,7 @@ pub mod quantized_llama2_c; pub mod quantized_metavoice; pub mod quantized_mistral; pub mod quantized_mixformer; +pub mod quantized_moondream; pub mod quantized_mpt; pub mod quantized_rwkv_v5; pub mod quantized_rwkv_v6; diff --git a/candle-transformers/src/models/moondream.rs b/candle-transformers/src/models/moondream.rs index c36052c67..42b24fb82 100644 --- a/candle-transformers/src/models/moondream.rs +++ b/candle-transformers/src/models/moondream.rs @@ -25,15 +25,15 @@ fn scaled_dot_product_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Result Result { + let _enter = self.span.enter(); + let xs = xs.apply(&self.embedding)?; + let bos_token = bos_token.apply(&self.embedding)?; + // Python implementation sequence order is + // https://github.com/vikhyat/moondream/blob/a9d788a20d1543fb1479edc54106e88cff7759d3/moondream/moondream.py#L43-L56 + let mut xs = Tensor::cat(&[bos_token, img_embeds.clone(), xs], 1)?; + let (_b_size, seq_len, _embds) = xs.dims3()?; + let mask = Some(get_mask(seq_len, xs.device())?); + for block in self.blocks.iter_mut() { + xs = block.forward(&xs, mask.as_ref())? + } + let xs = xs + .narrow(1, seq_len - 1, 1)? + .apply(&self.head)? + .squeeze(1)?; + Ok(xs) + } + pub fn clear_kv_cache(&mut self) { self.blocks.iter_mut().for_each(|b| b.clear_kv_cache()) } diff --git a/candle-transformers/src/models/quantized_moondream.rs b/candle-transformers/src/models/quantized_moondream.rs new file mode 100644 index 000000000..1b125d930 --- /dev/null +++ b/candle-transformers/src/models/quantized_moondream.rs @@ -0,0 +1,271 @@ +use crate::models::moondream::{Config, VisionConfig}; +use crate::models::quantized_mixformer::MixFormerSequentialForCausalLM as PhiModel; +use crate::quantized_nn::{layer_norm, linear_b, Linear}; +use crate::quantized_var_builder::VarBuilder; +use candle::{IndexOp, Module, Result, Tensor, D}; + +fn scaled_dot_product_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Result { + let dim = q.dim(D::Minus1)?; + let scale_factor = 1.0 / (dim as f64).sqrt(); + let attn_weights = (q.matmul(&k.t()?)? * scale_factor)?; + candle_nn::ops::softmax_last_dim(&attn_weights)?.matmul(v) +} + +#[derive(Debug, Clone)] +struct LinearPatchEmbedding { + linear: Linear, +} + +impl LinearPatchEmbedding { + fn new(vb: VarBuilder) -> Result { + let linear = linear_b(588, 1152, true, vb.pp("linear"))?; + Ok(Self { linear }) + } +} + +impl Module for LinearPatchEmbedding { + fn forward(&self, xs: &Tensor) -> Result { + xs.apply(&self.linear) + } +} + +#[derive(Debug, Clone)] +struct Attention { + num_heads: usize, + head_dim: usize, + qkv: Linear, + proj: Linear, +} + +impl Attention { + pub fn new(vb: VarBuilder, dim: usize, num_heads: usize) -> Result { + let qkv = linear_b(dim, dim * 3, true, vb.pp("qkv"))?; + let proj = linear_b(dim, dim, true, vb.pp("proj"))?; + Ok(Self { + num_heads, + head_dim: dim / num_heads, + qkv, + proj, + }) + } +} + +impl Module for Attention { + fn forward(&self, xs: &Tensor) -> Result { + let (b, n, c) = xs.dims3()?; + let qkv = xs + .apply(&self.qkv)? + .reshape((b, n, 3, self.num_heads, self.head_dim))? + .permute((2, 0, 3, 1, 4))?; + let (q, k, v) = ( + qkv.i(0)?.contiguous()?, + qkv.i(1)?.contiguous()?, + qkv.i(2)?.contiguous()?, + ); + scaled_dot_product_attention(&q, &k, &v)? + .transpose(1, 2)? + .reshape((b, n, c))? + .apply(&self.proj) + } +} + +#[derive(Debug, Clone)] +struct VitBlock { + attn: Attention, + mlp: Mlp, + norm1: candle_nn::LayerNorm, + norm2: candle_nn::LayerNorm, +} + +impl VitBlock { + fn new(vb: VarBuilder, dim: usize, num_heads: usize, cfg: &VisionConfig) -> Result { + let attn = Attention::new(vb.pp("attn"), dim, num_heads)?; + let mlp = Mlp::new(vb.pp("mlp"), dim, cfg.hidden_features, dim, cfg.act)?; + let norm1 = layer_norm(dim, 1e-5, vb.pp("norm1"))?; + let norm2 = layer_norm(dim, 1e-5, vb.pp("norm2"))?; + Ok(Self { + attn, + mlp, + norm1, + norm2, + }) + } +} + +impl Module for VitBlock { + fn forward(&self, xs: &Tensor) -> Result { + let ys = xs.apply(&self.norm1)?.apply(&self.attn)?; + let xs = (xs + &ys)?; + let ys = xs.apply(&self.norm2)?.apply(&self.mlp)?; + let xs = (&xs + &ys)?; + Ok(xs) + } +} + +#[derive(Debug, Clone)] +struct VisionTransformer { + patch_embed: LinearPatchEmbedding, + pos_embed: Tensor, + blocks: Vec, + norm: candle_nn::LayerNorm, +} + +impl VisionTransformer { + fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result { + let patch_embed = LinearPatchEmbedding::new(vb.pp("patch_embed"))?; + let pos_embed = vb + .get((1, cfg.embed_len, cfg.embed_dim), "pos_embed")? + .dequantize(vb.device())?; + let blocks = (0..cfg.num_blocks) + .map(|i| { + VitBlock::new( + vb.pp(format!("blocks.{}", i)), + cfg.embed_dim, + cfg.num_heads, + cfg, + ) + }) + .collect::>()?; + let norm = layer_norm(cfg.embed_dim, 1e-5, vb.pp("norm"))?; + Ok(Self { + patch_embed, + pos_embed, + blocks, + norm, + }) + } +} + +impl Module for VisionTransformer { + fn forward(&self, xs: &Tensor) -> Result { + let mut xs = (&xs.apply(&self.patch_embed)? + &self.pos_embed)?; + for block in self.blocks.iter() { + xs = xs.apply(block)?; + } + xs.apply(&self.norm) + } +} + +#[derive(Debug, Clone)] +pub struct Encoder { + model: VisionTransformer, +} + +impl Encoder { + fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result { + let model = VisionTransformer::new(cfg, vb.pp("model.visual"))?; + Ok(Self { model }) + } +} + +impl Module for Encoder { + fn forward(&self, xs: &Tensor) -> Result { + xs.apply(&self.model) + } +} + +#[derive(Debug, Clone)] +struct Mlp { + fc1: Linear, + act: candle_nn::Activation, + fc2: Linear, +} + +impl Mlp { + fn new( + vb: VarBuilder, + in_features: usize, + hidden_features: usize, + out_features: usize, + act: candle_nn::Activation, + ) -> Result { + let fc1 = linear_b(in_features, hidden_features, true, vb.pp("fc1"))?; + let fc2 = linear_b(hidden_features, out_features, true, vb.pp("fc2"))?; + Ok(Self { fc1, act, fc2 }) + } +} + +impl Module for Mlp { + fn forward(&self, xs: &Tensor) -> Result { + xs.apply(&self.fc1)?.apply(&self.act)?.apply(&self.fc2) + } +} + +#[derive(Debug, Clone)] +struct VisionProjection { + mlp: Mlp, +} + +impl VisionProjection { + fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result { + let mlp = Mlp::new( + vb.pp("mlp"), + cfg.image_embedding_dim, + cfg.hidden_dim, + cfg.model_dim, + cfg.act, + )?; + Ok(Self { mlp }) + } +} + +impl Module for VisionProjection { + fn forward(&self, xs: &Tensor) -> Result { + xs.apply(&self.mlp) + } +} + +#[derive(Debug, Clone)] +pub struct VisionEncoder { + encoder: Encoder, + projection: VisionProjection, +} + +impl VisionEncoder { + pub fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result { + let encoder = Encoder::new(cfg, vb.pp("encoder"))?; + let projection = VisionProjection::new(cfg, vb.pp("projection"))?; + Ok(Self { + encoder, + projection, + }) + } +} + +impl Module for VisionEncoder { + fn forward(&self, xs: &Tensor) -> Result { + let (b, c, hp1, wp2) = xs.dims4()?; + let (p1, p2) = (14, 14); + let h = hp1 / p1; + let w = wp2 / p2; + xs.reshape((b, c, h, p1, h, p2))? + .permute((0, 2, 4, 1, 3, 5))? + .reshape((b, h * w, c * p1 * p2))? + .apply(&self.encoder)? + .apply(&self.projection) + } +} + +pub struct Model { + pub text_model: PhiModel, + pub vision_encoder: VisionEncoder, +} + +impl Model { + pub fn new(config: &Config, vb: VarBuilder) -> Result { + let text_model = PhiModel::new_v2(&config.phi_config, vb.pp("text_model"))?; + let vision_encoder = VisionEncoder::new(&config.vision_config, vb.pp("vision_encoder"))?; + Ok(Self { + text_model, + vision_encoder, + }) + } + + pub fn vision_encoder(&self) -> &VisionEncoder { + &self.vision_encoder + } + + pub fn text_model(&mut self) -> &mut PhiModel { + &mut self.text_model + } +}