Skip to content

Commit

Permalink
Update inference script and add checkpoint path argument (#52)
Browse files Browse the repository at this point in the history
* Fix inference code and add ability to select model checkpoint
  • Loading branch information
Landanjs authored Jul 18, 2023
1 parent 22e36b2 commit c872947
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 38 deletions.
76 changes: 45 additions & 31 deletions diffusion/inference/inference_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import base64
import io
from typing import Any, Dict, List, Optional

import torch
from composer.utils.file_helpers import get_file
Expand All @@ -16,22 +17,22 @@
LOCAL_CHECKPOINT_PATH = '/tmp/model.pt'


def download_model():
"""Download model from remote storage."""
model_uri = 'oci://mosaicml-internal-checkpoints/stable-diffusion-hero-run/4-13-512-ema/ep5-ba850000-rank0.pt'
get_file(path=model_uri, destination=LOCAL_CHECKPOINT_PATH)


class StableDiffusionInference():
"""Inference endpoint class for Stable Diffusion."""
"""Inference endpoint class for Stable Diffusion.
Args:
chkpt_path (str, optional): The path to the local folder, URL or object score that contains the checkpoint.
If not specified, pulls the pretrained Stable Diffusion 2.0 base weights from HuggingFace.
Default: ``None``.
"""

def __init__(self):
pretrained_flag = False
def __init__(self, chkpt_path: Optional[str] = None):
pretrained_flag = chkpt_path is None
self.device = torch.cuda.current_device()

model = stable_diffusion_2(pretrained=pretrained_flag, encode_latents_in_fp16=True, fsdp=False)
if not pretrained_flag:
download_model()
get_file(path=chkpt_path, destination=LOCAL_CHECKPOINT_PATH)
state_dict = torch.load(LOCAL_CHECKPOINT_PATH)
for key in list(state_dict['state']['model'].keys()):
if 'val_metrics.' in key:
Expand All @@ -40,29 +41,42 @@ def __init__(self):
model.to(self.device)
self.model = model.eval()

def predict(self, **inputs):
if 'prompt' not in inputs:
print('No prompt provided, returning nothing')
return

# Parse and cast args
kwargs = {}
for arg in ['prompt', 'negative_prompt']:
if arg in inputs:
kwargs[arg] = inputs[arg]
for arg in ['height', 'width', 'num_inference_steps', 'num_images_per_prompt', 'seed']:
if arg in inputs:
kwargs[arg] = int(inputs[arg])
for arg in ['guidance_scale']:
if arg in inputs:
kwargs[arg] = float(inputs[arg])

prompt = kwargs.pop('prompt')
prompts = [prompt] if isinstance(prompt, str) else prompt
def predict(self, model_requests: List[Dict[str, Any]]):
prompts = []
negative_prompts = []
generate_kwargs = {}

# assumes the same generate_kwargs across all samples
for req in model_requests:
if 'input' not in req:
raise RuntimeError('"input" must be provided to generate call')
inputs = req['input']

# Prompts and negative prompts if available
if isinstance(inputs, str):
prompts.append(inputs)
elif isinstance(input, Dict):
if 'prompt' not in req:
raise RuntimeError('"prompt" must be provided to generate call if using a dict as input')
prompts.append(inputs['prompt'])
if 'negative_prompt' in req:
negative_prompts.append(inputs['negative_prompt'])

generate_kwargs = req['parameters']

# Check for prompts
if len(prompts) == 0:
raise RuntimeError('No prompts provided, must be either a string or dictionary with "prompt"')

# Check negative prompt length
if len(negative_prompts) == 0:
negative_prompts = None
elif len(prompts) != len(negative_prompts):
raise RuntimeError('There must be the same number of negative prompts as prompts.')

# Generate images
with torch.cuda.amp.autocast(True):
imgs = self.model.generate(prompt=prompts, **kwargs).cpu()
imgs = self.model.generate(prompt=prompts, negative_prompt=negative_prompts, **generate_kwargs).cpu()

# Send as bytes
png_images = []
Expand All @@ -72,5 +86,5 @@ def predict(self, **inputs):
img_byte_arr = io.BytesIO()
pil_image.save(img_byte_arr, format='PNG')
base64_encoded_image = base64.b64encode(img_byte_arr.getvalue()).decode('utf-8')
png_images.append(bytes(base64_encoded_image, 'utf-8'))
png_images.append(base64_encoded_image)
return png_images
16 changes: 9 additions & 7 deletions diffusion/inference/mosaic_inference.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name: sd-o1ema
cluster: r7z13
gpu_num: 1
gpu_type: a100_40gb
cluster:
gpu_num:
gpu_type:
image: mosaicml/inference:latest
replicas: 1
integrations:
Expand All @@ -10,7 +10,9 @@ integrations:
git_branch: main
pip_install: .[all]
model:
checkpoint_path: hkunlp/instructor-large
custom_model:
model_handler: diffusion.inference.inference_model.StableDiffusionInference
command: uvicorn serve:app --host 0.0.0.0 --port 8080
model_handler: diffusion.inference.inference_model.StableDiffusionInference
command: |
export PYTHONPATH=$PYTHONPATH:/code/diffusion
rm /usr/lib/python3/dist-packages/packaging-23.1.dist-info/REQUESTED
pip install --force-reinstall --no-deps packaging==23.1
pip install --upgrade xformers

0 comments on commit c872947

Please sign in to comment.