Skip to content

Commit

Permalink
Merge pull request #3125 from mozilla/utf8-regressions
Browse files Browse the repository at this point in the history
Fix some regressions from Alphabet refactoring (Fixes #3123)
  • Loading branch information
reuben authored Jul 4, 2020
2 parents 1964e80 + 03ed4a4 commit 66d1f16
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 9 deletions.
14 changes: 14 additions & 0 deletions native_client/alphabet.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ Alphabet::init(const char *config_file)
if (line == " ") {
space_label_ = label;
}
if (line.length() == 0) {
continue;
}
label_to_str_[label] = line;
str_to_label_[line] = label;
++label;
Expand Down Expand Up @@ -187,3 +190,14 @@ Alphabet::Encode(const std::string& input) const
}
return result;
}

std::vector<unsigned int>
UTF8Alphabet::Encode(const std::string& input) const
{
std::vector<unsigned int> result;
for (auto byte_char : input) {
std::string byte_str(1, byte_char);
result.push_back(EncodeSingle(byte_str));
}
return result;
}
5 changes: 3 additions & 2 deletions native_client/alphabet.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class Alphabet {

// Encode a sequence of character/output classes into a sequence of labels.
// Characters are assumed to always take a single Unicode codepoint.
std::vector<unsigned int> Encode(const std::string& input) const;
virtual std::vector<unsigned int> Encode(const std::string& input) const;

protected:
size_t size_;
Expand All @@ -77,7 +77,8 @@ class UTF8Alphabet : public Alphabet
int init(const char*) override {
return 0;
}
};

std::vector<unsigned int> Encode(const std::string& input) const override;
};

#endif //ALPHABET_H
24 changes: 19 additions & 5 deletions native_client/ctcdecode/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
from . import swigwrapper # pylint: disable=import-self
from .swigwrapper import UTF8Alphabet

__version__ = swigwrapper.__version__
# This module is built with SWIG_PYTHON_STRICT_BYTE_CHAR so we must handle
# string encoding explicitly, here and throughout this file.
__version__ = swigwrapper.__version__.decode('utf-8')

# Hack: import error codes by matching on their names, as SWIG unfortunately
# does not support binding enums to Python in a scoped manner yet.
Expand All @@ -30,7 +32,7 @@ def __init__(self, alpha=None, beta=None, scorer_path=None, alphabet=None):
assert beta is not None, 'beta parameter is required'
assert scorer_path, 'scorer_path parameter is required'

err = self.init(scorer_path, alphabet)
err = self.init(scorer_path.encode('utf-8'), alphabet)
if err != 0:
raise ValueError('Scorer initialization failed with error code 0x{:X}'.format(err))

Expand All @@ -41,15 +43,27 @@ class Alphabet(swigwrapper.Alphabet):
"""Convenience wrapper for Alphabet which calls init in the constructor"""
def __init__(self, config_path):
super(Alphabet, self).__init__()
err = self.init(config_path)
err = self.init(config_path.encode('utf-8'))
if err != 0:
raise ValueError('Alphabet initialization failed with error code 0x{:X}'.format(err))

def EncodeSingle(self, input):
return super(Alphabet, self).EncodeSingle(input.encode('utf-8'))

def Encode(self, input):
"""Convert SWIG's UnsignedIntVec to a Python list"""
res = super(Alphabet, self).Encode(input)
# Convert SWIG's UnsignedIntVec to a Python list
res = super(Alphabet, self).Encode(input.encode('utf-8'))
return [el for el in res]

def DecodeSingle(self, input):
res = super(Alphabet, self).DecodeSingle(input)
return res.decode('utf-8')

def Decode(self, input):
res = super(Alphabet, self).Decode(input)
return res.decode('utf-8')



def ctc_beam_search_decoder(probs_seq,
alphabet,
Expand Down
5 changes: 3 additions & 2 deletions native_client/ctcdecode/ctc_beam_search_decoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,9 @@ DecoderState::decode(size_t num_results) const
// score the last word of each prefix that doesn't end with space
if (ext_scorer_) {
for (size_t i = 0; i < beam_size_ && i < prefixes_copy.size(); ++i) {
auto prefix = prefixes_copy[i];
if (!ext_scorer_->is_scoring_boundary(prefix->parent, prefix->character)) {
PathTrie* prefix = prefixes_copy[i];
PathTrie* prefix_boundary = ext_scorer_->is_utf8_mode() ? prefix : prefix->parent;
if (prefix_boundary && !ext_scorer_->is_scoring_boundary(prefix_boundary, prefix->character)) {
float score = 0.0;
std::vector<std::string> ngram = ext_scorer_->make_ngram(prefix);
bool bos = ngram.size() < ext_scorer_->get_max_order();
Expand Down
1 change: 1 addition & 0 deletions native_client/ctcdecode/swigwrapper.i
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
%{
#include "ctc_beam_search_decoder.h"
#define SWIG_FILE_WITH_INIT
#define SWIG_PYTHON_STRICT_BYTE_CHAR
#include "workspace_status.h"
%}

Expand Down

0 comments on commit 66d1f16

Please sign in to comment.