Skip to content

Commit

Permalink
Add more metadata, refactor output, add datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
rockerBOO committed Aug 27, 2023
1 parent 9e4da55 commit baaea0e
Show file tree
Hide file tree
Showing 5 changed files with 843 additions and 70 deletions.
40 changes: 28 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
- [Save meta](#save-meta)
- [Average weights](#average-weights)
- [Tag frequency](#tag-frequency)
- [Dataset](#dataset)
- [Definition](#definition)
- [Update metadata](#update-metadata)
- [Usage](#usage)
Expand Down Expand Up @@ -201,22 +202,35 @@ arizona 6
enchanting and otherworldly 6
```

### Dataset

A pretty basic view of the dataset with the directories and number of images.

```
$ python lora-inspector.py -d /mnt/900/lora/booscapes.safetensors
Dataset dirs: 2
[source] 50 images
[p7] 4 images
```

### Definition

- epoch: an epoch is seeing the entire dataset once
- batches: how many batches per each epoch (does not include gradient
- Batches per epoch: how many batches per each epoch (does not include gradient
accumulation steps)
- train images: number of training images you have
- regularization images: number of regularization images
- scheduler: the learning rate scheduler.
- optimizer: the optimizer
- network dim/rank: the rank of the LoRA network
- alpha: the alpha to the rank of the LoRA network
- module: which python module was used to to create the network (includes module
arguments)
- noise offset: noise offset option
- adaptive noise scale: adapative noise scale
- multires noise discount: multires noise discount
- Gradient accumulation steps: gradient accumulation steps
- Train images: number of training images you have
- Regularization images: number of regularization images
- Scheduler: the learning rate scheduler (cosine, cosine_with_restart, linear,
constant, …)
- Optimizer: the optimizer (Adam, Prodigy, DAdaptation, Lion, …)
- Network dim/rank: the rank of the LoRA network
- Alpha: the alpha to the rank of the LoRA network
- Module: the python module that created the network
- Noise offset: noise offset option
- Adaptive noise scale: adaptive noise scale
- multires noise discount: multires noise discount (See
[Multi-Resolution Noise for Diffusion Model Training](https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2))
- multires noise scale: multires noise scale

- average magnitude: square each weight, add them up, get the square root
Expand Down Expand Up @@ -253,6 +267,8 @@ Saved to /mnt/900/lora/testing/armored-core-2023-08-02-173642-ddb4785e.safetenso

## Changelog

- 2023-08-27 — Add max_grad_norm, scale weight norms, gradient accumulation
steps, dropout, and datasets
- 2023-08-08 — Add simple metadata updater script
- 2023-07-31 — Add SDXL support
- 2023-07-17 — Add network dropout, scale weight norms, adaptive noise scale,
Expand Down
163 changes: 105 additions & 58 deletions lora-inspector.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class NameSpace(argparse.ArgumentParser):
save_meta: bool
weights: bool
tags: bool
dataset: bool


parsers: dict[str, Callable] = {
Expand Down Expand Up @@ -183,7 +184,7 @@ def find_vectors_weights(vectors):
text_encoder1_weight_results = {}
text_encoder2_weight_results = {}

print(f"model key count: {len(vectors.keys())}")
# print(f"model key count: {len(vectors.keys())}")
#
# print(vectors.keys())

Expand Down Expand Up @@ -288,7 +289,7 @@ def process_safetensor_file(file: Path, args) -> dict[str, Any]:
parsed = {}

if metadata is not None:
parsed = parse_metadata(metadata)
parsed = parse_metadata(metadata, args)
else:
parsed = {}

Expand All @@ -305,11 +306,46 @@ def process_safetensor_file(file: Path, args) -> dict[str, Any]:
return parsed


def process_datasets(metadata, args):
if "ss_dataset_dirs" not in metadata:
return

print(f"Dataset dirs: {len(metadata['ss_dataset_dirs'].keys())}")
for k, v in metadata["ss_dataset_dirs"].items():
print(f"\t[{k}] {v.get('img_count', 0)} images")


def process_modelspec(metadata, args):
if "modelspec.title" in metadata and metadata.get("modelspec.title", "") != "":
# item(items, "modelspec.implementation", "implementation"),
# item(items, "modelspec.sai_model_spec", "sai"),
# item(items, "modelspec.prediction_type", "prediction type"),
results = [
get_item(metadata, "modelspec.date", "Date"),
get_item(metadata, "modelspec.title", "Title"),
]

print_list(results)

results = [
get_item(metadata, "modelspec.resolution", "Resolution"),
get_item(metadata, "modelspec.architecture", "Architecture"),
]
print_list(results)


def print_list(list):
print(" ".join(list).strip(" "))


def parse_metadata(metadata):
def get_item(items, key, name):
if key in items and items.get(key) is not None and items.get(key) != "None":
return f"{name}: {items.get(key, '')}"

return ""


def parse_metadata(metadata, args):
if "sshs_model_hash" in metadata:
items = parse(metadata)

Expand All @@ -323,91 +359,89 @@ def parse_metadata(metadata):

# print(json.dumps(items, indent=4, sort_keys=True, default=str))

def item(items, key, name):
if key in items and items.get(key) is not None and items.get(key) != "None":
return f"{name}: {items.get(key, '')}"
process_modelspec(metadata, args)

return ""
results = [
get_item(items, "ss_network_dim", "Network Dim/Rank"),
get_item(items, "ss_network_alpha", "Alpha"),
get_item(items, "ss_network_dropout", "Dropout"),
get_item(items, "ss_network_module", "Module"),
]

print(
f"epoch: {items['ss_epoch']} batches: {items['ss_num_batches_per_epoch']}"
)
print_list(results)

print(
f"train images: {items['ss_num_train_images']} {item(items, 'ss_num_reg_images', 'regularization images')}"
)

if "modelspec.title" in items and items.get("modelspec.title", "") != "":
# item(items, "modelspec.implementation", "implementation"),
# item(items, "modelspec.resolution", "resolution"),
# item(items, "modelspec.sai_model_spec", "sai"),
# item(items, "modelspec.prediction_type", "prediction type"),
results = [
item(items, "modelspec.date", "date"),
item(items, "modelspec.title", "title"),
]
results = [
get_item(items, "ss_learning_rate", "Learning Rate (LR)"),
get_item(items, "ss_unet_lr", "UNet LR"),
get_item(items, "ss_text_encoder_lr", "TE LR"),
]

print_list(results)
print_list(results)

results = [
item(items, "modelspec.architecture", "architecture"),
]
print_list(results)
results = [
get_item(items, "ss_optimizer", "Optimizer"),
get_item(items, "ss_optimizer_args", "Optimizer args"),
]

# if (
# items["ss_num_reg_images"] > 0
# and items["ss_num_reg_images"] < items["ss_num_train_images"]
# ):
# print(
# f"Possibly not enough regularization images ({items['ss_num_reg_images']}) to training images ({items['ss_num_train_images']})."
# )
print_list(results)

print(
f"learning rate: {items['ss_learning_rate']} unet: {items['ss_unet_lr']} text encoder: {items['ss_text_encoder_lr']}"
)
results = [
get_item(items, "ss_lr_scheduler", "Scheduler"),
get_item(items, "ss_lr_scheduler_args", "Scheduler args"),
get_item(items, "ss_lr_warmup_steps", "Warmup steps"),
]

print_list(results)

results = [
item(items, "ss_lr_scheduler", "scheduler"),
item(items, "ss_lr_scheduler_args", "scheduler args"),
get_item(items, "ss_epoch", "Epoch"),
get_item(items, "ss_num_batches_per_epoch", "Batches per epoch"),
get_item(
items, "ss_gradient_accumulation_steps", "Gradient accumulation steps"
),
]

print_list(results)

results = [
item(items, "ss_optimizer", "optimizer"),
item(items, "ss_optimizer_args", "optimizer_args"),
get_item(items, "ss_num_train_images", "Train images"),
get_item(items, "ss_num_reg_images", "Regularization images"),
]

print_list(results)

if "loss_func" in items:
results = [
item(items, "ss_loss_func", "loss func"),
get_item(items, "ss_loss_func", "Loss func"),
]

print_list(results)

print(
f"network dim/rank: {items['ss_network_dim']} network alpha: {items['ss_network_alpha']} network module: {items['ss_network_module']} {items.get('ss_network_args')}"
)

results = [
item(items, "ss_noise_offset", "noise offset"),
item(items, "ss_adaptive_noise_scale", "adaptive noise scale"),
item(items, "ss_multires_noise_iterations", "multires noise iterations"),
item(items, "ss_multires_noise_discount", "multires noise discount"),
get_item(items, "ss_noise_offset", "Noise offset"),
get_item(items, "ss_adaptive_noise_scale", "Adaptive noise scale"),
get_item(
items, "ss_multires_noise_iterations", "Multires noise iterations"
),
get_item(items, "ss_multires_noise_discount", "Multires noise discount"),
]

print_list(results)

results = [
item(items, "ss_min_snr_gamma", "min snr gamma"),
item(items, "ss_zero_terminal_snr", "zero terminal snr"),
item(items, "ss_clip_skip", "clip skip"),
get_item(items, "ss_min_snr_gamma", "Min SNR gamma"),
get_item(items, "ss_zero_terminal_snr", "Zero terminal SNR"),
get_item(items, "ss_max_grad_norm", "Max grad norm"),
get_item(items, "ss_scale_weight_norms", "Scale weight norms"),
get_item(items, "ss_clip_skip", "Clip skip"),
]

print_list(results)

if args.dataset is True:
process_datasets(items, args)

return items
else:
print(
Expand All @@ -430,17 +464,23 @@ def print_tags(freq):
longest_tag = 0
for k in freq.keys():
for kitem in freq[k].keys():
if int(freq[k][kitem]) > 3:
tags.append((kitem, freq[k][kitem]))
# if int(freq[k][kitem]) > 3:
tags.append((kitem, freq[k][kitem]))

if len(kitem) > longest_tag:
longest_tag = len(kitem)
if len(kitem) > longest_tag:
longest_tag = len(kitem)

ordered = OrderedDict(reversed(sorted(tags, key=lambda t: t[1])))

justify_to = longest_tag + 1 if longest_tag < 60 else 60

for k, v in ordered.items():
for i, (k, v) in enumerate(ordered.items()):
# we can stop after 20
if i > 20:
remaining = len(ordered.items()) - i
print(f"{remaining} more tags...")
break

print(k.ljust(justify_to), v)


Expand Down Expand Up @@ -520,6 +560,13 @@ def process(args: type[NameSpace]):
help="Show the most common tags in the training set",
)

parser.add_argument(
"-d",
"--dataset",
action="store_true",
help="Show the dataset metadata including directory names and number of images",
)

args = parser.parse_args(namespace=NameSpace)
results = process(args)

Expand Down
Loading

0 comments on commit baaea0e

Please sign in to comment.