Skip to content

Commit

Permalink
fix #136. Updated eole/bin/model/average_models.py to work with safet…
Browse files Browse the repository at this point in the history
…ensors model format. (#137)
  • Loading branch information
isanvicente authored Oct 30, 2024
1 parent 0ec1088 commit 146c8f9
Showing 1 changed file with 31 additions and 16 deletions.
47 changes: 31 additions & 16 deletions eole/bin/model/average_models.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,37 @@
#!/usr/bin/env python
import torch
from eole.bin import BaseBin, register_bin
from eole.models import model_saver
from eole.config import recursive_model_fields_set
from safetensors.torch import load_file, save_file
import os
import json


def average_models(model_files, fp32=False):
def average_models(model_paths, fp32=False):
vocab = None
config = None
avg_model = None
avg_generator = None

for i, model_file in enumerate(model_files):
m = torch.load(model_file, map_location="cpu")
model_weights = m["model"]
generator_weights = m["generator"]
for i, model_path in enumerate(model_paths):
m = model_saver.load_checkpoint(model_path)
model_weights = load_file(os.path.join(model_path, "model.00.safetensors"))

if fp32:
for k, v in model_weights.items():
model_weights[k] = v.float()
for k, v in generator_weights.items():
generator_weights[k] = v.float()

if i == 0:
vocab, config = m["vocab"], m["config"]
vocab, config, optim = m["vocab"], m["config"], m["optim"]
avg_model = model_weights
avg_generator = generator_weights
else:
for k, v in avg_model.items():
avg_model[k].mul_(i).add_(model_weights[k]).div_(i + 1)

for k, v in avg_generator.items():
avg_generator[k].mul_(i).add_(generator_weights[k]).div_(i + 1)

final = {
"vocab": vocab,
"config": config,
"optim": None,
"generator": avg_generator,
"optim": optim,
"model": avg_model,
}
return final
Expand All @@ -56,4 +52,23 @@ def add_args(cls, parser):
@classmethod
def run(cls, args):
final = average_models(args.models, args.fp32)
torch.save(final, args.output)

if not os.path.isdir(args.output):
os.makedirs(args.output, exist_ok=True)

# this maybe better implemented using model_saver classes
# config
with open(os.path.join(args.output, "config.json"), "w") as f:
json.dump(
recursive_model_fields_set(final["config"]),
f,
indent=2,
ensure_ascii=False,
)
# vocab
with open(os.path.join(args.output, "vocab.json"), "w") as f:
json.dump(final["vocab"], f, indent=2, ensure_ascii=False)
# optimizer
torch.save(final["optim"], os.path.join(args.output, "optimizer.pt"))
# model weights
save_file(final["model"], os.path.join(args.output, "model.00.safetensors"))

0 comments on commit 146c8f9

Please sign in to comment.