Skip to content

Commit

Permalink
Loading model from both path or array of bytes
Browse files Browse the repository at this point in the history
  • Loading branch information
danielefernandes authored and bernardohenz committed Sep 30, 2020
1 parent 0c020d1 commit feb33f8
Show file tree
Hide file tree
Showing 31 changed files with 647 additions and 104 deletions.
2 changes: 1 addition & 1 deletion native_client/Makefile
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ clean:
rm -f deepspeech

$(DEEPSPEECH_BIN): client.cc Makefile
$(CXX) $(CFLAGS) $(CFLAGS_DEEPSPEECH) $(SOX_CFLAGS) client.cc $(LDFLAGS) $(SOX_LDFLAGS)
$(CXX) $(CFLAGS) $(CFLAGS_DEEPSPEECH) $(SOX_CFLAGS) client.cc $(LDFLAGS) $(SOX_LDFLAGS) -llzma -lbz2
ifeq ($(OS),Darwin)
install_name_tool -change bazel-out/local-opt/bin/native_client/libdeepspeech.so @rpath/libdeepspeech.so deepspeech
endif
Expand Down
Empty file modified native_client/alphabet.cc
100644 → 100755
Empty file.
Empty file modified native_client/alphabet.h
100644 → 100755
Empty file.
8 changes: 8 additions & 0 deletions native_client/args.h
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ bool extended_metadata = false;

bool json_output = false;

bool init_from_array_of_bytes = false;

int json_candidate_transcripts = 3;

int stream_size = 0;
Expand All @@ -59,6 +61,7 @@ void PrintHelp(const char* bin)
"\t--candidate_transcripts NUMBER\tNumber of candidate transcripts to include in JSON output\n"
"\t--stream size\t\t\tRun in stream mode, output intermediate results\n"
"\t--hot_words\t\t\tHot-words and their boosts. Word:Boost pairs are comma-separated\n"
"\t--init_from_bytes\t\tTest init model and scorer from array of bytes\n"
"\t--help\t\t\t\tShow help\n"
"\t--version\t\t\tPrint version and exits\n";
char* version = DS_Version();
Expand All @@ -80,6 +83,7 @@ bool ProcessArgs(int argc, char** argv)
{"t", no_argument, nullptr, 't'},
{"extended", no_argument, nullptr, 'e'},
{"json", no_argument, nullptr, 'j'},
{"init_from_bytes", no_argument, nullptr, 'B'},
{"candidate_transcripts", required_argument, nullptr, 150},
{"stream", required_argument, nullptr, 's'},
{"hot_words", required_argument, nullptr, 'w'},
Expand Down Expand Up @@ -135,6 +139,10 @@ bool ProcessArgs(int argc, char** argv)
case 'j':
json_output = true;
break;

case 'B':
init_from_array_of_bytes = true;
break;

case 150:
json_candidate_transcripts = atoi(optarg);
Expand Down
30 changes: 28 additions & 2 deletions native_client/client.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
#include <unistd.h>
#endif // NO_DIR
#include <vector>
#include <iostream>
#include <fstream>

#include "deepspeech.h"
#include "args.h"
Expand Down Expand Up @@ -415,8 +417,21 @@ main(int argc, char **argv)

// Initialise DeepSpeech
ModelState* ctx;
std::string buffer_model_str;
// sphinx-doc: c_ref_model_start
int status = DS_CreateModel(model, &ctx);
int status;
if (init_from_array_of_bytes){
// Reading model file to a char * buffer
std::ifstream is_model( model, std::ios::binary );
std::stringstream buffer_model;
buffer_model << is_model.rdbuf();
buffer_model_str = buffer_model.str();
status = DS_CreateModelFromBuffer(buffer_model_str.c_str(), buffer_model_str.size(), &ctx);
}else {
// Keep old method due to backwards compatibility
status = DS_CreateModel(model, &ctx);
}

if (status != 0) {
char* error = DS_ErrorCodeToErrorMessage(status);
fprintf(stderr, "Could not create model: %s\n", error);
Expand All @@ -433,7 +448,18 @@ main(int argc, char **argv)
}

