Skip to content

Commit

Permalink
Update deployment code to use explicit downloader (#54)
Browse files Browse the repository at this point in the history
  • Loading branch information
coryMosaicML authored Aug 9, 2023
1 parent 131f30e commit 41f49c3
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 6 deletions.
19 changes: 13 additions & 6 deletions diffusion/inference/inference_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import base64
import io
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List

import torch
from composer.utils.file_helpers import get_file
Expand All @@ -17,6 +17,15 @@
LOCAL_CHECKPOINT_PATH = '/tmp/model.pt'


def download_checkpoint(chkpt_path: str):
"""Downloads the Stable Diffusion checkpoint to the local filesystem.
Args:
chkpt_path (str): The path to the local folder, URL or object score that contains the checkpoint.
"""
get_file(path=chkpt_path, destination=LOCAL_CHECKPOINT_PATH)


class StableDiffusionInference():
"""Inference endpoint class for Stable Diffusion.
Expand All @@ -26,13 +35,11 @@ class StableDiffusionInference():
Default: ``None``.
"""

def __init__(self, chkpt_path: Optional[str] = None):
pretrained_flag = chkpt_path is None
def __init__(self, pretrained: bool = False):
self.device = torch.cuda.current_device()

model = stable_diffusion_2(pretrained=pretrained_flag, encode_latents_in_fp16=True, fsdp=False)
if not pretrained_flag:
get_file(path=chkpt_path, destination=LOCAL_CHECKPOINT_PATH)
model = stable_diffusion_2(pretrained=pretrained, encode_latents_in_fp16=True, fsdp=False)
if not pretrained:
state_dict = torch.load(LOCAL_CHECKPOINT_PATH)
for key in list(state_dict['state']['model'].keys()):
if 'val_metrics.' in key:
Expand Down
5 changes: 5 additions & 0 deletions diffusion/inference/mosaic_inference.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@ integrations:
git_branch: main
pip_install: .[all]
model:
downloader: diffusion.inference.inference_model.download_checkpoint
download_parameters:
chkpt_path: # Path to download the checkpoint to evaluate
model_handler: diffusion.inference.inference_model.StableDiffusionInference
model_parameters:
pretrained: false
command: |
export PYTHONPATH=$PYTHONPATH:/code/diffusion
rm /usr/lib/python3/dist-packages/packaging-23.1.dist-info/REQUESTED
Expand Down

0 comments on commit 41f49c3

Please sign in to comment.