diff --git a/eole/bin/model/average_models.py b/eole/bin/model/average_models.py index 2d45104c..3a4b8b57 100755 --- a/eole/bin/model/average_models.py +++ b/eole/bin/model/average_models.py @@ -1,41 +1,37 @@ #!/usr/bin/env python import torch from eole.bin import BaseBin, register_bin +from eole.models import model_saver +from eole.config import recursive_model_fields_set +from safetensors.torch import load_file, save_file +import os +import json -def average_models(model_files, fp32=False): +def average_models(model_paths, fp32=False): vocab = None config = None avg_model = None - avg_generator = None - for i, model_file in enumerate(model_files): - m = torch.load(model_file, map_location="cpu") - model_weights = m["model"] - generator_weights = m["generator"] + for i, model_path in enumerate(model_paths): + m = model_saver.load_checkpoint(model_path) + model_weights = load_file(os.path.join(model_path, "model.00.safetensors")) if fp32: for k, v in model_weights.items(): model_weights[k] = v.float() - for k, v in generator_weights.items(): - generator_weights[k] = v.float() if i == 0: - vocab, config = m["vocab"], m["config"] + vocab, config, optim = m["vocab"], m["config"], m["optim"] avg_model = model_weights - avg_generator = generator_weights else: for k, v in avg_model.items(): avg_model[k].mul_(i).add_(model_weights[k]).div_(i + 1) - for k, v in avg_generator.items(): - avg_generator[k].mul_(i).add_(generator_weights[k]).div_(i + 1) - final = { "vocab": vocab, "config": config, - "optim": None, - "generator": avg_generator, + "optim": optim, "model": avg_model, } return final @@ -56,4 +52,23 @@ def add_args(cls, parser): @classmethod def run(cls, args): final = average_models(args.models, args.fp32) - torch.save(final, args.output) + + if not os.path.isdir(args.output): + os.makedirs(args.output, exist_ok=True) + + # this maybe better implemented using model_saver classes + # config + with open(os.path.join(args.output, "config.json"), "w") as f: + json.dump( + recursive_model_fields_set(final["config"]), + f, + indent=2, + ensure_ascii=False, + ) + # vocab + with open(os.path.join(args.output, "vocab.json"), "w") as f: + json.dump(final["vocab"], f, indent=2, ensure_ascii=False) + # optimizer + torch.save(final["optim"], os.path.join(args.output, "optimizer.pt")) + # model weights + save_file(final["model"], os.path.join(args.output, "model.00.safetensors"))