Skip to content

Commit 0e82a75

Browse files
committed
Implementation of Pipeline and TensorRT
1 parent cee1633 commit 0e82a75

20 files changed

+1912
-292
lines changed

.gitignore

+3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
# https://github.com/github/gitignore/blob/main/Python.gitignore
22

3+
./model.safetensors
4+
./model.ckpt
5+
36
# Byte-compiled / optimized / DLL files
47
__pycache__/
58
*.py[cod]

examples/benchmark/main.py

+82
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import io
2+
from typing import *
3+
4+
import fire
5+
import PIL.Image
6+
import requests
7+
import torch
8+
from diffusers import AutoencoderTiny, LCMScheduler, StableDiffusionPipeline
9+
from tqdm import tqdm
10+
11+
from streamdiffusion import StreamDiffusion
12+
from streamdiffusion.image_utils import pil2tensor, postprocess_image
13+
14+
15+
def download_image(url: str):
16+
response = requests.get(url)
17+
image = PIL.Image.open(io.BytesIO(response.content))
18+
return image
19+
20+
21+
def run(
22+
wamup: int = 10, iterations: int = 50, acceleration: Optional[Literal["xformers", "sfast", "tensorrt"]] = None
23+
):
24+
pipe: StableDiffusionPipeline = StableDiffusionPipeline.from_single_file("./model.safetensors").to(
25+
device=torch.device("cuda"),
26+
dtype=torch.float16,
27+
)
28+
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
29+
pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taesd").to(device=pipe.device, dtype=pipe.dtype)
30+
pipe.load_lora_weights("latent-consistency/lcm-lora-sdv1-5")
31+
pipe.fuse_lora()
32+
33+
stream = StreamDiffusion(
34+
pipe,
35+
[32, 45],
36+
torch_dtype=torch.float16,
37+
)
38+
39+
if acceleration == "xformers":
40+
pipe.enable_xformers_memory_efficient_attention()
41+
elif acceleration == "tensorrt":
42+
from streamdiffusion.acceleration.tensorrt import accelerate_with_tensorrt
43+
44+
stream = accelerate_with_tensorrt(stream)
45+
elif acceleration == "sfast":
46+
from streamdiffusion.acceleration.sfast import accelerate_with_stable_fast
47+
48+
stream = accelerate_with_stable_fast(stream)
49+
50+
stream.prepare(
51+
"Girl with panda ears wearing a hood",
52+
num_inference_steps=50,
53+
generator=torch.manual_seed(2),
54+
)
55+
56+
image = download_image("https://github.com/ddpn08.png").resize((512, 512))
57+
input_tensor = pil2tensor(image)
58+
59+
# warmup
60+
for _ in range(wamup):
61+
stream(input_tensor.detach().clone().to(device=stream.device, dtype=stream.dtype))
62+
63+
results = []
64+
65+
for _ in tqdm(range(iterations)):
66+
start = torch.cuda.Event(enable_timing=True)
67+
end = torch.cuda.Event(enable_timing=True)
68+
69+
start.record()
70+
x_output = stream(input_tensor.detach().clone().to(device=stream.device, dtype=stream.dtype))
71+
postprocess_image(x_output, output_type="pil")[0]
72+
end.record()
73+
74+
torch.cuda.synchronize()
75+
results.append(start.elapsed_time(end))
76+
77+
print(f"Average time: {sum(results) / len(results)}ms")
78+
print(f"Average FPS: {1000 / (sum(results) / len(results))}")
79+
80+
81+
if __name__ == "__main__":
82+
fire.Fire(run)

examples/img2img/main.py

+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import os
2+
from typing import *
3+
4+
import fire
5+
import PIL.Image
6+
import torch
7+
from diffusers import AutoencoderTiny, LCMScheduler, StableDiffusionPipeline
8+
9+
from streamdiffusion import StreamDiffusion
10+
from streamdiffusion.image_utils import pil2tensor, postprocess_image
11+
12+
13+
def main(input: str, output: str, scale: int = 1):
14+
pipe: StableDiffusionPipeline = StableDiffusionPipeline.from_single_file("./model.safetensors").to(
15+
device=torch.device("cuda"),
16+
dtype=torch.float16,
17+
)
18+
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
19+
pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taesd").to(device=pipe.device, dtype=pipe.dtype)
20+
pipe.load_lora_weights("latent-consistency/lcm-lora-sdv1-5")
21+
pipe.fuse_lora()
22+
pipe.enable_xformers_memory_efficient_attention()
23+
24+
input_image = PIL.Image.open(os.path.join(input))
25+
width = int(input_image.width * scale)
26+
height = int(input_image.height * scale)
27+
28+
stream = StreamDiffusion(
29+
pipe,
30+
[35, 45],
31+
torch_dtype=torch.float16,
32+
width=width,
33+
height=height,
34+
)
35+
stream.prepare(
36+
"Girl with panda ears wearing a hood",
37+
num_inference_steps=50,
38+
generator=torch.manual_seed(2),
39+
)
40+
41+
input_image = input_image.resize((width, height))
42+
input_tensor = pil2tensor(input_image)
43+
44+
for _ in range(stream.batch_size - 1):
45+
stream(input_tensor.detach().clone().to(device=stream.device, dtype=stream.dtype))
46+
47+
output_x = stream(input_tensor.detach().clone().to(device=stream.device, dtype=stream.dtype))
48+
output_image = postprocess_image(output_x, output_type="pil")[0]
49+
output_image.save(output)
50+
51+
52+
if __name__ == "__main__":
53+
fire.Fire(main)

