Skip to content

Commit

Permalink
cleaner loading context function
Browse files Browse the repository at this point in the history
  • Loading branch information
PABannier committed Oct 1, 2023
1 parent 59ab00b commit db79e88
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 56 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,4 @@ endif()

target_link_libraries(${ENCODEC_LIB} PUBLIC ggml)
target_include_directories(${ENCODEC_LIB} PUBLIC .)
target_compile_features(${ENCODEC_LIB} PUBLIC cxx_std_11)
target_compile_features(${ENCODEC_LIB} PUBLIC cxx_std_14)
29 changes: 16 additions & 13 deletions encodec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <stdexcept>
#include <fstream>
#include <map>
#include <memory>
#include <string>
#include <vector>

Expand Down Expand Up @@ -207,7 +208,7 @@ static struct ggml_tensor * strided_conv_transpose_1d(
return unpadded;
}

bool encodec_model_load(const std::string& fname, encodec_model& model) {
bool encodec_load_model_weights(const std::string& fname, encodec_model& model) {
fprintf(stderr, "%s: loading model from '%s'\n", __func__, fname.c_str());

auto infile = std::ifstream(fname, std::ios::binary);
Expand Down Expand Up @@ -542,7 +543,7 @@ static struct ggml_cgraph * encodec_build_graph(
const std::vector<float> & inp_audio) {
const int32_t audio_length = inp_audio.size();

const auto & model = ectx.model;
const auto & model = *ectx.model;

struct ggml_init_params ggml_params = {
/*.mem_size =*/ ectx.buf_compute.size(),
Expand Down Expand Up @@ -788,7 +789,7 @@ static struct ggml_cgraph * encodec_build_graph(
return gf;
}

bool encodec_model_eval(
bool encodec_reconstruct_audio(
encodec_context & ectx,
std::vector<float> & raw_audio,
int n_threads) {
Expand Down Expand Up @@ -849,18 +850,20 @@ bool encodec_model_eval(
return true;
}

struct encodec_context encodec_new_context_with_model(encodec_model & model) {
encodec_context ctx = encodec_context(model);
return ctx;
}
std::shared_ptr<encodec_context> encodec_load_model(const std::string & model_path) {
int64_t t_start_load_us = ggml_time_us();

encodec_context ectx;

struct encodec_model encodec_load_model_from_file(std::string fname) {
encodec_model model;
if (!encodec_model_load(fname, model)) {
fprintf(stderr, "%s: failed to load model\n", __func__);
exit(0);
ectx.model = std::make_unique<encodec_model>();
if (!encodec_load_model_weights(model_path, *ectx.model)) {
fprintf(stderr, "%s: failed to load model weights from '%s'\n", __func__, model_path.c_str());
return {};
}
return model;

ectx.t_load_us = ggml_time_us() - t_start_load_us;

return std::make_unique<encodec_context>(std::move(ectx));
}

void encodec_free(encodec_context & ectx) {
Expand Down
19 changes: 4 additions & 15 deletions encodec.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,16 +131,7 @@ struct encodec_model {
};

struct encodec_context {
encodec_context(encodec_model & model) : model(model) {}

~encodec_context() {
if (model_owner) {
delete &model;
}
}

encodec_model & model;
bool model_owner = false;
std::unique_ptr<encodec_model> model;

struct ggml_context * ctx_audio;
struct ggml_tensor * reconstructed_audio;
Expand All @@ -155,15 +146,13 @@ struct encodec_context {
struct ggml_allocr * allocr = {};

// statistics
int64_t t_load_us = 0;
int64_t t_compute_ms = 0;
};

std::shared_ptr<encodec_context> encodec_load_model(const std::string & model_path);

struct encodec_model encodec_load_model_from_file(std::string fname);

struct encodec_context encodec_new_context_with_model(encodec_model & model);

bool encodec_model_eval(
bool encodec_reconstruct_audio(
encodec_context & ectx,
std::vector<float> & raw_audio,
int n_threads);
Expand Down
2 changes: 1 addition & 1 deletion examples/main/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ add_executable(${TARGET} main.cpp dr_wav.h)

install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE encodec.cpp ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_11)
target_compile_features(${TARGET} PRIVATE cxx_std_14)

if(MSVC)
target_compile_definitions(${TARGET} PRIVATE -D_CRT_SECURE_NO_WARNINGS=1)
Expand Down
40 changes: 14 additions & 26 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ struct encodec_params {
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());

// weights location
std::string model_path = "./ggml_weights";
std::string model_path = "/Users/pbannier/Documents/encodec.cpp/ggml_weights/ggml-model.bin";

// input location
std::string original_audio_path = "./input.wav";
std::string original_audio_path = "/Users/pbannier/Documents/encodec/test_24k.wav";

// output location
std::string dest_wav_path = "output.wav";
Expand Down Expand Up @@ -76,7 +76,7 @@ bool read_wav_from_disk(std::string in_path, std::vector<float>& audio_arr) {
return false;
}

fprintf(stderr, "Number of frames read = %lld.\n", total_frame_count);
fprintf(stderr, "%s: Number of frames read = %lld.\n", __func__, total_frame_count);

audio_arr.resize(total_frame_count);
memcpy(audio_arr.data(), raw_audio, total_frame_count * sizeof(float));
Expand All @@ -102,13 +102,6 @@ void write_wav_on_disk(std::vector<float>& audio_arr, std::string dest_path) {
fprintf(stderr, "%s: Number of frames written = %lld.\n", __func__, frames);
}

struct encodec_context encodec_init_from_params(encodec_params & params) {
encodec_model model = encodec_load_model_from_file(params.model_path);
encodec_context ectx = encodec_new_context_with_model(model);

return ectx;
}

int main(int argc, char **argv) {
ggml_time_init();
const int64_t t_main_start_us = ggml_time_us();
Expand All @@ -120,15 +113,12 @@ int main(int argc, char **argv) {
return 1;
}

int64_t t_load_us = 0;
int64_t t_eval_us = 0;

// initialize encodec context
const int64_t t_start_us = ggml_time_us();
encodec_context ectx = encodec_init_from_params(params);
t_load_us = ggml_time_us() - t_start_us;

printf("\n");
std::shared_ptr<encodec_context> ectx = encodec_load_model(params.model_path);
if (!ectx) {
printf("%s: error during loading model\n", __func__);
return 1;
}

// read audio from disk
std::vector<float> original_audio_arr;
Expand All @@ -138,29 +128,27 @@ int main(int argc, char **argv) {
}

// reconstruct audio
const int64_t t_eval_us_start = ggml_time_us();
if (!encodec_model_eval(ectx, original_audio_arr, params.n_threads)) {
if (!encodec_reconstruct_audio(*ectx, original_audio_arr, params.n_threads)) {
printf("%s: error during inference\n", __func__);
return 1;
}
t_eval_us = ggml_time_us() - t_eval_us_start;

// write reconstructed audio on disk
std::vector<float> audio_arr(ectx.reconstructed_audio->ne[0]);
memcpy(ectx.reconstructed_audio->data, audio_arr.data(), audio_arr.size() * sizeof(float));
std::vector<float> audio_arr(ectx->reconstructed_audio->ne[0]);
memcpy(ectx->reconstructed_audio->data, audio_arr.data(), audio_arr.size() * sizeof(float));
write_wav_on_disk(audio_arr, params.dest_wav_path);

// report timing
{
const int64_t t_main_end_us = ggml_time_us();

printf("\n\n");
printf("%s: load time = %8.2f ms\n", __func__, t_load_us/1000.0f);
printf("%s: eval time = %8.2f ms\n", __func__, t_eval_us/1000.0f);
printf("%s: load time = %8.2f ms\n", __func__, ectx->t_load_us/1000.0f);
printf("%s: eval time = %8.2f ms\n", __func__, ectx->t_compute_ms/1000.0f);
printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f);
}

encodec_free(ectx);
encodec_free(*ectx);

return 0;
}

0 comments on commit db79e88

Please sign in to comment.