if (scorer) {
status = DS_EnableExternalScorer(ctx, scorer);
if (init_from_array_of_bytes){
// Reading scorer file to a string buffer
std::ifstream is_scorer(scorer, std::ios::binary );
std::stringstream buffer_scorer;
buffer_scorer << is_scorer.rdbuf();
std::string tmp_str_scorer = buffer_scorer.str();
status = DS_EnableExternalScorerFromBuffer(ctx, tmp_str_scorer.c_str(), tmp_str_scorer.size());
} else {
// Keep old method due to backwards compatibility
status = DS_EnableExternalScorer(ctx, scorer);
}

if (status != 0) {
fprintf(stderr, "Could not enable external scorer.\n");
return 1;
Expand Down
85 changes: 53 additions & 32 deletions native_client/ctcdecode/scorer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,22 +29,24 @@ static const int32_t FILE_VERSION = 6;

int
Scorer::init(const std::string& lm_path,
bool load_from_bytes,
const Alphabet& alphabet)
{
set_alphabet(alphabet);
return load_lm(lm_path);
return load_lm(lm_path, load_from_bytes);
}

int
Scorer::init(const std::string& lm_path,
bool load_from_bytes,
const std::string& alphabet_config_path)
{
int err = alphabet_.init(alphabet_config_path.c_str());
int err = alphabet_.init(alphabet_config_path.c_str()); // Do we need to make this initiable from bytes?
if (err != 0) {
return err;
}
setup_char_map();
return load_lm(lm_path);
return load_lm(lm_path, load_from_bytes);
}

void
Expand All @@ -69,45 +71,60 @@ void Scorer::setup_char_map()
}
}

