-
Notifications
You must be signed in to change notification settings - Fork 303
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update Torch-DirectML samples and docs for Torch-DirectML 2.3.0 (#610)
* Update Torch-DirectML samples and docs for torch-directml 2.3.0 * Update PyTorch/README.md Co-authored-by: Dwayne Robinson <[email protected]> * Update PyTorch/diffusion/sd/README.md Co-authored-by: Dwayne Robinson <[email protected]> * Update PyTorch/diffusion/sd/README.md Co-authored-by: Dwayne Robinson <[email protected]> * Update PyTorch/diffusion/sd/README.md Co-authored-by: Dwayne Robinson <[email protected]> * Update PyTorch/diffusion/sd/app.py Co-authored-by: Dwayne Robinson <[email protected]> * Update PyTorch/diffusion/sd/app.py Co-authored-by: Dwayne Robinson <[email protected]> --------- Co-authored-by: Sheil Kumar <[email protected]> Co-authored-by: Dwayne Robinson <[email protected]>
- Loading branch information
1 parent
61a1a50
commit 1d738f7
Showing
9 changed files
with
165 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
# Stable Diffusion Turbo & XL Turbo | ||
This sample provides a simple way to load and run Stability AI's text-to-image generation models, Stable Diffusion Turbo & XL Turbo, with our DirectML-backend. | ||
|
||
- [About the Models](#about-the-models) | ||
- [Setup](#setup) | ||
- [Run the App](#run-the-app) | ||
- [External Links](#external-links) | ||
- [Model License](#model-license) | ||
|
||
|
||
## About the Models | ||
|
||
Stable Diffusion Turbo & XL Turbo are distilled versions of SD 2.1 and SDXL 1.0 respectively. Both models are fast generative text-to-image model that can synthesize photorealistic images from a text prompt in a single network evaluation. | ||
|
||
Refer to the HuggingFace repositories for [SDXL Turbo](https://huggingface.co/stabilityai/sdxl-turbo) and [SD Turbo](https://huggingface.co/stabilityai/sd-turbo) for more information. | ||
|
||
|
||
## Setup | ||
Once you've set up `torch-directml` following our [Windows](https://learn.microsoft.com/en-us/windows/ai/directml/pytorch-windows) and [WSL](https://learn.microsoft.com/en-us/windows/ai/directml/pytorch-wsl) guidance, install the requirements by running: | ||
|
||
|
||
``` | ||
pip install -r requirements.txt | ||
``` | ||
|
||
|
||
## Run the App | ||
To use Stable Diffusion with the text-to-image interface, run: | ||
```bash | ||
> python app.py | ||
``` | ||
|
||
When you run this code, a local URL will be displayed on the console. Open http://localhost:7860 (or the local URL you see) in a browser to interact with the text-to-image interface. | ||
|
||
Within the interface, use the dropdown to switch between SD Turbo and SDXL Turbo. You can also use the slider to set the number of iteration steps (1 to 4) for image generation. | ||
|
||
![slider_dropdown](assets/slider_dropdown.png) | ||
|
||
|
||
Enter the desired prompt and "Run" to generate an image: | ||
``` | ||
Sample Prompt: A professional photo of a cat eating cake | ||
``` | ||
|
||
Two sample images will be generated: | ||
![image1](assets/t2i.png) | ||
|
||
|
||
|
||
## External Links | ||
- [SDXL Turbo HuggingFace Repo](https://huggingface.co/stabilityai/sdxl-turbo) | ||
- [SD Turbo HuggingFace Repo](https://huggingface.co/stabilityai/sd-turbo) | ||
|
||
|
||
## Model License | ||
The models are intended for both non-commercial and commercial usage under the following licenses: [SDXL Turbo](https://huggingface.co/stabilityai/sdxl-turbo/blob/main/LICENSE.md), [SD Turbo](https://huggingface.co/stabilityai/sdxl-turbo/blob/main/LICENSE.md). | ||
|
||
For commercial use, please refer to https://stability.ai/license. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
import torch | ||
import torch_directml | ||
import gradio as gr | ||
from diffusers import AutoPipelineForText2Image, StableDiffusionPipeline, LMSDiscreteScheduler | ||
from PIL import Image | ||
import numpy as np | ||
|
||
def preprocess(image): | ||
w, h = image.size | ||
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 | ||
image = image.resize((w, h), resample=PIL.Image.LANCZOS) | ||
image = np.array(image).astype(np.float32) / 255.0 | ||
image = image[None].transpose(0, 3, 1, 2) | ||
image = torch.from_numpy(image) | ||
return 2. * image - 1. | ||
|
||
lms = LMSDiscreteScheduler( | ||
beta_start=0.00085, | ||
beta_end=0.012, | ||
beta_schedule="scaled_linear" | ||
) | ||
|
||
device = torch_directml.device(torch_directml.default_device()) | ||
|
||
block = gr.Blocks(css=".container { max-width: 800px; margin: auto; }") | ||
num_samples = 2 | ||
|
||
def load_model(model_name): | ||
return AutoPipelineForText2Image.from_pretrained( | ||
model_name, | ||
torch_dtype=torch.float16, | ||
variant="fp16" | ||
).to(device) | ||
|
||
model_name = "stabilityai/sd-turbo" | ||
pipe = load_model("stabilityai/sd-turbo") | ||
|
||
def infer(prompt, inference_step, model_selector): | ||
global model_name, pipe | ||
|
||
if model_selector == "SD Turbo": | ||
if model_name != "stabilityai/sd-turbo": | ||
model_name = "stabilityai/sd-turbo" | ||
pipe = load_model("stabilityai/sd-turbo") | ||
else: | ||
if model_name != "stabilityai/sdxl-turbo": | ||
model_name = "stabilityai/sdxl-turbo" | ||
pipe = load_model("stabilityai/sdxl-turbo") | ||
|
||
images = pipe(prompt=[prompt] * num_samples, num_inference_steps=inference_step, guidance_scale=0.0)[0] | ||
return images | ||
|
||
|
||
with block as demo: | ||
gr.Markdown("<h1><center>Stable Diffusion Turbo and XL Turbo with DirectML Backend</center></h1>") | ||
|
||
with gr.Group(): | ||
with gr.Box(): | ||
with gr.Row().style(mobile_collapse=False, equal_height=True): | ||
|
||
text = gr.Textbox( | ||
label="Enter your prompt", show_label=False, max_lines=1 | ||
).style( | ||
border=(True, False, True, True), | ||
rounded=(True, False, False, True), | ||
container=False, | ||
) | ||
btn = gr.Button("Run").style( | ||
margin=False, | ||
rounded=(False, True, True, False), | ||
) | ||
with gr.Row().style(mobile_collapse=False, equal_height=True): | ||
iteration_slider = gr.Slider( | ||
label="Steps", | ||
step = 1, | ||
maximum = 4, | ||
minimum = 1, | ||
value = 1 | ||
) | ||
|
||
model_selector = gr.Dropdown( | ||
["SD Turbo", "SD Turbo XL"], label="Model", info="Select the SD model to use", value="SD Turbo" | ||
) | ||
|
||
gallery = gr.Gallery(label="Generated images", show_label=False).style( | ||
grid=[2], height="auto" | ||
) | ||
text.submit(infer, inputs=[text, iteration_slider, model_selector], outputs=gallery) | ||
btn.click(infer, inputs=[text, iteration_slider, model_selector], outputs=gallery) | ||
|
||
gr.Markdown( | ||
"""___ | ||
<p style='text-align: center'> | ||
Created by CompVis and Stability AI | ||
<br/> | ||
</p>""" | ||
) | ||
|
||
demo.launch(debug=True) |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
diffusers==0.29.2 | ||
gradio==3.13.2 | ||
numpy==1.26.4 | ||
Pillow==10.4.0 | ||
scipy==1.14.0 | ||
transformers==4.42.3 |