examples/mov2mov/main.py

+85
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import os
2+
from typing import *
3+
4+
import ffmpeg
5+
import fire
6+
import PIL.Image
7+
import torch
8+
from diffusers import AutoencoderTiny, LCMScheduler, StableDiffusionPipeline
9+
from tqdm import tqdm
10+
11+
from streamdiffusion import StreamDiffusion
12+
from streamdiffusion.acceleration.sfast import accelerate_with_stable_fast
13+
from streamdiffusion.image_utils import pil2tensor, postprocess_image
14+
15+
16+
def extract_frames(video_path: str, output_dir: str):
17+
os.makedirs(output_dir, exist_ok=True)
18+
ffmpeg.input(video_path).output(f"{output_dir}/%04d.png").run()
19+
20+
21+
def get_frame_rate(video_path: str):
22+
probe = ffmpeg.probe(video_path)
23+
video_info = next(s for s in probe["streams"] if s["codec_type"] == "video")
24+
return int(video_info["r_frame_rate"].split("/")[0])
25+
26+
27+
def main(input: str, output: str, scale: int = 1):
28+
if os.path.isdir(output):
29+
raise ValueError("Output directory already exists")
30+
frame_rate = get_frame_rate(input)
31+
extract_frames(input, os.path.join(output, "frames"))
32+
images = sorted(os.listdir(os.path.join(output, "frames")))
33+
34+
pipe: StableDiffusionPipeline = StableDiffusionPipeline.from_single_file("./model.safetensors").to(
35+
device=torch.device("cuda"),
36+
dtype=torch.float16,
37+
)
38+
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
39+
pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taesd").to(device=pipe.device, dtype=pipe.dtype)
40+
pipe.load_lora_weights("latent-consistency/lcm-lora-sdv1-5")
41+
pipe.fuse_lora()
42+
43+
sample_image = PIL.Image.open(os.path.join(output, "frames", images[0]))
44+
width = int(sample_image.width * scale)
45+
height = int(sample_image.height * scale)
46+
47+
stream = StreamDiffusion(
48+
pipe,
49+
[40, 49],
50+
torch_dtype=torch.float16,
51+
width=width,
52+
height=height,
53+
)
54+
stream = accelerate_with_stable_fast(stream)
55+
stream.prepare(
56+
"Girl with panda ears wearing a hood",
57+
num_inference_steps=50,
58+
generator=torch.manual_seed(2),
59+
)
60+
61+
for _ in range(stream.batch_size - 1):
62+
stream(
63+
pil2tensor(sample_image.resize((width, height)))
64+
.detach()
65+
.clone()
66+
.to(device=stream.device, dtype=stream.dtype)
67+
)
68+
69+
for image_path in tqdm(images + [images[0]] * (stream.batch_size - 1)):
70+
pil_image = PIL.Image.open(os.path.join(output, "frames", image_path))
71+
pil_image = pil_image.resize((width, height))
72+
input_tensor = pil2tensor(pil_image)
73+
output_x = stream(input_tensor.detach().clone().to(device=stream.device, dtype=stream.dtype))
74+
output_image = postprocess_image(output_x, output_type="pil")[0]
75+
output_image.save(os.path.join(output, image_path))
76+
77+
output_video_path = os.path.join(output, "output.mp4")
78+
79+
ffmpeg.input(os.path.join(output, "%04d.png"), framerate=frame_rate).output(
80+
output_video_path, crf=17, pix_fmt="yuv420p", vcodec="libx264"
81+
).run()
82+
83+
84+
if __name__ == "__main__":
85+
fire.Fire(main)

examples/screen/main.py

