diff --git a/candle-core/benches/benchmarks/affine.rs b/candle-core/benches/benchmarks/affine.rs index eded9f5787..c1004c6c6c 100644 --- a/candle-core/benches/benchmarks/affine.rs +++ b/candle-core/benches/benchmarks/affine.rs @@ -12,7 +12,7 @@ fn run_affine_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: let m = 1024; let k = 1024; - let tensor = Tensor::zeros((b, m, k), dtype, &device).unwrap(); + let tensor = Tensor::zeros((b, m, k), dtype, device).unwrap(); let flops = b * m * k * dtype.size_in_bytes(); diff --git a/candle-core/benches/benchmarks/qmatmul.rs b/candle-core/benches/benchmarks/qmatmul.rs index ccb136ac1e..4d34588b36 100644 --- a/candle-core/benches/benchmarks/qmatmul.rs +++ b/candle-core/benches/benchmarks/qmatmul.rs @@ -7,7 +7,7 @@ use criterion::{black_box, criterion_group, Criterion, Throughput}; use std::time::Instant; fn run(matmul: &QMatMul, x: &Tensor) { - matmul.forward(&x).unwrap(); + matmul.forward(x).unwrap(); } fn run_bench(c: &mut Criterion, device: &Device, dtype: GgmlDType) { @@ -50,7 +50,7 @@ fn run_bench(c: &mut Criterion, device: &Device, dtype: GgmlDType) { fn criterion_benchmark(c: &mut Criterion) { let handler = BenchDeviceHandler::new().unwrap(); for device in handler.devices { - for dtype in vec![ + for dtype in [ GgmlDType::F32, GgmlDType::F16, GgmlDType::Q4_0, diff --git a/candle-core/benches/benchmarks/unary.rs b/candle-core/benches/benchmarks/unary.rs index a8e0d02500..9efd75093d 100644 --- a/candle-core/benches/benchmarks/unary.rs +++ b/candle-core/benches/benchmarks/unary.rs @@ -12,7 +12,7 @@ fn run_unary_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: & let m = 1024; let k = 1024; - let tensor = Tensor::arange(0.0f32, (b * m * k) as f32, &device) + let tensor = Tensor::arange(0.0f32, (b * m * k) as f32, device) .unwrap() .to_dtype(dtype) .unwrap() diff --git a/candle-core/benches/benchmarks/where_cond.rs b/candle-core/benches/benchmarks/where_cond.rs index c517dcf590..0e91f656fc 100644 --- a/candle-core/benches/benchmarks/where_cond.rs +++ b/candle-core/benches/benchmarks/where_cond.rs @@ -25,9 +25,9 @@ const SIZE: usize = B * M * K; const DATA: [u8; SIZE] = create_cond_arr::(); fn run_where_cond_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) { - let tensor = Tensor::from_slice(DATA.as_slice(), (B, M, K), &device).unwrap(); - let on_true = Tensor::ones((B, M, K), dtype, &device).unwrap(); - let on_false = Tensor::zeros((B, M, K), dtype, &device).unwrap(); + let tensor = Tensor::from_slice(DATA.as_slice(), (B, M, K), device).unwrap(); + let on_true = Tensor::ones((B, M, K), dtype, device).unwrap(); + let on_false = Tensor::zeros((B, M, K), dtype, device).unwrap(); let elements = B * M * K; // E.g. 2 f32 tensors + 1 u8 tensor diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index dd1b44b0a0..82532f204f 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -590,9 +590,9 @@ impl Tensor { /// /// * `args` - A slice of 1D tensors. /// * `xy_indexing` - Whether to use xy indexing or ij indexing. If xy is selected, the - /// first dimension corresponds to the cardinality of the second input and the second - /// dimension corresponds to the cardinality of the first input. If ij is selected, the - /// dimensions are in the same order as the cardinality of the inputs. + /// first dimension corresponds to the cardinality of the second input and the second + /// dimension corresponds to the cardinality of the first input. If ij is selected, the + /// dimensions are in the same order as the cardinality of the inputs. /// /// # Examples /// diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml index fa5c620a48..56e3d535de 100644 --- a/candle-examples/Cargo.toml +++ b/candle-examples/Cargo.toml @@ -35,7 +35,7 @@ serde = { workspace = true } serde_json = { workspace = true } symphonia = { version = "0.5.3", features = ["all"], optional = true } tokenizers = { workspace = true, features = ["onig"] } -cpal= { version = "0.15.2", optional = true } +cpal = { version = "0.15.2", optional = true } [dev-dependencies] anyhow = { workspace = true } diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs index fa7ce81b59..93f1e50825 100644 --- a/candle-examples/examples/llama/main.rs +++ b/candle-examples/examples/llama/main.rs @@ -32,7 +32,9 @@ enum Which { V1, V2, V3, + V31, V3Instruct, + V31Instruct, #[value(name = "solar-10.7b")] Solar10_7B, #[value(name = "tiny-llama-1.1b-chat")] @@ -133,6 +135,8 @@ fn main() -> Result<()> { Which::V2 => "meta-llama/Llama-2-7b-hf".to_string(), Which::V3 => "meta-llama/Meta-Llama-3-8B".to_string(), Which::V3Instruct => "meta-llama/Meta-Llama-3-8B-Instruct".to_string(), + Which::V31 => "meta-llama/Meta-Llama-3.1-8B".to_string(), + Which::V31Instruct => "meta-llama/Meta-Llama-3.1-8B-Instruct".to_string(), Which::Solar10_7B => "upstage/SOLAR-10.7B-v1.0".to_string(), Which::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string(), }); @@ -146,7 +150,13 @@ fn main() -> Result<()> { let config = config.into_config(args.use_flash_attn); let filenames = match args.which { - Which::V1 | Which::V2 | Which::V3 | Which::V3Instruct | Which::Solar10_7B => { + Which::V1 + | Which::V2 + | Which::V3 + | Which::V3Instruct + | Which::V31 + | Which::V31Instruct + | Which::Solar10_7B => { candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")? } Which::TinyLlama1_1BChat => vec![api.get("model.safetensors")?], @@ -157,9 +167,11 @@ fn main() -> Result<()> { (Llama::load(vb, &config)?, tokenizer_filename, cache, config) }; let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; - let eos_token_id = config - .eos_token_id - .or_else(|| tokenizer.token_to_id(EOS_TOKEN)); + let eos_token_id = config.eos_token_id.or_else(|| { + tokenizer + .token_to_id(EOS_TOKEN) + .map(model::LlamaEosToks::Single) + }); let prompt = args.prompt.as_ref().map_or(DEFAULT_PROMPT, |p| p.as_str()); let mut tokens = tokenizer .encode(prompt, true) @@ -217,8 +229,14 @@ fn main() -> Result<()> { token_generated += 1; tokens.push(next_token); - if Some(next_token) == eos_token_id { - break; + match eos_token_id { + Some(model::LlamaEosToks::Single(eos_tok_id)) if next_token == eos_tok_id => { + break; + } + Some(model::LlamaEosToks::Multiple(ref eos_ids)) if eos_ids.contains(&next_token) => { + break; + } + _ => (), } if let Some(t) = tokenizer.next_token(next_token)? { print!("{t}"); diff --git a/candle-examples/examples/yolo-v3/darknet.rs b/candle-examples/examples/yolo-v3/darknet.rs index 331e712c69..1892acdd68 100644 --- a/candle-examples/examples/yolo-v3/darknet.rs +++ b/candle-examples/examples/yolo-v3/darknet.rs @@ -272,7 +272,7 @@ impl Darknet { let mut prev_channels: usize = 3; for (index, block) in self.blocks.iter().enumerate() { let channels_and_bl = match block.block_type.as_str() { - "convolutional" => conv(vb.pp(&index.to_string()), index, prev_channels, block)?, + "convolutional" => conv(vb.pp(index.to_string()), index, prev_channels, block)?, "upsample" => upsample(prev_channels)?, "shortcut" => shortcut(index, prev_channels, block)?, "route" => route(index, &blocks, block)?, diff --git a/candle-nn/src/activation.rs b/candle-nn/src/activation.rs index b9745375cd..fc1819f5e0 100644 --- a/candle-nn/src/activation.rs +++ b/candle-nn/src/activation.rs @@ -93,9 +93,9 @@ impl candle::Module for PReLU { /// # Arguments /// /// * `num_channels` - The number of channels. Use `None` to have as single trainable value and -/// `Some` for a 1D vector with the appropriate number of channels. When applying the `forward` -/// function, the input tensor shape `s` should either be one dimension with this number of -/// channels or if `s.len() >= 2` it should have `s[1]` equal to this number. +/// `Some` for a 1D vector with the appropriate number of channels. When applying the `forward` +/// function, the input tensor shape `s` should either be one dimension with this number of +/// channels or if `s.len() >= 2` it should have `s[1]` equal to this number. pub fn prelu(num_channels: Option, vs: crate::VarBuilder) -> Result { let init_ws = crate::init::Init::Const(0.25); // When using a scalar weight, the PyTorch encoding is to use a 1d vector of length 1. diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs index d6f6214faf..f6e6160bd2 100644 --- a/candle-nn/src/var_builder.rs +++ b/candle-nn/src/var_builder.rs @@ -264,6 +264,7 @@ impl SimpleBackend for VarMap { } } +#[allow(dead_code)] pub struct SafeTensorWithRouting<'a> { routing: HashMap, safetensors: Vec>, diff --git a/candle-transformers/src/models/beit.rs b/candle-transformers/src/models/beit.rs index 62bdd75adc..8f6284a8e6 100644 --- a/candle-transformers/src/models/beit.rs +++ b/candle-transformers/src/models/beit.rs @@ -288,7 +288,7 @@ impl BeitVisionTransformer { let norm = layer_norm(embed_dim, 1e-6, vb.pp("norm"))?; let vb_b = vb.pp("blocks"); let blocks = (0..depth) - .map(|i| Block::new(vb_b.pp(&i.to_string()), embed_dim, num_heads)) + .map(|i| Block::new(vb_b.pp(i.to_string()), embed_dim, num_heads)) .collect::>>()?; Ok(Self { patch_embed, diff --git a/candle-transformers/src/models/clip/text_model.rs b/candle-transformers/src/models/clip/text_model.rs index 4e4b4c9084..51db14ee0c 100644 --- a/candle-transformers/src/models/clip/text_model.rs +++ b/candle-transformers/src/models/clip/text_model.rs @@ -249,7 +249,7 @@ impl ClipEncoder { let vs = vs.pp("layers"); let mut layers: Vec = Vec::new(); for index in 0..c.num_hidden_layers() { - let layer = ClipEncoderLayer::new(vs.pp(&index.to_string()), c)?; + let layer = ClipEncoderLayer::new(vs.pp(index.to_string()), c)?; layers.push(layer) } Ok(ClipEncoder { layers }) diff --git a/candle-transformers/src/models/dinov2.rs b/candle-transformers/src/models/dinov2.rs index 00e501ce0d..706dfda0e7 100644 --- a/candle-transformers/src/models/dinov2.rs +++ b/candle-transformers/src/models/dinov2.rs @@ -214,7 +214,7 @@ impl DinoVisionTransformer { let norm = layer_norm(embed_dim, 1e-5, vb.pp("norm"))?; let vb_b = vb.pp("blocks"); let blocks = (0..depth) - .map(|i| Block::new(vb_b.pp(&i.to_string()), embed_dim, num_heads)) + .map(|i| Block::new(vb_b.pp(i.to_string()), embed_dim, num_heads)) .collect::>>()?; Ok(Self { patch_embed, diff --git a/candle-transformers/src/models/dinov2reg4.rs b/candle-transformers/src/models/dinov2reg4.rs index 6bbe2e2410..1d81703c9c 100644 --- a/candle-transformers/src/models/dinov2reg4.rs +++ b/candle-transformers/src/models/dinov2reg4.rs @@ -212,7 +212,7 @@ impl DinoVisionTransformer { let norm = layer_norm(embed_dim, 1e-6, vb.pp("norm"))?; let vb_b = vb.pp("blocks"); let blocks = (0..depth) - .map(|i| Block::new(vb_b.pp(&i.to_string()), embed_dim, num_heads)) + .map(|i| Block::new(vb_b.pp(i.to_string()), embed_dim, num_heads)) .collect::>>()?; Ok(Self { patch_embed, diff --git a/candle-transformers/src/models/encodec.rs b/candle-transformers/src/models/encodec.rs index 14a85d3ef9..fb70fb52f4 100644 --- a/candle-transformers/src/models/encodec.rs +++ b/candle-transformers/src/models/encodec.rs @@ -571,7 +571,7 @@ impl<'a> Layer<'a> { } fn next(&mut self) -> VarBuilder { - let vb = self.vb.pp(&self.cnt.to_string()); + let vb = self.vb.pp(self.cnt.to_string()); self.cnt += 1; vb } diff --git a/candle-transformers/src/models/eva2.rs b/candle-transformers/src/models/eva2.rs index eb2df4cd0d..013c385d1c 100644 --- a/candle-transformers/src/models/eva2.rs +++ b/candle-transformers/src/models/eva2.rs @@ -255,14 +255,7 @@ impl EVA2VisionTransformer { let norm = layer_norm(embed_dim, 1e-6, vb.pp("norm"))?; let vb_b = vb.pp("blocks"); let blocks = (0..depth) - .map(|i| { - Block::new( - vb_b.pp(&i.to_string()), - embed_dim, - num_heads, - &rot_pos_embed, - ) - }) + .map(|i| Block::new(vb_b.pp(i.to_string()), embed_dim, num_heads, &rot_pos_embed)) .collect::>>()?; Ok(Self { patch_embed, diff --git a/candle-transformers/src/models/llama.rs b/candle-transformers/src/models/llama.rs index a1f43d35b8..3681472be8 100644 --- a/candle-transformers/src/models/llama.rs +++ b/candle-transformers/src/models/llama.rs @@ -1,9 +1,33 @@ use super::with_tracing::{linear_no_bias as linear, Linear, RmsNorm}; use candle::{DType, Device, IndexOp, Result, Tensor, D}; use candle_nn::{embedding, Embedding, Module, VarBuilder}; -use std::collections::HashMap; +use std::{collections::HashMap, f32::consts::PI}; -pub const MAX_SEQ_LEN: usize = 4096; +pub const DEFAULT_MAX_SEQ_LEN: usize = 4096; + +#[derive(Debug, Clone, serde::Deserialize, Default)] +pub enum Llama3RopeType { + #[serde(rename = "llama3")] + Llama3, + #[default] + #[serde(rename = "default")] + Default, +} + +#[derive(Debug, Clone, serde::Deserialize, Default)] +pub struct Llama3RopeConfig { + pub factor: f32, + pub low_freq_factor: f32, + pub high_freq_factor: f32, + pub original_max_position_embeddings: usize, + pub rope_type: Llama3RopeType, +} +#[derive(Debug, Clone, serde::Deserialize)] +#[serde(untagged)] +pub enum LlamaEosToks { + Single(u32), + Multiple(Vec), +} #[derive(Debug, Clone, serde::Deserialize)] pub struct LlamaConfig { @@ -17,7 +41,9 @@ pub struct LlamaConfig { #[serde(default = "default_rope")] pub rope_theta: f32, pub bos_token_id: Option, - pub eos_token_id: Option, + pub eos_token_id: Option, + pub rope_scaling: Option, + pub max_position_embeddings: usize, } impl LlamaConfig { @@ -44,6 +70,8 @@ impl LlamaConfig { use_flash_attn, bos_token_id: self.bos_token_id, eos_token_id: self.eos_token_id, + rope_scaling: self.rope_scaling, + max_position_embeddings: self.max_position_embeddings, } } } @@ -60,7 +88,9 @@ pub struct Config { pub rms_norm_eps: f64, pub rope_theta: f32, pub bos_token_id: Option, - pub eos_token_id: Option, + pub eos_token_id: Option, + pub rope_scaling: Option, + pub max_position_embeddings: usize, } impl Config { @@ -77,6 +107,8 @@ impl Config { rope_theta: 10_000.0, bos_token_id: None, eos_token_id: None, + rope_scaling: None, + max_position_embeddings: DEFAULT_MAX_SEQ_LEN, } } @@ -93,6 +125,8 @@ impl Config { rope_theta: 10_000.0, bos_token_id: None, eos_token_id: None, + rope_scaling: None, + max_position_embeddings: DEFAULT_MAX_SEQ_LEN, } } } @@ -107,18 +141,54 @@ pub struct Cache { device: Device, } +fn calculate_default_inv_freq(cfg: &Config) -> Vec { + let head_dim = cfg.hidden_size / cfg.num_attention_heads; + (0..head_dim) + .step_by(2) + .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32)) + .collect() +} + impl Cache { pub fn new(use_kv_cache: bool, dtype: DType, config: &Config, device: &Device) -> Result { // precompute freqs_cis - let n_elem = config.hidden_size / config.num_attention_heads; - let theta: Vec<_> = (0..n_elem) - .step_by(2) - .map(|i| 1f32 / config.rope_theta.powf(i as f32 / n_elem as f32)) - .collect(); - let theta = Tensor::new(theta.as_slice(), device)?; - let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)? + let theta = match &config.rope_scaling { + None + | Some(Llama3RopeConfig { + rope_type: Llama3RopeType::Default, + .. + }) => calculate_default_inv_freq(config), + Some(rope_scaling) => { + let low_freq_wavelen = rope_scaling.original_max_position_embeddings as f32 + / rope_scaling.low_freq_factor; + let high_freq_wavelen = rope_scaling.original_max_position_embeddings as f32 + / rope_scaling.high_freq_factor; + + calculate_default_inv_freq(config) + .into_iter() + .map(|freq| { + let wavelen = 2. * PI / freq; + if wavelen < high_freq_wavelen { + freq + } else if wavelen > low_freq_wavelen { + freq / rope_scaling.factor + } else { + let smooth = (rope_scaling.original_max_position_embeddings as f32 + / wavelen + - rope_scaling.low_freq_factor) + / (rope_scaling.high_freq_factor - rope_scaling.low_freq_factor); + (1. - smooth) * freq / rope_scaling.factor + smooth * freq + } + }) + .collect::>() + } + }; + + let theta = Tensor::new(theta, device)?; + + let idx_theta = Tensor::arange(0, config.max_position_embeddings as u32, device)? .to_dtype(DType::F32)? - .reshape((MAX_SEQ_LEN, 1))? + .reshape((config.max_position_embeddings, 1))? .matmul(&theta.reshape((1, theta.elem_count()))?)?; // This is different from the paper, see: // https://github.com/huggingface/transformers/blob/6112b1c6442aaf7affd2b0676a1cd4eee30c45cf/src/transformers/models/llama/modeling_llama.py#L112 @@ -160,6 +230,7 @@ struct CausalSelfAttention { use_flash_attn: bool, span: tracing::Span, span_rot: tracing::Span, + max_position_embeddings: usize, } #[cfg(feature = "flash-attn")] @@ -220,15 +291,23 @@ impl CausalSelfAttention { k = Tensor::cat(&[cache_k, &k], 2)?.contiguous()?; v = Tensor::cat(&[cache_v, &v], 2)?.contiguous()?; let k_seq_len = k.dims()[1]; - if k_seq_len > MAX_SEQ_LEN { + if k_seq_len > self.max_position_embeddings { k = k - .narrow(D::Minus1, k_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)? + .narrow( + D::Minus1, + k_seq_len - self.max_position_embeddings, + self.max_position_embeddings, + )? .contiguous()? } let v_seq_len = v.dims()[1]; - if v_seq_len > 2 * MAX_SEQ_LEN { + if v_seq_len > 2 * self.max_position_embeddings { v = v - .narrow(D::Minus1, v_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)? + .narrow( + D::Minus1, + v_seq_len - self.max_position_embeddings, + self.max_position_embeddings, + )? .contiguous()? } } @@ -291,6 +370,7 @@ impl CausalSelfAttention { use_flash_attn: cfg.use_flash_attn, span, span_rot, + max_position_embeddings: cfg.max_position_embeddings, }) } } diff --git a/candle-transformers/src/models/llava/config.rs b/candle-transformers/src/models/llava/config.rs index d2d47003ec..5dca68704e 100644 --- a/candle-transformers/src/models/llava/config.rs +++ b/candle-transformers/src/models/llava/config.rs @@ -2,7 +2,7 @@ use std::collections::HashMap; use crate::models::{ clip::{text_model::Activation, vision_model::ClipVisionConfig}, - llama::Config, + llama::{Config, LlamaEosToks}, }; use serde::{Deserialize, Serialize}; @@ -73,8 +73,10 @@ impl LLaVAConfig { rms_norm_eps: self.rms_norm_eps as f64, rope_theta: self.rope_theta, bos_token_id: Some(self.bos_token_id as u32), - eos_token_id: Some(self.eos_token_id as u32), + eos_token_id: Some(LlamaEosToks::Single(self.eos_token_id as u32)), use_flash_attn: false, + rope_scaling: None, // Assume we don't have LLaVA for Llama 3.1 + max_position_embeddings: self.max_position_embeddings, } } } diff --git a/candle-transformers/src/models/stable_diffusion/attention.rs b/candle-transformers/src/models/stable_diffusion/attention.rs index 4d5a7c47af..5cc59e8203 100644 --- a/candle-transformers/src/models/stable_diffusion/attention.rs +++ b/candle-transformers/src/models/stable_diffusion/attention.rs @@ -358,7 +358,7 @@ impl SpatialTransformer { let vs_tb = vs.pp("transformer_blocks"); for index in 0..config.depth { let tb = BasicTransformerBlock::new( - vs_tb.pp(&index.to_string()), + vs_tb.pp(index.to_string()), inner_dim, n_heads, d_head, diff --git a/candle-transformers/src/models/stable_diffusion/clip.rs b/candle-transformers/src/models/stable_diffusion/clip.rs index 20e8ceaca4..5254818e60 100644 --- a/candle-transformers/src/models/stable_diffusion/clip.rs +++ b/candle-transformers/src/models/stable_diffusion/clip.rs @@ -322,7 +322,7 @@ impl ClipEncoder { let vs = vs.pp("layers"); let mut layers: Vec = Vec::new(); for index in 0..c.num_hidden_layers { - let layer = ClipEncoderLayer::new(vs.pp(&index.to_string()), c)?; + let layer = ClipEncoderLayer::new(vs.pp(index.to_string()), c)?; layers.push(layer) } Ok(ClipEncoder { layers }) diff --git a/candle-transformers/src/models/stable_diffusion/unet_2d.rs b/candle-transformers/src/models/stable_diffusion/unet_2d.rs index f23bd42597..cbef3316dd 100644 --- a/candle-transformers/src/models/stable_diffusion/unet_2d.rs +++ b/candle-transformers/src/models/stable_diffusion/unet_2d.rs @@ -161,7 +161,7 @@ impl UNet2DConditionModel { transformer_layers_per_block, }; let block = CrossAttnDownBlock2D::new( - vs_db.pp(&i.to_string()), + vs_db.pp(i.to_string()), in_channels, out_channels, Some(time_embed_dim), @@ -171,7 +171,7 @@ impl UNet2DConditionModel { Ok(UNetDownBlock::CrossAttn(block)) } else { let block = DownBlock2D::new( - vs_db.pp(&i.to_string()), + vs_db.pp(i.to_string()), in_channels, out_channels, Some(time_embed_dim), @@ -251,7 +251,7 @@ impl UNet2DConditionModel { transformer_layers_per_block, }; let block = CrossAttnUpBlock2D::new( - vs_ub.pp(&i.to_string()), + vs_ub.pp(i.to_string()), in_channels, prev_out_channels, out_channels, @@ -262,7 +262,7 @@ impl UNet2DConditionModel { Ok(UNetUpBlock::CrossAttn(block)) } else { let block = UpBlock2D::new( - vs_ub.pp(&i.to_string()), + vs_ub.pp(i.to_string()), in_channels, prev_out_channels, out_channels, diff --git a/candle-transformers/src/models/stable_diffusion/unet_2d_blocks.rs b/candle-transformers/src/models/stable_diffusion/unet_2d_blocks.rs index 18448427d6..028c51b744 100644 --- a/candle-transformers/src/models/stable_diffusion/unet_2d_blocks.rs +++ b/candle-transformers/src/models/stable_diffusion/unet_2d_blocks.rs @@ -146,7 +146,7 @@ impl DownEncoderBlock2D { (0..(config.num_layers)) .map(|i| { let in_channels = if i == 0 { in_channels } else { out_channels }; - ResnetBlock2D::new(vs.pp(&i.to_string()), in_channels, conv_cfg) + ResnetBlock2D::new(vs.pp(i.to_string()), in_channels, conv_cfg) }) .collect::>>()? }; @@ -235,7 +235,7 @@ impl UpDecoderBlock2D { (0..(config.num_layers)) .map(|i| { let in_channels = if i == 0 { in_channels } else { out_channels }; - ResnetBlock2D::new(vs.pp(&i.to_string()), in_channels, conv_cfg) + ResnetBlock2D::new(vs.pp(i.to_string()), in_channels, conv_cfg) }) .collect::>>()? }; @@ -328,9 +328,9 @@ impl UNetMidBlock2D { }; let mut attn_resnets = vec![]; for index in 0..config.num_layers { - let attn = AttentionBlock::new(vs_attns.pp(&index.to_string()), in_channels, attn_cfg)?; + let attn = AttentionBlock::new(vs_attns.pp(index.to_string()), in_channels, attn_cfg)?; let resnet = ResnetBlock2D::new( - vs_resnets.pp(&(index + 1).to_string()), + vs_resnets.pp((index + 1).to_string()), in_channels, resnet_cfg, )?; @@ -425,7 +425,7 @@ impl UNetMidBlock2DCrossAttn { let mut attn_resnets = vec![]; for index in 0..config.num_layers { let attn = SpatialTransformer::new( - vs_attns.pp(&index.to_string()), + vs_attns.pp(index.to_string()), in_channels, n_heads, in_channels / n_heads, @@ -433,7 +433,7 @@ impl UNetMidBlock2DCrossAttn { attn_cfg, )?; let resnet = ResnetBlock2D::new( - vs_resnets.pp(&(index + 1).to_string()), + vs_resnets.pp((index + 1).to_string()), in_channels, resnet_cfg, )?; @@ -515,7 +515,7 @@ impl DownBlock2D { let resnets = (0..config.num_layers) .map(|i| { let in_channels = if i == 0 { in_channels } else { out_channels }; - ResnetBlock2D::new(vs_resnets.pp(&i.to_string()), in_channels, resnet_cfg) + ResnetBlock2D::new(vs_resnets.pp(i.to_string()), in_channels, resnet_cfg) }) .collect::>>()?; let downsampler = if config.add_downsample { @@ -619,7 +619,7 @@ impl CrossAttnDownBlock2D { let attentions = (0..config.downblock.num_layers) .map(|i| { SpatialTransformer::new( - vs_attn.pp(&i.to_string()), + vs_attn.pp(i.to_string()), out_channels, n_heads, out_channels / n_heads, @@ -724,7 +724,7 @@ impl UpBlock2D { out_channels }; let in_channels = resnet_in_channels + res_skip_channels; - ResnetBlock2D::new(vs_resnets.pp(&i.to_string()), in_channels, resnet_cfg) + ResnetBlock2D::new(vs_resnets.pp(i.to_string()), in_channels, resnet_cfg) }) .collect::>>()?; let upsampler = if config.add_upsample { @@ -826,7 +826,7 @@ impl CrossAttnUpBlock2D { let attentions = (0..config.upblock.num_layers) .map(|i| { SpatialTransformer::new( - vs_attn.pp(&i.to_string()), + vs_attn.pp(i.to_string()), out_channels, n_heads, out_channels / n_heads, diff --git a/candle-transformers/src/models/stable_diffusion/vae.rs b/candle-transformers/src/models/stable_diffusion/vae.rs index 21709afe8b..670b3f5638 100644 --- a/candle-transformers/src/models/stable_diffusion/vae.rs +++ b/candle-transformers/src/models/stable_diffusion/vae.rs @@ -80,7 +80,7 @@ impl Encoder { ..Default::default() }; let down_block = DownEncoderBlock2D::new( - vs_down_blocks.pp(&index.to_string()), + vs_down_blocks.pp(index.to_string()), in_channels, out_channels, cfg, @@ -222,7 +222,7 @@ impl Decoder { ..Default::default() }; let up_block = UpDecoderBlock2D::new( - vs_up_blocks.pp(&index.to_string()), + vs_up_blocks.pp(index.to_string()), in_channels, out_channels, cfg, diff --git a/candle-transformers/src/models/t5.rs b/candle-transformers/src/models/t5.rs index 8a7a8955b6..21517d64b5 100644 --- a/candle-transformers/src/models/t5.rs +++ b/candle-transformers/src/models/t5.rs @@ -601,7 +601,7 @@ impl T5Block { None }; let ff_i = if cross_attn.is_some() { 2 } else { 1 }; - let ff = T5LayerFF::load(vb.pp(&ff_i.to_string()), cfg)?; + let ff = T5LayerFF::load(vb.pp(ff_i.to_string()), cfg)?; Ok(Self { self_attn, cross_attn,