Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

api : expose new load weight function #37

Merged
merged 1 commit into from
Apr 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 40 additions & 13 deletions encodec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -257,21 +257,13 @@ static struct ggml_tensor *forward_pass_lstm_unilayer(
return hs;
}

bool encodec_load_model_weights(const std::string &fname, encodec_model &model, int n_gpu_layers) {
fprintf(stderr, "%s: loading model from '%s'\n", __func__, fname.c_str());

auto infile = std::ifstream(fname, std::ios::binary);
if (!infile) {
fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str());
return false;
}

bool encodec_load_model_weights(std::ifstream &infile, encodec_model &model, int n_gpu_layers) {
// verify magic (i.e. ggml signature in hex format)
{
uint32_t magic;
read_safe(infile, magic);
if (magic != ENCODEC_FILE_MAGIC) {
fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str());
fprintf(stderr, "%s: invalid model file (bad magic)\n", __func__);
return false;
}
}
Expand Down Expand Up @@ -312,8 +304,8 @@ bool encodec_load_model_weights(const std::string &fname, encodec_model &model,
// in order to save memory and also to speed up the computation
ggml_type wtype = ggml_ftype_to_ggml_type((ggml_ftype)(model.hparams.ftype));
if (wtype == GGML_TYPE_COUNT) {
fprintf(stderr, "%s: invalid model file '%s' (bad ftype value %d)\n",
__func__, fname.c_str(), model.hparams.ftype);
fprintf(stderr, "%s: invalid model file (bad ftype value %d)\n",
__func__, model.hparams.ftype);
return 1;
}

Expand Down Expand Up @@ -1341,13 +1333,48 @@ bool encodec_decompress_audio(
return true;
}

struct encodec_context *encodec_load_model(std::ifstream &fin, int n_gpu_layers) {
int64_t t_start_load_us = ggml_time_us();

struct encodec_context *ectx = new encodec_context();

ectx->model = encodec_model();
if (!encodec_load_model_weights(fin, ectx->model, n_gpu_layers)) {
fprintf(stderr, "%s: failed to load model weights\n", __func__);
return {};
}

// pre-compute the number of codebooks required
int bandwidth = ectx->model.hparams.bandwidth;
int sr = ectx->model.hparams.sr;

int hop_length = 1;
for (int i = 0; i < 4; i++) {
hop_length *= ectx->model.hparams.ratios[i];
}
ectx->model.hparams.hop_length = hop_length;

ectx->model.hparams.n_q = get_num_codebooks(bandwidth, hop_length, sr);
fprintf(stderr, "%s: n_q = %d\n", __func__, ectx->model.hparams.n_q);

ectx->t_load_us = ggml_time_us() - t_start_load_us;

return ectx;
}

struct encodec_context *encodec_load_model(const std::string &model_path, int n_gpu_layers) {
int64_t t_start_load_us = ggml_time_us();

auto infile = std::ifstream(model_path, std::ios::binary);
if (!infile) {
fprintf(stderr, "%s: failed to open '%s'\n", __func__, model_path.c_str());
return nullptr;
}

struct encodec_context *ectx = new encodec_context();

ectx->model = encodec_model();
if (!encodec_load_model_weights(model_path, ectx->model, n_gpu_layers)) {
if (!encodec_load_model_weights(infile, ectx->model, n_gpu_layers)) {
fprintf(stderr, "%s: failed to load model weights from '%s'\n", __func__, model_path.c_str());
return {};
}
Expand Down
11 changes: 11 additions & 0 deletions encodec.h
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,17 @@ struct encodec_context *encodec_load_model(
const std::string &model_path,
int n_gpu_layers);

/**
* Loads an encodec model from an opened input file stream.
*
* @param fin The input file stream to read the encodec model from. The pointer in the file should be placed at the model data.
* @param n_gpu_layers The number of GPU layers to use.
* @return A pointer to the encodec context struct.
*/
struct encodec_context *encodec_load_model(
std::ifstream &fin,
int n_gpu_layers);

/**
* Sets the target bandwidth for the given encodec context.
*
Expand Down