+109
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
import io
2+
import multiprocessing as mp
3+
import threading
4+
import time
5+
from time import sleep
6+
from typing import *
7+
8+
import fire
9+
import mss
10+
import PIL.Image
11+
import torch
12+
from diffusers import AutoencoderTiny, LCMScheduler, StableDiffusionPipeline
13+
from matplotlib import pyplot as plt
14+
from socks import UDP, receive_udp_data
15+
16+
from streamdiffusion import StreamDiffusion
17+
from streamdiffusion.acceleration.tensorrt import accelerate_with_tensorrt
18+
from streamdiffusion.image_utils import pil2tensor, postprocess_image
19+
20+
21+
input = None
22+
23+
24+
def screen(
25+
height: int = 512,
26+
width: int = 512,
27+
monitor: Dict[str, int] = {"top": 300, "left": 200, "width": 512, "height": 512},
28+
):
29+
global input
30+
with mss.mss() as sct:
31+
while True:
32+
img = sct.grab(monitor)
33+
img = PIL.Image.frombytes("RGB", img.size, img.bgra, "raw", "BGRX")
34+
img.resize((height, width))
35+
input = pil2tensor(img)
36+
37+
38+
def result_window(server_ip: str, server_port: int):
39+
plt.ion()
40+
fig, ax = plt.subplots(figsize=(8, 8))
41+
42+
while True:
43+
received_data = receive_udp_data(server_ip, server_port)
44+
images = PIL.Image.open(io.BytesIO(received_data))
45+
ax.clear()
46+
ax.imshow(images)
47+
ax.axis("off")
48+
plt.pause(0.00001)
49+
50+
51+
def run(address: str = "127.0.0.1", port: int = 8080):
52+
pipe: StableDiffusionPipeline = StableDiffusionPipeline.from_single_file("./model.safetensors").to(
53+
device=torch.device("cuda")
54+
)
55+
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
56+
pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taesd").to(device=pipe.device, dtype=pipe.dtype)
57+
pipe.load_lora_weights("latent-consistency/lcm-lora-sdv1-5")
58+
pipe.fuse_lora()
59+
60+
stream = StreamDiffusion(
61+
pipe,
62+
[32, 45],
63+
)
64+
stream = accelerate_with_tensorrt(stream, "./engines", max_batch_size=2)
65+
stream.prepare(
66+
"Girl with panda ears wearing a hood",
67+
num_inference_steps=50,
68+
generator=torch.manual_seed(2),
69+
)
70+
71+
output_window = mp.Process(target=result_window, args=(address, port))
72+
input_screen = threading.Thread(target=screen)
73+
74+
output_window.start()
75+
print("Waiting for output window to start...")
76+
time.sleep(5)
77+
input_screen.start()
78+
79+
udp = UDP(address, port)
80+
81+
main_thread_time_cumulative = 0
82+
lowpass_alpha = 0.1
83+
84+
while True:
85+
if input is None:
86+
sleep(0.01)
87+
continue
88+
89+
start = torch.cuda.Event(enable_timing=True)
90+
end = torch.cuda.Event(enable_timing=True)
91+
92+
start.record()
93+
94+
x_output = stream(input.to(device=stream.device, dtype=stream.dtype))
95+
output_images = postprocess_image(x_output, output_type="pil")[0]
96+
97+
udp.send_udp_data(output_images)
98+
end.record()
99+
torch.cuda.synchronize()
100+
main_thread_time = start.elapsed_time(end) / 1000
101+
main_thread_time_cumulative = (
102+
lowpass_alpha * main_thread_time + (1 - lowpass_alpha) * main_thread_time_cumulative
103+
)
104+
fps = 1 / main_thread_time_cumulative
105+
print(f"fps: {fps}, main_thread_time: {main_thread_time_cumulative}")
106+
107+
108+
if __name__ == "__main__":
109+
fire.Fire(run)

examples/screen/requirements.txt

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
matplotlib
2+
pillow
3+
mss

examples/screen/socks.py

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import io
2+
import socket
3+
from typing import *
4+
5+
6+
class UDP:
7+
def __init__(self, ip, port):
8+
self.ip = ip
9+
self.port = port
10+
self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
11+
12+
def __del__(self):
13+
self.sock.close()
14+
15+
def send_udp_data(self, images):
16+
img_byte_arr = io.BytesIO()
17+
images.save(img_byte_arr, format="JPEG")
18+
img_byte_arr = img_byte_arr.getvalue()
19+
self.sock.sendto(img_byte_arr, (self.ip, self.port))
20+
21+
22+
def receive_udp_data(ip, port):
23+
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
24+
sock.bind((ip, port))
25+
data, addr = sock.recvfrom(65535) # 65535 is the maximum UDP packet size
26+
sock.close()
27+
return data

0 commit comments

Comments
 (0)