Skip to content

Commit

Permalink
patch lora bin to dump json config (#28)
Browse files Browse the repository at this point in the history
  • Loading branch information
francoishernandez authored Jun 12, 2024
1 parent 7d210ec commit 61c4b3f
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion eole/bin/model/lora_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from eole.models.model_saver import load_checkpoint
from eole.models import get_model_class
from eole.inputters.inputter import dict_to_vocabs, vocabs_to_dict
from eole.config import recursive_model_fields_set
from safetensors import safe_open
from safetensors.torch import save_file
import glob
Expand Down Expand Up @@ -101,8 +102,9 @@ def run(cls, args):
with open(os.path.join(args.output, "vocab.json"), "w", encoding="utf-8") as f:
json.dump(vocab_dict, f, indent=2, ensure_ascii=False)
# save config
config_dict = recursive_model_fields_set(new_config)
with open(os.path.join(args.output, "config.json"), "w", encoding="utf-8") as f:
json.dump(new_config, f, indent=2, ensure_ascii=False)
json.dump(config_dict, f, indent=2, ensure_ascii=False)
shards = glob.glob(os.path.join(args.base_model, "model.*.safetensors"))
f = []
for i, shard in enumerate(shards):
Expand Down

0 comments on commit 61c4b3f

Please sign in to comment.