Skip to content

Commit

Permalink
Use the mps device in all examples.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Feb 26, 2023
1 parent 259fa62 commit 2c8fada
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 35 deletions.
15 changes: 4 additions & 11 deletions examples/stable-diffusion-img2img/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,18 +161,11 @@ fn run(args: Args) -> anyhow::Result<()> {
}
};

let cuda_device = Device::cuda_if_available();
let cpu_or_cuda = |name: &str| {
if cpu.iter().any(|c| c == "all" || c == name) {
Device::Cpu
} else {
cuda_device
}
};
let init_image = image_preprocess(input_image)?;
let clip_device = cpu_or_cuda("clip");
let vae_device = cpu_or_cuda("vae");
let unet_device = cpu_or_cuda("unet");
let device_setup = diffusers::utils::DeviceSetup::new(cpu);
let clip_device = device_setup.get("clip");
let vae_device = device_setup.get("vae");
let unet_device = device_setup.get("unet");
let scheduler = sd_config.build_scheduler(n_steps);

let tokenizer = clip::Tokenizer::create(vocab_file, &sd_config.clip)?;
Expand Down
15 changes: 4 additions & 11 deletions examples/stable-diffusion-inpaint/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,19 +163,12 @@ fn run(args: Args) -> anyhow::Result<()> {
stable_diffusion::StableDiffusionConfig::v2_1(sliced_attention_size)
}
};
let cuda_device = Device::cuda_if_available();
let cpu_or_cuda = |name: &str| {
if cpu.iter().any(|c| c == "all" || c == name) {
Device::Cpu
} else {
cuda_device
}
};
let (mask, masked_image) = prepare_mask_and_masked_image(input_image, mask_image)?;
println!("Loaded input image and mask, {:?} {:?}.", masked_image.size(), mask.size());
let clip_device = cpu_or_cuda("clip");
let vae_device = cpu_or_cuda("vae");
let unet_device = cpu_or_cuda("unet");
let device_setup = diffusers::utils::DeviceSetup::new(cpu);
let clip_device = device_setup.get("clip");
let vae_device = device_setup.get("vae");
let unet_device = device_setup.get("unet");
let scheduler = sd_config.build_scheduler(n_steps);

let tokenizer = clip::Tokenizer::create(vocab_file, &sd_config.clip)?;
Expand Down
16 changes: 4 additions & 12 deletions examples/stable-diffusion/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,18 +230,10 @@ fn run(args: Args) -> anyhow::Result<()> {
stable_diffusion::StableDiffusionConfig::v2_1(sliced_attention_size)
}
};
let accelerator_device =
if tch::utils::has_mps() { Device::Mps } else { Device::cuda_if_available() };
let cpu_or_accelerator = |name: &str| {
if cpu.iter().any(|c| c == "all" || c == name) {
Device::Cpu
} else {
accelerator_device
}
};
let clip_device = cpu_or_accelerator("clip");
let vae_device = cpu_or_accelerator("vae");
let unet_device = cpu_or_accelerator("unet");
let device_setup = diffusers::utils::DeviceSetup::new(cpu);
let clip_device = device_setup.get("clip");
let vae_device = device_setup.get("vae");
let unet_device = device_setup.get("unet");
let scheduler = sd_config.build_scheduler(n_steps);

let tokenizer = clip::Tokenizer::create(vocab_file, &sd_config.clip)?;
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ pub mod models;
pub mod pipelines;
pub mod schedulers;
pub mod transformers;
mod utils;
pub mod utils;
22 changes: 22 additions & 0 deletions src/utils.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,32 @@
// A simple wrapper around File::open adding details about the
// problematic file.
use std::path::Path;
use tch::Device;

pub(crate) fn file_open<P: AsRef<Path>>(path: P) -> anyhow::Result<std::fs::File> {
std::fs::File::open(path.as_ref()).map_err(|e| {
let context = format!("error opening {:?}", path.as_ref().to_string_lossy());
anyhow::Error::new(e).context(context)
})
}

pub struct DeviceSetup {
accelerator_device: Device,
cpu: Vec<String>,
}

impl DeviceSetup {
pub fn new(cpu: Vec<String>) -> Self {
let accelerator_device =
if tch::utils::has_mps() { Device::Mps } else { Device::cuda_if_available() };
Self { accelerator_device, cpu }
}

pub fn get(&self, name: &str) -> Device {
if self.cpu.iter().any(|c| c == "all" || c == name) {
Device::Cpu
} else {
self.accelerator_device
}
}
}

0 comments on commit 2c8fada

Please sign in to comment.