From a06c81e17cd2194361cce95a6179edb5fdb12e78 Mon Sep 17 00:00:00 2001 From: Pierre-Antoine Bannier Date: Mon, 9 Oct 2023 17:58:59 +0200 Subject: [PATCH] parse model weights --- convert.py | 35 +++++++++++++++++++++++++++++++---- 1 file changed, 31 insertions(+), 4 deletions(-) diff --git a/convert.py b/convert.py index 3572d44..d4dddc7 100644 --- a/convert.py +++ b/convert.py @@ -43,11 +43,8 @@ parser.add_argument("--use-f16", type=bool, default=True) -def parse_codec_model(checkpoint, out_dir, use_f16): +def parse_codec_model(checkpoint, outfile, use_f16): """Load encodec model checkpoint.""" - outfile = open(out_dir, "wb") - outfile.write(struct.pack("i", 0x67676d6c)) # ggml magic - for name in checkpoint.keys(): if "weight_g" in name: # the tensor has already been parsed with the corresponding "weight_v" @@ -107,6 +104,27 @@ def parse_codec_model(checkpoint, out_dir, use_f16): outfile.close() +def parse_hparams(outfile, use_f16): + # for now this is hardcoded as we only support the 24Khz model + in_channels = 1 + hidden_dim = 128 + n_filters = 32 + kernel_size = 7 + residual_kernel_size = 3 + n_q = 32 + n_bins = 1024 + ftype = int(use_f16) + + outfile.write(struct.pack("i", in_channels)) + outfile.write(struct.pack("i", hidden_dim)) + outfile.write(struct.pack("i", n_filters)) + outfile.write(struct.pack("i", kernel_size)) + outfile.write(struct.pack("i", residual_kernel_size)) + outfile.write(struct.pack("i", n_q)) + outfile.write(struct.pack("i", n_bins)) + outfile.write(struct.pack("i", ftype)) + + if __name__ == "__main__": args = parser.parse_args() @@ -118,6 +136,15 @@ def parse_codec_model(checkpoint, out_dir, use_f16): outfile = Path(out_dir / "ggml-model.bin") checkpoint = torch.load(dir_model / "encodec_24khz-d7cc33bc.th", map_location="cpu") + + # Step 1: insert ggml magic + outfile = open(out_dir, "wb") + outfile.write(struct.pack("i", 0x67676d6c)) + + # Step 2: insert hyperparameters + parse_hparams(outfile, args.use_f16) + + # Step 3: insert weights parse_codec_model(checkpoint, outfile, args.use_f16) print("Done.")