Skip to content

Commit

Permalink
Merge pull request #8 from newfla/fix_sdxl_turbo
Browse files Browse the repository at this point in the history
fix: sdxl vae NaN
  • Loading branch information
newfla authored Nov 4, 2024
2 parents 4743cc2 + 43367de commit 99da8a0
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 10 deletions.
7 changes: 7 additions & 0 deletions src/modifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,10 @@ pub fn real_esrgan_x4plus_anime_6_b(mut builder: ConfigBuilder) -> Result<Config
builder.upscale_model(upscaler_path);
Ok(builder)
}

/// Apply <https://huggingface.co/madebyollin/sdxl-vae-fp16-fix> to avoid black images with xl models
pub fn sdxl_vae_fp16_fix(mut builder: ConfigBuilder) -> Result<ConfigBuilder, ApiError> {
let vae_path = download_file_hf_hub("madebyollin/sdxl-vae-fp16-fix", "sdxl.vae.safetensors")?;
builder.vae(vae_path);
Ok(builder)
}
14 changes: 6 additions & 8 deletions src/preset_builder.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use crate::api::{self, SampleMethod};
use crate::{
api::{self, SampleMethod},
modifier::sdxl_vae_fp16_fix,
};
use hf_hub::api::sync::ApiError;

use crate::{api::ConfigBuilder, util::download_file_hf_hub};
Expand Down Expand Up @@ -71,18 +74,14 @@ pub fn sdxl_base_1_0() -> Result<ConfigBuilder, ApiError> {
"sd_xl_base_1.0.safetensors",
)?;

let vae_path = download_file_hf_hub("madebyollin/sdxl-vae-fp16-fix", "sdxl_vae.safetensors")?;

let mut config = ConfigBuilder::default();

config
.model(model_path)
.vae(vae_path)
.vae_tiling(true)
.height(1024)
.width(1024);

Ok(config)
sdxl_vae_fp16_fix(config)
}

pub fn flux_1_dev(sd_type: api::WeightType) -> Result<ConfigBuilder, ApiError> {
Expand Down Expand Up @@ -167,8 +166,7 @@ pub fn sdxl_turbo_1_0_fp16() -> Result<ConfigBuilder, ApiError> {
let mut config = ConfigBuilder::default();

config.model(model_path).guidance(0.).cfg_scale(1.).steps(4);

Ok(config)
sdxl_vae_fp16_fix(config)
}

pub fn stable_diffusion_3_5_large_fp16() -> Result<ConfigBuilder, ApiError> {
Expand Down
5 changes: 3 additions & 2 deletions src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@ pub fn set_hf_token(token: &str) {
*data = token.to_owned();
}

pub(crate) fn download_file_hf_hub(model: &str, file: &str) -> Result<PathBuf, ApiError> {
/// Download file from huggingface hub
pub fn download_file_hf_hub(repo: &str, file: &str) -> Result<PathBuf, ApiError> {
let token = TOKEN.get().map(|token| token.read().unwrap().to_owned());
let repo = ApiBuilder::new()
.with_token(token)
.build()?
.model(model.to_string());
.model(repo.to_string());
repo.get(file)
}

0 comments on commit 99da8a0

Please sign in to comment.