Skip to content

Commit

Permalink
Add prediction type to inference (#62)
Browse files Browse the repository at this point in the history
  • Loading branch information
Landanjs authored Aug 25, 2023
1 parent ee8d0b2 commit 053d32a
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions diffusion/inference/inference_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,16 @@ class StableDiffusionInference():
Default: ``None``.
"""

def __init__(self, pretrained: bool = False):
def __init__(self, pretrained: bool = False, prediction_type: str = 'epsilon'):
self.device = torch.cuda.current_device()

model = stable_diffusion_2(pretrained=pretrained, encode_latents_in_fp16=True, fsdp=False)
model = stable_diffusion_2(
pretrained=pretrained,
prediction_type=prediction_type,
encode_latents_in_fp16=True,
fsdp=False,
)

if not pretrained:
state_dict = torch.load(LOCAL_CHECKPOINT_PATH)
for key in list(state_dict['state']['model'].keys()):
Expand Down

0 comments on commit 053d32a

Please sign in to comment.