Skip to content

Commit 467e2ea

Browse files
authored
Add sentence WER for NLP input (#35)
1 parent 57d9962 commit 467e2ea

File tree

3 files changed

+31
-1
lines changed

3 files changed

+31
-1
lines changed

src/fstalign.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -678,9 +678,10 @@ void HandleWer(FstLoader *refLoader, FstLoader *hypLoader, SynonymEngine *engine
678678
RecordSpeakerSwitchWer(stitches, alignerOptions.speaker_switch_context_size);
679679
}
680680

681-
// Calculate and record per-speaker WER
681+
// Calculate and record supplementary WER
682682
RecordSpeakerWer(stitches);
683683
RecordTagWer(stitches);
684+
RecordSentenceWer(stitches);
684685

685686
if (!output_nlp.empty()) {
686687
ofstream nlp_ostream(output_nlp);

src/wer.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,34 @@ void RecordWer(spWERA topAlignment) {
7474
}
7575
}
7676

77+
void RecordSentenceWer(vector<shared_ptr<Stitching>> stitches) {
78+
std::set<std::string> eos_punc{".", "?", "!"};
79+
vector<WerResult> sentence_wers;
80+
WerResult curr_wer = {0, 0, 0, 0, 0};
81+
for (auto &stitch : stitches) {
82+
curr_wer.deletions += stitch->comment.rfind("del", 0) == 0;
83+
curr_wer.insertions += stitch->comment.rfind("ins", 0) == 0;
84+
curr_wer.substitutions += stitch->comment.rfind("sub", 0) == 0;
85+
curr_wer.numWordsInReference += stitch->comment.rfind("ins", 0) != 0;
86+
87+
// Check if we hit EOS
88+
if (eos_punc.find(stitch->nlpRow.punctuation) != eos_punc.end()) {
89+
sentence_wers.push_back(curr_wer);
90+
curr_wer = {0, 0, 0, 0, 0};
91+
}
92+
}
93+
// Add last one if its empty case
94+
if (curr_wer.numWordsInReference > 0) {
95+
sentence_wers.push_back(curr_wer);
96+
}
97+
98+
// Add to log
99+
for (int i=0; i < sentence_wers.size(); i++) {
100+
RecordWerResult(jsonLogger::JsonLogger::getLogger().root["wer"]["sentenceWer"][i], sentence_wers[i]);
101+
}
102+
}
103+
104+
77105
void RecordSpeakerWer(vector<shared_ptr<Stitching>> stitches) {
78106
// Note: stitches must have already been aligned to NLP rows
79107
// Logic for segment boundaries copied from speaker switch WER code

src/wer.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ void RecordWerResult(Json::Value &json, WerResult wr);
3939
void RecordWer(spWERA topAlignment);
4040
void RecordSpeakerWer(vector<shared_ptr<Stitching>> stitches);
4141
void RecordSpeakerSwitchWer(vector<shared_ptr<Stitching>> stitches, int speaker_switch_context_size);
42+
void RecordSentenceWer(vector<shared_ptr<Stitching>> stitches);
4243
void RecordTagWer(vector<shared_ptr<Stitching>> stitches);
4344
void RecordCaseWer(vector<shared_ptr<Stitching>> aligned_stitches);
4445

0 commit comments

Comments
 (0)