Skip to content

Commit

Permalink
refactor: split pipeline and inference
Browse files Browse the repository at this point in the history
  • Loading branch information
fboulnois committed Oct 26, 2022
1 parent debace7 commit d9689c4
Showing 1 changed file with 29 additions and 35 deletions.
64 changes: 29 additions & 35 deletions docker-entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
from diffusers import StableDiffusionPipeline


def cuda_device():
return "cuda"


def iso_date_time():
return datetime.datetime.now().isoformat()

Expand All @@ -13,45 +17,42 @@ def skip_safety_checker(images, *args, **kwargs):
return images, False


def stable_diffusion(
model,
prompt,
samples,
iters,
height,
width,
steps,
scale,
seed,
half,
skip,
do_slice,
token,
):
device = "cuda"
def stable_diffusion_pipeline(model, half, skip, do_slice, token):
if token is None:
with open("token.txt") as f:
token = f.read().replace("\n", "")

dtype, rev = (torch.float16, "fp16") if half else (torch.float32, "main")

print("load pipeline start:", iso_date_time())

pipe = StableDiffusionPipeline.from_pretrained(
pipeline = StableDiffusionPipeline.from_pretrained(
model, torch_dtype=dtype, revision=rev, use_auth_token=token
).to(device)
).to(cuda_device())

if skip:
pipe.safety_checker = skip_safety_checker
pipeline.safety_checker = skip_safety_checker

if do_slice:
pipe.enable_attention_slicing()
pipeline.enable_attention_slicing()

print("loaded models after:", iso_date_time())

return pipeline


def stable_diffusion_inference(
pipeline, prompt, samples, iters, height, width, steps, scale, seed
):
if seed == 0:
seed = torch.random.seed()

prefix = prompt.replace(" ", "_")[:170]

generator = torch.Generator(device=device).manual_seed(seed)
generator = torch.Generator(device=cuda_device()).manual_seed(seed)
for j in range(iters):
with autocast(device):
result = pipe(
with autocast(cuda_device()):
result = pipeline(
[prompt] * samples,
height=height,
width=width,
Expand Down Expand Up @@ -154,15 +155,12 @@ def main():
if args.prompt0 is not None:
args.prompt = args.prompt0

if args.seed == 0:
args.seed = torch.random.seed()

if args.token is None:
with open("token.txt") as f:
args.token = f.read().replace("\n", "")
pipeline = stable_diffusion_pipeline(
args.model, args.half, args.skip, args.attention_slicing, args.token
)

stable_diffusion(
args.model,
stable_diffusion_inference(
pipeline,
args.prompt,
args.n_samples,
args.n_iter,
Expand All @@ -171,10 +169,6 @@ def main():
args.ddim_steps,
args.scale,
args.seed,
args.half,
args.skip,
args.attention_slicing,
args.token,
)


Expand Down

0 comments on commit d9689c4

Please sign in to comment.