Skip to content

Commit

Permalink
Also adapt the img2img and inpaint pipelines.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Dec 11, 2022
1 parent 935c349 commit a6b1788
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 61 deletions.
4 changes: 0 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,6 @@ clap = { version = "4.0.19", optional = true, features = ["derive"] }
name = "stable-diffusion"
required-features = ["clap"]

[[example]]
name = "stable-diffusion-2"
required-features = ["clap"]

[[example]]
name = "stable-diffusion-img2img"
required-features = ["clap"]
Expand Down
100 changes: 74 additions & 26 deletions examples/stable-diffusion-img2img/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
// image: https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg
// prompt = "A fantasy landscape, trending on artstation"
use clap::Parser;
use diffusers::pipelines::stable_diffusion::v1_5 as stable_diffusion;
use diffusers::{schedulers::ddim, transformers::clip};
use diffusers::pipelines::stable_diffusion;
use diffusers::transformers::clip;
use tch::{nn::Module, Device, Kind, Tensor};

const GUIDANCE_SCALE: f64 = 7.5;
Expand All @@ -24,24 +24,28 @@ struct Args {

/// The prompt to be used for image generation.
#[arg(long, default_value = "A fantasy landscape, trending on artstation.")]
prompt: Option<String>,
prompt: String,

/// When set, use the CPU for the listed devices, can be 'all', 'unet', 'clip', etc.
/// Multiple values can be set.
#[arg(long)]
cpu: Vec<String>,

/// The UNet weight file, in .ot format.
#[arg(long, value_name = "FILE", default_value = "data/unet.ot")]
unet_weights: String,
#[arg(long, value_name = "FILE")]
unet_weights: Option<String>,

/// The CLIP weight file, in .ot format.
#[arg(long, value_name = "FILE", default_value = "data/pytorch_model.ot")]
clip_weights: String,
#[arg(long, value_name = "FILE")]
clip_weights: Option<String>,

/// The VAE weight file, in .ot format.
#[arg(long, value_name = "FILE", default_value = "data/vae.ot")]
vae_weights: String,
#[arg(long, value_name = "FILE")]
vae_weights: Option<String>,

#[arg(long, value_name = "FILE", default_value = "data/bpe_simple_vocab_16e6.txt")]
/// The file specifying the vocabulary to used for tokenization.
vocab_file: String,

/// The size of the sliced attention or 0 to disable slicing (default)
#[arg(long)]
Expand Down Expand Up @@ -72,6 +76,47 @@ struct Args {
/// Do not use autocast.
#[arg(long, action)]
no_autocast: bool,

#[arg(long, value_enum, default_value = "v2-1")]
sd_version: StableDiffusionVersion,
}

#[derive(Debug, Clone, Copy, clap::ValueEnum)]
enum StableDiffusionVersion {
V1_5,
V2_1,
}

impl Args {
fn clip_weights(&self) -> String {
match &self.clip_weights {
Some(w) => w.clone(),
None => match self.sd_version {
StableDiffusionVersion::V1_5 => "data/pytorch_model.ot".to_string(),
StableDiffusionVersion::V2_1 => "data/clip_v2.1.ot".to_string(),
},
}
}

fn vae_weights(&self) -> String {
match &self.vae_weights {
Some(w) => w.clone(),
None => match self.sd_version {
StableDiffusionVersion::V1_5 => "data/vae.ot".to_string(),
StableDiffusionVersion::V2_1 => "data/vae_v2.1.ot".to_string(),
},
}
}

fn unet_weights(&self) -> String {
match &self.unet_weights {
Some(w) => w.clone(),
None => match self.sd_version {
StableDiffusionVersion::V1_5 => "data/unet.ot".to_string(),
StableDiffusionVersion::V2_1 => "data/unet_v2.1.ot".to_string(),
},
}
}
}

fn image_preprocess<T: AsRef<std::path::Path>>(path: T) -> anyhow::Result<Tensor> {
Expand All @@ -84,27 +129,38 @@ fn image_preprocess<T: AsRef<std::path::Path>>(path: T) -> anyhow::Result<Tensor
}

fn run(args: Args) -> anyhow::Result<()> {
let clip_weights = args.clip_weights();
let vae_weights = args.vae_weights();
let unet_weights = args.unet_weights();
let Args {
prompt,
cpu,
n_steps,
seed,
final_image,
vae_weights,
clip_weights,
unet_weights,
sliced_attention_size,
num_samples,
strength,
input_image,
no_autocast: _,
sd_version,
vocab_file,
..
} = args;
if !(0. ..=1.).contains(&strength) {
anyhow::bail!("strength should be between 0 and 1, got {strength}")
}
tch::maybe_init_cuda();
println!("Cuda available: {}", tch::Cuda::is_available());
println!("Cudnn available: {}", tch::Cuda::cudnn_is_available());
let sd_config = match sd_version {
StableDiffusionVersion::V1_5 => {
stable_diffusion::StableDiffusionConfig::v1_5(sliced_attention_size)
}
StableDiffusionVersion::V2_1 => {
stable_diffusion::StableDiffusionConfig::v2_1(sliced_attention_size)
}
};

let cuda_device = Device::cuda_if_available();
let cpu_or_cuda = |name: &str| {
if cpu.iter().any(|c| c == "all" || c == name) {
Expand All @@ -117,13 +173,9 @@ fn run(args: Args) -> anyhow::Result<()> {
let clip_device = cpu_or_cuda("clip");
let vae_device = cpu_or_cuda("vae");
let unet_device = cpu_or_cuda("unet");
let scheduler = ddim::DDIMScheduler::new(n_steps, 1000, Default::default());
let scheduler = sd_config.build_scheduler(n_steps);

let clip_config = stable_diffusion::clip_config();
let tokenizer = clip::Tokenizer::create("data/bpe_simple_vocab_16e6.txt", &clip_config)?;
let prompt = prompt.unwrap_or_else(|| {
"A very realistic photo of a rusty robot walking on a sandy beach".to_string()
});
let tokenizer = clip::Tokenizer::create(vocab_file, &sd_config.clip)?;
println!("Running with prompt \"{prompt}\".");
let tokens = tokenizer.encode(&prompt)?;
let tokens: Vec<i64> = tokens.into_iter().map(|x| x as i64).collect();
Expand All @@ -135,19 +187,15 @@ fn run(args: Args) -> anyhow::Result<()> {
let no_grad_guard = tch::no_grad_guard();

println!("Building the Clip transformer.");
let text_model = diffusers::pipelines::stable_diffusion::build_clip_transformer(
&clip_weights,
&clip_config,
clip_device,
)?;
let text_model = sd_config.build_clip_transformer(&clip_weights, clip_device)?;
let text_embeddings = text_model.forward(&tokens);
let uncond_embeddings = text_model.forward(&uncond_tokens);
let text_embeddings = Tensor::cat(&[uncond_embeddings, text_embeddings], 0).to(unet_device);

println!("Building the autoencoder.");
let vae = stable_diffusion::build_vae(&vae_weights, vae_device)?;
let vae = sd_config.build_vae(&vae_weights, vae_device)?;
println!("Building the unet.");
let unet = stable_diffusion::build_unet(&unet_weights, unet_device, 4, sliced_attention_size)?;
let unet = sd_config.build_unet(&unet_weights, unet_device, 4)?;

println!("Generating the latent from the input image {:?}.", init_image.size());
let init_image = init_image.to(vae_device);
Expand Down
109 changes: 78 additions & 31 deletions examples/stable-diffusion-inpaint/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,10 @@
// Sample mask:
// https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png
use clap::Parser;
use diffusers::pipelines::stable_diffusion::v1_5 as stable_diffusion;
use diffusers::{schedulers::ddim, transformers::clip};
use diffusers::pipelines::stable_diffusion;
use diffusers::transformers::clip;
use tch::{nn::Module, Device, Kind, Tensor};

const HEIGHT: i64 = 512;
const WIDTH: i64 = 512;
const GUIDANCE_SCALE: f64 = 7.5;

#[derive(Parser)]
Expand All @@ -37,24 +35,28 @@ struct Args {

/// The prompt to be used for image generation.
#[arg(long, default_value = "Face of a yellow cat, high resolution, sitting on a park bench")]
prompt: Option<String>,
prompt: String,

/// When set, use the CPU for the listed devices, can be 'all', 'unet', 'clip', etc.
/// Multiple values can be set.
#[arg(long)]
cpu: Vec<String>,

#[arg(long, value_name = "FILE", default_value = "data/bpe_simple_vocab_16e6.txt")]
/// The file specifying the vocabulary to used for tokenization.
vocab_file: String,

/// The UNet weight file, in .ot format.
#[arg(long, value_name = "FILE", default_value = "data/unet-inpaint.ot")]
unet_weights: String,
#[arg(long, value_name = "FILE")]
unet_weights: Option<String>,

/// The CLIP weight file, in .ot format.
#[arg(long, value_name = "FILE", default_value = "data/pytorch_model.ot")]
clip_weights: String,
#[arg(long, value_name = "FILE")]
clip_weights: Option<String>,

/// The VAE weight file, in .ot format.
#[arg(long, value_name = "FILE", default_value = "data/vae.ot")]
vae_weights: String,
#[arg(long, value_name = "FILE")]
vae_weights: Option<String>,

/// The size of the sliced attention or 0 to disable slicing (default)
#[arg(long)]
Expand All @@ -75,6 +77,47 @@ struct Args {
/// The name of the final image to generate.
#[arg(long, value_name = "FILE", default_value = "sd_final.png")]
final_image: String,

#[arg(long, value_enum, default_value = "v2-1")]
sd_version: StableDiffusionVersion,
}

#[derive(Debug, Clone, Copy, clap::ValueEnum)]
enum StableDiffusionVersion {
V1_5,
V2_1,
}

impl Args {
fn clip_weights(&self) -> String {
match &self.clip_weights {
Some(w) => w.clone(),
None => match self.sd_version {
StableDiffusionVersion::V1_5 => "data/pytorch_model.ot".to_string(),
StableDiffusionVersion::V2_1 => "data/clip_v2.1.ot".to_string(),
},
}
}

fn vae_weights(&self) -> String {
match &self.vae_weights {
Some(w) => w.clone(),
None => match self.sd_version {
StableDiffusionVersion::V1_5 => "data/vae.ot".to_string(),
StableDiffusionVersion::V2_1 => "data/vae_v2.1.ot".to_string(),
},
}
}

fn unet_weights(&self) -> String {
match &self.unet_weights {
Some(w) => w.clone(),
None => match self.sd_version {
StableDiffusionVersion::V1_5 => "data/unet-inpaint.ot".to_string(),
StableDiffusionVersion::V2_1 => "data/unet-inpaint_v2.1.ot".to_string(),
},
}
}
}

fn prepare_mask_and_masked_image<T: AsRef<std::path::Path>>(
Expand All @@ -92,23 +135,34 @@ fn prepare_mask_and_masked_image<T: AsRef<std::path::Path>>(
}

fn run(args: Args) -> anyhow::Result<()> {
let clip_weights = args.clip_weights();
let vae_weights = args.vae_weights();
let unet_weights = args.unet_weights();
let Args {
prompt,
cpu,
n_steps,
seed,
final_image,
vae_weights,
clip_weights,
unet_weights,
sliced_attention_size,
num_samples,
input_image,
mask_image,
vocab_file,
sd_version,
..
} = args;
tch::maybe_init_cuda();
println!("Cuda available: {}", tch::Cuda::is_available());
println!("Cudnn available: {}", tch::Cuda::cudnn_is_available());
let sd_config = match sd_version {
StableDiffusionVersion::V1_5 => {
stable_diffusion::StableDiffusionConfig::v1_5(sliced_attention_size)
}
StableDiffusionVersion::V2_1 => {
stable_diffusion::StableDiffusionConfig::v2_1(sliced_attention_size)
}
};
let cuda_device = Device::cuda_if_available();
let cpu_or_cuda = |name: &str| {
if cpu.iter().any(|c| c == "all" || c == name) {
Expand All @@ -122,13 +176,9 @@ fn run(args: Args) -> anyhow::Result<()> {
let clip_device = cpu_or_cuda("clip");
let vae_device = cpu_or_cuda("vae");
let unet_device = cpu_or_cuda("unet");
let scheduler = ddim::DDIMScheduler::new(n_steps, 1000, Default::default());
let scheduler = sd_config.build_scheduler(n_steps);

let clip_config = stable_diffusion::clip_config();
let tokenizer = clip::Tokenizer::create("data/bpe_simple_vocab_16e6.txt", &clip_config)?;
let prompt = prompt.unwrap_or_else(|| {
"A very realistic photo of a rusty robot walking on a sandy beach".to_string()
});
let tokenizer = clip::Tokenizer::create(&vocab_file, &sd_config.clip)?;
println!("Running with prompt \"{prompt}\".");
let tokens = tokenizer.encode(&prompt)?;
let tokens: Vec<i64> = tokens.into_iter().map(|x| x as i64).collect();
Expand All @@ -140,22 +190,17 @@ fn run(args: Args) -> anyhow::Result<()> {
let no_grad_guard = tch::no_grad_guard();

println!("Building the Clip transformer.");
let text_model = diffusers::pipelines::stable_diffusion::build_clip_transformer(
&clip_weights,
&clip_config,
clip_device,
)?;
let text_model = sd_config.build_clip_transformer(&clip_weights, clip_device)?;
let text_embeddings = text_model.forward(&tokens);
let uncond_embeddings = text_model.forward(&uncond_tokens);
let text_embeddings = Tensor::cat(&[uncond_embeddings, text_embeddings], 0).to(unet_device);

println!("Building the autoencoder.");
let vae = stable_diffusion::build_vae(&vae_weights, vae_device)?;
let vae = sd_config.build_vae(&vae_weights, vae_device)?;
println!("Building the unet.");
let unet = stable_diffusion::build_unet(&unet_weights, unet_device, 9, sliced_attention_size)?;
let unet = sd_config.build_unet(&unet_weights, unet_device, 9)?;

// torch.nn.functional.interpolate(mask, size=(height // 8, width // 8))
let mask = mask.upsample_nearest2d(&[HEIGHT / 8, WIDTH / 8], None, None);
let mask = mask.upsample_nearest2d(&[sd_config.height / 8, sd_config.width / 8], None, None);
let mask = Tensor::cat(&[&mask, &mask], 0).to_device(unet_device);
let masked_image_dist = vae.encode(&masked_image.to_device(vae_device));

Expand All @@ -164,8 +209,10 @@ fn run(args: Args) -> anyhow::Result<()> {
tch::manual_seed(seed + idx);
let masked_image_latents = (masked_image_dist.sample() * 0.18215).to(unet_device);
let masked_image_latents = Tensor::cat(&[&masked_image_latents, &masked_image_latents], 0);
let mut latents =
Tensor::randn(&[bsize, 4, HEIGHT / 8, WIDTH / 8], (Kind::Float, unet_device));
let mut latents = Tensor::randn(
&[bsize, 4, sd_config.height / 8, sd_config.width / 8],
(Kind::Float, unet_device),
);

for (timestep_index, &timestep) in scheduler.timesteps().iter().enumerate() {
println!("Timestep {timestep_index}/{n_steps}");
Expand Down

0 comments on commit a6b1788

Please sign in to comment.