Skip to content

Commit

Permalink
-Change string to char* in the API
Browse files Browse the repository at this point in the history
  • Loading branch information
bernardohenz committed Sep 30, 2020
1 parent 8ea719d commit 5a5ef46
Show file tree
Hide file tree
Showing 9 changed files with 52 additions and 37 deletions.
7 changes: 5 additions & 2 deletions native_client/client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -417,14 +417,16 @@ main(int argc, char **argv)

// Initialise DeepSpeech
ModelState* ctx;
std::string buffer_model_str;
// sphinx-doc: c_ref_model_start
int status;
if (init_from_array_of_bytes){
// Reading model file to a char * buffer
std::ifstream is_model( model, std::ios::binary );
std::stringstream buffer_model;
buffer_model << is_model.rdbuf();
status = DS_CreateModelFromBuffer(buffer_model.str(), &ctx);
buffer_model_str = buffer_model.str();
status = DS_CreateModelFromBuffer(buffer_model_str.c_str(), buffer_model_str.size(), &ctx);
}else {
// Keep old method due to backwards compatibility
status = DS_CreateModel(model, &ctx);
Expand All @@ -451,7 +453,8 @@ main(int argc, char **argv)
std::ifstream is_scorer(scorer, std::ios::binary );
std::stringstream buffer_scorer;
buffer_scorer << is_scorer.rdbuf();
status = DS_EnableExternalScorerFromBuffer(ctx, buffer_scorer.str());
std::string tmp_str_scorer = buffer_scorer.str();
status = DS_EnableExternalScorerFromBuffer(ctx, tmp_str_scorer.c_str(), tmp_str_scorer.size());
} else {
// Keep old method due to backwards compatibility
status = DS_EnableExternalScorer(ctx, scorer);
Expand Down
31 changes: 20 additions & 11 deletions native_client/deepspeech.cc
Original file line number Diff line number Diff line change
Expand Up @@ -263,9 +263,10 @@ StreamingState::processBatch(const vector<float>& buf, unsigned int n_steps)
}

