diff --git a/src/modifier.rs b/src/modifier.rs index 346820d..c4f7f59 100644 --- a/src/modifier.rs +++ b/src/modifier.rs @@ -18,3 +18,53 @@ pub fn sdxl_vae_fp16_fix(mut builder: ConfigBuilder) -> Result taesd autoencoder for faster decoding (SD v1/v2) +pub fn taesd(mut builder: ConfigBuilder) -> Result { + let taesd_path = + download_file_hf_hub("madebyollin/taesd", "diffusion_pytorch_model.safetensors")?; + builder.taesd(taesd_path); + Ok(builder) +} + +/// Apply taesd autoencoder for faster decoding (SDXL) +pub fn taesd_xl(mut builder: ConfigBuilder) -> Result { + let taesd_path = + download_file_hf_hub("madebyollin/taesdxl", "diffusion_pytorch_model.safetensors")?; + builder.taesd(taesd_path); + Ok(builder) +} + +#[cfg(test)] +mod tests { + use crate::{ + api::txt2img, + preset::{Modifier, Preset, PresetBuilder}, + }; + + use super::{taesd, taesd_xl}; + + static PROMPT: &str = "a lovely duck drinking water from a bottle"; + + fn run(preset: Preset, m: Modifier) { + let config = PresetBuilder::default() + .preset(preset) + .prompt(PROMPT) + .with_modifier(m) + .build() + .unwrap(); + txt2img(config).unwrap(); + } + + #[ignore] + #[test] + fn test_taesd() { + run(Preset::StableDiffusion1_5, taesd); + } + + #[ignore] + #[test] + fn test_taesd_xl() { + run(Preset::SDXLTurbo1_0Fp16, taesd_xl); + } +} diff --git a/src/preset.rs b/src/preset.rs index e25aed2..4b1461d 100644 --- a/src/preset.rs +++ b/src/preset.rs @@ -83,11 +83,12 @@ pub struct PresetConfig { impl PresetBuilder { /// Add modifier that will apply in sequence - pub fn with_modifier(&mut self, f: Modifier) { + pub fn with_modifier(&mut self, f: Modifier) -> &mut Self { if self.modifiers.is_none() { self.modifiers = Some(Vec::new()); } self.modifiers.as_mut().unwrap().push(f); + self } pub fn build(&mut self) -> Result {