Skip to content

Commit

Permalink
Add slider to control DREAM lambda
Browse files Browse the repository at this point in the history
This allows adjusting the tradeoff between faster learning and better composition vs better preservation of detail.
  • Loading branch information
RossM committed Dec 23, 2023
1 parent 92d3219 commit 418ba6c
Show file tree
Hide file tree
Showing 8 changed files with 31 additions and 4 deletions.
1 change: 1 addition & 0 deletions dreambooth/dataclasses/db_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class DreamboothConfig(BaseModel):
max_token_length: int = 75
min_snr_gamma: float = 0.0
use_dream: bool = False
dream_detail: float = 0.5
mixed_precision: str = "fp16"
model_dir: str = ""
model_name: str = ""
Expand Down
7 changes: 3 additions & 4 deletions dreambooth/train_dreambooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -1621,9 +1621,8 @@ def lora_save_function(weights, filename):
sqrt_alpha_prod = alpha_prod ** 0.5
sqrt_one_minus_alpha_prod = (1 - alpha_prod) ** 0.5

# The paper uses lambda = sqrt(1 - alpha) ** p, with p = 1 in their experiments, but
# lambda = 1 seems to give better results for fine-tuning.
dream_lambda = 1
# The paper uses lambda = sqrt(1 - alpha) ** p, with p = 1 in their experiments.
dream_lambda = (1 - alpha_prod) ** args.dream_detail

if args.model_type == "SDXL":
with accelerator.autocast():
Expand All @@ -1649,7 +1648,7 @@ def lora_save_function(weights, filename):
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")

del alpha_prod, sqrt_alpha_prod, sqrt_one_minus_alpha_prod, dream_lambda, predicted_noise, delta_noise
del alpha_prod, sqrt_alpha_prod, sqrt_one_minus_alpha_prod, dream_lambda, model_pred, predicted_noise, delta_noise

if args.model_type == "SDXL":
with accelerator.autocast():
Expand Down
5 changes: 5 additions & 0 deletions index.html
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,11 @@
id="use_dream" name="use_dream">
<label class="form-check-label" for="use_dream">Use DREAM</label>
</div>
</div>
<div class="form-group">
<div class="dbInput db-slider" data-min="0" data-max="1.0"
data-step="0.01" id="dream_detail" data-value="0.5"
data-label="DREAM detail preservation"></div>
</div>
<div class="form-group">
<div class="dbInput db-slider" data-min="75" data-max="300"
Expand Down
1 change: 1 addition & 0 deletions javascript/dreambooth.js
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ let db_titles = {
"Use CPU Only (SLOW)": "Guess what - this will be incredibly slow, but it will work for < 8GB GPUs.",
"Use Concepts List": "Train multiple concepts from a JSON file or string.",
"Use DREAM": "Enable DREAM (http://arxiv.org/abs/2312.00210). This may provide better results, but trains slower.",
"DREAM detail preservation": "A factor that influences how DREAM trades off composition versus detail. Low values will improve composition but may result in loss of detail. High values preserve detail but may reduce the overall effect of DREAM.",
"Use EMA": "Enabling this will provide better results and editability, but cost more VRAM.",
"Use EMA for prediction": "",
"Use EMA Weights for Inference": "Enabling this will save the EMA unet weights as the 'normal' model weights and ignore the regular unet weights.",
Expand Down
9 changes: 9 additions & 0 deletions scripts/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,13 @@ def on_ui_tabs():
db_use_dream = gr.Checkbox(
label="Use DREAM", value=False
)
db_dream_detail_preservation = gr.Slider(
label="DREAM detail preservation",
minimum=0,
maximum=1,
step=0.01,
visible=True,
)
db_pad_tokens = gr.Checkbox(
label="Pad Tokens", value=True
)
Expand Down Expand Up @@ -1352,6 +1359,7 @@ def format_updates():
db_tenc_grad_clip_norm,
db_min_snr_gamma,
db_use_dream,
db_dream_detail_preservation,
db_pad_tokens,
db_strict_tokens,
db_max_token_length,
Expand Down Expand Up @@ -1486,6 +1494,7 @@ def toggle_advanced():
db_max_token_length,
db_min_snr_gamma,
db_use_dream,
db_dream_detail_preservation,
db_mixed_precision,
db_model_name,
db_model_path,
Expand Down
6 changes: 6 additions & 0 deletions templates/defaults/defaults.json
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,12 @@
"use_dream": {
"value": false
},
"dream_detail": {
"value": 0.5,
"min": 0.0,
"max": 1,
"step": 0.01
},
"max_token_length": {
"value": 75,
"min": 75,
Expand Down
1 change: 1 addition & 0 deletions templates/defaults/dreambooth_model_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
"max_token_length": 75,
"min_snr_gamma": 0.0,
"use_dream": false,
"dream_detail": 0.5,
"mixed_precision": "fp16",
"noise_scheduler": "DDPM",
"num_train_epochs": 200,
Expand Down
5 changes: 5 additions & 0 deletions templates/locales/titles_en.json
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,11 @@
"title": "Enable Diffusion Rectification and Estimation-Adaptive Models (DREAM).",
"description": "Whether or not to use DREAM. DREAM performs an additional model evaluation at each training step, which increases training time but can help improve the stability and consistency of the generated images."
},
"dream_detail": {
"label": "DREAM detail",
"title": "Select how much detail DREAM preserves.",
"description": "A factor that influences how DREAM trades off composition versus detail. Low values will improve composition but may result in loss of detail. High values preserve detail but may reduce the overall effect of DREAM."
},
"train_unet": {
"label": "Train UNET",
"title": "Train UNET as an additional module.",
Expand Down

0 comments on commit 418ba6c

Please sign in to comment.