int
DS_CreateModel_(const std::string &aModelString,
DS_CreateModel_(const char* aModelString,
bool init_from_bytes,
ModelState** retval)
ModelState** retval,
size_t bufferSize=0)
{
*retval = nullptr;

Expand All @@ -278,7 +279,7 @@ DS_CreateModel_(const std::string &aModelString,
LOGD("DeepSpeech: %s", ds_git_version());
#endif

if (aModelString.length() < 1) {
if ( (!init_from_bytes && (strlen(aModelString) < 1)) || (init_from_bytes && (bufferSize<1))) {
std::cerr << "No model specified, cannot continue." << std::endl;
return DS_ERR_NO_MODEL;
}
Expand All @@ -296,7 +297,7 @@ DS_CreateModel_(const std::string &aModelString,
return DS_ERR_FAIL_CREATE_MODEL;
}

int err = model->init(aModelString, init_from_bytes);
int err = model->init(aModelString, init_from_bytes, bufferSize);
if (err != DS_ERR_OK) {
return err;
}
Expand All @@ -313,10 +314,11 @@ DS_CreateModel(const char* aModelPath,
}

int
DS_CreateModelFromBuffer(const std::string &aModelBuffer,
DS_CreateModelFromBuffer(const char* aModelBuffer,
size_t bufferSize,
ModelState** retval)
{
return DS_CreateModel_(aModelBuffer, true, retval);
return DS_CreateModel_(aModelBuffer, true, retval, bufferSize);
}


Expand Down Expand Up @@ -347,12 +349,18 @@ DS_FreeModel(ModelState* ctx)

int
DS_EnableExternalScorer_(ModelState* aCtx,
const std::string &aScorerString,
bool init_from_bytes)
const char* aScorerString,
bool init_from_bytes,
size_t bufferSize=0)
{
std::unique_ptr<Scorer> scorer(new Scorer());

int err = scorer->init(aScorerString, init_from_bytes, aCtx->alphabet_);
int err;
if (init_from_bytes)
err = scorer->init(std::string(aScorerString, bufferSize), init_from_bytes, aCtx->alphabet_);
else
err = scorer->init(aScorerString, init_from_bytes, aCtx->alphabet_);


if (err != 0) {
return DS_ERR_INVALID_SCORER;
Expand All @@ -370,9 +378,10 @@ DS_EnableExternalScorer(ModelState* aCtx,

int
DS_EnableExternalScorerFromBuffer(ModelState* aCtx,
const std::string &aScorerBuffer)
const char* aScorerBuffer,
size_t bufferSize)
{
return DS_EnableExternalScorer_(aCtx, aScorerBuffer, true);
return DS_EnableExternalScorer_(aCtx, aScorerBuffer, true, bufferSize);
}

int
Expand Down
7 changes: 5 additions & 2 deletions native_client/deepspeech.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ typedef struct Metadata {
APPLY(DS_ERR_SCORER_NO_TRIE, 0x2007, "Reached end of scorer file before loading vocabulary trie.") \
APPLY(DS_ERR_SCORER_INVALID_TRIE, 0x2008, "Invalid magic in trie header.") \
APPLY(DS_ERR_SCORER_VERSION_MISMATCH, 0x2009, "Scorer file version does not match expected version.") \
APPLY(DS_ERR_MODEL_NOT_SUP_BUFFER, 0x2010, "Load from buffer does not support memorymaped models.") \
APPLY(DS_ERR_FAIL_INIT_MMAP, 0x3000, "Failed to initialize memory mapped model.") \
APPLY(DS_ERR_FAIL_INIT_SESS, 0x3001, "Failed to initialize the session.") \
APPLY(DS_ERR_FAIL_INTERPRETER, 0x3002, "Interpreter failed.") \
Expand Down Expand Up @@ -118,7 +119,8 @@ int DS_CreateModel(const char* aModelPath,
* @return Zero on success, non-zero on failure.
*/
DEEPSPEECH_EXPORT
int DS_CreateModelFromBuffer(const std::string &aModelBuffer,
int DS_CreateModelFromBuffer(const char* aModelBuffer,
size_t bufferSize,
ModelState** retval);


Expand Down Expand Up @@ -185,7 +187,8 @@ int DS_EnableExternalScorer(ModelState* aCtx,
*/
DEEPSPEECH_EXPORT
int DS_EnableExternalScorerFromBuffer(ModelState* aCtx,
const std::string &aScorerBuffer);
const char* aScorerBuffer,
size_t bufferSize);


/**
Expand Down
2 changes: 1 addition & 1 deletion native_client/modelstate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ ModelState::~ModelState()
}

int
ModelState::init(const std::string &model_string, bool init_from_bytes)
ModelState::init(const char* model_string, bool init_from_bytes, size_t bufferSize)
{
return DS_ERR_OK;
}
Expand Down
2 changes: 1 addition & 1 deletion native_client/modelstate.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ struct ModelState {
ModelState();
virtual ~ModelState();

virtual int init(const std::string &model_string, bool init_from_bytes);
virtual int init(const char* model_string, bool init_from_bytes, size_t bufferSize);

virtual void compute_mfcc(const std::vector<float>& audio_buffer, std::vector<float>& mfcc_output) = 0;

Expand Down
14 changes: 5 additions & 9 deletions native_client/tflitemodelstate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -156,24 +156,21 @@ getTfliteDelegates()
}

int
TFLiteModelState::init(const std::string &model_string, bool init_from_bytes)
TFLiteModelState::init(const char *model_string, bool init_from_bytes, size_t bufferSize)
{
int err = ModelState::init(model_string, init_from_bytes);
int err = ModelState::init(model_string, init_from_bytes, bufferSize);
if (err != DS_ERR_OK) {
return err;
}

if (init_from_bytes){
char *tmp_buffer = new char[model_string.size()];
std::copy(model_string.begin(), model_string.end(), tmp_buffer);
// Using c_str does not work
fbmodel_ = tflite::FlatBufferModel::VerifyAndBuildFromBuffer(tmp_buffer,model_string.size());
fbmodel_ = tflite::FlatBufferModel::VerifyAndBuildFromBuffer(model_string, bufferSize);
if (!fbmodel_) {
std::cerr << "Error at reading model buffer " << std::endl;
return DS_ERR_FAIL_INIT_MMAP;
}
} else {
fbmodel_ = tflite::FlatBufferModel::BuildFromFile(model_string.c_str());
fbmodel_ = tflite::FlatBufferModel::BuildFromFile(model_string);
if (!fbmodel_) {
std::cerr << "Error at reading model file " << model_string << std::endl;
return DS_ERR_FAIL_INIT_MMAP;
Expand Down Expand Up @@ -334,7 +331,6 @@ TFLiteModelState::init(const std::string &model_string, bool init_from_bytes)
assert(dims_c->data[1] == dims_h->data[1]);
assert(state_size_ > 0);
state_size_ = dims_c->data[1];

return DS_ERR_OK;
}

Expand Down
2 changes: 1 addition & 1 deletion native_client/tflitemodelstate.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ struct TFLiteModelState : public ModelState
TFLiteModelState();
virtual ~TFLiteModelState();

virtual int init(const std::string &model_string, bool init_from_bytes) override;
virtual int init(const char* model_string, bool init_from_bytes, size_t bufferSize) override;

virtual void compute_mfcc(const std::vector<float>& audio_buffer,
std::vector<float>& mfcc_output) override;
Expand Down
22 changes: 13 additions & 9 deletions native_client/tfmodelstate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,21 @@ TFModelState::~TFModelState()
}
}

int loadGraphFromBinaryData(Env* env, const std::string& data,
int loadGraphFromBinaryData(Env* env, const char* data, size_t bufferSize,
::tensorflow::protobuf::MessageLite* proto) {

if (!proto->ParseFromString(data)) {
std::string dataString(data, bufferSize);
if (!proto->ParseFromString(dataString)) {
std::cerr << "Can't parse data as binary proto" << std::endl;
return -1;
}
return 0;
}

int
TFModelState::init(const std::string &model_string, bool init_from_bytes)
TFModelState::init(const char* model_string, bool init_from_bytes, size_t bufferSize)
{
int err = ModelState::init(model_string, init_from_bytes);
int err = ModelState::init(model_string, init_from_bytes, bufferSize);
if (err != DS_ERR_OK) {
return err;
}
Expand All @@ -46,16 +47,16 @@ TFModelState::init(const std::string &model_string, bool init_from_bytes)
mmap_env_.reset(new MemmappedEnv(Env::Default()));
bool is_mmap = false;
if (init_from_bytes) {
int loadGraphStatus = loadGraphFromBinaryData(Env::Default(), model_string, &graph_def_);
int loadGraphStatus = loadGraphFromBinaryData(mmap_env_.get(), model_string, bufferSize, &graph_def_);
if (loadGraphStatus != 0) {
return DS_ERR_FAIL_CREATE_SESS;
}
} else {
is_mmap = model_string.find(".pbmm") != std::string::npos;
is_mmap = std::string(model_string).find(".pbmm") != std::string::npos;
if (!is_mmap) {
std::cerr << "Warning: reading entire model file into memory. Transform model file into an mmapped graph to reduce heap usage." << std::endl;
} else {
status = mmap_env_->InitializeFromFile(model_string.c_str());
status = mmap_env_->InitializeFromFile(model_string);
if (!status.ok()) {
std::cerr << status << std::endl;
return DS_ERR_FAIL_INIT_MMAP;
Expand All @@ -77,14 +78,17 @@ TFModelState::init(const std::string &model_string, bool init_from_bytes)
session_.reset(session);

if (init_from_bytes){
// Need some help
if( is_mmap) {
std::cerr << "Load from buffer does not support .pbmm models." << std::endl;
return DS_ERR_MODEL_NOT_SUP_BUFFER;
}
} else {
if (is_mmap) {
status = ReadBinaryProto(mmap_env_.get(),
MemmappedFileSystem::kMemmappedPackageDefaultGraphDef,
&graph_def_);
} else {
status = ReadBinaryProto(Env::Default(), model_string.c_str(), &graph_def_);
status = ReadBinaryProto(Env::Default(), model_string, &graph_def_);
}
}

Expand Down
2 changes: 1 addition & 1 deletion native_client/tfmodelstate.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ struct TFModelState : public ModelState
TFModelState();
virtual ~TFModelState();

virtual int init(const std::string &model_string, bool init_from_bytes) override;
virtual int init(const char* model_string, bool init_from_bytes, size_t bufferSize) override;

virtual void infer(const std::vector<float>& mfcc,
unsigned int n_frames,
Expand Down

0 comments on commit 5a5ef46

Please sign in to comment.