Skip to content

Commit

Permalink
Make a separate download function
Browse files Browse the repository at this point in the history
  • Loading branch information
corystephenson-db committed Aug 3, 2023
1 parent 131f30e commit 655b596
Showing 1 changed file with 12 additions and 5 deletions.
17 changes: 12 additions & 5 deletions diffusion/inference/inference_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,15 @@
LOCAL_CHECKPOINT_PATH = '/tmp/model.pt'


def download_checkpoint(chkpt_path: str):
"""Downloads the Stable Diffusion checkpoint to the local filesystem.
Args:
chkpt_path (str): The path to the local folder, URL or object score that contains the checkpoint.
"""
get_file(path=chkpt_path, destination=LOCAL_CHECKPOINT_PATH)


class StableDiffusionInference():
"""Inference endpoint class for Stable Diffusion.
Expand All @@ -26,13 +35,11 @@ class StableDiffusionInference():
Default: ``None``.
"""

def __init__(self, chkpt_path: Optional[str] = None):
pretrained_flag = chkpt_path is None
def __init__(self, pretrained: bool = False):
self.device = torch.cuda.current_device()

model = stable_diffusion_2(pretrained=pretrained_flag, encode_latents_in_fp16=True, fsdp=False)
if not pretrained_flag:
get_file(path=chkpt_path, destination=LOCAL_CHECKPOINT_PATH)
model = stable_diffusion_2(pretrained=pretrained, encode_latents_in_fp16=True, fsdp=False)
if not pretrained:
state_dict = torch.load(LOCAL_CHECKPOINT_PATH)
for key in list(state_dict['state']['model'].keys()):
if 'val_metrics.' in key:
Expand Down

0 comments on commit 655b596

Please sign in to comment.