int Scorer::load_lm(const std::string& lm_path)
int Scorer::load_lm(const std::string& lm_string, bool load_from_bytes)
{
// Check if file is readable to avoid KenLM throwing an exception
const char* filename = lm_path.c_str();
if (access(filename, R_OK) != 0) {
return DS_ERR_SCORER_UNREADABLE;
}

// Check if the file format is valid to avoid KenLM throwing an exception
lm::ngram::ModelType model_type;
if (!lm::ngram::RecognizeBinary(filename, model_type)) {
return DS_ERR_SCORER_INVALID_LM;
if (!load_from_bytes){
// Check if file is readable to avoid KenLM throwing an exception
const char* filename = lm_string.c_str();
if (access(filename, R_OK) != 0) {
return DS_ERR_SCORER_UNREADABLE;
}

// Check if the file format is valid to avoid KenLM throwing an exception
lm::ngram::ModelType model_type;
if (!lm::ngram::RecognizeBinary(filename, model_type)) {
return DS_ERR_SCORER_INVALID_LM;
}
}

// Load the LM
lm::ngram::Config config;
config.load_method = util::LoadMethod::LAZY;
language_model_.reset(lm::ngram::LoadVirtual(filename, config));
max_order_ = language_model_->Order();

uint64_t package_size;
{
util::scoped_fd fd(util::OpenReadOrThrow(filename));
package_size = util::SizeFile(fd.get());
if (load_from_bytes){
language_model_.reset(lm::ngram::LoadVirtual(lm_string.c_str(), lm_string.size(), config));
} else {
language_model_.reset(lm::ngram::LoadVirtual(lm_string.c_str(), config));
}

max_order_ = language_model_->Order();
std::stringstream stst;
uint64_t trie_offset = language_model_->GetEndOfSearchOffset();
if (package_size <= trie_offset) {
// File ends without a trie structure
return DS_ERR_SCORER_NO_TRIE;

if (!load_from_bytes){
uint64_t package_size;
{
util::scoped_fd fd(util::OpenReadOrThrow(lm_string.c_str()));
package_size = util::SizeFile(fd.get());
}

if (package_size <= trie_offset) {
// File ends without a trie structure
return DS_ERR_SCORER_NO_TRIE;
}
// Read metadata and trie from file
std::ifstream fin(lm_string.c_str(), std::ios::binary);
stst<<fin.rdbuf();
} else {
stst = std::stringstream(lm_string);
}

// Read metadata and trie from file
std::ifstream fin(lm_path, std::ios::binary);
fin.seekg(trie_offset);
return load_trie(fin, lm_path);
stst.seekg(trie_offset);
return load_trie(stst, lm_string, load_from_bytes);
}

int Scorer::load_trie(std::ifstream& fin, const std::string& file_path)
int Scorer::load_trie(std::stringstream& fin, const std::string& file_path, bool load_from_bytes)
{

int magic;
fin.read(reinterpret_cast<char*>(&magic), sizeof(magic));
if (magic != MAGIC) {
Expand Down Expand Up @@ -140,9 +157,13 @@ int Scorer::load_trie(std::ifstream& fin, const std::string& file_path)
reset_params(alpha, beta);

fst::FstReadOptions opt;
opt.mode = fst::FstReadOptions::MAP;
opt.source = file_path;
dictionary.reset(FstType::Read(fin, opt));
if (load_from_bytes) {
dictionary.reset(fst::ConstFst<fst::StdArc>::Read(fin, opt));
} else {
opt.mode = fst::FstReadOptions::MAP;
opt.source = file_path;
dictionary.reset(FstType::Read(fin, opt));
}
return DS_ERR_OK;
}

Expand Down
6 changes: 4 additions & 2 deletions native_client/ctcdecode/scorer.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,11 @@ class Scorer {
Scorer& operator=(const Scorer&) = delete;

int init(const std::string &lm_path,
bool load_from_bytes,
const Alphabet &alphabet);

int init(const std::string &lm_path,
bool load_from_bytes,
const std::string &alphabet_config_path);

double get_log_cond_prob(const std::vector<std::string> &words,
Expand Down Expand Up @@ -84,7 +86,7 @@ class Scorer {
void fill_dictionary(const std::unordered_set<std::string> &vocabulary);

// load language model from given path
int load_lm(const std::string &lm_path);
int load_lm(const std::string &lm_path, bool load_from_bytes=false);

// language model weight
double alpha = 0.;
Expand All @@ -98,7 +100,7 @@ class Scorer {
// necessary setup after setting alphabet
void setup_char_map();

int load_trie(std::ifstream& fin, const std::string& file_path);
int load_trie(std::stringstream& fin, const std::string& file_path, bool load_from_bytes=false);

private:
std::unique_ptr<lm::base::Model> language_model_;
Expand Down
56 changes: 48 additions & 8 deletions native_client/deepspeech.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -263,8 +263,9 @@ StreamingState::processBatch(const vector<float>& buf, unsigned int n_steps)
}

int
DS_CreateModel(const char* aModelPath,
ModelState** retval)
DS_CreateModel_(const std::string &aModelString,
bool init_from_bytes,
ModelState** retval)
{
*retval = nullptr;

Expand All @@ -277,7 +278,7 @@ DS_CreateModel(const char* aModelPath,
LOGD("DeepSpeech: %s", ds_git_version());
#endif

if (!aModelPath || strlen(aModelPath) < 1) {
if ( (!init_from_bytes && (strlen(aModelString) < 1)) || (init_from_bytes && (bufferSize<1))) {
std::cerr << "No model specified, cannot continue." << std::endl;
return DS_ERR_NO_MODEL;
}
Expand All @@ -294,8 +295,8 @@ DS_CreateModel(const char* aModelPath,
std::cerr << "Could not allocate model state." << std::endl;
return DS_ERR_FAIL_CREATE_MODEL;
}

int err = model->init(aModelPath);
int err = model->init(aModelString, init_from_bytes, bufferSize);
if (err != DS_ERR_OK) {
return err;
}
Expand All @@ -304,6 +305,22 @@ DS_CreateModel(const char* aModelPath,
return DS_ERR_OK;
}

int
DS_CreateModel(const char* aModelPath,
ModelState** retval)
{
return DS_CreateModel_(aModelPath, false, retval);
}

int
DS_CreateModelFromBuffer(const char* aModelBuffer,
size_t bufferSize,
ModelState** retval)
{
return DS_CreateModel_(aModelBuffer, true, retval, bufferSize);
}


unsigned int
DS_GetModelBeamWidth(const ModelState* aCtx)
{
Expand All @@ -330,18 +347,41 @@ DS_FreeModel(ModelState* ctx)
}

int
DS_EnableExternalScorer(ModelState* aCtx,
const char* aScorerPath)
DS_EnableExternalScorer_(ModelState* aCtx,
const std::string &aScorerString,
bool init_from_bytes)
{
std::unique_ptr<Scorer> scorer(new Scorer());
int err = scorer->init(aScorerPath, aCtx->alphabet_);

int err;
if (init_from_bytes)
err = scorer->init(std::string(aScorerString, bufferSize), init_from_bytes, aCtx->alphabet_);
else
err = scorer->init(aScorerString, init_from_bytes, aCtx->alphabet_);


if (err != 0) {
return DS_ERR_INVALID_SCORER;
}
aCtx->scorer_ = std::move(scorer);
return DS_ERR_OK;
}

int
DS_EnableExternalScorer(ModelState* aCtx,
const char* aScorerPath)
{
return DS_EnableExternalScorer_(aCtx, aScorerPath, false);
}

int
DS_EnableExternalScorerFromBuffer(ModelState* aCtx,
const char* aScorerBuffer,
size_t bufferSize)
{
return DS_EnableExternalScorer_(aCtx, aScorerBuffer, true, bufferSize);
}

int
DS_AddHotWord(ModelState* aCtx,
const char* word,
Expand Down
Loading

0 comments on commit feb33f8

Please sign in to comment.