Skip to content

Commit

Permalink
Quantized version of flux. (huggingface#2500)
Browse files Browse the repository at this point in the history
* Quantized version of flux.

* More generic sampling.

* Hook the quantized model.

* Use the newly minted gguf file.

* Fix for the quantized model.

* Default to avoid the faster cuda kernels.
  • Loading branch information
LaurentMazare committed Sep 26, 2024
1 parent d01207d commit 10d4718
Show file tree
Hide file tree
Showing 6 changed files with 555 additions and 26 deletions.
2 changes: 1 addition & 1 deletion candle-examples/examples/flux/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ descriptions,

```bash
cargo run --features cuda --example flux -r -- \
--height 1024 --width 1024
--height 1024 --width 1024 \
--prompt "a rusty robot walking on a beach holding a small torch, the robot has the word "rust" written on it, high quality, 4k"
```

83 changes: 64 additions & 19 deletions candle-examples/examples/flux/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ struct Args {
#[arg(long)]
cpu: bool,

/// Use the quantized model.
#[arg(long)]
quantized: bool,

/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
Expand All @@ -40,6 +44,10 @@ struct Args {

#[arg(long, value_enum, default_value = "schnell")]
model: Model,

/// Use the faster kernels which are buggy at the moment.
#[arg(long)]
no_dmmv: bool,
}

#[derive(Debug, Clone, Copy, clap::ValueEnum, PartialEq, Eq)]
Expand All @@ -60,6 +68,8 @@ fn run(args: Args) -> Result<()> {
tracing,
decode_only,
model,
quantized,
..
} = args;
let width = width.unwrap_or(1360);
let height = height.unwrap_or(768);
Expand Down Expand Up @@ -146,38 +156,71 @@ fn run(args: Args) -> Result<()> {
};
println!("CLIP\n{clip_emb}");
let img = {
let model_file = match model {
Model::Schnell => bf_repo.get("flux1-schnell.safetensors")?,
Model::Dev => bf_repo.get("flux1-dev.safetensors")?,
};
let vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? };
let cfg = match model {
Model::Dev => flux::model::Config::dev(),
Model::Schnell => flux::model::Config::schnell(),
};
let img = flux::sampling::get_noise(1, height, width, &device)?.to_dtype(dtype)?;
let state = flux::sampling::State::new(&t5_emb, &clip_emb, &img)?;
let state = if quantized {
flux::sampling::State::new(
&t5_emb.to_dtype(candle::DType::F32)?,
&clip_emb.to_dtype(candle::DType::F32)?,
&img.to_dtype(candle::DType::F32)?,
)?
} else {
flux::sampling::State::new(&t5_emb, &clip_emb, &img)?
};
let timesteps = match model {
Model::Dev => {
flux::sampling::get_schedule(50, Some((state.img.dim(1)?, 0.5, 1.15)))
}
Model::Schnell => flux::sampling::get_schedule(4, None),
};
let model = flux::model::Flux::new(&cfg, vb)?;

println!("{state:?}");
println!("{timesteps:?}");
flux::sampling::denoise(
&model,
&state.img,
&state.img_ids,
&state.txt,
&state.txt_ids,
&state.vec,
&timesteps,
4.,
)?
if quantized {
let model_file = match model {
Model::Schnell => api
.repo(hf_hub::Repo::model("lmz/candle-flux".to_string()))
.get("flux1-schnell.gguf")?,
Model::Dev => todo!(),
};
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(
model_file, &device,
)?;

let model = flux::quantized_model::Flux::new(&cfg, vb)?;
flux::sampling::denoise(
&model,
&state.img,
&state.img_ids,
&state.txt,
&state.txt_ids,
&state.vec,
&timesteps,
4.,
)?
.to_dtype(dtype)?
} else {
let model_file = match model {
Model::Schnell => bf_repo.get("flux1-schnell.safetensors")?,
Model::Dev => bf_repo.get("flux1-dev.safetensors")?,
};
let vb = unsafe {
VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)?
};
let model = flux::model::Flux::new(&cfg, vb)?;
flux::sampling::denoise(
&model,
&state.img,
&state.img_ids,
&state.txt,
&state.txt_ids,
&state.vec,
&timesteps,
4.,
)?
}
};
flux::sampling::unpack(&img, height, width)?
}
Expand Down Expand Up @@ -206,5 +249,7 @@ fn run(args: Args) -> Result<()> {

fn main() -> Result<()> {
let args = Args::parse();
#[cfg(feature = "cuda")]
candle::quantized::cuda::set_force_dmmv(!args.no_dmmv);
run(args)
}
17 changes: 17 additions & 0 deletions candle-transformers/src/models/flux/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,20 @@
use candle::{Result, Tensor};

pub trait WithForward {
#[allow(clippy::too_many_arguments)]
fn forward(
&self,
img: &Tensor,
img_ids: &Tensor,
txt: &Tensor,
txt_ids: &Tensor,
timesteps: &Tensor,
y: &Tensor,
guidance: Option<&Tensor>,
) -> Result<Tensor>;
}

pub mod autoencoder;
pub mod model;
pub mod quantized_model;
pub mod sampling;
10 changes: 6 additions & 4 deletions candle-transformers/src/models/flux/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,14 +109,14 @@ fn apply_rope(x: &Tensor, freq_cis: &Tensor) -> Result<Tensor> {
(fr0.broadcast_mul(&x0)? + fr1.broadcast_mul(&x1)?)?.reshape(dims.to_vec())
}

fn attention(q: &Tensor, k: &Tensor, v: &Tensor, pe: &Tensor) -> Result<Tensor> {
pub(crate) fn attention(q: &Tensor, k: &Tensor, v: &Tensor, pe: &Tensor) -> Result<Tensor> {
let q = apply_rope(q, pe)?.contiguous()?;
let k = apply_rope(k, pe)?.contiguous()?;
let x = scaled_dot_product_attention(&q, &k, v)?;
x.transpose(1, 2)?.flatten_from(2)
}

fn timestep_embedding(t: &Tensor, dim: usize, dtype: DType) -> Result<Tensor> {
pub(crate) fn timestep_embedding(t: &Tensor, dim: usize, dtype: DType) -> Result<Tensor> {
const TIME_FACTOR: f64 = 1000.;
const MAX_PERIOD: f64 = 10000.;
if dim % 2 == 1 {
Expand Down Expand Up @@ -144,7 +144,7 @@ pub struct EmbedNd {
}

impl EmbedNd {
fn new(dim: usize, theta: usize, axes_dim: Vec<usize>) -> Self {
pub fn new(dim: usize, theta: usize, axes_dim: Vec<usize>) -> Self {
Self {
dim,
theta,
Expand Down Expand Up @@ -575,9 +575,11 @@ impl Flux {
final_layer,
})
}
}

impl super::WithForward for Flux {
#[allow(clippy::too_many_arguments)]
pub fn forward(
fn forward(
&self,
img: &Tensor,
img_ids: &Tensor,
Expand Down
Loading

0 comments on commit 10d4718

Please sign in to comment.