Skip to content

Commit

Permalink
refactor token properties inside stitching
Browse files Browse the repository at this point in the history
  • Loading branch information
nishchalb committed Oct 8, 2024
1 parent bcb9f65 commit cb346d7
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 50 deletions.
65 changes: 40 additions & 25 deletions src/fstalign.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -242,15 +242,15 @@ vector<Stitching> make_stitches(wer_alignment &alignment, vector<RawCtmRecord> h
stitches.emplace_back();
Stitching &part = stitches.back();
part.classLabel = tk_classLabel;
part.reftk = ref_tk;
part.hyptk = hyp_tk;
part.reftk = {ref_tk};
part.hyptk = {hyp_tk};
bool del = false, ins = false, sub = false;
if (ref_tk == INS) {
part.comment = "ins";
} else if (hyp_tk == DEL) {
part.comment = "del";
} else if (hyp_tk != ref_tk) {
part.comment = "sub(" + part.hyptk + ")";
part.comment = "sub(" + part.hyptk.token + ")";
}

// for classes, we will have only one token in the global vector
Expand Down Expand Up @@ -281,10 +281,10 @@ vector<Stitching> make_stitches(wer_alignment &alignment, vector<RawCtmRecord> h

if (!hyp_ctm_rows.empty()) {
auto ctmPart = hyp_ctm_rows[hypRowIndex];
part.start_ts = ctmPart.start_time_secs;
part.duration = ctmPart.duration_secs;
part.end_ts = ctmPart.start_time_secs + ctmPart.duration_secs;
part.confidence = ctmPart.confidence;
part.hyptk.start_ts = ctmPart.start_time_secs;
part.hyptk.duration = ctmPart.duration_secs;
part.hyptk.end_ts = ctmPart.start_time_secs + ctmPart.duration_secs;
part.hyptk.confidence = ctmPart.confidence;

part.hyp_orig = ctmPart.word;
// sanity check
Expand All @@ -308,21 +308,24 @@ vector<Stitching> make_stitches(wer_alignment &alignment, vector<RawCtmRecord> h
float ts = stof(hypNlpPart.ts);
float endTs = stof(hypNlpPart.endTs);

part.start_ts = ts;
part.end_ts = endTs;
part.duration = endTs - ts;
part.hyptk.start_ts = ts;
part.hyptk.end_ts = endTs;
part.hyptk.duration = endTs - ts;
part.hyptk.confidence = stof(hypNlpPart.confidence);
} else if (!hypNlpPart.ts.empty()) {
float ts = stof(hypNlpPart.ts);

part.start_ts = ts;
part.end_ts = ts;
part.duration = 0.0;
part.hyptk.start_ts = ts;
part.hyptk.end_ts = ts;
part.hyptk.duration = 0.0;
part.hyptk.confidence = stof(hypNlpPart.confidence);
} else if (!hypNlpPart.endTs.empty()) {
float endTs = stof(hypNlpPart.endTs);

part.start_ts = endTs;
part.end_ts = endTs;
part.duration = 0.0;
part.hyptk.start_ts = endTs;
part.hyptk.end_ts = endTs;
part.hyptk.duration = 0.0;
part.hyptk.confidence = stof(hypNlpPart.confidence);
}
}

Expand Down Expand Up @@ -575,15 +578,15 @@ void write_stitches_to_nlp(vector<Stitching>& stitches, ofstream &output_nlp_fil
// if the comment starts with 'ins'
if (stitch.comment.find("ins") == 0 && !add_inserts) {
// there's no nlp row info for such case, let's skip over it
if (stitch.confidence >= 1) {
logger->warn("an insertion with high confidence was found for {}@{}", stitch.hyptk, stitch.start_ts);
if (stitch.hyptk.confidence >= 1) {
logger->warn("an insertion with high confidence was found for {}@{}", stitch.hyptk.token, stitch.hyptk.start_ts);
}

continue;
}

string original_nlp_token = stitch.nlpRow.token;
string ref_tk = stitch.reftk;
string ref_tk = stitch.reftk.token;

// trying to salvage some of the original punctuation in a relatively safe manner
if (iequals(ref_tk, original_nlp_token)) {
Expand All @@ -597,21 +600,21 @@ void write_stitches_to_nlp(vector<Stitching>& stitches, ofstream &output_nlp_fil
ref_tk = original_nlp_token;
} else if (stitch.comment.find("ins") == 0) {
assert(add_inserts);
logger->debug("an insertion was found for {} {}", stitch.hyptk, stitch.comment);
logger->debug("an insertion was found for {} {}", stitch.hyptk.token, stitch.comment);
ref_tk = "";
stitch.comment = "ins(" + stitch.hyptk + ")";
stitch.comment = "ins(" + stitch.hyptk.token + ")";
}

if (ref_tk == NOOP) {
continue;
}

output_nlp_file << ref_tk << "|" << stitch.nlpRow.speakerId << "|";
if (stitch.hyptk == DEL) {
if (stitch.hyptk.token == DEL) {
// we have no ts/endTs data to put...
output_nlp_file << "||";
} else {
output_nlp_file << fmt::format("{0:.4f}", stitch.start_ts) << "|" << fmt::format("{0:.4f}", stitch.end_ts)
output_nlp_file << fmt::format("{0:.4f}", stitch.hyptk.start_ts) << "|" << fmt::format("{0:.4f}", stitch.hyptk.end_ts)
<< "|";
}

Expand Down Expand Up @@ -698,8 +701,8 @@ void HandleWer(FstLoader& refLoader, FstLoader& hypLoader, SynonymEngine &engine
JsonLogUnigramBigramStats(topAlignment);
if (!output_sbs.empty()) {
logger->info("output_sbs = {}", output_sbs);
std::vector<string> extra_nlp_columns = {"confidence"};
WriteSbs(topAlignment, stitches, output_sbs, extra_nlp_columns);
std::vector<string> extra_hyp_columns = {"confidence"};
WriteSbs(topAlignment, stitches, output_sbs, std::vector<string>(),extra_hyp_columns);
}

if (!output_nlp.empty() && !nlp_ref_loader) {
Expand All @@ -721,3 +724,15 @@ void HandleAlign(NlpFstLoader& refLoader, CtmFstLoader& hypLoader, SynonymEngine
align_stitches_to_nlp(refLoader, stitches);
write_stitches_to_nlp(stitches, output_nlp_file, refLoader.mJsonNorm);
}

string GetTokenPropertyAsString(Stitching stitch, bool refToken, string property) {
std::unordered_map<std::string, std::function<string(Token)>> col_name_to_val = {
{"speaker", [](Token tk) {return tk.speaker;}},
{"ts", [](Token tk) {return to_string(tk.start_ts);}},
{"endTs", [](Token tk) {return to_string(tk.end_ts);}},
{"confidence", [](Token tk) {return to_string(tk.confidence);}},
};
if (refToken) return col_name_to_val[property](stitch.reftk);
if (!refToken) return col_name_to_val[property](stitch.hyptk);
return "";
}
20 changes: 14 additions & 6 deletions src/fstalign.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,21 @@ fstalign.h
using namespace std;
using namespace fst;

// Represent information associated with a reference or hypothesis token
struct Token {
string token;
float start_ts=-1.0;
float end_ts=-1.0;
float duration=0.0;
float confidence=-1.0;
string speaker;
};

// Stitchings will be used to represent fstalign output, combining reference,
// hypothesis, and error information into a record-like data structure.
struct Stitching {
string reftk;
string hyptk;
float start_ts;
float end_ts;
float duration;
float confidence;
Token reftk;
Token hyptk;
string classLabel;
RawNlpRecord nlpRow;
string hyp_orig;
Expand Down Expand Up @@ -55,4 +61,6 @@ void HandleWer(FstLoader& refLoader, FstLoader& hypLoader, SynonymEngine &engine
void HandleAlign(NlpFstLoader &refLoader, CtmFstLoader &hypLoader, SynonymEngine &engine, ofstream &output_nlp_file,
AlignerOptions alignerOptions);

string GetTokenPropertyAsString(Stitching stitch, bool refToken, string property);

#endif // __FSTALIGN_H__
33 changes: 15 additions & 18 deletions src/wer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -262,8 +262,8 @@ void RecordCaseWer(const vector<Stitching> &aligned_stitches) {
for (const auto &stitch : aligned_stitches) {
const string &hyp = stitch.hyp_orig;
const string &ref = stitch.nlpRow.token;
const string &reftk = stitch.reftk;
const string &hyptk = stitch.hyptk;
const string &reftk = stitch.reftk.token;
const string &hyptk = stitch.hyptk.token;
const string &ref_casing = stitch.nlpRow.casing;

if (hyptk == DEL || reftk == INS) {
Expand Down Expand Up @@ -526,7 +526,7 @@ void AddErrorGroup(ErrorGroups &groups, size_t &line, string &ref, string &hyp)
hyp = "";
}

void WriteSbs(wer_alignment &topAlignment, const vector<Stitching>& stitches, string sbs_filename, const vector<string> extra_nlp_columns) {
void WriteSbs(wer_alignment &topAlignment, const vector<Stitching>& stitches, string sbs_filename, const vector<string> extra_ref_columns, const vector<string> extra_hyp_columns) {
auto logger = logger::GetOrCreateLogger("wer");
logger->set_level(spdlog::level::info);

Expand All @@ -537,8 +537,11 @@ void WriteSbs(wer_alignment &topAlignment, const vector<Stitching>& stitches, st
string prev_tk_classLabel = "";
logger->info("Side-by-Side alignment info going into {}", sbs_filename);
myfile << fmt::format("{0:>20}\t{1:20}\t{2}\t{3}\t{4}", "ref_token", "hyp_token", "IsErr", "Class", "Wer_Tag_Entities");
for (string col_name: extra_nlp_columns) {
myfile << fmt::format("\t{0}", col_name);
for (string col_name: extra_ref_columns) {
myfile << fmt::format("\tref_{0}", col_name);
}
for (string col_name: extra_hyp_columns) {
myfile << fmt::format("\thyp_{0}", col_name);
}
myfile << endl;

Expand All @@ -549,15 +552,6 @@ void WriteSbs(wer_alignment &topAlignment, const vector<Stitching>& stitches, st
string hyp_err = "";

std::set<std::string> op_set = {"<ins>", "<del>", "<sub>"};
std::unordered_map<std::string, std::function<string(RawNlpRecord)>> nlp_name_to_val = {
{"speaker", [](RawNlpRecord row) {return row.speakerId;}},
{"punctuation", [](RawNlpRecord row) {return row.punctuation;}},
{"prepunctuation", [](RawNlpRecord row) {return row.prepunctuation;}},
{"ts", [](RawNlpRecord row) {return row.ts;}},
{"endTs", [](RawNlpRecord row) {return row.endTs;}},
{"case", [](RawNlpRecord row) {return row.casing;}},
{"confidence", [](RawNlpRecord row) {return row.confidence;}},
};

size_t offset = 2; // line number in output file where first triple starts
for (const auto &p_stitch: stitches) {
Expand All @@ -567,8 +561,8 @@ void WriteSbs(wer_alignment &topAlignment, const vector<Stitching>& stitches, st
for (auto wer_tag: wer_tags) {
tk_wer_tags = tk_wer_tags + "###" + wer_tag.tag_id + "_" + wer_tag.entity_type + "###|";
}
string ref_tk = p_stitch.reftk;
string hyp_tk = p_stitch.hyptk;
string ref_tk = p_stitch.reftk.token;
string hyp_tk = p_stitch.hyptk.token;
string tag = "";

if (ref_tk == NOOP) {
Expand Down Expand Up @@ -602,8 +596,11 @@ void WriteSbs(wer_alignment &topAlignment, const vector<Stitching>& stitches, st

myfile << fmt::format("{0:>20}\t{1:20}\t{2}\t{3}\t{4}", ref_tk, hyp_tk, tag, eff_class, tk_wer_tags);

for (string col_name: extra_nlp_columns) {
myfile << fmt::format("\t{0}", nlp_name_to_val[col_name](p_stitch.nlpRow));
for (string col_name: extra_ref_columns) {
myfile << fmt::format("\t{0}", GetTokenPropertyAsString(p_stitch, true, col_name));
}
for (string col_name: extra_hyp_columns) {
myfile << fmt::format("\t{0}", GetTokenPropertyAsString(p_stitch, false, col_name));
}
myfile << endl;
offset++;
Expand Down
2 changes: 1 addition & 1 deletion src/wer.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,5 +49,5 @@ void CalculatePrecisionRecall(wer_alignment &topAlignment, int threshold);
typedef vector<pair<size_t, string>> ErrorGroups;

void AddErrorGroup(ErrorGroups &groups, size_t &line, string &ref, string &hyp);
void WriteSbs(wer_alignment &topAlignment, const vector<Stitching>& stitches, string sbs_filename, const vector<string> extra_nlp_columns);
void WriteSbs(wer_alignment &topAlignment, const vector<Stitching>& stitches, string sbs_filename, const vector<string> extra_ref_columns, const vector<string> extra_hyp_columns);
void JsonLogUnigramBigramStats(wer_alignment &topAlignment);

0 comments on commit cb346d7

Please sign in to comment.