Skip to content

Commit 1d738f7

Browse files
smk2007Sheil Kumarfdwr
authored
Update Torch-DirectML samples and docs for Torch-DirectML 2.3.0 (#610)
* Update Torch-DirectML samples and docs for torch-directml 2.3.0 * Update PyTorch/README.md Co-authored-by: Dwayne Robinson <[email protected]> * Update PyTorch/diffusion/sd/README.md Co-authored-by: Dwayne Robinson <[email protected]> * Update PyTorch/diffusion/sd/README.md Co-authored-by: Dwayne Robinson <[email protected]> * Update PyTorch/diffusion/sd/README.md Co-authored-by: Dwayne Robinson <[email protected]> * Update PyTorch/diffusion/sd/app.py Co-authored-by: Dwayne Robinson <[email protected]> * Update PyTorch/diffusion/sd/app.py Co-authored-by: Dwayne Robinson <[email protected]> --------- Co-authored-by: Sheil Kumar <[email protected]> Co-authored-by: Dwayne Robinson <[email protected]>
1 parent 61a1a50 commit 1d738f7

File tree

9 files changed

+165
-2
lines changed

9 files changed

+165
-2
lines changed

PyTorch/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,16 @@ pip install torch-directml
1515
```
1616

1717
## Samples
18-
For `torch-directml` samples find brief summaries below or explore the [cv](./cv/), [transformer](./transformer/) or [llm](./llm/) folders:
18+
Try the `torch-directml` samples below, or explore the [cv](./cv/), [transformer](./transformer/), [llm](./llm/) and [diffusion](./diffusion/) folders:
1919
* [attention is all you need - the original transformer model](./transformer/attention_is_all_you_need/)
2020
* [yolov3 - a real-time object detection model](./cv/yolov3/)
2121
* [squeezenet - a small image classification model](./cv/squeezenet)
2222
* [resnet50 - an image classification model](./cv/resnet50)
2323
* [maskrcnn - an object detection model](./cv/objectDetection/maskrcnn/)
2424
* [llm - a text generation and chatbot app supporting various language models](./llm/)
2525
* [whisper - a general-purpose speech recognition model](./audio/whisper/)
26+
* [Stable Diffusion Turbo & XL Turbo - a text-to-image generation model](./diffusion/sd/)
2627

2728
## External Links
28-
2929
* [torch-directml PyPI project](https://pypi.org/project/torch-directml/)
3030
* [PyTorch homepage](https://pytorch.org/)

PyTorch/diffusion/sd/.DS_Store

6 KB
Binary file not shown.

PyTorch/diffusion/sd/README.md

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Stable Diffusion Turbo & XL Turbo
2+
This sample provides a simple way to load and run Stability AI's text-to-image generation models, Stable Diffusion Turbo & XL Turbo, with our DirectML-backend.
3+
4+
- [About the Models](#about-the-models)
5+
- [Setup](#setup)
6+
- [Run the App](#run-the-app)
7+
- [External Links](#external-links)
8+
- [Model License](#model-license)
9+
10+
11+
## About the Models
12+
13+
Stable Diffusion Turbo & XL Turbo are distilled versions of SD 2.1 and SDXL 1.0 respectively. Both models are fast generative text-to-image model that can synthesize photorealistic images from a text prompt in a single network evaluation.
14+
15+
Refer to the HuggingFace repositories for [SDXL Turbo](https://huggingface.co/stabilityai/sdxl-turbo) and [SD Turbo](https://huggingface.co/stabilityai/sd-turbo) for more information.
16+
17+
18+
## Setup
19+
Once you've set up `torch-directml` following our [Windows](https://learn.microsoft.com/en-us/windows/ai/directml/pytorch-windows) and [WSL](https://learn.microsoft.com/en-us/windows/ai/directml/pytorch-wsl) guidance, install the requirements by running:
20+
21+
22+
```
23+
pip install -r requirements.txt
24+
```
25+
26+
27+
## Run the App
28+
To use Stable Diffusion with the text-to-image interface, run:
29+
```bash
30+
> python app.py
31+
```
32+
33+
When you run this code, a local URL will be displayed on the console. Open http://localhost:7860 (or the local URL you see) in a browser to interact with the text-to-image interface.
34+
35+
Within the interface, use the dropdown to switch between SD Turbo and SDXL Turbo. You can also use the slider to set the number of iteration steps (1 to 4) for image generation.
36+
37+
![slider_dropdown](assets/slider_dropdown.png)
38+
39+
40+
Enter the desired prompt and "Run" to generate an image:
41+
```
42+
Sample Prompt: A professional photo of a cat eating cake
43+
```
44+
45+
Two sample images will be generated:
46+
![image1](assets/t2i.png)
47+
48+
49+
50+
## External Links
51+
- [SDXL Turbo HuggingFace Repo](https://huggingface.co/stabilityai/sdxl-turbo)
52+
- [SD Turbo HuggingFace Repo](https://huggingface.co/stabilityai/sd-turbo)
53+
54+
55+
## Model License
56+
The models are intended for both non-commercial and commercial usage under the following licenses: [SDXL Turbo](https://huggingface.co/stabilityai/sdxl-turbo/blob/main/LICENSE.md), [SD Turbo](https://huggingface.co/stabilityai/sdxl-turbo/blob/main/LICENSE.md).
57+
58+
For commercial use, please refer to https://stability.ai/license.

PyTorch/diffusion/sd/app.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
import torch
2+
import torch_directml
3+
import gradio as gr
4+
from diffusers import AutoPipelineForText2Image, StableDiffusionPipeline, LMSDiscreteScheduler
5+
from PIL import Image
6+
import numpy as np
7+
8+
def preprocess(image):
9+
w, h = image.size
10+
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
11+
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
12+
image = np.array(image).astype(np.float32) / 255.0
13+
image = image[None].transpose(0, 3, 1, 2)
14+
image = torch.from_numpy(image)
15+
return 2. * image - 1.
16+
17+
lms = LMSDiscreteScheduler(
18+
beta_start=0.00085,
19+
beta_end=0.012,
20+
beta_schedule="scaled_linear"
21+
)
22+
23+
device = torch_directml.device(torch_directml.default_device())
24+
25+
block = gr.Blocks(css=".container { max-width: 800px; margin: auto; }")
26+
num_samples = 2
27+
28+
def load_model(model_name):
29+
return AutoPipelineForText2Image.from_pretrained(
30+
model_name,
31+
torch_dtype=torch.float16,
32+
variant="fp16"
33+
).to(device)
34+
35+
model_name = "stabilityai/sd-turbo"
36+
pipe = load_model("stabilityai/sd-turbo")
37+
38+
def infer(prompt, inference_step, model_selector):
39+
global model_name, pipe
40+
41+
if model_selector == "SD Turbo":
42+
if model_name != "stabilityai/sd-turbo":
43+
model_name = "stabilityai/sd-turbo"
44+
pipe = load_model("stabilityai/sd-turbo")
45+
else:
46+
if model_name != "stabilityai/sdxl-turbo":
47+
model_name = "stabilityai/sdxl-turbo"
48+
pipe = load_model("stabilityai/sdxl-turbo")
49+
50+
images = pipe(prompt=[prompt] * num_samples, num_inference_steps=inference_step, guidance_scale=0.0)[0]
51+
return images
52+
53+
54+
with block as demo:
55+
gr.Markdown("<h1><center>Stable Diffusion Turbo and XL Turbo with DirectML Backend</center></h1>")
56+
57+
with gr.Group():
58+
with gr.Box():
59+
with gr.Row().style(mobile_collapse=False, equal_height=True):
60+
61+
text = gr.Textbox(
62+
label="Enter your prompt", show_label=False, max_lines=1
63+
).style(
64+
border=(True, False, True, True),
65+
rounded=(True, False, False, True),
66+
container=False,
67+
)
68+
btn = gr.Button("Run").style(
69+
margin=False,
70+
rounded=(False, True, True, False),
71+
)
72+
with gr.Row().style(mobile_collapse=False, equal_height=True):
73+
iteration_slider = gr.Slider(
74+
label="Steps",
75+
step = 1,
76+
maximum = 4,
77+
minimum = 1,
78+
value = 1
79+
)
80+
81+
model_selector = gr.Dropdown(
82+
["SD Turbo", "SD Turbo XL"], label="Model", info="Select the SD model to use", value="SD Turbo"
83+
)
84+
85+
gallery = gr.Gallery(label="Generated images", show_label=False).style(
86+
grid=[2], height="auto"
87+
)
88+
text.submit(infer, inputs=[text, iteration_slider, model_selector], outputs=gallery)
89+
btn.click(infer, inputs=[text, iteration_slider, model_selector], outputs=gallery)
90+
91+
gr.Markdown(
92+
"""___
93+
<p style='text-align: center'>
94+
Created by CompVis and Stability AI
95+
<br/>
96+
</p>"""
97+
)
98+
99+
demo.launch(debug=True)
25.3 KB
Loading

PyTorch/diffusion/sd/assets/t2i.png

1.09 MB
Loading
348 KB
Loading
380 KB
Loading

PyTorch/diffusion/sd/requirements.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
diffusers==0.29.2
2+
gradio==3.13.2
3+
numpy==1.26.4
4+
Pillow==10.4.0
5+
scipy==1.14.0
6+
transformers==4.42.3

0 commit comments

Comments
 (0)