Skip to content

Commit

Permalink
Fix ruff errors
Browse files Browse the repository at this point in the history
  • Loading branch information
yorickvP committed Sep 27, 2024
1 parent e718818 commit 461db42
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any, Optional

import torch
import torch._dynamo as dynamo

torch.set_float32_matmul_precision("high")
torch.backends.cuda.matmul.allow_tf32 = True
Expand Down Expand Up @@ -183,7 +184,7 @@ def compile_ae(self):
# torch.compile has to recompile if it makes invalid assumptions
# about the input sizes. Having higher input sizes first makes
# for fewer recompiles.
VAE_SIZES = [
vae_sizes = [
[1, 16, 192, 168],
[1, 16, 96, 96],
[1, 16, 96, 168],
Expand All @@ -202,19 +203,19 @@ def compile_ae(self):
]
print("compiling AE")
st = time.time()
device = torch.device('cuda')
device = torch.device("cuda")
if self.offload:
self.ae.decoder.to(device)

self.ae.decoder = torch.compile(self.ae.decoder)

# actual compilation happens when you give it inputs
for f in VAE_SIZES:
for f in vae_sizes:
print("Compiling AE for size", f)
x = torch.rand(f, device=device)
torch._dynamo.mark_dynamic(x, 0, min=1, max=4)
torch._dynamo.mark_dynamic(x, 2, min=80)
torch._dynamo.mark_dynamic(x, 3, min=80)
dynamo.mark_dynamic(x, 0, min=1, max=4)
dynamo.mark_dynamic(x, 2, min=80)
dynamo.mark_dynamic(x, 3, min=80)
with torch.autocast(
device_type=device.type, dtype=torch.bfloat16, cache_enabled=False
):
Expand Down

0 comments on commit 461db42

Please sign in to comment.