From 10d47183c088ce449da13d74f07171c8106cd6dd Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Thu, 26 Sep 2024 10:23:43 +0200 Subject: [PATCH] Quantized version of flux. (#2500) * Quantized version of flux. * More generic sampling. * Hook the quantized model. * Use the newly minted gguf file. * Fix for the quantized model. * Default to avoid the faster cuda kernels. --- candle-examples/examples/flux/README.md | 2 +- candle-examples/examples/flux/main.rs | 83 +++- candle-transformers/src/models/flux/mod.rs | 17 + candle-transformers/src/models/flux/model.rs | 10 +- .../src/models/flux/quantized_model.rs | 465 ++++++++++++++++++ .../src/models/flux/sampling.rs | 4 +- 6 files changed, 555 insertions(+), 26 deletions(-) create mode 100644 candle-transformers/src/models/flux/quantized_model.rs diff --git a/candle-examples/examples/flux/README.md b/candle-examples/examples/flux/README.md index 528f058e38..dfc8ad5f8c 100644 --- a/candle-examples/examples/flux/README.md +++ b/candle-examples/examples/flux/README.md @@ -13,7 +13,7 @@ descriptions, ```bash cargo run --features cuda --example flux -r -- \ - --height 1024 --width 1024 + --height 1024 --width 1024 \ --prompt "a rusty robot walking on a beach holding a small torch, the robot has the word "rust" written on it, high quality, 4k" ``` diff --git a/candle-examples/examples/flux/main.rs b/candle-examples/examples/flux/main.rs index 539ae6f260..24b1fa2bc6 100644 --- a/candle-examples/examples/flux/main.rs +++ b/candle-examples/examples/flux/main.rs @@ -23,6 +23,10 @@ struct Args { #[arg(long)] cpu: bool, + /// Use the quantized model. + #[arg(long)] + quantized: bool, + /// Enable tracing (generates a trace-timestamp.json file). #[arg(long)] tracing: bool, @@ -40,6 +44,10 @@ struct Args { #[arg(long, value_enum, default_value = "schnell")] model: Model, + + /// Use the faster kernels which are buggy at the moment. + #[arg(long)] + no_dmmv: bool, } #[derive(Debug, Clone, Copy, clap::ValueEnum, PartialEq, Eq)] @@ -60,6 +68,8 @@ fn run(args: Args) -> Result<()> { tracing, decode_only, model, + quantized, + .. } = args; let width = width.unwrap_or(1360); let height = height.unwrap_or(768); @@ -146,38 +156,71 @@ fn run(args: Args) -> Result<()> { }; println!("CLIP\n{clip_emb}"); let img = { - let model_file = match model { - Model::Schnell => bf_repo.get("flux1-schnell.safetensors")?, - Model::Dev => bf_repo.get("flux1-dev.safetensors")?, - }; - let vb = - unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? }; let cfg = match model { Model::Dev => flux::model::Config::dev(), Model::Schnell => flux::model::Config::schnell(), }; let img = flux::sampling::get_noise(1, height, width, &device)?.to_dtype(dtype)?; - let state = flux::sampling::State::new(&t5_emb, &clip_emb, &img)?; + let state = if quantized { + flux::sampling::State::new( + &t5_emb.to_dtype(candle::DType::F32)?, + &clip_emb.to_dtype(candle::DType::F32)?, + &img.to_dtype(candle::DType::F32)?, + )? + } else { + flux::sampling::State::new(&t5_emb, &clip_emb, &img)? + }; let timesteps = match model { Model::Dev => { flux::sampling::get_schedule(50, Some((state.img.dim(1)?, 0.5, 1.15))) } Model::Schnell => flux::sampling::get_schedule(4, None), }; - let model = flux::model::Flux::new(&cfg, vb)?; - println!("{state:?}"); println!("{timesteps:?}"); - flux::sampling::denoise( - &model, - &state.img, - &state.img_ids, - &state.txt, - &state.txt_ids, - &state.vec, - ×teps, - 4., - )? + if quantized { + let model_file = match model { + Model::Schnell => api + .repo(hf_hub::Repo::model("lmz/candle-flux".to_string())) + .get("flux1-schnell.gguf")?, + Model::Dev => todo!(), + }; + let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf( + model_file, &device, + )?; + + let model = flux::quantized_model::Flux::new(&cfg, vb)?; + flux::sampling::denoise( + &model, + &state.img, + &state.img_ids, + &state.txt, + &state.txt_ids, + &state.vec, + ×teps, + 4., + )? + .to_dtype(dtype)? + } else { + let model_file = match model { + Model::Schnell => bf_repo.get("flux1-schnell.safetensors")?, + Model::Dev => bf_repo.get("flux1-dev.safetensors")?, + }; + let vb = unsafe { + VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? + }; + let model = flux::model::Flux::new(&cfg, vb)?; + flux::sampling::denoise( + &model, + &state.img, + &state.img_ids, + &state.txt, + &state.txt_ids, + &state.vec, + ×teps, + 4., + )? + } }; flux::sampling::unpack(&img, height, width)? } @@ -206,5 +249,7 @@ fn run(args: Args) -> Result<()> { fn main() -> Result<()> { let args = Args::parse(); + #[cfg(feature = "cuda")] + candle::quantized::cuda::set_force_dmmv(!args.no_dmmv); run(args) } diff --git a/candle-transformers/src/models/flux/mod.rs b/candle-transformers/src/models/flux/mod.rs index 763fa90da1..b0c8a6939a 100644 --- a/candle-transformers/src/models/flux/mod.rs +++ b/candle-transformers/src/models/flux/mod.rs @@ -1,3 +1,20 @@ +use candle::{Result, Tensor}; + +pub trait WithForward { + #[allow(clippy::too_many_arguments)] + fn forward( + &self, + img: &Tensor, + img_ids: &Tensor, + txt: &Tensor, + txt_ids: &Tensor, + timesteps: &Tensor, + y: &Tensor, + guidance: Option<&Tensor>, + ) -> Result; +} + pub mod autoencoder; pub mod model; +pub mod quantized_model; pub mod sampling; diff --git a/candle-transformers/src/models/flux/model.rs b/candle-transformers/src/models/flux/model.rs index 4e47873fe0..17b4eb2532 100644 --- a/candle-transformers/src/models/flux/model.rs +++ b/candle-transformers/src/models/flux/model.rs @@ -109,14 +109,14 @@ fn apply_rope(x: &Tensor, freq_cis: &Tensor) -> Result { (fr0.broadcast_mul(&x0)? + fr1.broadcast_mul(&x1)?)?.reshape(dims.to_vec()) } -fn attention(q: &Tensor, k: &Tensor, v: &Tensor, pe: &Tensor) -> Result { +pub(crate) fn attention(q: &Tensor, k: &Tensor, v: &Tensor, pe: &Tensor) -> Result { let q = apply_rope(q, pe)?.contiguous()?; let k = apply_rope(k, pe)?.contiguous()?; let x = scaled_dot_product_attention(&q, &k, v)?; x.transpose(1, 2)?.flatten_from(2) } -fn timestep_embedding(t: &Tensor, dim: usize, dtype: DType) -> Result { +pub(crate) fn timestep_embedding(t: &Tensor, dim: usize, dtype: DType) -> Result { const TIME_FACTOR: f64 = 1000.; const MAX_PERIOD: f64 = 10000.; if dim % 2 == 1 { @@ -144,7 +144,7 @@ pub struct EmbedNd { } impl EmbedNd { - fn new(dim: usize, theta: usize, axes_dim: Vec) -> Self { + pub fn new(dim: usize, theta: usize, axes_dim: Vec) -> Self { Self { dim, theta, @@ -575,9 +575,11 @@ impl Flux { final_layer, }) } +} +impl super::WithForward for Flux { #[allow(clippy::too_many_arguments)] - pub fn forward( + fn forward( &self, img: &Tensor, img_ids: &Tensor, diff --git a/candle-transformers/src/models/flux/quantized_model.rs b/candle-transformers/src/models/flux/quantized_model.rs new file mode 100644 index 0000000000..0efeeab573 --- /dev/null +++ b/candle-transformers/src/models/flux/quantized_model.rs @@ -0,0 +1,465 @@ +use super::model::{attention, timestep_embedding, Config, EmbedNd}; +use crate::quantized_nn::{linear, linear_b, Linear}; +use crate::quantized_var_builder::VarBuilder; +use candle::{DType, IndexOp, Result, Tensor, D}; +use candle_nn::{LayerNorm, RmsNorm}; + +fn layer_norm(dim: usize, vb: VarBuilder) -> Result { + let ws = Tensor::ones(dim, DType::F32, vb.device())?; + Ok(LayerNorm::new_no_bias(ws, 1e-6)) +} + +#[derive(Debug, Clone)] +pub struct MlpEmbedder { + in_layer: Linear, + out_layer: Linear, +} + +impl MlpEmbedder { + fn new(in_sz: usize, h_sz: usize, vb: VarBuilder) -> Result { + let in_layer = linear(in_sz, h_sz, vb.pp("in_layer"))?; + let out_layer = linear(h_sz, h_sz, vb.pp("out_layer"))?; + Ok(Self { + in_layer, + out_layer, + }) + } +} + +impl candle::Module for MlpEmbedder { + fn forward(&self, xs: &Tensor) -> Result { + xs.apply(&self.in_layer)?.silu()?.apply(&self.out_layer) + } +} + +#[derive(Debug, Clone)] +pub struct QkNorm { + query_norm: RmsNorm, + key_norm: RmsNorm, +} + +impl QkNorm { + fn new(dim: usize, vb: VarBuilder) -> Result { + let query_norm = vb.get(dim, "query_norm.scale")?.dequantize(vb.device())?; + let query_norm = RmsNorm::new(query_norm, 1e-6); + let key_norm = vb.get(dim, "key_norm.scale")?.dequantize(vb.device())?; + let key_norm = RmsNorm::new(key_norm, 1e-6); + Ok(Self { + query_norm, + key_norm, + }) + } +} + +struct ModulationOut { + shift: Tensor, + scale: Tensor, + gate: Tensor, +} + +impl ModulationOut { + fn scale_shift(&self, xs: &Tensor) -> Result { + xs.broadcast_mul(&(&self.scale + 1.)?)? + .broadcast_add(&self.shift) + } + + fn gate(&self, xs: &Tensor) -> Result { + self.gate.broadcast_mul(xs) + } +} + +#[derive(Debug, Clone)] +struct Modulation1 { + lin: Linear, +} + +impl Modulation1 { + fn new(dim: usize, vb: VarBuilder) -> Result { + let lin = linear(dim, 3 * dim, vb.pp("lin"))?; + Ok(Self { lin }) + } + + fn forward(&self, vec_: &Tensor) -> Result { + let ys = vec_ + .silu()? + .apply(&self.lin)? + .unsqueeze(1)? + .chunk(3, D::Minus1)?; + if ys.len() != 3 { + candle::bail!("unexpected len from chunk {ys:?}") + } + Ok(ModulationOut { + shift: ys[0].clone(), + scale: ys[1].clone(), + gate: ys[2].clone(), + }) + } +} + +#[derive(Debug, Clone)] +struct Modulation2 { + lin: Linear, +} + +impl Modulation2 { + fn new(dim: usize, vb: VarBuilder) -> Result { + let lin = linear(dim, 6 * dim, vb.pp("lin"))?; + Ok(Self { lin }) + } + + fn forward(&self, vec_: &Tensor) -> Result<(ModulationOut, ModulationOut)> { + let ys = vec_ + .silu()? + .apply(&self.lin)? + .unsqueeze(1)? + .chunk(6, D::Minus1)?; + if ys.len() != 6 { + candle::bail!("unexpected len from chunk {ys:?}") + } + let mod1 = ModulationOut { + shift: ys[0].clone(), + scale: ys[1].clone(), + gate: ys[2].clone(), + }; + let mod2 = ModulationOut { + shift: ys[3].clone(), + scale: ys[4].clone(), + gate: ys[5].clone(), + }; + Ok((mod1, mod2)) + } +} + +#[derive(Debug, Clone)] +pub struct SelfAttention { + qkv: Linear, + norm: QkNorm, + proj: Linear, + num_heads: usize, +} + +impl SelfAttention { + fn new(dim: usize, num_heads: usize, qkv_bias: bool, vb: VarBuilder) -> Result { + let head_dim = dim / num_heads; + let qkv = linear_b(dim, dim * 3, qkv_bias, vb.pp("qkv"))?; + let norm = QkNorm::new(head_dim, vb.pp("norm"))?; + let proj = linear(dim, dim, vb.pp("proj"))?; + Ok(Self { + qkv, + norm, + proj, + num_heads, + }) + } + + fn qkv(&self, xs: &Tensor) -> Result<(Tensor, Tensor, Tensor)> { + let qkv = xs.apply(&self.qkv)?; + let (b, l, _khd) = qkv.dims3()?; + let qkv = qkv.reshape((b, l, 3, self.num_heads, ()))?; + let q = qkv.i((.., .., 0))?.transpose(1, 2)?; + let k = qkv.i((.., .., 1))?.transpose(1, 2)?; + let v = qkv.i((.., .., 2))?.transpose(1, 2)?; + let q = q.apply(&self.norm.query_norm)?; + let k = k.apply(&self.norm.key_norm)?; + Ok((q, k, v)) + } + + #[allow(unused)] + fn forward(&self, xs: &Tensor, pe: &Tensor) -> Result { + let (q, k, v) = self.qkv(xs)?; + attention(&q, &k, &v, pe)?.apply(&self.proj) + } +} + +#[derive(Debug, Clone)] +struct Mlp { + lin1: Linear, + lin2: Linear, +} + +impl Mlp { + fn new(in_sz: usize, mlp_sz: usize, vb: VarBuilder) -> Result { + let lin1 = linear(in_sz, mlp_sz, vb.pp("0"))?; + let lin2 = linear(mlp_sz, in_sz, vb.pp("2"))?; + Ok(Self { lin1, lin2 }) + } +} + +impl candle::Module for Mlp { + fn forward(&self, xs: &Tensor) -> Result { + xs.apply(&self.lin1)?.gelu()?.apply(&self.lin2) + } +} + +#[derive(Debug, Clone)] +pub struct DoubleStreamBlock { + img_mod: Modulation2, + img_norm1: LayerNorm, + img_attn: SelfAttention, + img_norm2: LayerNorm, + img_mlp: Mlp, + txt_mod: Modulation2, + txt_norm1: LayerNorm, + txt_attn: SelfAttention, + txt_norm2: LayerNorm, + txt_mlp: Mlp, +} + +impl DoubleStreamBlock { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let h_sz = cfg.hidden_size; + let mlp_sz = (h_sz as f64 * cfg.mlp_ratio) as usize; + let img_mod = Modulation2::new(h_sz, vb.pp("img_mod"))?; + let img_norm1 = layer_norm(h_sz, vb.pp("img_norm1"))?; + let img_attn = SelfAttention::new(h_sz, cfg.num_heads, cfg.qkv_bias, vb.pp("img_attn"))?; + let img_norm2 = layer_norm(h_sz, vb.pp("img_norm2"))?; + let img_mlp = Mlp::new(h_sz, mlp_sz, vb.pp("img_mlp"))?; + let txt_mod = Modulation2::new(h_sz, vb.pp("txt_mod"))?; + let txt_norm1 = layer_norm(h_sz, vb.pp("txt_norm1"))?; + let txt_attn = SelfAttention::new(h_sz, cfg.num_heads, cfg.qkv_bias, vb.pp("txt_attn"))?; + let txt_norm2 = layer_norm(h_sz, vb.pp("txt_norm2"))?; + let txt_mlp = Mlp::new(h_sz, mlp_sz, vb.pp("txt_mlp"))?; + Ok(Self { + img_mod, + img_norm1, + img_attn, + img_norm2, + img_mlp, + txt_mod, + txt_norm1, + txt_attn, + txt_norm2, + txt_mlp, + }) + } + + fn forward( + &self, + img: &Tensor, + txt: &Tensor, + vec_: &Tensor, + pe: &Tensor, + ) -> Result<(Tensor, Tensor)> { + let (img_mod1, img_mod2) = self.img_mod.forward(vec_)?; // shift, scale, gate + let (txt_mod1, txt_mod2) = self.txt_mod.forward(vec_)?; // shift, scale, gate + let img_modulated = img.apply(&self.img_norm1)?; + let img_modulated = img_mod1.scale_shift(&img_modulated)?; + let (img_q, img_k, img_v) = self.img_attn.qkv(&img_modulated)?; + + let txt_modulated = txt.apply(&self.txt_norm1)?; + let txt_modulated = txt_mod1.scale_shift(&txt_modulated)?; + let (txt_q, txt_k, txt_v) = self.txt_attn.qkv(&txt_modulated)?; + + let q = Tensor::cat(&[txt_q, img_q], 2)?; + let k = Tensor::cat(&[txt_k, img_k], 2)?; + let v = Tensor::cat(&[txt_v, img_v], 2)?; + + let attn = attention(&q, &k, &v, pe)?; + let txt_attn = attn.narrow(1, 0, txt.dim(1)?)?; + let img_attn = attn.narrow(1, txt.dim(1)?, attn.dim(1)? - txt.dim(1)?)?; + + let img = (img + img_mod1.gate(&img_attn.apply(&self.img_attn.proj)?))?; + let img = (&img + + img_mod2.gate( + &img_mod2 + .scale_shift(&img.apply(&self.img_norm2)?)? + .apply(&self.img_mlp)?, + )?)?; + + let txt = (txt + txt_mod1.gate(&txt_attn.apply(&self.txt_attn.proj)?))?; + let txt = (&txt + + txt_mod2.gate( + &txt_mod2 + .scale_shift(&txt.apply(&self.txt_norm2)?)? + .apply(&self.txt_mlp)?, + )?)?; + + Ok((img, txt)) + } +} + +#[derive(Debug, Clone)] +pub struct SingleStreamBlock { + linear1: Linear, + linear2: Linear, + norm: QkNorm, + pre_norm: LayerNorm, + modulation: Modulation1, + h_sz: usize, + mlp_sz: usize, + num_heads: usize, +} + +impl SingleStreamBlock { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let h_sz = cfg.hidden_size; + let mlp_sz = (h_sz as f64 * cfg.mlp_ratio) as usize; + let head_dim = h_sz / cfg.num_heads; + let linear1 = linear(h_sz, h_sz * 3 + mlp_sz, vb.pp("linear1"))?; + let linear2 = linear(h_sz + mlp_sz, h_sz, vb.pp("linear2"))?; + let norm = QkNorm::new(head_dim, vb.pp("norm"))?; + let pre_norm = layer_norm(h_sz, vb.pp("pre_norm"))?; + let modulation = Modulation1::new(h_sz, vb.pp("modulation"))?; + Ok(Self { + linear1, + linear2, + norm, + pre_norm, + modulation, + h_sz, + mlp_sz, + num_heads: cfg.num_heads, + }) + } + + fn forward(&self, xs: &Tensor, vec_: &Tensor, pe: &Tensor) -> Result { + let mod_ = self.modulation.forward(vec_)?; + let x_mod = mod_.scale_shift(&xs.apply(&self.pre_norm)?)?; + let x_mod = x_mod.apply(&self.linear1)?; + let qkv = x_mod.narrow(D::Minus1, 0, 3 * self.h_sz)?; + let (b, l, _khd) = qkv.dims3()?; + let qkv = qkv.reshape((b, l, 3, self.num_heads, ()))?; + let q = qkv.i((.., .., 0))?.transpose(1, 2)?; + let k = qkv.i((.., .., 1))?.transpose(1, 2)?; + let v = qkv.i((.., .., 2))?.transpose(1, 2)?; + let mlp = x_mod.narrow(D::Minus1, 3 * self.h_sz, self.mlp_sz)?; + let q = q.apply(&self.norm.query_norm)?; + let k = k.apply(&self.norm.key_norm)?; + let attn = attention(&q, &k, &v, pe)?; + let output = Tensor::cat(&[attn, mlp.gelu()?], 2)?.apply(&self.linear2)?; + xs + mod_.gate(&output) + } +} + +#[derive(Debug, Clone)] +pub struct LastLayer { + norm_final: LayerNorm, + linear: Linear, + ada_ln_modulation: Linear, +} + +impl LastLayer { + fn new(h_sz: usize, p_sz: usize, out_c: usize, vb: VarBuilder) -> Result { + let norm_final = layer_norm(h_sz, vb.pp("norm_final"))?; + let linear_ = linear(h_sz, p_sz * p_sz * out_c, vb.pp("linear"))?; + let ada_ln_modulation = linear(h_sz, 2 * h_sz, vb.pp("adaLN_modulation.1"))?; + Ok(Self { + norm_final, + linear: linear_, + ada_ln_modulation, + }) + } + + fn forward(&self, xs: &Tensor, vec: &Tensor) -> Result { + let chunks = vec.silu()?.apply(&self.ada_ln_modulation)?.chunk(2, 1)?; + let (shift, scale) = (&chunks[0], &chunks[1]); + let xs = xs + .apply(&self.norm_final)? + .broadcast_mul(&(scale.unsqueeze(1)? + 1.0)?)? + .broadcast_add(&shift.unsqueeze(1)?)?; + xs.apply(&self.linear) + } +} + +#[derive(Debug, Clone)] +pub struct Flux { + img_in: Linear, + txt_in: Linear, + time_in: MlpEmbedder, + vector_in: MlpEmbedder, + guidance_in: Option, + pe_embedder: EmbedNd, + double_blocks: Vec, + single_blocks: Vec, + final_layer: LastLayer, +} + +impl Flux { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + let img_in = linear(cfg.in_channels, cfg.hidden_size, vb.pp("img_in"))?; + let txt_in = linear(cfg.context_in_dim, cfg.hidden_size, vb.pp("txt_in"))?; + let mut double_blocks = Vec::with_capacity(cfg.depth); + let vb_d = vb.pp("double_blocks"); + for idx in 0..cfg.depth { + let db = DoubleStreamBlock::new(cfg, vb_d.pp(idx))?; + double_blocks.push(db) + } + let mut single_blocks = Vec::with_capacity(cfg.depth_single_blocks); + let vb_s = vb.pp("single_blocks"); + for idx in 0..cfg.depth_single_blocks { + let sb = SingleStreamBlock::new(cfg, vb_s.pp(idx))?; + single_blocks.push(sb) + } + let time_in = MlpEmbedder::new(256, cfg.hidden_size, vb.pp("time_in"))?; + let vector_in = MlpEmbedder::new(cfg.vec_in_dim, cfg.hidden_size, vb.pp("vector_in"))?; + let guidance_in = if cfg.guidance_embed { + let mlp = MlpEmbedder::new(256, cfg.hidden_size, vb.pp("guidance_in"))?; + Some(mlp) + } else { + None + }; + let final_layer = + LastLayer::new(cfg.hidden_size, 1, cfg.in_channels, vb.pp("final_layer"))?; + let pe_dim = cfg.hidden_size / cfg.num_heads; + let pe_embedder = EmbedNd::new(pe_dim, cfg.theta, cfg.axes_dim.to_vec()); + Ok(Self { + img_in, + txt_in, + time_in, + vector_in, + guidance_in, + pe_embedder, + double_blocks, + single_blocks, + final_layer, + }) + } +} + +impl super::WithForward for Flux { + #[allow(clippy::too_many_arguments)] + fn forward( + &self, + img: &Tensor, + img_ids: &Tensor, + txt: &Tensor, + txt_ids: &Tensor, + timesteps: &Tensor, + y: &Tensor, + guidance: Option<&Tensor>, + ) -> Result { + if txt.rank() != 3 { + candle::bail!("unexpected shape for txt {:?}", txt.shape()) + } + if img.rank() != 3 { + candle::bail!("unexpected shape for img {:?}", img.shape()) + } + let dtype = img.dtype(); + let pe = { + let ids = Tensor::cat(&[txt_ids, img_ids], 1)?; + ids.apply(&self.pe_embedder)? + }; + let mut txt = txt.apply(&self.txt_in)?; + let mut img = img.apply(&self.img_in)?; + let vec_ = timestep_embedding(timesteps, 256, dtype)?.apply(&self.time_in)?; + let vec_ = match (self.guidance_in.as_ref(), guidance) { + (Some(g_in), Some(guidance)) => { + (vec_ + timestep_embedding(guidance, 256, dtype)?.apply(g_in))? + } + _ => vec_, + }; + let vec_ = (vec_ + y.apply(&self.vector_in))?; + + // Double blocks + for block in self.double_blocks.iter() { + (img, txt) = block.forward(&img, &txt, &vec_, &pe)? + } + // Single blocks + let mut img = Tensor::cat(&[&txt, &img], 1)?; + for block in self.single_blocks.iter() { + img = block.forward(&img, &vec_, &pe)?; + } + let img = img.i((.., txt.dim(1)?..))?; + self.final_layer.forward(&img, &vec_) + } +} diff --git a/candle-transformers/src/models/flux/sampling.rs b/candle-transformers/src/models/flux/sampling.rs index 89b9a95382..f3f0eafd4b 100644 --- a/candle-transformers/src/models/flux/sampling.rs +++ b/candle-transformers/src/models/flux/sampling.rs @@ -92,8 +92,8 @@ pub fn unpack(xs: &Tensor, height: usize, width: usize) -> Result { } #[allow(clippy::too_many_arguments)] -pub fn denoise( - model: &super::model::Flux, +pub fn denoise( + model: &M, img: &Tensor, img_ids: &Tensor, txt: &Tensor,