Skip to content

Commit

Permalink
parse model weights
Browse files Browse the repository at this point in the history
  • Loading branch information
PABannier committed Oct 9, 2023
1 parent f62bf09 commit a06c81e
Showing 1 changed file with 31 additions and 4 deletions.
35 changes: 31 additions & 4 deletions convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,8 @@
parser.add_argument("--use-f16", type=bool, default=True)


def parse_codec_model(checkpoint, out_dir, use_f16):
def parse_codec_model(checkpoint, outfile, use_f16):
"""Load encodec model checkpoint."""
outfile = open(out_dir, "wb")
outfile.write(struct.pack("i", 0x67676d6c)) # ggml magic

for name in checkpoint.keys():
if "weight_g" in name:
# the tensor has already been parsed with the corresponding "weight_v"
Expand Down Expand Up @@ -107,6 +104,27 @@ def parse_codec_model(checkpoint, out_dir, use_f16):
outfile.close()


def parse_hparams(outfile, use_f16):
# for now this is hardcoded as we only support the 24Khz model
in_channels = 1
hidden_dim = 128
n_filters = 32
kernel_size = 7
residual_kernel_size = 3
n_q = 32
n_bins = 1024
ftype = int(use_f16)

outfile.write(struct.pack("i", in_channels))
outfile.write(struct.pack("i", hidden_dim))
outfile.write(struct.pack("i", n_filters))
outfile.write(struct.pack("i", kernel_size))
outfile.write(struct.pack("i", residual_kernel_size))
outfile.write(struct.pack("i", n_q))
outfile.write(struct.pack("i", n_bins))
outfile.write(struct.pack("i", ftype))


if __name__ == "__main__":
args = parser.parse_args()

Expand All @@ -118,6 +136,15 @@ def parse_codec_model(checkpoint, out_dir, use_f16):
outfile = Path(out_dir / "ggml-model.bin")

checkpoint = torch.load(dir_model / "encodec_24khz-d7cc33bc.th", map_location="cpu")

# Step 1: insert ggml magic
outfile = open(out_dir, "wb")
outfile.write(struct.pack("i", 0x67676d6c))

# Step 2: insert hyperparameters
parse_hparams(outfile, args.use_f16)

# Step 3: insert weights
parse_codec_model(checkpoint, outfile, args.use_f16)

print("Done.")

0 comments on commit a06c81e

Please sign in to comment.