Skip to content

Commit

Permalink
Improve model offloading and GPU memory management
Browse files Browse the repository at this point in the history
  • Loading branch information
IndigoDosSantos authored Jun 8, 2024
1 parent d8c1607 commit 53af46a
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,15 +79,14 @@ def generate_images(prompt, height, width, negative_prompt, guidance_scale, num_
# Sanitize user input prompt before using it, with a timeout of 5 seconds
cleaned_prompt = clean_prompt_with_timeout(prompt, timeout=5)
print("Processed prompt:", cleaned_prompt)

# Load, use, and discard the prior model
prior = load_model("prior")

with torch.cuda.amp.autocast(dtype=dtype):
seed = torch.seed() if seed == -1 else seed # Get the initial seed
torch.manual_seed(seed) # Apply the seed for generation
generator = torch.Generator(device).manual_seed(seed) # Preserve for reproducibility

# Load, use, and discard the prior model
prior = load_model("prior")
prior.enable_model_cpu_offload()
prior_output = prior(
prompt=cleaned_prompt,
Expand All @@ -99,6 +98,8 @@ def generate_images(prompt, height, width, negative_prompt, guidance_scale, num_
num_images_per_prompt=int(num_images_per_prompt),
generator=generator,
)
del prior
torch.cuda.empty_cache() # Release GPU memory

# Load, use, and discard the decoder model
decoder = load_model("decoder")
Expand All @@ -112,6 +113,8 @@ def generate_images(prompt, height, width, negative_prompt, guidance_scale, num_
output_type="pil",
generator=generator,
).images
del decoder
torch.cuda.empty_cache() # Release GPU memory

metadata_embedded = {
"parameters": "Stable Cascade",
Expand Down

0 comments on commit 53af46a

Please sign in to comment.