From c58c5d5b01b1457997ac68b3a873b64ca98afcb6 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Fri, 20 Sep 2024 14:31:20 -0600 Subject: [PATCH] Add the mimi audio-tokenizer. (#2488) * Add the mimi audio-tokenizer. * Formatting tweaks. * Add a full example. * Use the transformers names. * More renamings. * Get encoding and decoding to work. * Clippy fixes. --- .gitignore | 3 + candle-examples/Cargo.toml | 5 + candle-examples/examples/mimi/README.md | 20 + candle-examples/examples/mimi/audio_io.rs | 275 ++++++ candle-examples/examples/mimi/main.rs | 131 +++ candle-transformers/src/models/mimi/conv.rs | 670 +++++++++++++++ .../src/models/mimi/encodec.rs | 229 +++++ candle-transformers/src/models/mimi/mod.rs | 22 + .../src/models/mimi/quantization.rs | 404 +++++++++ candle-transformers/src/models/mimi/seanet.rs | 465 ++++++++++ .../src/models/mimi/transformer.rs | 802 ++++++++++++++++++ candle-transformers/src/models/mod.rs | 1 + 12 files changed, 3027 insertions(+) create mode 100644 candle-examples/examples/mimi/README.md create mode 100644 candle-examples/examples/mimi/audio_io.rs create mode 100644 candle-examples/examples/mimi/main.rs create mode 100644 candle-transformers/src/models/mimi/conv.rs create mode 100644 candle-transformers/src/models/mimi/encodec.rs create mode 100644 candle-transformers/src/models/mimi/mod.rs create mode 100644 candle-transformers/src/models/mimi/quantization.rs create mode 100644 candle-transformers/src/models/mimi/seanet.rs create mode 100644 candle-transformers/src/models/mimi/transformer.rs diff --git a/.gitignore b/.gitignore index 38a7d50471..4dfbcc1663 100644 --- a/.gitignore +++ b/.gitignore @@ -43,3 +43,6 @@ candle-wasm-examples/**/config*.json __pycache__ out.safetensors out.wav +bria.mp3 +bria.safetensors +bria.wav diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml index 6879c48b28..543c96667a 100644 --- a/candle-examples/Cargo.toml +++ b/candle-examples/Cargo.toml @@ -67,6 +67,7 @@ onnx = ["candle-onnx"] metal = ["candle/metal", "candle-nn/metal"] microphone = ["cpal"] encodec = ["cpal", "symphonia", "rubato"] +mimi = ["cpal", "symphonia", "rubato"] depth_anything_v2 = ["palette", "enterpolation"] [[example]] @@ -101,6 +102,10 @@ required-features = ["candle-datasets"] name = "llama2-c" required-features = ["candle-datasets"] +[[example]] +name = "mimi" +required-features = ["mimi"] + [[example]] name = "encodec" required-features = ["encodec"] diff --git a/candle-examples/examples/mimi/README.md b/candle-examples/examples/mimi/README.md new file mode 100644 index 0000000000..bbcfcdb710 --- /dev/null +++ b/candle-examples/examples/mimi/README.md @@ -0,0 +1,20 @@ +# candle-mimi + +[Mimi](https://huggingface.co/kyutai/mimi) is a state of the art audio +compression model using an encoder/decoder architecture with residual vector +quantization. The candle implementation supports streaming meaning that it's +possible to encode or decode a stream of audio tokens on the flight to provide +low latency interaction with an audio model. + +## Running one example + +Generating some audio tokens from an audio files. +```bash +wget https://github.com/metavoiceio/metavoice-src/raw/main/assets/bria.mp3 +cargo run --example mimi --features mimi --release -- audio-to-code bria.mp3 bria.safetensors +``` + +And decoding the audio tokens back into a sound file. +```bash +cargo run --example mimi --features mimi --release -- code-to-audio bria.safetensors bria.wav +``` diff --git a/candle-examples/examples/mimi/audio_io.rs b/candle-examples/examples/mimi/audio_io.rs new file mode 100644 index 0000000000..2103dd4adf --- /dev/null +++ b/candle-examples/examples/mimi/audio_io.rs @@ -0,0 +1,275 @@ +#![allow(unused)] +use anyhow::{Context, Result}; +use std::sync::{Arc, Mutex}; + +pub const SAMPLE_RATE: usize = 24_000; + +pub(crate) struct AudioOutputData_ { + resampled_data: std::collections::VecDeque, + resampler: rubato::FastFixedIn, + output_buffer: Vec, + input_buffer: Vec, + input_len: usize, +} + +impl AudioOutputData_ { + pub(crate) fn new(input_sample_rate: usize, output_sample_rate: usize) -> Result { + use rubato::Resampler; + + let resampled_data = std::collections::VecDeque::with_capacity(output_sample_rate * 10); + let resample_ratio = output_sample_rate as f64 / input_sample_rate as f64; + let resampler = rubato::FastFixedIn::new( + resample_ratio, + f64::max(resample_ratio, 1.0), + rubato::PolynomialDegree::Septic, + 1024, + 1, + )?; + let input_buffer = resampler.input_buffer_allocate(true).remove(0); + let output_buffer = resampler.output_buffer_allocate(true).remove(0); + Ok(Self { + resampled_data, + resampler, + input_buffer, + output_buffer, + input_len: 0, + }) + } + + pub fn reset(&mut self) { + use rubato::Resampler; + self.output_buffer.fill(0.); + self.input_buffer.fill(0.); + self.resampler.reset(); + self.resampled_data.clear(); + } + + pub(crate) fn take_all(&mut self) -> Vec { + let mut data = Vec::with_capacity(self.resampled_data.len()); + while let Some(elem) = self.resampled_data.pop_back() { + data.push(elem); + } + data + } + + pub(crate) fn is_empty(&self) -> bool { + self.resampled_data.is_empty() + } + + // Assumes that the input buffer is large enough. + fn push_input_buffer(&mut self, samples: &[f32]) { + self.input_buffer[self.input_len..self.input_len + samples.len()].copy_from_slice(samples); + self.input_len += samples.len() + } + + pub(crate) fn push_samples(&mut self, samples: &[f32]) -> Result<()> { + use rubato::Resampler; + + let mut pos_in = 0; + loop { + let rem = self.input_buffer.len() - self.input_len; + let pos_end = usize::min(pos_in + rem, samples.len()); + self.push_input_buffer(&samples[pos_in..pos_end]); + pos_in = pos_end; + if self.input_len < self.input_buffer.len() { + break; + } + let (_, out_len) = self.resampler.process_into_buffer( + &[&self.input_buffer], + &mut [&mut self.output_buffer], + None, + )?; + for &elem in self.output_buffer[..out_len].iter() { + self.resampled_data.push_front(elem) + } + self.input_len = 0; + } + Ok(()) + } +} + +type AudioOutputData = Arc>; + +pub(crate) fn setup_output_stream() -> Result<(cpal::Stream, AudioOutputData)> { + use cpal::traits::{DeviceTrait, HostTrait, StreamTrait}; + + println!("Setup audio output stream!"); + let host = cpal::default_host(); + let device = host + .default_output_device() + .context("no output device available")?; + let mut supported_configs_range = device.supported_output_configs()?; + let config_range = match supported_configs_range.find(|c| c.channels() == 1) { + // On macOS, it's commonly the case that there are only stereo outputs. + None => device + .supported_output_configs()? + .next() + .context("no audio output available")?, + Some(config_range) => config_range, + }; + let sample_rate = cpal::SampleRate(SAMPLE_RATE as u32).clamp( + config_range.min_sample_rate(), + config_range.max_sample_rate(), + ); + let config: cpal::StreamConfig = config_range.with_sample_rate(sample_rate).into(); + let channels = config.channels as usize; + println!( + "cpal device: {} {} {config:?}", + device.name().unwrap_or_else(|_| "unk".to_string()), + config.sample_rate.0 + ); + let audio_data = Arc::new(Mutex::new(AudioOutputData_::new( + SAMPLE_RATE, + config.sample_rate.0 as usize, + )?)); + let ad = audio_data.clone(); + let stream = device.build_output_stream( + &config, + move |data: &mut [f32], _: &cpal::OutputCallbackInfo| { + data.fill(0.); + let mut ad = ad.lock().unwrap(); + let mut last_elem = 0f32; + for (idx, elem) in data.iter_mut().enumerate() { + if idx % channels == 0 { + match ad.resampled_data.pop_back() { + None => break, + Some(v) => { + last_elem = v; + *elem = v + } + } + } else { + *elem = last_elem + } + } + }, + move |err| eprintln!("cpal error: {err}"), + None, // None=blocking, Some(Duration)=timeout + )?; + stream.play()?; + Ok((stream, audio_data)) +} + +pub(crate) fn setup_input_stream() -> Result<(cpal::Stream, AudioOutputData)> { + use cpal::traits::{DeviceTrait, HostTrait, StreamTrait}; + + println!("Setup audio input stream!"); + let host = cpal::default_host(); + let device = host + .default_input_device() + .context("no input device available")?; + let mut supported_configs_range = device.supported_input_configs()?; + let config_range = supported_configs_range + .find(|c| c.channels() == 1) + .context("no audio input available")?; + let sample_rate = cpal::SampleRate(SAMPLE_RATE as u32).clamp( + config_range.min_sample_rate(), + config_range.max_sample_rate(), + ); + let config: cpal::StreamConfig = config_range.with_sample_rate(sample_rate).into(); + println!( + "cpal device: {} {} {config:?}", + device.name().unwrap_or_else(|_| "unk".to_string()), + config.sample_rate.0 + ); + let audio_data = Arc::new(Mutex::new(AudioOutputData_::new( + config.sample_rate.0 as usize, + SAMPLE_RATE, + )?)); + let ad = audio_data.clone(); + let stream = device.build_input_stream( + &config, + move |data: &[f32], _: &cpal::InputCallbackInfo| { + let mut ad = ad.lock().unwrap(); + if let Err(err) = ad.push_samples(data) { + eprintln!("error processing audio input {err:?}") + } + }, + move |err| eprintln!("cpal error: {err}"), + None, // None=blocking, Some(Duration)=timeout + )?; + stream.play()?; + Ok((stream, audio_data)) +} + +fn conv(samples: &mut Vec, data: std::borrow::Cow>) +where + T: symphonia::core::sample::Sample, + f32: symphonia::core::conv::FromSample, +{ + use symphonia::core::audio::Signal; + use symphonia::core::conv::FromSample; + samples.extend(data.chan(0).iter().map(|v| f32::from_sample(*v))) +} + +pub(crate) fn pcm_decode>(path: P) -> Result<(Vec, u32)> { + use symphonia::core::audio::{AudioBufferRef, Signal}; + + let src = std::fs::File::open(path)?; + let mss = symphonia::core::io::MediaSourceStream::new(Box::new(src), Default::default()); + let hint = symphonia::core::probe::Hint::new(); + let meta_opts: symphonia::core::meta::MetadataOptions = Default::default(); + let fmt_opts: symphonia::core::formats::FormatOptions = Default::default(); + let probed = symphonia::default::get_probe().format(&hint, mss, &fmt_opts, &meta_opts)?; + let mut format = probed.format; + let track = format + .tracks() + .iter() + .find(|t| t.codec_params.codec != symphonia::core::codecs::CODEC_TYPE_NULL) + .expect("no supported audio tracks"); + let mut decoder = symphonia::default::get_codecs() + .make(&track.codec_params, &Default::default()) + .expect("unsupported codec"); + let track_id = track.id; + let sample_rate = track.codec_params.sample_rate.unwrap_or(0); + let mut pcm_data = Vec::new(); + while let Ok(packet) = format.next_packet() { + while !format.metadata().is_latest() { + format.metadata().pop(); + } + if packet.track_id() != track_id { + continue; + } + match decoder.decode(&packet)? { + AudioBufferRef::F32(buf) => pcm_data.extend(buf.chan(0)), + AudioBufferRef::U8(data) => conv(&mut pcm_data, data), + AudioBufferRef::U16(data) => conv(&mut pcm_data, data), + AudioBufferRef::U24(data) => conv(&mut pcm_data, data), + AudioBufferRef::U32(data) => conv(&mut pcm_data, data), + AudioBufferRef::S8(data) => conv(&mut pcm_data, data), + AudioBufferRef::S16(data) => conv(&mut pcm_data, data), + AudioBufferRef::S24(data) => conv(&mut pcm_data, data), + AudioBufferRef::S32(data) => conv(&mut pcm_data, data), + AudioBufferRef::F64(data) => conv(&mut pcm_data, data), + } + } + Ok((pcm_data, sample_rate)) +} + +pub(crate) fn resample(pcm_in: &[f32], sr_in: usize, sr_out: usize) -> Result> { + use rubato::Resampler; + + let mut pcm_out = + Vec::with_capacity((pcm_in.len() as f64 * sr_out as f64 / sr_in as f64) as usize + 1024); + + let mut resampler = rubato::FftFixedInOut::::new(sr_in, sr_out, 1024, 1)?; + let mut output_buffer = resampler.output_buffer_allocate(true); + let mut pos_in = 0; + while pos_in + resampler.input_frames_next() < pcm_in.len() { + let (in_len, out_len) = + resampler.process_into_buffer(&[&pcm_in[pos_in..]], &mut output_buffer, None)?; + pos_in += in_len; + pcm_out.extend_from_slice(&output_buffer[0][..out_len]); + } + + if pos_in < pcm_in.len() { + let (_in_len, out_len) = resampler.process_partial_into_buffer( + Some(&[&pcm_in[pos_in..]]), + &mut output_buffer, + None, + )?; + pcm_out.extend_from_slice(&output_buffer[0][..out_len]); + } + + Ok(pcm_out) +} diff --git a/candle-examples/examples/mimi/main.rs b/candle-examples/examples/mimi/main.rs new file mode 100644 index 0000000000..cfc1a553e5 --- /dev/null +++ b/candle-examples/examples/mimi/main.rs @@ -0,0 +1,131 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use anyhow::Result; +use candle::{DType, IndexOp, Tensor}; +use candle_nn::VarBuilder; +use candle_transformers::models::mimi::{Config, Model}; +use clap::{Parser, ValueEnum}; +use hf_hub::api::sync::Api; + +mod audio_io; + +#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)] +enum Action { + AudioToAudio, + AudioToCode, + CodeToAudio, +} + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + /// The action to be performed, specifies the format for the input and output data. + action: Action, + + /// The input file, either an audio file or some mimi tokens stored as safetensors. + in_file: String, + + /// The output file, either a wave audio file or some mimi tokens stored as safetensors. + out_file: String, + + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + /// The model weight file, in safetensor format. + #[arg(long)] + model: Option, +} + +fn main() -> Result<()> { + let args = Args::parse(); + let device = candle_examples::device(args.cpu)?; + let model = match args.model { + Some(model) => std::path::PathBuf::from(model), + None => Api::new()? + .model("kyutai/mimi".to_string()) + .get("model.safetensors")?, + }; + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? }; + let config = Config::v0_1(None); + let mut model = Model::new(config, vb)?; + + let codes = match args.action { + Action::CodeToAudio => { + let codes = candle::safetensors::load(args.in_file, &device)?; + codes.get("codes").expect("no codes in input file").clone() + } + Action::AudioToCode | Action::AudioToAudio => { + let pcm = if args.in_file == "-" { + println!(">>>> RECORDING AUDIO, PRESS ENTER ONCE DONE <<<<"); + let (stream, input_audio) = audio_io::setup_input_stream()?; + let mut pcms = vec![]; + let stdin = std::thread::spawn(|| { + let mut s = String::new(); + std::io::stdin().read_line(&mut s) + }); + while !stdin.is_finished() { + let input = input_audio.lock().unwrap().take_all(); + if input.is_empty() { + std::thread::sleep(std::time::Duration::from_millis(100)); + continue; + } + pcms.push(input) + } + drop(stream); + pcms.concat() + } else { + let (pcm, sample_rate) = audio_io::pcm_decode(args.in_file)?; + if sample_rate != 24_000 { + println!("WARNING: mimi uses a 24khz sample rate, input uses {sample_rate}, resampling..."); + audio_io::resample(&pcm, sample_rate as usize, 24_000)? + } else { + pcm + } + }; + let pcm_len = pcm.len(); + let pcm = Tensor::from_vec(pcm, (1, 1, pcm_len), &device)?; + println!("input pcm shape: {:?}", pcm.shape()); + model.encode(&pcm)? + } + }; + println!("codes shape: {:?}", codes.shape()); + + match args.action { + Action::AudioToCode => { + codes.save_safetensors("codes", &args.out_file)?; + } + Action::AudioToAudio | Action::CodeToAudio => { + let pcm = model.decode(&codes)?; + println!("output pcm shape: {:?}", pcm.shape()); + let pcm = pcm.i(0)?.i(0)?; + let pcm = candle_examples::audio::normalize_loudness(&pcm, 24_000, true)?; + let pcm = pcm.to_vec1::()?; + if args.out_file == "-" { + let (stream, ad) = audio_io::setup_output_stream()?; + { + let mut ad = ad.lock().unwrap(); + ad.push_samples(&pcm)?; + } + loop { + let ad = ad.lock().unwrap(); + if ad.is_empty() { + break; + } + // That's very weird, calling thread::sleep here triggers the stream to stop + // playing (the callback doesn't seem to be called anymore). + // std::thread::sleep(std::time::Duration::from_millis(100)); + } + drop(stream) + } else { + let mut output = std::fs::File::create(&args.out_file)?; + candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, 24_000)?; + } + } + } + Ok(()) +} diff --git a/candle-transformers/src/models/mimi/conv.rs b/candle-transformers/src/models/mimi/conv.rs new file mode 100644 index 0000000000..87e9fb4cdd --- /dev/null +++ b/candle-transformers/src/models/mimi/conv.rs @@ -0,0 +1,670 @@ +// Copyright (c) Kyutai, all rights reserved. +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +use candle::{Module, Result, StreamTensor, StreamingModule, Tensor, D}; +use candle_nn::{Conv1d, VarBuilder}; + +#[allow(clippy::enum_variant_names)] +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub enum Norm { + WeightNorm, + SpectralNorm, + TimeGroupNorm, +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub enum PadMode { + Constant, + Reflect, + Replicate, +} + +// Applies weight norm for inference by recomputing the weight tensor. This +// does not apply to training. +// https://pytorch.org/docs/stable/generated/torch.nn.utils.weight_norm.html +fn conv1d_weight_norm( + in_c: usize, + out_c: usize, + kernel_size: usize, + bias: bool, + config: candle_nn::Conv1dConfig, + vb: VarBuilder, +) -> Result { + let weight = if vb.contains_tensor("weight") { + vb.get((out_c, in_c, kernel_size), "weight")? + } else { + let weight_g = vb.get((out_c, 1, 1), "weight_g")?; + let weight_v = vb.get((out_c, in_c, kernel_size), "weight_v")?; + let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?; + weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)? + }; + let bias = if bias { + Some(vb.get(out_c, "bias")?) + } else { + None + }; + Ok(Conv1d::new(weight, bias, config)) +} + +#[derive(Debug, Clone)] +pub struct NormConv1d { + conv: Conv1d, + norm: Option, + span: tracing::Span, +} + +impl NormConv1d { + #[allow(clippy::too_many_arguments)] + pub fn new( + in_c: usize, + out_c: usize, + k_size: usize, + causal: bool, + norm: Option, + bias: bool, + cfg: candle_nn::Conv1dConfig, + vb: VarBuilder, + ) -> Result { + let conv = match norm { + None | Some(Norm::TimeGroupNorm) => { + if bias { + candle_nn::conv1d(in_c, out_c, k_size, cfg, vb.pp("conv"))? + } else { + candle_nn::conv1d_no_bias(in_c, out_c, k_size, cfg, vb.pp("conv"))? + } + } + Some(Norm::WeightNorm) => { + conv1d_weight_norm(in_c, out_c, k_size, bias, cfg, vb.pp("conv"))? + } + Some(Norm::SpectralNorm) => candle::bail!("SpectralNorm is not supported yet."), + }; + let norm = match norm { + None | Some(Norm::WeightNorm) | Some(Norm::SpectralNorm) => None, + Some(Norm::TimeGroupNorm) => { + if causal { + candle::bail!("GroupNorm doesn't support causal evaluation.") + } + let norm = candle_nn::group_norm(1, out_c, 1e-5, vb.pp("norm"))?; + Some(norm) + } + }; + Ok(Self { + conv, + norm, + span: tracing::span!(tracing::Level::TRACE, "norm-conv1d"), + }) + } +} + +impl Module for NormConv1d { + fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); + let xs = xs.apply(&self.conv)?; + match self.norm.as_ref() { + None => Ok(xs), + Some(norm) => xs.apply(norm), + } + } +} + +#[derive(Debug, Clone)] +pub struct NormConvTranspose1d { + ws: Tensor, + bs: Option, + k_size: usize, + stride: usize, + groups: usize, + norm: Option, + span: tracing::Span, +} + +impl NormConvTranspose1d { + #[allow(clippy::too_many_arguments)] + pub fn new( + in_c: usize, + out_c: usize, + k_size: usize, + causal: bool, + norm: Option, + bias: bool, + stride: usize, + groups: usize, + vb: VarBuilder, + ) -> Result { + let vb = vb.pp("conv"); + let bs = if bias { + Some(vb.get(out_c, "bias")?) + } else { + None + }; + let ws = match norm { + None | Some(Norm::TimeGroupNorm) => vb.get((in_c, out_c / groups, k_size), "weight")?, + Some(Norm::WeightNorm) => { + if vb.contains_tensor("weight") { + vb.get((in_c, out_c, k_size), "weight")? + } else { + let weight_g = vb.get((in_c, 1, 1), "weight_g")?; + let weight_v = vb.get((in_c, out_c, k_size), "weight_v")?; + let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?; + weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)? + } + } + Some(Norm::SpectralNorm) => candle::bail!("SpectralNorm is not supported yet."), + }; + let (ws, groups) = if groups == out_c && in_c == out_c { + let eye = Tensor::eye(out_c, ws.dtype(), ws.device())?; + let ws = ws + .repeat((1, out_c, 1))? + .mul(&eye.unsqueeze(2)?.repeat((1, 1, k_size))?)?; + (ws, 1) + } else { + (ws, groups) + }; + let norm = match norm { + None | Some(Norm::WeightNorm) | Some(Norm::SpectralNorm) => None, + Some(Norm::TimeGroupNorm) => { + if causal { + candle::bail!("GroupNorm doesn't support causal evaluation.") + } + let norm = candle_nn::group_norm(1, out_c, 1e-5, vb.pp("norm"))?; + Some(norm) + } + }; + Ok(Self { + ws, + bs, + k_size, + stride, + groups, + norm, + span: tracing::span!(tracing::Level::TRACE, "norm-conv-tr1d"), + }) + } +} + +impl Module for NormConvTranspose1d { + fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); + // conv-transpose1d seems to be broken on metal after enough iterations. Causing + // the following error: + // _status < MTLCommandBufferStatusCommitted > + // -[IOGPUMetalCommandBuffer setCurrentCommandEncoder:] + // This is now fixed in candle. + let xs = Tensor::conv_transpose1d(xs, &self.ws, 0, 0, self.stride, 1, self.groups)?; + let xs = match &self.bs { + None => xs, + Some(bias) => { + let b = bias.dims1()?; + let bias = bias.reshape((1, b, 1))?; + xs.broadcast_add(&bias)? + } + }; + match self.norm.as_ref() { + None => Ok(xs), + Some(norm) => xs.apply(norm), + } + } +} + +fn get_extra_padding_for_conv1d( + xs: &Tensor, + k_size: usize, + stride: usize, + padding_total: usize, +) -> Result { + let len = xs.dim(D::Minus1)?; + let n_frames = (len + padding_total).saturating_sub(k_size) as f64 / stride as f64 + 1.0; + let ideal_len = + ((n_frames.ceil() as usize - 1) * stride + k_size).saturating_sub(padding_total); + Ok(ideal_len.saturating_sub(len)) +} + +fn pad1d(xs: &Tensor, pad_l: usize, pad_r: usize, mode: PadMode) -> Result { + match mode { + PadMode::Constant => xs.pad_with_zeros(D::Minus1, pad_l, pad_r), + PadMode::Reflect => candle::bail!("pad-mode 'reflect' is not supported"), + PadMode::Replicate => xs.pad_with_same(D::Minus1, pad_l, pad_r), + } +} + +fn unpad1d(xs: &Tensor, unpad_l: usize, unpad_r: usize) -> Result { + let len = xs.dim(D::Minus1)?; + if len < unpad_l + unpad_r { + candle::bail!("unpad1d: tensor len {len} is too low, {unpad_l} + {unpad_r}") + } + xs.narrow(D::Minus1, unpad_l, len - (unpad_l + unpad_r)) +} + +#[derive(Debug, Clone)] +pub struct StreamableConv1d { + conv: NormConv1d, + causal: bool, + pad_mode: PadMode, + state_prev_xs: StreamTensor, + left_pad_applied: bool, + kernel_size: usize, + span: tracing::Span, +} + +impl StreamableConv1d { + #[allow(clippy::too_many_arguments)] + pub fn new( + in_c: usize, + out_c: usize, + k_size: usize, + stride: usize, + dilation: usize, + groups: usize, + bias: bool, + causal: bool, + norm: Option, + pad_mode: PadMode, + vb: VarBuilder, + ) -> Result { + let cfg = candle_nn::Conv1dConfig { + padding: 0, + stride, + dilation, + groups, + }; + let conv = NormConv1d::new(in_c, out_c, k_size, causal, norm, bias, cfg, vb)?; + if k_size < stride { + candle::bail!("kernel-size {k_size} is smaller than stride {stride}") + } + Ok(Self { + conv, + causal, + pad_mode, + state_prev_xs: StreamTensor::empty(), + left_pad_applied: false, + kernel_size: k_size, + span: tracing::span!(tracing::Level::TRACE, "streamable-conv1d"), + }) + } +} + +impl Module for StreamableConv1d { + fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); + let (_b, _t, _c) = xs.dims3()?; + let k_size = self.conv.conv.weight().dim(D::Minus1)?; + let conv_cfg = self.conv.conv.config(); + // Effective kernel size with dilations. + let k_size = (k_size - 1) * conv_cfg.dilation + 1; + let padding_total = k_size - conv_cfg.stride; + let extra_padding = + get_extra_padding_for_conv1d(xs, k_size, conv_cfg.stride, padding_total)?; + let xs = if self.causal { + pad1d(xs, padding_total, extra_padding, self.pad_mode)? + } else { + let padding_right = padding_total / 2; + let padding_left = padding_total - padding_right; + pad1d( + xs, + padding_left, + padding_right + extra_padding, + self.pad_mode, + )? + }; + xs.apply(&self.conv) + } +} + +impl StreamingModule for StreamableConv1d { + fn reset_state(&mut self) { + self.state_prev_xs.reset(); + self.left_pad_applied = false; + } + + fn step(&mut self, xs: &StreamTensor) -> Result { + let _enter = self.span.enter(); + let xs = match xs.as_option() { + None => return Ok(().into()), + Some(xs) => xs.clone(), + }; + let xs = if self.left_pad_applied { + xs + } else { + self.left_pad_applied = true; + let k_size = self.conv.conv.weight().dim(D::Minus1)?; + let conv_cfg = self.conv.conv.config(); + let k_size = (k_size - 1) * conv_cfg.dilation + 1; + let padding_total = k_size - conv_cfg.stride; + pad1d(&xs, padding_total, 0, self.pad_mode)? + }; + let cfg = self.conv.conv.config(); + let stride = cfg.stride; + let dilation = cfg.dilation; + let kernel = (self.kernel_size - 1) * dilation + 1; + let xs = StreamTensor::cat2(&self.state_prev_xs, &xs.into(), D::Minus1)?; + let seq_len = xs.seq_len(D::Minus1)?; + let num_frames = (seq_len + stride).saturating_sub(kernel) / stride; + if num_frames > 0 { + let offset = num_frames * stride; + self.state_prev_xs = xs.narrow(D::Minus1, offset, seq_len - offset)?; + let in_l = (num_frames - 1) * stride + kernel; + let xs = xs.narrow(D::Minus1, 0, in_l)?; + // We apply the underlying convtr directly rather than through forward so as + // not to apply any padding here. + xs.apply(&self.conv.conv) + } else { + self.state_prev_xs = xs; + Ok(StreamTensor::empty()) + } + } +} + +#[derive(Debug, Clone)] +pub struct StreamableConvTranspose1d { + convtr: NormConvTranspose1d, + causal: bool, + state_prev_ys: StreamTensor, + kernel_size: usize, + span: tracing::Span, +} + +impl StreamableConvTranspose1d { + #[allow(clippy::too_many_arguments)] + pub fn new( + in_c: usize, + out_c: usize, + k_size: usize, + stride: usize, + groups: usize, + bias: bool, + causal: bool, + norm: Option, + vb: VarBuilder, + ) -> Result { + let convtr = + NormConvTranspose1d::new(in_c, out_c, k_size, causal, norm, bias, stride, groups, vb)?; + Ok(Self { + convtr, + causal, + kernel_size: k_size, + state_prev_ys: StreamTensor::empty(), + span: tracing::span!(tracing::Level::TRACE, "streamable-conv-tr1d"), + }) + } +} + +impl Module for StreamableConvTranspose1d { + fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); + let k_size = self.convtr.k_size; + let stride = self.convtr.stride; + let padding_total = k_size.saturating_sub(stride); + let xs = xs.apply(&self.convtr)?; + if self.causal { + // This corresponds to trim_right_ratio = 1. + unpad1d(&xs, 0, padding_total) + } else { + let padding_right = padding_total / 2; + let padding_left = padding_total - padding_right; + unpad1d(&xs, padding_left, padding_right) + } + } +} + +impl StreamingModule for StreamableConvTranspose1d { + fn reset_state(&mut self) { + self.state_prev_ys.reset() + } + + fn step(&mut self, xs: &StreamTensor) -> Result { + let _enter = self.span.enter(); + let xs = match xs.as_option() { + Some(xs) => xs, + None => return Ok(StreamTensor::empty()), + }; + let stride = self.convtr.stride; + // We apply the underlying convtr directly rather than through forward so as + // not to apply any padding here. + let ys = self.convtr.forward(xs)?; + let ot = ys.dim(D::Minus1)?; + let ys = match self.state_prev_ys.as_option() { + None => ys, + Some(prev_ys) => { + let pt = prev_ys.dim(D::Minus1)?; + // Remove the bias as it will be applied multiple times. + let prev_ys = match &self.convtr.bs { + None => prev_ys.clone(), + Some(bias) => { + let bias = bias.reshape((1, (), 1))?; + prev_ys.broadcast_sub(&bias)? + } + }; + let ys1 = (ys.narrow(D::Minus1, 0, pt)? + prev_ys)?; + let ys2 = ys.narrow(D::Minus1, pt, ot - pt)?; + Tensor::cat(&[ys1, ys2], D::Minus1)? + } + }; + let invalid_steps = self.kernel_size - stride; + let (ys, prev_ys) = StreamTensor::from(ys).split(D::Minus1, ot - invalid_steps)?; + self.state_prev_ys = prev_ys; + Ok(ys) + } +} + +#[derive(Debug, Clone)] +pub struct ConvDownsample1d { + conv: StreamableConv1d, +} + +impl ConvDownsample1d { + pub fn new( + stride: usize, + dim: usize, + causal: bool, + learnt: bool, + vb: VarBuilder, + ) -> Result { + if !learnt { + candle::bail!("only learnt=true is supported") + } + let conv = StreamableConv1d::new( + /* in_c */ dim, + /* out_c */ dim, + /* k_size_c */ 2 * stride, + /* stride */ stride, + /* dilation */ 1, + /* groups */ 1, // channel_wise = false + /* bias */ false, + /* causal */ causal, + /* norm */ None, + /* pad_mode */ PadMode::Replicate, + vb, + )?; + Ok(Self { conv }) + } +} + +impl Module for ConvDownsample1d { + fn forward(&self, xs: &Tensor) -> Result { + xs.apply(&self.conv) + } +} + +impl StreamingModule for ConvDownsample1d { + fn reset_state(&mut self) { + self.conv.reset_state() + } + + fn step(&mut self, xs: &StreamTensor) -> Result { + self.conv.step(xs) + } +} + +#[derive(Debug, Clone)] +pub struct ConvTrUpsample1d { + convtr: StreamableConvTranspose1d, +} + +impl ConvTrUpsample1d { + pub fn new( + stride: usize, + dim: usize, + causal: bool, + learnt: bool, + vb: VarBuilder, + ) -> Result { + if !learnt { + candle::bail!("only learnt=true is supported") + } + let convtr = StreamableConvTranspose1d::new( + dim, + dim, + /* k_size */ 2 * stride, + /* stride */ stride, + /* groups */ dim, + /* bias */ false, + /* causal */ causal, + /* norm */ None, + vb, + )?; + Ok(Self { convtr }) + } +} + +impl Module for ConvTrUpsample1d { + fn forward(&self, xs: &Tensor) -> Result { + xs.apply(&self.convtr) + } +} + +impl StreamingModule for ConvTrUpsample1d { + fn reset_state(&mut self) { + self.convtr.reset_state() + } + + fn step(&mut self, xs: &StreamTensor) -> Result { + self.convtr.step(xs) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use candle::IndexOp; + + fn run_conv1d( + k_size: usize, + stride: usize, + dilation: usize, + step_size: usize, + len: usize, + bias: bool, + ) -> Result<()> { + // TODO: We should ensure for the seed to be constant when running these tests. + let dev = &candle::Device::Cpu; + let vm = candle_nn::VarMap::new(); + let vb = VarBuilder::from_varmap(&vm, candle::DType::F32, dev); + let conv1d = StreamableConv1d::new( + /* in_c */ 2, + /* out_c */ 3, + /* k_size */ k_size, + /* stride */ stride, + /* dilation */ dilation, + /* groups */ 1, + /* bias */ bias, + /* causal */ true, + /* norm */ None, + /* pad_mode */ PadMode::Constant, + vb, + )?; + let xs = Tensor::randn(0f32, 1., (1, 2, step_size * len), dev)?; + let ys = conv1d.forward(&xs)?; + let mut conv1d = conv1d; + let mut ys_steps = vec![]; + for idx in 0..len { + let xs = xs.i((.., .., step_size * idx..step_size * (idx + 1)))?; + let ys = conv1d.step(&xs.into())?; + if let Some(ys) = ys.as_option() { + ys_steps.push(ys.clone()) + } + } + let ys_steps = Tensor::cat(&ys_steps, D::Minus1)?; + let diff = (&ys - &ys_steps)? + .abs()? + .flatten_all()? + .max(0)? + .to_vec0::()?; + if diff > 1e-5 { + println!("{xs}"); + println!("{ys}"); + println!("{ys_steps}"); + candle::bail!("larger diff than expected {diff}") + } + Ok(()) + } + + fn run_conv_tr1d( + k_size: usize, + stride: usize, + step_size: usize, + len: usize, + bias: bool, + ) -> Result<()> { + // TODO: We should ensure for the seed to be constant when running these tests. + let dev = &candle::Device::Cpu; + let vm = candle_nn::VarMap::new(); + let vb = VarBuilder::from_varmap(&vm, candle::DType::F32, dev); + let conv1d = StreamableConvTranspose1d::new( + /* in_c */ 2, /* out_c */ 3, /* k_size */ k_size, + /* stride */ stride, /* groups */ 1, /* bias */ bias, + /* causal */ true, /* norm */ None, vb, + )?; + let xs = Tensor::randn(0f32, 1., (1, 2, step_size * len), dev)?; + let ys = conv1d.forward(&xs)?; + let mut conv1d = conv1d; + let mut ys_steps = vec![]; + for idx in 0..len { + let xs = xs.i((.., .., step_size * idx..step_size * (idx + 1)))?; + let ys = conv1d.step(&xs.into())?; + if let Some(ys) = ys.as_option() { + ys_steps.push(ys.clone()) + } + } + let ys_steps = Tensor::cat(&ys_steps, D::Minus1)?; + let diff = (&ys - &ys_steps)? + .abs()? + .flatten_all()? + .max(0)? + .to_vec0::()?; + if diff > 1e-5 { + println!("{xs}"); + println!("{ys}"); + println!("{ys_steps}"); + candle::bail!("larger diff than expected {diff}") + } + Ok(()) + } + + #[test] + fn conv1d() -> Result<()> { + for step_size in [1, 2, 3] { + for bias in [false, true] { + run_conv1d(1, 1, 1, step_size, 5, bias)?; + run_conv1d(2, 1, 1, step_size, 5, bias)?; + run_conv1d(2, 2, 1, step_size, 6, bias)?; + run_conv1d(3, 2, 1, step_size, 8, bias)?; + run_conv1d(3, 2, 2, step_size, 8, bias)?; + } + } + Ok(()) + } + + #[test] + fn conv_tr1d() -> Result<()> { + for step_size in [1, 2, 3] { + for bias in [false, true] { + run_conv_tr1d(1, 1, step_size, 5, bias)?; + run_conv_tr1d(2, 1, step_size, 5, bias)?; + run_conv_tr1d(3, 1, step_size, 5, bias)?; + run_conv_tr1d(3, 2, step_size, 5, bias)?; + } + } + Ok(()) + } +} diff --git a/candle-transformers/src/models/mimi/encodec.rs b/candle-transformers/src/models/mimi/encodec.rs new file mode 100644 index 0000000000..f659da3a1f --- /dev/null +++ b/candle-transformers/src/models/mimi/encodec.rs @@ -0,0 +1,229 @@ +// Copyright (c) Kyutai, all rights reserved. +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +use super::{conv, quantization, seanet, transformer}; +use candle::{DType, Device, Module, Result, StreamTensor, StreamingModule, Tensor}; +use candle_nn::VarBuilder; + +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub enum ResampleMethod { + Conv, + Interpolate, +} + +#[derive(Debug, Clone)] +pub struct Config { + pub channels: usize, + pub sample_rate: f64, + pub frame_rate: f64, + pub renormalize: bool, + pub resample_method: ResampleMethod, + pub seanet: seanet::Config, + pub transformer: transformer::Config, + pub quantizer_n_q: usize, + pub quantizer_bins: usize, + pub quantizer_dim: usize, +} + +impl Config { + // /lustre/scwpod02/client/kyutai/alex/mimi_exp/xps/b7d2bd5a/.hydra/config.yaml + pub fn v0_1(num_codebooks: Option) -> Self { + let seanet_cfg = seanet::Config { + dimension: 512, + channels: 1, + causal: true, + n_filters: 64, + n_residual_layers: 1, + activation: candle_nn::Activation::Elu(1.), + compress: 2, + dilation_base: 2, + disable_norm_outer_blocks: 0, + final_activation: None, + kernel_size: 7, + residual_kernel_size: 3, + last_kernel_size: 3, + lstm: 0, + norm: conv::Norm::WeightNorm, + pad_mode: conv::PadMode::Constant, + ratios: vec![8, 6, 5, 4], + true_skip: true, + }; + let transformer_cfg = transformer::Config { + d_model: seanet_cfg.dimension, + num_heads: 8, + num_layers: 8, + causal: true, + norm_first: true, + bias_ff: false, + bias_attn: false, + layer_scale: Some(0.01), + context: 250, + conv_kernel_size: 5, + use_conv_bias: true, + use_conv_block: false, + cross_attention: false, + max_period: 10000, + gating: None, + norm: super::NormType::LayerNorm, + positional_embedding: transformer::PositionalEmbedding::Rope, + + dim_feedforward: 2048, + kv_repeat: 1, + conv_layout: true, // see builders.py + max_seq_len: 8192, // the transformer works at 25hz so this is ~5 mins. + }; + Config { + channels: 1, + sample_rate: 24_000., + frame_rate: 12.5, + renormalize: true, + resample_method: ResampleMethod::Conv, + seanet: seanet_cfg, + transformer: transformer_cfg, + quantizer_n_q: num_codebooks.unwrap_or(16), + quantizer_bins: 2048, + quantizer_dim: 256, + } + } +} + +#[derive(Debug, Clone)] +pub struct Encodec { + encoder: seanet::SeaNetEncoder, + decoder: seanet::SeaNetDecoder, + encoder_transformer: transformer::ProjectedTransformer, + decoder_transformer: transformer::ProjectedTransformer, + downsample: conv::ConvDownsample1d, + upsample: conv::ConvTrUpsample1d, + quantizer: quantization::SplitResidualVectorQuantizer, + config: Config, +} + +impl Encodec { + pub fn new(cfg: Config, vb: VarBuilder) -> Result { + let dim = cfg.seanet.dimension; + let encoder = seanet::SeaNetEncoder::new(&cfg.seanet, vb.pp("encoder"))?; + let decoder = seanet::SeaNetDecoder::new(&cfg.seanet, vb.pp("decoder"))?; + let encoder_transformer = transformer::ProjectedTransformer::new( + dim, + &[dim], + &cfg.transformer, + vb.pp("encoder_transformer"), + )?; + let decoder_transformer = transformer::ProjectedTransformer::new( + dim, + &[dim], + &cfg.transformer, + vb.pp("decoder_transformer"), + )?; + let quantizer = quantization::SplitResidualVectorQuantizer::new( + /* dim */ cfg.quantizer_dim, + /* input_dim */ Some(dim), + /* output_dim */ Some(dim), + /* n_q */ cfg.quantizer_n_q, + /* bins */ cfg.quantizer_bins, + vb.pp("quantizer"), + )?; + let encoder_frame_rate = + cfg.sample_rate / cfg.seanet.ratios.iter().product::() as f64; + + let downsample_stride = (encoder_frame_rate / cfg.frame_rate) as usize; + // `upsample` and `downsample` only apply if frame_rate is different from encoder_frame_rate. + let downsample = conv::ConvDownsample1d::new( + /* stride */ downsample_stride, + /* dim */ dim, + /* causal */ true, + /* learnt */ true, + vb.pp("downsample"), + )?; + let upsample = conv::ConvTrUpsample1d::new( + /* stride */ downsample_stride, + /* dim */ dim, + /* causal */ true, + /* learnt */ true, + vb.pp("upsample"), + )?; + + Ok(Self { + encoder, + decoder, + encoder_transformer, + decoder_transformer, + quantizer, + downsample, + upsample, + config: cfg, + }) + } + + pub fn config(&self) -> &Config { + &self.config + } + + pub fn encode_pre_quantize(&mut self, xs: &Tensor) -> Result { + let xs = self.encoder.forward(xs)?; + self.encoder_transformer.reset_state(); + let xs = self.encoder_transformer.forward(&xs)?; + let xs = &xs[0]; + xs.apply(&self.downsample) + } + + pub fn encode(&mut self, xs: &Tensor) -> Result { + let xs = self.encoder.forward(xs)?; + self.encoder_transformer.reset_state(); + let xs = self.encoder_transformer.forward(&xs)?; + let xs = &xs[0]; + let xs = xs.apply(&self.downsample)?; + let codes = self.quantizer.encode(&xs)?; + Ok(codes) + } + + pub fn encode_step(&mut self, xs: &StreamTensor) -> Result { + let xs = self.encoder.step(xs)?; + let xs = self.encoder_transformer.step(&xs)?; + let xs = self.downsample.step(&xs)?; + match xs.as_option() { + None => Ok(().into()), + Some(xs) => { + let codes = self.quantizer.encode(xs)?; + Ok(codes.into()) + } + } + } + + pub fn decode(&mut self, codes: &Tensor) -> Result { + let emb = self.quantizer.decode(codes)?; + let emb = emb.apply(&self.upsample)?; + self.decoder_transformer.reset_state(); + let outs = self.decoder_transformer.forward(&emb)?; + let out = &outs[0]; + self.decoder.forward(out) + } + + pub fn decode_step(&mut self, codes: &StreamTensor) -> Result { + let emb = match codes.as_option() { + Some(codes) => StreamTensor::from_tensor(self.quantizer.decode(codes)?), + None => StreamTensor::empty(), + }; + let emb = self.upsample.step(&emb)?; + let out = self.decoder_transformer.step(&emb)?; + self.decoder.step(&out) + } + + pub fn reset_state(&mut self) { + self.encoder.reset_state(); + self.encoder_transformer.reset_state(); + self.decoder.reset_state(); + self.decoder_transformer.reset_state(); + self.upsample.reset_state(); + } +} + +pub fn load(model_file: &str, num_codebooks: Option, dev: &Device) -> Result { + let vb = + unsafe { candle_nn::VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, dev)? }; + let cfg = Config::v0_1(num_codebooks); + let encodec = Encodec::new(cfg, vb)?; + Ok(encodec) +} diff --git a/candle-transformers/src/models/mimi/mod.rs b/candle-transformers/src/models/mimi/mod.rs new file mode 100644 index 0000000000..dc40e38e29 --- /dev/null +++ b/candle-transformers/src/models/mimi/mod.rs @@ -0,0 +1,22 @@ +// Adapted from the reference implementation at: +// https://github.com/kyutai-labs/moshi +// Copyright (c) Kyutai, all rights reserved. +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +pub use candle; +pub use candle_nn; + +pub mod conv; +pub mod encodec; +pub mod quantization; +pub mod seanet; +pub mod transformer; + +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub enum NormType { + RmsNorm, + LayerNorm, +} + +pub use encodec::{load, Config, Encodec as Model}; diff --git a/candle-transformers/src/models/mimi/quantization.rs b/candle-transformers/src/models/mimi/quantization.rs new file mode 100644 index 0000000000..3fde16472b --- /dev/null +++ b/candle-transformers/src/models/mimi/quantization.rs @@ -0,0 +1,404 @@ +// Copyright (c) Kyutai, all rights reserved. +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +use candle::{IndexOp, Layout, Result, Shape, Tensor, D}; +use candle_nn::{linear, Linear, VarBuilder}; + +struct CodebookEncode; + +impl candle::CustomOp2 for CodebookEncode { + fn name(&self) -> &'static str { + "cb" + } + + fn cpu_fwd( + &self, + lhs_storage: &candle::CpuStorage, + lhs_layout: &Layout, + rhs_storage: &candle::CpuStorage, + rhs_layout: &Layout, + ) -> Result<(candle::CpuStorage, Shape)> { + use rayon::prelude::*; + + let (lhs_dim1, lhs_dim2) = lhs_layout.shape().dims2()?; + let (rhs_dim1, rhs_dim2) = rhs_layout.shape().dims2()?; + if lhs_dim2 != rhs_dim2 { + candle::bail!("CodebookEncode, mismatch on last dim, {lhs_layout:?} {rhs_layout:?}"); + } + if lhs_dim2 == 0 { + candle::bail!("CodebookEncode, empty last dim {lhs_layout:?}") + } + let lhs = match lhs_layout.contiguous_offsets() { + None => candle::bail!("CodebookEncode, lhs has to be contiguous, got {lhs_layout:?}"), + Some((o1, o2)) => { + let slice = lhs_storage.as_slice::()?; + &slice[o1..o2] + } + }; + let rhs = match rhs_layout.contiguous_offsets() { + None => candle::bail!("CodebookEncode, rhs has to be contiguous, got {rhs_layout:?}"), + Some((o1, o2)) => { + let slice = rhs_storage.as_slice::()?; + &slice[o1..o2] + } + }; + let dst = (0..lhs_dim1) + .into_par_iter() + .map(|idx1| { + let mut where_min = 0; + let mut min_dist = f32::INFINITY; + let lhs = &lhs[idx1 * lhs_dim2..(idx1 + 1) * lhs_dim2]; + for idx2 in 0..rhs_dim1 { + let rhs = &rhs[idx2 * rhs_dim2..(idx2 + 1) * rhs_dim2]; + let mut dist = 0f32; + for (a, b) in lhs.iter().zip(rhs.iter()) { + dist += (a - b) * (a - b) + } + if dist < min_dist { + min_dist = dist; + where_min = idx2; + } + } + where_min as u32 + }) + .collect(); + let storage = candle::WithDType::to_cpu_storage_owned(dst); + Ok((storage, (lhs_dim1,).into())) + } +} + +#[allow(unused)] +#[derive(Debug, Clone)] +pub struct EuclideanCodebook { + initialized: Tensor, + cluster_usage: Tensor, + embedding_sum: Tensor, + embedding: Tensor, + c2: Tensor, + epsilon: f64, + dim: usize, + span_encode: tracing::Span, + span_decode: tracing::Span, +} + +impl EuclideanCodebook { + pub fn new(dim: usize, codebook_size: usize, vb: VarBuilder) -> Result { + let epsilon = 1e-5; + let initialized = vb.get(1, "initialized")?; + let cluster_usage = vb.get(codebook_size, "cluster_usage")?; + let embedding_sum = vb.get((codebook_size, dim), "embed_sum")?; + let embedding = { + let cluster_usage = cluster_usage.maximum(epsilon)?.unsqueeze(1)?; + embedding_sum.broadcast_div(&cluster_usage)? + }; + let c2 = ((&embedding * &embedding)?.sum(D::Minus1)? / 2.0)?; + Ok(Self { + initialized, + cluster_usage, + embedding_sum, + embedding, + c2, + epsilon, + dim, + span_encode: tracing::span!(tracing::Level::TRACE, "euclidean-encode"), + span_decode: tracing::span!(tracing::Level::TRACE, "euclidean-encode"), + }) + } + + pub fn encode_very_slow(&self, xs: &Tensor) -> Result { + let _enter = self.span_encode.enter(); + let mut target_shape = xs.dims().to_vec(); + target_shape.pop(); + let xs = xs.flatten_to(D::Minus2)?; + let _ = xs.dims2()?; + // TODO: avoid repeating this. + let cluster_usage = self.cluster_usage.maximum(self.epsilon)?.unsqueeze(1)?; + let embedding = self.embedding_sum.broadcast_div(&cluster_usage)?; + // Manual cdist implementation. + let diff = xs.unsqueeze(1)?.broadcast_sub(&embedding.unsqueeze(0)?)?; + let dists = diff.sqr()?.sum(D::Minus1)?; + let codes = dists.argmin(D::Minus1)?; + codes.reshape(target_shape) + } + + pub fn encode_slow(&self, xs: &Tensor) -> Result { + let _enter = self.span_encode.enter(); + let mut target_shape = xs.dims().to_vec(); + target_shape.pop(); + let xs = xs.flatten_to(D::Minus2)?; + let _ = xs.dims2()?; + let dot_prod = xs.matmul(&self.embedding.t()?)?; + let codes = self.c2.broadcast_sub(&dot_prod)?.argmin(D::Minus1)?; + codes.reshape(target_shape) + } + + pub fn encode(&self, xs: &Tensor) -> Result { + let _enter = self.span_encode.enter(); + let mut target_shape = xs.dims().to_vec(); + target_shape.pop(); + let xs = xs.flatten_to(D::Minus2)?; + let _ = xs.dims2()?; + let codes = Tensor::apply_op2(&xs, &self.embedding, CodebookEncode)?; + codes.reshape(target_shape) + } + + pub fn decode(&self, indexes: &Tensor) -> Result { + let _enter = self.span_decode.enter(); + // let ys = candle_nn::Embedding::new(self.embedding.clone(), self.dim).forward(xs)?; + let mut final_dims = indexes.dims().to_vec(); + final_dims.push(self.dim); + let indexes = indexes.flatten_all()?; + let values = self.embedding.index_select(&indexes, 0)?; + let values = values.reshape(final_dims)?; + Ok(values) + } +} + +#[allow(unused)] +#[derive(Debug, Clone)] +pub struct VectorQuantization { + project_in: Option, + project_out: Option, + codebook: EuclideanCodebook, +} + +impl VectorQuantization { + pub fn new( + dim: usize, + codebook_size: usize, + codebook_dim: Option, + vb: VarBuilder, + ) -> Result { + let codebook_dim = codebook_dim.unwrap_or(dim); + let (project_in, project_out) = if codebook_dim == dim { + (None, None) + } else { + let p_in = linear(dim, codebook_dim, vb.pp("project_in"))?; + let p_out = linear(codebook_dim, dim, vb.pp("project_out"))?; + (Some(p_in), Some(p_out)) + }; + let codebook = EuclideanCodebook::new(codebook_dim, codebook_size, vb.pp("codebook"))?; + Ok(Self { + project_in, + project_out, + codebook, + }) + } + + pub fn encode(&self, xs: &Tensor) -> Result { + let xs = xs.t()?.apply(&self.project_in.as_ref())?; + self.codebook.encode_slow(&xs) + } + + pub fn decode(&self, codes: &Tensor) -> Result { + let quantized = self.codebook.decode(codes)?; + let quantized = match &self.project_out { + None => quantized, + Some(p) => quantized.apply(p)?, + }; + quantized.t() + } +} + +#[derive(Debug, Clone)] +pub struct ResidualVectorQuantization { + layers: Vec, +} + +impl ResidualVectorQuantization { + pub fn new( + n_q: usize, + dim: usize, + codebook_size: usize, + codebook_dim: Option, + vb: VarBuilder, + ) -> Result { + let vb = vb.pp("layers"); + let mut layers = Vec::with_capacity(n_q); + for i in 0..n_q { + let layer = VectorQuantization::new(dim, codebook_size, codebook_dim, vb.pp(i))?; + layers.push(layer) + } + Ok(Self { layers }) + } + + pub fn encode(&self, xs: &Tensor) -> Result { + let mut codes = Vec::with_capacity(self.layers.len()); + let mut residual = xs.clone(); + for layer in self.layers.iter() { + let indices = layer.encode(&residual)?; + let quantized = layer.decode(&indices)?; + residual = (residual - quantized)?; + codes.push(indices) + } + Tensor::stack(&codes, 0) + } + + pub fn decode(&self, xs: &Tensor) -> Result { + if self.layers.is_empty() { + candle::bail!("empty layers in ResidualVectorQuantization") + } + if self.layers.len() != xs.dim(0)? { + candle::bail!( + "mismatch between the number of layers {} and the code shape {:?}", + self.layers.len(), + xs.shape() + ) + } + let mut quantized = self.layers[0].decode(&xs.i(0)?)?; + for (i, layer) in self.layers.iter().enumerate().skip(1) { + let xs = xs.i(i)?; + quantized = (quantized + layer.decode(&xs))? + } + Ok(quantized) + } +} + +#[allow(unused)] +#[derive(Debug, Clone)] +pub struct ResidualVectorQuantizer { + vq: ResidualVectorQuantization, + input_proj: Option, + output_proj: Option, +} + +impl ResidualVectorQuantizer { + pub fn new( + dim: usize, + input_dim: Option, + output_dim: Option, + n_q: usize, + bins: usize, + force_projection: bool, + vb: VarBuilder, + ) -> Result { + let input_dim = input_dim.unwrap_or(dim); + let output_dim = output_dim.unwrap_or(dim); + + let input_proj = if input_dim == dim && !force_projection { + None + } else { + let c = candle_nn::conv1d_no_bias( + input_dim, + dim, + 1, + Default::default(), + vb.pp("input_proj"), + )?; + Some(c) + }; + let output_proj = if output_dim == dim && !force_projection { + None + } else { + let c = candle_nn::conv1d_no_bias( + dim, + output_dim, + 1, + Default::default(), + vb.pp("output_proj"), + )?; + Some(c) + }; + + let vq = ResidualVectorQuantization::new( + n_q, dim, /* codebook_size */ bins, /* codebook_dim */ None, vb, + )?; + Ok(Self { + vq, + input_proj, + output_proj, + }) + } + + pub fn encode(&self, xs: &Tensor) -> Result { + let codes = self.vq.encode(&xs.apply(&self.input_proj.as_ref())?)?; + codes.transpose(0, 1) + } + + pub fn decode(&self, codes: &Tensor) -> Result { + // codes is [B, K, T], with T frames, K nb of codebooks, vq.decode expects [K, B, T]. + let codes = codes.transpose(0, 1)?; + let quantized = self.vq.decode(&codes)?; + match &self.output_proj { + None => Ok(quantized), + Some(p) => quantized.apply(p), + } + } +} + +// we do not use any codebook_offset at the moment. When reconstructing the codes, we could just +// concatenate the indexes. +#[derive(Debug, Clone)] +pub struct SplitResidualVectorQuantizer { + rvq_first: ResidualVectorQuantizer, + rvq_rest: ResidualVectorQuantizer, + n_q: usize, + span_encode: tracing::Span, + span_decode: tracing::Span, +} + +impl SplitResidualVectorQuantizer { + pub fn new( + dim: usize, + input_dim: Option, + output_dim: Option, + n_q: usize, + bins: usize, + vb: VarBuilder, + ) -> Result { + let rvq_first = ResidualVectorQuantizer::new( + dim, + input_dim, + output_dim, + 1, + bins, + true, + vb.pp("semantic_residual_vector_quantizer"), + )?; + let rvq_rest = ResidualVectorQuantizer::new( + dim, + input_dim, + output_dim, + n_q - 1, + bins, + true, + vb.pp("acoustic_residual_vector_quantizer"), + )?; + let span_encode = tracing::span!(tracing::Level::TRACE, "split-rvq-encode"); + let span_decode = tracing::span!(tracing::Level::TRACE, "split-rvq-decode"); + Ok(Self { + rvq_first, + rvq_rest, + n_q, + span_encode, + span_decode, + }) + } + + pub fn encode(&self, xs: &Tensor) -> Result { + let _enter = self.span_encode.enter(); + let codes = self.rvq_first.encode(xs)?; + if self.n_q > 1 { + // We encode xs again here rather than the residual. The decomposition is not + // hierarchical but rather having semantic tokens for rvq_first and the acoustic tokens + // for rvq_rest. + let rest_codes = self.rvq_rest.encode(xs)?; + Tensor::cat(&[codes, rest_codes], 1) + } else { + Ok(codes) + } + } + + pub fn decode(&self, codes: &Tensor) -> Result { + // codes is [B, K, T], with T frames, K nb of codebooks. + let _enter = self.span_decode.enter(); + let quantized = self.rvq_first.decode(&codes.i((.., ..1))?)?; + let quantized = if self.n_q > 1 { + (quantized + self.rvq_rest.decode(&codes.i((.., 1..))?))? + } else { + quantized + }; + Ok(quantized) + } +} diff --git a/candle-transformers/src/models/mimi/seanet.rs b/candle-transformers/src/models/mimi/seanet.rs new file mode 100644 index 0000000000..aa5c7d2139 --- /dev/null +++ b/candle-transformers/src/models/mimi/seanet.rs @@ -0,0 +1,465 @@ +// Copyright (c) Kyutai, all rights reserved. +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +use candle::{streaming, Module, Result, StreamTensor, StreamingModule, Tensor}; +use candle_nn::VarBuilder; + +use super::conv::{StreamableConv1d, StreamableConvTranspose1d}; + +#[derive(Debug, Clone)] +pub struct Config { + pub dimension: usize, + pub channels: usize, + pub causal: bool, + pub n_filters: usize, + pub n_residual_layers: usize, + pub ratios: Vec, + pub activation: candle_nn::Activation, + pub norm: super::conv::Norm, + pub kernel_size: usize, + pub residual_kernel_size: usize, + pub last_kernel_size: usize, + pub dilation_base: usize, + pub pad_mode: super::conv::PadMode, + pub true_skip: bool, + pub compress: usize, + pub lstm: usize, + pub disable_norm_outer_blocks: usize, + pub final_activation: Option, +} + +#[derive(Debug, Clone)] +pub struct SeaNetResnetBlock { + block: Vec, + shortcut: Option, + activation: candle_nn::Activation, + skip_op: candle::StreamingBinOp, + span: tracing::Span, +} + +impl SeaNetResnetBlock { + #[allow(clippy::too_many_arguments)] + pub fn new( + dim: usize, + k_sizes_and_dilations: &[(usize, usize)], + activation: candle_nn::Activation, + norm: Option, + causal: bool, + pad_mode: super::conv::PadMode, + compress: usize, + true_skip: bool, + vb: VarBuilder, + ) -> Result { + let mut block = Vec::with_capacity(k_sizes_and_dilations.len()); + let hidden = dim / compress; + let vb_b = vb.pp("block"); + for (i, (k_size, dilation)) in k_sizes_and_dilations.iter().enumerate() { + let in_c = if i == 0 { dim } else { hidden }; + let out_c = if i == k_sizes_and_dilations.len() - 1 { + dim + } else { + hidden + }; + let c = StreamableConv1d::new( + in_c, + out_c, + /* k_size */ *k_size, + /* stride */ 1, + /* dilation */ *dilation, + /* groups */ 1, + /* bias */ true, + /* causal */ causal, + /* norm */ norm, + /* pad_mode */ pad_mode, + vb_b.pp(2 * i + 1), + )?; + block.push(c) + } + let shortcut = if true_skip { + None + } else { + let c = StreamableConv1d::new( + dim, + dim, + /* k_size */ 1, + /* stride */ 1, + /* dilation */ 1, + /* groups */ 1, + /* bias */ true, + /* causal */ causal, + /* norm */ norm, + /* pad_mode */ pad_mode, + vb.pp("shortcut"), + )?; + Some(c) + }; + Ok(Self { + block, + shortcut, + activation, + skip_op: streaming::StreamingBinOp::new(streaming::BinOp::Add, candle::D::Minus1), + span: tracing::span!(tracing::Level::TRACE, "sea-resnet"), + }) + } +} + +impl Module for SeaNetResnetBlock { + fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); + let mut ys = xs.clone(); + for block in self.block.iter() { + ys = ys.apply(&self.activation)?.apply(block)?; + } + match self.shortcut.as_ref() { + None => ys + xs, + Some(shortcut) => ys + xs.apply(shortcut), + } + } +} + +impl StreamingModule for SeaNetResnetBlock { + fn reset_state(&mut self) { + for block in self.block.iter_mut() { + block.reset_state() + } + if let Some(shortcut) = self.shortcut.as_mut() { + shortcut.reset_state() + } + } + + fn step(&mut self, xs: &StreamTensor) -> Result { + let _enter = self.span.enter(); + let mut ys = xs.clone(); + for block in self.block.iter_mut() { + ys = block.step(&ys.apply(&self.activation)?)?; + } + match self.shortcut.as_ref() { + None => self.skip_op.step(&ys, xs), + Some(shortcut) => self.skip_op.step(&ys, &xs.apply(shortcut)?), + } + } +} + +#[derive(Debug, Clone)] +struct EncoderLayer { + residuals: Vec, + downsample: StreamableConv1d, +} + +#[derive(Debug, Clone)] +pub struct SeaNetEncoder { + init_conv1d: StreamableConv1d, + activation: candle_nn::Activation, + layers: Vec, + final_conv1d: StreamableConv1d, + span: tracing::Span, +} + +impl SeaNetEncoder { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + if cfg.lstm > 0 { + candle::bail!("seanet lstm is not supported") + } + let n_blocks = 2 + cfg.ratios.len(); + let mut mult = 1usize; + let init_norm = if cfg.disable_norm_outer_blocks >= 1 { + None + } else { + Some(cfg.norm) + }; + let mut layer_idx = 0; + let vb = vb.pp("layers"); + let init_conv1d = StreamableConv1d::new( + cfg.channels, + mult * cfg.n_filters, + cfg.kernel_size, + /* stride */ 1, + /* dilation */ 1, + /* groups */ 1, + /* bias */ true, + /* causal */ cfg.causal, + /* norm */ init_norm, + /* pad_mode */ cfg.pad_mode, + vb.pp(layer_idx), + )?; + layer_idx += 1; + let mut layers = Vec::with_capacity(cfg.ratios.len()); + + for (i, &ratio) in cfg.ratios.iter().rev().enumerate() { + let norm = if cfg.disable_norm_outer_blocks >= i + 2 { + None + } else { + Some(cfg.norm) + }; + let mut residuals = Vec::with_capacity(cfg.n_residual_layers); + for j in 0..cfg.n_residual_layers { + let resnet_block = SeaNetResnetBlock::new( + mult * cfg.n_filters, + &[ + (cfg.residual_kernel_size, cfg.dilation_base.pow(j as u32)), + (1, 1), + ], + cfg.activation, + norm, + cfg.causal, + cfg.pad_mode, + cfg.compress, + cfg.true_skip, + vb.pp(layer_idx), + )?; + residuals.push(resnet_block); + layer_idx += 1; + } + let downsample = StreamableConv1d::new( + mult * cfg.n_filters, + mult * cfg.n_filters * 2, + /* k_size */ ratio * 2, + /* stride */ ratio, + /* dilation */ 1, + /* groups */ 1, + /* bias */ true, + /* causal */ true, + /* norm */ norm, + /* pad_mode */ cfg.pad_mode, + vb.pp(layer_idx + 1), + )?; + layer_idx += 2; + let layer = EncoderLayer { + downsample, + residuals, + }; + layers.push(layer); + mult *= 2 + } + + let final_norm = if cfg.disable_norm_outer_blocks >= n_blocks { + None + } else { + Some(cfg.norm) + }; + let final_conv1d = StreamableConv1d::new( + mult * cfg.n_filters, + cfg.dimension, + cfg.last_kernel_size, + /* stride */ 1, + /* dilation */ 1, + /* groups */ 1, + /* bias */ true, + /* causal */ cfg.causal, + /* norm */ final_norm, + /* pad_mode */ cfg.pad_mode, + vb.pp(layer_idx + 1), + )?; + Ok(Self { + init_conv1d, + activation: cfg.activation, + layers, + final_conv1d, + span: tracing::span!(tracing::Level::TRACE, "sea-encoder"), + }) + } +} + +impl Module for SeaNetEncoder { + fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); + let mut xs = xs.apply(&self.init_conv1d)?; + for layer in self.layers.iter() { + for residual in layer.residuals.iter() { + xs = xs.apply(residual)? + } + xs = xs.apply(&self.activation)?.apply(&layer.downsample)?; + } + xs.apply(&self.activation)?.apply(&self.final_conv1d) + } +} + +impl StreamingModule for SeaNetEncoder { + fn reset_state(&mut self) { + self.init_conv1d.reset_state(); + self.layers.iter_mut().for_each(|v| { + v.residuals.iter_mut().for_each(|v| v.reset_state()); + v.downsample.reset_state() + }); + self.final_conv1d.reset_state(); + } + + fn step(&mut self, xs: &StreamTensor) -> Result { + let _enter = self.span.enter(); + let mut xs = self.init_conv1d.step(xs)?; + for layer in self.layers.iter_mut() { + for residual in layer.residuals.iter_mut() { + xs = residual.step(&xs)?; + } + xs = layer.downsample.step(&xs.apply(&self.activation)?)?; + } + self.final_conv1d.step(&xs.apply(&self.activation)?) + } +} + +#[derive(Debug, Clone)] +struct DecoderLayer { + upsample: StreamableConvTranspose1d, + residuals: Vec, +} + +#[derive(Debug, Clone)] +pub struct SeaNetDecoder { + init_conv1d: StreamableConv1d, + activation: candle_nn::Activation, + layers: Vec, + final_conv1d: StreamableConv1d, + final_activation: Option, + span: tracing::Span, +} + +impl SeaNetDecoder { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + if cfg.lstm > 0 { + candle::bail!("seanet lstm is not supported") + } + let n_blocks = 2 + cfg.ratios.len(); + let mut mult = 1 << cfg.ratios.len(); + let init_norm = if cfg.disable_norm_outer_blocks == n_blocks { + None + } else { + Some(cfg.norm) + }; + let mut layer_idx = 0; + let vb = vb.pp("layers"); + let init_conv1d = StreamableConv1d::new( + cfg.dimension, + mult * cfg.n_filters, + cfg.kernel_size, + /* stride */ 1, + /* dilation */ 1, + /* groups */ 1, + /* bias */ true, + /* causal */ cfg.causal, + /* norm */ init_norm, + /* pad_mode */ cfg.pad_mode, + vb.pp(layer_idx), + )?; + layer_idx += 1; + let mut layers = Vec::with_capacity(cfg.ratios.len()); + for (i, &ratio) in cfg.ratios.iter().enumerate() { + let norm = if cfg.disable_norm_outer_blocks + i + 1 >= n_blocks { + None + } else { + Some(cfg.norm) + }; + let upsample = StreamableConvTranspose1d::new( + mult * cfg.n_filters, + mult * cfg.n_filters / 2, + /* k_size */ ratio * 2, + /* stride */ ratio, + /* groups */ 1, + /* bias */ true, + /* causal */ true, + /* norm */ norm, + vb.pp(layer_idx + 1), + )?; + layer_idx += 2; + + let mut residuals = Vec::with_capacity(cfg.n_residual_layers); + for j in 0..cfg.n_residual_layers { + let resnet_block = SeaNetResnetBlock::new( + mult * cfg.n_filters / 2, + &[ + (cfg.residual_kernel_size, cfg.dilation_base.pow(j as u32)), + (1, 1), + ], + cfg.activation, + norm, + cfg.causal, + cfg.pad_mode, + cfg.compress, + cfg.true_skip, + vb.pp(layer_idx), + )?; + residuals.push(resnet_block); + layer_idx += 1; + } + let layer = DecoderLayer { + upsample, + residuals, + }; + layers.push(layer); + mult /= 2 + } + let final_norm = if cfg.disable_norm_outer_blocks >= 1 { + None + } else { + Some(cfg.norm) + }; + let final_conv1d = StreamableConv1d::new( + cfg.n_filters, + cfg.channels, + cfg.last_kernel_size, + /* stride */ 1, + /* dilation */ 1, + /* groups */ 1, + /* bias */ true, + /* causal */ cfg.causal, + /* norm */ final_norm, + /* pad_mode */ cfg.pad_mode, + vb.pp(layer_idx + 1), + )?; + Ok(Self { + init_conv1d, + activation: cfg.activation, + layers, + final_conv1d, + final_activation: cfg.final_activation, + span: tracing::span!(tracing::Level::TRACE, "sea-decoder"), + }) + } +} + +impl Module for SeaNetDecoder { + fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); + let mut xs = xs.apply(&self.init_conv1d)?; + for layer in self.layers.iter() { + xs = xs.apply(&self.activation)?.apply(&layer.upsample)?; + for residual in layer.residuals.iter() { + xs = xs.apply(residual)? + } + } + let xs = xs.apply(&self.activation)?.apply(&self.final_conv1d)?; + let xs = match self.final_activation.as_ref() { + None => xs, + Some(act) => xs.apply(act)?, + }; + Ok(xs) + } +} + +impl StreamingModule for SeaNetDecoder { + fn reset_state(&mut self) { + self.init_conv1d.reset_state(); + self.layers.iter_mut().for_each(|v| { + v.residuals.iter_mut().for_each(|v| v.reset_state()); + v.upsample.reset_state() + }); + self.final_conv1d.reset_state(); + } + + fn step(&mut self, xs: &StreamTensor) -> Result { + let _enter = self.span.enter(); + let mut xs = self.init_conv1d.step(xs)?; + for layer in self.layers.iter_mut() { + xs = layer.upsample.step(&xs.apply(&self.activation)?)?; + for residual in layer.residuals.iter_mut() { + xs = residual.step(&xs)?; + } + } + let xs = self.final_conv1d.step(&xs.apply(&self.activation)?)?; + let xs = match self.final_activation.as_ref() { + None => xs, + Some(act) => xs.apply(act)?, + }; + Ok(xs) + } +} diff --git a/candle-transformers/src/models/mimi/transformer.rs b/candle-transformers/src/models/mimi/transformer.rs new file mode 100644 index 0000000000..de22127462 --- /dev/null +++ b/candle-transformers/src/models/mimi/transformer.rs @@ -0,0 +1,802 @@ +// Copyright (c) Kyutai, all rights reserved. +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +use candle::{DType, Device, IndexOp, Module, Result, StreamTensor, StreamingModule, Tensor, D}; +use candle_nn::{linear_no_bias, Linear, VarBuilder}; +use std::sync::Arc; + +fn linear(in_d: usize, out_d: usize, bias: bool, vb: VarBuilder) -> Result { + if bias { + candle_nn::linear(in_d, out_d, vb) + } else { + linear_no_bias(in_d, out_d, vb) + } +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub enum PositionalEmbedding { + Rope, + Sin, + None, +} + +#[derive(Debug, Clone)] +pub struct Config { + pub d_model: usize, + pub num_heads: usize, + pub num_layers: usize, + pub causal: bool, + pub norm_first: bool, + pub bias_ff: bool, + pub bias_attn: bool, + pub layer_scale: Option, + pub positional_embedding: PositionalEmbedding, + pub use_conv_block: bool, + pub cross_attention: bool, + pub conv_kernel_size: usize, + pub use_conv_bias: bool, + pub gating: Option, + pub norm: super::NormType, + pub context: usize, + pub max_period: usize, + pub max_seq_len: usize, + + pub kv_repeat: usize, + pub dim_feedforward: usize, + pub conv_layout: bool, +} + +#[derive(Debug, Clone)] +pub struct RotaryEmbedding { + sin: Tensor, + cos: Tensor, + span: tracing::Span, +} + +impl RotaryEmbedding { + pub fn new(dim: usize, max_seq_len: usize, theta: f32, dev: &Device) -> Result { + let inv_freq: Vec<_> = (0..dim) + .step_by(2) + .map(|i| 1f32 / theta.powf(i as f32 / dim as f32)) + .collect(); + let inv_freq_len = inv_freq.len(); + let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?; + let t = Tensor::arange(0u32, max_seq_len as u32, dev)? + .to_dtype(DType::F32)? + .reshape((max_seq_len, 1))?; + let freqs = t.matmul(&inv_freq)?; + Ok(Self { + sin: freqs.sin()?, + cos: freqs.cos()?, + span: tracing::span!(tracing::Level::TRACE, "rot"), + }) + } + + pub fn apply_rotary_emb(&self, qk: &Tensor, seqlen_offset: usize) -> Result { + let _enter = self.span.enter(); + let (_b_size, _nheads, seqlen, _headdim) = qk.dims4()?; + let qk_dtype = qk.dtype(); + let c = self.cos.narrow(0, seqlen_offset, seqlen)?; + let s = self.sin.narrow(0, seqlen_offset, seqlen)?; + candle_nn::rotary_emb::rope_i(&qk.to_dtype(DType::F32)?, &c, &s)?.to_dtype(qk_dtype) + } +} + +#[derive(Debug, Clone)] +pub struct LayerScale { + scale: Tensor, +} + +impl LayerScale { + pub fn new(d_model: usize, _init: f64, vb: VarBuilder) -> Result { + let scale = vb.get(d_model, "scale")?; + Ok(Self { scale }) + } +} + +impl Module for LayerScale { + fn forward(&self, xs: &Tensor) -> Result { + xs.broadcast_mul(&self.scale) + } +} + +pub(crate) fn get_mask( + size1: usize, + size2: usize, + context: usize, + device: &Device, +) -> Result { + let mask: Vec<_> = (0..size1) + .flat_map(|i| { + (0..size2) + .map(move |j| u8::from(size1 + j > size2 + i || size1 + j + context < size2 + i)) + }) + .collect(); + Tensor::from_slice(&mask, (size1, size2), device) +} + +#[derive(Debug, Clone)] +pub struct StreamingMultiheadAttention { + q_proj: Linear, + k_proj: Linear, + v_proj: Linear, + out_proj: Linear, + kv_repeat: usize, + num_heads: usize, + context: usize, + neg_inf: Tensor, + rope: Option>, + kv_cache: candle_nn::kv_cache::KvCache, + pos: usize, + use_flash_attn: bool, + span: tracing::Span, +} + +impl StreamingMultiheadAttention { + pub fn new(rope: &Option>, cfg: &Config, vb: VarBuilder) -> Result { + let embed_dim = cfg.d_model; + let num_kv = cfg.num_heads / cfg.kv_repeat; + let kv_dim = num_kv * (embed_dim / cfg.num_heads); + let q_proj = linear(embed_dim, embed_dim, cfg.bias_attn, vb.pp("q_proj"))?; + let k_proj = linear(embed_dim, kv_dim, cfg.bias_attn, vb.pp("k_proj"))?; + let v_proj = linear(embed_dim, kv_dim, cfg.bias_attn, vb.pp("v_proj"))?; + let out_proj = linear(embed_dim, embed_dim, cfg.bias_attn, vb.pp("o_proj"))?; + let neg_inf = Tensor::new(f32::NEG_INFINITY, vb.device())?.to_dtype(vb.dtype())?; + Ok(Self { + q_proj, + k_proj, + v_proj, + out_proj, + rope: rope.clone(), + kv_repeat: cfg.kv_repeat, + num_heads: cfg.num_heads, + context: cfg.context, + neg_inf, + kv_cache: candle_nn::kv_cache::KvCache::new(2, cfg.max_seq_len), + pos: 0, + use_flash_attn: false, + span: tracing::span!(tracing::Level::TRACE, "mha"), + }) + } + + pub fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result { + let _enter = self.span.enter(); + if self.kv_repeat != 1 { + candle::bail!("only kv-repeat = 1 is supported") + } + let (b, t, hd) = xs.dims3()?; + let head_dim = hd / self.num_heads; + let q = xs + .apply(&self.q_proj)? + .reshape((b, t, self.num_heads, head_dim))?; + let k = xs + .apply(&self.k_proj)? + .reshape((b, t, self.num_heads, head_dim))?; + let v = xs + .apply(&self.v_proj)? + .reshape((b, t, self.num_heads, head_dim))?; + // qk_layer_norm = None + // kv_repeat = 1, otherwise we would need repeat_kv + let mut q = q.transpose(1, 2)?.contiguous()?; // b,h,t,d + let mut k = k.transpose(1, 2)?.contiguous()?; // b,h,k,d + let v = v.transpose(1, 2)?.contiguous()?; // b,h,k,d + if let Some(rope) = &self.rope { + q = rope.apply_rotary_emb(&q, self.pos)?; + k = rope.apply_rotary_emb(&k, self.pos)?; + } + + let (k, v) = { + self.pos += k.dim(2)?; + self.kv_cache.append(&k.contiguous()?, &v.contiguous()?)? + }; + // The KV cache keeps all the data at the moment, we want to trim + // down the part that comes from the cache to at most context to + // be coherent with the mask shape we provide. + let k_len = k.dim(2)?; + let k_target_len = t + usize::min(self.context, k_len - t); + let (k, v) = if k_target_len < k_len { + let k = k.narrow(2, k_len - k_target_len, k_target_len)?; + let v = v.narrow(2, k_len - k_target_len, k_target_len)?; + (k, v) + } else { + (k.clone(), v.clone()) + }; + + let xs = if q.dtype() == DType::BF16 && self.use_flash_attn { + let q = q.transpose(1, 2)?; + let k = k.transpose(1, 2)?; + let v = v.transpose(1, 2)?; + let softmax_scale = 1f32 / (head_dim as f32).sqrt(); + flash_attn(&q, &k, &v, softmax_scale, t > 1)?.transpose(1, 2)? + } else { + let pre_ws = q.matmul(&k.t()?)?; // b,h,t,k + let pre_ws = (pre_ws * (head_dim as f64).powf(-0.5))?; + + let pre_ws = match mask { + None => pre_ws, + Some(mask) => { + let mask = mask.broadcast_left((b, self.num_heads))?; + let neg_inf = self.neg_inf.broadcast_as(pre_ws.shape())?; + mask.where_cond(&neg_inf, &pre_ws)? + } + }; + + let ws = candle_nn::ops::softmax_last_dim(&pre_ws)?; // b,h,t,k + ws.matmul(&v)? // b,h,t,d + }; + let xs = xs + .transpose(1, 2)? // b,t,h,d + .reshape((b, t, hd))? + .apply(&self.out_proj)?; + Ok(xs) + } + + pub fn reset_kv_cache(&mut self) { + self.kv_cache.reset() + } + + pub fn set_kv_cache(&mut self, kv_cache: candle_nn::kv_cache::KvCache) { + self.kv_cache = kv_cache + } +} + +#[derive(Debug, Clone)] +pub struct StreamingMultiheadCrossAttention { + in_proj_q: Linear, + in_proj_k: Linear, + in_proj_v: Linear, + out_proj: Linear, + kv_repeat: usize, + num_heads: usize, + neg_inf: Tensor, + span: tracing::Span, +} + +impl StreamingMultiheadCrossAttention { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + let embed_dim = cfg.d_model; + let num_kv = cfg.num_heads / cfg.kv_repeat; + let kv_dim = num_kv * (embed_dim / cfg.num_heads); + let out_dim = embed_dim + 2 * kv_dim; + let in_proj_weight = vb.get((out_dim, embed_dim), "in_proj_weight")?; + let in_proj_weight_q = in_proj_weight.narrow(0, 0, embed_dim)?; + let in_proj_weight_k = in_proj_weight.narrow(0, embed_dim, kv_dim)?; + let in_proj_weight_v = in_proj_weight.narrow(0, embed_dim + kv_dim, kv_dim)?; + let (in_proj_bias_q, in_proj_bias_k, in_proj_bias_v) = if cfg.bias_attn { + let b = vb.get(out_dim, "in_proj_bias")?; + let q = b.narrow(0, 0, embed_dim)?; + let k = b.narrow(0, embed_dim, kv_dim)?; + let v = b.narrow(0, embed_dim + kv_dim, kv_dim)?; + (Some(q), Some(k), Some(v)) + } else { + (None, None, None) + }; + let in_proj_q = Linear::new(in_proj_weight_q, in_proj_bias_q); + let in_proj_k = Linear::new(in_proj_weight_k, in_proj_bias_k); + let in_proj_v = Linear::new(in_proj_weight_v, in_proj_bias_v); + let out_proj = linear(embed_dim, embed_dim, cfg.bias_attn, vb.pp("out_proj"))?; + let neg_inf = Tensor::new(f32::NEG_INFINITY, vb.device())?.to_dtype(vb.dtype())?; + Ok(Self { + in_proj_q, + in_proj_k, + in_proj_v, + out_proj, + kv_repeat: cfg.kv_repeat, + num_heads: cfg.num_heads, + neg_inf, + span: tracing::span!(tracing::Level::TRACE, "mhca"), + }) + } + + pub fn forward(&self, xs: &Tensor, ca_src: &Tensor, mask: Option<&Tensor>) -> Result { + let _enter = self.span.enter(); + if self.kv_repeat != 1 { + candle::bail!("only kv-repeat = 1 is supported") + } + let (b, t, hd) = xs.dims3()?; + let head_dim = hd / self.num_heads; + // time_dim = 1, layout: b,t,h,d + let q = xs.apply(&self.in_proj_q)?; + let k = ca_src.apply(&self.in_proj_k)?; + let v = ca_src.apply(&self.in_proj_v)?; + let (ca_b, ca_t, ca_dim) = k.dims3()?; + let q = q.reshape((b, t, self.num_heads, head_dim))?; + let k = k.reshape((ca_b, ca_t, ca_dim / head_dim, head_dim))?; + let v = v.reshape((ca_b, ca_t, ca_dim / head_dim, head_dim))?; + // qk_layer_norm = None + // kv_repeat = 1, otherwise we would need repeat_kv + let q = q.transpose(1, 2)?.contiguous()?; // b,h,t,d + let k = k.transpose(1, 2)?.contiguous()?; // b,h,k,d + let v = v.transpose(1, 2)?.contiguous()?; // b,h,k,d + + let pre_ws = q.matmul(&k.t()?)?; // b,h,t,k + let pre_ws = (pre_ws * (head_dim as f64).powf(-0.5))?; + + let pre_ws = match mask { + None => pre_ws, + Some(mask) => { + let mask = mask.broadcast_left((b, self.num_heads))?; + let neg_inf = self.neg_inf.broadcast_as(pre_ws.shape())?; + mask.where_cond(&neg_inf, &pre_ws)? + } + }; + + let ws = candle_nn::ops::softmax_last_dim(&pre_ws)?; // b,h,t,k + let xs = ws.matmul(&v)?; // b,h,t,d + let xs = xs + .transpose(1, 2)? // b,t,h,d + .reshape((b, t, hd))? + .apply(&self.out_proj)?; + Ok(xs) + } +} + +#[derive(Debug, Clone)] +pub enum Mlp { + NoGating { + span1: tracing::Span, + linear1: Linear, + span2: tracing::Span, + linear2: Linear, + span: tracing::Span, + }, + Gating { + linear_in: Linear, + linear_out: Linear, + activation: candle_nn::Activation, + span: tracing::Span, + }, +} + +impl Mlp { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + let d_model = cfg.d_model; + let span = tracing::span!(tracing::Level::TRACE, "mlp"); + + match cfg.gating { + None => { + let span1 = tracing::span!(tracing::Level::TRACE, "lin1"); + let span2 = tracing::span!(tracing::Level::TRACE, "lin2"); + let linear1 = linear(d_model, cfg.dim_feedforward, cfg.bias_ff, vb.pp("mlp.fc1"))?; + let linear2 = linear(cfg.dim_feedforward, d_model, cfg.bias_ff, vb.pp("mlp.fc2"))?; + Ok(Self::NoGating { + linear1, + linear2, + span, + span1, + span2, + }) + } + Some(activation) => { + let vb = vb.pp("gating"); + let hidden = if cfg.dim_feedforward == 4 * d_model { + 11 * d_model / 4 + } else { + 2 * cfg.dim_feedforward / 3 + }; + // TODO: Maybe use bias_ff here? + let linear_in = linear(d_model, 2 * hidden, false, vb.pp("linear_in"))?; + let linear_out = linear(hidden, d_model, false, vb.pp("linear_out"))?; + Ok(Self::Gating { + linear_in, + linear_out, + activation, + span, + }) + } + } + } +} + +impl Module for Mlp { + fn forward(&self, xs: &Tensor) -> Result { + match self { + Self::NoGating { + linear1, + linear2, + span, + span1, + span2, + } => { + let _enter = span.enter(); + let xs = { + let _enter = span1.enter(); + xs.apply(linear1)? + }; + let xs = xs.gelu_erf()?; + { + let _enter = span2.enter(); + xs.apply(linear2) + } + } + Self::Gating { + linear_in, + linear_out, + activation, + span, + } => { + let _enter = span.enter(); + let xs = xs.apply(linear_in)?; + let (b, t, _) = xs.dims3()?; + let xs = xs.reshape((b, t, 2, ()))?; + let xs = (xs.i((.., .., 0))?.apply(activation)? * xs.i((.., .., 1))?)?; + xs.apply(linear_out) + } + } + } +} + +#[derive(Debug, Clone)] +pub struct RmsNorm { + pub(crate) alpha: Tensor, + pub(crate) eps: f32, +} + +impl RmsNorm { + pub fn new(d_model: usize, eps: f32, vb: VarBuilder) -> Result { + let alpha = vb.get((1, 1, d_model), "alpha")?.reshape(d_model)?; + Ok(Self { alpha, eps }) + } +} + +impl Module for RmsNorm { + fn forward(&self, xs: &Tensor) -> Result { + candle_nn::ops::rms_norm(xs, &self.alpha, self.eps) + } +} + +#[derive(Debug, Clone)] +pub enum Norm { + LayerNorm(candle_nn::LayerNorm), + RmsNorm(RmsNorm), +} + +impl Norm { + pub fn new(d_model: usize, cfg: &Config, vb: VarBuilder) -> Result { + let norm = match cfg.norm { + super::NormType::LayerNorm => { + let norm = candle_nn::layer_norm(d_model, 1e-5, vb)?; + Self::LayerNorm(norm) + } + super::NormType::RmsNorm => { + let norm = RmsNorm::new(d_model, 1e-8, vb)?; + Self::RmsNorm(norm) + } + }; + Ok(norm) + } +} + +impl Module for Norm { + fn forward(&self, xs: &Tensor) -> Result { + match self { + Self::LayerNorm(m) => m.forward(xs), + Self::RmsNorm(m) => m.forward(xs), + } + } +} + +#[derive(Debug, Clone)] +pub struct StreamingTransformerLayer { + self_attn: StreamingMultiheadAttention, + mlp: Mlp, + norm1: Norm, + norm2: Norm, + layer_scale_1: Option, + layer_scale_2: Option, + cross_attn: Option<(candle_nn::LayerNorm, StreamingMultiheadCrossAttention)>, + norm_first: bool, + span: tracing::Span, +} + +impl StreamingTransformerLayer { + pub fn new(rope: &Option>, cfg: &Config, vb: VarBuilder) -> Result { + if cfg.use_conv_block { + candle::bail!("conv-block is not supported") + } + let d_model = cfg.d_model; + let mlp = Mlp::new(cfg, vb.clone())?; + let (norm1, norm2) = match cfg.norm { + super::NormType::LayerNorm => { + let norm1 = candle_nn::layer_norm(d_model, 1e-5, vb.pp("input_layernorm"))?; + let norm2 = + candle_nn::layer_norm(d_model, 1e-5, vb.pp("post_attention_layernorm"))?; + (Norm::LayerNorm(norm1), Norm::LayerNorm(norm2)) + } + super::NormType::RmsNorm => { + let norm1 = RmsNorm::new(d_model, 1e-8, vb.pp("input_rmsnorm"))?; + let norm2 = RmsNorm::new(d_model, 1e-8, vb.pp("post_attention_rmsnorm"))?; + (Norm::RmsNorm(norm1), Norm::RmsNorm(norm2)) + } + }; + let layer_scale_1 = match cfg.layer_scale { + None => None, + Some(ls) => { + let ls = LayerScale::new(d_model, ls, vb.pp("self_attn_layer_scale"))?; + Some(ls) + } + }; + let layer_scale_2 = match cfg.layer_scale { + None => None, + Some(ls) => { + let ls = LayerScale::new(d_model, ls, vb.pp("mlp_layer_scale"))?; + Some(ls) + } + }; + let self_attn = StreamingMultiheadAttention::new(rope, cfg, vb.pp("self_attn"))?; + let cross_attn = if cfg.cross_attention { + let norm_cross = candle_nn::layer_norm(cfg.d_model, 1e-5, vb.pp("norm_cross"))?; + let cross_attn = StreamingMultiheadCrossAttention::new(cfg, vb.pp("cross_attention"))?; + Some((norm_cross, cross_attn)) + } else { + None + }; + Ok(Self { + self_attn, + mlp, + norm1, + norm2, + layer_scale_1, + layer_scale_2, + cross_attn, + norm_first: cfg.norm_first, + span: tracing::span!(tracing::Level::TRACE, "transformer-layer"), + }) + } + + pub fn forward( + &mut self, + xs: &Tensor, + ca_src: Option<&Tensor>, + mask: Option<&Tensor>, + ) -> Result { + let _enter = self.span.enter(); + if !self.norm_first { + candle::bail!("only norm_first = true is supported") + } + let norm1 = xs.apply(&self.norm1)?; + let xs = (xs + + self + .self_attn + .forward(&norm1, mask)? + .apply(&self.layer_scale_1.as_ref())?)?; + + let xs = match (&self.cross_attn, ca_src) { + (Some((norm_cross, cross_attn)), Some(ca_src)) => { + let residual = &xs; + let xs = xs.apply(norm_cross)?; + (residual + cross_attn.forward(&xs, ca_src, None)?)? + } + _ => xs, + }; + + let xs = (&xs + + xs.apply(&self.norm2)? + .apply(&self.mlp)? + .apply(&self.layer_scale_2.as_ref()))?; + Ok(xs) + } + + pub fn reset_kv_cache(&mut self) { + self.self_attn.reset_kv_cache() + } + + pub fn set_kv_cache(&mut self, kv_cache: candle_nn::kv_cache::KvCache) { + self.self_attn.set_kv_cache(kv_cache) + } +} + +#[derive(Debug, Clone)] +pub struct StreamingTransformer { + layers: Vec, + context: usize, + positional_embedding: PositionalEmbedding, + max_period: usize, +} + +impl StreamingTransformer { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + let vb_l = vb.pp("layers"); + let rope = match cfg.positional_embedding { + PositionalEmbedding::Rope => { + let rope = RotaryEmbedding::new( + cfg.d_model / cfg.num_heads, + cfg.max_seq_len, + cfg.max_period as f32, + vb.device(), + )?; + Some(Arc::new(rope)) + } + PositionalEmbedding::Sin | PositionalEmbedding::None => None, + }; + let mut layers = Vec::with_capacity(cfg.num_layers); + for layer_idx in 0..cfg.num_layers { + let layer = StreamingTransformerLayer::new(&rope, cfg, vb_l.pp(layer_idx))?; + layers.push(layer) + } + Ok(Self { + layers, + context: cfg.context, + positional_embedding: cfg.positional_embedding, + max_period: cfg.max_period, + }) + } + + pub fn forward(&mut self, xs: &Tensor) -> Result { + self.forward_ca(xs, None) + } + + pub fn forward_ca(&mut self, xs: &Tensor, ca_src: Option<&Tensor>) -> Result { + let (_b, t, c) = xs.dims3()?; + // We will extract at most "context" from the kv_cache. + // Note that the mask will discard the values that are before context. + let pos = self.layers[0] + .self_attn + .kv_cache + .k_cache() + .current_seq_len() + .min(self.context); + let mask = if t == 1 { + None + } else { + Some(get_mask(t, pos + t, self.context, xs.device())?) + }; + let mut xs = match self.positional_embedding { + PositionalEmbedding::Rope | PositionalEmbedding::None => xs.clone(), + PositionalEmbedding::Sin => { + let dev = xs.device(); + let theta = self.max_period as f32; + let half_dim = c / 2; + let positions = Tensor::arange(pos as u32, (pos + t) as u32, dev)? + .unsqueeze(1)? + .to_dtype(DType::F32)?; + let inv_freq: Vec<_> = (0..half_dim) + .map(|i| 1f32 / theta.powf(i as f32 / (half_dim - 1) as f32)) + .collect(); + let inv_freq_len = inv_freq.len(); + let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?; + let freqs = positions.broadcast_mul(&inv_freq)?; + let pos_emb = + Tensor::cat(&[freqs.cos()?, freqs.sin()?], D::Minus1)?.to_dtype(xs.dtype())?; + xs.broadcast_add(&pos_emb)? + } + }; + for layer in self.layers.iter_mut() { + xs = layer.forward(&xs, ca_src, mask.as_ref())?; + } + Ok(xs) + } + + pub fn copy_state(&mut self, from: &Self) -> Result<()> { + if self.layers.len() != from.layers.len() { + candle::bail!("cannot copy kv-caches as the transformers have different depths") + } + self.layers + .iter_mut() + .zip(from.layers.iter()) + .for_each(|(v, w)| v.set_kv_cache(w.self_attn.kv_cache.clone())); + Ok(()) + } +} + +impl StreamingModule for StreamingTransformer { + fn reset_state(&mut self) { + self.layers.iter_mut().for_each(|v| v.reset_kv_cache()) + } + + fn step(&mut self, xs: &StreamTensor) -> Result { + match xs.as_option() { + None => Ok(StreamTensor::empty()), + Some(xs) => Ok(StreamTensor::from_tensor(self.forward(xs)?)), + } + } +} + +#[derive(Debug, Clone)] +pub struct ProjectedTransformer { + transformer: StreamingTransformer, + input_proj: Option, + output_projs: Vec>, + conv_layout: bool, + span: tracing::Span, +} + +impl ProjectedTransformer { + pub fn new( + input_dim: usize, + output_dims: &[usize], + cfg: &Config, + vb: VarBuilder, + ) -> Result { + let transformer = StreamingTransformer::new(cfg, vb.clone())?; + let input_proj = if input_dim == cfg.d_model { + None + } else { + let l = linear_no_bias(input_dim, cfg.d_model, vb.pp("input_proj"))?; + Some(l) + }; + let mut output_projs = Vec::with_capacity(output_dims.len()); + let vb_o = vb.pp("output_projs"); + for (i, &output_dim) in output_dims.iter().enumerate() { + let output_proj = if output_dim == cfg.d_model { + None + } else { + let l = linear_no_bias(cfg.d_model, output_dim, vb_o.pp(i))?; + Some(l) + }; + output_projs.push(output_proj) + } + Ok(Self { + transformer, + input_proj, + output_projs, + conv_layout: cfg.conv_layout, + span: tracing::span!(tracing::Level::TRACE, "proj-transformer"), + }) + } + + pub fn forward(&mut self, xs: &Tensor) -> Result> { + let _enter = self.span.enter(); + let xs = if self.conv_layout { + xs.transpose(1, 2)? + } else { + xs.clone() + }; + let xs = xs.apply(&self.input_proj.as_ref())?; + let xs = self.transformer.forward(&xs)?; + let mut ys = Vec::with_capacity(self.output_projs.len()); + for output_proj in self.output_projs.iter() { + let ys_ = xs.apply(&output_proj.as_ref())?; + let ys_ = if self.conv_layout { + ys_.transpose(1, 2)? + } else { + ys_ + }; + ys.push(ys_) + } + Ok(ys) + } +} + +impl StreamingModule for ProjectedTransformer { + fn reset_state(&mut self) { + self.transformer.reset_state() + } + + fn step(&mut self, xs: &StreamTensor) -> Result { + let xs = xs.apply(&|x: &Tensor| { + if self.conv_layout { + x.transpose(1, 2) + } else { + Ok(x.clone()) + } + })?; + let xs = xs.apply(&self.input_proj.as_ref())?; + let xs = self.transformer.step(&xs)?; + let ys = xs.apply(&self.output_projs[0].as_ref())?; + ys.apply(&|y: &Tensor| { + if self.conv_layout { + y.transpose(1, 2) + } else { + Ok(y.clone()) + } + }) + } +} + +#[cfg(feature = "flash-attn")] +fn flash_attn( + q: &Tensor, + k: &Tensor, + v: &Tensor, + softmax_scale: f32, + causal: bool, +) -> Result { + candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal) +} + +#[cfg(not(feature = "flash-attn"))] +fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result { + unimplemented!("compile with '--features flash-attn'") +} diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index 9f7856ea20..07672bcc33 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -33,6 +33,7 @@ pub mod llava; pub mod mamba; pub mod marian; pub mod metavoice; +pub mod mimi; pub mod mistral; pub mod mixformer; pub mod mixtral;