Skip to content

Commit

Permalink
feat: taesd modifiers
Browse files Browse the repository at this point in the history
  • Loading branch information
newfla committed Nov 5, 2024
1 parent e325d07 commit 266e0e3
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 1 deletion.
50 changes: 50 additions & 0 deletions src/modifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,53 @@ pub fn sdxl_vae_fp16_fix(mut builder: ConfigBuilder) -> Result<ConfigBuilder, Ap
builder.vae(vae_path);
Ok(builder)
}

/// Apply <https://huggingface.co/madebyollin/taesd> taesd autoencoder for faster decoding (SD v1/v2)
pub fn taesd(mut builder: ConfigBuilder) -> Result<ConfigBuilder, ApiError> {
let taesd_path =
download_file_hf_hub("madebyollin/taesd", "diffusion_pytorch_model.safetensors")?;
builder.taesd(taesd_path);
Ok(builder)
}

/// Apply <https://huggingface.co/madebyollin/taesdxl> taesd autoencoder for faster decoding (SDXL)
pub fn taesd_xl(mut builder: ConfigBuilder) -> Result<ConfigBuilder, ApiError> {
let taesd_path =
download_file_hf_hub("madebyollin/taesdxl", "diffusion_pytorch_model.safetensors")?;
builder.taesd(taesd_path);
Ok(builder)
}

#[cfg(test)]
mod tests {
use crate::{
api::txt2img,
preset::{Modifier, Preset, PresetBuilder},
};

use super::{taesd, taesd_xl};

static PROMPT: &str = "a lovely duck drinking water from a bottle";

fn run(preset: Preset, m: Modifier) {
let config = PresetBuilder::default()
.preset(preset)
.prompt(PROMPT)
.with_modifier(m)
.build()
.unwrap();
txt2img(config).unwrap();
}

#[ignore]
#[test]
fn test_taesd() {
run(Preset::StableDiffusion1_5, taesd);
}

#[ignore]
#[test]
fn test_taesd_xl() {
run(Preset::SDXLTurbo1_0Fp16, taesd_xl);
}
}
3 changes: 2 additions & 1 deletion src/preset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,12 @@ pub struct PresetConfig {

impl PresetBuilder {
/// Add modifier that will apply in sequence
pub fn with_modifier(&mut self, f: Modifier) {
pub fn with_modifier(&mut self, f: Modifier) -> &mut Self {
if self.modifiers.is_none() {
self.modifiers = Some(Vec::new());
}
self.modifiers.as_mut().unwrap().push(f);
self
}

pub fn build(&mut self) -> Result<Config, ConfigBuilderError> {
Expand Down

0 comments on commit 266e0e3

Please sign in to comment.