Skip to content

Commit 363deb8

Browse files
authored
Bug: WER sidecar info not appearing in SBS (#55)
* add test * add test * fix * Add and use wer tag data structure * fix test * Remove debug log * remove unigram and bigram info from sbs output * fix log json missing unigram bigram info if output sbs not set * version bump
1 parent fced1d9 commit 363deb8

16 files changed

+129
-178
lines changed

src/Nlp.cpp

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,24 +30,21 @@ NlpFstLoader::NlpFstLoader(std::vector<RawNlpRecord> &records, Json::Value norma
3030

3131
// fuse multiple rows that have the same id/label into one entry only
3232
for (auto &row : records) {
33-
mNlpRows.push_back(row);
3433
auto curr_tk = row.token;
3534
auto curr_label = row.best_label;
3635
auto curr_label_id = row.best_label_id;
3736
auto punctuation = row.punctuation;
3837
auto curr_row_tags = row.wer_tags;
3938

4039
// Update wer tags in records to real string labels
41-
vector<string> real_wer_tags;
4240
for (auto &tag : curr_row_tags) {
43-
auto real_tag = tag;
4441
if (mWerSidecar != Json::nullValue) {
45-
real_tag = "###" + real_tag + "_" + mWerSidecar[real_tag]["entity_type"].asString() + "###";
42+
tag.entity_type = mWerSidecar[tag.tag_id]["entity_type"].asString();
4643
}
47-
real_wer_tags.push_back(real_tag);
4844
}
49-
row.wer_tags = real_wer_tags;
45+
row.wer_tags = curr_row_tags;
5046
std::string speaker = row.speakerId;
47+
mNlpRows.push_back(row);
5148

5249
if (processLabels && curr_label != "") {
5350
if (firstTk || curr_label != last_label) {
@@ -411,17 +408,18 @@ std::string NlpReader::GetBestLabel(std::string &labels) {
411408
return labels;
412409
}
413410

414-
std::vector<std::string> NlpReader::GetWerTags(std::string &wer_tags_str) {
415-
std::vector<std::string> wer_tags;
411+
std::vector<WerTagEntry> NlpReader::GetWerTags(std::string &wer_tags_str) {
412+
std::vector<WerTagEntry> wer_tags;
416413
if (wer_tags_str == "[]") {
417414
return wer_tags;
418415
}
419416
// wer_tags_str looks like: ['89', '90', '100']
420417
int current_pos = 2;
421418
auto pos = wer_tags_str.find("'", current_pos);
422419
while (pos != -1) {
423-
std::string wer_tag = wer_tags_str.substr(current_pos, pos - current_pos);
424-
wer_tags.push_back(wer_tag);
420+
WerTagEntry entry;
421+
entry.tag_id = wer_tags_str.substr(current_pos, pos - current_pos);
422+
wer_tags.push_back(entry);
425423
current_pos = wer_tags_str.find("'", pos + 1) + 1;
426424
if (current_pos == 0) {
427425
break;

src/Nlp.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@
1616
using namespace std;
1717
using namespace fst;
1818

19+
struct WerTagEntry {
20+
string tag_id;
21+
string entity_type;
22+
};
23+
1924
struct RawNlpRecord {
2025
string token;
2126
string speakerId;
@@ -27,7 +32,7 @@ struct RawNlpRecord {
2732
string labels;
2833
string best_label;
2934
string best_label_id;
30-
vector<string> wer_tags;
35+
vector<WerTagEntry> wer_tags;
3136
string confidence;
3237
};
3338

@@ -37,7 +42,7 @@ class NlpReader {
3742
virtual ~NlpReader();
3843
vector<RawNlpRecord> read_from_disk(const std::string &filename);
3944
string GetBestLabel(std::string &labels);
40-
vector<string> GetWerTags(std::string &wer_tags_str);
45+
vector<WerTagEntry> GetWerTags(std::string &wer_tags_str);
4146
string GetLabelId(std::string &label);
4247
};
4348

src/fstalign.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -619,7 +619,7 @@ void write_stitches_to_nlp(vector<Stitching>& stitches, ofstream &output_nlp_fil
619619
<< "[";
620620
/* for (auto wer_tag : nlpRow.wer_tags) { */
621621
for (auto it = stitch.nlpRow.wer_tags.begin(); it != stitch.nlpRow.wer_tags.end(); ++it) {
622-
output_nlp_file << "'" << *it << "'";
622+
output_nlp_file << "'" << it->tag_id << "'";
623623
if (std::next(it) != stitch.nlpRow.wer_tags.end()) {
624624
output_nlp_file << ", ";
625625
}
@@ -695,6 +695,7 @@ void HandleWer(FstLoader& refLoader, FstLoader& hypLoader, SynonymEngine &engine
695695
}
696696
}
697697

698+
JsonLogUnigramBigramStats(topAlignment);
698699
if (!output_sbs.empty()) {
699700
logger->info("output_sbs = {}", output_sbs);
700701
WriteSbs(topAlignment, stitches, output_sbs);

src/version.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#pragma once
22

33
#define FSTALIGNER_VERSION_MAJOR 1
4-
#define FSTALIGNER_VERSION_MINOR 12
4+
#define FSTALIGNER_VERSION_MINOR 13
55
#define FSTALIGNER_VERSION_PATCH 0

src/wer.cpp

Lines changed: 10 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -350,19 +350,16 @@ void RecordTagWer(const vector<Stitching>& stitches) {
350350
for (const auto &stitch : stitches) {
351351
if (!stitch.nlpRow.wer_tags.empty()) {
352352
for (auto wer_tag : stitch.nlpRow.wer_tags) {
353-
int tag_start = wer_tag.find_first_not_of('#');
354-
int tag_end = wer_tag.find('_');
355-
string wer_tag_id = wer_tag.substr(tag_start, tag_end - tag_start);
356-
wer_results.insert(std::pair<std::string, WerResult>(wer_tag_id, {0, 0, 0, 0, 0}));
353+
wer_results.insert(std::pair<std::string, WerResult>(wer_tag.tag_id, {0, 0, 0, 0, 0}));
357354
// Check with rfind since other comments can be there
358355
bool del = stitch.comment.rfind("del", 0) == 0;
359356
bool ins = stitch.comment.rfind("ins", 0) == 0;
360357
bool sub = stitch.comment.rfind("sub", 0) == 0;
361-
wer_results[wer_tag_id].insertions += ins;
362-
wer_results[wer_tag_id].deletions += del;
363-
wer_results[wer_tag_id].substitutions += sub;
358+
wer_results[wer_tag.tag_id].insertions += ins;
359+
wer_results[wer_tag.tag_id].deletions += del;
360+
wer_results[wer_tag.tag_id].substitutions += sub;
364361
if (!ins) {
365-
wer_results[wer_tag_id].numWordsInReference += 1;
362+
wer_results[wer_tag.tag_id].numWordsInReference += 1;
366363
}
367364
}
368365
}
@@ -555,7 +552,7 @@ void WriteSbs(wer_alignment &topAlignment, const vector<Stitching>& stitches, st
555552
string tk_wer_tags = "";
556553
auto wer_tags = p_stitch.nlpRow.wer_tags;
557554
for (auto wer_tag: wer_tags) {
558-
tk_wer_tags = tk_wer_tags + wer_tag + "|";
555+
tk_wer_tags = tk_wer_tags + "###" + wer_tag.tag_id + "_" + wer_tag.entity_type + "###|";
559556
}
560557
string ref_tk = p_stitch.reftk;
561558
string hyp_tk = p_stitch.hyptk;
@@ -606,6 +603,10 @@ void WriteSbs(wer_alignment &topAlignment, const vector<Stitching>& stitches, st
606603
myfile << fmt::format("{0:>20}\t{1}", group.first, group.second) << endl;
607604
}
608605

606+
myfile.close();
607+
}
608+
609+
void JsonLogUnigramBigramStats(wer_alignment &topAlignment) {
609610
for (const auto &a : topAlignment.unigram_stats) {
610611
string word = a.first;
611612
gram_error_counter u = a.second;
@@ -617,18 +618,6 @@ void WriteSbs(wer_alignment &topAlignment, const vector<Stitching>& stitches, st
617618
jsonLogger::JsonLogger::getLogger().root["wer"]["unigrams"][word]["precision"] = u.precision;
618619
jsonLogger::JsonLogger::getLogger().root["wer"]["unigrams"][word]["recall"] = u.recall;
619620
}
620-
// output error unigrams
621-
myfile << string(60, '-') << endl << fmt::format("{0:>20}\t{1:10}\t{2:10}", "Unigram", "Prec.", "Recall") << endl;
622-
for (const auto &a : topAlignment.unigram_stats) {
623-
string word = a.first;
624-
gram_error_counter u = a.second;
625-
myfile << fmt::format("{0:>20}\t{1}/{2} ({3:.1f} %)\t{4}/{5} ({6:.1f} %)", word, u.correct,
626-
(u.correct + u.ins + u.subst_fp), (float)u.precision, u.correct, (u.correct + u.del + u.subst_fn),
627-
(float)u.recall)
628-
<< endl;
629-
}
630-
631-
myfile << string(60, '-') << endl << fmt::format("{0:>20}\t{1:20}\t{2:20}", "Bigram", "Precision", "Recall") << endl;
632621

633622
for (const auto &a : topAlignment.bigrams_stats) {
634623
string word = a.first;
@@ -641,14 +630,4 @@ void WriteSbs(wer_alignment &topAlignment, const vector<Stitching>& stitches, st
641630
jsonLogger::JsonLogger::getLogger().root["wer"]["bigrams"][word]["precision"] = u.precision;
642631
jsonLogger::JsonLogger::getLogger().root["wer"]["bigrams"][word]["recall"] = u.recall;
643632
}
644-
for (const auto &a : topAlignment.bigrams_stats) {
645-
string word = a.first;
646-
gram_error_counter u = a.second;
647-
myfile << fmt::format("{0:>20}\t{1}/{2} ({3:.1f} %)\t{4}/{5} ({6:.1f} %)", word, u.correct,
648-
(u.correct + u.ins + u.subst_fp), (float)u.precision, u.correct, (u.correct + u.del + u.subst_fn),
649-
(float)u.recall)
650-
<< endl;
651-
}
652-
653-
myfile.close();
654633
}

src/wer.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,3 +50,4 @@ typedef vector<pair<size_t, string>> ErrorGroups;
5050

5151
void AddErrorGroup(ErrorGroups &groups, size_t &line, string &ref, string &hyp);
5252
void WriteSbs(wer_alignment &topAlignment, const vector<Stitching>& stitches, string sbs_filename);
53+
void JsonLogUnigramBigramStats(wer_alignment &topAlignment);

test/data/short.aligned.case.nlp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ sure|1|0.0000|0.0000|.||LC|[]|[]||||
2323
When|1|0.0000|0.0000|||UC|[]|[]||||
2424
I|1|0.0000|0.0000|||CA|[]|[]||||
2525
hear|1|0.0000|0.0000|||LC|[]|[]||||
26-
Foobar|1|0.0000|0.0000|,||UC|[]|[]||||
26+
Foobar|1|0.0000|0.0000|,||UC|[]|['1', '2']||||
2727
I|1|0.0000|0.0000|||CA|[]|[]||||
2828
think|1|0.0000|0.0000|||LC|[]|[]||||
2929
about|1|0.0000|0.0000|||LC|[]|[]||||

test/data/short.aligned.punc.nlp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ sure|1|0.0000|0.0000|.||LC|[]|[]||||
3131
When|1|0.0000|0.0000|||UC|[]|[]||||
3232
I|1|0.0000|0.0000|||CA|[]|[]||||
3333
hear|1|0.0000|0.0000|||LC|[]|[]||||
34-
Foobar|1|0.0000|0.0000|,||UC|[]|[]||||
34+
Foobar|1|0.0000|0.0000|,||UC|[]|['1', '2']||||
3535
,|1|0.0000|0.0000|||||[]||||
3636
I|1|0.0000|0.0000|||CA|[]|[]||||
3737
think|1|0.0000|0.0000|||LC|[]|[]||||

test/data/short.aligned.punc_case.nlp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ sure|1|0.0000|0.0000|.||LC|[]|[]||||
3131
When|1|0.0000|0.0000|||UC|[]|[]||||
3232
I|1|0.0000|0.0000|||CA|[]|[]||||
3333
hear|1|0.0000|0.0000|||LC|[]|[]||||
34-
Foobar|1|0.0000|0.0000|,||UC|[]|[]||||
34+
Foobar|1|0.0000|0.0000|,||UC|[]|['1', '2']||||
3535
,|1|0.0000|0.0000|||||[]||||
3636
I|1|0.0000|0.0000|||CA|[]|[]||||
3737
think|1|0.0000|0.0000|||LC|[]|[]||||

test/data/short.sbs.txt

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
ref_token hyp_token IsErr Class Wer_Tag_Entities
2+
<crosstalk> <crosstalk>
3+
Yeah Yeah
4+
, ,
5+
yeah <del> ERR
6+
, <del> ERR
7+
right right
8+
. <del> ERR
9+
Yeah <del> ERR
10+
, <del> ERR
11+
all <del> ERR
12+
right <del> ERR
13+
, I'll ERR
14+
probably do ERR
15+
just just
16+
that that
17+
. ? ERR
18+
Are Are
19+
there there
20+
any any
21+
visuals visuals
22+
that that
23+
come come
24+
to to
25+
mind mind
26+
or or ___100002_SYN_1-1___
27+
<ins> ? ERR
28+
Yeah Yeah
29+
, ,
30+
sure sure
31+
. .
32+
When When
33+
I I
34+
hear hear
35+
Foobar Foobar ###1_PROPER_NOUN###|###2_SPACY>ORG###|
36+
, ,
37+
I I
38+
think think
39+
about about
40+
just just
41+
that that
42+
: :
43+
<ins> Foobar ERR
44+
foo , ERR
45+
a a
46+
------------------------------------------------------------
47+
Line Group
48+
5 yeah , <-> ***
49+
8 . Yeah , all right , probably <-> I'll do
50+
17 . <-> ?
51+
27 *** <-> ?
52+
43 foo <-> Foobar ,

0 commit comments

Comments
 (0)