Skip to content

Commit

Permalink
Allow flexible forwarding of NLP columns to SBS (#56)
Browse files Browse the repository at this point in the history
* initial attempt

* refactor token properties inside stitching

* read from CLI

* fix test

* handle empty confidence

* set confidence when present

* use pinned kaldi dockerhub image

* update version
  • Loading branch information
nishchalb authored Oct 31, 2024
1 parent dcf2655 commit 3446afd
Show file tree
Hide file tree
Showing 7 changed files with 89 additions and 50 deletions.
3 changes: 2 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Using kaldi image for pre-built OpenFST, version is 1.7.2
FROM kaldiasr/kaldi:latest as kaldi-base
FROM kaldiasr/kaldi:cpu-debian10-2024-07-29 as kaldi-base

FROM debian:11

COPY --from=kaldi-base /opt/kaldi/tools/openfst /opt/openfst
Expand Down
65 changes: 40 additions & 25 deletions src/fstalign.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -242,15 +242,15 @@ vector<Stitching> make_stitches(wer_alignment &alignment, vector<RawCtmRecord> h
stitches.emplace_back();
Stitching &part = stitches.back();
part.classLabel = tk_classLabel;
part.reftk = ref_tk;
part.hyptk = hyp_tk;
part.reftk = {ref_tk};
part.hyptk = {hyp_tk};
bool del = false, ins = false, sub = false;
if (ref_tk == INS) {
part.comment = "ins";
} else if (hyp_tk == DEL) {
part.comment = "del";
} else if (hyp_tk != ref_tk) {
part.comment = "sub(" + part.hyptk + ")";
part.comment = "sub(" + part.hyptk.token + ")";
}

// for classes, we will have only one token in the global vector
Expand Down Expand Up @@ -281,10 +281,10 @@ vector<Stitching> make_stitches(wer_alignment &alignment, vector<RawCtmRecord> h

if (!hyp_ctm_rows.empty()) {
auto ctmPart = hyp_ctm_rows[hypRowIndex];
part.start_ts = ctmPart.start_time_secs;
part.duration = ctmPart.duration_secs;
part.end_ts = ctmPart.start_time_secs + ctmPart.duration_secs;
part.confidence = ctmPart.confidence;
part.hyptk.start_ts = ctmPart.start_time_secs;
part.hyptk.duration = ctmPart.duration_secs;
part.hyptk.end_ts = ctmPart.start_time_secs + ctmPart.duration_secs;
part.hyptk.confidence = ctmPart.confidence;

part.hyp_orig = ctmPart.word;
// sanity check
Expand All @@ -308,21 +308,24 @@ vector<Stitching> make_stitches(wer_alignment &alignment, vector<RawCtmRecord> h
float ts = stof(hypNlpPart.ts);
float endTs = stof(hypNlpPart.endTs);

part.start_ts = ts;
part.end_ts = endTs;
part.duration = endTs - ts;
part.hyptk.start_ts = ts;
part.hyptk.end_ts = endTs;
part.hyptk.duration = endTs - ts;
} else if (!hypNlpPart.ts.empty()) {
float ts = stof(hypNlpPart.ts);

part.start_ts = ts;
part.end_ts = ts;
part.duration = 0.0;
part.hyptk.start_ts = ts;
part.hyptk.end_ts = ts;
part.hyptk.duration = 0.0;
} else if (!hypNlpPart.endTs.empty()) {
float endTs = stof(hypNlpPart.endTs);

part.start_ts = endTs;
part.end_ts = endTs;
part.duration = 0.0;
part.hyptk.start_ts = endTs;
part.hyptk.end_ts = endTs;
part.hyptk.duration = 0.0;
}
if (!hypNlpPart.confidence.empty()) {
part.hyptk.confidence = stof(hypNlpPart.confidence);
}
}

Expand Down Expand Up @@ -575,15 +578,15 @@ void write_stitches_to_nlp(vector<Stitching>& stitches, ofstream &output_nlp_fil
// if the comment starts with 'ins'
if (stitch.comment.find("ins") == 0 && !add_inserts) {
// there's no nlp row info for such case, let's skip over it
if (stitch.confidence >= 1) {
logger->warn("an insertion with high confidence was found for {}@{}", stitch.hyptk, stitch.start_ts);
if (stitch.hyptk.confidence >= 1) {
logger->warn("an insertion with high confidence was found for {}@{}", stitch.hyptk.token, stitch.hyptk.start_ts);
}

continue;
}

string original_nlp_token = stitch.nlpRow.token;
string ref_tk = stitch.reftk;
string ref_tk = stitch.reftk.token;

// trying to salvage some of the original punctuation in a relatively safe manner
if (iequals(ref_tk, original_nlp_token)) {
Expand All @@ -597,21 +600,21 @@ void write_stitches_to_nlp(vector<Stitching>& stitches, ofstream &output_nlp_fil
ref_tk = original_nlp_token;
} else if (stitch.comment.find("ins") == 0) {
assert(add_inserts);
logger->debug("an insertion was found for {} {}", stitch.hyptk, stitch.comment);
logger->debug("an insertion was found for {} {}", stitch.hyptk.token, stitch.comment);
ref_tk = "";
stitch.comment = "ins(" + stitch.hyptk + ")";
stitch.comment = "ins(" + stitch.hyptk.token + ")";
}

if (ref_tk == NOOP) {
continue;
}

output_nlp_file << ref_tk << "|" << stitch.nlpRow.speakerId << "|";
if (stitch.hyptk == DEL) {
if (stitch.hyptk.token == DEL) {
// we have no ts/endTs data to put...
output_nlp_file << "||";
} else {
output_nlp_file << fmt::format("{0:.4f}", stitch.start_ts) << "|" << fmt::format("{0:.4f}", stitch.end_ts)
output_nlp_file << fmt::format("{0:.4f}", stitch.hyptk.start_ts) << "|" << fmt::format("{0:.4f}", stitch.hyptk.end_ts)
<< "|";
}

Expand All @@ -632,7 +635,7 @@ void write_stitches_to_nlp(vector<Stitching>& stitches, ofstream &output_nlp_fil
}

void HandleWer(FstLoader& refLoader, FstLoader& hypLoader, SynonymEngine &engine, const string& output_sbs, const string& output_nlp,
AlignerOptions alignerOptions, bool add_inserts_nlp, bool use_case) {
AlignerOptions alignerOptions, bool add_inserts_nlp, bool use_case, std::vector<string> ref_extra_columns, std::vector<string> hyp_extra_columns) {
// int speaker_switch_context_size, int numBests, int pr_threshold, string symbols_filename,
// string composition_approach, bool record_case_stats) {
auto logger = logger::GetOrCreateLogger("fstalign");
Expand Down Expand Up @@ -698,7 +701,7 @@ void HandleWer(FstLoader& refLoader, FstLoader& hypLoader, SynonymEngine &engine
JsonLogUnigramBigramStats(topAlignment);
if (!output_sbs.empty()) {
logger->info("output_sbs = {}", output_sbs);
WriteSbs(topAlignment, stitches, output_sbs);
WriteSbs(topAlignment, stitches, output_sbs, ref_extra_columns, hyp_extra_columns);
}

if (!output_nlp.empty() && !nlp_ref_loader) {
Expand All @@ -720,3 +723,15 @@ void HandleAlign(NlpFstLoader& refLoader, CtmFstLoader& hypLoader, SynonymEngine
align_stitches_to_nlp(refLoader, stitches);
write_stitches_to_nlp(stitches, output_nlp_file, refLoader.mJsonNorm);
}

string GetTokenPropertyAsString(Stitching stitch, bool refToken, string property) {
std::unordered_map<std::string, std::function<string(Token)>> col_name_to_val = {
{"speaker", [](Token tk) {return tk.speaker;}},
{"ts", [](Token tk) {return to_string(tk.start_ts);}},
{"endTs", [](Token tk) {return to_string(tk.end_ts);}},
{"confidence", [](Token tk) {return to_string(tk.confidence);}},
};
if (refToken) return col_name_to_val[property](stitch.reftk);
if (!refToken) return col_name_to_val[property](stitch.hyptk);
return "";
}
29 changes: 15 additions & 14 deletions src/fstalign.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,21 @@ fstalign.h
using namespace std;
using namespace fst;

// Represent information associated with a reference or hypothesis token
struct Token {
string token;
float start_ts=0.0;
float end_ts=0.0;
float duration=0.0;
float confidence=-1.0;
string speaker;
};

// Stitchings will be used to represent fstalign output, combining reference,
// hypothesis, and error information into a record-like data structure.
struct Stitching {
string reftk;
string hyptk;
float start_ts;
float end_ts;
float duration;
float confidence;
Token reftk;
Token hyptk;
string classLabel;
RawNlpRecord nlpRow;
string hyp_orig;
Expand All @@ -42,17 +48,12 @@ struct AlignerOptions {
int levenstein_maximum_error_streak = 100;
};

// original
// void HandleWer(FstLoader *refLoader, FstLoader *hypLoader, SynonymEngine *engine, string output_sbs, string
// output_nlp,
// int speaker_switch_context_size, int numBests, int pr_threshold, string symbols_filename,
// string composition_approach, bool record_case_stats);
// void HandleAlign(NlpFstLoader *refLoader, CtmFstLoader *hypLoader, SynonymEngine *engine, ofstream &output_nlp_file,
// int numBests, string symbols_filename, string composition_approach);

void HandleWer(FstLoader& refLoader, FstLoader& hypLoader, SynonymEngine &engine, const string& output_sbs, const string& output_nlp,
AlignerOptions alignerOptions, bool add_inserts_nlp = false, bool use_case = false);
AlignerOptions alignerOptions, bool add_inserts_nlp, bool use_case, std::vector<string> ref_extra_columns, std::vector<string> hyp_extra_columns);
void HandleAlign(NlpFstLoader &refLoader, CtmFstLoader &hypLoader, SynonymEngine &engine, ofstream &output_nlp_file,
AlignerOptions alignerOptions);

string GetTokenPropertyAsString(Stitching stitch, bool refToken, string property);

#endif // __FSTALIGN_H__
9 changes: 8 additions & 1 deletion src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ int main(int argc, char **argv) {
bool disable_cutoffs = false;
bool disable_hyphen_ignore = false;

std::vector<string> ref_extra_columns = std::vector<string>();
std::vector<string> hyp_extra_columns = std::vector<string>();

CLI::App app("Rev FST Align");
app.set_help_all_flag("--help-all", "Expand all help");
app.add_flag("--version", version, "Show fstalign version.");
Expand Down Expand Up @@ -97,6 +100,10 @@ int main(int argc, char **argv) {

c->add_option("--composition-approach", composition_approach,
"Desired composition logic. Choices are 'standard' or 'adapted'");
c->add_option("--ref-extra-cols", ref_extra_columns,
"Extra columns from the reference to include in SBS output.");
c->add_option("--hyp-extra-cols", hyp_extra_columns,
"Extra columns from the hypothesis to include in SBS output.");
}
get_wer->add_option("--wer-sidecar", wer_sidecar_filename,
"WER sidecar json file.");
Expand Down Expand Up @@ -180,7 +187,7 @@ int main(int argc, char **argv) {
}

if (command == "wer") {
HandleWer(*ref, *hyp, engine, output_sbs, output_nlp, alignerOptions, add_inserts_nlp, use_case);
HandleWer(*ref, *hyp, engine, output_sbs, output_nlp, alignerOptions, add_inserts_nlp, use_case, ref_extra_columns, hyp_extra_columns);
} else if (command == "align") {
if (output_nlp.empty()) {
console->error("the output nlp file must be specified");
Expand Down
2 changes: 1 addition & 1 deletion src/version.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#pragma once

#define FSTALIGNER_VERSION_MAJOR 1
#define FSTALIGNER_VERSION_MINOR 13
#define FSTALIGNER_VERSION_MINOR 14
#define FSTALIGNER_VERSION_PATCH 0
29 changes: 22 additions & 7 deletions src/wer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -262,8 +262,8 @@ void RecordCaseWer(const vector<Stitching> &aligned_stitches) {
for (const auto &stitch : aligned_stitches) {
const string &hyp = stitch.hyp_orig;
const string &ref = stitch.nlpRow.token;
const string &reftk = stitch.reftk;
const string &hyptk = stitch.hyptk;
const string &reftk = stitch.reftk.token;
const string &hyptk = stitch.hyptk.token;
const string &ref_casing = stitch.nlpRow.casing;

if (hyptk == DEL || reftk == INS) {
Expand Down Expand Up @@ -526,7 +526,7 @@ void AddErrorGroup(ErrorGroups &groups, size_t &line, string &ref, string &hyp)
hyp = "";
}

void WriteSbs(wer_alignment &topAlignment, const vector<Stitching>& stitches, string sbs_filename) {
void WriteSbs(wer_alignment &topAlignment, const vector<Stitching>& stitches, string sbs_filename, const vector<string> extra_ref_columns, const vector<string> extra_hyp_columns) {
auto logger = logger::GetOrCreateLogger("wer");
logger->set_level(spdlog::level::info);

Expand All @@ -536,7 +536,14 @@ void WriteSbs(wer_alignment &topAlignment, const vector<Stitching>& stitches, st
AlignmentTraversor visitor(topAlignment);
string prev_tk_classLabel = "";
logger->info("Side-by-Side alignment info going into {}", sbs_filename);
myfile << fmt::format("{0:>20}\t{1:20}\t{2}\t{3}\t{4}", "ref_token", "hyp_token", "IsErr", "Class", "Wer_Tag_Entities") << endl;
myfile << fmt::format("{0:>20}\t{1:20}\t{2}\t{3}\t{4}", "ref_token", "hyp_token", "IsErr", "Class", "Wer_Tag_Entities");
for (string col_name: extra_ref_columns) {
myfile << fmt::format("\tref_{0}", col_name);
}
for (string col_name: extra_hyp_columns) {
myfile << fmt::format("\thyp_{0}", col_name);
}
myfile << endl;

// keep track of error groupings
ErrorGroups groups_err;
Expand All @@ -554,8 +561,8 @@ void WriteSbs(wer_alignment &topAlignment, const vector<Stitching>& stitches, st
for (auto wer_tag: wer_tags) {
tk_wer_tags = tk_wer_tags + "###" + wer_tag.tag_id + "_" + wer_tag.entity_type + "###|";
}
string ref_tk = p_stitch.reftk;
string hyp_tk = p_stitch.hyptk;
string ref_tk = p_stitch.reftk.token;
string hyp_tk = p_stitch.hyptk.token;
string tag = "";

if (ref_tk == NOOP) {
Expand Down Expand Up @@ -587,7 +594,15 @@ void WriteSbs(wer_alignment &topAlignment, const vector<Stitching>& stitches, st
eff_class = tk_classLabel;
}

myfile << fmt::format("{0:>20}\t{1:20}\t{2}\t{3}\t{4}", ref_tk, hyp_tk, tag, eff_class, tk_wer_tags) << endl;
myfile << fmt::format("{0:>20}\t{1:20}\t{2}\t{3}\t{4}", ref_tk, hyp_tk, tag, eff_class, tk_wer_tags);

for (string col_name: extra_ref_columns) {
myfile << fmt::format("\t{0}", GetTokenPropertyAsString(p_stitch, true, col_name));
}
for (string col_name: extra_hyp_columns) {
myfile << fmt::format("\t{0}", GetTokenPropertyAsString(p_stitch, false, col_name));
}
myfile << endl;
offset++;
}

Expand Down
2 changes: 1 addition & 1 deletion src/wer.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,5 +49,5 @@ void CalculatePrecisionRecall(wer_alignment &topAlignment, int threshold);
typedef vector<pair<size_t, string>> ErrorGroups;

void AddErrorGroup(ErrorGroups &groups, size_t &line, string &ref, string &hyp);
void WriteSbs(wer_alignment &topAlignment, const vector<Stitching>& stitches, string sbs_filename);
void WriteSbs(wer_alignment &topAlignment, const vector<Stitching>& stitches, string sbs_filename, const vector<string> extra_ref_columns, const vector<string> extra_hyp_columns);
void JsonLogUnigramBigramStats(wer_alignment &topAlignment);

0 comments on commit 3446afd

Please sign in to comment.