Skip to content

Commit

Permalink
Add a debug feature to show ranking change by rescoring.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 543657356
  • Loading branch information
Noriyuki Takahashi authored and hiroyuki-komatsu committed Jun 27, 2023
1 parent be09920 commit 08e58a6
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/converter/segments.h
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,8 @@ class Segment final {
std::vector<uint32_t> inner_segment_boundary;
// LINT.ThenChange(//converter/segments_matchers.h)

// The original cost before rescoring. Used for debugging purpose.
int32_t cost_before_rescoring = 0;
#ifdef MOZC_CANDIDATE_DEBUG
void Dlog(absl::string_view filename, int line,
absl::string_view message) const;
Expand Down
43 changes: 43 additions & 0 deletions src/prediction/dictionary_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,10 @@ bool DictionaryPredictor::AddPredictionToCandidates(

MaybeMoveLiteralCandidateToTop(request, segments);

if (rescorer_ != nullptr && IsDebug(request)) {
AddRescoringDebugDescription(segments);
}

return added > 0;
#undef MOZC_ADD_DEBUG_CANDIDATE
}
Expand Down Expand Up @@ -700,6 +704,7 @@ void DictionaryPredictor::FillCandidate(
auto it = merged_types.find(result.value);
SetDebugDescription(it == merged_types.end() ? 0 : it->second,
&candidate->description);
candidate->cost_before_rescoring = result.cost_before_rescoring;
}
#ifdef MOZC_DEBUG
candidate->log += "\n" + result.log;
Expand Down Expand Up @@ -1279,6 +1284,9 @@ void DictionaryPredictor::MaybeRescoreResults(
const ConversionRequest &request, const Segments &segments,
absl::Span<Result> results) const {
if (!rescorer_) return;
if (IsDebug(request)) {
for (Result &r : results) r.cost_before_rescoring = r.cost;
}
// Concatenate top values of history segments.
std::string history;
for (size_t i = 0; i < segments.history_segments_size(); ++i) {
Expand All @@ -1289,6 +1297,41 @@ void DictionaryPredictor::MaybeRescoreResults(
rescorer_->RescoreResults(request, history, results);
}

void DictionaryPredictor::AddRescoringDebugDescription(Segments *segments) {
if (segments->conversion_segments_size() == 0) {
return;
}
Segment *seg = segments->mutable_conversion_segment(0);
if (seg->candidates_size() == 0) {
return;
}
// Calculate the ranking by the original costs. Note: this can be slightly
// different from the actual ranking because, when the candidates were
// generated, `filter.ShouldRemove()` was applied to the results ordered by
// the rescored costs. To get the true original ranking, we need to apply
// `filter.ShouldRemove()` to the results ordered by the original cost.
// This is just for debugging, so such difference won't matter.
std::vector<const Segment::Candidate *> cands;
cands.reserve(seg->candidates_size());
for (int i = 0; i < seg->candidates_size(); ++i) {
cands.push_back(&seg->candidate(i));
}
std::sort(cands.begin(), cands.end(),
[](const Segment::Candidate *l, const Segment::Candidate *r) {
return l->cost_before_rescoring < r->cost_before_rescoring;
});
absl::flat_hash_map<const Segment::Candidate *, size_t> orig_rank;
for (size_t i = 0; i < cands.size(); ++i) orig_rank[cands[i]] = i + 1;

// Populate the debug description.
for (size_t i = 0; i < seg->candidates_size(); ++i) {
Segment::Candidate *c = seg->mutable_candidate(i);
const size_t rank = i + 1;
Util::AppendStringWithDelimiter(" ", absl::StrCat(orig_rank[c], "", rank),
&c->description);
}
}

} // namespace mozc::prediction

#undef MOZC_WORD_LOG_MESSAGE
Expand Down
1 change: 1 addition & 0 deletions src/prediction/dictionary_predictor.h
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ class DictionaryPredictor : public PredictorInterface {
void MaybeRescoreResults(const ConversionRequest &request,
const Segments &segments,
absl::Span<Result> results) const;
static void AddRescoringDebugDescription(Segments *segments);

// Test peer to access private methods
friend class DictionaryPredictorTestPeer;
Expand Down
30 changes: 30 additions & 0 deletions src/prediction/dictionary_predictor_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,10 @@ class DictionaryPredictorTestPeer {
DictionaryPredictor::MaybeMoveLiteralCandidateToTop(request, segments);
}

static void AddRescoringDebugDescription(Segments *segments) {
DictionaryPredictor::AddRescoringDebugDescription(segments);
}

private:
DictionaryPredictor predictor_;
};
Expand Down Expand Up @@ -1823,5 +1827,31 @@ TEST_F(DictionaryPredictorTest, Rescoring) {
}));
}

TEST_F(DictionaryPredictorTest, AddRescoringDebugDescription) {
Segments segments;
Segment *segment = segments.add_segment();

Segment::Candidate *cand1 = segment->push_back_candidate();
cand1->key = "Cand1";
cand1->cost = 1000;
cand1->cost_before_rescoring = 3000;

Segment::Candidate *cand2 = segment->push_back_candidate();
cand2->key = "Cand2";
cand2->cost = 2000;
cand2->cost_before_rescoring = 2000;

Segment::Candidate *cand3 = segment->push_back_candidate();
cand3->key = "Cand3";
cand3->cost = 3000;
cand3->cost_before_rescoring = 1000;

DictionaryPredictorTestPeer::AddRescoringDebugDescription(&segments);

EXPECT_EQ(cand1->description, "3→1");
EXPECT_EQ(cand2->description, "2→2");
EXPECT_EQ(cand3->description, "1→3");
}

} // namespace
} // namespace mozc::prediction
2 changes: 2 additions & 0 deletions src/prediction/result.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ struct Result {
size_t consumed_key_size = 0;
// The total penalty added to this result.
int penalty = 0;
// The original cost before rescoring. Used for debugging purpose.
int cost_before_rescoring = 0;
// If removed is true, this result is not used for a candidate.
bool removed = false;
#ifndef NDEBUG
Expand Down

0 comments on commit 08e58a6

Please sign in to comment.