diff --git a/src/fstalign.cpp b/src/fstalign.cpp index 5283d48..125cb52 100644 --- a/src/fstalign.cpp +++ b/src/fstalign.cpp @@ -242,15 +242,15 @@ vector make_stitches(wer_alignment &alignment, vector h stitches.emplace_back(); Stitching &part = stitches.back(); part.classLabel = tk_classLabel; - part.reftk = ref_tk; - part.hyptk = hyp_tk; + part.reftk = {ref_tk}; + part.hyptk = {hyp_tk}; bool del = false, ins = false, sub = false; if (ref_tk == INS) { part.comment = "ins"; } else if (hyp_tk == DEL) { part.comment = "del"; } else if (hyp_tk != ref_tk) { - part.comment = "sub(" + part.hyptk + ")"; + part.comment = "sub(" + part.hyptk.token + ")"; } // for classes, we will have only one token in the global vector @@ -281,10 +281,10 @@ vector make_stitches(wer_alignment &alignment, vector h if (!hyp_ctm_rows.empty()) { auto ctmPart = hyp_ctm_rows[hypRowIndex]; - part.start_ts = ctmPart.start_time_secs; - part.duration = ctmPart.duration_secs; - part.end_ts = ctmPart.start_time_secs + ctmPart.duration_secs; - part.confidence = ctmPart.confidence; + part.hyptk.start_ts = ctmPart.start_time_secs; + part.hyptk.duration = ctmPart.duration_secs; + part.hyptk.end_ts = ctmPart.start_time_secs + ctmPart.duration_secs; + part.hyptk.confidence = ctmPart.confidence; part.hyp_orig = ctmPart.word; // sanity check @@ -308,21 +308,24 @@ vector make_stitches(wer_alignment &alignment, vector h float ts = stof(hypNlpPart.ts); float endTs = stof(hypNlpPart.endTs); - part.start_ts = ts; - part.end_ts = endTs; - part.duration = endTs - ts; + part.hyptk.start_ts = ts; + part.hyptk.end_ts = endTs; + part.hyptk.duration = endTs - ts; + part.hyptk.confidence = stof(hypNlpPart.confidence); } else if (!hypNlpPart.ts.empty()) { float ts = stof(hypNlpPart.ts); - part.start_ts = ts; - part.end_ts = ts; - part.duration = 0.0; + part.hyptk.start_ts = ts; + part.hyptk.end_ts = ts; + part.hyptk.duration = 0.0; + part.hyptk.confidence = stof(hypNlpPart.confidence); } else if (!hypNlpPart.endTs.empty()) { float endTs = stof(hypNlpPart.endTs); - part.start_ts = endTs; - part.end_ts = endTs; - part.duration = 0.0; + part.hyptk.start_ts = endTs; + part.hyptk.end_ts = endTs; + part.hyptk.duration = 0.0; + part.hyptk.confidence = stof(hypNlpPart.confidence); } } @@ -575,15 +578,15 @@ void write_stitches_to_nlp(vector& stitches, ofstream &output_nlp_fil // if the comment starts with 'ins' if (stitch.comment.find("ins") == 0 && !add_inserts) { // there's no nlp row info for such case, let's skip over it - if (stitch.confidence >= 1) { - logger->warn("an insertion with high confidence was found for {}@{}", stitch.hyptk, stitch.start_ts); + if (stitch.hyptk.confidence >= 1) { + logger->warn("an insertion with high confidence was found for {}@{}", stitch.hyptk.token, stitch.hyptk.start_ts); } continue; } string original_nlp_token = stitch.nlpRow.token; - string ref_tk = stitch.reftk; + string ref_tk = stitch.reftk.token; // trying to salvage some of the original punctuation in a relatively safe manner if (iequals(ref_tk, original_nlp_token)) { @@ -597,9 +600,9 @@ void write_stitches_to_nlp(vector& stitches, ofstream &output_nlp_fil ref_tk = original_nlp_token; } else if (stitch.comment.find("ins") == 0) { assert(add_inserts); - logger->debug("an insertion was found for {} {}", stitch.hyptk, stitch.comment); + logger->debug("an insertion was found for {} {}", stitch.hyptk.token, stitch.comment); ref_tk = ""; - stitch.comment = "ins(" + stitch.hyptk + ")"; + stitch.comment = "ins(" + stitch.hyptk.token + ")"; } if (ref_tk == NOOP) { @@ -607,11 +610,11 @@ void write_stitches_to_nlp(vector& stitches, ofstream &output_nlp_fil } output_nlp_file << ref_tk << "|" << stitch.nlpRow.speakerId << "|"; - if (stitch.hyptk == DEL) { + if (stitch.hyptk.token == DEL) { // we have no ts/endTs data to put... output_nlp_file << "||"; } else { - output_nlp_file << fmt::format("{0:.4f}", stitch.start_ts) << "|" << fmt::format("{0:.4f}", stitch.end_ts) + output_nlp_file << fmt::format("{0:.4f}", stitch.hyptk.start_ts) << "|" << fmt::format("{0:.4f}", stitch.hyptk.end_ts) << "|"; } @@ -698,8 +701,8 @@ void HandleWer(FstLoader& refLoader, FstLoader& hypLoader, SynonymEngine &engine JsonLogUnigramBigramStats(topAlignment); if (!output_sbs.empty()) { logger->info("output_sbs = {}", output_sbs); - std::vector extra_nlp_columns = {"confidence"}; - WriteSbs(topAlignment, stitches, output_sbs, extra_nlp_columns); + std::vector extra_hyp_columns = {"confidence"}; + WriteSbs(topAlignment, stitches, output_sbs, std::vector(),extra_hyp_columns); } if (!output_nlp.empty() && !nlp_ref_loader) { @@ -721,3 +724,15 @@ void HandleAlign(NlpFstLoader& refLoader, CtmFstLoader& hypLoader, SynonymEngine align_stitches_to_nlp(refLoader, stitches); write_stitches_to_nlp(stitches, output_nlp_file, refLoader.mJsonNorm); } + +string GetTokenPropertyAsString(Stitching stitch, bool refToken, string property) { + std::unordered_map> col_name_to_val = { + {"speaker", [](Token tk) {return tk.speaker;}}, + {"ts", [](Token tk) {return to_string(tk.start_ts);}}, + {"endTs", [](Token tk) {return to_string(tk.end_ts);}}, + {"confidence", [](Token tk) {return to_string(tk.confidence);}}, + }; + if (refToken) return col_name_to_val[property](stitch.reftk); + if (!refToken) return col_name_to_val[property](stitch.hyptk); + return ""; +} diff --git a/src/fstalign.h b/src/fstalign.h index 0320785..5cbe130 100644 --- a/src/fstalign.h +++ b/src/fstalign.h @@ -15,15 +15,21 @@ fstalign.h using namespace std; using namespace fst; +// Represent information associated with a reference or hypothesis token +struct Token { + string token; + float start_ts=-1.0; + float end_ts=-1.0; + float duration=0.0; + float confidence=-1.0; + string speaker; +}; + // Stitchings will be used to represent fstalign output, combining reference, // hypothesis, and error information into a record-like data structure. struct Stitching { - string reftk; - string hyptk; - float start_ts; - float end_ts; - float duration; - float confidence; + Token reftk; + Token hyptk; string classLabel; RawNlpRecord nlpRow; string hyp_orig; @@ -55,4 +61,6 @@ void HandleWer(FstLoader& refLoader, FstLoader& hypLoader, SynonymEngine &engine void HandleAlign(NlpFstLoader &refLoader, CtmFstLoader &hypLoader, SynonymEngine &engine, ofstream &output_nlp_file, AlignerOptions alignerOptions); +string GetTokenPropertyAsString(Stitching stitch, bool refToken, string property); + #endif // __FSTALIGN_H__ diff --git a/src/wer.cpp b/src/wer.cpp index eb03803..f326c5d 100644 --- a/src/wer.cpp +++ b/src/wer.cpp @@ -262,8 +262,8 @@ void RecordCaseWer(const vector &aligned_stitches) { for (const auto &stitch : aligned_stitches) { const string &hyp = stitch.hyp_orig; const string &ref = stitch.nlpRow.token; - const string &reftk = stitch.reftk; - const string &hyptk = stitch.hyptk; + const string &reftk = stitch.reftk.token; + const string &hyptk = stitch.hyptk.token; const string &ref_casing = stitch.nlpRow.casing; if (hyptk == DEL || reftk == INS) { @@ -526,7 +526,7 @@ void AddErrorGroup(ErrorGroups &groups, size_t &line, string &ref, string &hyp) hyp = ""; } -void WriteSbs(wer_alignment &topAlignment, const vector& stitches, string sbs_filename, const vector extra_nlp_columns) { +void WriteSbs(wer_alignment &topAlignment, const vector& stitches, string sbs_filename, const vector extra_ref_columns, const vector extra_hyp_columns) { auto logger = logger::GetOrCreateLogger("wer"); logger->set_level(spdlog::level::info); @@ -537,8 +537,11 @@ void WriteSbs(wer_alignment &topAlignment, const vector& stitches, st string prev_tk_classLabel = ""; logger->info("Side-by-Side alignment info going into {}", sbs_filename); myfile << fmt::format("{0:>20}\t{1:20}\t{2}\t{3}\t{4}", "ref_token", "hyp_token", "IsErr", "Class", "Wer_Tag_Entities"); - for (string col_name: extra_nlp_columns) { - myfile << fmt::format("\t{0}", col_name); + for (string col_name: extra_ref_columns) { + myfile << fmt::format("\tref_{0}", col_name); + } + for (string col_name: extra_hyp_columns) { + myfile << fmt::format("\thyp_{0}", col_name); } myfile << endl; @@ -549,15 +552,6 @@ void WriteSbs(wer_alignment &topAlignment, const vector& stitches, st string hyp_err = ""; std::set op_set = {"", "", ""}; - std::unordered_map> nlp_name_to_val = { - {"speaker", [](RawNlpRecord row) {return row.speakerId;}}, - {"punctuation", [](RawNlpRecord row) {return row.punctuation;}}, - {"prepunctuation", [](RawNlpRecord row) {return row.prepunctuation;}}, - {"ts", [](RawNlpRecord row) {return row.ts;}}, - {"endTs", [](RawNlpRecord row) {return row.endTs;}}, - {"case", [](RawNlpRecord row) {return row.casing;}}, - {"confidence", [](RawNlpRecord row) {return row.confidence;}}, - }; size_t offset = 2; // line number in output file where first triple starts for (const auto &p_stitch: stitches) { @@ -567,8 +561,8 @@ void WriteSbs(wer_alignment &topAlignment, const vector& stitches, st for (auto wer_tag: wer_tags) { tk_wer_tags = tk_wer_tags + "###" + wer_tag.tag_id + "_" + wer_tag.entity_type + "###|"; } - string ref_tk = p_stitch.reftk; - string hyp_tk = p_stitch.hyptk; + string ref_tk = p_stitch.reftk.token; + string hyp_tk = p_stitch.hyptk.token; string tag = ""; if (ref_tk == NOOP) { @@ -602,8 +596,11 @@ void WriteSbs(wer_alignment &topAlignment, const vector& stitches, st myfile << fmt::format("{0:>20}\t{1:20}\t{2}\t{3}\t{4}", ref_tk, hyp_tk, tag, eff_class, tk_wer_tags); - for (string col_name: extra_nlp_columns) { - myfile << fmt::format("\t{0}", nlp_name_to_val[col_name](p_stitch.nlpRow)); + for (string col_name: extra_ref_columns) { + myfile << fmt::format("\t{0}", GetTokenPropertyAsString(p_stitch, true, col_name)); + } + for (string col_name: extra_hyp_columns) { + myfile << fmt::format("\t{0}", GetTokenPropertyAsString(p_stitch, false, col_name)); } myfile << endl; offset++; diff --git a/src/wer.h b/src/wer.h index ffcd804..b66e978 100644 --- a/src/wer.h +++ b/src/wer.h @@ -49,5 +49,5 @@ void CalculatePrecisionRecall(wer_alignment &topAlignment, int threshold); typedef vector> ErrorGroups; void AddErrorGroup(ErrorGroups &groups, size_t &line, string &ref, string &hyp); -void WriteSbs(wer_alignment &topAlignment, const vector& stitches, string sbs_filename, const vector extra_nlp_columns); +void WriteSbs(wer_alignment &topAlignment, const vector& stitches, string sbs_filename, const vector extra_ref_columns, const vector extra_hyp_columns); void JsonLogUnigramBigramStats(wer_alignment &topAlignment);