Skip to content

Commit

Permalink
Update Torch-DirectML samples and docs for Torch-DirectML 2.3.0 (#610)
Browse files Browse the repository at this point in the history
* Update Torch-DirectML samples and docs for torch-directml 2.3.0

* Update PyTorch/README.md

Co-authored-by: Dwayne Robinson <[email protected]>

* Update PyTorch/diffusion/sd/README.md

Co-authored-by: Dwayne Robinson <[email protected]>

* Update PyTorch/diffusion/sd/README.md

Co-authored-by: Dwayne Robinson <[email protected]>

* Update PyTorch/diffusion/sd/README.md

Co-authored-by: Dwayne Robinson <[email protected]>

* Update PyTorch/diffusion/sd/app.py

Co-authored-by: Dwayne Robinson <[email protected]>

* Update PyTorch/diffusion/sd/app.py

Co-authored-by: Dwayne Robinson <[email protected]>

---------

Co-authored-by: Sheil Kumar <[email protected]>
Co-authored-by: Dwayne Robinson <[email protected]>
  • Loading branch information
3 people authored Jul 17, 2024
1 parent 61a1a50 commit 1d738f7
Show file tree
Hide file tree
Showing 9 changed files with 165 additions and 2 deletions.
4 changes: 2 additions & 2 deletions PyTorch/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,16 @@ pip install torch-directml
```

## Samples
For `torch-directml` samples find brief summaries below or explore the [cv](./cv/), [transformer](./transformer/) or [llm](./llm/) folders:
Try the `torch-directml` samples below, or explore the [cv](./cv/), [transformer](./transformer/), [llm](./llm/) and [diffusion](./diffusion/) folders:
* [attention is all you need - the original transformer model](./transformer/attention_is_all_you_need/)
* [yolov3 - a real-time object detection model](./cv/yolov3/)
* [squeezenet - a small image classification model](./cv/squeezenet)
* [resnet50 - an image classification model](./cv/resnet50)
* [maskrcnn - an object detection model](./cv/objectDetection/maskrcnn/)
* [llm - a text generation and chatbot app supporting various language models](./llm/)
* [whisper - a general-purpose speech recognition model](./audio/whisper/)
* [Stable Diffusion Turbo & XL Turbo - a text-to-image generation model](./diffusion/sd/)

## External Links

* [torch-directml PyPI project](https://pypi.org/project/torch-directml/)
* [PyTorch homepage](https://pytorch.org/)
Binary file added PyTorch/diffusion/sd/.DS_Store
Binary file not shown.
58 changes: 58 additions & 0 deletions PyTorch/diffusion/sd/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Stable Diffusion Turbo & XL Turbo
This sample provides a simple way to load and run Stability AI's text-to-image generation models, Stable Diffusion Turbo & XL Turbo, with our DirectML-backend.

- [About the Models](#about-the-models)
- [Setup](#setup)
- [Run the App](#run-the-app)
- [External Links](#external-links)
- [Model License](#model-license)


## About the Models

Stable Diffusion Turbo & XL Turbo are distilled versions of SD 2.1 and SDXL 1.0 respectively. Both models are fast generative text-to-image model that can synthesize photorealistic images from a text prompt in a single network evaluation.

Refer to the HuggingFace repositories for [SDXL Turbo](https://huggingface.co/stabilityai/sdxl-turbo) and [SD Turbo](https://huggingface.co/stabilityai/sd-turbo) for more information.


## Setup
Once you've set up `torch-directml` following our [Windows](https://learn.microsoft.com/en-us/windows/ai/directml/pytorch-windows) and [WSL](https://learn.microsoft.com/en-us/windows/ai/directml/pytorch-wsl) guidance, install the requirements by running:


```
pip install -r requirements.txt
```


## Run the App
To use Stable Diffusion with the text-to-image interface, run:
```bash
> python app.py
```

When you run this code, a local URL will be displayed on the console. Open http://localhost:7860 (or the local URL you see) in a browser to interact with the text-to-image interface.

Within the interface, use the dropdown to switch between SD Turbo and SDXL Turbo. You can also use the slider to set the number of iteration steps (1 to 4) for image generation.

![slider_dropdown](assets/slider_dropdown.png)


Enter the desired prompt and "Run" to generate an image:
```
Sample Prompt: A professional photo of a cat eating cake
```

Two sample images will be generated:
![image1](assets/t2i.png)



## External Links
- [SDXL Turbo HuggingFace Repo](https://huggingface.co/stabilityai/sdxl-turbo)
- [SD Turbo HuggingFace Repo](https://huggingface.co/stabilityai/sd-turbo)


## Model License
The models are intended for both non-commercial and commercial usage under the following licenses: [SDXL Turbo](https://huggingface.co/stabilityai/sdxl-turbo/blob/main/LICENSE.md), [SD Turbo](https://huggingface.co/stabilityai/sdxl-turbo/blob/main/LICENSE.md).

For commercial use, please refer to https://stability.ai/license.
99 changes: 99 additions & 0 deletions PyTorch/diffusion/sd/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import torch
import torch_directml
import gradio as gr
from diffusers import AutoPipelineForText2Image, StableDiffusionPipeline, LMSDiscreteScheduler
from PIL import Image
import numpy as np

def preprocess(image):
w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return 2. * image - 1.

lms = LMSDiscreteScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear"
)

device = torch_directml.device(torch_directml.default_device())

block = gr.Blocks(css=".container { max-width: 800px; margin: auto; }")
num_samples = 2

def load_model(model_name):
return AutoPipelineForText2Image.from_pretrained(
model_name,
torch_dtype=torch.float16,
variant="fp16"
).to(device)

model_name = "stabilityai/sd-turbo"
pipe = load_model("stabilityai/sd-turbo")

def infer(prompt, inference_step, model_selector):
global model_name, pipe

if model_selector == "SD Turbo":
if model_name != "stabilityai/sd-turbo":
model_name = "stabilityai/sd-turbo"
pipe = load_model("stabilityai/sd-turbo")
else:
if model_name != "stabilityai/sdxl-turbo":
model_name = "stabilityai/sdxl-turbo"
pipe = load_model("stabilityai/sdxl-turbo")

images = pipe(prompt=[prompt] * num_samples, num_inference_steps=inference_step, guidance_scale=0.0)[0]
return images


with block as demo:
gr.Markdown("<h1><center>Stable Diffusion Turbo and XL Turbo with DirectML Backend</center></h1>")

with gr.Group():
with gr.Box():
with gr.Row().style(mobile_collapse=False, equal_height=True):

text = gr.Textbox(
label="Enter your prompt", show_label=False, max_lines=1
).style(
border=(True, False, True, True),
rounded=(True, False, False, True),
container=False,
)
btn = gr.Button("Run").style(
margin=False,
rounded=(False, True, True, False),
)
with gr.Row().style(mobile_collapse=False, equal_height=True):
iteration_slider = gr.Slider(
label="Steps",
step = 1,
maximum = 4,
minimum = 1,
value = 1
)

model_selector = gr.Dropdown(
["SD Turbo", "SD Turbo XL"], label="Model", info="Select the SD model to use", value="SD Turbo"
)

gallery = gr.Gallery(label="Generated images", show_label=False).style(
grid=[2], height="auto"
)
text.submit(infer, inputs=[text, iteration_slider, model_selector], outputs=gallery)
btn.click(infer, inputs=[text, iteration_slider, model_selector], outputs=gallery)

gr.Markdown(
"""___
<p style='text-align: center'>
Created by CompVis and Stability AI
<br/>
</p>"""
)

demo.launch(debug=True)
Binary file added PyTorch/diffusion/sd/assets/slider_dropdown.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added PyTorch/diffusion/sd/assets/t2i.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added PyTorch/diffusion/sd/assets/tmp31loaztf.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added PyTorch/diffusion/sd/assets/tmpl0cj8qg9.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
6 changes: 6 additions & 0 deletions PyTorch/diffusion/sd/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
diffusers==0.29.2
gradio==3.13.2
numpy==1.26.4
Pillow==10.4.0
scipy==1.14.0
transformers==4.42.3

0 comments on commit 1d738f7

Please sign in to comment.