-
Notifications
You must be signed in to change notification settings - Fork 1
/
finetune.py
79 lines (65 loc) · 2.89 KB
/
finetune.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
#
# Fine-tine images with Stable Diffusion on Replicate
#
import os
import replicate
from dotenv import load_dotenv, find_dotenv
# Check if a model exists on Replicate
def model_exists(modelname):
"""Check if a model exists on Replicate."""
try:
# Hack to see if this model really exists
model = replicate.models.get(modelname)
vl = model.versions.list()
print(model.versions)
return True
except:
return False
# Function that calls replicate to render an image with the standard SD XL model
def train_model(token, tmpdir, modelname, masktarget=None, captionprefix="a photo of", dreambooth=False, use_face_detection_instead=True):
"""Upload images to Replicate."""
filename = f'{tmpdir}/{token}.zip'
if not os.path.exists(filename):
print(f"File {filename} does not exist. Did you run prepare first?")
return None
caption_text = f'{captionprefix} {token}'
username = os.getenv('REPLICATE_USERNAME')
modelname_full = f'{username}/{modelname}'
modelname_full = modelname_full.lower()
# Check if the model exists
if not model_exists(modelname_full):
print(f"I can't fine the model '{modelname_full}' on Replicate, you need to create it first.")
print(f" 1. Browse to this page on the Replicate web site: https://replicate.com/create")
print(f" 2. Create a model with the name '{modelname.lower()}' (all lowercase) and run this script again.")
return False
else:
print(f'Found target model {modelname_full}.')
max_train_steps = 1000
unet_learning_rate = 1E-6
if dreambooth:
max_train_steps = 4000
unet_learning_rate = 2E-6
resolution = 1024
if masktarget:
use_face_detection_instead = False
mode = 'DreamBooth' if dreambooth else 'LoRA'
maskmode = f'Mask prompt: "{masktarget}"' if not use_face_detection_instead else 'Face Detection: True'
print(f'Training {mode} model. Token: "{token}" Caption: "{caption_text}" {maskmode}')
training = replicate.trainings.create(
version="stability-ai/sdxl:7ca7f0d3a51cd993449541539270971d38a24d9a0d42f073caf25190d41346d7",
input={
#"input_images": "https://guido.appenzeller.net/wp-content/uploads/tmp/gappnzllr.zip",
#"input_images": "https://guido.appenzeller.net/wp-content/uploads/tmp/chrltt.zip",
"input_images": f"https://guido.appenzeller.net/wp-content/uploads/tmp/{token}.zip",
"caption_prefix": caption_text,
"token_string": token,
"mask_target_prompts": masktarget,
"resolution": resolution,
"use_face_detection_instead": use_face_detection_instead,
"is_lora": not dreambooth,
"max_train_steps": max_train_steps,
"unet_learning_rate": unet_learning_rate,
},
destination=modelname_full
)
print(training.status)